diff --git a/.gitattributes b/.gitattributes index f3f46d83d773f725a9678a5cf514f84cde035809..2c029d9fd2330e6b128281245ab0169ce2753f2a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -112,3 +112,6 @@ phivenv/Lib/site-packages/torch/lib/cpuinfo.lib filter=lfs diff=lfs merge=lfs -t phivenv/Lib/site-packages/torch/lib/fbgemm.dll filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/torch/lib/c10.dll filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/fmt.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/libittnotify.lib filter=lfs diff=lfs merge=lfs -text diff --git a/phivenv/Lib/site-packages/torch/_C/_VariableFunctions.pyi b/phivenv/Lib/site-packages/torch/_C/_VariableFunctions.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f7429c5df2fede845357a319b6a03ff038ed3a7c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_VariableFunctions.pyi @@ -0,0 +1,32949 @@ +# @generated by tools/pyi/gen_pyi.py from torch/_C/_VariableFunctions.pyi.in +# mypy: disable-error-code="type-arg" +# mypy: allow-untyped-defs +# ruff: noqa: F401,PYI054 + +from collections.abc import Sequence +from types import EllipsisType +from typing import Any, Callable, Literal, overload, TypeVar + +import torch +from torch import ( + contiguous_format, + Generator, + inf, + memory_format, + strided, + SymInt, + Tensor, +) +from torch._prims_common import DeviceLikeType +from torch.types import ( + _bool, + _complex, + _device, + _dtype, + _float, + _int, + _layout, + _qscheme, + _size, + Device, + Number, +) + +__all__ = [ + "__and__", + "__lshift__", + "__or__", + "__rshift__", + "__xor__", + "_adaptive_avg_pool2d", + "_adaptive_avg_pool3d", + "_add_batch_dim", + "_add_relu", + "_add_relu_", + "_addmm_activation", + "_aminmax", + "_amp_foreach_non_finite_check_and_unscale_", + "_amp_update_scale_", + "_assert_async", + "_assert_scalar", + "_assert_tensor_metadata", + "_batch_norm_impl_index", + "_cast_Byte", + "_cast_Char", + "_cast_Double", + "_cast_Float", + "_cast_Half", + "_cast_Int", + "_cast_Long", + "_cast_Short", + "_choose_qparams_per_tensor", + "_chunk_cat", + "_coalesce", + "_compute_linear_combination", + "_conj", + "_conj_copy", + "_conj_physical", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + "_convert_weight_to_int4pack", + "_convert_weight_to_int4pack_for_cpu", + "_convolution", + "_convolution_mode", + "_copy_from", + "_copy_from_and_resize", + "_cslt_compress", + "_cslt_sparse_mm", + "_cslt_sparse_mm_search", + "_ctc_loss", + "_cudnn_ctc_loss", + "_cudnn_init_dropout_state", + "_cudnn_rnn", + "_cudnn_rnn_flatten_weight", + "_cufft_clear_plan_cache", + "_cufft_get_plan_cache_max_size", + "_cufft_get_plan_cache_size", + "_cufft_set_plan_cache_max_size", + "_cummax_helper", + "_cummin_helper", + "_debug_has_internal_overlap", + "_dim_arange", + "_dirichlet_grad", + "_disable_functionalization", + "_dyn_quant_matmul_4bit", + "_dyn_quant_pack_4bit_weight", + "_efficientzerotensor", + "_embedding_bag", + "_embedding_bag_forward_only", + "_empty_affine_quantized", + "_empty_per_channel_affine_quantized", + "_enable_functionalization", + "_euclidean_dist", + "_fake_quantize_learnable_per_channel_affine", + "_fake_quantize_learnable_per_tensor_affine", + "_fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + "_fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + "_fft_c2c", + "_fft_c2r", + "_fft_r2c", + "_fill_mem_eff_dropout_mask_", + "_foobar", + "_foreach_abs", + "_foreach_abs_", + "_foreach_acos", + "_foreach_acos_", + "_foreach_add", + "_foreach_add_", + "_foreach_addcdiv", + "_foreach_addcdiv_", + "_foreach_addcmul", + "_foreach_addcmul_", + "_foreach_asin", + "_foreach_asin_", + "_foreach_atan", + "_foreach_atan_", + "_foreach_ceil", + "_foreach_ceil_", + "_foreach_clamp_max", + "_foreach_clamp_max_", + "_foreach_clamp_min", + "_foreach_clamp_min_", + "_foreach_copy_", + "_foreach_cos", + "_foreach_cos_", + "_foreach_cosh", + "_foreach_cosh_", + "_foreach_div", + "_foreach_div_", + "_foreach_erf", + "_foreach_erf_", + "_foreach_erfc", + "_foreach_erfc_", + "_foreach_exp", + "_foreach_exp_", + "_foreach_expm1", + "_foreach_expm1_", + "_foreach_floor", + "_foreach_floor_", + "_foreach_frac", + "_foreach_frac_", + "_foreach_lerp", + "_foreach_lerp_", + "_foreach_lgamma", + "_foreach_lgamma_", + "_foreach_log", + "_foreach_log10", + "_foreach_log10_", + "_foreach_log1p", + "_foreach_log1p_", + "_foreach_log2", + "_foreach_log2_", + "_foreach_log_", + "_foreach_max", + "_foreach_maximum", + "_foreach_maximum_", + "_foreach_minimum", + "_foreach_minimum_", + "_foreach_mul", + "_foreach_mul_", + "_foreach_neg", + "_foreach_neg_", + "_foreach_norm", + "_foreach_pow", + "_foreach_pow_", + "_foreach_reciprocal", + "_foreach_reciprocal_", + "_foreach_round", + "_foreach_round_", + "_foreach_rsqrt", + "_foreach_rsqrt_", + "_foreach_sigmoid", + "_foreach_sigmoid_", + "_foreach_sign", + "_foreach_sign_", + "_foreach_sin", + "_foreach_sin_", + "_foreach_sinh", + "_foreach_sinh_", + "_foreach_sqrt", + "_foreach_sqrt_", + "_foreach_sub", + "_foreach_sub_", + "_foreach_tan", + "_foreach_tan_", + "_foreach_tanh", + "_foreach_tanh_", + "_foreach_trunc", + "_foreach_trunc_", + "_foreach_zero_", + "_from_functional_tensor", + "_functional_assert_async", + "_functional_assert_scalar", + "_functional_sym_constrain_range", + "_functional_sym_constrain_range_for_size", + "_functionalize_apply_view_metas", + "_functionalize_are_all_mutations_hidden_from_autograd", + "_functionalize_are_all_mutations_under_no_grad_or_inference_mode", + "_functionalize_commit_update", + "_functionalize_has_metadata_mutation", + "_functionalize_is_symbolic", + "_functionalize_mark_mutation_hidden_from_autograd", + "_functionalize_replace", + "_functionalize_set_storage_changed", + "_functionalize_sync", + "_functionalize_unsafe_set", + "_functionalize_was_inductor_storage_resized", + "_functionalize_was_storage_changed", + "_fused_adagrad_", + "_fused_adam_", + "_fused_adamw_", + "_fused_dropout", + "_fused_moving_avg_obs_fq_helper", + "_fused_moving_avg_obs_fq_helper", + "_fused_rms_norm", + "_fused_sdp_choice", + "_fused_sgd_", + "_fw_primal_copy", + "_grid_sampler_2d_cpu_fallback", + "_grouped_mm", + "_has_compatible_shallow_copy_type", + "_histogramdd_bin_edges", + "_histogramdd_from_bin_cts", + "_histogramdd_from_bin_tensors", + "_index_put_impl_", + "_indices_copy", + "_int_mm", + "_is_all_true", + "_is_any_true", + "_is_functional_tensor", + "_is_functional_tensor_base", + "_is_zerotensor", + "_lazy_clone", + "_linalg_check_errors", + "_linalg_det", + "_linalg_det", + "_linalg_eigh", + "_linalg_eigh", + "_linalg_slogdet", + "_linalg_slogdet", + "_linalg_solve_ex", + "_linalg_solve_ex", + "_linalg_svd", + "_linalg_svd", + "_log_softmax", + "_log_softmax_backward_data", + "_logcumsumexp", + "_lstm_mps", + "_lu_with_info", + "_lu_with_info", + "_make_dep_token", + "_make_dual", + "_make_dual_copy", + "_make_per_channel_quantized_tensor", + "_make_per_tensor_quantized_tensor", + "_masked_scale", + "_masked_softmax", + "_mixed_dtypes_linear", + "_mkldnn_reshape", + "_mkldnn_transpose", + "_mkldnn_transpose_", + "_mps_convolution", + "_mps_convolution_transpose", + "_native_batch_norm_legit", + "_native_batch_norm_legit_no_training", + "_native_multi_head_attention", + "_neg_view", + "_neg_view_copy", + "_nested_compute_contiguous_strides_offsets", + "_nested_from_padded", + "_nested_from_padded_and_nested_example", + "_nested_from_padded_tensor", + "_nested_get_jagged_dummy", + "_nested_get_lengths", + "_nested_get_max_seqlen", + "_nested_get_min_seqlen", + "_nested_get_offsets", + "_nested_get_ragged_idx", + "_nested_get_values", + "_nested_get_values_copy", + "_nested_tensor_from_mask", + "_nested_tensor_from_mask_left_aligned", + "_nested_tensor_from_tensor_list", + "_nested_tensor_softmax_with_shape", + "_nested_view_from_buffer", + "_nested_view_from_buffer_copy", + "_nested_view_from_jagged", + "_nested_view_from_jagged_copy", + "_nnpack_available", + "_nnpack_spatial_convolution", + "_pack_padded_sequence", + "_pad_packed_sequence", + "_pin_memory", + "_prelu_kernel", + "_print", + "_propagate_xla_data", + "_remove_batch_dim", + "_reshape_alias_copy", + "_reshape_from_tensor", + "_resize_output_", + "_rowwise_prune", + "_safe_softmax", + "_sample_dirichlet", + "_saturate_weight_to_fp16", + "_scaled_dot_product_attention_math", + "_scaled_dot_product_attention_math_for_mps", + "_scaled_dot_product_cudnn_attention", + "_scaled_dot_product_cudnn_attention", + "_scaled_dot_product_efficient_attention", + "_scaled_dot_product_efficient_attention", + "_scaled_dot_product_flash_attention", + "_scaled_dot_product_flash_attention", + "_scaled_dot_product_flash_attention_for_cpu", + "_scaled_dot_product_flash_attention_for_cpu", + "_scaled_grouped_mm", + "_scaled_mm", + "_shape_as_tensor", + "_sobol_engine_draw", + "_sobol_engine_ff_", + "_sobol_engine_initialize_state_", + "_sobol_engine_scramble_", + "_softmax", + "_softmax_backward_data", + "_sparse_broadcast_to", + "_sparse_broadcast_to_copy", + "_sparse_csr_prod", + "_sparse_csr_sum", + "_sparse_log_softmax_backward_data", + "_sparse_semi_structured_addmm", + "_sparse_semi_structured_apply", + "_sparse_semi_structured_apply_dense", + "_sparse_semi_structured_linear", + "_sparse_semi_structured_mm", + "_sparse_semi_structured_tile", + "_sparse_softmax_backward_data", + "_sparse_sparse_matmul", + "_sparse_sum", + "_stack", + "_standard_gamma", + "_standard_gamma_grad", + "_sync", + "_test_autograd_multiple_dispatch", + "_test_autograd_multiple_dispatch_view", + "_test_autograd_multiple_dispatch_view_copy", + "_test_check_tensor", + "_test_functorch_fallback", + "_test_parallel_materialize", + "_test_serialization_subcmul", + "_to_cpu", + "_to_functional_tensor", + "_to_sparse_semi_structured", + "_transform_bias_rescale_qkv", + "_transformer_encoder_layer_fwd", + "_trilinear", + "_triton_multi_head_attention", + "_triton_scaled_dot_attention", + "_unique", + "_unique2", + "_unpack_dual", + "_unpack_dual", + "_unsafe_index", + "_unsafe_index_put", + "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", + "_use_cudnn_ctc_loss", + "_use_cudnn_rnn_flatten_weight", + "_validate_compressed_sparse_indices", + "_validate_sparse_bsc_tensor_args", + "_validate_sparse_bsr_tensor_args", + "_validate_sparse_compressed_tensor_args", + "_validate_sparse_coo_tensor_args", + "_validate_sparse_csc_tensor_args", + "_validate_sparse_csr_tensor_args", + "_values_copy", + "_weight_int4pack_mm", + "_weight_int4pack_mm_for_cpu", + "_weight_int4pack_mm_with_scales_and_zeros", + "_weight_int8pack_mm", + "_weight_norm", + "_weight_norm_interface", + "_wrapped_linear_prepack", + "_wrapped_quantized_linear_prepacked", + "abs", + "abs_", + "absolute", + "acos", + "acos_", + "acosh", + "acosh_", + "adaptive_avg_pool1d", + "adaptive_max_pool1d", + "add", + "addbmm", + "addcdiv", + "addcmul", + "addmm", + "addmv", + "addmv_", + "addr", + "adjoint", + "affine_grid_generator", + "alias_copy", + "all", + "allclose", + "alpha_dropout", + "alpha_dropout_", + "amax", + "amin", + "aminmax", + "aminmax", + "angle", + "any", + "arange", + "arccos", + "arccos_", + "arccosh", + "arccosh_", + "arcsin", + "arcsin_", + "arcsinh", + "arcsinh_", + "arctan", + "arctan2", + "arctan_", + "arctanh", + "arctanh_", + "argmax", + "argmin", + "argsort", + "argwhere", + "as_strided", + "as_strided_", + "as_strided_copy", + "as_strided_scatter", + "as_tensor", + "asarray", + "asin", + "asin_", + "asinh", + "asinh_", + "atan", + "atan2", + "atan_", + "atanh", + "atanh_", + "avg_pool1d", + "baddbmm", + "bartlett_window", + "batch_norm", + "batch_norm_backward_elemt", + "batch_norm_backward_reduce", + "batch_norm_elemt", + "batch_norm_gather_stats", + "batch_norm_gather_stats_with_counts", + "batch_norm_stats", + "batch_norm_update_stats", + "bernoulli", + "bilinear", + "binary_cross_entropy_with_logits", + "bincount", + "binomial", + "bitwise_and", + "bitwise_left_shift", + "bitwise_not", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "blackman_window", + "bmm", + "broadcast_to", + "bucketize", + "can_cast", + "cat", + "ccol_indices_copy", + "ceil", + "ceil_", + "celu", + "celu_", + "channel_shuffle", + "cholesky", + "cholesky_inverse", + "cholesky_solve", + "choose_qparams_optimized", + "chunk", + "clamp", + "clamp_", + "clamp_max", + "clamp_max_", + "clamp_min", + "clamp_min_", + "clip", + "clip_", + "clone", + "col_indices_copy", + "column_stack", + "combinations", + "complex", + "concat", + "concatenate", + "conj", + "conj_physical", + "conj_physical_", + "constant_pad_nd", + "conv1d", + "conv2d", + "conv3d", + "conv_tbc", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "convolution", + "copysign", + "corrcoef", + "cos", + "cos_", + "cosh", + "cosh_", + "cosine_embedding_loss", + "cosine_similarity", + "count_nonzero", + "cov", + "cross", + "crow_indices_copy", + "ctc_loss", + "cudnn_affine_grid_generator", + "cudnn_batch_norm", + "cudnn_convolution", + "cudnn_convolution_add_relu", + "cudnn_convolution_relu", + "cudnn_convolution_transpose", + "cudnn_grid_sampler", + "cudnn_is_acceptable", + "cummax", + "cummax", + "cummin", + "cummin", + "cumprod", + "cumsum", + "cumulative_trapezoid", + "deg2rad", + "deg2rad_", + "dequantize", + "det", + "detach", + "detach_", + "detach_copy", + "diag", + "diag_embed", + "diagflat", + "diagonal", + "diagonal_copy", + "diagonal_scatter", + "diff", + "digamma", + "dist", + "div", + "divide", + "dot", + "dropout", + "dropout_", + "dsmm", + "dsplit", + "dstack", + "embedding", + "embedding_bag", + "embedding_renorm_", + "empty", + "empty_like", + "empty_permuted", + "empty_quantized", + "empty_strided", + "eq", + "equal", + "erf", + "erf_", + "erfc", + "erfc_", + "erfinv", + "exp", + "exp2", + "exp2_", + "exp_", + "expand_copy", + "expm1", + "expm1_", + "eye", + "fake_quantize_per_channel_affine", + "fake_quantize_per_tensor_affine", + "fbgemm_linear_fp16_weight", + "fbgemm_linear_fp16_weight_fp32_activation", + "fbgemm_linear_int8_weight", + "fbgemm_linear_int8_weight_fp32_activation", + "fbgemm_linear_quantize_weight", + "fbgemm_pack_gemm_matrix_fp16", + "fbgemm_pack_quantized_matrix", + "feature_alpha_dropout", + "feature_alpha_dropout_", + "feature_dropout", + "feature_dropout_", + "fill", + "fill_", + "fix", + "fix_", + "flatten", + "flip", + "fliplr", + "flipud", + "float_power", + "floor", + "floor_", + "floor_divide", + "fmax", + "fmin", + "fmod", + "frac", + "frac_", + "frexp", + "frexp", + "frobenius_norm", + "from_file", + "from_numpy", + "frombuffer", + "full", + "full_like", + "fused_moving_avg_obs_fake_quant", + "gather", + "gcd", + "gcd_", + "ge", + "geqrf", + "geqrf", + "ger", + "get_default_dtype", + "get_num_interop_threads", + "get_num_threads", + "gradient", + "greater", + "greater_equal", + "grid_sampler", + "grid_sampler_2d", + "grid_sampler_3d", + "group_norm", + "gru", + "gru_cell", + "gt", + "hamming_window", + "hann_window", + "hardshrink", + "heaviside", + "hinge_embedding_loss", + "histc", + "histogram", + "histogram", + "histogramdd", + "histogramdd", + "hsmm", + "hsplit", + "hspmm", + "hstack", + "hypot", + "i0", + "i0_", + "igamma", + "igammac", + "imag", + "index_add", + "index_copy", + "index_fill", + "index_put", + "index_put_", + "index_reduce", + "index_select", + "indices_copy", + "init_num_threads", + "inner", + "instance_norm", + "int_repr", + "inverse", + "is_complex", + "is_conj", + "is_distributed", + "is_floating_point", + "is_grad_enabled", + "is_inference", + "is_inference_mode_enabled", + "is_neg", + "is_nonzero", + "is_same_size", + "is_signed", + "is_vulkan_available", + "isclose", + "isfinite", + "isin", + "isinf", + "isnan", + "isneginf", + "isposinf", + "isreal", + "istft", + "kaiser_window", + "kl_div", + "kron", + "kthvalue", + "kthvalue", + "layer_norm", + "lcm", + "lcm_", + "ldexp", + "ldexp_", + "le", + "lerp", + "less", + "less_equal", + "lgamma", + "linspace", + "log", + "log10", + "log10_", + "log1p", + "log1p_", + "log2", + "log2_", + "log_", + "log_softmax", + "logaddexp", + "logaddexp2", + "logcumsumexp", + "logdet", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logit", + "logit_", + "logspace", + "logsumexp", + "lstm", + "lstm_cell", + "lt", + "lu_solve", + "lu_unpack", + "lu_unpack", + "margin_ranking_loss", + "masked_fill", + "masked_scatter", + "masked_select", + "matmul", + "matrix_exp", + "matrix_power", + "max", + "max", + "max_pool1d", + "max_pool1d_with_indices", + "max_pool2d", + "max_pool3d", + "maximum", + "mean", + "median", + "median", + "min", + "min", + "minimum", + "miopen_batch_norm", + "miopen_convolution", + "miopen_convolution_add_relu", + "miopen_convolution_relu", + "miopen_convolution_transpose", + "miopen_depthwise_convolution", + "miopen_rnn", + "mkldnn_adaptive_avg_pool2d", + "mkldnn_convolution", + "mkldnn_linear_backward_weights", + "mkldnn_max_pool2d", + "mkldnn_max_pool3d", + "mkldnn_rnn_layer", + "mm", + "mode", + "mode", + "moveaxis", + "movedim", + "msort", + "mul", + "multinomial", + "multiply", + "mv", + "mvlgamma", + "nan_to_num", + "nan_to_num_", + "nanmean", + "nanmedian", + "nanmedian", + "nanquantile", + "nansum", + "narrow", + "narrow_copy", + "native_batch_norm", + "native_channel_shuffle", + "native_dropout", + "native_group_norm", + "native_layer_norm", + "native_norm", + "ne", + "neg", + "neg_", + "negative", + "negative_", + "nextafter", + "nonzero", + "nonzero_static", + "norm_except_dim", + "normal", + "not_equal", + "nuclear_norm", + "numel", + "ones", + "ones_like", + "orgqr", + "ormqr", + "outer", + "pairwise_distance", + "pdist", + "permute", + "permute_copy", + "pinverse", + "pixel_shuffle", + "pixel_unshuffle", + "poisson", + "poisson_nll_loss", + "polar", + "polygamma", + "positive", + "pow", + "prelu", + "prod", + "promote_types", + "put", + "q_per_channel_axis", + "q_per_channel_scales", + "q_per_channel_zero_points", + "q_scale", + "q_zero_point", + "qr", + "qr", + "quantile", + "quantize_per_channel", + "quantize_per_tensor", + "quantize_per_tensor_dynamic", + "quantized_batch_norm", + "quantized_gru_cell", + "quantized_lstm_cell", + "quantized_max_pool1d", + "quantized_max_pool2d", + "quantized_max_pool3d", + "quantized_rnn_relu_cell", + "quantized_rnn_tanh_cell", + "rad2deg", + "rad2deg_", + "rand", + "rand_like", + "randint", + "randint_like", + "randn", + "randn_like", + "randperm", + "range", + "ravel", + "real", + "reciprocal", + "reciprocal_", + "relu", + "relu_", + "remainder", + "renorm", + "repeat_interleave", + "reshape", + "resize_as_", + "resize_as_sparse_", + "resolve_conj", + "resolve_neg", + "result_type", + "rms_norm", + "rnn_relu", + "rnn_relu_cell", + "rnn_tanh", + "rnn_tanh_cell", + "roll", + "rot90", + "round", + "round_", + "row_indices_copy", + "row_stack", + "rrelu", + "rrelu_", + "rsqrt", + "rsqrt_", + "rsub", + "saddmm", + "scalar_tensor", + "scatter", + "scatter_add", + "scatter_reduce", + "searchsorted", + "segment_reduce", + "select", + "select_copy", + "select_scatter", + "selu", + "selu_", + "set_flush_denormal", + "set_num_interop_threads", + "set_num_threads", + "sgn", + "sigmoid", + "sigmoid_", + "sign", + "signbit", + "sin", + "sin_", + "sinc", + "sinc_", + "sinh", + "sinh_", + "slice_copy", + "slice_inverse", + "slice_scatter", + "slogdet", + "slogdet", + "smm", + "softmax", + "sort", + "sort", + "sparse_bsc_tensor", + "sparse_bsr_tensor", + "sparse_compressed_tensor", + "sparse_coo_tensor", + "sparse_csc_tensor", + "sparse_csr_tensor", + "split_copy", + "split_with_sizes", + "split_with_sizes_copy", + "spmm", + "sqrt", + "sqrt_", + "square", + "square_", + "squeeze", + "squeeze_copy", + "sspaddmm", + "stack", + "std", + "std_mean", + "sub", + "subtract", + "sum", + "svd", + "svd", + "swapaxes", + "swapdims", + "sym_constrain_range", + "sym_constrain_range_for_size", + "t", + "t_copy", + "take", + "take_along_dim", + "tan", + "tan_", + "tanh", + "tanh_", + "tensor", + "tensor_split", + "threshold", + "threshold_", + "tile", + "topk", + "topk", + "trace", + "transpose", + "transpose_copy", + "trapezoid", + "trapz", + "triangular_solve", + "triangular_solve", + "tril", + "tril_indices", + "triplet_margin_loss", + "triu", + "triu_indices", + "true_divide", + "trunc", + "trunc_", + "unbind", + "unbind_copy", + "unflatten", + "unfold_copy", + "unique_dim", + "unsafe_chunk", + "unsafe_split", + "unsafe_split_with_sizes", + "unsqueeze", + "unsqueeze_copy", + "values_copy", + "vander", + "var", + "var_mean", + "vdot", + "view_as_complex", + "view_as_complex_copy", + "view_as_real", + "view_as_real_copy", + "view_copy", + "vsplit", + "vstack", + "where", + "xlogy", + "xlogy_", + "zero_", + "zeros", + "zeros_like", +] + +@overload +def __and__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __and__(input: Tensor, other: Number | _complex) -> Tensor: ... +@overload +def __lshift__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __lshift__(input: Tensor, other: Number | _complex) -> Tensor: ... +@overload +def __or__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __or__(input: Tensor, other: Number | _complex) -> Tensor: ... +@overload +def __rshift__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __rshift__(input: Tensor, other: Number | _complex) -> Tensor: ... +@overload +def __xor__(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def __xor__(input: Tensor, other: Number | _complex) -> Tensor: ... +def _adaptive_avg_pool2d( + input: Tensor, + output_size: _int | SymInt | Sequence[_int | SymInt], +) -> Tensor: ... +def _adaptive_avg_pool3d( + input: Tensor, + output_size: _int | SymInt | Sequence[_int | SymInt], +) -> Tensor: ... +def _add_batch_dim(input: Tensor, batch_dim: _int, level: _int) -> Tensor: ... +@overload +def _add_relu( + input: Tensor, + other: Tensor, + *, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def _add_relu( + input: Tensor, + other: Number | _complex, + alpha: Number | _complex = 1, +) -> Tensor: ... +@overload +def _add_relu_( + input: Tensor, + other: Tensor, + *, + alpha: Number | _complex = 1, +) -> Tensor: ... +@overload +def _add_relu_( + input: Tensor, + other: Number | _complex, + alpha: Number | _complex = 1, +) -> Tensor: ... +def _addmm_activation( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + use_gelu: _bool = False, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def _aminmax(input: Tensor) -> tuple[Tensor, Tensor]: ... +@overload +def _aminmax( + input: Tensor, + dim: _int, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: ... +def _amp_foreach_non_finite_check_and_unscale_( + self: tuple[Tensor, ...] | list[Tensor] | None, + found_inf: Tensor, + inv_scale: Tensor, +) -> None: ... +def _amp_update_scale_( + input: Tensor, + growth_tracker: Tensor, + found_inf: Tensor, + scale_growth_factor: _float, + scale_backoff_factor: _float, + growth_interval: _int, +) -> Tensor: ... +@overload +def _assert_async(input: Tensor) -> None: + r""" + _assert_async(tensor) -> void + + Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, + this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for + CUDA tensors, we DO NOT synchronize and you may only find out the assertion + failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for + testing invariants in CUDA tensors without giving up performance. This function + is NOT intended to be used for regular error checking, as it will trash your CUDA + context if the assert fails (forcing you to restart your PyTorch process.) + + Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. + """ + +@overload +def _assert_async(input: Tensor, assert_msg: str) -> None: + r""" + _assert_async(tensor) -> void + + Asynchronously assert that the contents of tensor are nonzero. For CPU tensors, + this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for + CUDA tensors, we DO NOT synchronize and you may only find out the assertion + failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for + testing invariants in CUDA tensors without giving up performance. This function + is NOT intended to be used for regular error checking, as it will trash your CUDA + context if the assert fails (forcing you to restart your PyTorch process.) + + Args: + tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero + elements (including False for boolean tensors) cause an assertion failure + to be raised. + """ + +def _assert_scalar(self: Number | _complex, assert_msg: str) -> None: ... +def _assert_tensor_metadata( + a: Tensor, + size: Sequence[_int | SymInt] | None = None, + stride: Sequence[_int | SymInt] | None = None, + dtype: _dtype | None = None, + *, + device: DeviceLikeType | None = None, + layout: _layout | None = None, +) -> None: ... +def _batch_norm_impl_index( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + momentum: _float, + eps: _float, + cudnn_enabled: _bool, +) -> tuple[Tensor, Tensor, Tensor, Tensor, _int]: ... +def _cast_Byte(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Char(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Double(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Float(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Half(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Int(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Long(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _cast_Short(input: Tensor, non_blocking: _bool = False) -> Tensor: ... +def _choose_qparams_per_tensor( + input: Tensor, + reduce_range: _bool = False, +) -> tuple[_float, _int]: ... +def _chunk_cat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int, + num_chunks: _int, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _coalesce(input: Tensor) -> Tensor: ... +def _compute_linear_combination( + input: Tensor, + coefficients: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _conj(input: Tensor) -> Tensor: ... +def _conj_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def _conj_physical(input: Tensor) -> Tensor: ... +def _convert_indices_from_coo_to_csr( + input: Tensor, + size: _int, + *, + out_int32: _bool = False, + out: Tensor | None = None, +) -> Tensor: ... +def _convert_indices_from_csr_to_coo( + crow_indices: Tensor, + col_indices: Tensor, + *, + out_int32: _bool = False, + transpose: _bool = False, + out: Tensor | None = None, +) -> Tensor: ... +def _convert_weight_to_int4pack(input: Tensor, innerKTiles: _int) -> Tensor: ... +def _convert_weight_to_int4pack_for_cpu( + input: Tensor, + innerKTiles: _int, +) -> Tensor: ... +@overload +def _convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + transposed: _bool, + output_padding: _size, + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, + cudnn_enabled: _bool, +) -> Tensor: ... +@overload +def _convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + transposed: _bool, + output_padding: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, + cudnn_enabled: _bool, + allow_tf32: _bool, +) -> Tensor: ... +def _convolution_mode( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: str, + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def _copy_from( + input: Tensor, + dst: Tensor, + non_blocking: _bool = False, +) -> Tensor: ... +def _copy_from_and_resize(input: Tensor, dst: Tensor) -> Tensor: ... +def _cslt_compress(input: Tensor) -> Tensor: ... +def _cslt_sparse_mm( + compressed_A: Tensor, + dense_B: Tensor, + bias: Tensor | None = None, + alpha: Tensor | None = None, + out_dtype: _dtype | None = None, + transpose_result: _bool = False, + alg_id: _int = 0, + split_k: _int = 1, + split_k_mode: _int = -1, +) -> Tensor: ... +def _cslt_sparse_mm_search( + compressed_A: Tensor, + dense_B: Tensor, + bias: Tensor | None = None, + alpha: Tensor | None = None, + out_dtype: _dtype | None = None, + transpose_result: _bool = False, +) -> _int: ... +@overload +def _ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: _size, + target_lengths: _size, + blank: _int = 0, + zero_infinity: _bool = False, +) -> tuple[Tensor, Tensor]: ... +@overload +def _ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: _int = 0, + zero_infinity: _bool = False, +) -> tuple[Tensor, Tensor]: ... +@overload +def _cudnn_ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: _size, + target_lengths: _size, + blank: _int, + deterministic: _bool, + zero_infinity: _bool, +) -> tuple[Tensor, Tensor]: ... +@overload +def _cudnn_ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: _int, + deterministic: _bool, + zero_infinity: _bool, +) -> tuple[Tensor, Tensor]: ... +def _cudnn_init_dropout_state( + dropout: _float, + train: _bool, + dropout_seed: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def _cudnn_rnn( + input: Tensor, + weight: tuple[Tensor, ...] | list[Tensor] | None, + weight_stride0: _int, + weight_buf: Tensor | None, + hx: Tensor, + cx: Tensor | None, + mode: _int, + hidden_size: _int | SymInt, + proj_size: _int | SymInt, + num_layers: _int, + batch_first: _bool, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_sizes: Sequence[_int | SymInt], + dropout_state: Tensor | None, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _cudnn_rnn_flatten_weight( + weight_arr: tuple[Tensor, ...] | list[Tensor] | None, + weight_stride0: _int, + input_size: _int | SymInt, + mode: _int, + hidden_size: _int | SymInt, + proj_size: _int | SymInt, + num_layers: _int, + batch_first: _bool, + bidirectional: _bool, +) -> Tensor: ... +def _cufft_clear_plan_cache(device_index: _int) -> None: ... +def _cufft_get_plan_cache_max_size(device_index: _int) -> _int: ... +def _cufft_get_plan_cache_size(device_index: _int) -> _int: ... +def _cufft_set_plan_cache_max_size( + device_index: _int, + max_size: _int, +) -> None: ... +def _cummax_helper( + input: Tensor, + values: Tensor, + indices: Tensor, + dim: _int, +) -> None: ... +def _cummin_helper( + input: Tensor, + values: Tensor, + indices: Tensor, + dim: _int, +) -> None: ... +def _debug_has_internal_overlap(input: Tensor) -> _int: ... +def _dim_arange(like: Tensor, dim: _int) -> Tensor: ... +def _dirichlet_grad(x: Tensor, alpha: Tensor, total: Tensor) -> Tensor: ... +def _disable_functionalization(): ... +def _dyn_quant_matmul_4bit( + inp: Tensor, + packed_weights: Tensor, + block_size: _int, + in_features: _int, + out_features: _int, +) -> Tensor: ... +def _dyn_quant_pack_4bit_weight( + weights: Tensor, + scales_zeros: Tensor, + bias: Tensor | None, + block_size: _int, + in_features: _int, + out_features: _int, +) -> Tensor: ... +@overload +def _efficientzerotensor( + size: Sequence[_int | SymInt], + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def _efficientzerotensor( + *size: _int | SymInt, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def _embedding_bag( + weight: Tensor, + indices: Tensor, + offsets: Tensor, + scale_grad_by_freq: _bool = False, + mode: _int = 0, + sparse: _bool = False, + per_sample_weights: Tensor | None = None, + include_last_offset: _bool = False, + padding_idx: _int = -1, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +def _embedding_bag_forward_only( + weight: Tensor, + indices: Tensor, + offsets: Tensor, + scale_grad_by_freq: _bool = False, + mode: _int = 0, + sparse: _bool = False, + per_sample_weights: Tensor | None = None, + include_last_offset: _bool = False, + padding_idx: _int = -1, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def _empty_affine_quantized( + size: Sequence[_int | SymInt], + *, + scale: _float = 1, + zero_point: _int = 0, + memory_format: memory_format | None = contiguous_format, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def _empty_affine_quantized( + *size: _int | SymInt, + scale: _float = 1, + zero_point: _int = 0, + memory_format: memory_format | None = contiguous_format, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def _empty_per_channel_affine_quantized( + size: Sequence[_int | SymInt], + *, + scales: Tensor, + zero_points: Tensor, + axis: _int, + memory_format: memory_format | None = contiguous_format, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def _empty_per_channel_affine_quantized( + *size: _int | SymInt, + scales: Tensor, + zero_points: Tensor, + axis: _int, + memory_format: memory_format | None = contiguous_format, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def _enable_functionalization(*, reapply_views: _bool = False) -> None: ... +def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: ... +def _fake_quantize_learnable_per_channel_affine( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + axis: _int, + quant_min: _int, + quant_max: _int, + grad_factor: _float = 1.0, +) -> Tensor: ... +def _fake_quantize_learnable_per_tensor_affine( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + quant_min: _int, + quant_max: _int, + grad_factor: _float = 1.0, +) -> Tensor: ... +def _fake_quantize_per_tensor_affine_cachemask_tensor_qparams( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + fake_quant_enabled: Tensor, + quant_min: _int, + quant_max: _int, +) -> torch.return_types._fake_quantize_per_tensor_affine_cachemask_tensor_qparams: # fmt: skip + ... +def _fft_c2c( + input: Tensor, + dim: Sequence[_int | SymInt], + normalization: _int, + forward: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _fft_c2r( + input: Tensor, + dim: _size, + normalization: _int, + last_dim_size: _int | SymInt, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _fft_r2c( + input: Tensor, + dim: _size, + normalization: _int, + onesided: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _fill_mem_eff_dropout_mask_( + input: Tensor, + dropout_p: _float, + seed: _int, + offset: _int, +) -> Tensor: ... +def _foobar( + input: Tensor, + arg1: _bool = True, + arg2: _bool = True, + *, + arg3: _bool = True, +) -> Tensor: ... +def _foreach_abs( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_abs(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.abs` to each Tensor of the input list. + """ + +def _foreach_abs_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_abs_(self: List[Tensor]) -> None + + Apply :func:`torch.abs` to each Tensor of the input list. + """ + +def _foreach_acos( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_acos(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.acos` to each Tensor of the input list. + """ + +def _foreach_acos_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_acos_(self: List[Tensor]) -> None + + Apply :func:`torch.acos` to each Tensor of the input list. + """ + +@overload +def _foreach_add( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_add( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, + *, + alpha: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_add( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, + *, + alpha: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_add( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_add_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_add_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, + *, + alpha: Number | _complex = 1, +) -> None: ... +@overload +def _foreach_add_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, + *, + alpha: Number | _complex = 1, +) -> None: ... +@overload +def _foreach_add_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_addcdiv( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Tensor, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + value: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcdiv_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_addcdiv_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Tensor, +) -> None: ... +@overload +def _foreach_addcdiv_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + value: Number | _complex = 1, +) -> None: ... +@overload +def _foreach_addcmul( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Tensor, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + value: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_addcmul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_addcmul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Tensor, +) -> None: ... +@overload +def _foreach_addcmul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensor1: tuple[Tensor, ...] | list[Tensor] | None, + tensor2: tuple[Tensor, ...] | list[Tensor] | None, + value: Number | _complex = 1, +) -> None: ... +def _foreach_asin( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_asin(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.asin` to each Tensor of the input list. + """ + +def _foreach_asin_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_asin_(self: List[Tensor]) -> None + + Apply :func:`torch.asin` to each Tensor of the input list. + """ + +def _foreach_atan( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_atan(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.atan` to each Tensor of the input list. + """ + +def _foreach_atan_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_atan_(self: List[Tensor]) -> None + + Apply :func:`torch.atan` to each Tensor of the input list. + """ + +def _foreach_ceil( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_ceil(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.ceil` to each Tensor of the input list. + """ + +def _foreach_ceil_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_ceil_(self: List[Tensor]) -> None + + Apply :func:`torch.ceil` to each Tensor of the input list. + """ + +@overload +def _foreach_clamp_max( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_max_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_clamp_max_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_clamp_max_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +@overload +def _foreach_clamp_min( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_clamp_min_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_clamp_min_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_clamp_min_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_copy_( + self: tuple[Tensor, ...] | list[Tensor] | None, + src: tuple[Tensor, ...] | list[Tensor] | None, + non_blocking: _bool = False, +) -> None: ... +def _foreach_cos( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_cos(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.cos` to each Tensor of the input list. + """ + +def _foreach_cos_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_cos_(self: List[Tensor]) -> None + + Apply :func:`torch.cos` to each Tensor of the input list. + """ + +def _foreach_cosh( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_cosh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.cosh` to each Tensor of the input list. + """ + +def _foreach_cosh_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_cosh_(self: List[Tensor]) -> None + + Apply :func:`torch.cosh` to each Tensor of the input list. + """ + +@overload +def _foreach_div( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_div( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_div( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_div( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_div_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_div_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, +) -> None: ... +@overload +def _foreach_div_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_div_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_erf( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_erf(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.erf` to each Tensor of the input list. + """ + +def _foreach_erf_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_erf_(self: List[Tensor]) -> None + + Apply :func:`torch.erf` to each Tensor of the input list. + """ + +def _foreach_erfc( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_erfc(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.erfc` to each Tensor of the input list. + """ + +def _foreach_erfc_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_erfc_(self: List[Tensor]) -> None + + Apply :func:`torch.erfc` to each Tensor of the input list. + """ + +def _foreach_exp( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_exp(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.exp` to each Tensor of the input list. + """ + +def _foreach_exp_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_exp_(self: List[Tensor]) -> None + + Apply :func:`torch.exp` to each Tensor of the input list. + """ + +def _foreach_expm1( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_expm1(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.expm1` to each Tensor of the input list. + """ + +def _foreach_expm1_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_expm1_(self: List[Tensor]) -> None + + Apply :func:`torch.expm1` to each Tensor of the input list. + """ + +def _foreach_floor( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_floor(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.floor` to each Tensor of the input list. + """ + +def _foreach_floor_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_floor_(self: List[Tensor]) -> None + + Apply :func:`torch.floor` to each Tensor of the input list. + """ + +def _foreach_frac( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_frac(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.frac` to each Tensor of the input list. + """ + +def _foreach_frac_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_frac_(self: List[Tensor]) -> None + + Apply :func:`torch.frac` to each Tensor of the input list. + """ + +@overload +def _foreach_lerp( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weight: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_lerp( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weight: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_lerp( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weights: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_lerp_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weight: Number | _complex, +) -> None: ... +@overload +def _foreach_lerp_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weight: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_lerp_( + self: tuple[Tensor, ...] | list[Tensor] | None, + tensors1: tuple[Tensor, ...] | list[Tensor] | None, + weights: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_lgamma( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_lgamma(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.lgamma` to each Tensor of the input list. + """ + +def _foreach_lgamma_( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: + r""" + _foreach_lgamma_(self: List[Tensor]) -> None + + Apply :func:`torch.lgamma` to each Tensor of the input list. + """ + +def _foreach_log( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_log(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log` to each Tensor of the input list. + """ + +def _foreach_log10( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_log10(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log10` to each Tensor of the input list. + """ + +def _foreach_log10_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_log10_(self: List[Tensor]) -> None + + Apply :func:`torch.log10` to each Tensor of the input list. + """ + +def _foreach_log1p( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_log1p(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log1p` to each Tensor of the input list. + """ + +def _foreach_log1p_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_log1p_(self: List[Tensor]) -> None + + Apply :func:`torch.log1p` to each Tensor of the input list. + """ + +def _foreach_log2( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_log2(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.log2` to each Tensor of the input list. + """ + +def _foreach_log2_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_log2_(self: List[Tensor]) -> None + + Apply :func:`torch.log2` to each Tensor of the input list. + """ + +def _foreach_log_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_log_(self: List[Tensor]) -> None + + Apply :func:`torch.log` to each Tensor of the input list. + """ + +def _foreach_max( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_maximum( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_maximum( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_maximum( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_maximum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_maximum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_maximum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +@overload +def _foreach_minimum( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_minimum( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_minimum( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_minimum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_minimum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_minimum_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +@overload +def _foreach_mul( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_mul( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_mul( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_mul( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_mul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_mul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: Tensor, +) -> None: ... +@overload +def _foreach_mul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +@overload +def _foreach_mul_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_neg( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_neg(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.neg` to each Tensor of the input list. + """ + +def _foreach_neg_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_neg_(self: List[Tensor]) -> None + + Apply :func:`torch.neg` to each Tensor of the input list. + """ + +def _foreach_norm( + self: tuple[Tensor, ...] | list[Tensor] | None, + ord: Number | _complex = 2, + dtype: _dtype | None = None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow( + self: Number | _complex, + exponent: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_pow_( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_pow_( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: Number | _complex, +) -> None: ... +@overload +def _foreach_pow_( + self: tuple[Tensor, ...] | list[Tensor] | None, + exponent: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: ... +def _foreach_reciprocal( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_reciprocal(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.reciprocal` to each Tensor of the input list. + """ + +def _foreach_reciprocal_( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: + r""" + _foreach_reciprocal_(self: List[Tensor]) -> None + + Apply :func:`torch.reciprocal` to each Tensor of the input list. + """ + +def _foreach_round( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_round(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.round` to each Tensor of the input list. + """ + +def _foreach_round_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_round_(self: List[Tensor]) -> None + + Apply :func:`torch.round` to each Tensor of the input list. + """ + +def _foreach_rsqrt( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +def _foreach_rsqrt_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: ... +def _foreach_sigmoid( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_sigmoid(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sigmoid` to each Tensor of the input list. + """ + +def _foreach_sigmoid_( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> None: + r""" + _foreach_sigmoid_(self: List[Tensor]) -> None + + Apply :func:`torch.sigmoid` to each Tensor of the input list. + """ + +def _foreach_sign( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +def _foreach_sign_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: ... +def _foreach_sin( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_sin(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sin` to each Tensor of the input list. + """ + +def _foreach_sin_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_sin_(self: List[Tensor]) -> None + + Apply :func:`torch.sin` to each Tensor of the input list. + """ + +def _foreach_sinh( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_sinh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sinh` to each Tensor of the input list. + """ + +def _foreach_sinh_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_sinh_(self: List[Tensor]) -> None + + Apply :func:`torch.sinh` to each Tensor of the input list. + """ + +def _foreach_sqrt( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_sqrt(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.sqrt` to each Tensor of the input list. + """ + +def _foreach_sqrt_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_sqrt_(self: List[Tensor]) -> None + + Apply :func:`torch.sqrt` to each Tensor of the input list. + """ + +@overload +def _foreach_sub( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_sub( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, + *, + alpha: Number | _complex = 1, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_sub( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> tuple[Tensor, ...]: ... +@overload +def _foreach_sub_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalars: Sequence[Number | _complex], +) -> None: ... +@overload +def _foreach_sub_( + self: tuple[Tensor, ...] | list[Tensor] | None, + other: tuple[Tensor, ...] | list[Tensor] | None, + *, + alpha: Number | _complex = 1, +) -> None: ... +@overload +def _foreach_sub_( + self: tuple[Tensor, ...] | list[Tensor] | None, + scalar: Number | _complex, +) -> None: ... +def _foreach_tan( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_tan(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.tan` to each Tensor of the input list. + """ + +def _foreach_tan_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_tan_(self: List[Tensor]) -> None + + Apply :func:`torch.tan` to each Tensor of the input list. + """ + +def _foreach_tanh( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_tanh(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.tanh` to each Tensor of the input list. + """ + +def _foreach_tanh_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_tanh_(self: List[Tensor]) -> None + + Apply :func:`torch.tanh` to each Tensor of the input list. + """ + +def _foreach_trunc( + self: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + _foreach_trunc(self: List[Tensor]) -> List[Tensor] + + Apply :func:`torch.trunc` to each Tensor of the input list. + """ + +def _foreach_trunc_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_trunc_(self: List[Tensor]) -> None + + Apply :func:`torch.trunc` to each Tensor of the input list. + """ + +def _foreach_zero_(self: tuple[Tensor, ...] | list[Tensor] | None) -> None: + r""" + _foreach_zero_(self: List[Tensor]) -> None + + Apply :func:`torch.zero` to each Tensor of the input list. + """ + +def _from_functional_tensor(t: Tensor) -> Tensor: ... +def _functional_assert_async( + input: Tensor, + assert_msg: str, + dep_token: Tensor, +) -> Tensor: ... +def _functional_assert_scalar( + self: Number | _complex, + assert_msg: str, + dep_token: Tensor, +) -> Tensor: ... +def _functional_sym_constrain_range( + size: Number | _complex, + min: _int | None, + max: _int | None, + dep_token: Tensor, +) -> Tensor: ... +def _functional_sym_constrain_range_for_size( + size: Number | _complex, + min: _int | None, + max: _int | None, + dep_token: Tensor, +) -> Tensor: ... +def _functionalize_apply_view_metas(tensor: Tensor, base: Tensor) -> Tensor: ... +def _functionalize_are_all_mutations_hidden_from_autograd( + t: Tensor, +) -> _bool: ... +def _functionalize_are_all_mutations_under_no_grad_or_inference_mode( + t: Tensor, +) -> _bool: ... +def _functionalize_commit_update(t: Tensor) -> None: ... +def _functionalize_has_metadata_mutation(tensor: Tensor) -> _bool: ... +def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ... +def _functionalize_mark_mutation_hidden_from_autograd(t: Tensor) -> None: ... +def _functionalize_replace(self_: Tensor, other: Tensor) -> None: ... +def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ... +def _functionalize_sync(t: Tensor) -> None: ... +def _functionalize_unsafe_set(dst: Tensor, src: Tensor) -> None: ... +def _functionalize_was_inductor_storage_resized(t: Tensor) -> _bool: ... +def _functionalize_was_storage_changed(tensor: Tensor) -> _bool: ... +@overload +def _fused_adagrad_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + state_sums: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: Tensor, + lr_decay: _float, + weight_decay: _float, + eps: _float, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adagrad_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + state_sums: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: _float, + lr_decay: _float, + weight_decay: _float, + eps: _float, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adam_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + exp_avgs: tuple[Tensor, ...] | list[Tensor] | None, + exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + max_exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: Tensor, + beta1: _float, + beta2: _float, + weight_decay: _float, + eps: _float, + amsgrad: _bool, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adam_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + exp_avgs: tuple[Tensor, ...] | list[Tensor] | None, + exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + max_exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: _float, + beta1: _float, + beta2: _float, + weight_decay: _float, + eps: _float, + amsgrad: _bool, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adamw_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + exp_avgs: tuple[Tensor, ...] | list[Tensor] | None, + exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + max_exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: Tensor, + beta1: _float, + beta2: _float, + weight_decay: _float, + eps: _float, + amsgrad: _bool, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_adamw_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + exp_avgs: tuple[Tensor, ...] | list[Tensor] | None, + exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + max_exp_avg_sqs: tuple[Tensor, ...] | list[Tensor] | None, + state_steps: tuple[Tensor, ...] | list[Tensor] | None, + *, + lr: _float, + beta1: _float, + beta2: _float, + weight_decay: _float, + eps: _float, + amsgrad: _bool, + maximize: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +def _fused_dropout( + input: Tensor, + p: _float, + generator: Generator | None = None, +) -> tuple[Tensor, Tensor]: ... +def _fused_moving_avg_obs_fq_helper( + input: Tensor, + observer_on: Tensor, + fake_quant_on: Tensor, + running_min: Tensor, + running_max: Tensor, + scale: Tensor, + zero_point: Tensor, + averaging_const: _float, + quant_min: _int, + quant_max: _int, + ch_axis: _int, + per_row_fake_quant: _bool = False, + symmetric_quant: _bool = False, +) -> torch.return_types._fused_moving_avg_obs_fq_helper: ... +def _fused_rms_norm( + input: Tensor, + normalized_shape_ndim: _int, + weight: Tensor, + eps: _float, +) -> Tensor: ... +def _fused_sdp_choice( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor | None = None, + dropout_p: _float = 0.0, + is_causal: _bool = False, + *, + scale: _float | None = None, + enable_gqa: _bool = False, +) -> _int: ... +@overload +def _fused_sgd_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + momentum_buffer_list: tuple[Tensor, ...] | list[Tensor] | None, + *, + weight_decay: _float, + momentum: _float, + lr: Tensor, + dampening: _float, + nesterov: _bool, + maximize: _bool, + is_first_step: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +@overload +def _fused_sgd_( + self: tuple[Tensor, ...] | list[Tensor] | None, + grads: tuple[Tensor, ...] | list[Tensor] | None, + momentum_buffer_list: tuple[Tensor, ...] | list[Tensor] | None, + *, + weight_decay: _float, + momentum: _float, + lr: _float, + dampening: _float, + nesterov: _bool, + maximize: _bool, + is_first_step: _bool, + grad_scale: Tensor | None = None, + found_inf: Tensor | None = None, +) -> None: ... +def _fw_primal_copy( + input: Tensor, + level: _int, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _grid_sampler_2d_cpu_fallback( + input: Tensor, + grid: Tensor, + interpolation_mode: _int, + padding_mode: _int, + align_corners: _bool, +) -> Tensor: ... +def _grouped_mm( + input: Tensor, + mat2: Tensor, + offs: Tensor | None = None, + bias: Tensor | None = None, + out_dtype: _dtype | None = None, +) -> Tensor: ... +def _has_compatible_shallow_copy_type( + input: Tensor, + from_: Tensor, +) -> _bool: ... +def _histogramdd_bin_edges( + input: Tensor, + bins: _size, + *, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> tuple[Tensor, ...]: ... +def _histogramdd_from_bin_cts( + input: Tensor, + bins: _size, + *, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> Tensor: ... +def _histogramdd_from_bin_tensors( + input: Tensor, + bins: tuple[Tensor, ...] | list[Tensor] | None, + *, + weight: Tensor | None = None, + density: _bool = False, +) -> Tensor: ... +def _index_put_impl_( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, + unsafe: _bool = False, +) -> Tensor: ... +def _indices_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def _int_mm( + input: Tensor, + mat2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _is_all_true(input: Tensor) -> Tensor: ... +def _is_any_true(input: Tensor) -> Tensor: ... +def _is_functional_tensor(t: Tensor) -> _bool: ... +def _is_functional_tensor_base(t: Tensor) -> _bool: ... +def _is_zerotensor(input: Tensor) -> _bool: ... +def _lazy_clone(input: Tensor) -> Tensor: ... +def _linalg_check_errors( + info: Tensor, + api_name: str, + *, + is_matrix: _bool, +) -> None: ... +def _linalg_det( + A: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_det: ... +def _linalg_eigh( + A: Tensor, + UPLO: str = "L", + compute_v: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_eigh: ... +def _linalg_slogdet( + A: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_slogdet: ... +def _linalg_solve_ex( + A: Tensor, + B: Tensor, + *, + left: _bool = True, + check_errors: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_solve_ex: ... +def _linalg_svd( + A: Tensor, + full_matrices: _bool = False, + compute_uv: _bool = True, + *, + driver: str | None = None, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types._linalg_svd: ... +def _log_softmax( + input: Tensor, + dim: _int, + half_to_float: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _log_softmax_backward_data( + grad_output: Tensor, + output: Tensor, + dim: _int, + input_dtype: _dtype, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _logcumsumexp( + input: Tensor, + dim: _int, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _lstm_mps( + input: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _lu_with_info( + input: Tensor, + pivot: _bool = True, + check_errors: _bool = True, +) -> torch.return_types._lu_with_info: ... +def _make_dep_token( + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def _make_dual(primal: Tensor, tangent: Tensor, level: _int) -> Tensor: ... +def _make_dual_copy( + primal: Tensor, + tangent: Tensor, + level: _int, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _make_per_channel_quantized_tensor( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + axis: _int, +) -> Tensor: ... +def _make_per_tensor_quantized_tensor( + input: Tensor, + scale: _float, + zero_point: _int, +) -> Tensor: ... +def _masked_scale(input: Tensor, mask: Tensor, scale: _float) -> Tensor: ... +def _masked_softmax( + input: Tensor, + mask: Tensor, + dim: _int | None = None, + mask_type: _int | None = None, +) -> Tensor: ... +def _mixed_dtypes_linear( + input: Tensor, + weight: Tensor, + scale: Tensor, + *, + bias: Tensor | None = None, + activation: str | None = None, +) -> Tensor: ... +def _mkldnn_reshape(input: Tensor, shape: _size) -> Tensor: ... +def _mkldnn_transpose(input: Tensor, dim0: _int, dim1: _int) -> Tensor: ... +def _mkldnn_transpose_(input: Tensor, dim0: _int, dim1: _int) -> Tensor: ... +def _mps_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def _mps_convolution_transpose( + input: Tensor, + weight: Tensor, + padding: Sequence[_int | SymInt], + output_padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +@overload +def _native_batch_norm_legit( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor, + running_var: Tensor, + training: _bool, + momentum: _float, + eps: _float, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> tuple[Tensor, Tensor, Tensor]: ... +@overload +def _native_batch_norm_legit( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + training: _bool, + momentum: _float, + eps: _float, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> tuple[Tensor, Tensor, Tensor]: ... +def _native_batch_norm_legit_no_training( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor, + running_var: Tensor, + momentum: _float, + eps: _float, +) -> tuple[Tensor, Tensor, Tensor]: ... +def _native_multi_head_attention( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim: _int, + num_head: _int, + qkv_weight: Tensor, + qkv_bias: Tensor, + proj_weight: Tensor, + proj_bias: Tensor, + mask: Tensor | None = None, + need_weights: _bool = True, + average_attn_weights: _bool = True, + mask_type: _int | None = None, +) -> tuple[Tensor, Tensor]: ... +def _neg_view(input: Tensor) -> Tensor: ... +def _neg_view_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def _nested_compute_contiguous_strides_offsets( + nested_size: Tensor, +) -> tuple[Tensor, Tensor]: ... +def _nested_from_padded( + padded: Tensor, + cpu_nested_shape_example: Tensor, + fuse_transform_0213: _bool = False, +) -> Tensor: ... +def _nested_from_padded_and_nested_example( + padded: Tensor, + nt_example: Tensor, +) -> Tensor: ... +def _nested_from_padded_tensor( + padded: Tensor, + offsets: Tensor, + dummy: Tensor, + ragged_idx: _int = 1, + min_seqlen: Tensor | None = None, + max_seqlen: Tensor | None = None, + sum_S: _int | SymInt | None = None, +) -> Tensor: ... +def _nested_get_jagged_dummy(any: Tensor) -> Tensor: ... +def _nested_get_lengths(input: Tensor) -> Tensor: ... +def _nested_get_max_seqlen(input: Tensor) -> Tensor: ... +def _nested_get_min_seqlen(input: Tensor) -> Tensor: ... +def _nested_get_offsets(input: Tensor) -> Tensor: ... +def _nested_get_ragged_idx(input: Tensor) -> _int: ... +def _nested_get_values(input: Tensor) -> Tensor: ... +def _nested_get_values_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _nested_tensor_from_mask( + t: Tensor, + mask: Tensor, + mask_check: _bool = True, +) -> Tensor: ... +def _nested_tensor_from_mask_left_aligned(t: Tensor, mask: Tensor) -> _bool: ... +def _nested_tensor_from_tensor_list( + list: tuple[Tensor, ...] | list[Tensor] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = None, +) -> Tensor: ... +def _nested_tensor_softmax_with_shape( + input: Tensor, + query: Tensor, +) -> Tensor: ... +def _nested_view_from_buffer( + input: Tensor, + nested_size: Tensor, + nested_strides: Tensor, + offsets: Tensor, +) -> Tensor: ... +def _nested_view_from_buffer_copy( + input: Tensor, + nested_size: Tensor, + nested_strides: Tensor, + offsets: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _nested_view_from_jagged( + input: Tensor, + offsets: Tensor, + dummy: Tensor, + lengths: Tensor | None = None, + ragged_idx: _int = 1, + min_seqlen: Tensor | None = None, + max_seqlen: Tensor | None = None, +) -> Tensor: ... +def _nested_view_from_jagged_copy( + input: Tensor, + offsets: Tensor, + dummy: Tensor, + lengths: Tensor | None = None, + ragged_idx: _int = 1, + min_seqlen: Tensor | None = None, + max_seqlen: Tensor | None = None, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _nnpack_available() -> _bool: ... +def _nnpack_spatial_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: _int | SymInt | Sequence[_int | SymInt], + stride: _int | SymInt | Sequence[_int | SymInt] = 1, +) -> Tensor: ... +def _pack_padded_sequence( + input: Tensor, + lengths: Tensor, + batch_first: _bool, +) -> tuple[Tensor, Tensor]: ... +def _pad_packed_sequence( + data: Tensor, + batch_sizes: Tensor, + batch_first: _bool, + padding_value: Number | _complex, + total_length: _int, +) -> tuple[Tensor, Tensor]: ... +def _pin_memory( + input: Tensor, + device: DeviceLikeType | None = None, +) -> Tensor: ... +def _prelu_kernel(input: Tensor, weight: Tensor) -> Tensor: ... +def _print(s: str) -> None: ... +def _propagate_xla_data(input: Tensor, output: Tensor) -> None: ... +def _remove_batch_dim( + input: Tensor, + level: _int, + batch_size: _int | SymInt, + out_dim: _int, +) -> Tensor: ... +def _reshape_alias_copy( + input: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + *, + out: Tensor | None = None, +) -> Tensor: ... +def _reshape_from_tensor(input: Tensor, shape: Tensor) -> Tensor: ... +def _resize_output_( + input: Tensor, + size: Sequence[_int | SymInt], + device: DeviceLikeType | None, +) -> Tensor: ... +def _rowwise_prune( + weight: Tensor, + mask: Tensor, + compressed_indices_dtype: _dtype, +) -> tuple[Tensor, Tensor]: ... +def _safe_softmax( + input: Tensor, + dim: _int, + dtype: _dtype | None = None, +) -> Tensor: ... +def _sample_dirichlet( + input: Tensor, + generator: Generator | None = None, +) -> Tensor: ... +def _saturate_weight_to_fp16(weight: Tensor) -> Tensor: ... +def _scaled_dot_product_attention_math( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor | None = None, + dropout_p: _float = 0.0, + is_causal: _bool = False, + dropout_mask: Tensor | None = None, + *, + scale: _float | None = None, + enable_gqa: _bool = False, +) -> tuple[Tensor, Tensor]: ... +def _scaled_dot_product_attention_math_for_mps( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor | None = None, + dropout_p: _float = 0.0, + is_causal: _bool = False, + dropout_mask: Tensor | None = None, + *, + scale: _float | None = None, +) -> tuple[Tensor, Tensor]: ... +def _scaled_dot_product_cudnn_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor | None, + compute_log_sumexp: _bool, + dropout_p: _float = 0.0, + is_causal: _bool = False, + return_debug_mask: _bool = False, + *, + scale: _float | None = None, +) -> torch.return_types._scaled_dot_product_cudnn_attention: ... +def _scaled_dot_product_efficient_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor | None, + compute_log_sumexp: _bool, + dropout_p: _float = 0.0, + is_causal: _bool = False, + *, + scale: _float | None = None, +) -> torch.return_types._scaled_dot_product_efficient_attention: ... +def _scaled_dot_product_flash_attention( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: _float = 0.0, + is_causal: _bool = False, + return_debug_mask: _bool = False, + *, + scale: _float | None = None, +) -> torch.return_types._scaled_dot_product_flash_attention: ... +def _scaled_dot_product_flash_attention_for_cpu( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: _float = 0.0, + is_causal: _bool = False, + *, + attn_mask: Tensor | None = None, + scale: _float | None = None, +) -> torch.return_types._scaled_dot_product_flash_attention_for_cpu: ... +def _scaled_grouped_mm( + input: Tensor, + mat2: Tensor, + scale_a: Tensor, + scale_b: Tensor, + offs: Tensor | None = None, + bias: Tensor | None = None, + scale_result: Tensor | None = None, + out_dtype: _dtype | None = None, + use_fast_accum: _bool = False, +) -> Tensor: ... +def _scaled_mm( + input: Tensor, + mat2: Tensor, + scale_a: Tensor, + scale_b: Tensor, + bias: Tensor | None = None, + scale_result: Tensor | None = None, + out_dtype: _dtype | None = None, + use_fast_accum: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _shape_as_tensor(input: Tensor) -> Tensor: ... +def _sobol_engine_draw( + quasi: Tensor, + n: _int, + sobolstate: Tensor, + dimension: _int, + num_generated: _int, + dtype: _dtype | None, +) -> tuple[Tensor, Tensor]: ... +def _sobol_engine_ff_( + input: Tensor, + n: _int, + sobolstate: Tensor, + dimension: _int, + num_generated: _int, +) -> Tensor: ... +def _sobol_engine_initialize_state_( + input: Tensor, + dimension: _int, +) -> Tensor: ... +def _sobol_engine_scramble_( + input: Tensor, + ltm: Tensor, + dimension: _int, +) -> Tensor: ... +def _softmax( + input: Tensor, + dim: _int, + half_to_float: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _softmax_backward_data( + grad_output: Tensor, + output: Tensor, + dim: _int, + input_dtype: _dtype, + *, + grad_input: Tensor | None = None, +) -> Tensor: ... +def _sparse_broadcast_to(input: Tensor, size: _size) -> Tensor: ... +def _sparse_broadcast_to_copy( + input: Tensor, + size: _size, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _sparse_csr_prod( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_csr_sum( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_log_softmax_backward_data( + grad_output: Tensor, + output: Tensor, + dim: _int, + input: Tensor, +) -> Tensor: ... +def _sparse_semi_structured_addmm( + input: Tensor, + mat1: Tensor, + mat1_meta: Tensor, + mat2: Tensor, + *, + alpha: Number | _complex = 1, + beta: Number | _complex = 1, + out_dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_semi_structured_apply( + input: Tensor, + thread_masks: Tensor, +) -> tuple[Tensor, Tensor]: ... +def _sparse_semi_structured_apply_dense( + input: Tensor, + thread_masks: Tensor, +) -> Tensor: ... +def _sparse_semi_structured_linear( + input: Tensor, + weight: Tensor, + meta: Tensor, + *, + bias: Tensor | None = None, + activation: str | None = None, + out_dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_semi_structured_mm( + mat1: Tensor, + mat1_meta: Tensor, + mat2: Tensor, + *, + out_dtype: _dtype | None = None, +) -> Tensor: ... +def _sparse_semi_structured_tile( + input: Tensor, + algorithm: str = "", + use_cutlass: _bool = True, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def _sparse_softmax_backward_data( + grad_output: Tensor, + output: Tensor, + dim: _int, + input: Tensor, +) -> Tensor: ... +def _sparse_sparse_matmul(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, *, dtype: _dtype) -> Tensor: ... +@overload +def _sparse_sum(input: Tensor, dim: _int | _size) -> Tensor: ... +@overload +def _sparse_sum( + input: Tensor, + dim: _int | _size, + *, + dtype: _dtype, +) -> Tensor: ... +def _stack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _standard_gamma( + input: Tensor, + generator: Generator | None = None, +) -> Tensor: ... +def _standard_gamma_grad(input: Tensor, output: Tensor) -> Tensor: ... +def _sync(t: Tensor) -> None: ... +@overload +def _test_autograd_multiple_dispatch(input: Tensor) -> Tensor: ... +@overload +def _test_autograd_multiple_dispatch(input: Tensor, b: _bool) -> Tensor: ... +def _test_autograd_multiple_dispatch_view(input: Tensor) -> Tensor: ... +def _test_autograd_multiple_dispatch_view_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def _test_check_tensor(input: Tensor) -> Tensor: ... +def _test_functorch_fallback(input: Tensor, other: Tensor) -> Tensor: ... +def _test_parallel_materialize( + input: Tensor, + num_parallel: _int, + skip_first: _bool = False, +) -> Tensor: ... +def _test_serialization_subcmul( + input: Tensor, + other: Tensor, + alpha: Number | _complex = 1, +) -> Tensor: ... +def _to_cpu( + tensors: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: ... +def _to_functional_tensor(t: Tensor) -> Tensor: ... +def _to_sparse_semi_structured(dense: Tensor) -> tuple[Tensor, Tensor]: ... +def _transform_bias_rescale_qkv( + qkv: Tensor, + qkv_bias: Tensor, + num_heads: _int, +) -> tuple[Tensor, Tensor, Tensor]: ... +def _transformer_encoder_layer_fwd( + src: Tensor, + embed_dim: _int, + num_heads: _int, + qkv_weight: Tensor, + qkv_bias: Tensor, + proj_weight: Tensor, + proj_bias: Tensor, + use_gelu: _bool, + norm_first: _bool, + eps: _float, + norm_weight_1: Tensor, + norm_bias_1: Tensor, + norm_weight_2: Tensor, + norm_bias_2: Tensor, + ffn_weight_1: Tensor, + ffn_bias_1: Tensor, + ffn_weight_2: Tensor, + ffn_bias_2: Tensor, + mask: Tensor | None = None, + mask_type: _int | None = None, +) -> Tensor: ... +def _trilinear( + i1: Tensor, + i2: Tensor, + i3: Tensor, + expand1: _size, + expand2: _size, + expand3: _size, + sumdim: _size, + unroll_dim: _int = 1, +) -> Tensor: ... +def _triton_multi_head_attention( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim: _int, + num_head: _int, + qkv_weight: Tensor, + qkv_bias: Tensor, + proj_weight: Tensor, + proj_bias: Tensor, + mask: Tensor | None = None, +) -> Tensor: ... +def _triton_scaled_dot_attention( + q: Tensor, + k: Tensor, + v: Tensor, + dropout_p: _float = 0.0, +) -> Tensor: ... +def _unique( + input: Tensor, + sorted: _bool = True, + return_inverse: _bool = False, +) -> tuple[Tensor, Tensor]: ... +def _unique2( + input: Tensor, + sorted: _bool = True, + return_inverse: _bool = False, + return_counts: _bool = False, +) -> tuple[Tensor, Tensor, Tensor]: ... +def _unpack_dual( + dual: Tensor, + level: _int, +) -> torch.return_types._unpack_dual: ... +def _unsafe_index( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, +) -> Tensor: ... +def _unsafe_index_put( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, +) -> Tensor: ... +def _unsafe_masked_index( + input: Tensor, + mask: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + fill: Number | _complex, +) -> Tensor: ... +def _unsafe_masked_index_put_accumulate( + input: Tensor, + mask: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, +) -> Tensor: ... +@overload +def _use_cudnn_ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: _int, +) -> _bool: ... +@overload +def _use_cudnn_ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: _size, + target_lengths: _size, + blank: _int, +) -> _bool: ... +def _use_cudnn_rnn_flatten_weight() -> _bool: ... +def _validate_compressed_sparse_indices( + is_crow: _bool, + compressed_idx: Tensor, + plain_idx: Tensor, + cdim: _int, + dim: _int, + nnz: _int, +) -> None: ... +def _validate_sparse_bsc_tensor_args( + ccol_indices: Tensor, + row_indices: Tensor, + values: Tensor, + size: _size, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_bsr_tensor_args( + crow_indices: Tensor, + col_indices: Tensor, + values: Tensor, + size: _size, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_compressed_tensor_args( + compressed_indices: Tensor, + plain_indices: Tensor, + values: Tensor, + size: _size, + layout: _layout, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_coo_tensor_args( + indices: Tensor, + values: Tensor, + size: _size, + is_coalesced: _bool | None = None, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_csc_tensor_args( + ccol_indices: Tensor, + row_indices: Tensor, + values: Tensor, + size: _size, + check_pinning: _bool | None = None, +) -> None: ... +def _validate_sparse_csr_tensor_args( + crow_indices: Tensor, + col_indices: Tensor, + values: Tensor, + size: _size, + check_pinning: _bool | None = None, +) -> None: ... +def _values_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def _weight_int4pack_mm( + input: Tensor, + mat2: Tensor, + qGroupSize: _int, + qScaleAndZeros: Tensor, +) -> Tensor: ... +def _weight_int4pack_mm_for_cpu( + input: Tensor, + mat2: Tensor, + qGroupSize: _int, + qScaleAndZeros: Tensor, +) -> Tensor: ... +def _weight_int4pack_mm_with_scales_and_zeros( + input: Tensor, + mat2: Tensor, + qGroupSize: _int, + qScale: Tensor, + qZeros: Tensor, +) -> Tensor: ... +def _weight_int8pack_mm( + input: Tensor, + mat2: Tensor, + scales: Tensor, +) -> Tensor: ... +def _weight_norm(v: Tensor, g: Tensor, dim: _int = 0) -> Tensor: ... +def _weight_norm_interface( + v: Tensor, + g: Tensor, + dim: _int = 0, +) -> tuple[Tensor, Tensor]: ... +def _wrapped_linear_prepack( + weight: Tensor, + weight_scale: Tensor, + weight_zero_point: Tensor, + bias: Tensor, +) -> Tensor: ... +def _wrapped_quantized_linear_prepacked( + input: Tensor, + input_scale: Tensor, + input_zero_point: Tensor, + packed_weight: Tensor, + output_scale: Tensor, + output_zero_point: Tensor, + out_channel: _int, +) -> Tensor: ... +def abs(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + abs(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the absolute value of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = |\text{input}_{i}| + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.abs(torch.tensor([-1, -2, 3])) + tensor([ 1, 2, 3]) + """ + +def abs_(input: Tensor) -> Tensor: ... +def absolute(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + absolute(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.abs` + """ + +def acos(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + acos(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the inverse cosine of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = \cos^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.3348, -0.5889, 0.2005, -0.1584]) + >>> torch.acos(a) + tensor([ 1.2294, 2.2004, 1.3690, 1.7298]) + """ + +def acos_(input: Tensor) -> Tensor: ... +def acosh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + acosh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \cosh^{-1}(\text{input}_{i}) + + Note: + The domain of the inverse hyperbolic cosine is `[1, inf)` and values outside this range + will be mapped to ``NaN``, except for `+ INF` for which the output is mapped to `+ INF`. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4).uniform_(1, 2) + >>> a + tensor([ 1.3192, 1.9915, 1.9674, 1.7151 ]) + >>> torch.acosh(a) + tensor([ 0.7791, 1.3120, 1.2979, 1.1341 ]) + """ + +def acosh_(input: Tensor) -> Tensor: ... +def adaptive_avg_pool1d(input: Tensor, output_size: _int | _size) -> Tensor: ... +def adaptive_max_pool1d( + input: Tensor, + output_size: _int | _size, +) -> tuple[Tensor, Tensor]: ... +@overload +def add( + input: Tensor | Number | _complex, + other: Tensor | Number | _complex, + *, + alpha: Number | _complex | None = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + +@overload +def add(self: Tensor, alpha: Number | _complex, other: Tensor) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + +@overload +def add( + self: Tensor, + alpha: Number | _complex, + other: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + add(input, other, *, alpha=1, out=None) -> Tensor + + Adds :attr:`other`, scaled by :attr:`alpha`, to :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i + \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to add to :attr:`input`. + + Keyword arguments: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.0202, 1.0985, 1.3506, -0.6056]) + >>> torch.add(a, 20) + tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) + + >>> b = torch.randn(4) + >>> b + tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + >>> c = torch.randn(4, 1) + >>> c + tensor([[ 0.3743], + [-1.7724], + [-0.5811], + [-0.8017]]) + >>> torch.add(b, c, alpha=10) + tensor([[ 2.7695, 3.3930, 4.3672, 4.1450], + [-18.6971, -18.0736, -17.0994, -17.3216], + [ -6.7845, -6.1610, -5.1868, -5.4090], + [ -8.9902, -8.3667, -7.3925, -7.6147]]) + """ + +@overload +def addbmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + batch1: Tensor, + batch2: Tensor, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addbmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + batch1: Tensor, + batch2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addbmm( + input: Tensor, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addbmm( + beta: Number | _complex, + self: Tensor, + batch1: Tensor, + batch2: Tensor, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addbmm( + beta: Number | _complex, + self: Tensor, + batch1: Tensor, + batch2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored + in :attr:`batch1` and :attr:`batch2`, + with a reduced add step (all matrix multiplications get accumulated + along the first dimension). + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the + same number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + .. math:: + out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` + must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.addbmm(M, batch1, batch2) + tensor([[ 6.6311, 0.0503, 6.9768, -12.0362, -2.1653], + [ -4.8185, -1.4255, -6.6760, 8.9453, 2.5743], + [ -3.8202, 4.3691, 1.0943, -1.1109, 5.4730]]) + """ + +@overload +def addcdiv( + self: Tensor, + value: Number | _complex, + tensor1: Tensor, + tensor2: Tensor, +) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + +@overload +def addcdiv( + self: Tensor, + value: Number | _complex, + tensor1: Tensor, + tensor2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + +@overload +def addcdiv( + input: Tensor, + tensor1: Tensor, + tensor2: Tensor, + *, + value: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, + multiplies the result by the scalar :attr:`value` and adds it to :attr:`input`. + + .. warning:: + Integer division with addcdiv is no longer supported, and in a future + release addcdiv will perform a true division of tensor1 and tensor2. + The historic addcdiv behavior can be implemented as + (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) + for integer inputs and as (input + value * tensor1 / tensor2) for float inputs. + The future addcdiv behavior is just the latter implementation: + (input + value * tensor1 / tensor2), for all dtypes. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i} + + + The shapes of :attr:`input`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the numerator tensor + tensor2 (Tensor): the denominator tensor + + Keyword args: + value (Number, optional): multiplier for :math:`\text{tensor1} / \text{tensor2}` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcdiv(t, t1, t2, value=0.1) + tensor([[-0.2312, -3.6496, 0.1312], + [-1.0428, 3.4292, -0.1030], + [-0.5369, -0.9829, 0.0430]]) + """ + +@overload +def addcmul( + self: Tensor, + value: Number | _complex, + tensor1: Tensor, + tensor2: Tensor, +) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + +@overload +def addcmul( + self: Tensor, + value: Number | _complex, + tensor1: Tensor, + tensor2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + +@overload +def addcmul( + input: Tensor, + tensor1: Tensor, + tensor2: Tensor, + *, + value: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor + + Performs the element-wise multiplication of :attr:`tensor1` + by :attr:`tensor2`, multiplies the result by the scalar :attr:`value` + and adds it to :attr:`input`. + + .. math:: + \text{out}_i = \text{input}_i + \text{value} \times \text{tensor1}_i \times \text{tensor2}_i + + The shapes of :attr:`tensor`, :attr:`tensor1`, and :attr:`tensor2` must be + :ref:`broadcastable `. + + For inputs of type `FloatTensor` or `DoubleTensor`, :attr:`value` must be + a real number, otherwise an integer. + + Args: + input (Tensor): the tensor to be added + tensor1 (Tensor): the tensor to be multiplied + tensor2 (Tensor): the tensor to be multiplied + + Keyword args: + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(1, 3) + >>> t1 = torch.randn(3, 1) + >>> t2 = torch.randn(1, 3) + >>> torch.addcmul(t, t1, t2, value=0.1) + tensor([[-0.8635, -0.6391, 1.6174], + [-0.7617, -0.5879, 1.7388], + [-0.8353, -0.6249, 1.6511]]) + """ + +@overload +def addmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat1: Tensor, + mat2: Tensor, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat1: Tensor, + mat2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + out_dtype: _dtype, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + beta: Number | _complex, + self: Tensor, + mat1: Tensor, + mat2: Tensor, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmm( + beta: Number | _complex, + self: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addmm(input, mat1, mat2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. + The matrix :attr:`input` is added to the final result. + + If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a :math:`(n \times p)` tensor + and :attr:`out` will be a :math:`(n \times p)` tensor. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat1` and :attr:`mat2` and the added matrix :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operation has support for arguments with :ref:`sparse layouts`. If + :attr:`input` is sparse the result will have the same layout and if :attr:`out` + is provided it must have the same layout as :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): matrix to be added + mat1 (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2, 3) + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.addmm(M, mat1, mat2) + tensor([[-4.8716, 1.4671, -1.3746], + [ 0.7573, -3.9555, -2.8681]]) + """ + +@overload +def addmv( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat: Tensor, + vec: Tensor, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat: Tensor, + vec: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv( + input: Tensor, + mat: Tensor, + vec: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv( + beta: Number | _complex, + self: Tensor, + mat: Tensor, + vec: Tensor, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv( + beta: Number | _complex, + self: Tensor, + mat: Tensor, + vec: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`mat` and + the vector :attr:`vec`. + The vector :attr:`input` is added to the final result. + + If :attr:`mat` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a 1-D tensor of size `n` and + :attr:`out` will be 1-D tensor of size `n`. + + :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between + :attr:`mat` and :attr:`vec` and the added tensor :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + Args: + input (Tensor): vector to be added + mat (Tensor): matrix to be matrix multiplied + vec (Tensor): vector to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(2) + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.addmv(M, mat, vec) + tensor([-0.3768, -5.5565]) + """ + +@overload +def addmv_( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat: Tensor, + vec: Tensor, +) -> Tensor: ... +@overload +def addmv_( + input: Tensor, + mat: Tensor, + vec: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, +) -> Tensor: ... +@overload +def addmv_( + beta: Number | _complex, + self: Tensor, + mat: Tensor, + vec: Tensor, +) -> Tensor: ... +@overload +def addr( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + vec1: Tensor, + vec2: Tensor, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +@overload +def addr( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + vec1: Tensor, + vec2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +@overload +def addr( + input: Tensor, + vec1: Tensor, + vec2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +@overload +def addr( + beta: Number | _complex, + self: Tensor, + vec1: Tensor, + vec2: Tensor, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +@overload +def addr( + beta: Number | _complex, + self: Tensor, + vec1: Tensor, + vec2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor + + Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` + and adds it to the matrix :attr:`input`. + + Optional values :attr:`beta` and :attr:`alpha` are scaling factors on the + outer product between :attr:`vec1` and :attr:`vec2` and the added matrix + :attr:`input` respectively. + + .. math:: + \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector + of size `m`, then :attr:`input` must be + :ref:`broadcastable ` with a matrix of size + :math:`(n \times m)` and :attr:`out` will be a matrix of size + :math:`(n \times m)`. + + Args: + input (Tensor): matrix to be added + vec1 (Tensor): the first vector of the outer product + vec2 (Tensor): the second vector of the outer product + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{vec1} \otimes \text{vec2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> vec1 = torch.arange(1., 4.) + >>> vec2 = torch.arange(1., 3.) + >>> M = torch.zeros(3, 2) + >>> torch.addr(M, vec1, vec2) + tensor([[ 1., 2.], + [ 2., 4.], + [ 3., 6.]]) + """ + +def adjoint(input: Tensor) -> Tensor: + r""" + adjoint(input: Tensor) -> Tensor + Returns a view of the tensor conjugated and with the last two dimensions transposed. + + ``x.adjoint()`` is equivalent to ``x.transpose(-2, -1).conj()`` for complex tensors and + to ``x.transpose(-2, -1)`` for real tensors. + + Args: + {input} + + Example:: + + >>> x = torch.arange(4, dtype=torch.float) + >>> A = torch.complex(x, x).reshape(2, 2) + >>> A + tensor([[0.+0.j, 1.+1.j], + [2.+2.j, 3.+3.j]]) + >>> A.adjoint() + tensor([[0.-0.j, 2.-2.j], + [1.-1.j, 3.-3.j]]) + >>> (A.adjoint() == A.mH).all() + tensor(True) + """ + +def affine_grid_generator( + theta: Tensor, + size: Sequence[_int | SymInt], + align_corners: _bool, +) -> Tensor: ... +def alias_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.alias`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def all(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + all(input: Tensor, *, out=None) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + +@overload +def all( + input: Tensor, + dim: _size | None = None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + all(input: Tensor, *, out=None) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + +@overload +def all( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + all(input: Tensor, *, out=None) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + +@overload +def all( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + all(input: Tensor, *, out=None) -> Tensor + + Tests if all elements in :attr:`input` evaluate to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.all(a) + tensor(False, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.all(a) + tensor(False) + + .. function:: all(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if all elements in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(4, 2).bool() + >>> a + tensor([[True, True], + [True, False], + [True, True], + [True, True]], dtype=torch.bool) + >>> torch.all(a, dim=1) + tensor([ True, False, True, True], dtype=torch.bool) + >>> torch.all(a, dim=0) + tensor([ True, False], dtype=torch.bool) + """ + +def allclose( + input: Tensor, + other: Tensor, + rtol: _float = 1e-05, + atol: _float = 1e-08, + equal_nan: _bool = False, +) -> _bool: + r""" + allclose(input: Tensor, other: Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> bool + + This function checks if :attr:`input` and :attr:`other` satisfy the condition: + + .. math:: + \lvert \text{input}_i - \text{other}_i \rvert \leq \texttt{atol} + \texttt{rtol} \times \lvert \text{other}_i \rvert + + elementwise, for all elements of :attr:`input` and :attr:`other`. The behaviour of this function is analogous to + `numpy.allclose `_ + + Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + atol (float, optional): absolute tolerance. Default: 1e-08 + rtol (float, optional): relative tolerance. Default: 1e-05 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + + Example:: + + >>> torch.allclose(torch.tensor([10000., 1e-07]), torch.tensor([10000.1, 1e-08])) + False + >>> torch.allclose(torch.tensor([10000., 1e-08]), torch.tensor([10000.1, 1e-09])) + True + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')])) + False + >>> torch.allclose(torch.tensor([1.0, float('nan')]), torch.tensor([1.0, float('nan')]), equal_nan=True) + True + """ + +def alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def alpha_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def amax( + input: Tensor, + dim: _int | _size = (), + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + amax(input, dim, keepdim=False, *, out=None) -> Tensor + + Returns the maximum value of each slice of the :attr:`input` tensor in the given + dimension(s) :attr:`dim`. + + .. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices. + + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.8177, 1.4878, -0.2491, 0.9130], + [-0.7158, 1.1775, 2.0992, 0.4817], + [-0.0053, 0.0164, -1.3738, -0.0507], + [ 1.9700, 1.1106, -1.0318, -1.0816]]) + >>> torch.amax(a, 1) + tensor([1.4878, 2.0992, 0.0164, 1.9700]) + """ + +def amin( + input: Tensor, + dim: _int | _size = (), + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + amin(input, dim, keepdim=False, *, out=None) -> Tensor + + Returns the minimum value of each slice of the :attr:`input` tensor in the given + dimension(s) :attr:`dim`. + + .. note:: + The difference between ``max``/``min`` and ``amax``/``amin`` is: + - ``amax``/``amin`` supports reducing on multiple dimensions, + - ``amax``/``amin`` does not return indices. + + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.6451, -0.4866, 0.2987, -1.3312], + [-0.5744, 1.2980, 1.8397, -0.2713], + [ 0.9128, 0.9214, -1.7268, -0.2995], + [ 0.9023, 0.4853, 0.9075, -1.6165]]) + >>> torch.amin(a, 1) + tensor([-1.3312, -0.5744, -1.7268, -1.6165]) + """ + +def aminmax( + input: Tensor, + *, + dim: _int | None = None, + keepdim: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.aminmax: + r""" + aminmax(input, *, dim=None, keepdim=False, out=None) -> (Tensor min, Tensor max) + + Computes the minimum and maximum values of the :attr:`input` tensor. + + Args: + input (Tensor): + The input tensor + + Keyword Args: + dim (Optional[int]): + The dimension along which to compute the values. If `None`, + computes the values over the entire :attr:`input` tensor. + Default is `None`. + keepdim (bool): + If `True`, the reduced dimensions will be kept in the output + tensor as dimensions with size 1 for broadcasting, otherwise + they will be removed, as if calling (:func:`torch.squeeze`). + Default is `False`. + out (Optional[Tuple[Tensor, Tensor]]): + Optional tensors on which to write the result. Must have the same + shape and dtype as the expected output. + Default is `None`. + + Returns: + A named tuple `(min, max)` containing the minimum and maximum values. + + Raises: + RuntimeError + If any of the dimensions to compute the values over has size 0. + + .. note:: + NaN values are propagated to the output if at least one value is NaN. + + .. seealso:: + :func:`torch.amin` computes just the minimum value + :func:`torch.amax` computes just the maximum value + + Example:: + + >>> torch.aminmax(torch.tensor([1, -3, 5])) + torch.return_types.aminmax( + min=tensor(-3), + max=tensor(5)) + + >>> # aminmax propagates NaNs + >>> torch.aminmax(torch.tensor([1, -3, 5, torch.nan])) + torch.return_types.aminmax( + min=tensor(nan), + max=tensor(nan)) + + >>> t = torch.arange(10).view(2, 5) + >>> t + tensor([[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9]]) + >>> t.aminmax(dim=0, keepdim=True) + torch.return_types.aminmax( + min=tensor([[0, 1, 2, 3, 4]]), + max=tensor([[5, 6, 7, 8, 9]])) + """ + +def angle(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + angle(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the element-wise angle (in radians) of the given :attr:`input` tensor. + + .. math:: + \text{out}_{i} = angle(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + .. note:: Starting in PyTorch 1.8, angle returns pi for negative real numbers, + zero for non-negative real numbers, and propagates NaNs. Previously + the function would return zero for all real numbers and not propagate + floating-point NaNs. + + Example:: + + >>> torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159 + tensor([ 135., 135, -45]) + """ + +@overload +def any(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + any(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + +@overload +def any( + input: Tensor, + dim: _size | None = None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + any(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + +@overload +def any( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + any(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + +@overload +def any( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + any(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Tests if any element in :attr:`input` evaluates to `True`. + + .. note:: This function matches the behaviour of NumPy in returning + output of dtype `bool` for all supported dtypes except `uint8`. + For `uint8` the dtype of output is `uint8` itself. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(1, 2).bool() + >>> a + tensor([[False, True]], dtype=torch.bool) + >>> torch.any(a) + tensor(True, dtype=torch.bool) + >>> a = torch.arange(0, 3) + >>> a + tensor([0, 1, 2]) + >>> torch.any(a) + tensor(True) + + .. function:: any(input, dim, keepdim=False, *, out=None) -> Tensor + :noindex: + + For each row of :attr:`input` in the given dimension :attr:`dim`, + returns `True` if any element in the row evaluate to `True` and `False` otherwise. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 2) < 0 + >>> a + tensor([[ True, True], + [False, True], + [ True, True], + [False, False]]) + >>> torch.any(a, 1) + tensor([ True, True, True, False]) + >>> torch.any(a, 0) + tensor([True, True]) + """ + +@overload +def arange( + start: Number, + end: Number, + step: Number, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + start: Number, + end: Number, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + end: Number, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + end: Number | _complex, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + start: Number | _complex, + end: Number | _complex, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +@overload +def arange( + start: Number | _complex, + end: Number | _complex, + step: Number | _complex = 1, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lceil \frac{\text{end} - \text{start}}{\text{step}} \right\rceil` + with values from the interval ``[start, end)`` taken with common difference + :attr:`step` beginning from `start`. + + Note: When using floating-point dtypes (especially reduced precision types like ``bfloat16``), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer :attr:`step` is subject to floating point rounding errors when + comparing against :attr:`end`; to avoid inconsistency, we advise subtracting a small epsilon from :attr:`end` + in such cases. + + .. math:: + \text{out}_{{i+1}} = \text{out}_{i} + \text{step} + + Args: + start (Number, optional): the starting value for the set of points. Default: ``0``. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `stop` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.arange(5) + tensor([ 0, 1, 2, 3, 4]) + >>> torch.arange(1, 4) + tensor([ 1, 2, 3]) + >>> torch.arange(1, 2.5, 0.5) + tensor([ 1.0000, 1.5000, 2.0000]) + """ + +def arccos(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arccos(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.acos`. + """ + +def arccos_(input: Tensor) -> Tensor: ... +def arccosh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arccosh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.acosh`. + """ + +def arccosh_(input: Tensor) -> Tensor: ... +def arcsin(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arcsin(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.asin`. + """ + +def arcsin_(input: Tensor) -> Tensor: ... +def arcsinh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arcsinh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.asinh`. + """ + +def arcsinh_(input: Tensor) -> Tensor: ... +def arctan(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arctan(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.atan`. + """ + +def arctan2( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + arctan2(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + Alias for :func:`torch.atan2`. + """ + +def arctan_(input: Tensor) -> Tensor: ... +def arctanh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + arctanh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Alias for :func:`torch.atanh`. + """ + +def arctanh_(input: Tensor) -> Tensor: ... +def argmax( + input: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + argmax(input) -> LongTensor + + Returns the indices of the maximum value of all elements in the :attr:`input` tensor. + + This is the second value returned by :meth:`torch.max`. See its + documentation for the exact semantics of this method. + + .. note:: If there are multiple maximal values then the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a) + tensor(0) + + .. function:: argmax(input, dim, keepdim=False) -> LongTensor + :noindex: + + Returns the indices of the maximum values of a tensor across a dimension. + + This is the second value returned by :meth:`torch.max`. See its + documentation for the exact semantics of this method. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, the argmax of the flattened input is returned. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [ 0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195]]) + >>> torch.argmax(a, dim=1) + tensor([ 0, 2, 0, 1]) + """ + +def argmin( + input: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + argmin(input, dim=None, keepdim=False) -> LongTensor + + Returns the indices of the minimum value(s) of the flattened tensor or along a dimension + + This is the second value returned by :meth:`torch.min`. See its + documentation for the exact semantics of this method. + + .. note:: If there are multiple minimal values then the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, the argmin of the flattened input is returned. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.1139, 0.2254, -0.1381, 0.3687], + [ 1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [ 1.7809, -1.2960, 0.9384, 0.1438]]) + >>> torch.argmin(a) + tensor(13) + >>> torch.argmin(a, dim=1) + tensor([ 2, 1, 3, 1]) + >>> torch.argmin(a, dim=1, keepdim=True) + tensor([[2], + [1], + [3], + [1]]) + """ + +@overload +def argsort( + input: Tensor, + *, + stable: _bool, + dim: _int = -1, + descending: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + +@overload +def argsort( + input: Tensor, + dim: _int = -1, + descending: _bool = False, +) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + +@overload +def argsort( + input: Tensor, + dim: str | EllipsisType | None, + descending: _bool = False, +) -> Tensor: + r""" + argsort(input, dim=-1, descending=False, stable=False) -> Tensor + + Returns the indices that sort a tensor along a given dimension in ascending + order by value. + + This is the second value returned by :meth:`torch.sort`. See its documentation + for the exact semantics of this method. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. If ``False``, the relative order of values + which compare equal is not guaranteed. ``True`` is slower. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): controls the relative order of equivalent elements + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0785, 1.5267, -0.8521, 0.4065], + [ 0.1598, 0.0788, -0.0745, -1.2700], + [ 1.2208, 1.0722, -0.7064, 1.2564], + [ 0.0669, -0.2318, -0.8229, -0.9280]]) + + + >>> torch.argsort(a, dim=1) + tensor([[2, 0, 3, 1], + [3, 2, 1, 0], + [2, 1, 0, 3], + [3, 2, 1, 0]]) + """ + +def argwhere(input: Tensor) -> Tensor: + r""" + argwhere(input) -> Tensor + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + .. note:: + This function is similar to NumPy's `argwhere`. + + When :attr:`input` is on CUDA, this function causes host-device synchronization. + + Args: + {input} + + Example:: + + >>> t = torch.tensor([1, 0, 1]) + >>> torch.argwhere(t) + tensor([[0], + [2]]) + >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) + >>> torch.argwhere(t) + tensor([[0, 0], + [0, 2], + [1, 1], + [1, 2]]) + """ + +def as_strided( + input: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, +) -> Tensor: + r""" + as_strided(input, size, stride, storage_offset=None) -> Tensor + + Create a view of an existing `torch.Tensor` :attr:`input` with specified + :attr:`size`, :attr:`stride` and :attr:`storage_offset`. + + .. warning:: + Prefer using other view functions, like :meth:`torch.Tensor.view` or + :meth:`torch.Tensor.expand`, to setting a view's strides manually with + `as_strided`, as this function will throw an error on non-standard Pytorch + backends (that do not have a concept of stride) and the result will depend + on the current layout in memory. The constructed view must only refer to + elements within the Tensor's storage or a runtime error will be thrown. + If the generated view is "overlapped" (with multiple indices referring to + the same element in memory), the behavior of inplace operations on this view + is undefined (and might not throw runtime errors). + + Args: + input (Tensor): the input tensor. + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor. + If ``None``, the storage_offset of the output tensor will match the input tensor. + + Example:: + + >>> x = torch.randn(3, 3) + >>> x + tensor([[ 0.9039, 0.6291, 1.0795], + [ 0.1586, 2.1939, -0.4900], + [-0.1909, -0.7503, 1.9355]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2)) + >>> t + tensor([[0.9039, 1.0795], + [0.6291, 0.1586]]) + >>> t = torch.as_strided(x, (2, 2), (1, 2), 1) + tensor([[0.6291, 0.1586], + [1.0795, 2.1939]]) + """ + +def as_strided_( + input: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, +) -> Tensor: ... +def as_strided_copy( + input: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.as_strided`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def as_strided_scatter( + input: Tensor, + src: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, +) -> Tensor: + r""" + as_strided_scatter(input, src, size, stride, storage_offset=None) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` along + the elements corresponding to the result of calling + input.as_strided(size, stride, storage_offset). + + This function returns a tensor with fresh storage; it does not + return a view. + + Args: + input (Tensor): the input tensor. + size (tuple or ints): the shape of the output tensor + stride (tuple or ints): the stride of the output tensor + storage_offset (int, optional): the offset in the underlying storage of the output tensor + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + `torch.as_strided(input, size, stride, storage_offset)` + + Example:: + + >>> a = torch.arange(4).reshape(2, 2) + 1 + >>> a + tensor([[1, 2], + [3, 4]]) + >>> b = torch.zeros(3, 3) + >>> b + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> torch.as_strided_scatter(b, a, (2, 2), (1, 2)) + tensor([[1., 3., 2.], + [4., 0., 0.], + [0., 0., 0.]]) + """ + +def as_tensor( + data: Any, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, +) -> Tensor: + r""" + as_tensor(data: Any, dtype: Optional[dtype] = None, device: Optional[DeviceLikeType]) -> Tensor + + Converts :attr:`data` into a tensor, sharing data and preserving autograd + history if possible. + + If :attr:`data` is already a tensor with the requested dtype and device + then :attr:`data` itself is returned, but if :attr:`data` is a + tensor with a different dtype or device then it's copied as if using + `data.to(dtype=dtype, device=device)`. + + If :attr:`data` is a NumPy array (an ndarray) with the same dtype and device then a + tensor is constructed using :func:`torch.from_numpy`. + + If :attr:`data` is a CuPy array, the returned tensor will be located on the same device as the CuPy array unless + specifically overwritten by :attr:`device` or a default device. + + .. seealso:: + + :func:`torch.tensor` never shares its data and creates a new "leaf tensor" (see :doc:`/notes/autograd`). + + + Args: + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + + + Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.as_tensor(a, device=torch.device('cuda')) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([1, 2, 3]) + """ + +def asarray( + obj: Any, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + copy: _bool | None = None, + requires_grad: _bool = False, +) -> Tensor: + r""" + asarray(obj: Any, *, dtype: Optional[dtype], device: Optional[DeviceLikeType], copy: Optional[bool] = None, requires_grad: bool = False) -> Tensor # noqa: B950 + + Converts :attr:`obj` to a tensor. + + :attr:`obj` can be one of: + + 1. a tensor + 2. a NumPy array or a NumPy scalar + 3. a DLPack capsule + 4. an object that implements Python's buffer protocol + 5. a scalar + 6. a sequence of scalars + + When :attr:`obj` is a tensor, NumPy array, or DLPack capsule the returned tensor will, + by default, not require a gradient, have the same datatype as :attr:`obj`, be on the + same device, and share memory with it. These properties can be controlled with the + :attr:`dtype`, :attr:`device`, :attr:`copy`, and :attr:`requires_grad` keyword arguments. + If the returned tensor is of a different datatype, on a different device, or a copy is + requested then it will not share its memory with :attr:`obj`. If :attr:`requires_grad` + is ``True`` then the returned tensor will require a gradient, and if :attr:`obj` is + also a tensor with an autograd history then the returned tensor will have the same history. + + When :attr:`obj` is not a tensor, NumPy array, or DLPack capsule but implements Python's + buffer protocol then the buffer is interpreted as an array of bytes grouped according to + the size of the datatype passed to the :attr:`dtype` keyword argument. (If no datatype is + passed then the default floating point datatype is used, instead.) The returned tensor + will have the specified datatype (or default floating point datatype if none is specified) + and, by default, be on the CPU device and share memory with the buffer. + + When :attr:`obj` is a NumPy scalar, the returned tensor will be a 0-dimensional tensor on + the CPU and that doesn't share its memory (i.e. ``copy=True``). By default datatype will + be the PyTorch datatype corresponding to the NumPy's scalar's datatype. + + When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the + returned tensor will, by default, infer its datatype from the scalar values, be on the + current default device, and not share its memory. + + .. seealso:: + + :func:`torch.tensor` creates a tensor that always copies the data from the input object. + :func:`torch.from_numpy` creates a tensor that always shares memory from NumPy arrays. + :func:`torch.frombuffer` creates a tensor that always shares memory from objects that + implement the buffer protocol. + :func:`torch.from_dlpack` creates a tensor that always shares memory from + DLPack capsules. + + Args: + obj (object): a tensor, NumPy array, DLPack Capsule, object that implements Python's + buffer protocol, scalar, or sequence of scalars. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the datatype of the returned tensor. + Default: ``None``, which causes the datatype of the returned tensor to be + inferred from :attr:`obj`. + copy (bool, optional): controls whether the returned tensor shares memory with :attr:`obj`. + Default: ``None``, which causes the returned tensor to share memory with :attr:`obj` + whenever possible. If ``True`` then the returned tensor does not share its memory. + If ``False`` then the returned tensor shares its memory with :attr:`obj` and an + error is thrown if it cannot. + device (:class:`torch.device`, optional): the device of the returned tensor. + Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if + :attr:`obj` is a Python sequence, the current default device will be used. + requires_grad (bool, optional): whether the returned tensor requires grad. + Default: ``False``, which causes the returned tensor not to require a gradient. + If ``True``, then the returned tensor will require a gradient, and if :attr:`obj` + is also a tensor with an autograd history then the returned tensor will have + the same history. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> # Shares memory with tensor 'a' + >>> b = torch.asarray(a) + >>> a.data_ptr() == b.data_ptr() + True + >>> # Forces memory copy + >>> c = torch.asarray(a, copy=True) + >>> a.data_ptr() == c.data_ptr() + False + + >>> a = torch.tensor([1., 2., 3.], requires_grad=True) + >>> b = a + 2 + >>> b + tensor([3., 4., 5.], grad_fn=) + >>> # Shares memory with tensor 'b', with no grad + >>> c = torch.asarray(b) + >>> c + tensor([3., 4., 5.]) + >>> # Shares memory with tensor 'b', retaining autograd history + >>> d = torch.asarray(b, requires_grad=True) + >>> d + tensor([3., 4., 5.], grad_fn=) + + >>> array = numpy.array([1, 2, 3]) + >>> # Shares memory with array 'array' + >>> t1 = torch.asarray(array) + >>> array.__array_interface__['data'][0] == t1.data_ptr() + True + >>> # Copies memory due to dtype mismatch + >>> t2 = torch.asarray(array, dtype=torch.float32) + >>> array.__array_interface__['data'][0] == t2.data_ptr() + False + + >>> scalar = numpy.float64(0.5) + >>> torch.asarray(scalar) + tensor(0.5000, dtype=torch.float64) + """ + +def asin(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + asin(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the arcsine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sin^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5962, 1.4985, -0.4396, 1.4525]) + >>> torch.asin(a) + tensor([-0.6387, nan, -0.4552, nan]) + """ + +def asin_(input: Tensor) -> Tensor: ... +def asinh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + asinh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sinh^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1606, -1.4267, -1.0899, -1.0250 ]) + >>> torch.asinh(a) + tensor([ 0.1599, -1.1534, -0.9435, -0.8990 ]) + """ + +def asinh_(input: Tensor) -> Tensor: ... +def atan(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + atan(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the arctangent of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tan^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.2341, 0.2539, -0.6256, -0.6448]) + >>> torch.atan(a) + tensor([ 0.2299, 0.2487, -0.5591, -0.5727]) + """ + +def atan2( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + atan2(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Element-wise arctangent of :math:`\text{input}_{i} / \text{other}_{i}` + with consideration of the quadrant. Returns a new tensor with the signed angles + in radians between vector :math:`(\text{other}_{i}, \text{input}_{i})` + and vector :math:`(1, 0)`. (Note that :math:`\text{other}_{i}`, the second + parameter, is the x-coordinate, while :math:`\text{input}_{i}`, the first + parameter, is the y-coordinate.) + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.9041, 0.0196, -0.3108, -2.4423]) + >>> torch.atan2(a, torch.randn(4)) + tensor([ 0.9833, 0.0811, -1.9743, -1.4151]) + """ + +def atan_(input: Tensor) -> Tensor: ... +def atanh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + atanh(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. + + Note: + The domain of the inverse hyperbolic tangent is `(-1, 1)` and values outside this range + will be mapped to ``NaN``, except for the values `1` and `-1` for which the output is + mapped to `+/-INF` respectively. + + .. math:: + \text{out}_{i} = \tanh^{-1}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4).uniform_(-1, 1) + >>> a + tensor([ -0.9385, 0.2968, -0.8591, -0.1871 ]) + >>> torch.atanh(a) + tensor([ -1.7253, 0.3060, -1.2899, -0.1893 ]) + """ + +def atanh_(input: Tensor) -> Tensor: ... +def avg_pool1d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + ceil_mode: _bool = False, + count_include_pad: _bool = True, +) -> Tensor: ... +@overload +def baddbmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + batch1: Tensor, + batch2: Tensor, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + batch1: Tensor, + batch2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + input: Tensor, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + input: Tensor, + batch1: Tensor, + batch2: Tensor, + out_dtype: _dtype, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + beta: Number | _complex, + self: Tensor, + batch1: Tensor, + batch2: Tensor, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def baddbmm( + beta: Number | _complex, + self: Tensor, + batch1: Tensor, + batch2: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + baddbmm(input, batch1, batch2, out_dtype=None, *, beta=1, alpha=1, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices in :attr:`batch1` + and :attr:`batch2`. + :attr:`input` is added to the final result. + + :attr:`batch1` and :attr:`batch2` must be 3-D tensors each containing the same + number of matrices. + + If :attr:`batch1` is a :math:`(b \times n \times m)` tensor, :attr:`batch2` is a + :math:`(b \times m \times p)` tensor, then :attr:`input` must be + :ref:`broadcastable ` with a + :math:`(b \times n \times p)` tensor and :attr:`out` will be a + :math:`(b \times n \times p)` tensor. Both :attr:`alpha` and :attr:`beta` mean the + same as the scaling factors used in :meth:`torch.addbmm`. + + .. math:: + \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + + If :attr:`beta` is 0, then the content of :attr:`input` will be ignored, and `nan` and `inf` in + it will not be propagated. + + For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and + :attr:`alpha` must be real numbers, otherwise they should be integers. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the tensor to be added + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{batch1} \mathbin{@} \text{batch2}` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + + Example:: + + >>> M = torch.randn(10, 3, 5) + >>> batch1 = torch.randn(10, 3, 4) + >>> batch2 = torch.randn(10, 4, 5) + >>> torch.baddbmm(M, batch1, batch2).size() + torch.Size([10, 3, 5]) + """ + +@overload +def bartlett_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + bartlett_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Bartlett window function. + + .. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.bartlett_window(L, periodic=True)`` equal to + ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +@overload +def bartlett_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + bartlett_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Bartlett window function. + + .. math:: + w[n] = 1 - \left| \frac{2n}{N-1} - 1 \right| = \begin{cases} + \frac{2n}{N - 1} & \text{if } 0 \leq n \leq \frac{N - 1}{2} \\ + 2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\ + \end{cases}, + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.bartlett_window(L, periodic=True)`` equal to + ``torch.bartlett_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +def batch_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + momentum: _float, + eps: _float, + cudnn_enabled: _bool, +) -> Tensor: ... +def batch_norm_backward_elemt( + grad_out: Tensor, + input: Tensor, + mean: Tensor, + invstd: Tensor, + weight: Tensor | None, + sum_dy: Tensor, + sum_dy_xmu: Tensor, + count: Tensor, +) -> Tensor: ... +def batch_norm_backward_reduce( + grad_out: Tensor, + input: Tensor, + mean: Tensor, + invstd: Tensor, + weight: Tensor | None, + input_g: _bool, + weight_g: _bool, + bias_g: _bool, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +def batch_norm_elemt( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + mean: Tensor, + invstd: Tensor, + eps: _float, + *, + out: Tensor | None = None, +) -> Tensor: ... +def batch_norm_gather_stats( + input: Tensor, + mean: Tensor, + invstd: Tensor, + running_mean: Tensor | None, + running_var: Tensor | None, + momentum: _float, + eps: _float, + count: _int, +) -> tuple[Tensor, Tensor]: ... +def batch_norm_gather_stats_with_counts( + input: Tensor, + mean: Tensor, + invstd: Tensor, + running_mean: Tensor | None, + running_var: Tensor | None, + momentum: _float, + eps: _float, + counts: Tensor, +) -> tuple[Tensor, Tensor]: ... +def batch_norm_stats(input: Tensor, eps: _float) -> tuple[Tensor, Tensor]: ... +def batch_norm_update_stats( + input: Tensor, + running_mean: Tensor | None, + running_var: Tensor | None, + momentum: _float, +) -> tuple[Tensor, Tensor]: ... +@overload +def bernoulli( + input: Tensor, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) -> Tensor + + Draws binary random numbers (0 or 1) from a Bernoulli distribution. + + The :attr:`input` tensor should be a tensor containing probabilities + to be used for drawing the binary random number. + Hence, all values in :attr:`input` have to be in the range: + :math:`0 \leq \text{input}_i \leq 1`. + + The :math:`\text{i}^{th}` element of the output tensor will draw a + value :math:`1` according to the :math:`\text{i}^{th}` probability value given + in :attr:`input`. + + .. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) + + The returned :attr:`out` tensor only has values 0 or 1 and is of the same + shape as :attr:`input`. + + :attr:`out` can have integral ``dtype``, but :attr:`input` must have floating + point ``dtype``. + + Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + +@overload +def bernoulli( + input: Tensor, + p: _float, + *, + generator: Generator | None = None, +) -> Tensor: + r""" + bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) -> Tensor + + Draws binary random numbers (0 or 1) from a Bernoulli distribution. + + The :attr:`input` tensor should be a tensor containing probabilities + to be used for drawing the binary random number. + Hence, all values in :attr:`input` have to be in the range: + :math:`0 \leq \text{input}_i \leq 1`. + + The :math:`\text{i}^{th}` element of the output tensor will draw a + value :math:`1` according to the :math:`\text{i}^{th}` probability value given + in :attr:`input`. + + .. math:: + \text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i}) + + The returned :attr:`out` tensor only has values 0 or 1 and is of the same + shape as :attr:`input`. + + :attr:`out` can have integral ``dtype``, but :attr:`input` must have floating + point ``dtype``. + + Args: + input (Tensor): the input tensor of probability values for the Bernoulli distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.empty(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] + >>> a + tensor([[ 0.1737, 0.0950, 0.3609], + [ 0.7148, 0.0289, 0.2676], + [ 0.9456, 0.8937, 0.7202]]) + >>> torch.bernoulli(a) + tensor([[ 1., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.]]) + + >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 + >>> torch.bernoulli(a) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 + >>> torch.bernoulli(a) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]) + """ + +def bilinear( + input1: Tensor, + input2: Tensor, + weight: Tensor, + bias: Tensor | None = None, +) -> Tensor: ... +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Tensor | None = None, + pos_weight: Tensor | None = None, + reduction: _int = 1, +) -> Tensor: ... +def bincount( + input: Tensor, + weights: Tensor | None = None, + minlength: _int | SymInt = 0, +) -> Tensor: + r""" + bincount(input, weights=None, minlength=0) -> Tensor + + Count the frequency of each value in an array of non-negative ints. + + The number of bins (size 1) is one larger than the largest value in + :attr:`input` unless :attr:`input` is empty, in which case the result is a + tensor of size 0. If :attr:`minlength` is specified, the number of bins is at least + :attr:`minlength` and if :attr:`input` is empty, then the result is tensor of size + :attr:`minlength` filled with zeros. If ``n`` is the value at position ``i``, + ``out[n] += weights[i]`` if :attr:`weights` is specified else + ``out[n] += 1``. + + Note: + This operation may produce nondeterministic gradients when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + Arguments: + input (Tensor): 1-d int tensor + weights (Tensor): optional, weight for each value in the input tensor. + Should be of same size as input tensor. + minlength (int): optional, minimum number of bins. Should be non-negative. + + Returns: + output (Tensor): a tensor of shape ``Size([max(input) + 1])`` if + :attr:`input` is non-empty, else ``Size(0)`` + + Example:: + + >>> input = torch.randint(0, 8, (5,), dtype=torch.int64) + >>> weights = torch.linspace(0, 1, steps=5) + >>> input, weights + (tensor([4, 3, 6, 3, 4]), + tensor([ 0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + + >>> torch.bincount(input) + tensor([0, 0, 0, 2, 2, 0, 1]) + + >>> input.bincount(weights) + tensor([0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 0.5000]) + """ + +def binomial( + count: Tensor, + prob: Tensor, + generator: Generator | None = None, +) -> Tensor: ... +@overload +def bitwise_and( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + +@overload +def bitwise_and(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + +@overload +def bitwise_and( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_and(input, other, *, out=None) -> Tensor + + Computes the bitwise AND of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical AND. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_and(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([1, 0, 3], dtype=torch.int8) + >>> torch.bitwise_and(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ False, True, False]) + """ + +@overload +def bitwise_left_shift( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + +@overload +def bitwise_left_shift(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + +@overload +def bitwise_left_shift( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_left_shift(input, other, *, out=None) -> Tensor + + Computes the left arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i << \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_left_shift(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 24], dtype=torch.int8) + """ + +def bitwise_not(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + bitwise_not(input, *, out=None) -> Tensor + + Computes the bitwise NOT of the given input tensor. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical NOT. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_not(torch.tensor([-1, -2, 3], dtype=torch.int8)) + tensor([ 0, 1, -4], dtype=torch.int8) + """ + +@overload +def bitwise_or( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + +@overload +def bitwise_or(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + +@overload +def bitwise_or( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical OR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_or(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -2, 3], dtype=torch.int8) + >>> torch.bitwise_or(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, True, False]) + """ + +@overload +def bitwise_right_shift( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + +@overload +def bitwise_right_shift(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + +@overload +def bitwise_right_shift( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_right_shift(input, other, *, out=None) -> Tensor + + Computes the right arithmetic shift of :attr:`input` by :attr:`other` bits. + The input tensor must be of integral type. This operator supports + :ref:`broadcasting to a common shape ` and + :ref:`type promotion `. + In any case, if the value of the right operand is negative or is greater + or equal to the number of bits in the promoted left operand, the behavior is undefined. + + The operation applied is: + + .. math:: + \text{out}_i = \text{input}_i >> \text{other}_i + + Args: + input (Tensor or Scalar): the first input tensor + other (Tensor or Scalar): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_right_shift(torch.tensor([-2, -7, 31], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-1, -7, 3], dtype=torch.int8) + """ + +@overload +def bitwise_xor( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + +@overload +def bitwise_xor(self: Number | _complex, other: Tensor) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + +@overload +def bitwise_xor( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bitwise_xor(input, other, *, out=None) -> Tensor + + Computes the bitwise XOR of :attr:`input` and :attr:`other`. The input tensor must be of + integral or Boolean types. For bool tensors, it computes the logical XOR. + + Args: + input: the first input tensor + other: the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.bitwise_xor(torch.tensor([-1, -2, 3], dtype=torch.int8), torch.tensor([1, 0, 3], dtype=torch.int8)) + tensor([-2, -2, 0], dtype=torch.int8) + >>> torch.bitwise_xor(torch.tensor([True, True, False]), torch.tensor([False, True, False])) + tensor([ True, False, False]) + """ + +@overload +def blackman_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + blackman_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Blackman window function. + + .. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.blackman_window(L, periodic=True)`` equal to + ``torch.blackman_window(L + 1, periodic=False)[:-1]``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +@overload +def blackman_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + blackman_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Blackman window function. + + .. math:: + w[n] = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{N - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{N - 1} \right) + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.blackman_window(L, periodic=True)`` equal to + ``torch.blackman_window(L + 1, periodic=False)[:-1]``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +@overload +def bmm( + input: Tensor, + mat2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bmm(input, mat2, out_dtype=None, *, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored in :attr:`input` + and :attr:`mat2`. + + :attr:`input` and :attr:`mat2` must be 3-D tensors each containing + the same number of matrices. + + If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a + :math:`(b \times m \times p)` tensor, :attr:`out` will be a + :math:`(b \times n \times p)` tensor. + + .. math:: + \text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Args: + input (Tensor): the first batch of matrices to be multiplied + mat2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword Args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.randn(10, 3, 4) + >>> mat2 = torch.randn(10, 4, 5) + >>> res = torch.bmm(input, mat2) + >>> res.size() + torch.Size([10, 3, 5]) + """ + +@overload +def bmm( + input: Tensor, + mat2: Tensor, + out_dtype: _dtype, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + bmm(input, mat2, out_dtype=None, *, out=None) -> Tensor + + Performs a batch matrix-matrix product of matrices stored in :attr:`input` + and :attr:`mat2`. + + :attr:`input` and :attr:`mat2` must be 3-D tensors each containing + the same number of matrices. + + If :attr:`input` is a :math:`(b \times n \times m)` tensor, :attr:`mat2` is a + :math:`(b \times m \times p)` tensor, :attr:`out` will be a + :math:`(b \times n \times p)` tensor. + + .. math:: + \text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Args: + input (Tensor): the first batch of matrices to be multiplied + mat2 (Tensor): the second batch of matrices to be multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword Args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.randn(10, 3, 4) + >>> mat2 = torch.randn(10, 4, 5) + >>> res = torch.bmm(input, mat2) + >>> res.size() + torch.Size([10, 3, 5]) + """ + +def broadcast_to(input: Tensor, size: Sequence[_int | SymInt]) -> Tensor: + r""" + broadcast_to(input, shape) -> Tensor + + Broadcasts :attr:`input` to the shape :attr:`\shape`. + Equivalent to calling ``input.expand(shape)``. See :meth:`~Tensor.expand` for details. + + Args: + input (Tensor): the input tensor. + shape (list, tuple, or :class:`torch.Size`): the new shape. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> torch.broadcast_to(x, (3, 3)) + tensor([[1, 2, 3], + [1, 2, 3], + [1, 2, 3]]) + """ + +@overload +def bucketize( + input: Tensor, + boundaries: Tensor, + *, + out_int32: _bool = False, + right: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + + Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the + boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size + as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that + this behavior is opposite the behavior of + `numpy.digitize `_. + More formally, the returned index satisfies the following rules: + + .. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + + Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): determines the behavior for values in :attr:`boundaries`. See the table above. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + + Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) + """ + +@overload +def bucketize( + self: Number | _complex, + boundaries: Tensor, + *, + out_int32: _bool = False, + right: _bool = False, +) -> Tensor: + r""" + bucketize(input, boundaries, *, out_int32=False, right=False, out=None) -> Tensor + + Returns the indices of the buckets to which each value in the :attr:`input` belongs, where the + boundaries of the buckets are set by :attr:`boundaries`. Return a new tensor with the same size + as :attr:`input`. If :attr:`right` is False (default), then the left boundary is open. Note that + this behavior is opposite the behavior of + `numpy.digitize `_. + More formally, the returned index satisfies the following rules: + + .. list-table:: + :widths: 15 85 + :header-rows: 1 + + * - :attr:`right` + - *returned index satisfies* + * - False + - ``boundaries[i-1] < input[m][n]...[l][x] <= boundaries[i]`` + * - True + - ``boundaries[i-1] <= input[m][n]...[l][x] < boundaries[i]`` + + Args: + input (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + boundaries (Tensor): 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): determines the behavior for values in :attr:`boundaries`. See the table above. + out (Tensor, optional): the output tensor, must be the same size as :attr:`input` if provided. + + + Example:: + + >>> boundaries = torch.tensor([1, 3, 5, 7, 9]) + >>> boundaries + tensor([1, 3, 5, 7, 9]) + >>> v = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> v + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.bucketize(v, boundaries) + tensor([[1, 3, 4], + [1, 3, 4]]) + >>> torch.bucketize(v, boundaries, right=True) + tensor([[2, 3, 5], + [2, 3, 5]]) + """ + +def can_cast(from_: _dtype, to: _dtype) -> _bool: + r""" + can_cast(from_, to) -> bool + + Determines if a type conversion is allowed under PyTorch casting rules + described in the type promotion :ref:`documentation `. + + Args: + from\_ (dtype): The original :class:`torch.dtype`. + to (dtype): The target :class:`torch.dtype`. + + Example:: + + >>> torch.can_cast(torch.double, torch.float) + True + >>> torch.can_cast(torch.float, torch.int) + False + """ + +@overload +def cat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cat(tensors, dim=0, *, out=None) -> Tensor + + Concatenates the given sequence of tensors in :attr:`tensors` in the given dimension. + All tensors must either have the same shape (except in the concatenating + dimension) or be a 1-D empty tensor with size ``(0,)``. + + :func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` + and :func:`torch.chunk`. + + :func:`torch.cat` can be best understood via examples. + + .. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + + Args: + tensors (sequence of Tensors): Non-empty tensors provided must have the same shape, + except in the cat dimension. + + dim (int, optional): the dimension over which the tensors are concatenated + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) + """ + +@overload +def cat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: str | EllipsisType | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cat(tensors, dim=0, *, out=None) -> Tensor + + Concatenates the given sequence of tensors in :attr:`tensors` in the given dimension. + All tensors must either have the same shape (except in the concatenating + dimension) or be a 1-D empty tensor with size ``(0,)``. + + :func:`torch.cat` can be seen as an inverse operation for :func:`torch.split` + and :func:`torch.chunk`. + + :func:`torch.cat` can be best understood via examples. + + .. seealso:: + + :func:`torch.stack` concatenates the given sequence along a new dimension. + + Args: + tensors (sequence of Tensors): Non-empty tensors provided must have the same shape, + except in the cat dimension. + + dim (int, optional): the dimension over which the tensors are concatenated + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 0) + tensor([[ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497], + [ 0.6580, -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497]]) + >>> torch.cat((x, x, x), 1) + tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, + -1.0969, -0.4614], + [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, + -0.5790, 0.1497]]) + """ + +def ccol_indices_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +def ceil(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + ceil(input, *, out=None) -> Tensor + + Returns a new tensor with the ceil of the elements of :attr:`input`, + the smallest integer greater than or equal to each element. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + .. math:: + \text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + >>> torch.ceil(a) + tensor([-0., -1., -1., 1.]) + """ + +def ceil_(input: Tensor) -> Tensor: ... +def celu(input: Tensor, alpha: Number | _complex = 1.0) -> Tensor: ... +def celu_(input: Tensor, alpha: Number | _complex = 1.0) -> Tensor: ... +def channel_shuffle(input: Tensor, groups: _int | SymInt) -> Tensor: ... +def cholesky( + input: Tensor, + upper: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cholesky(input, upper=False, *, out=None) -> Tensor + + Computes the Cholesky decomposition of a symmetric positive-definite + matrix :math:`A` or for batches of symmetric positive-definite matrices. + + If :attr:`upper` is ``True``, the returned matrix ``U`` is upper-triangular, and + the decomposition has the form: + + .. math:: + + A = U^TU + + If :attr:`upper` is ``False``, the returned matrix ``L`` is lower-triangular, and + the decomposition has the form: + + .. math:: + + A = LL^T + + If :attr:`upper` is ``True``, and :math:`A` is a batch of symmetric positive-definite + matrices, then the returned tensor will be composed of upper-triangular Cholesky factors + of each of the individual matrices. Similarly, when :attr:`upper` is ``False``, the returned + tensor will be composed of lower-triangular Cholesky factors of each of the individual + matrices. + + .. warning:: + + :func:`torch.cholesky` is deprecated in favor of :func:`torch.linalg.cholesky` + and will be removed in a future PyTorch release. + + ``L = torch.cholesky(A)`` should be replaced with + + .. code:: python + + L = torch.linalg.cholesky(A) + + ``U = torch.cholesky(A, upper=True)`` should be replaced with + + .. code:: python + + U = torch.linalg.cholesky(A).mH + + This transform will produce equivalent results for all valid (symmetric positive definite) inputs. + + Args: + input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more + batch dimensions consisting of symmetric positive-definite matrices. + upper (bool, optional): flag that indicates whether to return a + upper or lower triangular matrix. Default: ``False`` + + Keyword args: + out (Tensor, optional): the output matrix + + Example:: + + >>> a = torch.randn(3, 3) + >>> a = a @ a.mT + 1e-3 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> a + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> l + tensor([[ 1.5528, 0.0000, 0.0000], + [-0.4821, 1.0592, 0.0000], + [ 0.9371, 0.5487, 0.7023]]) + >>> l @ l.mT + tensor([[ 2.4112, -0.7486, 1.4551], + [-0.7486, 1.3544, 0.1294], + [ 1.4551, 0.1294, 1.6724]]) + >>> a = torch.randn(3, 2, 2) # Example for batched input + >>> a = a @ a.mT + 1e-03 # make symmetric positive-definite + >>> l = torch.cholesky(a) + >>> z = l @ l.mT + >>> torch.dist(z, a) + tensor(2.3842e-07) + """ + +def cholesky_inverse( + input: Tensor, + upper: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cholesky_inverse(L, upper=False, *, out=None) -> Tensor + + Computes the inverse of a complex Hermitian or real symmetric + positive-definite matrix given its Cholesky decomposition. + + Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, + and :math:`L` its Cholesky decomposition such that: + + .. math:: + + A = LL^{\text{H}} + + where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, + and the transpose when :math:`L` is real-valued. + + Computes the inverse matrix :math:`A^{-1}`. + + Supports input of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices, and if :math:`A` is a batch of matrices + then the output has the same batch dimensions. + + Args: + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False`` + + Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + + Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> torch.cholesky_inverse(L) + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + >>> A.inverse() + tensor([[ 1.9314, 1.2251, -0.0889], + [ 1.2251, 2.4439, 0.2122], + [-0.0889, 0.2122, 0.1412]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(torch.inverse(A), torch.cholesky_inverse(L)) + tensor(5.6358e-7) + """ + +def cholesky_solve( + input: Tensor, + input2: Tensor, + upper: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cholesky_solve(B, L, upper=False, *, out=None) -> Tensor + + Computes the solution of a system of linear equations with complex Hermitian + or real symmetric positive-definite lhs given its Cholesky decomposition. + + Let :math:`A` be a complex Hermitian or real symmetric positive-definite matrix, + and :math:`L` its Cholesky decomposition such that: + + .. math:: + + A = LL^{\text{H}} + + where :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, + and the transpose when :math:`L` is real-valued. + + Returns the solution :math:`X` of the following linear system: + + .. math:: + + AX = B + + Supports inputs of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices, and if :math:`A` or :math:`B` is a batch of matrices + then the output has the same batch dimensions. + + Args: + B (Tensor): right-hand side tensor of shape `(*, n, k)` + where :math:`*` is zero or more batch dimensions + L (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of lower or upper triangular Cholesky decompositions of + symmetric or Hermitian positive-definite matrices. + upper (bool, optional): flag that indicates whether :math:`L` is lower triangular + or upper triangular. Default: ``False``. + + Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + + Example:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.T + torch.eye(3) * 1e-3 # Creates a symmetric positive-definite matrix + >>> L = torch.linalg.cholesky(A) # Extract Cholesky decomposition + >>> B = torch.randn(3, 2) + >>> torch.cholesky_solve(B, L) + tensor([[ -8.1625, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + >>> A.inverse() @ B + tensor([[ -8.1626, 19.6097], + [ -5.8398, 14.2387], + [ -4.3771, 10.4173]]) + + >>> A = torch.randn(3, 2, 2, dtype=torch.complex64) + >>> A = A @ A.mH + torch.eye(2) * 1e-3 # Batch of Hermitian positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> B = torch.randn(2, 1, dtype=torch.complex64) + >>> X = torch.cholesky_solve(B, L) + >>> torch.dist(X, A.inverse() @ B) + tensor(1.6881e-5) + """ + +def choose_qparams_optimized( + input: Tensor, + numel: _int, + n_bins: _int, + ratio: _float, + bit_width: _int, +) -> tuple[Tensor, Tensor]: ... +def chunk(input: Tensor, chunks: _int, dim: _int = 0) -> tuple[Tensor, ...]: + r""" + chunk(input: Tensor, chunks: int, dim: int = 0) -> Tuple[Tensor, ...] + + Attempts to split a tensor into the specified number of chunks. Each chunk is a view of + the input tensor. + + + .. note:: + + This function may return fewer than the specified number of chunks! + + .. seealso:: + + :func:`torch.tensor_split` a function that always returns exactly the specified number of chunks + + If the tensor size along the given dimension :attr:`dim` is divisible by :attr:`chunks`, + all returned chunks will be the same size. + If the tensor size along the given dimension :attr:`dim` is not divisible by :attr:`chunks`, + all returned chunks will be the same size, except the last one. + If such division is not possible, this function may return fewer + than the specified number of chunks. + + Arguments: + input (Tensor): the tensor to split + chunks (int): number of chunks to return + dim (int): dimension along which to split the tensor + + Example: + >>> torch.arange(11).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10])) + >>> torch.arange(12).chunk(6) + (tensor([0, 1]), + tensor([2, 3]), + tensor([4, 5]), + tensor([6, 7]), + tensor([8, 9]), + tensor([10, 11])) + >>> torch.arange(13).chunk(6) + (tensor([0, 1, 2]), + tensor([3, 4, 5]), + tensor([6, 7, 8]), + tensor([ 9, 10, 11]), + tensor([12])) + """ + +@overload +def clamp( + input: Tensor, + min: Tensor | None = None, + max: Tensor | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + clamp(input, min=None, max=None, *, out=None) -> Tensor + + Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. + Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + + .. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + + If :attr:`min` is ``None``, there is no lower bound. + Or, if :attr:`max` is ``None`` there is no upper bound. + + + .. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + + Args: + input (Tensor): the input tensor. + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + """ + +@overload +def clamp( + input: Tensor, + min: Number | _complex | None = None, + max: Number | _complex | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + clamp(input, min=None, max=None, *, out=None) -> Tensor + + Clamps all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]`. + Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: + + .. math:: + y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i) + + If :attr:`min` is ``None``, there is no lower bound. + Or, if :attr:`max` is ``None`` there is no upper bound. + + + .. note:: + If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) ` + sets all elements in :attr:`input` to the value of :attr:`max`. + + Args: + input (Tensor): the input tensor. + min (Number or Tensor, optional): lower-bound of the range to be clamped to + max (Number or Tensor, optional): upper-bound of the range to be clamped to + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.7120, 0.1734, -0.0478, -0.0922]) + >>> torch.clamp(a, min=-0.5, max=0.5) + tensor([-0.5000, 0.1734, -0.0478, -0.0922]) + + >>> min = torch.linspace(-1, 1, steps=4) + >>> torch.clamp(a, min=min) + tensor([-1.0000, 0.1734, 0.3333, 1.0000]) + """ + +@overload +def clamp_( + input: Tensor, + min: Tensor | None = None, + max: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_( + input: Tensor, + min: Number | _complex | None = None, + max: Number | _complex | None = None, +) -> Tensor: ... +@overload +def clamp_max( + input: Tensor, + max: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_max( + input: Tensor, + max: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_max_(input: Tensor, max: Tensor) -> Tensor: ... +@overload +def clamp_max_(input: Tensor, max: Number | _complex) -> Tensor: ... +@overload +def clamp_min( + input: Tensor, + min: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_min( + input: Tensor, + min: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def clamp_min_(input: Tensor, min: Tensor) -> Tensor: ... +@overload +def clamp_min_(input: Tensor, min: Number | _complex) -> Tensor: ... +@overload +def clip( + input: Tensor, + min: Tensor | None = None, + max: Tensor | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + clip(input, min=None, max=None, *, out=None) -> Tensor + + Alias for :func:`torch.clamp`. + """ + +@overload +def clip( + input: Tensor, + min: Number | _complex | None = None, + max: Number | _complex | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + clip(input, min=None, max=None, *, out=None) -> Tensor + + Alias for :func:`torch.clamp`. + """ + +@overload +def clip_( + input: Tensor, + min: Tensor | None = None, + max: Tensor | None = None, +) -> Tensor: ... +@overload +def clip_( + input: Tensor, + min: Number | _complex | None = None, + max: Number | _complex | None = None, +) -> Tensor: ... +def clone( + input: Tensor, + *, + memory_format: memory_format | None = None, +) -> Tensor: + r""" + clone(input, *, memory_format=torch.preserve_format) -> Tensor + + Returns a copy of :attr:`input`. + + .. note:: + + This function is differentiable, so gradients will flow back from the + result of this operation to :attr:`input`. To create a tensor without an + autograd relationship to :attr:`input` see :meth:`~Tensor.detach`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned tensor. Default: ``torch.preserve_format``. + """ + +def col_indices_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.col_indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def column_stack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + column_stack(tensors, *, out=None) -> Tensor + + Creates a new tensor by horizontally stacking the tensors in :attr:`tensors`. + + Equivalent to ``torch.hstack(tensors)``, except each zero or one dimensional tensor ``t`` + in :attr:`tensors` is first reshaped into a ``(t.numel(), 1)`` column before being stacked horizontally. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.column_stack((a, b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + >>> a = torch.arange(5) + >>> b = torch.arange(10).reshape(5, 2) + >>> torch.column_stack((a, b, b)) + tensor([[0, 0, 1, 0, 1], + [1, 2, 3, 2, 3], + [2, 4, 5, 4, 5], + [3, 6, 7, 6, 7], + [4, 8, 9, 8, 9]]) + """ + +def combinations( + input: Tensor, + r: _int = 2, + with_replacement: _bool = False, +) -> Tensor: + r""" + combinations(input: Tensor, r: int = 2, with_replacement: bool = False) -> seq + + Compute combinations of length :math:`r` of the given tensor. The behavior is similar to + python's `itertools.combinations` when `with_replacement` is set to `False`, and + `itertools.combinations_with_replacement` when `with_replacement` is set to `True`. + + Arguments: + input (Tensor): 1D vector. + r (int, optional): number of elements to combine + with_replacement (bool, optional): whether to allow duplication in combination + + Returns: + Tensor: A tensor equivalent to converting all the input tensors into lists, do + `itertools.combinations` or `itertools.combinations_with_replacement` on these + lists, and finally convert the resulting list into tensor. + + Example:: + + >>> a = [1, 2, 3] + >>> list(itertools.combinations(a, r=2)) + [(1, 2), (1, 3), (2, 3)] + >>> list(itertools.combinations(a, r=3)) + [(1, 2, 3)] + >>> list(itertools.combinations_with_replacement(a, r=2)) + [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] + >>> tensor_a = torch.tensor(a) + >>> torch.combinations(tensor_a) + tensor([[1, 2], + [1, 3], + [2, 3]]) + >>> torch.combinations(tensor_a, r=3) + tensor([[1, 2, 3]]) + >>> torch.combinations(tensor_a, with_replacement=True) + tensor([[1, 1], + [1, 2], + [1, 3], + [2, 2], + [2, 3], + [3, 3]]) + """ + +def complex( + real: Tensor, + imag: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + complex(real, imag, *, out=None) -> Tensor + + Constructs a complex tensor with its real part equal to :attr:`real` and its + imaginary part equal to :attr:`imag`. + + Args: + real (Tensor): The real part of the complex tensor. Must be half, float or double. + imag (Tensor): The imaginary part of the complex tensor. Must be same dtype + as :attr:`real`. + + Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + + Example:: + + >>> real = torch.tensor([1, 2], dtype=torch.float32) + >>> imag = torch.tensor([3, 4], dtype=torch.float32) + >>> z = torch.complex(real, imag) + >>> z + tensor([(1.+3.j), (2.+4.j)]) + >>> z.dtype + torch.complex64 + """ + +@overload +def concat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + concat(tensors, dim=0, *, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + +@overload +def concat( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: str | EllipsisType | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + concat(tensors, dim=0, *, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + +@overload +def concatenate( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + concatenate(tensors, axis=0, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + +@overload +def concatenate( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: str | EllipsisType | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + concatenate(tensors, axis=0, out=None) -> Tensor + + Alias of :func:`torch.cat`. + """ + +def conj(input: Tensor) -> Tensor: + r""" + conj(input) -> Tensor + + Returns a view of :attr:`input` with a flipped conjugate bit. If :attr:`input` has a non-complex dtype, + this function just returns :attr:`input`. + + .. note:: + :func:`torch.conj` performs a lazy conjugation, but the actual conjugated tensor can be materialized + at any time using :func:`torch.resolve_conj`. + + .. warning:: In the future, :func:`torch.conj` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> x.is_conj() + False + >>> y = torch.conj(x) + >>> y.is_conj() + True + """ + +def conj_physical(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + conj_physical(input, *, out=None) -> Tensor + + Computes the element-wise conjugate of the given :attr:`input` tensor. + If :attr:`input` has a non-complex dtype, this function just returns :attr:`input`. + + .. note:: + This performs the conjugate operation regardless of the fact conjugate bit is set or not. + + .. warning:: In the future, :func:`torch.conj_physical` may return a non-writeable view for an :attr:`input` of + non-complex dtype. It's recommended that programs not modify the tensor returned by :func:`torch.conj_physical` + when :attr:`input` is of non-complex dtype to be compatible with this change. + + .. math:: + \text{out}_{i} = conj(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.conj_physical(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])) + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + """ + +def conj_physical_(input: Tensor) -> Tensor: ... +def constant_pad_nd( + input: Tensor, + pad: Sequence[_int | SymInt], + value: Number | _complex = 0, +) -> Tensor: ... +@overload +def conv1d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv1d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: str = "valid", + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv2d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv2d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: str = "valid", + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv3d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +@overload +def conv3d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: str = "valid", + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, + groups: _int | SymInt = 1, +) -> Tensor: ... +def conv_tbc( + input: Tensor, + weight: Tensor, + bias: Tensor, + pad: _int = 0, +) -> Tensor: ... +def conv_transpose1d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + output_padding: _int | SymInt | Sequence[_int | SymInt] = 0, + groups: _int | SymInt = 1, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, +) -> Tensor: ... +def conv_transpose2d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + output_padding: _int | SymInt | Sequence[_int | SymInt] = 0, + groups: _int | SymInt = 1, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, +) -> Tensor: ... +def conv_transpose3d( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, + stride: _int | SymInt | Sequence[_int | SymInt] = 1, + padding: _int | SymInt | Sequence[_int | SymInt] = 0, + output_padding: _int | SymInt | Sequence[_int | SymInt] = 0, + groups: _int | SymInt = 1, + dilation: _int | SymInt | Sequence[_int | SymInt] = 1, +) -> Tensor: ... +def convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + transposed: _bool, + output_padding: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +@overload +def copysign( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + copysign(input, other, *, out=None) -> Tensor + + Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + + .. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} + + + Supports :ref:`broadcasting to a common shape `, + and integer and float inputs. + + Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + + .. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + """ + +@overload +def copysign( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + copysign(input, other, *, out=None) -> Tensor + + Create a new floating-point tensor with the magnitude of :attr:`input` and the sign of :attr:`other`, elementwise. + + .. math:: + \text{out}_{i} = \begin{cases} + -|\text{input}_{i}| & \text{if } \text{other}_{i} \leq -0.0 \\ + |\text{input}_{i}| & \text{if } \text{other}_{i} \geq 0.0 \\ + \end{cases} + + + Supports :ref:`broadcasting to a common shape `, + and integer and float inputs. + + Args: + input (Tensor): magnitudes. + other (Tensor or Number): contains value(s) whose signbit(s) are + applied to the magnitudes in :attr:`input`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.2557, -0.0026, -0.5387, 0.4740, -0.9244]) + >>> torch.copysign(a, 1) + tensor([1.2557, 0.0026, 0.5387, 0.4740, 0.9244]) + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.7079, 0.2778, -1.0249, 0.5719], + [-0.0059, -0.2600, -0.4475, -1.3948], + [ 0.3667, -0.9567, -2.5757, -0.1751], + [ 0.2046, -0.0742, 0.2998, -0.1054]]) + >>> b = torch.randn(4) + tensor([ 0.2373, 0.3120, 0.3190, -1.1128]) + >>> torch.copysign(a, b) + tensor([[ 0.7079, 0.2778, 1.0249, -0.5719], + [ 0.0059, 0.2600, 0.4475, -1.3948], + [ 0.3667, 0.9567, 2.5757, -0.1751], + [ 0.2046, 0.0742, 0.2998, -0.1054]]) + >>> a = torch.tensor([1.]) + >>> b = torch.tensor([-0.]) + >>> torch.copysign(a, b) + tensor([-1.]) + + .. note:: + copysign handles signed zeros. If the other argument has a negative zero (-0), + the corresponding output value will be negative. + """ + +def corrcoef(input: Tensor) -> Tensor: + r""" + corrcoef(input) -> Tensor + + Estimates the Pearson product-moment correlation coefficient matrix of the variables given by the :attr:`input` matrix, + where rows are the variables and columns are the observations. + + .. note:: + + The correlation coefficient matrix R is computed using the covariance matrix C as given by + :math:`R_{ij} = \frac{ C_{ij} } { \sqrt{ C_{ii} * C_{jj} } }` + + .. note:: + + Due to floating point rounding, the resulting array may not be Hermitian and its diagonal elements may not be 1. + The real and imaginary values are clipped to the interval [-1, 1] in an attempt to improve this situation. + + Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + + Returns: + (Tensor) The correlation coefficient matrix of the variables. + + .. seealso:: + + :func:`torch.cov` covariance matrix. + + Example:: + + >>> x = torch.tensor([[0, 1, 2], [2, 1, 0]]) + >>> torch.corrcoef(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> x = torch.randn(2, 4) + >>> x + tensor([[-0.2678, -0.0908, -0.3766, 0.2780], + [-0.5812, 0.1535, 0.2387, 0.2350]]) + >>> torch.corrcoef(x) + tensor([[1.0000, 0.3582], + [0.3582, 1.0000]]) + >>> torch.corrcoef(x[0]) + tensor(1.) + """ + +def cos(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + cos(input, *, out=None) -> Tensor + + Returns a new tensor with the cosine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \cos(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 1.4309, 1.2706, -0.8562, 0.9796]) + >>> torch.cos(a) + tensor([ 0.1395, 0.2957, 0.6553, 0.5574]) + """ + +def cos_(input: Tensor) -> Tensor: ... +def cosh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + cosh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic cosine of the elements of + :attr:`input`. + + .. math:: + \text{out}_{i} = \cosh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.1632, 1.1835, -0.6979, -0.7325]) + >>> torch.cosh(a) + tensor([ 1.0133, 1.7860, 1.2536, 1.2805]) + + .. note:: + When :attr:`input` is on the CPU, the implementation of torch.cosh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. + """ + +def cosh_(input: Tensor) -> Tensor: ... +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: _float = 0.0, + reduction: _int = 1, +) -> Tensor: ... +def cosine_similarity( + x1: Tensor, + x2: Tensor, + dim: _int = 1, + eps: _float = 1e-08, +) -> Tensor: ... +@overload +def count_nonzero(input: Tensor, dim: _int | None = None) -> Tensor: + r""" + count_nonzero(input, dim=None) -> Tensor + + Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. + If no dim is specified then all non-zeros in the tensor are counted. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + + Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) + """ + +@overload +def count_nonzero(input: Tensor, dim: _size) -> Tensor: + r""" + count_nonzero(input, dim=None) -> Tensor + + Counts the number of non-zero values in the tensor :attr:`input` along the given :attr:`dim`. + If no dim is specified then all non-zeros in the tensor are counted. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): Dim or tuple of dims along which to count non-zeros. + + Example:: + + >>> x = torch.zeros(3,3) + >>> x[torch.randn(3,3) > 0.5] = 1 + >>> x + tensor([[0., 1., 1.], + [0., 0., 0.], + [0., 0., 1.]]) + >>> torch.count_nonzero(x) + tensor(3) + >>> torch.count_nonzero(x, dim=0) + tensor([0, 1, 2]) + """ + +def cov( + input: Tensor, + *, + correction: _int = 1, + fweights: Tensor | None = None, + aweights: Tensor | None = None, +) -> Tensor: + r""" + cov(input, *, correction=1, fweights=None, aweights=None) -> Tensor + + Estimates the covariance matrix of the variables given by the :attr:`input` matrix, where rows are + the variables and columns are the observations. + + A covariance matrix is a square matrix giving the covariance of each pair of variables. The diagonal contains + the variance of each variable (covariance of a variable with itself). By definition, if :attr:`input` represents + a single variable (Scalar or 1D) then its variance is returned. + + The sample covariance of the variables :math:`x` and :math:`y` is given by: + + .. math:: + \text{cov}(x,y) = \frac{\sum^{N}_{i = 1}(x_{i} - \bar{x})(y_{i} - \bar{y})}{\max(0,~N~-~\delta N)} + + where :math:`\bar{x}` and :math:`\bar{y}` are the simple means of the :math:`x` and :math:`y` respectively, and + :math:`\delta N` is the :attr:`correction`. + + If :attr:`fweights` and/or :attr:`aweights` are provided, the weighted covariance + is calculated, which is given by: + + .. math:: + \text{cov}_w(x,y) = \frac{\sum^{N}_{i = 1}w_i(x_{i} - \mu_x^*)(y_{i} - \mu_y^*)} + {\max(0,~\sum^{N}_{i = 1}w_i~-~\frac{\sum^{N}_{i = 1}w_ia_i}{\sum^{N}_{i = 1}w_i}~\delta N)} + + where :math:`w` denotes :attr:`fweights` or :attr:`aweights` (``f`` and ``a`` for brevity) based on whichever is + provided, or :math:`w = f \times a` if both are provided, and + :math:`\mu_x^* = \frac{\sum^{N}_{i = 1}w_ix_{i} }{\sum^{N}_{i = 1}w_i}` is the weighted mean of the variable. If not + provided, ``f`` and/or ``a`` can be seen as a :math:`\mathbb{1}` vector of appropriate size. + + Args: + input (Tensor): A 2D matrix containing multiple variables and observations, or a + Scalar or 1D vector representing a single variable. + + Keyword Args: + correction (int, optional): difference between the sample size and sample degrees of freedom. + Defaults to Bessel's correction, ``correction = 1`` which returns the unbiased estimate, + even if both :attr:`fweights` and :attr:`aweights` are specified. ``correction = 0`` + will return the simple average. Defaults to ``1``. + fweights (tensor, optional): A Scalar or 1D tensor of observation vector frequencies representing the number of + times each observation should be repeated. Its numel must equal the number of columns of :attr:`input`. + Must have integral dtype. Ignored if ``None``. Defaults to ``None``. + aweights (tensor, optional): A Scalar or 1D array of observation vector weights. + These relative weights are typically large for observations considered "important" and smaller for + observations considered less "important". Its numel must equal the number of columns of :attr:`input`. + Must have floating point dtype. Ignored if ``None``. Defaults to ``None``. + + Returns: + (Tensor) The covariance matrix of the variables. + + .. seealso:: + + :func:`torch.corrcoef` normalized covariance matrix. + + Example:: + + >>> x = torch.tensor([[0, 2], [1, 1], [2, 0]]).T + >>> x + tensor([[0, 1, 2], + [2, 1, 0]]) + >>> torch.cov(x) + tensor([[ 1., -1.], + [-1., 1.]]) + >>> torch.cov(x, correction=0) + tensor([[ 0.6667, -0.6667], + [-0.6667, 0.6667]]) + >>> fw = torch.randint(1, 10, (3,)) + >>> fw + tensor([1, 6, 9]) + >>> aw = torch.rand(3) + >>> aw + tensor([0.4282, 0.0255, 0.4144]) + >>> torch.cov(x, fweights=fw, aweights=aw) + tensor([[ 0.4169, -0.4169], + [-0.4169, 0.4169]]) + """ + +def cross( + input: Tensor, + other: Tensor, + dim: _int | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + cross(input, other, dim=None, *, out=None) -> Tensor + + + Returns the cross product of vectors in dimension :attr:`dim` of :attr:`input` + and :attr:`other`. + + Supports input of float, double, cfloat and cdouble dtypes. Also supports batches + of vectors, for which it computes the product along the dimension :attr:`dim`. + In this case, the output has the same batch dimensions as the inputs. + + .. warning:: + If :attr:`dim` is not given, it defaults to the first dimension found + with the size 3. Note that this might be unexpected. + + This behavior is deprecated and will be changed to match that of :func:`torch.linalg.cross` + in a future release. + + .. seealso:: + :func:`torch.linalg.cross` which has dim=-1 as default. + + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + dim (int, optional): the dimension to take the cross-product in. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4, 3) + >>> a + tensor([[-0.3956, 1.1455, 1.6895], + [-0.5849, 1.3672, 0.3599], + [-1.1626, 0.7180, -0.0521], + [-0.1339, 0.9902, -2.0225]]) + >>> b = torch.randn(4, 3) + >>> b + tensor([[-0.0257, -1.4725, -1.2251], + [-1.1479, -0.7005, -1.9757], + [-1.3904, 0.3726, -1.1836], + [-0.9688, -0.7153, 0.2159]]) + >>> torch.cross(a, b, dim=1) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + >>> torch.cross(a, b) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + """ + +def crow_indices_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.crow_indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: _size, + target_lengths: _size, + blank: _int = 0, + reduction: _int = 1, + zero_infinity: _bool = False, +) -> Tensor: ... +@overload +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: _int = 0, + reduction: _int = 1, + zero_infinity: _bool = False, +) -> Tensor: ... +def cudnn_affine_grid_generator( + theta: Tensor, + N: _int, + C: _int, + H: _int, + W: _int, +) -> Tensor: ... +def cudnn_batch_norm( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + exponential_average_factor: _float, + epsilon: _float, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +def cudnn_convolution( + input: Tensor, + weight: Tensor, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, + allow_tf32: _bool, + *, + out: Tensor | None = None, +) -> Tensor: ... +def cudnn_convolution_add_relu( + input: Tensor, + weight: Tensor, + z: Tensor, + alpha: Number | _complex | None, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def cudnn_convolution_relu( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def cudnn_convolution_transpose( + input: Tensor, + weight: Tensor, + padding: Sequence[_int | SymInt], + output_padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, + allow_tf32: _bool, +) -> Tensor: ... +def cudnn_grid_sampler(input: Tensor, grid: Tensor) -> Tensor: ... +def cudnn_is_acceptable(input: Tensor) -> _bool: ... +@overload +def cummax( + input: Tensor, + dim: _int, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.cummax: + r""" + cummax(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) + """ + +@overload +def cummax( + input: Tensor, + dim: str | EllipsisType | None, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.cummax: + r""" + cummax(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative maximum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) + """ + +@overload +def cummin( + input: Tensor, + dim: _int, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.cummin: + r""" + cummin(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) + """ + +@overload +def cummin( + input: Tensor, + dim: str | EllipsisType | None, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.cummin: + r""" + cummin(input, dim, *, out=None) -> (Tensor, LongTensor) + Returns a namedtuple ``(values, indices)`` where ``values`` is the cumulative minimum of + elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index + location of each maximum value found in the dimension :attr:`dim`. + + .. math:: + y_i = min(x_1, x_2, x_3, \dots, x_i) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220, -0.3885, 1.1762, + 0.9165, 1.6684]) + >>> torch.cummin(a, dim=0) + torch.return_types.cummin( + values=tensor([-0.2284, -0.6628, -0.6628, -0.6628, -1.3298, -1.3298, -1.3298, -1.3298, + -1.3298, -1.3298]), + indices=tensor([0, 1, 1, 1, 4, 4, 4, 4, 4, 4])) + """ + +@overload +def cumprod( + input: Tensor, + dim: _int, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + cumprod(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative product of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) + """ + +@overload +def cumprod( + input: Tensor, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + cumprod(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative product of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 \times x_2\times x_3\times \dots \times x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> a + tensor([ 0.6001, 0.2069, -0.1919, 0.9792, 0.6727, 1.0062, 0.4126, + -0.2129, -0.4206, 0.1968]) + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0158, -0.0065, + 0.0014, -0.0006, -0.0001]) + + >>> a[5] = 0.0 + >>> torch.cumprod(a, dim=0) + tensor([ 0.6001, 0.1241, -0.0238, -0.0233, -0.0157, -0.0000, -0.0000, + 0.0000, -0.0000, -0.0000]) + """ + +@overload +def cumsum( + input: Tensor, + dim: _int, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + cumsum(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative sum of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) + """ + +@overload +def cumsum( + input: Tensor, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + cumsum(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative sum of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) + >>> torch.cumsum(a, dim=0) + tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) + """ + +@overload +def cumulative_trapezoid(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Cumulatively computes the `trapezoidal rule `_ + along :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` + and this function is that, :func:`torch.trapezoid` returns a value for each integration, + where as this function returns a cumulative value for every spacing within the integration. This + is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) + """ + +@overload +def cumulative_trapezoid( + y: Tensor, + *, + dx: Number | _complex = 1, + dim: _int = -1, +) -> Tensor: + r""" + cumulative_trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Cumulatively computes the `trapezoidal rule `_ + along :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. + + For more details, please read :func:`torch.trapezoid`. The difference between :func:`torch.trapezoid` + and this function is that, :func:`torch.trapezoid` returns a value for each integration, + where as this function returns a cumulative value for every spacing within the integration. This + is analogous to how `.sum` returns a value and `.cumsum` returns a cumulative sum. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Cumulatively computes the trapezoidal rule in 1D, spacing is implicitly 1. + >>> y = torch.tensor([1, 5, 10]) + >>> torch.cumulative_trapezoid(y) + tensor([3., 10.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> (1 + 5) / 2 + 3.0 + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Cumulatively computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.cumulative_trapezoid(y, dx=2) + tensor([6., 21.]) + + >>> # Cumulatively computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([6., 28.5]) + + >>> # Computes the same trapezoidal rule directly up to each element to verify + >>> ((3 - 1) * (1 + 5)) / 2 + 6.0 + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.cumulative_trapezoid(y) + tensor([[ 0.5, 2.], + [ 3.5, 8.], + [ 6.5, 14.]]) + + >>> # Cumulatively computes the trapezoidal rule for each column of the matrix + >>> torch.cumulative_trapezoid(y, dim=0) + tensor([[ 1.5, 2.5, 3.5], + [ 6.0, 8.0, 10.0]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[2., 5.], + [2., 5.], + [2., 5.]]) + + >>> # Cumulatively computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.cumulative_trapezoid(y, x) + tensor([[1., 2.], + [2., 4.], + [3., 6.]]) + """ + +def deg2rad(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + deg2rad(input, *, out=None) -> Tensor + + Returns a new tensor with each of the elements of :attr:`input` + converted from angles in degrees to radians. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]]) + >>> torch.deg2rad(a) + tensor([[ 3.1416, -3.1416], + [ 6.2832, -6.2832], + [ 1.5708, -1.5708]]) + """ + +def deg2rad_(input: Tensor) -> Tensor: ... +@overload +def dequantize(input: Tensor) -> Tensor: + r""" + dequantize(tensor) -> Tensor + + Returns an fp32 Tensor by dequantizing a quantized Tensor + + Args: + tensor (Tensor): A quantized Tensor + + .. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + + Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + + Args: + tensors (sequence of Tensors): A list of quantized Tensors + """ + +@overload +def dequantize( + tensors: tuple[Tensor, ...] | list[Tensor] | None, +) -> tuple[Tensor, ...]: + r""" + dequantize(tensor) -> Tensor + + Returns an fp32 Tensor by dequantizing a quantized Tensor + + Args: + tensor (Tensor): A quantized Tensor + + .. function:: dequantize(tensors) -> sequence of Tensors + :noindex: + + Given a list of quantized Tensors, dequantize them and return a list of fp32 Tensors + + Args: + tensors (sequence of Tensors): A list of quantized Tensors + """ + +def det(input: Tensor) -> Tensor: + r""" + det(input) -> Tensor + + Alias for :func:`torch.linalg.det` + """ + +def detach(input: Tensor) -> Tensor: ... +def detach_(input: Tensor) -> Tensor: ... +def detach_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.detach`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def diag( + input: Tensor, + diagonal: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + diag(input, diagonal=0, *, out=None) -> Tensor + + - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. + - If :attr:`input` is a matrix (2-D tensor), then returns a 1-D tensor with + the diagonal elements of :attr:`input`. + + The argument :attr:`diagonal` controls which diagonal to consider: + + - If :attr:`diagonal` = 0, it is the main diagonal. + - If :attr:`diagonal` > 0, it is above the main diagonal. + - If :attr:`diagonal` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.diagonal` always returns the diagonal of its input. + + :func:`torch.diagflat` always constructs a tensor with diagonal elements + specified by the input. + + Examples: + + Get the square matrix where the input vector is the diagonal:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.5950,-0.0872, 2.3298]) + >>> torch.diag(a) + tensor([[ 0.5950, 0.0000, 0.0000], + [ 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 2.3298]]) + >>> torch.diag(a, 1) + tensor([[ 0.0000, 0.5950, 0.0000, 0.0000], + [ 0.0000, 0.0000,-0.0872, 0.0000], + [ 0.0000, 0.0000, 0.0000, 2.3298], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + Get the k-th diagonal of a given matrix:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-0.4264, 0.0255,-0.1064], + [ 0.8795,-0.2429, 0.1374], + [ 0.1029,-0.6482,-1.6300]]) + >>> torch.diag(a, 0) + tensor([-0.4264,-0.2429,-1.6300]) + >>> torch.diag(a, 1) + tensor([ 0.0255, 0.1374]) + """ + +def diag_embed( + input: Tensor, + offset: _int = 0, + dim1: _int = -2, + dim2: _int = -1, +) -> Tensor: + r""" + diag_embed(input, offset=0, dim1=-2, dim2=-1) -> Tensor + + Creates a tensor whose diagonals of certain 2D planes (specified by + :attr:`dim1` and :attr:`dim2`) are filled by :attr:`input`. + To facilitate creating batched diagonal matrices, the 2D planes formed by + the last two dimensions of the returned tensor are chosen by default. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + The size of the new matrix will be calculated to make the specified diagonal + of the size of the last input dimension. + Note that for :attr:`offset` other than :math:`0`, the order of :attr:`dim1` + and :attr:`dim2` matters. Exchanging them is equivalent to changing the + sign of :attr:`offset`. + + Applying :meth:`torch.diagonal` to the output of this function with + the same arguments yields a matrix identical to input. However, + :meth:`torch.diagonal` has different default dimensions, so those + need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 1-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: -2. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: -1. + + Example:: + + >>> a = torch.randn(2, 3) + >>> torch.diag_embed(a) + tensor([[[ 1.5410, 0.0000, 0.0000], + [ 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -2.1788]], + + [[ 0.5684, 0.0000, 0.0000], + [ 0.0000, -1.0845, 0.0000], + [ 0.0000, 0.0000, -1.3986]]]) + + >>> torch.diag_embed(a, offset=1, dim1=0, dim2=2) + tensor([[[ 0.0000, 1.5410, 0.0000, 0.0000], + [ 0.0000, 0.5684, 0.0000, 0.0000]], + + [[ 0.0000, 0.0000, -0.2934, 0.0000], + [ 0.0000, 0.0000, -1.0845, 0.0000]], + + [[ 0.0000, 0.0000, 0.0000, -2.1788], + [ 0.0000, 0.0000, 0.0000, -1.3986]], + + [[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]]]) + """ + +def diagflat(input: Tensor, offset: _int = 0) -> Tensor: + r""" + diagflat(input, offset=0) -> Tensor + + - If :attr:`input` is a vector (1-D tensor), then returns a 2-D square tensor + with the elements of :attr:`input` as the diagonal. + - If :attr:`input` is a tensor with more than one dimension, then returns a + 2-D tensor with diagonal elements equal to a flattened :attr:`input`. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. + offset (int, optional): the diagonal to consider. Default: 0 (main + diagonal). + + Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([-0.2956, -0.9068, 0.1695]) + >>> torch.diagflat(a) + tensor([[-0.2956, 0.0000, 0.0000], + [ 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.1695]]) + >>> torch.diagflat(a, 1) + tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.9068, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.1695], + [ 0.0000, 0.0000, 0.0000, 0.0000]]) + + >>> a = torch.randn(2, 2) + >>> a + tensor([[ 0.2094, -0.3018], + [-0.1516, 1.9342]]) + >>> torch.diagflat(a) + tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], + [ 0.0000, -0.3018, 0.0000, 0.0000], + [ 0.0000, 0.0000, -0.1516, 0.0000], + [ 0.0000, 0.0000, 0.0000, 1.9342]]) + """ + +@overload +def diagonal( + input: Tensor, + offset: _int = 0, + dim1: _int = 0, + dim2: _int = 1, +) -> Tensor: + r""" + diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + + Returns a partial view of :attr:`input` with the its diagonal elements + with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension + at the end of the shape. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Applying :meth:`torch.diag_embed` to the output of this function with + the same arguments yields a diagonal matrix with the diagonal entries + of the input. However, :meth:`torch.diag_embed` has different default + dimensions, so those need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + + Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + >>> b = torch.randn(2, 5) + >>> b + tensor([[-1.7948, -1.2731, -0.3181, 2.0200, -1.6745], + [ 1.8262, -1.5049, 0.4114, 1.0704, -1.2607]]) + + >>> torch.diagonal(b, 1, 1, 0) + tensor([1.8262]) + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) + """ + +@overload +def diagonal( + input: Tensor, + *, + outdim: str | EllipsisType | None, + dim1: str | EllipsisType | None, + dim2: str | EllipsisType | None, + offset: _int = 0, +) -> Tensor: + r""" + diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor + + Returns a partial view of :attr:`input` with the its diagonal elements + with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension + at the end of the shape. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Applying :meth:`torch.diag_embed` to the output of this function with + the same arguments yields a diagonal matrix with the diagonal entries + of the input. However, :meth:`torch.diag_embed` has different default + dimensions, so those need to be explicitly specified. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1. + + Examples:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0854, 1.1431, -0.1752], + [ 0.8536, -0.0905, 0.0360], + [ 0.6927, -0.3735, -0.4945]]) + + + >>> torch.diagonal(a) + tensor([-1.0854, -0.0905, -0.4945]) + + + >>> torch.diagonal(a, 1) + tensor([ 1.1431, 0.0360]) + + >>> b = torch.randn(2, 5) + >>> b + tensor([[-1.7948, -1.2731, -0.3181, 2.0200, -1.6745], + [ 1.8262, -1.5049, 0.4114, 1.0704, -1.2607]]) + + >>> torch.diagonal(b, 1, 1, 0) + tensor([1.8262]) + + >>> x = torch.randn(2, 5, 4, 2) + >>> torch.diagonal(x, offset=-1, dim1=1, dim2=2) + tensor([[[-1.2631, 0.3755, -1.5977, -1.8172], + [-1.1065, 1.0401, -0.2235, -0.7938]], + + [[-1.7325, -0.3081, 0.6166, 0.2335], + [ 1.0500, 0.7336, -0.3836, -1.1015]]]) + """ + +def diagonal_copy( + input: Tensor, + offset: _int = 0, + dim1: _int = 0, + dim2: _int = 1, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.diagonal`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def diagonal_scatter( + input: Tensor, + src: Tensor, + offset: _int = 0, + dim1: _int = 0, + dim2: _int = 1, +) -> Tensor: + r""" + diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` along + the diagonal elements of :attr:`input`, with respect to :attr:`dim1` + and :attr:`dim2`. + + This function returns a tensor with fresh storage; it does not + return a view. + + The argument :attr:`offset` controls which diagonal to consider: + + - If :attr:`offset` = 0, it is the main diagonal. + - If :attr:`offset` > 0, it is above the main diagonal. + - If :attr:`offset` < 0, it is below the main diagonal. + + Args: + input (Tensor): the input tensor. Must be at least 2-dimensional. + src (Tensor): the tensor to embed into :attr:`input`. + offset (int, optional): which diagonal to consider. Default: 0 + (main diagonal). + dim1 (int, optional): first dimension with respect to which to + take diagonal. Default: 0. + dim2 (int, optional): second dimension with respect to which to + take diagonal. Default: 1. + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.diagonal(input, offset, dim1, dim2)`` + + Examples:: + + >>> a = torch.zeros(3, 3) + >>> a + tensor([[0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + + >>> torch.diagonal_scatter(a, torch.ones(3), 0) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + + >>> torch.diagonal_scatter(a, torch.ones(2), 1) + tensor([[0., 1., 0.], + [0., 0., 1.], + [0., 0., 0.]]) + """ + +def diff( + input: Tensor, + n: _int = 1, + dim: _int = -1, + prepend: Tensor | None = None, + append: Tensor | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + diff(input, n=1, dim=-1, prepend=None, append=None) -> Tensor + + Computes the n-th forward difference along the given dimension. + + The first-order differences are given by `out[i] = input[i + 1] - input[i]`. Higher-order + differences are calculated by using :func:`torch.diff` recursively. + + Args: + input (Tensor): the tensor to compute the differences on + n (int, optional): the number of times to recursively compute the difference + dim (int, optional): the dimension to compute the difference along. + Default is the last dimension. + prepend, append (Tensor, optional): values to prepend or append to + :attr:`input` along :attr:`dim` before computing the difference. + Their dimensions must be equivalent to that of input, and their shapes + must match input's shape except on :attr:`dim`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 3, 2]) + >>> torch.diff(a) + tensor([ 2, -1]) + >>> b = torch.tensor([4, 5]) + >>> torch.diff(a, append=b) + tensor([ 2, -1, 2, 1]) + >>> c = torch.tensor([[1, 2, 3], [3, 4, 5]]) + >>> torch.diff(c, dim=0) + tensor([[2, 2, 2]]) + >>> torch.diff(c, dim=1) + tensor([[1, 1], + [1, 1]]) + """ + +def digamma(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + digamma(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.digamma`. + """ + +def dist(input: Tensor, other: Tensor, p: Number | _complex = 2) -> Tensor: + r""" + dist(input, other, p=2) -> Tensor + + Returns the p-norm of (:attr:`input` - :attr:`other`) + + The shapes of :attr:`input` and :attr:`other` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + other (Tensor): the Right-hand-side input tensor + p (float, optional): the norm to be computed + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([-1.5393, -0.8675, 0.5916, 1.6321]) + >>> y = torch.randn(4) + >>> y + tensor([ 0.0967, -1.0511, 0.6295, 0.8360]) + >>> torch.dist(x, y, 3.5) + tensor(1.6727) + >>> torch.dist(x, y, 3) + tensor(1.6973) + >>> torch.dist(x, y, 0) + tensor(4.) + >>> torch.dist(x, y, 1) + tensor(2.6537) + """ + +def div( + input: Tensor | Number, + other: Tensor | Number, + *, + rounding_mode: str | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + div(input, other, *, rounding_mode=None, out=None) -> Tensor + + Divides each element of the input ``input`` by the corresponding element of + :attr:`other`. + + .. math:: + \text{out}_i = \frac{\text{input}_i}{\text{other}_i} + + .. note:: + By default, this performs a "true" division like Python 3. + See the :attr:`rounding_mode` argument for floor division. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + Always promotes integer types to the default scalar type. + + Args: + input (Tensor): the dividend + other (Tensor or Number): the divisor + + Keyword args: + rounding_mode (str, optional): Type of rounding applied to the result: + + * None - default behavior. Performs no rounding and, if both :attr:`input` and + :attr:`other` are integer types, promotes the inputs to the default scalar type. + Equivalent to true division in Python (the ``/`` operator) and NumPy's ``np.true_divide``. + * ``"trunc"`` - rounds the results of the division towards zero. + Equivalent to C-style integer division. + * ``"floor"`` - rounds the results of the division down. + Equivalent to floor division in Python (the ``//`` operator) and NumPy's ``np.floor_divide``. + + out (Tensor, optional): the output tensor. + + Examples:: + + >>> x = torch.tensor([ 0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) + >>> torch.div(x, 0.5) + tensor([ 0.7620, 2.5548, -0.5944, -0.7438, 0.9274]) + + >>> a = torch.tensor([[-0.3711, -1.9353, -0.4605, -0.2917], + ... [ 0.1815, -1.0111, 0.9805, -1.5923], + ... [ 0.1062, 1.4581, 0.7759, -1.2344], + ... [-0.1830, -0.0313, 1.1908, -1.4757]]) + >>> b = torch.tensor([ 0.8032, 0.2930, -0.8113, -0.2308]) + >>> torch.div(a, b) + tensor([[-0.4620, -6.6051, 0.5676, 1.2639], + [ 0.2260, -3.4509, -1.2086, 6.8990], + [ 0.1322, 4.9764, -0.9564, 5.3484], + [-0.2278, -0.1068, -1.4678, 6.3938]]) + + >>> torch.div(a, b, rounding_mode='trunc') + tensor([[-0., -6., 0., 1.], + [ 0., -3., -1., 6.], + [ 0., 4., -0., 5.], + [-0., -0., -1., 6.]]) + + >>> torch.div(a, b, rounding_mode='floor') + tensor([[-1., -7., 0., 1.], + [ 0., -4., -2., 6.], + [ 0., 4., -1., 5.], + [-1., -1., -2., 6.]]) + """ + +@overload +def divide( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + +@overload +def divide( + input: Tensor, + other: Tensor, + *, + rounding_mode: str | None, + out: Tensor | None = None, +) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + +@overload +def divide( + input: Tensor, + other: Number | _complex, + *, + rounding_mode: str | None, +) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + +@overload +def divide(input: Tensor, other: Number | _complex) -> Tensor: + r""" + divide(input, other, *, rounding_mode=None, out=None) -> Tensor + + Alias for :func:`torch.div`. + """ + +def dot( + input: Tensor, + tensor: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + dot(input, tensor, *, out=None) -> Tensor + + Computes the dot product of two 1D tensors. + + .. note:: + + Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + + Args: + input (Tensor): first tensor in the dot product, must be 1D. + tensor (Tensor): second tensor in the dot product, must be 1D. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + + >>> t1, t2 = torch.tensor([0, 1]), torch.tensor([2, 3]) + >>> torch.dot(t1, t2) + tensor(3) + """ + +def dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ... +@overload +def dsplit(input: Tensor, sections: _int) -> tuple[Tensor, ...]: + r""" + dsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors + depthwise according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) + (the split dimension is 2), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.dsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + """ + +@overload +def dsplit(input: Tensor, indices: _size) -> tuple[Tensor, ...]: + r""" + dsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with three or more dimensions, into multiple tensors + depthwise according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=2) + (the split dimension is 2), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.dsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(2, 2, 4) + >>> t + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.]], + [[ 8., 9., 10., 11.], + [12., 13., 14., 15.]]]) + >>> torch.dsplit(t, 2) + (tensor([[[ 0., 1.], + [ 4., 5.]], + [[ 8., 9.], + [12., 13.]]]), + tensor([[[ 2., 3.], + [ 6., 7.]], + [[10., 11.], + [14., 15.]]])) + + >>> torch.dsplit(t, [3, 6]) + (tensor([[[ 0., 1., 2.], + [ 4., 5., 6.]], + [[ 8., 9., 10.], + [12., 13., 14.]]]), + tensor([[[ 3.], + [ 7.]], + [[11.], + [15.]]]), + tensor([], size=(2, 2, 0))) + """ + +def dstack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + dstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence depthwise (along third axis). + + This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by :func:`torch.atleast_3d`. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.dstack((a,b)) + tensor([[[1, 4], + [2, 5], + [3, 6]]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.dstack((a,b)) + tensor([[[1, 4]], + [[2, 5]], + [[3, 6]]]) + """ + +def embedding( + weight: Tensor, + indices: Tensor, + padding_idx: _int | SymInt = -1, + scale_grad_by_freq: _bool = False, + sparse: _bool = False, +) -> Tensor: ... +@overload +def embedding_bag( + weight: Tensor, + indices: Tensor, + offsets: Tensor, + scale_grad_by_freq: _bool, + mode: _int, + sparse: _bool, + per_sample_weights: Tensor | None, + include_last_offset: _bool, + padding_idx: _int | None, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def embedding_bag( + weight: Tensor, + indices: Tensor, + offsets: Tensor, + scale_grad_by_freq: _bool = False, + mode: _int = 0, + sparse: _bool = False, + per_sample_weights: Tensor | None = None, + include_last_offset: _bool = False, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +def embedding_renorm_( + input: Tensor, + indices: Tensor, + max_norm: _float, + norm_type: _float, +) -> Tensor: ... +@overload +def empty( + size: Sequence[_int | SymInt], + *, + memory_format: memory_format | None = None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + +@overload +def empty( + *size: _int | SymInt, + memory_format: memory_format | None = None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + +@overload +def empty( + size: _size, + *, + names: Sequence[str | EllipsisType | None] | None, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + +@overload +def empty( + *size: _int, + names: Sequence[str | EllipsisType | None] | None, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format) -> Tensor + + Returns a tensor filled with uninitialized data. The shape of the tensor is + defined by the variable argument :attr:`size`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + + Example:: + + >>> torch.empty((2,3), dtype=torch.int64) + tensor([[ 9.4064e+13, 2.8000e+01, 9.3493e+13], + [ 7.5751e+18, 7.1428e+18, 7.5955e+18]]) + """ + +def empty_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns an uninitialized tensor with the same size as :attr:`input`. + ``torch.empty_like(input)`` is equivalent to + ``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> a=torch.empty((2,3), dtype=torch.int32, device = 'cuda') + >>> torch.empty_like(a) + tensor([[0, 0, 0], + [0, 0, 0]], device='cuda:0', dtype=torch.int32) + """ + +def empty_permuted( + size: Sequence[_int | SymInt], + physical_layout: _size, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty_permuted(size, physical_layout, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Creates an uninitialized, non-overlapping and dense tensor with the + specified :attr:`size`, with :attr:`physical_layout` specifying how the + dimensions are physically laid out in memory (each logical dimension is listed + from outermost to innermost). :attr:`physical_layout` is a generalization + of NCHW/NHWC notation: if each dimension is assigned a number according to + what order they occur in size (N=0, C=1, H=2, W=3), then NCHW is ``(0, 1, 2, 3)`` + while NHWC is ``(0, 2, 3, 1)``. Equivalently, the strides of the output + tensor ``t`` are such that ``t.stride(physical_layout[i]) == contiguous_strides[i]`` + (notably, this function is *not* equivalent to ``torch.empty(size).permute(physical_layout)``). + + Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense + tensor with no overlaps. If possible, prefer using this function over + :func:`torch.empty_strided` or manual use of :func:`torch.as_strided`. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (tuple of int): the shape of the output tensor + physical_layout (tuple of int): the ordering of dimensions physically in memory + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Examples: + + >>> torch.empty((2, 3, 5, 7)).stride() + (105, 35, 7, 1) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 1, 2, 3)).stride() + (105, 35, 7, 1) + >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride() + (105, 1, 21, 3) + >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).dim_order() + (0, 2, 3, 1) + """ + +def empty_quantized( + size: _size, + qtensor: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +def empty_strided( + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + empty_strided(size, stride, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Creates a tensor with the specified :attr:`size` and :attr:`stride` and filled with undefined data. + + .. warning:: + If the constructed tensor is "overlapped" (with multiple indices referring to the same element + in memory) its behavior is undefined. + + .. note:: + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, the output tensor is initialized to prevent any possible + nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors + are filled with the maximum value. + + Args: + size (tuple of int): the shape of the output tensor + stride (tuple of int): the strides of the output tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> a = torch.empty_strided((2, 3), (1, 2)) + >>> a + tensor([[8.9683e-44, 4.4842e-44, 5.1239e+07], + [0.0000e+00, 0.0000e+00, 3.0705e-41]]) + >>> a.stride() + (1, 2) + >>> a.size() + torch.Size([2, 3]) + """ + +@overload +def eq( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + eq(input, other, *, out=None) -> Tensor + + Computes element-wise equality + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) + """ + +@overload +def eq( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + eq(input, other, *, out=None) -> Tensor + + Computes element-wise equality + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[ True, False], + [False, True]]) + """ + +def equal(input: Tensor, other: Tensor) -> _bool: + r""" + equal(input, other) -> bool + + ``True`` if two tensors have the same size and elements, ``False`` otherwise. + + .. note:: + + Tensors containing NaNs are never equal to each other. Additionally, this function does not + differentiate between the data types of the tensors during comparison. For more thorough tensor checks, + use :meth:`torch.testing.assert_close`. + + Example:: + + >>> torch.equal(torch.tensor([1, 2]), torch.tensor([1, 2])) + True + >>> torch.equal(torch.tensor([3, torch.nan]), torch.tensor([3, torch.nan])) + False + >>> torch.equal(torch.tensor([1, 2, 3], dtype=torch.int32), torch.tensor([1, 2, 3], dtype=torch.float32)) + True + """ + +def erf(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + erf(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erf`. + """ + +def erf_(input: Tensor) -> Tensor: ... +def erfc(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + erfc(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erfc`. + """ + +def erfc_(input: Tensor) -> Tensor: ... +def erfinv(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + erfinv(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.erfinv`. + """ + +def exp(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + exp(input, *, out=None) -> Tensor + + Returns a new tensor with the exponential of the elements + of the input tensor :attr:`input`. + + .. math:: + y_{i} = e^{x_{i}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.exp(torch.tensor([0, math.log(2.)])) + tensor([ 1., 2.]) + """ + +def exp2(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + exp2(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.exp2`. + """ + +def exp2_(input: Tensor) -> Tensor: ... +def exp_(input: Tensor) -> Tensor: ... +def expand_copy( + input: Tensor, + size: Sequence[_int | SymInt], + *, + implicit: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.Tensor.expand`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def expm1(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + expm1(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.expm1`. + """ + +def expm1_(input: Tensor) -> Tensor: ... +@overload +def eye( + n: _int | SymInt, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + + Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) + """ + +@overload +def eye( + n: _int | SymInt, + m: _int | SymInt, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Args: + n (int): the number of rows + m (int, optional): the number of columns with default being :attr:`n` + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 2-D tensor with ones on the diagonal and zeros elsewhere + + Example:: + + >>> torch.eye(3) + tensor([[ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]) + """ + +def fake_quantize_per_channel_affine( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + axis: _int, + quant_min: _int, + quant_max: _int, +) -> Tensor: + r""" + fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized per channel using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`, across the channel specified by :attr:`axis`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), in ``torch.float32`` + scale (Tensor): quantization scale, per channel in ``torch.float32`` + zero_point (Tensor): quantization zero_point, per channel in ``torch.int32`` or ``torch.half`` or ``torch.float32`` + axis (int32): channel axis + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized per channel ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(2, 2, 2) + >>> x + tensor([[[-0.2525, -0.0466], + [ 0.3491, -0.2168]], + + [[-0.5906, 1.6258], + [ 0.6444, -0.0542]]]) + >>> scales = (torch.randn(2) + 1) * 0.05 + >>> scales + tensor([0.0475, 0.0486]) + >>> zero_points = torch.zeros(2).to(torch.int32) + >>> zero_points + tensor([0, 0]) + >>> torch.fake_quantize_per_channel_affine(x, scales, zero_points, 1, 0, 255) + tensor([[[0.0000, 0.0000], + [0.3405, 0.0000]], + + [[0.0000, 1.6134], + [0.6323, 0.0000]]]) + """ + +@overload +def fake_quantize_per_tensor_affine( + input: Tensor, + scale: _float, + zero_point: _int, + quant_min: _int, + quant_max: _int, +) -> Tensor: + r""" + fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + """ + +@overload +def fake_quantize_per_tensor_affine( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + quant_min: _int, + quant_max: _int, +) -> Tensor: + r""" + fake_quantize_per_tensor_affine(input, scale, zero_point, quant_min, quant_max) -> Tensor + + Returns a new tensor with the data in :attr:`input` fake quantized using :attr:`scale`, + :attr:`zero_point`, :attr:`quant_min` and :attr:`quant_max`. + + .. math:: + \text{output} = ( + min( + \text{quant\_max}, + max( + \text{quant\_min}, + \text{std::nearby\_int}(\text{input} / \text{scale}) + \text{zero\_point} + ) + ) - \text{zero\_point} + ) \times \text{scale} + + Args: + input (Tensor): the input value(s), ``torch.float32`` tensor + scale (double scalar or ``float32`` Tensor): quantization scale + zero_point (int64 scalar or ``int32`` Tensor): quantization zero_point + quant_min (int64): lower bound of the quantized domain + quant_max (int64): upper bound of the quantized domain + + Returns: + Tensor: A newly fake_quantized ``torch.float32`` tensor + + Example:: + + >>> x = torch.randn(4) + >>> x + tensor([ 0.0552, 0.9730, 0.3973, -1.0780]) + >>> torch.fake_quantize_per_tensor_affine(x, 0.1, 0, 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + >>> torch.fake_quantize_per_tensor_affine(x, torch.tensor(0.1), torch.tensor(0), 0, 255) + tensor([0.1000, 1.0000, 0.4000, 0.0000]) + """ + +def fbgemm_linear_fp16_weight( + input: Tensor, + packed_weight: Tensor, + bias: Tensor, +) -> Tensor: ... +def fbgemm_linear_fp16_weight_fp32_activation( + input: Tensor, + packed_weight: Tensor, + bias: Tensor, +) -> Tensor: ... +def fbgemm_linear_int8_weight( + input: Tensor, + weight: Tensor, + packed: Tensor, + col_offsets: Tensor, + weight_scale: Number | _complex, + weight_zero_point: Number | _complex, + bias: Tensor, +) -> Tensor: ... +def fbgemm_linear_int8_weight_fp32_activation( + input: Tensor, + weight: Tensor, + packed: Tensor, + col_offsets: Tensor, + weight_scale: Number | _complex, + weight_zero_point: Number | _complex, + bias: Tensor, +) -> Tensor: ... +def fbgemm_linear_quantize_weight( + input: Tensor, +) -> tuple[Tensor, Tensor, _float, _int]: ... +def fbgemm_pack_gemm_matrix_fp16(input: Tensor) -> Tensor: ... +@overload +def fbgemm_pack_quantized_matrix(input: Tensor) -> Tensor: ... +@overload +def fbgemm_pack_quantized_matrix(input: Tensor, K: _int, N: _int) -> Tensor: ... +def feature_alpha_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_alpha_dropout_( + input: Tensor, + p: _float, + train: _bool, +) -> Tensor: ... +def feature_dropout(input: Tensor, p: _float, train: _bool) -> Tensor: ... +def feature_dropout_(input: Tensor, p: _float, train: _bool) -> Tensor: ... +@overload +def fill(input: Tensor, value: Tensor) -> Tensor: ... +@overload +def fill(input: Tensor, value: Number | _complex) -> Tensor: ... +@overload +def fill_(input: Tensor, value: Tensor) -> Tensor: ... +@overload +def fill_(input: Tensor, value: Number | _complex) -> Tensor: ... +def fix(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + fix(input, *, out=None) -> Tensor + + Alias for :func:`torch.trunc` + """ + +def fix_(input: Tensor) -> Tensor: ... +@overload +def flatten( + input: Tensor, + start_dim: _int = 0, + end_dim: _int = -1, +) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + +@overload +def flatten( + input: Tensor, + start_dim: _int, + end_dim: _int, + out_dim: str | EllipsisType | None, +) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + +@overload +def flatten( + input: Tensor, + start_dim: str | EllipsisType | None, + end_dim: str | EllipsisType | None, + out_dim: str | EllipsisType | None, +) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + +@overload +def flatten( + input: Tensor, + dims: Sequence[str | EllipsisType | None], + out_dim: str | EllipsisType | None, +) -> Tensor: + r""" + flatten(input, start_dim=0, end_dim=-1) -> Tensor + + Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim` + are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened. + The order of elements in :attr:`input` is unchanged. + + Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view, + or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can + be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the + flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Args: + input (Tensor): the input tensor. + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.flatten(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + >>> torch.flatten(t, start_dim=1) + tensor([[1, 2, 3, 4], + [5, 6, 7, 8]]) + """ + +def flip(input: Tensor, dims: _size) -> Tensor: + r""" + flip(input, dims) -> Tensor + + Reverse the order of an n-D tensor along given axis in dims. + + .. note:: + `torch.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flip` is expected to be slower than `np.flip`. + + Args: + input (Tensor): the input tensor. + dims (a list or tuple): axis to flip on + + Example:: + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[ 0, 1], + [ 2, 3]], + + [[ 4, 5], + [ 6, 7]]]) + >>> torch.flip(x, [0, 1]) + tensor([[[ 6, 7], + [ 4, 5]], + + [[ 2, 3], + [ 0, 1]]]) + """ + +def fliplr(input: Tensor) -> Tensor: + r""" + fliplr(input) -> Tensor + + Flip tensor in the left/right direction, returning a new tensor. + + Flip the entries in each row in the left/right direction. + Columns are preserved, but appear in a different order than before. + + Note: + Requires the tensor to be at least 2-D. + + .. note:: + `torch.fliplr` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.fliplr`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.fliplr` is expected to be slower than `np.fliplr`. + + Args: + input (Tensor): Must be at least 2-dimensional. + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.fliplr(x) + tensor([[1, 0], + [3, 2]]) + """ + +def flipud(input: Tensor) -> Tensor: + r""" + flipud(input) -> Tensor + + Flip tensor in the up/down direction, returning a new tensor. + + Flip the entries in each column in the up/down direction. + Rows are preserved, but appear in a different order than before. + + Note: + Requires the tensor to be at least 1-D. + + .. note:: + `torch.flipud` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flipud`, + which returns a view in constant time. Since copying a tensor's data is more work than viewing that data, + `torch.flipud` is expected to be slower than `np.flipud`. + + Args: + input (Tensor): Must be at least 1-dimensional. + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.flipud(x) + tensor([[2, 3], + [0, 1]]) + """ + +@overload +def float_power( + input: Tensor, + exponent: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + +@overload +def float_power( + self: Number | _complex, + exponent: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + +@overload +def float_power( + input: Tensor, + exponent: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + float_power(input, exponent, *, out=None) -> Tensor + + Raises :attr:`input` to the power of :attr:`exponent`, elementwise, in double precision. + If neither input is complex returns a ``torch.float64`` tensor, + and if one or more inputs is complex returns a ``torch.complex128`` tensor. + + .. note:: + This function always computes in double precision, unlike :func:`torch.pow`, + which implements more typical :ref:`type promotion `. + This is useful when the computation needs to be performed in a wider or more precise dtype, + or the results of the computation may contain fractional values not representable in the input dtypes, + like when an integer base is raised to a negative integer exponent. + + Args: + input (Tensor or Number): the base value(s) + exponent (Tensor or Number): the exponent value(s) + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randint(10, (4,)) + >>> a + tensor([6, 4, 7, 1]) + >>> torch.float_power(a, 2) + tensor([36., 16., 49., 1.], dtype=torch.float64) + + >>> a = torch.arange(1, 5) + >>> a + tensor([ 1, 2, 3, 4]) + >>> exp = torch.tensor([2, -3, 4, -5]) + >>> exp + tensor([ 2, -3, 4, -5]) + >>> torch.float_power(a, exp) + tensor([1.0000e+00, 1.2500e-01, 8.1000e+01, 9.7656e-04], dtype=torch.float64) + """ + +def floor(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + floor(input, *, out=None) -> Tensor + + Returns a new tensor with the floor of the elements of :attr:`input`, + the largest integer less than or equal to each element. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + .. math:: + \text{out}_{i} = \left\lfloor \text{input}_{i} \right\rfloor + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.8166, 1.5308, -0.2530, -0.2091]) + >>> torch.floor(a) + tensor([-1., 1., -1., -1.]) + """ + +def floor_(input: Tensor) -> Tensor: ... +def floor_divide( + input: Tensor | Number, + other: Tensor | Number, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + floor_divide(input, other, *, out=None) -> Tensor + + .. note:: + + Before PyTorch 1.13 :func:`torch.floor_divide` incorrectly performed + truncation division. To restore the previous behavior use + :func:`torch.div` with ``rounding_mode='trunc'``. + + Computes :attr:`input` divided by :attr:`other`, elementwise, and floors + the result. + + .. math:: + \text{{out}}_i = \text{floor} \left( \frac{{\text{{input}}_i}}{{\text{{other}}_i}} \right) + + + + Supports broadcasting to a common shape, type promotion, and integer and float inputs. + + Args: + input (Tensor or Number): the dividend + other (Tensor or Number): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([4.0, 3.0]) + >>> b = torch.tensor([2.0, 2.0]) + >>> torch.floor_divide(a, b) + tensor([2.0, 1.0]) + >>> torch.floor_divide(a, 1.4) + tensor([2.0, 2.0]) + """ + +def fmax( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + fmax(input, other, *, out=None) -> Tensor + + Computes the element-wise maximum of :attr:`input` and :attr:`other`. + + This is like :func:`torch.maximum` except it handles NaNs differently: + if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum. + Only if both elements are NaN is NaN propagated. + + This function is a wrapper around C++'s ``std::fmax`` and is similar to NumPy's ``fmax`` function. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and floating-point inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([9.7, float('nan'), 3.1, float('nan')]) + >>> b = torch.tensor([-2.2, 0.5, float('nan'), float('nan')]) + >>> torch.fmax(a, b) + tensor([9.7000, 0.5000, 3.1000, nan]) + """ + +def fmin( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + fmin(input, other, *, out=None) -> Tensor + + Computes the element-wise minimum of :attr:`input` and :attr:`other`. + + This is like :func:`torch.minimum` except it handles NaNs differently: + if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum. + Only if both elements are NaN is NaN propagated. + + This function is a wrapper around C++'s ``std::fmin`` and is similar to NumPy's ``fmin`` function. + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and floating-point inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([2.2, float('nan'), 2.1, float('nan')]) + >>> b = torch.tensor([-9.3, 0.1, float('nan'), float('nan')]) + >>> torch.fmin(a, b) + tensor([-9.3000, 0.1000, 2.1000, nan]) + """ + +@overload +def fmod( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + fmod(input, other, *, out=None) -> Tensor + + Applies C++'s `std::fmod `_ entrywise. + The result has the same sign as the dividend :attr:`input` and its absolute value + is less than that of :attr:`other`. + + This function may be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + + .. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + + .. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + + Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + """ + +@overload +def fmod( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + fmod(input, other, *, out=None) -> Tensor + + Applies C++'s `std::fmod `_ entrywise. + The result has the same sign as the dividend :attr:`input` and its absolute value + is less than that of :attr:`other`. + + This function may be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + + When the divisor is zero, returns ``NaN`` for floating point dtypes + on both CPU and GPU; raises ``RuntimeError`` for integer division by + zero on CPU; Integer division by zero on GPU may return any value. + + .. note:: + + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + + .. seealso:: + + :func:`torch.remainder` which implements Python's modulus operator. + This one is defined using division rounding down the result. + + Args: + input (Tensor): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.fmod(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([-1., -0., -1., 1., 0., 1.]) + >>> torch.fmod(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([1.0000, 0.5000, 0.0000, 1.0000, 0.5000]) + """ + +def frac(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + frac(input, *, out=None) -> Tensor + + Computes the fractional portion of each element in :attr:`input`. + + .. math:: + \text{out}_{i} = \text{input}_{i} - \left\lfloor |\text{input}_{i}| \right\rfloor * \operatorname{sgn}(\text{input}_{i}) + + Example:: + + >>> torch.frac(torch.tensor([1, 2.5, -3.2])) + tensor([ 0.0000, 0.5000, -0.2000]) + """ + +def frac_(input: Tensor) -> Tensor: ... +def frexp( + input: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.frexp: + r""" + frexp(input, *, out=None) -> (Tensor mantissa, Tensor exponent) + + Decomposes :attr:`input` into mantissa and exponent tensors + such that :math:`\text{input} = \text{mantissa} \times 2^{\text{exponent}}`. + + The range of mantissa is the open interval (-1, 1). + + Supports float inputs. + + Args: + input (Tensor): the input tensor + + + Keyword args: + out (tuple, optional): the output tensors + + Example:: + + >>> x = torch.arange(9.) + >>> mantissa, exponent = torch.frexp(x) + >>> mantissa + tensor([0.0000, 0.5000, 0.5000, 0.7500, 0.5000, 0.6250, 0.7500, 0.8750, 0.5000]) + >>> exponent + tensor([0, 1, 2, 2, 3, 3, 3, 3, 4], dtype=torch.int32) + >>> torch.ldexp(mantissa, exponent) + tensor([0., 1., 2., 3., 4., 5., 6., 7., 8.]) + """ + +def frobenius_norm( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: ... +def from_file( + filename: str, + shared: _bool | None = None, + size: _int | None = 0, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + from_file(filename, shared=None, size=0, *, dtype=None, layout=None, device=None, pin_memory=False) + + Creates a CPU tensor with a storage backed by a memory-mapped file. + + If ``shared`` is True, then memory is shared between processes. All changes are written to the file. + If ``shared`` is False, then changes to the tensor do not affect the file. + + ``size`` is the number of elements in the Tensor. If ``shared`` is ``False``, then the file must contain + at least ``size * sizeof(dtype)`` bytes. If ``shared`` is ``True`` the file will be created if needed. + + .. note:: + Only CPU tensors can be mapped to files. + + .. note:: + For now, tensors with storages backed by a memory-mapped file cannot be created in pinned memory. + + + Args: + filename (str): file name to map + shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the + underlying `mmap(2) call `_) + size (int): number of elements in the tensor + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> t = torch.randn(2, 5, dtype=torch.float64) + >>> t.numpy().tofile('storage.pt') + >>> t_mapped = torch.from_file('storage.pt', shared=False, size=10, dtype=torch.float64) + """ + +def from_numpy(ndarray) -> Tensor: + r""" + from_numpy(ndarray) -> Tensor + + Creates a :class:`Tensor` from a :class:`numpy.ndarray`. + + The returned tensor and :attr:`ndarray` share the same memory. Modifications to + the tensor will be reflected in the :attr:`ndarray` and vice versa. The returned + tensor is not resizable. + + It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``, + ``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``, + ``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, + and ``bool``. + + .. warning:: + Writing to a tensor created from a read-only NumPy array is not supported and will result in undefined behavior. + + Example:: + + >>> a = numpy.array([1, 2, 3]) + >>> t = torch.from_numpy(a) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + """ + +def frombuffer( + buffer: Any, + *, + dtype: _dtype, + count: int = -1, + offset: int = 0, + requires_grad: _bool = False, +) -> Tensor: + r""" + frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False) -> Tensor + + Creates a 1-dimensional :class:`Tensor` from an object that implements + the Python buffer protocol. + + Skips the first :attr:`offset` bytes in the buffer, and interprets the rest of + the raw bytes as a 1-dimensional tensor of type :attr:`dtype` with :attr:`count` + elements. + + Note that either of the following must be true: + + 1. :attr:`count` is a positive non-zero number, and the total number of bytes + in the buffer is more than :attr:`offset` plus :attr:`count` times the size + (in bytes) of :attr:`dtype`. + + 2. :attr:`count` is negative, and the length (number of bytes) of the buffer + subtracted by the :attr:`offset` is a multiple of the size (in bytes) of + :attr:`dtype`. + + The returned tensor and buffer share the same memory. Modifications to + the tensor will be reflected in the buffer and vice versa. The returned + tensor is not resizable. + + .. note:: + This function increments the reference count for the object that + owns the shared memory. Therefore, such memory will not be deallocated + before the returned tensor goes out of scope. + + .. warning:: + This function's behavior is undefined when passed an object implementing + the buffer protocol whose data is not on the CPU. Doing so is likely to + cause a segmentation fault. + + .. warning:: + This function does not try to infer the :attr:`dtype` (hence, it is not + optional). Passing a different :attr:`dtype` than its source may result + in unexpected behavior. + + Args: + buffer (object): a Python object that exposes the buffer interface. + + Keyword args: + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + count (int, optional): the number of desired elements to be read. + If negative, all the elements (until the end of the buffer) will be + read. Default: -1. + offset (int, optional): the number of bytes to skip at the start of + the buffer. Default: 0. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> import array + >>> a = array.array('i', [1, 2, 3]) + >>> t = torch.frombuffer(a, dtype=torch.int32) + >>> t + tensor([ 1, 2, 3]) + >>> t[0] = -1 + >>> a + array([-1, 2, 3]) + + >>> # Interprets the signed char bytes as 32-bit integers. + >>> # Each 4 signed char elements will be interpreted as + >>> # 1 signed 32-bit integer. + >>> import array + >>> a = array.array('b', [-1, 0, 0, 0]) + >>> torch.frombuffer(a, dtype=torch.int32) + tensor([255], dtype=torch.int32) + """ + +@overload +def full( + size: _size, + fill_value: Number | _complex, + *, + out: Tensor | None = None, + layout: _layout = strided, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + +@overload +def full( + size: _size, + fill_value: Number | _complex, + *, + names: list[str | None], + layout: _layout = strided, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + +@overload +def full( + size: Sequence[_int | SymInt], + fill_value: Number | _complex, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + +@overload +def full( + size: _size, + fill_value: Number | _complex, + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a tensor of size :attr:`size` filled with :attr:`fill_value`. The + tensor's dtype is inferred from :attr:`fill_value`. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.full((2, 3), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416]]) + """ + +def full_like( + input: Tensor, + fill_value: Number | _complex, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + full_like(input, fill_value, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`. + ``torch.full_like(input, fill_value)`` is equivalent to + ``torch.full(input.size(), fill_value, dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + fill_value: the number to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +def fused_moving_avg_obs_fake_quant( + input: Tensor, + observer_on: Tensor, + fake_quant_on: Tensor, + running_min: Tensor, + running_max: Tensor, + scale: Tensor, + zero_point: Tensor, + averaging_const: _float, + quant_min: _int, + quant_max: _int, + ch_axis: _int, + per_row_fake_quant: _bool = False, + symmetric_quant: _bool = False, +) -> Tensor: ... +@overload +def gather( + input: Tensor, + dim: _int, + index: Tensor, + *, + sparse_grad: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + + Gathers values along an axis specified by `dim`. + + For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + :attr:`input` and :attr:`index` must have the same number of dimensions. + It is also required that ``index.size(d) <= input.size(d)`` for all + dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. + Note that ``input`` and ``index`` do not broadcast against each other. + + Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + + Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + + Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) + """ + +@overload +def gather( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + *, + sparse_grad: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor + + Gathers values along an axis specified by `dim`. + + For a 3-D tensor the output is specified by:: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + :attr:`input` and :attr:`index` must have the same number of dimensions. + It is also required that ``index.size(d) <= input.size(d)`` for all + dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. + Note that ``input`` and ``index`` do not broadcast against each other. + + Args: + input (Tensor): the source tensor + dim (int): the axis along which to index + index (LongTensor): the indices of elements to gather + + Keyword arguments: + sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor + + Example:: + + >>> t = torch.tensor([[1, 2], [3, 4]]) + >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) + tensor([[ 1, 1], + [ 4, 3]]) + """ + +def gcd( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + gcd(input, other, *, out=None) -> Tensor + + Computes the element-wise greatest common divisor (GCD) of :attr:`input` and :attr:`other`. + + Both :attr:`input` and :attr:`other` must have integer types. + + .. note:: + This defines :math:`gcd(0, 0) = 0`. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.gcd(a, b) + tensor([1, 2, 5]) + >>> c = torch.tensor([3]) + >>> torch.gcd(a, c) + tensor([1, 1, 3]) + """ + +def gcd_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def ge( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ge(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \geq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) + """ + +@overload +def ge( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ge(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \geq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than or equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ge(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, True], [False, True]]) + """ + +def geqrf( + input: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.geqrf: + r""" + geqrf(input, *, out=None) -> (Tensor, Tensor) + + This is a low-level function for calling LAPACK's geqrf directly. This function + returns a namedtuple (a, tau) as defined in `LAPACK documentation for geqrf`_ . + + Computes a QR decomposition of :attr:`input`. + Both `Q` and `R` matrices are stored in the same output tensor `a`. + The elements of `R` are stored on and above the diagonal. + Elementary reflectors (or Householder vectors) implicitly defining matrix `Q` + are stored below the diagonal. + The results of this function can be used together with :func:`torch.linalg.householder_product` + to obtain the `Q` matrix or + with :func:`torch.ormqr`, which uses an implicit representation of the `Q` matrix, + for an efficient matrix-matrix multiplication. + + See `LAPACK documentation for geqrf`_ for further details. + + .. note:: + See also :func:`torch.linalg.qr`, which computes Q and R matrices, and :func:`torch.linalg.lstsq` + with the ``driver="gels"`` option for a function that can solve matrix equations using a QR decomposition. + + Args: + input (Tensor): the input matrix + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, Tensor). Ignored if `None`. Default: `None`. + + .. _LAPACK documentation for geqrf: + http://www.netlib.org/lapack/explore-html/df/dc5/group__variants_g_ecomputational_ga3766ea903391b5cf9008132f7440ec7b.html + """ + +def ger( + input: Tensor, + vec2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ger(input, vec2, *, out=None) -> Tensor + + Alias of :func:`torch.outer`. + + .. warning:: + This function is deprecated and will be removed in a future PyTorch release. + Use :func:`torch.outer` instead. + """ + +def get_default_dtype() -> _dtype: + r""" + get_default_dtype() -> torch.dtype + + Get the current default floating point :class:`torch.dtype`. + + Example:: + + >>> torch.get_default_dtype() # initial default for floating point is torch.float32 + torch.float32 + >>> torch.set_default_dtype(torch.float64) + >>> torch.get_default_dtype() # default is now changed to torch.float64 + torch.float64 + """ + +def get_num_interop_threads() -> _int: + r""" + get_num_interop_threads() -> int + + Returns the number of threads used for inter-op parallelism on CPU + (e.g. in JIT interpreter) + """ + +def get_num_threads() -> _int: + r""" + get_num_threads() -> int + + Returns the number of threads used for parallelizing CPU operations + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: Number | _complex | None = None, + dim: _int | None = None, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: Sequence[Number | _complex], + dim: _int | None = None, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: Sequence[Number | _complex], + dim: _size, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int | None = None, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: Number | _complex, + dim: _size, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + spacing: tuple[Tensor, ...] | list[Tensor] | None, + dim: _size, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def gradient( + input: Tensor, + *, + dim: _size, + edge_order: _int = 1, +) -> tuple[Tensor, ...]: + r""" + gradient(input, *, spacing=1, dim=None, edge_order=1) -> List of Tensors + + Estimates the gradient of a function :math:`g : \mathbb{R}^n \rightarrow \mathbb{R}` in + one or more dimensions using the `second-order accurate central differences method + `_ and + either first or second order estimates at the boundaries. + + The gradient of :math:`g` is estimated using samples. By default, when :attr:`spacing` is not + specified, the samples are entirely described by :attr:`input`, and the mapping of input coordinates + to an output is the same as the tensor's mapping of indices to values. For example, for a three-dimensional + :attr:`input` the function described is :math:`g : \mathbb{R}^3 \rightarrow \mathbb{R}`, and + :math:`g(1, 2, 3)\ == input[1, 2, 3]`. + + When :attr:`spacing` is specified, it modifies the relationship between :attr:`input` and input coordinates. + This is detailed in the "Keyword Arguments" section below. + + The gradient is estimated by estimating each partial derivative of :math:`g` independently. This estimation is + accurate if :math:`g` is in :math:`C^3` (it has at least 3 continuous derivatives), and the estimation can be + improved by providing closer samples. Mathematically, the value at each interior point of a partial derivative + is estimated using `Taylor's theorem with remainder `_. + Letting :math:`x` be an interior point with :math:`x-h_l` and :math:`x+h_r` be points neighboring + it to the left and right respectively, :math:`f(x+h_r)` and :math:`f(x-h_l)` can be estimated using: + + .. math:: + \begin{aligned} + f(x+h_r) = f(x) + h_r f'(x) + {h_r}^2 \frac{f''(x)}{2} + {h_r}^3 \frac{f'''(\xi_1)}{6}, \xi_1 \in (x, x+h_r) \\ + f(x-h_l) = f(x) - h_l f'(x) + {h_l}^2 \frac{f''(x)}{2} - {h_l}^3 \frac{f'''(\xi_2)}{6}, \xi_2 \in (x, x-h_l) \\ + \end{aligned} + + Using the fact that :math:`f \in C^3` and solving the linear system, we derive: + + .. math:: + f'(x) \approx \frac{ {h_l}^2 f(x+h_r) - {h_r}^2 f(x-h_l) + + ({h_r}^2-{h_l}^2 ) f(x) }{ {h_r} {h_l}^2 + {h_r}^2 {h_l} } + + .. note:: + We estimate the gradient of functions in complex domain + :math:`g : \mathbb{C}^n \rightarrow \mathbb{C}` in the same way. + + The value of each partial derivative at the boundary points is computed differently. See edge_order below. + + Args: + input (``Tensor``): the tensor that represents the values of the function + + Keyword args: + spacing (``scalar``, ``list of scalar``, ``list of Tensor``, optional): :attr:`spacing` can be used to modify + how the :attr:`input` tensor's indices relate to sample coordinates. If :attr:`spacing` is a scalar then + the indices are multiplied by the scalar to produce the coordinates. For example, if :attr:`spacing=2` the + indices (1, 2, 3) become coordinates (2, 4, 6). If :attr:`spacing` is a list of scalars then the corresponding + indices are multiplied. For example, if :attr:`spacing=(2, -1, 3)` the indices (1, 2, 3) become coordinates (2, -2, 9). + Finally, if :attr:`spacing` is a list of one-dimensional tensors then each tensor specifies the coordinates for + the corresponding dimension. For example, if the indices are (1, 2, 3) and the tensors are (t0, t1, t2), then + the coordinates are (t0[1], t1[2], t2[3]) + + dim (``int``, ``list of int``, optional): the dimension or dimensions to approximate the gradient over. By default + the partial gradient in every dimension is computed. Note that when :attr:`dim` is specified the elements of + the :attr:`spacing` argument must correspond with the specified dims." + + edge_order (``int``, optional): 1 or 2, for `first-order + `_ or + `second-order `_ + estimation of the boundary ("edge") values, respectively. + + Examples:: + + >>> # Estimates the gradient of f(x)=x^2 at points [-2, -1, 2, 4] + >>> coordinates = (torch.tensor([-2., -1., 1., 4.]),) + >>> values = torch.tensor([4., 1., 1., 16.], ) + >>> torch.gradient(values, spacing = coordinates) + (tensor([-3., -2., 2., 5.]),) + + >>> # Estimates the gradient of the R^2 -> R function whose samples are + >>> # described by the tensor t. Implicit coordinates are [0, 1] for the outermost + >>> # dimension and [0, 1, 2, 3] for the innermost dimension, and function estimates + >>> # partial derivative for both dimensions. + >>> t = torch.tensor([[1, 2, 4, 8], [10, 20, 40, 80]]) + >>> torch.gradient(t) + (tensor([[ 9., 18., 36., 72.], + [ 9., 18., 36., 72.]]), + tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]])) + + >>> # A scalar value for spacing modifies the relationship between tensor indices + >>> # and input coordinates by multiplying the indices to find the + >>> # coordinates. For example, below the indices of the innermost + >>> # 0, 1, 2, 3 translate to coordinates of [0, 2, 4, 6], and the indices of + >>> # the outermost dimension 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = 2.0) # dim = None (implicitly [0, 1]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.5000, 0.7500, 1.5000, 2.0000], + [ 5.0000, 7.5000, 15.0000, 20.0000]])) + >>> # doubling the spacing between samples halves the estimated partial gradients. + + >>> + >>> # Estimates only the partial derivative for dimension 1 + >>> torch.gradient(t, dim = 1) # spacing = None (implicitly 1.) + (tensor([[ 1.0000, 1.5000, 3.0000, 4.0000], + [10.0000, 15.0000, 30.0000, 40.0000]]),) + + >>> # When spacing is a list of scalars, the relationship between the tensor + >>> # indices and input coordinates changes based on dimension. + >>> # For example, below, the indices of the innermost dimension 0, 1, 2, 3 translate + >>> # to coordinates of [0, 3, 6, 9], and the indices of the outermost dimension + >>> # 0, 1 translate to coordinates of [0, 2]. + >>> torch.gradient(t, spacing = [3., 2.]) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + + >>> # The following example is a replication of the previous one with explicit + >>> # coordinates. + >>> coords = (torch.tensor([0, 2]), torch.tensor([0, 3, 6, 9])) + >>> torch.gradient(t, spacing = coords) + (tensor([[ 4.5000, 9.0000, 18.0000, 36.0000], + [ 4.5000, 9.0000, 18.0000, 36.0000]]), + tensor([[ 0.3333, 0.5000, 1.0000, 1.3333], + [ 3.3333, 5.0000, 10.0000, 13.3333]])) + """ + +@overload +def greater( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + greater(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.gt`. + """ + +@overload +def greater( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + greater(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.gt`. + """ + +@overload +def greater_equal( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + greater_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ge`. + """ + +@overload +def greater_equal( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + greater_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ge`. + """ + +def grid_sampler( + input: Tensor, + grid: Tensor, + interpolation_mode: _int, + padding_mode: _int, + align_corners: _bool, +) -> Tensor: ... +def grid_sampler_2d( + input: Tensor, + grid: Tensor, + interpolation_mode: _int, + padding_mode: _int, + align_corners: _bool, +) -> Tensor: ... +def grid_sampler_3d( + input: Tensor, + grid: Tensor, + interpolation_mode: _int, + padding_mode: _int, + align_corners: _bool, +) -> Tensor: ... +def group_norm( + input: Tensor, + num_groups: _int, + weight: Tensor | None = None, + bias: Tensor | None = None, + eps: _float = 1e-05, + cudnn_enabled: _bool = True, +) -> Tensor: ... +@overload +def gru( + data: Tensor, + batch_sizes: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, +) -> tuple[Tensor, Tensor]: ... +@overload +def gru( + input: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor]: ... +def gru_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor | None = None, + b_hh: Tensor | None = None, +) -> Tensor: ... +@overload +def gt( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + gt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} > \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + + Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) + """ + +@overload +def gt( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + gt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} > \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is greater than :attr:`other` and False elsewhere + + Example:: + + >>> torch.gt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [False, False]]) + """ + +@overload +def hamming_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + +@overload +def hamming_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + +@overload +def hamming_window( + window_length: _int, + periodic: _bool, + alpha: _float, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + +@overload +def hamming_window( + window_length: _int, + periodic: _bool, + alpha: _float, + beta: _float, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hamming window function. + + .. math:: + w[n] = \alpha - \beta\ \cos \left( \frac{2 \pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hamming_window(L, periodic=True)`` equal to + ``torch.hamming_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + .. note:: + This is a generalized version of :meth:`torch.hann_window`. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window. + """ + +@overload +def hann_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hann_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hann window function. + + .. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hann_window(L, periodic=True)`` equal to + ``torch.hann_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +@overload +def hann_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + hann_window(window_length, periodic=True, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Hann window function. + + .. math:: + w[n] = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{N - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{N - 1} \right), + + where :math:`N` is the full window size. + + The input :attr:`window_length` is a positive integer controlling the + returned window size. :attr:`periodic` flag determines whether the returned + window trims off the last duplicate value from the symmetric window and is + ready to be used as a periodic window with functions like + :meth:`torch.stft`. Therefore, if :attr:`periodic` is true, the :math:`N` in + above formula is in fact :math:`\text{window\_length} + 1`. Also, we always have + ``torch.hann_window(L, periodic=True)`` equal to + ``torch.hann_window(L + 1, periodic=False)[:-1])``. + + .. note:: + If :attr:`window_length` :math:`=1`, the returned window contains a single value 1. + + Arguments: + window_length (int): the size of returned window + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). Only floating point types are supported. + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Returns: + Tensor: A 1-D tensor of size :math:`(\text{window\_length},)` containing the window + """ + +def hardshrink( + input: Tensor, + lambd: Number | _complex = 0.5, + *, + out: Tensor | None = None, +) -> Tensor: ... +def heaviside( + input: Tensor, + values: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + heaviside(input, values, *, out=None) -> Tensor + + Computes the Heaviside step function for each element in :attr:`input`. + The Heaviside step function is defined as: + + .. math:: + \text{{heaviside}}(input, values) = \begin{cases} + 0, & \text{if input < 0}\\ + values, & \text{if input == 0}\\ + 1, & \text{if input > 0} + \end{cases} + + + Args: + input (Tensor): the input tensor. + values (Tensor): The values to use where :attr:`input` is zero. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> input = torch.tensor([-1.5, 0, 2.0]) + >>> values = torch.tensor([0.5]) + >>> torch.heaviside(input, values) + tensor([0.0000, 0.5000, 1.0000]) + >>> values = torch.tensor([1.2, -2.0, 3.5]) + >>> torch.heaviside(input, values) + tensor([0., -2., 1.]) + """ + +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: _float = 1.0, + reduction: _int = 1, +) -> Tensor: ... +def histc( + input: Tensor, + bins: _int = 100, + min: Number | _complex = 0, + max: Number | _complex = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + histc(input, bins=100, min=0, max=0, *, out=None) -> Tensor + + Computes the histogram of a tensor. + + The elements are sorted into equal width bins between :attr:`min` and + :attr:`max`. If :attr:`min` and :attr:`max` are both zero, the minimum and + maximum values of the data are used. + + Elements lower than min and higher than max and ``NaN`` elements are ignored. + + Args: + input (Tensor): the input tensor. + bins (int): number of histogram bins + min (Scalar): lower end of the range (inclusive) + max (Scalar): upper end of the range (inclusive) + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: Histogram represented as a tensor + + Example:: + + >>> torch.histc(torch.tensor([1., 2, 1]), bins=4, min=0, max=3) + tensor([ 0., 2., 1., 0.]) + """ + +@overload +def histogram( + input: Tensor, + bins: Tensor, + *, + weight: Tensor | None = None, + density: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + + Computes a histogram of the values in a tensor. + + :attr:`bins` can be an integer or a 1D tensor. + + If :attr:`bins` is an int, it specifies the number of equal-width bins. + By default, the lower and upper range of the bins is determined by the + minimum and maximum elements of the input tensor. The :attr:`range` + argument can be provided to specify a range for the bins. + + If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges + including the rightmost edge. It should contain at least 2 elements + and its elements should be increasing. + + Args: + input (Tensor): the input tensor. + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + + Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + out (Tensor, optional): the output tensor. (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + + Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + + Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + """ + +@overload +def histogram( + input: Tensor, + bins: _int = 100, + *, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor) + + Computes a histogram of the values in a tensor. + + :attr:`bins` can be an integer or a 1D tensor. + + If :attr:`bins` is an int, it specifies the number of equal-width bins. + By default, the lower and upper range of the bins is determined by the + minimum and maximum elements of the input tensor. The :attr:`range` + argument can be provided to specify a range for the bins. + + If :attr:`bins` is a 1D tensor, it specifies the sequence of bin edges + including the rightmost edge. It should contain at least 2 elements + and its elements should be increasing. + + Args: + input (Tensor): the input tensor. + bins: int or 1D Tensor. If int, defines the number of equal-width bins. If tensor, + defines the sequence of bin edges including the rightmost edge. + + Keyword args: + range (tuple of float): Defines the range of the bins. + weight (Tensor): If provided, weight should have the same shape as input. Each value in + input contributes its associated weight towards its bin's result. + density (bool): If False, the result will contain the count (or total weight) in each bin. + If True, the result is the value of the probability density function over the bins, + normalized such that the integral over the range of the bins is 1. + out (Tensor, optional): the output tensor. (tuple, optional): The result tuple of two output tensors (hist, bin_edges). + + Returns: + hist (Tensor): 1D Tensor containing the values of the histogram. + bin_edges(Tensor): 1D Tensor containing the edges of the histogram bins. + + Example:: + + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.])) + (tensor([ 0., 5., 2., 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + >>> torch.histogram(torch.tensor([1., 2, 1]), bins=4, range=(0., 3.), weight=torch.tensor([1., 2., 4.]), density=True) + (tensor([ 0., 0.9524, 0.3810, 0.]), tensor([0., 0.75, 1.5, 2.25, 3.])) + """ + +@overload +def histogramdd( + input: Tensor, + bins: _int, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + +@overload +def histogramdd( + input: Tensor, + bins: _size, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + +@overload +def histogramdd( + input: Tensor, + bins: tuple[Tensor, ...] | list[Tensor] | None, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, +) -> torch.return_types.histogramdd: + r""" + histogramdd(input, bins, *, range=None, weight=None, density=False, out=None) -> (Tensor, Tensor[]) + + Computes a multi-dimensional histogram of the values in a tensor. + + Interprets the elements of an input tensor whose innermost dimension has size N + as a collection of N-dimensional points. Maps each of the points into a set of + N-dimensional bins and returns the number of points (or total weight) in each bin. + + :attr:`input` must be a tensor with at least 2 dimensions. + If input has shape (M, N), each of its M rows defines a point in N-dimensional space. + If input has three or more dimensions, all but the last dimension are flattened. + + Each dimension is independently associated with its own strictly increasing sequence + of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D + tensors. Alternatively, bin edges may be constructed automatically by passing a + sequence of integers specifying the number of equal-width bins in each dimension. + + For each N-dimensional point in input: + - Each of its coordinates is binned independently among the bin edges + corresponding to its dimension + - Binning results are combined to identify the N-dimensional bin (if any) + into which the point falls + - If the point falls into a bin, the bin's count (or total weight) is incremented + - Points which do not fall into any bin do not contribute to the output + + :attr:`bins` can be a sequence of N 1D tensors, a sequence of N ints, or a single int. + + If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences + of bin edges. Each 1D tensor should contain a strictly increasing sequence with at + least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying + the left and right edges of all bins. Every bin is exclusive of its left edge. Only + the rightmost bin is inclusive of its right edge. + + If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins + in each dimension. By default, the leftmost and rightmost bin edges in each dimension + are determined by the minimum and maximum elements of the input tensor in the + corresponding dimension. The :attr:`range` argument can be provided to manually + specify the leftmost and rightmost bin edges in each dimension. + + If :attr:`bins` is an int, it specifies the number of equal-width bins for all dimensions. + + .. note:: + See also :func:`torch.histogram`, which specifically computes 1D histograms. + While :func:`torch.histogramdd` infers the dimensionality of its bins and + binned values from the shape of :attr:`input`, :func:`torch.histogram` + accepts and flattens :attr:`input` of any shape. + + Args: + input (Tensor): the input tensor. + bins: Tensor[], int[], or int. + If Tensor[], defines the sequences of bin edges. + If int[], defines the number of equal-width bins in each dimension. + If int, defines the number of equal-width bins for all dimensions. + Keyword args: + range (sequence of float): Defines the leftmost and rightmost bin edges + in each dimension. + weight (Tensor): By default, each value in the input has weight 1. If a weight + tensor is passed, each N-dimensional coordinate in input + contributes its associated weight towards its bin's result. + The weight tensor should have the same shape as the :attr:`input` + tensor excluding its innermost dimension N. + density (bool): If False (default), the result will contain the count (or total weight) + in each bin. If True, each count (weight) is divided by the total count + (total weight), then divided by the volume of its associated bin. + Returns: + hist (Tensor): N-dimensional Tensor containing the values of the histogram. + bin_edges(Tensor[]): sequence of N 1D Tensors containing the bin edges. + + Example:: + + >>> torch.histogramdd(torch.tensor([[0., 1.], [1., 0.], [2., 0.], [2., 2.]]), bins=[3, 3], + ... weight=torch.tensor([1., 2., 4., 8.])) + torch.return_types.histogramdd( + hist=tensor([[0., 1., 0.], + [2., 0., 0.], + [4., 0., 8.]]), + bin_edges=(tensor([0.0000, 0.6667, 1.3333, 2.0000]), + tensor([0.0000, 0.6667, 1.3333, 2.0000]))) + + >>> torch.histogramdd(torch.tensor([[0., 0.], [1., 1.], [2., 2.]]), bins=[2, 2], + ... range=[0., 1., 0., 1.], density=True) + torch.return_types.histogramdd( + hist=tensor([[2., 0.], + [0., 2.]]), + bin_edges=(tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]))) + """ + +def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ... +@overload +def hsplit(input: Tensor, sections: _int) -> tuple[Tensor, ...]: + r""" + hsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors + horizontally according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + If :attr:`input` is one dimensional this is equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is + zero), and if :attr:`input` has two or more dimensions it's equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), + except that if :attr:`indices_or_sections` is an integer it must evenly divide + the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.hsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + """ + +@overload +def hsplit(input: Tensor, indices: _size) -> tuple[Tensor, ...]: + r""" + hsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with one or more dimensions, into multiple tensors + horizontally according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + If :attr:`input` is one dimensional this is equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=0) (the split dimension is + zero), and if :attr:`input` has two or more dimensions it's equivalent to calling + torch.tensor_split(input, indices_or_sections, dim=1) (the split dimension is 1), + except that if :attr:`indices_or_sections` is an integer it must evenly divide + the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.hsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.hsplit(t, 2) + (tensor([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + tensor([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])) + >>> torch.hsplit(t, [3, 6]) + (tensor([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + tensor([[ 3.], + [ 7.], + [11.], + [15.]]), + tensor([], size=(4, 0))) + """ + +def hspmm( + mat1: Tensor, + mat2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + hspmm(mat1, mat2, *, out=None) -> Tensor + + Performs a matrix multiplication of a :ref:`sparse COO matrix + ` :attr:`mat1` and a strided matrix :attr:`mat2`. The + result is a (1 + 1)-dimensional :ref:`hybrid COO matrix + `. + + Args: + mat1 (Tensor): the first sparse matrix to be matrix multiplied + mat2 (Tensor): the second strided matrix to be matrix multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + """ + +def hstack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + hstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence horizontally (column wise). + + This is equivalent to concatenation along the first axis for 1-D tensors, and along the second axis for all other tensors. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.hstack((a,b)) + tensor([1, 2, 3, 4, 5, 6]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.hstack((a,b)) + tensor([[1, 4], + [2, 5], + [3, 6]]) + """ + +def hypot( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + hypot(input, other, *, out=None) -> Tensor + + Given the legs of a right triangle, return its hypotenuse. + + .. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}^{2} + \text{other}_{i}^{2}} + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])) + tensor([5.0000, 5.6569, 6.4031]) + """ + +def i0(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + i0(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.i0`. + """ + +def i0_(input: Tensor) -> Tensor: ... +def igamma( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + igamma(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.gammainc`. + """ + +def igammac( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + igammac(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.gammaincc`. + """ + +def imag(input: Tensor) -> Tensor: + r""" + imag(input) -> Tensor + + Returns a new tensor containing imaginary values of the :attr:`self` tensor. + The returned tensor and :attr:`self` share the same underlying storage. + + .. warning:: + :func:`imag` is only supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.imag + tensor([ 0.3553, -0.7896, -0.0633, -0.8119]) + """ + +@overload +def index_add( + input: Tensor, + dim: _int, + index: Tensor, + source: Tensor, + *, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + index_add(input: Tensor, dim: int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor]) -> Tensor # noqa: B950 + + See :meth:`~Tensor.index_add_` for function description. + """ + +@overload +def index_add( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + source: Tensor, + *, + alpha: Number | _complex = 1, +) -> Tensor: + r""" + index_add(input: Tensor, dim: int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor]) -> Tensor # noqa: B950 + + See :meth:`~Tensor.index_add_` for function description. + """ + +@overload +def index_copy( + input: Tensor, + dim: _int, + index: Tensor, + source: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + index_copy(input: Tensor, dim: int, index: Tensor, source: Tensor, *, out: Optional[Tensor]) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + +@overload +def index_copy( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + source: Tensor, +) -> Tensor: + r""" + index_copy(input: Tensor, dim: int, index: Tensor, source: Tensor, *, out: Optional[Tensor]) -> Tensor + + See :meth:`~Tensor.index_add_` for function description. + """ + +@overload +def index_fill( + input: Tensor, + dim: _int, + index: Tensor, + value: Tensor, +) -> Tensor: ... +@overload +def index_fill( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + value: Tensor, +) -> Tensor: ... +@overload +def index_fill( + input: Tensor, + dim: _int, + index: Tensor, + value: Number | _complex, +) -> Tensor: ... +@overload +def index_fill( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + value: Number | _complex, +) -> Tensor: ... +def index_put( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, +) -> Tensor: ... +def index_put_( + input: Tensor, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, +) -> Tensor: ... +def index_reduce( + input: Tensor, + dim: _int, + index: Tensor, + source: Tensor, + reduce: str, + *, + include_self: _bool = True, + out: Tensor | None = None, +) -> Tensor: + r""" + index_reduce(input: Tensor, dim: int, index: Tensor, source: Tensor, reduce: str, *, include_self: bool = True, out: Optional[Tensor]) -> Tensor # noqa: B950 + + See :meth:`~Tensor.index_reduce_` for function description. + """ + +@overload +def index_select( + input: Tensor, + dim: _int, + index: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + index_select(input, dim, index, *, out=None) -> Tensor + + Returns a new tensor which indexes the :attr:`input` tensor along dimension + :attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + + The returned tensor has the same number of dimensions as the original tensor + (:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length + of :attr:`index`; other dimensions have the same size as in the original tensor. + + .. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) + """ + +@overload +def index_select( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + index_select(input, dim, index, *, out=None) -> Tensor + + Returns a new tensor which indexes the :attr:`input` tensor along dimension + :attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. + + The returned tensor has the same number of dimensions as the original tensor + (:attr:`input`). The :attr:`dim`\ th dimension has the same size as the length + of :attr:`index`; other dimensions have the same size as in the original tensor. + + .. note:: The returned tensor does **not** use the same storage as the original + tensor. If :attr:`out` has a different shape than expected, we + silently change it to the correct shape, reallocating the underlying + storage if necessary. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension in which we index + index (IntTensor or LongTensor): the 1-D tensor containing the indices to index + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-0.4664, 0.2647, -0.1228, -1.1068], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> indices = torch.tensor([0, 2]) + >>> torch.index_select(x, 0, indices) + tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], + [-1.1734, -0.6571, 0.7230, -0.6004]]) + >>> torch.index_select(x, 1, indices) + tensor([[ 0.1427, -0.5414], + [-0.4664, -0.1228], + [-1.1734, 0.7230]]) + """ + +def indices_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.indices`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def init_num_threads() -> None: ... +def inner( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + inner(input, other, *, out=None) -> Tensor + + Computes the dot product for 1D tensors. For higher dimensions, sums the product + of elements from :attr:`input` and :attr:`other` along their last dimension. + + .. note:: + + If either :attr:`input` or :attr:`other` is a scalar, the result is equivalent + to `torch.mul(input, other)`. + + If both :attr:`input` and :attr:`other` are non-scalars, the size of their last + dimension must match and the result is equivalent to `torch.tensordot(input, + other, dims=([-1], [-1]))` + + Args: + input (Tensor): First input tensor + other (Tensor): Second input tensor + + Keyword args: + out (Tensor, optional): Optional output tensor to write result into. The output + shape is `input.shape[:-1] + other.shape[:-1]`. + + Example:: + + # Dot product + >>> torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1])) + tensor(7) + + # Multidimensional input tensors + >>> a = torch.randn(2, 3) + >>> a + tensor([[0.8173, 1.0874, 1.1784], + [0.3279, 0.1234, 2.7894]]) + >>> b = torch.randn(2, 4, 3) + >>> b + tensor([[[-0.4682, -0.7159, 0.1506], + [ 0.4034, -0.3657, 1.0387], + [ 0.9892, -0.6684, 0.1774], + [ 0.9482, 1.3261, 0.3917]], + + [[ 0.4537, 0.7493, 1.1724], + [ 0.2291, 0.5749, -0.2267], + [-0.7920, 0.3607, -0.3701], + [ 1.3666, -0.5850, -1.7242]]]) + >>> torch.inner(a, b) + tensor([[[-0.9837, 1.1560, 0.2907, 2.6785], + [ 2.5671, 0.5452, -0.6912, -1.5509]], + + [[ 0.1782, 2.9843, 0.7366, 1.5672], + [ 3.5115, -0.4864, -1.2476, -4.4337]]]) + + # Scalar input + >>> torch.inner(a, torch.tensor(2)) + tensor([[1.6347, 2.1748, 2.3567], + [0.6558, 0.2469, 5.5787]]) + """ + +def instance_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + use_input_stats: _bool, + momentum: _float, + eps: _float, + cudnn_enabled: _bool, +) -> Tensor: ... +def int_repr(input: Tensor) -> Tensor: ... +def inverse(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + inverse(input, *, out=None) -> Tensor + + Alias for :func:`torch.linalg.inv` + """ + +def is_complex(input: Tensor) -> _bool: + r""" + is_complex(input) -> (bool) + + Returns True if the data type of :attr:`input` is a complex data type i.e., + one of ``torch.complex64``, and ``torch.complex128``. + + Args: + input (Tensor): the input tensor. + """ + +def is_conj(input: Tensor) -> _bool: + r""" + is_conj(input) -> (bool) + + Returns True if the :attr:`input` is a conjugated tensor, i.e. its conjugate bit is set to `True`. + + Args: + input (Tensor): the input tensor. + """ + +def is_distributed(input: Tensor) -> _bool: ... +def is_floating_point(input: Tensor) -> _bool: + r""" + is_floating_point(input) -> (bool) + + Returns True if the data type of :attr:`input` is a floating point data type i.e., + one of ``torch.float64``, ``torch.float32``, ``torch.float16``, and ``torch.bfloat16``. + + Args: + input (Tensor): the input tensor. + """ + +def is_grad_enabled() -> _bool: + r""" + is_grad_enabled() -> (bool) + + Returns True if grad mode is currently enabled. + """ + +def is_inference(input: Tensor) -> _bool: + r""" + is_inference(input) -> (bool) + + Returns True if :attr:`input` is an inference tensor. + + A non-view tensor is an inference tensor if and only if it was + allocated during inference mode. A view tensor is an inference + tensor if and only if the tensor it is a view of is an inference tensor. + + For details on inference mode please see + `Inference Mode `_. + + Args: + input (Tensor): the input tensor. + """ + +def is_inference_mode_enabled() -> _bool: + r""" + is_inference_mode_enabled() -> (bool) + + Returns True if inference mode is currently enabled. + """ + +def is_neg(input: Tensor) -> _bool: ... +def is_nonzero(input: Tensor) -> _bool: + r""" + is_nonzero(input) -> (bool) + + Returns True if the :attr:`input` is a single element tensor which is not equal to zero + after type conversions. + i.e. not equal to ``torch.tensor([0.])`` or ``torch.tensor([0])`` or + ``torch.tensor([False])``. + Throws a ``RuntimeError`` if ``torch.numel() != 1`` (even in case + of sparse tensors). + + Args: + input (Tensor): the input tensor. + + Examples:: + + >>> torch.is_nonzero(torch.tensor([0.])) + False + >>> torch.is_nonzero(torch.tensor([1.5])) + True + >>> torch.is_nonzero(torch.tensor([False])) + False + >>> torch.is_nonzero(torch.tensor([3])) + True + >>> torch.is_nonzero(torch.tensor([1, 3, 5])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with more than one value is ambiguous + >>> torch.is_nonzero(torch.tensor([])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with no values is ambiguous + """ + +def is_same_size(input: Tensor, other: Tensor) -> _bool: ... +def is_signed(input: Tensor) -> _bool: ... +def is_vulkan_available() -> _bool: ... +def isclose( + input: Tensor, + other: Tensor, + rtol: _float = 1e-05, + atol: _float = 1e-08, + equal_nan: _bool = False, +) -> Tensor: + r""" + isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + + Returns a new tensor with boolean elements representing if each element of + :attr:`input` is "close" to the corresponding element of :attr:`other`. + Closeness is defined as: + + .. math:: + \lvert \text{input}_i - \text{other}_i \rvert \leq \texttt{rtol} \times \lvert \text{other}_i \rvert + \texttt{atol} + + + where :attr:`input` and :attr:`other` are finite. Where :attr:`input` + and/or :attr:`other` are nonfinite they are close if and only if + they are equal, with NaNs being considered equal to each other when + :attr:`equal_nan` is True. + + Args: + input (Tensor): first tensor to compare + other (Tensor): second tensor to compare + rtol (float, optional): relative tolerance. Default: 1e-05 + atol (float, optional): absolute tolerance. Default: 1e-08 + equal_nan (bool, optional): if ``True``, then two ``NaN`` s will be considered equal. Default: ``False`` + + Examples:: + + >>> torch.isclose(torch.tensor((1., 2, 3)), torch.tensor((1 + 1e-10, 3, 4))) + tensor([ True, False, False]) + >>> torch.isclose(torch.tensor((float('inf'), 4)), torch.tensor((float('inf'), 6)), rtol=.5) + tensor([True, True]) + """ + +def isfinite(input: Tensor) -> Tensor: + r""" + isfinite(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element is `finite` or not. + + Real values are finite when they are not NaN, negative infinity, or infinity. + Complex values are finite when both their real and imaginary parts are finite. + + Args: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is finite and False elsewhere + + Example:: + + >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([True, False, True, False, False]) + """ + +@overload +def isin( + elements: Tensor, + test_elements: Tensor, + *, + assume_unique: _bool = False, + invert: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + +@overload +def isin( + element: Number | _complex, + test_elements: Tensor, + *, + assume_unique: _bool = False, + invert: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + +@overload +def isin( + elements: Tensor, + test_element: Number | _complex, + *, + assume_unique: _bool = False, + invert: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor + + Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns + a boolean tensor of the same shape as :attr:`elements` that is True for elements + in :attr:`test_elements` and False otherwise. + + .. note:: + One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both. + + Args: + elements (Tensor or Scalar): Input elements + test_elements (Tensor or Scalar): Values against which to test for each input element + assume_unique (bool, optional): If True, assumes both :attr:`elements` and + :attr:`test_elements` contain unique elements, which can speed up the + calculation. Default: False + invert (bool, optional): If True, inverts the boolean return tensor, resulting in True + values for elements *not* in :attr:`test_elements`. Default: False + + Returns: + A boolean tensor of the same shape as :attr:`elements` that is True for elements in + :attr:`test_elements` and False otherwise + + Example: + >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3])) + tensor([[False, True], + [ True, False]]) + """ + +def isinf(input: Tensor) -> Tensor: + r""" + isinf(input) -> Tensor + + Tests if each element of :attr:`input` is infinite + (positive or negative infinity) or not. + + .. note:: + Complex values are infinite when their real or imaginary part is + infinite. + + Args: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is infinite and False elsewhere + + Example:: + + >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) + tensor([False, True, False, True, False]) + """ + +def isnan(input: Tensor) -> Tensor: + r""" + isnan(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element of :attr:`input` + is NaN or not. Complex values are considered NaN when either their real + and/or imaginary part is NaN. + + Arguments: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is NaN and False elsewhere + + Example:: + + >>> torch.isnan(torch.tensor([1, float('nan'), 2])) + tensor([False, True, False]) + """ + +def isneginf(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + isneginf(input, *, out=None) -> Tensor + Tests if each element of :attr:`input` is negative infinity or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isneginf(a) + tensor([ True, False, False]) + """ + +def isposinf(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + isposinf(input, *, out=None) -> Tensor + Tests if each element of :attr:`input` is positive infinity or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([-float('inf'), float('inf'), 1.2]) + >>> torch.isposinf(a) + tensor([False, True, False]) + """ + +def isreal(input: Tensor) -> Tensor: + r""" + isreal(input) -> Tensor + + Returns a new tensor with boolean elements representing if each element of :attr:`input` is real-valued or not. + All real-valued types are considered real. Complex values are considered real when their imaginary part is 0. + + Arguments: + input (Tensor): the input tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is real and False elsewhere + + Example:: + + >>> torch.isreal(torch.tensor([1, 1+1j, 2+0j])) + tensor([True, False, True]) + """ + +def istft( + input: Tensor, + n_fft: _int, + hop_length: _int | None = None, + win_length: _int | None = None, + window: Tensor | None = None, + center: _bool = True, + normalized: _bool = False, + onesided: _bool | None = None, + length: _int | None = None, + return_complex: _bool = False, +) -> Tensor: ... +@overload +def kaiser_window( + window_length: _int, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + +@overload +def kaiser_window( + window_length: _int, + periodic: _bool, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + +@overload +def kaiser_window( + window_length: _int, + periodic: _bool, + beta: _float, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + kaiser_window(window_length, periodic=True, beta=12.0, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Computes the Kaiser window with window length :attr:`window_length` and shape parameter :attr:`beta`. + + Let I_0 be the zeroth order modified Bessel function of the first kind (see :func:`torch.i0`) and + ``N = L - 1`` if :attr:`periodic` is False and ``L`` if :attr:`periodic` is True, + where ``L`` is the :attr:`window_length`. This function computes: + + .. math:: + out_i = I_0 \left( \beta \sqrt{1 - \left( {\frac{i - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + + Calling ``torch.kaiser_window(L, B, periodic=True)`` is equivalent to calling + ``torch.kaiser_window(L + 1, B, periodic=False)[:-1])``. + The :attr:`periodic` argument is intended as a helpful shorthand + to produce a periodic window as input to functions like :func:`torch.stft`. + + .. note:: + If :attr:`window_length` is one, then the returned window is a single element tensor containing a one. + + + Args: + window_length (int): length of the window. + periodic (bool, optional): If True, returns a periodic window suitable for use in spectral analysis. + If False, returns a symmetric window suitable for use in filter design. + beta (float, optional): shape parameter for the window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only + ``torch.strided`` (dense layout) is supported. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + """ + +def kl_div( + input: Tensor, + target: Tensor, + reduction: _int = 1, + *, + log_target: _bool = False, +) -> Tensor: ... +def kron( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + kron(input, other, *, out=None) -> Tensor + + Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. + + If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a + :math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a + :math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: + + .. math:: + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = + \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, + + where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. + If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. + + Supports real-valued and complex-valued inputs. + + .. note:: + This function generalizes the typical definition of the Kronecker product for two matrices to two tensors, + as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: + + .. math:: + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + + where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. + + Arguments: + input (Tensor) + other (Tensor) + + Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + + Examples:: + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.ones(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.arange(1, 5).reshape(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 2.], + [0., 0., 3., 4.]]) + """ + +@overload +def kthvalue( + input: Tensor, + k: _int | SymInt, + dim: _int = -1, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.kthvalue: + r""" + kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th + smallest element of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each element found. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors + are the same size as :attr:`input`, except in the dimension :attr:`dim` where + they are of size 1. Otherwise, :attr:`dim` is squeezed + (see :func:`torch.squeeze`), resulting in both the :attr:`values` and + :attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + + .. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + + Args: + input (Tensor): the input tensor. + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) + """ + +@overload +def kthvalue( + input: Tensor, + k: _int | SymInt, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.kthvalue: + r""" + kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th + smallest element of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each element found. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`keepdim` is ``True``, both the :attr:`values` and :attr:`indices` tensors + are the same size as :attr:`input`, except in the dimension :attr:`dim` where + they are of size 1. Otherwise, :attr:`dim` is squeezed + (see :func:`torch.squeeze`), resulting in both the :attr:`values` and + :attr:`indices` tensors having 1 fewer dimension than the :attr:`input` tensor. + + .. note:: + When :attr:`input` is a CUDA tensor and there are multiple valid + :attr:`k` th values, this function may nondeterministically return + :attr:`indices` for any of them. + + Args: + input (Tensor): the input tensor. + k (int): k for the k-th smallest element + dim (int, optional): the dimension to find the kth value along + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) + can be optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.kthvalue(x, 4) + torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3)) + + >>> x=torch.arange(1.,7.).resize_(2,3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.]]) + >>> torch.kthvalue(x, 2, 0, True) + torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]])) + """ + +def layer_norm( + input: Tensor, + normalized_shape: Sequence[_int | SymInt], + weight: Tensor | None = None, + bias: Tensor | None = None, + eps: _float = 1e-05, + cudnn_enable: _bool = True, +) -> Tensor: ... +def lcm( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lcm(input, other, *, out=None) -> Tensor + + Computes the element-wise least common multiple (LCM) of :attr:`input` and :attr:`other`. + + Both :attr:`input` and :attr:`other` must have integer types. + + .. note:: + This defines :math:`lcm(0, 0) = 0` and :math:`lcm(0, a) = 0`. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([5, 10, 15]) + >>> b = torch.tensor([3, 4, 5]) + >>> torch.lcm(a, b) + tensor([15, 20, 15]) + >>> c = torch.tensor([3]) + >>> torch.lcm(a, c) + tensor([15, 30, 15]) + """ + +def lcm_(input: Tensor, other: Tensor) -> Tensor: ... +def ldexp( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ldexp(input, other, *, out=None) -> Tensor + + Multiplies :attr:`input` by 2 ** :attr:`other`. + + .. math:: + \text{{out}}_i = \text{{input}}_i * 2^\text{{other}}_i + + + Typically this function is used to construct floating point numbers by multiplying + mantissas in :attr:`input` with integral powers of two created from the exponents + in :attr:`other`. + + Args: + input (Tensor): the input tensor. + other (Tensor): a tensor of exponents, typically integers. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.ldexp(torch.tensor([1.]), torch.tensor([1])) + tensor([2.]) + >>> torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])) + tensor([ 2., 4., 8., 16.]) + """ + +def ldexp_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def le( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + le(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \leq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + + Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) + """ + +@overload +def le( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + le(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \leq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or Scalar): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than or equal to + :attr:`other` and False elsewhere + + Example:: + + >>> torch.le(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[True, False], [True, True]]) + """ + +@overload +def lerp( + input: Tensor, + end: Tensor, + weight: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lerp(input, end, weight, *, out=None) + + Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based + on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + + .. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) + + The shapes of :attr:`start` and :attr:`end` must be + :ref:`broadcastable `. If :attr:`weight` is a tensor, then + the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + + Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + """ + +@overload +def lerp( + input: Tensor, + end: Tensor, + weight: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lerp(input, end, weight, *, out=None) + + Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based + on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. + + .. math:: + \text{out}_i = \text{start}_i + \text{weight}_i \times (\text{end}_i - \text{start}_i) + + The shapes of :attr:`start` and :attr:`end` must be + :ref:`broadcastable `. If :attr:`weight` is a tensor, then + the shapes of :attr:`weight`, :attr:`start`, and :attr:`end` must be :ref:`broadcastable `. + + Args: + input (Tensor): the tensor with the starting points + end (Tensor): the tensor with the ending points + weight (float or tensor): the weight for the interpolation formula + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> start = torch.arange(1., 5.) + >>> end = torch.empty(4).fill_(10) + >>> start + tensor([ 1., 2., 3., 4.]) + >>> end + tensor([ 10., 10., 10., 10.]) + >>> torch.lerp(start, end, 0.5) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + >>> torch.lerp(start, end, torch.full_like(start, 0.5)) + tensor([ 5.5000, 6.0000, 6.5000, 7.0000]) + """ + +@overload +def less( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + less(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.lt`. + """ + +@overload +def less( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + less(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.lt`. + """ + +@overload +def less_equal( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + less_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.le`. + """ + +@overload +def less_equal( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + less_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.le`. + """ + +def lgamma(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + lgamma(input, *, out=None) -> Tensor + + Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`. + + .. math:: + \text{out}_{i} = \ln |\Gamma(\text{input}_{i})| + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.arange(0.5, 2, 0.5) + >>> torch.lgamma(a) + tensor([ 0.5724, 0.0000, -0.1208]) + """ + +@overload +def linspace( + start: Number, + end: Number, + steps: _int | None = None, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +@overload +def linspace( + start: Tensor, + end: Tensor, + steps: _int, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +@overload +def linspace( + start: Number | _complex, + end: Tensor, + steps: _int, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +@overload +def linspace( + start: Tensor, + end: Number | _complex, + steps: _int, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +@overload +def linspace( + start: Number | _complex, + end: Number | _complex, + steps: _int, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + linspace(start, end, steps, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + + From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + + Example:: + + >>> torch.linspace(3, 10, steps=5) + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000]) + >>> torch.linspace(-10, 10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=5) + tensor([-10., -5., 0., 5., 10.]) + >>> torch.linspace(start=-10, end=10, steps=1) + tensor([-10.]) + """ + +def log(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + log(input, *, out=None) -> Tensor + + Returns a new tensor with the natural logarithm of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{e} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) * 5 + >>> a + tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) + >>> torch.log(a) + tensor([ 1.5637, 1.4640, 0.1952, -1.4226, 1.5204]) + """ + +def log10(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + log10(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the logarithm to the base 10 of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{10} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.5224, 0.9354, 0.7257, 0.1301, 0.2251]) + + + >>> torch.log10(a) + tensor([-0.2820, -0.0290, -0.1392, -0.8857, -0.6476]) + """ + +def log10_(input: Tensor) -> Tensor: ... +def log1p(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + log1p(input, *, out=None) -> Tensor + + Returns a new tensor with the natural logarithm of (1 + :attr:`input`). + + .. math:: + y_i = \log_{e} (x_i + 1) + + .. note:: This function is more accurate than :func:`torch.log` for small + values of :attr:`input` + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([-1.0090, -0.9923, 1.0249, -0.5372, 0.2492]) + >>> torch.log1p(a) + tensor([ nan, -4.8653, 0.7055, -0.7705, 0.2225]) + """ + +def log1p_(input: Tensor) -> Tensor: ... +def log2(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + log2(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the logarithm to the base 2 of the elements + of :attr:`input`. + + .. math:: + y_{i} = \log_{2} (x_{i}) + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.rand(5) + >>> a + tensor([ 0.8419, 0.8003, 0.9971, 0.5287, 0.0490]) + + + >>> torch.log2(a) + tensor([-0.2483, -0.3213, -0.0042, -0.9196, -4.3504]) + """ + +def log2_(input: Tensor) -> Tensor: ... +def log_(input: Tensor) -> Tensor: ... +@overload +def log_softmax( + input: Tensor, + dim: _int, + dtype: _dtype | None = None, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def log_softmax( + input: Tensor, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, +) -> Tensor: ... +def logaddexp( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logaddexp(input, other, *, out=None) -> Tensor + + Logarithm of the sum of exponentiations of the inputs. + + Calculates pointwise :math:`\log\left(e^x + e^y\right)`. This function is useful + in statistics where the calculated probabilities of events may be so small as to + exceed the range of normal floating point numbers. In such cases the logarithm + of the calculated probability is stored. This function allows adding + probabilities stored in such a fashion. + + This op should be disambiguated with :func:`torch.logsumexp` which performs a + reduction on a single tensor. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1.0, -2, -3])) + tensor([-0.3069, -0.6867, -0.8731]) + >>> torch.logaddexp(torch.tensor([-100.0, -200, -300]), torch.tensor([-1.0, -2, -3])) + tensor([-1., -2., -3.]) + >>> torch.logaddexp(torch.tensor([1.0, 2000, 30000]), torch.tensor([-1.0, -2, -3])) + tensor([1.1269e+00, 2.0000e+03, 3.0000e+04]) + """ + +def logaddexp2( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logaddexp2(input, other, *, out=None) -> Tensor + + Logarithm of the sum of exponentiations of the inputs in base-2. + + Calculates pointwise :math:`\log_2\left(2^x + 2^y\right)`. See + :func:`torch.logaddexp` for more details. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword arguments: + out (Tensor, optional): the output tensor. + """ + +@overload +def logcumsumexp( + input: Tensor, + dim: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logcumsumexp(input, dim, *, out=None) -> Tensor + Returns the logarithm of the cumulative summation of the exponentiation of + elements of :attr:`input` in the dimension :attr:`dim`. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logcumsumexp}(x)_{ij} = \log \sum\limits_{k=0}^{j} \exp(x_{ik}) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) + """ + +@overload +def logcumsumexp( + input: Tensor, + dim: str | EllipsisType | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logcumsumexp(input, dim, *, out=None) -> Tensor + Returns the logarithm of the cumulative summation of the exponentiation of + elements of :attr:`input` in the dimension :attr:`dim`. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logcumsumexp}(x)_{ij} = \log \sum\limits_{k=0}^{j} \exp(x_{ik}) + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(10) + >>> torch.logcumsumexp(a, dim=0) + tensor([-0.42296738, -0.04462666, 0.86278635, 0.94622083, 1.05277811, + 1.39202815, 1.83525007, 1.84492621, 2.06084887, 2.06844475])) + """ + +def logdet(input: Tensor) -> Tensor: + r""" + logdet(input) -> Tensor + + Calculates log determinant of a square matrix or batches of square matrices. + + It returns ``-inf`` if the input has a determinant of zero, and ``NaN`` if it has + a negative determinant. + + .. note:: + Backward through :meth:`logdet` internally uses SVD results when :attr:`input` + is not invertible. In this case, double backward through :meth:`logdet` will + be unstable in when :attr:`input` doesn't have distinct singular values. See + :func:`torch.linalg.svd` for details. + + .. seealso:: + + :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the + absolute value of the determinant of real-valued (resp. complex) square matrices. + + Arguments: + input (Tensor): the input tensor of size ``(*, n, n)`` where ``*`` is zero or more + batch dimensions. + + Example:: + + >>> A = torch.randn(3, 3) + >>> torch.det(A) + tensor(0.2611) + >>> torch.logdet(A) + tensor(-1.3430) + >>> A + tensor([[[ 0.9254, -0.6213], + [-0.5787, 1.6843]], + + [[ 0.3242, -0.9665], + [ 0.4539, -0.0887]], + + [[ 1.1336, -0.4025], + [-0.7089, 0.9032]]]) + >>> A.det() + tensor([1.1990, 0.4099, 0.7386]) + >>> A.det().log() + tensor([ 0.1815, -0.8917, -0.3031]) + """ + +def logical_and( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logical_and(input, other, *, out=None) -> Tensor + + Computes the element-wise logical AND of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute AND with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_and(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, False]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_and(a, b) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b.double()) + tensor([False, False, True, False]) + >>> torch.logical_and(a.double(), b) + tensor([False, False, True, False]) + >>> torch.logical_and(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([False, False, True, False]) + """ + +def logical_not(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + logical_not(input, *, out=None) -> Tensor + + Computes the element-wise logical NOT of the given input tensor. If not specified, the output tensor will have the bool + dtype. If the input tensor is not a bool tensor, zeros are treated as ``False`` and non-zeros are treated as ``True``. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_not(torch.tensor([True, False])) + tensor([False, True]) + >>> torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1.5, -10.], dtype=torch.double)) + tensor([ True, False, False]) + >>> torch.logical_not(torch.tensor([0., 1., -10.], dtype=torch.double), out=torch.empty(3, dtype=torch.int16)) + tensor([1, 0, 0], dtype=torch.int16) + """ + +def logical_or( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logical_or(input, other, *, out=None) -> Tensor + + Computes the element-wise logical OR of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute OR with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_or(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([ True, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_or(a, b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b.double()) + tensor([ True, True, True, False]) + >>> torch.logical_or(a.double(), b) + tensor([ True, True, True, False]) + >>> torch.logical_or(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, True, False]) + """ + +def logical_xor( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logical_xor(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor + + Computes the element-wise logical XOR of the given input tensors. Zeros are treated as ``False`` and nonzeros are + treated as ``True``. + + Args: + input (Tensor): the input tensor. + other (Tensor): the tensor to compute XOR with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.logical_xor(torch.tensor([True, False, True]), torch.tensor([True, False, False])) + tensor([False, False, True]) + >>> a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + >>> b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + >>> torch.logical_xor(a, b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b.double()) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a.double(), b) + tensor([ True, True, False, False]) + >>> torch.logical_xor(a, b, out=torch.empty(4, dtype=torch.bool)) + tensor([ True, True, False, False]) + """ + +def logit( + input: Tensor, + eps: _float | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logit(input, eps=None, *, out=None) -> Tensor + + Alias for :func:`torch.special.logit`. + """ + +def logit_(input: Tensor, eps: _float | None = None) -> Tensor: ... +@overload +def logspace( + start: Number, + end: Number, + steps: _int | None = None, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logspace( + start: Tensor, + end: Tensor, + steps: _int, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logspace( + start: Number | _complex, + end: Tensor, + steps: _int, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logspace( + start: Tensor, + end: Number | _complex, + steps: _int, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logspace( + start: Number | _complex, + end: Number | _complex, + steps: _int, + base: _float = 10.0, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to + :math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale + with base :attr:`base`. That is, the values are: + + .. math:: + (\text{base}^{\text{start}}, + \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \ldots, + \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, + \text{base}^{\text{end}}) + + + + From PyTorch 1.11 logspace requires the steps argument. Use steps=100 to restore the previous behavior. + + Args: + start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional + end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional + steps (int): size of the constructed tensor + base (float, optional): base of the logarithm function. Default: ``10.0``. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype (see torch.get_default_dtype()) + when both :attr:`start` and :attr:`end` are real, + and corresponding complex dtype when either is complex. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.logspace(start=-10, end=10, steps=5) + tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10]) + >>> torch.logspace(start=0.1, end=1.0, steps=5) + tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000]) + >>> torch.logspace(start=0.1, end=1.0, steps=1) + tensor([1.2589]) + >>> torch.logspace(start=2, end=2, steps=1, base=2) + tensor([4.0]) + """ + +@overload +def logsumexp( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logsumexp(input, dim, keepdim=False, *, out=None) + + Returns the log of summed exponentials of each row of the :attr:`input` + tensor in the given dimension :attr:`dim`. The computation is numerically + stabilized. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij}) + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) + """ + +@overload +def logsumexp( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + logsumexp(input, dim, keepdim=False, *, out=None) + + Returns the log of summed exponentials of each row of the :attr:`input` + tensor in the given dimension :attr:`dim`. The computation is numerically + stabilized. + + For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is + + .. math:: + \text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij}) + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints): the dimension or dimensions to reduce. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> torch.logsumexp(a, 1) + tensor([1.4907, 1.0593, 1.5696]) + >>> torch.dist(torch.logsumexp(a, 1), torch.log(torch.sum(torch.exp(a), 1))) + tensor(1.6859e-07) + """ + +@overload +def lstm( + data: Tensor, + batch_sizes: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, +) -> tuple[Tensor, Tensor, Tensor]: ... +@overload +def lstm( + input: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor, Tensor]: ... +def lstm_cell( + input: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor | None = None, + b_hh: Tensor | None = None, +) -> tuple[Tensor, Tensor]: ... +@overload +def lt( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} < \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + + Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) + """ + +@overload +def lt( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lt(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} < \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is less than :attr:`other` and False elsewhere + + Example:: + + >>> torch.lt(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, False], [True, False]]) + """ + +def lu_solve( + input: Tensor, + LU_data: Tensor, + LU_pivots: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor + + Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted + LU factorization of A from :func:`~linalg.lu_factor`. + + This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`. + + .. warning:: + + :func:`torch.lu_solve` is deprecated in favor of :func:`torch.linalg.lu_solve`. + :func:`torch.lu_solve` will be removed in a future PyTorch release. + ``X = torch.lu_solve(B, LU, pivots)`` should be replaced with + + .. code:: python + + X = linalg.lu_solve(LU, pivots, B) + + Arguments: + b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*` + is zero or more batch dimensions. + LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`, + where :math:`*` is zero or more batch dimensions. + LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`, + where :math:`*` is zero or more batch dimensions. + The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of + :attr:`LU_data`. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(2, 3, 1) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> x = torch.lu_solve(b, LU, pivots) + >>> torch.dist(A @ x, b) + tensor(1.00000e-07 * + 2.8312) + """ + +def lu_unpack( + LU_data: Tensor, + LU_pivots: Tensor, + unpack_data: _bool = True, + unpack_pivots: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.lu_unpack: + r""" + lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor) + + Unpacks the LU decomposition returned by :func:`~linalg.lu_factor` into the `P, L, U` matrices. + + .. seealso:: + + :func:`~linalg.lu` returns the matrices from the LU decomposition. Its gradient formula is more efficient + than that of doing :func:`~linalg.lu_factor` followed by :func:`~linalg.lu_unpack`. + + Args: + LU_data (Tensor): the packed LU factorization data + LU_pivots (Tensor): the packed LU factorization pivots + unpack_data (bool): flag indicating if the data should be unpacked. + If ``False``, then the returned ``L`` and ``U`` are empty tensors. + Default: ``True`` + unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``. + If ``False``, then the returned ``P`` is an empty tensor. + Default: ``True`` + + Keyword args: + out (tuple, optional): output tuple of three tensors. Ignored if `None`. + + Returns: + A namedtuple ``(P, L, U)`` + + Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # We can recover A from the factorization + >>> A_ = P @ L @ U + >>> torch.allclose(A, A_) + True + + >>> # LU factorization of a rectangular matrix: + >>> A = torch.randn(2, 3, 2) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> P, L, U = torch.lu_unpack(LU, pivots) + >>> # P, L, U are the same as returned by linalg.lu + >>> P_, L_, U_ = torch.linalg.lu(A) + >>> torch.allclose(P, P_) and torch.allclose(L, L_) and torch.allclose(U, U_) + True + """ + +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: _float = 0.0, + reduction: _int = 1, +) -> Tensor: ... +@overload +def masked_fill(input: Tensor, mask: Tensor, value: Tensor) -> Tensor: ... +@overload +def masked_fill( + input: Tensor, + mask: Tensor, + value: Number | _complex, +) -> Tensor: ... +def masked_scatter(input: Tensor, mask: Tensor, source: Tensor) -> Tensor: ... +def masked_select( + input: Tensor, + mask: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + masked_select(input, mask, *, out=None) -> Tensor + + Returns a new 1-D tensor which indexes the :attr:`input` tensor according to + the boolean mask :attr:`mask` which is a `BoolTensor`. + + The shapes of the :attr:`mask` tensor and the :attr:`input` tensor don't need + to match, but they must be :ref:`broadcastable `. + + .. note:: The returned tensor does **not** use the same storage + as the original tensor + + Args: + input (Tensor): the input tensor. + mask (BoolTensor): the tensor containing the binary mask to index with + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(3, 4) + >>> x + tensor([[ 0.3552, -2.3825, -0.8297, 0.3477], + [-1.2035, 1.2252, 0.5002, 0.6248], + [ 0.1307, -2.0608, 0.1244, 2.0139]]) + >>> mask = x.ge(0.5) + >>> mask + tensor([[False, False, False, False], + [False, True, True, True], + [False, False, False, True]]) + >>> torch.masked_select(x, mask) + tensor([ 1.2252, 0.5002, 0.6248, 2.0139]) + """ + +def matmul( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + matmul(input, other, *, out=None) -> Tensor + + Matrix product of two tensors. + + The behavior depends on the dimensionality of the tensors as follows: + + - If both tensors are 1-dimensional, the dot product (scalar) is returned. + - If both arguments are 2-dimensional, the matrix-matrix product is returned. + - If the first argument is 1-dimensional and the second argument is 2-dimensional, + a 1 is prepended to its dimension for the purpose of the matrix multiply. + After the matrix multiply, the prepended dimension is removed. + - If the first argument is 2-dimensional and the second argument is 1-dimensional, + the matrix-vector product is returned. + - If both arguments are at least 1-dimensional and at least one argument is + N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first + argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the + batched matrix multiply and removed after. If the second argument is 1-dimensional, a + 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. + The non-matrix (i.e. batch) dimensions are :ref:`broadcasted ` (and thus + must be broadcastable). For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)` + tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor. + + Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs + are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)` + tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the + matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor. + + This operation has support for arguments with :ref:`sparse layouts`. In particular the + matrix-matrix (both arguments 2-dimensional) supports sparse arguments with the same restrictions + as :func:`torch.mm` + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + .. note:: + + The 1-dimensional dot product version of this function does not support an :attr:`out` parameter. + + Arguments: + input (Tensor): the first tensor to be multiplied + other (Tensor): the second tensor to be multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> # vector x vector + >>> tensor1 = torch.randn(3) + >>> tensor2 = torch.randn(3) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([]) + >>> # matrix x vector + >>> tensor1 = torch.randn(3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([3]) + >>> # batched matrix x broadcasted vector + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3]) + >>> # batched matrix x batched matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(10, 4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + >>> # batched matrix x broadcasted matrix + >>> tensor1 = torch.randn(10, 3, 4) + >>> tensor2 = torch.randn(4, 5) + >>> torch.matmul(tensor1, tensor2).size() + torch.Size([10, 3, 5]) + """ + +def matrix_exp(input: Tensor) -> Tensor: + r""" + matrix_exp(A) -> Tensor + + Alias for :func:`torch.linalg.matrix_exp`. + """ + +def matrix_power( + input: Tensor, + n: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + matrix_power(input, n, *, out=None) -> Tensor + + Alias for :func:`torch.linalg.matrix_power` + """ + +@overload +def max(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + max(input, *, out=None) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + >>> a.max(dim=1, keepdim=True) + torch.return_types.max( + values=tensor([[2.], [4.]]), + indices=tensor([[1], [1]])) + >>> a.max(dim=1, keepdim=False) + torch.return_types.max( + values=tensor([2., 4.]), + indices=tensor([1, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + +@overload +def max( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + max(input, *, out=None) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + >>> a.max(dim=1, keepdim=True) + torch.return_types.max( + values=tensor([[2.], [4.]]), + indices=tensor([[1], [1]])) + >>> a.max(dim=1, keepdim=False) + torch.return_types.max( + values=tensor([2., 4.]), + indices=tensor([1, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + +@overload +def max( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.max: + r""" + max(input, *, out=None) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + >>> a.max(dim=1, keepdim=True) + torch.return_types.max( + values=tensor([[2.], [4.]]), + indices=tensor([[1], [1]])) + >>> a.max(dim=1, keepdim=False) + torch.return_types.max( + values=tensor([2., 4.]), + indices=tensor([1, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + +@overload +def max( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.max: + r""" + max(input, *, out=None) -> Tensor + + Returns the maximum value of all elements in the ``input`` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6763, 0.7445, -2.2369]]) + >>> torch.max(a) + tensor(0.7445) + + .. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each maximum value found + (argmax). + + If ``keepdim`` is ``True``, the output tensors are of the same size + as ``input`` except in the dimension ``dim`` where they are of size 1. + Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than ``input``. + + .. note:: If there are multiple maximal values in a reduced row then + the indices of the first maximal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (max, max_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-1.2360, -0.2942, -0.1222, 0.8475], + [ 1.1949, -1.1127, -2.2379, -0.6702], + [ 1.5717, -0.9207, 0.1297, -1.8768], + [-0.6172, 1.0036, -0.6060, -0.2432]]) + >>> torch.max(a, 1) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) + >>> a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + >>> a.max(dim=1, keepdim=True) + torch.return_types.max( + values=tensor([[2.], [4.]]), + indices=tensor([[1], [1]])) + >>> a.max(dim=1, keepdim=False) + torch.return_types.max( + values=tensor([2., 4.]), + indices=tensor([1, 1])) + + .. function:: max(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.maximum`. + """ + +def max_pool1d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def max_pool1d_with_indices( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> tuple[Tensor, Tensor]: ... +def max_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def max_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def maximum( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + maximum(input, other, *, out=None) -> Tensor + + Computes the element-wise maximum of :attr:`input` and :attr:`other`. + + .. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`maximum` is not supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.maximum(a, b) + tensor([3, 2, 4]) + """ + +@overload +def mean( + input: Tensor, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + .. note:: + If the `input` tensor is empty, ``torch.mean()`` returns ``nan``. + This behavior is consistent with NumPy and follows the definition + that the mean over an empty set is undefined. + + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + +@overload +def mean( + input: Tensor, + dim: _int | _size | None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + .. note:: + If the `input` tensor is empty, ``torch.mean()`` returns ``nan``. + This behavior is consistent with NumPy and follows the definition + that the mean over an empty set is undefined. + + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + +@overload +def mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + mean(input, *, dtype=None) -> Tensor + + .. note:: + If the `input` tensor is empty, ``torch.mean()`` returns ``nan``. + This behavior is consistent with NumPy and follows the definition + that the mean over an empty set is undefined. + + + Returns the mean value of all elements in the :attr:`input` tensor. Input must be floating point or complex. + + Args: + input (Tensor): + the input tensor, either of floating point or complex dtype + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.2294, -0.5481, 1.3288]]) + >>> torch.mean(a) + tensor(0.3367) + + .. function:: mean(input, dim, keepdim=False, *, dtype=None, out=None) -> Tensor + :noindex: + + Returns the mean value of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.nanmean` computes the mean value of `non-NaN` elements. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.3841, 0.6320, 0.4254, -0.7384], + [-0.9644, 1.0131, -0.6549, -1.4279], + [-0.2951, -1.3350, -0.7694, 0.5600], + [ 1.0842, -0.9580, 0.3623, 0.2343]]) + >>> torch.mean(a, 1) + tensor([-0.0163, -0.5085, -0.4599, 0.1807]) + >>> torch.mean(a, 1, True) + tensor([[-0.0163], + [-0.5085], + [-0.4599], + [ 0.1807]]) + """ + +@overload +def median(input: Tensor) -> Tensor: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + +@overload +def median( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.median: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + +@overload +def median( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.median: + r""" + median(input) -> Tensor + + Returns the median of the values in :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements. In this case the lower of the two medians is returned. To + compute the mean of both medians, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + This function produces deterministic (sub)gradients unlike ``median(dim=0)`` + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 1.5219, -1.5212, 0.2202]]) + >>> torch.median(a) + tensor(0.2202) + + .. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, and ``indices`` contains the index of the median values found in the dimension :attr:`dim`. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size + as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the outputs tensor having 1 fewer dimension than :attr:`input`. + + .. note:: + The median is not unique for :attr:`input` tensors with an even number + of elements in the dimension :attr:`dim`. In this case the lower of the + two medians is returned. To compute the mean of both medians in + :attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead. + + .. warning:: + ``indices`` does not necessarily contain the first occurrence of each + median value found, unless it is unique. + The exact implementation details are device-specific. + Do not expect the same result when run on CPU and GPU in general. + For the same reason do not expect the gradients to be deterministic. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.randn(4, 5) + >>> a + tensor([[ 0.2505, -0.3982, -0.9948, 0.3518, -1.3131], + [ 0.3180, -0.6993, 1.0436, 0.0438, 0.2270], + [-0.2751, 0.7303, 0.2192, 0.3321, 0.2488], + [ 1.0778, -1.9510, 0.7048, 0.4742, -0.7125]]) + >>> torch.median(a, 1) + torch.return_types.median(values=tensor([-0.3982, 0.2270, 0.2488, 0.4742]), indices=tensor([1, 4, 4, 3])) + """ + +@overload +def min(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + min(input, *, out=None) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + +@overload +def min( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + min(input, *, out=None) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + +@overload +def min( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.min: + r""" + min(input, *, out=None) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + +@overload +def min( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.min: + r""" + min(input, *, out=None) -> Tensor + + Returns the minimum value of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.6750, 1.0857, 1.7197]]) + >>> torch.min(a) + tensor(0.6750) + + .. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`. And ``indices`` is the index location of each minimum value found + (argmin). + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: If there are multiple minimal values in a reduced row then + the indices of the first minimal value are returned. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the tuple of two output tensors (min, min_indices) + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[-0.6248, 1.1334, -1.1899, -0.2803], + [-1.4644, -0.2635, -0.3651, 0.6134], + [ 0.2457, 0.0384, 1.0128, 0.7015], + [-0.1153, 2.9849, 2.1458, 0.5788]]) + >>> torch.min(a, 1) + torch.return_types.min(values=tensor([-1.1899, -1.4644, 0.0384, -0.1153]), indices=tensor([2, 0, 1, 0])) + + .. function:: min(input, other, *, out=None) -> Tensor + :noindex: + + See :func:`torch.minimum`. + """ + +def minimum( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + minimum(input, other, *, out=None) -> Tensor + + Computes the element-wise minimum of :attr:`input` and :attr:`other`. + + .. note:: + If one of the elements being compared is a NaN, then that element is returned. + :func:`minimum` is not supported for tensors with complex dtypes. + + Args: + input (Tensor): the input tensor. + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2, -1)) + >>> b = torch.tensor((3, 0, 4)) + >>> torch.minimum(a, b) + tensor([1, 0, -1]) + """ + +def miopen_batch_norm( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + exponential_average_factor: _float, + epsilon: _float, +) -> tuple[Tensor, Tensor, Tensor]: ... +def miopen_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, +) -> Tensor: ... +def miopen_convolution_add_relu( + input: Tensor, + weight: Tensor, + z: Tensor, + alpha: Number | _complex | None, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def miopen_convolution_relu( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + stride: Sequence[_int | SymInt], + padding: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def miopen_convolution_transpose( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + output_padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, +) -> Tensor: ... +def miopen_depthwise_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, + benchmark: _bool, + deterministic: _bool, +) -> Tensor: ... +def miopen_rnn( + input: Tensor, + weight: tuple[Tensor, ...] | list[Tensor] | None, + weight_stride0: _int, + hx: Tensor, + cx: Tensor | None, + mode: _int, + hidden_size: _int, + num_layers: _int, + batch_first: _bool, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_sizes: _size, + dropout_state: Tensor | None, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ... +def mkldnn_adaptive_avg_pool2d( + input: Tensor, + output_size: _int | _size, + *, + out: Tensor | None = None, +) -> Tensor: ... +def mkldnn_convolution( + input: Tensor, + weight: Tensor, + bias: Tensor | None, + padding: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + dilation: Sequence[_int | SymInt], + groups: _int | SymInt, +) -> Tensor: ... +def mkldnn_linear_backward_weights( + grad_output: Tensor, + input: Tensor, + weight: Tensor, + bias_defined: _bool, +) -> tuple[Tensor, Tensor]: ... +def mkldnn_max_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def mkldnn_max_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def mkldnn_rnn_layer( + input: Tensor, + weight0: Tensor, + weight1: Tensor, + weight2: Tensor, + weight3: Tensor, + hx_: Tensor, + cx_: Tensor, + reverse: _bool, + batch_sizes: _size, + mode: _int, + hidden_size: _int, + num_layers: _int, + has_biases: _bool, + bidirectional: _bool, + batch_first: _bool, + train: _bool, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: ... +@overload +def mm(input: Tensor, mat2: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + mm(input, mat2, out_dtype=None, *, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, :attr:`out` will be a :math:`(n \times p)` tensor. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Supports strided and sparse 2-D tensors as inputs, autograd with + respect to strided inputs. + + This operation has support for arguments with :ref:`sparse layouts`. + If :attr:`out` is provided its layout will be used. Otherwise, the result + layout will be deduced from that of :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.mm(mat1, mat2) + tensor([[ 0.4851, 0.5037, -0.3633], + [-0.0760, -3.6705, 2.4784]]) + """ + +@overload +def mm( + input: Tensor, + mat2: Tensor, + out_dtype: _dtype, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + mm(input, mat2, out_dtype=None, *, out=None) -> Tensor + + Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`mat2` is a + :math:`(m \times p)` tensor, :attr:`out` will be a :math:`(n \times p)` tensor. + + .. note:: This function does not :ref:`broadcast `. + For broadcasting matrix products, see :func:`torch.matmul`. + + Supports strided and sparse 2-D tensors as inputs, autograd with + respect to strided inputs. + + This operation has support for arguments with :ref:`sparse layouts`. + If :attr:`out` is provided its layout will be used. Otherwise, the result + layout will be deduced from that of :attr:`input`. + + + .. warning:: + Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, + or may not have autograd support. If you notice missing functionality please + open a feature request. + + This operator supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + input (Tensor): the first matrix to be matrix multiplied + mat2 (Tensor): the second matrix to be matrix multiplied + out_dtype (dtype, optional): the dtype of the output tensor, + Supported only on CUDA and for torch.float32 given + torch.float16/torch.bfloat16 input dtypes + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat1 = torch.randn(2, 3) + >>> mat2 = torch.randn(3, 3) + >>> torch.mm(mat1, mat2) + tensor([[ 0.4851, 0.5037, -0.3633], + [-0.0760, -3.6705, 2.4784]]) + """ + +@overload +def mode( + input: Tensor, + dim: _int = -1, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.mode: + r""" + mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the mode + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`, i.e. a value which appears most often + in that row, and ``indices`` is the index location of each mode value found. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> b = torch.tensor([[0, 0, 0, 2, 0, 0, 2], + ... [0, 3, 0, 0, 2, 0, 1], + ... [2, 2, 2, 0, 0, 0, 3], + ... [2, 2, 3, 0, 1, 1, 0], + ... [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) + """ + +@overload +def mode( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.mode: + r""" + mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + + Returns a namedtuple ``(values, indices)`` where ``values`` is the mode + value of each row of the :attr:`input` tensor in the given dimension + :attr:`dim`, i.e. a value which appears most often + in that row, and ``indices`` is the index location of each mode value found. + + By default, :attr:`dim` is the last dimension of the :attr:`input` tensor. + + If :attr:`keepdim` is ``True``, the output tensors are of the same size as + :attr:`input` except in the dimension :attr:`dim` where they are of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting + in the output tensors having 1 fewer dimension than :attr:`input`. + + .. note:: This function is not defined for ``torch.cuda.Tensor`` yet. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out (tuple, optional): the result tuple of two output tensors (values, indices) + + Example:: + + >>> b = torch.tensor([[0, 0, 0, 2, 0, 0, 2], + ... [0, 3, 0, 0, 2, 0, 1], + ... [2, 2, 2, 0, 0, 0, 3], + ... [2, 2, 3, 0, 1, 1, 0], + ... [1, 1, 0, 0, 2, 0, 2]]) + >>> torch.mode(b, 0) + torch.return_types.mode( + values=tensor([0, 2, 0, 0, 0, 0, 2]), + indices=tensor([1, 3, 4, 4, 2, 4, 4])) + """ + +@overload +def moveaxis(input: Tensor, source: _int, destination: _int) -> Tensor: + r""" + moveaxis(input, source, destination) -> Tensor + + Alias for :func:`torch.movedim`. + + This function is equivalent to NumPy's moveaxis function. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + +@overload +def moveaxis(input: Tensor, source: _size, destination: _size) -> Tensor: + r""" + moveaxis(input, source, destination) -> Tensor + + Alias for :func:`torch.movedim`. + + This function is equivalent to NumPy's moveaxis function. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.moveaxis(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.moveaxis(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.moveaxis(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.moveaxis(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + +@overload +def movedim(input: Tensor, source: _int, destination: _int) -> Tensor: + r""" + movedim(input, source, destination) -> Tensor + + Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` + to the position(s) in :attr:`destination`. + + Other dimensions of :attr:`input` that are not explicitly moved remain in + their original order and appear at the positions not specified in :attr:`destination`. + + Args: + input (Tensor): the input tensor. + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + +@overload +def movedim(input: Tensor, source: _size, destination: _size) -> Tensor: + r""" + movedim(input, source, destination) -> Tensor + + Moves the dimension(s) of :attr:`input` at the position(s) in :attr:`source` + to the position(s) in :attr:`destination`. + + Other dimensions of :attr:`input` that are not explicitly moved remain in + their original order and appear at the positions not specified in :attr:`destination`. + + Args: + input (Tensor): the input tensor. + source (int or tuple of ints): Original positions of the dims to move. These must be unique. + destination (int or tuple of ints): Destination positions for each of the original dims. These must also be unique. + + Examples:: + + >>> t = torch.randn(3,2,1) + >>> t + tensor([[[-0.3362], + [-0.8437]], + + [[-0.9627], + [ 0.1727]], + + [[ 0.5173], + [-0.1398]]]) + >>> torch.movedim(t, 1, 0).shape + torch.Size([2, 3, 1]) + >>> torch.movedim(t, 1, 0) + tensor([[[-0.3362], + [-0.9627], + [ 0.5173]], + + [[-0.8437], + [ 0.1727], + [-0.1398]]]) + >>> torch.movedim(t, (1, 2), (0, 1)).shape + torch.Size([2, 1, 3]) + >>> torch.movedim(t, (1, 2), (0, 1)) + tensor([[[-0.3362, -0.9627, 0.5173]], + + [[-0.8437, 0.1727, -0.1398]]]) + """ + +def msort(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + msort(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Sorts the elements of the :attr:`input` tensor along its first dimension + in ascending order by value. + + .. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`. + See also :func:`torch.sort`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.randn(3, 4) + >>> t + tensor([[-0.1321, 0.4370, -1.2631, -1.1289], + [-2.0527, -1.1250, 0.2275, 0.3077], + [-0.0881, -0.1259, -0.5495, 1.0284]]) + >>> torch.msort(t) + tensor([[-2.0527, -1.1250, -1.2631, -1.1289], + [-0.1321, -0.1259, -0.5495, 0.3077], + [-0.0881, 0.4370, 0.2275, 1.0284]]) + """ + +def mul( + input: Tensor | Number | _complex, + other: Tensor | Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + mul(input, other, *, out=None) -> Tensor + + Multiplies :attr:`input` by :attr:`other`. + + + .. math:: + \text{out}_i = \text{input}_i \times \text{other}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number) - the tensor or number to multiply input by. + + Keyword args: + out (Tensor, optional): the output tensor. + + Examples:: + + >>> a = torch.randn(3) + >>> a + tensor([ 0.2015, -0.4255, 2.6087]) + >>> torch.mul(a, 100) + tensor([ 20.1494, -42.5491, 260.8663]) + + >>> b = torch.randn(4, 1) + >>> b + tensor([[ 1.1207], + [-0.3137], + [ 0.0700], + [ 0.8378]]) + >>> c = torch.randn(1, 4) + >>> c + tensor([[ 0.5146, 0.1216, -0.5244, 2.2382]]) + >>> torch.mul(b, c) + tensor([[ 0.5767, 0.1363, -0.5877, 2.5083], + [-0.1614, -0.0382, 0.1645, -0.7021], + [ 0.0360, 0.0085, -0.0367, 0.1567], + [ 0.4312, 0.1019, -0.4394, 1.8753]]) + """ + +def multinomial( + input: Tensor, + num_samples: _int | SymInt, + replacement: _bool = False, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor + + Returns a tensor where each row contains :attr:`num_samples` indices sampled + from the multinomial (a stricter definition would be multivariate, + refer to :class:`torch.distributions.multinomial.Multinomial` for more details) + probability distribution located in the corresponding row + of tensor :attr:`input`. + + .. note:: + The rows of :attr:`input` do not need to sum to one (in which case we use + the values as weights), but must be non-negative, finite and have + a non-zero sum. + + Indices are ordered from left to right according to when each was sampled + (first samples are placed in first column). + + If :attr:`input` is a vector, :attr:`out` is a vector of size :attr:`num_samples`. + + If :attr:`input` is a matrix with `m` rows, :attr:`out` is an matrix of shape + :math:`(m \times \text{num\_samples})`. + + If replacement is ``True``, samples are drawn with replacement. + + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + + .. note:: + When drawn without replacement, :attr:`num_samples` must be lower than + number of non-zero elements in :attr:`input` (or the min number of non-zero + elements in each row of :attr:`input` if it is a matrix). + + Args: + input (Tensor): the input tensor containing probabilities + num_samples (int): number of samples to draw + replacement (bool, optional): whether to draw with replacement or not + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights + >>> torch.multinomial(weights, 2) + tensor([1, 2]) + >>> torch.multinomial(weights, 5) # ERROR! + RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement + >>> torch.multinomial(weights, 4, replacement=True) + tensor([ 2, 1, 1, 1]) + """ + +@overload +def multiply( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + multiply(input, other, *, out=None) + + Alias for :func:`torch.mul`. + """ + +@overload +def multiply(input: Tensor, other: Number | _complex) -> Tensor: + r""" + multiply(input, other, *, out=None) + + Alias for :func:`torch.mul`. + """ + +def mv(input: Tensor, vec: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + mv(input, vec, *, out=None) -> Tensor + + Performs a matrix-vector product of the matrix :attr:`input` and the vector + :attr:`vec`. + + If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`vec` is a 1-D tensor of + size :math:`m`, :attr:`out` will be 1-D of size :math:`n`. + + .. note:: This function does not :ref:`broadcast `. + + Args: + input (Tensor): matrix to be multiplied + vec (Tensor): vector to be multiplied + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> mat = torch.randn(2, 3) + >>> vec = torch.randn(3) + >>> torch.mv(mat, vec) + tensor([ 1.0404, -0.6361]) + """ + +def mvlgamma( + input: Tensor, + p: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + mvlgamma(input, p, *, out=None) -> Tensor + + Alias for :func:`torch.special.multigammaln`. + """ + +def nan_to_num( + input: Tensor, + nan: _float | None = None, + posinf: _float | None = None, + neginf: _float | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None) -> Tensor + + Replaces :literal:`NaN`, positive infinity, and negative infinity values in :attr:`input` + with the values specified by :attr:`nan`, :attr:`posinf`, and :attr:`neginf`, respectively. + By default, :literal:`NaN`\ s are replaced with zero, positive infinity is replaced with the + greatest finite value representable by :attr:`input`'s dtype, and negative infinity + is replaced with the least finite value representable by :attr:`input`'s dtype. + + Args: + input (Tensor): the input tensor. + nan (Number, optional): the value to replace :literal:`NaN`\s with. Default is zero. + posinf (Number, optional): if a Number, the value to replace positive infinity values with. + If None, positive infinity values are replaced with the greatest finite value representable by :attr:`input`'s dtype. + Default is None. + neginf (Number, optional): if a Number, the value to replace negative infinity values with. + If None, negative infinity values are replaced with the lowest finite value representable by :attr:`input`'s dtype. + Default is None. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + >>> torch.nan_to_num(x) + tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0) + tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00]) + >>> torch.nan_to_num(x, nan=2.0, posinf=1.0) + tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00]) + """ + +def nan_to_num_( + input: Tensor, + nan: _float | None = None, + posinf: _float | None = None, + neginf: _float | None = None, +) -> Tensor: ... +def nanmean( + input: Tensor, + dim: _int | _size | None = None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + + Computes the mean of all `non-NaN` elements along the specified dimensions. + Input must be floating point or complex. + + This function is identical to :func:`torch.mean` when there are no `NaN` values + in the :attr:`input` tensor. In the presence of `NaN`, :func:`torch.mean` will + propagate the `NaN` to the output whereas :func:`torch.nanmean` will ignore the + `NaN` values (`torch.nanmean(a)` is equivalent to `torch.mean(a[~a.isnan()])`). + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor, either of floating point or complex dtype + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + .. seealso:: + + :func:`torch.mean` computes the mean value, propagating `NaN`. + + Example:: + + >>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]]) + >>> x.mean() + tensor(nan) + >>> x.nanmean() + tensor(1.8000) + >>> x.mean(dim=0) + tensor([ nan, 1.5000, 2.5000]) + >>> x.nanmean(dim=0) + tensor([1.0000, 1.5000, 2.5000]) + + # If all elements in the reduced dimensions are NaN then the result is NaN + >>> torch.tensor([torch.nan]).nanmean() + tensor(nan) + """ + +@overload +def nanmedian(input: Tensor) -> Tensor: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + +@overload +def nanmedian( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.nanmedian: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + +@overload +def nanmedian( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.nanmedian: + r""" + nanmedian(input) -> Tensor + + Returns the median of the values in :attr:`input`, ignoring ``NaN`` values. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in :attr:`input`. + When :attr:`input` has one or more ``NaN`` values, :func:`torch.median` will always return ``NaN``, + while this function will return the median of the non-``NaN`` elements in :attr:`input`. + If all the elements in :attr:`input` are ``NaN`` it will also return ``NaN``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.tensor([1, float('nan'), 3, 2]) + >>> a.median() + tensor(nan) + >>> a.nanmedian() + tensor(2.) + + .. function:: nanmedian(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) + :noindex: + + Returns a namedtuple ``(values, indices)`` where ``values`` contains the median of each row of :attr:`input` + in the dimension :attr:`dim`, ignoring ``NaN`` values, and ``indices`` contains the index of the median values + found in the dimension :attr:`dim`. + + This function is identical to :func:`torch.median` when there are no ``NaN`` values in a reduced row. When a reduced row has + one or more ``NaN`` values, :func:`torch.median` will always reduce it to ``NaN``, while this function will reduce it to the + median of the non-``NaN`` elements. If all the elements in a reduced row are ``NaN`` then it will be reduced to ``NaN``, too. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + out ((Tensor, Tensor), optional): The first tensor will be populated with the median values and the second + tensor, which must have dtype long, with their indices in the dimension + :attr:`dim` of :attr:`input`. + + Example:: + + >>> a = torch.tensor([[2, 3, 1], [float('nan'), 1, float('nan')]]) + >>> a + tensor([[2., 3., 1.], + [nan, 1., nan]]) + >>> a.median(0) + torch.return_types.median(values=tensor([nan, 1., nan]), indices=tensor([1, 1, 1])) + >>> a.nanmedian(0) + torch.return_types.nanmedian(values=tensor([2., 1., 1.]), indices=tensor([0, 1, 0])) + """ + +@overload +def nanquantile( + input: Tensor, + q: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + out: Tensor | None = None, +) -> Tensor: + r""" + nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, + computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did + not exist. If all values in a reduced row are ``NaN`` then the quantiles for + that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) + """ + +@overload +def nanquantile( + input: Tensor, + q: _float, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + out: Tensor | None = None, +) -> Tensor: + r""" + nanquantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + This is a variant of :func:`torch.quantile` that "ignores" ``NaN`` values, + computing the quantiles :attr:`q` as if ``NaN`` values in :attr:`input` did + not exist. If all values in a reduced row are ``NaN`` then the quantiles for + that reduction will be ``NaN``. See the documentation for :func:`torch.quantile`. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of quantile values in the range [0, 1] + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([float('nan'), 1, 2]) + >>> t.quantile(0.5) + tensor(nan) + >>> t.nanquantile(0.5) + tensor(1.5000) + >>> t = torch.tensor([[float('nan'), float('nan')], [1, 2]]) + >>> t + tensor([[nan, nan], + [1., 2.]]) + >>> t.nanquantile(0.5, dim=0) + tensor([1., 2.]) + >>> t.nanquantile(0.5, dim=1) + tensor([ nan, 1.5000]) + """ + +def nansum( + input: Tensor, + dim: _int | _size | None = None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + nansum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements, treating Not a Numbers (NaNs) as zero. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.tensor([1., 2., float('nan'), 4.]) + >>> torch.nansum(a) + tensor(7.) + + .. function:: nansum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`, treating Not a Numbers (NaNs) as zero. + If :attr:`dim` is a list of dimensions, reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> torch.nansum(torch.tensor([1., float("nan")])) + tensor(1.) + >>> a = torch.tensor([[1, 2], [3., float("nan")]]) + >>> torch.nansum(a) + tensor(6.) + >>> torch.nansum(a, dim=0) + tensor([4., 2.]) + >>> torch.nansum(a, dim=1) + tensor([3., 3.]) + """ + +@overload +def narrow( + input: Tensor, + dim: _int, + start: Tensor, + length: _int | SymInt, +) -> Tensor: + r""" + narrow(input, dim, start, length) -> Tensor + + Returns a new tensor that is a narrowed version of :attr:`input` tensor. The + dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The + returned tensor and :attr:`input` tensor share the same underlying storage. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) + """ + +@overload +def narrow( + input: Tensor, + dim: _int, + start: _int | SymInt, + length: _int | SymInt, +) -> Tensor: + r""" + narrow(input, dim, start, length) -> Tensor + + Returns a new tensor that is a narrowed version of :attr:`input` tensor. The + dimension :attr:`dim` is input from :attr:`start` to ``start + length``. The + returned tensor and :attr:`input` tensor share the same underlying storage. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) + """ + +def narrow_copy( + input: Tensor, + dim: _int, + start: _int | SymInt, + length: _int | SymInt, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + narrow_copy(input, dim, start, length, *, out=None) -> Tensor + + Same as :meth:`Tensor.narrow` except this returns a copy rather + than shared storage. This is primarily for sparse tensors, which + do not have a shared-storage narrow method. + + Args: + input (Tensor): the tensor to narrow + dim (int): the dimension along which to narrow + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> torch.narrow_copy(x, 0, 0, 2) + tensor([[ 1, 2, 3], + [ 4, 5, 6]]) + >>> torch.narrow_copy(x, 1, 1, 2) + tensor([[ 2, 3], + [ 5, 6], + [ 8, 9]]) + >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) + >>> torch.narrow_copy(s, 0, 0, 1) + tensor(indices=tensor([[0, 0], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + + .. seealso:: + + :func:`torch.narrow` for a non copy variant + """ + +def native_batch_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + running_mean: Tensor | None, + running_var: Tensor | None, + training: _bool, + momentum: _float, + eps: _float, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> tuple[Tensor, Tensor, Tensor]: ... +def native_channel_shuffle(input: Tensor, groups: _int | SymInt) -> Tensor: ... +def native_dropout( + input: Tensor, + p: _float, + train: _bool | None, +) -> tuple[Tensor, Tensor]: ... +def native_group_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + N: _int | SymInt, + C: _int | SymInt, + HxW: _int | SymInt, + group: _int, + eps: _float, +) -> tuple[Tensor, Tensor, Tensor]: ... +def native_layer_norm( + input: Tensor, + normalized_shape: Sequence[_int | SymInt], + weight: Tensor | None, + bias: Tensor | None, + eps: _float, +) -> tuple[Tensor, Tensor, Tensor]: ... +@overload +def native_norm( + input: Tensor, + p: Number | _complex | None, + dim: _int | _size, + keepdim: _bool, + dtype: _dtype | None, +) -> Tensor: ... +@overload +def native_norm(input: Tensor, p: Number | _complex = 2) -> Tensor: ... +@overload +def ne( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ne(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \neq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) + """ + +@overload +def ne( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ne(input, other, *, out=None) -> Tensor + + Computes :math:`\text{input} \neq \text{other}` element-wise. + + + The second argument can be a number or a tensor whose shape is + :ref:`broadcastable ` with the first argument. + + Args: + input (Tensor): the tensor to compare + other (Tensor or float): the tensor or value to compare + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + A boolean tensor that is True where :attr:`input` is not equal to :attr:`other` and False elsewhere + + Example:: + + >>> torch.ne(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]])) + tensor([[False, True], [True, False]]) + """ + +def neg(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + neg(input, *, out=None) -> Tensor + + Returns a new tensor with the negative of the elements of :attr:`input`. + + .. math:: + \text{out} = -1 \times \text{input} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(5) + >>> a + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.neg(a) + tensor([-0.0090, 0.2262, 0.0682, 0.2866, -0.3940]) + """ + +def neg_(input: Tensor) -> Tensor: ... +def negative(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + negative(input, *, out=None) -> Tensor + + Alias for :func:`torch.neg` + """ + +def negative_(input: Tensor) -> Tensor: ... +def nextafter( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + nextafter(input, other, *, out=None) -> Tensor + + Return the next floating-point value after :attr:`input` towards :attr:`other`, elementwise. + + The shapes of ``input`` and ``other`` must be + :ref:`broadcastable `. + + Args: + input (Tensor): the first input tensor + other (Tensor): the second input tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> eps = torch.finfo(torch.float32).eps + >>> torch.nextafter(torch.tensor([1.0, 2.0]), torch.tensor([2.0, 1.0])) == torch.tensor([eps + 1, 2 - eps]) + tensor([True, True]) + """ + +@overload +def nonzero( + input: Tensor, + *, + as_tuple: Literal[False] = False, + out: Tensor | None = None, +) -> Tensor: + r""" + nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + + .. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + + **When** :attr:`as_tuple` **is** ``False`` **(default)**: + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + **When** :attr:`as_tuple` **is** ``True``: + + Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, + each containing the indices (in that dimension) of all non-zero elements of + :attr:`input` . + + If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` + tensors of size :math:`z`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + As a special case, when :attr:`input` has zero dimensions and a nonzero scalar + value, it is treated as a one-dimensional tensor with one element. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (LongTensor, optional): the output tensor containing indices + + Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + + Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) + """ + +@overload +def nonzero( + input: Tensor, + *, + as_tuple: Literal[True], +) -> tuple[Tensor, ...]: + r""" + nonzero(input, *, out=None, as_tuple=False) -> LongTensor or tuple of LongTensors + + .. note:: + :func:`torch.nonzero(..., as_tuple=False) ` (default) returns a + 2-D tensor where each row is the index for a nonzero value. + + :func:`torch.nonzero(..., as_tuple=True) ` returns a tuple of 1-D + index tensors, allowing for advanced indexing, so ``x[x.nonzero(as_tuple=True)]`` + gives all nonzero values of tensor ``x``. Of the returned tuple, each index tensor + contains nonzero indices for a certain dimension. + + See below for more details on the two behaviors. + + When :attr:`input` is on CUDA, :func:`torch.nonzero() ` causes + host-device synchronization. + + **When** :attr:`as_tuple` **is** ``False`` **(default)**: + + Returns a tensor containing the indices of all non-zero elements of + :attr:`input`. Each row in the result contains the indices of a non-zero + element in :attr:`input`. The result is sorted lexicographically, with + the last index changing the fastest (C-style). + + If :attr:`input` has :math:`n` dimensions, then the resulting indices tensor + :attr:`out` is of size :math:`(z \times n)`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + **When** :attr:`as_tuple` **is** ``True``: + + Returns a tuple of 1-D tensors, one for each dimension in :attr:`input`, + each containing the indices (in that dimension) of all non-zero elements of + :attr:`input` . + + If :attr:`input` has :math:`n` dimensions, then the resulting tuple contains :math:`n` + tensors of size :math:`z`, where :math:`z` is the total number of + non-zero elements in the :attr:`input` tensor. + + As a special case, when :attr:`input` has zero dimensions and a nonzero scalar + value, it is treated as a one-dimensional tensor with one element. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (LongTensor, optional): the output tensor containing indices + + Returns: + LongTensor or tuple of LongTensor: If :attr:`as_tuple` is ``False``, the output + tensor containing indices. If :attr:`as_tuple` is ``True``, one 1-D tensor for + each dimension, containing the indices of each nonzero element along that + dimension. + + Example:: + + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) + tensor([[ 0], + [ 1], + [ 2], + [ 4]]) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]])) + tensor([[ 0, 0], + [ 1, 1], + [ 2, 2], + [ 3, 3]]) + >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) + (tensor([0, 1, 2, 4]),) + >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], + ... [0.0, 0.4, 0.0, 0.0], + ... [0.0, 0.0, 1.2, 0.0], + ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) + (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) + >>> torch.nonzero(torch.tensor(5), as_tuple=True) + (tensor([0]),) + """ + +def nonzero_static( + input: Tensor, + *, + size: _int | SymInt, + fill_value: _int = -1, + out: Tensor | None = None, +) -> Tensor: ... +def norm_except_dim(v: Tensor, pow: _int = 2, dim: _int = 0) -> Tensor: ... +@overload +def normal( + mean: Tensor, + std: Tensor, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + +@overload +def normal( + mean: Tensor, + std: _float = 1, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + +@overload +def normal( + mean: _float, + std: Tensor, + *, + generator: Generator | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + +@overload +def normal( + mean: _float, + std: _float, + size: Sequence[_int | SymInt], + *, + generator: Generator | None = None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + normal(mean, std, *, generator=None, out=None) -> Tensor + + Returns a tensor of random numbers drawn from separate normal distributions + whose mean and standard deviation are given. + + The :attr:`mean` is a tensor with the mean of + each output element's normal distribution + + The :attr:`std` is a tensor with the standard deviation of + each output element's normal distribution + + The shapes of :attr:`mean` and :attr:`std` don't need to match, but the + total number of elements in each tensor need to be the same. + + .. note:: When the shapes do not match, the shape of :attr:`mean` + is used as the shape for the returned output tensor + + .. note:: When :attr:`std` is a CUDA tensor, this function synchronizes + its device with the CPU. + + Args: + mean (Tensor): the tensor of per-element means + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=torch.arange(1., 11.), std=torch.arange(1, 0, -0.1)) + tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, + 8.0505, 8.1408, 9.0563, 10.0566]) + + .. function:: normal(mean=0.0, std, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means are shared among all drawn + elements. + + Args: + mean (float, optional): the mean for all distributions + std (Tensor): the tensor of per-element standard deviations + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) + tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) + + .. function:: normal(mean, std=1.0, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the standard deviations are shared among + all drawn elements. + + Args: + mean (Tensor): the tensor of per-element means + std (float, optional): the standard deviation for all distributions + + Keyword args: + out (Tensor, optional): the output tensor + + Example:: + + >>> torch.normal(mean=torch.arange(1., 6.)) + tensor([ 1.1552, 2.6148, 2.6535, 5.8318, 4.2361]) + + .. function:: normal(mean, std, size, *, out=None) -> Tensor + :noindex: + + Similar to the function above, but the means and standard deviations are shared + among all drawn elements. The resulting tensor has size given by :attr:`size`. + + Args: + mean (float): the mean for all distributions + std (float): the standard deviation for all distributions + size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.normal(2, 3, size=(1, 4)) + tensor([[-1.3987, -1.9544, 3.6048, 0.7909]]) + """ + +@overload +def not_equal( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + not_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ne`. + """ + +@overload +def not_equal( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + not_equal(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.ne`. + """ + +@overload +def nuclear_norm( + input: Tensor, + dim: _int | _size, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: ... +@overload +def nuclear_norm( + input: Tensor, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: ... +def numel(self: Tensor) -> _int: + r""" + numel(input: Tensor) -> int + + Returns the total number of elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> a = torch.randn(1, 2, 3, 4, 5) + >>> torch.numel(a) + 120 + >>> a = torch.zeros(4,4) + >>> torch.numel(a) + 16 + """ + +@overload +def ones( + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + +@overload +def ones( + *size: _int | SymInt, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + +@overload +def ones( + size: _size, + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + +@overload +def ones( + *size: _int, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword arguments: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.ones(2, 3) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + + >>> torch.ones(5) + tensor([ 1., 1., 1., 1., 1.]) + """ + +def ones_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor filled with the scalar value `1`, with the same size as + :attr:`input`. ``torch.ones_like(input)`` is equivalent to + ``torch.ones(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.ones_like(input, out=output)`` is equivalent to + ``torch.ones(input.size(), out=output)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword arguments: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> input = torch.empty(2, 3) + >>> torch.ones_like(input) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.]]) + """ + +def orgqr( + input: Tensor, + input2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + orgqr(input, tau) -> Tensor + + Alias for :func:`torch.linalg.householder_product`. + """ + +def ormqr( + input: Tensor, + input2: Tensor, + input3: Tensor, + left: _bool = True, + transpose: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + ormqr(input, tau, other, left=True, transpose=False, *, out=None) -> Tensor + + Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. + + Multiplies a :math:`m \times n` matrix `C` (given by :attr:`other`) with a matrix `Q`, + where `Q` is represented using Householder reflectors `(input, tau)`. + See `Representation of Orthogonal or Unitary Matrices`_ for further details. + + If :attr:`left` is `True` then `op(Q)` times `C` is computed, otherwise the result is `C` times `op(Q)`. + When :attr:`left` is `True`, the implicit matrix `Q` has size :math:`m \times m`. + It has size :math:`n \times n` otherwise. + If :attr:`transpose` is `True` then `op` is the conjugate transpose operation, otherwise it's a no-op. + + Supports inputs of float, double, cfloat and cdouble dtypes. + Also supports batched inputs, and, if the input is batched, the output is batched with the same dimensions. + + .. seealso:: + :func:`torch.geqrf` can be used to form the Householder representation `(input, tau)` of matrix `Q` + from the QR decomposition. + + .. note:: + This function supports backward but it is only fast when ``(input, tau)`` do not require gradients + and/or ``tau.size(-1)`` is very small. + `` + + Args: + input (Tensor): tensor of shape `(*, mn, k)` where `*` is zero or more batch dimensions + and `mn` equals to `m` or `n` depending on the :attr:`left`. + tau (Tensor): tensor of shape `(*, min(mn, k))` where `*` is zero or more batch dimensions. + other (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + left (bool): controls the order of multiplication. + transpose (bool): controls whether the matrix `Q` is conjugate transposed or not. + + Keyword args: + out (Tensor, optional): the output Tensor. Ignored if `None`. Default: `None`. + + .. _Representation of Orthogonal or Unitary Matrices: + https://www.netlib.org/lapack/lug/node128.html + """ + +def outer( + input: Tensor, + vec2: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + outer(input, vec2, *, out=None) -> Tensor + + Outer product of :attr:`input` and :attr:`vec2`. + If :attr:`input` is a vector of size :math:`n` and :attr:`vec2` is a vector of + size :math:`m`, then :attr:`out` must be a matrix of size :math:`(n \times m)`. + + .. note:: This function does not :ref:`broadcast `. + + Args: + input (Tensor): 1-D input vector + vec2 (Tensor): 1-D input vector + + Keyword args: + out (Tensor, optional): optional output matrix + + Example:: + + >>> v1 = torch.arange(1., 5.) + >>> v2 = torch.arange(1., 4.) + >>> torch.outer(v1, v2) + tensor([[ 1., 2., 3.], + [ 2., 4., 6.], + [ 3., 6., 9.], + [ 4., 8., 12.]]) + """ + +def pairwise_distance( + x1: Tensor, + x2: Tensor, + p: _float = 2, + eps: _float = 1e-06, + keepdim: _bool = False, +) -> Tensor: ... +def pdist(input: Tensor, p: _float = 2) -> Tensor: ... +def permute(input: Tensor, dims: _size) -> Tensor: + r""" + permute(input, dims) -> Tensor + + Returns a view of the original tensor :attr:`input` with its dimensions permuted. + + Args: + input (Tensor): the input tensor. + dims (tuple of int): The desired ordering of dimensions + + Example: + >>> x = torch.randn(2, 3, 5) + >>> x.size() + torch.Size([2, 3, 5]) + >>> torch.permute(x, (2, 0, 1)).size() + torch.Size([5, 2, 3]) + """ + +def permute_copy( + input: Tensor, + dims: _size, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.permute`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def pinverse(input: Tensor, rcond: _float = 1e-15) -> Tensor: + r""" + pinverse(input, rcond=1e-15) -> Tensor + + Alias for :func:`torch.linalg.pinv` + """ + +def pixel_shuffle(input: Tensor, upscale_factor: _int) -> Tensor: ... +def pixel_unshuffle(input: Tensor, downscale_factor: _int) -> Tensor: ... +def poisson(input: Tensor, generator: Generator | None = None) -> Tensor: + r""" + poisson(input, generator=None) -> Tensor + + Returns a tensor of the same size as :attr:`input` with each element + sampled from a Poisson distribution with rate parameter given by the corresponding + element in :attr:`input` i.e., + + .. math:: + \text{out}_i \sim \text{Poisson}(\text{input}_i) + + :attr:`input` must be non-negative. + + Args: + input (Tensor): the input tensor containing the rates of the Poisson distribution + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + + Example:: + + >>> rates = torch.rand(4, 4) * 5 # rate parameter between 0 and 5 + >>> torch.poisson(rates) + tensor([[9., 1., 3., 5.], + [8., 6., 6., 0.], + [0., 4., 5., 3.], + [2., 1., 4., 2.]]) + """ + +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: _bool, + full: _bool, + eps: _float, + reduction: _int, +) -> Tensor: ... +def polar( + abs: Tensor, + angle: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + polar(abs, angle, *, out=None) -> Tensor + + Constructs a complex tensor whose elements are Cartesian coordinates + corresponding to the polar coordinates with absolute value :attr:`abs` and angle + :attr:`angle`. + + .. math:: + \text{out} = \text{abs} \cdot \cos(\text{angle}) + \text{abs} \cdot \sin(\text{angle}) \cdot j + + .. note:: + `torch.polar` is similar to + `std::polar `_ + and does not compute the polar decomposition + of a complex tensor like Python's `cmath.polar` and SciPy's `linalg.polar` do. + The behavior of this function is undefined if `abs` is negative or NaN, or if `angle` is + infinite. + + + Args: + abs (Tensor): The absolute value the complex tensor. Must be float or double. + angle (Tensor): The angle of the complex tensor. Must be same dtype as + :attr:`abs`. + + Keyword args: + out (Tensor): If the inputs are ``torch.float32``, must be + ``torch.complex64``. If the inputs are ``torch.float64``, must be + ``torch.complex128``. + + Example:: + + >>> import numpy as np + >>> abs = torch.tensor([1, 2], dtype=torch.float64) + >>> angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64) + >>> z = torch.polar(abs, angle) + >>> z + tensor([(0.0000+1.0000j), (-1.4142-1.4142j)], dtype=torch.complex128) + """ + +def polygamma( + n: _int, + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + polygamma(n, input, *, out=None) -> Tensor + + Alias for :func:`torch.special.polygamma`. + """ + +def positive(input: Tensor) -> Tensor: + r""" + positive(input) -> Tensor + + Returns :attr:`input`. + Throws a runtime error if :attr:`input` is a bool tensor. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> t = torch.randn(5) + >>> t + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + >>> torch.positive(t) + tensor([ 0.0090, -0.2262, -0.0682, -0.2866, 0.3940]) + """ + +@overload +def pow( + input: Tensor, + exponent: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + +@overload +def pow( + self: Number | _complex, + exponent: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + +@overload +def pow( + input: Tensor, + exponent: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + pow(input, exponent, *, out=None) -> Tensor + + Takes the power of each element in :attr:`input` with :attr:`exponent` and + returns a tensor with the result. + + :attr:`exponent` can be either a single ``float`` number or a `Tensor` + with the same number of elements as :attr:`input`. + + When :attr:`exponent` is a scalar value, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ \text{exponent} + + When :attr:`exponent` is a tensor, the operation applied is: + + .. math:: + \text{out}_i = x_i ^ {\text{exponent}_i} + + When :attr:`exponent` is a tensor, the shapes of :attr:`input` + and :attr:`exponent` must be :ref:`broadcastable `. + + Args: + input (Tensor): the input tensor. + exponent (float or tensor): the exponent value + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.4331, 1.2475, 0.6834, -0.2791]) + >>> torch.pow(a, 2) + tensor([ 0.1875, 1.5561, 0.4670, 0.0779]) + >>> exp = torch.arange(1., 5.) + + >>> a = torch.arange(1., 5.) + >>> a + tensor([ 1., 2., 3., 4.]) + >>> exp + tensor([ 1., 2., 3., 4.]) + >>> torch.pow(a, exp) + tensor([ 1., 4., 27., 256.]) + + .. function:: pow(self, exponent, *, out=None) -> Tensor + :noindex: + + :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. + The returned tensor :attr:`out` is of the same shape as :attr:`exponent` + + The operation applied is: + + .. math:: + \text{out}_i = \text{self} ^ {\text{exponent}_i} + + Args: + self (float): the scalar base value for the power operation + exponent (Tensor): the exponent tensor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> exp = torch.arange(1., 5.) + >>> base = 2 + >>> torch.pow(base, exp) + tensor([ 2., 4., 8., 16.]) + """ + +def prelu(input: Tensor, weight: Tensor) -> Tensor: ... +@overload +def prod(input: Tensor, *, dtype: _dtype | None = None) -> Tensor: + r""" + prod(input: Tensor, *, dtype: Optional[_dtype]) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + +@overload +def prod( + input: Tensor, + dim: _int, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + prod(input: Tensor, *, dtype: Optional[_dtype]) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + +@overload +def prod( + input: Tensor, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + prod(input: Tensor, *, dtype: Optional[_dtype]) -> Tensor + + Returns the product of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[-0.8020, 0.5428, -1.5854]]) + >>> torch.prod(a) + tensor(0.6902) + + .. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the product of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in + the output tensor having 1 fewer dimension than :attr:`input`. + + Args: + input (Tensor): the input tensor. + + dim (int, optional): the dimension to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 2) + >>> a + tensor([[ 0.5261, -0.3837], + [ 1.1857, -0.2498], + [-1.1646, 0.0705], + [ 1.1131, -1.0629]]) + >>> torch.prod(a, 1) + tensor([-0.2018, -0.2962, -0.0821, -1.1831]) + """ + +def promote_types(type1: _dtype, type2: _dtype) -> _dtype: + r""" + promote_types(type1, type2) -> dtype + + Returns the :class:`torch.dtype` with the smallest size and scalar kind that is + not smaller nor of lower kind than either `type1` or `type2`. See type promotion + :ref:`documentation ` for more information on the type + promotion logic. + + Args: + type1 (:class:`torch.dtype`) + type2 (:class:`torch.dtype`) + + Example:: + + >>> torch.promote_types(torch.int32, torch.float32) + torch.float32 + >>> torch.promote_types(torch.uint8, torch.long) + torch.long + """ + +def put( + input: Tensor, + index: Tensor, + source: Tensor, + accumulate: _bool = False, +) -> Tensor: ... +def q_per_channel_axis(input: Tensor) -> _int: ... +def q_per_channel_scales(input: Tensor) -> Tensor: ... +def q_per_channel_zero_points(input: Tensor) -> Tensor: ... +def q_scale(input: Tensor) -> _float: ... +def q_zero_point(input: Tensor) -> _int: ... +def qr( + input: Tensor, + some: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.qr: + r""" + qr(input: Tensor, some: bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None]) -> (Tensor, Tensor) + + Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, + and returns a namedtuple (Q, R) of tensors such that :math:`\text{input} = Q R` + with :math:`Q` being an orthogonal matrix or batch of orthogonal matrices and + :math:`R` being an upper triangular matrix or batch of upper triangular matrices. + + If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization. + Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization. + + .. warning:: + + :func:`torch.qr` is deprecated in favor of :func:`torch.linalg.qr` + and will be removed in a future PyTorch release. The boolean parameter :attr:`some` has been + replaced with a string parameter :attr:`mode`. + + ``Q, R = torch.qr(A)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A) + + ``Q, R = torch.qr(A, some=False)`` should be replaced with + + .. code:: python + + Q, R = torch.linalg.qr(A, mode="complete") + + .. warning:: + If you plan to backpropagate through QR, note that the current backward implementation + is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` + columns of :attr:`input` are linearly independent. + This behavior will probably change once QR supports pivoting. + + .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, + and may produce different (valid) decompositions on different device types + or different platforms. + + Args: + input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more + batch dimensions consisting of matrices of dimension :math:`m \times n`. + some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for + complete QR decomposition. If `k = min(m, n)` then: + + * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default) + + * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n) + + Keyword args: + out (tuple, optional): tuple of `Q` and `R` tensors. + The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above. + + Example:: + + >>> a = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> q, r = torch.qr(a) + >>> q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> r + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> torch.mm(q, r).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> torch.mm(q.t(), q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> a = torch.randn(3, 4, 5) + >>> q, r = torch.qr(a, some=False) + >>> torch.allclose(torch.matmul(q, r), a) + True + >>> torch.allclose(torch.matmul(q.mT, q), torch.eye(5)) + True + """ + +@overload +def quantile( + input: Tensor, + q: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + out: Tensor | None = None, +) -> Tensor: + r""" + quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + + To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location + of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with + indices ``i`` and ``j`` in the sorted order, result is computed according to the given + :attr:`interpolation` method as follows: + + - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. + - ``lower``: ``a``. + - ``higher``: ``b``. + - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). + - ``midpoint``: ``(a + b) / 2``. + + If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size + equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + + .. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + + dim (int, optional): the dimension to reduce. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) + """ + +@overload +def quantile( + input: Tensor, + q: _float, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + out: Tensor | None = None, +) -> Tensor: + r""" + quantile(input, q, dim=None, keepdim=False, *, interpolation='linear', out=None) -> Tensor + + Computes the q-th quantiles of each row of the :attr:`input` tensor along the dimension :attr:`dim`. + + To compute the quantile, we map q in [0, 1] to the range of indices [0, n] to find the location + of the quantile in the sorted input. If the quantile lies between two data points ``a < b`` with + indices ``i`` and ``j`` in the sorted order, result is computed according to the given + :attr:`interpolation` method as follows: + + - ``linear``: ``a + (b - a) * fraction``, where ``fraction`` is the fractional part of the computed quantile index. + - ``lower``: ``a``. + - ``higher``: ``b``. + - ``nearest``: ``a`` or ``b``, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). + - ``midpoint``: ``(a + b) / 2``. + + If :attr:`q` is a 1D tensor, the first dimension of the output represents the quantiles and has size + equal to the size of :attr:`q`, the remaining dimensions are what remains from the reduction. + + .. note:: + By default :attr:`dim` is ``None`` resulting in the :attr:`input` tensor being flattened before computation. + + Args: + input (Tensor): the input tensor. + q (float or Tensor): a scalar or 1D tensor of values in the range [0, 1]. + + dim (int, optional): the dimension to reduce. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword arguments: + interpolation (str): interpolation method to use when the desired quantile lies between two data points. + Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. + Default is ``linear``. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(2, 3) + >>> a + tensor([[ 0.0795, -1.2117, 0.9765], + [ 1.1707, 0.6706, 0.4884]]) + >>> q = torch.tensor([0.25, 0.5, 0.75]) + >>> torch.quantile(a, q, dim=1, keepdim=True) + tensor([[[-0.5661], + [ 0.5795]], + + [[ 0.0795], + [ 0.6706]], + + [[ 0.5280], + [ 0.9206]]]) + >>> torch.quantile(a, q, dim=1, keepdim=True).shape + torch.Size([3, 2, 1]) + >>> a = torch.arange(4.) + >>> a + tensor([0., 1., 2., 3.]) + >>> torch.quantile(a, 0.6, interpolation='linear') + tensor(1.8000) + >>> torch.quantile(a, 0.6, interpolation='lower') + tensor(1.) + >>> torch.quantile(a, 0.6, interpolation='higher') + tensor(2.) + >>> torch.quantile(a, 0.6, interpolation='midpoint') + tensor(1.5000) + >>> torch.quantile(a, 0.6, interpolation='nearest') + tensor(2.) + >>> torch.quantile(a, 0.4, interpolation='nearest') + tensor(1.) + """ + +def quantize_per_channel( + input: Tensor, + scales: Tensor, + zero_points: Tensor, + axis: _int, + dtype: _dtype, +) -> Tensor: + r""" + quantize_per_channel(input, scales, zero_points, axis, dtype) -> Tensor + + Converts a float tensor to a per-channel quantized tensor with given scales and zero points. + + Arguments: + input (Tensor): float tensor to quantize + scales (Tensor): float 1D tensor of scales to use, size should match ``input.size(axis)`` + zero_points (int): integer 1D tensor of offset to use, size should match ``input.size(axis)`` + axis (int): dimension on which apply per-channel quantization + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor + + Example:: + + >>> x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) + tensor([[-1., 0.], + [ 1., 2.]], size=(2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_channel_affine, + scale=tensor([0.1000, 0.0100], dtype=torch.float64), + zero_point=tensor([10, 0]), axis=0) + >>> torch.quantize_per_channel(x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8).int_repr() + tensor([[ 0, 10], + [100, 200]], dtype=torch.uint8) + """ + +@overload +def quantize_per_tensor( + input: Tensor, + scale: Tensor, + zero_point: Tensor, + dtype: _dtype, +) -> Tensor: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + +@overload +def quantize_per_tensor( + input: Tensor, + scale: _float, + zero_point: _int, + dtype: _dtype, +) -> Tensor: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + +@overload +def quantize_per_tensor( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + scales: Tensor, + zero_points: Tensor, + dtype: _dtype, +) -> tuple[Tensor, ...]: + r""" + quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor + + Converts a float tensor to a quantized tensor with given scale and zero point. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + scale (float or Tensor): scale to apply in quantization formula + zero_point (int or Tensor): offset in integer value that maps to float zero + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8``, ``torch.qint32`` + + Returns: + Tensor: A newly quantized tensor or list of quantized tensors. + + Example:: + + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8).int_repr() + tensor([ 0, 10, 20, 30], dtype=torch.uint8) + >>> torch.quantize_per_tensor([torch.tensor([-1.0, 0.0]), torch.tensor([-2.0, 2.0])], + >>> torch.tensor([0.1, 0.2]), torch.tensor([10, 20]), torch.quint8) + (tensor([-1., 0.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.1, zero_point=10), + tensor([-2., 2.], size=(2,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=20)) + >>> torch.quantize_per_tensor(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.tensor(0.1), torch.tensor(10), torch.quint8) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.10, zero_point=10) + """ + +def quantize_per_tensor_dynamic( + input: Tensor, + dtype: _dtype, + reduce_range: _bool, +) -> Tensor: + r""" + quantize_per_tensor_dynamic(input, dtype, reduce_range) -> Tensor + + Converts a float tensor to a quantized tensor with scale and zero_point calculated + dynamically based on the input. + + Arguments: + input (Tensor): float tensor or list of tensors to quantize + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + Has to be one of the quantized dtypes: ``torch.quint8``, ``torch.qint8`` + reduce_range (bool): a flag to indicate whether to reduce the range of quantized + data by 1 bit, it's required to avoid instruction overflow for some hardwares + + Returns: + Tensor: A newly (dynamically) quantized tensor + + Example:: + + >>> t = torch.quantize_per_tensor_dynamic(torch.tensor([-1.0, 0.0, 1.0, 2.0]), torch.quint8, False) + >>> print(t) + tensor([-1., 0., 1., 2.], size=(4,), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.011764705882352941, + zero_point=85) + >>> t.int_repr() + tensor([ 0, 85, 170, 255], dtype=torch.uint8) + """ + +def quantized_batch_norm( + input: Tensor, + weight: Tensor | None, + bias: Tensor | None, + mean: Tensor, + var: Tensor, + eps: _float, + output_scale: _float, + output_zero_point: _int, +) -> Tensor: + r""" + quantized_batch_norm(input, weight=None, bias=None, mean, var, eps, output_scale, output_zero_point) -> Tensor + + Applies batch normalization on a 4D (NCHW) quantized tensor. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Arguments: + input (Tensor): quantized tensor + weight (Tensor): float tensor that corresponds to the gamma, size C + bias (Tensor): float tensor that corresponds to the beta, size C + mean (Tensor): float mean value in batch normalization, size C + var (Tensor): float tensor for variance, size C + eps (float): a value added to the denominator for numerical stability. + output_scale (float): output quantized tensor scale + output_zero_point (int): output quantized tensor zero_point + + Returns: + Tensor: A quantized tensor with batch normalization applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_batch_norm(qx, torch.ones(2), torch.zeros(2), torch.rand(2), torch.rand(2), 0.00001, 0.2, 2) + tensor([[[[-0.2000, -0.2000], + [ 1.6000, -0.2000]], + + [[-0.4000, -0.4000], + [-0.4000, 0.6000]]], + + + [[[-0.2000, -0.2000], + [-0.2000, -0.2000]], + + [[ 0.6000, -0.4000], + [ 0.6000, -0.4000]]]], size=(2, 2, 2, 2), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=0.2, zero_point=2) + """ + +def quantized_gru_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, + packed_ih: Tensor, + packed_hh: Tensor, + col_offsets_ih: Tensor, + col_offsets_hh: Tensor, + scale_ih: Number | _complex, + scale_hh: Number | _complex, + zero_point_ih: Number | _complex, + zero_point_hh: Number | _complex, +) -> Tensor: ... +def quantized_lstm_cell( + input: Tensor, + hx: tuple[Tensor, ...] | list[Tensor] | None, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, + packed_ih: Tensor, + packed_hh: Tensor, + col_offsets_ih: Tensor, + col_offsets_hh: Tensor, + scale_ih: Number | _complex, + scale_hh: Number | _complex, + zero_point_ih: Number | _complex, + zero_point_hh: Number | _complex, +) -> tuple[Tensor, Tensor]: ... +def quantized_max_pool1d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: + r""" + quantized_max_pool1d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + + Applies a 1D max pooling over an input quantized tensor composed of several input planes. + + Arguments: + input (Tensor): quantized tensor + kernel_size (list of int): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + + Returns: + Tensor: A quantized tensor with max_pool1d applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool1d(qx, [2]) + tensor([[0.0000], + [1.5000]], size=(2, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) + """ + +def quantized_max_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: + r""" + quantized_max_pool2d(input, kernel_size, stride=[], padding=0, dilation=1, ceil_mode=False) -> Tensor + + Applies a 2D max pooling over an input quantized tensor composed of several input planes. + + Arguments: + input (Tensor): quantized tensor + kernel_size (``list of int``): the size of the sliding window + stride (``list of int``, optional): the stride of the sliding window + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 + ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. + Defaults to False. + + + Returns: + Tensor: A quantized tensor with max_pool2d applied. + + Example:: + + >>> qx = torch.quantize_per_tensor(torch.rand(2, 2, 2, 2), 1.5, 3, torch.quint8) + >>> torch.quantized_max_pool2d(qx, [2,2]) + tensor([[[[1.5000]], + + [[1.5000]]], + + + [[[0.0000]], + + [[0.0000]]]], size=(2, 2, 1, 1), dtype=torch.quint8, + quantization_scheme=torch.per_tensor_affine, scale=1.5, zero_point=3) + """ + +def quantized_max_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size = (), + padding: _int | _size = 0, + dilation: _int | _size = 1, + ceil_mode: _bool = False, +) -> Tensor: ... +def quantized_rnn_relu_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, + packed_ih: Tensor, + packed_hh: Tensor, + col_offsets_ih: Tensor, + col_offsets_hh: Tensor, + scale_ih: Number | _complex, + scale_hh: Number | _complex, + zero_point_ih: Number | _complex, + zero_point_hh: Number | _complex, +) -> Tensor: ... +def quantized_rnn_tanh_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor, + b_hh: Tensor, + packed_ih: Tensor, + packed_hh: Tensor, + col_offsets_ih: Tensor, + col_offsets_hh: Tensor, + scale_ih: Number | _complex, + scale_hh: Number | _complex, + zero_point_ih: Number | _complex, + zero_point_hh: Number | _complex, +) -> Tensor: ... +def rad2deg(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + rad2deg(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with each of the elements of :attr:`input` + converted from angles in radians to degrees. + + Args: + input (Tensor): the input tensor. + + Keyword arguments: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]]) + >>> torch.rad2deg(a) + tensor([[ 180.0233, -180.0233], + [ 359.9894, -359.9894], + [ 89.9544, -89.9544]]) + """ + +def rad2deg_(input: Tensor) -> Tensor: ... +@overload +def rand( + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + *size: _int | SymInt, + generator: Generator | None, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + *size: _int | SymInt, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + *size: _int | SymInt, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + size: Sequence[_int | SymInt], + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +@overload +def rand( + *size: _int | SymInt, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a tensor filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)` + + The shape of the tensor is defined by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.rand(4) + tensor([ 0.5204, 0.2503, 0.3525, 0.5673]) + >>> torch.rand(2, 3) + tensor([[ 0.8237, 0.5781, 0.6879], + [ 0.3816, 0.7249, 0.0998]]) + """ + +def rand_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` that is filled with + random numbers from a uniform distribution on the interval :math:`[0, 1)`. + ``torch.rand_like(input)`` is equivalent to + ``torch.rand(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randint( + low: _int, + high: _int, + size: _size, + *, + generator: Generator | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + high: _int, + size: _size, + *, + generator: Generator | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + high: _int | SymInt, + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + high: _int | SymInt, + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + low: _int | SymInt, + high: _int | SymInt, + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint( + low: _int | SymInt, + high: _int | SymInt, + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint(low=0, high, size, \*, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with random integers generated uniformly + between :attr:`low` (inclusive) and :attr:`high` (exclusive). + + The shape of the tensor is defined by the variable argument :attr:`size`. + + .. note:: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + this function returns a tensor with dtype ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.randint(3, 5, (3,)) + tensor([4, 3, 4]) + + + >>> torch.randint(10, (2, 2)) + tensor([[0, 2], + [5, 5]]) + + + >>> torch.randint(3, 10, (2, 2)) + tensor([[4, 5], + [6, 7]]) + """ + +@overload +def randint_like( + input: Tensor, + low: _int | SymInt, + high: _int | SymInt, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randint_like( + input: Tensor, + high: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randint_like( + input: Tensor, + high: _int | SymInt, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randint_like(input, low=0, high, \*, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same shape as Tensor :attr:`input` filled with + random integers generated uniformly between :attr:`low` (inclusive) and + :attr:`high` (exclusive). + + .. note: + With the global dtype default (``torch.float32``), this function returns + a tensor with dtype ``torch.int64``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randn( + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + *size: _int | SymInt, + generator: Generator | None, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + size: Sequence[_int | SymInt], + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + *size: _int | SymInt, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + *size: _int | SymInt, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + size: Sequence[_int | SymInt], + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +@overload +def randn( + *size: _int | SymInt, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + + Returns a tensor filled with random numbers from a normal distribution + with mean `0` and variance `1` (also called the standard normal + distribution). + + .. math:: + \text{out}_{i} \sim \mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a `complex normal distribution`_ with zero mean and + unit variance as + + .. math:: + \text{out}_{i} \sim \mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\operatorname{Re})` and imaginary + :math:`(\operatorname{Im})` part of :math:`\text{out}_i` as + + .. math:: + \operatorname{Re}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}),\quad + \operatorname{Im}(\text{out}_{i}) \sim \mathcal{N}(0, \frac{1}{2}) + + The shape of the tensor is defined by the variable argument :attr:`size`. + + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randn(4) + tensor([-2.1436, 0.9966, 2.3426, -0.6366]) + >>> torch.randn(2, 3) + tensor([[ 1.5954, 2.8929, -1.0923], + [ 1.1719, -0.4709, -0.1996]]) + + .. _complex normal distribution: https://en.wikipedia.org/wiki/Complex_normal_distribution + """ + +def randn_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor with the same size as :attr:`input` that is filled with + random numbers from a normal distribution with mean 0 and variance 1. Please refer to :func:`torch.randn` for the + sampling process of complex dtypes. ``torch.randn_like(input)`` is equivalent to + ``torch.randn(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + +@overload +def randperm( + n: _int | SymInt, + *, + generator: Generator | None, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a random permutation of integers from ``0`` to ``n - 1``. + + Args: + n (int): the upper bound (exclusive) + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) + """ + +@overload +def randperm( + n: _int | SymInt, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + randperm(n, *, generator=None, out=None, dtype=torch.int64,layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Returns a random permutation of integers from ``0`` to ``n - 1``. + + Args: + n (int): the upper bound (exclusive) + + Keyword args: + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: ``torch.int64``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> torch.randperm(4) + tensor([2, 1, 0, 3]) + """ + +def range( + start: Number, + end: Number, + step: Number = 1, + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + range(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1` + with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is + the gap between two values in the tensor. + + .. math:: + \text{out}_{i+1} = \text{out}_i + \text{step}. + + .. warning:: + This function is deprecated and will be removed in a future release because its behavior is inconsistent with + Python's range builtin. Instead, use :func:`torch.arange`, which produces values in [start, end). + + Args: + start (float, optional): the starting value for the set of points. Default: ``0``. + end (float): the ending value for the set of points + step (float, optional): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). If `dtype` is not given, infer the data type from the other input + arguments. If any of `start`, `end`, or `step` are floating-point, the + `dtype` is inferred to be the default dtype, see + :meth:`~torch.get_default_dtype`. Otherwise, the `dtype` is inferred to + be `torch.int64`. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.range(1, 4) + tensor([ 1., 2., 3., 4.]) + >>> torch.range(1, 4, 0.5) + tensor([ 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000]) + """ + +def ravel(input: Tensor) -> Tensor: + r""" + ravel(input) -> Tensor + + Return a contiguous flattened tensor. A copy is made only if needed. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> t = torch.tensor([[[1, 2], + ... [3, 4]], + ... [[5, 6], + ... [7, 8]]]) + >>> torch.ravel(t) + tensor([1, 2, 3, 4, 5, 6, 7, 8]) + """ + +def real(input: Tensor) -> Tensor: + r""" + real(input) -> Tensor + + Returns a new tensor containing real values of the :attr:`self` tensor. + The returned tensor and :attr:`self` share the same underlying storage. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j)]) + >>> x.real + tensor([ 0.3100, -0.5445, -1.6492, -0.0638]) + """ + +def reciprocal(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + reciprocal(input, *, out=None) -> Tensor + + Returns a new tensor with the reciprocal of the elements of :attr:`input` + + .. math:: + \text{out}_{i} = \frac{1}{\text{input}_{i}} + + .. note:: + Unlike NumPy's reciprocal, torch.reciprocal supports integral inputs. Integral + inputs to reciprocal are automatically :ref:`promoted ` to + the default scalar type. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.4595, -2.1219, -1.4314, 0.7298]) + >>> torch.reciprocal(a) + tensor([-2.1763, -0.4713, -0.6986, 1.3702]) + """ + +def reciprocal_(input: Tensor) -> Tensor: ... +def relu(input: Tensor) -> Tensor: ... +def relu_(input: Tensor) -> Tensor: ... +@overload +def remainder( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + +@overload +def remainder(self: Number | _complex, other: Tensor) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + +@overload +def remainder( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + remainder(input, other, *, out=None) -> Tensor + + Computes + `Python's modulus operation `_ + entrywise. The result has the same sign as the divisor :attr:`other` and its absolute value + is less than that of :attr:`other`. + + It may also be defined in terms of :func:`torch.div` as + + .. code:: python + + torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer and float inputs. + + .. note:: + Complex inputs are not supported. In some cases, it is not mathematically + possible to satisfy the definition of a modulo operation with complex numbers. + See :func:`torch.fmod` for how division by zero is handled. + + .. seealso:: + + :func:`torch.fmod` which implements C++'s `std::fmod `_. + This one is defined in terms of division rounding towards zero. + + Args: + input (Tensor or Scalar): the dividend + other (Tensor or Scalar): the divisor + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2) + tensor([ 1., 0., 1., 1., 0., 1.]) + >>> torch.remainder(torch.tensor([1, 2, 3, 4, 5]), -1.5) + tensor([ -0.5000, -1.0000, 0.0000, -0.5000, -1.0000 ]) + """ + +def renorm( + input: Tensor, + p: Number | _complex, + dim: _int, + maxnorm: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + renorm(input, p, dim, maxnorm, *, out=None) -> Tensor + + Returns a tensor where each sub-tensor of :attr:`input` along dimension + :attr:`dim` is normalized such that the `p`-norm of the sub-tensor is lower + than the value :attr:`maxnorm` + + .. note:: If the norm of a row is lower than `maxnorm`, the row is unchanged + + Args: + input (Tensor): the input tensor. + p (float): the power for the norm computation + dim (int): the dimension to slice over to get the sub-tensors + maxnorm (float): the maximum norm to keep each sub-tensor under + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.ones(3, 3) + >>> x[1].fill_(2) + tensor([ 2., 2., 2.]) + >>> x[2].fill_(3) + tensor([ 3., 3., 3.]) + >>> x + tensor([[ 1., 1., 1.], + [ 2., 2., 2.], + [ 3., 3., 3.]]) + >>> torch.renorm(x, 1, 0, 5) + tensor([[ 1.0000, 1.0000, 1.0000], + [ 1.6667, 1.6667, 1.6667], + [ 1.6667, 1.6667, 1.6667]]) + """ + +@overload +def repeat_interleave( + input: Tensor, + repeats: Tensor, + dim: _int | None = None, + *, + output_size: _int | SymInt | None = None, +) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + +@overload +def repeat_interleave( + repeats: Tensor, + *, + output_size: _int | SymInt | None = None, +) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + +@overload +def repeat_interleave( + input: Tensor, + repeats: _int | SymInt, + dim: _int | None = None, + *, + output_size: _int | SymInt | None = None, +) -> Tensor: + r""" + repeat_interleave(input, repeats, dim=None, *, output_size=None) -> Tensor + + Repeat elements of a tensor. + + .. warning:: + + This is different from :meth:`torch.Tensor.repeat` but similar to ``numpy.repeat``. + + Args: + input (Tensor): the input tensor. + repeats (Tensor or int): The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. + By default, use the flattened input array, and return a flat output + array. + + Keyword args: + output_size (int, optional): Total output size for the given axis + ( e.g. sum of repeats). If given, it will avoid stream synchronization + needed to calculate output shape of the tensor. + + Returns: + Tensor: Repeated tensor which has the same shape as input, except along the given axis. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat_interleave(2) + tensor([1, 1, 2, 2, 3, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.repeat_interleave(y, 2) + tensor([1, 1, 2, 2, 3, 3, 4, 4]) + >>> torch.repeat_interleave(y, 3, dim=1) + tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + tensor([[1, 2], + [3, 4], + [3, 4]]) + >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) + tensor([[1, 2], + [3, 4], + [3, 4]]) + + If the `repeats` is `tensor([n1, n2, n3, ...])`, then the output will be + `tensor([0, 0, ..., 1, 1, ..., 2, 2, ..., ...])` where `0` appears `n1` times, + `1` appears `n2` times, `2` appears `n3` times, etc. + + .. function:: repeat_interleave(repeats, *) -> Tensor + :noindex: + + Repeats 0 repeats[0] times, 1 repeats[1] times, 2 repeats[2] times, etc. + + Args: + repeats (Tensor): The number of repetitions for each element. + + Returns: + Tensor: Repeated tensor of size `sum(repeats)`. + + Example:: + + >>> torch.repeat_interleave(torch.tensor([1, 2, 3])) + tensor([0, 1, 1, 2, 2, 2]) + """ + +def reshape(input: Tensor, shape: Sequence[_int | SymInt]) -> Tensor: + r""" + reshape(input, shape) -> Tensor + + Returns a tensor with the same data and number of elements as :attr:`input`, + but with the specified shape. When possible, the returned tensor will be a view + of :attr:`input`. Otherwise, it will be a copy. Contiguous inputs and inputs + with compatible strides can be reshaped without copying, but you should not + depend on the copying vs. viewing behavior. + + See :meth:`torch.Tensor.view` on when it is possible to return a view. + + A single dimension may be -1, in which case it's inferred from the remaining + dimensions and the number of elements in :attr:`input`. + + Args: + input (Tensor): the tensor to be reshaped + shape (tuple of int): the new shape + + Example:: + + >>> a = torch.arange(4.) + >>> torch.reshape(a, (2, 2)) + tensor([[ 0., 1.], + [ 2., 3.]]) + >>> b = torch.tensor([[0, 1], [2, 3]]) + >>> torch.reshape(b, (-1,)) + tensor([ 0, 1, 2, 3]) + """ + +def resize_as_( + input: Tensor, + the_template: Tensor, + *, + memory_format: memory_format | None = None, +) -> Tensor: ... +def resize_as_sparse_(input: Tensor, the_template: Tensor) -> Tensor: ... +def resolve_conj(input: Tensor) -> Tensor: + r""" + resolve_conj(input) -> Tensor + + Returns a new tensor with materialized conjugation if :attr:`input`'s conjugate bit is set to `True`, + else returns :attr:`input`. The output tensor will always have its conjugate bit set to `False`. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> y.is_conj() + True + >>> z = y.resolve_conj() + >>> z + tensor([-1 - 1j, -2 - 2j, 3 + 3j]) + >>> z.is_conj() + False + """ + +def resolve_neg(input: Tensor) -> Tensor: + r""" + resolve_neg(input) -> Tensor + + Returns a new tensor with materialized negation if :attr:`input`'s negative bit is set to `True`, + else returns :attr:`input`. The output tensor will always have its negative bit set to `False`. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + >>> y = x.conj() + >>> z = y.imag + >>> z.is_neg() + True + >>> out = z.resolve_neg() + >>> out + tensor([-1., -2., 3.]) + >>> out.is_neg() + False + """ + +@overload +def result_type(tensor: Tensor, other: Tensor) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + +@overload +def result_type(scalar: Number | _complex, tensor: Tensor) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + +@overload +def result_type(tensor: Tensor, other: Number | _complex) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + +@overload +def result_type( + scalar1: Number | _complex, + scalar2: Number | _complex, +) -> _dtype: + r""" + result_type(tensor1, tensor2) -> dtype + + Returns the :class:`torch.dtype` that would result from performing an arithmetic + operation on the provided input tensors. See type promotion :ref:`documentation ` + for more information on the type promotion logic. + + Args: + tensor1 (Tensor or Number): an input tensor or number + tensor2 (Tensor or Number): an input tensor or number + + Example:: + + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.int), 1.0) + torch.float32 + >>> torch.result_type(torch.tensor([1, 2], dtype=torch.uint8), torch.tensor(1)) + torch.uint8 + """ + +def rms_norm( + input: Tensor, + normalized_shape: Sequence[_int | SymInt], + weight: Tensor | None = None, + eps: _float | None = None, +) -> Tensor: ... +@overload +def rnn_relu( + data: Tensor, + batch_sizes: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, +) -> tuple[Tensor, Tensor]: ... +@overload +def rnn_relu( + input: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor]: ... +def rnn_relu_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor | None = None, + b_hh: Tensor | None = None, +) -> Tensor: ... +@overload +def rnn_tanh( + data: Tensor, + batch_sizes: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, +) -> tuple[Tensor, Tensor]: ... +@overload +def rnn_tanh( + input: Tensor, + hx: Tensor, + params: tuple[Tensor, ...] | list[Tensor] | None, + has_biases: _bool, + num_layers: _int, + dropout: _float, + train: _bool, + bidirectional: _bool, + batch_first: _bool, +) -> tuple[Tensor, Tensor]: ... +def rnn_tanh_cell( + input: Tensor, + hx: Tensor, + w_ih: Tensor, + w_hh: Tensor, + b_ih: Tensor | None = None, + b_hh: Tensor | None = None, +) -> Tensor: ... +def roll( + input: Tensor, + shifts: _int | SymInt | Sequence[_int | SymInt], + dims: _int | _size = (), +) -> Tensor: + r""" + roll(input, shifts, dims=None) -> Tensor + + Roll the tensor :attr:`input` along the given dimension(s). Elements that are + shifted beyond the last position are re-introduced at the first position. If + :attr:`dims` is `None`, the tensor will be flattened before rolling and then + restored to the original shape. + + Args: + input (Tensor): the input tensor. + shifts (int or tuple of ints): The number of places by which the elements + of the tensor are shifted. If shifts is a tuple, dims must be a tuple of + the same size, and each dimension will be rolled by the corresponding + value + dims (int or tuple of ints): Axis along which to roll + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + >>> x + tensor([[1, 2], + [3, 4], + [5, 6], + [7, 8]]) + >>> torch.roll(x, 1) + tensor([[8, 1], + [2, 3], + [4, 5], + [6, 7]]) + >>> torch.roll(x, 1, 0) + tensor([[7, 8], + [1, 2], + [3, 4], + [5, 6]]) + >>> torch.roll(x, -1, 0) + tensor([[3, 4], + [5, 6], + [7, 8], + [1, 2]]) + >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) + tensor([[6, 5], + [8, 7], + [2, 1], + [4, 3]]) + """ + +def rot90(input: Tensor, k: _int = 1, dims: _size = (0, 1)) -> Tensor: + r""" + rot90(input, k=1, dims=(0, 1)) -> Tensor + + Rotate an n-D tensor by 90 degrees in the plane specified by dims axis. + Rotation direction is from the first towards the second axis if k > 0, and from the second towards the first for k < 0. + + Args: + input (Tensor): the input tensor. + k (int): number of times to rotate. Default value is 1 + dims (a list or tuple): axis to rotate. Default value is [0, 1] + + Example:: + + >>> x = torch.arange(4).view(2, 2) + >>> x + tensor([[0, 1], + [2, 3]]) + >>> torch.rot90(x, 1, [0, 1]) + tensor([[1, 3], + [0, 2]]) + + >>> x = torch.arange(8).view(2, 2, 2) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.rot90(x, 1, [1, 2]) + tensor([[[1, 3], + [0, 2]], + + [[5, 7], + [4, 6]]]) + """ + +@overload +def round(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + round(input, *, decimals=0, out=None) -> Tensor + + Rounds elements of :attr:`input` to the nearest integer. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + The return type of output is same as that of input's dtype. + + .. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + + .. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + + Args: + input (Tensor): the input tensor. + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) + """ + +@overload +def round( + input: Tensor, + *, + decimals: _int, + out: Tensor | None = None, +) -> Tensor: + r""" + round(input, *, decimals=0, out=None) -> Tensor + + Rounds elements of :attr:`input` to the nearest integer. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + The return type of output is same as that of input's dtype. + + .. note:: + This function implements the "round half to even" to + break ties when a number is equidistant from two + integers (e.g. `round(2.5)` is 2). + + When the :attr:\`decimals\` argument is specified the + algorithm used is similar to NumPy's `around`. This + algorithm is fast but inexact and it can easily + overflow for low precision dtypes. + Eg. `round(tensor([10000], dtype=torch.float16), decimals=3)` is `inf`. + + .. seealso:: + :func:`torch.ceil`, which rounds up. + :func:`torch.floor`, which rounds down. + :func:`torch.trunc`, which rounds towards zero. + + Args: + input (Tensor): the input tensor. + decimals (int): Number of decimal places to round to (default: 0). + If decimals is negative, it specifies the number of positions + to the left of the decimal point. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.round(torch.tensor((4.7, -2.3, 9.1, -7.7))) + tensor([ 5., -2., 9., -8.]) + + >>> # Values equidistant from two integers are rounded towards the + >>> # the nearest even value (zero is treated as even) + >>> torch.round(torch.tensor([-0.5, 0.5, 1.5, 2.5])) + tensor([-0., 0., 2., 2.]) + + >>> # A positive decimals argument rounds to the to that decimal place + >>> torch.round(torch.tensor([0.1234567]), decimals=3) + tensor([0.1230]) + + >>> # A negative decimals argument rounds to the left of the decimal + >>> torch.round(torch.tensor([1200.1234567]), decimals=-3) + tensor([1000.]) + """ + +@overload +def round_(input: Tensor) -> Tensor: ... +@overload +def round_(input: Tensor, *, decimals: _int) -> Tensor: ... +def row_indices_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def row_stack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + row_stack(tensors, *, out=None) -> Tensor + + Alias of :func:`torch.vstack`. + """ + +def rrelu( + input: Tensor, + lower: Number | _complex = 0.125, + upper: Number | _complex = 0.3333333333333333, + training: _bool = False, + generator: Generator | None = None, +) -> Tensor: ... +def rrelu_( + input: Tensor, + lower: Number | _complex = 0.125, + upper: Number | _complex = 0.3333333333333333, + training: _bool = False, + generator: Generator | None = None, +) -> Tensor: ... +def rsqrt(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + rsqrt(input, *, out=None) -> Tensor + + Returns a new tensor with the reciprocal of the square-root of each of + the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \frac{1}{\sqrt{\text{input}_{i}}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.0370, 0.2970, 1.5420, -0.9105]) + >>> torch.rsqrt(a) + tensor([ nan, 1.8351, 0.8053, nan]) + """ + +def rsqrt_(input: Tensor) -> Tensor: ... +@overload +def rsub( + input: Tensor, + other: Tensor, + *, + alpha: Number | _complex = 1, +) -> Tensor: ... +@overload +def rsub( + input: Tensor, + other: Number | _complex, + alpha: Number | _complex = 1, +) -> Tensor: ... +def saddmm( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number = 1, + alpha: Number = 1, + out: Tensor | None = None, +) -> Tensor: ... +def scalar_tensor( + s: Number | _complex, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: ... +@overload +def scatter( + input: Tensor, + dim: _int, + index: Tensor, + src: Tensor, + *, + reduce: str, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: _int, + index: Tensor, + src: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: _int, + index: Tensor, + value: Number | _complex, + *, + reduce: str, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + src: Tensor, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: _int, + index: Tensor, + value: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + value: Number | _complex, +) -> Tensor: + r""" + scatter(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + +@overload +def scatter_add( + input: Tensor, + dim: _int, + index: Tensor, + src: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter_add(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + +@overload +def scatter_add( + input: Tensor, + dim: str | EllipsisType | None, + index: Tensor, + src: Tensor, +) -> Tensor: + r""" + scatter_add(input, dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + +def scatter_reduce( + input: Tensor, + dim: _int, + index: Tensor, + src: Tensor, + reduce: str, + *, + include_self: _bool = True, + out: Tensor | None = None, +) -> Tensor: + r""" + scatter_reduce(input, dim, index, src, reduce, *, include_self=True) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` + """ + +@overload +def searchsorted( + sorted_sequence: Tensor, + input: Tensor, + *, + out_int32: _bool = False, + right: _bool = False, + side: str | None = None, + sorter: Tensor | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + + Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the + corresponding values in :attr:`values` were inserted before the indices, when sorted, the order + of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. + Return a new tensor with the same size as :attr:`values`. More formally, + the returned index satisfies the following rules: + + .. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + + Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) + """ + +@overload +def searchsorted( + sorted_sequence: Tensor, + self: Number | _complex, + *, + out_int32: _bool = False, + right: _bool = False, + side: str | None = None, + sorter: Tensor | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) -> Tensor + + Find the indices from the *innermost* dimension of :attr:`sorted_sequence` such that, if the + corresponding values in :attr:`values` were inserted before the indices, when sorted, the order + of the corresponding *innermost* dimension within :attr:`sorted_sequence` would be preserved. + Return a new tensor with the same size as :attr:`values`. More formally, + the returned index satisfies the following rules: + + .. list-table:: + :widths: 12 10 78 + :header-rows: 1 + + * - :attr:`sorted_sequence` + - :attr:`right` + - *returned index satisfies* + * - 1-D + - False + - ``sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]`` + * - 1-D + - True + - ``sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]`` + * - N-D + - False + - ``sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]`` + * - N-D + - True + - ``sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]`` + + Args: + sorted_sequence (Tensor): N-D or 1-D tensor, containing monotonically increasing sequence on the *innermost* + dimension unless :attr:`sorter` is provided, in which case the sequence does not + need to be sorted + values (Tensor or Scalar): N-D tensor or a Scalar containing the search value(s). + + Keyword args: + out_int32 (bool, optional): indicate the output data type. torch.int32 if True, torch.int64 otherwise. + Default value is False, i.e. default output data type is torch.int64. + right (bool, optional): if False, return the first suitable location that is found. If True, return the + last such index. If no suitable index found, return 0 for non-numerical value + (eg. nan, inf) or the size of *innermost* dimension within :attr:`sorted_sequence` + (one pass the last index of the *innermost* dimension). In other words, if False, + gets the lower bound index for each value in :attr:`values` on the corresponding + *innermost* dimension of the :attr:`sorted_sequence`. If True, gets the upper + bound index instead. Default value is False. :attr:`side` does the same and is + preferred. It will error if :attr:`side` is set to "left" while this is True. + side (str, optional): the same as :attr:`right` but preferred. "left" corresponds to False for :attr:`right` + and "right" corresponds to True for :attr:`right`. It will error if this is set to + "left" while :attr:`right` is True. Default value is None. + out (Tensor, optional): the output tensor, must be the same size as :attr:`values` if provided. + sorter (LongTensor, optional): if provided, a tensor matching the shape of the unsorted + :attr:`sorted_sequence` containing a sequence of indices that sort it in the + ascending order on the innermost dimension + + + Example:: + + >>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + >>> sorted_sequence + tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) + >>> values + tensor([[3, 6, 9], + [3, 6, 9]]) + >>> torch.searchsorted(sorted_sequence, values) + tensor([[1, 3, 4], + [1, 2, 4]]) + >>> torch.searchsorted(sorted_sequence, values, side='right') + tensor([[2, 3, 5], + [1, 3, 4]]) + + >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) + >>> sorted_sequence_1d + tensor([1, 3, 5, 7, 9]) + >>> torch.searchsorted(sorted_sequence_1d, values) + tensor([[1, 3, 4], + [1, 3, 4]]) + """ + +def segment_reduce( + data: Tensor, + reduce: str, + *, + lengths: Tensor | None = None, + indices: Tensor | None = None, + offsets: Tensor | None = None, + axis: _int = 0, + unsafe: _bool = False, + initial: Number | _complex | None = None, +) -> Tensor: ... +@overload +def select(input: Tensor, dim: _int, index: _int | SymInt) -> Tensor: + r""" + select(input, dim, index) -> Tensor + + Slices the :attr:`input` tensor along the selected dimension at the given index. + This function returns a view of the original tensor with the given dimension removed. + + .. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to slice + index (int): the index to select with + + .. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. + """ + +@overload +def select( + input: Tensor, + dim: str | EllipsisType | None, + index: _int, +) -> Tensor: + r""" + select(input, dim, index) -> Tensor + + Slices the :attr:`input` tensor along the selected dimension at the given index. + This function returns a view of the original tensor with the given dimension removed. + + .. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to slice + index (int): the index to select with + + .. note:: + + :meth:`select` is equivalent to slicing. For example, + ``tensor.select(0, index)`` is equivalent to ``tensor[index]`` and + ``tensor.select(2, index)`` is equivalent to ``tensor[:,:,index]``. + """ + +def select_copy( + input: Tensor, + dim: _int, + index: _int | SymInt, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.select`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def select_scatter( + input: Tensor, + src: Tensor, + dim: _int, + index: _int | SymInt, +) -> Tensor: + r""" + select_scatter(input, src, dim, index) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` at the given index. + This function returns a tensor with fresh storage; it does not create a view. + + + Args: + input (Tensor): the input tensor. + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into. + index (int): the index to select with + + .. note:: + + :attr:`src` must be of the proper size in order to be embedded + into :attr:`input`. Specifically, it should have the same shape as + ``torch.select(input, dim, index)`` + + Example:: + + >>> a = torch.zeros(2, 2) + >>> b = torch.ones(2) + >>> a.select_scatter(b, 0, 0) + tensor([[1., 1.], + [0., 0.]]) + """ + +def selu(input: Tensor) -> Tensor: ... +def selu_(input: Tensor) -> Tensor: ... +def set_flush_denormal(mode: _bool) -> _bool: + r""" + set_flush_denormal(mode) -> bool + + Disables denormal floating numbers on CPU. + + Returns ``True`` if your system supports flushing denormal numbers and it + successfully configures flush denormal mode. :meth:`~torch.set_flush_denormal` + is supported on x86 architectures supporting SSE3 and AArch64 architecture. + + Args: + mode (bool): Controls whether to enable flush denormal mode or not + + Example:: + + >>> torch.set_flush_denormal(True) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor([ 0.], dtype=torch.float64) + >>> torch.set_flush_denormal(False) + True + >>> torch.tensor([1e-323], dtype=torch.float64) + tensor(9.88131e-324 * + [ 1.0000], dtype=torch.float64) + """ + +def set_num_interop_threads(num: _int) -> None: + r""" + set_num_interop_threads(int) + + Sets the number of threads used for interop parallelism + (e.g. in JIT interpreter) on CPU. + + .. warning:: + Can only be called once and before any inter-op parallel work + is started (e.g. JIT execution). + """ + +def set_num_threads(num: _int) -> None: + r""" + set_num_threads(int) + + Sets the number of threads used for intraop parallelism on CPU. + + .. warning:: + To ensure that the correct number of threads is used, set_num_threads + must be called before running eager, JIT or autograd code. + """ + +def sgn(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sgn(input, *, out=None) -> Tensor + + This function is an extension of torch.sign() to complex tensors. + It computes a new tensor whose elements have + the same angles as the corresponding elements of :attr:`input` and + absolute values (i.e. magnitudes) of one for complex tensors and + is equivalent to torch.sign() for non-complex tensors. + + .. math:: + \text{out}_{i} = \begin{cases} + 0 & |\text{{input}}_i| == 0 \\ + \frac{{\text{{input}}_i}}{|{\text{{input}}_i}|} & \text{otherwise} + \end{cases} + + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([3+4j, 7-24j, 0, 1+2j]) + >>> t.sgn() + tensor([0.6000+0.8000j, 0.2800-0.9600j, 0.0000+0.0000j, 0.4472+0.8944j]) + """ + +def sigmoid(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sigmoid(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.expit`. + """ + +def sigmoid_(input: Tensor) -> Tensor: ... +def sign(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sign(input, *, out=None) -> Tensor + + Returns a new tensor with the signs of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \operatorname{sgn}(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> a + tensor([ 0.7000, -1.2000, 0.0000, 2.3000]) + >>> torch.sign(a) + tensor([ 1., -1., 0., 1.]) + """ + +def signbit(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + signbit(input, *, out=None) -> Tensor + + Tests if each element of :attr:`input` has its sign bit set or not. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([0.7, -1.2, 0., 2.3]) + >>> torch.signbit(a) + tensor([ False, True, False, False]) + >>> a = torch.tensor([-0.0, 0.0]) + >>> torch.signbit(a) + tensor([ True, False]) + + .. note:: + signbit handles signed zeros, so negative zero (-0) returns True. + """ + +def sin(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sin(input, *, out=None) -> Tensor + + Returns a new tensor with the sine of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sin(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-0.5461, 0.1347, -2.7266, -0.2746]) + >>> torch.sin(a) + tensor([-0.5194, 0.1343, -0.4032, -0.2711]) + """ + +def sin_(input: Tensor) -> Tensor: ... +def sinc(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sinc(input, *, out=None) -> Tensor + + Alias for :func:`torch.special.sinc`. + """ + +def sinc_(input: Tensor) -> Tensor: ... +def sinh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sinh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic sine of the elements of + :attr:`input`. + + .. math:: + \text{out}_{i} = \sinh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.5380, -0.8632, -0.1265, 0.9399]) + >>> torch.sinh(a) + tensor([ 0.5644, -0.9744, -0.1268, 1.0845]) + + .. note:: + When :attr:`input` is on the CPU, the implementation of torch.sinh may use + the Sleef library, which rounds very large results to infinity or negative + infinity. See `here `_ for details. + """ + +def sinh_(input: Tensor) -> Tensor: ... +def slice_copy( + input: Tensor, + dim: _int = 0, + start: _int | SymInt | None = None, + end: _int | SymInt | None = None, + step: _int | SymInt = 1, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.slice`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def slice_inverse( + input: Tensor, + src: Tensor, + dim: _int = 0, + start: _int | SymInt | None = None, + end: _int | SymInt | None = None, + step: _int | SymInt = 1, +) -> Tensor: ... +def slice_scatter( + input: Tensor, + src: Tensor, + dim: _int = 0, + start: _int | SymInt | None = None, + end: _int | SymInt | None = None, + step: _int | SymInt = 1, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + slice_scatter(input, src, dim=0, start=None, end=None, step=1) -> Tensor + + Embeds the values of the :attr:`src` tensor into :attr:`input` at the given + dimension. + This function returns a tensor with fresh storage; it does not create a view. + + + Args: + input (Tensor): the input tensor. + src (Tensor): The tensor to embed into :attr:`input` + dim (int): the dimension to insert the slice into + start (Optional[int]): the start index of where to insert the slice + end (Optional[int]): the end index of where to insert the slice + step (int): the how many elements to skip in + + Example:: + + >>> a = torch.zeros(8, 8) + >>> b = torch.ones(2, 8) + >>> a.slice_scatter(b, start=6) + tensor([[0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1.]]) + + >>> b = torch.ones(8, 2) + >>> a.slice_scatter(b, dim=1, start=2, end=6, step=2) + tensor([[0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.], + [0., 0., 1., 0., 1., 0., 0., 0.]]) + """ + +def slogdet( + input: Tensor, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.slogdet: + r""" + slogdet(input) -> (Tensor, Tensor) + + Alias for :func:`torch.linalg.slogdet` + """ + +def smm(input: Tensor, mat2: Tensor) -> Tensor: + r""" + smm(input, mat) -> Tensor + + Performs a matrix multiplication of the sparse matrix :attr:`input` + with the dense matrix :attr:`mat`. + + Args: + input (Tensor): a sparse matrix to be matrix multiplied + mat (Tensor): a dense matrix to be matrix multiplied + """ + +@overload +def softmax( + input: Tensor, + dim: _int, + dtype: _dtype | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + softmax(input, dim, *, dtype=None) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + +@overload +def softmax( + input: Tensor, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, +) -> Tensor: + r""" + softmax(input, dim, *, dtype=None) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + +@overload +def sort( + input: Tensor, + *, + stable: _bool | None, + dim: _int = -1, + descending: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + +@overload +def sort( + input: Tensor, + dim: _int = -1, + descending: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + +@overload +def sort( + input: Tensor, + *, + stable: _bool | None, + dim: str | EllipsisType | None, + descending: _bool = False, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + +@overload +def sort( + input: Tensor, + dim: str | EllipsisType | None, + descending: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.sort: + r""" + sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) + + Sorts the elements of the :attr:`input` tensor along a given dimension + in ascending order by value. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`descending` is ``True`` then the elements are sorted in descending + order by value. + + If :attr:`stable` is ``True`` then the sorting routine becomes stable, preserving + the order of equivalent elements. + + A namedtuple of (values, indices) is returned, where the `values` are the + sorted values and `indices` are the indices of the elements in the original + `input` tensor. + + Args: + input (Tensor): the input tensor. + dim (int, optional): the dimension to sort along + descending (bool, optional): controls the sorting order (ascending or descending) + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + + Keyword args: + out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can + be optionally given to be used as output buffers + + Example:: + + >>> x = torch.randn(3, 4) + >>> sorted, indices = torch.sort(x) + >>> sorted + tensor([[-0.2162, 0.0608, 0.6719, 2.3332], + [-0.5793, 0.0061, 0.6058, 0.9497], + [-0.5071, 0.3343, 0.9553, 1.0960]]) + >>> indices + tensor([[ 1, 0, 2, 3], + [ 3, 1, 0, 2], + [ 0, 3, 1, 2]]) + + >>> sorted, indices = torch.sort(x, 0) + >>> sorted + tensor([[-0.5071, -0.2162, 0.6719, -0.5793], + [ 0.0608, 0.0061, 0.9497, 0.3343], + [ 0.6058, 0.9553, 1.0960, 2.3332]]) + >>> indices + tensor([[ 2, 0, 0, 1], + [ 0, 1, 1, 2], + [ 1, 2, 2, 0]]) + >>> x = torch.tensor([0, 1] * 9) + >>> x.sort() + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) + >>> x.sort(stable=True) + torch.return_types.sort( + values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), + indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17])) + """ + +def sparse_bsc_tensor( + ccol_indices: Tensor | list, + row_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse + Column)) ` with specified 2-dimensional blocks at the + given :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix + multiplication operations in BSC format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncolblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + index in values and row_indices depending on where the given + column starts. Each successive number in the tensor subtracted + by the number before it denotes the number of elements in a + given column. + row_indices (array_like): Row block co-ordinates of each block in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial blocks for the tensor. Can be a list, + tuple, NumPy ``ndarray``, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> ccol_indices = [0, 1, 2] + >>> row_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 1, 2]), + row_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsc) + """ + +def sparse_bsr_tensor( + crow_indices: Tensor | list, + col_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_bsr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row)) + ` with specified 2-dimensional blocks at the given + :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix + multiplication operations in BSR format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrowblocks + 1)``. The last element of each + batch is the number of non-zeros. This tensor encodes the + block index in values and col_indices depending on where the + given row block starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + blocks in a given row. + col_indices (array_like): Column block co-ordinates of each block + in values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the + number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize == + values.shape[1:3]``. If not provided, the size will be + inferred as the minimum size big enough to hold all non-zero + blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> crow_indices = [0, 1, 2] + >>> col_indices = [0, 1] + >>> values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + >>> torch.sparse_bsr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 1, 2]), + col_indices=tensor([0, 1]), + values=tensor([[[1., 2.], + [3., 4.]], + [[5., 6.], + [7., 8.]]]), size=(2, 2), nnz=2, dtype=torch.float64, + layout=torch.sparse_bsr) + """ + +def sparse_compressed_tensor( + compressed_indices: Tensor | list, + plain_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR, + CSC, BSR, or BSC - ` with specified values at + the given :attr:`compressed_indices` and :attr:`plain_indices`. Sparse + matrix multiplication operations in Compressed Sparse format are + typically faster than that for sparse tensors in COO format. Make you + have a look at :ref:`the note on the data type of the indices + `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + compressed_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, compressed_dim_size + 1)``. The last element of + each batch is the number of non-zero elements or blocks. This + tensor encodes the index in ``values`` and ``plain_indices`` + depending on where the given compressed dimension (row or + column) starts. Each successive number in the tensor + subtracted by the number before it denotes the number of + elements or blocks in a given compressed dimension. + plain_indices (array_like): Plain dimension (column or row) + co-ordinates of each element or block in values. (B+1)-dimensional + tensor with the same length as values. + + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types. that + represents a (1+K)-dimensional (for CSR and CSC layouts) or + (1+2+K)-dimensional tensor (for BSR and BSC layouts) where + ``K`` is the number of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * + blocksize[1], *densesize)`` where ``blocksize[0] == + blocksize[1] == 1`` for CSR and CSC formats. If not provided, + the size will be inferred as the minimum size big enough to + hold all non-zero elements or blocks. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + layout (:class:`torch.layout`, required): the desired layout of + returned tensor: :attr:`torch.sparse_csr`, + :attr:`torch.sparse_csc`, :attr:`torch.sparse_bsr`, or + :attr:`torch.sparse_bsc`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> compressed_indices = [0, 2, 4] + >>> plain_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_compressed_tensor(torch.tensor(compressed_indices, dtype=torch.int64), + ... torch.tensor(plain_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double, layout=torch.sparse_csr) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) + """ + +def sparse_coo_tensor( + indices: Tensor, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, + is_coalesced: _bool | None = None, +) -> Tensor: + r""" + sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None, is_coalesced=None) -> Tensor + + Constructs a :ref:`sparse tensor in COO(rdinate) format + ` with specified values at the given + :attr:`indices`. + + .. note:: + + This function returns an :ref:`uncoalesced tensor + ` when :attr:`is_coalesced` is + unspecified or ``None``. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + indices (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. Will be cast to a :class:`torch.LongTensor` + internally. The indices are the coordinates of the non-zero values in the matrix, and thus + should be two-dimensional where the first dimension is the number of tensor dimensions and + the second dimension is the number of non-zero values. + values (array_like): Initial values for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + size (list, tuple, or :class:`torch.Size`, optional): Size of the sparse tensor. If not + provided the size will be inferred as the minimum size big enough to hold all non-zero + elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if None, infers data type from :attr:`values`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + is_coalesced (bool, optional): When``True``, the caller is + responsible for providing tensor indices that correspond to a + coalesced tensor. If the :attr:`check_invariants` flag is + False, no error will be raised if the prerequisites are not + met and this will lead to silently incorrect results. To force + coalescion please use :meth:`coalesce` on the resulting + Tensor. + Default: None: except for trivial cases (e.g. nnz < 2) the + resulting Tensor has is_coalesced set to ``False```. + + Example:: + + >>> i = torch.tensor([[0, 1, 1], + ... [2, 0, 2]]) + >>> v = torch.tensor([3, 4, 5], dtype=torch.float32) + >>> torch.sparse_coo_tensor(i, v, [2, 4]) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 4), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v) # Shape inference + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + size=(2, 3), nnz=3, layout=torch.sparse_coo) + + >>> torch.sparse_coo_tensor(i, v, [2, 4], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) + tensor(indices=tensor([[0, 1, 1], + [2, 0, 2]]), + values=tensor([3., 4., 5.]), + device='cuda:0', size=(2, 4), nnz=3, dtype=torch.float64, + layout=torch.sparse_coo) + + # Create an empty sparse tensor with the following invariants: + # 1. sparse_dim + dense_dim = len(SparseTensor.shape) + # 2. SparseTensor._indices().shape = (sparse_dim, nnz) + # 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) + # + # For instance, to create an empty sparse tensor with nnz = 0, dense_dim = 0 and + # sparse_dim = 1 (hence indices is a 2D tensor of shape = (1, 0)) + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), [], [1]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0,)), + size=(1,), nnz=0, layout=torch.sparse_coo) + + # and to create an empty sparse tensor with nnz = 0, dense_dim = 1 and + # sparse_dim = 1 + >>> S = torch.sparse_coo_tensor(torch.empty([1, 0]), torch.empty([0, 2]), [1, 2]) + tensor(indices=tensor([], size=(1, 0)), + values=tensor([], size=(0, 2)), + size=(1, 2), nnz=0, layout=torch.sparse_coo) + + .. _torch.sparse: https://pytorch.org/docs/stable/sparse.html + """ + +def sparse_csc_tensor( + ccol_indices: Tensor | list, + row_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_csc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column) + ` with specified values at the given + :attr:`ccol_indices` and :attr:`row_indices`. Sparse matrix + multiplication operations in CSC format are typically faster than that + for sparse tensors in COO format. Make you have a look at :ref:`the + note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + ccol_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, ncols + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and row_indices depending on where the given column + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + column. + row_indices (array_like): Row co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length as + values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> ccol_indices = [0, 2, 4] + >>> row_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csc_tensor(torch.tensor(ccol_indices, dtype=torch.int64), + ... torch.tensor(row_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(ccol_indices=tensor([0, 2, 4]), + row_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) + """ + +def sparse_csr_tensor( + crow_indices: Tensor | list, + col_indices: Tensor | list, + values: Tensor | list, + size: _size | None = None, + *, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + check_invariants: _bool | None = None, +) -> Tensor: + r""" + sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None) -> Tensor + + Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) ` with specified + values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations + in CSR format are typically faster than that for sparse tensors in COO format. Make you have a look + at :ref:`the note on the data type of the indices `. + + .. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor. + + Args: + crow_indices (array_like): (B+1)-dimensional array of size + ``(*batchsize, nrows + 1)``. The last element of each batch + is the number of non-zeros. This tensor encodes the index in + values and col_indices depending on where the given row + starts. Each successive number in the tensor subtracted by the + number before it denotes the number of elements in a given + row. + col_indices (array_like): Column co-ordinates of each element in + values. (B+1)-dimensional tensor with the same length + as values. + values (array_list): Initial values for the tensor. Can be a list, + tuple, NumPy ``ndarray``, scalar, and other types that + represents a (1+K)-dimensional tensor where ``K`` is the number + of dense dimensions. + size (list, tuple, :class:`torch.Size`, optional): Size of the + sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If + not provided, the size will be inferred as the minimum size + big enough to hold all non-zero elements. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of + returned tensor. Default: if None, infers data type from + :attr:`values`. + device (:class:`torch.device`, optional): the desired device of + returned tensor. Default: if None, uses the current device + for the default tensor type (see + :func:`torch.set_default_device`). :attr:`device` will be + the CPU for CPU tensor types and the current CUDA device for + CUDA tensor types. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. + + Example:: + + >>> crow_indices = [0, 2, 4] + >>> col_indices = [0, 1, 0, 1] + >>> values = [1, 2, 3, 4] + >>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64), + ... torch.tensor(col_indices, dtype=torch.int64), + ... torch.tensor(values), dtype=torch.double) + tensor(crow_indices=tensor([0, 2, 4]), + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64, layout=torch.sparse_csr) + """ + +def split_copy( + input: Tensor, + split_size: _int | SymInt, + dim: _int = 0, + *, + out: tuple[Tensor, ...] | list[Tensor] | None = None, +) -> None: + r""" + Performs the same operation as :func:`torch.split`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def split_with_sizes( + input: Tensor, + split_sizes: Sequence[_int | SymInt], + dim: _int = 0, +) -> tuple[Tensor, ...]: ... +def split_with_sizes_copy( + input: Tensor, + split_sizes: Sequence[_int | SymInt], + dim: _int = 0, + *, + out: tuple[Tensor, ...] | list[Tensor] | None = None, +) -> None: + r""" + Performs the same operation as :func:`torch.split_with_sizes`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def spmm(input: Tensor, mat2: Tensor) -> Tensor: ... +def sqrt(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + sqrt(input, *, out=None) -> Tensor + + Returns a new tensor with the square-root of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \sqrt{\text{input}_{i}} + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.sqrt(a) + tensor([ nan, 1.0112, 0.2883, 0.6933]) + """ + +def sqrt_(input: Tensor) -> Tensor: ... +def square(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + square(input: Tensor, *, out: Optional[Tensor]) -> Tensor + + Returns a new tensor with the square of the elements of :attr:`input`. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-2.0755, 1.0226, 0.0831, 0.4806]) + >>> torch.square(a) + tensor([ 4.3077, 1.0457, 0.0069, 0.2310]) + """ + +def square_(input: Tensor) -> Tensor: ... +@overload +def squeeze(input: Tensor) -> Tensor: + r""" + squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + +@overload +def squeeze(input: Tensor, dim: _int) -> Tensor: + r""" + squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + +@overload +def squeeze(input: Tensor, dim: _size) -> Tensor: + r""" + squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + +@overload +def squeeze(input: Tensor, dim: str | EllipsisType | None) -> Tensor: + r""" + squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor + + Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. + + For example, if `input` is of shape: + :math:`(A \times 1 \times B \times C \times 1 \times D)` then the `input.squeeze()` + will be of shape: :math:`(A \times B \times C \times D)`. + + When :attr:`dim` is given, a squeeze operation is done only in the given + dimension(s). If `input` is of shape: :math:`(A \times 1 \times B)`, + ``squeeze(input, 0)`` leaves the tensor unchanged, but ``squeeze(input, 1)`` + will squeeze the tensor to the shape :math:`(A \times B)`. + + .. note:: The returned tensor shares the storage with the input tensor, + so changing the contents of one will change the contents of the other. + + .. warning:: If the tensor has a batch dimension of size 1, then `squeeze(input)` + will also remove the batch dimension, which can lead to unexpected + errors. Consider specifying only the dims you wish to be squeezed. + + Args: + input (Tensor): the input tensor. + dim (int or tuple of ints, optional): if given, the input will be squeezed + only in the specified dimensions. + + .. versionchanged:: 2.0 + :attr:`dim` now accepts tuples of dimensions. + + Example:: + + >>> x = torch.zeros(2, 1, 2, 1, 2) + >>> x.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x) + >>> y.size() + torch.Size([2, 2, 2]) + >>> y = torch.squeeze(x, 0) + >>> y.size() + torch.Size([2, 1, 2, 1, 2]) + >>> y = torch.squeeze(x, 1) + >>> y.size() + torch.Size([2, 2, 1, 2]) + >>> y = torch.squeeze(x, (1, 2, 3)) + torch.Size([2, 2, 2]) + """ + +@overload +def squeeze_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def squeeze_copy( + input: Tensor, + dim: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def squeeze_copy( + input: Tensor, + dim: _size, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.squeeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def sspaddmm( + beta: Number | _complex, + self: Tensor, + alpha: Number | _complex, + mat1: Tensor, + mat2: Tensor, +) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + +@overload +def sspaddmm( + input: Tensor, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + +@overload +def sspaddmm( + beta: Number | _complex, + self: Tensor, + mat1: Tensor, + mat2: Tensor, +) -> Tensor: + r""" + sspaddmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor + + Matrix multiplies a sparse tensor :attr:`mat1` with a dense tensor + :attr:`mat2`, then adds the sparse tensor :attr:`input` to the result. + + Note: This function is equivalent to :func:`torch.addmm`, except + :attr:`input` and :attr:`mat1` are sparse. + + Args: + input (Tensor): a sparse matrix to be added + mat1 (Tensor): a sparse matrix to be matrix multiplied + mat2 (Tensor): a dense matrix to be matrix multiplied + + Keyword args: + beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) + out (Tensor, optional): the output tensor. + """ + +def stack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + dim: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + stack(tensors, dim=0, *, out=None) -> Tensor + + Concatenates a sequence of tensors along a new dimension. + + All tensors need to be of the same size. + + .. seealso:: + + :func:`torch.cat` concatenates the given sequence along an existing dimension. + + Arguments: + tensors (sequence of Tensors): sequence of tensors to concatenate + dim (int, optional): dimension to insert. Has to be between 0 and the number + of dimensions of concatenated tensors (inclusive). Default: 0 + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]) + >>> torch.stack((x, x)) # same as torch.stack((x, x), dim=0) + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]], + + [[ 0.3367, 0.1288, 0.2345], + [ 0.2303, -1.1229, -0.1863]]]) + >>> torch.stack((x, x)).size() + torch.Size([2, 2, 3]) + >>> torch.stack((x, x), dim=1) + tensor([[[ 0.3367, 0.1288, 0.2345], + [ 0.3367, 0.1288, 0.2345]], + + [[ 0.2303, -1.1229, -0.1863], + [ 0.2303, -1.1229, -0.1863]]]) + >>> torch.stack((x, x), dim=2) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + >>> torch.stack((x, x), dim=-1) + tensor([[[ 0.3367, 0.3367], + [ 0.1288, 0.1288], + [ 0.2345, 0.2345]], + + [[ 0.2303, 0.2303], + [-1.1229, -1.1229], + [-0.1863, -0.1863]]]) + """ + +@overload +def std( + input: Tensor, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std( + input: Tensor, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std(input: Tensor, unbiased: _bool = True) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the standard deviation over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + unbiased: _bool = True, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def std_mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the standard deviation and mean over the dimensions specified by + :attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or + ``None`` to reduce over all dimensions. + + The standard deviation (:math:`\sigma`) is calculated as + + .. math:: \sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (std, mean) containing the standard deviation and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def sub( + input: Tensor | Number | _complex, + other: Tensor | Number | _complex, + *, + alpha: Number | _complex | None = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + +@overload +def sub(self: Tensor, alpha: Number | _complex, other: Tensor) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + +@overload +def sub( + self: Tensor, + alpha: Number | _complex, + other: Tensor, + *, + out: Tensor, +) -> Tensor: + r""" + sub(input, other, *, alpha=1, out=None) -> Tensor + + Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + + .. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i + + + Supports :ref:`broadcasting to a common shape `, + :ref:`type promotion `, and integer, float, and complex inputs. + + Args: + input (Tensor): the input tensor. + other (Tensor or Number): the tensor or number to subtract from :attr:`input`. + + Keyword args: + alpha (Number): the multiplier for :attr:`other`. + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) + """ + +@overload +def subtract( + input: Tensor, + other: Tensor, + *, + alpha: Number | _complex = 1, + out: Tensor | None = None, +) -> Tensor: + r""" + subtract(input, other, *, alpha=1, out=None) -> Tensor + + Alias for :func:`torch.sub`. + """ + +@overload +def subtract( + input: Tensor, + other: Number | _complex, + alpha: Number | _complex = 1, +) -> Tensor: + r""" + subtract(input, other, *, alpha=1, out=None) -> Tensor + + Alias for :func:`torch.sub`. + """ + +@overload +def sum(input: Tensor, *, dtype: _dtype | None = None) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + .. note:: Use the `dtype` argument if you need the result in a specific tensor type. + Otherwise, the result type may be automatically promoted (e.g., from `torch.int32` to `torch.int64`). + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + +@overload +def sum( + input: Tensor, + dim: _int | _size | None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + .. note:: Use the `dtype` argument if you need the result in a specific tensor type. + Otherwise, the result type may be automatically promoted (e.g., from `torch.int32` to `torch.int64`). + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + +@overload +def sum( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + out: Tensor | None = None, +) -> Tensor: + r""" + sum(input, *, dtype=None) -> Tensor + + Returns the sum of all elements in the :attr:`input` tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + .. note:: Use the `dtype` argument if you need the result in a specific tensor type. + Otherwise, the result type may be automatically promoted (e.g., from `torch.int32` to `torch.int64`). + + Example:: + + >>> a = torch.randn(1, 3) + >>> a + tensor([[ 0.1133, -0.9567, 0.2958]]) + >>> torch.sum(a) + tensor(-0.5475) + + .. function:: sum(input, dim, keepdim=False, *, dtype=None) -> Tensor + :noindex: + + Returns the sum of each row of the :attr:`input` tensor in the given + dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, + reduce over all of them. + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + Example:: + + >>> a = torch.randn(4, 4) + >>> a + tensor([[ 0.0569, -0.2475, 0.0737, -0.3429], + [-0.2993, 0.9138, 0.9337, -1.6864], + [ 0.1132, 0.7892, -0.1003, 0.5688], + [ 0.3637, -0.9906, -0.4752, -1.5197]]) + >>> torch.sum(a, 1) + tensor([-0.4598, -0.1381, 1.3708, -2.6217]) + >>> b = torch.arange(4 * 5 * 6).view(4, 5, 6) + >>> torch.sum(b, (2, 1)) + tensor([ 435., 1335., 2235., 3135.]) + """ + +def svd( + input: Tensor, + some: _bool = True, + compute_uv: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.svd: + r""" + svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) + + Computes the singular value decomposition of either a matrix or batch of + matrices :attr:`input`. The singular value decomposition is represented as a + namedtuple `(U, S, V)`, such that :attr:`input` :math:`= U \text{diag}(S) V^{\text{H}}`. + where :math:`V^{\text{H}}` is the transpose of `V` for real inputs, + and the conjugate transpose of `V` for complex inputs. + If :attr:`input` is a batch of matrices, then `U`, `S`, and `V` are also + batched with the same batch dimensions as :attr:`input`. + + If :attr:`some` is `True` (default), the method returns the reduced singular + value decomposition. In this case, if the last two dimensions of :attr:`input` are + `m` and `n`, then the returned `U` and `V` matrices will contain only + `min(n, m)` orthonormal columns. + + If :attr:`compute_uv` is `False`, the returned `U` and `V` will be + zero-filled matrices of shape `(m, m)` and `(n, n)` + respectively, and the same device as :attr:`input`. The argument :attr:`some` + has no effect when :attr:`compute_uv` is `False`. + + Supports :attr:`input` of float, double, cfloat and cdouble data types. + The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will + always be real-valued, even if :attr:`input` is complex. + + .. warning:: + + :func:`torch.svd` is deprecated in favor of :func:`torch.linalg.svd` + and will be removed in a future PyTorch release. + + ``U, S, V = torch.svd(A, some=some, compute_uv=True)`` (default) should be replaced with + + .. code:: python + + U, S, Vh = torch.linalg.svd(A, full_matrices=not some) + V = Vh.mH + + ``_, S, _ = torch.svd(A, some=some, compute_uv=False)`` should be replaced with + + .. code:: python + + S = torch.linalg.svdvals(A) + + .. note:: Differences with :func:`torch.linalg.svd`: + + * :attr:`some` is the opposite of + :func:`torch.linalg.svd`'s :attr:`full_matrices`. Note that + default value for both is `True`, so the default behavior is + effectively the opposite. + * :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns + `Vh`, that is, :math:`V^{\text{H}}`. + * If :attr:`compute_uv` is `False`, :func:`torch.svd` returns zero-filled + tensors for `U` and `Vh`, whereas :func:`torch.linalg.svd` returns + empty tensors. + + .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, + then the singular values of each matrix in the batch are returned in descending order. + + .. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is `True`. + + .. note:: When :attr:`some` is `False`, the gradients on `U[..., :, min(m, n):]` + and `V[..., :, min(m, n):]` will be ignored in the backward pass, as those vectors + can be arbitrary bases of the corresponding subspaces. + + .. note:: The implementation of :func:`torch.linalg.svd` on CPU uses LAPACK's routine `?gesdd` + (a divide-and-conquer algorithm) instead of `?gesvd` for speed. Analogously, + on GPU, it uses cuSOLVER's routines `gesvdj` and `gesvdjBatched` on CUDA 10.1.243 + and later, and MAGMA's routine `gesdd` on earlier versions of CUDA. + + .. note:: The returned `U` will not be contiguous. The matrix (or batch of matrices) will + be represented as a column-major matrix (i.e. Fortran-contiguous). + + .. warning:: The gradients with respect to `U` and `V` will only be finite when the input does not + have zero nor repeated singular values. + + .. warning:: If the distance between any two singular values is close to zero, the gradients with respect to + `U` and `V` will be numerically unstable, as they depends on + :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. The same happens when the matrix + has small singular values, as these gradients also depend on `S^{-1}`. + + .. warning:: For complex-valued :attr:`input` the singular value decomposition is not unique, + as `U` and `V` may be multiplied by an arbitrary phase factor :math:`e^{i \phi}` on every column. + The same happens when :attr:`input` has repeated singular values, where one may multiply + the columns of the spanning subspace in `U` and `V` by a rotation matrix + and `the resulting vectors will span the same subspace`_. + Different platforms, like NumPy, or inputs on different device types, + may produce different `U` and `V` tensors. + + Args: + input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more + batch dimensions consisting of `(m, n)` matrices. + some (bool, optional): controls whether to compute the reduced or full decomposition, and + consequently, the shape of returned `U` and `V`. Default: `True`. + compute_uv (bool, optional): controls whether to compute `U` and `V`. Default: `True`. + + Keyword args: + out (tuple, optional): the output tuple of tensors + + Example:: + + >>> a = torch.randn(5, 3) + >>> a + tensor([[ 0.2364, -0.7752, 0.6372], + [ 1.7201, 0.7394, -0.0504], + [-0.3371, -1.0584, 0.5296], + [ 0.3550, -0.4022, 1.5569], + [ 0.2445, -0.0158, 1.1414]]) + >>> u, s, v = torch.svd(a) + >>> u + tensor([[ 0.4027, 0.0287, 0.5434], + [-0.1946, 0.8833, 0.3679], + [ 0.4296, -0.2890, 0.5261], + [ 0.6604, 0.2717, -0.2618], + [ 0.4234, 0.2481, -0.4733]]) + >>> s + tensor([2.3289, 2.0315, 0.7806]) + >>> v + tensor([[-0.0199, 0.8766, 0.4809], + [-0.5080, 0.4054, -0.7600], + [ 0.8611, 0.2594, -0.4373]]) + >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) + tensor(8.6531e-07) + >>> a_big = torch.randn(7, 5, 3) + >>> u, s, v = torch.svd(a_big) + >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.mT)) + tensor(2.6503e-06) + + .. _the resulting vectors will span the same subspace: + (https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD) + """ + +def swapaxes(input: Tensor, axis0: _int, axis1: _int) -> Tensor: + r""" + swapaxes(input, axis0, axis1) -> Tensor + + Alias for :func:`torch.transpose`. + + This function is equivalent to NumPy's swapaxes function. + + Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapaxes(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) + """ + +def swapdims(input: Tensor, dim0: _int, dim1: _int) -> Tensor: + r""" + swapdims(input, dim0, dim1) -> Tensor + + Alias for :func:`torch.transpose`. + + This function is equivalent to NumPy's swapaxes function. + + Examples:: + + >>> x = torch.tensor([[[0,1],[2,3]],[[4,5],[6,7]]]) + >>> x + tensor([[[0, 1], + [2, 3]], + + [[4, 5], + [6, 7]]]) + >>> torch.swapdims(x, 0, 1) + tensor([[[0, 1], + [4, 5]], + + [[2, 3], + [6, 7]]]) + >>> torch.swapdims(x, 0, 2) + tensor([[[0, 4], + [2, 6]], + + [[1, 5], + [3, 7]]]) + """ + +def sym_constrain_range( + size: Number | _complex, + *, + min: _int | None = None, + max: _int | None = None, +) -> None: ... +def sym_constrain_range_for_size( + size: Number | _complex, + *, + min: _int | None = None, + max: _int | None = None, +) -> None: ... +def t(input: Tensor) -> Tensor: + r""" + t(input) -> Tensor + + Expects :attr:`input` to be <= 2-D tensor and transposes dimensions 0 + and 1. + + 0-D and 1-D tensors are returned as is. When input is a 2-D tensor this + is equivalent to ``transpose(input, 0, 1)``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x = torch.randn(()) + >>> x + tensor(0.1995) + >>> torch.t(x) + tensor(0.1995) + >>> x = torch.randn(3) + >>> x + tensor([ 2.4320, -0.4608, 0.7702]) + >>> torch.t(x) + tensor([ 2.4320, -0.4608, 0.7702]) + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 0.4875, 0.9158, -0.5872], + [ 0.3938, -0.6929, 0.6932]]) + >>> torch.t(x) + tensor([[ 0.4875, 0.3938], + [ 0.9158, -0.6929], + [-0.5872, 0.6932]]) + + See also :func:`torch.transpose`. + """ + +def t_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.t`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def take( + input: Tensor, + index: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + take(input, index) -> Tensor + + Returns a new tensor with the elements of :attr:`input` at the given indices. + The input tensor is treated as if it were viewed as a 1-D tensor. The result + takes the same shape as the indices. + + Args: + input (Tensor): the input tensor. + index (LongTensor): the indices into tensor + + Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> torch.take(src, torch.tensor([0, 2, 5])) + tensor([ 4, 5, 8]) + """ + +def take_along_dim( + input: Tensor, + indices: Tensor, + dim: _int | None = None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + take_along_dim(input, indices, dim=None, *, out=None) -> Tensor + + Selects values from :attr:`input` at the 1-dimensional indices from :attr:`indices` along the given :attr:`dim`. + + If :attr:`dim` is None, the input array is treated as if it has been flattened to 1d. + + Functions that return indices along a dimension, like :func:`torch.argmax` and :func:`torch.argsort`, + are designed to work with this function. See the examples below. + + .. note:: + This function is similar to NumPy's `take_along_axis`. + See also :func:`torch.gather`. + + Args: + input (Tensor): the input tensor. + indices (LongTensor): the indices into :attr:`input`. Must have long dtype. + dim (int, optional): dimension to select along. Default: 0 + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) + >>> max_idx = torch.argmax(t) + >>> torch.take_along_dim(t, max_idx) + tensor([60]) + >>> sorted_idx = torch.argsort(t, dim=1) + >>> torch.take_along_dim(t, sorted_idx, dim=1) + tensor([[10, 20, 30], + [40, 50, 60]]) + """ + +def tan(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + tan(input, *, out=None) -> Tensor + + Returns a new tensor with the tangent of the elements of :attr:`input`. + + .. math:: + \text{out}_{i} = \tan(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([-1.2027, -1.7687, 0.4412, -1.3856]) + >>> torch.tan(a) + tensor([-2.5930, 4.9859, 0.4722, -5.3366]) + """ + +def tan_(input: Tensor) -> Tensor: ... +def tanh(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + tanh(input, *, out=None) -> Tensor + + Returns a new tensor with the hyperbolic tangent of the elements + of :attr:`input`. + + .. math:: + \text{out}_{i} = \tanh(\text{input}_{i}) + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 0.8986, -0.7279, 1.1745, 0.2611]) + >>> torch.tanh(a) + tensor([ 0.7156, -0.6218, 0.8257, 0.2553]) + """ + +def tanh_(input: Tensor) -> Tensor: ... +def tensor( + data: Any, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, +) -> Tensor: + r""" + tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor + + Constructs a tensor with no autograd history (also known as a "leaf tensor", see :doc:`/notes/autograd`) by copying :attr:`data`. + + .. warning:: + + When working with tensors prefer using :func:`torch.Tensor.clone`, + :func:`torch.Tensor.detach`, and :func:`torch.Tensor.requires_grad_` for + readability. Letting `t` be a tensor, ``torch.tensor(t)`` is equivalent to + ``t.detach().clone()``, and ``torch.tensor(t, requires_grad=True)`` + is equivalent to ``t.detach().clone().requires_grad_(True)``. + + .. seealso:: + + :func:`torch.as_tensor` preserves autograd history and avoids copies where possible. + :func:`torch.from_numpy` creates a tensor that shares storage with a NumPy array. + + Args: + data (array_like): Initial data for the tensor. Can be a list, tuple, + NumPy ``ndarray``, scalar, and other types. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, infers data type from :attr:`data`. + device (:class:`torch.device`, optional): the device of the constructed tensor. If None and data is a tensor + then the device of data is used. If None and data is not a tensor then + the result tensor is constructed on the current device. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + + Example:: + + >>> torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) + tensor([[ 0.1000, 1.2000], + [ 2.2000, 3.1000], + [ 4.9000, 5.2000]]) + + >>> torch.tensor([0, 1]) # Type inference on data + tensor([ 0, 1]) + + >>> torch.tensor([[0.11111, 0.222222, 0.3333333]], + ... dtype=torch.float64, + ... device=torch.device('cuda:0')) # creates a double tensor on a CUDA device + tensor([[ 0.1111, 0.2222, 0.3333]], dtype=torch.float64, device='cuda:0') + + >>> torch.tensor(3.14159) # Create a zero-dimensional (scalar) tensor + tensor(3.1416) + + >>> torch.tensor([]) # Create an empty tensor (of size (0,)) + tensor([]) + """ + +@overload +def tensor_split( + input: Tensor, + tensor_indices_or_sections: Tensor, + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + +@overload +def tensor_split( + input: Tensor, + sections: _int | SymInt, + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + +@overload +def tensor_split( + input: Tensor, + indices: Sequence[_int | SymInt], + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + tensor_split(input, indices_or_sections, dim=0) -> List of Tensors + + Splits a tensor into multiple sub-tensors, all of which are views of :attr:`input`, + along dimension :attr:`dim` according to the indices or number of sections specified + by :attr:`indices_or_sections`. This function is based on NumPy's + :func:`numpy.array_split`. + + Args: + input (Tensor): the tensor to split + indices_or_sections (Tensor, int or list or tuple of ints): + If :attr:`indices_or_sections` is an integer ``n`` or a zero dimensional long tensor + with value ``n``, :attr:`input` is split into ``n`` sections along dimension :attr:`dim`. + If :attr:`input` is divisible by ``n`` along dimension :attr:`dim`, each + section will be of equal size, :code:`input.size(dim) / n`. If :attr:`input` + is not divisible by ``n``, the sizes of the first :code:`int(input.size(dim) % n)` + sections will have size :code:`int(input.size(dim) / n) + 1`, and the rest will + have size :code:`int(input.size(dim) / n)`. + + If :attr:`indices_or_sections` is a list or tuple of ints, or a one-dimensional long + tensor, then :attr:`input` is split along dimension :attr:`dim` at each of the indices + in the list, tuple or tensor. For instance, :code:`indices_or_sections=[2, 3]` and :code:`dim=0` + would result in the tensors :code:`input[:2]`, :code:`input[2:3]`, and :code:`input[3:]`. + + If :attr:`indices_or_sections` is a tensor, it must be a zero-dimensional or one-dimensional + long tensor on the CPU. + + dim (int, optional): dimension along which to split the tensor. Default: ``0`` + + Example:: + + >>> x = torch.arange(8) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) + + >>> x = torch.arange(7) + >>> torch.tensor_split(x, 3) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + >>> torch.tensor_split(x, (1, 6)) + (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) + + >>> x = torch.arange(14).reshape(2, 7) + >>> x + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13]]) + >>> torch.tensor_split(x, 3, dim=1) + (tensor([[0, 1, 2], + [7, 8, 9]]), + tensor([[ 3, 4], + [10, 11]]), + tensor([[ 5, 6], + [12, 13]])) + >>> torch.tensor_split(x, (1, 6), dim=1) + (tensor([[0], + [7]]), + tensor([[ 1, 2, 3, 4, 5], + [ 8, 9, 10, 11, 12]]), + tensor([[ 6], + [13]])) + """ + +def threshold( + input: Tensor, + threshold: Number | _complex, + value: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: ... +def threshold_( + input: Tensor, + threshold: Number | _complex, + value: Number | _complex, +) -> Tensor: ... +def tile(input: Tensor, dims: Sequence[_int | SymInt]) -> Tensor: + r""" + tile(input, dims) -> Tensor + + Constructs a tensor by repeating the elements of :attr:`input`. + The :attr:`dims` argument specifies the number of repetitions + in each dimension. + + If :attr:`dims` specifies fewer dimensions than :attr:`input` has, then + ones are prepended to :attr:`dims` until all dimensions are specified. + For example, if :attr:`input` has shape (8, 6, 4, 2) and :attr:`dims` + is (2, 2), then :attr:`dims` is treated as (1, 1, 2, 2). + + Analogously, if :attr:`input` has fewer dimensions than :attr:`dims` + specifies, then :attr:`input` is treated as if it were unsqueezed at + dimension zero until it has as many dimensions as :attr:`dims` specifies. + For example, if :attr:`input` has shape (4, 2) and :attr:`dims` + is (3, 3, 2, 2), then :attr:`input` is treated as if it had the + shape (1, 1, 4, 2). + + .. note:: + + This function is similar to NumPy's tile function. + + Args: + input (Tensor): the tensor whose elements to repeat. + dims (tuple): the number of repetitions per dimension. + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.tile((2,)) + tensor([1, 2, 3, 1, 2, 3]) + >>> y = torch.tensor([[1, 2], [3, 4]]) + >>> torch.tile(y, (2, 2)) + tensor([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) + """ + +def topk( + input: Tensor, + k: _int | SymInt, + dim: _int = -1, + largest: _bool = True, + sorted: _bool = True, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.topk: + r""" + topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor) + + Returns the :attr:`k` largest elements of the given :attr:`input` tensor along + a given dimension. + + If :attr:`dim` is not given, the last dimension of the `input` is chosen. + + If :attr:`largest` is ``False`` then the `k` smallest elements are returned. + + A namedtuple of `(values, indices)` is returned with the `values` and + `indices` of the largest `k` elements of each row of the `input` tensor in the + given dimension `dim`. + + The boolean option :attr:`sorted` if ``True``, will make sure that the returned + `k` elements are themselves sorted + + .. note:: + When using `torch.topk`, the indices of tied elements are not guaranteed to be stable + and may vary across different invocations. + + Args: + input (Tensor): the input tensor. + k (int): the k in "top-k" + dim (int, optional): the dimension to sort along + largest (bool, optional): controls whether to return largest or + smallest elements + sorted (bool, optional): controls whether to return the elements + in sorted order + + Keyword args: + out (tuple, optional): the output tuple of (Tensor, LongTensor) that can be + optionally given to be used as output buffers + + Example:: + + >>> x = torch.arange(1., 6.) + >>> x + tensor([ 1., 2., 3., 4., 5.]) + >>> torch.topk(x, 3) + torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2])) + """ + +def trace(input: Tensor) -> Tensor: + r""" + trace(input) -> Tensor + + Returns the sum of the elements of the diagonal of the input 2-D matrix. + + Example:: + + >>> x = torch.arange(1., 10.).view(3, 3) + >>> x + tensor([[ 1., 2., 3.], + [ 4., 5., 6.], + [ 7., 8., 9.]]) + >>> torch.trace(x) + tensor(15.) + """ + +@overload +def transpose(input: Tensor, dim0: _int, dim1: _int) -> Tensor: + r""" + transpose(input, dim0, dim1) -> Tensor + + Returns a tensor that is a transposed version of :attr:`input`. + The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + + If :attr:`input` is a strided tensor then the resulting :attr:`out` + tensor shares its underlying storage with the :attr:`input` tensor, so + changing the content of one would change the content of the other. + + If :attr:`input` is a :ref:`sparse tensor ` then the + resulting :attr:`out` tensor *does not* share the underlying storage + with the :attr:`input` tensor. + + If :attr:`input` is a :ref:`sparse tensor ` with compressed + layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments + :attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must + both be sparse dimensions. The batch dimensions of a sparse tensor are the + dimensions preceding the sparse dimensions. + + .. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + + Args: + input (Tensor): the input tensor. + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + + See also :func:`torch.t`. + """ + +@overload +def transpose( + input: Tensor, + dim0: str | EllipsisType | None, + dim1: str | EllipsisType | None, +) -> Tensor: + r""" + transpose(input, dim0, dim1) -> Tensor + + Returns a tensor that is a transposed version of :attr:`input`. + The given dimensions :attr:`dim0` and :attr:`dim1` are swapped. + + If :attr:`input` is a strided tensor then the resulting :attr:`out` + tensor shares its underlying storage with the :attr:`input` tensor, so + changing the content of one would change the content of the other. + + If :attr:`input` is a :ref:`sparse tensor ` then the + resulting :attr:`out` tensor *does not* share the underlying storage + with the :attr:`input` tensor. + + If :attr:`input` is a :ref:`sparse tensor ` with compressed + layout (SparseCSR, SparseBSR, SparseCSC or SparseBSC) the arguments + :attr:`dim0` and :attr:`dim1` must be both batch dimensions, or must + both be sparse dimensions. The batch dimensions of a sparse tensor are the + dimensions preceding the sparse dimensions. + + .. note:: + Transpositions which interchange the sparse dimensions of a `SparseCSR` + or `SparseCSC` layout tensor will result in the layout changing between + the two options. Transposition of the sparse dimensions of a ` SparseBSR` + or `SparseBSC` layout tensor will likewise generate a result with the + opposite layout. + + + Args: + input (Tensor): the input tensor. + dim0 (int): the first dimension to be transposed + dim1 (int): the second dimension to be transposed + + Example:: + + >>> x = torch.randn(2, 3) + >>> x + tensor([[ 1.0028, -0.9893, 0.5809], + [-0.1669, 0.7299, 0.4942]]) + >>> torch.transpose(x, 0, 1) + tensor([[ 1.0028, -0.1669], + [-0.9893, 0.7299], + [ 0.5809, 0.4942]]) + + See also :func:`torch.t`. + """ + +def transpose_copy( + input: Tensor, + dim0: _int, + dim1: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.transpose`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def trapezoid(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Computes the `trapezoidal rule `_ along + :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. Only one of :attr:`x` or :attr:`dx` should be specified. + + + Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, + the default computation is + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`dx` is specified the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + + effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, + assuming :attr:`x` is also a one-dimensional tensor with + elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. + The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` + and :attr:`y`, the function computes the difference between consecutive elements along + dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have + the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. + After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. + See the examples below for details. + + .. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) + """ + +@overload +def trapezoid( + y: Tensor, + *, + dx: Number | _complex = 1, + dim: _int = -1, +) -> Tensor: + r""" + trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor + + Computes the `trapezoidal rule `_ along + :attr:`dim`. By default the spacing between elements is assumed to be 1, but + :attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be + used to specify arbitrary spacing along :attr:`dim`. Only one of :attr:`x` or :attr:`dx` should be specified. + + + Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`, + the default computation is + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{1}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`dx` is specified the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{\Delta x}{2} (y_i + y_{i-1}) + \end{aligned} + + effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified, + assuming :attr:`x` is also a one-dimensional tensor with + elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes + + .. math:: + \begin{aligned} + \sum_{i = 1}^{n} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1}) + \end{aligned} + + When :attr:`x` and :attr:`y` have the same size, the computation is as described above and no broadcasting is needed. + The broadcasting behavior of this function is as follows when their sizes are different. For both :attr:`x` + and :attr:`y`, the function computes the difference between consecutive elements along + dimension :attr:`dim`. This effectively creates two tensors, `x_diff` and `y_diff`, that have + the same shape as the original tensors except their lengths along the dimension :attr:`dim` is reduced by 1. + After that, those two tensors are broadcast together to compute final output as part of the trapezoidal rule. + See the examples below for details. + + .. note:: + The trapezoidal rule is a technique for approximating the definite integral of a function + by averaging its left and right Riemann sums. The approximation becomes more accurate as + the resolution of the partition increases. + + Arguments: + y (Tensor): Values to use when computing the trapezoidal rule. + x (Tensor): If specified, defines spacing between values as specified above. + + Keyword arguments: + dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx` + are specified then this defaults to 1. Effectively multiplies the result by its value. + dim (int): The dimension along which to compute the trapezoidal rule. + The last (inner-most) dimension by default. + + Examples:: + + >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1 + >>> y = torch.tensor([1, 5, 10]) + >>> torch.trapezoid(y) + tensor(10.5) + + >>> # Computes the same trapezoidal rule directly to verify + >>> (1 + 10 + 10) / 2 + 10.5 + + >>> # Computes the trapezoidal rule in 1D with constant spacing of 2 + >>> # NOTE: the result is the same as before, but multiplied by 2 + >>> torch.trapezoid(y, dx=2) + 21.0 + + >>> # Computes the trapezoidal rule in 1D with arbitrary spacing + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + 28.5 + + >>> # Computes the same trapezoidal rule directly to verify + >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2 + 28.5 + + >>> # Computes the trapezoidal rule for each row of a 3x3 matrix + >>> y = torch.arange(9).reshape(3, 3) + tensor([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> torch.trapezoid(y) + tensor([ 2., 8., 14.]) + + >>> # Computes the trapezoidal rule for each column of the matrix + >>> torch.trapezoid(y, dim=0) + tensor([ 6., 8., 10.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with the same arbitrary spacing + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([1, 3, 6]) + >>> torch.trapezoid(y, x) + array([5., 5., 5.]) + + >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix + >>> # with different arbitrary spacing per row + >>> y = torch.ones(3, 3) + >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]]) + >>> torch.trapezoid(y, x) + array([2., 4., 6.]) + """ + +@overload +def trapz(y: Tensor, *, dx: _float = 1, dim: _int = -1) -> Tensor: + r""" + trapz(y, x, *, dim=-1) -> Tensor + + Alias for :func:`torch.trapezoid`. + """ + +@overload +def trapz(y: Tensor, x: Tensor, *, dim: _int = -1) -> Tensor: + r""" + trapz(y, x, *, dim=-1) -> Tensor + + Alias for :func:`torch.trapezoid`. + """ + +def triangular_solve( + input: Tensor, + A: Tensor, + upper: _bool = True, + transpose: _bool = False, + unitriangular: _bool = False, + *, + out: Tensor | tuple[Tensor, ...] | list[Tensor] | None = None, +) -> torch.return_types.triangular_solve: + r""" + triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None) -> (Tensor, Tensor) + + Solves a system of equations with a square upper or lower triangular invertible matrix :math:`A` + and multiple right-hand sides :math:`b`. + + In symbols, it solves :math:`AX = b` and assumes :math:`A` is square upper-triangular + (or lower-triangular if :attr:`upper`\ `= False`) and does not have zeros on the diagonal. + + `torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are + batches of 2D matrices. If the inputs are batches, then returns + batched outputs `X` + + If the diagonal of :attr:`A` contains zeros or elements that are very close to zero and + :attr:`unitriangular`\ `= False` (default) or if the input matrix is badly conditioned, + the result may contain `NaN` s. + + Supports input of float, double, cfloat and cdouble data types. + + .. warning:: + + :func:`torch.triangular_solve` is deprecated in favor of :func:`torch.linalg.solve_triangular` + and will be removed in a future PyTorch release. + :func:`torch.linalg.solve_triangular` has its arguments reversed and does not return a + copy of one of the inputs. + + ``X = torch.triangular_solve(B, A).solution`` should be replaced with + + .. code:: python + + X = torch.linalg.solve_triangular(A, B) + + Args: + b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where + :math:`*` is zero of more batch dimensions + A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)` + where :math:`*` is zero or more batch dimensions + upper (bool, optional): whether :math:`A` is upper or lower triangular. Default: ``True``. + transpose (bool, optional): solves `op(A)X = b` where `op(A) = A^T` if this flag is ``True``, + and `op(A) = A` if it is ``False``. Default: ``False``. + unitriangular (bool, optional): whether :math:`A` is unit triangular. + If True, the diagonal elements of :math:`A` are assumed to be + 1 and not referenced from :math:`A`. Default: ``False``. + + Keyword args: + out ((Tensor, Tensor), optional): tuple of two tensors to write + the output to. Ignored if `None`. Default: `None`. + + Returns: + A namedtuple `(solution, cloned_coefficient)` where `cloned_coefficient` + is a clone of :math:`A` and `solution` is the solution :math:`X` to :math:`AX = b` + (or whatever variant of the system of equations, depending on the keyword arguments.) + + Examples:: + + >>> A = torch.randn(2, 2).triu() + >>> A + tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]]) + >>> b = torch.randn(2, 3) + >>> b + tensor([[-0.0210, 2.3513, -1.5492], + [ 1.5429, 0.7403, -1.0243]]) + >>> torch.triangular_solve(b, A) + torch.return_types.triangular_solve( + solution=tensor([[ 1.7841, 2.9046, -2.5405], + [ 1.9320, 0.9270, -1.2826]]), + cloned_coefficient=tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]])) + """ + +def tril( + input: Tensor, + diagonal: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + tril(input, diagonal=0, *, out=None) -> Tensor + + Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices + :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + + The lower triangular part of the matrix is defined as the elements on and + below the diagonal. + + The argument :attr:`diagonal` controls which diagonal to consider. If + :attr:`diagonal` = 0, all elements on and below the main diagonal are + retained. A positive value includes just as many diagonals above the main + diagonal, and similarly a negative value excludes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[-1.0813, -0.8619, 0.7105], + [ 0.0935, 0.1380, 2.2112], + [-0.3409, -0.9828, 0.0289]]) + >>> torch.tril(a) + tensor([[-1.0813, 0.0000, 0.0000], + [ 0.0935, 0.1380, 0.0000], + [-0.3409, -0.9828, 0.0289]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 1.2219, 0.5653, -0.2521, -0.2345, 1.2544, 0.3461], + [ 0.4785, -0.4477, 0.6049, 0.6368, 0.8775, 0.7145], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.3615, 0.6864], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0978]]) + >>> torch.tril(b, diagonal=1) + tensor([[ 1.2219, 0.5653, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, -0.4477, 0.6049, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, -1.1243, -0.5413, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024, 0.0000]]) + >>> torch.tril(b, diagonal=-1) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.4785, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [ 1.1502, 3.2716, 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]]) + """ + +def tril_indices( + row: _int, + col: _int, + offset: _int = 0, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + tril_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + + Returns the indices of the lower triangular part of a :attr:`row`-by- + :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row + coordinates of all indices and the second row contains column coordinates. + Indices are ordered based on rows and then columns. + + The lower triangular part of the matrix is defined as the elements on and + below the diagonal. + + The argument :attr:`offset` controls which diagonal to consider. If + :attr:`offset` = 0, all elements on and below the main diagonal are + retained. A positive value includes just as many diagonals above the main + diagonal, and similarly a negative value excludes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` + where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + .. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. + + Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + + Example:: + + >>> a = torch.tril_indices(3, 3) + >>> a + tensor([[0, 1, 1, 2, 2, 2], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, -1) + >>> a + tensor([[1, 2, 2, 3, 3, 3], + [0, 0, 1, 0, 1, 2]]) + + >>> a = torch.tril_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]]) + """ + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: _float = 1.0, + p: _float = 2, + eps: _float = 1e-06, + swap: _bool = False, + reduction: _int = 1, +) -> Tensor: ... +def triu( + input: Tensor, + diagonal: _int = 0, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + triu(input, diagonal=0, *, out=None) -> Tensor + + Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices + :attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. + + The upper triangular part of the matrix is defined as the elements on and + above the diagonal. + + The argument :attr:`diagonal` controls which diagonal to consider. If + :attr:`diagonal` = 0, all elements on and above the main diagonal are + retained. A positive value excludes just as many diagonals above the main + diagonal, and similarly a negative value includes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + Args: + input (Tensor): the input tensor. + diagonal (int, optional): the diagonal to consider + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(3, 3) + >>> a + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.3480, -0.5211, -0.4573]]) + >>> torch.triu(a) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.0000, -1.0680, 0.6602], + [ 0.0000, 0.0000, -0.4573]]) + >>> torch.triu(a, diagonal=1) + tensor([[ 0.0000, 0.5207, 2.0049], + [ 0.0000, 0.0000, 0.6602], + [ 0.0000, 0.0000, 0.0000]]) + >>> torch.triu(a, diagonal=-1) + tensor([[ 0.2309, 0.5207, 2.0049], + [ 0.2072, -1.0680, 0.6602], + [ 0.0000, -0.5211, -0.4573]]) + + >>> b = torch.randn(4, 6) + >>> b + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.4333, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=1) + tensor([[ 0.0000, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [ 0.0000, 0.0000, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.0000, 0.0000, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=-1) + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, -1.3337, -1.6556, 0.4798, 0.2830]]) + """ + +def triu_indices( + row: _int, + col: _int, + offset: _int = 0, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + triu_indices(row, col, offset=0, *, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor + + Returns the indices of the upper triangular part of a :attr:`row` by + :attr:`col` matrix in a 2-by-N Tensor, where the first row contains row + coordinates of all indices and the second row contains column coordinates. + Indices are ordered based on rows and then columns. + + The upper triangular part of the matrix is defined as the elements on and + above the diagonal. + + The argument :attr:`offset` controls which diagonal to consider. If + :attr:`offset` = 0, all elements on and above the main diagonal are + retained. A positive value excludes just as many diagonals above the main + diagonal, and similarly a negative value includes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` + where :math:`d_{1}, d_{2}` are the dimensions of the matrix. + + .. note:: + When running on CUDA, ``row * col`` must be less than :math:`2^{59}` to + prevent overflow during calculation. + + Args: + row (``int``): number of rows in the 2-D matrix. + col (``int``): number of columns in the 2-D matrix. + offset (``int``): diagonal offset from the main diagonal. + Default: if not provided, 0. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. + + Example:: + + >>> a = torch.triu_indices(3, 3) + >>> a + tensor([[0, 0, 0, 1, 1, 2], + [0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, -1) + >>> a + tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3], + [0, 1, 2, 0, 1, 2, 1, 2, 2]]) + + >>> a = torch.triu_indices(4, 3, 1) + >>> a + tensor([[0, 0, 1], + [1, 2, 2]]) + """ + +def true_divide( + input: Tensor | Number, + other: Tensor | Number, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + true_divide(dividend, divisor, *, out) -> Tensor + + Alias for :func:`torch.div` with ``rounding_mode=None``. + """ + +def trunc(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + trunc(input, *, out=None) -> Tensor + + Returns a new tensor with the truncated integer values of + the elements of :attr:`input`. + + For integer inputs, follows the array-api convention of returning a + copy of the input tensor. + + Args: + input (Tensor): the input tensor. + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.randn(4) + >>> a + tensor([ 3.4742, 0.5466, -0.8008, -0.9079]) + >>> torch.trunc(a) + tensor([ 3., 0., -0., -0.]) + """ + +def trunc_(input: Tensor) -> Tensor: ... +@overload +def unbind(input: Tensor, dim: _int = 0) -> tuple[Tensor, ...]: + r""" + unbind(input, dim=0) -> seq + + Removes a tensor dimension. + + Returns a tuple of all slices along a given dimension, already without it. + + Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + + Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) + """ + +@overload +def unbind( + input: Tensor, + dim: str | EllipsisType | None, +) -> tuple[Tensor, ...]: + r""" + unbind(input, dim=0) -> seq + + Removes a tensor dimension. + + Returns a tuple of all slices along a given dimension, already without it. + + Arguments: + input (Tensor): the tensor to unbind + dim (int): dimension to remove + + Example:: + + >>> torch.unbind(torch.tensor([[1, 2, 3], + >>> [4, 5, 6], + >>> [7, 8, 9]])) + (tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9])) + """ + +def unbind_copy( + input: Tensor, + dim: _int = 0, + *, + out: tuple[Tensor, ...] | list[Tensor] | None = None, +) -> None: + r""" + Performs the same operation as :func:`torch.unbind`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def unflatten( + input: Tensor, + dim: str | EllipsisType | None, + sizes: Sequence[_int | SymInt], + names: Sequence[str | EllipsisType | None], +) -> Tensor: + r""" + unflatten(input, dim, sizes) -> Tensor + + Expands a dimension of the input tensor over multiple dimensions. + + .. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + + Args: + input (Tensor): the input tensor. + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + + Returns: + A View of input with the specified dimension unflattened. + + Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) + """ + +@overload +def unflatten( + input: Tensor, + dim: _int, + sizes: Sequence[_int | SymInt], +) -> Tensor: + r""" + unflatten(input, dim, sizes) -> Tensor + + Expands a dimension of the input tensor over multiple dimensions. + + .. seealso:: + + :func:`torch.flatten` the inverse of this function. It coalesces several dimensions into one. + + Args: + input (Tensor): the input tensor. + dim (int): Dimension to be unflattened, specified as an index into + ``input.shape``. + sizes (Tuple[int]): New shape of the unflattened dimension. + One of its elements can be `-1` in which case the corresponding output + dimension is inferred. Otherwise, the product of ``sizes`` *must* + equal ``input.shape[dim]``. + + Returns: + A View of input with the specified dimension unflattened. + + Examples:: + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape + torch.Size([3, 2, 2, 1]) + >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape + torch.Size([5, 2, 2, 3, 1, 1, 3]) + """ + +def unfold_copy( + input: Tensor, + dimension: _int, + size: _int, + step: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.unfold`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def unique_dim( + input: Tensor, + dim: _int, + sorted: _bool = True, + return_inverse: _bool = False, + return_counts: _bool = False, +) -> tuple[Tensor, Tensor, Tensor]: ... +def unsafe_chunk( + input: Tensor, + chunks: _int, + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + unsafe_chunk(input, chunks, dim=0) -> List of Tensors + + Works like :func:`torch.chunk` but without enforcing the autograd restrictions + on inplace modification of the outputs. + + .. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. + """ + +def unsafe_split( + input: Tensor, + split_size: _int | SymInt, + dim: _int = 0, +) -> tuple[Tensor, ...]: + r""" + unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors + + Works like :func:`torch.split` but without enforcing the autograd restrictions + on inplace modification of the outputs. + + .. warning:: + This function is safe to use as long as only the input, or only the outputs + are modified inplace after calling this function. It is user's + responsibility to ensure that is the case. If both the input and one or more + of the outputs are modified inplace, gradients computed by autograd will be + silently incorrect. + """ + +def unsafe_split_with_sizes( + input: Tensor, + split_sizes: Sequence[_int | SymInt], + dim: _int = 0, +) -> tuple[Tensor, ...]: ... +def unsqueeze(input: Tensor, dim: _int) -> Tensor: + r""" + unsqueeze(input, dim) -> Tensor + + Returns a new tensor with a dimension of size one inserted at the + specified position. + + The returned tensor shares the same underlying data with this tensor. + + A :attr:`dim` value within the range ``[-input.dim() - 1, input.dim() + 1)`` + can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze` + applied at :attr:`dim` = ``dim + input.dim() + 1``. + + Args: + input (Tensor): the input tensor. + dim (int): the index at which to insert the singleton dimension + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4]) + >>> torch.unsqueeze(x, 0) + tensor([[ 1, 2, 3, 4]]) + >>> torch.unsqueeze(x, 1) + tensor([[ 1], + [ 2], + [ 3], + [ 4]]) + """ + +def unsqueeze_copy( + input: Tensor, + dim: _int, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.unsqueeze`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def values_copy(input: Tensor, *, out: Tensor | None = None) -> Tensor: + r""" + Performs the same operation as :func:`torch.values`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def vander( + x: Tensor, + N: _int | None = None, + increasing: _bool = False, +) -> Tensor: + r""" + vander(x, N=None, increasing=False) -> Tensor + + Generates a Vandermonde matrix. + + The columns of the output matrix are elementwise powers of the input vector :math:`x^{(N-1)}, x^{(N-2)}, ..., x^0`. + If increasing is True, the order of the columns is reversed :math:`x^0, x^1, ..., x^{(N-1)}`. Such a + matrix with a geometric progression in each row is named for Alexandre-Theophile Vandermonde. + + Arguments: + x (Tensor): 1-D input tensor. + N (int, optional): Number of columns in the output. If N is not specified, + a square array is returned :math:`(N = len(x))`. + increasing (bool, optional): Order of the powers of the columns. If True, + the powers increase from left to right, if False (the default) they are reversed. + + Returns: + Tensor: Vandermonde matrix. If increasing is False, the first column is :math:`x^{(N-1)}`, + the second :math:`x^{(N-2)}` and so forth. If increasing is True, the columns + are :math:`x^0, x^1, ..., x^{(N-1)}`. + + Example:: + + >>> x = torch.tensor([1, 2, 3, 5]) + >>> torch.vander(x) + tensor([[ 1, 1, 1, 1], + [ 8, 4, 2, 1], + [ 27, 9, 3, 1], + [125, 25, 5, 1]]) + >>> torch.vander(x, N=3) + tensor([[ 1, 1, 1], + [ 4, 2, 1], + [ 9, 3, 1], + [25, 5, 1]]) + >>> torch.vander(x, N=3, increasing=True) + tensor([[ 1, 1, 1], + [ 1, 2, 4], + [ 1, 3, 9], + [ 1, 5, 25]]) + """ + +@overload +def var( + input: Tensor, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var( + input: Tensor, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var(input: Tensor, unbiased: _bool = True) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + out: Tensor | None = None, +) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + + Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` + can be a single dimension, list of dimensions, or ``None`` to reduce over all + dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + unbiased: _bool = True, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +@overload +def var_mean( + input: Tensor, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, +) -> tuple[Tensor, Tensor]: + r""" + var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + + Calculates the variance and mean over the dimensions specified by :attr:`dim`. + :attr:`dim` can be a single dimension, list of dimensions, or ``None`` to + reduce over all dimensions. + + The variance (:math:`\sigma^2`) is calculated as + + .. math:: \sigma^2 = \frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + + where :math:`x` is the sample set of elements, :math:`\bar{x}` is the + sample mean, :math:`N` is the number of samples and :math:`\delta N` is + the :attr:`correction`. + + + + If :attr:`keepdim` is ``True``, the output tensor is of the same size + as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. + Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the + output tensor having 1 (or ``len(dim)``) fewer dimension(s). + + + Args: + input (Tensor): the input tensor. + + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. + + + Keyword args: + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + + .. versionchanged:: 2.0 + Previously this argument was called ``unbiased`` and was a boolean + with ``True`` corresponding to ``correction=1`` and ``False`` being + ``correction=0``. + + keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. + + out (Tensor, optional): the output tensor. + + Returns: + A tuple (var, mean) containing the variance and mean. + + Example: + + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]] + ... ) # fmt: skip + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) + + .. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction + """ + +def vdot( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + vdot(input, other, *, out=None) -> Tensor + + Computes the dot product of two 1D vectors along a dimension. + + In symbols, this function computes + + .. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + + where :math:`\overline{x_i}` denotes the conjugate for complex + vectors, and it is the identity for real vectors. + + .. note:: + + Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product + of two 1D tensors with the same number of elements. + + .. seealso:: + + :func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension. + + Args: + input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex. + other (Tensor): second tensor in the dot product, must be 1D. + + Keyword args: + + .. note:: out (Tensor, optional): the output tensor. + + + Example:: + + >>> torch.vdot(torch.tensor([2, 3]), torch.tensor([2, 1])) + tensor(7) + >>> a = torch.tensor((1 +2j, 3 - 1j)) + >>> b = torch.tensor((2 +1j, 4 - 0j)) + >>> torch.vdot(a, b) + tensor([16.+1.j]) + >>> torch.vdot(b, a) + tensor([16.-1.j]) + """ + +def view_as_complex(input: Tensor) -> Tensor: + r""" + view_as_complex(input) -> Tensor + + Returns a view of :attr:`input` as a complex tensor. For an input complex + tensor of :attr:`size` :math:`m1, m2, \dots, mi, 2`, this function returns a + new complex tensor of :attr:`size` :math:`m1, m2, \dots, mi` where the last + dimension of the input tensor is expected to represent the real and imaginary + components of complex numbers. + + .. warning:: + :func:`view_as_complex` is only supported for tensors with + :class:`torch.dtype` ``torch.float64`` and ``torch.float32``. The input is + expected to have the last dimension of :attr:`size` 2. In addition, the + tensor must have a `stride` of 1 for its last dimension. The strides of all + other dimensions must be even numbers. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, 2) + >>> x + tensor([[ 1.6116, -0.5772], + [-1.4606, -0.9120], + [ 0.0786, -1.7497], + [-0.6561, -1.6623]]) + >>> torch.view_as_complex(x) + tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)]) + """ + +def view_as_complex_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.view_as_complex`, but all output tensors + are freshly created instead of aliasing the input. + """ + +def view_as_real(input: Tensor) -> Tensor: + r""" + view_as_real(input) -> Tensor + + Returns a view of :attr:`input` as a real tensor. For an input complex tensor of + :attr:`size` :math:`m1, m2, \dots, mi`, this function returns a new + real tensor of size :math:`m1, m2, \dots, mi, 2`, where the last dimension of size 2 + represents the real and imaginary components of complex numbers. + + .. warning:: + :func:`view_as_real` is only supported for tensors with ``complex dtypes``. + + Args: + input (Tensor): the input tensor. + + Example:: + + >>> x=torch.randn(4, dtype=torch.cfloat) + >>> x + tensor([(0.4737-0.3839j), (-0.2098-0.6699j), (0.3470-0.9451j), (-0.5174-1.3136j)]) + >>> torch.view_as_real(x) + tensor([[ 0.4737, -0.3839], + [-0.2098, -0.6699], + [ 0.3470, -0.9451], + [-0.5174, -1.3136]]) + """ + +def view_as_real_copy( + input: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.view_as_real`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def view_copy( + input: Tensor, + dtype: _dtype, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.view`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def view_copy( + input: Tensor, + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, +) -> Tensor: + r""" + Performs the same operation as :func:`torch.view`, but all output tensors + are freshly created instead of aliasing the input. + """ + +@overload +def vsplit(input: Tensor, sections: _int) -> tuple[Tensor, ...]: + r""" + vsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors + vertically according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) + (the split dimension is 0), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.vsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + """ + +@overload +def vsplit(input: Tensor, indices: _size) -> tuple[Tensor, ...]: + r""" + vsplit(input, indices_or_sections) -> List of Tensors + + Splits :attr:`input`, a tensor with two or more dimensions, into multiple tensors + vertically according to :attr:`indices_or_sections`. Each split is a view of + :attr:`input`. + + This is equivalent to calling torch.tensor_split(input, indices_or_sections, dim=0) + (the split dimension is 0), except that if :attr:`indices_or_sections` is an integer + it must evenly divide the split dimension or a runtime error will be thrown. + + This function is based on NumPy's :func:`numpy.vsplit`. + + Args: + input (Tensor): tensor to split. + indices_or_sections (int or list or tuple of ints): See argument in :func:`torch.tensor_split`. + + Example:: + + >>> t = torch.arange(16.0).reshape(4,4) + >>> t + tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> torch.vsplit(t, 2) + (tensor([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), + tensor([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])) + >>> torch.vsplit(t, [3, 6]) + (tensor([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]]), + tensor([[12., 13., 14., 15.]]), + tensor([], size=(0, 4))) + """ + +def vstack( + tensors: tuple[Tensor, ...] | list[Tensor] | None, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + vstack(tensors, *, out=None) -> Tensor + + Stack tensors in sequence vertically (row wise). + + This is equivalent to concatenation along the first axis after all 1-D tensors have been reshaped by :func:`torch.atleast_2d`. + + Args: + tensors (sequence of Tensors): sequence of tensors to concatenate + + Keyword args: + out (Tensor, optional): the output tensor. + + Example:: + + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5, 6]) + >>> torch.vstack((a,b)) + tensor([[1, 2, 3], + [4, 5, 6]]) + >>> a = torch.tensor([[1],[2],[3]]) + >>> b = torch.tensor([[4],[5],[6]]) + >>> torch.vstack((a,b)) + tensor([[1], + [2], + [3], + [4], + [5], + [6]]) + """ + +@overload +def where(condition: Tensor) -> tuple[Tensor, ...]: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def where( + condition: Tensor, + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def where( + condition: Tensor, + self: Number | _complex, + other: Tensor, +) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def where( + condition: Tensor, + input: Tensor, + other: Number | _complex, +) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def where( + condition: Tensor, + self: Number | _complex, + other: Number | _complex, +) -> Tensor: + r""" + where(condition, input, other, *, out=None) -> Tensor + + Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. + + The operation is defined as: + + .. math:: + \text{out}_i = \begin{cases} + \text{input}_i & \text{if } \text{condition}_i \\ + \text{other}_i & \text{otherwise} \\ + \end{cases} + + .. note:: + The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + + Arguments: + condition (BoolTensor): When True (nonzero), yield input, otherwise yield other + input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices + where :attr:`condition` is ``True`` + other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices + where :attr:`condition` is ``False`` + + Keyword args: + out (Tensor, optional): the output tensor. + + Returns: + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + + Example:: + + >>> x = torch.randn(3, 2) + >>> y = torch.ones(3, 2) + >>> x + tensor([[-0.4620, 0.3139], + [ 0.3898, -0.7197], + [ 0.0478, -0.1657]]) + >>> torch.where(x > 0, 1.0, 0.0) + tensor([[0., 1.], + [1., 0.], + [1., 0.]]) + >>> torch.where(x > 0, x, y) + tensor([[ 1.0000, 0.3139], + [ 0.3898, 1.0000], + [ 0.0478, 1.0000]]) + >>> x = torch.randn(2, 2, dtype=torch.double) + >>> x + tensor([[ 1.0779, 0.0383], + [-0.8785, -1.1089]], dtype=torch.float64) + >>> torch.where(x > 0, x, 0.) + tensor([[1.0779, 0.0383], + [0.0000, 0.0000]], dtype=torch.float64) + + .. function:: where(condition) -> tuple of LongTensor + :noindex: + + ``torch.where(condition)`` is identical to + ``torch.nonzero(condition, as_tuple=True)``. + + .. note:: + See also :func:`torch.nonzero`. + """ + +@overload +def xlogy( + input: Tensor, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + +@overload +def xlogy( + self: Number | _complex, + other: Tensor, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + +@overload +def xlogy( + input: Tensor, + other: Number | _complex, + *, + out: Tensor | None = None, +) -> Tensor: + r""" + xlogy(input, other, *, out=None) -> Tensor + + Alias for :func:`torch.special.xlogy`. + """ + +@overload +def xlogy_(input: Tensor, other: Tensor) -> Tensor: ... +@overload +def xlogy_(input: Tensor, other: Number | _complex) -> Tensor: ... +def zero_(input: Tensor) -> Tensor: ... +@overload +def zeros( + size: Sequence[_int | SymInt], + *, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + +@overload +def zeros( + *size: _int | SymInt, + out: Tensor | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + +@overload +def zeros( + size: _size, + *, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + +@overload +def zeros( + *size: _int, + names: Sequence[str | EllipsisType | None] | None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the shape defined + by the variable argument :attr:`size`. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + out (Tensor, optional): the output tensor. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + + Example:: + + >>> torch.zeros(2, 3) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + + >>> torch.zeros(5) + tensor([ 0., 0., 0., 0., 0.]) + """ + +def zeros_like( + input: Tensor, + *, + memory_format: memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, +) -> Tensor: + r""" + zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor + + Returns a tensor filled with the scalar value `0`, with the same size as + :attr:`input`. ``torch.zeros_like(input)`` is equivalent to + ``torch.zeros(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``. + + .. warning:: + As of 0.4, this function does not support an :attr:`out` keyword. As an alternative, + the old ``torch.zeros_like(input, out=output)`` is equivalent to + ``torch.zeros(input.size(), out=output)``. + + Args: + input (Tensor): the size of :attr:`input` will determine size of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: if ``None``, defaults to the dtype of :attr:`input`. + layout (:class:`torch.layout`, optional): the desired layout of returned tensor. + Default: if ``None``, defaults to the layout of :attr:`input`. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, defaults to the device of :attr:`input`. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + Example:: + + >>> input = torch.empty(2, 3) + >>> torch.zeros_like(input) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]]) + """ diff --git a/phivenv/Lib/site-packages/torch/_C/__init__.pyi b/phivenv/Lib/site-packages/torch/_C/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..102db955944256c2d8ea8ef56900eb65c61f7bab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/__init__.pyi @@ -0,0 +1,12850 @@ +# @generated by tools/pyi/gen_pyi.py from torch/_C/__init__.pyi.in +# mypy: disable-error-code="type-arg" +# mypy: allow-untyped-defs +# ruff: noqa: F401 + +from collections.abc import Iterable, Iterator, Sequence +from enum import Enum, IntEnum +from pathlib import Path +from types import EllipsisType +from typing import ( + Any, + AnyStr, + Callable, + Generic, + IO, + Literal, + NamedTuple, + overload, + SupportsIndex, + TypeVar, +) +from typing_extensions import ParamSpec, Protocol, runtime_checkable, Self, TypeAlias + +import numpy + +import torch +from torch import inf, SymInt, Tensor +from torch._C import ( + _aoti, + _cpu, + _dynamo, + _export, + _functorch, + _lazy, + _lazy_ts_backend, + _nn, + _onnx, + _VariableFunctions, + _verbose, +) +from torch._prims_common import DeviceLikeType +from torch.autograd.graph import Node as _Node +from torch.fx.node import Node as FxNode +from torch.package import PackageExporter +from torch.storage import TypedStorage, UntypedStorage +from torch.types import ( + _bool, + _bytes, + _complex, + _device, + _dispatchkey, + _dtype, + _float, + _int, + _layout, + _qscheme, + _size, + _str, + _symsize, + Device, + IntLikeType, + Number, + Storage, +) +from torch.utils._python_dispatch import TorchDispatchMode + +# This module is defined in torch/csrc/Module.cpp + +K = TypeVar("K") # noqa: PYI001 +T = TypeVar("T") # noqa: PYI001 +S = TypeVar("S", bound=torch.Tensor) # noqa: PYI001 +P = ParamSpec("P") # noqa: PYI001 +R = TypeVar("R", covariant=True) # return value (always covariant) # noqa: PYI001 +T_co = TypeVar("T_co", covariant=True) # noqa: PYI001 + +@runtime_checkable +class _NestedSequence(Protocol[T_co]): + """A protocol for representing nested sequences. + + References:: + `numpy._typing._NestedSequence` + + """ + + def __len__(self, /) -> _int: ... + def __getitem__(self, index: _int, /) -> T_co | _NestedSequence[T_co]: ... + def __contains__(self, x: object, /) -> _bool: ... + def __iter__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ... + def __reversed__(self, /) -> Iterator[T_co | _NestedSequence[T_co]]: ... + def count(self, value: Any, /) -> _int: ... + def index(self, value: Any, /) -> _int: ... + +# Defined in torch/csrc/Device.cpp +class device: + type: str # THPDevice_type + index: _int # THPDevice_index + + def __get__(self, instance, owner=None) -> device: ... + + # THPDevice_pynew + @overload + def __init__(self, device: DeviceLikeType) -> None: ... + @overload + def __init__(self, type: str, index: _int) -> None: ... + + # Uncomment if we ever make torch.device a decorator + # def __call__(self, func: T) -> T: ... + + def __enter__(self) -> Self: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... + def __reduce__(self) -> tuple[Any, ...]: ... # THPDevice_reduce + +# Defined in torch/csrc/Stream.cpp +class Stream: + stream_id: _int # Stream id + device_index: _int + device_type: _int + + device: _device # The device of the stream + + @overload + def __new__( + cls, + device: DeviceLikeType | None = None, + *, + priority: _int = 0, + ) -> Self: ... + @overload + def __new__( + cls, + stream_id: _int, + device_index: _int, + device_type: _int, + *, + priority: _int = 0, + ) -> Self: ... + def query(self) -> _bool: ... + def synchronize(self) -> None: ... + def wait_event(self, event: Event) -> None: ... + def wait_stream(self, other: Stream) -> None: ... + def record_event(self, event: Event | None = None) -> Event: ... + def __hash__(self) -> _int: ... + def __eq__(self, other: object) -> _bool: ... + def __enter__(self) -> Self: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... + +# Defined in torch/csrc/Event.cpp +class Event: + device: _device # The device of the Event + event_id: _int # The raw event created by device backend + + def __new__( + cls, + device: DeviceLikeType | None = None, + *, + enable_timing: _bool = False, + blocking: _bool = False, + interprocess: _bool = False, + ) -> Self: ... + @classmethod + def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> Event: ... + def record(self, stream: Stream | None = None) -> None: ... + def wait(self, stream: Stream | None = None) -> None: ... + def query(self) -> _bool: ... + def elapsed_time(self, other: Event) -> _float: ... + def synchronize(self) -> None: ... + def ipc_handle(self) -> bytes: ... + +# Defined in torch/csrc/Size.cpp +class Size(tuple[_int, ...]): + # TODO: __reduce__ + + @overload + def __getitem__(self: Size, key: SupportsIndex, /) -> _int: ... + @overload + def __getitem__(self: Size, key: slice, /) -> Size: ... + # Note: torch.Size does not support adding non-integer tuples. + def __add__(self, other: tuple[_int, ...], /) -> Size: ... # type: ignore[override] + def __radd__(self: Size, other: tuple[_int, ...], /) -> Size: ... + def __mul__(self, other: SupportsIndex, /) -> Size: ... + def __rmul__(self, other: SupportsIndex, /) -> Size: ... + def numel(self: Size, /) -> _int: ... + +# Defined in torch/csrc/Dtype.cpp +class dtype: + # TODO: __reduce__ + is_floating_point: _bool + is_complex: _bool + is_signed: _bool + itemsize: _int + def to_real(self) -> dtype: ... + def to_complex(self) -> dtype: ... + +# Defined in torch/csrc/TypeInfo.cpp +class iinfo: + bits: _int + min: _int + max: _int + dtype: str + + def __init__(self, dtype: _dtype) -> None: ... + +class finfo: + bits: _int + min: _float + max: _float + eps: _float + tiny: _float + smallest_normal: _float + resolution: _float + dtype: str + + @overload + def __init__(self, dtype: _dtype) -> None: ... + @overload + def __init__(self) -> None: ... + +float32: dtype = ... +float: dtype = ... +float64: dtype = ... +double: dtype = ... +float16: dtype = ... +bfloat16: dtype = ... +float8_e4m3fn: dtype = ... +float8_e4m3fnuz: dtype = ... +float8_e5m2: dtype = ... +float8_e5m2fnuz: dtype = ... +float8_e8m0fnu: dtype = ... +float4_e2m1fn_x2: dtype = ... +half: dtype = ... +uint8: dtype = ... +uint16: dtype = ... +uint32: dtype = ... +uint64: dtype = ... +int8: dtype = ... +int16: dtype = ... +short: dtype = ... +int32: dtype = ... +int: dtype = ... +int64: dtype = ... +long: dtype = ... +complex32: dtype = ... +complex64: dtype = ... +chalf: dtype = ... +cfloat: dtype = ... +complex128: dtype = ... +cdouble: dtype = ... +quint8: dtype = ... +qint8: dtype = ... +qint32: dtype = ... +bool: dtype = ... +quint4x2: dtype = ... +quint2x4: dtype = ... +bits1x8: dtype = ... +bits2x4: dtype = ... +bits4x2: dtype = ... +bits8: dtype = ... +bits16: dtype = ... + +# Defined in torch/csrc/Layout.cpp +class layout: ... + +# Defined in torch/csrc/utils/disable_torch_function.cpp +def DisableTorchFunction(): ... +def DisableTorchFunctionSubclass(): ... + +# Defined in torch/csrc/utils/tensor_layouts.cpp +strided: layout = ... +sparse_coo: layout = ... +sparse_csr: layout = ... +sparse_csc: layout = ... +sparse_bsr: layout = ... +sparse_bsc: layout = ... +_mkldnn: layout = ... +jagged: layout = ... + +# Defined in torch/csrc/MemoryFormat.cpp +class memory_format: ... + +# Defined in torch/csrc/utils/tensor_memoryformats.cpp +contiguous_format: memory_format = ... +channels_last: memory_format = ... +channels_last_3d: memory_format = ... +preserve_format: memory_format = ... + +# Defined in torch/csrc/QScheme.cpp +class qscheme: ... + +# Defined in torch/csrc/utils/tensor_qschemes.h +per_tensor_affine: qscheme = ... +per_channel_affine: qscheme = ... +per_tensor_symmetric: qscheme = ... +per_channel_symmetric: qscheme = ... +per_channel_affine_float_qparams: qscheme = ... + +# Defined in torch/csrc/autograd/python_function.cpp +class _FunctionBase: + saved_tensors: tuple[Tensor] + _raw_saved_tensors: tuple[Any] + next_functions: tuple[tuple[Any, _int], ...] + needs_input_grad: tuple[_bool] + metadata: dict + _materialize_non_diff_grads: _bool + # skip adding type hints for the fields that have wrappers defined + # in torch/autograd/function.py + +# Defined in torch/csrc/autograd/python_legacy_variable.cpp +class _LegacyVariableBase(Tensor): # inherits from Tensor to appease mypy + def __init__( + self, + data: Tensor | None = ..., + requires_grad: _bool | None = ..., + volatile: _bool | None = ..., + _grad_fn: _FunctionBase | None = ..., + ) -> None: ... + +# Defined in torch/csrc/jit/python/init.cpp +class IODescriptor: ... +class JITException(Exception): ... + +class Future(Generic[T]): + def __init__(self, devices: list[device]) -> None: ... + def done(self) -> _bool: ... + def value(self) -> T: ... + def wait(self) -> T: ... + def add_done_callback(self, callback: Callable) -> None: ... + def then(self, callback: Callable) -> Future[T]: ... + def set_result(self, result: T) -> None: ... + def _set_unwrap_func(self, callback: Callable) -> None: ... + +class _Await: + def __init__(self) -> None: ... + def fn(self) -> Callable: ... + def args(self) -> tuple[Any, ...]: ... + def is_nowait(self) -> _bool: ... + +def _jit_set_num_profiled_runs(num: _size) -> _size: ... + +# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h +class _MobileOptimizerType: ... + +CONV_BN_FUSION: _MobileOptimizerType +INSERT_FOLD_PREPACK_OPS: _MobileOptimizerType +REMOVE_DROPOUT: _MobileOptimizerType +FUSE_ADD_RELU: _MobileOptimizerType +HOIST_CONV_PACKED_PARAMS: _MobileOptimizerType +VULKAN_AUTOMATIC_GPU_TRANSFER: _MobileOptimizerType + +def fork(*args: Any, **kwargs: Any) -> Future: ... +def wait(fut: Future) -> Any: ... +def _awaitable(*args: Any, **kwargs: Any) -> _Await: ... +def _awaitable_wait(aw: _Await) -> Any: ... +def _awaitable_nowait(x: Any) -> _Await: ... +def _collect_all(futures: list[Future]) -> Future: ... +def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ... +def unify_type_list(types: list[JitType]) -> JitType: ... +def _freeze_module( + module: ScriptModule, + preserved_attrs: list[str] = ..., + freeze_interfaces: _bool = True, + preserveParameters: _bool = True, +) -> ScriptModule: ... +def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ... +def _jit_pass_optimize_for_inference( + module: torch.jit.ScriptModule, + other_methods: list[str] = ..., +) -> None: ... +def _jit_pass_fold_frozen_conv_bn(graph: Graph): ... +def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ... +def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ... +def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ... +def _jit_pass_concat_frozen_linear(graph: Graph): ... +def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ... +def _jit_pass_transpose_frozen_linear(graph: Graph): ... +def _jit_pass_remove_dropout(module: torch.jit.ScriptModule): ... +def _is_tracing() -> _bool: ... +def _jit_init() -> _bool: ... +def _jit_flatten(arg: Any) -> tuple[list[Tensor], IODescriptor]: ... +def _jit_unflatten(vars: list[Tensor], desc: IODescriptor) -> Any: ... +def _jit_get_operation(op_name: str) -> tuple[Callable, list[str]]: ... +def _get_operation_overload( + op_name: str, + op_overload_name: str, +) -> tuple[Callable, Callable, list[Any]]: ... +def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ... +def _jit_pass_optimize_for_mobile( + module: torch.jit.ScriptModule, + optimization_blocklist: set[_MobileOptimizerType], + preserved_methods: list[AnyStr], +) -> torch.jit.ScriptModule: ... +def _clone_module_with_class( + module: torch.jit.ScriptModule, + ignored_methods: list[AnyStr], + ignored_attributes: list[AnyStr], +) -> torch.jit.ScriptModule: ... +def _jit_pass_vulkan_optimize_for_mobile( + module: torch.jit.ScriptModule, + optimization_blocklist: set[_MobileOptimizerType], + preserved_methods: list[AnyStr], +) -> torch.jit.ScriptModule: ... +def _jit_pass_metal_optimize_for_mobile( + module: torch.jit.ScriptModule, + preserved_methods: list[AnyStr], +) -> torch.jit.ScriptModule: ... +def _jit_pass_inline(Graph) -> None: ... +def _jit_pass_constant_propagation(Graph) -> None: ... +def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ... +def _jit_register_decomposition_for_schema(schema: FunctionSchema, Graph) -> None: ... +def _jit_erase_non_input_shape_information(Graph) -> None: ... +def _jit_get_schemas_for_operator(name: str) -> list[FunctionSchema]: ... +def _jit_get_all_schemas() -> list[FunctionSchema]: ... +def _jit_check_alias_annotation( + g: Graph, + args: tuple[Any, ...], + unqualified_op_name: str, +): ... +def _jit_can_fuse_on_cpu() -> _bool: ... +def _jit_can_fuse_on_gpu() -> _bool: ... +def _jit_can_fuse_on_cpu_legacy() -> _bool: ... +def _debug_get_fusion_group_inlining() -> _bool: ... +def _debug_set_fusion_group_inlining(enable: _bool): ... +def _jit_texpr_fuser_enabled() -> _bool: ... +def _jit_nvfuser_enabled() -> _bool: ... +def _jit_llga_enabled() -> _bool: ... +def _jit_set_llga_enabled(enable: _bool): ... +def _llvm_enabled() -> _bool: ... +def _jit_override_can_fuse_on_cpu(override: _bool): ... +def _jit_override_can_fuse_on_gpu(override: _bool): ... +def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ... +def _jit_set_symbolic_shapes_test_mode(override: _bool): ... +def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ... +def _jit_set_texpr_fuser_enabled(enable: _bool): ... +def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ... +def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ... +def _jit_cat_wo_conditionals(optimize_cat: _bool): ... +def _jit_opt_conditionals(opt_conds: _bool): ... +def _jit_pass_canonicalize(graph: Graph, keep_unique_names: _bool = True): ... +def _jit_pass_erase_shape_information(graph: Graph): ... +def _jit_pass_fold_convbn(module: torch.jit.ScriptModule): ... +def _jit_pass_insert_observers( + module: torch.jit.ScriptModule, + method_name: str, + qconfig_dict: dict[str, Any], + inplace: _bool, + quant_type: _int, +): ... +def _jit_pass_insert_quant_dequant( + module: torch.jit.ScriptModule, + method_name: str, + inplace: _bool, + debug: _bool, + quant_type: _int, +): ... +def _jit_pass_insert_quant_dequant_for_ondevice_ptq( + module: torch.jit.ScriptModule, + method_name: str, + inplace: _bool, + debug: _bool, + quant_type: _int, +): ... +def _jit_pass_quant_finalize( + module: torch.jit.ScriptModule, + quant_type: _int, + preserved_attrs: Sequence[str], +): ... +def _jit_pass_quant_finalize_for_ondevice_ptq( + module: torch.jit.ScriptModule, + quant_type: _int, + method_name: str, +): ... +def _jit_pass_insert_observer_method_for_ondevice_ptq( + module: torch.jit.ScriptModule, + method_name: str, + qconfig_dict: dict[str, Any], + inplace: _bool, + quant_type: _int, +): ... +def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ... +def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ... +def _jit_set_fusion_strategy( + strategy: list[tuple[str, _int]], +) -> list[tuple[str, _int]]: ... +def _jit_try_infer_type(obj: Any) -> InferredType: ... +def _jit_get_trigger_value(trigger_name: str) -> _int: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]] + +# Defined in torch/csrc/jit/python/script_init.cpp +# and torch/csrc/jit/python/init.cpp +def _maybe_call_torch_function_for_op_packet( + op_overload_packet: Any, + *args: Any, + **kwargs: Any, +) -> Any: ... +def _check_schema_allow_fake_script_object( + schema: FunctionSchema, + *args: Any, + **kwargs: Any, +) -> _bool: ... +def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ... +def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ... +def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ... +def _jit_assert_is_instance(obj: Any, type: JitType): ... +def _jit_clear_class_registry() -> None: ... +def _jit_set_emit_hooks( + ModuleHook: Callable | None, + FunctionHook: Callable | None, +) -> None: ... +def _jit_get_emit_hooks() -> tuple[Callable, Callable]: ... +def _load_for_lite_interpreter( + filename: str | Path, + map_location: DeviceLikeType | None, +): ... +def _load_for_lite_interpreter_from_buffer( + buffer: IO[bytes], + map_location: DeviceLikeType | None, +): ... +def _export_operator_list(module: LiteScriptModule): ... +def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ... +def _get_model_bytecode_version(filename: str | Path) -> _int: ... +def _get_model_bytecode_version_from_buffer(buffer: IO[bytes]) -> _int: ... +def _backport_for_mobile( + filename_input: str | Path, + filename_output: str | Path, + to_version: _int, +) -> None: ... +def _backport_for_mobile_from_buffer( + buffer: IO[bytes], + filename_output: str | Path, + to_version: _int, +) -> None: ... +def _backport_for_mobile_to_buffer( + filename_input: str | Path, + to_version: _int, +) -> bytes: ... +def _backport_for_mobile_from_buffer_to_buffer( + buffer: IO[bytes], + to_version: _int, +) -> bytes: ... +def _get_model_ops_and_info(filename: str | Path): ... +def _get_model_ops_and_info_from_buffer(buffer: IO[bytes]): ... +def _get_mobile_model_contained_types(filename: str | Path): ... +def _get_mobile_model_contained_types_from_buffer(buffer: IO[bytes]): ... +def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ... +def _get_graph_executor_optimize(optimize: _bool | None = None) -> _bool: ... +def _set_graph_executor_optimize(optimize: _bool): ... +def _export_opnames(module: ScriptModule) -> list[str]: ... +def _create_function_from_trace( + qualname: str, + func: Callable[..., Any], + input_tuple: tuple[Any, ...], + var_lookup_fn: Callable[[Tensor], str], + strict: _bool, + force_outplace: _bool, + argument_names: list[str], +) -> tuple[Graph, Stack]: ... +def _create_function_from_trace_with_dict( + qualname: str, + func: Callable[..., Any], + input_dict: dict[str, Any], + var_lookup_fn: Callable[[Tensor], str], + strict: _bool, + force_outplace: _bool, + argument_names: list[str], +) -> tuple[Graph, Stack]: ... +def _jit_is_script_object(obj: Any) -> _bool: ... +def _last_executed_optimized_graph() -> Graph: ... +def parse_type_comment(comment: str) -> Decl: ... +def _get_upgraders_map_size() -> _int: ... +def _get_upgraders_entry_map() -> dict[str, str]: ... +def _dump_upgraders_map() -> dict[str, str]: ... +def _test_only_populate_upgraders(content: dict[str, str]) -> None: ... +def _test_only_remove_upgraders(content: dict[str, str]) -> None: ... +def merge_type_from_type_comment( + decl: Decl, + type_annotation_decl: Decl, + is_method: _bool, +) -> Decl: ... +def parse_ir(input: str, parse_tensor_constants: _bool = False) -> Graph: ... +def parse_schema(schema: str) -> FunctionSchema: ... +def get_device(input: Tensor) -> _int: ... +def _resolve_type_from_object( + obj: Any, + range: SourceRange, + rcb: ResolutionCallback, +) -> JitType: ... +def _create_module_with_type(ty: JitType) -> ScriptModule: ... +def _create_object_with_type(ty: ClassType) -> ScriptObject: ... +def _run_emit_module_hook(m: ScriptModule): ... +def _replace_overloaded_method_decl( + overload_decl: Decl, + implementation_def: Def, + new_name: str, +) -> Def: ... +def _jit_pass_lower_all_tuples(graph: Graph) -> None: ... +def _jit_pass_onnx_set_dynamic_input_shape( + graph: Graph, + dynamic_axes: dict[str, dict[_int, str]], + input_names: list[str], +) -> None: ... +def _jit_pass_onnx_graph_shape_type_inference( + graph: Graph, + params_dict: dict[str, IValue], + opset_version: _int, +) -> None: ... +def _jit_pass_onnx_assign_output_shape( + graph: Graph, + tensors: list[Tensor], + desc: IODescriptor, + onnx_shape_inference: _bool, + is_script: _bool, + opset_version: _int, +) -> None: ... +def _jit_pass_onnx_remove_inplace_ops_for_onnx( + graph: Graph, + module: ScriptModule | None = None, +) -> None: ... +def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... +def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... +def _jit_pass_peephole( + graph: Graph, + disable_shape_peepholes: _bool = False, +) -> None: ... +def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ... +def _jit_pass_fuse_addmm(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess(graph: Graph) -> None: ... +def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ... +def _jit_pass_onnx_remove_print(graph: Graph) -> None: ... +def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ... +def _jit_pass_onnx_unpack_quantized_weights( + graph: Graph, + paramsDict: dict[str, IValue], +) -> dict[str, IValue]: ... +def _jit_pass_onnx_quantization_insert_permutes( + graph: Graph, + paramsDict: dict[str, IValue], +) -> dict[str, IValue]: ... +def _jit_pass_custom_pattern_based_rewrite_graph( + pattern: str, + fused_node_name: str, + graph: Graph, +) -> None: ... +def _jit_onnx_list_model_parameters( + module: ScriptModule, +) -> tuple[ScriptModule, list[IValue]]: ... +def _jit_pass_erase_number_types(graph: Graph) -> None: ... +def _jit_pass_onnx_lint(graph: Graph) -> None: ... +def _jit_pass_onnx( + graph: Graph, + _jit_pass_onnx: _onnx.OperatorExportTypes, +) -> Graph: ... +def _jit_pass_onnx_scalar_type_analysis( + graph: Graph, + lowprecision_cast: _bool, + opset_version: _int, +) -> None: ... +def _jit_pass_onnx_peephole( + graph: Graph, + opset_version: _int, + fixed_batch_size: _bool, +) -> None: ... +def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ... +def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ... +def _jit_pass_onnx_function_extraction( + graph: Graph, + module_names: set[str], + param_names: list[str], +) -> dict[Node, dict[str, str]]: ... +def _jit_pass_onnx_clear_scope_records() -> None: ... +def _jit_pass_onnx_track_scope_attributes( + graph: Graph, + onnx_attrs: dict[str, Any], +) -> None: ... +def _jit_is_onnx_log_enabled() -> _bool: ... +def _jit_set_onnx_log_enabled(enabled: _bool) -> None: ... +def _jit_set_onnx_log_output_stream(stream_name: str) -> None: ... +def _jit_onnx_log(*args: Any) -> None: ... +def _jit_pass_lower_graph(graph: Graph, m: Module) -> tuple[Graph, list[IValue]]: ... +def _jit_pass_inline_fork_wait(graph: Graph) -> None: ... +def _jit_pass_onnx_deduplicate_initializers( + graph: Graph, + params_dict: dict[str, IValue], + is_train: _bool, +) -> dict[str, IValue]: ... +def _jit_pass_onnx_eval_peephole( + graph: Graph, + paramsDict: dict[str, IValue], +) -> dict[str, IValue]: ... +def _jit_pass_onnx_constant_fold( + graph: Graph, + paramsDict: dict[str, IValue], + opset_version: _int, +) -> dict[str, IValue]: ... +def _jit_pass_onnx_eliminate_unused_items( + graph: Graph, + paramsDict: dict[str, IValue], +) -> dict[str, IValue]: ... +def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ... +def _jit_pass_filter_non_tensor_arguments( + params: dict[str, IValue], +) -> dict[str, Tensor]: ... +def _jit_decay_packed_param_input_types(graph: Graph) -> None: ... +def _jit_pass_onnx_node_shape_type_inference( + n: Node, + paramsDict: dict[str, IValue], + opset_version: _int, +) -> None: ... +def _jit_onnx_convert_pattern_from_subblock( + block: Block, + n: Node, + env: dict[Value, Value], + values_in_env: set[Value], +) -> list[Value]: ... +def _jit_pass_onnx_block( + old_block: Block, + new_block: Block, + operator_export_type: _onnx.OperatorExportTypes, + env: dict[Value, Value], + values_in_env: set[Value], + is_sub_block: _bool, +) -> dict[Value, Value]: ... +def _jit_pass_onnx_assign_scoped_names_for_node_and_value(graph: Graph) -> None: ... +def _jit_pass_fixup_onnx_controlflow_node( + n: Node, + opset_version: _int, +) -> list[Value]: ... +def _jit_onnx_create_full_scope_name(class_name: str, variable_name: str) -> str: ... +def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ... +def _generate_upgraders_graph() -> dict[str, Graph]: ... +def _calculate_package_version_based_on_upgraders(val: _bool): ... +def _get_version_calculator_flag() -> _bool: ... +def _jit_script_interface_compile( + name: str, + class_def: ClassDef, + rcb: ResolutionCallback, + is_module: _bool, +): ... +def _jit_script_compile_overload( + qualname: str, + overload_decl: Decl, + implementation_def: Def, + rcb: ResolutionCallback, + implementation_defaults: dict[str, Any], + signature: Any, +): ... +def _jit_script_compile( + qual_name: str, + definition: Def, + rcb: ResolutionCallback, + defaults: dict[str, Any], +): ... +def _jit_script_class_compile( + qual_name: str, + definition: ClassDef, + defaults: dict[str, dict[str, Any]], + rcb: ResolutionCallback, +): ... +def _parse_source_def(src: str) -> Def: ... +def import_ir_module( + cu: CompilationUnit, + filename: str | Path, + map_location: DeviceLikeType | None, + extra_files: dict[str, Any], +) -> ScriptModule: ... +def import_ir_module_from_buffer( + cu: CompilationUnit, + buffer: IO[bytes], + map_location: DeviceLikeType | None, + extra_files: dict[str, Any], +) -> ScriptModule: ... +def _import_ir_module_from_package( + cu: CompilationUnit, + reader: PyTorchFileReader, + storage_context: DeserializationStorageContext, + map_location: DeviceLikeType | None, + ts_id: str, +) -> ScriptModule: ... +def _assign_output_shapes(graph: Graph, inputs: list[Tensor]) -> Graph: ... +def _check_onnx_proto(proto: str) -> None: ... +def _propagate_and_assign_input_shapes( + graph: Graph, + inputs: tuple[Tensor, ...], + param_count_list: list[_int], + with_grad: _bool, + propagate: _bool, +) -> Graph: ... + +# Defined in torch/csrc/jit/runtime/graph_executor.h +class GraphExecutorState: ... + +# Defined in torch/torch/csrc/jit/ir/alias_analysis.h +class AliasDb: ... + +class _InsertPoint: + def __enter__(self) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... + +# Defined in torch/csrc/jit/ir/ir.h +class Use: + @property + def user(self) -> Node: ... + @property + def offset(self) -> _int: ... + def isAfter(self, other: Use) -> _bool: ... + +# Defined in torch/csrc/jit/ir/ir.h +class Value: + def type(self) -> JitType: ... + def setType(self, t: JitType) -> Value: ... + def setTypeAs(self, other: Value) -> Value: ... + def inferTypeFrom(self, t: Tensor) -> None: ... + def debugName(self) -> str: ... + def setDebugName(self, name: str) -> None: ... + def unique(self) -> _int: ... + def offset(self) -> _int: ... + def node(self) -> Node: ... + def uses(self) -> list[Use]: ... + def replaceAllUsesWith(self, val: Value) -> None: ... + def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ... + def requires_grad(self) -> _bool: ... + def requiresGrad(self) -> _bool: ... + def copyMetadata(self, other: Value) -> Value: ... + def isCompleteTensor(self) -> _bool: ... + def toIValue(self) -> IValue: ... + +# Defined in torch/csrc/jit/ir/ir.h +class Block: + def inputs(self) -> Iterator[Value]: ... + def outputs(self) -> Iterator[Value]: ... + def nodes(self) -> Iterator[Node]: ... + def paramNode(self) -> Node: ... + def returnNode(self) -> Node: ... + def owningNode(self) -> Node: ... + def registerOutput(self, n: Value) -> _int: ... + def addNode(self, name: str, inputs: Sequence[Value]) -> Node: ... + +# Defined in torch/csrc/jit/ir/ir.h +class Node: + def __getitem__(self, key: str) -> Any: ... + def schema(self) -> str: ... + def input(self) -> Value: ... + def inputs(self) -> Iterator[Value]: ... + def inputsAt(self, idx: _int) -> Value: ... + def inputsSize(self) -> _int: ... + def output(self) -> Value: ... + def outputs(self) -> Iterator[Value]: ... + def outputsAt(self, idx: _int) -> Value: ... + def outputsSize(self) -> _int: ... + def hasMultipleOutputs(self) -> _bool: ... + def blocks(self) -> list[Block]: ... + def addBlock(self) -> Block: ... + def mustBeNone(self) -> _bool: ... + def matches(self, pattern: str) -> _bool: ... + def kind(self) -> str: ... + def kindOf(self, name: str) -> str: ... + def addInput(self, name: str) -> Value: ... + def replaceInput(self, i: _int, newValue: Value) -> Value: ... + def replaceInputWith(self, from_: Value, to: Value) -> None: ... + def replaceAllUsesWith(self, n: Node) -> None: ... + def insertBefore(self, n: Node) -> Node: ... + def insertAfter(self, n: Node) -> Node: ... + def isBefore(self, n: Node) -> _bool: ... + def isAfter(self, n: Node) -> _bool: ... + def moveBefore(self, n: Node) -> None: ... + def moveAfter(self, n: Node) -> None: ... + def removeInput(self, i: _int) -> None: ... + def removeAllInputs(self, i: _int) -> None: ... + def hasUses(self) -> _bool: ... + def eraseOutput(self, i: _int) -> None: ... + def addOutput(self) -> Value: ... + def scopeName(self) -> str: ... + def isNondeterministic(self) -> _bool: ... + def copyAttributes(self, rhs: Node) -> Node: ... + def copyMetadata(self, rhs: Node) -> Node: ... + def hasAttributes(self) -> _bool: ... + def hasAttribute(self, name: str) -> _bool: ... + def removeAttribute(self, attr: str) -> Node: ... + def namedInput(self, name: str) -> Value: ... + def sourceRange(self) -> SourceRange: ... + def owningBlock(self) -> Block: ... + def findNode(self, kind: str, recurse: _bool = True) -> Node: ... + def findAllNodes(self, kind: str, recurse: _bool = True) -> list[Node]: ... + def getModuleHierarchy(self) -> str: ... + def prev(self) -> Node: ... + def destroy(self) -> None: ... + def attributeNames(self) -> list[str]: ... + + # Accessors for attributes as types. + def f(self, name: str) -> _float: ... + def f_(self, name: str, val: _float) -> Node: ... + def fs(self, name: str) -> list[_float]: ... + def fs_(self, name: str, val: list[_float]) -> Node: ... + def c(self, name: str) -> complex: ... + def c_(self, name: str, val: complex) -> Node: ... + def s(self, name: str) -> str: ... + def s_(self, name: str, val: str) -> Node: ... + def ss(self, name: str) -> list[str]: ... + def ss_(self, name: str, val: list[str]) -> Node: ... + def i(self, name: str) -> _int: ... + def i_(self, name: str, val: _int) -> Node: ... + # Cannot define "is" like this because it's a reserved keyword in python. + # def is(self, name: str) -> List[_int]: ... + # def is_(self, name: str, val: List[_int]) -> Node: ... + def g(self, name: str) -> Graph: ... + def g_(self, name: str, val: Graph) -> Node: ... + def gs(self, name: str) -> list[Graph]: ... + def gs_(self, name: str, val: list[Graph]) -> Node: ... + def ival(self, name: str) -> IValue: ... + def ival_(self, name: str, val: IValue) -> Node: ... + def t(self, name: str) -> Tensor: ... + def t_(self, name: str, val: Tensor) -> Node: ... + def ts(self, name: str) -> list[Tensor]: ... + def ts_(self, name: str, val: list[Tensor]) -> Node: ... + def ty(self, name: str) -> JitType: ... + def ty_(self, name: str, val: JitType) -> Node: ... + def tys(self, name: str) -> list[JitType]: ... + def tys_(self, name: str, val: list[JitType]) -> Node: ... + +# Defined in torch/torch/csrc/jit/ir/ir.h +class Graph: + def inputs(self) -> Iterator[Value]: ... + def outputs(self) -> Iterator[Value]: ... + def nodes(self) -> Iterator[Node]: ... + def param_node(self) -> Node: ... + def return_node(self) -> Node: ... + def addInput(self, name: str = "") -> Value: ... + def eraseInput(self, i: _int) -> None: ... + def registerOutput(self, n: Value) -> _int: ... + def eraseOutput(self, i: _int) -> None: ... + def create(self, name: str, args, num_outputs: _int) -> Node: ... + def appendNode(self, n: Node) -> Node: ... + def prependNode(self, n: Node) -> Node: ... + def insertNode(self, n: Node) -> Node: ... + def block(self) -> Block: ... + def lint(self) -> None: ... + def alias_db(self) -> AliasDb: ... + def setInsertPoint(self, n: Block | Node) -> None: ... + def insert_point_guard(self, n: Block | Node) -> _InsertPoint: ... + def insertPoint(self) -> Node: ... + def insertGraph(self, callee: Graph, inputs: list[Value]) -> list[Value]: ... + def makeMultiOutputIntoTuple(self) -> None: ... + def copy(self) -> Graph: ... + +# Defined in torch/aten/src/ATen/core/alias_info.h +class AliasInfo: + is_write: _bool + before_set: set[str] + after_set: set[str] + def __init__( + self, + is_write: _bool, + before_set: set[str], + after_set: set[str], + ) -> None: ... + +# Defined in torch/aten/src/ATen/core/function_schema.h +class Argument: + name: str + type: JitType + default_value: Any | None + def has_default_value(self) -> _bool: ... + kwarg_only: _bool + is_out: _bool + alias_info: AliasInfo | None + is_write: _bool + real_type: JitType + def __init__( + self, + name: str, + type: JitType, + N: _int | None, + defualt_value: Any | None, + kwarg_only: _bool, + alias_info: AliasInfo | None, + ) -> None: ... + +class FunctionSchema: + arguments: list[Argument] + returns: list[Argument] + name: str + overload_name: str + is_mutable: _bool + def __init__( + self, + name: str, + overload_name: str, + arguments: list[Argument], + returns: list[Argument], + is_vararg: _bool, + is_varret: _bool, + ) -> None: ... + +class _UpgraderEntry: + bumped_at_version: _int + upgrader_name: str + old_schema: str + def __init__( + self, + bumped_at_version: _int, + upgrader_name: str, + old_schema: str, + ) -> None: ... + +class _UpgraderRange: + min_version: _int + max_version: _int + +def _get_max_operator_version() -> _int: ... +def _get_operator_version_map() -> dict[str, list[_UpgraderEntry]]: ... +def _get_upgrader_ranges(name: str) -> list[_UpgraderRange]: ... +def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ... +def _test_only_remove_entry_to_op_version(op_name: str) -> None: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class ScriptModuleSerializer: + def __init__(self, export_writer: PyTorchFileWriter) -> None: ... + def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ... + def write_files(self) -> None: ... + def storage_context(self) -> SerializationStorageContext: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class SerializationStorageContext: + def __init__(self) -> None: ... + def has_storage(self, storage: Storage) -> _bool: ... + def get_or_add_storage(self, storage: Storage) -> _int: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class DeserializationStorageContext: + def __init__(self) -> None: ... + def get_storage(self, name: str, dtype: _dtype) -> Tensor: ... + def has_storage(self, name: str) -> _bool: ... + def add_storage(self, name: str, tensor: Tensor) -> _int: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class ConcreteModuleTypeBuilder: + def __init__(self, obj: Any) -> None: ... + def set_module_dict(self): ... + def set_module_list(self): ... + def set_parameter_list(self): ... + def set_parameter_dict(self): ... + def add_attribute( + self, + name: str, + ty: JitType, + is_param: _bool, + is_buffer: _bool, + ): ... + def add_module(self, name: str, meta: ConcreteModuleType): ... + def add_constant(self, name: str, value: Any): ... + def add_overload(self, method_name: str, overloaded_method_names: list[str]): ... + def add_builtin_function(self, name: str, symbol_name: str): ... + def add_failed_attribute(self, name: str, failure_reason: str): ... + def add_function_attribute( + self, + name: str, + ty: JitType, + func: Callable[..., Any], + ): ... + def add_ignored_attribute(self, name: str): ... + def add_ignored_attributes(self, names: list[str]): ... + def add_forward_hook(self, hook: Callable[..., Any]): ... + def add_forward_pre_hook(self, pre_hook: Callable[..., Any]): ... + +class ConcreteModuleType: + def get_constants(self) -> dict[str, Any]: ... + def equals(self, other: ConcreteModuleType) -> _bool: ... + @staticmethod + def from_jit_type(ty: JitType) -> ConcreteModuleType: ... + +class CallStack: + def __init__(self, name: str, range: SourceRange) -> None: ... + +class ErrorReport: + def __init__(self, range: SourceRange) -> None: ... + def what(self) -> str: ... + @staticmethod + def call_stack() -> str: ... + +class CompilationUnit: + def __init__(self, lang: str = ..., _frames_up: _int = ...) -> None: ... + def find_function(self, name: str) -> ScriptFunction: ... + def __getattr__(self, name: str) -> ScriptFunction: ... + def define( + self, + script: str, + rcb: ResolutionCallback = ..., + _frames_up: _int = ..., + ): ... + def get_interface(self, name: str) -> InterfaceType: ... + def get_functions(self) -> list[ScriptFunction]: ... + def create_function( + self, + name: str, + graph: Graph, + shouldMangle: _bool = ..., + ) -> ScriptFunction: ... + def get_class(self, name: str) -> ClassType: ... + +class ScriptObject: + def setattr(self, name: str, value: Any): ... + def _get_method(self, name: str) -> ScriptMethod: ... + def _type(self) -> ClassType: ... + +class ScriptModule(ScriptObject): + def _method_names(self) -> list[str]: ... + def _get_method(self, name: str) -> ScriptMethod: ... + +class LiteScriptModule: + def __call__(self, *input): ... + def find_method(self, method_name: str): ... + def forward(self, *input) -> list[str]: ... + def run_method(self, method_name: str, *input): ... + +# NOTE: switch to collections.abc.Callable in python 3.9 +class ScriptFunction(Generic[P, R]): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... + def save(self, filename: str, _extra_files: dict[str, bytes]) -> None: ... + def save_to_buffer(self, _extra_files: dict[str, bytes]) -> bytes: ... + @property + def graph(self) -> Graph: ... + def inlined_graph(self) -> Graph: ... + def schema(self) -> FunctionSchema: ... + def code(self) -> str: ... + def name(self) -> str: ... + @property + def qualified_name(self) -> str: ... + +# NOTE: switch to collections.abc.Callable in python 3.9 +class ScriptMethod(Generic[P, R]): + graph: Graph + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... + @property + def owner(self) -> ScriptModule: ... + @property + def name(self) -> str: ... + @property + def schema(self) -> FunctionSchema: ... + +class ScriptDict(Generic[K, T]): + def __init__(self, dict: dict[K, T]) -> None: ... + def __len__(self) -> _int: ... + def __contains__(self, key: K) -> _bool: ... + def __getitem__(self, key: K) -> T: ... + def __setitem__(self, key: K, value: T) -> None: ... + def __delitem__(self, key: K) -> None: ... + def __iter__(self) -> Iterator[K]: ... + def items(self) -> Iterator[tuple[K, T]]: ... + def keys(self) -> Iterator[K]: ... + +class ScriptList(Generic[T]): + def __init__(self, list: list[T]) -> None: ... + def __len__(self) -> _int: ... + def __contains__(self, item: T) -> _bool: ... + @overload + def __getitem__(self, idx: _int) -> T: ... + @overload + def __getitem__(self, idx: slice) -> ScriptList[T]: ... + @overload + def __setitem__(self, idx: _int, value: T) -> None: ... + @overload + def __setitem__(self, idx: slice, value: list[T]) -> None: ... + def __delitem__(self, idx: _int) -> None: ... + def __iter__(self) -> Iterator[T]: ... + def count(self, value: T) -> _int: ... + def remove(self, value: T) -> None: ... + def append(self, value: T) -> None: ... + def clear(self) -> None: ... + @overload + def extend(self, values: list[T]) -> None: ... + @overload + def extend(self, values: Iterable[T]) -> None: ... + @overload + def pop(self) -> T: ... + @overload + def pop(self, idx: _int) -> T: ... + +class ModuleDict: + def __init__(self, mod: ScriptModule) -> None: ... + def items(self) -> list[tuple[str, Any]]: ... + +class ParameterDict: + def __init__(self, mod: ScriptModule) -> None: ... + +class BufferDict: + def __init__(self, mod: ScriptModule) -> None: ... + +# Defined in torch/csrc/jit/api/module.h +class Module: ... + +# Defined in torch/csrc/Module.cpp +def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension +def _autograd_init() -> _bool: ... # THPAutograd_initExtension +def _add_docstr(obj: T, doc_obj: str) -> T: ... # THPModule_addDocStr +def _init_names(arg: Sequence[type]) -> None: ... # THPModule_initNames +def _has_distributed() -> _bool: ... # THPModule_hasDistributed +def _set_default_tensor_type(type) -> None: ... # THPModule_setDefaultTensorType +def _set_default_dtype(d: _dtype) -> None: ... # THPModule_setDefaultDtype +def _infer_size(arg1: Size, arg2: Size) -> Size: ... # THPModule_inferSize +def _crash_if_csrc_asan() -> _int: ... # THPModule_crashIfCsrcASAN +def _crash_if_csrc_ubsan() -> _int: ... # THPModule_crashIfCsrcUBSAN +def _crash_if_aten_asan() -> _int: ... # THPModule_crashIfATenASAN +def _show_config() -> str: ... # THPModule_showConfig +def _cxx_flags() -> str: ... # THPModule_cxxFlags +def _parallel_info() -> str: ... # THPModule_parallelInfo +def _get_cpu_capability() -> str: ... # THPModule_getCpuCapability +def _set_backcompat_broadcast_warn( + arg: _bool, +) -> None: ... # THPModule_setBackcompatBroadcastWarn +def _get_backcompat_broadcast_warn() -> ( + _bool +): ... # THPModule_getBackcompatBroadcastWarn +def _set_backcompat_keepdim_warn( + arg: _bool, +) -> None: ... # THPModule_setBackcompatKeepdimWarn +def _get_backcompat_keepdim_warn() -> _bool: ... # THPModule_getBackcompatKeepdimWarn +def get_num_thread() -> _int: ... # THPModule_getNumThreads +def set_num_threads(nthreads: _int) -> None: ... # THPModule_setNumThreads +def get_num_interop_threads() -> _int: ... # THPModule_getNumInteropThreads +def set_num_interop_threads( + nthreads: _int, +) -> None: ... # THPModule_setNumInteropThreads +def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN +def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN +def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP +def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash +def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP +def _set_sdp_use_mem_efficient( + arg: _bool, +) -> None: ... # THPModule_setSDPUseMemEfficient +def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP +def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath +def _get_math_sdp_allow_fp16_bf16_reduction() -> ( + _bool +): ... # THPModule_allowFP16BF16ReductionMathSDP +def _set_math_sdp_allow_fp16_bf16_reduction( + arg: _bool, +) -> None: ... # THPModule_setAllowFP16BF16ReductionMathSDP +def _get_overrideable_sdp_enabled() -> ( + _bool +): ... # THPModule_userEnabledOverrideableSDP +def _set_sdp_use_overrideable( + arg: _bool, +) -> None: ... # THPModule_setSDPUseOverrideable +def _get_sdp_priority_order() -> list[_int]: ... # THPModule_getSDPPriorityOrder +def _set_sdp_priority_order( + arg: list[_int], +) -> None: ... # THPModule_setSDPPriorityOrder +def _get_cudnn_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP +def _set_sdp_use_cudnn(arg: _bool) -> None: ... # THPModule_setSDPUseMath +def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn +def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn +def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN +def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN +def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN +def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN +def _get_mkldnn_deterministic() -> _bool: ... # THPModule_deterministicMkldnn +def _set_mkldnn_deterministic( + arg: _bool, +) -> None: ... # THPModule_setDeterministicMkldnn +def _get_onednn_allow_tf32() -> _bool: ... # THPModule_allowTF32OneDNN +def _set_onednn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32OneDNN +def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms +def _get_deterministic_algorithms_warn_only() -> ( + _bool +): ... # THPModule_deterministicAlgorithmsWarnOnly +def _set_deterministic_algorithms( + mode: _bool, + *, + warn_only: _bool = ..., +) -> None: ... # THPModule_setDeterministicAlgorithms +def _get_deterministic_fill_uninitialized_memory() -> ( + _bool +): ... # THPModule_deterministicFillUninitializedMemory +def _set_deterministic_fill_uninitialized_memory( + arg: _bool, +) -> None: ... # THPModule_setDeterministicFillUninitializedMemory +def _get_nnpack_enabled() -> _bool: ... # THPModule_userEnabledNNPACK +def _set_nnpack_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledNNPACK +def _get_warnAlways() -> _bool: ... # THPModule_warnAlways +def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways +def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN +def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN +def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS +def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS +def _get_float32_matmul_precision() -> str: ... # THPModule_float32MatmulPrecision +def _set_float32_matmul_precision( + arg: str, +) -> None: ... # THPModule_setFloat32MatmulPrecision +def _get_cublas_allow_fp16_reduced_precision_reduction() -> ( + _bool +): ... # THPModule_allowFP16ReductionCuBLAS +def _set_cublas_allow_fp16_reduced_precision_reduction( + arg: _bool, +) -> None: ... # THPModule_setAllowFP16ReductionCuBLAS +def _get_cublas_allow_bf16_reduced_precision_reduction() -> ( + _bool +): ... # THPModule_allowBF16ReductionCuBLAS +def _set_cublas_allow_bf16_reduced_precision_reduction( + arg: _bool, +) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS +def _get_cublas_allow_fp16_accumulation() -> ( + _bool +): ... # THPModule_allowFP16AccumulationCuBLAS +def _set_cublas_allow_fp16_accumulation( + arg: _bool, +) -> None: ... # THPModule_setAllowFP16AccumulationCuBLAS +def _get_sm_carveout_experimental() -> _int | None: ... +def _set_sm_carveout_experimental(arg: _int | None) -> None: ... +def _set_conj(x: Tensor, conj: _bool) -> None: ... +def _set_neg(x: Tensor, neg: _bool) -> None: ... +def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ... +def _meta_in_tls_dispatch_include() -> _bool: ... +def _stash_obj_in_tls(key: str, arg: Any) -> None: ... +def _get_obj_in_tls(key: str) -> Any: ... +def _is_key_in_tls(key: str) -> _bool: ... +def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ... +def _select_conv_backend(*args, **kwargs) -> ConvBackend: ... +def _conv_determine_backend_memory_format( + input: Tensor, + weight: Tensor, + backend: ConvBackend, +) -> memory_format: ... +def _has_storage(x: Tensor) -> _bool: ... +def _construct_storage_from_data_pointer( + data_ptr: _int, + device: torch.device, + size: _int, +) -> Storage: ... +def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... +def _group_tensors_by_device_and_dtype( + nested_tensorlists: list[list[Tensor | None]], + with_indices: _bool = False, +) -> dict[ + tuple[torch.device, torch.dtype], + tuple[list[list[Tensor | None]], list[_int]], +]: ... + +# NB: There is no Capsule type in typing, see +# https://github.com/python/cpython/issues/109562 +def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack +def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack +def _get_cpp_backtrace( + frames_to_skip: _int, + maximum_number_of_frames: _int, +) -> str: ... # THPModule_getCppBacktrace +def set_flush_denormal(arg: _bool) -> _bool: ... # THPModule_setFlushDenormal +def get_default_dtype() -> _dtype: ... # THPModule_getDefaultDtype +def _get_default_device() -> str: ... # THPModule_getDefaultDevice +def _get_qengine() -> _int: ... # THPModule_qEngine +def _set_qengine(qengine: _int) -> None: ... # THPModule_setQEngine +def _supported_qengines() -> list[_int]: ... # THPModule_supportedQEngines +def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK +def _check_sparse_tensor_invariants() -> ( + _bool +): ... # THPModule_checkSparseTensorInvariants +def _set_check_sparse_tensor_invariants( + arg: _bool, +) -> None: ... # THPModule_setCheckSparseTensorInvariants +def _is_default_mobile_cpu_allocator_set() -> ( + _bool +): ... # THPModule_isDefaultMobileCPUAllocatorSet +def _set_default_mobile_cpu_allocator() -> ( + None +): ... # THPModule_setDefaultMobileCPUAllocator +def _unset_default_mobile_cpu_allocator() -> ( + None +): ... # THPModule_unsetDefaultMobileCPUAllocator +def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction +def _is_torch_function_all_disabled() -> ( + _bool +): ... # THPModule_isAllDisabledTorchFunction +def _has_torch_function( + args: Iterable[Any], +) -> _bool: ... # THPModule_has_torch_function +def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary +def _has_torch_function_variadic( + *args: Any, +) -> _bool: ... # THPModule_has_torch_function_variadic +def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting +def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting +def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython +def _log_api_usage_metadata( + event: str, + metadata_map: dict[str, str], +) -> None: ... # LogAPIUsageMetadataFromPython +def _demangle(str) -> str: ... # c10::demangle +def _disabled_torch_function_impl( + func: Callable, + types: Iterable[type], + args: tuple, + kwargs: dict, +) -> Any: ... # THPModule_disable_torch_function +def _disabled_torch_dispatch_impl( + func: Callable, + types: Iterable[type], + args: tuple, + kwargs: dict, +) -> Any: ... # THPModule_disable_dispatch_function +def _get_linalg_preferred_backend() -> _LinalgBackend: ... +def _set_linalg_preferred_backend(arg: _LinalgBackend): ... + +class _LinalgBackend: + Default: _LinalgBackend + Cusolver: _LinalgBackend + Magma: _LinalgBackend + +# mypy error: +# Detected enum "torch._C.BatchNormBackend" in a type stub with zero +# members. There is a chance this is due to a recent change in the semantics +# of enum membership. If so, use `member = value` to mark an enum member, +# instead of `member: type` +class BatchNormBackend(Enum): ... # type: ignore[misc] + +def _get_blas_preferred_backend() -> _BlasBackend: ... +def _set_blas_preferred_backend(arg: _BlasBackend): ... + +class _BlasBackend: + Default: _BlasBackend + Cublas: _BlasBackend + Cublaslt: _BlasBackend + Ck: _BlasBackend + +def _get_rocm_fa_preferred_backend() -> torch._C._ROCmFABackend: ... +def _set_rocm_fa_preferred_backend(arg: torch._C._ROCmFABackend): ... + +class _ROCmFABackend: + Default: _ROCmFABackend + AOTriton: _ROCmFABackend + Ck: _ROCmFABackend + +# mypy error: +# Error (MYPY) [misc] +# Detected enum "torch._C.ConvBackend" in a type stub with zero members. +# There is a chance this is due to a recent change in the semantics of enum +# membership. If so, use `member = value` to mark an enum member, instead of +# `member: type` +class ConvBackend(Enum): ... # type: ignore[misc] + +class Tag(Enum): + core = 0 + cudagraph_unsafe = 1 + data_dependent_output = 2 + dynamic_output_shape = 3 + flexible_layout = 4 + generated = 5 + inplace_view = 6 + maybe_aliasing_or_mutating = 7 + needs_contiguous_strides = 8 + needs_exact_strides = 9 + needs_fixed_stride_order = 10 + nondeterministic_bitwise = 11 + nondeterministic_seeded = 12 + pointwise = 13 + pt2_compliant_tag = 14 + view_copy = 15 + +# Defined in `valgrind.h` and `callgrind.h` respectively. +def _valgrind_supported_platform() -> _bool: ... # NVALGRIND +def _valgrind_toggle() -> None: ... # CALLGRIND_TOGGLE_COLLECT +def _valgrind_toggle_and_dump_stats() -> ( + None +): ... # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS + +has_openmp: _bool +has_mkl: _bool +_has_kleidiai: _bool +_has_mps: _bool +has_lapack: _bool +_has_cuda: _bool +_has_magma: _bool +_has_xpu: _bool +_has_mkldnn: _bool +_has_cudnn: _bool +_has_cusparselt: _bool +has_spectral: _bool +_GLIBCXX_USE_CXX11_ABI: _bool +default_generator: Generator + +# Defined in torch/csrc/autograd/init.cpp +def _set_grad_enabled(enabled: _bool) -> None: ... +def is_grad_enabled() -> _bool: ... +def _set_fwd_grad_enabled(enabled: _bool) -> None: ... +def _is_fwd_grad_enabled() -> _bool: ... +def _any_requires_grad(*args, **kwargs) -> _bool: ... +def _any_output_is_alias_to_input_or_output(*args, **kwargs) -> _bool: ... +def is_inference_mode_enabled() -> _bool: ... +@overload +def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ... +@overload +def set_autocast_enabled(enabled: _bool) -> None: ... +@overload +def is_autocast_enabled(device_type: str) -> _bool: ... +@overload +def is_autocast_enabled() -> _bool: ... +def set_autocast_dtype(device_type: str, dtype: _dtype) -> None: ... +def get_autocast_dtype(device_type: str) -> _dtype: ... +def clear_autocast_cache() -> None: ... +def set_autocast_cpu_enabled(enabled: _bool) -> None: ... +def is_autocast_cpu_enabled() -> _bool: ... +def _is_any_autocast_enabled() -> _bool: ... +def _is_autocast_available(device_type: str) -> _bool: ... +def set_autocast_cpu_dtype(dtype: _dtype) -> None: ... +def set_autocast_gpu_dtype(dtype: _dtype) -> None: ... +def get_autocast_cpu_dtype() -> _dtype: ... +def get_autocast_gpu_dtype() -> _dtype: ... +def autocast_increment_nesting() -> _int: ... +def autocast_decrement_nesting() -> _int: ... +def is_autocast_cache_enabled() -> _bool: ... +def set_autocast_cache_enabled(enabled: _bool) -> None: ... +def _increment_version(tensors: Iterable[Tensor]) -> None: ... +def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ... +def is_anomaly_enabled() -> _bool: ... +def is_anomaly_check_nan_enabled() -> _bool: ... +def _is_multithreading_enabled() -> _bool: ... +def _set_multithreading_enabled(enabled: _bool) -> None: ... +def _set_view_replay_enabled(enabled: _bool) -> None: ... +def _is_view_replay_enabled() -> _bool: ... +def _enter_dual_level() -> _int: ... +def _exit_dual_level(level: _int) -> None: ... +def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ... +def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ... +def __set_forward_AD_enabled(enabled: _bool) -> None: ... +def __is_forward_AD_enabled() -> _bool: ... +def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... +def _reset_default_hooks() -> None: ... +def _is_torch_function_mode_enabled() -> _bool: ... +def _push_on_torch_function_stack(cls: Any) -> None: ... +def _pop_torch_function_stack() -> Any: ... +def _get_function_stack_at(idx: _int) -> Any: ... +def _len_torch_function_stack() -> _int: ... +def _set_torch_dispatch_mode(cls: Any) -> None: ... +def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ... +def _pop_torch_dispatch_stack(mode_key: _TorchDispatchModeKey | None = None) -> Any: ... +def _get_dispatch_mode(mode_key: _TorchDispatchModeKey | None) -> Any: ... +def _unset_dispatch_mode(mode: _TorchDispatchModeKey) -> TorchDispatchMode | None: ... +def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ... +def _get_dispatch_stack_at(idx: _int) -> Any: ... +def _len_torch_dispatch_stack() -> _int: ... +def _activate_gpu_trace() -> None: ... + +class _DisableTorchDispatch: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _EnableTorchFunction: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _EnablePythonDispatcher: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _DisablePythonDispatcher: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _EnablePreDispatch: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _DisableFuncTorch: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _DisableAutocast: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _InferenceMode: + def __init__(self, enabled: _bool) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +def _set_autograd_fallback_mode(mode: str) -> None: ... +def _get_autograd_fallback_mode() -> str: ... + +# Defined in torch/csrc/jit/python/script_init.cpp +class LoggerBase: ... +class NoopLogger(LoggerBase): ... +class LockingLogger(LoggerBase): ... + +class AggregationType(Enum): + SUM = 0 + AVG = 1 + +class FileCheck: + def run(self, test_string: str) -> None: ... + def check(self, test_string: str) -> FileCheck: ... + def check_not(self, test_string: str) -> FileCheck: ... + def check_same(self, test_string: str) -> FileCheck: ... + def check_next(self, test_string: str) -> FileCheck: ... + def check_count( + self, + test_string: str, + count: _int, + exactly: _bool = False, + ) -> FileCheck: ... + def check_dag(self, test_string: str) -> FileCheck: ... + def check_source_highlighted(self, test_string: str) -> FileCheck: ... + def check_regex(self, test_string: str) -> FileCheck: ... + +# Defined in torch/csrc/jit/python/init.cpp +class PyTorchFileReader: + @overload + def __init__(self, name: str) -> None: ... + @overload + def __init__(self, buffer: IO[bytes]) -> None: ... + def get_record(self, name: str) -> bytes: ... + def get_all_records(self) -> list[str]: ... + def serialization_id(self) -> str: ... + +class PyTorchFileWriter: + @overload + def __init__( + self, + name: str, + compute_crc32: _bool = True, + storage_alignment: _int = 64, + ) -> None: ... + @overload + def __init__( + self, + buffer: IO[bytes], + compute_crc32: _bool = True, + storage_alignment: _int = 64, + ) -> None: ... + def write_record( + self, + name: str, + data: Storage | bytes | _int, + size: _int, + ) -> None: ... + def write_end_of_file(self) -> None: ... + def set_min_version(self, version: _int) -> None: ... + def get_all_written_records(self) -> list[str]: ... + def archive_name(self) -> str: ... + def serialization_id(self) -> str: ... + +def _jit_get_inline_everything_mode() -> _bool: ... +def _jit_set_inline_everything_mode(enabled: _bool) -> None: ... +def _jit_get_logging_option() -> str: ... +def _jit_set_logging_option(option: str) -> None: ... +def _jit_set_logging_stream(stream_name: str) -> None: ... +def _jit_pass_cse(Graph) -> _bool: ... +def _jit_pass_dce(Graph) -> None: ... +def _jit_pass_dce_graph(Graph) -> None: ... +def _jit_pass_lint(Graph) -> None: ... + +# Defined in torch/csrc/jit/python/python_custom_class.cpp +def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ... + +# Defined in torch/csrc/Module.cpp +def _rename_privateuse1_backend(backend: str) -> None: ... +def _get_privateuse1_backend_name() -> str: ... + +# Defined in torch/csrc/Generator.cpp +class Generator: + device: _device + def __init__(self, device: DeviceLikeType | None = None) -> None: ... + def __reduce__( + self, + ) -> tuple[type[Generator], tuple[_device], tuple[_int, _int | None, Tensor]]: ... + def __setstate__(self, state: tuple[_int, _int | None, Tensor]) -> None: ... + def get_state(self) -> Tensor: ... + def set_state(self, _new_state: Tensor) -> Generator: ... + def clone_state(self) -> Generator: ... + def graphsafe_get_state(self) -> Generator: ... + def graphsafe_set_state(self, _new_state: Generator) -> Generator: ... + def set_offset(self, offset: _int) -> Generator: ... + def get_offset(self) -> _int: ... + def manual_seed(self, seed: _int) -> Generator: ... + def seed(self) -> _int: ... + def initial_seed(self) -> _int: ... + +# Defined in torch/csrc/utils/python_dispatch.cpp + +class _DispatchOperatorHandle: + def schema(self) -> FunctionSchema: ... + def debug(self) -> str: ... + def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ... + +class _DispatchModule: + def reset(self) -> None: ... + def def_(self, schema: str, alias: str = "") -> _DispatchModule: ... + def def_legacy(self, schema: str) -> _DispatchModule: ... + def def_name_t_t( + self, + name: str, + dispatch: str, + debug: str = "default_def_name_t_t", + ) -> _DispatchModule: ... + def def_schema_t_t( + self, + schema: str, + dispatch: str, + alias: str, + debug: str = "default_def_schema_t_t", + ) -> _DispatchModule: ... + def impl_t_t( + self, + name: str, + dispatch: str, + debug: str = "impl_t_t", + ) -> _DispatchModule: ... + def impl_with_aoti_compile( + self, + ns: str, + op_name_with_overload: str, + dispatch: _dispatchkey, + ) -> None: ... + def impl(self, name: str, dispatch: _dispatchkey, func: Callable) -> None: ... + def define(self, schema: str, alias: str = "") -> str: ... + def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ... + def fallback( + self, + dispatch: _dispatchkey, + func: Callable, + with_keyset: _bool = False, + ) -> None: ... + +_after_ADInplaceOrView_keyset: DispatchKeySet +_after_autograd_keyset: DispatchKeySet + +def _dispatch_library( + kind: str, + name: str, + dispatch: str, + file: str = "", + linenum: Any = 0, +) -> _DispatchModule: ... +def _dispatch_dump(name: str) -> str: ... +def _dispatch_dump_table(name: str) -> str: ... +def _dispatch_check_invariants(name: str) -> None: ... +def _dispatch_check_all_invariants() -> None: ... +def _dispatch_call_boxed(handle: _DispatchOperatorHandle, *args, **kwargs) -> Any: ... +def _dispatch_find_schema_or_throw( + name: str, + overload_name: str, +) -> _DispatchOperatorHandle: ... +def _dispatch_set_report_error_callback( + handle: _DispatchOperatorHandle, + callback: Callable, +) -> None: ... +def _dispatch_has_kernel(name: str) -> _bool: ... +def _dispatch_has_kernel_for_dispatch_key( + name: str, + dispatch: _dispatchkey, +) -> _bool: ... +def _dispatch_has_kernel_for_any_dispatch_key( + name: str, + dispatch_key_set: DispatchKeySet, +) -> _bool: ... +def _dispatch_kernel_for_dispatch_key_is_fallthrough( + name: str, + dispatch: _dispatchkey, +) -> _bool: ... +def _dispatch_has_computed_kernel_for_dispatch_key( + name: str, + dispatch: _dispatchkey, +) -> _bool: ... +def _dispatch_find_dangling_impls() -> list[str]: ... +def _dispatch_get_all_op_names() -> list[str]: ... +def _dispatch_tls_set_dispatch_key_excluded( + dispatch: _dispatchkey, + val: _bool, +) -> None: ... +def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ... +def _dispatch_tls_set_dispatch_key_included( + dispatch: _dispatchkey, + val: _bool, +) -> None: ... +def _dispatch_tls_is_dispatch_key_included(dispatch: _dispatchkey) -> _bool: ... +def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ... +def _dispatch_key_name(dispatch: _dispatchkey) -> str: ... +def _dispatch_key_for_device(device_type: str) -> str: ... +def _parse_dispatch_key(key: str) -> DispatchKey | None: ... +def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ... +def _dispatch_num_backends() -> _int: ... +def _dispatch_pystub(name: str, overload: str) -> tuple[str, str] | None: ... +def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ... +def _functionality_to_backend_keys(dispatch: _dispatchkey) -> list[DispatchKey]: ... +def _functionalization_reapply_views_tls() -> _bool: ... +def _only_lift_cpu_tensors() -> _bool: ... +def _set_only_lift_cpu_tensors(value: _bool) -> None: ... +def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ... +def _set_warn_deprecated_on_mutable_data_ptr(tensor: Tensor) -> None: ... + +class DispatchKey(Enum): + Undefined = ... + FPGA = ... + MAIA = ... + Vulkan = ... + Metal = ... + MKLDNN = ... + OpenGL = ... + OpenCL = ... + IDEEP = ... + CustomRNGKeyId = ... + MkldnnCPU = ... + Sparse = ... + SparseCsr = ... + NestedTensor = ... + Dense = ... + PythonTLSSnapshot = ... + PreDispatch = ... + PythonDispatcher = ... + Python = ... + FuncTorchDynamicLayerBackMode = ... + ZeroTensor = ... + Conjugate = ... + Negative = ... + BackendSelect = ... + Named = ... + AutogradOther = ... + AutogradFunctionality = ... + AutogradNestedTensor = ... + Tracer = ... + Autocast = ... + AutocastCPU = ... + AutocastCUDA = ... + Batched = ... + VmapMode = ... + FuncTorchGradWrapper = ... + FuncTorchBatched = ... + BatchedNestedTensor = ... + FuncTorchVmapMode = ... + FuncTorchDynamicLayerFrontMode = ... + Functionalize = ... + TESTING_ONLY_GenericWrapper = ... + TESTING_ONLY_GenericMode = ... + ADInplaceOrView = ... + Autograd = ... + CompositeImplicitAutograd = ... + CompositeImplicitAutogradNestedTensor = ... + CompositeExplicitAutograd = ... + CompositeExplicitAutogradNonFunctional = ... + FuncTorchBatchedDecomposition = ... + CPU = ... + CUDA = ... + HIP = ... + XLA = ... + MTIA = ... + MPS = ... + IPU = ... + XPU = ... + HPU = ... + VE = ... + Lazy = ... + Meta = ... + PrivateUse1 = ... + PrivateUse2 = ... + PrivateUse3 = ... + QuantizedCPU = ... + QuantizedCUDA = ... + QuantizedHIP = ... + QuantizedXLA = ... + QuantizedMTIA = ... + QuantizedMPS = ... + QuantizedIPU = ... + QuantizedXPU = ... + QuantizedHPU = ... + QuantizedVE = ... + QuantizedLazy = ... + QuantizedMeta = ... + QuantizedPrivateUse1 = ... + QuantizedPrivateUse2 = ... + QuantizedPrivateUse3 = ... + SparseCPU = ... + SparseCUDA = ... + SparseHIP = ... + SparseXLA = ... + SparseMTIA = ... + SparseMPS = ... + SparseIPU = ... + SparseXPU = ... + SparseHPU = ... + SparseVE = ... + SparseLazy = ... + SparseMeta = ... + SparsePrivateUse1 = ... + SparsePrivateUse2 = ... + SparsePrivateUse3 = ... + SparseCsrCPU = ... + SparseCsrCUDA = ... + SparseCsrHIP = ... + SparseCsrXLA = ... + SparseCsrMTIA = ... + SparseCsrMPS = ... + SparseCsrIPU = ... + SparseCsrXPU = ... + SparseCsrHPU = ... + SparseCsrVE = ... + SparseCsrLazy = ... + SparseCsrMeta = ... + SparseCsrPrivateUse1 = ... + SparseCsrPrivateUse2 = ... + SparseCsrPrivateUse3 = ... + NestedTensorCPU = ... + NestedTensorCUDA = ... + NestedTensorHIP = ... + NestedTensorXLA = ... + NestedTensorMTIA = ... + NestedTensorMPS = ... + NestedTensorIPU = ... + NestedTensorXPU = ... + NestedTensorHPU = ... + NestedTensorVE = ... + NestedTensorLazy = ... + NestedTensorMeta = ... + NestedTensorPrivateUse1 = ... + NestedTensorPrivateUse2 = ... + NestedTensorPrivateUse3 = ... + AutogradCPU = ... + AutogradCUDA = ... + AutogradHIP = ... + AutogradXLA = ... + AutogradMTIA = ... + AutogradMPS = ... + AutogradIPU = ... + AutogradXPU = ... + AutogradHPU = ... + AutogradVE = ... + AutogradLazy = ... + AutogradMeta = ... + AutogradPrivateUse1 = ... + AutogradPrivateUse2 = ... + AutogradPrivateUse3 = ... + +class DispatchKeySet: + def __init__(self, key: DispatchKey) -> None: ... + def __or__(self, other: DispatchKeySet) -> DispatchKeySet: ... + def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ... + def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ... + def raw_repr(self) -> _int: ... + @staticmethod + def from_raw_repr(raw: _int) -> DispatchKeySet: ... + def highestPriorityTypeId(self) -> DispatchKey: ... + def has(self, k: _dispatchkey) -> _bool: ... + def add(self, k: _dispatchkey) -> DispatchKeySet: ... + def remove(self, k: _dispatchkey) -> DispatchKeySet: ... + +_dispatch_autogradother_backends: DispatchKeySet +_additional_keys_to_prop_for_wrapper_tensors: DispatchKeySet + +def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ... +def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ... +def _dispatch_keyset_full() -> DispatchKeySet: ... +def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ... +def _dispatch_get_backend_keyset_from_autograd( + dispatch: _dispatchkey, +) -> DispatchKeySet: ... +def _dispatch_keys(tensor: Tensor) -> DispatchKeySet: ... +def _dispatch_tls_local_exclude_set() -> DispatchKeySet: ... +def _dispatch_tls_local_include_set() -> DispatchKeySet: ... +def _dispatch_is_included_in_alias( + dispatch_a: _dispatchkey, + dispatch_b: _dispatchkey, +) -> _bool: ... +def _propagate_xla_data(a: Tensor, b: Tensor) -> None: ... +def _replace_(a: Tensor, b: Tensor) -> None: ... +def _commit_update(a: Tensor) -> None: ... + +class _ExcludeDispatchKeyGuard: + def __init__(self, keyset: DispatchKeySet) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _IncludeDispatchKeyGuard: + def __init__(self, k: DispatchKey) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _ForceDispatchKeyGuard: + def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _PreserveDispatchKeyGuard: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _AutoDispatchBelowAutograd: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +class _AutoDispatchBelowADInplaceOrView: + def __init__(self) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ... +def _dispatch_get_registrations_for_dispatch_key( + dispatch_key: str = "", +) -> list[str]: ... +def _are_functorch_transforms_active() -> _bool: ... + +# Define in torch/csrc/autograd/init.cpp +def _set_python_dispatcher(dispatcher: object) -> None: ... +def _get_nested_int(id: _int, coeff: _int) -> SymInt: ... +def _get_constant_bool_symnode(val: _bool) -> Any: ... + +class _TorchDispatchModeKey(Enum): + FAKE = ... + PROXY = ... + FUNCTIONAL = ... + +class _SetExcludeDispatchKeyGuard: + def __init__(self, k: DispatchKey, enabled: _bool) -> None: ... + def __enter__(self): ... + def __exit__(self, *exc_info: object) -> None: ... + +# Defined in torch/csrc/utils/schema_info.h + +class _SchemaInfo: + def __init__(self, schema: FunctionSchema) -> None: ... + @overload + def is_mutable(self) -> _bool: ... + @overload + def is_mutable(self, name: str) -> _bool: ... + def has_argument(self, name: str) -> _bool: ... + +# Defined in torch/csrc/utils/init.cpp +class BenchmarkConfig: + num_calling_threads: _int + num_worker_threads: _int + num_warmup_iters: _int + num_iters: _int + profiler_output_path: str + +class BenchmarkExecutionStats: + latency_avg_ms: _float + num_iters: _int + +class ThroughputBenchmark: + def __init__(self, module: Any) -> None: ... + def add_input(self, *args: Any, **kwargs: Any) -> None: ... + def run_once(self, *args: Any, **kwargs: Any) -> Any: ... + def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ... + +# Defined in torch/csrc/Storage.cpp +class StorageBase: ... + +# TODO: where +class DoubleTensor(Tensor): ... +class FloatTensor(Tensor): ... +class BFloat16Tensor(Tensor): ... +class LongTensor(Tensor): ... +class IntTensor(Tensor): ... +class ShortTensor(Tensor): ... +class HalfTensor(Tensor): ... +class CharTensor(Tensor): ... +class ByteTensor(Tensor): ... +class BoolTensor(Tensor): ... + +# Defined in torch/csrc/autograd/python_engine.cpp +class _ImperativeEngine: + def queue_callback(self, callback: Callable[[], None]) -> None: ... + def run_backward(self, *args: Any, **kwargs: Any) -> tuple[Tensor, ...]: ... + def is_checkpoint_valid(self) -> _bool: ... + +# Defined in torch/csrc/autograd/python_variable.cpp +class _TensorMeta(type): ... + +_Index: TypeAlias = SupportsIndex | _bool | _int | slice | EllipsisType | Tensor | None | _NestedSequence[_bool | _int | slice | EllipsisType | Tensor | None] # fmt: skip + +# Defined in torch/csrc/autograd/python_variable.cpp +class TensorBase(metaclass=_TensorMeta): + requires_grad: _bool + retains_grad: _bool + shape: Size + data: Tensor + names: list[str] + device: _device + dtype: _dtype + layout: _layout + real: Tensor + imag: Tensor + T: Tensor + H: Tensor + mT: Tensor + mH: Tensor + ndim: _int + output_nr: _int + _version: _int + _base: Tensor | None + _cdata: _int + grad_fn: _Node | None + _grad_fn: Any + _grad: Tensor | None + grad: Tensor | None + _backward_hooks: dict[_int, Callable[[Tensor], Tensor | None]] | None + nbytes: _int + itemsize: _int + _has_symbolic_sizes_strides: _bool + + def _view_func_unsafe( + self, + new_base: Tensor, + symint_visitor_fn: Callable[[_int], _int] | None = None, + tensor_visitor_fn: Callable[[Tensor], Tensor] | None = None, + ): ... + def __abs__(self) -> Tensor: ... + def __add__(self, other: Tensor | Number | _complex) -> Tensor: ... + @overload + def __and__(self, other: Tensor) -> Tensor: ... + @overload + def __and__(self, other: Number | _complex) -> Tensor: ... + @overload + def __and__(self, other: Tensor | _int) -> Tensor: ... + def __bool__(self) -> _bool: ... + def __complex__(self) -> _complex: ... + def __contains__(self, item: Any, /) -> _bool: ... + def __div__(self, other: Tensor | Number | _complex) -> Tensor: ... + @overload + def __eq__(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[overload-overlap] + @overload + def __eq__(self, other: object) -> _bool: ... + def __float__(self) -> _float: ... + def __floordiv__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __ge__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __getitem__(self, indices: _Index | tuple[_Index, ...], /) -> Tensor: ... + def __gt__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __iadd__(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034 + @overload + def __iand__(self, other: Tensor) -> Tensor: ... + @overload + def __iand__(self, other: Number | _complex) -> Tensor: ... + @overload + def __iand__(self, other: Tensor | _int) -> Tensor: ... + def __idiv__(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034 + def __ifloordiv__(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034 + @overload + def __ilshift__(self, other: Tensor) -> Tensor: ... + @overload + def __ilshift__(self, other: Number | _complex) -> Tensor: ... + @overload + def __ilshift__(self, other: Tensor | _int) -> Tensor: ... + def __imod__(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034 + def __imul__(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034 + def __index__(self) -> _int: ... + @overload + def __init__( + self, + *args: Any, + device: DeviceLikeType | None = None, + ) -> None: ... + @overload + def __init__(self, storage: Storage) -> None: ... + @overload + def __init__(self, other: Tensor) -> None: ... + @overload + def __init__( + self, + size: _size, + *, + device: DeviceLikeType | None = None, + ) -> None: ... + def __int__(self) -> _int: ... + def __invert__(self) -> Tensor: ... + @overload + def __ior__(self, other: Tensor) -> Tensor: ... + @overload + def __ior__(self, other: Number | _complex) -> Tensor: ... + @overload + def __ior__(self, other: Tensor | _int) -> Tensor: ... + @overload + def __irshift__(self, other: Tensor) -> Tensor: ... + @overload + def __irshift__(self, other: Number | _complex) -> Tensor: ... + @overload + def __irshift__(self, other: Tensor | _int) -> Tensor: ... + def __isub__(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034 + @overload + def __ixor__(self, other: Tensor) -> Tensor: ... + @overload + def __ixor__(self, other: Number | _complex) -> Tensor: ... + @overload + def __ixor__(self, other: Tensor | _int) -> Tensor: ... + def __le__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __long__(self) -> _int: ... + @overload + def __lshift__(self, other: Tensor) -> Tensor: ... + @overload + def __lshift__(self, other: Number | _complex) -> Tensor: ... + @overload + def __lshift__(self, other: Tensor | _int) -> Tensor: ... + def __lt__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __matmul__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __mod__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __mul__(self, other: Tensor | Number | _complex) -> Tensor: ... + @overload + def __ne__(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[overload-overlap] + @overload + def __ne__(self, other: object) -> _bool: ... + def __neg__(self) -> Tensor: ... + def __new__(cls, *args, **kwargs) -> Self: ... + def __nonzero__(self) -> _bool: ... + @overload + def __or__(self, other: Tensor) -> Tensor: ... + @overload + def __or__(self, other: Number | _complex) -> Tensor: ... + @overload + def __or__(self, other: Tensor | _int) -> Tensor: ... + def __pow__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __radd__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __rand__(self, other: Tensor | _int) -> Tensor: ... + def __rfloordiv__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __rmul__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __ror__(self, other: Tensor | _int) -> Tensor: ... + def __rpow__(self, other: Tensor | Number | _complex) -> Tensor: ... # type: ignore[has-type] + @overload + def __rshift__(self, other: Tensor) -> Tensor: ... + @overload + def __rshift__(self, other: Number | _complex) -> Tensor: ... + @overload + def __rshift__(self, other: Tensor | _int) -> Tensor: ... + def __rsub__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __rtruediv__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __rxor__(self, other: Tensor | _int) -> Tensor: ... + def __setitem__( + self, + indices: _Index | tuple[_Index, ...], + value: Tensor | Number, + /, + ) -> None: ... + def __sub__(self, other: Tensor | Number | _complex) -> Tensor: ... + def __truediv__(self, other: Tensor | Number | _complex) -> Tensor: ... + @overload + def __xor__(self, other: Tensor) -> Tensor: ... + @overload + def __xor__(self, other: Number | _complex) -> Tensor: ... + @overload + def __xor__(self, other: Tensor | _int) -> Tensor: ... + def _addmm_activation( + self, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + use_gelu: _bool = False, + ) -> Tensor: ... + def _autocast_to_full_precision( + self, + cuda_enabled: _bool, + cpu_enabled: _bool, + ) -> Tensor: ... + def _autocast_to_reduced_precision( + self, + cuda_enabled: _bool, + cpu_enabled: _bool, + cuda_dtype: _dtype, + cpu_dtype: _dtype, + ) -> Tensor: ... + def _coalesced_(self, coalesced: _bool) -> Tensor: ... + def _conj(self) -> Tensor: ... + def _conj_physical(self) -> Tensor: ... + def _dimI(self) -> _int: ... + def _dimV(self) -> _int: ... + def _indices(self) -> Tensor: ... + def _is_all_true(self) -> Tensor: ... + def _is_any_true(self) -> Tensor: ... + def _is_view(self) -> _bool: ... + def _is_zerotensor(self) -> _bool: ... + def _lazy_clone(self) -> Tensor: ... + @staticmethod + def _make_subclass( + cls: type[S], + data: Tensor, + require_grad: _bool = False, + dispatch_strides: _bool = False, + dispatch_device: _bool = False, + device_for_backend_keys: _device | None = None, + ) -> S: ... + @staticmethod + def _make_wrapper_subclass( + cls: type[S], + size: Sequence[_int | SymInt], + strides: Sequence[_int | SymInt] | None = None, + storage_offset: _int | SymInt | None = None, + memory_format: torch.memory_format | None = None, + dtype: _dtype | None = None, + layout: _layout = strided, + device: _device | None = None, + pin_memory: _bool = False, + requires_grad: _bool = False, + dispatch_sizes_strides_policy: str | None = None, + dispatch_device: _bool = False, + dispatch_layout: _bool = False, + _extra_dispatch_keys: torch.DispatchKeySet | None = None, + storage_size: _int | SymInt | None = None, + ) -> S: ... + def _neg_view(self) -> Tensor: ... + def _nested_tensor_size(self) -> Tensor: ... + def _nested_tensor_storage_offsets(self) -> Tensor: ... + def _nested_tensor_strides(self) -> Tensor: ... + def _nnz(self) -> _int: ... + def _sparse_mask_projection( + self, + mask: Tensor, + accumulate_matches: _bool = False, + ) -> Tensor: ... + def _to_dense( + self, + dtype: _dtype | None = None, + masked_grad: _bool | None = None, + ) -> Tensor: ... + @overload + def _to_sparse( + self, + *, + layout: _layout | None = None, + blocksize: _int | _size | None = None, + dense_dim: _int | None = None, + ) -> Tensor: ... + @overload + def _to_sparse(self, sparse_dim: _int) -> Tensor: ... + def _to_sparse_bsc( + self, + blocksize: _int | _size, + dense_dim: _int | None = None, + ) -> Tensor: ... + def _to_sparse_bsr( + self, + blocksize: _int | _size, + dense_dim: _int | None = None, + ) -> Tensor: ... + def _to_sparse_csc(self, dense_dim: _int | None = None) -> Tensor: ... + def _to_sparse_csr(self, dense_dim: _int | None = None) -> Tensor: ... + def _values(self) -> Tensor: ... + def abs(self) -> Tensor: + r""" + abs() -> Tensor + + See :func:`torch.abs` + """ + + def abs_(self) -> Tensor: + r""" + abs_() -> Tensor + + In-place version of :meth:`~Tensor.abs` + """ + + def absolute(self) -> Tensor: + r""" + absolute() -> Tensor + + Alias for :func:`abs` + """ + + def absolute_(self) -> Tensor: + r""" + absolute_() -> Tensor + + In-place version of :meth:`~Tensor.absolute` + Alias for :func:`abs_` + """ + + def acos(self) -> Tensor: + r""" + acos() -> Tensor + + See :func:`torch.acos` + """ + + def acos_(self) -> Tensor: + r""" + acos_() -> Tensor + + In-place version of :meth:`~Tensor.acos` + """ + + def acosh(self) -> Tensor: + r""" + acosh() -> Tensor + + See :func:`torch.acosh` + """ + + def acosh_(self) -> Tensor: + r""" + acosh_() -> Tensor + + In-place version of :meth:`~Tensor.acosh` + """ + + def add( + self, + other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat, + *, + alpha: Number | _complex | None = 1, + out: Tensor | None = None, + ) -> Tensor: + r""" + add(other, *, alpha=1) -> Tensor + + Add a scalar or tensor to :attr:`self` tensor. If both :attr:`alpha` + and :attr:`other` are specified, each element of :attr:`other` is scaled by + :attr:`alpha` before being used. + + When :attr:`other` is a tensor, the shape of :attr:`other` must be + :ref:`broadcastable ` with the shape of the underlying + tensor + + See :func:`torch.add` + """ + + def add_( + self, + other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat, + *, + alpha: Number | _complex | None = 1, + ) -> Tensor: + r""" + add_(other, *, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.add` + """ + + def addbmm( + self, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + addbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.addbmm` + """ + + def addbmm_( + self, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + addbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.addbmm` + """ + + def addcdiv( + self, + tensor1: Tensor, + tensor2: Tensor, + *, + value: Number | _complex = 1, + ) -> Tensor: + r""" + addcdiv(tensor1, tensor2, *, value=1) -> Tensor + + See :func:`torch.addcdiv` + """ + + def addcdiv_( + self, + tensor1: Tensor, + tensor2: Tensor, + *, + value: Number | _complex = 1, + ) -> Tensor: + r""" + addcdiv_(tensor1, tensor2, *, value=1) -> Tensor + + In-place version of :meth:`~Tensor.addcdiv` + """ + + def addcmul( + self, + tensor1: Tensor, + tensor2: Tensor, + *, + value: Number | _complex = 1, + ) -> Tensor: + r""" + addcmul(tensor1, tensor2, *, value=1) -> Tensor + + See :func:`torch.addcmul` + """ + + def addcmul_( + self, + tensor1: Tensor, + tensor2: Tensor, + *, + value: Number | _complex = 1, + ) -> Tensor: + r""" + addcmul_(tensor1, tensor2, *, value=1) -> Tensor + + In-place version of :meth:`~Tensor.addcmul` + """ + + def addmm( + self, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + addmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.addmm` + """ + + def addmm_( + self, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + addmm_(mat1, mat2, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.addmm` + """ + + def addmv( + self, + mat: Tensor, + vec: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + addmv(mat, vec, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.addmv` + """ + + def addmv_( + self, + mat: Tensor, + vec: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + addmv_(mat, vec, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.addmv` + """ + + def addr( + self, + vec1: Tensor, + vec2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + addr(vec1, vec2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.addr` + """ + + def addr_( + self, + vec1: Tensor, + vec2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + addr_(vec1, vec2, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.addr` + """ + + def adjoint(self) -> Tensor: + r""" + adjoint() -> Tensor + + Alias for :func:`adjoint` + """ + + def align_as(self, other: Tensor) -> Tensor: + r""" + align_as(other) -> Tensor + + Permutes the dimensions of the :attr:`self` tensor to match the dimension order + in the :attr:`other` tensor, adding size-one dims for any new names. + + This operation is useful for explicit broadcasting by names (see examples). + + All of the dims of :attr:`self` must be named in order to use this method. + The resulting tensor is a view on the original tensor. + + All dimension names of :attr:`self` must be present in ``other.names``. + :attr:`other` may contain named dimensions that are not in ``self.names``; + the output tensor has a size-one dimension for each of those new names. + + To align a tensor to a specific order, use :meth:`~Tensor.align_to`. + + Examples:: + + # Example 1: Applying a mask + >>> mask = torch.randint(2, [127, 128], dtype=torch.bool).refine_names('W', 'H') + >>> imgs = torch.randn(32, 128, 127, 3, names=('N', 'H', 'W', 'C')) + >>> imgs.masked_fill_(mask.align_as(imgs), 0) + + + # Example 2: Applying a per-channel-scale + >>> def scale_channels(input, scale): + >>> scale = scale.refine_names('C') + >>> return input * scale.align_as(input) + + >>> num_channels = 3 + >>> scale = torch.randn(num_channels, names=('C',)) + >>> imgs = torch.rand(32, 128, 128, num_channels, names=('N', 'H', 'W', 'C')) + >>> more_imgs = torch.rand(32, num_channels, 128, 128, names=('N', 'C', 'H', 'W')) + >>> videos = torch.randn(3, num_channels, 128, 128, 128, names=('N', 'C', 'H', 'W', 'D')) + + # scale_channels is agnostic to the dimension order of the input + >>> scale_channels(imgs, scale) + >>> scale_channels(more_imgs, scale) + >>> scale_channels(videos, scale) + + .. warning:: + The named tensor API is experimental and subject to change. + """ + + @overload + def align_to( + self, + order: Sequence[str | EllipsisType | None], + ellipsis_idx: _int, + ) -> Tensor: ... + @overload + def align_to(self, names: Sequence[str | EllipsisType | None]) -> Tensor: ... + @overload + def all(self) -> Tensor: + r""" + all(dim=None, keepdim=False) -> Tensor + + See :func:`torch.all` + """ + + @overload + def all(self, dim: _size | None = None, keepdim: _bool = False) -> Tensor: + r""" + all(dim=None, keepdim=False) -> Tensor + + See :func:`torch.all` + """ + + @overload + def all(self, dim: _int, keepdim: _bool = False) -> Tensor: + r""" + all(dim=None, keepdim=False) -> Tensor + + See :func:`torch.all` + """ + + @overload + def all( + self, + dim: str | EllipsisType | None, + keepdim: _bool = False, + ) -> Tensor: + r""" + all(dim=None, keepdim=False) -> Tensor + + See :func:`torch.all` + """ + + def allclose( + self, + other: Tensor, + rtol: _float = 1e-05, + atol: _float = 1e-08, + equal_nan: _bool = False, + ) -> _bool: + r""" + allclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + + See :func:`torch.allclose` + """ + + def amax(self, dim: _int | _size = (), keepdim: _bool = False) -> Tensor: + r""" + amax(dim=None, keepdim=False) -> Tensor + + See :func:`torch.amax` + """ + + def amin(self, dim: _int | _size = (), keepdim: _bool = False) -> Tensor: + r""" + amin(dim=None, keepdim=False) -> Tensor + + See :func:`torch.amin` + """ + + def aminmax( + self, + *, + dim: _int | None = None, + keepdim: _bool = False, + ) -> torch.return_types.aminmax: + r""" + aminmax(*, dim=None, keepdim=False) -> (Tensor min, Tensor max) + + See :func:`torch.aminmax` + """ + + def angle(self) -> Tensor: + r""" + angle() -> Tensor + + See :func:`torch.angle` + """ + + @overload + def any(self) -> Tensor: + r""" + any(dim=None, keepdim=False) -> Tensor + + See :func:`torch.any` + """ + + @overload + def any(self, dim: _size | None = None, keepdim: _bool = False) -> Tensor: + r""" + any(dim=None, keepdim=False) -> Tensor + + See :func:`torch.any` + """ + + @overload + def any(self, dim: _int, keepdim: _bool = False) -> Tensor: + r""" + any(dim=None, keepdim=False) -> Tensor + + See :func:`torch.any` + """ + + @overload + def any( + self, + dim: str | EllipsisType | None, + keepdim: _bool = False, + ) -> Tensor: + r""" + any(dim=None, keepdim=False) -> Tensor + + See :func:`torch.any` + """ + + def apply_(self, callable: Callable) -> Tensor: + r""" + apply_(callable) -> Tensor + + Applies the function :attr:`callable` to each element in the tensor, replacing + each element with the value returned by :attr:`callable`. + + .. note:: + + This function only works with CPU tensors and should not be used in code + sections that require high performance. + """ + + def arccos(self) -> Tensor: + r""" + arccos() -> Tensor + + See :func:`torch.arccos` + """ + + def arccos_(self) -> Tensor: + r""" + arccos_() -> Tensor + + In-place version of :meth:`~Tensor.arccos` + """ + + def arccosh(self) -> Tensor: + r""" + acosh() -> Tensor + + See :func:`torch.arccosh` + """ + + def arccosh_(self) -> Tensor: + r""" + acosh_() -> Tensor + + In-place version of :meth:`~Tensor.arccosh` + """ + + def arcsin(self) -> Tensor: + r""" + arcsin() -> Tensor + + See :func:`torch.arcsin` + """ + + def arcsin_(self) -> Tensor: + r""" + arcsin_() -> Tensor + + In-place version of :meth:`~Tensor.arcsin` + """ + + def arcsinh(self) -> Tensor: + r""" + arcsinh() -> Tensor + + See :func:`torch.arcsinh` + """ + + def arcsinh_(self) -> Tensor: + r""" + arcsinh_() -> Tensor + + In-place version of :meth:`~Tensor.arcsinh` + """ + + def arctan(self) -> Tensor: + r""" + arctan() -> Tensor + + See :func:`torch.arctan` + """ + + def arctan2(self, other: Tensor) -> Tensor: + r""" + arctan2(other) -> Tensor + + See :func:`torch.arctan2` + """ + + def arctan2_(self, other: Tensor) -> Tensor: + r""" + atan2_(other) -> Tensor + + In-place version of :meth:`~Tensor.arctan2` + """ + + def arctan_(self) -> Tensor: + r""" + arctan_() -> Tensor + + In-place version of :meth:`~Tensor.arctan` + """ + + def arctanh(self) -> Tensor: + r""" + arctanh() -> Tensor + + See :func:`torch.arctanh` + """ + + def arctanh_(self) -> Tensor: + r""" + arctanh_(other) -> Tensor + + In-place version of :meth:`~Tensor.arctanh` + """ + + def argmax(self, dim: _int | None = None, keepdim: _bool = False) -> Tensor: + r""" + argmax(dim=None, keepdim=False) -> LongTensor + + See :func:`torch.argmax` + """ + + def argmin(self, dim: _int | None = None, keepdim: _bool = False) -> Tensor: + r""" + argmin(dim=None, keepdim=False) -> LongTensor + + See :func:`torch.argmin` + """ + + @overload + def argsort( + self, + *, + stable: _bool, + dim: _int = -1, + descending: _bool = False, + ) -> Tensor: + r""" + argsort(dim=-1, descending=False) -> LongTensor + + See :func:`torch.argsort` + """ + + @overload + def argsort(self, dim: _int = -1, descending: _bool = False) -> Tensor: + r""" + argsort(dim=-1, descending=False) -> LongTensor + + See :func:`torch.argsort` + """ + + @overload + def argsort( + self, + dim: str | EllipsisType | None, + descending: _bool = False, + ) -> Tensor: + r""" + argsort(dim=-1, descending=False) -> LongTensor + + See :func:`torch.argsort` + """ + + def argwhere(self) -> Tensor: + r""" + argwhere() -> Tensor + + See :func:`torch.argwhere` + """ + + def as_strided( + self, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, + ) -> Tensor: + r""" + as_strided(size, stride, storage_offset=None) -> Tensor + + See :func:`torch.as_strided` + """ + + def as_strided_( + self, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, + ) -> Tensor: + r""" + as_strided_(size, stride, storage_offset=None) -> Tensor + + In-place version of :meth:`~Tensor.as_strided` + """ + + def as_strided_scatter( + self, + src: Tensor, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + storage_offset: _int | SymInt | None = None, + ) -> Tensor: + r""" + as_strided_scatter(src, size, stride, storage_offset=None) -> Tensor + + See :func:`torch.as_strided_scatter` + """ + + def as_subclass(self, cls: type[S]) -> S: + r""" + as_subclass(cls) -> Tensor + + Makes a ``cls`` instance with the same data pointer as ``self``. Changes + in the output mirror changes in ``self``, and the output stays attached + to the autograd graph. ``cls`` must be a subclass of ``Tensor``. + """ + + def asin(self) -> Tensor: + r""" + asin() -> Tensor + + See :func:`torch.asin` + """ + + def asin_(self) -> Tensor: + r""" + asin_() -> Tensor + + In-place version of :meth:`~Tensor.asin` + """ + + def asinh(self) -> Tensor: + r""" + asinh() -> Tensor + + See :func:`torch.asinh` + """ + + def asinh_(self) -> Tensor: + r""" + asinh_() -> Tensor + + In-place version of :meth:`~Tensor.asinh` + """ + + def atan(self) -> Tensor: + r""" + atan() -> Tensor + + See :func:`torch.atan` + """ + + def atan2(self, other: Tensor) -> Tensor: + r""" + atan2(other) -> Tensor + + See :func:`torch.atan2` + """ + + def atan2_(self, other: Tensor) -> Tensor: + r""" + atan2_(other) -> Tensor + + In-place version of :meth:`~Tensor.atan2` + """ + + def atan_(self) -> Tensor: + r""" + atan_() -> Tensor + + In-place version of :meth:`~Tensor.atan` + """ + + def atanh(self) -> Tensor: + r""" + atanh() -> Tensor + + See :func:`torch.atanh` + """ + + def atanh_(self) -> Tensor: + r""" + atanh_(other) -> Tensor + + In-place version of :meth:`~Tensor.atanh` + """ + + def baddbmm( + self, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + baddbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.baddbmm` + """ + + def baddbmm_( + self, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + baddbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.baddbmm` + """ + + @overload + def bernoulli(self, *, generator: Generator | None = None) -> Tensor: + r""" + bernoulli(*, generator=None) -> Tensor + + Returns a result tensor where each :math:`\texttt{result[i]}` is independently + sampled from :math:`\text{Bernoulli}(\texttt{self[i]})`. :attr:`self` must have + floating point ``dtype``, and the result will have the same ``dtype``. + + See :func:`torch.bernoulli` + """ + + @overload + def bernoulli( + self, + p: _float, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + bernoulli(*, generator=None) -> Tensor + + Returns a result tensor where each :math:`\texttt{result[i]}` is independently + sampled from :math:`\text{Bernoulli}(\texttt{self[i]})`. :attr:`self` must have + floating point ``dtype``, and the result will have the same ``dtype``. + + See :func:`torch.bernoulli` + """ + + @overload + def bernoulli_( + self, + p: Tensor, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + bernoulli_(p=0.5, *, generator=None) -> Tensor + + Fills each location of :attr:`self` with an independent sample from + :math:`\text{Bernoulli}(\texttt{p})`. :attr:`self` can have integral + ``dtype``. + + :attr:`p` should either be a scalar or tensor containing probabilities to be + used for drawing the binary random number. + + If it is a tensor, the :math:`\text{i}^{th}` element of :attr:`self` tensor + will be set to a value sampled from + :math:`\text{Bernoulli}(\texttt{p\_tensor[i]})`. In this case `p` must have + floating point ``dtype``. + + See also :meth:`~Tensor.bernoulli` and :func:`torch.bernoulli` + """ + + @overload + def bernoulli_( + self, + p: _float = 0.5, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + bernoulli_(p=0.5, *, generator=None) -> Tensor + + Fills each location of :attr:`self` with an independent sample from + :math:`\text{Bernoulli}(\texttt{p})`. :attr:`self` can have integral + ``dtype``. + + :attr:`p` should either be a scalar or tensor containing probabilities to be + used for drawing the binary random number. + + If it is a tensor, the :math:`\text{i}^{th}` element of :attr:`self` tensor + will be set to a value sampled from + :math:`\text{Bernoulli}(\texttt{p\_tensor[i]})`. In this case `p` must have + floating point ``dtype``. + + See also :meth:`~Tensor.bernoulli` and :func:`torch.bernoulli` + """ + + def bfloat16(self) -> Tensor: + r""" + bfloat16(memory_format=torch.preserve_format) -> Tensor + ``self.bfloat16()`` is equivalent to ``self.to(torch.bfloat16)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + def bincount( + self, + weights: Tensor | None = None, + minlength: _int | SymInt = 0, + ) -> Tensor: + r""" + bincount(weights=None, minlength=0) -> Tensor + + See :func:`torch.bincount` + """ + + @overload + def bitwise_and(self, other: Tensor) -> Tensor: + r""" + bitwise_and() -> Tensor + + See :func:`torch.bitwise_and` + """ + + @overload + def bitwise_and(self, other: Number | _complex) -> Tensor: + r""" + bitwise_and() -> Tensor + + See :func:`torch.bitwise_and` + """ + + @overload + def bitwise_and_(self, other: Tensor) -> Tensor: + r""" + bitwise_and_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_and` + """ + + @overload + def bitwise_and_(self, other: Number | _complex) -> Tensor: + r""" + bitwise_and_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_and` + """ + + @overload + def bitwise_left_shift(self, other: Tensor) -> Tensor: + r""" + bitwise_left_shift(other) -> Tensor + + See :func:`torch.bitwise_left_shift` + """ + + @overload + def bitwise_left_shift(self, other: Number | _complex) -> Tensor: + r""" + bitwise_left_shift(other) -> Tensor + + See :func:`torch.bitwise_left_shift` + """ + + @overload + def bitwise_left_shift_(self, other: Tensor) -> Tensor: + r""" + bitwise_left_shift_(other) -> Tensor + + In-place version of :meth:`~Tensor.bitwise_left_shift` + """ + + @overload + def bitwise_left_shift_(self, other: Number | _complex) -> Tensor: + r""" + bitwise_left_shift_(other) -> Tensor + + In-place version of :meth:`~Tensor.bitwise_left_shift` + """ + + def bitwise_not(self) -> Tensor: + r""" + bitwise_not() -> Tensor + + See :func:`torch.bitwise_not` + """ + + def bitwise_not_(self) -> Tensor: + r""" + bitwise_not_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_not` + """ + + @overload + def bitwise_or(self, other: Tensor) -> Tensor: + r""" + bitwise_or() -> Tensor + + See :func:`torch.bitwise_or` + """ + + @overload + def bitwise_or(self, other: Number | _complex) -> Tensor: + r""" + bitwise_or() -> Tensor + + See :func:`torch.bitwise_or` + """ + + @overload + def bitwise_or_(self, other: Tensor) -> Tensor: + r""" + bitwise_or_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_or` + """ + + @overload + def bitwise_or_(self, other: Number | _complex) -> Tensor: + r""" + bitwise_or_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_or` + """ + + @overload + def bitwise_right_shift(self, other: Tensor) -> Tensor: + r""" + bitwise_right_shift(other) -> Tensor + + See :func:`torch.bitwise_right_shift` + """ + + @overload + def bitwise_right_shift(self, other: Number | _complex) -> Tensor: + r""" + bitwise_right_shift(other) -> Tensor + + See :func:`torch.bitwise_right_shift` + """ + + @overload + def bitwise_right_shift_(self, other: Tensor) -> Tensor: + r""" + bitwise_right_shift_(other) -> Tensor + + In-place version of :meth:`~Tensor.bitwise_right_shift` + """ + + @overload + def bitwise_right_shift_(self, other: Number | _complex) -> Tensor: + r""" + bitwise_right_shift_(other) -> Tensor + + In-place version of :meth:`~Tensor.bitwise_right_shift` + """ + + @overload + def bitwise_xor(self, other: Tensor) -> Tensor: + r""" + bitwise_xor() -> Tensor + + See :func:`torch.bitwise_xor` + """ + + @overload + def bitwise_xor(self, other: Number | _complex) -> Tensor: + r""" + bitwise_xor() -> Tensor + + See :func:`torch.bitwise_xor` + """ + + @overload + def bitwise_xor_(self, other: Tensor) -> Tensor: + r""" + bitwise_xor_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_xor` + """ + + @overload + def bitwise_xor_(self, other: Number | _complex) -> Tensor: + r""" + bitwise_xor_() -> Tensor + + In-place version of :meth:`~Tensor.bitwise_xor` + """ + + def bmm(self, mat2: Tensor) -> Tensor: + r""" + bmm(batch2) -> Tensor + + See :func:`torch.bmm` + """ + + def bool(self) -> Tensor: + r""" + bool(memory_format=torch.preserve_format) -> Tensor + + ``self.bool()`` is equivalent to ``self.to(torch.bool)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + @overload + def broadcast_to(self, size: Sequence[_int | SymInt]) -> Tensor: + r""" + broadcast_to(shape) -> Tensor + + See :func:`torch.broadcast_to`. + """ + + @overload + def broadcast_to(self, *size: _int | SymInt) -> Tensor: + r""" + broadcast_to(shape) -> Tensor + + See :func:`torch.broadcast_to`. + """ + + def byte(self) -> Tensor: + r""" + byte(memory_format=torch.preserve_format) -> Tensor + + ``self.byte()`` is equivalent to ``self.to(torch.uint8)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + def cauchy_( + self, + median: _float = 0, + sigma: _float = 1, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + cauchy_(median=0, sigma=1, *, generator=None) -> Tensor + + Fills the tensor with numbers drawn from the Cauchy distribution: + + .. math:: + + f(x) = \dfrac{1}{\pi} \dfrac{\sigma}{(x - \text{median})^2 + \sigma^2} + + .. note:: + Sigma (:math:`\sigma`) is used to denote the scale parameter in Cauchy distribution. + """ + + def ccol_indices(self) -> Tensor: ... + def ceil(self) -> Tensor: + r""" + ceil() -> Tensor + + See :func:`torch.ceil` + """ + + def ceil_(self) -> Tensor: + r""" + ceil_() -> Tensor + + In-place version of :meth:`~Tensor.ceil` + """ + + def chalf(self, *, memory_format: memory_format | None = None) -> Tensor: + r""" + chalf(memory_format=torch.preserve_format) -> Tensor + + ``self.chalf()`` is equivalent to ``self.to(torch.complex32)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + def char(self) -> Tensor: + r""" + char(memory_format=torch.preserve_format) -> Tensor + + ``self.char()`` is equivalent to ``self.to(torch.int8)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + def cholesky(self, upper: _bool = False) -> Tensor: + r""" + cholesky(upper=False) -> Tensor + + See :func:`torch.cholesky` + """ + + def cholesky_inverse(self, upper: _bool = False) -> Tensor: + r""" + cholesky_inverse(upper=False) -> Tensor + + See :func:`torch.cholesky_inverse` + """ + + def cholesky_solve(self, input2: Tensor, upper: _bool = False) -> Tensor: + r""" + cholesky_solve(input2, upper=False) -> Tensor + + See :func:`torch.cholesky_solve` + """ + + def chunk(self, chunks: _int, dim: _int = 0) -> tuple[Tensor, ...]: + r""" + chunk(chunks, dim=0) -> List of Tensors + + See :func:`torch.chunk` + """ + + @overload + def clamp( + self, + min: Tensor | None = None, + max: Tensor | None = None, + ) -> Tensor: + r""" + clamp(min=None, max=None) -> Tensor + + See :func:`torch.clamp` + """ + + @overload + def clamp( + self, + min: Number | _complex | None = None, + max: Number | _complex | None = None, + ) -> Tensor: + r""" + clamp(min=None, max=None) -> Tensor + + See :func:`torch.clamp` + """ + + @overload + def clamp_( + self, + min: Tensor | None = None, + max: Tensor | None = None, + ) -> Tensor: + r""" + clamp_(min=None, max=None) -> Tensor + + In-place version of :meth:`~Tensor.clamp` + """ + + @overload + def clamp_( + self, + min: Number | _complex | None = None, + max: Number | _complex | None = None, + ) -> Tensor: + r""" + clamp_(min=None, max=None) -> Tensor + + In-place version of :meth:`~Tensor.clamp` + """ + + @overload + def clamp_max(self, max: Tensor) -> Tensor: ... + @overload + def clamp_max(self, max: Number | _complex) -> Tensor: ... + @overload + def clamp_max_(self, max: Tensor) -> Tensor: ... + @overload + def clamp_max_(self, max: Number | _complex) -> Tensor: ... + @overload + def clamp_min(self, min: Tensor) -> Tensor: ... + @overload + def clamp_min(self, min: Number | _complex) -> Tensor: ... + @overload + def clamp_min_(self, min: Tensor) -> Tensor: ... + @overload + def clamp_min_(self, min: Number | _complex) -> Tensor: ... + @overload + def clip( + self, + min: Tensor | None = None, + max: Tensor | None = None, + ) -> Tensor: + r""" + clip(min=None, max=None) -> Tensor + + Alias for :meth:`~Tensor.clamp`. + """ + + @overload + def clip( + self, + min: Number | _complex | None = None, + max: Number | _complex | None = None, + ) -> Tensor: + r""" + clip(min=None, max=None) -> Tensor + + Alias for :meth:`~Tensor.clamp`. + """ + + @overload + def clip_( + self, + min: Tensor | None = None, + max: Tensor | None = None, + ) -> Tensor: + r""" + clip_(min=None, max=None) -> Tensor + + Alias for :meth:`~Tensor.clamp_`. + """ + + @overload + def clip_( + self, + min: Number | _complex | None = None, + max: Number | _complex | None = None, + ) -> Tensor: + r""" + clip_(min=None, max=None) -> Tensor + + Alias for :meth:`~Tensor.clamp_`. + """ + + def clone(self, *, memory_format: memory_format | None = None) -> Tensor: + r""" + clone(*, memory_format=torch.preserve_format) -> Tensor + + See :func:`torch.clone` + """ + + def coalesce(self) -> Tensor: + r""" + coalesce() -> Tensor + + Returns a coalesced copy of :attr:`self` if :attr:`self` is an + :ref:`uncoalesced tensor `. + + Returns :attr:`self` if :attr:`self` is a coalesced tensor. + + .. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + """ + + def col_indices(self) -> Tensor: + r""" + col_indices() -> IntTensor + + Returns the tensor containing the column indices of the :attr:`self` + tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``. + The ``col_indices`` tensor is strictly of shape (:attr:`self`.nnz()) + and of type ``int32`` or ``int64``. When using MKL routines such as sparse + matrix multiplication, it is necessary to use ``int32`` indexing in order + to avoid downcasting and potentially losing information. + + Example:: + + >>> csr = torch.eye(5,5).to_sparse_csr() + >>> csr.col_indices() + tensor([0, 1, 2, 3, 4], dtype=torch.int32) + """ + + def conj(self) -> Tensor: + r""" + conj() -> Tensor + + See :func:`torch.conj` + """ + + def conj_physical(self) -> Tensor: + r""" + conj_physical() -> Tensor + + See :func:`torch.conj_physical` + """ + + def conj_physical_(self) -> Tensor: + r""" + conj_physical_() -> Tensor + + In-place version of :meth:`~Tensor.conj_physical` + """ + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> Tensor: + r""" + contiguous(memory_format=torch.contiguous_format) -> Tensor + + Returns a contiguous in memory tensor containing the same data as :attr:`self` tensor. If + :attr:`self` tensor is already in the specified memory format, this function returns the + :attr:`self` tensor. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + """ + + def copy_(self, other: Tensor, non_blocking: _bool = False) -> Tensor: + r""" + copy_(src, non_blocking=False) -> Tensor + + Copies the elements from :attr:`src` into :attr:`self` tensor and returns + :attr:`self`. + + The :attr:`src` tensor must be :ref:`broadcastable ` + with the :attr:`self` tensor. It may be of a different data type or reside on a + different device. + + Args: + src (Tensor): the source tensor to copy from + non_blocking (bool): if ``True`` and this copy is between CPU and GPU, + the copy may occur asynchronously with respect to the host. For other + cases, this argument has no effect. + """ + + @overload + def copysign(self, other: Tensor) -> Tensor: + r""" + copysign(other) -> Tensor + + See :func:`torch.copysign` + """ + + @overload + def copysign(self, other: Number | _complex) -> Tensor: + r""" + copysign(other) -> Tensor + + See :func:`torch.copysign` + """ + + @overload + def copysign_(self, other: Tensor) -> Tensor: + r""" + copysign_(other) -> Tensor + + In-place version of :meth:`~Tensor.copysign` + """ + + @overload + def copysign_(self, other: Number | _complex) -> Tensor: + r""" + copysign_(other) -> Tensor + + In-place version of :meth:`~Tensor.copysign` + """ + + def corrcoef(self) -> Tensor: + r""" + corrcoef() -> Tensor + + See :func:`torch.corrcoef` + """ + + def cos(self) -> Tensor: + r""" + cos() -> Tensor + + See :func:`torch.cos` + """ + + def cos_(self) -> Tensor: + r""" + cos_() -> Tensor + + In-place version of :meth:`~Tensor.cos` + """ + + def cosh(self) -> Tensor: + r""" + cosh() -> Tensor + + See :func:`torch.cosh` + """ + + def cosh_(self) -> Tensor: + r""" + cosh_() -> Tensor + + In-place version of :meth:`~Tensor.cosh` + """ + + @overload + def count_nonzero(self, dim: _int | None = None) -> Tensor: + r""" + count_nonzero(dim=None) -> Tensor + + See :func:`torch.count_nonzero` + """ + + @overload + def count_nonzero(self, dim: _size) -> Tensor: + r""" + count_nonzero(dim=None) -> Tensor + + See :func:`torch.count_nonzero` + """ + + @overload + def count_nonzero(self, *dim: _int) -> Tensor: + r""" + count_nonzero(dim=None) -> Tensor + + See :func:`torch.count_nonzero` + """ + + def cov( + self, + *, + correction: _int = 1, + fweights: Tensor | None = None, + aweights: Tensor | None = None, + ) -> Tensor: + r""" + cov(*, correction=1, fweights=None, aweights=None) -> Tensor + + See :func:`torch.cov` + """ + + def cpu( + self, + memory_format: torch.memory_format = torch.preserve_format, + ) -> Tensor: + r""" + cpu(memory_format=torch.preserve_format) -> Tensor + + Returns a copy of this object in CPU memory. + + If this object is already in CPU memory, + then no copy is performed and the original object is returned. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + def cross(self, other: Tensor, dim: _int | None = None) -> Tensor: + r""" + cross(other, dim=None) -> Tensor + + See :func:`torch.cross` + """ + + def crow_indices(self) -> Tensor: + r""" + crow_indices() -> IntTensor + + Returns the tensor containing the compressed row indices of the :attr:`self` + tensor when :attr:`self` is a sparse CSR tensor of layout ``sparse_csr``. + The ``crow_indices`` tensor is strictly of shape (:attr:`self`.size(0) + 1) + and of type ``int32`` or ``int64``. When using MKL routines such as sparse + matrix multiplication, it is necessary to use ``int32`` indexing in order + to avoid downcasting and potentially losing information. + + Example:: + + >>> csr = torch.eye(5,5).to_sparse_csr() + >>> csr.crow_indices() + tensor([0, 1, 2, 3, 4, 5], dtype=torch.int32) + """ + + def cuda( + self, + device: _device | _int | str | None = None, + non_blocking: _bool = False, + memory_format: torch.memory_format = torch.preserve_format, + ) -> Tensor: + r""" + cuda(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + + Returns a copy of this object in CUDA memory. + + If this object is already in CUDA memory and on the correct device, + then no copy is performed and the original object is returned. + + Args: + device (:class:`torch.device`): The destination GPU device. + Defaults to the current CUDA device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + @overload + def cummax(self, dim: _int) -> torch.return_types.cummax: + r""" + cummax(dim) -> (Tensor, Tensor) + + See :func:`torch.cummax` + """ + + @overload + def cummax( + self, + dim: str | EllipsisType | None, + ) -> torch.return_types.cummax: + r""" + cummax(dim) -> (Tensor, Tensor) + + See :func:`torch.cummax` + """ + + @overload + def cummin(self, dim: _int) -> torch.return_types.cummin: + r""" + cummin(dim) -> (Tensor, Tensor) + + See :func:`torch.cummin` + """ + + @overload + def cummin( + self, + dim: str | EllipsisType | None, + ) -> torch.return_types.cummin: + r""" + cummin(dim) -> (Tensor, Tensor) + + See :func:`torch.cummin` + """ + + @overload + def cumprod(self, dim: _int, *, dtype: _dtype | None = None) -> Tensor: + r""" + cumprod(dim, dtype=None) -> Tensor + + See :func:`torch.cumprod` + """ + + @overload + def cumprod( + self, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + cumprod(dim, dtype=None) -> Tensor + + See :func:`torch.cumprod` + """ + + @overload + def cumprod_(self, dim: _int, *, dtype: _dtype | None = None) -> Tensor: + r""" + cumprod_(dim, dtype=None) -> Tensor + + In-place version of :meth:`~Tensor.cumprod` + """ + + @overload + def cumprod_( + self, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + cumprod_(dim, dtype=None) -> Tensor + + In-place version of :meth:`~Tensor.cumprod` + """ + + @overload + def cumsum(self, dim: _int, *, dtype: _dtype | None = None) -> Tensor: + r""" + cumsum(dim, dtype=None) -> Tensor + + See :func:`torch.cumsum` + """ + + @overload + def cumsum( + self, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + cumsum(dim, dtype=None) -> Tensor + + See :func:`torch.cumsum` + """ + + @overload + def cumsum_(self, dim: _int, *, dtype: _dtype | None = None) -> Tensor: + r""" + cumsum_(dim, dtype=None) -> Tensor + + In-place version of :meth:`~Tensor.cumsum` + """ + + @overload + def cumsum_( + self, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + cumsum_(dim, dtype=None) -> Tensor + + In-place version of :meth:`~Tensor.cumsum` + """ + + def data_ptr(self) -> _int: + r""" + data_ptr() -> int + + Returns the address of the first element of :attr:`self` tensor. + """ + + def deg2rad(self) -> Tensor: + r""" + deg2rad() -> Tensor + + See :func:`torch.deg2rad` + """ + + def deg2rad_(self) -> Tensor: + r""" + deg2rad_() -> Tensor + + In-place version of :meth:`~Tensor.deg2rad` + """ + + def dense_dim(self) -> _int: + r""" + dense_dim() -> int + + Return the number of dense dimensions in a :ref:`sparse tensor ` :attr:`self`. + + .. note:: + Returns ``len(self.shape)`` if :attr:`self` is not a sparse tensor. + + See also :meth:`Tensor.sparse_dim` and :ref:`hybrid tensors `. + """ + + def dequantize(self) -> Tensor: + r""" + dequantize() -> Tensor + + Given a quantized Tensor, dequantize it and return the dequantized float Tensor. + """ + + def det(self) -> Tensor: + r""" + det() -> Tensor + + See :func:`torch.det` + """ + + def detach(self) -> Tensor: ... + def detach_(self) -> Tensor: ... + def diag(self, diagonal: _int = 0) -> Tensor: + r""" + diag(diagonal=0) -> Tensor + + See :func:`torch.diag` + """ + + def diag_embed( + self, + offset: _int = 0, + dim1: _int = -2, + dim2: _int = -1, + ) -> Tensor: + r""" + diag_embed(offset=0, dim1=-2, dim2=-1) -> Tensor + + See :func:`torch.diag_embed` + """ + + def diagflat(self, offset: _int = 0) -> Tensor: + r""" + diagflat(offset=0) -> Tensor + + See :func:`torch.diagflat` + """ + + @overload + def diagonal( + self, + *, + outdim: str | EllipsisType | None, + dim1: str | EllipsisType | None, + dim2: str | EllipsisType | None, + offset: _int = 0, + ) -> Tensor: + r""" + diagonal(offset=0, dim1=0, dim2=1) -> Tensor + + See :func:`torch.diagonal` + """ + + @overload + def diagonal( + self, + offset: _int = 0, + dim1: _int = 0, + dim2: _int = 1, + ) -> Tensor: + r""" + diagonal(offset=0, dim1=0, dim2=1) -> Tensor + + See :func:`torch.diagonal` + """ + + def diagonal_scatter( + self, + src: Tensor, + offset: _int = 0, + dim1: _int = 0, + dim2: _int = 1, + ) -> Tensor: + r""" + diagonal_scatter(src, offset=0, dim1=0, dim2=1) -> Tensor + + See :func:`torch.diagonal_scatter` + """ + + def diff( + self, + n: _int = 1, + dim: _int = -1, + prepend: Tensor | None = None, + append: Tensor | None = None, + ) -> Tensor: + r""" + diff(n=1, dim=-1, prepend=None, append=None) -> Tensor + + See :func:`torch.diff` + """ + + def digamma(self) -> Tensor: + r""" + digamma() -> Tensor + + See :func:`torch.digamma` + """ + + def digamma_(self) -> Tensor: + r""" + digamma_() -> Tensor + + In-place version of :meth:`~Tensor.digamma` + """ + + def dim(self) -> _int: + r""" + dim() -> int + + Returns the number of dimensions of :attr:`self` tensor. + """ + + def dist(self, other: Tensor, p: Number | _complex = 2) -> Tensor: + r""" + dist(other, p=2) -> Tensor + + See :func:`torch.dist` + """ + + def div( + self, + other: Tensor | Number, + *, + rounding_mode: str | None = None, + ) -> Tensor: + r""" + div(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.div` + """ + + def div_( + self, + other: Tensor | Number, + *, + rounding_mode: str | None = None, + ) -> Tensor: + r""" + div_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.div` + """ + + @overload + def divide(self, other: Tensor) -> Tensor: + r""" + divide(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.divide` + """ + + @overload + def divide(self, other: Tensor, *, rounding_mode: str | None) -> Tensor: + r""" + divide(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.divide` + """ + + @overload + def divide( + self, + other: Number | _complex, + *, + rounding_mode: str | None, + ) -> Tensor: + r""" + divide(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.divide` + """ + + @overload + def divide(self, other: Number | _complex) -> Tensor: + r""" + divide(value, *, rounding_mode=None) -> Tensor + + See :func:`torch.divide` + """ + + @overload + def divide_(self, other: Tensor) -> Tensor: + r""" + divide_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.divide` + """ + + @overload + def divide_(self, other: Tensor, *, rounding_mode: str | None) -> Tensor: + r""" + divide_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.divide` + """ + + @overload + def divide_( + self, + other: Number | _complex, + *, + rounding_mode: str | None, + ) -> Tensor: + r""" + divide_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.divide` + """ + + @overload + def divide_(self, other: Number | _complex) -> Tensor: + r""" + divide_(value, *, rounding_mode=None) -> Tensor + + In-place version of :meth:`~Tensor.divide` + """ + + def dot(self, tensor: Tensor) -> Tensor: + r""" + dot(other) -> Tensor + + See :func:`torch.dot` + """ + + def double(self) -> Tensor: + r""" + double(memory_format=torch.preserve_format) -> Tensor + + ``self.double()`` is equivalent to ``self.to(torch.float64)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + @overload + def dsplit(self, sections: _int) -> tuple[Tensor, ...]: + r""" + dsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.dsplit` + """ + + @overload + def dsplit(self, indices: _size) -> tuple[Tensor, ...]: + r""" + dsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.dsplit` + """ + + @overload + def dsplit(self, *indices: _int) -> tuple[Tensor, ...]: + r""" + dsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.dsplit` + """ + + def element_size(self) -> _int: + r""" + element_size() -> int + + Returns the size in bytes of an individual element. + + Example:: + + >>> torch.tensor([]).element_size() + 4 + >>> torch.tensor([], dtype=torch.uint8).element_size() + 1 + """ + + @overload + def eq(self, other: Tensor) -> Tensor: + r""" + eq(other) -> Tensor + + See :func:`torch.eq` + """ + + @overload + def eq(self, other: Number | _complex) -> Tensor: + r""" + eq(other) -> Tensor + + See :func:`torch.eq` + """ + + @overload + def eq_(self, other: Tensor) -> Tensor: + r""" + eq_(other) -> Tensor + + In-place version of :meth:`~Tensor.eq` + """ + + @overload + def eq_(self, other: Number | _complex) -> Tensor: + r""" + eq_(other) -> Tensor + + In-place version of :meth:`~Tensor.eq` + """ + + def equal(self, other: Tensor) -> _bool: + r""" + equal(other) -> bool + + See :func:`torch.equal` + """ + + def erf(self) -> Tensor: + r""" + erf() -> Tensor + + See :func:`torch.erf` + """ + + def erf_(self) -> Tensor: + r""" + erf_() -> Tensor + + In-place version of :meth:`~Tensor.erf` + """ + + def erfc(self) -> Tensor: + r""" + erfc() -> Tensor + + See :func:`torch.erfc` + """ + + def erfc_(self) -> Tensor: + r""" + erfc_() -> Tensor + + In-place version of :meth:`~Tensor.erfc` + """ + + def erfinv(self) -> Tensor: + r""" + erfinv() -> Tensor + + See :func:`torch.erfinv` + """ + + def erfinv_(self) -> Tensor: + r""" + erfinv_() -> Tensor + + In-place version of :meth:`~Tensor.erfinv` + """ + + def exp(self) -> Tensor: + r""" + exp() -> Tensor + + See :func:`torch.exp` + """ + + def exp2(self) -> Tensor: + r""" + exp2() -> Tensor + + See :func:`torch.exp2` + """ + + def exp2_(self) -> Tensor: + r""" + exp2_() -> Tensor + + In-place version of :meth:`~Tensor.exp2` + """ + + def exp_(self) -> Tensor: + r""" + exp_() -> Tensor + + In-place version of :meth:`~Tensor.exp` + """ + + @overload + def expand( + self, + size: Sequence[_int | SymInt], + *, + implicit: _bool = False, + ) -> Tensor: + r""" + expand(*sizes) -> Tensor + + Returns a new view of the :attr:`self` tensor with singleton dimensions expanded + to a larger size. + + Passing -1 as the size for a dimension means not changing the size of + that dimension. + + Tensor can be also expanded to a larger number of dimensions, and the + new ones will be appended at the front. For the new dimensions, the + size cannot be set to -1. + + Expanding a tensor does not allocate new memory, but only creates a + new view on the existing tensor where a dimension of size one is + expanded to a larger size by setting the ``stride`` to 0. Any dimension + of size 1 can be expanded to an arbitrary value without allocating new + memory. + + Args: + *sizes (torch.Size or int...): the desired expanded size + + .. warning:: + + More than one element of an expanded tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensors, please clone them first. + + Example:: + + >>> x = torch.tensor([[1], [2], [3]]) + >>> x.size() + torch.Size([3, 1]) + >>> x.expand(3, 4) + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + >>> x.expand(-1, 4) # -1 means not changing the size of that dimension + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + """ + + @overload + def expand(self, *size: _int | SymInt, implicit: _bool = False) -> Tensor: + r""" + expand(*sizes) -> Tensor + + Returns a new view of the :attr:`self` tensor with singleton dimensions expanded + to a larger size. + + Passing -1 as the size for a dimension means not changing the size of + that dimension. + + Tensor can be also expanded to a larger number of dimensions, and the + new ones will be appended at the front. For the new dimensions, the + size cannot be set to -1. + + Expanding a tensor does not allocate new memory, but only creates a + new view on the existing tensor where a dimension of size one is + expanded to a larger size by setting the ``stride`` to 0. Any dimension + of size 1 can be expanded to an arbitrary value without allocating new + memory. + + Args: + *sizes (torch.Size or int...): the desired expanded size + + .. warning:: + + More than one element of an expanded tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensors, please clone them first. + + Example:: + + >>> x = torch.tensor([[1], [2], [3]]) + >>> x.size() + torch.Size([3, 1]) + >>> x.expand(3, 4) + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + >>> x.expand(-1, 4) # -1 means not changing the size of that dimension + tensor([[ 1, 1, 1, 1], + [ 2, 2, 2, 2], + [ 3, 3, 3, 3]]) + """ + + def expand_as(self, other: Tensor) -> Tensor: + r""" + expand_as(other) -> Tensor + + Expand this tensor to the same size as :attr:`other`. + ``self.expand_as(other)`` is equivalent to ``self.expand(other.size())``. + + Please see :meth:`~Tensor.expand` for more information about ``expand``. + + Args: + other (:class:`torch.Tensor`): The result tensor has the same size + as :attr:`other`. + """ + + def expm1(self) -> Tensor: + r""" + expm1() -> Tensor + + See :func:`torch.expm1` + """ + + def expm1_(self) -> Tensor: + r""" + expm1_() -> Tensor + + In-place version of :meth:`~Tensor.expm1` + """ + + def exponential_( + self, + lambd: _float = 1, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + exponential_(lambd=1, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with elements drawn from the PDF (probability density function): + + .. math:: + + f(x) = \lambda e^{-\lambda x}, x > 0 + + .. note:: + In probability theory, exponential distribution is supported on interval [0, :math:`\inf`) (i.e., :math:`x >= 0`) + implying that zero can be sampled from the exponential distribution. + However, :func:`torch.Tensor.exponential_` does not sample zero, + which means that its actual support is the interval (0, :math:`\inf`). + + Note that :func:`torch.distributions.exponential.Exponential` is supported on the interval [0, :math:`\inf`) and can sample zero. + """ + + @overload + def fill_(self, value: Tensor) -> Tensor: + r""" + fill_(value) -> Tensor + + Fills :attr:`self` tensor with the specified value. + """ + + @overload + def fill_(self, value: Number | _complex) -> Tensor: + r""" + fill_(value) -> Tensor + + Fills :attr:`self` tensor with the specified value. + """ + + def fill_diagonal_( + self, + fill_value: Number | _complex, + wrap: _bool = False, + ) -> Tensor: + r""" + fill_diagonal_(fill_value, wrap=False) -> Tensor + + Fill the main diagonal of a tensor that has at least 2-dimensions. + When dims>2, all dimensions of input must be of equal length. + This function modifies the input tensor in-place, and returns the input tensor. + + Arguments: + fill_value (Scalar): the fill value + wrap (bool): the diagonal 'wrapped' after N columns for tall matrices. + + Example:: + + >>> a = torch.zeros(3, 3) + >>> a.fill_diagonal_(5) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.]]) + >>> b = torch.zeros(7, 3) + >>> b.fill_diagonal_(5) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.]]) + >>> c = torch.zeros(7, 3) + >>> c.fill_diagonal_(5, wrap=True) + tensor([[5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.], + [0., 0., 0.], + [5., 0., 0.], + [0., 5., 0.], + [0., 0., 5.]]) + """ + + def fix(self) -> Tensor: + r""" + fix() -> Tensor + + See :func:`torch.fix`. + """ + + def fix_(self) -> Tensor: + r""" + fix_() -> Tensor + + In-place version of :meth:`~Tensor.fix` + """ + + @overload + def flatten( + self, + start_dim: _int, + end_dim: _int, + out_dim: str | EllipsisType | None, + ) -> Tensor: + r""" + flatten(start_dim=0, end_dim=-1) -> Tensor + + See :func:`torch.flatten` + """ + + @overload + def flatten(self, start_dim: _int = 0, end_dim: _int = -1) -> Tensor: + r""" + flatten(start_dim=0, end_dim=-1) -> Tensor + + See :func:`torch.flatten` + """ + + @overload + def flatten( + self, + start_dim: str | EllipsisType | None, + end_dim: str | EllipsisType | None, + out_dim: str | EllipsisType | None, + ) -> Tensor: + r""" + flatten(start_dim=0, end_dim=-1) -> Tensor + + See :func:`torch.flatten` + """ + + @overload + def flatten( + self, + dims: Sequence[str | EllipsisType | None], + out_dim: str | EllipsisType | None, + ) -> Tensor: + r""" + flatten(start_dim=0, end_dim=-1) -> Tensor + + See :func:`torch.flatten` + """ + + @overload + def flip(self, dims: _size) -> Tensor: + r""" + flip(dims) -> Tensor + + See :func:`torch.flip` + """ + + @overload + def flip(self, *dims: _int) -> Tensor: + r""" + flip(dims) -> Tensor + + See :func:`torch.flip` + """ + + def fliplr(self) -> Tensor: + r""" + fliplr() -> Tensor + + See :func:`torch.fliplr` + """ + + def flipud(self) -> Tensor: + r""" + flipud() -> Tensor + + See :func:`torch.flipud` + """ + + def float(self) -> Tensor: + r""" + float(memory_format=torch.preserve_format) -> Tensor + + ``self.float()`` is equivalent to ``self.to(torch.float32)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + @overload + def float_power(self, exponent: Tensor) -> Tensor: + r""" + float_power(exponent) -> Tensor + + See :func:`torch.float_power` + """ + + @overload + def float_power(self, exponent: Number | _complex) -> Tensor: + r""" + float_power(exponent) -> Tensor + + See :func:`torch.float_power` + """ + + @overload + def float_power_(self, exponent: Tensor) -> Tensor: + r""" + float_power_(exponent) -> Tensor + + In-place version of :meth:`~Tensor.float_power` + """ + + @overload + def float_power_(self, exponent: Number | _complex) -> Tensor: + r""" + float_power_(exponent) -> Tensor + + In-place version of :meth:`~Tensor.float_power` + """ + + def floor(self) -> Tensor: + r""" + floor() -> Tensor + + See :func:`torch.floor` + """ + + def floor_(self) -> Tensor: + r""" + floor_() -> Tensor + + In-place version of :meth:`~Tensor.floor` + """ + + def floor_divide( + self, + other: Tensor | Number | torch.SymInt | torch.SymFloat, + *, + out: Tensor | None = None, + ) -> Tensor: + r""" + floor_divide(value) -> Tensor + + See :func:`torch.floor_divide` + """ + + def floor_divide_( + self, + other: Tensor | Number | torch.SymInt | torch.SymFloat, + ) -> Tensor: + r""" + floor_divide_(value) -> Tensor + + In-place version of :meth:`~Tensor.floor_divide` + """ + + def fmax(self, other: Tensor) -> Tensor: + r""" + fmax(other) -> Tensor + + See :func:`torch.fmax` + """ + + def fmin(self, other: Tensor) -> Tensor: + r""" + fmin(other) -> Tensor + + See :func:`torch.fmin` + """ + + @overload + def fmod(self, other: Tensor) -> Tensor: + r""" + fmod(divisor) -> Tensor + + See :func:`torch.fmod` + """ + + @overload + def fmod(self, other: Number | _complex) -> Tensor: + r""" + fmod(divisor) -> Tensor + + See :func:`torch.fmod` + """ + + @overload + def fmod_(self, other: Tensor) -> Tensor: + r""" + fmod_(divisor) -> Tensor + + In-place version of :meth:`~Tensor.fmod` + """ + + @overload + def fmod_(self, other: Number | _complex) -> Tensor: + r""" + fmod_(divisor) -> Tensor + + In-place version of :meth:`~Tensor.fmod` + """ + + def frac(self) -> Tensor: + r""" + frac() -> Tensor + + See :func:`torch.frac` + """ + + def frac_(self) -> Tensor: + r""" + frac_() -> Tensor + + In-place version of :meth:`~Tensor.frac` + """ + + def frexp(self) -> torch.return_types.frexp: + r""" + frexp(input) -> (Tensor mantissa, Tensor exponent) + + See :func:`torch.frexp` + """ + + @overload + def gather( + self, + dim: _int, + index: Tensor, + *, + sparse_grad: _bool = False, + ) -> Tensor: + r""" + gather(dim, index) -> Tensor + + See :func:`torch.gather` + """ + + @overload + def gather( + self, + dim: str | EllipsisType | None, + index: Tensor, + *, + sparse_grad: _bool = False, + ) -> Tensor: + r""" + gather(dim, index) -> Tensor + + See :func:`torch.gather` + """ + + def gcd(self, other: Tensor) -> Tensor: + r""" + gcd(other) -> Tensor + + See :func:`torch.gcd` + """ + + def gcd_(self, other: Tensor) -> Tensor: + r""" + gcd_(other) -> Tensor + + In-place version of :meth:`~Tensor.gcd` + """ + + @overload + def ge(self, other: Tensor) -> Tensor: + r""" + ge(other) -> Tensor + + See :func:`torch.ge`. + """ + + @overload + def ge(self, other: Number | _complex) -> Tensor: + r""" + ge(other) -> Tensor + + See :func:`torch.ge`. + """ + + @overload + def ge_(self, other: Tensor) -> Tensor: + r""" + ge_(other) -> Tensor + + In-place version of :meth:`~Tensor.ge`. + """ + + @overload + def ge_(self, other: Number | _complex) -> Tensor: + r""" + ge_(other) -> Tensor + + In-place version of :meth:`~Tensor.ge`. + """ + + def geometric_( + self, + p: _float, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + geometric_(p, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with elements drawn from the geometric distribution: + + .. math:: + + P(X=k) = (1 - p)^{k - 1} p, k = 1, 2, ... + + .. note:: + :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`, whereas + :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success + hence draws samples in :math:`\{0, 1, \ldots\}`. + """ + + def geqrf(self) -> torch.return_types.geqrf: + r""" + geqrf() -> (Tensor, Tensor) + + See :func:`torch.geqrf` + """ + + def ger(self, vec2: Tensor) -> Tensor: + r""" + ger(vec2) -> Tensor + + See :func:`torch.ger` + """ + + def get_device(self) -> _int: + r""" + get_device() -> Device ordinal (Integer) + + For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor resides. + For CPU tensors, this function returns `-1`. + + Example:: + + >>> x = torch.randn(3, 4, 5, device='cuda:0') + >>> x.get_device() + 0 + >>> x.cpu().get_device() + -1 + """ + + @overload + def greater(self, other: Tensor) -> Tensor: + r""" + greater(other) -> Tensor + + See :func:`torch.greater`. + """ + + @overload + def greater(self, other: Number | _complex) -> Tensor: + r""" + greater(other) -> Tensor + + See :func:`torch.greater`. + """ + + @overload + def greater_(self, other: Tensor) -> Tensor: + r""" + greater_(other) -> Tensor + + In-place version of :meth:`~Tensor.greater`. + """ + + @overload + def greater_(self, other: Number | _complex) -> Tensor: + r""" + greater_(other) -> Tensor + + In-place version of :meth:`~Tensor.greater`. + """ + + @overload + def greater_equal(self, other: Tensor) -> Tensor: + r""" + greater_equal(other) -> Tensor + + See :func:`torch.greater_equal`. + """ + + @overload + def greater_equal(self, other: Number | _complex) -> Tensor: + r""" + greater_equal(other) -> Tensor + + See :func:`torch.greater_equal`. + """ + + @overload + def greater_equal_(self, other: Tensor) -> Tensor: + r""" + greater_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.greater_equal`. + """ + + @overload + def greater_equal_(self, other: Number | _complex) -> Tensor: + r""" + greater_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.greater_equal`. + """ + + @overload + def gt(self, other: Tensor) -> Tensor: + r""" + gt(other) -> Tensor + + See :func:`torch.gt`. + """ + + @overload + def gt(self, other: Number | _complex) -> Tensor: + r""" + gt(other) -> Tensor + + See :func:`torch.gt`. + """ + + @overload + def gt_(self, other: Tensor) -> Tensor: + r""" + gt_(other) -> Tensor + + In-place version of :meth:`~Tensor.gt`. + """ + + @overload + def gt_(self, other: Number | _complex) -> Tensor: + r""" + gt_(other) -> Tensor + + In-place version of :meth:`~Tensor.gt`. + """ + + def half(self) -> Tensor: + r""" + half(memory_format=torch.preserve_format) -> Tensor + + ``self.half()`` is equivalent to ``self.to(torch.float16)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + def hardshrink(self, lambd: Number | _complex = 0.5) -> Tensor: + r""" + hardshrink(lambd=0.5) -> Tensor + + See :func:`torch.nn.functional.hardshrink` + """ + + def has_names(self) -> _bool: + r""" + Is ``True`` if any of this tensor's dimensions are named. Otherwise, is ``False``. + """ + + def heaviside(self, values: Tensor) -> Tensor: + r""" + heaviside(values) -> Tensor + + See :func:`torch.heaviside` + """ + + def heaviside_(self, values: Tensor) -> Tensor: + r""" + heaviside_(values) -> Tensor + + In-place version of :meth:`~Tensor.heaviside` + """ + + def histc( + self, + bins: _int = 100, + min: Number | _complex = 0, + max: Number | _complex = 0, + ) -> Tensor: + r""" + histc(bins=100, min=0, max=0) -> Tensor + + See :func:`torch.histc` + """ + + @overload + def histogram( + self, + bins: Tensor, + *, + weight: Tensor | None = None, + density: _bool = False, + ) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False) -> (Tensor, Tensor) + + See :func:`torch.histogram` + """ + + @overload + def histogram( + self, + bins: _int = 100, + *, + range: Sequence[_float] | None = None, + weight: Tensor | None = None, + density: _bool = False, + ) -> torch.return_types.histogram: + r""" + histogram(input, bins, *, range=None, weight=None, density=False) -> (Tensor, Tensor) + + See :func:`torch.histogram` + """ + + @overload + def hsplit(self, sections: _int) -> tuple[Tensor, ...]: + r""" + hsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.hsplit` + """ + + @overload + def hsplit(self, indices: _size) -> tuple[Tensor, ...]: + r""" + hsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.hsplit` + """ + + @overload + def hsplit(self, *indices: _int) -> tuple[Tensor, ...]: + r""" + hsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.hsplit` + """ + + def hypot(self, other: Tensor) -> Tensor: + r""" + hypot(other) -> Tensor + + See :func:`torch.hypot` + """ + + def hypot_(self, other: Tensor) -> Tensor: + r""" + hypot_(other) -> Tensor + + In-place version of :meth:`~Tensor.hypot` + """ + + def i0(self) -> Tensor: + r""" + i0() -> Tensor + + See :func:`torch.i0` + """ + + def i0_(self) -> Tensor: + r""" + i0_() -> Tensor + + In-place version of :meth:`~Tensor.i0` + """ + + def igamma(self, other: Tensor) -> Tensor: + r""" + igamma(other) -> Tensor + + See :func:`torch.igamma` + """ + + def igamma_(self, other: Tensor) -> Tensor: + r""" + igamma_(other) -> Tensor + + In-place version of :meth:`~Tensor.igamma` + """ + + def igammac(self, other: Tensor) -> Tensor: + r""" + igammac(other) -> Tensor + See :func:`torch.igammac` + """ + + def igammac_(self, other: Tensor) -> Tensor: + r""" + igammac_(other) -> Tensor + In-place version of :meth:`~Tensor.igammac` + """ + + @overload + def index_add( + self, + dim: _int, + index: Tensor, + source: Tensor, + *, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + index_add(dim, index, source, *, alpha=1) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_add_`. + """ + + @overload + def index_add( + self, + dim: str | EllipsisType | None, + index: Tensor, + source: Tensor, + *, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + index_add(dim, index, source, *, alpha=1) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_add_`. + """ + + def index_add_( + self, + dim: _int, + index: Tensor, + source: Tensor, + *, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + index_add_(dim, index, source, *, alpha=1) -> Tensor + + Accumulate the elements of :attr:`alpha` times ``source`` into the :attr:`self` + tensor by adding to the indices in the order given in :attr:`index`. For example, + if ``dim == 0``, ``index[i] == j``, and ``alpha=-1``, then the ``i``\ th row of + ``source`` is subtracted from the ``j``\ th row of :attr:`self`. + + The :attr:`dim`\ th dimension of ``source`` must have the same size as the + length of :attr:`index` (which must be a vector), and all other dimensions must + match :attr:`self`, or an error will be raised. + + For a 3-D tensor the output is given as:: + + self[index[i], :, :] += alpha * src[i, :, :] # if dim == 0 + self[:, index[i], :] += alpha * src[:, i, :] # if dim == 1 + self[:, :, index[i]] += alpha * src[:, :, i] # if dim == 2 + + Note: + This operation may behave nondeterministically when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + Args: + dim (int): dimension along which to index + index (Tensor): indices of ``source`` to select from, + should have dtype either `torch.int64` or `torch.int32` + source (Tensor): the tensor containing values to add + + Keyword args: + alpha (Number): the scalar multiplier for ``source`` + + Example:: + + >>> x = torch.ones(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_add_(0, index, t) + tensor([[ 2., 3., 4.], + [ 1., 1., 1.], + [ 8., 9., 10.], + [ 1., 1., 1.], + [ 5., 6., 7.]]) + >>> x.index_add_(0, index, t, alpha=-1) + tensor([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]]) + """ + + @overload + def index_copy(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: + r""" + index_copy(dim, index, tensor2) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_copy_`. + """ + + @overload + def index_copy( + self, + dim: str | EllipsisType | None, + index: Tensor, + source: Tensor, + ) -> Tensor: + r""" + index_copy(dim, index, tensor2) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_copy_`. + """ + + @overload + def index_copy_(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: + r""" + index_copy_(dim, index, tensor) -> Tensor + + Copies the elements of :attr:`tensor` into the :attr:`self` tensor by selecting + the indices in the order given in :attr:`index`. For example, if ``dim == 0`` + and ``index[i] == j``, then the ``i``\ th row of :attr:`tensor` is copied to the + ``j``\ th row of :attr:`self`. + + The :attr:`dim`\ th dimension of :attr:`tensor` must have the same size as the + length of :attr:`index` (which must be a vector), and all other dimensions must + match :attr:`self`, or an error will be raised. + + .. note:: + If :attr:`index` contains duplicate entries, multiple elements from + :attr:`tensor` will be copied to the same index of :attr:`self`. The result + is nondeterministic since it depends on which copy occurs last. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`tensor` to select from + tensor (Tensor): the tensor containing values to copy + + Example:: + + >>> x = torch.zeros(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_copy_(0, index, t) + tensor([[ 1., 2., 3.], + [ 0., 0., 0.], + [ 7., 8., 9.], + [ 0., 0., 0.], + [ 4., 5., 6.]]) + """ + + @overload + def index_copy_( + self, + dim: str | EllipsisType | None, + index: Tensor, + source: Tensor, + ) -> Tensor: + r""" + index_copy_(dim, index, tensor) -> Tensor + + Copies the elements of :attr:`tensor` into the :attr:`self` tensor by selecting + the indices in the order given in :attr:`index`. For example, if ``dim == 0`` + and ``index[i] == j``, then the ``i``\ th row of :attr:`tensor` is copied to the + ``j``\ th row of :attr:`self`. + + The :attr:`dim`\ th dimension of :attr:`tensor` must have the same size as the + length of :attr:`index` (which must be a vector), and all other dimensions must + match :attr:`self`, or an error will be raised. + + .. note:: + If :attr:`index` contains duplicate entries, multiple elements from + :attr:`tensor` will be copied to the same index of :attr:`self`. The result + is nondeterministic since it depends on which copy occurs last. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`tensor` to select from + tensor (Tensor): the tensor containing values to copy + + Example:: + + >>> x = torch.zeros(5, 3) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2]) + >>> x.index_copy_(0, index, t) + tensor([[ 1., 2., 3.], + [ 0., 0., 0.], + [ 7., 8., 9.], + [ 0., 0., 0.], + [ 4., 5., 6.]]) + """ + + @overload + def index_fill(self, dim: _int, index: Tensor, value: Tensor) -> Tensor: + r""" + index_fill(dim, index, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_fill_`. + """ + + @overload + def index_fill( + self, + dim: str | EllipsisType | None, + index: Tensor, + value: Tensor, + ) -> Tensor: + r""" + index_fill(dim, index, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_fill_`. + """ + + @overload + def index_fill( + self, + dim: _int, + index: Tensor, + value: Number | _complex, + ) -> Tensor: + r""" + index_fill(dim, index, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_fill_`. + """ + + @overload + def index_fill( + self, + dim: str | EllipsisType | None, + index: Tensor, + value: Number | _complex, + ) -> Tensor: + r""" + index_fill(dim, index, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.index_fill_`. + """ + + @overload + def index_fill_(self, dim: _int, index: Tensor, value: Tensor) -> Tensor: + r""" + index_fill_(dim, index, value) -> Tensor + + Fills the elements of the :attr:`self` tensor with value :attr:`value` by + selecting the indices in the order given in :attr:`index`. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) + """ + + @overload + def index_fill_( + self, + dim: str | EllipsisType | None, + index: Tensor, + value: Tensor, + ) -> Tensor: + r""" + index_fill_(dim, index, value) -> Tensor + + Fills the elements of the :attr:`self` tensor with value :attr:`value` by + selecting the indices in the order given in :attr:`index`. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) + """ + + @overload + def index_fill_( + self, + dim: _int, + index: Tensor, + value: Number | _complex, + ) -> Tensor: + r""" + index_fill_(dim, index, value) -> Tensor + + Fills the elements of the :attr:`self` tensor with value :attr:`value` by + selecting the indices in the order given in :attr:`index`. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) + """ + + @overload + def index_fill_( + self, + dim: str | EllipsisType | None, + index: Tensor, + value: Number | _complex, + ) -> Tensor: + r""" + index_fill_(dim, index, value) -> Tensor + + Fills the elements of the :attr:`self` tensor with value :attr:`value` by + selecting the indices in the order given in :attr:`index`. + + Args: + dim (int): dimension along which to index + index (LongTensor): indices of :attr:`self` tensor to fill in + value (float): the value to fill with + + Example:: + + >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) + >>> index = torch.tensor([0, 2]) + >>> x.index_fill_(1, index, -1) + tensor([[-1., 2., -1.], + [-1., 5., -1.], + [-1., 8., -1.]]) + """ + + def index_put( + self, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, + ) -> Tensor: + r""" + index_put(indices, values, accumulate=False) -> Tensor + + Out-place version of :meth:`~Tensor.index_put_`. + """ + + def index_put_( + self, + indices: tuple[Tensor, ...] | list[Tensor] | None, + values: Tensor, + accumulate: _bool = False, + ) -> Tensor: + r""" + index_put_(indices, values, accumulate=False) -> Tensor + + Puts values from the tensor :attr:`values` into the tensor :attr:`self` using + the indices specified in :attr:`indices` (which is a tuple of Tensors). The + expression ``tensor.index_put_(indices, values)`` is equivalent to + ``tensor[indices] = values``. Returns :attr:`self`. + + If :attr:`accumulate` is ``True``, the elements in :attr:`values` are added to + :attr:`self`. If accumulate is ``False``, the behavior is undefined if indices + contain duplicate elements. + + Args: + indices (tuple of LongTensor): tensors used to index into `self`. + values (Tensor): tensor of same dtype as `self`. + accumulate (bool): whether to accumulate into self + """ + + def index_reduce( + self, + dim: _int, + index: Tensor, + source: Tensor, + reduce: str, + *, + include_self: _bool = True, + ) -> Tensor: ... + def index_reduce_( + self, + dim: _int, + index: Tensor, + source: Tensor, + reduce: str, + *, + include_self: _bool = True, + ) -> Tensor: + r""" + index_reduce_(dim, index, source, reduce, *, include_self=True) -> Tensor + + Accumulate the elements of ``source`` into the :attr:`self` + tensor by accumulating to the indices in the order given in :attr:`index` + using the reduction given by the ``reduce`` argument. For example, if ``dim == 0``, + ``index[i] == j``, ``reduce == prod`` and ``include_self == True`` then the ``i``\ th + row of ``source`` is multiplied by the ``j``\ th row of :attr:`self`. If + :obj:`include_self="True"`, the values in the :attr:`self` tensor are included + in the reduction, otherwise, rows in the :attr:`self` tensor that are accumulated + to are treated as if they were filled with the reduction identites. + + The :attr:`dim`\ th dimension of ``source`` must have the same size as the + length of :attr:`index` (which must be a vector), and all other dimensions must + match :attr:`self`, or an error will be raised. + + For a 3-D tensor with :obj:`reduce="prod"` and :obj:`include_self=True` the + output is given as:: + + self[index[i], :, :] *= src[i, :, :] # if dim == 0 + self[:, index[i], :] *= src[:, i, :] # if dim == 1 + self[:, :, index[i]] *= src[:, :, i] # if dim == 2 + + Note: + This operation may behave nondeterministically when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + .. note:: + + This function only supports floating point tensors. + + .. warning:: + + This function is in beta and may change in the near future. + + Args: + dim (int): dimension along which to index + index (Tensor): indices of ``source`` to select from, + should have dtype either `torch.int64` or `torch.int32` + source (FloatTensor): the tensor containing values to accumulate + reduce (str): the reduction operation to apply + (:obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`) + + Keyword args: + include_self (bool): whether the elements from the ``self`` tensor are + included in the reduction + + Example:: + + >>> x = torch.empty(5, 3).fill_(2) + >>> t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float) + >>> index = torch.tensor([0, 4, 2, 0]) + >>> x.index_reduce_(0, index, t, 'prod') + tensor([[20., 44., 72.], + [ 2., 2., 2.], + [14., 16., 18.], + [ 2., 2., 2.], + [ 8., 10., 12.]]) + >>> x = torch.empty(5, 3).fill_(2) + >>> x.index_reduce_(0, index, t, 'prod', include_self=False) + tensor([[10., 22., 36.], + [ 2., 2., 2.], + [ 7., 8., 9.], + [ 2., 2., 2.], + [ 4., 5., 6.]]) + """ + + @overload + def index_select(self, dim: _int, index: Tensor) -> Tensor: + r""" + index_select(dim, index) -> Tensor + + See :func:`torch.index_select` + """ + + @overload + def index_select( + self, + dim: str | EllipsisType | None, + index: Tensor, + ) -> Tensor: + r""" + index_select(dim, index) -> Tensor + + See :func:`torch.index_select` + """ + + def indices(self) -> Tensor: + r""" + indices() -> Tensor + + Return the indices tensor of a :ref:`sparse COO tensor `. + + .. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + + See also :meth:`Tensor.values`. + + .. note:: + This method can only be called on a coalesced sparse tensor. See + :meth:`Tensor.coalesce` for details. + """ + + def inner(self, other: Tensor) -> Tensor: + r""" + inner(other) -> Tensor + + See :func:`torch.inner`. + """ + + def int(self) -> Tensor: + r""" + int(memory_format=torch.preserve_format) -> Tensor + + ``self.int()`` is equivalent to ``self.to(torch.int32)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + def int_repr(self) -> Tensor: + r""" + int_repr() -> Tensor + + Given a quantized Tensor, + ``self.int_repr()`` returns a CPU Tensor with uint8_t as data type that stores the + underlying uint8_t values of the given Tensor. + """ + + def inverse(self) -> Tensor: + r""" + inverse() -> Tensor + + See :func:`torch.inverse` + """ + + def is_coalesced(self) -> _bool: + r""" + is_coalesced() -> bool + + Returns ``True`` if :attr:`self` is a :ref:`sparse COO tensor + ` that is coalesced, ``False`` otherwise. + + .. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + + See :meth:`coalesce` and :ref:`uncoalesced tensors `. + """ + + def is_complex(self) -> _bool: + r""" + is_complex() -> bool + + Returns True if the data type of :attr:`self` is a complex data type. + """ + + def is_conj(self) -> _bool: + r""" + is_conj() -> bool + + Returns True if the conjugate bit of :attr:`self` is set to true. + """ + + def is_contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> _bool: + r""" + is_contiguous(memory_format=torch.contiguous_format) -> bool + + Returns True if :attr:`self` tensor is contiguous in memory in the order specified + by memory format. + + Args: + memory_format (:class:`torch.memory_format`, optional): Specifies memory allocation + order. Default: ``torch.contiguous_format``. + """ + is_cpu: _bool + r"""Is ``True`` if the Tensor is stored on the CPU, ``False`` otherwise.""" + is_cuda: _bool + r"""Is ``True`` if the Tensor is stored on the GPU, ``False`` otherwise.""" + + def is_distributed(self) -> _bool: ... + def is_floating_point(self) -> _bool: + r""" + is_floating_point() -> bool + + Returns True if the data type of :attr:`self` is a floating point data type. + """ + + def is_inference(self) -> _bool: + r""" + is_inference() -> bool + + See :func:`torch.is_inference` + """ + is_ipu: _bool + r"""Is ``True`` if the Tensor is stored on the IPU, ``False`` otherwise.""" + is_leaf: _bool + r"""All Tensors that have :attr:`requires_grad` which is ``False`` will be leaf Tensors by convention. + + For Tensors that have :attr:`requires_grad` which is ``True``, they will be leaf Tensors if they were + created by the user. This means that they are not the result of an operation and so + :attr:`grad_fn` is None. + + Only leaf Tensors will have their :attr:`grad` populated during a call to :func:`backward`. + To get :attr:`grad` populated for non-leaf Tensors, you can use :func:`retain_grad`. + + Example:: + + >>> a = torch.rand(10, requires_grad=True) + >>> a.is_leaf + True + >>> b = torch.rand(10, requires_grad=True).cuda() + >>> b.is_leaf + False + # b was created by the operation that cast a cpu Tensor into a cuda Tensor + >>> c = torch.rand(10, requires_grad=True) + 2 + >>> c.is_leaf + False + # c was created by the addition operation + >>> d = torch.rand(10).cuda() + >>> d.is_leaf + True + # d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + >>> e = torch.rand(10).cuda().requires_grad_() + >>> e.is_leaf + True + # e requires gradients and has no operations creating it + >>> f = torch.rand(10, requires_grad=True, device="cuda") + >>> f.is_leaf + True + # f requires grad, has no operation creating it""" + is_maia: _bool + is_meta: _bool + r"""Is ``True`` if the Tensor is a meta tensor, ``False`` otherwise. Meta tensors + are like normal tensors, but they carry no data.""" + is_mkldnn: _bool + is_mps: _bool + r"""Is ``True`` if the Tensor is stored on the MPS device, ``False`` otherwise.""" + is_mtia: _bool + def is_neg(self) -> _bool: + r""" + is_neg() -> bool + + Returns True if the negative bit of :attr:`self` is set to true. + """ + is_nested: _bool + def is_nonzero(self) -> _bool: ... + def is_pinned(self, device: DeviceLikeType | None = None) -> _bool: + r""" + Returns true if this tensor resides in pinned memory. + By default, the device pinned memory on will be the current :ref:`accelerator`. + """ + is_quantized: _bool + r"""Is ``True`` if the Tensor is quantized, ``False`` otherwise.""" + + def is_same_size(self, other: Tensor) -> _bool: ... + def is_set_to(self, tensor: Tensor) -> _bool: + r""" + is_set_to(tensor) -> bool + + Returns True if both tensors are pointing to the exact same memory (same + storage, offset, size and stride). + """ + + def is_signed(self) -> _bool: + r""" + is_signed() -> bool + + Returns True if the data type of :attr:`self` is a signed data type. + """ + is_sparse: _bool + r"""Is ``True`` if the Tensor uses sparse COO storage layout, ``False`` otherwise.""" + is_sparse_csr: _bool + r"""Is ``True`` if the Tensor uses sparse CSR storage layout, ``False`` otherwise.""" + is_vulkan: _bool + is_xpu: _bool + r"""Is ``True`` if the Tensor is stored on the XPU, ``False`` otherwise.""" + + def isclose( + self, + other: Tensor, + rtol: _float = 1e-05, + atol: _float = 1e-08, + equal_nan: _bool = False, + ) -> Tensor: + r""" + isclose(other, rtol=1e-05, atol=1e-08, equal_nan=False) -> Tensor + + See :func:`torch.isclose` + """ + + def isfinite(self) -> Tensor: + r""" + isfinite() -> Tensor + + See :func:`torch.isfinite` + """ + + def isinf(self) -> Tensor: + r""" + isinf() -> Tensor + + See :func:`torch.isinf` + """ + + def isnan(self) -> Tensor: + r""" + isnan() -> Tensor + + See :func:`torch.isnan` + """ + + def isneginf(self) -> Tensor: + r""" + isneginf() -> Tensor + + See :func:`torch.isneginf` + """ + + def isposinf(self) -> Tensor: + r""" + isposinf() -> Tensor + + See :func:`torch.isposinf` + """ + + def isreal(self) -> Tensor: + r""" + isreal() -> Tensor + + See :func:`torch.isreal` + """ + + def istft( + self, + n_fft: _int, + hop_length: _int | None = None, + win_length: _int | None = None, + window: Tensor | None = None, + center: _bool = True, + normalized: _bool = False, + onesided: _bool | None = None, + length: _int | None = None, + return_complex: _bool = False, + ) -> Tensor: + r""" + istft(n_fft, hop_length=None, win_length=None, window=None, + center=True, normalized=False, onesided=True, length=None) -> Tensor + + See :func:`torch.istft` + """ + + def item(self) -> Number: + r""" + item() -> number + + Returns the value of this tensor as a standard Python number. This only works + for tensors with one element. For other cases, see :meth:`~Tensor.tolist`. + + This operation is not differentiable. + + Example:: + + >>> x = torch.tensor([1.0]) + >>> x.item() + 1.0 + """ + + def kron(self, other: Tensor) -> Tensor: + r""" + kron(other) -> Tensor + + See :func:`torch.kron` + """ + + @overload + def kthvalue( + self, + k: _int | SymInt, + dim: _int = -1, + keepdim: _bool = False, + ) -> torch.return_types.kthvalue: + r""" + kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.kthvalue` + """ + + @overload + def kthvalue( + self, + k: _int | SymInt, + dim: str | EllipsisType | None, + keepdim: _bool = False, + ) -> torch.return_types.kthvalue: + r""" + kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.kthvalue` + """ + + def lcm(self, other: Tensor) -> Tensor: + r""" + lcm(other) -> Tensor + + See :func:`torch.lcm` + """ + + def lcm_(self, other: Tensor) -> Tensor: + r""" + lcm_(other) -> Tensor + + In-place version of :meth:`~Tensor.lcm` + """ + + def ldexp(self, other: Tensor) -> Tensor: + r""" + ldexp(other) -> Tensor + + See :func:`torch.ldexp` + """ + + def ldexp_(self, other: Tensor) -> Tensor: + r""" + ldexp_(other) -> Tensor + + In-place version of :meth:`~Tensor.ldexp` + """ + + @overload + def le(self, other: Tensor) -> Tensor: + r""" + le(other) -> Tensor + + See :func:`torch.le`. + """ + + @overload + def le(self, other: Number | _complex) -> Tensor: + r""" + le(other) -> Tensor + + See :func:`torch.le`. + """ + + @overload + def le_(self, other: Tensor) -> Tensor: + r""" + le_(other) -> Tensor + + In-place version of :meth:`~Tensor.le`. + """ + + @overload + def le_(self, other: Number | _complex) -> Tensor: + r""" + le_(other) -> Tensor + + In-place version of :meth:`~Tensor.le`. + """ + + @overload + def lerp(self, end: Tensor, weight: Tensor) -> Tensor: + r""" + lerp(end, weight) -> Tensor + + See :func:`torch.lerp` + """ + + @overload + def lerp(self, end: Tensor, weight: Number | _complex) -> Tensor: + r""" + lerp(end, weight) -> Tensor + + See :func:`torch.lerp` + """ + + @overload + def lerp_(self, end: Tensor, weight: Tensor) -> Tensor: + r""" + lerp_(end, weight) -> Tensor + + In-place version of :meth:`~Tensor.lerp` + """ + + @overload + def lerp_(self, end: Tensor, weight: Number | _complex) -> Tensor: + r""" + lerp_(end, weight) -> Tensor + + In-place version of :meth:`~Tensor.lerp` + """ + + @overload + def less(self, other: Tensor) -> Tensor: + r""" + lt(other) -> Tensor + + See :func:`torch.less`. + """ + + @overload + def less(self, other: Number | _complex) -> Tensor: + r""" + lt(other) -> Tensor + + See :func:`torch.less`. + """ + + @overload + def less_(self, other: Tensor) -> Tensor: + r""" + less_(other) -> Tensor + + In-place version of :meth:`~Tensor.less`. + """ + + @overload + def less_(self, other: Number | _complex) -> Tensor: + r""" + less_(other) -> Tensor + + In-place version of :meth:`~Tensor.less`. + """ + + @overload + def less_equal(self, other: Tensor) -> Tensor: + r""" + less_equal(other) -> Tensor + + See :func:`torch.less_equal`. + """ + + @overload + def less_equal(self, other: Number | _complex) -> Tensor: + r""" + less_equal(other) -> Tensor + + See :func:`torch.less_equal`. + """ + + @overload + def less_equal_(self, other: Tensor) -> Tensor: + r""" + less_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.less_equal`. + """ + + @overload + def less_equal_(self, other: Number | _complex) -> Tensor: + r""" + less_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.less_equal`. + """ + + def lgamma(self) -> Tensor: + r""" + lgamma() -> Tensor + + See :func:`torch.lgamma` + """ + + def lgamma_(self) -> Tensor: + r""" + lgamma_() -> Tensor + + In-place version of :meth:`~Tensor.lgamma` + """ + + def log(self) -> Tensor: + r""" + log() -> Tensor + + See :func:`torch.log` + """ + + def log10(self) -> Tensor: + r""" + log10() -> Tensor + + See :func:`torch.log10` + """ + + def log10_(self) -> Tensor: + r""" + log10_() -> Tensor + + In-place version of :meth:`~Tensor.log10` + """ + + def log1p(self) -> Tensor: + r""" + log1p() -> Tensor + + See :func:`torch.log1p` + """ + + def log1p_(self) -> Tensor: + r""" + log1p_() -> Tensor + + In-place version of :meth:`~Tensor.log1p` + """ + + def log2(self) -> Tensor: + r""" + log2() -> Tensor + + See :func:`torch.log2` + """ + + def log2_(self) -> Tensor: + r""" + log2_() -> Tensor + + In-place version of :meth:`~Tensor.log2` + """ + + def log_(self) -> Tensor: + r""" + log_() -> Tensor + + In-place version of :meth:`~Tensor.log` + """ + + def log_normal_( + self, + mean: _float = 1, + std: _float = 2, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + log_normal_(mean=1, std=2, *, generator=None) + + Fills :attr:`self` tensor with numbers samples from the log-normal distribution + parameterized by the given mean :math:`\mu` and standard deviation + :math:`\sigma`. Note that :attr:`mean` and :attr:`std` are the mean and + standard deviation of the underlying normal distribution, and not of the + returned distribution: + + .. math:: + + f(x) = \dfrac{1}{x \sigma \sqrt{2\pi}}\ e^{-\frac{(\ln x - \mu)^2}{2\sigma^2}} + """ + + @overload + def log_softmax(self, dim: _int, dtype: _dtype | None = None) -> Tensor: ... + @overload + def log_softmax( + self, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + ) -> Tensor: ... + def logaddexp(self, other: Tensor) -> Tensor: + r""" + logaddexp(other) -> Tensor + + See :func:`torch.logaddexp` + """ + + def logaddexp2(self, other: Tensor) -> Tensor: + r""" + logaddexp2(other) -> Tensor + + See :func:`torch.logaddexp2` + """ + + @overload + def logcumsumexp(self, dim: _int) -> Tensor: + r""" + logcumsumexp(dim) -> Tensor + + See :func:`torch.logcumsumexp` + """ + + @overload + def logcumsumexp(self, dim: str | EllipsisType | None) -> Tensor: + r""" + logcumsumexp(dim) -> Tensor + + See :func:`torch.logcumsumexp` + """ + + def logdet(self) -> Tensor: + r""" + logdet() -> Tensor + + See :func:`torch.logdet` + """ + + def logical_and(self, other: Tensor) -> Tensor: + r""" + logical_and() -> Tensor + + See :func:`torch.logical_and` + """ + + def logical_and_(self, other: Tensor) -> Tensor: + r""" + logical_and_() -> Tensor + + In-place version of :meth:`~Tensor.logical_and` + """ + + def logical_not(self) -> Tensor: + r""" + logical_not() -> Tensor + + See :func:`torch.logical_not` + """ + + def logical_not_(self) -> Tensor: + r""" + logical_not_() -> Tensor + + In-place version of :meth:`~Tensor.logical_not` + """ + + def logical_or(self, other: Tensor) -> Tensor: + r""" + logical_or() -> Tensor + + See :func:`torch.logical_or` + """ + + def logical_or_(self, other: Tensor) -> Tensor: + r""" + logical_or_() -> Tensor + + In-place version of :meth:`~Tensor.logical_or` + """ + + def logical_xor(self, other: Tensor) -> Tensor: + r""" + logical_xor() -> Tensor + + See :func:`torch.logical_xor` + """ + + def logical_xor_(self, other: Tensor) -> Tensor: + r""" + logical_xor_() -> Tensor + + In-place version of :meth:`~Tensor.logical_xor` + """ + + def logit(self, eps: _float | None = None) -> Tensor: + r""" + logit() -> Tensor + + See :func:`torch.logit` + """ + + def logit_(self, eps: _float | None = None) -> Tensor: + r""" + logit_() -> Tensor + + In-place version of :meth:`~Tensor.logit` + """ + + @overload + def logsumexp(self, dim: _int | _size, keepdim: _bool = False) -> Tensor: + r""" + logsumexp(dim, keepdim=False) -> Tensor + + See :func:`torch.logsumexp` + """ + + @overload + def logsumexp( + self, + dim: Sequence[str | EllipsisType | None], + keepdim: _bool = False, + ) -> Tensor: + r""" + logsumexp(dim, keepdim=False) -> Tensor + + See :func:`torch.logsumexp` + """ + + def long(self) -> Tensor: + r""" + long(memory_format=torch.preserve_format) -> Tensor + + ``self.long()`` is equivalent to ``self.to(torch.int64)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + @overload + def lt(self, other: Tensor) -> Tensor: + r""" + lt(other) -> Tensor + + See :func:`torch.lt`. + """ + + @overload + def lt(self, other: Number | _complex) -> Tensor: + r""" + lt(other) -> Tensor + + See :func:`torch.lt`. + """ + + @overload + def lt_(self, other: Tensor) -> Tensor: + r""" + lt_(other) -> Tensor + + In-place version of :meth:`~Tensor.lt`. + """ + + @overload + def lt_(self, other: Number | _complex) -> Tensor: + r""" + lt_(other) -> Tensor + + In-place version of :meth:`~Tensor.lt`. + """ + + def lu_solve(self, LU_data: Tensor, LU_pivots: Tensor) -> Tensor: + r""" + lu_solve(LU_data, LU_pivots) -> Tensor + + See :func:`torch.lu_solve` + """ + + def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ... + def map_(self, other: Tensor, callable: Callable) -> Tensor: + r""" + map_(tensor, callable) + + Applies :attr:`callable` for each element in :attr:`self` tensor and the given + :attr:`tensor` and stores the results in :attr:`self` tensor. :attr:`self` tensor and + the given :attr:`tensor` must be :ref:`broadcastable `. + + The :attr:`callable` should have the signature:: + + def callable(a, b) -> number + """ + + @overload + def masked_fill(self, mask: Tensor, value: Tensor) -> Tensor: + r""" + masked_fill(mask, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.masked_fill_` + """ + + @overload + def masked_fill(self, mask: Tensor, value: Number | _complex) -> Tensor: + r""" + masked_fill(mask, value) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.masked_fill_` + """ + + @overload + def masked_fill_(self, mask: Tensor, value: Tensor) -> Tensor: + r""" + masked_fill_(mask, value) + + Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is + True. The shape of :attr:`mask` must be + :ref:`broadcastable ` with the shape of the underlying + tensor. + + Args: + mask (BoolTensor): the boolean mask + value (float): the value to fill in with + """ + + @overload + def masked_fill_(self, mask: Tensor, value: Number | _complex) -> Tensor: + r""" + masked_fill_(mask, value) + + Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is + True. The shape of :attr:`mask` must be + :ref:`broadcastable ` with the shape of the underlying + tensor. + + Args: + mask (BoolTensor): the boolean mask + value (float): the value to fill in with + """ + + def masked_scatter(self, mask: Tensor, source: Tensor) -> Tensor: + r""" + masked_scatter(mask, tensor) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.masked_scatter_` + + .. note:: + + The inputs :attr:`self` and :attr:`mask` + :ref:`broadcast `. + + Example: + + >>> self = torch.tensor([0, 0, 0, 0, 0]) + >>> mask = torch.tensor( + ... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], + ... dtype=torch.bool, + ... ) + >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + >>> self.masked_scatter(mask, source) + tensor([[0, 0, 0, 0, 1], + [2, 3, 0, 4, 5]]) + """ + + def masked_scatter_(self, mask: Tensor, source: Tensor) -> Tensor: + r""" + masked_scatter_(mask, source) + + Copies elements from :attr:`source` into :attr:`self` tensor at positions where + the :attr:`mask` is True. Elements from :attr:`source` are copied into :attr:`self` + starting at position 0 of :attr:`source` and continuing in order one-by-one for each + occurrence of :attr:`mask` being True. + The shape of :attr:`mask` must be :ref:`broadcastable ` + with the shape of the underlying tensor. The :attr:`source` should have at least + as many elements as the number of ones in :attr:`mask`. + + Args: + mask (BoolTensor): the boolean mask + source (Tensor): the tensor to copy from + + .. note:: + + The :attr:`mask` operates on the :attr:`self` tensor, not on the given + :attr:`source` tensor. + + Example: + + >>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) + >>> mask = torch.tensor( + ... [[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], + ... dtype=torch.bool, + ... ) + >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + >>> self.masked_scatter_(mask, source) + tensor([[0, 0, 0, 0, 1], + [2, 3, 0, 4, 5]]) + """ + + def masked_select(self, mask: Tensor) -> Tensor: + r""" + masked_select(mask) -> Tensor + + See :func:`torch.masked_select` + """ + + def matmul(self, other: Tensor) -> Tensor: + r""" + matmul(tensor2) -> Tensor + + See :func:`torch.matmul` + """ + + def matrix_exp(self) -> Tensor: + r""" + matrix_exp() -> Tensor + + See :func:`torch.matrix_exp` + """ + + def matrix_power(self, n: _int) -> Tensor: + r""" + matrix_power(n) -> Tensor + + .. note:: :meth:`~Tensor.matrix_power` is deprecated, use :func:`torch.linalg.matrix_power` instead. + + Alias for :func:`torch.linalg.matrix_power` + """ + + @overload + def max(self) -> Tensor: + r""" + max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.max` + """ + + @overload + def max(self, other: Tensor) -> Tensor: + r""" + max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.max` + """ + + @overload + def max(self, dim: _int, keepdim: _bool = False) -> torch.return_types.max: + r""" + max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.max` + """ + + @overload + def max( + self, + dim: str | EllipsisType | None, + keepdim: _bool = False, + ) -> torch.return_types.max: + r""" + max(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.max` + """ + + def maximum(self, other: Tensor) -> Tensor: + r""" + maximum(other) -> Tensor + + See :func:`torch.maximum` + """ + + @overload + def mean(self, *, dtype: _dtype | None = None) -> Tensor: + r""" + mean(dim=None, keepdim=False, *, dtype=None) -> Tensor + + See :func:`torch.mean` + """ + + @overload + def mean( + self, + dim: _int | _size | None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + mean(dim=None, keepdim=False, *, dtype=None) -> Tensor + + See :func:`torch.mean` + """ + + @overload + def mean( + self, + dim: Sequence[str | EllipsisType | None], + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + mean(dim=None, keepdim=False, *, dtype=None) -> Tensor + + See :func:`torch.mean` + """ + + @overload + def median(self) -> Tensor: + r""" + median(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.median` + """ + + @overload + def median( + self, + dim: _int, + keepdim: _bool = False, + ) -> torch.return_types.median: + r""" + median(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.median` + """ + + @overload + def median( + self, + dim: str | EllipsisType | None, + keepdim: _bool = False, + ) -> torch.return_types.median: + r""" + median(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.median` + """ + + @overload + def min(self) -> Tensor: + r""" + min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.min` + """ + + @overload + def min(self, other: Tensor) -> Tensor: + r""" + min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.min` + """ + + @overload + def min(self, dim: _int, keepdim: _bool = False) -> torch.return_types.min: + r""" + min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.min` + """ + + @overload + def min( + self, + dim: str | EllipsisType | None, + keepdim: _bool = False, + ) -> torch.return_types.min: + r""" + min(dim=None, keepdim=False) -> Tensor or (Tensor, Tensor) + + See :func:`torch.min` + """ + + def minimum(self, other: Tensor) -> Tensor: + r""" + minimum(other) -> Tensor + + See :func:`torch.minimum` + """ + + def mm(self, mat2: Tensor) -> Tensor: + r""" + mm(mat2) -> Tensor + + See :func:`torch.mm` + """ + + @overload + def mode( + self, + dim: _int = -1, + keepdim: _bool = False, + ) -> torch.return_types.mode: + r""" + mode(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.mode` + """ + + @overload + def mode( + self, + dim: str | EllipsisType | None, + keepdim: _bool = False, + ) -> torch.return_types.mode: + r""" + mode(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.mode` + """ + + @overload + def moveaxis(self, source: _int, destination: _int) -> Tensor: + r""" + moveaxis(source, destination) -> Tensor + + See :func:`torch.moveaxis` + """ + + @overload + def moveaxis(self, source: _size, destination: _size) -> Tensor: + r""" + moveaxis(source, destination) -> Tensor + + See :func:`torch.moveaxis` + """ + + @overload + def movedim(self, source: _int, destination: _int) -> Tensor: + r""" + movedim(source, destination) -> Tensor + + See :func:`torch.movedim` + """ + + @overload + def movedim(self, source: _size, destination: _size) -> Tensor: + r""" + movedim(source, destination) -> Tensor + + See :func:`torch.movedim` + """ + + def msort(self) -> Tensor: + r""" + msort() -> Tensor + + See :func:`torch.msort` + """ + + def mul( + self, + other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat, + *, + out: Tensor | None = None, + ) -> Tensor: + r""" + mul(value) -> Tensor + + See :func:`torch.mul`. + """ + + def mul_( + self, + other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat, + ) -> Tensor: + r""" + mul_(value) -> Tensor + + In-place version of :meth:`~Tensor.mul`. + """ + + def multinomial( + self, + num_samples: _int | SymInt, + replacement: _bool = False, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + multinomial(num_samples, replacement=False, *, generator=None) -> Tensor + + See :func:`torch.multinomial` + """ + + @overload + def multiply(self, other: Tensor) -> Tensor: + r""" + multiply(value) -> Tensor + + See :func:`torch.multiply`. + """ + + @overload + def multiply(self, other: Number | _complex) -> Tensor: + r""" + multiply(value) -> Tensor + + See :func:`torch.multiply`. + """ + + @overload + def multiply_(self, other: Tensor) -> Tensor: + r""" + multiply_(value) -> Tensor + + In-place version of :meth:`~Tensor.multiply`. + """ + + @overload + def multiply_(self, other: Number | _complex) -> Tensor: + r""" + multiply_(value) -> Tensor + + In-place version of :meth:`~Tensor.multiply`. + """ + + def mv(self, vec: Tensor) -> Tensor: + r""" + mv(vec) -> Tensor + + See :func:`torch.mv` + """ + + def mvlgamma(self, p: _int) -> Tensor: + r""" + mvlgamma(p) -> Tensor + + See :func:`torch.mvlgamma` + """ + + def mvlgamma_(self, p: _int) -> Tensor: + r""" + mvlgamma_(p) -> Tensor + + In-place version of :meth:`~Tensor.mvlgamma` + """ + + def nan_to_num( + self, + nan: _float | None = None, + posinf: _float | None = None, + neginf: _float | None = None, + ) -> Tensor: + r""" + nan_to_num(nan=0.0, posinf=None, neginf=None) -> Tensor + + See :func:`torch.nan_to_num`. + """ + + def nan_to_num_( + self, + nan: _float | None = None, + posinf: _float | None = None, + neginf: _float | None = None, + ) -> Tensor: + r""" + nan_to_num_(nan=0.0, posinf=None, neginf=None) -> Tensor + + In-place version of :meth:`~Tensor.nan_to_num`. + """ + + def nanmean( + self, + dim: _int | _size | None = None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + nanmean(dim=None, keepdim=False, *, dtype=None) -> Tensor + + See :func:`torch.nanmean` + """ + + @overload + def nanmedian(self) -> Tensor: + r""" + nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.nanmedian` + """ + + @overload + def nanmedian( + self, + dim: _int, + keepdim: _bool = False, + ) -> torch.return_types.nanmedian: + r""" + nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.nanmedian` + """ + + @overload + def nanmedian( + self, + dim: str | EllipsisType | None, + keepdim: _bool = False, + ) -> torch.return_types.nanmedian: + r""" + nanmedian(dim=None, keepdim=False) -> (Tensor, LongTensor) + + See :func:`torch.nanmedian` + """ + + @overload + def nanquantile( + self, + q: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + ) -> Tensor: + r""" + nanquantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + + See :func:`torch.nanquantile` + """ + + @overload + def nanquantile( + self, + q: _float, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + ) -> Tensor: + r""" + nanquantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + + See :func:`torch.nanquantile` + """ + + def nansum( + self, + dim: _int | _size | None = None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + nansum(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.nansum` + """ + + @overload + def narrow(self, dim: _int, start: Tensor, length: _int | SymInt) -> Tensor: + r""" + narrow(dimension, start, length) -> Tensor + + See :func:`torch.narrow`. + """ + + @overload + def narrow( + self, + dim: _int, + start: _int | SymInt, + length: _int | SymInt, + ) -> Tensor: + r""" + narrow(dimension, start, length) -> Tensor + + See :func:`torch.narrow`. + """ + + def narrow_copy( + self, + dim: _int, + start: _int | SymInt, + length: _int | SymInt, + ) -> Tensor: + r""" + narrow_copy(dimension, start, length) -> Tensor + + See :func:`torch.narrow_copy`. + """ + + def ndimension(self) -> _int: + r""" + ndimension() -> int + + Alias for :meth:`~Tensor.dim()` + """ + + @overload + def ne(self, other: Tensor) -> Tensor: + r""" + ne(other) -> Tensor + + See :func:`torch.ne`. + """ + + @overload + def ne(self, other: Number | _complex) -> Tensor: + r""" + ne(other) -> Tensor + + See :func:`torch.ne`. + """ + + @overload + def ne_(self, other: Tensor) -> Tensor: + r""" + ne_(other) -> Tensor + + In-place version of :meth:`~Tensor.ne`. + """ + + @overload + def ne_(self, other: Number | _complex) -> Tensor: + r""" + ne_(other) -> Tensor + + In-place version of :meth:`~Tensor.ne`. + """ + + def neg(self) -> Tensor: + r""" + neg() -> Tensor + + See :func:`torch.neg` + """ + + def neg_(self) -> Tensor: + r""" + neg_() -> Tensor + + In-place version of :meth:`~Tensor.neg` + """ + + def negative(self) -> Tensor: + r""" + negative() -> Tensor + + See :func:`torch.negative` + """ + + def negative_(self) -> Tensor: + r""" + negative_() -> Tensor + + In-place version of :meth:`~Tensor.negative` + """ + + def nelement(self) -> _int: + r""" + nelement() -> int + + Alias for :meth:`~Tensor.numel` + """ + + @overload + def new(cls, *args: Any, device: DeviceLikeType | None = None) -> Self: ... + @overload + def new(cls, storage: Storage) -> Self: ... + @overload + def new(cls, other: Tensor) -> Self: ... + @overload + def new(cls, size: _size, *, device: DeviceLikeType | None = None) -> Self: ... + @overload + def new_empty( + self, + size: Sequence[_int | SymInt], + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, + ) -> Tensor: + r""" + new_empty(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with uninitialized data. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty((2, 3)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + """ + + @overload + def new_empty( + self, + *size: _int | SymInt, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, + ) -> Tensor: + r""" + new_empty(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with uninitialized data. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty((2, 3)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + """ + + def new_empty_strided( + self, + size: Sequence[_int | SymInt], + stride: Sequence[_int | SymInt], + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, + ) -> Tensor: + r""" + new_empty_strided(size, stride, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` and strides :attr:`stride` filled with + uninitialized data. By default, the returned Tensor has the same + :class:`torch.dtype` and :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones(()) + >>> tensor.new_empty_strided((2, 3), (3, 1)) + tensor([[ 5.8182e-18, 4.5765e-41, -1.0545e+30], + [ 3.0949e-41, 4.4842e-44, 0.0000e+00]]) + """ + + def new_full( + self, + size: Sequence[_int | SymInt], + fill_value: Number | _complex, + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, + ) -> Tensor: + r""" + new_full(size, fill_value, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with :attr:`fill_value`. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + fill_value (scalar): the number to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones((2,), dtype=torch.float64) + >>> tensor.new_full((3, 4), 3.141592) + tensor([[ 3.1416, 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416, 3.1416], + [ 3.1416, 3.1416, 3.1416, 3.1416]], dtype=torch.float64) + """ + + @overload + def new_ones( + self, + size: _size, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, + ) -> Tensor: + r""" + new_ones(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``1``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.int32) + >>> tensor.new_ones((2, 3)) + tensor([[ 1, 1, 1], + [ 1, 1, 1]], dtype=torch.int32) + """ + + @overload + def new_ones( + self, + size: Sequence[_int | SymInt], + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, + ) -> Tensor: + r""" + new_ones(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``1``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.int32) + >>> tensor.new_ones((2, 3)) + tensor([[ 1, 1, 1], + [ 1, 1, 1]], dtype=torch.int32) + """ + + @overload + def new_ones( + self, + *size: _int | SymInt, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, + ) -> Tensor: + r""" + new_ones(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``1``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.int32) + >>> tensor.new_ones((2, 3)) + tensor([[ 1, 1, 1], + [ 1, 1, 1]], dtype=torch.int32) + """ + + def new_tensor( + self, + data: Any, + dtype: _dtype | None = None, + device: DeviceLikeType | None = None, + requires_grad: _bool = False, + pin_memory: _bool = False, + ) -> Tensor: + r""" + new_tensor(data, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a new Tensor with :attr:`data` as the tensor data. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + .. warning:: + + :func:`new_tensor` always copies :attr:`data`. If you have a Tensor + ``data`` and want to avoid a copy, use :func:`torch.Tensor.requires_grad_` + or :func:`torch.Tensor.detach`. + If you have a numpy array and want to avoid a copy, use + :func:`torch.from_numpy`. + + .. warning:: + + When data is a tensor `x`, :func:`new_tensor()` reads out 'the data' from whatever it is passed, + and constructs a leaf variable. Therefore ``tensor.new_tensor(x)`` is equivalent to ``x.detach().clone()`` + and ``tensor.new_tensor(x, requires_grad=True)`` is equivalent to ``x.detach().clone().requires_grad_(True)``. + The equivalents using ``detach()`` and ``clone()`` are recommended. + + Args: + data (array_like): The returned Tensor copies :attr:`data`. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.ones((2,), dtype=torch.int8) + >>> data = [[0, 1], [2, 3]] + >>> tensor.new_tensor(data) + tensor([[ 0, 1], + [ 2, 3]], dtype=torch.int8) + """ + + @overload + def new_zeros( + self, + size: Sequence[_int | SymInt], + *, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, + ) -> Tensor: + r""" + new_zeros(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``0``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.float64) + >>> tensor.new_zeros((2, 3)) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]], dtype=torch.float64) + """ + + @overload + def new_zeros( + self, + *size: _int | SymInt, + dtype: _dtype | None = None, + layout: _layout | None = None, + device: DeviceLikeType | None = None, + pin_memory: _bool | None = False, + requires_grad: _bool | None = False, + ) -> Tensor: + r""" + new_zeros(size, *, dtype=None, device=None, requires_grad=False, layout=torch.strided, pin_memory=False) -> Tensor + + + Returns a Tensor of size :attr:`size` filled with ``0``. + By default, the returned Tensor has the same :class:`torch.dtype` and + :class:`torch.device` as this tensor. + + Args: + size (int...): a list, tuple, or :class:`torch.Size` of integers defining the + shape of the output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired type of returned tensor. + Default: if None, same :class:`torch.dtype` as this tensor. + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if None, same :class:`torch.device` as this tensor. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + + Example:: + + >>> tensor = torch.tensor((), dtype=torch.float64) + >>> tensor.new_zeros((2, 3)) + tensor([[ 0., 0., 0.], + [ 0., 0., 0.]], dtype=torch.float64) + """ + + def nextafter(self, other: Tensor) -> Tensor: + r""" + nextafter(other) -> Tensor + See :func:`torch.nextafter` + """ + + def nextafter_(self, other: Tensor) -> Tensor: + r""" + nextafter_(other) -> Tensor + In-place version of :meth:`~Tensor.nextafter` + """ + + @overload + def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: + r""" + nonzero() -> LongTensor + + See :func:`torch.nonzero` + """ + + @overload + def nonzero(self, *, as_tuple: Literal[True]) -> tuple[Tensor, ...]: + r""" + nonzero() -> LongTensor + + See :func:`torch.nonzero` + """ + + def nonzero_static( + self, + *, + size: _int | SymInt, + fill_value: _int = -1, + ) -> Tensor: + r""" + nonzero_static(input, *, size, fill_value=-1) -> Tensor + + Returns a 2-D tensor where each row is the index for a non-zero value. + The returned Tensor has the same `torch.dtype` as `torch.nonzero()`. + + Args: + input (Tensor): the input tensor to count non-zero elements. + + Keyword args: + size (int): the size of non-zero elements expected to be included in the out + tensor. Pad the out tensor with `fill_value` if the `size` is larger + than total number of non-zero elements, truncate out tensor if `size` + is smaller. The size must be a non-negative integer. + fill_value (int): the value to fill the output tensor with when `size` is larger + than the total number of non-zero elements. Default is `-1` to represent + invalid index. + + Example: + + # Example 1: Padding + >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) + >>> static_size = 4 + >>> t = torch.nonzero_static(input_tensor, size=static_size) + tensor([[ 0, 0], + [ 1, 0], + [ 1, 1], + [ -1, -1]], dtype=torch.int64) + + # Example 2: Truncating + >>> input_tensor = torch.tensor([[1, 0], [3, 2]]) + >>> static_size = 2 + >>> t = torch.nonzero_static(input_tensor, size=static_size) + tensor([[ 0, 0], + [ 1, 0]], dtype=torch.int64) + + # Example 3: 0 size + >>> input_tensor = torch.tensor([10]) + >>> static_size = 0 + >>> t = torch.nonzero_static(input_tensor, size=static_size) + tensor([], size=(0, 1), dtype=torch.int64) + + # Example 4: 0 rank input + >>> input_tensor = torch.tensor(10) + >>> static_size = 2 + >>> t = torch.nonzero_static(input_tensor, size=static_size) + tensor([], size=(2, 0), dtype=torch.int64) + """ + + def normal_( + self, + mean: _float = 0, + std: _float = 1, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + normal_(mean=0, std=1, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with elements samples from the normal distribution + parameterized by :attr:`mean` and :attr:`std`. + """ + + @overload + def not_equal(self, other: Tensor) -> Tensor: + r""" + not_equal(other) -> Tensor + + See :func:`torch.not_equal`. + """ + + @overload + def not_equal(self, other: Number | _complex) -> Tensor: + r""" + not_equal(other) -> Tensor + + See :func:`torch.not_equal`. + """ + + @overload + def not_equal_(self, other: Tensor) -> Tensor: + r""" + not_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.not_equal`. + """ + + @overload + def not_equal_(self, other: Number | _complex) -> Tensor: + r""" + not_equal_(other) -> Tensor + + In-place version of :meth:`~Tensor.not_equal`. + """ + + def numel(self) -> _int: + r""" + numel() -> int + + See :func:`torch.numel` + """ + + def numpy(self, *, force: _bool = False) -> numpy.ndarray: + r""" + numpy(*, force=False) -> numpy.ndarray + + Returns the tensor as a NumPy :class:`ndarray`. + + If :attr:`force` is ``False`` (the default), the conversion + is performed only if the tensor is on the CPU, does not require grad, + does not have its conjugate bit set, and is a dtype and layout that + NumPy supports. The returned ndarray and the tensor will share their + storage, so changes to the tensor will be reflected in the ndarray + and vice versa. + + If :attr:`force` is ``True`` this is equivalent to + calling ``t.detach().cpu().resolve_conj().resolve_neg().numpy()``. + If the tensor isn't on the CPU or the conjugate or negative bit is set, + the tensor won't share its storage with the returned ndarray. + Setting :attr:`force` to ``True`` can be a useful shorthand. + + Args: + force (bool): if ``True``, the ndarray may be a copy of the tensor + instead of always sharing memory, defaults to ``False``. + """ + + def orgqr(self, input2: Tensor) -> Tensor: + r""" + orgqr(input2) -> Tensor + + See :func:`torch.orgqr` + """ + + def ormqr( + self, + input2: Tensor, + input3: Tensor, + left: _bool = True, + transpose: _bool = False, + ) -> Tensor: + r""" + ormqr(input2, input3, left=True, transpose=False) -> Tensor + + See :func:`torch.ormqr` + """ + + def outer(self, vec2: Tensor) -> Tensor: + r""" + outer(vec2) -> Tensor + + See :func:`torch.outer`. + """ + + @overload + def permute(self, dims: _size) -> Tensor: + r""" + permute(*dims) -> Tensor + + See :func:`torch.permute` + """ + + @overload + def permute(self, *dims: _int) -> Tensor: + r""" + permute(*dims) -> Tensor + + See :func:`torch.permute` + """ + + def pin_memory(self, device: DeviceLikeType | None = None) -> Tensor: + r""" + pin_memory() -> Tensor + + Copies the tensor to pinned memory, if it's not already pinned. + By default, the device pinned memory on will be the current :ref:`accelerator`. + """ + + def pinverse(self, rcond: _float = 1e-15) -> Tensor: + r""" + pinverse() -> Tensor + + See :func:`torch.pinverse` + """ + + def polygamma(self, n: _int) -> Tensor: + r""" + polygamma(n) -> Tensor + + See :func:`torch.polygamma` + """ + + def polygamma_(self, n: _int) -> Tensor: + r""" + polygamma_(n) -> Tensor + + In-place version of :meth:`~Tensor.polygamma` + """ + + def positive(self) -> Tensor: + r""" + positive() -> Tensor + + See :func:`torch.positive` + """ + + @overload + def pow(self, exponent: Tensor) -> Tensor: + r""" + pow(exponent) -> Tensor + + See :func:`torch.pow` + """ + + @overload + def pow(self, exponent: Number | _complex) -> Tensor: + r""" + pow(exponent) -> Tensor + + See :func:`torch.pow` + """ + + @overload + def pow_(self, exponent: Tensor) -> Tensor: + r""" + pow_(exponent) -> Tensor + + In-place version of :meth:`~Tensor.pow` + """ + + @overload + def pow_(self, exponent: Number | _complex) -> Tensor: + r""" + pow_(exponent) -> Tensor + + In-place version of :meth:`~Tensor.pow` + """ + + def prelu(self, weight: Tensor) -> Tensor: ... + @overload + def prod(self, *, dtype: _dtype | None = None) -> Tensor: + r""" + prod(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.prod` + """ + + @overload + def prod( + self, + dim: _int, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + prod(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.prod` + """ + + @overload + def prod( + self, + dim: str | EllipsisType | None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + prod(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.prod` + """ + + def put( + self, + index: Tensor, + source: Tensor, + accumulate: _bool = False, + ) -> Tensor: + r""" + put(input, index, source, accumulate=False) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.put_`. + `input` corresponds to `self` in :meth:`torch.Tensor.put_`. + """ + + def put_( + self, + index: Tensor, + source: Tensor, + accumulate: _bool = False, + ) -> Tensor: + r""" + put_(index, source, accumulate=False) -> Tensor + + Copies the elements from :attr:`source` into the positions specified by + :attr:`index`. For the purpose of indexing, the :attr:`self` tensor is treated as if + it were a 1-D tensor. + + :attr:`index` and :attr:`source` need to have the same number of elements, but not necessarily + the same shape. + + If :attr:`accumulate` is ``True``, the elements in :attr:`source` are added to + :attr:`self`. If accumulate is ``False``, the behavior is undefined if :attr:`index` + contain duplicate elements. + + Args: + index (LongTensor): the indices into self + source (Tensor): the tensor containing values to copy from + accumulate (bool): whether to accumulate into self + + Example:: + + >>> src = torch.tensor([[4, 3, 5], + ... [6, 7, 8]]) + >>> src.put_(torch.tensor([1, 3]), torch.tensor([9, 10])) + tensor([[ 4, 9, 5], + [ 10, 7, 8]]) + """ + + def q_per_channel_axis(self) -> _int: + r""" + q_per_channel_axis() -> int + + Given a Tensor quantized by linear (affine) per-channel quantization, + returns the index of dimension on which per-channel quantization is applied. + """ + + def q_per_channel_scales(self) -> Tensor: + r""" + q_per_channel_scales() -> Tensor + + Given a Tensor quantized by linear (affine) per-channel quantization, + returns a Tensor of scales of the underlying quantizer. It has the number of + elements that matches the corresponding dimensions (from q_per_channel_axis) of + the tensor. + """ + + def q_per_channel_zero_points(self) -> Tensor: + r""" + q_per_channel_zero_points() -> Tensor + + Given a Tensor quantized by linear (affine) per-channel quantization, + returns a tensor of zero_points of the underlying quantizer. It has the number of + elements that matches the corresponding dimensions (from q_per_channel_axis) of + the tensor. + """ + + def q_scale(self) -> _float: + r""" + q_scale() -> float + + Given a Tensor quantized by linear(affine) quantization, + returns the scale of the underlying quantizer(). + """ + + def q_zero_point(self) -> _int: + r""" + q_zero_point() -> int + + Given a Tensor quantized by linear(affine) quantization, + returns the zero_point of the underlying quantizer(). + """ + + def qr(self, some: _bool = True) -> torch.return_types.qr: + r""" + qr(some=True) -> (Tensor, Tensor) + + See :func:`torch.qr` + """ + + def qscheme(self) -> _qscheme: + r""" + qscheme() -> torch.qscheme + + Returns the quantization scheme of a given QTensor. + """ + + @overload + def quantile( + self, + q: Tensor, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + ) -> Tensor: + r""" + quantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + + See :func:`torch.quantile` + """ + + @overload + def quantile( + self, + q: _float, + dim: _int | None = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + ) -> Tensor: + r""" + quantile(q, dim=None, keepdim=False, *, interpolation='linear') -> Tensor + + See :func:`torch.quantile` + """ + + def rad2deg(self) -> Tensor: + r""" + rad2deg() -> Tensor + + See :func:`torch.rad2deg` + """ + + def rad2deg_(self) -> Tensor: + r""" + rad2deg_() -> Tensor + + In-place version of :meth:`~Tensor.rad2deg` + """ + + @overload + def random_(self, *, generator: Generator | None = None) -> Tensor: + r""" + random_(from=0, to=None, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with numbers sampled from the discrete uniform + distribution over ``[from, to - 1]``. If not specified, the values are usually + only bounded by :attr:`self` tensor's data type. However, for floating point + types, if unspecified, range will be ``[0, 2^mantissa]`` to ensure that every + value is representable. For example, `torch.tensor(1, dtype=torch.double).random_()` + will be uniform in ``[0, 2^53]``. + """ + + @overload + def random_( + self, + from_: _int, + to: _int | None, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + random_(from=0, to=None, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with numbers sampled from the discrete uniform + distribution over ``[from, to - 1]``. If not specified, the values are usually + only bounded by :attr:`self` tensor's data type. However, for floating point + types, if unspecified, range will be ``[0, 2^mantissa]`` to ensure that every + value is representable. For example, `torch.tensor(1, dtype=torch.double).random_()` + will be uniform in ``[0, 2^53]``. + """ + + @overload + def random_( + self, + to: _int, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + random_(from=0, to=None, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with numbers sampled from the discrete uniform + distribution over ``[from, to - 1]``. If not specified, the values are usually + only bounded by :attr:`self` tensor's data type. However, for floating point + types, if unspecified, range will be ``[0, 2^mantissa]`` to ensure that every + value is representable. For example, `torch.tensor(1, dtype=torch.double).random_()` + will be uniform in ``[0, 2^53]``. + """ + + def ravel(self) -> Tensor: + r""" + ravel() -> Tensor + + see :func:`torch.ravel` + """ + + def reciprocal(self) -> Tensor: + r""" + reciprocal() -> Tensor + + See :func:`torch.reciprocal` + """ + + def reciprocal_(self) -> Tensor: + r""" + reciprocal_() -> Tensor + + In-place version of :meth:`~Tensor.reciprocal` + """ + + def record_stream(self, s: Stream) -> None: + r""" + record_stream(stream) + + Marks the tensor as having been used by this stream. When the tensor + is deallocated, ensure the tensor memory is not reused for another tensor + until all work queued on :attr:`stream` at the time of deallocation is + complete. + + .. note:: + + The caching allocator is aware of only the stream where a tensor was + allocated. Due to the awareness, it already correctly manages the life + cycle of tensors on only one stream. But if a tensor is used on a stream + different from the stream of origin, the allocator might reuse the memory + unexpectedly. Calling this method lets the allocator know which streams + have used the tensor. + + .. warning:: + + This method is most suitable for use cases where you are providing a + function that created a tensor on a side stream, and want users to be able + to make use of the tensor without having to think carefully about stream + safety when making use of them. These safety guarantees come at some + performance and predictability cost (analogous to the tradeoff between GC + and manual memory management), so if you are in a situation where + you manage the full lifetime of your tensors, you may consider instead + manually managing CUDA events so that calling this method is not necessary. + In particular, when you call this method, on later allocations the + allocator will poll the recorded stream to see if all operations have + completed yet; you can potentially race with side stream computation and + non-deterministically reuse or fail to reuse memory for an allocation. + + You can safely use tensors allocated on side streams without + :meth:`~Tensor.record_stream`; you must manually ensure that + any non-creation stream uses of a tensor are synced back to the creation + stream before you deallocate the tensor. As the CUDA caching allocator + guarantees that the memory will only be reused with the same creation stream, + this is sufficient to ensure that writes to future reallocations of the + memory will be delayed until non-creation stream uses are done. + (Counterintuitively, you may observe that on the CPU side we have already + reallocated the tensor, even though CUDA kernels on the old tensor are + still in progress. This is fine, because CUDA operations on the new + tensor will appropriately wait for the old operations to complete, as they + are all on the same stream.) + + Concretely, this looks like this:: + + with torch.cuda.stream(s0): + x = torch.zeros(N) + + s1.wait_stream(s0) + with torch.cuda.stream(s1): + y = some_comm_op(x) + + ... some compute on s0 ... + + # synchronize creation stream s0 to side stream s1 + # before deallocating x + s0.wait_stream(s1) + del x + + Note that some discretion is required when deciding when to perform + ``s0.wait_stream(s1)``. In particular, if we were to wait immediately + after ``some_comm_op``, there wouldn't be any point in having the side + stream; it would be equivalent to have run ``some_comm_op`` on ``s0``. + Instead, the synchronization must be placed at some appropriate, later + point in time where you expect the side stream ``s1`` to have finished + work. This location is typically identified via profiling, e.g., using + Chrome traces produced + :meth:`torch.autograd.profiler.profile.export_chrome_trace`. If you + place the wait too early, work on s0 will block until ``s1`` has finished, + preventing further overlapping of communication and computation. If you + place the wait too late, you will use more memory than is strictly + necessary (as you are keeping ``x`` live for longer.) For a concrete + example of how this guidance can be applied in practice, see this post: + `FSDP and CUDACachingAllocator + `_. + """ + + def refine_names( + self, + names: Sequence[str | EllipsisType | None], + ) -> Tensor: ... + def relu(self) -> Tensor: ... + def relu_(self) -> Tensor: ... + @overload + def remainder(self, other: Tensor) -> Tensor: + r""" + remainder(divisor) -> Tensor + + See :func:`torch.remainder` + """ + + @overload + def remainder(self, other: Number | _complex) -> Tensor: + r""" + remainder(divisor) -> Tensor + + See :func:`torch.remainder` + """ + + @overload + def remainder_(self, other: Tensor) -> Tensor: + r""" + remainder_(divisor) -> Tensor + + In-place version of :meth:`~Tensor.remainder` + """ + + @overload + def remainder_(self, other: Number | _complex) -> Tensor: + r""" + remainder_(divisor) -> Tensor + + In-place version of :meth:`~Tensor.remainder` + """ + + def rename( + self, + names: Sequence[str | EllipsisType | None] | None, + ) -> Tensor: ... + def rename_( + self, + names: Sequence[str | EllipsisType | None] | None, + ) -> Tensor: ... + def renorm( + self, + p: Number | _complex, + dim: _int, + maxnorm: Number | _complex, + ) -> Tensor: + r""" + renorm(p, dim, maxnorm) -> Tensor + + See :func:`torch.renorm` + """ + + def renorm_( + self, + p: Number | _complex, + dim: _int, + maxnorm: Number | _complex, + ) -> Tensor: + r""" + renorm_(p, dim, maxnorm) -> Tensor + + In-place version of :meth:`~Tensor.renorm` + """ + + @overload + def repeat(self, repeats: Sequence[_int | SymInt]) -> Tensor: + r""" + repeat(*repeats) -> Tensor + + Repeats this tensor along the specified dimensions. + + Unlike :meth:`~Tensor.expand`, this function copies the tensor's data. + + .. warning:: + + :meth:`~Tensor.repeat` behaves differently from + `numpy.repeat `_, + but is more similar to + `numpy.tile `_. + For the operator similar to `numpy.repeat`, see :func:`torch.repeat_interleave`. + + Args: + repeat (torch.Size, int..., tuple of int or list of int): The number of times to repeat this tensor along each dimension + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat(4, 2) + tensor([[ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3]]) + >>> x.repeat(4, 2, 1).size() + torch.Size([4, 2, 3]) + """ + + @overload + def repeat(self, *repeats: _int | SymInt) -> Tensor: + r""" + repeat(*repeats) -> Tensor + + Repeats this tensor along the specified dimensions. + + Unlike :meth:`~Tensor.expand`, this function copies the tensor's data. + + .. warning:: + + :meth:`~Tensor.repeat` behaves differently from + `numpy.repeat `_, + but is more similar to + `numpy.tile `_. + For the operator similar to `numpy.repeat`, see :func:`torch.repeat_interleave`. + + Args: + repeat (torch.Size, int..., tuple of int or list of int): The number of times to repeat this tensor along each dimension + + Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> x.repeat(4, 2) + tensor([[ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3]]) + >>> x.repeat(4, 2, 1).size() + torch.Size([4, 2, 3]) + """ + + @overload + def repeat_interleave( + self, + repeats: Tensor, + dim: _int | None = None, + *, + output_size: _int | SymInt | None = None, + ) -> Tensor: + r""" + repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor + + See :func:`torch.repeat_interleave`. + """ + + @overload + def repeat_interleave( + self, + repeats: _int | SymInt, + dim: _int | None = None, + *, + output_size: _int | SymInt | None = None, + ) -> Tensor: + r""" + repeat_interleave(repeats, dim=None, *, output_size=None) -> Tensor + + See :func:`torch.repeat_interleave`. + """ + + def requires_grad_(self, mode: _bool = True) -> Tensor: + r""" + requires_grad_(requires_grad=True) -> Tensor + + Change if autograd should record operations on this tensor: sets this tensor's + :attr:`requires_grad` attribute in-place. Returns this tensor. + + :func:`requires_grad_`'s main use case is to tell autograd to begin recording + operations on a Tensor ``tensor``. If ``tensor`` has ``requires_grad=False`` + (because it was obtained through a DataLoader, or required preprocessing or + initialization), ``tensor.requires_grad_()`` makes it so that autograd will + begin to record operations on ``tensor``. + + Args: + requires_grad (bool): If autograd should record operations on this tensor. + Default: ``True``. + + Example:: + + >>> # Let's say we want to preprocess some saved weights and use + >>> # the result as new weights. + >>> saved_weights = [0.1, 0.2, 0.3, 0.25] + >>> loaded_weights = torch.tensor(saved_weights) + >>> weights = preprocess(loaded_weights) # some function + >>> weights + tensor([-0.5503, 0.4926, -2.1158, -0.8303]) + + >>> # Now, start to record operations done to weights + >>> weights.requires_grad_() + >>> out = weights.pow(2).sum() + >>> out.backward() + >>> weights.grad + tensor([-1.1007, 0.9853, -4.2316, -1.6606]) + """ + + @overload + def reshape(self, shape: Sequence[_int | SymInt]) -> Tensor: + r""" + reshape(*shape) -> Tensor + + Returns a tensor with the same data and number of elements as :attr:`self` + but with the specified shape. This method returns a view if :attr:`shape` is + compatible with the current shape. See :meth:`torch.Tensor.view` on when it is + possible to return a view. + + See :func:`torch.reshape` + + Args: + shape (tuple of ints or int...): the desired shape + """ + + @overload + def reshape(self, *shape: _int | SymInt) -> Tensor: + r""" + reshape(*shape) -> Tensor + + Returns a tensor with the same data and number of elements as :attr:`self` + but with the specified shape. This method returns a view if :attr:`shape` is + compatible with the current shape. See :meth:`torch.Tensor.view` on when it is + possible to return a view. + + See :func:`torch.reshape` + + Args: + shape (tuple of ints or int...): the desired shape + """ + + def reshape_as(self, other: Tensor) -> Tensor: + r""" + reshape_as(other) -> Tensor + + Returns this tensor as the same shape as :attr:`other`. + ``self.reshape_as(other)`` is equivalent to ``self.reshape(other.sizes())``. + This method returns a view if ``other.sizes()`` is compatible with the current + shape. See :meth:`torch.Tensor.view` on when it is possible to return a view. + + Please see :meth:`reshape` for more information about ``reshape``. + + Args: + other (:class:`torch.Tensor`): The result tensor has the same shape + as :attr:`other`. + """ + + @overload + def resize_( + self, + size: Sequence[_int | SymInt], + *, + memory_format: memory_format | None = None, + ) -> Tensor: + r""" + resize_(*sizes, memory_format=torch.contiguous_format) -> Tensor + + Resizes :attr:`self` tensor to the specified size. If the number of elements is + larger than the current storage size, then the underlying storage is resized + to fit the new number of elements. If the number of elements is smaller, the + underlying storage is not changed. Existing elements are preserved but any new + memory is uninitialized. + + .. warning:: + + This is a low-level method. The storage is reinterpreted as C-contiguous, + ignoring the current strides (unless the target size equals the current + size, in which case the tensor is left unchanged). For most purposes, you + will instead want to use :meth:`~Tensor.view()`, which checks for + contiguity, or :meth:`~Tensor.reshape()`, which copies data if needed. To + change the size in-place with custom strides, see :meth:`~Tensor.set_()`. + + .. note:: + + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, new elements are initialized to prevent nondeterministic behavior + from using the result as an input to an operation. Floating point and + complex values are set to NaN, and integer values are set to the maximum + value. + + Args: + sizes (torch.Size or int...): the desired size + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``sizes``. + + Example:: + + >>> x = torch.tensor([[1, 2], [3, 4], [5, 6]]) + >>> x.resize_(2, 2) + tensor([[ 1, 2], + [ 3, 4]]) + """ + + @overload + def resize_( + self, + *size: _int | SymInt, + memory_format: memory_format | None = None, + ) -> Tensor: + r""" + resize_(*sizes, memory_format=torch.contiguous_format) -> Tensor + + Resizes :attr:`self` tensor to the specified size. If the number of elements is + larger than the current storage size, then the underlying storage is resized + to fit the new number of elements. If the number of elements is smaller, the + underlying storage is not changed. Existing elements are preserved but any new + memory is uninitialized. + + .. warning:: + + This is a low-level method. The storage is reinterpreted as C-contiguous, + ignoring the current strides (unless the target size equals the current + size, in which case the tensor is left unchanged). For most purposes, you + will instead want to use :meth:`~Tensor.view()`, which checks for + contiguity, or :meth:`~Tensor.reshape()`, which copies data if needed. To + change the size in-place with custom strides, see :meth:`~Tensor.set_()`. + + .. note:: + + If :func:`torch.use_deterministic_algorithms()` and + :attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to + ``True``, new elements are initialized to prevent nondeterministic behavior + from using the result as an input to an operation. Floating point and + complex values are set to NaN, and integer values are set to the maximum + value. + + Args: + sizes (torch.Size or int...): the desired size + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``sizes``. + + Example:: + + >>> x = torch.tensor([[1, 2], [3, 4], [5, 6]]) + >>> x.resize_(2, 2) + tensor([[ 1, 2], + [ 3, 4]]) + """ + + def resize_as_( + self, + the_template: Tensor, + *, + memory_format: memory_format | None = None, + ) -> Tensor: + r""" + resize_as_(tensor, memory_format=torch.contiguous_format) -> Tensor + + Resizes the :attr:`self` tensor to be the same size as the specified + :attr:`tensor`. This is equivalent to ``self.resize_(tensor.size())``. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + Tensor. Default: ``torch.contiguous_format``. Note that memory format of + :attr:`self` is going to be unaffected if ``self.size()`` matches ``tensor.size()``. + """ + + def resize_as_sparse_(self, the_template: Tensor) -> Tensor: ... + def resolve_conj(self) -> Tensor: + r""" + resolve_conj() -> Tensor + + See :func:`torch.resolve_conj` + """ + + def resolve_neg(self) -> Tensor: + r""" + resolve_neg() -> Tensor + + See :func:`torch.resolve_neg` + """ + + def retain_grad(self) -> None: + r""" + retain_grad() -> None + + Enables this Tensor to have their :attr:`grad` populated during + :func:`backward`. This is a no-op for leaf tensors. + """ + + def roll( + self, + shifts: _int | SymInt | Sequence[_int | SymInt], + dims: _int | _size = (), + ) -> Tensor: + r""" + roll(shifts, dims) -> Tensor + + See :func:`torch.roll` + """ + + def rot90(self, k: _int = 1, dims: _size = (0, 1)) -> Tensor: + r""" + rot90(k, dims) -> Tensor + + See :func:`torch.rot90` + """ + + @overload + def round(self) -> Tensor: + r""" + round(decimals=0) -> Tensor + + See :func:`torch.round` + """ + + @overload + def round(self, *, decimals: _int) -> Tensor: + r""" + round(decimals=0) -> Tensor + + See :func:`torch.round` + """ + + @overload + def round_(self) -> Tensor: + r""" + round_(decimals=0) -> Tensor + + In-place version of :meth:`~Tensor.round` + """ + + @overload + def round_(self, *, decimals: _int) -> Tensor: + r""" + round_(decimals=0) -> Tensor + + In-place version of :meth:`~Tensor.round` + """ + + def row_indices(self) -> Tensor: ... + def rsqrt(self) -> Tensor: + r""" + rsqrt() -> Tensor + + See :func:`torch.rsqrt` + """ + + def rsqrt_(self) -> Tensor: + r""" + rsqrt_() -> Tensor + + In-place version of :meth:`~Tensor.rsqrt` + """ + + @overload + def scatter(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + + @overload + def scatter( + self, + dim: _int, + index: Tensor, + src: Tensor, + *, + reduce: str, + ) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + + @overload + def scatter( + self, + dim: _int, + index: Tensor, + value: Number | _complex, + *, + reduce: str, + ) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + + @overload + def scatter( + self, + dim: str | EllipsisType | None, + index: Tensor, + src: Tensor, + ) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + + @overload + def scatter( + self, + dim: _int, + index: Tensor, + value: Number | _complex, + ) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + + @overload + def scatter( + self, + dim: str | EllipsisType | None, + index: Tensor, + value: Number | _complex, + ) -> Tensor: + r""" + scatter(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_` + """ + + @overload + def scatter_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: + r""" + scatter_(dim, index, src, *, reduce=None) -> Tensor + + Writes all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor. For each value in :attr:`src`, its output + index is specified by its index in :attr:`src` for ``dimension != dim`` and by + the corresponding value in :attr:`index` for ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + + :attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be + between ``0`` and ``self.size(dim) - 1`` inclusive. + + .. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, which is applied to all + values in the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index`. For each value in :attr:`src`, the reduction + operation is applied to an index in :attr:`self` which is specified by + its index in :attr:`src` for ``dimension != dim`` and by the corresponding + value in :attr:`index` for ``dimension = dim``. + + Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` + is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + + Reducing with the addition operation is the same as using + :meth:`~torch.Tensor.scatter_add_`. + + .. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + .. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + + Writes the value from :attr:`value` into :attr:`self` at the indices + specified in the :attr:`index` tensor. This operation is equivalent to the previous version, + with the :attr:`src` tensor filled entirely with :attr:`value`. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ + + @overload + def scatter_( + self, + dim: _int, + index: Tensor, + src: Tensor, + *, + reduce: str, + ) -> Tensor: + r""" + scatter_(dim, index, src, *, reduce=None) -> Tensor + + Writes all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor. For each value in :attr:`src`, its output + index is specified by its index in :attr:`src` for ``dimension != dim`` and by + the corresponding value in :attr:`index` for ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + + :attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be + between ``0`` and ``self.size(dim) - 1`` inclusive. + + .. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, which is applied to all + values in the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index`. For each value in :attr:`src`, the reduction + operation is applied to an index in :attr:`self` which is specified by + its index in :attr:`src` for ``dimension != dim`` and by the corresponding + value in :attr:`index` for ``dimension = dim``. + + Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` + is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + + Reducing with the addition operation is the same as using + :meth:`~torch.Tensor.scatter_add_`. + + .. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + .. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + + Writes the value from :attr:`value` into :attr:`self` at the indices + specified in the :attr:`index` tensor. This operation is equivalent to the previous version, + with the :attr:`src` tensor filled entirely with :attr:`value`. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ + + @overload + def scatter_( + self, + dim: _int, + index: Tensor, + value: Number | _complex, + *, + reduce: str, + ) -> Tensor: + r""" + scatter_(dim, index, src, *, reduce=None) -> Tensor + + Writes all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor. For each value in :attr:`src`, its output + index is specified by its index in :attr:`src` for ``dimension != dim`` and by + the corresponding value in :attr:`index` for ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + + :attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be + between ``0`` and ``self.size(dim) - 1`` inclusive. + + .. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, which is applied to all + values in the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index`. For each value in :attr:`src`, the reduction + operation is applied to an index in :attr:`self` which is specified by + its index in :attr:`src` for ``dimension != dim`` and by the corresponding + value in :attr:`index` for ``dimension = dim``. + + Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` + is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + + Reducing with the addition operation is the same as using + :meth:`~torch.Tensor.scatter_add_`. + + .. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + .. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + + Writes the value from :attr:`value` into :attr:`self` at the indices + specified in the :attr:`index` tensor. This operation is equivalent to the previous version, + with the :attr:`src` tensor filled entirely with :attr:`value`. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ + + @overload + def scatter_( + self, + dim: _int, + index: Tensor, + value: Number | _complex, + ) -> Tensor: + r""" + scatter_(dim, index, src, *, reduce=None) -> Tensor + + Writes all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor. For each value in :attr:`src`, its output + index is specified by its index in :attr:`src` for ``dimension != dim`` and by + the corresponding value in :attr:`index` for ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + This is the reverse operation of the manner described in :meth:`~Tensor.gather`. + + :attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be + between ``0`` and ``self.size(dim) - 1`` inclusive. + + .. warning:: + + When indices are not unique, the behavior is non-deterministic (one of the + values from ``src`` will be picked arbitrarily) and the gradient will be + incorrect (it will be propagated to all locations in the source that + correspond to the same index)! + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Additionally accepts an optional :attr:`reduce` argument that allows + specification of an optional reduction operation, which is applied to all + values in the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index`. For each value in :attr:`src`, the reduction + operation is applied to an index in :attr:`self` which is specified by + its index in :attr:`src` for ``dimension != dim`` and by the corresponding + value in :attr:`index` for ``dimension = dim``. + + Given a 3-D tensor and reduction using the multiplication operation, :attr:`self` + is updated as:: + + self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2 + + Reducing with the addition operation is the same as using + :meth:`~torch.Tensor.scatter_add_`. + + .. warning:: + The reduce argument with Tensor ``src`` is deprecated and will be removed in + a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_` + instead for more reduction options. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + src (Tensor): the source element(s) to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> src = torch.arange(1, 11).reshape((2, 5)) + >>> src + tensor([[ 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10]]) + >>> index = torch.tensor([[0, 1, 2, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) + tensor([[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]) + >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) + tensor([[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]) + + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='multiply') + tensor([[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]) + >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), + ... 1.23, reduce='add') + tensor([[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + .. function:: scatter_(dim, index, value, *, reduce=None) -> Tensor: + :noindex: + + Writes the value from :attr:`value` into :attr:`self` at the indices + specified in the :attr:`index` tensor. This operation is equivalent to the previous version, + with the :attr:`src` tensor filled entirely with :attr:`value`. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter, can be either empty + or of the same dimensionality as ``src``. When empty, the operation + returns ``self`` unchanged. + value (Scalar): the value to scatter. + + Keyword args: + reduce (str, optional): reduction operation to apply, can be either + ``'add'`` or ``'multiply'``. + + Example:: + + >>> index = torch.tensor([[0, 1]]) + >>> value = 2 + >>> torch.zeros(3, 5).scatter_(0, index, value) + tensor([[2., 0., 0., 0., 0.], + [0., 2., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ + + @overload + def scatter_add(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: + r""" + scatter_add(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + + @overload + def scatter_add( + self, + dim: str | EllipsisType | None, + index: Tensor, + src: Tensor, + ) -> Tensor: + r""" + scatter_add(dim, index, src) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_add_` + """ + + def scatter_add_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: + r""" + scatter_add_(dim, index, src) -> Tensor + + Adds all values from the tensor :attr:`src` into :attr:`self` at the indices + specified in the :attr:`index` tensor in a similar fashion as + :meth:`~torch.Tensor.scatter_`. For each value in :attr:`src`, it is added to + an index in :attr:`self` which is specified by its index in :attr:`src` + for ``dimension != dim`` and by the corresponding value in :attr:`index` for + ``dimension = dim``. + + For a 3-D tensor, :attr:`self` is updated as:: + + self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + + :attr:`self`, :attr:`index` and :attr:`src` should have same number of + dimensions. It is also required that ``index.size(d) <= src.size(d)`` for all + dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all dimensions + ``d != dim``. Note that ``index`` and ``src`` do not broadcast. + + Note: + This operation may behave nondeterministically when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter and add, can be + either empty or of the same dimensionality as ``src``. When empty, the + operation returns ``self`` unchanged. + src (Tensor): the source elements to scatter and add + + Example:: + + >>> src = torch.ones((2, 5)) + >>> index = torch.tensor([[0, 1, 2, 0, 0]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[1., 0., 0., 1., 1.], + [0., 1., 0., 0., 0.], + [0., 0., 1., 0., 0.]]) + >>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) + >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) + tensor([[2., 0., 0., 1., 1.], + [0., 2., 0., 0., 0.], + [0., 0., 2., 1., 1.]]) + """ + + def scatter_reduce( + self, + dim: _int, + index: Tensor, + src: Tensor, + reduce: str, + *, + include_self: _bool = True, + ) -> Tensor: + r""" + scatter_reduce(dim, index, src, reduce, *, include_self=True) -> Tensor + + Out-of-place version of :meth:`torch.Tensor.scatter_reduce_` + """ + + def scatter_reduce_( + self, + dim: _int, + index: Tensor, + src: Tensor, + reduce: str, + *, + include_self: _bool = True, + ) -> Tensor: + r""" + scatter_reduce_(dim, index, src, reduce, *, include_self=True) -> Tensor + + Reduces all values from the :attr:`src` tensor to the indices specified in + the :attr:`index` tensor in the :attr:`self` tensor using the applied reduction + defined via the :attr:`reduce` argument (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, + :obj:`"amax"`, :obj:`"amin"`). For each value in :attr:`src`, it is reduced to an + index in :attr:`self` which is specified by its index in :attr:`src` for + ``dimension != dim`` and by the corresponding value in :attr:`index` for + ``dimension = dim``. If :obj:`include_self="True"`, the values in the :attr:`self` + tensor are included in the reduction. + + :attr:`self`, :attr:`index` and :attr:`src` should all have + the same number of dimensions. It is also required that + ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that + ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. + Note that ``index`` and ``src`` do not broadcast. + + For a 3-D tensor with :obj:`reduce="sum"` and :obj:`include_self=True` the + output is given as:: + + self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + + Note: + This operation may behave nondeterministically when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information. + + .. note:: + + The backward pass is implemented only for ``src.shape == index.shape``. + + .. warning:: + + This function is in beta and may change in the near future. + + Args: + dim (int): the axis along which to index + index (LongTensor): the indices of elements to scatter and reduce. + src (Tensor): the source elements to scatter and reduce + reduce (str): the reduction operation to apply for non-unique indices + (:obj:`"sum"`, :obj:`"prod"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`) + include_self (bool): whether elements from the :attr:`self` tensor are + included in the reduction + + Example:: + + >>> src = torch.tensor([1., 2., 3., 4., 5., 6.]) + >>> index = torch.tensor([0, 1, 0, 1, 2, 1]) + >>> input = torch.tensor([1., 2., 3., 4.]) + >>> input.scatter_reduce(0, index, src, reduce="sum") + tensor([5., 14., 8., 4.]) + >>> input.scatter_reduce(0, index, src, reduce="sum", include_self=False) + tensor([4., 12., 5., 4.]) + >>> input2 = torch.tensor([5., 4., 3., 2.]) + >>> input2.scatter_reduce(0, index, src, reduce="amax") + tensor([5., 6., 5., 2.]) + >>> input2.scatter_reduce(0, index, src, reduce="amax", include_self=False) + tensor([3., 6., 5., 2.]) + """ + + @overload + def select(self, dim: _int, index: _int | SymInt) -> Tensor: + r""" + select(dim, index) -> Tensor + + See :func:`torch.select` + """ + + @overload + def select(self, dim: str | EllipsisType | None, index: _int) -> Tensor: + r""" + select(dim, index) -> Tensor + + See :func:`torch.select` + """ + + def select_scatter( + self, + src: Tensor, + dim: _int, + index: _int | SymInt, + ) -> Tensor: + r""" + select_scatter(src, dim, index) -> Tensor + + See :func:`torch.select_scatter` + """ + + @overload + def set_( + self, + source: Storage | TypedStorage | UntypedStorage, + storage_offset: IntLikeType, + size: _symsize, + stride: _symsize, + ) -> Tensor: + r""" + set_(source=None, storage_offset=0, size=None, stride=None) -> Tensor + + Sets the underlying storage, size, and strides. If :attr:`source` is a tensor, + :attr:`self` tensor will share the same storage and have the same size and + strides as :attr:`source`. Changes to elements in one tensor will be reflected + in the other. + + If :attr:`source` is a :class:`~torch.Storage`, the method sets the underlying + storage, offset, size, and stride. + + Args: + source (Tensor or Storage): the tensor or storage to use + storage_offset (int, optional): the offset in the storage + size (torch.Size, optional): the desired size. Defaults to the size of the source. + stride (tuple, optional): the desired stride. Defaults to C-contiguous strides. + """ + + @overload + def set_(self, source: Storage | TypedStorage | UntypedStorage) -> Tensor: + r""" + set_(source=None, storage_offset=0, size=None, stride=None) -> Tensor + + Sets the underlying storage, size, and strides. If :attr:`source` is a tensor, + :attr:`self` tensor will share the same storage and have the same size and + strides as :attr:`source`. Changes to elements in one tensor will be reflected + in the other. + + If :attr:`source` is a :class:`~torch.Storage`, the method sets the underlying + storage, offset, size, and stride. + + Args: + source (Tensor or Storage): the tensor or storage to use + storage_offset (int, optional): the offset in the storage + size (torch.Size, optional): the desired size. Defaults to the size of the source. + stride (tuple, optional): the desired stride. Defaults to C-contiguous strides. + """ + + def sgn(self) -> Tensor: + r""" + sgn() -> Tensor + + See :func:`torch.sgn` + """ + + def sgn_(self) -> Tensor: + r""" + sgn_() -> Tensor + + In-place version of :meth:`~Tensor.sgn` + """ + + def short(self) -> Tensor: + r""" + short(memory_format=torch.preserve_format) -> Tensor + + ``self.short()`` is equivalent to ``self.to(torch.int16)``. See :func:`to`. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + def sigmoid(self) -> Tensor: + r""" + sigmoid() -> Tensor + + See :func:`torch.sigmoid` + """ + + def sigmoid_(self) -> Tensor: + r""" + sigmoid_() -> Tensor + + In-place version of :meth:`~Tensor.sigmoid` + """ + + def sign(self) -> Tensor: + r""" + sign() -> Tensor + + See :func:`torch.sign` + """ + + def sign_(self) -> Tensor: + r""" + sign_() -> Tensor + + In-place version of :meth:`~Tensor.sign` + """ + + def signbit(self) -> Tensor: + r""" + signbit() -> Tensor + + See :func:`torch.signbit` + """ + + def sin(self) -> Tensor: + r""" + sin() -> Tensor + + See :func:`torch.sin` + """ + + def sin_(self) -> Tensor: + r""" + sin_() -> Tensor + + In-place version of :meth:`~Tensor.sin` + """ + + def sinc(self) -> Tensor: + r""" + sinc() -> Tensor + + See :func:`torch.sinc` + """ + + def sinc_(self) -> Tensor: + r""" + sinc_() -> Tensor + + In-place version of :meth:`~Tensor.sinc` + """ + + def sinh(self) -> Tensor: + r""" + sinh() -> Tensor + + See :func:`torch.sinh` + """ + + def sinh_(self) -> Tensor: + r""" + sinh_() -> Tensor + + In-place version of :meth:`~Tensor.sinh` + """ + + @overload + def size(self, dim: None = None) -> Size: + r""" + size(dim=None) -> torch.Size or int + + Returns the size of the :attr:`self` tensor. If ``dim`` is not specified, + the returned value is a :class:`torch.Size`, a subclass of :class:`tuple`. + If ``dim`` is specified, returns an int holding the size of that dimension. + + Args: + dim (int, optional): The dimension for which to retrieve the size. + + Example:: + + >>> t = torch.empty(3, 4, 5) + >>> t.size() + torch.Size([3, 4, 5]) + >>> t.size(dim=1) + 4 + """ + + @overload + def size(self, dim: _int) -> _int: + r""" + size(dim=None) -> torch.Size or int + + Returns the size of the :attr:`self` tensor. If ``dim`` is not specified, + the returned value is a :class:`torch.Size`, a subclass of :class:`tuple`. + If ``dim`` is specified, returns an int holding the size of that dimension. + + Args: + dim (int, optional): The dimension for which to retrieve the size. + + Example:: + + >>> t = torch.empty(3, 4, 5) + >>> t.size() + torch.Size([3, 4, 5]) + >>> t.size(dim=1) + 4 + """ + + def slice_inverse( + self, + src: Tensor, + dim: _int = 0, + start: _int | SymInt | None = None, + end: _int | SymInt | None = None, + step: _int | SymInt = 1, + ) -> Tensor: ... + def slice_scatter( + self, + src: Tensor, + dim: _int = 0, + start: _int | SymInt | None = None, + end: _int | SymInt | None = None, + step: _int | SymInt = 1, + ) -> Tensor: + r""" + slice_scatter(src, dim=0, start=None, end=None, step=1) -> Tensor + + See :func:`torch.slice_scatter` + """ + + def slogdet(self) -> torch.return_types.slogdet: + r""" + slogdet() -> (Tensor, Tensor) + + See :func:`torch.slogdet` + """ + + def smm(self, mat2: Tensor) -> Tensor: + r""" + smm(mat) -> Tensor + + See :func:`torch.smm` + """ + + @overload + def softmax(self, dim: _int, dtype: _dtype | None = None) -> Tensor: + r""" + softmax(dim) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + + @overload + def softmax( + self, + dim: str | EllipsisType | None, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + softmax(dim) -> Tensor + + Alias for :func:`torch.nn.functional.softmax`. + """ + + @overload + def sort( + self, + *, + stable: _bool | None, + dim: _int = -1, + descending: _bool = False, + ) -> torch.return_types.sort: + r""" + sort(dim=-1, descending=False) -> (Tensor, LongTensor) + + See :func:`torch.sort` + """ + + @overload + def sort( + self, + dim: _int = -1, + descending: _bool = False, + ) -> torch.return_types.sort: + r""" + sort(dim=-1, descending=False) -> (Tensor, LongTensor) + + See :func:`torch.sort` + """ + + @overload + def sort( + self, + *, + stable: _bool | None, + dim: str | EllipsisType | None, + descending: _bool = False, + ) -> torch.return_types.sort: + r""" + sort(dim=-1, descending=False) -> (Tensor, LongTensor) + + See :func:`torch.sort` + """ + + @overload + def sort( + self, + dim: str | EllipsisType | None, + descending: _bool = False, + ) -> torch.return_types.sort: + r""" + sort(dim=-1, descending=False) -> (Tensor, LongTensor) + + See :func:`torch.sort` + """ + + def sparse_dim(self) -> _int: + r""" + sparse_dim() -> int + + Return the number of sparse dimensions in a :ref:`sparse tensor ` :attr:`self`. + + .. note:: + Returns ``0`` if :attr:`self` is not a sparse tensor. + + See also :meth:`Tensor.dense_dim` and :ref:`hybrid tensors `. + """ + + def sparse_mask(self, mask: Tensor) -> Tensor: + r""" + sparse_mask(mask) -> Tensor + + Returns a new :ref:`sparse tensor ` with values from a + strided tensor :attr:`self` filtered by the indices of the sparse + tensor :attr:`mask`. The values of :attr:`mask` sparse tensor are + ignored. :attr:`self` and :attr:`mask` tensors must have the same + shape. + + .. note:: + + The returned sparse tensor might contain duplicate values if :attr:`mask` + is not coalesced. It is therefore advisable to pass ``mask.coalesce()`` + if such behavior is not desired. + + .. note:: + + The returned sparse tensor has the same indices as the sparse tensor + :attr:`mask`, even when the corresponding values in :attr:`self` are + zeros. + + Args: + mask (Tensor): a sparse tensor whose indices are used as a filter + + Example:: + + >>> nse = 5 + >>> dims = (5, 5, 2, 2) + >>> I = torch.cat([torch.randint(0, dims[0], size=(nse,)), + ... torch.randint(0, dims[1], size=(nse,))], 0).reshape(2, nse) + >>> V = torch.randn(nse, dims[2], dims[3]) + >>> S = torch.sparse_coo_tensor(I, V, dims).coalesce() + >>> D = torch.randn(dims) + >>> D.sparse_mask(S) + tensor(indices=tensor([[0, 0, 0, 2], + [0, 1, 4, 3]]), + values=tensor([[[ 1.6550, 0.2397], + [-0.1611, -0.0779]], + + [[ 0.2326, -1.0558], + [ 1.4711, 1.9678]], + + [[-0.5138, -0.0411], + [ 1.9417, 0.5158]], + + [[ 0.0793, 0.0036], + [-0.2569, -0.1055]]]), + size=(5, 5, 2, 2), nnz=4, layout=torch.sparse_coo) + """ + + def sparse_resize_( + self, + size: _size, + sparse_dim: _int, + dense_dim: _int, + ) -> Tensor: + r""" + sparse_resize_(size, sparse_dim, dense_dim) -> Tensor + + Resizes :attr:`self` :ref:`sparse tensor ` to the desired + size and the number of sparse and dense dimensions. + + .. note:: + If the number of specified elements in :attr:`self` is zero, then + :attr:`size`, :attr:`sparse_dim`, and :attr:`dense_dim` can be any + size and positive integers such that ``len(size) == sparse_dim + + dense_dim``. + + If :attr:`self` specifies one or more elements, however, then each + dimension in :attr:`size` must not be smaller than the corresponding + dimension of :attr:`self`, :attr:`sparse_dim` must equal the number + of sparse dimensions in :attr:`self`, and :attr:`dense_dim` must + equal the number of dense dimensions in :attr:`self`. + + .. warning:: + Throws an error if :attr:`self` is not a sparse tensor. + + Args: + size (torch.Size): the desired size. If :attr:`self` is non-empty + sparse tensor, the desired size cannot be smaller than the + original size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions + """ + + def sparse_resize_and_clear_( + self, + size: _size, + sparse_dim: _int, + dense_dim: _int, + ) -> Tensor: + r""" + sparse_resize_and_clear_(size, sparse_dim, dense_dim) -> Tensor + + Removes all specified elements from a :ref:`sparse tensor + ` :attr:`self` and resizes :attr:`self` to the desired + size and the number of sparse and dense dimensions. + + .. warning: + Throws an error if :attr:`self` is not a sparse tensor. + + Args: + size (torch.Size): the desired size. + sparse_dim (int): the number of sparse dimensions + dense_dim (int): the number of dense dimensions + """ + + @overload + def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ... + @overload + def split( + self, + split_size: tuple[_int, ...], + dim: _int = 0, + ) -> Sequence[Tensor]: ... + def split_with_sizes( + self, + split_sizes: Sequence[_int | SymInt], + dim: _int = 0, + ) -> tuple[Tensor, ...]: ... + def sqrt(self) -> Tensor: + r""" + sqrt() -> Tensor + + See :func:`torch.sqrt` + """ + + def sqrt_(self) -> Tensor: + r""" + sqrt_() -> Tensor + + In-place version of :meth:`~Tensor.sqrt` + """ + + def square(self) -> Tensor: + r""" + square() -> Tensor + + See :func:`torch.square` + """ + + def square_(self) -> Tensor: + r""" + square_() -> Tensor + + In-place version of :meth:`~Tensor.square` + """ + + @overload + def squeeze(self) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + + @overload + def squeeze(self, dim: _int) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + + @overload + def squeeze(self, dim: _size) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + + @overload + def squeeze(self, *dim: _int) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + + @overload + def squeeze(self, dim: str | EllipsisType | None) -> Tensor: + r""" + squeeze(dim=None) -> Tensor + + See :func:`torch.squeeze` + """ + + @overload + def squeeze_(self) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + + @overload + def squeeze_(self, dim: _int) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + + @overload + def squeeze_(self, dim: _size) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + + @overload + def squeeze_(self, *dim: _int) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + + @overload + def squeeze_(self, dim: str | EllipsisType | None) -> Tensor: + r""" + squeeze_(dim=None) -> Tensor + + In-place version of :meth:`~Tensor.squeeze` + """ + + def sspaddmm( + self, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number | _complex = 1, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + sspaddmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor + + See :func:`torch.sspaddmm` + """ + + @overload + def std( + self, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, + ) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + + @overload + def std( + self, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + ) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + + @overload + def std(self, unbiased: _bool = True) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + + @overload + def std( + self, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, + ) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + + @overload + def std( + self, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + ) -> Tensor: + r""" + std(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.std` + """ + + def untyped_storage(self) -> UntypedStorage: ... + def storage_offset(self) -> _int | SymInt: + r""" + storage_offset() -> int + + Returns :attr:`self` tensor's offset in the underlying storage in terms of + number of storage elements (not bytes). + + Example:: + + >>> x = torch.tensor([1, 2, 3, 4, 5]) + >>> x.storage_offset() + 0 + >>> x[3:].storage_offset() + 3 + """ + + def storage_type(self) -> Storage: ... + @overload + def stride(self, dim: None = None) -> tuple[_int, ...]: + r""" + stride(dim) -> tuple or int + + Returns the stride of :attr:`self` tensor. + + Stride is the jump necessary to go from one element to the next one in the + specified dimension :attr:`dim`. A tuple of all strides is returned when no + argument is passed in. Otherwise, an integer value is returned as the stride in + the particular dimension :attr:`dim`. + + Args: + dim (int, optional): the desired dimension in which stride is required + + Example:: + + >>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + >>> x.stride() + (5, 1) + >>> x.stride(0) + 5 + >>> x.stride(-1) + 1 + """ + + @overload + def stride(self, dim: _int) -> _int: + r""" + stride(dim) -> tuple or int + + Returns the stride of :attr:`self` tensor. + + Stride is the jump necessary to go from one element to the next one in the + specified dimension :attr:`dim`. A tuple of all strides is returned when no + argument is passed in. Otherwise, an integer value is returned as the stride in + the particular dimension :attr:`dim`. + + Args: + dim (int, optional): the desired dimension in which stride is required + + Example:: + + >>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + >>> x.stride() + (5, 1) + >>> x.stride(0) + 5 + >>> x.stride(-1) + 1 + """ + + def sub( + self, + other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat, + *, + alpha: Number | _complex | None = 1, + out: Tensor | None = None, + ) -> Tensor: + r""" + sub(other, *, alpha=1) -> Tensor + + See :func:`torch.sub`. + """ + + def sub_( + self, + other: Tensor | Number | _complex | torch.SymInt | torch.SymFloat, + *, + alpha: Number | _complex | None = 1, + ) -> Tensor: + r""" + sub_(other, *, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.sub` + """ + + @overload + def subtract( + self, + other: Tensor, + *, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + subtract(other, *, alpha=1) -> Tensor + + See :func:`torch.subtract`. + """ + + @overload + def subtract( + self, + other: Number | _complex, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + subtract(other, *, alpha=1) -> Tensor + + See :func:`torch.subtract`. + """ + + @overload + def subtract_( + self, + other: Tensor, + *, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + subtract_(other, *, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.subtract`. + """ + + @overload + def subtract_( + self, + other: Number | _complex, + alpha: Number | _complex = 1, + ) -> Tensor: + r""" + subtract_(other, *, alpha=1) -> Tensor + + In-place version of :meth:`~Tensor.subtract`. + """ + + @overload + def sum(self, *, dtype: _dtype | None = None) -> Tensor: + r""" + sum(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.sum` + """ + + @overload + def sum( + self, + dim: _int | _size | None, + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + sum(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.sum` + """ + + @overload + def sum( + self, + dim: Sequence[str | EllipsisType | None], + keepdim: _bool = False, + *, + dtype: _dtype | None = None, + ) -> Tensor: + r""" + sum(dim=None, keepdim=False, dtype=None) -> Tensor + + See :func:`torch.sum` + """ + + @overload + def sum_to_size(self, size: Sequence[_int | SymInt]) -> Tensor: + r""" + sum_to_size(*size) -> Tensor + + Sum ``this`` tensor to :attr:`size`. + :attr:`size` must be broadcastable to ``this`` tensor size. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + """ + + @overload + def sum_to_size(self, *size: _int | SymInt) -> Tensor: + r""" + sum_to_size(*size) -> Tensor + + Sum ``this`` tensor to :attr:`size`. + :attr:`size` must be broadcastable to ``this`` tensor size. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + """ + + def svd( + self, + some: _bool = True, + compute_uv: _bool = True, + ) -> torch.return_types.svd: + r""" + svd(some=True, compute_uv=True) -> (Tensor, Tensor, Tensor) + + See :func:`torch.svd` + """ + + def swapaxes(self, axis0: _int, axis1: _int) -> Tensor: + r""" + swapaxes(axis0, axis1) -> Tensor + + See :func:`torch.swapaxes` + """ + + def swapaxes_(self, axis0: _int, axis1: _int) -> Tensor: + r""" + swapaxes_(axis0, axis1) -> Tensor + + In-place version of :meth:`~Tensor.swapaxes` + """ + + def swapdims(self, dim0: _int, dim1: _int) -> Tensor: + r""" + swapdims(dim0, dim1) -> Tensor + + See :func:`torch.swapdims` + """ + + def swapdims_(self, dim0: _int, dim1: _int) -> Tensor: + r""" + swapdims_(dim0, dim1) -> Tensor + + In-place version of :meth:`~Tensor.swapdims` + """ + + def t(self) -> Tensor: + r""" + t() -> Tensor + + See :func:`torch.t` + """ + + def t_(self) -> Tensor: + r""" + t_() -> Tensor + + In-place version of :meth:`~Tensor.t` + """ + + def take(self, index: Tensor) -> Tensor: + r""" + take(indices) -> Tensor + + See :func:`torch.take` + """ + + def take_along_dim( + self, + indices: Tensor, + dim: _int | None = None, + ) -> Tensor: + r""" + take_along_dim(indices, dim) -> Tensor + + See :func:`torch.take_along_dim` + """ + + def tan(self) -> Tensor: + r""" + tan() -> Tensor + + See :func:`torch.tan` + """ + + def tan_(self) -> Tensor: + r""" + tan_() -> Tensor + + In-place version of :meth:`~Tensor.tan` + """ + + def tanh(self) -> Tensor: + r""" + tanh() -> Tensor + + See :func:`torch.tanh` + """ + + def tanh_(self) -> Tensor: + r""" + tanh_() -> Tensor + + In-place version of :meth:`~Tensor.tanh` + """ + + @overload + def tensor_split( + self, + indices: Sequence[_int | SymInt], + dim: _int = 0, + ) -> tuple[Tensor, ...]: + r""" + tensor_split(indices_or_sections, dim=0) -> List of Tensors + + See :func:`torch.tensor_split` + """ + + @overload + def tensor_split( + self, + tensor_indices_or_sections: Tensor, + dim: _int = 0, + ) -> tuple[Tensor, ...]: + r""" + tensor_split(indices_or_sections, dim=0) -> List of Tensors + + See :func:`torch.tensor_split` + """ + + @overload + def tensor_split( + self, + sections: _int | SymInt, + dim: _int = 0, + ) -> tuple[Tensor, ...]: + r""" + tensor_split(indices_or_sections, dim=0) -> List of Tensors + + See :func:`torch.tensor_split` + """ + + @overload + def tile(self, dims: Sequence[_int | SymInt]) -> Tensor: + r""" + tile(dims) -> Tensor + + See :func:`torch.tile` + """ + + @overload + def tile(self, *dims: _int | SymInt) -> Tensor: + r""" + tile(dims) -> Tensor + + See :func:`torch.tile` + """ + + @overload + def to( + self, + dtype: _dtype, + non_blocking: _bool = False, + copy: _bool = False, + *, + memory_format: torch.memory_format | None = None, + ) -> Tensor: + r""" + to(*args, **kwargs) -> Tensor + + Performs Tensor dtype and/or device conversion. A :class:`torch.dtype` and :class:`torch.device` are + inferred from the arguments of ``self.to(*args, **kwargs)``. + + .. note:: + + If the ``self`` Tensor already + has the correct :class:`torch.dtype` and :class:`torch.device`, then ``self`` is returned. + Otherwise, the returned tensor is a copy of ``self`` with the desired + :class:`torch.dtype` and :class:`torch.device`. + + .. note:: + + If ``self`` requires gradients (``requires_grad=True``) but the target + ``dtype`` specified is an integer type, the returned tensor will implicitly + set ``requires_grad=False``. This is because only tensors with + floating-point or complex dtypes can require gradients. + + Here are the ways to call ``to``: + + .. method:: to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`dtype` + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. note:: + + According to `C++ type conversion rules `_, + converting floating point value to integer type will truncate the fractional part. + If the truncated value cannot fit into the target type (e.g., casting ``torch.inf`` to ``torch.long``), + the behavior is undefined and the result may vary across platforms. + + .. method:: to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`device` and (optional) + :attr:`dtype`. If :attr:`dtype` is ``None`` it is inferred to be ``self.dtype``. + When :attr:`non_blocking` is set to ``True``, the function attempts to perform + the conversion asynchronously with respect to the host, if possible. This + asynchronous behavior applies to both pinned and pageable memory. However, + caution is advised when using this feature. For more information, refer to the + `tutorial on good usage of non_blocking and pin_memory `__. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. method:: to(other, non_blocking=False, copy=False) -> Tensor + :noindex: + + Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as + the Tensor :attr:`other`. + When :attr:`non_blocking` is set to ``True``, the function attempts to perform + the conversion asynchronously with respect to the host, if possible. This + asynchronous behavior applies to both pinned and pageable memory. However, + caution is advised when using this feature. For more information, refer to the + `tutorial on good usage of non_blocking and pin_memory `__. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Example:: + + >>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu + >>> tensor.to(torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64) + + >>> cuda0 = torch.device('cuda:0') + >>> tensor.to(cuda0) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], device='cuda:0') + + >>> tensor.to(cuda0, dtype=torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + + >>> other = torch.randn((), dtype=torch.float64, device=cuda0) + >>> tensor.to(other, non_blocking=True) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + """ + + @overload + def to( + self, + device: DeviceLikeType | None = None, + dtype: _dtype | None = None, + non_blocking: _bool = False, + copy: _bool = False, + *, + memory_format: torch.memory_format | None = None, + ) -> Tensor: + r""" + to(*args, **kwargs) -> Tensor + + Performs Tensor dtype and/or device conversion. A :class:`torch.dtype` and :class:`torch.device` are + inferred from the arguments of ``self.to(*args, **kwargs)``. + + .. note:: + + If the ``self`` Tensor already + has the correct :class:`torch.dtype` and :class:`torch.device`, then ``self`` is returned. + Otherwise, the returned tensor is a copy of ``self`` with the desired + :class:`torch.dtype` and :class:`torch.device`. + + .. note:: + + If ``self`` requires gradients (``requires_grad=True``) but the target + ``dtype`` specified is an integer type, the returned tensor will implicitly + set ``requires_grad=False``. This is because only tensors with + floating-point or complex dtypes can require gradients. + + Here are the ways to call ``to``: + + .. method:: to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`dtype` + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. note:: + + According to `C++ type conversion rules `_, + converting floating point value to integer type will truncate the fractional part. + If the truncated value cannot fit into the target type (e.g., casting ``torch.inf`` to ``torch.long``), + the behavior is undefined and the result may vary across platforms. + + .. method:: to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`device` and (optional) + :attr:`dtype`. If :attr:`dtype` is ``None`` it is inferred to be ``self.dtype``. + When :attr:`non_blocking` is set to ``True``, the function attempts to perform + the conversion asynchronously with respect to the host, if possible. This + asynchronous behavior applies to both pinned and pageable memory. However, + caution is advised when using this feature. For more information, refer to the + `tutorial on good usage of non_blocking and pin_memory `__. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. method:: to(other, non_blocking=False, copy=False) -> Tensor + :noindex: + + Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as + the Tensor :attr:`other`. + When :attr:`non_blocking` is set to ``True``, the function attempts to perform + the conversion asynchronously with respect to the host, if possible. This + asynchronous behavior applies to both pinned and pageable memory. However, + caution is advised when using this feature. For more information, refer to the + `tutorial on good usage of non_blocking and pin_memory `__. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Example:: + + >>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu + >>> tensor.to(torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64) + + >>> cuda0 = torch.device('cuda:0') + >>> tensor.to(cuda0) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], device='cuda:0') + + >>> tensor.to(cuda0, dtype=torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + + >>> other = torch.randn((), dtype=torch.float64, device=cuda0) + >>> tensor.to(other, non_blocking=True) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + """ + + @overload + def to( + self, + other: Tensor, + non_blocking: _bool = False, + copy: _bool = False, + *, + memory_format: torch.memory_format | None = None, + ) -> Tensor: + r""" + to(*args, **kwargs) -> Tensor + + Performs Tensor dtype and/or device conversion. A :class:`torch.dtype` and :class:`torch.device` are + inferred from the arguments of ``self.to(*args, **kwargs)``. + + .. note:: + + If the ``self`` Tensor already + has the correct :class:`torch.dtype` and :class:`torch.device`, then ``self`` is returned. + Otherwise, the returned tensor is a copy of ``self`` with the desired + :class:`torch.dtype` and :class:`torch.device`. + + .. note:: + + If ``self`` requires gradients (``requires_grad=True``) but the target + ``dtype`` specified is an integer type, the returned tensor will implicitly + set ``requires_grad=False``. This is because only tensors with + floating-point or complex dtypes can require gradients. + + Here are the ways to call ``to``: + + .. method:: to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`dtype` + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. note:: + + According to `C++ type conversion rules `_, + converting floating point value to integer type will truncate the fractional part. + If the truncated value cannot fit into the target type (e.g., casting ``torch.inf`` to ``torch.long``), + the behavior is undefined and the result may vary across platforms. + + .. method:: to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) -> Tensor + :noindex: + + Returns a Tensor with the specified :attr:`device` and (optional) + :attr:`dtype`. If :attr:`dtype` is ``None`` it is inferred to be ``self.dtype``. + When :attr:`non_blocking` is set to ``True``, the function attempts to perform + the conversion asynchronously with respect to the host, if possible. This + asynchronous behavior applies to both pinned and pageable memory. However, + caution is advised when using this feature. For more information, refer to the + `tutorial on good usage of non_blocking and pin_memory `__. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Args: + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + + .. method:: to(other, non_blocking=False, copy=False) -> Tensor + :noindex: + + Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as + the Tensor :attr:`other`. + When :attr:`non_blocking` is set to ``True``, the function attempts to perform + the conversion asynchronously with respect to the host, if possible. This + asynchronous behavior applies to both pinned and pageable memory. However, + caution is advised when using this feature. For more information, refer to the + `tutorial on good usage of non_blocking and pin_memory `__. + When :attr:`copy` is set, a new Tensor is created even when the Tensor + already matches the desired conversion. + + Example:: + + >>> tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu + >>> tensor.to(torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64) + + >>> cuda0 = torch.device('cuda:0') + >>> tensor.to(cuda0) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], device='cuda:0') + + >>> tensor.to(cuda0, dtype=torch.float64) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + + >>> other = torch.randn((), dtype=torch.float64, device=cuda0) + >>> tensor.to(other, non_blocking=True) + tensor([[-0.5044, 0.0005], + [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0') + """ + + def to_dense( + self, + dtype: _dtype | None = None, + *, + masked_grad: _bool | None = None, + ) -> Tensor: + r""" + to_dense(dtype=None, *, masked_grad=True) -> Tensor + + Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`. + + Keyword args: + {dtype} + masked_grad (bool, optional): If set to ``True`` (default) and + :attr:`self` has a sparse layout then the backward of + :meth:`to_dense` returns ``grad.sparse_mask(self)``. + + Example:: + + >>> s = torch.sparse_coo_tensor( + ... torch.tensor([[1, 1], + ... [0, 2]]), + ... torch.tensor([9, 10]), + ... size=(3, 3)) + >>> s.to_dense() + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) + """ + + def to_mkldnn(self, dtype: _dtype | None = None) -> Tensor: + r""" + to_mkldnn() -> Tensor + Returns a copy of the tensor in ``torch.mkldnn`` layout. + """ + + def to_padded_tensor( + self, + padding: _float, + output_size: Sequence[_int | SymInt] | None = None, + ) -> Tensor: + r""" + to_padded_tensor(padding, output_size=None) -> Tensor + See :func:`to_padded_tensor` + """ + + @overload + def to_sparse( + self, + *, + layout: _layout | None = None, + blocksize: _int | _size | None = None, + dense_dim: _int | None = None, + ) -> Tensor: + r""" + to_sparse(sparseDims) -> Tensor + + Returns a sparse copy of the tensor. PyTorch supports sparse tensors in + :ref:`coordinate format `. + + Args: + sparseDims (int, optional): the number of sparse dimensions to include in the new sparse tensor + + Example:: + + >>> d = torch.tensor([[0, 0, 0], [9, 0, 10], [0, 0, 0]]) + >>> d + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) + >>> d.to_sparse() + tensor(indices=tensor([[1, 1], + [0, 2]]), + values=tensor([ 9, 10]), + size=(3, 3), nnz=2, layout=torch.sparse_coo) + >>> d.to_sparse(1) + tensor(indices=tensor([[1]]), + values=tensor([[ 9, 0, 10]]), + size=(3, 3), nnz=1, layout=torch.sparse_coo) + + .. method:: to_sparse(*, layout=None, blocksize=None, dense_dim=None) -> Tensor + :noindex: + + Returns a sparse tensor with the specified layout and blocksize. If + the :attr:`self` is strided, the number of dense dimensions could be + specified, and a hybrid sparse tensor will be created, with + `dense_dim` dense dimensions and `self.dim() - 2 - dense_dim` batch + dimension. + + .. note:: If the :attr:`self` layout and blocksize parameters match + with the specified layout and blocksize, return + :attr:`self`. Otherwise, return a sparse tensor copy of + :attr:`self`. + + Args: + + layout (:class:`torch.layout`, optional): The desired sparse + layout. One of ``torch.sparse_coo``, ``torch.sparse_csr``, + ``torch.sparse_csc``, ``torch.sparse_bsr``, or + ``torch.sparse_bsc``. Default: if ``None``, + ``torch.sparse_coo``. + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR or BSC tensor. For other layouts, + specifying the block size that is not ``None`` will result in a + RuntimeError exception. A block size must be a tuple of length + two such that its items evenly divide the two sparse dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR, CSC, BSR or BSC tensor. This argument should be + used only if :attr:`self` is a strided tensor, and must be a + value between 0 and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> x = torch.tensor([[1, 0], [0, 0], [2, 3]]) + >>> x.to_sparse(layout=torch.sparse_coo) + tensor(indices=tensor([[0, 2, 2], + [0, 0, 1]]), + values=tensor([1, 2, 3]), + size=(3, 2), nnz=3, layout=torch.sparse_coo) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(1, 2)) + tensor(crow_indices=tensor([0, 1, 1, 2]), + col_indices=tensor([0, 0]), + values=tensor([[[1, 0]], + [[2, 3]]]), size=(3, 2), nnz=2, layout=torch.sparse_bsr) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(2, 1)) + RuntimeError: Tensor size(-2) 3 needs to be divisible by blocksize[0] 2 + >>> x.to_sparse(layout=torch.sparse_csr, blocksize=(3, 1)) + RuntimeError: to_sparse for Strided to SparseCsr conversion does not use specified blocksize + + >>> x = torch.tensor([[[1], [0]], [[0], [0]], [[2], [3]]]) + >>> x.to_sparse(layout=torch.sparse_csr, dense_dim=1) + tensor(crow_indices=tensor([0, 1, 1, 3]), + col_indices=tensor([0, 0, 1]), + values=tensor([[1], + [2], + [3]]), size=(3, 2, 1), nnz=3, layout=torch.sparse_csr) + """ + + @overload + def to_sparse(self, sparse_dim: _int) -> Tensor: + r""" + to_sparse(sparseDims) -> Tensor + + Returns a sparse copy of the tensor. PyTorch supports sparse tensors in + :ref:`coordinate format `. + + Args: + sparseDims (int, optional): the number of sparse dimensions to include in the new sparse tensor + + Example:: + + >>> d = torch.tensor([[0, 0, 0], [9, 0, 10], [0, 0, 0]]) + >>> d + tensor([[ 0, 0, 0], + [ 9, 0, 10], + [ 0, 0, 0]]) + >>> d.to_sparse() + tensor(indices=tensor([[1, 1], + [0, 2]]), + values=tensor([ 9, 10]), + size=(3, 3), nnz=2, layout=torch.sparse_coo) + >>> d.to_sparse(1) + tensor(indices=tensor([[1]]), + values=tensor([[ 9, 0, 10]]), + size=(3, 3), nnz=1, layout=torch.sparse_coo) + + .. method:: to_sparse(*, layout=None, blocksize=None, dense_dim=None) -> Tensor + :noindex: + + Returns a sparse tensor with the specified layout and blocksize. If + the :attr:`self` is strided, the number of dense dimensions could be + specified, and a hybrid sparse tensor will be created, with + `dense_dim` dense dimensions and `self.dim() - 2 - dense_dim` batch + dimension. + + .. note:: If the :attr:`self` layout and blocksize parameters match + with the specified layout and blocksize, return + :attr:`self`. Otherwise, return a sparse tensor copy of + :attr:`self`. + + Args: + + layout (:class:`torch.layout`, optional): The desired sparse + layout. One of ``torch.sparse_coo``, ``torch.sparse_csr``, + ``torch.sparse_csc``, ``torch.sparse_bsr``, or + ``torch.sparse_bsc``. Default: if ``None``, + ``torch.sparse_coo``. + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR or BSC tensor. For other layouts, + specifying the block size that is not ``None`` will result in a + RuntimeError exception. A block size must be a tuple of length + two such that its items evenly divide the two sparse dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR, CSC, BSR or BSC tensor. This argument should be + used only if :attr:`self` is a strided tensor, and must be a + value between 0 and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> x = torch.tensor([[1, 0], [0, 0], [2, 3]]) + >>> x.to_sparse(layout=torch.sparse_coo) + tensor(indices=tensor([[0, 2, 2], + [0, 0, 1]]), + values=tensor([1, 2, 3]), + size=(3, 2), nnz=3, layout=torch.sparse_coo) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(1, 2)) + tensor(crow_indices=tensor([0, 1, 1, 2]), + col_indices=tensor([0, 0]), + values=tensor([[[1, 0]], + [[2, 3]]]), size=(3, 2), nnz=2, layout=torch.sparse_bsr) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(2, 1)) + RuntimeError: Tensor size(-2) 3 needs to be divisible by blocksize[0] 2 + >>> x.to_sparse(layout=torch.sparse_csr, blocksize=(3, 1)) + RuntimeError: to_sparse for Strided to SparseCsr conversion does not use specified blocksize + + >>> x = torch.tensor([[[1], [0]], [[0], [0]], [[2], [3]]]) + >>> x.to_sparse(layout=torch.sparse_csr, dense_dim=1) + tensor(crow_indices=tensor([0, 1, 1, 3]), + col_indices=tensor([0, 0, 1]), + values=tensor([[1], + [2], + [3]]), size=(3, 2, 1), nnz=3, layout=torch.sparse_csr) + """ + + def to_sparse_bsc( + self, + blocksize: _int | _size, + dense_dim: _int | None = None, + ) -> Tensor: + r""" + to_sparse_bsc(blocksize, dense_dim) -> Tensor + + Convert a tensor to a block sparse column (BSC) storage format of + given blocksize. If the :attr:`self` is strided, then the number of + dense dimensions could be specified, and a hybrid BSC tensor will be + created, with `dense_dim` dense dimensions and `self.dim() - 2 - + dense_dim` batch dimension. + + Args: + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSC tensor. A block size must be a tuple of + length two such that its items evenly divide the two sparse + dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting BSC tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> dense = torch.randn(10, 10) + >>> sparse = dense.to_sparse_csr() + >>> sparse_bsc = sparse.to_sparse_bsc((5, 5)) + >>> sparse_bsc.row_indices() + tensor([0, 1, 0, 1]) + + >>> dense = torch.zeros(4, 3, 1) + >>> dense[0:2, 0] = dense[0:2, 2] = dense[2:4, 1] = 1 + >>> dense.to_sparse_bsc((2, 1), 1) + tensor(ccol_indices=tensor([0, 1, 2, 3]), + row_indices=tensor([0, 1, 0]), + values=tensor([[[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]]]), size=(4, 3, 1), nnz=3, + layout=torch.sparse_bsc) + """ + + def to_sparse_bsr( + self, + blocksize: _int | _size, + dense_dim: _int | None = None, + ) -> Tensor: + r""" + to_sparse_bsr(blocksize, dense_dim) -> Tensor + + Convert a tensor to a block sparse row (BSR) storage format of given + blocksize. If the :attr:`self` is strided, then the number of dense + dimensions could be specified, and a hybrid BSR tensor will be + created, with `dense_dim` dense dimensions and `self.dim() - 2 - + dense_dim` batch dimension. + + Args: + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR tensor. A block size must be a tuple of + length two such that its items evenly divide the two sparse + dimensions. + + dense_dim (int, optional): Number of dense dimensions of the + resulting BSR tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> dense = torch.randn(10, 10) + >>> sparse = dense.to_sparse_csr() + >>> sparse_bsr = sparse.to_sparse_bsr((5, 5)) + >>> sparse_bsr.col_indices() + tensor([0, 1, 0, 1]) + + >>> dense = torch.zeros(4, 3, 1) + >>> dense[0:2, 0] = dense[0:2, 2] = dense[2:4, 1] = 1 + >>> dense.to_sparse_bsr((2, 1), 1) + tensor(crow_indices=tensor([0, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([[[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]], + + + [[[1.]], + + [[1.]]]]), size=(4, 3, 1), nnz=3, + layout=torch.sparse_bsr) + """ + + def to_sparse_csc(self, dense_dim: _int | None = None) -> Tensor: + r""" + to_sparse_csc() -> Tensor + + Convert a tensor to compressed column storage (CSC) format. Except + for strided tensors, only works with 2D tensors. If the :attr:`self` + is strided, then the number of dense dimensions could be specified, + and a hybrid CSC tensor will be created, with `dense_dim` dense + dimensions and `self.dim() - 2 - dense_dim` batch dimension. + + Args: + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSC tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_csc() + >>> sparse._nnz() + 25 + + >>> dense = torch.zeros(3, 3, 1, 1) + >>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 1 + >>> dense.to_sparse_csc(dense_dim=2) + tensor(ccol_indices=tensor([0, 1, 2, 3]), + row_indices=tensor([0, 2, 1]), + values=tensor([[[1.]], + + [[1.]], + + [[1.]]]), size=(3, 3, 1, 1), nnz=3, + layout=torch.sparse_csc) + """ + + def to_sparse_csr(self, dense_dim: _int | None = None) -> Tensor: + r""" + to_sparse_csr(dense_dim=None) -> Tensor + + Convert a tensor to compressed row storage format (CSR). Except for + strided tensors, only works with 2D tensors. If the :attr:`self` is + strided, then the number of dense dimensions could be specified, and a + hybrid CSR tensor will be created, with `dense_dim` dense dimensions + and `self.dim() - 2 - dense_dim` batch dimension. + + Args: + + dense_dim (int, optional): Number of dense dimensions of the + resulting CSR tensor. This argument should be used only if + :attr:`self` is a strided tensor, and must be a value between 0 + and dimension of :attr:`self` tensor minus two. + + Example:: + + >>> dense = torch.randn(5, 5) + >>> sparse = dense.to_sparse_csr() + >>> sparse._nnz() + 25 + + >>> dense = torch.zeros(3, 3, 1, 1) + >>> dense[0, 0] = dense[1, 2] = dense[2, 1] = 1 + >>> dense.to_sparse_csr(dense_dim=2) + tensor(crow_indices=tensor([0, 1, 2, 3]), + col_indices=tensor([0, 2, 1]), + values=tensor([[[1.]], + + [[1.]], + + [[1.]]]), size=(3, 3, 1, 1), nnz=3, + layout=torch.sparse_csr) + """ + + def tolist(self) -> list: + r""" + tolist() -> list or number + + Returns the tensor as a (nested) list. For scalars, a standard + Python number is returned, just like with :meth:`~Tensor.item`. + Tensors are automatically moved to the CPU first if necessary. + + This operation is not differentiable. + + Examples:: + + >>> a = torch.randn(2, 2) + >>> a.tolist() + [[0.012766935862600803, 0.5415473580360413], + [-0.08909505605697632, 0.7729271650314331]] + >>> a[0,0].tolist() + 0.012766935862600803 + """ + + def topk( + self, + k: _int | SymInt, + dim: _int = -1, + largest: _bool = True, + sorted: _bool = True, + ) -> torch.return_types.topk: + r""" + topk(k, dim=None, largest=True, sorted=True) -> (Tensor, LongTensor) + + See :func:`torch.topk` + """ + + def trace(self) -> Tensor: + r""" + trace() -> Tensor + + See :func:`torch.trace` + """ + + @overload + def transpose(self, dim0: _int, dim1: _int) -> Tensor: + r""" + transpose(dim0, dim1) -> Tensor + + See :func:`torch.transpose` + """ + + @overload + def transpose( + self, + dim0: str | EllipsisType | None, + dim1: str | EllipsisType | None, + ) -> Tensor: + r""" + transpose(dim0, dim1) -> Tensor + + See :func:`torch.transpose` + """ + + def transpose_(self, dim0: _int, dim1: _int) -> Tensor: + r""" + transpose_(dim0, dim1) -> Tensor + + In-place version of :meth:`~Tensor.transpose` + """ + + def triangular_solve( + self, + A: Tensor, + upper: _bool = True, + transpose: _bool = False, + unitriangular: _bool = False, + ) -> torch.return_types.triangular_solve: + r""" + triangular_solve(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) + + See :func:`torch.triangular_solve` + """ + + def tril(self, diagonal: _int = 0) -> Tensor: + r""" + tril(diagonal=0) -> Tensor + + See :func:`torch.tril` + """ + + def tril_(self, diagonal: _int = 0) -> Tensor: + r""" + tril_(diagonal=0) -> Tensor + + In-place version of :meth:`~Tensor.tril` + """ + + def triu(self, diagonal: _int = 0) -> Tensor: + r""" + triu(diagonal=0) -> Tensor + + See :func:`torch.triu` + """ + + def triu_(self, diagonal: _int = 0) -> Tensor: + r""" + triu_(diagonal=0) -> Tensor + + In-place version of :meth:`~Tensor.triu` + """ + + def true_divide( + self, + other: Tensor | Number | torch.SymInt | torch.SymFloat, + *, + out: Tensor | None = None, + ) -> Tensor: + r""" + true_divide(value) -> Tensor + + See :func:`torch.true_divide` + """ + + def true_divide_( + self, + other: Tensor | Number | torch.SymInt | torch.SymFloat, + ) -> Tensor: + r""" + true_divide_(value) -> Tensor + + In-place version of :meth:`~Tensor.true_divide_` + """ + + def trunc(self) -> Tensor: + r""" + trunc() -> Tensor + + See :func:`torch.trunc` + """ + + def trunc_(self) -> Tensor: + r""" + trunc_() -> Tensor + + In-place version of :meth:`~Tensor.trunc` + """ + + @overload + def type(self, dtype: None = None, non_blocking: _bool = False) -> str: + r""" + type(dtype=None, non_blocking=False, **kwargs) -> str or Tensor + Returns the type if `dtype` is not provided, else casts this object to + the specified type. + + If this is already of the correct type, no copy is performed and the + original object is returned. + + Args: + dtype (dtype or string): The desired type + non_blocking (bool): If ``True``, and the source is in pinned memory + and destination is on the GPU or vice versa, the copy is performed + asynchronously with respect to the host. Otherwise, the argument + has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. The ``async`` arg is deprecated. + """ + + @overload + def type(self, dtype: str | _dtype, non_blocking: _bool = False) -> Tensor: + r""" + type(dtype=None, non_blocking=False, **kwargs) -> str or Tensor + Returns the type if `dtype` is not provided, else casts this object to + the specified type. + + If this is already of the correct type, no copy is performed and the + original object is returned. + + Args: + dtype (dtype or string): The desired type + non_blocking (bool): If ``True``, and the source is in pinned memory + and destination is on the GPU or vice versa, the copy is performed + asynchronously with respect to the host. Otherwise, the argument + has no effect. + **kwargs: For compatibility, may contain the key ``async`` in place of + the ``non_blocking`` argument. The ``async`` arg is deprecated. + """ + + def type_as(self, other: Tensor) -> Tensor: + r""" + type_as(tensor) -> Tensor + + Returns this tensor cast to the type of the given tensor. + + This is a no-op if the tensor is already of the correct type. This is + equivalent to ``self.type(tensor.type())`` + + Args: + tensor (Tensor): the tensor which has the desired type + """ + + @overload + def unbind(self, dim: _int = 0) -> tuple[Tensor, ...]: + r""" + unbind(dim=0) -> seq + + See :func:`torch.unbind` + """ + + @overload + def unbind(self, dim: str | EllipsisType | None) -> tuple[Tensor, ...]: + r""" + unbind(dim=0) -> seq + + See :func:`torch.unbind` + """ + + @overload + def unflatten( + self, + dim: str | EllipsisType | None, + sizes: Sequence[_int | SymInt], + names: Sequence[str | EllipsisType | None], + ) -> Tensor: ... + @overload + def unflatten(self, dim: _int, sizes: Sequence[_int | SymInt]) -> Tensor: ... + def unfold(self, dimension: _int, size: _int, step: _int) -> Tensor: + r""" + unfold(dimension, size, step) -> Tensor + + Returns a view of the original tensor which contains all slices of size :attr:`size` from + :attr:`self` tensor in the dimension :attr:`dimension`. + + Step between two slices is given by :attr:`step`. + + If `sizedim` is the size of dimension :attr:`dimension` for :attr:`self`, the size of + dimension :attr:`dimension` in the returned tensor will be + `(sizedim - size) / step + 1`. + + An additional dimension of size :attr:`size` is appended in the returned tensor. + + Args: + dimension (int): dimension in which unfolding happens + size (int): the size of each slice that is unfolded + step (int): the step between each slice + + Example:: + + >>> x = torch.arange(1., 8) + >>> x + tensor([ 1., 2., 3., 4., 5., 6., 7.]) + >>> x.unfold(0, 2, 1) + tensor([[ 1., 2.], + [ 2., 3.], + [ 3., 4.], + [ 4., 5.], + [ 5., 6.], + [ 6., 7.]]) + >>> x.unfold(0, 2, 2) + tensor([[ 1., 2.], + [ 3., 4.], + [ 5., 6.]]) + """ + + def uniform_( + self, + from_: _float = 0, + to: _float = 1, + *, + generator: Generator | None = None, + ) -> Tensor: + r""" + uniform_(from=0, to=1, *, generator=None) -> Tensor + + Fills :attr:`self` tensor with numbers sampled from the continuous uniform + distribution: + + .. math:: + f(x) = \dfrac{1}{\text{to} - \text{from}} + """ + + def unsafe_chunk(self, chunks: _int, dim: _int = 0) -> tuple[Tensor, ...]: + r""" + unsafe_chunk(chunks, dim=0) -> List of Tensors + + See :func:`torch.unsafe_chunk` + """ + + def unsafe_split( + self, + split_size: _int | SymInt, + dim: _int = 0, + ) -> tuple[Tensor, ...]: + r""" + unsafe_split(split_size, dim=0) -> List of Tensors + + See :func:`torch.unsafe_split` + """ + + def unsafe_split_with_sizes( + self, + split_sizes: Sequence[_int | SymInt], + dim: _int = 0, + ) -> tuple[Tensor, ...]: ... + def unsqueeze(self, dim: _int) -> Tensor: + r""" + unsqueeze(dim) -> Tensor + + See :func:`torch.unsqueeze` + """ + + def unsqueeze_(self, dim: _int) -> Tensor: + r""" + unsqueeze_(dim) -> Tensor + + In-place version of :meth:`~Tensor.unsqueeze` + """ + + def values(self) -> Tensor: + r""" + values() -> Tensor + + Return the values tensor of a :ref:`sparse COO tensor `. + + .. warning:: + Throws an error if :attr:`self` is not a sparse COO tensor. + + See also :meth:`Tensor.indices`. + + .. note:: + This method can only be called on a coalesced sparse tensor. See + :meth:`Tensor.coalesce` for details. + """ + + @overload + def var( + self, + dim: _int | _size | None, + unbiased: _bool = True, + keepdim: _bool = False, + ) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + + @overload + def var( + self, + dim: _int | _size | None = None, + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + ) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + + @overload + def var(self, unbiased: _bool = True) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + + @overload + def var( + self, + dim: Sequence[str | EllipsisType | None], + unbiased: _bool = True, + keepdim: _bool = False, + ) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + + @overload + def var( + self, + dim: Sequence[str | EllipsisType | None], + *, + correction: Number | _complex | None = None, + keepdim: _bool = False, + ) -> Tensor: + r""" + var(dim=None, *, correction=1, keepdim=False) -> Tensor + + See :func:`torch.var` + """ + + def vdot(self, other: Tensor) -> Tensor: + r""" + vdot(other) -> Tensor + + See :func:`torch.vdot` + """ + + @overload + def view(self, dtype: _dtype) -> Tensor: + r""" + view(*shape) -> Tensor + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`shape`. + + The returned tensor shares the same data and must have the same number + of elements, but may have a different size. For a tensor to be viewed, the new + view size must be compatible with its original size and stride, i.e., each new + view dimension must either be a subspace of an original dimension, or only span + across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following + contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + + .. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + + Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` + without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a + :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which + returns a view if the shapes are compatible, and copies (equivalent to calling + :meth:`contiguous`) otherwise. + + Args: + shape (torch.Size or int...): the desired size + + Example:: + + >>> x = torch.randn(4, 4) + >>> x.size() + torch.Size([4, 4]) + >>> y = x.view(16) + >>> y.size() + torch.Size([16]) + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size() + torch.Size([2, 8]) + + >>> a = torch.randn(1, 2, 3, 4) + >>> a.size() + torch.Size([1, 2, 3, 4]) + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size() + torch.Size([1, 3, 2, 4]) + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size() + torch.Size([1, 3, 2, 4]) + >>> torch.equal(b, c) + False + + + .. method:: view(dtype) -> Tensor + :noindex: + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`dtype`. + + If the element size of :attr:`dtype` is different than that of ``self.dtype``, + then the size of the last dimension of the output will be scaled + proportionally. For instance, if :attr:`dtype` element size is twice that of + ``self.dtype``, then each pair of elements in the last dimension of + :attr:`self` will be combined, and the size of the last dimension of the output + will be half that of :attr:`self`. If :attr:`dtype` element size is half that + of ``self.dtype``, then each element in the last dimension of :attr:`self` will + be split in two, and the size of the last dimension of the output will be + double that of :attr:`self`. For this to be possible, the following conditions + must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + + Additionally, if the element size of :attr:`dtype` is greater than that of + ``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + + If any of the above conditions are not met, an error is thrown. + + .. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + + Args: + dtype (:class:`torch.dtype`): the desired dtype + + Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(torch.cfloat).size() + torch.Size([4, 2]) + + >>> x.view(torch.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=torch.uint8) + >>> x.view(torch.uint8).size() + torch.Size([4, 16]) + """ + + @overload + def view(self, size: Sequence[_int | SymInt]) -> Tensor: + r""" + view(*shape) -> Tensor + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`shape`. + + The returned tensor shares the same data and must have the same number + of elements, but may have a different size. For a tensor to be viewed, the new + view size must be compatible with its original size and stride, i.e., each new + view dimension must either be a subspace of an original dimension, or only span + across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following + contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + + .. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + + Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` + without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a + :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which + returns a view if the shapes are compatible, and copies (equivalent to calling + :meth:`contiguous`) otherwise. + + Args: + shape (torch.Size or int...): the desired size + + Example:: + + >>> x = torch.randn(4, 4) + >>> x.size() + torch.Size([4, 4]) + >>> y = x.view(16) + >>> y.size() + torch.Size([16]) + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size() + torch.Size([2, 8]) + + >>> a = torch.randn(1, 2, 3, 4) + >>> a.size() + torch.Size([1, 2, 3, 4]) + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size() + torch.Size([1, 3, 2, 4]) + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size() + torch.Size([1, 3, 2, 4]) + >>> torch.equal(b, c) + False + + + .. method:: view(dtype) -> Tensor + :noindex: + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`dtype`. + + If the element size of :attr:`dtype` is different than that of ``self.dtype``, + then the size of the last dimension of the output will be scaled + proportionally. For instance, if :attr:`dtype` element size is twice that of + ``self.dtype``, then each pair of elements in the last dimension of + :attr:`self` will be combined, and the size of the last dimension of the output + will be half that of :attr:`self`. If :attr:`dtype` element size is half that + of ``self.dtype``, then each element in the last dimension of :attr:`self` will + be split in two, and the size of the last dimension of the output will be + double that of :attr:`self`. For this to be possible, the following conditions + must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + + Additionally, if the element size of :attr:`dtype` is greater than that of + ``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + + If any of the above conditions are not met, an error is thrown. + + .. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + + Args: + dtype (:class:`torch.dtype`): the desired dtype + + Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(torch.cfloat).size() + torch.Size([4, 2]) + + >>> x.view(torch.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=torch.uint8) + >>> x.view(torch.uint8).size() + torch.Size([4, 16]) + """ + + @overload + def view(self, *size: _int | SymInt) -> Tensor: + r""" + view(*shape) -> Tensor + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`shape`. + + The returned tensor shares the same data and must have the same number + of elements, but may have a different size. For a tensor to be viewed, the new + view size must be compatible with its original size and stride, i.e., each new + view dimension must either be a subspace of an original dimension, or only span + across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following + contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + + .. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + + Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` + without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a + :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which + returns a view if the shapes are compatible, and copies (equivalent to calling + :meth:`contiguous`) otherwise. + + Args: + shape (torch.Size or int...): the desired size + + Example:: + + >>> x = torch.randn(4, 4) + >>> x.size() + torch.Size([4, 4]) + >>> y = x.view(16) + >>> y.size() + torch.Size([16]) + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size() + torch.Size([2, 8]) + + >>> a = torch.randn(1, 2, 3, 4) + >>> a.size() + torch.Size([1, 2, 3, 4]) + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size() + torch.Size([1, 3, 2, 4]) + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size() + torch.Size([1, 3, 2, 4]) + >>> torch.equal(b, c) + False + + + .. method:: view(dtype) -> Tensor + :noindex: + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`dtype`. + + If the element size of :attr:`dtype` is different than that of ``self.dtype``, + then the size of the last dimension of the output will be scaled + proportionally. For instance, if :attr:`dtype` element size is twice that of + ``self.dtype``, then each pair of elements in the last dimension of + :attr:`self` will be combined, and the size of the last dimension of the output + will be half that of :attr:`self`. If :attr:`dtype` element size is half that + of ``self.dtype``, then each element in the last dimension of :attr:`self` will + be split in two, and the size of the last dimension of the output will be + double that of :attr:`self`. For this to be possible, the following conditions + must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + + Additionally, if the element size of :attr:`dtype` is greater than that of + ``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + + If any of the above conditions are not met, an error is thrown. + + .. warning:: + + This overload is not supported by TorchScript, and using it in a Torchscript + program will cause undefined behavior. + + + Args: + dtype (:class:`torch.dtype`): the desired dtype + + Example:: + + >>> x = torch.randn(4, 4) + >>> x + tensor([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + torch.float32 + + >>> y = x.view(torch.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=torch.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(torch.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(torch.cfloat).size() + torch.Size([4, 2]) + + >>> x.view(torch.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=torch.uint8) + >>> x.view(torch.uint8).size() + torch.Size([4, 16]) + """ + + def view_as(self, other: Tensor) -> Tensor: + r""" + view_as(other) -> Tensor + + View this tensor as the same size as :attr:`other`. + ``self.view_as(other)`` is equivalent to ``self.view(other.size())``. + + Please see :meth:`~Tensor.view` for more information about ``view``. + + Args: + other (:class:`torch.Tensor`): The result tensor has the same size + as :attr:`other`. + """ + + @overload + def vsplit(self, sections: _int) -> tuple[Tensor, ...]: + r""" + vsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.vsplit` + """ + + @overload + def vsplit(self, indices: _size) -> tuple[Tensor, ...]: + r""" + vsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.vsplit` + """ + + @overload + def vsplit(self, *indices: _int) -> tuple[Tensor, ...]: + r""" + vsplit(split_size_or_sections) -> List of Tensors + + See :func:`torch.vsplit` + """ + + @overload + def where(self, condition: Tensor, other: Tensor) -> Tensor: + r""" + where(condition, y) -> Tensor + + ``self.where(condition, y)`` is equivalent to ``torch.where(condition, self, y)``. + See :func:`torch.where` + """ + + @overload + def where(self, condition: Tensor, other: Number | _complex) -> Tensor: + r""" + where(condition, y) -> Tensor + + ``self.where(condition, y)`` is equivalent to ``torch.where(condition, self, y)``. + See :func:`torch.where` + """ + + @overload + def xlogy(self, other: Tensor) -> Tensor: + r""" + xlogy(other) -> Tensor + + See :func:`torch.xlogy` + """ + + @overload + def xlogy(self, other: Number | _complex) -> Tensor: + r""" + xlogy(other) -> Tensor + + See :func:`torch.xlogy` + """ + + @overload + def xlogy_(self, other: Tensor) -> Tensor: + r""" + xlogy_(other) -> Tensor + + In-place version of :meth:`~Tensor.xlogy` + """ + + @overload + def xlogy_(self, other: Number | _complex) -> Tensor: + r""" + xlogy_(other) -> Tensor + + In-place version of :meth:`~Tensor.xlogy` + """ + + def xpu( + self, + device: _device | _int | str | None = None, + non_blocking: _bool = False, + memory_format: torch.memory_format = torch.preserve_format, + ) -> Tensor: + r""" + xpu(device=None, non_blocking=False, memory_format=torch.preserve_format) -> Tensor + + Returns a copy of this object in XPU memory. + + If this object is already in XPU memory and on the correct device, + then no copy is performed and the original object is returned. + + Args: + device (:class:`torch.device`): The destination XPU device. + Defaults to the current XPU device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. + Otherwise, the argument has no effect. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.preserve_format``. + """ + + def zero_(self) -> Tensor: + r""" + zero_() -> Tensor + + Fills :attr:`self` tensor with zeros. + """ + +_TensorBase = TensorBase + +# Defined in torch/csrc/multiprocessing/init.cpp +def _multiprocessing_init() -> None: ... +def _set_thread_name(name: str) -> None: ... +def _get_thread_name() -> str: ... + +# Defined in torch/csrc/Module.cpp +def _accelerator_hooks_device_count() -> _int: ... +def _accelerator_hooks_set_current_device(device_index: _int) -> None: ... +def _accelerator_hooks_get_current_device() -> _int: ... +def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ... +def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ... +def _get_accelerator(check: _bool = False) -> _device: ... +def _storage_Use_Count(storage_ptr: _int) -> _int: ... + +# Defined in torch/csrc/mtia/Module.cpp +def _mtia_init() -> None: ... +def _mtia_isBuilt() -> _bool: ... +def _mtia_isInBadFork() -> _bool: ... +def _mtia_deviceSynchronize() -> None: ... +def _mtia_getCurrentStream(device: _int) -> Stream: ... +def _mtia_setCurrentStream(stream: Stream) -> None: ... +def _mtia_getDefaultStream(device: _int) -> Stream: ... +def _mtia_memoryStats(device: _int) -> dict[str, Any]: ... +def _mtia_getDeviceCapability(device: _int) -> tuple[_int, _int]: ... +def _mtia_getDeviceProperties(device: _int) -> dict[str, Any]: ... +def _mtia_emptyCache() -> None: ... +def _mtia_recordMemoryHistory( + enabled: str | None, + stacks: str, + max_entries, +) -> None: ... +def _mtia_memorySnapshot() -> dict[str, Any]: ... +def _mtia_attachOutOfMemoryObserver( + observer: Callable[[_int, _int, _int, _int], None], +) -> None: ... +def _mtia_getDeviceCount() -> _int: ... +def _mtia_resetPeakMemoryStats(device: _int) -> None: ... + +# Defined in torch/csrc/mps/Module.cpp +def _mps_deviceSynchronize() -> None: ... +def _mps_get_default_generator() -> Generator: ... +def _mps_emptyCache() -> None: ... +def _mps_setMemoryFraction(fraction: _float) -> None: ... +def _mps_currentAllocatedMemory() -> _int: ... +def _mps_driverAllocatedMemory() -> _int: ... +def _mps_recommendedMaxMemory() -> _int: ... +def _mps_is_available() -> _bool: ... +def _mps_is_on_macos_or_newer(major: _int, minor: _int) -> _bool: ... +def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ... +def _mps_profilerStopTrace() -> None: ... +def _mps_acquireEvent(enable_timing: _bool) -> _int: ... +def _mps_releaseEvent(event_id: _int) -> None: ... +def _mps_recordEvent(event_id: _int) -> None: ... +def _mps_waitForEvent(event_id: _int) -> None: ... +def _mps_synchronizeEvent(event_id: _int) -> None: ... +def _mps_queryEvent(event_id: _int) -> _bool: ... +def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ... +def _mps_isCaptureEnabled() -> _bool: ... +def _mps_isCapturing() -> _bool: ... +def _mps_startCapture(name: str) -> None: ... +def _mps_stopCapture() -> None: ... + +# Defined in torch/csrc/cuda/Module.cpp +def _cuda_getCurrentStream(device: _int) -> tuple: ... +def _cuda_getCurrentRawStream(device: _int) -> _int: ... +def _cuda_getDefaultStream(device: _int) -> tuple: ... +def _cuda_getStreamFromExternal(data_ptr: _int, device_index: _int) -> tuple: ... +def _cuda_getCurrentBlasHandle() -> _int: ... +def _cuda_clearCublasWorkspaces() -> None: ... +def _cuda_setDevice(device: _int) -> None: ... +def _cuda_exchangeDevice(device: _int) -> _int: ... +def _cuda_maybeExchangeDevice(device: _int) -> _int: ... +def _cuda_getDevice() -> _int: ... +def _cuda_getDeviceCount() -> _int: ... +def _cuda_set_sync_debug_mode(warn_level: _int | str) -> None: ... +def _cuda_get_sync_debug_mode() -> _int: ... +def _cuda_sleep(cycles: _int) -> None: ... +def _cuda_synchronize() -> None: ... +def _cuda_ipc_collect() -> None: ... +def _cuda_getArchFlags() -> str | None: ... +def _cuda_init() -> None: ... +def _cuda_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ... +def _cuda_getCompiledVersion() -> _int: ... +def _cuda_cudaHostAllocator() -> _int: ... +def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ... +def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ... +def _cuda_cudaCachingAllocator_enable(val: _bool) -> None: ... +def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ... +def _cuda_beginAllocateToPool(device: _int, mempool_id: tuple[_int, _int]) -> None: ... +def _cuda_beginAllocateCurrentThreadToPool( + device: _int, + mempool_id: tuple[_int, _int], +) -> None: ... +def _cuda_endAllocateToPool(device: _int, mempool_id: tuple[_int, _int]) -> None: ... +def _cuda_beginAllocateCurrentStreamToPool( + device: _int, + mempool_id: tuple[_int, _int], +) -> None: ... +def _cuda_releasePool(device: _int, mempool_id: tuple[_int, _int]) -> None: ... +def _cuda_checkPoolLiveAllocations( + device: _int, + mempool_id: tuple[_int, _int], + expected_live_allocations: set, +) -> _bool: ... +def _cuda_setCheckpointPoolState( + device: _int, + state: _cuda_CUDAAllocator_AllocatorState, + stale_storages: list[_int], + storages_to_add_deleters_to: list[_int], +) -> None: ... +def _cuda_getMemoryFraction(device: _int) -> _float: ... +def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ... +def _cuda_emptyCache() -> None: ... +def _cuda_memoryStats(device: _int) -> dict[str, Any]: ... +def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ... +def _cuda_resetPeakMemoryStats(device: _int) -> None: ... +def _cuda_hostMemoryStats() -> dict[str, Any]: ... +def _cuda_resetAccumulatedHostMemoryStats() -> None: ... +def _cuda_resetPeakHostMemoryStats() -> None: ... +def _cuda_memorySnapshot(mempool_id: tuple[_int, _int] | None) -> dict[str, Any]: ... +def _cuda_record_memory_history_legacy( + enabled: _bool, + record_context: _bool, + record_context_cpp: _bool, + alloc_trace_max_entries: _int, + alloc_trace_record_context: _bool, + clear_history: _bool, + compile_context: _bool, +) -> None: ... +def _cuda_record_memory_history( + enabled: str | None, + context: str | None, + stacks: str, + max_entries: _int, + clear_history: _bool, + compile_context: _bool, +) -> None: ... +def _cuda_isHistoryEnabled() -> _bool: ... +def _cuda_getAllocatorBackend() -> str: ... + +class _cuda_CUDAAllocator_AllocatorState: ... + +def _cuda_getCheckpointState( + device: _int, + mempool: tuple[_int, _int], +) -> _cuda_CUDAAllocator_AllocatorState: ... +def _set_cached_tensors_enabled(enabled: _bool) -> None: ... +def _add_cached_tensor(t: Tensor) -> None: ... +def _remove_cached_tensor(t: Tensor) -> None: ... +def _tensors_data_ptrs_at_indices_equal( + tensors: list[Tensor | _int], + ptrs: list[_int | None], + indices: list[_int], +) -> _bool: ... +def _construct_CUDA_Tensor_From_Storage_And_Metadata( + metadata: dict, + storage: Storage, +) -> Tensor: ... +def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ... +def _set_storage_data_ptr_access_error_msg(storage_ptr: _int, s: str) -> None: ... +def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ... +def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ... + +class _cuda_CUDAAllocator: ... + +def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ... +def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ... +def _cuda_getAllocator() -> _cuda_CUDAAllocator: ... +def _cuda_lock_mutex() -> None: ... +def _cuda_unlock_mutex() -> None: ... +def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ... +def _cuda_jiterator_compile_and_launch_kernel( + code_string: str, + kernel_name: str, + return_by_ref: _bool, + num_outputs: _int, + tensors: tuple, + kwargs: dict[str, _int | _float | _bool], +) -> Tensor: ... +def _cuda_get_cudnn_benchmark_limit() -> _int: ... +def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ... +def _cuda_get_conv_benchmark_empty_cache() -> _bool: ... +def _cudnn_set_conv_benchmark_empty_cache(enable: _bool) -> None: ... +def _nccl_version() -> _int: ... +def _nccl_version_suffix() -> bytes: ... +def _nccl_unique_id() -> bytes: ... +def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ... +def _nccl_reduce( + input: Sequence[Tensor], + output: Tensor, + root: _int, + op: _int, + streams: Sequence[_CudaStreamBase] | None, + comms: Sequence[object] | None, +) -> None: ... +def _nccl_all_reduce( + input: Sequence[Tensor], + output: Sequence[Tensor], + op: _int, + streams: Sequence[_CudaStreamBase] | None, + comms: Sequence[object] | None, +) -> None: ... +def _nccl_broadcast( + input: Sequence[Tensor], + root: _int, + streams: Sequence[_CudaStreamBase] | None, + comms: Sequence[object] | None, +) -> None: ... +def _nccl_all_gather( + input: Sequence[Tensor], + output: Sequence[Tensor], + streams: Sequence[_CudaStreamBase] | None, + comms: Sequence[object] | None, +) -> None: ... +def _nccl_reduce_scatter( + input: Sequence[Tensor], + output: Sequence[Tensor], + op: _int, + streams: Sequence[_CudaStreamBase] | None, + comms: Sequence[object] | None, +) -> None: ... +def _rocm_is_backward_pass() -> _bool: ... +def _cuda_tunableop_enable(val: _bool) -> None: ... +def _cuda_tunableop_is_enabled() -> _bool: ... +def _cuda_tunableop_tuning_enable(val: _bool) -> None: ... +def _cuda_tunableop_tuning_is_enabled() -> _bool: ... +def _cuda_tunableop_set_max_tuning_duration(duration: _int) -> None: ... +def _cuda_tunableop_get_max_tuning_duration() -> _int: ... +def _cuda_tunableop_set_max_tuning_iterations(iterations: _int) -> None: ... +def _cuda_tunableop_get_max_tuning_iterations() -> _int: ... +def _cuda_tunableop_set_filename( + filename: str, + insert_device_ordinal: _bool | None, +) -> None: ... +def _cuda_tunableop_get_filename() -> str: ... +def _cuda_tunableop_write_file(filename: str | None) -> _bool: ... +def _cuda_tunableop_read_file(filename: str | None) -> _bool: ... +def _cuda_tunableop_write_file_on_exit(val: _bool) -> None: ... +def _cuda_tunableop_get_results() -> tuple[str, str, str, _float]: ... +def _cuda_tunableop_get_validators() -> tuple[str, str]: ... +def _cuda_tunableop_set_rotating_buffer_size(buffer_size: _int) -> None: ... +def _cuda_tunableop_get_rotation_buffer_size() -> _int: ... + +class _CudaDeviceProperties: + name: str + major: _int + minor: _int + multi_processor_count: _int + total_memory: _int + is_integrated: _int + is_multi_gpu_board: _int + max_threads_per_multi_processor: _int + gcnArchName: str + warp_size: _int + uuid: str + L2_cache_size: _int + +# Functions related to SDPA +class _SDPAParams: + query: Tensor + key: Tensor + value: Tensor + attn_mask: Tensor | None + dropout: _float + is_causal: _bool + enable_gqa: _bool + def __init__( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor | None, + dropout: _float, + is_causal: _bool, + enable_gqa: _bool, + ) -> None: ... + +class _SDPBackend(Enum): + ERROR = -1 + MATH = 0 + FLASH_ATTENTION = 1 + EFFICIENT_ATTENTION = 2 + CUDNN_ATTENTION = 3 + +def _is_flash_attention_available() -> _bool: ... +def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ... +def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ... +def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ... + +# Defined in torch/csrc/cuda/GdsFile.cpp +def _gds_register_buffer(t: Storage) -> None: ... +def _gds_deregister_buffer(t: Storage) -> None: ... +def _gds_register_handle(fd: _int) -> _int: ... +def _gds_deregister_handle(handle: _int) -> None: ... +def _gds_load_storage(handle: _int, s: Storage, offset: _int) -> None: ... +def _gds_save_storage(handle: _int, s: Storage, offset: _int) -> None: ... + +# Defined in torch/csrc/cuda/python_comm.cpp +def _broadcast(tensor: Tensor, devices: list[_int]) -> list[Tensor]: ... +def _broadcast_out(tensor: Tensor, out_tensors: list[Tensor]) -> list[Tensor]: ... +def _broadcast_coalesced( + tensors: list[Tensor], + devices: list[_int], + buffer_size: _int, +) -> list[list[Tensor]]: ... +def _scatter( + tensor: Tensor, + devices: list[_int], + chunk_sizes: list[_int] | None, + dim: _int, + streams: list[Stream] | None, +) -> list[Tensor]: ... +def _scatter_out( + tensor: Tensor, + out_tensors: list[Tensor], + dim: _int, + streams: list[Stream] | None, +) -> list[Tensor]: ... +def _gather( + tensors: list[Tensor], + dim: _int, + destination_index: _int | None, +) -> Tensor: ... +def _gather_out(tensors: list[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ... + +# Defined in torch/csrc/cuda/Stream.cpp +class _CudaStreamBase(Stream): + stream_id: _int + device_index: _int + device_type: _int + + device: _device + cuda_stream: _int + priority: _int + + def __new__( + cls, + priority: _int = 0, + stream_id: _int = 0, + device_index: _int = 0, + stream_ptr: _int = 0, + ) -> Self: ... + def query(self) -> _bool: ... + def synchronize(self) -> None: ... + def priority_range(self) -> tuple[_int, _int]: ... + +# Defined in torch/csrc/cuda/Event.cpp +class _CudaEventBase: + device: _device + cuda_event: _int + + def __new__( + cls, + enable_timing: _bool = False, + blocking: _bool = False, + interprocess: _bool = False, + external: _bool = False, + ) -> Self: ... + @classmethod + def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ... + def record(self, stream: _CudaStreamBase) -> None: ... + def wait(self, stream: _CudaStreamBase) -> None: ... + def query(self) -> _bool: ... + def elapsed_time(self, other: _CudaEventBase) -> _float: ... + def synchronize(self) -> None: ... + def ipc_handle(self) -> bytes: ... + +# Defined in torch/csrc/cuda/Graph.cpp +class _CUDAGraph: + def __new__(cls, keep_graph: _bool = ...) -> Self: ... + def capture_begin( + self, + pool: tuple[_int, _int] | None = ..., + capture_error_mode: str = "global", + ) -> None: ... + def capture_end(self) -> None: ... + def instantiate(self) -> None: ... + def register_generator_state(self, Generator) -> None: ... + def replay(self) -> None: ... + def reset(self) -> None: ... + def pool(self) -> tuple[_int, _int]: ... + def enable_debug_mode(self) -> None: ... + def debug_dump(self, debug_path: str) -> None: ... + def raw_cuda_graph(self) -> _int: ... + +# Defined in torch/csrc/cuda/MemPool.cpp +class _MemPool: + def __init__( + self, + allocator: _cuda_CUDAAllocator | None = None, + is_user_created: _bool = True, + use_on_oom: _bool = False, + symmetric: _bool = False, + ) -> None: ... + @property + def id(self) -> tuple[_int, _int]: ... + @property + def is_symmetric(self) -> _bool: ... + @property + def allocator(self) -> _cuda_CUDAAllocator | None: ... + def use_count(self) -> _int: ... + +def _cuda_isCurrentStreamCapturing() -> _bool: ... +def _graph_pool_handle() -> tuple[_int, _int]: ... + +# Defined in torch/csrc/xpu/Module.cpp +def _xpu_setDevice(device: _int) -> None: ... +def _xpu_exchangeDevice(device: _int) -> _int: ... +def _xpu_maybeExchangeDevice(device: _int) -> _int: ... +def _xpu_getDevice() -> _int: ... +def _xpu_getDeviceCount() -> _int: ... +def _xpu_getArchFlags() -> str | None: ... +def _xpu_init() -> None: ... +def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ... +def _xpu_getCurrentStream(device: _int) -> tuple: ... +def _xpu_getCurrentRawStream(device: _int) -> _int: ... +def _xpu_getStreamFromExternal(data_ptr: _int, device_index: _int) -> tuple: ... +def _xpu_synchronize(device: _int) -> None: ... +def _xpu_emptyCache() -> None: ... +def _xpu_memoryStats(device: _int) -> dict[str, Any]: ... +def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ... +def _xpu_resetPeakMemoryStats(device: _int) -> None: ... +def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ... + +class _XpuDeviceProperties: + name: str + platform_name: str + vendor: str + driver_version: str + version: str + max_compute_units: _int + gpu_eu_count: _int + max_work_group_size: _int + max_num_sub_groups: _int + sub_group_sizes: list[_int] + has_fp16: _bool + has_fp64: _bool + has_atomic64: _bool + has_bfloat16_conversions: _bool + has_subgroup_matrix_multiply_accumulate: _bool + has_subgroup_matrix_multiply_accumulate_tensor_float32: _bool + has_subgroup_2d_block_io: _bool + total_memory: _int + gpu_subslice_count: _int + architecture: _int + type: str + +# Defined in torch/csrc/xpu/Stream.cpp +class _XpuStreamBase(Stream): + stream_id: _int + device_index: _int + device_type: _int + + device: _device + sycl_queue: _int + priority: _int + + def __new__( + cls, + priority: _int = 0, + stream_id: _int = 0, + device_index: _int = 0, + device_type: _int = 0, + ) -> Self: ... + def query(self) -> _bool: ... + def synchronize(self) -> None: ... + @staticmethod + def priority_range() -> tuple: ... + +# Defined in torch/csrc/xpu/Event.cpp +class _XpuEventBase: + device: _device + sycl_event: _int + + def __new__(cls, enable_timing: _bool = False) -> Self: ... + def record(self, stream: _XpuEventBase) -> None: ... + def wait(self, stream: _XpuStreamBase) -> None: ... + def query(self) -> _bool: ... + def elapsed_time(self, other: _XpuEventBase) -> _float: ... + def synchronize(self) -> None: ... + +# Defined in torch/csrc/DataLoader.cpp +def _set_worker_signal_handlers( + *arg: Any, +) -> None: ... # THPModule_setWorkerSignalHandlers +def _set_worker_pids( + key: _int, + child_pids: tuple[_int, ...], +) -> None: ... # THPModule_setWorkerPIDs +def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs +def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails + +# Defined in torch/csrc/DeviceAccelerator.cpp +def _accelerator_getAccelerator() -> _device: ... +def _accelerator_setDeviceIndex(device_index: _int) -> None: ... +def _accelerator_getDeviceIndex() -> _int: ... +def _accelerator_setStream(Stream) -> None: ... +def _accelerator_getStream(device_index: _int) -> Stream: ... +def _accelerator_synchronizeDevice(device_index: _int) -> None: ... +def _accelerator_exchangeDevice(device_index: _int) -> _int: ... +def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... + +# Defined in torch/csrc/jit/python/python_tracer.cpp +class TracingState: + def push_scope(self, scope_name: str) -> None: ... + def pop_scope(self) -> None: ... + def current_scope(self) -> str: ... + def set_graph(self, graph: Graph) -> None: ... + def graph(self) -> Graph: ... + +def _create_graph_by_tracing( + func: Callable[..., Any], + inputs: Any, + var_name_lookup_fn: Callable[[Tensor], str], + strict: Any, + force_outplace: Any, + self: Any = None, + argument_names: list[str] = ..., +) -> tuple[Graph, Stack]: ... +def _tracer_warn_use_python(): ... +def _get_tracing_state() -> TracingState: ... + +# Defined in torch/csrc/jit/python/python_ir.cpp +# Not actually defined in python_ir.cpp, not sure where they are. +class IValue: ... + +Stack: TypeAlias = list[IValue] + +class JitType: + annotation_str: str + def isSubtypeOf(self, other: JitType) -> _bool: ... + def with_dtype(self, dtype: _dtype) -> JitType: ... + def with_sizes(self, sizes: list[_int | None]) -> JitType: ... + def kind(self) -> str: ... + def scalarType(self) -> str | None: ... + def getElementType(self) -> JitType: ... + def dtype(self) -> _dtype | None: ... + +class InferredType: + def __init__(self, arg: JitType | str) -> None: ... + def type(self) -> JitType: ... + def success(self) -> _bool: ... + def reason(self) -> str: ... + +class Type(JitType): + def str(self) -> _str: ... + def containedTypes(self) -> list[JitType]: ... + def dim(self) -> _int | None: ... + def undefined(self) -> _bool | None: ... + def sizes(self) -> list[_int] | None: ... + def symbol_sizes(self) -> list[_int] | None: ... + def varyingSizes(self) -> list[_int | None] | None: ... + def strides(self) -> list[_int] | None: ... + def contiguous(self) -> Self: ... + def device(self) -> _device | None: ... + def is_interface_type(self) -> _bool: ... + def requires_grad(self) -> _bool: ... + @property + def annotation_string(self) -> _str: ... + +class AnyType(JitType): + @staticmethod + def get() -> AnyType: ... + +class NoneType(JitType): + @staticmethod + def get() -> NoneType: ... + +class BoolType(JitType): + @staticmethod + def get() -> BoolType: ... + +class FloatType(JitType): + @staticmethod + def get() -> FloatType: ... + +class ComplexType(JitType): + @staticmethod + def get() -> ComplexType: ... + +class IntType(JitType): + @staticmethod + def get() -> IntType: ... + +class SymIntType(JitType): + @staticmethod + def get() -> SymIntType: ... + +class SymBoolType(JitType): + @staticmethod + def get() -> SymBoolType: ... + +class NumberType(JitType): + @staticmethod + def get() -> NumberType: ... + +class StringType(JitType): + @staticmethod + def get() -> StringType: ... + +class DeviceObjType(JitType): + @staticmethod + def get() -> DeviceObjType: ... + +class _GeneratorType(JitType): + @staticmethod + def get() -> _GeneratorType: ... + +class StreamObjType(JitType): + @staticmethod + def get() -> StreamObjType: ... + +class ListType(JitType): + def __init__(self, a: JitType) -> None: ... + def getElementType(self) -> JitType: ... + @staticmethod + def ofInts() -> ListType: ... + @staticmethod + def ofTensors() -> ListType: ... + @staticmethod + def ofFloats() -> ListType: ... + @staticmethod + def ofComplexDoubles() -> ListType: ... + @staticmethod + def ofBools() -> ListType: ... + @staticmethod + def ofStrings() -> ListType: ... + +class DictType(JitType): + def __init__(self, key: JitType, value: JitType) -> None: ... + def getKeyType(self) -> JitType: ... + def getValueType(self) -> JitType: ... + +class TupleType(JitType): + def __init__(self, a: list[JitType | None]) -> None: ... + def elements(self) -> list[JitType]: ... + +class UnionType(JitType): + def __init__(self, a: list[JitType]) -> None: ... + +class ClassType(JitType): + def __init__(self, qualified_name: str) -> None: ... + def qualified_name(self) -> str: ... + +class InterfaceType(JitType): + def __init__(self, qualified_name: str) -> None: ... + def getMethod(self, name: str) -> FunctionSchema | None: ... + def getMethodNames(self) -> list[str]: ... + +JitTypeT = TypeVar("JitTypeT", bound=JitType) # noqa: PYI001 + +class OptionalType(JitType, Generic[JitTypeT]): + def __init__(self, a: JitTypeT) -> None: ... + def getElementType(self) -> JitTypeT: ... + @staticmethod + def ofTensor() -> OptionalType: ... + +class FutureType(JitType): + def __init__(self, a: JitType) -> None: ... + def getElementType(self) -> JitType: ... + +class AwaitType(JitType): + def __init__(self, a: JitType) -> None: ... + def getElementType(self) -> JitType: ... + +class RRefType(JitType): + def __init__(self, a: JitType) -> None: ... + +class EnumType(JitType): + def __init__( + self, + qualified_name: str, + value_type: JitType, + enum_names_values: list[Any], + ) -> None: ... + +class TensorType(JitType): + @classmethod + def get(cls) -> TensorType: ... + @classmethod + def getInferred(cls) -> TensorType: ... + def with_sizes(self, other: list[_int | None] | None) -> TensorType: ... + def sizes(self) -> list[_int] | None: ... + def varyingSizes(self) -> list[_int | None] | None: ... + def strides(self) -> list[_int] | None: ... + def device(self) -> _device | None: ... + def dim(self) -> _int: ... + def dtype(self) -> _dtype | None: ... + @staticmethod + def create_from_tensor(t: Tensor) -> TensorType: ... + +# Defined in torch/csrc/jit/python/python_tree_views.cpp +class SourceRange: ... +class TreeView: ... + +class Ident(TreeView): + @property + def name(self) -> str: ... + +class ClassDef(TreeView): ... + +class Def(TreeView): + def name(self) -> Ident: ... + +class Decl(TreeView): ... + +# Defined in torch/csrc/distributed/rpc/init.cpp +def _rpc_init() -> _bool: ... + +# Defined in torch/csrc/distributed/autograd/init.cpp +def _dist_autograd_init() -> _bool: ... + +# Defined in torch/csrc/distributed/c10d/init.cpp +def _c10d_init() -> _bool: ... + +# Defined in torch/csrc/distributed/rpc/testing/init.cpp +def _faulty_agent_init() -> _bool: ... +def _register_py_class_for_device(device: str, cls: Any) -> None: ... + +# Defined in torch/csrc/Module.cpp +def _current_graph_task_id() -> _int: ... +def _current_autograd_node() -> _Node: ... +def _will_engine_execute_node(node: _Node) -> _bool: ... +def _dispatch_key_set(tensor) -> str: ... + +# Defined in torch/csrc/Exceptions.cpp +class AcceleratorError(RuntimeError): ... +class OutOfMemoryError(RuntimeError): ... +class _DistError(RuntimeError): ... +class _DistBackendError(RuntimeError): ... +class _DistStoreError(RuntimeError): ... +class _DistNetworkError(RuntimeError): ... +class _DistQueueEmptyError(_DistStoreError): ... + +# Defined in torch/csrc/profiler/init.cpp +class CapturedTraceback: ... + +def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ... +def symbolize_tracebacks( + tracebacks: list[CapturedTraceback], +) -> list[dict[str, Any]]: ... +def _load_mobile_module_from_file(filename: str): ... +def _load_mobile_module_from_bytes(bytes_: bytes): ... +def _load_jit_module_from_file(filename: str): ... +def _load_jit_module_from_bytes(bytes_: bytes): ... +def _save_mobile_module(m: LiteScriptModule, filename: str): ... +def _save_jit_module(m: ScriptModule, filename: str, extra_files: dict[str, Any]): ... +def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ... +def _save_jit_module_to_bytes( + m: ScriptModule, + extra_files: dict[str, Any], +) -> bytes: ... +def _get_module_info_from_flatbuffer(data: bytes): ... +def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ... +def _swap_tensor_impl(t1: Tensor, t2: Tensor): ... +def _pickle_save(obj: Any) -> bytes: ... +def _pickle_load_obj(bs: bytes) -> Any: ... + +# Defined in torch/csrc/jit/runtime/static/init.cpp +def _jit_to_static_module(graph_or_module: Graph | ScriptModule) -> Any: ... +def _fuse_to_static_module( + graph_or_module: Graph | ScriptModule, + min_size: _int, +) -> Any: ... + +# Defined in torch/csrc/fx/node.cpp +def _fx_map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: ... +def _fx_map_arg(a: Any, fn: Callable[[Any], Any]) -> Any: ... + +class _NodeBase: + _erased: _bool + _prev: FxNode + _next: FxNode + def __init__( + self, + graph: Any, + name: str, + op: str, + target: Any, + return_type: Any, + ) -> None: ... + def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ... + +class _NodeIter(Iterator[FxNode]): + def __init__(self, root: FxNode, reversed: _bool) -> None: ... + def __iter__(self) -> Self: ... + def __next__(self) -> FxNode: ... + +# Defined in torch/csrc/inductor/static_cuda_launcher.cpp +class _StaticCudaLauncher: + @staticmethod + def _load_kernel( + cubin_file: str, + func_name: str, + shared_mem_bytes: _int, + device: _int, + ) -> tuple[_int, _int, _int]: ... + @staticmethod + def _launch_kernel( + func: _int, + grid_x: _int, + grid_y: _int, + grid_z: _int, + num_warps: _int, + shared_mem_bytes: _int, + arg_types: str, + args: tuple[Any, ...], + stream: _int, + ) -> None: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_aoti.pyi b/phivenv/Lib/site-packages/torch/_C/_aoti.pyi new file mode 100644 index 0000000000000000000000000000000000000000..aefa10318f3c8efd48ddf4d30602ac500803a37f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_aoti.pyi @@ -0,0 +1,164 @@ +from ctypes import c_void_p +from typing import overload, Protocol + +from torch import Tensor + +# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp + +# Tensor to AtenTensorHandle +def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ... +def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ... + +# AtenTensorHandle to Tensor +def alloc_tensors_by_stealing_from_void_ptrs( + handles: list[c_void_p], +) -> list[Tensor]: ... +def alloc_tensor_by_stealing_from_void_ptr( + handle: c_void_p, +) -> Tensor: ... + +class AOTIModelContainerRunner(Protocol): + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... + +class AOTIModelContainerRunnerCpu: + def __init__(self, model_so_path: str, num_models: int) -> None: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... + +class AOTIModelContainerRunnerCuda: + @overload + def __init__(self, model_so_path: str, num_models: int) -> None: ... + @overload + def __init__( + self, model_so_path: str, num_models: int, device_str: str + ) -> None: ... + @overload + def __init__( + self, model_so_path: str, num_models: int, device_str: str, cubin_dir: str + ) -> None: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... + +class AOTIModelContainerRunnerXpu: + @overload + def __init__(self, model_so_path: str, num_models: int) -> None: ... + @overload + def __init__( + self, model_so_path: str, num_models: int, device_str: str + ) -> None: ... + @overload + def __init__( + self, model_so_path: str, num_models: int, device_str: str, kernel_bin_dir: str + ) -> None: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... + +class AOTIModelContainerRunnerMps: + def __init__(self, model_so_path: str, num_models: int) -> None: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_names_to_original_fqns(self) -> dict[str, str]: ... + def get_constant_names_to_dtypes(self) -> dict[str, int]: ... + def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... + def swap_constant_buffer(self) -> None: ... + def free_inactive_constant_buffer(self) -> None: ... + +# Defined in torch/csrc/inductor/aoti_package/pybind.cpp +class AOTIModelPackageLoader: + def __init__( + self, + model_package_path: str, + model_name: str, + run_single_threaded: bool, + num_runners: int, + device_index: int, + ) -> None: ... + def get_metadata(self) -> dict[str, str]: ... + def run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def boxed_run( + self, inputs: list[Tensor], stream_handle: c_void_p = ... + ) -> list[Tensor]: ... + def get_call_spec(self) -> list[str]: ... + def get_constant_fqns(self) -> list[str]: ... + def load_constants( + self, + constants_map: dict[str, Tensor], + use_inactive: bool, + check_full_update: bool, + user_managed: bool = ..., + ) -> None: ... + def update_constant_buffer( + self, + tensor_map: dict[str, Tensor], + use_inactive: bool, + validate_full_updates: bool, + user_managed: bool = ..., + ) -> None: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_autograd.pyi b/phivenv/Lib/site-packages/torch/_C/_autograd.pyi new file mode 100644 index 0000000000000000000000000000000000000000..9f35f079f001ef888f49f593eb7d81cf920212d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_autograd.pyi @@ -0,0 +1,141 @@ +# mypy: allow-untyped-defs +from enum import Enum +from typing import Any, Callable + +import torch +from torch._C._profiler import ( + _ProfilerEvent, + ActiveProfilerType, + ProfilerActivity, + ProfilerConfig, +) + +# Defined in torch/csrc/autograd/init.cpp + +class DeviceType(Enum): + CPU = ... + CUDA = ... + XPU = ... + MKLDNN = ... + OPENGL = ... + OPENCL = ... + IDEEP = ... + HIP = ... + FPGA = ... + MAIA = ... + XLA = ... + MTIA = ... + MPS = ... + HPU = ... + Meta = ... + Vulkan = ... + Metal = ... + PrivateUse1 = ... + +class ProfilerEvent: + def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ... + def cpu_memory_usage(self) -> int: ... + def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ... + def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ... + def cuda_memory_usage(self) -> int: ... + def device(self) -> int: ... + def handle(self) -> int: ... + def has_cuda(self) -> bool: ... + def is_remote(self) -> bool: ... + def kind(self) -> int: ... + def name(self) -> str: ... + def node_id(self) -> int: ... + def sequence_nr(self) -> int: ... + def shapes(self) -> list[list[int]]: ... + def thread_id(self) -> int: ... + def flops(self) -> float: ... + def is_async(self) -> bool: ... + +class _KinetoEvent: + def name(self) -> str: ... + def overload_name(self) -> str: ... + def device_index(self) -> int: ... + def device_resource_id(self) -> int: ... + def start_ns(self) -> int: ... + def end_ns(self) -> int: ... + def duration_ns(self) -> int: ... + def is_async(self) -> bool: ... + def linked_correlation_id(self) -> int: ... + def shapes(self) -> list[list[int]]: ... + def dtypes(self) -> list[str]: ... + def concrete_inputs(self) -> list[Any]: ... + def kwinputs(self) -> dict[str, Any]: ... + def device_type(self) -> DeviceType: ... + def start_thread_id(self) -> int: ... + def end_thread_id(self) -> int: ... + def correlation_id(self) -> int: ... + def fwd_thread_id(self) -> int: ... + def stack(self) -> list[str]: ... + def scope(self) -> int: ... + def sequence_nr(self) -> int: ... + def flops(self) -> int: ... + def cuda_elapsed_us(self) -> int: ... + def privateuse1_elapsed_us(self) -> int: ... + def is_user_annotation(self) -> bool: ... + +class _ProfilerResult: + def events(self) -> list[_KinetoEvent]: ... + def legacy_events(self) -> list[list[ProfilerEvent]]: ... + def save(self, path: str) -> None: ... + def experimental_event_tree(self) -> list[_ProfilerEvent]: ... + def trace_start_ns(self) -> int: ... + +class SavedTensor: ... + +def _enable_profiler( + config: ProfilerConfig, + activities: set[ProfilerActivity], +) -> None: ... +def _prepare_profiler( + config: ProfilerConfig, + activities: set[ProfilerActivity], +) -> None: ... +def _toggle_collection_dynamic( + enable: bool, + activities: set[ProfilerActivity], +) -> None: ... +def _disable_profiler() -> _ProfilerResult: ... +def _profiler_enabled() -> bool: ... +def _add_metadata_json(key: str, value: str) -> None: ... +def _kineto_step() -> None: ... +def _get_current_graph_task_keep_graph() -> bool: ... +def _get_sequence_nr() -> int: ... +def kineto_available() -> bool: ... +def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ... +def _record_function_with_args_exit(handle: torch.Tensor) -> None: ... +def _supported_activities() -> set[ProfilerActivity]: ... +def _enable_record_function(enable: bool) -> None: ... +def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... +def _push_saved_tensors_default_hooks( + pack_hook: Callable[[torch.Tensor], Any], + unpack_hook: Callable[[Any], torch.Tensor], +) -> None: ... +def _pop_saved_tensors_default_hooks() -> None: ... +def _top_saved_tensors_default_hooks( + ignore_is_tracing: bool, +) -> tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]]: ... +def _unsafe_set_version_counter( + t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...] +) -> None: ... +def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... +def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ... +def _profiler_type() -> ActiveProfilerType: ... +def _saved_tensors_hooks_enable() -> None: ... +def _saved_tensors_hooks_disable(message: str, fail_if_non_empty=True) -> None: ... +def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ... +def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ... + +class CreationMeta(Enum): + DEFAULT = ... + IN_CUSTOM_FUNCTION = ... + MULTI_OUTPUT_NODE = ... + NO_GRAD_MODE = ... + INFERENCE_MODE = ... + +def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ... +def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_cpu.pyi b/phivenv/Lib/site-packages/torch/_C/_cpu.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6804872c957177762c58bb598c74e5bcd8bc26f2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_cpu.pyi @@ -0,0 +1,13 @@ +from torch.types import _bool, _int + +# Defined in torch/csrc/cpu/Module.cpp + +def _is_avx2_supported() -> _bool: ... +def _is_avx512_supported() -> _bool: ... +def _is_avx512_vnni_supported() -> _bool: ... +def _is_avx512_bf16_supported() -> _bool: ... +def _is_amx_tile_supported() -> _bool: ... +def _is_amx_fp16_supported() -> _bool: ... +def _init_amx() -> _bool: ... +def _L1d_cache_size() -> _int: ... +def _L2_cache_size() -> _int: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_cudnn.pyi b/phivenv/Lib/site-packages/torch/_C/_cudnn.pyi new file mode 100644 index 0000000000000000000000000000000000000000..fab07805443d63f7a343f69d0be5b334e7547fa8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_cudnn.pyi @@ -0,0 +1,14 @@ +from enum import IntEnum + +# Defined in torch/csrc/cuda/shared/cudnn.cpp +is_cuda: bool + +def getRuntimeVersion() -> tuple[int, int, int]: ... +def getCompileVersion() -> tuple[int, int, int]: ... +def getVersionInt() -> int: ... + +class RNNMode(IntEnum): + rnn_relu = ... + rnn_tanh = ... + lstm = ... + gru = ... diff --git a/phivenv/Lib/site-packages/torch/_C/_cusparselt.pyi b/phivenv/Lib/site-packages/torch/_C/_cusparselt.pyi new file mode 100644 index 0000000000000000000000000000000000000000..47bbaacf10f1a6755fa1c2e81cf48f237b07fed8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_cusparselt.pyi @@ -0,0 +1 @@ +def getVersionInt() -> int: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_distributed_autograd.pyi b/phivenv/Lib/site-packages/torch/_C/_distributed_autograd.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a2b4aedd34acb358f57b8d4a80fab399a7121ebf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_distributed_autograd.pyi @@ -0,0 +1,26 @@ +from typing import Any + +import torch + +# This module is defined in torch/csrc/distributed/autograd/init.cpp + +class DistAutogradContext: + def _context_id(self) -> int: ... + def _recv_functions(self) -> dict[int, Any]: ... + def _send_functions(self) -> dict[int, Any]: ... + def _known_worker_ids(self) -> set[int]: ... + +def _new_context() -> DistAutogradContext: ... +def _release_context(context_id: int) -> None: ... +def _get_max_id() -> int: ... +def _is_valid_context(worker_id: int) -> bool: ... +def _retrieve_context(context_id: int) -> DistAutogradContext: ... +def _current_context() -> DistAutogradContext: ... +def _init(worker_id: int) -> None: ... +def _get_debug_info() -> dict[str, str]: ... +def backward( + context_id: int, + roots: list[torch.Tensor], + retain_graph: bool = False, +) -> None: ... +def get_gradients(context_id: int) -> dict[torch.Tensor, torch.Tensor]: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_distributed_c10d.pyi b/phivenv/Lib/site-packages/torch/_C/_distributed_c10d.pyi new file mode 100644 index 0000000000000000000000000000000000000000..1a0ca2343b3860f267c05914d826553f10e371f3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_distributed_c10d.pyi @@ -0,0 +1,797 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="type-arg" +from datetime import timedelta +from enum import Enum +from typing import Any, Optional, overload, Union + +import torch +from torch import Tensor +from torch._C import ScriptObject +from torch._C._autograd import DeviceType +from torch.futures import Future + +# This module is defined in torch/csrc/distributed/c10d/init.cpp + +_DEFAULT_FIRST_BUCKET_BYTES: int +_DEFAULT_NO_TIMEOUT: timedelta +_DEFAULT_PG_TIMEOUT: timedelta +_DEFAULT_PG_NCCL_TIMEOUT: timedelta + +class BuiltinCommHookType(Enum): + ALLREDUCE = ... + FP16_COMPRESS = ... + +def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ... +def _register_builtin_comm_hook( + reducer: Reducer, + comm_hook_type: BuiltinCommHookType, +): ... +def _set_global_rank(rank: int) -> None: ... +def _hash_tensors(tensors: list[Tensor]) -> int: ... + +class GradBucket: + def index(self) -> int: ... + def buffer(self) -> Tensor: ... + def gradients(self) -> list[Tensor]: ... + def is_last(self) -> bool: ... + def set_buffer(self, tensor: Tensor) -> None: ... + def parameters(self) -> list[Tensor]: ... + +class Reducer: + def __init__( + self, + params: list[Tensor], + bucket_indices: list[list[int]], + per_bucket_size_limits: list[int], + process_group: ProcessGroup, + expect_sparse_gradients: list[bool] = ..., + bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp + find_unused_parameters: bool = ..., + gradient_as_bucket_view: bool = ..., + param_to_name_mapping: dict[int, str] = ..., + first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp + skip_all_reduce_unused_params: bool = ..., + use_python_reducer: bool = ..., + ) -> None: ... + def prepare_for_forward(self) -> None: ... + def prepare_for_backward(self, output: list[Tensor]) -> None: ... + def get_backward_stats(self) -> list[int]: ... + def _install_post_backward_futures(self, futures: list[Future]) -> None: ... + def _rebuild_buckets(self) -> bool: ... + def _get_zeros_like_grad_buckets(self) -> list[GradBucket]: ... + def _push_all_rebuilt_params(self) -> None: ... + def _set_forward_pass_work_handle( + self, + work: Work, + use_static_world_size: bool, + ): ... + def _get_local_used_map(self) -> Tensor: ... + def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ... + def _set_static_graph(self) -> None: ... + def _run_comm_hook(self, bucket: GradBucket) -> Future: ... + def set_logger(self, logger: Logger) -> None: ... + def _remove_autograd_hooks(self) -> None: ... + def _check_reducer_finalized(self) -> None: ... + def _set_sparse_metadata(self, global_unique_ids: dict[str, Tensor]) -> None: ... + def _reset_state(self) -> None: ... + def _update_process_group(self, new_process_group: ProcessGroup) -> None: ... + +class DDPLoggingData: + strs_map: dict[str, str] + ints_map: dict[str, int] + +class Logger: + def __init__(self, reducer: Reducer) -> None: ... + def set_construction_data_and_log( + self, + module_name: str, + device_ids: list[int], + output_device: int, + broadcast_buffers: bool, + has_sync_bn: bool, + static_graph: bool, + ): ... + def set_runtime_stats_and_log(self) -> None: ... + def set_error_and_log(self, error: str) -> None: ... + def _get_ddp_logging_data(self) -> DDPLoggingData: ... + def _set_comm_hook_name(self, comm_hook: str) -> None: ... + def _set_uneven_input_join(self) -> None: ... + def _set_static_graph(self) -> None: ... + +class _WorkerServer: + def __init__(self, socket_path: str) -> None: ... + def shutdown(self) -> None: ... + +def get_debug_level(): ... +def set_debug_level(): ... +def set_debug_level_from_env(): ... + +class DebugLevel(Enum): + OFF = ... + INFO = ... + DETAIL = ... + +class ReduceOp: + def __init__(self, op: RedOpType) -> None: ... + + SUM: RedOpType = ... + AVG: RedOpType = ... + PRODUCT: RedOpType = ... + MIN: RedOpType = ... + MAX: RedOpType = ... + BAND: RedOpType = ... + BOR: RedOpType = ... + BXOR: RedOpType = ... + PREMUL_SUM: RedOpType = ... + UNUSED: RedOpType = ... + + # mypy error being ignored: + # Detected enum "torch._C._distributed_c10d.ReduceOp.RedOpType" in a type + # stub with zero members. There is a chance this is due to a recent change + # in the semantics of enum membership. If so, use `member = value` to mark + # an enum member, instead of `member: type` + class RedOpType(Enum): ... # type: ignore[misc] + +class BroadcastOptions: + rootRank: int + rootTensor: int + timeout: timedelta + asyncOp: bool + +class AllreduceOptions: + reduceOp: ReduceOp + timeout: timedelta + asyncOp: bool + sparseIndices: Optional[Tensor] + +class AllreduceCoalescedOptions(AllreduceOptions): ... + +class ReduceOptions: + reduceOp: ReduceOp + rootRank: int + rootTensor: int + timeout: timedelta + asyncOp: bool + +class AllgatherOptions: + timeout: timedelta + asyncOp: bool + +class GatherOptions: + rootRank: int + timeout: timedelta + asyncOp: bool + +class ScatterOptions: + rootRank: int + timeout: timedelta + asyncOp: bool + +class ReduceScatterOptions: + reduceOp: ReduceOp + timeout: timedelta + asyncOp: bool + +class BarrierOptions: + device_ids: list[int] + device: torch.device + timeout: timedelta + asyncOp: bool + +class AllToAllOptions: + timeout: timedelta + asyncOp: bool + +class Store: + def set(self, key: str, value: str): ... + def get(self, key: str) -> bytes: ... + def add(self, key: str, value: int) -> int: ... + def check(self, keys: list[str]) -> bool: ... + def compare_set( + self, + key: str, + expected_value: str, + desired_value: str, + ) -> bytes: ... + def delete_key(self, key: str) -> bool: ... + def num_keys(self) -> int: ... + def set_timeout(self, timeout: timedelta): ... + @overload + def wait(self, keys: list[str]): ... + @overload + def wait(self, keys: list[str], timeout: timedelta): ... + def queue_pop(self, key: str, block: bool = True) -> bytes: ... + def queue_push(self, key: str, value: Union[bytes, str]) -> None: ... + def queue_len(self, key: str) -> int: ... + +class FileStore(Store): + def __init__(self, path: str, numWorkers: int = ...) -> None: ... + +class HashStore(Store): + def __init__(self) -> None: ... + +class TCPStore(Store): + def __init__( + self, + host_name: str, + port: int, + world_size: int | None = ..., + is_master: bool = ..., + timeout: timedelta = ..., + wait_for_workers: bool = ..., + multi_tenant: bool = ..., + master_listen_fd: int | None = ..., + use_libuv: bool | None = ..., + ) -> None: ... + @property + def host(self) -> str: ... + @property + def port(self) -> int: ... + +class PrefixStore(Store): + def __init__(self, prefix: str, store: Store) -> None: ... + @property + def underlying_store(self) -> Store: ... + +class _ControlCollectives: + def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ... + def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ... + def broadcast_recv(self, key: str, timeout: timedelta) -> str: ... + def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ... + def gather_recv(self, key: str, timeout: timedelta) -> str: ... + def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ... + def scatter_recv(self, key: str, timeout: timedelta) -> str: ... + def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ... + def all_sum(self, key: str, data: int, timeout: timedelta) -> int: ... + +class _StoreCollectives(_ControlCollectives): + def __init__(self, store: Store, rank: int, world_size: int) -> None: ... + +class _DistributedBackendOptions: + def __init__(self) -> None: ... + @property + def store(self) -> Store: ... + @store.setter + def store(self, store: Store) -> None: ... + @property + def group_rank(self) -> int: ... + @group_rank.setter + def group_rank(self, rank: int) -> None: ... + @property + def group_size(self) -> int: ... + @group_size.setter + def group_size(self, size: int) -> None: ... + @property + def timeout(self) -> timedelta: ... + @timeout.setter + def timeout(self, timeout: timedelta) -> None: ... + @property + def group_id(self) -> str: ... + @group_id.setter + def group_id(self, group_id: str) -> None: ... + @property + def global_ranks_in_group(self) -> list[int]: ... + @global_ranks_in_group.setter + def global_ranks_in_group(self, ranks: list[int]) -> None: ... + +class Work: + def is_completed(self) -> bool: ... + def is_success(self) -> bool: ... + def exception(self) -> Any: ... + def wait(self, timeout: timedelta = ...) -> bool: ... + def get_future(self) -> Future: ... + def source_rank(self) -> int: ... + def _source_rank(self) -> int: ... + def result(self) -> list[Tensor]: ... + def synchronize(self): ... + def boxed(self) -> ScriptObject: ... + @staticmethod + def unbox(obj: ScriptObject) -> Work: ... + +class Backend: + class Options: + def __init__(self, backend: str, timeout: timedelta = ...) -> None: ... + @property + def backend(self) -> str: ... + @property + def _timeout(self) -> timedelta: ... + @_timeout.setter + def _timeout(self, val: timedelta) -> None: ... + + def __init__( + self, + rank: int, + size: int, + ) -> None: ... + @property + def supports_splitting(self) -> bool: ... + @property + def supports_coalescing(self) -> bool: ... + @property + def supports_time_estimate(self) -> bool: ... + @property + def options(self) -> Options: ... + def rank(self) -> int: ... + def size(self) -> int: ... + def abort(self) -> None: ... + def shutdown(self) -> None: ... + def eager_connect_single_device(self, device: torch.device | None) -> None: ... + def _set_sequence_number_for_group(self) -> None: ... + def _set_default_timeout(self, timeout: timedelta) -> None: ... + def get_error(self) -> ErrorType: ... + def supports_tensor_alloc(self, device: torch.device) -> bool: ... + def allocate_tensor( + self, + size: int, + *, + dtype: torch.dtype, + device: torch.device, + ) -> Tensor: ... + @property + def mem_allocator(self) -> Any: ... + +class ProcessGroup: + class BackendType(Enum): + UNDEFINED = ... + GLOO = ... + NCCL = ... + UCC = ... + MPI = ... + XCCL = ... + CUSTOM = ... + + def __init__( + self, + store: Store, + rank: int, + size: int, + ) -> None: ... + def rank(self) -> int: ... + def size(self) -> int: ... + def abort(self) -> None: ... + def shutdown(self) -> None: ... + @overload + def broadcast( + self, + tensors: list[Tensor], + opts=..., + ) -> Work: ... + @overload + def broadcast( + self, + tensor: Tensor, + root: int, + ) -> Work: ... + @overload + def allreduce( + self, + tensors: list[Tensor], + opts: AllreduceOptions = ..., + ) -> Work: ... + @overload + def allreduce( + self, + tensors: list[Tensor], + op=..., + ) -> Work: ... + @overload + def allreduce( + self, + tensor: Tensor, + op=..., + ) -> Work: ... + def allreduce_coalesced( + self, + tensors: list[Tensor], + opts=..., + ) -> Work: ... + def reduce_scatter_tensor_coalesced( + self, + outputTensors: list[Tensor], + inputTensors: list[Tensor], + opts: ReduceScatterOptions | None = None, + ) -> Work: ... + @overload + def reduce( + self, + tensors: list[Tensor], + opts=..., + ) -> Work: ... + @overload + def reduce( + self, + tensor: Tensor, + root: int, + op=..., + ) -> Work: ... + @overload + def allgather( + self, + output_tensors: list[list[Tensor]], + input_tensors: list[Tensor], + opts=..., + ) -> Work: ... + @overload + def allgather( + self, + output_tensors: list[Tensor], + input_tensor: Tensor, + ) -> Work: ... + def _allgather_base( + self, + output: Tensor, + input: Tensor, + opts=..., + ) -> Work: ... + def allgather_coalesced( + self, + output_lists: list[list[Tensor]], + input_list: list[Tensor], + opts=..., + ) -> Work: ... + def allgather_into_tensor_coalesced( + self, + output_lists: list[Tensor], + input_list: list[Tensor], + opts=..., + ) -> Work: ... + @overload + def gather( + self, + output_tensors: list[list[Tensor]], + input_tensors: list[Tensor], + opts=..., + ) -> Work: ... + @overload + def gather( + self, + output_tensors: list[Tensor], + input_tensor: Tensor, + root: int, + ) -> Work: ... + @overload + def scatter( + self, + output_tensors: list[Tensor], + input_tensors: list[list[Tensor]], + opts=..., + ) -> Work: ... + @overload + def scatter( + self, + output_tensor: Tensor, + input_tensors: list[Tensor], + root: int, + ) -> Work: ... + @overload + def reduce_scatter( + self, + output_tensors: list[Tensor], + input_tensors: list[list[Tensor]], + opts=..., + ) -> Work: ... + @overload + def reduce_scatter( + self, + output_tensors: Tensor, + input_tensor: list[Tensor], + ) -> Work: ... + def _reduce_scatter_base( + self, + outputTensor: Tensor, + inputTensor: Tensor, + opts: ReduceScatterOptions | None, + ) -> Work: ... + @overload + def alltoall_base( + self, + output_tensor: Tensor, + input_tensor: Tensor, + output_split_sizes: list[int], + input_split_sizes: list[int], + opts=..., + ) -> Work: ... + @overload + def alltoall_base( + self, + output: Tensor, + input: Tensor, + output_split_sizes: list[int], + input_split_sizes: list[int], + ) -> Work: ... + @overload + def alltoall( + self, + output_tensor: list[Tensor], + input_tensor: list[Tensor], + opts=..., + ) -> Work: ... + @overload + def alltoall( + self, + output: list[Tensor], + input: list[Tensor], + ) -> Work: ... + def send( + self, + tensors: list[Tensor], + dstRank: int, + tag: int, + ) -> Work: ... + def recv( + self, + tensors: list[Tensor], + srcRank: int, + tag: int, + ) -> Work: ... + def recv_anysource(self, tensors: list[Tensor], tag: int) -> Work: ... + def barrier(self, opts=...) -> Work: ... + def boxed(self) -> ScriptObject: ... + @staticmethod + def unbox(obj: ScriptObject) -> ProcessGroup: ... + def _start_coalescing(self, device: torch.device) -> None: ... + def _end_coalescing(self, device: torch.device) -> Work: ... + def _get_backend_name(self) -> str: ... + def _backend_id(self, backend_type: BackendType) -> int: ... + @property + def _device_types(self) -> list[torch.device]: ... + def _get_backend(self, device: torch.device) -> Backend: ... + def _set_default_backend(self, backend_type: BackendType) -> None: ... + def _register_backend( + self, + device: torch.device, + backend_type: BackendType, + backend: Backend | None, + ) -> None: ... + def _set_group_name(self, name: str) -> None: ... + def _set_group_desc(self, desc: str) -> None: ... + def name(self) -> str: ... + def _has_hooks(self) -> bool: ... + def _wait_for_pending_works(self) -> None: ... + def _set_sequence_number_for_group(self) -> None: ... + @property + def bound_device_id(self) -> torch.device | None: ... + @bound_device_id.setter + def bound_device_id(self, device: torch.device | None) -> None: ... + @property + def group_name(self) -> str: ... + @property + def group_desc(self) -> str: ... + +class FakeProcessGroup(Backend): + def __init__(self, rank: int, world_size: int) -> None: ... + +class FakeWork(Work): + seq_id: int + def __init__(self) -> None: ... + def wait(self, timeout: timedelta = ...) -> bool: ... + def getFuture(self) -> Future: ... + +class ProcessGroupGloo(Backend): + class Device: ... + + class Options(Backend.Options): + devices: list[ProcessGroupGloo.Device] + threads: int + global_ranks_in_group: list[int] + group_name: str + + def __init__(self): ... + + def __init__( + self, + store: Store, + rank: int, + size: int, + timeout: timedelta, + ) -> None: ... + @staticmethod + def create_device(hostname="", interface="", lazy_init=None) -> Device: ... + @staticmethod + def create_default_device(lazy_init=None) -> Device: ... + def _set_default_timeout(self, timeout) -> None: ... + @property + def options(self) -> Options: ... # type: ignore[override] + +class _ProcessGroupWrapper(Backend): + def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo) -> None: ... + wrapped_pg: Backend + +class ErrorType(Enum): + SUCCESS = ... + TIMEOUT = ... + COMM_ERROR = ... + REMOTE_ERROR = ... + +class ProcessGroupNCCL(Backend): + class NCCLConfig: + blocking: int + cga_cluster_size: int + min_ctas: int + max_ctas: int + + class Options(Backend.Options): + config: ProcessGroupNCCL.NCCLConfig + is_high_priority_stream: bool + split_from: ProcessGroupNCCL + split_color: int + global_ranks_in_group: list[int] + group_name: str + + def __init__(self, is_high_priority_stream: bool = False): ... + + def __init__( + self, + store: Store, + rank: int, + size: int, + options: Options, + ) -> None: ... + def _group_start(self) -> None: ... + def _group_end(self) -> None: ... + def _start_time_estimate(self) -> None: ... + def _end_time_estimate(self) -> float: ... + def _set_default_timeout(self, timeout) -> None: ... + def perform_nocolor_split(self, device: torch.device) -> None: ... + def register_mem_pool(self, pool: torch.cuda.MemPool) -> None: ... + def deregister_mem_pool(self, pool: torch.cuda.MemPool) -> None: ... + def comm_split_count(self) -> int: ... + def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ... + def abort(self) -> None: ... + def _is_initialized(self) -> bool: ... + @property + def uid(self) -> int: ... + @property + def options(self) -> Options: ... # type: ignore[override] + @staticmethod + def get_build_nccl_version(self) -> tuple[int, int, int]: ... + @staticmethod + def get_runtime_nccl_version(self) -> tuple[int, int, int]: ... + +class ProcessGroupUCC(Backend): + def __init__( + self, + store: Store, + rank: int, + size: int, + timeout: timedelta, + ) -> None: ... + +class ProcessGroupMPI(Backend): + def __init__( + self, + rank: int, + size: int, + pgComm: int, + ) -> None: ... + @staticmethod + def create(ranks: list[int]) -> ProcessGroupMPI: ... + +def _compute_bucket_assignment_by_size( + tensors: list[Tensor], + bucket_size_limits: list[int], + expect_sparse_gradient: list[bool] = ..., + tensor_indices: list[int] = ..., +) -> tuple[list[list[int]], list[int]]: ... +def _broadcast_coalesced( + process_group: ProcessGroup, + tensors: list[Tensor], + buffer_size: int, + src: int, +): ... +def _test_python_store(store: Store): ... +def _verify_params_across_processes( + process_group: ProcessGroup, + params: list[Tensor], + logger: Logger | None, +): ... +def _make_nccl_premul_sum(factor: float | list[Tensor]) -> ReduceOp: ... +def _register_process_group( + group_name: str, + process_group: ProcessGroup, +) -> None: ... +def _resolve_process_group(group_name: str) -> ProcessGroup: ... +def _register_work(tensor: torch.Tensor, work: Work) -> ProcessGroup: ... +def _get_work_registry_size() -> int: ... +def _set_allow_inflight_collective_as_graph_input( + value: bool, +) -> None: ... +def _allow_inflight_collective_as_graph_input() -> bool: ... +def _unregister_all_process_groups() -> None: ... +def _unregister_process_group(group_name: str) -> None: ... + +# Intializes the device state in CUmodule so that it’s able to perform NVSHMEM +# operations. CUmodule is a pointer to a CUDA module, carried by a int64 in +# Python. At C++ interface, it is converted to a uintptr_t. +def _nvshmemx_cumodule_init(module: int) -> None: ... + +# Check if NVSHMEM is available on current system. +def _is_nvshmem_available() -> bool: ... + +class _SymmetricMemory: + @staticmethod + def set_group_info( + group_name: str, + rank: int, + world_size: int, + store: Store, + ) -> None: ... + @staticmethod + def empty_strided_p2p( + size: torch.types._size, + stride: torch.types._size, + dtype: torch.dtype, + device: torch.device, + group_name: str | None = None, + alloc_id: int | None = None, + ) -> torch.Tensor: ... + @staticmethod + def has_multicast_support( + device_type: DeviceType, + device_idx: int, + ) -> bool: ... + @property + def rank(self) -> int: ... + @property + def world_size(self) -> int: ... + @staticmethod + def rendezvous( + tensor: torch.Tensor, group_name: str | None = None + ) -> _SymmetricMemory: ... + def get_buffer( + self, + rank: int, + sizes: torch.types._size, + dtype: torch.dtype, + storage_offset: int | None = 0, + ) -> torch.Tensor: ... + def get_signal_pad( + self, + rank: int, + sizes: torch.types._size = [], + dtype: torch.dtype | None = None, + storage_offset: int | None = 0, + ) -> torch.Tensor: ... + def barrier(self, channel: int = 0, timeout_ms: int = 0) -> None: ... + def put_signal( + self, + dst_rank: int, + channel: int = 0, + timeout_ms: int = 0, + ) -> None: ... + def wait_signal( + self, + src_rank: int, + channel: int = 0, + timeout_ms: int = 0, + ) -> None: ... + @staticmethod + def memset32( + tensor: torch.Tensor, offset: int, val: int, count: int = 1 + ) -> torch.Tensor: ... + @staticmethod + def stream_write_value32( + tensor: torch.Tensor, offset: int, val: int + ) -> torch.Tensor: ... + @property + def buffer_ptrs(self) -> list[int]: ... + @property + def buffer_ptrs_dev(self) -> int: ... + @property + def signal_pad_ptrs(self) -> list[int]: ... + @property + def signal_pad_ptrs_dev(self) -> int: ... + @property + def multicast_ptr(self) -> int: ... + @property + def buffer_size(self) -> int: ... + @property + def signal_pad_size(self) -> int: ... + +class ProcessGroupXCCL(Backend): + def __init__( + self, + store: Store, + rank: int, + size: int, + ): ... diff --git a/phivenv/Lib/site-packages/torch/_C/_distributed_rpc.pyi b/phivenv/Lib/site-packages/torch/_C/_distributed_rpc.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d533cf88dcb3ca89d7e707161fc89bd4c71062da --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_distributed_rpc.pyi @@ -0,0 +1,188 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="type-arg" +from datetime import timedelta +from typing import Any, Generic, overload, TypeVar + +import torch +from torch._C import Future +from torch._C._autograd import ProfilerEvent +from torch._C._distributed_c10d import Store +from torch._C._profiler import ProfilerConfig + +# This module is defined in torch/csrc/distributed/rpc/init.cpp + +_DEFAULT_INIT_METHOD: str +_DEFAULT_NUM_WORKER_THREADS: int +_UNSET_RPC_TIMEOUT: float +_DEFAULT_RPC_TIMEOUT_SEC: float + +_T = TypeVar("_T") + +class RpcBackendOptions: + rpc_timeout: float + init_method: str + def __init__( + self, + rpc_timeout: float = ..., + init_method: str = ..., + ) -> None: ... + +class WorkerInfo: + def __init__(self, name: str, worker_id: int) -> None: ... + @property + def name(self) -> str: ... + @property + def id(self) -> int: ... + def __eq__(self, other: object) -> bool: ... + +class RpcAgent: + def join(self, shutdown: bool = False, timeout: float = 0): ... + def sync(self): ... + def shutdown(self): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + def get_worker_infos(self) -> list[WorkerInfo]: ... + def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ... + def get_debug_info(self) -> dict[str, str]: ... + def get_metrics(self) -> dict[str, str]: ... + +class PyRRef(Generic[_T]): + def __init__(self, value: _T, type_hint: Any = None) -> None: ... + def is_owner(self) -> bool: ... + def confirmed_by_owner(self) -> bool: ... + def owner(self) -> WorkerInfo: ... + def owner_name(self) -> str: ... + def to_here(self, timeout: float = ...) -> _T: ... + def local_value(self) -> Any: ... + def rpc_sync(self, timeout: float = ...) -> Any: ... + def rpc_async(self, timeout: float = ...) -> Any: ... + def remote(self, timeout: float = ...) -> Any: ... + def _serialize(self) -> tuple: ... + @staticmethod + def _deserialize(tp: tuple) -> PyRRef: ... + def _get_type(self) -> type[_T]: ... + def _get_future(self) -> Future[_T]: ... + def _get_profiling_future(self) -> Future[_T]: ... + def _set_profiling_future(self, profilingFuture: Future[_T]): ... + +class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions): + num_worker_threads: int + device_maps: dict[str, dict[torch.device, torch.device]] + devices: list[torch.device] + def __init__( + self, + num_worker_threads: int, + _transports: list | None, + _channels: list | None, + rpc_timeout: float = ..., + init_method: str = ..., + device_maps: dict[str, dict[torch.device, torch.device]] = {}, # noqa: B006 + devices: list[torch.device] = [], # noqa: B006 + ) -> None: ... + def _set_device_map( + self, + to: str, + device_map: dict[torch.device, torch.device], + ): ... + +class TensorPipeAgent(RpcAgent): + def __init__( + self, + store: Store, + name: str, + worker_id: int, + world_size: int | None, + opts: _TensorPipeRpcBackendOptionsBase, + reverse_device_maps: dict[str, dict[torch.device, torch.device]], + devices: list[torch.device], + ) -> None: ... + def join(self, shutdown: bool = False, timeout: float = 0): ... + def shutdown(self): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + @overload + def get_worker_info(self, id: int) -> WorkerInfo: ... + def get_worker_infos(self) -> list[WorkerInfo]: ... + def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ... + def _update_group_membership( + self, + worker_info: WorkerInfo, + my_devices: list[torch.device], + reverse_device_map: dict[str, dict[torch.device, torch.device]], + is_join: bool, + ): ... + def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ... + @property + def is_static_group(self) -> bool: ... + @property + def store(self) -> Store: ... + +def _is_current_rpc_agent_set() -> bool: ... +def _get_current_rpc_agent() -> RpcAgent: ... +def _set_and_start_rpc_agent(agent: RpcAgent): ... +def _reset_current_rpc_agent(): ... +def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ... +def _destroy_rref_context(ignoreRRefLeak: bool): ... +def _rref_context_get_debug_info() -> dict[str, str]: ... +def _cleanup_python_rpc_handler(): ... +def _invoke_rpc_builtin( + dst: WorkerInfo, + opName: str, + rpcTimeoutSeconds: float, + *args: Any, + **kwargs: Any, +): ... +def _invoke_rpc_python_udf( + dst: WorkerInfo, + pickledPythonUDF: str, + tensors: list[torch.Tensor], + rpcTimeoutSeconds: float, + isAsyncExecution: bool, +): ... +def _invoke_rpc_torchscript( + dstWorkerName: str, + qualifiedNameStr: str, + argsTuple: tuple, + kwargsDict: dict, + rpcTimeoutSeconds: float, + isAsyncExecution: bool, +): ... +def _invoke_remote_builtin( + dst: WorkerInfo, + opName: str, + rpcTimeoutSeconds: float, + *args: Any, + **kwargs: Any, +): ... +def _invoke_remote_python_udf( + dst: WorkerInfo, + pickledPythonUDF: str, + tensors: list[torch.Tensor], + rpcTimeoutSeconds: float, + isAsyncExecution: bool, +): ... +def _invoke_remote_torchscript( + dstWorkerName: WorkerInfo, + qualifiedNameStr: str, + rpcTimeoutSeconds: float, + isAsyncExecution: bool, + *args: Any, + **kwargs: Any, +): ... +def get_rpc_timeout() -> float: ... +def enable_gil_profiling(flag: bool): ... +def _set_rpc_timeout(rpcTimeoutSeconds: float): ... + +class RemoteProfilerManager: + @staticmethod + def set_current_profiling_key(key: str): ... + +def _enable_server_process_global_profiler(new_config: ProfilerConfig): ... +def _disable_server_process_global_profiler() -> list[list[list[ProfilerEvent]]]: ... +def _set_profiler_node_id(default_node_id: int): ... +def _enable_jit_rref_pickle(): ... +def _disable_jit_rref_pickle(): ... diff --git a/phivenv/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi b/phivenv/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bd3520ce71a05e2dcec91f4940172fa08ae141c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_distributed_rpc_testing.pyi @@ -0,0 +1,32 @@ +import torch +from torch._C._distributed_c10d import Store +from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent + +# This module is defined in torch/csrc/distributed/rpc/testing/init.cpp + +class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): + def __init__( + self, + num_worker_threads: int, + rpc_timeout: float, + init_method: str, + messages_to_fail: list[str], + messages_to_delay: dict[str, float], + num_fail_sends: int, + ) -> None: ... + num_send_recv_threads: int + messages_to_fail: list[str] + messages_to_delay: dict[str, float] + num_fail_sends: int + +class FaultyTensorPipeAgent(TensorPipeAgent): + def __init__( + self, + store: Store, + name: str, + rank: int, + world_size: int, + options: FaultyTensorPipeRpcBackendOptions, + reverse_device_maps: dict[str, dict[torch.device, torch.device]], + devices: list[torch.device], + ) -> None: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_dynamo/__init__.pyi b/phivenv/Lib/site-packages/torch/_C/_dynamo/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..2b60c3c147eba84e940a51ce988a6f76454e984f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_dynamo/__init__.pyi @@ -0,0 +1,4 @@ +from . import compiled_autograd, eval_frame, guards # noqa: F401 + +def strip_function_call(name: str) -> str: ... +def is_valid_var_name(name: str) -> bool | int: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_dynamo/compiled_autograd.pyi b/phivenv/Lib/site-packages/torch/_C/_dynamo/compiled_autograd.pyi new file mode 100644 index 0000000000000000000000000000000000000000..648a5e61e7c6b47027245015d6f6ff9fd6e6fb2a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_dynamo/compiled_autograd.pyi @@ -0,0 +1,13 @@ +from typing import Callable + +from torch import Tensor +from torch._dynamo.compiled_autograd import AutogradCompilerInstance + +def set_autograd_compiler( + autograd_compiler: Callable[[], AutogradCompilerInstance] | None, + dynamic: bool, +) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ... +def clear_cache() -> None: ... +def is_cache_empty() -> bool: ... +def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ... +def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_dynamo/eval_frame.pyi b/phivenv/Lib/site-packages/torch/_C/_dynamo/eval_frame.pyi new file mode 100644 index 0000000000000000000000000000000000000000..55afcba2d834cacffe5c39554aad96ac72b19dc3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_dynamo/eval_frame.pyi @@ -0,0 +1,71 @@ +import enum +import types +from typing import Optional, overload + +from torch._dynamo.types import ( + DynamoCallback, + DynamoGuardCompleteHook, + DynamoGuardHook, + GuardFn, +) + +def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... +def set_skip_guard_eval_unsafe(value: bool) -> bool: ... +def get_eval_frame_callback() -> DynamoCallback: ... +def reset_code(code: types.CodeType) -> None: ... +def unsupported(obj1: object, obj2: object) -> object: ... +def set_code_exec_strategy( + code: types.CodeType, strategy: _FrameExecStrategy +) -> None: ... +def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... +def set_guard_complete_hook( + hook: Optional[DynamoGuardCompleteHook], +) -> Optional[DynamoGuardCompleteHook]: ... +def raise_sigtrap() -> None: ... + +class _CacheEntry: + def check_fn(self, *args: object, **kwargs: object) -> bool: ... + code: types.CodeType + next: _CacheEntry | None + +class _ExtraState: + def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ... + +class _FrameAction(enum.IntEnum): + DEFAULT = 0 + SKIP = 1 + RUN_ONLY = 2 + +class _FrameExecStrategy: + cur_action: _FrameAction + recursive_action: _FrameAction + + @overload + def __init__(self) -> None: ... + @overload + def __init__( + self, cur_action: _FrameAction, recursive_action: _FrameAction + ) -> None: ... + +# This is an object that encapsulates the Python FrameType, and exposes +# properties Dynamo cares about for a frame. +class _PyInterpreterFrame: + f_code: types.CodeType + f_locals: dict[str, object] + f_globals: dict[str, object] + f_builtins: dict[str, object] + f_lasti: int + f_lineo: int + f_back: types.FrameType + # A tuple containing cell objects captured by this frame. + closure: tuple[types.CellType] + +def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ... + +py_opcode_caches: list[int] + +def code_framelocals_names(code: types.CodeType) -> tuple[str]: ... +def _load_precompile_entry( + code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType +) -> None: ... +def _reset_precompile_entries(code: types.CodeType) -> None: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_dynamo/guards.pyi b/phivenv/Lib/site-packages/torch/_C/_dynamo/guards.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bd79bc17ace992f0117d2087c6c108536a523a64 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_dynamo/guards.pyi @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable + +import torch + +class GlobalStateGuard: + def check(self) -> bool: ... + def reason(self) -> str: ... + +class LeafGuard: ... +class GuardDebugInfo: ... + +class GuardManager: + def check(self, value) -> bool: ... + def check_verbose(self, value) -> GuardDebugInfo: ... + + # Accessors + def globals_dict_manager( + self, + f_globals: dict[str, Any], + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def framelocals_manager( + self, + key: tuple[str, int], + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def dict_getitem_manager( + self, + key, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def global_weakref_manager( + self, + global_name: str, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def type_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def getattr_manager( + self, + attr: str, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tensor_property_size_manager( + self, + idx: int, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tensor_property_shape_manager( + self, + idx: int, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tensor_property_storage_offset_manager( + self, + idx: None, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def indexed_manager( + self, + idx: int, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def lambda_manager( + self, + python_lambda, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + + # Leaf guards + def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ... + def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ... + def add_equals_match_guard( + self, + equals_val, + verbose_code_parts: list[str], + ) -> None: ... + def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... + def add_torch_function_mode_stack_guard( + self, initial_stack, verbose_code_parts: list[str] + ) -> None: ... + def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ... + +class RootGuardManager(GuardManager): + def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ... + def add_epilogue_lambda_guard( + self, + guard: LeafGuard, + verbose_code_parts: list[str], + ) -> None: ... + def clone_manager( + self, clone_filter_fn: Callable[[GuardManager], bool] + ) -> RootGuardManager: ... + +class DictGuardManager(GuardManager): + def get_key_manager( + self, + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def get_value_manager( + self, + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + +def install_object_aliasing_guard( + guard_managers: list[GuardManager], + tensor_names: list[str], + verbose_code_parts: list[str], +): ... +def install_no_tensor_aliasing_guard( + guard_managers: list[GuardManager], + tensor_names: list[str], + verbose_code_parts: list[str], +): ... +def install_storage_overlapping_guard( + overlapping_guard_managers: list[GuardManager], + non_overlapping_guard_managers: list[GuardManager], + verbose_code_parts: list[str], +): ... +def install_symbolic_shape_guard( + guard_managers: list[GuardManager], + nargs_int: int, + nargs_float: int, + py_addr: int, + py_addr_keep_alive: Any, + verbose_code_parts: list[str], +): ... +def profile_guard_manager( + guard_manager: GuardManager, + f_locals: dict[str, Any], + n_iters: int, +) -> float: ... + +class TensorGuards: + def __init__( + self, + *, + dynamic_dims_sizes: list[torch.SymInt | None] | None = None, + dynamic_dims_strides: list[torch.SymInt | None] | None = None, + ) -> None: ... + def check(self, *args) -> bool: ... + def check_verbose(self, *args, tensor_check_names=None) -> bool | str: ... + +def assert_size_stride( + item: torch.Tensor, + size: torch.types._size, + stride: torch.types._size, + op_name: str | None = None, +): ... +def assert_alignment( + item: torch.Tensor, + alignment: int, + op_name: str | None = None, +): ... +def check_obj_id(obj: object, expected: int) -> bool: ... +def check_type_id(obj: object, expected: int) -> bool: ... +def dict_version(d: dict[Any, Any]) -> int: ... +def compute_overlapping_tensors( + tensors: list[torch.Tensor], symbolic: bool = True +) -> set[int]: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_export/__init__.pyi b/phivenv/Lib/site-packages/torch/_C/_export/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..b590ec0f7e482f873b7c9be6b26c5ac80d94fb18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_export/__init__.pyi @@ -0,0 +1,9 @@ +# Defined in torch/csrc/export/pybind.cpp +class CppExportedProgram: ... + +def deserialize_exported_program( + serialized_program: str, +) -> CppExportedProgram: ... +def serialize_exported_program( + cpp_exported_program: CppExportedProgram, +) -> str: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_export/pt2_archive_constants.pyi b/phivenv/Lib/site-packages/torch/_C/_export/pt2_archive_constants.pyi new file mode 100644 index 0000000000000000000000000000000000000000..3bff0768c9925b697ac65c8ac18289db46691917 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_export/pt2_archive_constants.pyi @@ -0,0 +1,22 @@ +# Defined in torch/csrc/export/pt2_archive_constants.h + +ARCHIVE_ROOT_NAME: str = ... +ARCHIVE_FORMAT_PATH: str = ... +ARCHIVE_FORMAT_VALUE: str = ... +ARCHIVE_VERSION_PATH: str = ... +ARCHIVE_VERSION_VALUE: str = ... +MODELS_DIR: str = ... +MODELS_FILENAME_FORMAT: str = ... +AOTINDUCTOR_DIR: str = ... +MTIA_DIR: str = ... +WEIGHTS_DIR: str = ... +WEIGHT_FILENAME_PREFIX: str = ... +CONSTANTS_DIR: str = ... +TENSOR_CONSTANT_FILENAME_PREFIX: str = ... +CUSTOM_OBJ_FILENAME_PREFIX: str = ... +SAMPLE_INPUTS_DIR: str = ... +SAMPLE_INPUTS_FILENAME_FORMAT: str = ... +EXTRA_DIR: str = ... +MODULE_INFO_PATH: str = ... +XL_MODEL_WEIGHTS_DIR: str = ... +XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ... diff --git a/phivenv/Lib/site-packages/torch/_C/_functions.pyi b/phivenv/Lib/site-packages/torch/_C/_functions.pyi new file mode 100644 index 0000000000000000000000000000000000000000..4ffc1c43abc5a62d7be8969be29c25101faed2d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_functions.pyi @@ -0,0 +1,19 @@ +from typing import AnyStr, overload + +from torch import Tensor + +class UndefinedGrad: + def __init__(self) -> None: ... + def __call__(self, *inputs: Tensor) -> list[Tensor]: ... + +class DelayedError: + def __init__(self, msg: AnyStr, num_inputs: int) -> None: ... + + # __call__ should really be a higher-kinded type: + # def __call__(self, arg: Tensor) -> Tensor: ... + # def __call__(self, *args: Tensor * num_inputs) -> Tuple[Tensor * num_inputs]: ... + + @overload + def __call__(self, i0: Tensor) -> Tensor: ... + @overload + def __call__(self, *args: Tensor) -> tuple[Tensor, ...]: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_functorch.pyi b/phivenv/Lib/site-packages/torch/_C/_functorch.pyi new file mode 100644 index 0000000000000000000000000000000000000000..0b96eef2a8ab698824b7a31c22c754cd5eacfbb0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_functorch.pyi @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +from enum import Enum + +from torch import Tensor + +# Defined in torch/csrc/functorch/init.cpp + +def _set_dynamic_layer_keys_included(included: bool) -> None: ... +def get_unwrapped(tensor: Tensor) -> Tensor: ... +def is_batchedtensor(tensor: Tensor) -> bool: ... +def is_functionaltensor(tensor: Tensor) -> bool: ... +def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ... +def is_gradtrackingtensor(tensor: Tensor) -> bool: ... +def is_legacy_batchedtensor(tensor: Tensor) -> bool: ... +def maybe_get_bdim(tensor: Tensor) -> int: ... +def maybe_get_level(tensor: Tensor) -> int: ... +def maybe_current_level() -> int | None: ... +def unwrap_if_dead(tensor: Tensor) -> Tensor: ... +def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ... +def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ... +def _unwrap_batched(tensor: Tensor, level: int) -> tuple[Tensor, int | None]: ... +def current_level() -> int: ... +def count_jvp_interpreters() -> int: ... +def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ... +def set_single_level_autograd_function_allowed(allowed: bool) -> None: ... +def get_single_level_autograd_function_allowed() -> bool: ... +def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ... +def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ... +def _vmap_increment_nesting(batch_size: int, randomness: str) -> int: ... +def _vmap_decrement_nesting() -> int: ... +def _grad_increment_nesting() -> int: ... +def _grad_decrement_nesting() -> int: ... +def _jvp_increment_nesting() -> int: ... +def _jvp_decrement_nesting() -> int: ... + +# Defined in aten/src/ATen/functorch/Interpreter.h +class TransformType(Enum): + Torch = ... + Vmap = ... + Grad = ... + Jvp = ... + Functionalize = ... + +class RandomnessType(Enum): + Error = ... + Same = ... + Different = ... + +class CInterpreter: + def key(self) -> TransformType: ... + def level(self) -> int: ... + def serialize(self) -> bytes: ... + @staticmethod + def deserialize(bytes) -> CInterpreter: ... + +class CGradInterpreterPtr: + def __init__(self, interpreter: CInterpreter) -> None: ... + def lift(self, Tensor) -> Tensor: ... + def prevGradMode(self) -> bool: ... + +class CJvpInterpreterPtr: + def __init__(self, interpreter: CInterpreter) -> None: ... + def lift(self, Tensor) -> Tensor: ... + def prevFwdGradMode(self) -> bool: ... + +class CFunctionalizeInterpreterPtr: + def __init__(self, interpreter: CInterpreter) -> None: ... + def key(self) -> TransformType: ... + def level(self) -> int: ... + def functionalizeAddBackViews(self) -> bool: ... + +class CVmapInterpreterPtr: + def __init__(self, interpreter: CInterpreter) -> None: ... + def key(self) -> TransformType: ... + def level(self) -> int: ... + def batchSize(self) -> int: ... + def randomness(self) -> RandomnessType: ... + +class DynamicLayer: ... + +def get_dynamic_layer_stack_depth() -> int: ... +def get_interpreter_stack() -> list[CInterpreter]: ... +def peek_interpreter_stack() -> CInterpreter: ... +def pop_dynamic_layer_stack() -> DynamicLayer: ... +def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ... +def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_instruction_counter.pyi b/phivenv/Lib/site-packages/torch/_C/_instruction_counter.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d46f7aa9639542b9e6b2cdf3cb9456aa4ec11a64 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_instruction_counter.pyi @@ -0,0 +1,4 @@ +# Defined in torch/csrc/instruction_counter/Module.cpp + +def start() -> int: ... +def end(id: int) -> int: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_itt.pyi b/phivenv/Lib/site-packages/torch/_C/_itt.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a6f2559396fde84b318a768d3e6563ba6be93873 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_itt.pyi @@ -0,0 +1,5 @@ +# Defined in torch/csrc/itt.cpp +def is_available() -> None: ... +def rangePush(message: str) -> None: ... +def rangePop() -> None: ... +def mark(message: str) -> None: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_lazy.pyi b/phivenv/Lib/site-packages/torch/_C/_lazy.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5872a056255779c8a4a9571d65baef3731066b92 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_lazy.pyi @@ -0,0 +1,26 @@ +from torch import Tensor + +# defined in torch/csrc/lazy/python/init.cpp +def _mark_step(device: str, devices: list[str], wait: bool) -> None: ... +def _wait_device_ops(devices: list[str]) -> None: ... +def _reset_metrics() -> None: ... +def _counter_names() -> list[str]: ... +def _counter_value(name: str) -> int: ... +def _metrics_report() -> str: ... +def _get_graph_hash(tensors: list[Tensor]) -> str: ... +def _sync_multi( + tensors: list[Tensor], + devices: list[str], + wait: bool = True, + sync_ltc_data: bool = True, +) -> None: ... +def _get_tensor_id(tensor: Tensor) -> int: ... +def _get_tensors_text(tensors: list[Tensor]) -> str: ... +def _get_tensors_dot(tensors: list[Tensor]) -> str: ... +def _get_tensors_backend(tensors: list[Tensor]) -> str: ... +def _get_force_fallback() -> str: ... +def _set_force_fallback(newval: str) -> None: ... +def _clear_ir_cache() -> None: ... +def _dump_ir_cache(filename: str) -> None: ... +def _set_reuse_ir(val: bool) -> None: ... +def _get_default_device_type() -> str: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi b/phivenv/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a7fc4991e4abd3b8cb554053e6486fe2f3c1ff83 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_lazy_ts_backend.pyi @@ -0,0 +1,12 @@ +# mypy: allow-untyped-defs +# defined in torch/csrc/lazy/python/init.cpp + +from typing import Any + +from torch import Tensor + +def _init(): ... +def _get_tensors_ts_device_data_node( + tensors: list[Tensor], +) -> tuple[list[int], list[Any]]: ... +def _run_cached_graph(hash_str: str, graph_inputs: list[Any]) -> list[Tensor]: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_monitor.pyi b/phivenv/Lib/site-packages/torch/_C/_monitor.pyi new file mode 100644 index 0000000000000000000000000000000000000000..298549aedd15cddbf57593ee92aba6a16a434e4d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_monitor.pyi @@ -0,0 +1,58 @@ +# Defined in torch/csrc/monitor/python_init.cpp + +import datetime +from enum import Enum +from types import TracebackType +from typing import Callable + +class Aggregation(Enum): + VALUE = ... + MEAN = ... + COUNT = ... + SUM = ... + MAX = ... + MIN = ... + +class Stat: + name: str + count: int + def __init__( + self, + name: str, + aggregations: list[Aggregation], + window_size: int, + max_samples: int = -1, + ) -> None: ... + def add(self, v: float) -> None: ... + def get(self) -> dict[Aggregation, float]: ... + +class Event: + name: str + timestamp: datetime.datetime + data: dict[str, int | float | bool | str] + def __init__( + self, + name: str, + timestamp: datetime.datetime, + data: dict[str, int | float | bool | str], + ) -> None: ... + +def log_event(e: Event) -> None: ... + +class EventHandlerHandle: ... + +def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ... +def unregister_event_handler(handle: EventHandlerHandle) -> None: ... + +class _WaitCounterTracker: + def __enter__(self) -> None: ... + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: ... + +class _WaitCounter: + def __init__(self, key: str) -> None: ... + def guard(self) -> _WaitCounterTracker: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_nn.pyi b/phivenv/Lib/site-packages/torch/_C/_nn.pyi new file mode 100644 index 0000000000000000000000000000000000000000..3aae1223b4d5f8d7984c71bc3c99bdc86f109b4a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_nn.pyi @@ -0,0 +1,175 @@ +# @generated by tools/pyi/gen_pyi.py from torch/_C/_nn.pyi.in +# mypy: disable-error-code="type-arg" + +from collections.abc import Sequence +from typing import Literal, overload + +from torch import memory_format, Tensor +from torch.types import _bool, _device, _dtype, _int, _size + +# Defined in tools/autograd/templates/python_nn_functions.cpp + +def adaptive_avg_pool2d(input: Tensor, output_size: _int | _size) -> Tensor: ... +def adaptive_avg_pool3d(input: Tensor, output_size: _int | _size) -> Tensor: ... +def adaptive_max_pool2d( + input: Tensor, + output_size: _int | _size, +) -> tuple[Tensor, Tensor]: ... +def adaptive_max_pool3d( + input: Tensor, + output_size: _int | _size, +) -> tuple[Tensor, Tensor]: ... +def avg_pool2d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None = None, + padding: _int | _size = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: int | None = None, +) -> Tensor: ... +def avg_pool3d( + input: Tensor, + kernel_size: _int | _size, + stride: _int | _size | None = None, + padding: _int | _size = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: int | None = None, +) -> Tensor: ... +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Tensor | None = None, + reduction: str = ..., +) -> Tensor: ... +def col2im( + input: Tensor, + output_size: _int | _size, + kernel_size: _int | _size, + dilation: _int | _size, + stride: _int | _size | None = None, + padding: _int | _size = 0, +) -> Tensor: ... +def elu_(input: Tensor, alpha: float = ...) -> Tensor: ... +def fractional_max_pool2d( + input: Tensor, + kernel_size: _int | _size, + output_size: _int | _size, + _random_samples: Tensor, +) -> tuple[Tensor, Tensor]: ... +def fractional_max_pool3d( + input: Tensor, + kernel_size: _int | _size, + output_size: _int | _size, + _random_samples: Tensor, +) -> tuple[Tensor, Tensor]: ... +def gelu(input: Tensor, approximate: str = ...) -> Tensor: ... +def hardsigmoid(input: Tensor, *, out: Tensor | None = None) -> Tensor: ... +def hardtanh( + input: Tensor, + min_val: float = ..., + max_val: float = ..., + *, + out: Tensor | None = None, +) -> Tensor: ... +def hardtanh_( + input: Tensor, + min_val: float = ..., + max_val: float = ..., +) -> Tensor: ... +def leaky_relu( + input: Tensor, + negative_slope: float = ..., + *, + out: Tensor | None = None, +) -> Tensor: ... +def leaky_relu_(input: Tensor, negative_slope: float = ...) -> Tensor: ... +def linear( + input: Tensor, + weight: Tensor, + bias: Tensor | None = None, +) -> Tensor: ... +def log_sigmoid(input: Tensor) -> Tensor: ... +def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ... +def pad( + input: Tensor, + pad: Sequence[int], + mode: str = ..., + value: float | None = None, +) -> Tensor: ... +def scaled_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, +) -> Tensor: ... +def softplus( + input: Tensor, + beta: float = ..., + threshold: float = ..., +) -> Tensor: ... +def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ... + +# Defined in aten/src/ATen/native/mkldnn/Linear.cpp +def mkldnn_linear(input: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor: ... + +# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp +def mkldnn_reorder_conv2d_weight( + self: Tensor, + padding: list, + stride: list, + dilatation: list, + groups: int, +) -> Tensor: ... +def mkldnn_reorder_conv3d_weight( + self: Tensor, + padding: list, + stride: list, + dilatation: list, + groups: int, +) -> Tensor: ... + +# Defined in aten/src/ATen/native/mkldnn/Prelu.cpp +def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ... + +# Defined at tools/autograd/templates/python_nn_functions.cpp +@overload +def _parse_to( + device: _device, + dtype: _dtype, + non_blocking: _bool, + copy: _bool, + *, + memory_format: memory_format, +) -> tuple[_device, _dtype, _bool, memory_format]: ... +@overload +def _parse_to( + dtype: _dtype, + non_blocking: _bool, + copy: _bool, + *, + memory_format: memory_format, +) -> tuple[_device, _dtype, _bool, memory_format]: ... +@overload +def _parse_to( + tensor: Tensor, + non_blocking: _bool, + copy: _bool, + *, + memory_format: memory_format, +) -> tuple[_device, _dtype, _bool, memory_format]: ... + +# Defined in aten/src/ATen/native/PackedSequence.cpp +def pad_sequence( + sequences: list[Tensor] | tuple[Tensor, ...], + batch_first: bool = False, + padding_value: float = 0.0, + padding_side: Literal["left", "right"] = "right", +) -> Tensor: ... +def flatten_dense_tensors(tensors: list[Tensor]) -> Tensor: ... +def unflatten_dense_tensors(flat: Tensor, tensors: list[Tensor]) -> list[Tensor]: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_nvtx.pyi b/phivenv/Lib/site-packages/torch/_C/_nvtx.pyi new file mode 100644 index 0000000000000000000000000000000000000000..655523e202054d13227fbd55c9c8aeee199c1056 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_nvtx.pyi @@ -0,0 +1,9 @@ +# mypy: allow-untyped-defs +# Defined in torch/csrc/cuda/shared/nvtx.cpp +def rangePushA(message: str) -> int: ... +def rangePop() -> int: ... +def rangeStartA(message: str) -> int: ... +def rangeEnd(int) -> None: ... +def markA(message: str) -> None: ... +def deviceRangeStart(message: str, stream: int) -> object: ... +def deviceRangeEnd(range_handle: object, stream: int) -> None: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_onnx.pyi b/phivenv/Lib/site-packages/torch/_C/_onnx.pyi new file mode 100644 index 0000000000000000000000000000000000000000..cf0f794429c52de08dbbe2b72188a3213d105d53 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_onnx.pyi @@ -0,0 +1,39 @@ +# Defined in torch/csrc/onnx/init.cpp + +from enum import Enum + +PRODUCER_VERSION: str + +class TensorProtoDataType(Enum): + UNDEFINED = ... + FLOAT = ... + UINT8 = ... + INT8 = ... + UINT16 = ... + INT16 = ... + INT32 = ... + INT64 = ... + STRING = ... + BOOL = ... + FLOAT16 = ... + DOUBLE = ... + UINT32 = ... + UINT64 = ... + COMPLEX64 = ... + COMPLEX128 = ... + BFLOAT16 = ... + FLOAT8E5M2 = ... + FLOAT8E4M3FN = ... + FLOAT8E5M2FNUZ = ... + FLOAT8E4M3FNUZ = ... + +class OperatorExportTypes(Enum): + ONNX = ... + ONNX_ATEN = ... + ONNX_ATEN_FALLBACK = ... + ONNX_FALLTHROUGH = ... + +class TrainingMode(Enum): + EVAL = ... + PRESERVE = ... + TRAINING = ... diff --git a/phivenv/Lib/site-packages/torch/_C/_profiler.pyi b/phivenv/Lib/site-packages/torch/_C/_profiler.pyi new file mode 100644 index 0000000000000000000000000000000000000000..62ef0587dd08e80fadcc12922f61b98c20e108be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_profiler.pyi @@ -0,0 +1,246 @@ +from enum import Enum +from typing import Literal +from typing_extensions import TypeAlias + +from torch._C import device, dtype, layout + +# defined in torch/csrc/profiler/python/init.cpp + +class RecordScope(Enum): + FUNCTION = ... + BACKWARD_FUNCTION = ... + TORCHSCRIPT_FUNCTION = ... + KERNEL_FUNCTION_DTYPE = ... + CUSTOM_CLASS = ... + BUILD_FEATURE = ... + LITE_INTERPRETER = ... + USER_SCOPE = ... + STATIC_RUNTIME_OP = ... + STATIC_RUNTIME_MODEL = ... + +class ProfilerState(Enum): + Disable = ... + CPU = ... + CUDA = ... + NVTX = ... + ITT = ... + KINETO = ... + KINETO_GPU_FALLBACK = ... + KINETO_PRIVATEUSE1_FALLBACK = ... + KINETO_PRIVATEUSE1 = ... + +class ActiveProfilerType(Enum): + NONE = ... + LEGACY = ... + KINETO = ... + NVTX = ... + ITT = ... + +class ProfilerActivity(Enum): + CPU = ... + CUDA = ... + XPU = ... + MTIA = ... + HPU = ... + PrivateUse1 = ... + +class _EventType(Enum): + TorchOp = ... + Backend = ... + Allocation = ... + OutOfMemory = ... + PyCall = ... + PyCCall = ... + Kineto = ... + +class _ExperimentalConfig: + def __init__( + self, + profiler_metrics: list[str] = ..., + profiler_measure_per_kernel: bool = ..., + verbose: bool = ..., + performance_events: list[str] = ..., + enable_cuda_sync_events: bool = ..., + ) -> None: ... + +class ProfilerConfig: + def __init__( + self, + state: ProfilerState, + report_input_shapes: bool, + profile_memory: bool, + with_stack: bool, + with_flops: bool, + with_modules: bool, + experimental_config: _ExperimentalConfig, + trace_id: str | None = None, + ) -> None: ... + +class _ProfilerEvent: + start_tid: int + start_time_ns: int + children: list[_ProfilerEvent] + + # TODO(robieta): remove in favor of `self.typed` + extra_fields: ( + _ExtraFields_TorchOp + | _ExtraFields_Backend + | _ExtraFields_Allocation + | _ExtraFields_OutOfMemory + | _ExtraFields_PyCall + | _ExtraFields_PyCCall + | _ExtraFields_Kineto + ) + + @property + def typed( + self, + ) -> ( + tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp] + | tuple[Literal[_EventType.Backend], _ExtraFields_Backend] + | tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation] + | tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory] + | tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall] + | tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall] + | tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto] + ): ... + @property + def name(self) -> str: ... + @property + def tag(self) -> _EventType: ... + @property + def id(self) -> int: ... + @property + def parent(self) -> _ProfilerEvent | None: ... + @property + def correlation_id(self) -> int: ... + @property + def end_time_ns(self) -> int: ... + @property + def duration_time_ns(self) -> int: ... + +class _TensorMetadata: + impl_ptr: int | None + storage_data_ptr: int | None + id: int | None + + @property + def allocation_id(self) -> int | None: ... + @property + def layout(self) -> layout: ... + @property + def device(self) -> device: ... + @property + def dtype(self) -> dtype: ... + @property + def sizes(self) -> list[int]: ... + @property + def strides(self) -> list[int]: ... + +Scalar: TypeAlias = int | float | bool | complex +Input: TypeAlias = _TensorMetadata | list[_TensorMetadata] | Scalar | None + +class _ExtraFields_TorchOp: + name: str + sequence_number: int + allow_tf32_cublas: bool + + @property + def inputs(self) -> list[Input]: ... + @property + def scope(self) -> RecordScope: ... + +class _ExtraFields_Backend: ... + +class _ExtraFields_Allocation: + ptr: int + id: int | None + alloc_size: int + total_allocated: int + total_reserved: int + + @property + def allocation_id(self) -> int | None: ... + @property + def device(self) -> device: ... + +class _ExtraFields_OutOfMemory: ... + +class _PyFrameState: + line_number: int + function_name: str + + @property + def file_name(self) -> str: ... + +class _NNModuleInfo: + @property + def self_ptr(self) -> int: ... + @property + def cls_ptr(self) -> int: ... + @property + def cls_name(self) -> str: ... + @property + def parameters( + self, + ) -> list[tuple[str, _TensorMetadata, _TensorMetadata | None]]: ... + +class _OptimizerInfo: + @property + def parameters( + self, + ) -> list[ + tuple[ + # Parameter + _TensorMetadata, + # + # Gradient (if present during optimizer.step()) + _TensorMetadata | None, + # + # Optimizer state for Parameter as (name, tensor) pairs + list[tuple[str, _TensorMetadata]], + ] + ]: ... + +class _ExtraFields_PyCCall: + @property + def caller(self) -> _PyFrameState: ... + +class _ExtraFields_PyCall: + @property + def callsite(self) -> _PyFrameState: ... + @property + def caller(self) -> _PyFrameState: ... + @property + def module(self) -> _NNModuleInfo | None: ... + @property + def optimizer(self) -> _OptimizerInfo | None: ... + +class _ExtraFields_Kineto: ... + +def _add_execution_trace_observer(output_file_path: str) -> bool: ... +def _remove_execution_trace_observer() -> None: ... +def _enable_execution_trace_observer() -> None: ... +def _disable_execution_trace_observer() -> None: ... +def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ... +def _set_fwd_bwd_enabled_val(val: bool) -> None: ... +def _set_cuda_sync_enabled_val(val: bool) -> None: ... + +class CapturedTraceback: ... + +def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback: ... + +# The Dict has name, filename, line +def symbolize_tracebacks( + to_symbolize: list[CapturedTraceback], +) -> list[list[dict[str, str]]]: ... + +class _RecordFunctionFast: + def __init__( + self, + name: str, + input_values: list | tuple | None = None, + keyword_values: dict | None = None, + ) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, *exc_info: object) -> None: ... diff --git a/phivenv/Lib/site-packages/torch/_C/_verbose.pyi b/phivenv/Lib/site-packages/torch/_C/_verbose.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6d1dbfda288978aa1680412ad24bf488160ba854 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C/_verbose.pyi @@ -0,0 +1,3 @@ +# Defined in torch/csrc/utils/verbose.cpp +def mkl_set_verbose(enable: int) -> int: ... +def mkldnn_set_verbose(level: int) -> int: ... diff --git a/phivenv/Lib/site-packages/torch/_C_flatbuffer/__init__.pyi b/phivenv/Lib/site-packages/torch/_C_flatbuffer/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6ff722fbf70721517fa3a68aa5c292eb5cc787e7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_C_flatbuffer/__init__.pyi @@ -0,0 +1,11 @@ +# mypy: allow-untyped-defs +from torch._C import LiteScriptModule, ScriptModule + +def _load_mobile_module_from_file(filename: str): ... +def _load_mobile_module_from_bytes(bytes_: bytes): ... +def _load_jit_module_from_file(filename: str): ... +def _load_jit_module_from_bytes(bytes_: bytes): ... +def _save_mobile_module(m: LiteScriptModule, filename: str): ... +def _save_jit_module(m: ScriptModule, filename: str): ... +def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ... +def _save_jit_module_to_bytes(m: ScriptModule) -> bytes: ... diff --git a/phivenv/Lib/site-packages/torch/_awaits/__init__.py b/phivenv/Lib/site-packages/torch/_awaits/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0803f93467bf78e4092478e7823a0e31daf840f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_awaits/__init__.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + +import torch + +__all__ = ['Await'] + +W = TypeVar("W") + +class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef] + pass + +class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta): + r""" + Wrapper around a ``torch._C.Await`` which encapsulates delayed execution + of a callable. All manipulations happen with functions ``torch.jit._awaitable``, + ``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``. + + Torch scriptable manipulations: + ``torch.jit._awaitable(func, *args)`` + Creates ``Await[W]`` object, where W is return type of func. + + Returns: + ``torch.jit._awaitable_wait(Await[W])`` + Returns the result of the function, specified at ``_awaitable``, with specified arguments. + + Returns: + The result of type ``W`` of the function call. The result is owned by ``Await[W]`` + and returned on all following ``_awaitable_wait`` calls. + + + ``torch.jit._awaitable_nowait(W)`` + Returns: + Trivial ``Await[W]`` with specified result. + + + Only in eager mode: + ``fn() -> Callable[Tuple[Any], W]`` + Returns: + Specified at ``_awaitable`` python function ``func``. + + ``args() -> Tuple[Any]`` + Returns: + Specified at ``_awaitable`` python args. + + ``is_nowait() -> _bool`` + Returns: + ``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`). + + In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``, + ``_awaitable_wait()`` call will be transparently added. + """ diff --git a/phivenv/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93cadd0bdad02ed7500384b9c0f1022af035297b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_awaits/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_custom_op/__init__.py b/phivenv/Lib/site-packages/torch/_custom_op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a10f171a48441e6d4f1c5bd8d2cbb5faff9dc5f6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_custom_op/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eee697643229c5c469aa869e8e2935d1a070badf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_custom_op/__pycache__/autograd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a09a478be646df3d5a1bf9116901aa787098ec2a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_custom_op/__pycache__/impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_custom_op/autograd.py b/phivenv/Lib/site-packages/torch/_custom_op/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..073c3f712a98db12cc104bcb39c148926349725e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_custom_op/autograd.py @@ -0,0 +1,307 @@ +# mypy: allow-untyped-defs +import functools +from collections import namedtuple + +import torch +import torch.utils._pytree as pytree + + +# NOTE [CustomOp autograd kernel indirection] +# We register `inner` as the autograd kernel for this custom_op. +# `inner` either calls the autograd formula registered by the user, +# or goes into an `autograd_not_implemented` kernel. +# +# The reason why this indirection exists is +# so that we can swap out the autograd kernel (the PyTorch dispatcher +# doesn't actually allow us to do this). By default, we want +# the `autograd_not_implemented` behavior, but then the user may come +# and register something that is actually a backward formula +def autograd_kernel_indirection(custom_op): + autograd_fallback = autograd_not_implemented(custom_op) + + def inner(*args, **kwargs): + if custom_op._has_impl("autograd"): + kernel = custom_op._get_impl("autograd").func + return kernel(*args, **kwargs) + # As explained in NOTE ["backward", "save_for_backward", and "autograd"], + # after the user gives us "backward" and "save_for_backward", we generate + # the "autograd" impl. If the user only provided one, then we tell + # the user they've done something wrong. + if custom_op._has_impl("save_for_backward") or custom_op._has_impl("backward"): + missing = ( + "save_for_backward" if custom_op._has_impl("backward") else "backward" + ) + found = "save_for_backward" if missing == "backward" else "backward" + loc = custom_op._get_impl(found).location + raise RuntimeError( + f"We found a '{found}' registration for {custom_op} at " + f"{loc} but were unable to find a '{missing}' registration. " + f"To use the CustomOp API to register a backward formula, " + f"please provide us both a backward function and a " + f"'save for backward' function via `impl_backward` and " + f"`impl_save_for_backward` respectively." + ) + return autograd_fallback(*args, **kwargs) + + return inner + + +# TODO(#101191): Use the actual C++ autograd not implemented fallback, +# or change the default autograd fallback to the autograd not implemented fallback. +def autograd_not_implemented(custom_op): + def kernel(*args, **kwargs): + if torch.is_grad_enabled() and pytree.tree_any( + lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs) + ): + raise RuntimeError("Autograd has not been implemented for operator") + with torch._C._AutoDispatchBelowAutograd(): + return custom_op(*args, **kwargs) + + return kernel + + +def mark_non_differentiable(ctx, output, output_differentiability): + # Output types are restricted to be: + # - Tensor + # - Tensor[] + # - int, bool, Scalar, float + # See _check_can_register_backward + if output_differentiability is not None: + if not isinstance(output, tuple): + tuple_output = (output,) + else: + tuple_output = output # type: ignore[assignment] + assert len(output_differentiability) == len(tuple_output) + non_differentiable_tensors = [] + for idx, (differentiable, out) in enumerate( + zip(output_differentiability, tuple_output) + ): + if isinstance(out, torch.Tensor): + if not differentiable: + non_differentiable_tensors.append(out) + continue + if isinstance(out, list): + if not differentiable: + non_differentiable_tensors.extend(out) + continue + if differentiable: + raise RuntimeError( + f"With output_differentiability={output_differentiability}. " + f"At idx {idx}, we received an object of type {type(out)} that " + f"is not a Tensor, so it cannot have be marked as differentiable in " + f"output_differentiability." + ) + if non_differentiable_tensors: + ctx.mark_non_differentiable(*non_differentiable_tensors) + + +def construct_autograd_kernel( + schema, + output_differentiability, + custom_op, + op_overload, + save_for_backward_fn, + backward_fn, +): + def apply(*args): + flat_args, spec = pytree.tree_flatten(args) + out_spec = None + + def forward(ctx, *flat_args): + ctx.set_materialize_grads(True) + args = pytree.tree_unflatten(list(flat_args), spec) + with torch._C._AutoDispatchBelowAutograd(): + output = op_overload(*args) + + # We use the info about args to give better error messages in backward + args_info = namedtuple_args(schema, pytree.tree_map(type, args)) + + save_for_backward_fn_inputs = namedtuple_args(schema, args) + to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) + + save_pytree_for_backward(ctx, (to_save, args_info)) + mark_non_differentiable(ctx, output, output_differentiability) + + nonlocal out_spec + flat_output, out_spec = pytree.tree_flatten(output) + return tuple(flat_output) + + def backward(ctx, *flat_grad_output): + assert out_spec is not None + grads = pytree.tree_unflatten(list(flat_grad_output), out_spec) + saved, args_info = unpack_saved(ctx) + # There is nothing on the ctx object for now, it is just there so + # that we can add additional things in the future. + inner_ctx = object() + if not isinstance(grads, tuple): + grads = (grads,) + grad_inputs_dict = backward_fn(inner_ctx, saved, *grads) + + # Massage the grad_inputs_dict to a form acceptable by + # autograd.Function. + validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info) + return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) + + generated_cls = gen_autograd_function( + custom_op._opname + "_customop", forward, backward + ) + + flat_output = generated_cls.apply(*flat_args) + assert out_spec is not None + return pytree.tree_unflatten(list(flat_output), out_spec) + + return apply + + +def gen_autograd_function(name, forward, backward): + generated_cls = type( + name, + (torch.autograd.Function,), + { + "forward": staticmethod(forward), + "backward": staticmethod(backward), + }, + ) + return generated_cls + + +@functools.lru_cache +def namedtuple_args_cls(schema): + attribs = [arg.name for arg in schema.arguments.flat_all] + name = str(schema.name) + "_args" + # mypy doesn't support dynamic namedtuple name + tuple_cls = namedtuple(name, attribs) # type: ignore[misc] + return tuple_cls + + +def namedtuple_args(schema, args): + assert isinstance(args, tuple) + tuple_cls = namedtuple_args_cls(schema) + return tuple_cls(*args) + + +def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info): + def error(what): + backward = forward_op._get_impl("backward") + raise RuntimeError( + f"In the backward function defined for {forward_op} at " + f"{backward.location} using the CustomOp API, {what}" + ) + + if not isinstance(grad_inputs_dict, dict): + error( + f"expected the output of the backward function to be a dict but " + f"got {type(grad_inputs_dict)}" + ) + + expected_keys = { + arg.name + for arg in forward_op._schema.arguments.flat_all + if arg.type.is_tensor_like() + } + actual_keys = grad_inputs_dict.keys() + if expected_keys != actual_keys: + error( + f"expected the returned grad_input dict to have keys " + f"{expected_keys} but got {actual_keys}. The backward " + f"function must return a gradient (can be None) for each arg " + f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " + f"Args declared to be non-Tensor-like types should not appear " + f"in the grad_input dict" + ) + + for name, grad in grad_inputs_dict.items(): + arg_info = getattr(args_info, name) + + if isinstance(arg_info, list): + if not isinstance(grad, (tuple, list)): + error( + f"for input '{name}' expected the grad_input dict to " + f"hold a list of gradients but got object of type " + f"{type(grad)}." + ) + if not len(grad) == len(arg_info): + error( + f"for input '{name}' expected the grad_input dict to " + f"hold a list of {len(arg_info)} gradients but got " + f"{len(grad)}" + ) + for idx, (g, info) in enumerate(zip(grad, arg_info)): + if g is None: + continue + if not isinstance(g, torch.Tensor): + error( + f"for input '{name}' expected the grad_input dict to " + f"hold a list of None or Tensor gradients but got " + f"object of {type(g)} at index {idx}" + ) + if not issubclass(info, torch.Tensor): + error( + f"for input '{name}', got a Tensor as the gradient " + f"for the {idx}-th value but expected None because " + f"the {idx}-th value was not a Tensor (it was " + f"type {arg_info}" + ) + continue + + if grad is None: + continue + if not isinstance(grad, torch.Tensor): + error( + f"got object of type {type(grad)} as the gradient for input " + f"'{name}', " + f"but expected the gradient to be either None or a Tensor" + ) + if not issubclass(arg_info, torch.Tensor): + error( + f"got a Tensor as the gradient for input '{name}' but " + f"expected None as the gradient because input '{name}' " + f"was not a Tensor (it was type {arg_info})." + ) + + +def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): + result = [] + for name, arg_info in args_info._asdict().items(): + if name not in grad_inputs_dict: + result.append(pytree.tree_map(lambda x: None, arg_info)) + continue + result.append(grad_inputs_dict[name]) + return tuple(pytree.tree_leaves(result)) + + +# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it. +# autograd.Function prefers that users use ctx.save_for_backward to +# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the +# ctx object. +def save_pytree_for_backward(ctx, stuff): + flat_stuff, spec = pytree.tree_flatten(stuff) + num_elts = len(flat_stuff) + tensor_idxs = [ + idx for idx, thing in enumerate(flat_stuff) if isinstance(thing, torch.Tensor) + ] + non_tensor_idxs = [ + idx + for idx, thing in enumerate(flat_stuff) + if not isinstance(thing, torch.Tensor) + ] + tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)] + non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)] + + ctx.spec = spec + ctx.num_elts = num_elts + ctx.save_for_backward(*tensors) + ctx.tensor_idxs = tensor_idxs + ctx.saved_non_tensors = non_tensors + ctx.non_tensor_idxs = non_tensor_idxs + + +# Inverse operation to save_pytree_for_backward +def unpack_saved(ctx): + flat_stuff = [None] * ctx.num_elts + for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs): + flat_stuff[idx] = tensor + for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs): + flat_stuff[idx] = non_tensor + stuff = pytree.tree_unflatten(flat_stuff, ctx.spec) + return stuff diff --git a/phivenv/Lib/site-packages/torch/_custom_op/impl.py b/phivenv/Lib/site-packages/torch/_custom_op/impl.py new file mode 100644 index 0000000000000000000000000000000000000000..4b35931b6eeafac0147833381332dd3252f6c9ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_custom_op/impl.py @@ -0,0 +1,715 @@ +# mypy: allow-untyped-defs +import dataclasses +import functools +import inspect +import sys +import typing +import warnings +import weakref + +import torch +import torch._C as _C +import torch._library.infer_schema +import torch.library as library +from torch._library.infer_schema import infer_schema +from torch.library import get_ctx +from torchgen.model import ( + BaseTy, + BaseType, + FunctionSchema, + ListType, + OperatorName, + SchemaKind, +) + +from .autograd import autograd_kernel_indirection, construct_autograd_kernel + + +""" +torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library. +Please use those APIs instead. +""" + +__all__ = ["custom_op", "CustomOp", "get_ctx"] + + +SUPPORTED_DEVICE_TYPE_TO_KEY = { + "cpu": "CPU", + "cuda": "CUDA", +} + +# We will not let users register CustomOps with anything that could look like +# PyTorch internals to avoid confusion. +RESERVED_NS = { + "prim", + "prims", + "aten", + "at", + "torch", + "pytorch", +} + + +def warn_deprecated(): + warnings.warn( + "torch._custom_op is deprecated and will be removed in PyTorch 2.6, please " + "use the equivalent torch.library API instead.", + DeprecationWarning, + ) + + +def custom_op( + qualname: str, manual_schema: typing.Optional[str] = None +) -> typing.Callable: + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + warn_deprecated() + + def inner(func): + if not inspect.isfunction(func): + raise ValueError( + f"custom_op(...)(func): Expected `func` to be a Python " + f"function, got: {type(func)}" + ) + + ns, name = parse_qualname(qualname) + validate_namespace(ns) + if func.__name__ != name: + raise ValueError( + f"custom_op(qualname='{qualname}', ...)(func): expected `func` " + f"to have name '{name}' but got '{func.__name__}'. " + f"Please either change the name of `func` or the qualname that " + f"is passed to `custom_op`" + ) + + schema = ( + infer_schema(func, mutates_args=()) + if manual_schema is None + else manual_schema + ) + schema_str = f"{name}{schema}" + function_schema = FunctionSchema.parse(schema_str) + validate_schema(function_schema) + if manual_schema is not None: + validate_function_matches_schema(function_schema, func) + + lib = library.Library(ns, "FRAGMENT") + lib.define(schema_str) + ophandle = find_ophandle_or_throw(ns, function_schema.name) + result = CustomOp( + lib, ns, function_schema, name, ophandle, _private_access=True + ) + + result.__name__ = func.__name__ + result.__module__ = func.__module__ + result.__doc__ = func.__doc__ + + library.impl(lib, result._opname, "Autograd")( + autograd_kernel_indirection(weakref.proxy(result)) + ) + + torch._C._dispatch_set_report_error_callback( + ophandle, functools.partial(report_error_callback, weakref.proxy(result)) + ) + + return result + + return inner + + +# Global dictionary holding references to all CustomOp objects +# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime]) +# Used to query the CustomOp associated with a specific C++ dispatcher operator. +# An example usage is FakeTensor: FakeTensor checks if a specific operator +# has an implementation registered via the CustomOp API. +# Indexed by qualname (e.g. aten::foo) +global_registry: dict[str, "CustomOp"] = {} + + +class CustomOp: + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + + def __init__( + self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False + ): + super().__init__() + warn_deprecated() + if not _private_access: + raise RuntimeError( + "The CustomOp constructor is private and we do not guarantee " + "BC for it. Please use custom_op(...) to create a CustomOp object" + ) + name = f"{cpp_ns}::{operator_name}" + self._schema = schema + self._cpp_ns = cpp_ns + self._lib: library.Library = lib + self._ophandle: _C._DispatchOperatorHandle = ophandle + # Has the name of the op, e.g. "foo". We cache here for convenience. + self._opname: str = operator_name + # this is _opname but with namespace. e.g. "custom::foo" + self._qualname: str = name + self.__name__ = None # mypy requires this + # NB: Some of these impls are registered as kernels to DispatchKeys. + # Modifying the _impls dict directly won't do anything in that case. + self._impls: dict[str, typing.Optional[FuncAndLocation]] = {} + # See NOTE [CustomOp autograd kernel indirection] + self._registered_autograd_kernel_indirection = False + + global_registry[self._qualname] = self + + def _register_autograd_kernel_indirection(self): + assert not self._registered_autograd_kernel_indirection + self._lib.impl( + self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd" + ) + self._registered_autograd_kernel_indirection = True + + # Records the impl and the source location in self._impls + # Note that this doesn't cause torch.library to use the impl, that + # needs to be done in a separate self._lib.impl call. + def _register_impl(self, kind, func, stacklevel=2): + if self._has_impl(kind): + func_and_location = self._impls[kind] + assert func_and_location is not None # Pacify mypy + location = func_and_location.location + raise RuntimeError( + f"Attempting to register a {kind} impl for operator {self._qualname} " + f"that already has a {kind} impl registered from Python at " + f"{location}. This is not supported." + ) + frame = inspect.getframeinfo(sys._getframe(stacklevel)) + location = f"{frame.filename}:{frame.lineno}" + self._impls[kind] = FuncAndLocation(func, location) + + def _get_impl(self, kind): + return self._impls[kind] + + def _has_impl(self, kind): + return kind in self._impls + + def _destroy(self): + # NOTE: [CustomOp lifetime] + # A CustomOp, once created, lives forever. The mechanism is that the + # global registry holds a reference to it. However, to make testing + # easier, we want to be able to destroy CustomOp objects. + # CustomOp._destroy does the job, though it leaves the CustomOp + # in a garbage state. + del self._lib + + opnamespace = getattr(torch.ops, self._cpp_ns) + if hasattr(opnamespace, self._opname): + delattr(opnamespace, self._opname) + + del global_registry[self._qualname] + + def __repr__(self): + return f'' + + def __call__(self, *args, **kwargs): + # Bypass torch.ops.* and directly do OperatorHandle::callBoxed. + # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime + # issues from caching operators that make testing CustomOp difficult). + result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs) + return result + + def impl( + self, + device_types: typing.Union[str, typing.Iterable[str]], + _stacklevel=2, + ) -> typing.Callable: + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + if isinstance(device_types, str): + device_types = [device_types] + for device_type in device_types: + validate_device_type(device_type) + + def inner(f): + for device_type in set(device_types): + self._check_doesnt_have_library_impl(device_type) + self._register_impl(device_type, f, stacklevel=_stacklevel) + dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] + library.impl(self._lib, self._opname, dispatch_key)(f) + return f + + return inner + + def _check_doesnt_have_library_impl(self, device_type): + if self._has_impl(device_type): + return + key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] + if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key): + raise RuntimeError( + f"impl(..., device_types={device_type}): the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing torch.library or TORCH_LIBRARY registration." + ) + + def impl_factory(self) -> typing.Callable: + r"""Register an implementation for a factory function.""" + + def inner(f): + self._register_impl("factory", f) + library.impl(self._lib, self._opname, "BackendSelect")(f) + return f + + return inner + + def impl_abstract(self, _stacklevel=2) -> typing.Callable: + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + + def inner(f): + self._check_doesnt_have_library_meta_impl() + self._register_impl("abstract", f, stacklevel=_stacklevel) + location = self._get_impl("abstract").location + + qualname = self._qualname + + # Handle DispatchKey.Meta registration + @functools.wraps(f) + def f_with_ctx(*args, **kwargs): + def error_on_ctx(): + raise RuntimeError( + f"Attempted to call get_ctx() for the meta implementation " + f"for {qualname}." + f"You have presumably called get_ctx() because the operator " + f"has a data-dependent output shape; if so, there is no " + f"such meta implementation and this error is the correct " + f"behavior. Otherwise, please remove the call to get_ctx() " + f"in the implementation registered with impl_abstract " + f"at {location}" + ) + + with torch._library.fake_impl.set_ctx_getter(error_on_ctx): + return f(*args, **kwargs) + + self._lib.impl(self._opname, f_with_ctx, "Meta") + return f + + return inner + + def _check_can_register_backward(self): + def error(detail): + raise RuntimeError( + f"Cannot use torch._custom_ops APIs to register backward " + f"formula for {detail}. Got operator " + f"{self._qualname} with schema: {schema}" + ) + + schema = self._schema + if schema.kind() != SchemaKind.functional: + error("non-functional operator") + + rets = schema.returns + if not schema.returns: + error("operator with no returns") + + assert len(rets) > 0 + is_non_mutating_view = any( + r.annotation is not None and not r.annotation.is_write for r in rets + ) + if is_non_mutating_view: + error("operator that returns views") + + # We make assumptions about the schema's return types. + allowed_return_types = { + BaseType(BaseTy.int): "int", + BaseType(BaseTy.SymInt): "SymInt", + BaseType(BaseTy.bool): "bool", + BaseType(BaseTy.float): "float", + BaseType(BaseTy.Tensor): "Tensor", + ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]", + } + for ret in schema.returns: + if ret.type in allowed_return_types: + continue + error( + f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})" + ) + + def _check_doesnt_have_library_autograd_impl(self): + if self._registered_autograd_kernel_indirection: + return + + if _C._dispatch_has_kernel_for_dispatch_key( + self._qualname, "CompositeImplicitAutograd" + ): + raise RuntimeError( + f"impl_backward/impl_save_for_backward: the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an autograd formula; " + f"instead, the operator will decompose into its constituents and those " + f"can have autograd formulas defined on them." + ) + + # We can improve this by adding "all Autograd keys", but + # realistically people will just be using this API for CPU/CUDA for now. + for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]: + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key): + raise RuntimeError( + f"impl_backward/impl_save_for_backward: " + f"the operator {self._qualname} already has an Autograd kernel " + f"registered to DispatchKey::{key} vi a pre-existing " + f"torch.library or TORCH_LIBRARY registration. Please either " + f"remove those registrations or don't use the torch._custom_ops APIs" + ) + + def _check_doesnt_have_library_meta_impl(self): + if self._has_impl("abstract"): + return + + # If the user's operator is CompositeExplicitAutograd, + # allow them to impl_abstract. This is being pragmatic + # (existing custom ops may have CompositeExplicitAutograd + # registration that don't work with Meta kernels, so this + # gives them an escape hatch). + if _C._dispatch_has_kernel_for_dispatch_key( + self._qualname, "CompositeExplicitAutograd" + ) and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"): + return + + # Otherwise, if the user's already has a Meta kernel or their + # op is CompositeImplicitAutograd or some other alias dispatch key, + # raise. + + # Special case for CompositeImplicitAutograd + if _C._dispatch_has_kernel_for_dispatch_key( + self._qualname, "CompositeImplicitAutograd" + ): + raise RuntimeError( + f"impl_abstract(...): the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an abstract impl; " + f"instead, the operator will decompose into its constituents and those " + f"can have abstract impls defined on them." + ) + + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"): + raise RuntimeError( + f"impl_abstract(...): the operator {self._qualname} " + f"already has an DispatchKey::Meta implementation via a " + f"pre-existing torch.library or TORCH_LIBRARY registration. " + f"Please either remove that registration or don't call impl_abstract." + ) + + # NOTE ["backward", "save_for_backward", and "autograd"] + # As a part of the explicit autograd API, a user must provide us + # a "save_for_backward" function and a "backward" function. + # When both of these have been provided, then we automatically + # construct the "autograd" kernel. + def _register_autograd_kernel(self): + assert self._has_impl("backward") + assert self._has_impl("save_for_backward") + kernel = construct_autograd_kernel( + self._schema, + self._output_differentiability, + self, + get_op(self._qualname), + self._get_impl("save_for_backward").func, + self._get_impl("backward").func, + ) + self._register_impl("autograd", kernel) + + def impl_save_for_backward(self, _stacklevel=2): + r"""Register a function that tells us what to save for backward. + + Please see impl_backward for more details. + """ + + def inner(f): + self._check_can_register_backward() + self._check_doesnt_have_library_autograd_impl() + if not self._registered_autograd_kernel_indirection: + self._register_autograd_kernel_indirection() + self._register_impl("save_for_backward", f, stacklevel=_stacklevel) + if self._has_impl("backward"): + self._register_autograd_kernel() + + return inner + + def impl_backward(self, output_differentiability=None, _stacklevel=2): + r""" + This API is deprecated, please use torch.library.custom_op instead + """ + if output_differentiability is not None: + + def yell(): + raise RuntimeError( + f"impl_backward(output_differentiability): expected " + f"output_differentiability to be a list of bools with " + f"length equal to the number of outputs of this CustomOp " + f"got: {output_differentiability}" + ) + + if not isinstance(output_differentiability, list): + yell() + for diff in output_differentiability: + if not isinstance(diff, bool): + yell() + if len(self._schema.returns) != len(output_differentiability): + yell() + + def inner(f): + self._check_can_register_backward() + self._check_doesnt_have_library_autograd_impl() + if not self._registered_autograd_kernel_indirection: + self._register_autograd_kernel_indirection() + self._register_impl("backward", f, stacklevel=_stacklevel) + self._output_differentiability = output_differentiability + if self._has_impl("save_for_backward"): + self._register_autograd_kernel() + + return inner + + +@dataclasses.dataclass +class FuncAndLocation: + func: typing.Callable + location: str + + +def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName): + overload_name = ( + "" if operator_name.overload_name is None else operator_name.overload_name + ) + return _C._dispatch_find_schema_or_throw( + f"{cpp_ns}::{str(operator_name.name)}", overload_name + ) + + +def validate_namespace(ns: str) -> None: + if "." in ns: + raise ValueError( + f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a ' + f"valid variable name)" + ) + if ns in RESERVED_NS: + raise ValueError( + f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, " + f"please choose something else. " + ) + + +def validate_schema(schema: FunctionSchema) -> None: + if not torch._library.utils.is_functional_schema(schema): + raise ValueError( + f"custom_op only supports functional operators " + f"(ops that do not mutate any inputs, do not return " + f"views of the inputs, and has at least one return). " + f"Got the following non-functional schema: {schema}" + ) + + # For simplicity: don't allow self arguments + if schema.arguments.self_arg is not None: + raise ValueError( + f"custom_op does not support arguments named 'self'. Please " + f"rename your argument. Got: {schema}" + ) + + +def parse_qualname(qualname: str) -> tuple[str, str]: + names = qualname.split("::", 1) + if len(names) != 2: + raise ValueError( + f"Expected there to be a namespace in {qualname}, i.e. The " + f"operator name should look something like ns::foo" + ) + if "." in names[1]: + raise ValueError( + f"The torch.custom_ops APIs do not handle overloads, " + f"i.e. operator names with '.' in them. " + f"Please name your operator something like ns::foo. " + f"Got: {qualname}" + ) + return names[0], names[1] + + +def validate_device_type(device_type: str) -> None: + if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY: + raise ValueError( + f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type " + f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}." + ) + + +def supported_param(param: inspect.Parameter) -> bool: + return param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + + +def validate_function_matches_schema( + schema: FunctionSchema, func: typing.Callable +) -> None: + sig = inspect.signature(func) + + if not all(supported_param(p) for _, p in sig.parameters.items()): + raise ValueError( + f"custom_op(..., manual_schema)(func): positional-only args, " + f"varargs, and kwargs are not supported. Please rewrite `func` " + f"to not have them. Got `func` with signature: {sig}" + ) + + if ( + any( + p.annotation is not inspect.Parameter.empty + for _, p in sig.parameters.items() + ) + or sig.return_annotation is not inspect.Signature.empty + ): + raise ValueError( + f"custom_op(..., manual_schema)(func): When passing in a manual " + f"schema, we expect `func` to have no type annotations to avoid " + f"ambiguity. Got `func` with signature: {sig}" + ) + + positional = [ + (name, param) + for name, param in sig.parameters.items() + if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + kwargonly = [ + (name, param) + for name, param in sig.parameters.items() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] + + def error(): + raise ValueError( + f"custom_op(..., manual_schema)(func): When passing in a manual " + f"schema, we expect `func`'s signature to match `manual_schema` " + f"(aside from type annotations). " + f"func's signature: {sig}, manual_schema: {schema}" + ) + + def error_default_args(): + raise ValueError( + f"custom_op(..., manual_schema)(func): " + f"neither func nor manual_schema should have default " + f"arguments. Got " + f"func's signature: {sig}, manual_schema: {schema}" + ) + + def compare(sig_args, schema_args): + if len(sig_args) != len(schema_args): + error() + for (name, param), arg in zip(sig_args, schema_args): + if name != arg.name: + error() + if param.default is not inspect.Parameter.empty or arg.default is not None: + error_default_args() + + compare(positional, schema.arguments.flat_positional) + compare(kwargonly, schema.arguments.flat_kwarg_only) + + +def report_error_callback(custom_op: typing.Any, key: str) -> None: + if key == "Undefined": + raise NotImplementedError( + f"{custom_op}: There were no Tensor inputs to this operator " + f"(e.g. you passed an empty list of Tensors). If your operator is a " + f"factory function (that is, it takes no Tensors and constructs " + f"a new one), then please use CustomOp.impl_factory to register " + f"an implementation for it" + ) + if key == "Meta": + raise NotImplementedError( + f"{custom_op}: when running with device='Meta' tensors: there is no " + f"abstract impl registered for this CustomOp. Please register one via " + f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors" + ) + if key in ("CPU", "CUDA"): + device = key.lower() + raise NotImplementedError( + f"{custom_op}: when running with device='{device}' tensors: there is no " + f"{device} impl registered for this CustomOp. Please register one via " + f"CustomOp.impl(device_type='{device}')" + ) + raise NotImplementedError( + f"{custom_op}: No implementation for dispatch key {key}. It is likely " + f"that we have not added this functionality yet, please either open an " + f"issue or if you're feeling adventurous, use the low-level " + f"torch.library API" + ) + + +def custom_op_from_existing(op): + ns = op.namespace + lib = torch.library.Library(ns, "FRAGMENT") + name = op.name().split("::")[-1] + schema_str = str(op._schema) + # CustomOp expects the schema string without the namespace + schema_str = schema_str.split("::")[-1] + schema = FunctionSchema.parse(schema_str) + return CustomOp(lib, ns, schema, name, op, _private_access=True) + + +def get_op(qualname): + def error_not_found(): + raise ValueError( + f"Could not find the operator {qualname}. Please make sure you have " + f"already registered the operator and (if registered from C++) " + f"loaded it via torch.ops.load_library." + ) + + ns, name = parse_qualname(qualname) + if not hasattr(torch.ops, ns): + error_not_found() + opnamespace = getattr(torch.ops, ns) + if not hasattr(opnamespace, name): + error_not_found() + packet = getattr(opnamespace, name) + if not hasattr(packet, "default"): + error_not_found() + return packet.default + + +def _find_custom_op(qualname, also_check_torch_library=False): + if qualname in global_registry: + return global_registry[qualname] + if not also_check_torch_library: + raise RuntimeError( + f'Could not find custom op "{qualname}". Did you register it via ' + f"the torch._custom_ops API?" + ) + overload = get_op(qualname) + result = custom_op_from_existing(overload) + return result + + +def get_abstract_impl(qualname): + if qualname not in torch._custom_op.impl.global_registry: + return None + custom_op = torch._custom_op.impl.global_registry[qualname] + if custom_op is None: + return None + if not custom_op._has_impl("abstract"): + return None + return custom_op._get_impl("abstract").func + + +def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True): + ns, name = qualname.split("::") + schema_str = f"{name}{schema}" + function_schema = FunctionSchema.parse(schema_str) + validate_schema(function_schema) + tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else [] + lib = library.Library(ns, "FRAGMENT") + lib.define(schema_str, tags=tags) + ophandle = find_ophandle_or_throw(ns, function_schema.name) + result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) + result._register_autograd_kernel_indirection() + + torch._C._dispatch_set_report_error_callback( + ophandle, functools.partial(report_error_callback, weakref.proxy(result)) + ) + return get_op(qualname) diff --git a/phivenv/Lib/site-packages/torch/_decomp/__init__.py b/phivenv/Lib/site-packages/torch/_decomp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d66f64707b907b0bfdc1b2d41449e2f6b9af8bf6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_decomp/__init__.py @@ -0,0 +1,544 @@ +# mypy: allow-untyped-defs +import inspect +from collections import defaultdict +from collections.abc import Sequence +from functools import lru_cache, partial, wraps +from itertools import chain +from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec + + +if TYPE_CHECKING: + from torch.export.decomp_utils import CustomDecompTable + +import torch +import torch.library +from torch._ops import HigherOrderOperator, OperatorBase, OpOverload, OpOverloadPacket +from torch._prims_common import CustomOutParamAnnotation +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.utils import _pytree as pytree + + +__all__ = [ + "decomposition_table", + "pre_autograd_decomposition_table", + "meta_table", + "register_decomposition", + "get_decompositions", + "core_aten_decompositions", + "_should_decompose_because_unsafe_op", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +# TODO: relax key type here; torch registrations should be possible to; but +# right now this type is accurate +global_decomposition_table: dict[str, dict[torch._ops.OperatorBase, Callable]] = ( + defaultdict(dict) +) + +decomposition_table = global_decomposition_table["post_autograd"] +pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"] +meta_table = global_decomposition_table["meta"] + + +def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool: + """ + Returns True if the op must always decompose in export/compile tracing system + + In export, we always decompose certain CIA ops that are tagged with + maybe_aliasing_or_mutating because we statically need to know if the op is + mutating or not. But these CIA ops could have different behaviour in runtime. + + native_batch_norm is a prim op which has a wrong schema and it needs to be replaced + with correct schema. But until then, we will force decompose it via this tag. + """ + if not isinstance(op, torch._ops.OpOverload): + return False + if torch.Tag.maybe_aliasing_or_mutating in op.tags: + return True + return op == torch.ops.aten.native_batch_norm.default + + +def _add_op_to_registry(registry, op, fn): + """ + This is an internal API for adding an op to the decomposition table. + + If op is OpOverload, it will be added to the registry directly. + If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry. + """ + overloads: list[Union[torch._ops.OperatorBase]] = [] + if isinstance(op, HigherOrderOperator): + # There's no concept of overloads for HigherOrderOperator + registry[op] = fn + return + elif isinstance(op, OpOverload): + overloads.append(op) + else: + assert isinstance(op, OpOverloadPacket) + for ol in op.overloads(): + overloads.append(getattr(op, ol)) + + for op_overload in overloads: + if op_overload in registry: + raise RuntimeError(f"duplicate registrations for {op_overload}") + # TorchScript dumps a bunch of extra nonsense overloads + # which don't have corresponding dispatcher entries, we need + # to filter those out, e.g aten.add.float_int + if torch._C._dispatch_has_kernel(op_overload.name()): + registry[op_overload] = fn + + +def _convert_out_params(f): + out_annotation = f.__annotations__.get("out") + + # If there are no out params, do not wrap the function. + if not out_annotation: + return f + + # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this + if getattr(out_annotation, "__origin__", None) is tuple: + sig = inspect.signature(f) + out_names = sig.return_annotation._fields + # If out is a tuple, we need to register a function that unpacks all the out + # elements as this is what native_functions.yaml expects + + @wraps(f) + def _fn(*args, **kwargs): + out_kwargs = tuple(kwargs.pop(o, None) for o in out_names) + # Either all of the out kwargs are set or none of them + is_none = out_kwargs[0] is None + assert all((o is None) == is_none for o in out_kwargs) + return f(*args, **kwargs, out=None if is_none else out_kwargs) + + out_params = [ + inspect.Parameter( + o, + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=t, + ) + for o, t in zip(out_names, out_annotation.__args__) + ] + # Drop the out parameter and concatenate the new kwargs in the signature + params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, # type: ignore[arg-type] + return_annotation=sig.return_annotation, + ) + # Drop the out parameter and concatenate the new kwargs in the annotations + _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} + for o in out_params: + _fn.__annotations__[o.name] = o.annotation + + # Propagate that this function is wrapped by `out_wrapper` + _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined] + + return _fn + + # Alternatively, there may be a single tensor out parameter with a name + # other than "out". This will need special treatment and is indicated by an + # annotation, which we will remove here so it is not exposed after wrapping. + custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None) + if custom_out_param_name: + + @wraps(f) + def _fn(*args, **kwargs): + out_kwarg = kwargs.pop(custom_out_param_name, None) + return f(*args, **kwargs, out=out_kwarg) + + out_param = inspect.Parameter( + custom_out_param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=out_annotation, + ) + + # Drop the out parameter and concatenate the new kwarg in the signature + sig = inspect.signature(f) + params = chain( + (v for k, v in sig.parameters.items() if k != "out"), (out_param,) + ) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, # type: ignore[arg-type] + return_annotation=sig.return_annotation, + ) + + # Drop the out parameter and concatenate the new kwargs in the annotations + _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} + _fn.__annotations__[out_param.name] = out_param.annotation + + return _fn + + return f + + +def register_decomposition( + aten_op, registry=None, *, type="post_autograd", unsafe=False +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + A decorator to register a function as a decomposition to the Python + decomposition table. Use it like this:: + + @register_decomposition(torch.ops.aten.clamp_min) + def clamp_min(x): + return torch.clamp(self, min=min) + + If you are writing a new decomposition, consider contributing it + directly to PyTorch in torch._decomp.decompositions. + + This API is experimental; we are almost certainly going to extend + the API when we make decompositions eligible for use in transforms (e.g., + autograd) and not just backend tracing, where we then need to know if a + decomposition can be used to simulate a transform. + + By default, we also will register it to the Meta key of dispatcher, + and replace the c++ Meta implementation if there is already one. + + unsafe kwarg is for reuse of this function for registering non-function + things + """ + + assert type in {"post_autograd", "pre_autograd", "meta"} + + def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]: + orig_fn = fn + if not unsafe: + fn = _convert_out_params(fn) + + nonlocal registry + if registry is None: + registry = global_decomposition_table[type] + + def register(op): + _add_op_to_registry(registry, op, fn) + + # To handle allowing multiple aten_ops at once + pytree.tree_map_(register, aten_op) + return orig_fn + + return decomposition_decorator + + +def get_decompositions( + aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]], + type: str = "post_autograd", +) -> dict[torch._ops.OperatorBase, Callable]: + """ + Retrieve a dictionary of decompositions corresponding to the list of + operator overloads and overload packets passed as input. Overload + packets will include all decomposed overloads in the packet. If there is + no decomposition for a requested operator, it is silently ignored. + + This API is experimental; we are almost certainly going to give an alternate, + more recommended formulation, where a user provides the set of operators + they know how to implement, and we provide decompositions for everything + not in this set. + """ + assert type in {"post_autograd", "pre_autograd", "meta"} + + registry = global_decomposition_table[type] + packets_to_overloads = defaultdict(list) + for opo in registry: + if isinstance(opo, (OpOverload, OpOverloadPacket)): + packets_to_overloads[opo.overloadpacket].append(opo) + decompositions: dict[torch._ops.OperatorBase, Callable] = {} + for op in aten_ops: + if isinstance(op, OpOverloadPacket) and op in packets_to_overloads: + for op_overload in packets_to_overloads[op]: + decompositions[op_overload] = registry[op_overload] + elif isinstance(op, (torch._ops.OperatorBase)) and op in registry: + decompositions[op] = registry[op] + return decompositions + + +def remove_decompositions( + decompositions: dict[torch._ops.OperatorBase, Callable], + aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]], +) -> None: + """ + Given a dictionary of decompositions obtained from get_decompositions(), removes + operators associated with a list of operator overloads and overload packets passed + as input. If the decomposition dictionary does not contain a decomposition that is + specified to be removed, it is silently ignored. + """ + for op in aten_ops: + if isinstance(op, OpOverloadPacket): + for overload_name in op.overloads(): + opo = getattr(op, overload_name) + decompositions.pop(opo, None) + elif isinstance(op, OpOverload): + decompositions.pop(op, None) + + +# populate the table +import torch._decomp.decompositions +import torch._refs + + +def core_aten_decompositions() -> "CustomDecompTable": + from torch.export.exported_program import default_decompositions + + return default_decompositions() + + +# See NOTE [Core ATen Ops] +# +# list was copied from torch/_inductor/decomposition.py +# excluding decompositions that results in prim ops +# Resulting opset of decomposition is core aten ops +def _core_aten_decompositions_post_autograd() -> dict[ + torch._ops.OperatorBase, Callable +]: + aten = torch.ops.aten + return get_decompositions( + [ + aten.addcdiv, + aten.addcdiv_, + aten.addcmul, + aten.addcmul_, + aten.addr, + aten.affine_grid_generator, + aten.alias_copy, + aten.all, + aten.aminmax, + aten.arange.default, + aten.arange.start, + aten.avg_pool2d_backward, + aten.baddbmm, + aten.binary_cross_entropy, + aten.binary_cross_entropy_backward, + aten.binary_cross_entropy_with_logits, + aten.block_diag, + aten.bernoulli.p, + aten.bernoulli.default, + aten.celu, + aten.celu_, + aten.channel_shuffle, + aten.clamp_max, + aten.clamp_min, + aten.col2im, + aten.count_nonzero, + aten.linalg_cross, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.miopen_batch_norm_backward, + aten.deg2rad, + aten.deg2rad_, + aten.detach, + aten.diag_embed, + aten.diagonal_backward, + aten.diagonal_copy, + aten.dot, + aten.vdot, + aten.elu_, + aten.elu_backward, + aten._embedding_bag, + aten.embedding_dense_backward, + aten.empty_like, + aten._euclidean_dist.default, + aten.expand_as, + aten.expand_copy, + aten.eye, + aten.fill, + aten.fill_, + aten.floor_divide, + aten.frac, + aten.frac_, + aten._fused_moving_avg_obs_fq_helper, + aten.gelu_, + aten.gelu_backward, + aten.glu, + aten.glu_backward, + aten.hardshrink, + aten.hardsigmoid, + aten.hardsigmoid_, + aten.hardsigmoid_backward, + aten.hardswish, + aten.hardswish_, + aten.hardswish_backward, + aten.hardtanh_, + aten.hardtanh_backward, + aten.heaviside, + aten.heaviside_, + aten.huber_loss, + aten.huber_loss_backward, + aten.im2col, + aten.index_add.out, + aten.index_add.default, + aten.index_add_, + aten.index_copy.out, + aten.index_copy.default, + aten.index_copy_, + aten.index_fill.int_Scalar, + aten.index_fill.int_Tensor, + aten.index_fill.int_Scalar_out, + aten.index_fill.int_Tensor_out, + aten.index_fill_, + aten.isin, + aten.isneginf, + aten.isposinf, + aten.l1_loss, + aten._lazy_clone, + aten._test_parallel_materialize, + aten.leaky_relu_, + aten.leaky_relu_backward, + aten.lerp, + aten.lerp_, + aten.linspace, + aten.logaddexp, + aten.logaddexp2, + aten.logit, + aten.logit_, + aten.logit_backward, + aten.log_sigmoid_backward, + aten.log_sigmoid_forward, + aten._log_softmax_backward_data, + aten.logspace, + aten.logsumexp.default, + aten.masked_fill, + aten.masked_fill_, + aten.max_unpool2d, + aten.max_unpool3d, + aten.mish, + aten.mish_, + aten.mse_loss, + aten.mse_loss_backward, + aten.multi_margin_loss, + aten.multilabel_margin_loss_forward, + aten.mv, + aten.mvlgamma, + aten.mvlgamma_, + aten.nansum, + aten.nan_to_num, + aten.nan_to_num_, + aten.narrow, + aten.native_batch_norm_backward, + aten.native_dropout_backward, + aten.native_group_norm_backward, + aten.native_layer_norm_backward, + aten.new_empty, + aten.new_full, + aten.new_ones, + aten.new_zeros, + aten.nll_loss2d_forward, + aten.nll_loss2d_backward, + aten.nll_loss_backward, + aten.nll_loss_forward, + aten.norm.ScalarOpt_dtype, + aten.norm.Scalar, + aten.norm.ScalarOpt_dim_dtype, + aten.norm.ScalarOpt_dim, + aten.norm.dtype_out, + aten.norm.out, + aten.norm.names_dtype_out, + aten.norm.names_out, + aten.norm.ScalarOpt_dtype_out, + aten.norm.Scalar_out, + aten.ones, + aten.ones_like, + aten.pixel_shuffle, + aten.pixel_unshuffle, + aten._prelu_kernel, + aten._prelu_kernel_backward, + aten._reshape_alias, + aten.rad2deg, + aten.rad2deg_, + aten.reflection_pad1d, + aten.reflection_pad1d_backward, + aten.reflection_pad2d, + aten.reflection_pad2d_backward, + aten.reflection_pad3d, + aten.reflection_pad3d_backward, + aten.replication_pad1d, + aten.replication_pad2d, + aten.replication_pad3d, + aten.renorm, + aten.renorm_, + aten.replication_pad2d, + aten.resize_as, + aten.roll, + aten.rot90, + aten.rrelu_with_noise, + aten.rrelu_with_noise_, + aten.rsub, + aten._safe_softmax, + aten._scaled_dot_product_flash_attention_for_cpu.default, + aten.select_backward, + aten.select_scatter, + aten.sgn, + aten.sgn_, + aten.sigmoid_backward, + aten.silu, + aten.silu_, + aten.silu_backward.grad_input, + aten.sinc, + aten.sinc_, + aten.slice_backward, + aten.smooth_l1_loss, + aten.smooth_l1_loss_backward, + aten.soft_margin_loss, + aten.soft_margin_loss_backward, + aten._softmax_backward_data, + aten.softplus, + aten.softplus_backward, + aten.softshrink, + aten.special_entr, + aten.special_log_ndtr, + aten.special_xlog1py, + aten.split.Tensor, + aten.split_with_sizes_copy, + aten.squeeze_copy, + aten.squeeze.default, + aten.squeeze.dim, + aten.std.correction, + aten.std.out, + aten.std.correction_out, + aten.std.names_out, + aten.std.correction_names_out, + aten.std_mean.correction, + aten.std_mean.correction_out, + aten.stack, + aten.sum.default, + aten.sum.out, + aten.t, + aten.t_copy, + aten.take, + aten.tanh_backward, + aten.threshold, + aten.threshold_, + aten.threshold_backward, + aten.trace, + aten.transpose.int, + aten.transpose_copy, + aten.tril, + aten.tril_, + aten.triu, + aten.triu_, + aten.unbind, + aten.unfold_backward, + aten.unfold_copy, + aten._unsafe_index, + aten._unsafe_index_put, + aten._unsafe_masked_index, + aten._unsafe_masked_index_put_accumulate, + aten.unsafe_split.Tensor, + aten.unsafe_split_with_sizes, + aten.unsqueeze_copy, + aten._unsafe_view, + aten.upsample_linear1d, + aten.upsample_bilinear2d.out, + aten.upsample_trilinear3d.out, + aten.upsample_nearest2d_backward, + aten.view_as_complex, + aten.xlogy, + aten.xlogy_, + aten.zero, + aten.zero_, + aten.zeros, + aten.zeros_like, + aten._chunk_cat, + aten._weight_norm_interface, + ] + ) diff --git a/phivenv/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3182de40f6e3d6844f7cfb04eec0cee52f981dbd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_decomp/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99f578b8d99161066f4f76f3a28c1d2ef47acbd3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33426a8b553b35214f161e1929b5f06a75adc487 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_decomp/decompositions.py b/phivenv/Lib/site-packages/torch/_decomp/decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..ee304129228610325745f71be4791cfb243e3192 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_decomp/decompositions.py @@ -0,0 +1,5224 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import itertools +import numbers +import operator +import sys +from collections.abc import Iterable +from enum import Enum +from functools import partial, reduce +from itertools import chain, product +from typing import Any, Callable, cast, Optional, Union + +import torch +import torch._meta_registrations +import torch._prims as prims +import torch._prims_common as utils +import torch.nn.functional as F +from torch import sym_float, sym_int, Tensor +from torch._decomp import register_decomposition +from torch._higher_order_ops.out_dtype import out_dtype +from torch._prims_common import ( + IntLike, + NumberType, + suggest_memory_format, + TensorLike, + TensorSequenceType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + out_wrapper, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map + + +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + +# None of these functions are publicly accessible; get at them +# from torch._decomps +__all__: list[str] = [] + +aten = torch._ops.ops.aten + + +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + + +# This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided +# We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops +# Will need to validate the non-elementwise uses +def type_casts( + f: Callable, + type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND, + compute_dtype_only: bool = False, + include_non_tensor_args: bool = False, +): + @functools.wraps(f) + def inner(*args, **kwargs): + allowed_types = ( + (Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,) + ) # type: ignore[arg-type] + flat_args = [ + x + for x in pytree.arg_tree_leaves(*args, **kwargs) + if isinstance(x, allowed_types) + ] + computation_dtype, result_dtype = utils.elementwise_dtypes( + *flat_args, type_promotion_kind=type_promotion + ) + + # TODO: pretty sure this is not quite right + def increase_prec(x): + if isinstance(x, Tensor): + return x.to(computation_dtype) + else: + return x + + def decrease_prec(x): + if isinstance(x, Tensor): + return x.to(result_dtype) + else: + return x + + r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs)) + if compute_dtype_only: + return r + else: + return tree_map(decrease_prec, r) + + return inner + + +compute_only_pw_cast_for_opmath = partial( + type_casts, + type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + compute_dtype_only=True, +) +pw_cast_for_opmath = partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +) +pw_cast_for_opmath_non_tensor_args = partial( + type_casts, + type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + include_non_tensor_args=True, +) +pw_cast_for_int_to_real = partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +) + + +# This expands x until x.dim() == dim. Might be useful as an operator +def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor: + for _ in range(dim - x.dim()): + x = x.unsqueeze(-1) + return x + + +@register_decomposition(aten.tanh_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def tanh_backward(out_grad: Tensor, y: Tensor): + return out_grad * (1 - y * y).conj_physical() + + +@register_decomposition(aten.sigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def sigmoid_backward(out_grad: Tensor, y: Tensor): + return out_grad * (y * (1 - y)).conj_physical() + + +@register_decomposition(aten.softplus_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float): + z = (x * beta).exp() + return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) + + +@register_decomposition(aten.elu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def elu_backward( + grad_output: Tensor, + alpha: float, + scale: float, + input_scale: float, + is_result: bool, + self_or_result: Tensor, +): + negcoef = alpha * scale + poscoef = scale + negiptcoef = input_scale + if is_result: + return torch.where( + self_or_result <= 0, + grad_output * negiptcoef * (self_or_result + negcoef), + grad_output * poscoef, + ) + else: + return torch.where( + self_or_result <= 0, + grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef), + grad_output * poscoef, + ) + + +@register_decomposition([aten.fill.Scalar]) +def fill_scalar(self, value): + return torch.full_like(self, value) + + +@register_decomposition([aten.fill.Tensor]) +def fill_tensor(self, value: Tensor): + torch._check( + value.dim() == 0, + lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions", + ) + return aten.copy(self, value) + + +@register_decomposition(aten.hardsigmoid) +@out_wrapper() +@pw_cast_for_opmath +def hardsigmoid(self: Tensor) -> Tensor: + return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 + + +@register_decomposition(aten.hardsigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def hardsigmoid_backward(grad_output: Tensor, self: Tensor): + return torch.where( + (self > -3.0) & (self < 3.0), + grad_output * (1.0 / 6.0), + 0.0, + ) + + +@register_decomposition(aten.hardtanh_backward) +@out_wrapper("grad_input") +def hardtanh_backward( + grad_output: Tensor, self: Tensor, min_val: float, max_val: float +): + return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output) + + +@register_decomposition(aten.hardswish) +@out_wrapper() +@pw_cast_for_opmath +def hardswish(self: Tensor) -> Tensor: + return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 + + +@register_decomposition(aten.hardswish_backward) +@out_wrapper() +@pw_cast_for_opmath +def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor: + return torch.where( + self <= -3, + 0.0, + torch.where(self < 3, grad_output * ((self / 3) + 0.5), grad_output), + ) + + +@register_decomposition(aten.threshold_backward) +@out_wrapper("grad_input") +def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float): + return torch.where(self <= threshold, 0, grad_output) + + +@register_decomposition(aten.leaky_relu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def leaky_relu_backward( + grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool +): + return torch.where(self > 0, grad_output, grad_output * negative_slope) + + +@register_decomposition(aten.gelu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"): + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + x_sq = self * self + x_cube = x_sq * self + inner = kBeta * (self + kKappa * x_cube) + tanh_inner = torch.tanh(inner) + + left = 0.5 * self + right = 1 + tanh_inner + + left_derivative = 0.5 * right + + tanh_derivative = 1 - tanh_inner * tanh_inner + inner_derivative = kBeta * (1 + 3 * kKappa * x_sq) + right_derivative = left * tanh_derivative * inner_derivative + + return grad * (left_derivative + right_derivative) + else: + kAlpha = M_SQRT1_2 + kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5 + cdf = 0.5 * (1 + torch.erf(self * kAlpha)) + pdf = kBeta * torch.exp(self * self * -0.5) + return grad * (cdf + self * pdf) + + +@register_decomposition(aten.mish_backward) +@pw_cast_for_opmath +def mish_backward(grad_output: Tensor, input: Tensor): + input_tanh_softplus = torch.tanh(F.softplus(input)) + input_sigmoid = torch.sigmoid(input) + out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus) + return grad_output * (input_tanh_softplus + out) + + +@register_decomposition(aten.silu) +@out_wrapper() +@pw_cast_for_opmath +def silu(self: Tensor) -> Tensor: + return self * torch.sigmoid(self) + + +@register_decomposition(aten.silu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor: + sigmoid = 1 / (1 + torch.exp(-self)) + return grad_output * sigmoid * (1 + self * (1 - sigmoid)) + + +@register_decomposition(aten._prelu_kernel) +def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor: + return torch.where(self > 0, self, weight * self) + + +@register_decomposition(aten._prelu_kernel_backward) +def _prelu_kernel_backward( + grad_output: Tensor, + self: Tensor, + weight: Tensor, +) -> tuple[Tensor, Tensor]: + input_grad = torch.where(self > 0, grad_output, weight * grad_output) + weight_grad = torch.where(self > 0, 0.0, self * grad_output) + return (input_grad, weight_grad) + + +@register_decomposition(aten.rrelu_with_noise_backward) +@out_wrapper() +@pw_cast_for_opmath +def rrelu_with_noise_backward( + grad_output: Tensor, + self: Tensor, + noise: Tensor, + lower: float, + upper: float, + training: bool, + self_is_result: bool, +) -> Tensor: + if training and upper - lower > 1e-6: + return grad_output.mul(noise) + else: + negative_slope = (lower + upper) / 2 + return aten.leaky_relu_backward( + grad_output, self, negative_slope, self_is_result + ) + + +@register_decomposition(aten.log_sigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor: + in_negative = self < 0 + max_deriv = torch.where(in_negative, 1, 0) + sign = torch.where(in_negative, 1, -1) + z = torch.exp(-torch.abs(self)) + return grad_output * (max_deriv - sign * (z / (1 + z))) + # CPU has a special formula that uses buffer, but disabled for convenience sake + # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output + + +def apply_loss_reduction(loss: Tensor, reduction: int): + if reduction == Reduction.MEAN.value: + return torch.mean(loss) + elif reduction == Reduction.SUM.value: + return torch.sum(loss) + else: + return loss + + +def to_real_dtype(dtype: torch.dtype): + if dtype == torch.complex32: + return torch.float16 + elif dtype == torch.complex64: + return torch.float32 + elif dtype == torch.complex128: + return torch.float64 + + +# TODO: None of these loss castings are quite correct, see +# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels +# perform the pointwise portion in opmath, but don't maintain it between the +# pointwise portion and the reduction + + +@register_decomposition(aten.mse_loss) +@out_wrapper() +@pw_cast_for_opmath +def mse_loss( + self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value +) -> Tensor: + loss = (self - target) ** 2 + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.mse_loss_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def mse_loss_backward( + grad_output: Tensor, input: Tensor, target: Tensor, reduction: int +): + norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0 + return norm * (input - target) * grad_output + + +@register_decomposition(aten._safe_softmax) +def safe_softmax(self, dim, dtype=None): + out = torch.softmax(self, dim=dim, dtype=dtype) + masked = self.eq(float("-inf")) + masked_rows = torch.all(masked, dim=dim, keepdim=True) + zeros = torch.zeros_like(out) + return torch.where(masked_rows, zeros, out) + + +@register_decomposition(aten.smooth_l1_loss) +@out_wrapper() +@pw_cast_for_opmath +def smooth_l1_loss( + self: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, + beta: float = 1.0, +): + loss = (self - target).abs() + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.smooth_l1_loss_backward.default) +@pw_cast_for_opmath +def smooth_l1_loss_backward( + grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, beta: float +): + norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 + x = self - target + abs_x = torch.abs(x) + norm_grad = norm * grad_output + return torch.where( + abs_x < beta, + norm_grad * x / beta, + norm_grad * torch.sign(x), + ) + + +@register_decomposition(aten.smooth_l1_loss_backward.grad_input) +@pw_cast_for_opmath +def smooth_l1_loss_backward_out( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int, + beta: float, + grad_input: Tensor, +): + result = smooth_l1_loss_backward(grad_output, self, target, reduction, beta) + _maybe_resize_out(grad_input, result.shape) + return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) + + +@register_decomposition(aten.huber_loss_backward.default) +@pw_cast_for_opmath +def huber_loss_backward( + grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float +): + norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 + x = self - target + return torch.where( + x < -delta, + -norm * grad_output * delta, + torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output), + ) + + +# We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input' +@register_decomposition(aten.huber_loss_backward.out) +@pw_cast_for_opmath +def huber_loss_backward_out( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int, + delta: float, + grad_input: Tensor, +): + result = huber_loss_backward(grad_output, self, target, reduction, delta) + _maybe_resize_out(grad_input, result.shape) + return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) + + +def _nll_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + channel_dim = 0 if self.dim() < 2 else 1 + if reduction == Reduction.MEAN.value: + grad_output = grad_output / total_weight + + target = target.unsqueeze(channel_dim) + safe_target = torch.where(target != ignore_index, target, 0) + grad_input = torch.zeros_like(self) + grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) + + if grad_input.dim() > grad_output.dim() > 0: + grad_output = grad_output.unsqueeze(channel_dim) + + if weight is not None: + new_shape = [1 for _ in range(self.dim())] + new_shape[channel_dim] = weight.shape[0] + weight = weight.reshape(new_shape) + grad_output = grad_output * weight + + grad_output = torch.where(target != ignore_index, grad_output, 0) + + return grad_input * grad_output + + +@register_decomposition(aten.glu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor: + assert self.dim() > 0, "glu does not support 0-dimensional tensors" + wrap_dim = utils.canonicalize_dim(self.dim(), dim) + nIn = self.size(wrap_dim) + assert nIn % 2 == 0, ( + f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}" + ) + inputSize = nIn // 2 + firstHalf = self.narrow(wrap_dim, 0, inputSize) + secondHalf = self.narrow(wrap_dim, inputSize, inputSize) + gradInputFirstHalf = torch.sigmoid(secondHalf) + gradInputSecondHalf = ( + (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output + ) + gradInputFirstHalf = gradInputFirstHalf * grad_output + return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim) + + +@register_decomposition(aten.nll_loss_backward) +@out_wrapper("grad_input") +def nll_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D" + assert target.dim() <= 1, ( + "0D or 1D target tensor expected, multi-target not supported" + ) + + no_batch_dim = self.dim() == 1 and target.dim() == 0 + assert no_batch_dim or (self.shape[0] == target.shape[0]), ( + f"size mismatch (got input: {self.shape}, target: {target.shape})" + ) + assert total_weight.numel() == 1, ( + "expected total_weight to be a single element tensor, got: ", + f"{total_weight.shape} ({total_weight.numel()} elements)", + ) + + assert weight is None or weight.numel() == self.shape[-1], ( + "weight tensor should be defined either for all or no classes" + ) + + if reduction == Reduction.NONE.value and self.dim() == 2: + assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], ( + f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but " + f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}" + ) + else: + assert grad_output.dim() <= 1 and grad_output.numel() == 1, ( + f"Expected a single element grad_output tensor, but got: {grad_output.shape}" + ) + + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) + + +@register_decomposition(aten.nll_loss2d_backward) +@out_wrapper("grad_input") +def nll_loss2d_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + assert self.dim() == 4, ( + f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}" + ) + + assert target.dim() == 3, ( + f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}" + ) + + assert ( + self.shape[0] == target.shape[0] + and self.shape[2] == target.shape[1] + and self.shape[3] == target.shape[2] + ), f"size mismatch (got input: {self.shape}, target: {target.shape}" + + assert total_weight.numel() == 1, ( + "expected total_weight to be a single element tensor, " + f"got: {total_weight.shape} ( {total_weight.numel()}, elements)" + ) + + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) + + +@register_decomposition(aten.binary_cross_entropy) +@out_wrapper() +@pw_cast_for_opmath +def binary_cross_entropy( + self: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + # We cannot currently model this without introducing data-dependent control flow + # TORCH_CHECK( + # (input_val >= 0) && (input_val <= 1), + # "all elements of input should be between 0 and 1" + # ) + loss = (target - 1) * torch.maximum( + torch.log1p(-self), self.new_full((), -100) + ) - target * torch.maximum(torch.log(self), self.new_full((), -100)) + if weight is not None: + loss = loss * weight + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.binary_cross_entropy_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def binary_cross_entropy_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + EPSILON = 1e-12 + result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON) + if weight is not None: + result = result * weight + if reduction == Reduction.MEAN.value: + result = result / self.numel() + return result + + +@register_decomposition(aten.soft_margin_loss) +@out_wrapper() +@pw_cast_for_opmath +def soft_margin_loss( + input: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + loss = torch.log1p(torch.exp(-input * target)) + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.soft_margin_loss_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def soft_margin_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + grad_input = target * grad_output * (torch.sigmoid(target * self) - 1) + if reduction == Reduction.MEAN.value: + grad_input = grad_input / self.numel() + return grad_input + + +@register_decomposition(aten.dist) +@out_wrapper() +def dist(input: Tensor, other: Tensor, p: float = 2): + return aten.norm(input - other, p=p) + + +@register_decomposition(aten._euclidean_dist) +@out_wrapper() +def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: + x1_norm = x1.pow(2).sum(-1, True) + x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format) + x2_norm = x2.pow(2).sum(-1, True) + x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format) + x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1) + x2_ = torch.cat([x2, x2_pad, x2_norm], -1) + result = x1_.matmul(x2_.mT) + return result.clamp_min(0).sqrt() + + +@register_decomposition(aten.slice_backward) +@out_wrapper() +def slice_backward( + grad_output: Tensor, + input_sizes: list[int], + dim: int, + start: int, + end: int, + step: int, +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.slice_scatter(grad_input, grad_output, dim, start, end, step) + + +@register_decomposition(aten.slice.Tensor) +def slice_forward( + # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1 + self: Tensor, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +): + from torch.fx.experimental.symbolic_shapes import ( + guard_size_oblivious, + statically_known_true, + ) + + ndim = self.dim() + if ndim == 0: + raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") + dim = utils.canonicalize_dim(self.dim(), dim) + sizes = list(self.size()) + strides = list(self.stride()) + + if step <= 0: + raise RuntimeError("slice step must be positive") + + start_val = start if start is not None else 0 + end_val = end if end is not None else sys.maxsize # 2^63 - 1 + + if guard_size_oblivious(start_val < 0): + start_val += sizes[dim] + + if guard_size_oblivious(end_val < 0): + end_val += sizes[dim] + + if guard_size_oblivious(start_val < 0): + start_val = 0 + elif guard_size_oblivious(start_val > sizes[dim]): + start_val = sizes[dim] + + if statically_known_true(end_val == sys.maxsize): + end_val = sizes[dim] + elif guard_size_oblivious(end_val < start_val): + end_val = start_val + elif guard_size_oblivious(end_val > sizes[dim]): + end_val = sizes[dim] + + storage_offset = self.storage_offset() + start_val * strides[dim] + len = end_val - start_val + sizes[dim] = (len + step - 1) // step + strides[dim] *= step + + if self.is_quantized: + raise NotImplementedError( + "Slice decomposition for quantized tensors aren't implemented" + ) + else: + return self.as_strided(sizes, strides, storage_offset) + + +def _normalize_start_end( + x: Tensor, dim: int, start: Optional[int], end: Optional[int] +) -> tuple[int, int]: + """ + Normalize start and end such that both are in the range + [0, x.get_size()[dim]] and start <= end. + """ + dim_size = x.shape[dim] + + def clamp_wrap(val, lower, upper, default) -> int: + if val is None: + return default + if val < 0: + val = val + dim_size + return min(max(val, lower), upper) + + start = clamp_wrap(start, 0, dim_size, 0) + end = clamp_wrap(end, start, dim_size, dim_size) + return start, end + + +# This is not in torch._refs because aten.index used by +# aten._unsafe_masked_index does not have a decomposition. +@register_decomposition(aten.slice_scatter) +@out_wrapper() +def slice_scatter( + input: Tensor, + src: Tensor, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +): + dim = utils.canonicalize_dim(input.ndim, dim) + dim_size = input.shape[dim] + start, end = _normalize_start_end(input, dim, start, end) + + src_size = list(input.shape) + src_size[dim] = (end - start + (step - 1)) // step + src = src.expand(src_size) + + if start == 0 and end == dim_size and step == 1: + return src.clone() + + indices = [None] * input.dim() + idx = torch.arange(dim_size, device=input.device) + indices[dim] = (idx - start) // step + + mask = torch.ones(dim_size, device=input.device, dtype=torch.bool) + if start != 0: + mask = torch.logical_and(mask, idx >= start) + + if end != dim_size: + mask = torch.logical_and(mask, idx < end) + + if step != 1: + mask = torch.logical_and(mask, (idx - start) % step == 0) + + mask_shape = [1] * input.dim() + mask_shape[dim] = -1 + mask = mask.view(mask_shape) + return aten.where(mask, aten._unsafe_masked_index(src, mask, indices, 0), input) + + +@register_decomposition(aten.select_backward) +@out_wrapper() +def select_backward(grad_output: Tensor, input_sizes: list[int], dim: int, index: int): + grad_input = grad_output.new_zeros(input_sizes) + return torch.select_scatter(grad_input, grad_output, dim, index) + + +@register_decomposition(aten.diagonal_backward) +@out_wrapper() +def diagonal_backward( + grad_output: Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) + + +def _cast_grad_to_input_dtype( + grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype +): + if grad_output.dtype != input_dtype: + grad_input = grad_input.to(input_dtype) + return grad_input + + +@register_decomposition(aten._softmax_backward_data) +@out_wrapper("grad_input") +@compute_only_pw_cast_for_opmath +def _softmax_backward_data( + grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype +): + new_grad_output = grad_output * output + grad_input = new_grad_output - output * torch.sum( + new_grad_output, dim=dim, keepdim=True + ) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous() + + +@register_decomposition(aten._log_softmax_backward_data) +@out_wrapper() +@compute_only_pw_cast_for_opmath +def _log_softmax_backward_data( + grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype +): + grad_input = grad_output - torch.exp(output) * torch.sum( + grad_output, dim=dim, keepdim=True + ) + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype) + + +def _im2col_col2im_indices_along_dim( + input_d, kernel_d, dilation_d, padding_d, stride_d, device +): + """Utility function to implement im2col and col2im""" + blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1) + + arange_kw = partial(torch.arange, dtype=torch.int64, device=device) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1) + + # Broadcast and add kernel starting positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + return blocks_d_indices + kernel_grid + + +@register_decomposition(aten.im2col) +@out_wrapper() +def im2col( + input: Tensor, + kernel_size: list[int], + dilation: list[int], + padding: list[int], + stride: list[int], +) -> Tensor: + torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported") + torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported") + + def check_positive(param, param_name, strict=True): + cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) + torch._check( + cond, lambda: "{param_name} should be greater {'than' zero, but got {param}" + ) + + check_positive(kernel_size, "kernel_size") + check_positive(dilation, "dilation") + check_positive(dilation, "padding", strict=False) + check_positive(stride, "stride") + + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (3, 4) and all(d != 0 for d in shape[-3:]), + lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size " + f"and non-zero dimensions, but got: {tuple(shape)}", + ) + output_size = tuple( + 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st + for out, pad, dil, ker, st in zip( + shape[-2:], padding, dilation, kernel_size, stride + ) + ) + torch._check( + all(c > 0 for c in output_size), + lambda: f"Given an input with spacial size {tuple(shape[-2:])}, " + f"kernel_size={kernel_size}, dilation={dilation}, " + f"padding={padding}, stride={stride}, " + "the calculated shape of the array of sliding blocks " + f"is {output_size}, but its components must be at least one.", + ) + batched_input = ndim == 4 + if not batched_input: + input = input.unsqueeze(0) + + batch_dim, channel_dim, input_h, input_w = input.shape + + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + kernel_h, kernel_w = kernel_size + + blocks_row_indices = _im2col_col2im_indices_along_dim( + input_h, kernel_h, dilation_h, padding_h, stride_h, input.device + ) + blocks_col_indices = _im2col_col2im_indices_along_dim( + input_w, kernel_w, dilation_w, padding_w, stride_w, input.device + ) + + # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom) + # ugh + padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h)) + + blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1) + output = padded_input[:, :, blocks_row_indices, blocks_col_indices] + output = output.permute(0, 1, 2, 4, 3, 5) + num_blocks_row = blocks_row_indices.size(1) + num_blocks_col = blocks_col_indices.size(1) + output = output.reshape( + batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col + ) + + if not batched_input: + output = output.squeeze(0) + return output + + +@register_decomposition(aten.col2im) +@out_wrapper() +@pw_cast_for_opmath +def col2im( + input: Tensor, + output_size: list[int], + kernel_size: list[int], + dilation: list[int], + padding: list[int], + stride: list[int], +) -> Tensor: + torch._check(len(output_size) == 2, lambda: "only 2D output_size supported") + torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "only 2D padding supported") + torch._check(len(stride) == 2, lambda: "only 2D stride supported") + + def check_positive(param, param_name, strict=True): + cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) + torch._check( + cond, lambda: "{param_name} should be greater than zero, but got {param}" + ) + + check_positive(kernel_size, "kernel_size") + check_positive(dilation, "dilation") + check_positive(padding, "padding", strict=False) + check_positive(stride, "stride") + check_positive(output_size, "output_size") + + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (2, 3) and all(d != 0 for d in shape[-2:]), + lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size " + f"and non-zero dimensions, but got: {tuple(shape)}", + ) + prod_kernel_size = kernel_size[0] * kernel_size[1] + torch._check( + shape[-2] % prod_kernel_size == 0, + lambda: "Expected size of input's first non-batch dimension to be divisible by the " + f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and " + f"kernel_size={kernel_size}", + ) + col = [ + 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st + for out, pad, dil, ker, st in zip( + output_size, padding, dilation, kernel_size, stride + ) + ] + L = col[0] * col[1] + torch._check( + shape[-1] == L, + lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " + f"dilation={dilation}, padding={padding}, stride={stride}, " + f"expected input.size(-1) to be {L} but got {shape[-1]}.", + ) + torch._check( + L > 0, + lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " + f"dilation={dilation}, padding={padding}, stride={stride}, " + f"expected input.size(-1) to be {L} but got {shape[-1]}.", + ) + batched_input = ndim == 3 + if not batched_input: + input = input.unsqueeze(0) + + shape = input.shape + + out_h, out_w = output_size + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + kernel_h, kernel_w = kernel_size + + # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand + input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col) + input = input.permute(0, 1, 2, 4, 3, 5) + + indices_row = _im2col_col2im_indices_along_dim( + out_h, kernel_h, dilation_h, padding_h, stride_h, input.device + ) + indices_row = _unsqueeze_to_dim(indices_row, 4) + indices_col = _im2col_col2im_indices_along_dim( + out_w, kernel_w, dilation_w, padding_w, stride_w, input.device + ) + + output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)] + output = input.new_zeros( + [shape[0], shape[1] // prod(kernel_size)] + output_padded_size + ) + idx = (None, None, indices_row, indices_col) + output = aten._unsafe_index_put(output, idx, input, accumulate=True) + output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h)) + + if not batched_input: + output = output.squeeze(0) + return output + + +@register_decomposition(aten.native_dropout_backward) +@out_wrapper() +def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): + # According to the CUDA kernel implementation we should have this test; + # but it seems to fail tests! + # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") + + # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format + # This different from TensorIterator's behavior + r = (grad_output * (mask.type_as(grad_output) * scale)).clone( + memory_format=utils.suggest_memory_format(grad_output) + ) + return r + + +@register_decomposition(aten.unfold_backward) +@out_wrapper() +def unfold_backward( + grad: Tensor, input_size: list[int], dimension: int, size: int, step: int +) -> Tensor: + if len(input_size) == 0: + return torch.squeeze_copy(grad, 0) + dim = utils.canonicalize_dim(len(input_size), dimension) + idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32) + idx = idx.unfold(0, size, step).flatten() + grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1) + # nb. At the moment this generates two kernels in triton + # It could potentially be fused into one call to scatter_reduce, + # in the case step <= size provided scatter_reduce generates 1 kernel + grad_input = grad.new_zeros(input_size) + index = (None,) * dim + (idx,) + return aten._unsafe_index_put(grad_input, index, grad, accumulate=True).contiguous() + + +@register_decomposition(aten.logit_backward.default) +@pw_cast_for_opmath +def logit_backward( + grad_output: Tensor, self: Tensor, eps: Optional[float] = None +) -> Tensor: + if eps is not None: + lo = eps + hi = 1.0 - lo + return torch.where( + torch.logical_and(self >= lo, self <= hi), + grad_output / (self * (1.0 - self)), + 0.0, + ) + else: + return torch.where( + torch.logical_and(self >= 0.0, self <= 1.0), + grad_output / (self * (1.0 - self)), + self.new_full((), float("nan")), + ) + + +@register_decomposition(aten.dropout) +@aten.dropout.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.dropout.default.py_impl(DispatchKey.Autograd) +def dropout(input: Tensor, p: float, train: Optional[bool]): + if train and p != 0: + return aten.native_dropout(input, p, train)[0] + else: + return input.clone() + + +@register_decomposition(aten.native_dropout) +@out_wrapper("out0", "out1") +def native_dropout(input: Tensor, p: float, train: Optional[bool]): + if train and p != 0: + if p == 1: + return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool)) + if not input.dtype.is_floating_point: + raise RuntimeError( + "result type Float can't be cast to the desired output type Long" + ) + bool_mask = torch.rand_like(input) > p + res = bool_mask * input * float(1.0 / (1.0 - p)) + return (res, bool_mask) + else: + return (input, torch.ones_like(input, dtype=torch.bool)) + + +@register_decomposition(aten._softmax) +@out_wrapper() +def _softmax(x: Tensor, dim: int, half_to_float: bool): + # eager softmax returns a contiguous tensor. Ensure that decomp also returns + # a contiguous tensor. + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if x.numel() == 0: + unnormalized = torch.exp(x) + else: + x_max = torch.amax(x, dim, keepdim=True) + unnormalized = torch.exp(x - x_max) + result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) + if not half_to_float: + result = result.to(result_dtype) + return result + + +@register_decomposition(aten._log_softmax) +@out_wrapper(exact_dtype=True) +def _log_softmax(x: Tensor, dim: int, half_to_float: bool): + # eager log_softmax returns a contiguous tensor. Ensure that decomp also + # returns a contiguous tensor. + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if x.numel() == 0: + shifted = x + else: + x_max = torch.amax(x, dim, keepdim=True) + shifted = x - x_max + shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True)) + result = shifted - shifted_logsumexp + if not half_to_float: + result = result.to(result_dtype) + return result + + +@register_decomposition(aten.embedding) +@out_wrapper() +def embedding( + weight: Tensor, + indices: Tensor, + padding_idx: int = -1, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + assert weight.dim() == 2, "'weight' must be 2-D" + # Nb. scale_grad_by_freq is not used in the forward + if indices.ndim <= 1: + # We need this one as weight[indices] calls item() in these cases + out = weight.index_select(0, indices) + if indices.ndim == 0: + out = out.squeeze(0) + return out + else: + return weight[indices] + + +@register_decomposition(aten.embedding_dense_backward) +@out_wrapper() +def embedding_dense_backward( + grad_output: Tensor, + indices: Tensor, + num_weights: int, + padding_idx: int, + scale_grad_by_freq: bool, +): + computation_dtype, result_dtype = utils.elementwise_dtypes( + grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + grad_output = grad_output.to(computation_dtype) + indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment] + if scale_grad_by_freq: + counts = indices.new_zeros((num_weights,)) + ones = torch.ones_like(indices) + counts = aten._unsafe_index_put(counts, [indices], ones, accumulate=True) + grad_weights_scale = counts[indices] + grad_output = grad_output / grad_weights_scale.unsqueeze(-1) + + mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim) + grad = grad_output.masked_fill(mask, 0) + grad_weight = grad_output.new_zeros( + (num_weights,) + grad_output.shape[indices.ndim :] + ) + return aten._unsafe_index_put(grad_weight, [indices], grad, accumulate=True).to( + result_dtype + ) + + +def prod(x: list[int]): + r = 1 + for i in x: + r *= i + return r + + +def _pad_chunk( + tensors: list[Tensor], + dim: int, + num_chunks: int, +) -> list[Tensor]: + padded_tensors = [] + for tensor in tensors: + tensor_size = tensor.size() + pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks + if pad_along_dim != tensor_size[dim]: + # Use aten.constant_pad_nd instead of copy_ for functionalization + pad = [0] * 2 * (tensor.ndim - dim - 1) + [ + 0, + pad_along_dim - tensor_size[dim], + ] + tensor = aten.constant_pad_nd(tensor, pad, 0) + view_size = tensor_size[:dim] + torch.Size([num_chunks, -1]) + padded_tensors.append(tensor.reshape(view_size)) + return padded_tensors + + +def have_same_ndims(tensors: list[Tensor]): + ndim = tensors[0].ndim + for tensor in tensors: + if tensor.ndim != ndim: + return False + return True + + +def leading_dimension_matches(tensors: list[Tensor], dim: int): + leading_dim_sizes = tensors[0].size()[:dim] + for tensor in tensors: + torch._check( + tensor.size()[:dim] == leading_dim_sizes, + lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors", + ) + + +def _preprocess_chunk_cat_inputs( + tensors: list[Tensor], + dim: int, + num_chunks: int, +): + torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks") + torch._check( + len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list" + ) + expected_dtype = tensors[0].dtype + expected_device = tensors[0].device + for tensor in tensors: + torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor") + torch._check( + tensor.dtype == expected_dtype, + lambda: "_chunk_cat expects all input tensors with the same dtype", + ) + torch._check( + tensor.device == expected_device, + lambda: "_chunk_cat expects all inputs tensors on the same device", + ) + if have_same_ndims(tensors): + dim = utils.canonicalize_dim(tensors[0].dim(), dim) + else: + torch._check( + dim >= 0, + lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims", + ) + for tensor in tensors: + torch._check( + dim < tensor.ndim, + lambda: "_chunk_cat expects dim < ndim for all input tensors", + ) + leading_dimension_matches(tensors, dim) + return dim + + +@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out]) +def _chunk_cat( + tensors: list[Tensor], + dim: int, + num_chunks: int, + out: Optional[Tensor] = None, +) -> Tensor: + dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks) + padded_tensors = _pad_chunk(tensors, dim, num_chunks) + if out is None: + return torch.cat(padded_tensors, dim + 1) + else: + torch.cat(padded_tensors, dim + 1, out=out) + return out + + +# out_wrapper currently does not allow optional outputs +@register_decomposition( + [aten.split_with_sizes_copy.default, aten.split_with_sizes_copy.out] +) +def split_with_sizes_copy( + self: Tensor, + split_sizes: list[int], + dim: int = 0, + out: Optional[list[Tensor]] = None, +) -> Optional[list[Tensor]]: + splits = aten.split_with_sizes(self, split_sizes, dim=dim) + if out is None: + return [s.clone(memory_format=torch.contiguous_format) for s in splits] + else: + for output, split in zip(out, splits): + _maybe_resize_out(output, split.shape) + _safe_copy_out(copy_from=split, copy_to=output, exact_dtype=True) + return None + + +@register_decomposition(aten.unsafe_split.Tensor) +def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]: + return aten.split.Tensor(input, split_size, dim) + + +@register_decomposition(aten.unsafe_split_with_sizes.default) +def unsafe_split_with_sizes( + input: Tensor, split_sizes: list[int], dim: int = 0 +) -> tuple[Tensor, ...]: + return aten.split_with_sizes.default(input, split_sizes, dim) + + +@register_decomposition(aten.split.Tensor) +def split(self: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]: + input_sizes = self.shape + dim_size = input_sizes[dim] + if split_size == 0: + assert dim_size == 0 + return (self.detach(),) + chunks = (dim_size + split_size - 1) // split_size + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import guard_int + + chunks = guard_int(chunks) + split_sizes = [split_size for i in range(chunks)] + split_sizes[-1] = split_size - (split_size * chunks - dim_size) + return torch.split(self, split_sizes, dim) + + +@aten.tensor_split.tensor_indices_or_sections.py_impl( + DispatchKey.CompositeImplicitAutograd +) +def tensor_split_tensor_indices_or_sections_py_impl( + self: Tensor, + tensor_indices_or_sections: Tensor, + dim: int = 0, +) -> tuple[Tensor, ...]: + assert tensor_indices_or_sections.device.type == "cpu" + assert tensor_indices_or_sections.dtype == torch.int64 + split_dim = tensor_indices_or_sections.dim() + torch._check( + split_dim == 1 or split_dim == 0, + lambda: "tensor_split expected tensor_indices_or_sections to be a zero-dimensional " + f"or one-dimensional tensor, but got a tensor with {split_dim} dims", + ) + if split_dim == 0: + sections = tensor_indices_or_sections.item() + assert isinstance(sections, IntLike) + return self.tensor_split(sections, dim) + else: + indices = [i.item() for i in tensor_indices_or_sections] + # WARNING: Tempted to torch._check_is_size on the indices here? You + # can't: tensor_split works with negative values in indices: + # + # >>> torch.tensor_split(torch.randn(10), torch.tensor([-5, 5])) + # (tensor([ 0.3540, 2.1074, -0.8507, 1.1639, 0.3055]), tensor([]), + # tensor([-0.4285, 1.0692, -0.1776, 0.9362, 1.6143])) + # + # Sorry, I don't make the rules. Explicitly do the item call in user + # code if you KNOW that they are non-negative. + return self.tensor_split(indices, dim) + + +# TODO: this doesn't appear to have enough precision in bfloat16 +@register_decomposition(aten.addmm) +@out_wrapper(exact_dtype=True) +@pw_cast_for_opmath +def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + out = alpha * torch.mm(mat1, mat2) + if beta == 0: + return out + + # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition. + # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided. + # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition. + # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input. + # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases. + # This implementation is not ideal, and we should revisit this when we have a better solution. + return out + beta * self + + +@register_decomposition(aten._addmm_activation) +@out_wrapper() +@pw_cast_for_opmath +def _addmm_activation( + self: Tensor, + mat1: Tensor, + mat2: Tensor, + beta: int = 1, + alpha: int = 1, + use_gelu: bool = False, +): + out = addmm(self, mat1, mat2, beta, alpha) + if use_gelu: + if self.is_cuda: + return aten.gelu(out, approximate="tanh") + else: + return aten.gelu(out) + return aten.relu(out) + + +@register_decomposition(aten.addmv) +@out_wrapper(exact_dtype=True) +@pw_cast_for_opmath +def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + out = alpha * torch.mv(mat1, vec) + if beta == 0: + return out + if out.numel() == 0: # handle empty matrix + return beta * self + return out + beta * self + + +@register_decomposition(aten.native_group_norm_backward.default) +@pw_cast_for_opmath +def native_group_norm_backward( + grad_output: Tensor, + input: Tensor, + mean: Tensor, + rstd: Tensor, + gamma: Optional[Tensor], + N: int, + C: int, + HxW: int, + group: int, + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + utils.check_same_device( + grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False + ) + utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False) + utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False) + torch._check( + input.numel() == N * C * HxW, + lambda: f"Expect input to have {N * C * HxW} elements", + ) + torch._check( + mean.shape == (N, group), + lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}", + ) + torch._check( + gamma is None or gamma.numel() == C, + lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}", + ) + + cpg, _rem = divmod(C, group) + torch._check( + _rem == 0, + lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}", + ) + + # Compute Internal gradients + ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2]) + db = grad_output.view(N, C, HxW).sum(dim=[2]) + + d_input: Optional[Tensor] = None + d_gamma: Optional[Tensor] = None + d_bias: Optional[Tensor] = None + if output_mask[0]: + s = 1.0 / (HxW * cpg) + if gamma is not None: + ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) + db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) + c1 = torch.mul( + rstd.unsqueeze(-1), + gamma.reshape(1, group, cpg), + ) + else: + ds_val = ds.reshape(N, group, cpg).sum(2) + db_val = db.reshape(N, group, cpg).sum(2) + c1 = torch.mul( + rstd.unsqueeze(-1), + torch.ones((1, group, cpg), device=rstd.device), + ) + c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s + c3 = -c2 * mean - db_val * rstd * s + + c1 = c1.unsqueeze(-1) + c2 = _unsqueeze_to_dim(c2, 4) + c3 = _unsqueeze_to_dim(c3, 4) + d_input = ( + torch.mul(grad_output.reshape(N, group, cpg, HxW), c1) + + torch.mul(input.reshape(N, group, cpg, HxW), c2) + + c3 + ) + d_input = d_input.reshape(input.shape).to(input.dtype) + if output_mask[1]: + d_gamma = ( + ( + (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1)) + * rstd.unsqueeze(-1) + ) + .sum(dim=[0]) + .reshape(C) + ) + if output_mask[2]: + d_bias = db.sum(dim=[0]) + + return (d_input, d_gamma, d_bias) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_group_norm_backward.out) +def native_group_norm_backward_out( + grad_output: Tensor, + input: Tensor, + mean: Tensor, + rstd: Tensor, + gamma: Optional[Tensor], + N: int, + C: int, + HxW: int, + group: int, + output_mask: list[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + result = native_group_norm_backward( + grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]: + if x is not None: + return x.to(dtype) + return x + + +# TODO: Take a closer look at the type promotion semantics +@register_decomposition(aten.native_layer_norm_backward.default) +def native_layer_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + computation_dtype = utils.get_computation_dtype(input.dtype) + grad_out_cast, input_cast, weight_cast, bias_cast = ( + x.to(computation_dtype, memory_format=torch.contiguous_format) + if x is not None + else x + for x in (grad_out, input, weight, bias) + ) + assert grad_out_cast is not None + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices: list[int] = [] + outer_dim_indices: list[int] = [] + for i in range(input_ndim): + if i >= axis: + inner_dim_indices.append(i) + else: + outer_dim_indices.append(i) + + N = prod(inner_dims) # type: ignore[arg-type] + M = prod(outer_dims) # type: ignore[arg-type] + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): + return ( + input.new_zeros(input_shape) if output_mask[0] else None, + input.new_zeros(input_shape[axis:]) if output_mask[1] else None, + input.new_zeros(input_shape[axis:]) if output_mask[2] else None, + ) + mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr] + rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + x_hat = (input_cast - mean) * rstd + if weight_cast is not None: + grad_x_hat = grad_out_cast * weight_cast + else: + grad_x_hat = grad_out_cast + a = grad_x_hat * N + b = torch.sum(grad_x_hat, inner_dim_indices, True) + c1 = torch.mul(grad_x_hat, x_hat) + c2 = torch.sum(c1, inner_dim_indices, True) + c3 = torch.mul(x_hat, c2) + + inner = a - b - c3 + d_input: Optional[Tensor] = None + d_weight: Optional[Tensor] = None + d_bias: Optional[Tensor] = None + if output_mask[0]: + d_input = (rstd / N) * inner + + if output_mask[1] and weight_cast is not None: + if len(outer_dim_indices) > 0: + d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False) + else: + d_weight = grad_out_cast * x_hat + + if output_mask[2] and bias_cast is not None: + if len(outer_dim_indices) > 0: + d_bias = torch.sum(grad_out_cast, outer_dim_indices, False) + else: + d_bias = grad_out_cast.clone() + + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, input.dtype), + _maybe_cast(d_bias, input.dtype), + ) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_layer_norm_backward.out) +def native_layer_norm_backward_out( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: list[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + result = native_layer_norm_backward( + grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +def native_batch_norm_helper( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, + functional: bool, +) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + reduction_dims = [0] + list(range(2, input.dim())) + computation_dtype = utils.get_computation_dtype(input.dtype) + new_running_mean = running_mean + new_running_var = running_var + if training: + computation_dtype = utils.get_computation_dtype(input.dtype) + input_acc = input.to(dtype=computation_dtype) + biased_var, mean = torch.var_mean( + input_acc, dim=reduction_dims, correction=0, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + + output = (input - mean) * rstd + + save_mean = torch.squeeze(mean, reduction_dims) + save_rstd = torch.squeeze(rstd, reduction_dims) + if running_mean is not None: + new_running_mean = momentum * save_mean + (1 - momentum) * running_mean + if not functional: + running_mean.copy_(new_running_mean) + if running_var is not None: + n = input.numel() / input.shape[1] + # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction + # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose + # numerics probably don't matter. + squeezed_var = torch.squeeze(biased_var, reduction_dims) + unbiased_var = squeezed_var * (n / (n - 1)) + new_running_var = momentum * unbiased_var + (1 - momentum) * running_var + if not functional: + running_var.copy_(new_running_var) + else: + assert running_mean is not None and running_var is not None + running_mean = running_mean.to(dtype=computation_dtype, copy=True) + new_running_mean = running_mean + running_var = running_var.to(dtype=computation_dtype, copy=True) + new_running_var = running_var + mean = running_mean + invstd = 1 / (torch.sqrt(running_var + eps)) + # Very annoying inconsistency where CPU and CUDA give different shapes + if input.device.type != "cpu": + save_mean = running_mean + save_rstd = invstd + else: + save_mean = input.new_zeros((0,)) + save_rstd = input.new_zeros((0,)) + mean = _unsqueeze_to_dim(mean, input.dim() - 1) + invstd = _unsqueeze_to_dim(invstd, input.dim() - 1) + output = (input - mean) * invstd + + if weight is not None: + weight = weight.flatten() + weight = _unsqueeze_to_dim(weight, input.dim() - 1) + output = output * weight + + if bias is not None: + bias = bias.flatten() + bias = _unsqueeze_to_dim(bias, input.dim() - 1) + output = output + bias + + if input.device.type == "cpu": + save_mean = save_mean.to(dtype=input.dtype) + save_rstd = save_rstd.to(dtype=input.dtype) + return ( + output.to(dtype=input.dtype), + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) + + +@register_decomposition(aten.native_batch_norm) +@out_wrapper("out", "save_mean", "save_invstd") +def native_batch_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +# TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm +# with our new correctly schema'd _native_batch_norm_legit and its variants, but +# we cannot do that immediately in the C++ because it would be forwards incompatible +# with some mobile use cases. +# +# Since this change is most impactful for aot autograd/functionalization, we simply +# register this decomposition on the Autograd key for the python dispatcher (which is +# currently only used by aot autograd/functionalization and no one else, really). +# In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm +# to be _native_batch_norm_legit and have the right schema (stating that there are input mutations). +@aten.native_batch_norm.default.py_impl(DispatchKey.Autograd) +@aten.native_batch_norm.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def native_batch_norm_decomposition( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + if running_mean is None and running_var is None: + return aten._native_batch_norm_legit( + input, weight, bias, training, momentum, eps + ) + if running_mean is None: + raise RuntimeError( + "running_mean is None, but running_var is provided. " + "They should both be None or both be provided." + ) + if running_var is None: + raise RuntimeError( + "running_var is None, but running_mean is provided. " + "They should both be None or both be provided." + ) + if training: + # HACK: batch norm consolidation should clean this up so this op doesn't take in a training arg. + return aten._native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps + ) + else: + return aten._native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) + + +@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> list[Tensor]: + dim_size = tensor.size(dim) + split_size = (dim_size + chunks - 1) // chunks + + if split_size == 0 and dim_size == 0: + split_sizes = [split_size for _ in chunks] + split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) + return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim) + return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim) + + +@register_decomposition(aten._native_batch_norm_legit_no_training.default) +def _native_batch_norm_legit_no_training( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + return aten._native_batch_norm_legit.default( + input, + weight, + bias, + running_mean, + running_var, + False, # training + momentum, + eps, + ) + + +@register_decomposition(aten._native_batch_norm_legit.default) +def _native_batch_norm_legit( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit.no_stats) +def _native_batch_norm_legit_no_stats( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, None, None, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit_functional.default) +def _native_batch_norm_legit_functional( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + ( + output, + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, True + ) + assert new_running_mean is not None, "new_running_mean should not be None" + assert new_running_var is not None, "new_running_var should not be None" + return output, save_mean, save_rstd, new_running_mean, new_running_var + + +def _get_batch_norm_reserve_tensor( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + eps: float, + training: bool, +) -> Tensor: + """ + Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the + backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`, + which support a variety of backends including cudnn. We create this tensor here to get + the correct shape in the traced graph if we detect that will call the cudnn kernel, + and rely on DCE to avoid materializing this tensor. + """ + backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined] + input, weight, bias, running_mean, running_var, True, eps + ) + reserve_size = 0 + if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined] + reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size( # type: ignore[attr-defined] + input, training + ) + return torch.empty( + reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device + ) + + +@register_decomposition(aten._batch_norm_with_update.default) +def _batch_norm_with_update( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, + weight, + bias, + running_mean, + running_var, + True, # training + momentum, + eps, + False, # functional + ) + reserve = _get_batch_norm_reserve_tensor( + input, weight, bias, running_mean, running_var, eps, training=True + ) + return output, save_mean, save_rstd, reserve + + +@register_decomposition(aten._batch_norm_with_update_functional.default) +def _batch_norm_with_update_functional( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + ( + output, + save_mean, + save_rstd, + new_rm, + new_rv, + ) = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, True, momentum, eps, True + ) + reserve = _get_batch_norm_reserve_tensor( + input, weight, bias, running_mean, running_var, eps, training=True + ) + assert new_rm is not None, "new_running_mean should not be None" + assert new_rv is not None, "new_running_var should not be None" + return (output, save_mean, save_rstd, reserve, new_rm, new_rv) + + +@register_decomposition(aten._batch_norm_no_update.default) +def _batch_norm_no_update( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, + weight, + bias, + running_mean, + running_var, + False, # training + momentum, + eps, + False, # functional + ) + reserve = _get_batch_norm_reserve_tensor( + input, weight, bias, running_mean, running_var, eps, training=False + ) + return output, save_mean, save_rstd, reserve + + +@register_decomposition(aten._fused_dropout) +@out_wrapper("out0", "out1") +@pw_cast_for_opmath +def _fused_dropout_decomposition(input, p, generator=None): + assert generator is None + mask = (torch.rand_like(input) < p).to(dtype=torch.uint8) + res = mask.type_as(input) * input * (1.0 / p) + return (res, mask) + + +@register_decomposition(aten._to_copy) +@out_wrapper() +def _to_copy( + x: Union[Tensor, NumberType], + *, + dtype: Optional[torch.dtype] = None, + layout=None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + non_blocking: bool = False, + memory_format: Optional[torch.memory_format] = None, +): + assert not layout or layout == torch.strided, "TODO" + assert not pin_memory, "TODO" + assert isinstance(x, (torch.Tensor, int, float, bool, complex)) + if device is None and dtype is None and memory_format is None: + if isinstance(x, torch.Tensor): + return x.clone() + else: + return x + dtype_converted = False + + if isinstance(x, torch.Tensor): + x_tensor = x + else: + x_tensor = torch.scalar_tensor(x) + + if device is not None and device != x_tensor.device: + # avoid conversions on cpu + if dtype is not None and device.type == "cpu": + x_tensor = torch._prims.convert_element_type(x_tensor, dtype) + dtype_converted = True + x_tensor = torch._prims.device_put(x_tensor, device, non_blocking) + + if dtype is not None and not dtype_converted: + x_tensor = torch._prims.convert_element_type(x_tensor, dtype) + dtype_converted = True + + if memory_format is not None: # no ref/prim for memory format + return torch.clone(x_tensor, memory_format=memory_format) + return x_tensor + + +# Questionable decompositions +# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. +# Note that this decomposition causes issues with in-place ops +@register_decomposition([aten.detach, aten.lift, aten.lift_fresh]) +@out_wrapper() +def nop_decomposition(x): + return aten.alias(x) + + +# Also register to the Autograd dispatch key, so this decomp can run above autograd. +# native_batch_norm needs to decompose into other ops before autograd. +@aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.cudnn_batch_norm) +@out_wrapper("out0", "out1", "out2", "out3") +def cudnn_batch_norm( + input: Tensor, + weight: Tensor, + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +): + a, b, c = aten.native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + ) + # Cudnn return running mean and variance when training is True + if training: + return (a, b, c, input.new_zeros((0,), dtype=torch.uint8)) + return ( + a, + weight.new_zeros((0,)), + weight.new_zeros((0,)), + input.new_zeros((0,), dtype=torch.uint8), + ) + + +def _broadcast_batch_norm_backward(x, broadcast_mask): + for axis, mask in enumerate(broadcast_mask): + if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask): + x = x.unsqueeze(axis) + return x + + +@register_decomposition(aten.batch_norm_backward.default) +def batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: list[bool], + reserve: Tensor, +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + return native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ) + + +@register_decomposition(aten.native_batch_norm_backward.default) +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: list[bool], +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_dtype = input.dtype + if weight is not None: + weight_dtype = weight.dtype + else: + weight_dtype = input_dtype + computation_dtype = utils.get_computation_dtype(input.dtype) + ( + grad_out_cast, + input_cast, + weight_cast, + running_mean_cast, + running_var_cast, + save_mean_cast, + save_invstd_cast, + ) = ( + x.to(computation_dtype) if x is not None else x + for x in ( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + ) + ) + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(list(input_shape)) / input_shape[axis] + mean = save_mean_cast + invstd = save_invstd_cast + if train: + assert mean is not None and invstd is not None + + else: + assert running_mean_cast is not None and running_var_cast is not None + mean = running_mean_cast + invstd = torch.rsqrt(running_var_cast + eps) + + broadcast_mask: list[int] = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: list[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type] + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type] + dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator] + + grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask) + proj_scale = _broadcast_batch_norm_backward( + torch.mul(dot_p * norm, invstd * invstd), # type: ignore[operator] + broadcast_mask, + ) + + if weight_cast is None: + grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type] + else: + grad_scale = _broadcast_batch_norm_backward( + invstd * weight_cast, broadcast_mask + ) + + if train: + proj = (input_cast - mean) * proj_scale # type: ignore[operator] + grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out_cast * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + else: + grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp + + return ( + grad_input.to(input_dtype), + _maybe_cast(grad_weight, weight_dtype), + _maybe_cast(grad_bias, weight_dtype), + ) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_batch_norm_backward.out) +def native_batch_norm_backward_out( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: list[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + result = native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +@register_decomposition(aten.miopen_batch_norm_backward) +@out_wrapper("out0", "out1", "out2") +def miopen_batch_norm_backward( + input: Tensor, + grad_output: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + epsilon: float, +): + return aten.native_batch_norm_backward( + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + True, + epsilon, + [True, True, True], + ) + + +@register_decomposition(aten.cudnn_batch_norm_backward) +@out_wrapper("out0", "out1", "out2") +def cudnn_batch_norm_backward( + input: Tensor, + grad_output: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + epsilon: float, + reserveSpace: Tensor, +): + return aten.native_batch_norm_backward( + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + True, + epsilon, + [True, True, True], + ) + + +@register_decomposition(aten._adaptive_avg_pool2d) +@out_wrapper() +@pw_cast_for_opmath +def adaptive_avg_pool2d(input: Tensor, output_size: tuple[int, int]): + # Preconditions + device = input.device + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (3, 4), + lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}", + ) + for d in input.shape[-2:]: + torch._check( + d != 0, + lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for " + f"non-batch dimensions, but input has shape {tuple(shape)}.", + ) + + # Optimisation (we should also do this in the kernel implementation) + if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0: + stride = tuple(i // o for i, o in zip(shape[-2:], output_size)) + kernel = tuple( + i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride) + ) + return torch.nn.functional.avg_pool2d(input, kernel, stride) + + def start_index(a, b, c): + return torch.div(a * c, b, rounding_mode="trunc") + + def end_index(a, b, c): + return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc") + + def compute_idx(in_size, out_size): + orange = torch.arange(out_size, device=device, dtype=torch.int64) + i0 = start_index(orange, out_size, in_size) + # Let length = end_index - start_index, i.e. the length of the pooling kernels + # length.max() can be computed analytically as follows: + maxlength = in_size // out_size + 1 + in_size_mod = in_size % out_size + # adaptive = True iff there are kernels with different lengths + adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) + if adaptive: + maxlength += 1 + elif in_size_mod == 0: + maxlength -= 1 + + range_max = torch.arange(maxlength, device=device, dtype=torch.int64) + idx = i0.unsqueeze(-1) + range_max + if adaptive: + # Need to clamp to avoid accessing out-of-bounds memory + # TODO make minimum accept scalars + maxval = torch.scalar_tensor( + in_size - 1, dtype=idx.dtype, device=idx.device + ) + idx = torch.minimum(idx, maxval) + + # Compute the length + i1 = end_index(orange, out_size, in_size) + length = i1 - i0 + else: + length = maxlength + return idx, length, range_max, adaptive + + # length is not None if it's constant, otherwise we'll need to compute it + idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2]) + idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1]) + + vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw] + # Shortcut for the simpler case + if not adaptive_h and not adaptive_w: + return torch.mean(vals, dim=(-3, -1)) + + def maybe_mask(vals, length, range_max, adaptive, dim): + if isinstance(length, IntLike): + return vals, length + else: + # zero-out the things we didn't really want to select + assert dim < 0 + # hack + mask = range_max >= length.unsqueeze(-1) + if dim == -2: + mask = _unsqueeze_to_dim(mask, 4) + vals = torch.masked_fill(vals, mask, 0.0) + # Compute the length of each window + length = _unsqueeze_to_dim(length, -dim) + return vals, length + + vals, length_h = maybe_mask( + vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2 + ) + vals, length_w = maybe_mask( + vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1 + ) + + # We unroll the sum as we assume that the kernels are going to be small + ret = None + for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])): + if ret is None: + ret = vals[..., i, :, j] + else: + ret = ret + vals[..., i, :, j] + return ret / (length_h * length_w) + + +def _max_unpoolnd( + self: TensorLike, indices: TensorLike, output_size: list[int], dim: int +): + # If the input tensors self and indices came from max_pool call as + # required by the documentation, this operation is deterministic + # because that ensures that if there are two entries in `indices` + # tensor that are equal, the corresponding values in `self` are also + # equal. If this condition is not satisfied, the operation is + # non-deterministic as one of the different values in `self` 'wins'. + utils.alert_not_deterministic(f"max_unpooling{dim}d_forward_out") + nc = reduce(operator.mul, self.shape[:-dim]) + hw = reduce(operator.mul, output_size) + indices_nc_shape = [1] * self.ndim + indices_nc_shape[:-dim] = self.shape[:-dim] + indices_flat = ( + indices + aten.arange(nc, device=self.device).view(indices_nc_shape) * hw + ).reshape(-1) + + output = self.new_zeros(list(self.shape[:-dim]) + list(output_size)) + return aten._unsafe_index_put( + output.reshape(-1), [indices_flat], self.reshape(-1), accumulate=False + ).view(output.shape) + + +@register_decomposition(aten.max_unpool2d) +@out_wrapper() +def max_unpool2d( + self: TensorLike, + indices: TensorLike, + output_size: list[int], +): + torch._check( + indices.dtype == torch.int64, + lambda: f"elements in indices should be type int64 but got: {indices.dtype}", + ) + torch._check( + len(output_size) == 2, + lambda: ( + f"There should be exactly two elements (height, width) in output_size, " + f"but got {len(output_size)} elements." + ), + ) + + torch._check( + self.ndim in (3, 4), + lambda: ( + f"Input to max_unpooling2d should be a 3d or 4d Tensor, " + f"but got a tensor with {self.ndim} dimensions." + ), + ) + torch._check( + self.shape == indices.shape, + lambda: ( + f"Expected shape of indices to be same as that of the input tensor ({self.shape}) " + f"but got indices tensor with shape: {indices.shape}" + ), + ) + + for i in range(1, self.ndim): + torch._check( + self.size(i) > 0, + lambda: ( + f"max_unpooling2d(): " + f"Expected input to have non-zero size for non-batch dimensions, " + f"but got {self.shape} with dimension {i} being empty." + ), + ) + + return _max_unpoolnd(self, indices, output_size, 2) + + +@register_decomposition(aten.max_unpool3d) +@out_wrapper() +def max_unpool3d( + input: TensorLike, + indices: TensorLike, + output_size: list[int], + stride: list[int], + padding: list[int], +): + torch._check( + indices.dtype == torch.int64, lambda: "elements in indices should be type int64" + ) + torch._check( + input.ndim in (4, 5), + lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.", + ) + torch._check( + len(output_size) == 3, + lambda: ( + f"There should be exactly three elements (depth, height, width) in output_size, " + f"but got {len(output_size)} elements." + ), + ) + torch._check( + len(stride) == 3, + lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.", + ) + torch._check( + len(padding) == 3, + lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.", + ) + torch._check( + input.shape == indices.shape, + lambda: ( + f"Expected shape of indices to be same as that of the input tensor ({input.shape}) " + f"but got indices tensor with shape: {indices.shape}" + ), + ) + + for i in range(1, input.ndim): + torch._check( + input.size(i) > 0, + lambda: ( + f"max_unpooling3d(): " + f"Expected input to have non-zero size for non-batch dimensions, " + f"but got {input.shape} with dimension {i} being empty." + ), + ) + + torch._check( + stride[0] > 0 and stride[1] > 0 and stride[2] > 0, + lambda: f"strides should be greater than zero, but got stride: {stride}", + ) + + return _max_unpoolnd(input, indices, output_size, 3) + + +@register_decomposition(aten.index_add_) +def index_add_( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha) + + +@register_decomposition(aten.index_add) +@out_wrapper() +def index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha) + + +def _index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + inplace: bool, + alpha: NumberType = 1, +): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + index_size = index.size(0) if index.ndim == 1 else 1 + tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1 + torch._check( + tensor_size == index_size, + lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}", + ) + if alpha != 1: + python_type = utils.dtype_to_type(x.dtype) + torch._check( + python_type == bool + or utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + tensor = tensor * alpha + # Treat scalars as elements of \R^1 + zero_dim = x.ndim == 0 + x1 = x.unsqueeze(0) if zero_dim else x + idx = (None,) * dim + (index,) + index_put = aten.index_put_ if inplace else aten.index_put + out = index_put(x1, idx, tensor, accumulate=True) + if inplace: + return x + else: + return out.squeeze(0) if zero_dim else out.contiguous() + + +@register_decomposition(aten.pad_sequence.default) +@aten.pad_sequence.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def pad_sequence(sequences, batch_first=False, padding_value=0.0): + torch._check(len(sequences) > 0, lambda: "received an empty list of sequences") + sequences_size = len(sequences) + max_size = sequences[0].size() + trailing_dims = max_size[1:] + max_len = max(x.size(0) for x in sequences) + if batch_first: + out_dims = (sequences_size, max_len) + else: + out_dims = (max_len, sequences_size) + out_dims = out_dims + trailing_dims + out = sequences[0].new_full(out_dims, padding_value) + dim_paddings = (0, 0) * len(trailing_dims) + for i in range(sequences_size): + currseq = sequences[i] + row = aten.constant_pad_nd( + currseq, dim_paddings + (0, max_len - currseq.size(0)), padding_value + ) + if batch_first: + out = aten.select_scatter(out, row, dim=0, index=i) + else: + out = aten.select_scatter(out, row, dim=1, index=i) + return out + + +@register_decomposition(aten.index_copy_) +def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return _index_copy(x, dim, index, tensor, inplace=True) + + +@register_decomposition(aten.index_copy) +@out_wrapper() +def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return _index_copy(x, dim, index, tensor, inplace=False) + + +def _index_copy( + x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool +): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + # Treat scalars as elements of \R^1 + zero_dim = x.ndim == 0 + x1 = x.unsqueeze(0) if zero_dim else x + index = index.unsqueeze(0) if index.ndim == 0 else index + idx = (None,) * dim + (index,) + index_put = aten.index_put_ if inplace else aten.index_put + out = index_put(x1, idx, tensor) + if inplace: + return x + else: + return out.squeeze(0) if zero_dim else out.contiguous() + + +# nb: Should use acc_t, not op_math +@register_decomposition(aten.log_sigmoid_forward) +@out_wrapper("output", "buffer") +@pw_cast_for_opmath +def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]: + min = torch.minimum(self.new_zeros(()), self) + z = torch.exp(-torch.abs(self)) + if self.is_cuda or self.is_xpu: + buffer = self.new_zeros((0,)) + else: + buffer = z + return min - torch.log1p(z), buffer + + +@register_decomposition(aten.uniform) +@out_wrapper() +def uniform( + x: Tensor, + low: Union[bool, int, float] = 0.0, + high: Union[bool, int, float] = 1.0, + generator: Optional[torch.Generator] = None, +): + return prims._uniform_helper( + x.shape, + low=sym_float(low), + high=sym_float(high), + dtype=x.dtype, + device=x.device, + generator=generator, + ) + + +@register_decomposition(aten.uniform_) +def uniform_(self, low=0, high=1, generator=None): + return self.copy_(uniform(self, low, high, generator)) + + +# aten/src/ATen/native/UpSample.cpp compute_output_size +def upsample_compute_output_size(input_size, output_size, scale_factors): + spatial_dimensions = len(input_size) - 2 + if output_size is not None: + torch._check( + scale_factors is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + torch._check(len(output_size) == spatial_dimensions, lambda: "") + return output_size + if scale_factors is not None: + # NB: this isn't necessary lol + torch._check( + output_size is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + torch._check(len(scale_factors) == spatial_dimensions, lambda: "") + output_size = [] + for i, s in enumerate(scale_factors): + if int(s) == s: + output_size.append(input_size[i + 2] * int(s)) + else: + output_size.append(sym_int(input_size[i + 2] * s)) + return output_size + torch._check( + False, lambda: "Must specify exactly one of output_size and scale_factors" + ) + + +def get_scale_value(scales, idx): + if scales is None: + return None + return scales[idx] + + +@register_decomposition(aten.upsample_nearest1d.vec) +@register_decomposition(aten.upsample_nearest2d.vec) +@register_decomposition(aten.upsample_nearest3d.vec) +@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_nearest_vec( + input: Tensor, + output_size: Optional[list[int]], + scale_factors: Optional[list[float]], +) -> Tensor: + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = ( + scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item] + ) + return _upsample_nearest(input, osize, scales) + + +@register_decomposition(aten._upsample_nearest_exact1d.vec) +@register_decomposition(aten._upsample_nearest_exact2d.vec) +@register_decomposition(aten._upsample_nearest_exact3d.vec) +@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd) +@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd) +@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_nearest_exact_vec( + input: Tensor, + output_size: Optional[list[int]], + scale_factors: Optional[list[float]], +) -> Tensor: + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = ( + scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item] + ) + return _upsample_nearest(input, osize, scales, exact=True) + + +def _compute_upsample_nearest_indices(input, output_size, scales, exact=False): + # For each dim in output_size, compute the set of input indices used + # to produce the upsampled output. + indices = [] + num_spatial_dims = len(output_size) + offset = 0.5 if exact else 0.0 + + for d in range(num_spatial_dims): + # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp + # + # Indices are computed as following: + # scale = isize / osize + # Case: exact=False + # input_index = floor(output_index * scale) + # Same as OpenCV INTER_NEAREST + # + # Case: exact=False + # index_f32 = (output_index + 0.5) * scale - 0.5 + # input_index = round(index_f32) + # Same as Pillow and Scikit-Image/Scipy ndi.zoom + osize = output_size[d] + isize = input.shape[-num_spatial_dims + d] + scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize + + output_indices = torch.arange(osize, dtype=torch.float32, device=input.device) + input_indices = ((output_indices + offset) * scale).to(torch.int64) + for _ in range(num_spatial_dims - 1 - d): + input_indices = input_indices.unsqueeze(-1) + indices.append(input_indices) + return indices + + +@register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out]) +@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest1d( + input: Tensor, + output_size: list[int], + scales: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales]) + + +@register_decomposition( + [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out] +) +@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest_exact1d( + input: Tensor, + output_size: list[int], + scales: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales], exact=True) + + +@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out]) +@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest2d( + input: Tensor, + output_size: list[int], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales_h, scales_w]) + + +@register_decomposition( + [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out] +) +@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def _upsample_nearest_exact2d( + input: Tensor, + output_size: list[int], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True) + + +@register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out]) +@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest3d( + input: Tensor, + output_size: list[int], + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w]) + + +@register_decomposition( + [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out] +) +@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def _upsample_nearest_exact3d( + input: Tensor, + output_size: list[int], + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest( + input, output_size, [scales_d, scales_h, scales_w], exact=True + ) + + +@pw_cast_for_opmath +def _upsample_nearest( + input: Tensor, + output_size: list[int], + scales: list[Optional[float]], + exact: bool = False, +) -> Tensor: + spatial_indices = _compute_upsample_nearest_indices( + input, output_size, scales, exact=exact + ) + + indices = [None, None] + spatial_indices + result = aten._unsafe_index(input, indices) + + if result.ndim == 4: + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + n_channels = input.shape[1] + if input.device.type == "cuda" and n_channels < 4: + memory_format = torch.contiguous_format + + result = result.contiguous(memory_format=memory_format) + return result + + +def gather_params(params, has_biases, has_projections): + if has_biases and has_projections: + group_size = 5 + elif has_biases: + group_size = 4 + elif has_projections: + group_size = 3 + else: + group_size = 2 + + assert len(params) % group_size == 0, len(params) + return [ + tuple(params[i : i + group_size]) for i in range(0, len(params), group_size) + ] + + +def params_hiddens(params, hiddens, i, bidirectional): + if bidirectional: + cur_params, cur_hidden = params[2 * i], hiddens[2 * i] + bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1] + else: + cur_params, cur_hidden = params[i], hiddens[i] + bidir_params, bidir_hidden = None, None + + return cur_params, cur_hidden, bidir_params, bidir_hidden + + +def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens): + assert last_batch_size > batch_size + hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size)) + return cur_hidden.narrow(0, 0, batch_size) + + +def update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, batch_size, inp_hidden +): + if last_batch_size == batch_size: + return cur_hidden + assert last_batch_size < batch_size + return torch.concat( + ( + cur_hidden, + inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size), + ) + ) + + +def one_layer_rnn_data( + inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False +): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + + step_output = [] + hiddens: list[torch.Tensor] = [] + + last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] + cur_hidden = hidden.narrow(0, 0, last_batch_size) + split_inp = torch.split(inp, list(batch_sizes)) + if reverse: + split_inp = split_inp[::-1] + for inp in split_inp: + i = inp.shape[0] + + if last_batch_size == i: + pass # don't update cur_hidden + # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest + elif reverse: + cur_hidden = update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, i, hidden + ) + else: + cur_hidden = update_hidden_for_packed( + cur_hidden, last_batch_size, i, hiddens + ) + + cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) + last_batch_size = i + step_output.append(cur_hidden) + + if reverse: + step_output.reverse() + else: + hiddens.append(cur_hidden) + hiddens.reverse() + + out = torch.cat(step_output, 0) + hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden + return out, hidden_out + + +def rnn_cell(nonlinearity): + def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) + + return inner + + +def rnn_cell_data(nonlinearity): + def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + i = F.linear(i, ih_weight, ih_bias) + return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) + + return inner + + +def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + + precomputed_input = F.linear(inp, ih_weight, ih_bias) + precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input + cur_hidden = hidden.unsqueeze(0) + step_output = [] + for i in precomputed_input: + cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) + step_output.append(cur_hidden) + + if reverse: + step_output.reverse() + + out = torch.cat(step_output, 0) + + return out, cur_hidden.squeeze(0) + + +def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False): + w0 = params[0] + w1 = params[1] + if has_biases: + w2 = params[2] + w3 = params[3] + else: + w2 = torch.zeros(w0.size()) + w3 = torch.zeros(w1.size()) + + hx = hidden[0].unsqueeze(0) + cx = hidden[1].unsqueeze(0) + + batch_sizes: list[int] = [] + mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2 + hidden_size = hx.size(2) + num_layers = 1 + + # _rnn_helper already handles bidirectional and batch_first so we hard-code them to False here + bidirectional = False + batch_first = False + + train = False + # If batch_first, inp has been permuted in _rnn_helper. Convert to contiguous here. + # Same as aten/src/ATen/native/mkldnn/RNN.cpp: mkldnn_rnn: input = input.contiguous(); + inp = inp.contiguous() + hx = hx.contiguous() + cx = cx.contiguous() + outputs = torch.ops.aten.mkldnn_rnn_layer.default( + inp, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ) + y, hy, cy = outputs[0], outputs[1], outputs[2] + return y, (hy.squeeze(0), cy.squeeze(0)) + + +def _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + layer_fn, +): + input = input.transpose(0, 1) if batch_first else input + final_hiddens = [] + + for i in range(num_layers): + cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens( + params, hidden, i, bidirectional + ) + dropout = dropout if (train and num_layers < i - 1) else 0.0 + fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases) + final_hiddens.append(fwd_hidden) + + if bidirectional: + bwd_inp, bwd_hidden = layer_fn( + input, bidir_hidden, bidir_params, has_biases, reverse=True + ) + final_hiddens.append(bwd_hidden) + + if bidirectional: + input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined] + else: + input = fwd_inp + + if dropout != 0 and train and i < num_layers - 1: + input = torch.dropout(input, dropout, train=True) + + input = input.transpose(0, 1) if batch_first else input + return input, final_hiddens + + +@register_decomposition(aten.rnn_tanh.input) +@aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_tanh.input.py_impl(DispatchKey.Autograd) +def rnn_tanh_input( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_relu.input) +@aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_relu.input.py_impl(DispatchKey.Autograd) +def rnn_relu_input( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_relu.data) +@aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_relu.data.py_impl(DispatchKey.Autograd) +def rnn_relu_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial( + one_layer_rnn_data, + batch_sizes=batch_sizes, + hidden_fn=rnn_cell_data(torch.relu), + ), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_tanh.data) +@aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_tanh.data.py_impl(DispatchKey.Autograd) +def rnn_tanh_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial( + one_layer_rnn_data, + batch_sizes=batch_sizes, + hidden_fn=rnn_cell_data(torch.tanh), + ), + ) + return out, torch.stack(final_hiddens, 0) + + +def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim): + gates = F.linear(hx, hh_weight, hh_bias) + inp + chunked_gates = gates.chunk(4, chunk_dim) + in_gate = chunked_gates[0].sigmoid() + forget_gate = chunked_gates[1].sigmoid() + cell_gate = chunked_gates[2].tanh() + out_gate = chunked_gates[3].sigmoid() + cy = forget_gate * cx + (in_gate * cell_gate) + hy = out_gate * cy.tanh() + hy = hy if hr_weight is None else F.linear(hy, hr_weight, None) + + return hy, cy + + +def one_layer_lstm(inp, hidden, params, has_biases, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + hr_weight = ( + params[4] if len(params) == 5 else params[2] if len(params) == 3 else None + ) + + hx = hidden[0].unsqueeze(0) + cx = hidden[1].unsqueeze(0) + + precomputed_input = F.linear(inp, ih_weight, ih_bias) + precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input + step_output = [] + for inp in precomputed_input: + hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2) + step_output.append(hx) + + if reverse: + step_output.reverse() + + out = torch.cat(step_output, 0) + + return out, (hx.squeeze(1), cx.squeeze(1)) + + +def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + hr_weight = ( + params[4] if len(params) == 5 else params[2] if len(params) == 3 else None + ) + + step_output = [] + hiddens = [] + + last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] + split_inp = torch.split(inp, list(batch_sizes)) + if reverse: + split_inp = split_inp[::-1] + + orig_hx = hidden[0] + orig_cx = hidden[1] + hx, cx = ( + orig_hx.narrow(0, 0, last_batch_size), + orig_cx.narrow(0, 0, last_batch_size), + ) + + for inp in split_inp: + i = inp.shape[0] + inp = F.linear(inp, ih_weight, ih_bias) + + # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest + if i < last_batch_size: + hiddens.append( + ( + hx.narrow(0, i, last_batch_size - i), + cx.narrow(0, i, last_batch_size - i), + ) + ) + hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i) + + # this will only happen when reverse=True + if i > last_batch_size: + hx = torch.concat( + (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0 + ) + cx = torch.concat( + (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0 + ) + + hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1) + last_batch_size = i + step_output.append(hx) + + if reverse: + step_output.reverse() + hidden_out = (hx, cx) + else: + hiddens.append((hx, cx)) + hiddens.reverse() + hidden0, hidden1 = zip(*hiddens) + hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0) + + out = torch.cat(step_output, 0) + return out, hidden_out + + +def select_one_layer_lstm_function(input, hx, params): + r"""Check whether we could use decompose lstm with mkldnn_rnn_layer. + All the below conditions need to be met: + * ``torch._C._get_mkldnn_enabled()`` returns ``True``. + * All the input args are on CPU. + * The dtypes of args are either torch.float or torch.bfloat16. + * Inference. + * ``has_projections`` returns ``False``. + + Args: + * input: the input sequence to LSTM + * hx: a tuple of the input hidden state and cell state ``(h_0, c_0)`` to LSTM + * params: the weight and bias tensors of LSTM + """ + + def use_mkldnn(input, hx, params): + if not torch._C._get_mkldnn_enabled(): + return False + + tensors = [input] + list(hx) + list(chain.from_iterable(params)) + devices = {t.device for t in tensors} + if len(devices) != 1: + return False + + device = devices.pop() + if device != torch.device("cpu"): + return False + # With autocast, possible to have mixed dtype here + dtypes = {t.dtype for t in tensors} + for dtype in dtypes: + if dtype not in [torch.float, torch.bfloat16]: + return False + + if input.requires_grad: + return False + + has_projections = hx[0].size(2) != hx[1].size(2) + if has_projections: + return False + + return True + + # mkldnn_one_layer_lstm does not depend on seq_len while one_layer_lstm + # will expand over the seq_len dim + if use_mkldnn(input, hx, params): + return mkldnn_one_layer_lstm + else: + return one_layer_lstm + + +@register_decomposition(aten.lstm.input) +@aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.lstm.input.py_impl(DispatchKey.Autograd) +def lstm_impl( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + assert len(hx) == 2, "lstm expects two hidden states" + params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) + hidden = list(zip(hx[0], hx[1])) + layer_fn = select_one_layer_lstm_function(input, hx, params) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + layer_fn, + ) + final_hiddens = list(zip(*final_hiddens)) + return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) + + +@register_decomposition(aten.lstm.data) +@aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.lstm.data.py_impl(DispatchKey.Autograd) +def lstm_data_impl( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + assert len(hx) == 2, "lstm expects two hidden states" + params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) + hidden = list(zip(hx[0], hx[1])) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_lstm_data, batch_sizes=batch_sizes), + ) + final_hiddens = list(zip(*final_hiddens)) + return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) + + +def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + chunked_igates = inp.chunk(3, 1) + chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2) + reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() + input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() + new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() + return (cur_hidden - new_gate) * input_gate + new_gate + + +def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1) + chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1) + reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() + input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() + new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() + return (cur_hidden - new_gate) * input_gate + new_gate + + +@register_decomposition(aten.gru.data) +@aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.gru.data.py_impl(DispatchKey.Autograd) +def gru_impl_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hx.unbind(0), + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.gru.input) +@aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.gru.input.py_impl(DispatchKey.Autograd) +def gru_impl( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hx.unbind(0), + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=gru_cell), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten._upsample_bilinear2d_aa.vec) +@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd) +def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + return torch.ops.aten._upsample_bilinear2d_aa( + input, osize, align_corners, scale_h, scale_w + ) + + +@register_decomposition(aten._upsample_bicubic2d_aa.vec) +@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.Autograd) +def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + return torch.ops.aten._upsample_bicubic2d_aa( + input, osize, align_corners, scale_h, scale_w + ) + + +@register_decomposition(aten.upsample_bilinear2d.vec) +@register_decomposition(aten.upsample_trilinear3d.vec) +@aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_linear1d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_linear_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = scale_factors if scale_factors else [None] * len(osize) + return _upsample_linear(input, osize, align_corners, scales) + + +@register_decomposition([aten.upsample_linear1d.default, aten.upsample_linear1d.out]) +@out_wrapper() +def upsample_linear1d( + input: Tensor, + output_size: list[int], + align_corners: bool, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear(input, output_size, align_corners, [scales_w]) + + +@register_decomposition( + [aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out] +) +@aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +def upsample_bilinear2d( + input: Tensor, + output_size: list[int], + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear(input, output_size, align_corners, [scales_h, scales_w]) + + +@register_decomposition( + [aten.upsample_trilinear3d.default, aten.upsample_trilinear3d.out] +) +@out_wrapper() +def upsample_trilinear3d( + input: Tensor, + output_size: list[int], + align_corners: bool, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear( + input, output_size, align_corners, [scales_d, scales_h, scales_w] + ) + + +def _compute_scale(in_size, out_size, align_corners, scale=None): + if align_corners: + return (in_size - 1.0) / (out_size - 1.0) if out_size > 1 else 0 + else: + return 1.0 / scale if scale is not None and scale > 0 else in_size / out_size + + +def _compute_source_index(scale, dst_index, align_corners): + if align_corners: + return scale * dst_index + else: + return scale * (dst_index + 0.5) - 0.5 + + +def _sum_tensors_uint8( + src: Iterable[Tensor], weights: Iterable[Tensor], weights_precision: Tensor +) -> Tensor: + output = _sum_tensors( + s.to(torch.int32) * c.to(torch.int32) for s, c in zip(src, weights) + ) + (1 << (weights_precision - 1)) + output = output >> weights_precision + return torch.clamp(output, 0, 255).to(torch.uint8) + + +def _compute_weight_precision(weights: TensorSequenceType) -> Tensor: + max_weight = torch.stack(weights).max() + max_weight_precision = 22 + precisions = torch.arange(max_weight_precision, device=max_weight.device) + values = 0.5 + max_weight * (1 << (precisions + 1)) + mask = values >= (1 << 15) + return max_weight_precision - mask.sum() + + +@pw_cast_for_opmath +def _upsample_linear( + input: Tensor, + output_size: list[int], + align_corners: bool, + scales: list[Optional[float]], +) -> Tensor: + # get dimensions of original image + n_channels = input.shape[1] + inp_sizes = input.shape[2:] + n_dims = len(inp_sizes) + + _, dtype = utils.elementwise_dtypes( + input, + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + + def get_values(inp_size, out_size, scales, nsqueeze): + # First Calculate scaling factor + scale_factor = _compute_scale(inp_size, out_size, align_corners, scales) + # We have to create arange with int64 dtype and use .to in order to avoid + # additional kernels creation in inductor and get a perf slowdown + i = torch.arange(out_size, device=input.device).to(dtype=dtype) + + x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0) + x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze)) + x = x_f32.to(torch.int64) + xp1 = (x + 1).clamp(max=inp_size - 1) + return x_f32, x, xp1 + + values = [ + get_values(inp_size, out_size, scales, n_dims - 1 - i) + for i, (inp_size, out_size, scales) in enumerate( + zip(inp_sizes, output_size, scales) + ) + ] + xs_f32, xs, xp1s = list(zip(*values)) + + vs = [] + for a in product(*[[0, 1]] * n_dims): + idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)] + v = aten._unsafe_index(input, idx) + v = _maybe_convert_to_dtype(v, dtype) + vs.append(v) + + for i in reversed(range(n_dims)): + xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype) + vs = [ + # x1 * (1 - alpha) + x2 * alpha == x1 + (x2 - x1) * alpha + v1 + torch.mul(v2 - v1, xscale) + for v1, v2 in zip(vs[::2], vs[1::2]) + ] + + assert len(vs) == 1 + result = vs[0] + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + if input.device.type == "cuda" and n_channels < 16: + memory_format = torch.contiguous_format + + assert isinstance(result, torch.Tensor) + + result = result.contiguous(memory_format=memory_format) + + if not input.is_floating_point(): + result = result.round() + + return result + + +# We should be applying decompositions after all transformations +@register_decomposition(aten.is_same_size.default) +def is_same_size(a: Tensor, b: Tensor) -> bool: + return a.shape == b.shape + + +@register_decomposition([aten._reshape_alias, aten._unsafe_view]) +@out_wrapper() +def _reshape_alias(x, shape, *args): + return aten.view(x, shape) + + +@register_decomposition([aten._unsafe_index]) +def _unsafe_index(x, indices): + return aten.index(x, indices) + + +@register_decomposition([aten._unsafe_index_put]) +def _unsafe_index_put(x, indices, value, accumulate=False): + return aten.index_put(x, indices, value, accumulate) + + +@register_decomposition([aten._unsafe_masked_index]) +def _unsafe_masked_index(x, mask, indices, fill): + for index in indices: + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int], + lambda: "tensors used as indices must be long or int tensors", + ) + + torch._check( + mask.dtype == torch.bool, + lambda: "tensors used as masks must be bool tensors", + ) + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(x.numel() == 0): + meta_result = torch._meta_registrations.meta_index_Tensor(x, indices) + return x.new_full(meta_result.shape, fill) + + for i in range(len(indices)): + index = indices[i] + if index is not None: + indices[i] = index.clamp(min=0, max=x.size(i) - 1) + + return aten._unsafe_index(x, indices).masked_fill(~mask, fill) + + +@register_decomposition([aten._unsafe_masked_index_put_accumulate]) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + for index in indices: + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int], + lambda: "tensors used as indices must be long or int tensors", + ) + + torch._check( + mask.dtype == torch.bool, + lambda: "tensors used as masks must be bool tensors", + ) + + if x.numel() == 0: + return x.clone() + + for i in range(len(indices)): + index = indices[i] + if index is not None: + indices[i] = index.clamp(min=-x.size(i), max=x.size(i) - 1) + + masked_value = values.masked_fill(~mask, 0) + return aten._unsafe_index_put(x, indices, masked_value, accumulate=True) + + +def _nll_loss_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> tuple[Tensor, Tensor]: + # self can be [N, C] or [C] + # target can be [N] or [] + + n_dims = self.dim() + channel_dim = 1 + if n_dims < 2: + channel_dim = 0 + + if weight is not None: + if n_dims > 1: + shape = [ + 1, + ] * n_dims + shape[channel_dim] = weight.shape[0] + w = weight.view(shape) + else: + w = weight + self = self * w + safe_target = torch.where(target != ignore_index, target, 0) + safe_target_ = safe_target.unsqueeze(channel_dim) + # target can be [N, 1] or [1] + + result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) + + result = torch.where(target != ignore_index, result, 0) + + if reduction == Reduction.NONE.value and n_dims > 1: + total_weight = self.new_full((), 0.0) + return result, total_weight + + if weight is not None: + w = w.expand(self.shape) + wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) + wsum = torch.where(target != ignore_index, wsum, 0) + total_weight = wsum.sum() + else: + total_weight = (target != ignore_index).sum().to(self) + + if reduction == Reduction.SUM.value: + result = result.sum() + elif reduction == Reduction.MEAN.value: + result = result.sum() / total_weight + + return result, total_weight + + +@register_decomposition(aten.nll_loss_forward) +@out_wrapper("output", "total_weight") +def nll_loss_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> tuple[Tensor, Tensor]: + assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D" + assert target.dim() <= 1, ( + "0D or 1D target tensor expected, multi-target not supported" + ) + + no_batch_dim = self.dim() == 1 and target.dim() == 0 + assert no_batch_dim or (self.shape[0] == target.shape[0]), ( + f"size mismatch (got input: {self.shape}, target: {target.shape})" + ) + + n_classes = self.shape[-1] + + assert weight is None or (weight.dim() == 1 and weight.numel() == n_classes), ( + f"weight tensor should be defined either for all {n_classes} classes or no classes " + f"but got weight tensor of shape: {weight.shape}" + ) + + return _nll_loss_forward(self, target, weight, reduction, ignore_index) + + +@register_decomposition(aten.nll_loss2d_forward) +@out_wrapper("output", "total_weight") +def nll_loss2d_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> tuple[Tensor, Tensor]: + return _nll_loss_forward(self, target, weight, reduction, ignore_index) + + +# These are adapted from aten/src/ATen/native/UpSample.h, wich is based on +# https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor: + return ((A + 2) * x - (A + 3)) * x * x + 1 + + +def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor: + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A + + +def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType: + A = -0.75 + + if t.device == torch.device("cpu"): + tt1 = torch.stack([t, 1.0 - t], dim=0) + tt2 = torch.stack([t + 1.0, 2.0 - t], dim=0) + w03 = _upsample_cubic_convolution2(tt2, A) + w12 = _upsample_cubic_convolution1(tt1, A) + w0, w3 = torch.unbind(w03, dim=0) + w1, w2 = torch.unbind(w12, dim=0) + return w0, w1, w2, w3 + else: + return ( + _upsample_cubic_convolution2(t + 1.0, A), + _upsample_cubic_convolution1(t, A), + _upsample_cubic_convolution1(1.0 - t, A), + _upsample_cubic_convolution2(2.0 - t, A), + ) + + +def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor: + coeffs2 = _upsample_get_cubic_coefficients(ts) + return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2)) + + +# Need this instead of just sum() to keep mypy happy +def _sum_tensors(ts: Iterable[Tensor]) -> Tensor: + return reduce(torch.add, ts) + + +def _linspace_from_neg_one( + num_steps: int, align_corners: bool, dtype: torch.dtype, device: torch.device +): + if num_steps <= 1: + return torch.tensor(0, device=device, dtype=dtype) + + a = ((num_steps - 1) / num_steps) if not align_corners else 1 + return torch.linspace(-a, a, steps=num_steps, device=device, dtype=dtype) + + +def _make_base_grid_4d(theta: Tensor, h: int, w: int, align_corners: bool): + dtype = theta.dtype + device = theta.device + + # Using padding and summation generates a single kernel vs using torch.stack where 3 kernels generated + # corresponding to each individual tensor: grid_x, grid_y, grid_one + grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, w, 1) + grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(h, 1, 1) + grid_one = torch.ones((1, 1, 1), dtype=dtype, device=device) + + # this is just a temporary hack and we should use torch.stack here once #104480 is merged + grid_x = torch.nn.functional.pad(grid_x, pad=(0, 2), mode="constant", value=0) + grid_y = torch.nn.functional.pad(grid_y, pad=(1, 1), mode="constant", value=0) + grid_one = torch.nn.functional.pad(grid_one, pad=(2, 0), mode="constant", value=0) + return grid_x + grid_y + grid_one + + +def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: bool): + dtype = theta.dtype + device = theta.device + + grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, 1, w, 1) + grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(1, h, 1, 1) + grid_z = _linspace_from_neg_one(d, align_corners, dtype, device).view(d, 1, 1, 1) + grid_one = torch.ones((1, 1, 1, 1), dtype=dtype, device=device) + + # this is just a temporary hack and we should use torch.stack here once #104480 is merged + grid_x = torch.nn.functional.pad(grid_x, pad=(0, 3), mode="constant", value=0) + grid_y = torch.nn.functional.pad(grid_y, pad=(1, 2), mode="constant", value=0) + grid_z = torch.nn.functional.pad(grid_z, pad=(2, 1), mode="constant", value=0) + grid_one = torch.nn.functional.pad(grid_one, pad=(3, 0), mode="constant", value=0) + return grid_x + grid_y + grid_z + grid_one + + +def _affine_grid_generator_4d(theta: Tensor, size: list[int], align_corners: bool): + n, _, h, w = size + base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners) + # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3) + # We do manually a matrix multiplication which is faster than mm() + # (h * w, 3, 1) * (n, 1, 3, 2) -> (n, h * w, 2) + grid = (base_grid.view(-1, 3, 1) * theta.mT.unsqueeze(1)).sum(-2) + return grid.view(n, h, w, 2) + + +def _affine_grid_generator_5d(theta: Tensor, size: list[int], align_corners: bool): + n, _, d, h, w = size + base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners) + # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4) + # We do manually a matrix multiplication which is faster than mm() + # (d * h * w, 4, 1) * (n, 1, 4, 3) -> (n, h * w, 3) + grid = (base_grid.view(-1, 4, 1) * theta.mT.unsqueeze(1)).sum(-2) + return grid.view(n, d, h, w, 3) + + +@register_decomposition(aten.affine_grid_generator) +@out_wrapper() +@pw_cast_for_opmath +def affine_grid_generator(theta: Tensor, size: list[int], align_corners: bool): + torch._check( + len(size) in (4, 5), + lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.", + ) + if len(size) == 4: + return _affine_grid_generator_4d(theta, size, align_corners=align_corners) + else: + return _affine_grid_generator_5d(theta, size, align_corners=align_corners) + + +def _grid_sampler_2d( + a: Tensor, + grid: Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, + _expand_grid: bool = True, +) -> Tensor: + # This method is a copy of grid_sampler_2d implementation and introduced with additional arg _expand_grid to + # optionally expand the input grid for performance reasons. + # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x + # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) + # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. + # Thus we apply this hack to not expand the grid for this case. + + torch._check( + interpolation_mode in (0, 1, 2), + lambda: f"Invalid interpolation mode {interpolation_mode}", + ) + torch._check( + padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}" + ) + + def unnormalize(coords: Tensor, size: int) -> Tensor: + # Rescale coordinates from [-1, 1] to: + # [0, size - 1] if align_corners is True + # [-.5, size -.5] if align_corners is False + mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) + ofs = size * 0.5 - 0.5 + return coords * mul + ofs + + # Reflects coordinates until they fall between low and high (inclusive). + # The bounds are passed as twice their value so that half-integer values + # can be represented as ints. + def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor: + if twice_low == twice_high: + return torch.zeros_like(coords) + coords_min = twice_low / 2 + coords_span = (twice_high - twice_low) / 2 + coords2 = (coords - coords_min).abs() + extra = torch.fmod(coords2, coords_span) + flips = (coords2 / coords_span).floor().to(dtype=torch.int8) + return torch.where( + flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra + ) + + def compute_coordinates(coords: Tensor, size: int) -> Tensor: + if padding_mode == 0: # Zero + return coords + elif padding_mode == 1: # Borders + return torch.clamp(coords, 0, size - 1) + else: # padding_mode == 2, Reflection + if align_corners: + coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1)) + else: + coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) + return torch.clamp(coords_reflected, 0, size - 1) + + def compute_source_index(coords: Tensor, size: int) -> Tensor: + coords_un = unnormalize(coords, size) + return compute_coordinates(coords_un, size) + + N, C, iH, iW = a.shape + _, oH, oW, two = grid.shape + assert two == 2 + + if _expand_grid: + # Let's expand grid to [N, C, oH, oW, 2] + # This allows to generate a single triton cuda kernel instead of two kernels. + # Two kernels are due source indices, weights have shape (N, 1, oH, oW), xnumel=N*oH*oW + # and output has shape (N, C, oH, oW), xnumel=N*C*oH*oW + # Expanding grid to (N, C, oH, oW, two) unifies xnumel to N*C*oH*oW + grid = grid.view(N, 1, oH, oW, two).expand(N, C, oH, oW, 2) + + def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor: + return torch.logical_and( + 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH)) + ) + + N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1) + C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1) + + def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType: + cond = in_bounds_cond(xs, ys) + # To clip to inside valid coordinates, we map the coordinates + # to (x, y) = (0, 0) and also set the weight to 0 + # We also change the shape of the tensor to the appropriate one for + # broadcasting with N_idx, C_idx for the purposes of advanced indexing + c = C if _expand_grid else 1 + return tuple( + torch.where(cond, t, 0).view(N, c, oH, oW) + for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws) + ) + + def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor: + # Perform clipping, index into input tensor and multiply by weight + idx_x, idx_y, w_ = clip(ix, iy, w) + return a[N_idx, C_idx, idx_y, idx_x] * w_ + + x = grid[..., 0] + y = grid[..., 1] + + if interpolation_mode == 0: # Bilinear + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + + ix_nw, iy_nw = ix.floor(), iy.floor() + ix_ne, iy_ne = ix_nw + 1, iy_nw + ix_sw, iy_sw = ix_nw, iy_nw + 1 + ix_se, iy_se = ix_ne, iy_sw + + w_nw = (ix_se - ix) * (iy_se - iy) + w_ne = (ix - ix_sw) * (iy_sw - iy) + w_sw = (ix_ne - ix) * (iy - iy_ne) + w_se = (ix - ix_nw) * (iy - iy_nw) + + return _sum_tensors( + get_summand(ix, iy, w) + for (ix, iy, w) in ( + (ix_nw, iy_nw, w_nw), + (ix_ne, iy_ne, w_ne), + (ix_sw, iy_sw, w_sw), + (ix_se, iy_se, w_se), + ) + ) + elif interpolation_mode == 1: # Nearest + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + + ix_nearest = ix.round() + iy_nearest = iy.round() + + return get_summand(ix_nearest, iy_nearest, 1) + else: # interpolation_mode == 2, Bicubic + ix = unnormalize(x, iW) + iy = unnormalize(y, iH) + + ix_nw = ix.floor() + iy_nw = iy.floor() + + tx = ix - ix_nw + ty = iy - iy_nw + + if not _expand_grid: + tx = tx.unsqueeze(1) + ty = ty.unsqueeze(1) + + def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor: + x = compute_coordinates(ix, iW) + y = compute_coordinates(iy, iH) + return get_summand(x, y, 1) + + def get_coeff(ofs: int) -> Tensor: + iy_ofs = iy_nw + (ofs - 1) + cs = ( + get_value_bounded(ix_nw - 1, iy_ofs), + get_value_bounded(ix_nw, iy_ofs), + get_value_bounded(ix_nw + 1, iy_ofs), + get_value_bounded(ix_nw + 2, iy_ofs), + ) + return _upsample_cubic_interp1d(cs, tx) + + coeffs = tuple(get_coeff(ofs) for ofs in range(4)) + return _upsample_cubic_interp1d(coeffs, ty) + + +@register_decomposition(aten.grid_sampler_2d) +@out_wrapper() +@pw_cast_for_opmath +def grid_sampler_2d( + a: Tensor, + grid: Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> Tensor: + return _grid_sampler_2d( + a, + grid=grid, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + + +@register_decomposition(aten.mv) +@out_wrapper(exact_dtype=True) +@pw_cast_for_opmath +def mv(self, vec): + torch._check( + self.dim() == 2 and vec.dim() == 1, + lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}", + ) + torch._check( + self.size(1) == vec.size(0), + lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})", + ) + return (self * vec).sum(dim=1) + + +@register_decomposition(aten.binary_cross_entropy_with_logits) +@out_wrapper() +def binary_cross_entropy_with_logits( + self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value +): + if pos_weight is not None: + log_weight = (pos_weight - 1) * target + 1 + loss = (1 - target) * self - (log_weight * F.logsigmoid(self)) + else: + loss = (1 - target) * self - F.logsigmoid(self) + + if weight is not None: + loss = loss * weight + + return apply_loss_reduction(loss, reduction) + + +def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> bool: + # For comments of the logic of this function see eager in /native/LinearAlgebra.cpp + + t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1) + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if not (t1.ndim >= 3 and t2.ndim <= 2): + return False + if t2.requires_grad and not is_out: + return True + if tensor1.ndim == 2: + return False + if guard_size_oblivious(t1.numel() == 0): + return True + + t1_shape = t1.shape + t1_stride = t1.stride() + + # Check the contiguous, we can skip the dim with size of 1 + # as aten: https://github.com/pytorch/pytorch/blob/e201460f8aa1510b4c4686627d57b69756c4b916/aten/src/ATen/TensorGeometry.cpp#L17 + expected_stride = [1] + for size in reversed(t1_shape[1:]): + expected_stride.append(size * expected_stride[-1]) + return all( + guard_size_oblivious(size == 1) or left == right + for left, right, size in zip( + t1_stride, list(reversed(expected_stride)), t1_shape + ) + ) + + +@aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd) +@out_wrapper(pass_is_out=True) +def matmul(tensor1, tensor2, *, is_out=False): + dim_tensor1 = tensor1.dim() + dim_tensor2 = tensor2.dim() + assert dim_tensor1 != 0 and dim_tensor2 != 0 + if dim_tensor1 == 1 and dim_tensor2 == 1: + return torch.dot(tensor1, tensor2) + elif dim_tensor1 == 2 and dim_tensor2 == 1: + return torch.mv(tensor1, tensor2) + elif dim_tensor1 == 1 and dim_tensor2 == 2: + return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0) + elif dim_tensor1 == 2 and dim_tensor2 == 2: + return torch.mm(tensor1, tensor2) + elif should_fold(tensor1, tensor2, is_out): + # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) || + # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2) + # and some condition on the strides is fulfilled + + # optimization: use mm instead of bmm by folding the batch of the larger tensor + # into its leading matrix dimension + transpose = dim_tensor2 > dim_tensor1 + t1 = tensor2.mT if transpose else tensor1 + t2 = ( + tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1) + ) + # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2) + # and t1 and t2 are matmul-compatible + + # Why not t1.view(-1, sizes_1[-1])? + # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous. + # This can happen in e.g. [3, 5, 0] @ [0, 0]. + sizes_1 = t1.shape + output_shape = list(sizes_1[:-1]) + folded_dim1 = reduce(operator.mul, output_shape) + + # Readjust output_shape if we are multiplying by a matrix + t2_is_matrix = t2.dim() == 2 + if t2_is_matrix: + output_shape.append(t2.shape[1]) + + # This will almost always be a view. + # It may not be a view if t2->requires_grad(). See should_fold in aten/ for an explanation + t1_folded = t1.reshape(folded_dim1, sizes_1[-1]) + if t2_is_matrix: + # This copies if we perform a 2D @ 3D and the first tensor requires_grad + # See should_fold native/LinearAlgebra.cpp for why. + output = torch.ops.aten._unsafe_view(t1_folded.mm(t2), output_shape) + return output.mT.contiguous() if transpose else output + else: + return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape) + + elif dim_tensor1 >= 1 and dim_tensor2 >= 1: + # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); + # we track m1 vs m2 separately even though they must match for nicer error messages + n = tensor1.size(-2) if dim_tensor1 > 1 else 1 + m1 = tensor1.size(-1) + batch_tensor1 = tensor1.shape[:-2] + m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1) + p = tensor2.size(-1) if dim_tensor2 > 1 else 1 + + batch_tensor2: list[int] = [] + # TODO: handling of slice + for i in range(dim_tensor2 - 2): + batch_tensor2.append(tensor2.size(i)) + + # Same optimization for the gradients as that in should_fold + # If we're going to broadcast, we force it to go through the should_fold branch + if ( + dim_tensor1 == 3 + and dim_tensor2 == 3 + and batch_tensor1[0] != batch_tensor2[0] + ): + if batch_tensor1[0] == 1 and tensor1.requires_grad: + return matmul(tensor1.squeeze(0), tensor2) + if batch_tensor2[0] == 1 and tensor2.requires_grad: + return matmul(tensor1, tensor2.squeeze(0)) + + # expand the batch portion (i.e. cut off matrix dimensions and expand rest) + expand_batch_portion = list( + torch.broadcast_shapes(batch_tensor1, batch_tensor2) + ) + + tensor1_expand_size = expand_batch_portion + [n, m1] + + expand_batch_product = prod(expand_batch_portion) + + # HACK: We need reshape with symint support + tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape( + expand_batch_product, n, m1 + ) + + vector_rhs = dim_tensor2 == 1 + if vector_rhs: + tensor2_expand_size = expand_batch_portion + [m2] + tensor2_expanded = ( + tensor2.expand(tensor2_expand_size) + .reshape(expand_batch_product, m2) + .unsqueeze(2) + ) + else: + tensor2_expand_size = expand_batch_portion + [m2, p] + tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape( + expand_batch_product, m2, p + ) + + output_shape = expand_batch_portion + if dim_tensor1 > 1: + output_shape.append(n) + + if dim_tensor2 > 1: + output_shape.append(p) + + if vector_rhs: + return tensor1_expanded.bmm(tensor2_expanded).squeeze(-1).view(output_shape) + else: + return tensor1_expanded.bmm(tensor2_expanded).view(output_shape) + else: + torch._check(False, lambda: "both arguments to matmul need to be at least 1D") + + +@register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out]) +@aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +@pw_cast_for_opmath +def upsample_bicubic2d_default( + input: Tensor, + output_size: tuple[int, int], + align_corners: bool, + scale_h: Optional[float] = None, + scale_w: Optional[float] = None, +) -> Tensor: + # get dimensions of original image + _, _, in_h, in_w = input.shape + + # Calculate horizontal and vertical scaling factor + h_scale_factor = _compute_scale(in_h, output_size[0], align_corners, scale_h) + w_scale_factor = _compute_scale(in_w, output_size[1], align_corners, scale_w) + + _, dtype = utils.elementwise_dtypes( + input, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + # We have to create arange with int64 dtype and use .to in order to avoid + # additional kernels creation in inductor and get a perf slowdown + i = torch.arange(output_size[0], device=input.device).to(dtype=dtype) + j = torch.arange(output_size[1], device=input.device).to(dtype=dtype) + + x_float = _compute_source_index(w_scale_factor, j, align_corners) + y_float = _compute_source_index(h_scale_factor, i, align_corners) + y_float = y_float.unsqueeze(-1) + + x = x_float.floor() + y = y_float.floor() + + # We should also clamp xscale/yscale + # See guard_index_and_lambda in UpSample.h + yscale = (y_float - y).clamp(0.0, 1.0) + xscale = (x_float - x).clamp(0.0, 1.0) + x = x.to(torch.int64) + y = y.to(torch.int64) + + iys_ofs = (y - 1, y, y + 1, y + 2) + ixs_ofs = (x - 1, x, x + 1, x + 2) + + weights_x = _upsample_get_cubic_coefficients(xscale) + weights_y = _upsample_get_cubic_coefficients(yscale) + + weights_precision_x, weights_precision_y = None, None + if input.dtype == torch.uint8: + weights_precision_x = _compute_weight_precision(weights_x) + weights_precision_y = _compute_weight_precision(weights_y) + + weights_x = [ + (w * (1 << weights_precision_x) + torch.sign(w) * 0.5).to(torch.int16) + for w in weights_x + ] + weights_y = [ + (w * (1 << weights_precision_y) + torch.sign(w) * 0.5).to(torch.int16) + for w in weights_y + ] + + def load_bounded(ys, xs): + y_idx = torch.clamp(ys, 0, in_h - 1) + x_idx = torch.clamp(xs, 0, in_w - 1) + v = aten._unsafe_index(input, [None, None, y_idx, x_idx]) + return v + + def get_x_interp(y): + src_x = tuple(load_bounded(y, x_ofs) for x_ofs in ixs_ofs) + if input.dtype == torch.uint8: + assert weights_precision_x is not None + return _sum_tensors_uint8(src_x, weights_x, weights_precision_x) + return _sum_tensors(c1 * c2 for (c1, c2) in zip(src_x, weights_x)) + + src_y = tuple(get_x_interp(y_ofs) for y_ofs in iys_ofs) + if input.dtype == torch.uint8: + assert weights_precision_y is not None + result = _sum_tensors_uint8(src_y, weights_y, weights_precision_y) + else: + result = _sum_tensors(c1 * c2 for (c1, c2) in zip(src_y, weights_y)) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + result = result.contiguous(memory_format=memory_format) + return result + + +@register_decomposition(aten.upsample_bicubic2d.vec) +@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd) +@out_wrapper() +@pw_cast_for_opmath +def upsample_bicubic2d_vec( + a: Tensor, + output_size: Optional[tuple[int, int]], + align_corners: bool, + scale_factors: Optional[tuple[float, float]] = None, +) -> Tensor: + torch._check( + bool(output_size) + bool(scale_factors) == 1, + lambda: "Must specify exactly one of output_size and scale_factors.", + ) + if output_size is None: + assert scale_factors is not None + output_size = cast( + tuple[int, int], + tuple( + sym_int(sym_float(w) * scale) + for w, scale in zip(a.shape[2:], scale_factors) + ), + ) + scale_h, scale_w = scale_factors if scale_factors else (None, None) + return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w) + + +@register_decomposition(aten.reflection_pad1d) +@register_decomposition(aten.reflection_pad2d) +@register_decomposition(aten.reflection_pad3d) +@pw_cast_for_opmath +@out_wrapper() +def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return middle - 1 - (middle - 1 - dim_idx.abs()).abs() + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + + +@register_decomposition(aten.replication_pad1d) +@register_decomposition(aten.replication_pad2d) +@register_decomposition(aten.replication_pad3d) +@pw_cast_for_opmath +@out_wrapper() +def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return torch.clamp(dim_idx, 0, middle - 1) + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + + +def _reflection_or_replication_pad( + a: Tensor, + padding: tuple[int, ...], + idx_fn: Callable[[int, int, int], Tensor], +) -> Tensor: + dim = len(padding) // 2 + torch._check( + a.dim() in (dim + 1, dim + 2), + lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", + ) + inp_shape = a.shape[-dim:] + nc_dim = a.dim() - dim + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + result = a + for i in range(dim): + idx: list[Any] = [None] * result.dim() + idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) + result = aten._unsafe_index(result, idx) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(result) + result = result.contiguous(memory_format=memory_format) + return result + + +@register_decomposition(aten.reflection_pad1d_backward) +@register_decomposition(aten.reflection_pad2d_backward) +@register_decomposition(aten.reflection_pad3d_backward) +@out_wrapper("grad_input") +def _reflection_pad_backward(grad_output, x, padding): + dim = len(padding) // 2 + + dhw = [h - 1 for h in x.shape[-dim:]] + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + indices = [] + for i in range(x.ndim): + view_shape = [1] * x.ndim + view_shape[i] = -1 + indices.append(torch.arange(x.shape[i], device=x.device).view(view_shape)) + + b = indices[:-dim] + xyz = indices[-dim:] + + def index_range_condition(index_range): + i, lb, ub = index_range + return torch.logical_and(i >= lb, i <= ub) + + # Areas after reflection: + # + # top-left | top | top-right + # ----------------------------------------- + # left | center | right + # ----------------------------------------- + # bottom-left | bottom | bottom-right + # + # The center area is the original matrix. Other areas are reflections. + + center = [xyz[i] + padding_left[i] for i in range(dim)] + left_reflect = [padding_left[i] - xyz[i] for i in range(dim)] + right_reflect = [2 * dhw[i] + padding_left[i] - xyz[i] for i in range(dim)] + + # Accumulate gradients from different areas + # If some of the padding is negative, center load is not always valid + range_c = [ + (center[i], 0, dhw[i] + padding_left[i] + padding_right[i]) for i in range(dim) + ] + cond = functools.reduce( + aten.logical_and, [index_range_condition(range_c[i]) for i in range(dim)] + ) + grad = aten._unsafe_masked_index(grad_output, cond, b + center, 0.0) + + def accumulate(grad, out, index_ranges): + # If the upper bound is less than the lower bound, we can get rid of one accumulation. + # This happens when the padding size is zero. + for i in range(dim): + upper_less_than_lower = index_ranges[i][2] < index_ranges[i][1] + if isinstance(upper_less_than_lower, bool) and upper_less_than_lower: + return grad + + cond = functools.reduce( + aten.logical_and, + [index_range_condition(index_range) for index_range in index_ranges], + ) + g = aten._unsafe_masked_index(grad_output, cond, b + out, 0.0) + return grad + g + + for area in itertools.product(*[[-1, 0, 1] for _ in range(dim)]): + if area == tuple([0] * dim): + # center, this is already done. + continue + + outs = [] + index_ranges = [] + + for i in range(dim): + if area[i] == 0: + out = center[i] + index_range = range_c[i] + elif area[i] == -1: + out = left_reflect[i] + index_range = (xyz[i], 1, padding_left[i]) + elif area[i] == 1: + out = right_reflect[i] + index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1) + + outs.append(out) # type: ignore[possibly-undefined] + index_ranges.append(index_range) # type: ignore[possibly-undefined] + + grad = accumulate(grad, outs, index_ranges) + + return grad + + +@register_decomposition(aten.aminmax) +@out_wrapper("min", "max") +def aminmax(self, *, dim=None, keepdim=False): + amin = torch.amin(self, dim=dim, keepdim=keepdim) + amax = torch.amax(self, dim=dim, keepdim=keepdim) + return amin, amax + + +@register_decomposition(aten.nansum) +@out_wrapper() +def nansum(self, dim=None, keepdim=False, *, dtype=None): + return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype) + + +@register_decomposition([aten.arange.default, aten.arange.out]) +@out_wrapper() +def arange_default( + end: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + pin_memory: bool = False, +): + return aten.arange.start_step( + 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_decomposition([aten.arange.start]) +def arange_start( + start: NumberType, + end: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + pin_memory: bool = False, +): + return aten.arange.start_step( + start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_decomposition(out_dtype) +def out_dtype_decomp(*args, **kwargs): + from torch._higher_order_ops.out_dtype import out_dtype_dense + + return out_dtype_dense(*args, **kwargs) + + +@register_decomposition(aten.multi_margin_loss) +@aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: NumberType = 1, + margin: NumberType = 1, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + input = torch.atleast_2d(input) + target = torch.atleast_1d(target) + nframe = input.shape[0] + dim = input.shape[1] + torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported") + torch._check( + input.ndim == 2 and dim != 0, + lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}", + ) + torch._check( + target.ndim == 1 and target.numel() == nframe, + lambda: f"inconsistent target size, expected {nframe} but got {target.shape}", + ) + if weight is not None: + weight = torch.atleast_1d(weight) + torch._check( + weight.ndim == 1 and weight.numel() == dim, # type: ignore[union-attr] + lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}", # type: ignore[union-attr] + ) + target = target.unsqueeze(1) + u = torch.gather(input, dim=1, index=target) + z = margin - u + input + z = z.clamp_min(0) + z = z if p == 1 else z * z + if weight is not None: + z = z * weight[target] + idx = torch.arange(dim, device=input.device) + z = torch.where(idx != target, z, 0) + if reduction == Reduction.MEAN.value: + return z.mean() + elif reduction == Reduction.SUM.value: + return z.sum() / z.shape[1] + else: + return z.mean(dim=1) + + +@register_decomposition(aten.multilabel_margin_loss_forward) +@aten.multilabel_margin_loss_forward.default.py_impl(DispatchKey.Autograd) +@out_wrapper("output", "is_target") +def multilabel_margin_loss_forward( + input: Tensor, + target: Tensor, + reduction: int, +) -> tuple[Tensor, Tensor]: + orig_input_shape = input.shape + orig_target_shape = target.shape + input = torch.atleast_2d(input) + target = torch.atleast_2d(target) + dim = input.shape[1] + torch._check( + len(orig_input_shape) <= 2 and dim != 0, + lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {orig_input_shape}", + ) + torch._check( + len(orig_target_shape) <= 2 and orig_target_shape == orig_input_shape, + lambda: f"inconsistent target size: {orig_target_shape} for input of size: {orig_input_shape}", + ) + # ignores labels after the first -1, detects when -1 is not present + idx = torch.arange(dim, device=target.device) + is_end = target == -1 + end_idx = torch.amin(torch.where(is_end, idx, dim), dim=-1, keepdim=True) + # target indices + target_mask = idx < end_idx + # masks target to be able to use gather, which doesn't allow -1 + tidx0 = torch.where(target_mask, target, 0) + u = torch.gather(input, dim=-1, index=tidx0) + # is_target + tidx1 = torch.where(target_mask, target, -1) + is_target = torch.any(idx == tidx1.unsqueeze(dim=-1), dim=1) + # loss + z = 1.0 - u.T.unsqueeze(dim=-1) + input + z = z.clamp_min(0) + z = z / dim + # masks loss + z = torch.where(is_target, 0, z) + # reduction + if reduction == Reduction.MEAN.value: + z = z.sum(dim=(0, -1)).mean() + elif reduction == Reduction.SUM.value: + z = z.sum() + else: + z = z.sum(dim=(0, -1)) + # result + is_target = is_target.to(input.dtype).reshape(orig_target_shape) + return z, is_target + + +# scaled_dot_product_attention used to be decomposed in pre-autograd, given that +# it calls _scaled_dot_product_attention_math and +# _scaled_dot_product_attention_math only has a CompositeImplicitAutograd +# kernel. As a result it's decomposed into ops with finer granularity. +# However recent PRs (#103826 #105131 #115913) added new logic in +# scaled_dot_product_attention and now it calls +# _scaled_dot_product_flash_attention_for_cpu in export path. This results +# in _scaled_dot_product_flash_attention_for_cpu showing up in export result. +# This decomposition ensures scaled_dot_product_attention is still decomposed +# the same way as before, i.e., going through +# _scaled_dot_product_attention_math. Notice that this decomp rule should be +# excluded by inductor. +@register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default) +def scaled_dot_product_flash_attention_for_cpu( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +) -> tuple[Tensor, Tensor]: + torch._check( + torch.is_floating_point(query), + lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}", + ) + torch._check( + query.dim() == 4 and key.dim() == 4 and value.dim() == 4, + lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}", + ) + torch._check( + dropout_p == 0.0, lambda: f"dropout probability must be zero, got {dropout_p}" + ) + torch._check( + query.shape[3] == value.shape[3] and key.shape[3] == value.shape[3], + lambda: "q, k, v should have the same head size", + ) + + output, attn = aten._scaled_dot_product_attention_math.default( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + dropout_mask=None, + scale=scale, + ) + # Why this change? + # In pre-dispatch export scaled_dot_product_attention is executed via + # * flash_attention. + # flash_attention allocates output tensor as (N, H, L, E) (see PR #134656) + # assume x: [N, H, L, E] is the output sdpa + # In MHA code, this output is then permuted via (2, 0, 1, 3) to get + # (L, N, H, E) dim tensor + # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via + # x = x.view(L * N, H * E) + # During pre autograd dispatch call to contiguous is not traced because + # flash_attention output after the x.permute is already contiguous + # on which the view is valid + # However, during 2nd stage export, post-dispatch, we run _match variant + # instead of flash* to get the decomposition. _match variant returns + # x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns + # x: [L, N, H, E] and without converting this to contiguous tensor + # subsequent view is not valid and the export fails + # solution is to maintain the return tensor view from the decomp to be + # exactly same as *flash* variant. + + # Really the invariant you want to maintain is: + # pre-dispatch op-output and its decomposed representation must + # return tensor with same view and dims + output = ( + output.permute(2, 0, 1, 3) + .contiguous(memory_format=torch.contiguous_format) + .permute(1, 2, 0, 3) + ) + return output, attn + + +def register_inplace(aten_op, outplace_op): + @register_decomposition(aten_op) + def inplace_op(*args, **kwargs): + out = outplace_op(*args, **kwargs) + return args[0].copy_(out) + + return inplace_op + + +@register_decomposition([aten.baddbmm]) +@out_wrapper(exact_dtype=True) +@pw_cast_for_opmath +def baddbmm(self, batch1, batch2, beta=1, alpha=1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + result = torch.bmm(batch1, batch2) + if not isinstance(alpha, numbers.Number) or alpha != 1: + result = result * alpha + if beta == 0: + return result + if not isinstance(beta, numbers.Number) or beta != 1: + self = self * beta + return self + result + + +@register_decomposition(aten.floor_divide) +@out_wrapper() +def floor_divide(self, other): + return torch.div(self, other, rounding_mode="floor") + + +@register_decomposition(aten.sym_numel) +def sym_numel(t): + return functools.reduce(operator.mul, t.shape, 1) + + +@register_decomposition([aten.sum.default, aten.sum.out]) +def sum_default( + self: Tensor, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> Tensor: + if out is None: + return aten.sum.dim_IntList(self, [], dtype=dtype) + else: + return aten.sum.IntList_out(self, [], dtype=dtype, out=out) + + +@register_decomposition([aten.squeeze.default, aten.squeeze.dim]) +def squeeze_default(self: Tensor, dim: Optional[int] = None): + # handle a scalar directly + if not isinstance(self, torch.Tensor): + return self + # perform squeeze + if dim is None: + return aten.squeeze.dims(self, list(range(self.dim()))) + else: + return aten.squeeze.dims(self, [dim]) + + +@register_decomposition(torch.ops.aten._weight_norm_interface) +def _weight_norm_interface(v, g, dim=0): + # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 + keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) + # align with cuda behavior, keep norm in 'float' when g is 'bfloat16' + norm_dtype = torch.float if g.dtype == torch.bfloat16 else None + norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype) + return v * (g / norm.to(g.dtype)), norm + + +@register_decomposition(aten.isin) +@out_wrapper() +def isin(elements, test_elements, *, assume_unique=False, invert=False): + # handle when either elements or test_elements are Scalars (they can't both be) + if not isinstance(elements, torch.Tensor): + elements = torch.tensor(elements, device=test_elements.device) + if not isinstance(test_elements, torch.Tensor): + if invert: + return torch.ne(elements, test_elements) + else: + return torch.eq(elements, test_elements) + + if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145): + return isin_default(elements, test_elements, invert=invert) + else: + return isin_sorting( + elements, test_elements, assume_unique=assume_unique, invert=invert + ) + + +@register_decomposition(aten.bernoulli.default) +def bernoulli( + self: torch.Tensor, + *, + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + if generator is None: + raw_p = torch.rand(self.size(), dtype=torch.float32, device=self.device) + else: + raw_p = torch.rand( + self.size(), + generator=generator, + dtype=torch.float32, + device=self.device, + ) + p = (raw_p < self).to(self.dtype) + return p + + +def isin_default(elements, test_elements, *, invert=False): + if elements.numel() == 0: + return torch.empty_like(elements, dtype=torch.bool) + expanded_elem_shape = elements.shape + (1,) * test_elements.ndim + x = elements.view(expanded_elem_shape) + dim = tuple(range(-1, -test_elements.ndim - 1, -1)) + res = (x == test_elements).any(dim=dim) + return ~res if invert else res + + +def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False): + elements_flat = elements.flatten() + test_elements_flat = test_elements.flatten() + if assume_unique: + # This is the same as the aten implementation. For + # assume_unique=False, we cannot use unique() here, so we use a + # version with searchsorted instead. + all_elements = torch.cat([elements_flat, test_elements_flat]) + sorted_elements, sorted_order = torch.sort(all_elements, stable=True) + + duplicate_mask = sorted_elements[1:] == sorted_elements[:-1] + duplicate_mask = torch.constant_pad_nd(duplicate_mask, [0, 1], False) + + if invert: + duplicate_mask = duplicate_mask.logical_not() + + mask = torch.empty_like(duplicate_mask) + mask = mask.index_copy(0, sorted_order, duplicate_mask) + + return mask[0 : elements.numel()] + else: + sorted_test_elements, _ = torch.sort(test_elements_flat) + idx = torch.searchsorted(sorted_test_elements, elements_flat) + test_idx = torch.where(idx < sorted_test_elements.numel(), idx, 0) + cmp = sorted_test_elements[test_idx] == elements_flat + cmp = cmp.logical_not() if invert else cmp + return cmp.reshape(elements.shape) + + +@register_decomposition(aten.take) +@out_wrapper() +def take(self, index): + flattened = self.reshape(-1) + return flattened[index] + + +@register_decomposition(aten.resize_as) +def resize_as(self, other, memory_format=None): + if memory_format is None: + memory_format = torch.contiguous_format + if memory_format == torch.preserve_format: + memory_format = suggest_memory_format(other) + return aten.resize(self, other.shape, memory_format=memory_format) + + +register_inplace(aten.addbmm_, aten.addbmm) +register_inplace(aten.addmm_, aten.addmm) +register_inplace(aten.addmv_, aten.addmv) +register_inplace(aten.baddbmm_, aten.baddbmm) +register_inplace(aten.fill_, aten.fill) +register_inplace(aten.gelu_, aten.gelu) +register_inplace(aten.hardswish_, aten.hardswish) +register_inplace(aten.hardtanh_, aten.hardtanh) +register_inplace(aten.hardsigmoid_, aten.hardsigmoid) +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.index_put_, aten.index_put) +register_inplace(aten.index_reduce_, aten.index_reduce) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) +register_inplace(aten.leaky_relu_, aten.leaky_relu) +register_inplace(aten.logit_, aten.logit) +register_inplace(aten.relu_, aten.relu) +register_inplace(aten.renorm_, aten.renorm) +register_inplace(aten.round_, aten.round) +register_inplace(aten.scatter_, aten.scatter) +register_inplace(aten.scatter_add_, aten.scatter_add) +register_inplace(aten.scatter_reduce_, aten.scatter_reduce) +register_inplace(aten.silu_, aten.silu) diff --git a/phivenv/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py b/phivenv/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py new file mode 100644 index 0000000000000000000000000000000000000000..7b308fa2cf138682104b81d113bddc593af7c3e9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_decomp/decompositions_for_jvp.py @@ -0,0 +1,335 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +from typing import Callable, Optional + +import torch +import torch._decomp +from torch import Tensor +from torch._prims_common.wrappers import _maybe_remove_out_wrapper + + +decomposition_table = torch._decomp.decomposition_table +decomposition_table_for_jvp: dict[torch._ops.OperatorBase, Callable] = {} +register_decomposition = torch._decomp.register_decomposition +aten = torch.ops.aten + +# NOTE: [forward-mode AD decompositions mechanism] +# +# The mechanism is in VariableType, +# IF any inputs have forward grad +# AND there is no forward AD formula implemented +# AND the functions are actually differentiable +# run the decomposition +# See run_jit_decomposition_with_args_for_jvp +# We currently use python decompositions that we torchscript. +# +# Note that we would be building the backward graph at the decomposed level +# too, but that is OK, because we would've errored out otherwise anyway. +# +# TODO: The mechanism we are using to register decompositions doesn't +# seem to be exclusively used for jvp. So open question here is whether +# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things. +# If that is the case, we may go down the decomposition path unexpectedly +# (and possibly produce an unintelligible error) vs erroring out earlier and +# printing that the forward AD formula is not implemented. +# +# The solution to this may be to have an explicitly white list control when +# to enable the decomposition. + + +def maybe_register_decomposition(op): + def decorator(f): + try: + return register_decomposition(op)(f) + except Exception: + return f + + return decorator + + +# Functions where we need a special decomposition for jvp but there's another version that +# should be used more generally (ex. for jvp we need to recompute the mean and variance for +# the backwards of a normalization function. Without jvp, it should use the saved value) +decomposition_table_for_jvp = {} + + +def register_decomposition_for_jvp(fn): + return register_decomposition(fn, registry=decomposition_table_for_jvp) + + +def _register_jit_decomposition_for_jvp(decomp, use_python=False): + if decomp in decomposition_table_for_jvp: + decomposition_table_used = decomposition_table_for_jvp + elif decomp in decomposition_table: + decomposition_table_used = decomposition_table + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + decomp_fn = decomposition_table_used[decomp] + + # `out_wrapper` extends a decompositions signature with + # an `out` parameter. However jit will use the unwrapped function's + # signature instead so we need to unwrap here to prevent an error + decomp_fn = _maybe_remove_out_wrapper(decomp_fn) + + if use_python: + decomp_fn = torch.jit.ignore(decomp_fn) + sig = inspect.signature(decomp_fn) + + # Create a string wrapping the function from the signature + # example output: + # def wrapped_decomp(x: torch.Tensor, y: int, z: int): + # return decomp_fn(x, y, z) + # Thanks copilot! + def get_function_def(sig): + param_def = [f"{param_str}" for param_str in sig.parameters.values()] + param_use = [f"{param_str}" for param_str in sig.parameters.keys()] + + return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n" + + f_str = get_function_def(sig) + graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph + else: + graph = torch.jit.script(decomp_fn).graph + torch.jit._register_decomposition(decomp, graph) + + +# The only decompositions here are temporary or hacks for the purposes of jvp + + +# TODO: do these also belong here? +@maybe_register_decomposition(aten.trace.default) +def trace(self: Tensor) -> Tensor: + return torch.sum(torch.diag(self)) + + +@maybe_register_decomposition(aten.log_sigmoid_forward.default) +def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]: + min = torch.minimum(self.new_zeros(()), self) + z = torch.exp(-torch.abs(self)) + if self.is_cuda or self.is_xpu: + buffer = self.new_zeros((0,)) + else: + buffer = z + return min - torch.log1p(z), buffer + + +def recompute_mean_var( + input: Tensor, rstd: Tensor, inner_dim_indices: list[int], keepdim: bool +): + # for most norm decompositions, it will be the same as the core version except for here. + # We recompute the mean and variance so that they track gradients through input + + mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim) + var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim) + eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside + eps = eps.detach() + rstd = 1 / torch.sqrt(var + eps) + return mean, rstd + + +@register_decomposition_for_jvp(aten.native_layer_norm_backward) +def native_layer_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices = list(range(axis, input_ndim)) + outer_dim_indices = list(range(0, axis)) + + N = 1 + for i in inner_dims: + N *= i + M = 1 + for i in outer_dims: + M *= i + if M <= 0 or N <= 0: + return ( + input.new_zeros(input_shape), + input.new_zeros(input_shape[axis:]), + input.new_zeros(input_shape[axis:]), + ) + + mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True) + + x_hat = (input - mean_) * rstd_ + if weight is not None: + grad_x_hat = grad_out * weight + else: + grad_x_hat = grad_out + a = grad_x_hat * N + b = torch.sum(grad_x_hat, inner_dim_indices, True) + c1 = torch.mul(grad_x_hat, x_hat) + c2 = torch.sum(c1, inner_dim_indices, True) + c3 = torch.mul(x_hat, c2) + inner = a - b - c3 + + if output_mask[0]: + d_input: Optional[Tensor] = (rstd_ / N) * inner + else: + d_input = torch.zeros_like(input) # should be None but doesn't work with vjp + + if output_mask[1] and weight is not None: + if len(outer_dim_indices) > 0: + d_weight: Optional[Tensor] = torch.sum( + grad_out * x_hat, outer_dim_indices, False + ) + else: + d_weight = grad_out * x_hat + elif weight is not None: + d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp + else: + d_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2] and bias is not None: + if len(outer_dim_indices) > 0: + d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False) + else: + d_bias = grad_out.clone() + elif bias is not None: + d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp + else: + d_bias = torch.zeros(()) # should be None but doesn't work with vjp + + return (d_input, d_weight, d_bias) + + +def prod(x: list[int]): + r = 1 + for i in x: + r *= i + return r + + +@register_decomposition_for_jvp(aten.native_batch_norm_backward) +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: list[bool], +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type] + mean = save_mean + invstd = save_invstd + if train: + assert save_mean is not None and save_invstd is not None, ( + "when train=True, save_mean and save_invstd are required" + ) + + reduciton_dims = [0] + list(range(2, input.dim())) + assert invstd is not None # for typing + mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False) + else: + assert running_mean is not None and running_var is not None + mean = running_mean + invstd = torch.rsqrt(running_var + eps) + + assert invstd is not None and mean is not None + + broadcast_mask = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: list[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = torch.reshape(mean, broadcast_mask) + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out, reduction_axes) + dot_p = torch.sum(grad_out * (input - mean), reduction_axes) + + grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask) + proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) + + if weight is None: + grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 + else: + grad_scale = torch.reshape(invstd * weight, broadcast_mask) + + if train: + proj = (input - mean) * proj_scale + grad_input = ((grad_out - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + elif weight is not None: + grad_weight = torch.zeros_like( + weight + ) # should be None but doesn't work with vjp + else: + grad_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = torch.zeros_like( + grad_output_sum + ) # should be None but doesn't work with vjp + + return (grad_input, grad_weight, grad_bias) + + +@register_decomposition_for_jvp(aten.batch_norm_backward) +def batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + update: bool, + eps: float, + output_mask: list[bool], + reserve: Tensor, +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + return native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + update, + eps, + output_mask, + ) + + +_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True) +_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default) +_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.miopen_batch_norm_backward.default) diff --git a/phivenv/Lib/site-packages/torch/_decomp/decompositions_for_rng.py b/phivenv/Lib/site-packages/torch/_decomp/decompositions_for_rng.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6b540c5b1b8671a7354e8f2c8f79ce8f44af56 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_decomp/decompositions_for_rng.py @@ -0,0 +1,266 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +from collections import defaultdict +from typing import Callable + +import torch +import torch._decomp as decomp +from torch._decomp import get_decompositions +from torch._ops import OpOverload + + +aten = torch.ops.aten + +rng_decompositions: dict[str, dict[OpOverload, Callable]] = defaultdict(dict) + + +def register_rng_decomposition(aten_op): + return decomp.register_decomposition(aten_op, rng_decompositions) + + +def throw_on_non_cuda(device): + raise RuntimeError( + f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not " + f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is " + "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU." + ) + + +# TODO - We have to register many more distributions here, and also higher level +# ops like dropout which have fused implementation and can hide the rand inside. +@register_rng_decomposition(aten.rand) +def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False): + if device and device.type != "cuda": + throw_on_non_cuda(device) + seed, offset = PhiloxStateTracker.get_state_as_tuple() + dtype = dtype or torch.float32 + out, offset_jump = torch.ops.rngprims.philox_rand( + shape, seed, offset, None, device, dtype + ) + PhiloxStateTracker.advance_offset(offset_jump) + return out + + +@register_rng_decomposition(aten.rand_like) +def rand_like( + x: torch.Tensor, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=torch.preserve_format, +): + device = device or x.device + if device.type != "cuda": + throw_on_non_cuda(device) + dtype = dtype or x.dtype + seed, offset = PhiloxStateTracker.get_state_as_tuple() + out, offset_jump = torch.ops.rngprims.philox_rand( + x.shape, seed, offset, None, device, dtype + ) + PhiloxStateTracker.advance_offset(offset_jump) + return out + + +class PhiloxState: + """ + Represents a PhiloxRngState - (seed, offset) where offset = base_offset + + relative_offset. seed and base_offset basically point to the rng state just + before tracing starts. relative offset tracks the totally consumed offset at + trace time. + """ + + def __init__(self) -> None: + self.reset() + + def reset(self): + self.seed = torch.tensor(()) + self.base_offset = torch.tensor(()) + self.relative_offset = 0 + self.offset_advanced_alteast_once = False + + def validate_state(self): + assert self.seed.numel() != 0 and self.base_offset.numel() != 0 + + def advance_offset(self, consumed_offset): + self.offset_advanced_alteast_once = True + self.relative_offset = self.relative_offset + consumed_offset + + def set_state(self, seed, base_offset, relative_offset=0): + self.seed = seed + self.base_offset = base_offset + self.relative_offset = relative_offset + + def get_state_as_tuple(self): + self.validate_state() + return (self.seed, self.base_offset + self.relative_offset) + + def get_state_as_tensor(self): + # Only needed because we override get_rng_state. + self.validate_state() + return torch.stack([self.seed, self.base_offset + self.relative_offset]) + + def set_state_from_tensor(self, state): + # Only needed because we override set_rng_state. + self.seed, self.base_offset = torch.unbind(state) + self.relative_offset = 0 + + +class PhiloxStateTracker: + """ + Singleton class to track the philox rng state during AOT Autograd tracing. + For each aot tracing instance, AOT Autograd resets this tracker and keeps + track of both forward and backward offsets. At runtime, we only care about + the total consumed forward and backward offsets. For dynamic shapes, these + offsets are a function of input shapes. Therefore, the AOT generated graphs + have additional outputs that compute total consumed forward and backward + offsets. + """ + + running_state: PhiloxState + fwd_state: PhiloxState + bwd_state: PhiloxState + + def __enter__(self): + PhiloxStateTracker.reset() + return self + + def __exit__(self, exc_type, exc_cal, exc_tb): + PhiloxStateTracker.reset() + + @classmethod + def reset(cls): + cls.running_state = PhiloxState() + cls.fwd_state = PhiloxState() + cls.bwd_state = PhiloxState() + + @classmethod + def mark_beginning_of_forward(cls): + # Tells the tracker to use fwd_state as the running state + cls.running_state = cls.fwd_state + + @classmethod + def mark_beginning_of_backward(cls): + # Tells the tracker to use bwd_state as the running state + cls.running_state = cls.bwd_state + + @classmethod + def record_state(cls, seed, offset, mode): + # Records the seed and offset tensors. These tensors are used to invoke + # the philox_rand functional primitives. + if mode == "forward": + cls.fwd_state.set_state(seed, offset) + cls.mark_beginning_of_forward() + else: + assert mode == "backward" + cls.bwd_state.set_state(seed, offset) + + @classmethod + def get_state_as_tensor(cls): + # The only reason this exists is because we override get_rng_state and + # set_rng_state during tracing. get_rng_state expects a tensor output, + # so return (seed, offset) tuple upset other parts of the program like + # ctx.saved_tensors. + + # A bad consequence is that if user saves and restores rng state, we + # have little bit of ugliness in the generated code, where we first + # concat the (seed, offset) to create a tensor for get_rng_state, and + # then split it back to get (seed, offset) tuple in set_rng_state. + + # TODO: Investigate if there is be a better way to wrap the tuple in a + # false Tensor object, and then desugar it later on. + return cls.running_state.get_state_as_tensor() + + @classmethod + def get_state_as_tuple(cls): + return cls.running_state.get_state_as_tuple() + + @classmethod + def set_state_from_tensor(cls, x): + # This is only needed because we override set_rng_state. Look at the + # comment in get_state_from_tensor method. + cls.running_state.set_state_from_tensor(x) + + @classmethod + def advance_offset(cls, consumed_offset): + cls.running_state.advance_offset(consumed_offset) + + @classmethod + def get_current_relative_offset(cls): + return cls.running_state.relative_offset + + @staticmethod + def multiple_of_4(offset): + # torch cuda rng state offset must be a multiple of 4. For inductor, as + # we sum up all the numel, the result might not be a multiple of 4. This + # method achieves that. + return (offset + 3) // 4 * 4 + + @classmethod + def get_updated_fwd_offset(cls): + # Short circuit if no rand ops were observed + if not cls.fwd_state.offset_advanced_alteast_once: + return cls.fwd_state.base_offset + return cls.multiple_of_4( + cls.fwd_state.base_offset + cls.fwd_state.relative_offset + ) + + @classmethod + def get_updated_bwd_offset(cls): + # Short circuit if no rand ops were observed + if not cls.bwd_state.offset_advanced_alteast_once: + return cls.bwd_state.base_offset + return cls.multiple_of_4( + cls.bwd_state.base_offset + cls.bwd_state.relative_offset + ) + + +# Adding more decompositions which eventually use rand_like inside decomps. +# Adding these in rng_decompositions ensures the functionalization of rand_like +# ops used in these decomps. The list is copied from inductor codebase, which +# uses it for similar purpose. +# +# Caution - These decomps do not have same accuracy as that of eager. However, +# we can't just disable them with a config flag like fallback_random, because +# for functionalization of rng ops, we have to decompose these ops. +extra_random_decomps = get_decompositions( + [ + aten.cauchy, + aten.cauchy_, + aten.exponential, + aten.exponential_, + aten.geometric, + aten.geometric_, + aten.native_dropout, + aten.normal, + aten.normal_, + aten.normal_functional, + aten.log_normal, + aten.log_normal_, + aten.rrelu_with_noise, + aten.rrelu_with_noise_, + aten.uniform_, + ] +) +register_extra_random_decomp = functools.partial( + decomp.register_decomposition, registry=extra_random_decomps +) + + +@register_extra_random_decomp([aten.bernoulli_]) +def bernoulli_(self, p=0.5): + if self.device == torch.device("cpu"): + return NotImplemented + return self.copy_(torch.rand_like(self, dtype=torch.float32) < p) + + +@register_extra_random_decomp([aten.bernoulli.p]) +def bernoulli_p(self, p=0.5, *, generator=None): + if self.device == torch.device("cpu"): + return NotImplemented + assert generator is None + return torch.rand_like(self, dtype=torch.float32) < p + + +rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type] diff --git a/phivenv/Lib/site-packages/torch/_dispatch/__init__.py b/phivenv/Lib/site-packages/torch/_dispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e327341a55e14788044439370ee0f09f84988833 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dispatch/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a65113d51a82921a93711b0c7b53c0e00106c3b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dispatch/__pycache__/python.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dispatch/python.py b/phivenv/Lib/site-packages/torch/_dispatch/python.py new file mode 100644 index 0000000000000000000000000000000000000000..3a043bdf9657638cbf68ecb6c397b4d8669ebd34 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dispatch/python.py @@ -0,0 +1,192 @@ +# mypy: allow-untyped-defs +import itertools +import unittest.mock +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Callable, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch._C +import torch._ops +import torch.utils._python_dispatch +import torch.utils._pytree as pytree +from torch._C import DispatchKey + + +__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"] + +no_python_dispatcher = torch._C._DisablePythonDispatcher +enable_python_dispatcher = torch._C._EnablePythonDispatcher +enable_pre_dispatch = torch._C._EnablePreDispatch + +CROSSREF_FUNCTIONALIZE = False + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: + """ + Warning: the set of overloads this will report is very subtle. It is precisely + the set of torch.ops functions that have actually been accessed from Python + (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT + from the set of registered operators, which will in general be a larger set, + as this would include all operators which we ran C++ static initializers or + Python operator registration on. This does not eagerly populate the list on + torch.ops.aten; this list is lazy! + + In other words, this is good for traversing over everything that has an + OpOverload object allocated in Python. We use it for cache invalidation, but + don't rely on this list being complete. + + Note that even if we did report all C++ registered overloads, this isn't guaranteed + to be complete either, as a subsequent lazy load of a library which triggers more + registrations could add more things to the set. + """ + for ns in torch.ops: + packets = getattr(torch.ops, ns) + for op_name in packets: + packet = getattr(packets, op_name) + for overload in packet: + yield getattr(packet, overload) + + +@contextmanager +def suspend_functionalization(): + f_tls = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + f_rv = torch._C._functionalization_reapply_views_tls() + if f_tls: + torch._disable_functionalization() + try: + yield + finally: + if f_tls: + torch._enable_functionalization(reapply_views=f_rv) + + +def check_tensor_metadata_matches(nv, rv, desc): + assert callable(desc) + assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" + assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" + same_strides, idx = torch._prims_common.check_significant_strides( + nv, rv, only_cuda=False + ) + assert same_strides, ( + f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" + ) + + +def check_metadata_matches(n, r, desc): + assert callable(desc) + n_vals, _n_spec = pytree.tree_flatten(n) + r_vals, _r_spec = pytree.tree_flatten(r) + # TODO: test the specs match; empirically sometimes we have a tuple + # on one side and a list on the other + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") + + +class Lit: + def __init__(self, s): + self.s = s + + def __repr__(self): + return self.s + + +def _fmt(a: object) -> object: + if isinstance(a, torch.Tensor): + return Lit( + f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})" + ) + else: + return a + + +def make_crossref_functionalize( + op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey +) -> Union[Callable[_P, _T], DispatchKey]: + from torch._subclasses.fake_tensor import FakeTensorMode + + # This case is pretty weird, suppress it for now + if op == torch.ops.aten.lift_fresh.default: + return final_key + + def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T: + fake_mode = FakeTensorMode() + + def fakeify_defun(t): + if isinstance(t, torch.Tensor): + if torch._is_functional_tensor(t): + r = torch._from_functional_tensor(t) + # NB: This assumes that the inner tensor sizes/strides match + # the outer tensor sizes/strides. This doesn't necessarily have to + # be the case, see discussion at + # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456 + assert t.size() == r.size() + assert t.stride() == r.stride() + else: + r = t + # TODO: suppress guards + return fake_mode.from_tensor(r) + return t + + def maybe_detach(t): + if isinstance(t, torch.Tensor): + return t.detach() + else: + return t + + # TODO: This probably does the wrong thing if you're running other + # substantive modes with the normal op outside here + with ( + torch.utils._python_dispatch._disable_current_modes(), + suspend_functionalization(), + ): + f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) + orig_f_args, orig_f_kwargs = pytree.tree_map( + maybe_detach, (f_args, f_kwargs) + ) + with fake_mode: + f_r = op(*f_args, **f_kwargs) + r = op._op_dk(final_key, *args, **kwargs) + + def desc(): + fmt_args = ", ".join( + itertools.chain( + (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), + ( + f"{k}={pytree.tree_map(_fmt, v)}" + for k, v in orig_f_kwargs.items() + ), + ) + ) + return f"{op}({fmt_args})" + + check_metadata_matches(f_r, r, desc) + return r + + return handler + + +# NB: enabling this is slow, don't do it in a hot loop. This is purely +# for debugging purposes. +@contextmanager +def enable_crossref_functionalize(): + for op in all_py_loaded_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) + try: + with ( + enable_python_dispatcher(), + unittest.mock.patch("torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True), + ): + yield + finally: + for op in all_py_loaded_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__init__.py b/phivenv/Lib/site-packages/torch/_dynamo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4fea41183e7b21799adf0104e0ad1c5021ca232 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/__init__.py @@ -0,0 +1,161 @@ +""" +TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. +TorchDynamo hooks into the frame evaluation API in CPython (PEP 523) to dynamically modify Python +bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of +PyTorch operations into an FX Graph which is then just-in-time compiled with a customizable backend. +It creates this FX Graph through bytecode analysis and is designed to mix Python execution with +compiled backends to get the best of both worlds: usability and performance. This allows it to +seamlessly optimize PyTorch programs, including those using modern Python features. +""" + +import torch + +from . import config, convert_frame, eval_frame, resume_execution +from .backends.registry import list_backends, lookup_backend, register_backend +from .callback import callback_handler, on_compile_end, on_compile_start +from .code_context import code_context +from .convert_frame import replay +from .decorators import ( + allow_in_graph, + assume_constant_result, + disable, + disallow_in_graph, + dont_skip_tracing, + forbid_in_graph, + graph_break, + mark_dynamic, + mark_static, + mark_static_address, + maybe_mark_dynamic, + nonstrict_trace, + patch_dynamo_config, + run, + set_stance, + skip_frame, + substitute_in_graph, +) +from .eval_frame import ( + _reset_guarded_backend_cache, + explain, + export, + is_dynamo_supported, + is_inductor_supported, + optimize, + optimize_assert, + OptimizedModule, + reset_code, +) +from .external_utils import is_compiling +from .mutation_guard import GenerationTracker +from .pgo import reset_code_state +from .symbolic_convert import TensorifyState +from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count + + +# Register polyfill functions +from .polyfills import loader as _ # usort: skip # noqa: F401 + + +__all__ = [ + "allow_in_graph", + "assume_constant_result", + "disallow_in_graph", + "dont_skip_tracing", + "forbid_in_graph", + "substitute_in_graph", + "graph_break", + "mark_dynamic", + "maybe_mark_dynamic", + "mark_static", + "mark_static_address", + "nonstrict_trace", + "optimize", + "optimize_assert", + "patch_dynamo_config", + "skip_frame", + "export", + "explain", + "run", + "replay", + "disable", + "set_stance", + "reset", + "OptimizedModule", + "is_compiling", + "register_backend", + "list_backends", + "lookup_backend", + "config", +] + +# allowlist this for weights_only load of NJTs +torch.serialization.add_safe_globals([torch._dynamo.decorators._DimRange]) + +if torch.manual_seed is torch.random.manual_seed: + import torch.jit._builtins + + # Wrap manual_seed with the disable decorator. + # Can't do it at its implementation due to dependency issues. + torch.manual_seed = torch._disable_dynamo(torch.manual_seed) + # Add the new manual_seed to the builtin registry. + torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed") + + +def reset() -> None: + """ + Clear all compile caches and restore initial state. This function is intended + to reset Dynamo's state *as if* you had started a fresh process invocation, which + makes it good for testing scenarios where you want to behave as if you started + a new process. It does NOT affect any file system caches. + + NB: this does NOT reset logging state. Don't use this to test logging + initialization/reinitialization. + """ + # TODO: https://github.com/pytorch/pytorch/issues/139200 + import logging + + log = logging.getLogger(__name__) + log.info("torch._dynamo.reset") + with convert_frame.compile_lock: + reset_code_caches() + convert_frame.input_codes.clear() + reset_code_state() + convert_frame.output_codes.clear() + orig_code_map.clear() + guard_failures.clear() + graph_break_reasons.clear() + resume_execution.ContinueExecutionCache.cache.clear() + _reset_guarded_backend_cache() + reset_frame_count() + torch._dynamo.compiled_autograd.reset() + convert_frame.FRAME_COUNTER = 0 + convert_frame.FRAME_COMPILE_COUNTER.clear() + callback_handler.clear() + GenerationTracker.clear() + TensorifyState.clear() + torch._dynamo.utils.warn_once_cache.clear() + torch._dynamo.utils.user_obj_id_to_weakref.clear() + torch._C._autograd._saved_tensors_hooks_set_tracing(False) + + +def reset_code_caches() -> None: + """ + Clears in-memory code cache, which is what stores compiled products. This + resets less state than :func:`reset` and is mostly only used for testing + purposes. + """ + # TODO: https://github.com/pytorch/pytorch/issues/139200 + import logging + + log = logging.getLogger(__name__) + log.info("torch._dynamo.reset_code_caches") + """Clear compile caches that are keyed by code objects""" + with convert_frame.compile_lock: + reset_code_state() + for weak_code in ( + convert_frame.input_codes.seen + convert_frame.output_codes.seen + ): + code = weak_code() + if code: + reset_code(code) + code_context.clear() diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc64d429d8d32fb95f38492cfff4b1d27e32aaf7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..385d077e0f1f592380efc038972980c22db9f2c4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_analysis.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28b72b94af29ce651e0c033f6c64007b4f0370a3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_analysis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_transformation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_transformation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81fc806b71bbd5a2136c82dab0acaac45c548aa6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/bytecode_transformation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3675b17d5c861a6251e1122659fd569d3e2003d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/cache_size.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/callback.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/callback.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..176bffc411f89512c36392a240a23bae30272363 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/callback.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/code_context.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/code_context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbb98fffd7011822e24b405e2319f08d4165e79b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/code_context.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/codegen.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/codegen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af533c02341632e7078112fbfe7f9200815e1fd2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/codegen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/compiled_autograd.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/compiled_autograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04e78b0202bc3f75b2005221cac08cd369638dda Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/compiled_autograd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/comptime.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/comptime.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdcaff80324652704df0adedef6f26390ac97004 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/comptime.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24bd4e57ad1ea6d66635f5bd3ab414fb48d79b05 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc3eb09bcd422d0dc92d719eb198fa0f1897feb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/convert_frame.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/create_parameter_op.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/create_parameter_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c321f266dfc8aa0fa92a184f7321508d733b1776 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/create_parameter_op.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4ae64bc247d4222de22f01d64108bf354500048 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/current_scope_id.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..134adc4ce844f7e620a34a612e1d068a9cffd2ef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/debug_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/decorators.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/decorators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1664462fd0adaf07b1236f4f8287d2add3511c53 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/decorators.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/device_interface.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/device_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcd44260359e36c9104336e410722072b47eaa5c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/device_interface.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/distributed.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/distributed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04e9fd53f5981a886cebedd672035d5a7a25a7ba Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/distributed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/eval_frame.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/eval_frame.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0d8cdf132be1eaabd17df8a68a2ca55150e3bd4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/eval_frame.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/exc.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/exc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9df4e7936dda02705c46a93531c3e5eddff78a8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/exc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f03e028fa1c12db242a6c4b34eca7fd3e4424933 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/external_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/funcname_cache.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/funcname_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0742e8ff47c9aa79dea694f3f2d1257c0a660dbf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/funcname_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_break_hints.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_break_hints.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66dd96639b4ebc7aadc51894192ad2dcd532bd22 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_break_hints.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_deduplication.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_deduplication.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb6d728b8b20892b7ba900bd61f34468e883e4d0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_deduplication.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_region_tracker.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_region_tracker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e79f247f06f7f22bd5b50644e38e0f9f1c2695a5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_region_tracker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3df68c497faa6aed01490232f26d7e5a43c399c0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/graph_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/guards.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/guards.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4849c9a68fdde5462c8491254d6515fe031b8fd0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/guards.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/hooks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d80b56e52cebfd577b35f4006af90c9c5923acf8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/hooks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/logging.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/logging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b84ae0c838b9cba26e4ef952ed8edf629f5a49c4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/logging.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/metrics_context.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/metrics_context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1934a62e60caca00c820832f90e269a7f3a9a40a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/metrics_context.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/mutation_guard.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/mutation_guard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40560bb5598b04774816b621fcf910dd69dbfd6f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/mutation_guard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/output_graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/output_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f947f3a62423fa2f90c2fecd7cf8b648912516b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/output_graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/package.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/package.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c837e6670d5b18ddf867d57906edfa492c843e8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/package.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/pgo.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/pgo.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f52b1e7d8cbd1e33d87ea06fcd1b186491ae7f02 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/pgo.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/precompile_context.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/precompile_context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d24553c376b94b70c07cdd0d617e2dda8896a7e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/precompile_context.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ece32c5cd8e02f49b3e969ef6e17558c55541a3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/replay_record.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/replay_record.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b30f5bb607063705c7440b4228c8652ddad26526 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/replay_record.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/resume_execution.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/resume_execution.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..679fab61f58ec47f7365e8b9204af1ff2780d436 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/resume_execution.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/side_effects.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/side_effects.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c271954792298bc62816a4f2a3fc1040a0c77952 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/side_effects.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/source.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/source.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55fd4b66412733b4329b9018eea17c7e3fe534d5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/source.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/tensor_version_op.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/tensor_version_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baac98d48ec1de8a0c670f83b25c5a3450a55df7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/tensor_version_op.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/test_case.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/test_case.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8c488bd51f4bfd4ea8f0dc9b5d3f73ff565df3a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/test_case.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fb04a0eb270428d073e2ad524e540b3c73b8be0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f5d8d46f780aff871ee2987e407c8c0c6e880c1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/test_minifier_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/testing.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/testing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a96a5574f9a6dfd589382b3c9261a85634bf8f67 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/testing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..469e54aceba2f80dbf15d00bbf6d29bcb96f5170 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/__pycache__/types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py b/phivenv/Lib/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a50db84cd3a11a2cd5c72e6cfd0ca07283d81e3e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -0,0 +1,244 @@ +"""trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist: +if you make_fx trace through this call, we will not actually trace into fn; instead, +we will directly insert it as a call_function to fn in the graph. +(Unlike make_fx, Dynamo WILL inline into fn.) +You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing. + +Because proxy tensor tracing does not actually run the function, there are +requirements on the behavior of fn. We are still figuring it out, but here is the current state: + +1) fn SHOULD only take a single argument, which must be a tensor +2) fn MUST return a new tensor with the same metadata as the original tensor + (e.g., zeros_like(input) is a permissible implementation of fn). + This is verified via an extra assert that is inserted into the traced graph. +3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors + participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state) +These requirements stem from the requirement that we need to continue performing proxy tensor tracing, +which assumes accurate fake tensor metadata, without actually running fn. +In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns. + +Note that tensors / Python state are allowed to be mutated. +This is relaxed constraint is not always sound, but it is sound for backward tracing with fake +tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete +tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python). + +The intended use case for this function is to allow AOTAutograd to defer complex +backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves +the function call as is in the graph, and only when we Dynamo through the backward graph in +compiled autograd do we inline into the function. +""" + +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator, OpOverload +from torch._subclasses import FakeTensorMode +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.overrides import TorchFunctionMode +from torch.utils._python_dispatch import _get_current_dispatch_mode +from torch.utils._pytree import tree_map_only + + +Tensor = torch.Tensor + + +__all__ = ["trace_wrapped"] + + +if not torch._running_with_deploy(): + # torch.library.custom_op does not work with torch.deploy/multipy # codespell:ignore + + @torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] + def zeros_and_scatter( + shape: list[int], + indices: list[Tensor], + vals: Tensor, + ) -> Tensor: + """Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass""" + grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype) + return torch.ops.aten.index_put(grad, indices, vals, accumulate=True) + + @zeros_and_scatter.register_fake # type: ignore[misc] + def _( + shape: list[int], + indices: list[Tensor], + vals: Tensor, + ) -> Tensor: + return vals.new_empty(shape) + + @zeros_and_scatter.register_vmap # type: ignore[misc] + def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] + """The batching rule is special in that it returns a tensor that is not batched""" + indices_indims = indims[1] + expanded_indices = [] + for idx, idx_indim in zip(indices, indices_indims): + # The index is not a being batched, we should unsqueeze and expand to val + if idx_indim is None: + expanded_indices.append(idx.expand(value.shape)) + else: + # the index is being part of the vmap batch, it should be the same size as val + assert idx.shape == value.shape + expanded_indices.append(idx) + + out = torch.ops.flex_lib.zeros_and_scatter( + shape, + expanded_indices, + value, + ) + return out, None + + +class ModIndex(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x: Tensor, indices: list[Tensor]) -> Tensor: + return torch.ops.aten.index(x, indices) + + @staticmethod + def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None: + x, indices = inputs + ctx.save_for_backward(*indices) + ctx.input_shape = x.shape + + @staticmethod + def backward(ctx, gradOut): # type: ignore[no-untyped-def] + indices = ctx.saved_tensors + return ( + torch.ops.flex_lib.zeros_and_scatter( + ctx.input_shape, + indices, + gradOut, + ), + None, + ) + + +mod_index = ModIndex.apply + + +class TransformGetItemToIndex(TorchFunctionMode): + # This is needed since we want to support calling + # A[q_idx], where q_idx is a scalar tensor in score_mod. + # Today, when q_idx is a scalar tensor, we implicitly convert it to a python + # scalar and create a view. We do not want that behavior in this case, so we + # use this torchfunctionmode to override that behavior for score_mod + # wherever we're running it. + def __torch_function__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + if func == torch.Tensor.__getitem__: + index_args = pytree.tree_leaves(args[1]) + if all(isinstance(x, torch.Tensor) for x in index_args): + return mod_index(args[0], index_args) + return func(*args, **(kwargs or {})) + + +def trace_wrapped(*args: Any, **kwargs: Any) -> Any: + with torch.no_grad(): + return _trace_wrapped_op(*args, **kwargs) + + +class TraceWrapped(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("trace_wrapped") + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return super().__call__(*args, **kwargs) + + +# TODO(jansel): need to ensure this does not get DCEed +_trace_wrapped_op = TraceWrapped() + + +def _assert_meta( + grad: torch.Tensor, + size: tuple[int, ...], + stride: tuple[int, ...], + dtype: torch.dtype, +) -> torch.Tensor: + assert grad.size() == size, "size mismatch" + assert grad.stride() == stride, "stride mismatch" + assert grad.dtype == dtype, "dtype mismatch" + return grad + + +@_trace_wrapped_op.py_impl(ProxyTorchDispatchMode) +def inner_trace( + mode: ProxyTorchDispatchMode, + *args: Any, + bw_state: Optional[BackwardState] = None, + **kwargs: Any, +) -> Any: + def self_invoke(*args: Any, **dyn_kwargs: Any) -> Any: + with torch.no_grad(): + return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs) + + def unwrap_proxies(x: Any) -> Any: + if isinstance(x, torch.Tensor): + return mode.tracer.unwrap_proxy(x) # type: ignore[union-attr] + if isinstance(x, (list, tuple)): + return type(x)(map(unwrap_proxies, x)) + if x is None: + return None + raise AssertionError(f"unhandled type: {type(x)}") + + proxy_kwargs = {} + if bw_state is not None: + assert isinstance(bw_state, BackwardState) and bw_state.proxy is not None + proxy_kwargs["bw_state"] = bw_state.proxy + out_proxy = mode.tracer.create_proxy( + "call_function", + self_invoke, + unwrap_proxies(args), + proxy_kwargs, + name="trace_wrapped", + ) + + if args[0] is None: + grad = args[1] # module backward hooks + else: + grad = args[0] # other backward hooks + grad = tree_map_only(torch.Tensor, torch.empty_like, grad) + track_tensor_tree(grad, out_proxy, constant=None, tracer=mode.tracer) + return grad + + +@_trace_wrapped_op.py_impl(FakeTensorMode) +def inner_fake(*args: Any, **kwargs: Any) -> None: + raise RuntimeError("This op should never be invoked here") + + +@_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def _trace_wrapped_op_dense(*args: Any, fn: Any, **kwargs: Any) -> Any: + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return fn(*args, **kwargs) + + +_trace_wrapped_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(_trace_wrapped_op, deferred_error=True) +) + + +@_trace_wrapped_op.py_functionalize_impl +def _trace_wrapped_functionalized(ctx: Any, *args: Any, **kwargs: Any) -> Any: + unwrapped_args = ctx.unwrap_tensors(args) + with ctx.redispatch_to_next(): + return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, **kwargs)) + + +def autograd_function_backward_rewritten(original_backward: Any) -> Any: + def new_backward(ctx: Any, *grads: Any) -> Any: + grads = [g.contiguous() for g in grads] + return original_backward(ctx, *grads) + + return new_backward diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__init__.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47656276d83a209e3d2c3b5b73863c19fc839c9e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3d5224f5b39d8890bb5677acaa584c90d123685 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fddf668af1d20567c0dc4e3151b26e13ee5acdfe Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/cudagraphs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3c5b60f46233bbc3d92c5fc6748b4496e06c20 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/debugging.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43b0f8f87f02d3fa4f8e0f20a25633fefe1efb25 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/distributed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/inductor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/inductor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5340eabd040541a934fe2c168644764284ea5662 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/inductor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fd3c7e013749ce4d58bcacee9d5d860f3795e45 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/onnxrt.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/registry.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e177c908cb58bad9e5532626dc32fa8e035ab906 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/registry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/tensorrt.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/tensorrt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ebf4fded64acc12b24dc531dded1e1009b3cba9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/tensorrt.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06d80e3704d742916f53c72a992ab2f279278e68 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/torchxla.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/tvm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/tvm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f192ca182b68563d4aa52883d7999041a7bb521 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/backends/__pycache__/tvm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/common.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/common.py new file mode 100644 index 0000000000000000000000000000000000000000..a52d3dd3f712684a9eed251719aba67ccc2e64d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/common.py @@ -0,0 +1,167 @@ +# mypy: ignore-errors + +""" +This module provides common utilities and base classes for TorchDynamo backends. + +Key components: +- AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends +- Backend utilities for handling: + - Fake tensor conversion + - Device/dtype detection from inputs + - Memory efficient fusion + - Graph flattening + - Common compiler configurations + +The utilities here are used by various backend implementations to handle +common operations and provide consistent behavior across different backends. +AOT autograd functionality is particularly important as it enables ahead-of-time +optimization of both forward and backward passes. +""" + +import contextlib +import functools +import logging +from unittest.mock import patch + +import torch +from torch._dynamo import disable +from torch._dynamo.exc import TensorifyScalarRestartAnalysis +from torch._dynamo.utils import counters, defake, flatten_graph_inputs +from torch._functorch.aot_autograd import ( + aot_module_simplified, + SerializableAOTDispatchCompiler, +) +from torch.utils._python_dispatch import _disable_current_modes + + +log = logging.getLogger(__name__) + + +class AotAutograd: + def __init__(self, **kwargs) -> None: + self.__name__ = "compiler_fn" + self.kwargs = kwargs + + def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs): + if kwargs: + log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs) + + if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): + return flatten_graph_inputs( + gm, + example_inputs, + self, + ) + + # Hack to get around circular import problems with aot_eager_decomp_partition + if callable(self.kwargs.get("decompositions")): + self.kwargs["decompositions"] = self.kwargs["decompositions"]() + + # NB: dont delete counter increment + counters["aot_autograd"]["total"] += 1 + use_fallback = False + + if use_fallback: + log.debug("Unable to use AOT Autograd because graph has mutation") + counters["aot_autograd"]["not_ok"] += 1 + return gm + + def wrap_bw_compiler(bw_compiler_fn): + def _wrapped_bw_compiler(*args, **kwargs): + # Note [Wrapping bw_compiler in disable] + # The two disables here: + # - stop TorchDynamo from trying to compile the bw_compiler function itself + # - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces + return disable( + disable( + bw_compiler_fn, reason="do not trace backward compiler function" + )(*args, **kwargs), + reason="do not trace generated backwards pass", + ) + + return _wrapped_bw_compiler + + bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] + + if isinstance(bw_compiler, SerializableAOTDispatchCompiler): + bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn) + else: + bw_compiler = wrap_bw_compiler(bw_compiler) + + self.kwargs["bw_compiler"] = bw_compiler + self.kwargs["inference_compiler"] = ( + self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"] + ) + + from functorch.compile import nop + from torch._inductor.debug import enable_aot_logging + + # debug asserts slow down compile time noticeably, + # So only default them on when the aot_eager backend is used. + if self.kwargs.get("fw_compiler", None) == nop: + patch_config = patch("functorch.compile.config.debug_assert", True) + else: + patch_config = contextlib.nullcontext() + + try: + # NB: NOT cloned! + with enable_aot_logging(), patch_config: + cg = aot_module_simplified(gm, example_inputs, **self.kwargs) + counters["aot_autograd"]["ok"] += 1 + return disable(cg, reason="do not trace AOT-compiled graph") + except TensorifyScalarRestartAnalysis: + raise + except Exception: + counters["aot_autograd"]["not_ok"] += 1 + raise + + +def aot_autograd(**kwargs) -> AotAutograd: + return AotAutograd(**kwargs) + + +def mem_efficient_fusion_kwargs(use_decomps): + from functorch.compile import ( + default_decompositions, + min_cut_rematerialization_partition, + ts_compile, + ) + + kwargs = { + # these are taken from memory_efficient_fusion() + "fw_compiler": ts_compile, + "bw_compiler": ts_compile, + "partition_fn": min_cut_rematerialization_partition, + } + + if use_decomps: + kwargs["decompositions"] = default_decompositions + + return kwargs + + +def fake_tensor_unsupported(fn): + """ + Decorator for backends that need real inputs. We swap out fake + tensors for zero tensors. + """ + + @functools.wraps(fn) + def wrapper(model, inputs, **kwargs): + with _disable_current_modes(): + inputs = list(map(defake, inputs)) + return fn(model, inputs, **kwargs) + + return wrapper + + +def device_from_inputs(example_inputs) -> torch.device: + for x in example_inputs: + if hasattr(x, "device"): + return x.device + + +def dtype_from_inputs(example_inputs) -> torch.dtype: + for x in example_inputs: + if hasattr(x, "dtype"): + return x.dtype diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/cudagraphs.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/cudagraphs.py new file mode 100644 index 0000000000000000000000000000000000000000..192cdd6f02c2e68bf111c7d53ceaa58a2f6e34b4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/cudagraphs.py @@ -0,0 +1,279 @@ +# mypy: ignore-errors + +""" +This module implements CUDA graphs support for TorchDynamo backends. + +CUDA graphs allow for capturing and replaying GPU operations, which can significantly +reduce CPU overhead in GPU-accelerated PyTorch models. This module provides: + +- CUDA graph creation and management for both forward and backward passes +- Input mutation detection and handling +- Device compatibility checking +- Stack trace management for debugging +- Integration with TorchInductor's cudagraph trees + +The backend supports two main modes: +1. cudagraphs: Full CUDA graph support with both forward and backward pass optimization +2. cudagraphs_inner: Lower-level CUDA graph implementation used for benchmarking + +Key components: +- CudagraphsBackend: Main backend class for CUDA graph integration +- Mutation detection utilities to ensure graph safety +- Device mapping and compatibility checks +- Stack trace collection for debugging +""" + +import functools +from collections import defaultdict +from typing import Optional + +import torch +from torch._dynamo import config +from torch._dynamo.backends.common import aot_autograd +from torch._dynamo.backends.debugging import boxed_nop +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + check_multiple_devices_or_any_cpu_nodes, + format_default_skip_message, + get_mutation_stack_trace, + get_placeholder_info, + log_cudagraph_skip_and_bump_counter, +) +from torch._inductor.utils import ( + BoxedBool, + count_tangents, + get_first_incompatible_cudagraph_node, + num_fw_fixed_arguments, + output_node, +) +from torch.multiprocessing.reductions import StorageWeakRef + +from .registry import register_backend + + +def find_input_mutations(g): + def meta_fk(meta): + return meta["val"] if "val" in meta else meta["fake_result"] + + inputs = defaultdict(set) + input_idx = 0 + mutated_inputs = set() + for n in g.nodes: + if n.op == "placeholder": + if isinstance(meta_fk(n.meta), torch.Tensor): + inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) + input_idx += 1 + elif n.op == "call_function": + if not hasattr(n.target, "_schema"): + continue + + schema = n.target._schema + for i, arg in enumerate(schema.arguments): + if i < len(n.args): + argument = n.args[i] + else: + if arg.name not in n.kwargs: + continue + argument = n.kwargs[arg.name] + mut_arg = False + if arg.alias_info: + if arg.alias_info.is_write: + mut_arg = True + if mut_arg: + # TODO: not correct for args that contain tensors in a struct + # like list + mutated_inputs |= inputs[ + StorageWeakRef(meta_fk(argument.meta)._typed_storage()) + ] + + # TODO: error on unrecognized nodes + return mutated_inputs + + +def get_device_node_mapping(gm: torch.fx.GraphModule): + device_node_mapping: dict[torch.device, torch.fx.Node] = {} + for n in gm.graph.nodes: + t = n.meta.get("val", None) + if isinstance(t, torch.Tensor) and t.device not in device_node_mapping: + device_node_mapping[t.device] = n + return device_node_mapping + + +def check_for_mutation_ignore_cuda_graph_managed_tensor( + aot_model: torch.fx.GraphModule, num_fixed +) -> Optional[str]: + mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed)) + if not mutation_indices: + return None + + placeholders = get_placeholder_info(aot_model.graph) + return get_mutation_stack_trace(placeholders, mutation_indices) + + +def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: + if not config.cudagraph_backend_support_input_mutation: + if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor( + aot_model, num_fixed + ): + return mut_skip + + if skip := check_multiple_devices_or_any_cpu_nodes( + get_device_node_mapping(aot_model) + ): + return skip + + if node := get_first_incompatible_cudagraph_node(aot_model): + return format_default_skip_message(f"incompatible op ({node.name})") + + return None + + +def get_device_index(gm) -> int: + device = next(iter(get_device_node_mapping(gm))) + assert device.type == "cuda" + return device.index + + +def get_stack_traces(gm) -> list[Optional[str]]: + output = output_node(gm) + assert len(output.args) == 1 + return [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in output.args[0] + ] + + +def cudagraphs(dynamo_model, dynamo_inputs): + from torch._inductor.cudagraph_trees import cudagraphify_impl + + do_cudagraphs = BoxedBool(True) + boxed_device_index = BoxedDeviceIndex(None) + + def forward_cudagraphs(aot_model, aot_inputs, is_inference=False): + interp = boxed_nop(aot_model, aot_inputs) + fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs)) + if skip_msg := check_for_skip(aot_model, fixed): + BoxedBool.disable(do_cudagraphs) + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {skip_msg}" + ) + return interp + + boxed_device_index.set(get_device_index(aot_model)) + out = cudagraphify_impl( + interp, + aot_inputs, + range(fixed), + device_index=boxed_device_index.value, + is_backward=False, + is_inference=False, + stack_traces=get_stack_traces(aot_model), + placeholders=get_placeholder_info(aot_model.graph), + mutated_input_idxs=find_input_mutations(aot_model.graph), + ) + out._boxed_call = True + return out + + def backward_cudagraphs(aot_model, aot_inputs): + interp = boxed_nop(aot_model, aot_inputs) + if not do_cudagraphs: + return aot_model + + fixed = count_tangents(aot_model) + if skip_msg := check_for_skip(aot_model, fixed): + log_cudagraph_skip_and_bump_counter( + "skipping cudagraphs due to %s", skip_msg + ) + + # See [Backward Generation Handling] + manager = torch._inductor.cudagraph_trees.get_manager( + boxed_device_index.value, create_if_none_exists=False + ) + assert manager is not None + + def fn(inputs): + manager.set_to_running_backward() + return aot_model(inputs) + + fn._boxed_call = True + return fn + + out = cudagraphify_impl( + interp, + aot_inputs, + range(fixed), + device_index=get_device_index(aot_model), + is_backward=True, + is_inference=False, + stack_traces=get_stack_traces(aot_model), + placeholders=get_placeholder_info(aot_model.graph), + mutated_input_idxs=find_input_mutations(aot_model.graph), + ) + out._boxed_call = True + return out + + aot_cudagraphs = aot_autograd( + fw_compiler=forward_cudagraphs, + bw_compiler=backward_cudagraphs, + inference_compiler=functools.partial(forward_cudagraphs, is_inference=True), + keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation, + ) + return aot_cudagraphs(dynamo_model, dynamo_inputs) + + +class CudagraphsBackend: + compiler_name = "cudagraphs" + + @staticmethod + def reset(): + from torch._inductor.cudagraph_trees import reset_cudagraph_trees + + reset_cudagraph_trees() + + @staticmethod + def __call__(model, inputs): + return cudagraphs(model, inputs) + + +# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful +# for debugging and can serve as a perf baseline. +register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend()) + + +def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True): + """This isn't registered as a backend, but is used in some benchmarks""" + assert isinstance(inputs, (list, tuple)) + if copy_inputs: + static_inputs = [torch.zeros_like(x) for x in inputs] + else: + static_inputs = list(inputs) + + # warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + model(*inputs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # record + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + static_outputs = model(*static_inputs) + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + def run(*new_inputs): + assert len(static_inputs) == len(new_inputs) + if copy_inputs: + for dst, src in zip(static_inputs, new_inputs): + dst.copy_(src) + graph.replay() + if copy_outputs: + return [x.clone() for x in static_outputs] + else: + return static_outputs + + return run diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/debugging.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/debugging.py new file mode 100644 index 0000000000000000000000000000000000000000..cb398219a18339ae88c4b4a8eff75915882df2f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/debugging.py @@ -0,0 +1,470 @@ +# mypy: ignore-errors + +""" +This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot +compilation and execution issues. It includes: + +Key Debugging Backends: +- eager: Simple pass-through backend that runs models in eager mode +- eager_noexcept: Similar to eager but with additional exception handling +- eager_debug: Adds schema validation checks for custom operators +- aot_eager: Uses AOT Autograd with nop compiler for debugging +- aot_eager_decomp_partition: Uses TorchInductor decompositions for debugging +- torchscript: Compiles using TorchScript for debugging JIT-related issues + +Testing and Development Tools: +- Backends for inducing specific errors (compile/runtime/accuracy) +- ExplainOutput class for detailed graph compilation analysis +- Utilities for cross-referencing and mode management +- Tools for graph detail inspection and break reason analysis + +These backends are primarily used for: +1. Debugging graph breaks and compilation failures +2. Testing error handling and recovery mechanisms +3. Analyzing performance bottlenecks +4. Validating operator schemas and decompositions +""" + +import dataclasses +import functools +import logging +from importlib import import_module +from typing import Any, Optional + +import torch +from functorch.compile import min_cut_rematerialization_partition +from torch import _guards +from torch._functorch import config as functorch_config +from torch._functorch.compilers import ts_compile + +from .common import aot_autograd +from .registry import register_debug_backend as register_backend + + +log = logging.getLogger(__name__) + + +@register_backend +def eager(gm, fake_tensor_inputs, **kwargs): + if kwargs: + log.warning("eager backend ignoring extra kwargs %s", kwargs) + return gm.forward + + +def make_eager_backend_with_torch_function_mode(mode): + return make_eager_backend_with_torch_function_modes([mode]) + + +def make_eager_backend_with_torch_function_modes(modes): + """Used to trace HOPs (cond and while) for eager execution, the metadata + TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks + in the HOP, so we need to externally run this mode and not trace it.""" + from contextlib import ExitStack + + def fn(gm, fake_tensor_inputs, **kwargs): + stack = ExitStack() + for mode in modes: + stack.enter_context(mode) + + result = gm.forward + stack.close() + return result + + return fn + + +@register_backend +def eager_noexcept(gm, fake_tensor_inputs, **kwargs): + if kwargs: + log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs) + + # This backend is intended to check that dynamo-generated GraphModules + # do not cause errors. + def inner(*args): + try: + return gm(*args) + except Exception as e: + raise torch._dynamo.exc.TorchDynamoException( + "Unexpected exception when running generated GraphModule" + ) from e + + return inner + + +@register_backend +def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs): + if kwargs: + log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs) + + from torch.fx.experimental.proxy_tensor import make_fx + + def runnable_gm(*args): + return torch.fx.Interpreter(gm).run(*args) + + pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs) + pre_dispatch_gm.print_readable() + + return pre_dispatch_gm + + +@register_backend +def eager_debug(gm, fake_tensor_inputs, **kwargs): + if kwargs: + log.warning("eager_debug backend ignoring extra kwargs %s", kwargs) + + from torch._subclasses.schema_check_mode import SchemaCheckMode + + # We could add more debugging bits here. + # Right now, this backend can be used to check for and error on + # custom dispatcher ops that have incorrect schemas. + def inner(*args): + with SchemaCheckMode(): + return torch.fx.Interpreter(gm).run(*args) + + return inner + + +@register_backend(name="ts") +def torchscript(gm, fake_tensor_inputs): + return torch.jit.script(gm) + + +# used boxed call to discard inputs when they are no longer needed +def boxed_nop(fx_g, example_inputs): + def run(args): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + +def boxed_nop_with_mode(fx_g, example_inputs, *, mode): + def run(args): + with mode: + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + +def fake_crossref_boxed_nop(fx_g, example_inputs, ignore_op_fn=None): + def run(args): + with torch._subclasses.CrossRefFakeMode(ignore_op_fn): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + +def ignore_builtins(op: torch._ops.OpOverload) -> bool: + return op.namespace in ("aten", "prims", "prim") + + +def get_nop_func(): + if not torch._functorch.config.fake_tensor_crossref: + return boxed_nop + elif torch._functorch.config.fake_tensor_crossref == "all": + return fake_crossref_boxed_nop + else: + assert torch._functorch.config.fake_tensor_crossref == "custom_ops" + return functools.partial(fake_crossref_boxed_nop, ignore_op_fn=ignore_builtins) + + +# Useful for debugging purpose +# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. +def aot_eager( + gm, + fake_tensor_inputs, + fw_compiler=None, + bw_compiler=None, + **kwargs, +): + return aot_autograd( + fw_compiler=fw_compiler or boxed_nop, + bw_compiler=bw_compiler or boxed_nop, + partition_fn=min_cut_rematerialization_partition, + keep_inference_input_mutations=True, + )(gm, fake_tensor_inputs, **kwargs) + + +register_backend(name="aot_eager", compiler_fn=aot_eager) + +aot_eager_default_partitioner = aot_autograd( + fw_compiler=boxed_nop, keep_inference_input_mutations=True +) +register_backend( + name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner +) + + +# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs +# inductor problems. +# aot_eager_decomp_partition just replaces the inductor compiler with nop to help +# isolate inductor vs aot_eager errors +def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs): + if kwargs: + log.warning( + "aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs + ) + + from torch._inductor.compiler_bisector import CompilerBisector + + config_patches = {"unlift_effect_tokens": True} + if bisect_changes := CompilerBisector.get_config_change( + "aot_eager_decomp_partition" + ): + config_patches.update(bisect_changes) + + with functorch_config.patch(config_patches): + return aot_autograd( + # these are taken from memory_efficient_fusion() + fw_compiler=get_nop_func(), + bw_compiler=get_nop_func(), + # NB: lambda here is to delay import of inductor + decompositions=lambda: import_module( + "torch._inductor.compile_fx" + ).select_decomp_table(), + partition_fn=functools.partial( + min_cut_rematerialization_partition, compiler="inductor" + ), + )(gm, fake_tensor_inputs) + + +register_backend( + name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition +) + + +# aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition, +# except that it takes a TorchDispatchMode mode and run the fw/bw in the mode +def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg): + return aot_autograd( + # these are taken from memory_efficient_fusion() + fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode), + bw_compiler=functools.partial(boxed_nop_with_mode, mode=mode), + # NB: lambda here is to delay import of inductor + decompositions=lambda: import_module( + "torch._inductor.compile_fx" + ).select_decomp_table(), + partition_fn=functools.partial( + min_cut_rematerialization_partition, compiler="inductor" + ), + )(gm, fake_tensor_inputs) + + +register_backend( + name="aot_eager_decomp_partition_with_mode", + compiler_fn=aot_eager_decomp_partition_with_mode, +) + + +def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs): + # if the config is set, respect it, otherwise only test custom_ops. + # custom_op bad metas always manifest as an error whereas aten will only sometimes. + # by default, use the less noisy option + config_val = ( + "custom_ops" + if not functorch_config.fake_tensor_crossref + else functorch_config.fake_tensor_crossref + ) + with functorch_config.patch(fake_tensor_crossref=config_val): + return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs) + + +register_backend( + name="aot_eager_decomp_partition_crossref", + compiler_fn=aot_eager_decomp_partition_crossref, +) + + +# AOT Autograd with torchscript backend. Default partitioner. +# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser +# by using the relevant fuser with torch.jit.fuser(...) +aot_ts = aot_autograd(fw_compiler=ts_compile) +register_backend(name="aot_ts", compiler_fn=aot_ts) + +# These buggy backends are used for inducing bugs so that we can test +# our repro extraction / minifier scripts + + +class ReluCompileError(Exception): + pass + + +class TestingOnlyCompileError(Exception): + pass + + +@register_backend +def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise ReluCompileError + return gm + + +@register_backend +def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch._assert + node.args = (False, "ReluRuntimeError") + gm.recompile() + return gm + + +@register_backend +def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch.add + node.args = (node.args[0], 1) + gm.recompile() + + return gm + + +@register_backend +def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): + # Require at least one non-trivial thing in the graph, + # see https://github.com/pytorch/pytorch/issues/102898 + for node in gm.graph.nodes: + if node.op == "call_function": + break + else: + return gm + for t in example_inputs: + if not t.is_leaf: + raise TestingOnlyCompileError + return gm + + +@dataclasses.dataclass +class ExplainOutput: + """ + This is the output of :func:`torch._dynamo.explain()` + There is no reason to create this class directly. + """ + + graphs: list[torch.fx.GraphModule] + graph_count: int + graph_break_count: int + break_reasons: list[ + Any + ] # Type is GraphCompileReason but doesn't matter for this purpose + op_count: int + ops_per_graph: Optional[list[torch.fx.Node]] = None + out_guards: Optional[list[_guards.Guard]] = None + compile_times: Optional[str] = None + + def __str__(self) -> str: + output = f"Graph Count: {self.graph_count}\n" + output += f"Graph Break Count: {self.graph_break_count}\n" + output += f"Op Count: {self.op_count}\n" + + output += "Break Reasons:\n" + for idx, break_reason in enumerate(self.break_reasons): + output += f" Break Reason {idx + 1}:\n" + output += f" Reason: {break_reason.reason}\n" + output += " User Stack:\n" + for frame_summary in break_reason.user_stack: + output += f" {frame_summary}\n" + + if self.ops_per_graph is not None: + output += "Ops per Graph:\n" + for idx, ops in enumerate(self.ops_per_graph): + output += f" Ops {idx + 1}:\n" + for op in ops: + output += f" {op}\n" + + if self.out_guards is not None: + output += "Out Guards:\n" + for i, guard in enumerate(self.out_guards): + output += f" Guard {i + 1}:\n" + output += f" {str(guard)}" + + if self.compile_times is not None: + output += f"Compile Times: {self.compile_times}\n" + return output + + +def _explain_graph_detail( + gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons +): + """ + This function is a utility which processes a torch.fx.GraphModule and + accumulates information about its ops, graph breaks, and other details. It + is intended to be used by the ExplainWithBackend class and + `torch._dynamo.explain()` to provide details from Dynamo's graph capture. + + Parameters: + gm (torch.fx.GraphModule): The GraphModule to be processed. + graphs (list): A list that accumulates all the GraphModules processed. + op_count (int): The total count of operations in all GraphModules processed so far. + ops_per_graph (list): A list that accumulates the operations of each GraphModule. + break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule. + + Returns: + tuple: A tuple containing the processed GraphModule, the updated lists of graphs, + operations per graph, and break reasons, and the updated operation count. + """ + graphs.append(gm) + ops = [node.target for node in gm.graph.nodes if node.op == "call_function"] + op_count += len(ops) + ops_per_graph.append(ops) + if gm.compile_subgraph_reason.graph_break: + break_reasons.append(gm.compile_subgraph_reason) + + return gm, graphs, op_count, ops_per_graph, break_reasons + + +class ExplainWithBackend: + """ + This class is intended to be used as a backend for `torch.compile`. It is + composable with other backends. When used in this way, it accumulates + information about graph breaks, ops, and other info and provides a string + representation summarizing this information. + + Attributes: + backend (str): The name of the backend to use for optimization. + graphs (list): A list of the graphs captured by TorchDynamo. + op_count (int): The total number of operations in all optimized graphs. + break_reasons (list): A list of graph break reasons with stack traces. + + Example Usage: + def fn(x): + x = torch.sigmoid(x) + return x + + torch._dynamo.reset() + eb = ExplainWithBackend("inductor") + optimized_fn = torch.compile(fn, backend=eb) + result = optimized_fn(torch.randn(5)) + print(eb.output()) + """ + + def __init__(self, backend) -> None: + from .registry import lookup_backend + + self.backend = lookup_backend(backend) + self.graphs = [] + self.op_count = 0 + self.break_reasons = [] + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail( + gm, self.graphs, self.op_count, [], self.break_reasons + ) + return self.backend(gm, example_inputs) + + def output(self) -> ExplainOutput: + graph_count = len(self.graphs) + output = ExplainOutput( + self.graphs, + graph_count, + graph_count - 1, + self.break_reasons, + self.op_count, + ) + + return output diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/distributed.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..abe22f9419f685d5e2ad4a9452554ad4e8be11f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/distributed.py @@ -0,0 +1,578 @@ +# mypy: ignore-errors + +""" +This module implements distributed training optimizations for TorchDynamo backends. + +It provides functionality to optimize models wrapped in DistributedDataParallel (DDP) +by intelligently splitting compiled graphs to align with DDP's gradient synchronization +boundaries. Key features include: + +- Graph partitioning based on parameter bucket sizes +- Optimization of allreduce operations for distributed training +- Support for parameter ignoring and buffer handling +- Submodule compilation and management +- Debugging utilities for distributed training + +The main component is the DDPOptimizer class, which handles graph splitting and +recompilation to enable efficient distributed training while maintaining the benefits +of compilation. +""" + +import logging +import traceback +from dataclasses import dataclass, field +from typing import Any, Optional +from unittest import mock + +import torch +from torch import fx +from torch._dynamo.output_graph import GraphCompileReason +from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode +from torch._logging import trace_structured +from torch.fx.node import Node + + +# Regular log messages should go through 'log'. +# ddp_graph_log is a separate artifact logger reserved for dumping graphs. +# See docs/source/logging.rst for more info. +log = logging.getLogger(__name__) +ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs") + + +def args_str(args): + # a debug helper + if torch.is_tensor(args): + return f"T[{args.shape}]" + elif isinstance(args, tuple): + return f"tuple({', '.join([args_str(x) for x in args])})" + elif isinstance(args, list): + return f"list({', '.join([args_str(x) for x in args])})" + else: + return str(args) + + +@dataclass +class Bucket: + size: int = 0 + params: list[str] = field(default_factory=list) + nodes: list[fx.Node] = field(default_factory=list) + + # param_ids is just used for unit testing + param_ids: list = field(default_factory=list) + + # keep track of any buckets that were extended for logging purposes + opcount_increased_to_capture_external_output: int = 0 + paramsize_before_opcount_increase: int = 0 + + +def bucket_has_external_output(bucket: Bucket) -> bool: + nodes_in_bucket = set() + # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards + # so we don't reverse it here + for node in bucket.nodes: + # assume node.op != output, since those are filtered in the original iteration + nodes_in_bucket.add(node) + for user in node.users: + if user not in nodes_in_bucket: + return True + return False + + +def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int): + headers = ("Index", "Size (b)", "Param Names") + rows = [] + extended_buckets = [] + for idx, bucket in enumerate(reversed(buckets)): + if len(bucket.params) > 0: + rows.append((idx, bucket.size, bucket.params[0])) + rows.extend((None, None, param) for param in bucket.params[1:]) + if bucket.opcount_increased_to_capture_external_output > 0: + extended_buckets.append( + ( + idx, + bucket.opcount_increased_to_capture_external_output, + bucket.size - bucket.paramsize_before_opcount_increase, + ) + ) + + if len(rows): + log.info( + "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.", + bucket_bytes_cap, + len(buckets), + ) + + if len(extended_buckets): + log.warning( + "Some buckets were extended beyond their requested parameter capacities" + " in order to ensure each subgraph has an output node, required for fx graph partitioning." + " This can be the case when a subgraph would have only contained nodes performing inplace mutation," + " and returning no logical outputs. This should not be a problem, unless it results in too few graph" + " partitions for optimal DDP performance." + ) + + try: + from tabulate import tabulate + + log.debug( + "\nDDPOptimizer produced the following bucket assignments:\n%s", + tabulate(rows, headers=headers, tablefmt="simple_grid"), + ) + + if len(extended_buckets): + log.warning( + "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s", + tabulate( + extended_buckets, + headers=("Index", "Extra Ops", "Extra Param Size (b)"), + tablefmt="simple_grid", + ), + ) + except ImportError: + log.debug( + "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information." + ) + else: + log.debug("DDPOptimizer captured no parameters and did not split this graph.") + + +def has_higher_order_op(gm): + # Check if there is a higher order op in the graph + for node in gm.graph.nodes: + if node.op == "get_attr": + maybe_param = getattr(gm, node.target) + if isinstance(maybe_param, torch.fx.GraphModule): + return True + return False + + +def propagate_metadata(orig_gm, split_gm) -> None: + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + # TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384 + module.meta = orig_gm.meta + module._param_name_to_source = orig_gm._param_name_to_source + + +def propagate_dynamo_source(orig_gm, split_gm) -> None: + name_to_dynamo_source = {} + for node in orig_gm.graph.find_nodes(op="placeholder"): + name_to_dynamo_source[node.name] = node._dynamo_source + + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + for node in module.graph.find_nodes(op="placeholder"): + # non-placeholder in original_gm may become placeholder in submodules + node._dynamo_source = name_to_dynamo_source.get(node.name, None) + + +# compile each of the partitioned submodules using the user-provided compiler +class SubmodCompiler(torch.fx.interpreter.Interpreter): + def __init__(self, module, compiler, fake_mode) -> None: + super().__init__(module) + self.compiler = compiler + self.fake_mode = fake_mode + + def compile_submod(self, input_mod, args, kwargs): + """ + Compile the submodule, + using a wrapper to make sure its output is always a tuple, + which is required by AotAutograd based compilers + """ + assert len(kwargs) == 0, "We assume only args for these modules" + + class WrapperModule(torch.nn.Module): + def __init__(self, submod, unwrap_singleton_tuple) -> None: + super().__init__() + self.submod = submod + self.unwrap_singleton_tuple = unwrap_singleton_tuple + + def forward(self, *args): + x = self.submod(*args) + # TODO(whc) + # for some reason the isinstance check is necessary if I split one node per submod + # - even though I supposedly wrapped the output in a tuple in those cases, the real + # compiled module was still returning a tensor + if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)): + return x[0] + return x + + unwrap_singleton_tuple = False + for sn in input_mod.graph.nodes: + if sn.op == "output": + if not isinstance(sn.args[0], tuple): + unwrap_singleton_tuple = True + sn.args = (sn.args,) + + input_mod.recompile() + input_mod.compile_subgraph_reason = GraphCompileReason( + "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])." + " Set `torch._dynamo.config.optimize_ddp = False` to disable.", + [ + # it's close to useless to get a real stacktrace here, and quite verbose. + traceback.FrameSummary(__file__, 0, DDPOptimizer), + ], + ) + + wrapper = WrapperModule( + self.compiler(input_mod, args), + unwrap_singleton_tuple, + ) + return wrapper + + # Note: + # + # The way distributed works today around fake tensors can be somewhat confusing. + # Some of these codepaths are shared in both runtime, and compile time. The presence + # of a fake_mode, read off of fake tensor inputs, dictates how we will operate. + # + # A few things to keep in mind: + # + # 1) We invoke `compile_submod` with a real module. The output of that gets stored + # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`. + # + # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the + # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it. + # + # 3) Fake tensors should always be around during compile time. + # + # 4) Fake tensors should never be around at runtime. + # + # 5) We end up with a compilation mode that takes a real submodule and fake tensors, + # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd] + def run_node(self, n: Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + new_args = [] + assert self.fake_mode + for arg in args: + if isinstance(arg, torch.Tensor) and not isinstance( + arg, torch._subclasses.FakeTensor + ): + new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode)) + else: + new_args.append(arg) + + log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args)) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + if n.op == "call_module": + real_mod = self.fetch_attr(n.target) + if self.fake_mode: + curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode) + else: + curr_submod = real_mod + + ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph) + + # When calling the compiler on the submod, inputs (new_args) are expected to + # be FakeTensors already since Dynamo would have made them FakeTensors in the + # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors, + # since this wrapping happens during compilation + + # Note: Returning Fake Tensors on First AOT Autograd Call + # + # Inductor will optimize strides of outputs when it deems it profitable. + # For instance, converting to channels last. When we split the graph here + # into multiple inductor compilations, we need to make sure that the + # output strides of one compilation is appropriately passed to the subsequent + # compilations. However, the mapping from inductor output to dynamo output + # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing, + # subclass handling, etc. In order to replay all this logic we set a flag such that + # the first invocation of inductor in aot_autograd will return Fake Tensors with + # appropriate strides. Then, all of aot autograd's runtime logic is replayed. + # This gives us the appropriately strided outputs here which will reflect runtime strides. + + class FakeifyFirstAOTInvocationGuard: + def __init__(self) -> None: + self.tc = torch._guards.TracingContext.try_get() + assert self.tc + torch._guards.TracingContext.try_get().fakify_first_call = True + + def __del__(self) -> None: + self.tc.fakify_first_call = False + + # For aot_eager and other backends, tracing context is not set + has_tracing_context = torch._guards.TracingContext.try_get() is not None + if has_tracing_context: + g = FakeifyFirstAOTInvocationGuard() # noqa: F841 + + from torch._dynamo.utils import counters + + init = counters["aot_autograd"]["total"] + compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs) + + # TODO - better way of doing this? + # Only aot autograd handles fakifying first call + invoked_aot_autograd = init != counters["aot_autograd"]["total"] + + # We update the original (outer) graph with a call into the compiled module + # instead of the uncompiled one. + self.module.delete_submodule(n.target) + n.target = "compiled_" + n.target + self.module.add_submodule(n.target, compiled_submod_real) + + # Finally, we have to produce inputs for use compiling the next submodule, + # and these need to be FakeTensors, so we execute the module under fake_mode + # Because parameters are not fake we patch fake tensor mode to allow non fake inputs + with ( + self.fake_mode, + mock.patch.object(self.fake_mode, "allow_non_fake_inputs", True), + ): + if has_tracing_context and invoked_aot_autograd: + out = compiled_submod_real(*new_args, **kwargs) + # output should be fake or subclass + assert all( + (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor) + for t in (out if isinstance(out, (list, tuple)) else [out]) + ) + return out + else: + return curr_submod(*new_args, **kwargs) + else: + # placeholder or output nodes don't need to get compiled, just executed + return getattr(self, n.op)(n.target, new_args, kwargs) + + +class DDPOptimizer: + """Note [DDPOptimizer] + DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP), + breaking the dynamo graph into chunks to compile separately, with the breaks aligning to + the boundaries of gradient-allreduce buckets chosen by DDP. + + Background/Motivation + - DDP uses allreduce collectives to synchronize partial gradients computed on different workers + - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce + - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready + at around the same time during backward and thus can share the same allreduce efficiently + - Allreduces must overlap with backward compute for optimal training performance + - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which + operates when individual grads become 'ready' + - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the + autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole + fused backward function executes, preventing any overlap of compute and communication + + Algorithm + - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse + this graph in reverse order to determine the true order that gradients will become ready during backward. + - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started + and a graph break introduced + - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together + into an outer module that is returned to the user + + Notes + - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP, + and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does + in eager. + - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently + produce splits that do not necessarily align with the buckets used by DDP. This should result in performance + degradation approaching the baseline case where graph-splits are not used, but not worse. + - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the + subgraphs being compiled + - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers + left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are + also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP, + it is not catastrophic but could impact performance by choosing sub-optimal bucket splits. + - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients, + and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by + DDPOptimizer) + + Debugging + - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb. + - In many cases, the log messages are helpful (they show bucket size assignments)- + just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'. + - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model + in a single process (or with torchrun, in multiple processes) + + Args: + bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be + set to match the equivalent parameter on the original DDP module. + + backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph. + + first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP + special-cases the first bucket size since it is sometimes optimal to start a small allreduce early. + + """ + + def __init__( + self, + bucket_bytes_cap: int, + backend_compile_fn, + first_bucket_cap: Optional[int] = None, + ) -> None: + if first_bucket_cap is not None: + self.first_bucket_cap = first_bucket_cap + elif torch.distributed.is_available(): + # this constant comes from C10D lib which is not always built + self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES + else: + self.first_bucket_cap = bucket_bytes_cap + + self.bucket_bytes_cap = bucket_bytes_cap + assert self.first_bucket_cap <= self.bucket_bytes_cap, ( + "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP" + ) + + self.backend_compile_fn = backend_compile_fn + + def _ignore_parameter(self, parameter): + return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored + + def add_param(self, bucket, param, name): + bucket.size += param.untyped_storage().nbytes() + bucket.params.append(name) + bucket.param_ids.append(id(param)) + + def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix): + processed_modules.add(mod) + for name, param in mod.named_parameters(): + if param.requires_grad and not self._ignore_parameter(param): + self.add_param(bucket, param, f"{prefix}_{name}") + + def add_param_args(self, bucket, node): + for arg in node.args: + if not isinstance(arg, torch.fx.node.Node): + continue + if arg.op != "placeholder": + continue + param = arg.meta["example_value"] + if ( + isinstance(param, torch.nn.Parameter) + and param.requires_grad + and not self._ignore_parameter(param) + ): + self.add_param(bucket, param, arg.target) + + def compile_fn(self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]): + """ + Implements graph splitting, first determining a set of of buckets by counting + parameter sizes in reverse graph order, then invoking the user/backend compiler + to compile each subgraph. Finally, stiches compiled graphs into one graphmodule + and returns its callable. + """ + # 1: compute the partition map according to DDP bucket logic + buckets = [Bucket()] # (size, param_names) + processed_modules = set() + for node in reversed(gm.graph.nodes): + if node.op in ("output", "placeholder"): + continue + + if ( + buckets[0].size >= self.bucket_bytes_cap + or len(buckets) == 1 + and buckets[0].size >= self.first_bucket_cap + ): + if bucket_has_external_output(buckets[0]): + buckets.insert(0, Bucket()) + else: + # continue building this bucket past the point of filling its parameter capacity, + # to increase chances it contains at least one node that is either a global output or + # passed as input to a subsequent graph + + if buckets[0].opcount_increased_to_capture_external_output == 0: + buckets[0].paramsize_before_opcount_increase = buckets[0].size + buckets[0].opcount_increased_to_capture_external_output += 1 + + if node.op == "call_function": + self.add_param_args(buckets[0], node) + + elif node.op == "call_module": + target_mod = gm.get_submodule(node.target) + if target_mod not in processed_modules: + self.add_module_params_to_bucket( + target_mod, buckets[0], processed_modules, node.target + ) + elif node.op == "call_method": + if isinstance(node.args[0].target, str): + target_mod = None + try: + target_mod = gm.get_submodule(node.args[0].target) + except AttributeError: + pass + if target_mod is not None and target_mod not in processed_modules: + self.add_module_params_to_bucket( + target_mod, buckets[0], processed_modules, node.target + ) + # This handles situations like tmp = torch.mm(x, self.weight.t()) + # t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None + # tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None + self.add_param_args(buckets[0], node) + + elif node.op == "get_attr": + maybe_param = getattr(gm, node.target) + if ( + isinstance(maybe_param, torch.nn.Parameter) + and maybe_param.requires_grad + and not self._ignore_parameter(maybe_param) + ): + self.add_param(buckets[0], maybe_param, node.target) + + # All nodes have to be mapped to a bucket, even if they don't have their own params + # Ignored params still end up in buckets, we just don't count them towards the capacity + buckets[0].nodes.append(node) + + if len(buckets) > 1 and buckets[0].size == 0: + # we collected a small preamble graph with ops that don't include parameters, fuse it back + buckets[1].nodes.extend(buckets[0].nodes) + assert len(buckets[0].params) == 0, "Params should be empty if size is 0" + del buckets[0] + + # stash buckets for testing/debugging purposes + self.buckets = buckets + pretty_print_buckets(buckets, self.bucket_bytes_cap) + + if len(buckets) == 1: + # bypass split/fuse logic if there is only one bucket + return self.backend_compile_fn(gm, example_inputs) + + # 2: partition the graphmodule according to bucket capacity + partition_map = {} + for idx, b in enumerate(buckets): + for node in b.nodes: + partition_map[node] = idx + + split_gm = fx.passes.split_module.split_module( + gm, None, lambda node: partition_map[node] + ) + + # See note [Assumption on Dynamo Metadata] + propagate_dynamo_source(gm, split_gm) + propagate_metadata(gm, split_gm) + + debug_str = ( + f"\n---orig graph---\n{gm.graph}\n" + + f"\n---split graph---\n{split_gm.graph}\n" + ) + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + # only print the submod graphs, not their children + debug_str += f"\n---{name} graph---\n{module.graph}\n" + debug_str += "\n---------------\n" + ddp_graph_log.debug(debug_str) + + trace_structured( + "optimize_ddp_split_graph", + payload_fn=lambda: split_gm.print_readable(print_output=False), + ) + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + trace_structured( + "optimize_ddp_split_child", + lambda: {"name": name}, + payload_fn=lambda: module.print_readable(print_output=False), + ) + + fake_mode = detect_fake_mode(example_inputs) + if fake_mode is None: + fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() + + submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode) + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + submod_compiler.run(*example_inputs) + split_gm.recompile() + + ddp_graph_log.debug( + "\n---final graph---\n%s\n---------------\n", split_gm.graph + ) + return split_gm diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/inductor.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/inductor.py new file mode 100644 index 0000000000000000000000000000000000000000..141e15206c6c595ce01bc6d14ade77cf562b66be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/inductor.py @@ -0,0 +1,25 @@ +# mypy: ignore-errors + +""" +This module provides the TorchInductor backend integration for TorchDynamo. + +TorchInductor is a compiler backend that generates optimized code for both CPU and GPU. +This module lazily imports and registers the TorchInductor compiler to avoid loading it +into memory when it is not being used. This helps reduce memory overhead when using +other backends. + +The inductor backend can be used with torch.compile(): + model = torch.compile(model, backend="inductor") +""" + +from torch._dynamo import register_backend +from torch._dynamo.utils import dynamo_timed + + +@register_backend +def inductor(*args, **kwargs): + with dynamo_timed("inductor_import", log_pt2_compile_event=True): + # do import here to avoid loading inductor into memory when it is not used + from torch._inductor.compile_fx import compile_fx + + return compile_fx(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/onnxrt.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/onnxrt.py new file mode 100644 index 0000000000000000000000000000000000000000..afe4d9ddae67ac6424a38c7178bfd025cfca05f8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/onnxrt.py @@ -0,0 +1,38 @@ +# mypy: ignore-errors + +# This backend is maintained by ONNX team. To direct issues +# to the right people, please tag related GitHub issues with `module: onnx`. +# +# Maintainers' Github IDs: wschin, xadupre +from torch.onnx._internal.onnxruntime import ( + is_onnxrt_backend_supported, + torch_compile_backend, +) + +from .registry import register_backend + + +def has_onnxruntime(): + # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported() + return is_onnxrt_backend_supported() + + +if is_onnxrt_backend_supported(): + register_backend(name="onnxrt", compiler_fn=torch_compile_backend) +else: + + def information_displaying_backend(*args, **kwargs): + raise ImportError( + "onnxrt is not registered as a backend. " + "Please make sure all dependencies such as " + "numpy, onnx, onnxscript, and onnxruntime-training are installed. " + "Suggested procedure to fix dependency problem:\n" + " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n" + " (2) Open a new python terminal.\n" + " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n" + " (4) If it returns `True`, then you can use `onnxrt` backend.\n" + " (5) If it returns `False`, please execute the package importing section in " + "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails." + ) + + register_backend(name="onnxrt", compiler_fn=information_displaying_backend) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/registry.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..82f9de994649c540a8d77c3956c591dbe5831d3f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/registry.py @@ -0,0 +1,185 @@ +# mypy: ignore-errors + +""" +This module implements TorchDynamo's backend registry system for managing compiler backends. + +The registry provides a centralized way to register, discover and manage different compiler +backends that can be used with torch.compile(). It handles: + +- Backend registration and discovery through decorators and entry points +- Lazy loading of backend implementations +- Lookup and validation of backend names +- Categorization of backends using tags (debug, experimental, etc.) + +Key components: +- CompilerFn: Type for backend compiler functions that transform FX graphs +- _BACKENDS: Registry mapping backend names to entry points +- _COMPILER_FNS: Registry mapping backend names to loaded compiler functions + +Example usage: + @register_backend + def my_compiler(fx_graph, example_inputs): + # Transform FX graph into optimized implementation + return compiled_fn + + # Use registered backend + torch.compile(model, backend="my_compiler") + +The registry also supports discovering backends through setuptools entry points +in the "torch_dynamo_backends" group. Example: +``` +setup.py +--- +from setuptools import setup + +setup( + name='my_torch_backend', + version='0.1', + packages=['my_torch_backend'], + entry_points={ + 'torch_dynamo_backends': [ + # name = path to entry point of backend implementation + 'my_compiler = my_torch_backend.compiler:my_compiler_function', + ], + }, +) +``` +``` +my_torch_backend/compiler.py +--- +def my_compiler_function(fx_graph, example_inputs): + # Transform FX graph into optimized implementation + return compiled_fn +``` +Using `my_compiler` backend: +``` +import torch + +model = ... # Your PyTorch model +optimized_model = torch.compile(model, backend="my_compiler") +``` +""" + +import functools +import logging +import sys +from collections.abc import Sequence +from importlib.metadata import EntryPoint +from typing import Callable, Optional, Protocol + +import torch +from torch import fx + + +log = logging.getLogger(__name__) + + +class CompiledFn(Protocol): + def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ... + + +CompilerFn = Callable[[fx.GraphModule, list[torch.Tensor]], CompiledFn] + +_BACKENDS: dict[str, Optional[EntryPoint]] = {} +_COMPILER_FNS: dict[str, CompilerFn] = {} + + +def register_backend( + compiler_fn: Optional[CompilerFn] = None, + name: Optional[str] = None, + tags: Sequence[str] = (), +): + """ + Decorator to add a given compiler to the registry to allow calling + `torch.compile` with string shorthand. Note: for projects not + imported by default, it might be easier to pass a function directly + as a backend and not use a string. + + Args: + compiler_fn: Callable taking a FX graph and fake tensor inputs + name: Optional name, defaults to `compiler_fn.__name__` + tags: Optional set of string tags to categorize backend with + """ + if compiler_fn is None: + # @register_backend(name="") syntax + return functools.partial(register_backend, name=name, tags=tags) + assert callable(compiler_fn) + name = name or compiler_fn.__name__ + assert name not in _COMPILER_FNS, f"duplicate name: {name}" + if compiler_fn not in _BACKENDS: + _BACKENDS[name] = None + _COMPILER_FNS[name] = compiler_fn + compiler_fn._tags = tuple(tags) + return compiler_fn + + +register_debug_backend = functools.partial(register_backend, tags=("debug",)) +register_experimental_backend = functools.partial( + register_backend, tags=("experimental",) +) + + +def lookup_backend(compiler_fn): + """Expand backend strings to functions""" + if isinstance(compiler_fn, str): + if compiler_fn not in _BACKENDS: + _lazy_import() + if compiler_fn not in _BACKENDS: + from ..exc import InvalidBackend + + raise InvalidBackend(name=compiler_fn) + + if compiler_fn not in _COMPILER_FNS: + entry_point = _BACKENDS[compiler_fn] + register_backend(compiler_fn=entry_point.load(), name=compiler_fn) + compiler_fn = _COMPILER_FNS[compiler_fn] + return compiler_fn + + +def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: + """ + Return valid strings that can be passed to: + + torch.compile(..., backend="name") + """ + _lazy_import() + exclude_tags = set(exclude_tags or ()) + + backends = [ + name + for name in _BACKENDS.keys() + if name not in _COMPILER_FNS + or not exclude_tags.intersection(_COMPILER_FNS[name]._tags) + ] + return sorted(backends) + + +@functools.cache +def _lazy_import(): + from .. import backends + from ..utils import import_submodule + + import_submodule(backends) + + from ..repro.after_dynamo import dynamo_minifier_backend + + assert dynamo_minifier_backend is not None + + _discover_entrypoint_backends() + + +@functools.cache +def _discover_entrypoint_backends(): + # importing here so it will pick up the mocked version in test_backends.py + from importlib.metadata import entry_points + + group_name = "torch_dynamo_backends" + if sys.version_info < (3, 10): + eps = entry_points() + eps = eps[group_name] if group_name in eps else [] + eps = {ep.name: ep for ep in eps} + else: + eps = entry_points(group=group_name) + eps = {name: eps[name] for name in eps.names} + for backend_name in eps: + _BACKENDS[backend_name] = eps[backend_name] diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/tensorrt.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/tensorrt.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2ba60cdeb0f6581e049f670088269919fa0fa5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/tensorrt.py @@ -0,0 +1,14 @@ +# mypy: ignore-errors + +# import torch # type: ignore[import] +# from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import] +# from .registry import register_backend # type: ignore[import] + +""" +Placeholder for TensorRT backend for dynamo via torch-tensorrt +""" + +# @register_backend +# def tensorrt(gm, example_inputs): +# import torch_tensorrt # type: ignore[import] +# pass diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/torchxla.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/torchxla.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea061ee2291d1412ce60cf406cdb2db74e77bae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/torchxla.py @@ -0,0 +1,47 @@ +# mypy: ignore-errors + +import logging + +from functorch.compile import make_boxed_func + +from ..backends.common import aot_autograd +from .registry import register_backend, register_experimental_backend + + +log = logging.getLogger(__name__) + + +@register_experimental_backend +def openxla_eval(model, fake_tensor_inputs): + return xla_backend_helper(model, fake_tensor_inputs, boxed=False) + + +def openxla_eval_boxed(model, fake_tensor_inputs): + return xla_backend_helper(model, fake_tensor_inputs, boxed=True) + + +def xla_backend_helper(model, fake_tensor_inputs, boxed=False): + try: + import torch_xla.core.dynamo_bridge as bridge + except ImportError as e: + raise ImportError( + "Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla" + ) from e + + compiled_graph = None + + def fwd(*args): + nonlocal model + nonlocal compiled_graph + if compiled_graph is None: + compiled_graph = bridge.extract_compiled_graph(model, args) + del model + return compiled_graph(*args) + + return make_boxed_func(fwd) if boxed else fwd + + +openxla = aot_autograd( + fw_compiler=openxla_eval_boxed, +) +register_backend(name="openxla", compiler_fn=openxla) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/backends/tvm.py b/phivenv/Lib/site-packages/torch/_dynamo/backends/tvm.py new file mode 100644 index 0000000000000000000000000000000000000000..85baceab1b0f2e2733fa450cbb9d9845d3a59b08 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/backends/tvm.py @@ -0,0 +1,212 @@ +# mypy: ignore-errors + +""" +This module provides TVM backend integration for TorchDynamo. + +Apache TVM is a deep learning compiler framework that can optimize and execute +models on various hardware backends. This module enables: + +- Compilation of PyTorch models to TVM's computation graphs +- Multiple scheduling options: + - Default scheduler + - Auto-scheduler for automatic optimization + - Meta-schedule for evolutionary search-based tuning +- Hardware-specific optimizations: + - CUDA GPU support + - CPU support with LLVM targeting and architecture-specific tuning + - Automatic detection of CPU capabilities (AVX2, AVX512) +- Tensor conversion utilities between PyTorch and TVM formats +- Configurable optimization levels and tuning trials + +The backend can be used with torch.compile(): + model = torch.compile(model, backend="tvm") +""" + +import functools +import importlib +import logging +import os +import sys +import tempfile +from types import MappingProxyType +from typing import Optional + +import torch + +from .common import device_from_inputs, fake_tensor_unsupported +from .registry import register_backend + + +log = logging.getLogger(__name__) + + +@register_backend +@fake_tensor_unsupported +def tvm( + gm, + example_inputs, + *, + options: Optional[MappingProxyType] = MappingProxyType( + {"scheduler": None, "trials": 20000, "opt_level": 3} + ), +): + import tvm # type: ignore[import] + from tvm import relay # type: ignore[import] + from tvm.contrib import graph_executor # type: ignore[import] + + jit_mod = torch.jit.trace(gm, example_inputs) + device = device_from_inputs(example_inputs) + shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] + example_outputs = gm(*example_inputs) + if len(example_outputs) == 0: + log.warning("Explicitly fall back to eager due to zero output") + return gm.forward + mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) + if device.type == "cuda": + dev = tvm.cuda(device.index) + target = tvm.target.cuda() + else: + dev = tvm.cpu(0) + target = tvm.target.Target(llvm_target()) + + scheduler = options.get("scheduler", None) + if scheduler is None: + scheduler = os.environ.get("TVM_SCHEDULER", None) + + trials = options.get("trials", 20000) + opt_level = options.get("opt_level", 3) + + if scheduler == "auto_scheduler": + from tvm import auto_scheduler + + log_file = tempfile.NamedTemporaryFile() + + if not os.path.exists(log_file): + tasks, task_weights = auto_scheduler.extract_tasks( + mod["main"], params, target + ) + if len(tasks) != 0: + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + if not os.path.exists(log_file): + assert trials > 0 + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=trials, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + early_stopping=2000, + ) + try: + tuner.tune(tune_option) + except Exception: + if os.path.exists(log_file): + os.unlink(log_file) + raise + + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext( + opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True} + ): + lib = relay.build(mod, target=target, params=params) + elif scheduler == "meta_schedule": + from tvm import meta_schedule as ms + + with tempfile.TemporaryDirectory() as work_dir: + if device.type != "cuda": + # meta_schedule needs num-cores to be specified + # here we use the maximum core count + target = tvm.target.Target( + f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}" + ) + # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch + # once USE_PT_TVMDSOOP is updated and turned on by default in TVM. + assert trials > 0 + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + work_dir=work_dir, + max_trials_global=trials, + num_trials_per_iter=64, + params=params, + strategy="evolutionary", + opt_level=opt_level, + ) + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + opt_level=opt_level, + ) + elif scheduler == "default" or not scheduler: + # no autotuning + with tvm.transform.PassContext(opt_level=opt_level): + lib = relay.build(mod, target=target, params=params) + else: + raise NotImplementedError( + "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. " + "There are three available options: default, auto_scheduler and meta_schedule." + ) + m = graph_executor.GraphModule(lib["default"](dev)) + + def to_torch_tensor(nd_tensor): + """A helper function to transfer a NDArray to torch.tensor.""" + if nd_tensor.dtype == "bool": + # DLPack does not support boolean so it can't be handled by + # torch.utils.dlpack.from_pack. Workaround by going through + # numpy, although this brings additional data copy overhead. + return torch.from_numpy(nd_tensor.numpy()) + return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) + + def to_tvm_tensor(torch_tensor): + """A helper function to transfer a torch.tensor to NDArray.""" + if torch_tensor.dtype == torch.bool: + # same reason as above, fallback to numpy conversion which + # could introduce data copy overhead + return tvm.nd.array(torch_tensor.cpu().numpy()) + return tvm.nd.from_dlpack(torch_tensor) + + def exec_tvm(*i_args): + args = [a.contiguous() for a in i_args] + shape_info, _ = m.get_input_info() + active_inputs = {name for name, _ in shape_info.items()} + for idx, arg in enumerate(args, 0): + if arg.dim() != 0: + if arg.requires_grad: + arg = arg.detach() + inp_name = f"inp_{idx}" + if inp_name not in active_inputs: + log.warning( + "input %s skipped as not found in tvm's runtime library", + inp_name, + ) + continue + m.set_input( + inp_name, + to_tvm_tensor(arg), + ) + m.run() + return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())] + + return exec_tvm + + +tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule") +tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler") + + +def has_tvm(): + try: + importlib.import_module("tvm") + return True + except ImportError: + return False + + +@functools.cache +def llvm_target(): + if sys.platform == "linux": + cpuinfo = open("/proc/cpuinfo").read() + if "avx512" in cpuinfo: + return "llvm -mcpu=skylake-avx512" + elif "avx2" in cpuinfo: + return "llvm -mcpu=core-avx2" + return "llvm" diff --git a/phivenv/Lib/site-packages/torch/_dynamo/bytecode_analysis.py b/phivenv/Lib/site-packages/torch/_dynamo/bytecode_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..8480294a24a2f56c96e5156cbb6ff4650a6485de --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/bytecode_analysis.py @@ -0,0 +1,260 @@ +# mypy: allow-untyped-defs + +""" +This module provides utilities for analyzing and optimizing Python bytecode. +Key functionality includes: +- Dead code elimination +- Jump instruction optimization +- Stack size analysis and verification +- Live variable analysis +- Line number propagation and cleanup +- Exception table handling for Python 3.11+ + +The utilities in this module are used to analyze and transform bytecode +for better performance while maintaining correct semantics. +""" + +import bisect +import dataclasses +import dis +import sys +from typing import Any, Union + + +TERMINAL_OPCODES = { + dis.opmap["RETURN_VALUE"], + dis.opmap["JUMP_FORWARD"], + dis.opmap["RAISE_VARARGS"], + # TODO(jansel): double check exception handling +} +TERMINAL_OPCODES.add(dis.opmap["RERAISE"]) +if sys.version_info >= (3, 11): + TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD"]) + TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"]) +else: + TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"]) +if sys.version_info >= (3, 12): + TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"]) +if sys.version_info >= (3, 13): + TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD_NO_INTERRUPT"]) +JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs) +JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES} +HASLOCAL = set(dis.haslocal) +HASFREE = set(dis.hasfree) + +stack_effect = dis.stack_effect + + +def get_indexof(insts): + """ + Get a mapping from instruction memory address to index in instruction list. + Additionally checks that each instruction only appears once in the list. + """ + indexof = {} + for i, inst in enumerate(insts): + assert inst not in indexof + indexof[inst] = i + return indexof + + +def remove_dead_code(instructions): + """Dead code elimination""" + indexof = get_indexof(instructions) + live_code = set() + + def find_live_code(start): + for i in range(start, len(instructions)): + if i in live_code: + return + live_code.add(i) + inst = instructions[i] + if inst.exn_tab_entry: + find_live_code(indexof[inst.exn_tab_entry.target]) + if inst.opcode in JUMP_OPCODES: + find_live_code(indexof[inst.target]) + if inst.opcode in TERMINAL_OPCODES: + return + + find_live_code(0) + + # change exception table entries if start/end instructions are dead + # assumes that exception table entries have been propagated, + # e.g. with bytecode_transformation.propagate_inst_exn_table_entries, + # and that instructions with an exn_tab_entry lies within its start/end. + if sys.version_info >= (3, 11): + live_idx = sorted(live_code) + for i, inst in enumerate(instructions): + if i in live_code and inst.exn_tab_entry: + # find leftmost live instruction >= start + start_idx = bisect.bisect_left( + live_idx, indexof[inst.exn_tab_entry.start] + ) + assert start_idx < len(live_idx) + # find rightmost live instruction <= end + end_idx = ( + bisect.bisect_right(live_idx, indexof[inst.exn_tab_entry.end]) - 1 + ) + assert end_idx >= 0 + assert live_idx[start_idx] <= i <= live_idx[end_idx] + inst.exn_tab_entry.start = instructions[live_idx[start_idx]] + inst.exn_tab_entry.end = instructions[live_idx[end_idx]] + + return [inst for i, inst in enumerate(instructions) if i in live_code] + + +def remove_pointless_jumps(instructions): + """Eliminate jumps to the next instruction""" + pointless_jumps = { + id(a) + for a, b in zip(instructions, instructions[1:]) + if a.opname == "JUMP_ABSOLUTE" and a.target is b + } + return [inst for inst in instructions if id(inst) not in pointless_jumps] + + +def propagate_line_nums(instructions): + """Ensure every instruction has line number set in case some are removed""" + cur_line_no = None + + def populate_line_num(inst): + nonlocal cur_line_no + if inst.starts_line: + cur_line_no = inst.starts_line + + inst.starts_line = cur_line_no + + for inst in instructions: + populate_line_num(inst) + + +def remove_extra_line_nums(instructions): + """Remove extra starts line properties before packing bytecode""" + + cur_line_no = None + + def remove_line_num(inst): + nonlocal cur_line_no + if inst.starts_line is None: + return + elif inst.starts_line == cur_line_no: + inst.starts_line = None + else: + cur_line_no = inst.starts_line + + for inst in instructions: + remove_line_num(inst) + + +@dataclasses.dataclass +class ReadsWrites: + reads: set[Any] + writes: set[Any] + visited: set[Any] + + +def livevars_analysis(instructions, instruction): + indexof = get_indexof(instructions) + must = ReadsWrites(set(), set(), set()) + may = ReadsWrites(set(), set(), set()) + + def walk(state, start): + if start in state.visited: + return + state.visited.add(start) + + for i in range(start, len(instructions)): + inst = instructions[i] + if inst.opcode in HASLOCAL or inst.opcode in HASFREE: + if "LOAD" in inst.opname or "DELETE" in inst.opname: + if inst.argval not in must.writes: + state.reads.add(inst.argval) + elif "STORE" in inst.opname: + state.writes.add(inst.argval) + elif inst.opname == "MAKE_CELL": + pass + else: + raise NotImplementedError(f"unhandled {inst.opname}") + if inst.exn_tab_entry: + walk(may, indexof[inst.exn_tab_entry.target]) + if inst.opcode in JUMP_OPCODES: + walk(may, indexof[inst.target]) + state = may + if inst.opcode in TERMINAL_OPCODES: + return + + walk(must, indexof[instruction]) + return must.reads | may.reads + + +@dataclasses.dataclass +class FixedPointBox: + value: bool = True + + +@dataclasses.dataclass +class StackSize: + low: Union[int, float] + high: Union[int, float] + fixed_point: FixedPointBox + + def zero(self): + self.low = 0 + self.high = 0 + self.fixed_point.value = False + + def offset_of(self, other, n): + prior = (self.low, self.high) + self.low = min(self.low, other.low + n) + self.high = max(self.high, other.high + n) + if (self.low, self.high) != prior: + self.fixed_point.value = False + + def exn_tab_jump(self, depth): + prior = (self.low, self.high) + self.low = min(self.low, depth) + self.high = max(self.high, depth) + if (self.low, self.high) != prior: + self.fixed_point.value = False + + +def stacksize_analysis(instructions) -> Union[int, float]: + assert instructions + fixed_point = FixedPointBox() + stack_sizes = { + inst: StackSize(float("inf"), float("-inf"), fixed_point) + for inst in instructions + } + stack_sizes[instructions[0]].zero() + + for _ in range(100): + if fixed_point.value: + break + fixed_point.value = True + + for inst, next_inst in zip(instructions, instructions[1:] + [None]): + stack_size = stack_sizes[inst] + if inst.opcode not in TERMINAL_OPCODES: + assert next_inst is not None, f"missing next inst: {inst}" + eff = stack_effect(inst.opcode, inst.arg, jump=False) + stack_sizes[next_inst].offset_of(stack_size, eff) + if inst.opcode in JUMP_OPCODES: + stack_sizes[inst.target].offset_of( + stack_size, stack_effect(inst.opcode, inst.arg, jump=True) + ) + if inst.exn_tab_entry: + # see https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + # on why depth is computed this way. + depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1 + stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth) + + if False: + for inst in instructions: + stack_size = stack_sizes[inst] + print(stack_size.low, stack_size.high, inst) + + low = min(x.low for x in stack_sizes.values()) + high = max(x.high for x in stack_sizes.values()) + + assert fixed_point.value, "failed to reach fixed point" + assert low >= 0 + return high diff --git a/phivenv/Lib/site-packages/torch/_dynamo/bytecode_transformation.py b/phivenv/Lib/site-packages/torch/_dynamo/bytecode_transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..75c8b75b7756ea592ba15d534d1d0ead2ce55e3d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/bytecode_transformation.py @@ -0,0 +1,1660 @@ +# mypy: allow-untyped-defs + +""" +This module provides utilities for analyzing, transforming and manipulating Python bytecode. +It includes functionality for: +- Converting between different bytecode formats and versions +- Virtualizing jumps and managing jump targets +- Handling exception tables and their entries +- Managing instruction offsets and extended arguments +- Providing a clean API for bytecode modification and transformation +- Supporting Python version-specific bytecode features +- Generating bytecode from template functions + +The module is designed to work across different Python versions (3.7+) and handles +version-specific bytecode differences transparently. +""" + +import copy +import dataclasses +import dis +import functools +import itertools +import sys +import types +import uuid +from collections.abc import Iterator, Sequence +from typing import Any, Callable, cast, Optional, Union + +from ..utils._backport_slots import dataclass_slots +from .bytecode_analysis import ( + get_indexof, + propagate_line_nums, + remove_extra_line_nums, + stacksize_analysis, +) +from .utils import is_safe_constant + + +@dataclass_slots +@dataclasses.dataclass +class InstructionExnTabEntry: + start: "Instruction" + end: "Instruction" + target: "Instruction" + depth: int + lasti: bool + + def __repr__(self) -> str: + return ( + f"InstructionExnTabEntry(start={self.start.short_inst_repr()}, " + f"end={self.end.short_inst_repr()}, " + f"target={self.target.short_inst_repr()}, " + f"depth={self.depth}, lasti={self.lasti})" + ) + + def __eq__(self, o) -> bool: + return ( + self.start is o.start + and self.end is o.end + and self.target is o.target + and self.depth == o.depth + and self.lasti == o.lasti + ) + + +@dataclass_slots +@dataclasses.dataclass +class Instruction: + """A mutable version of dis.Instruction""" + + opcode: int + opname: str + arg: Optional[int] + argval: Any + offset: Optional[int] = None + starts_line: Optional[int] = None + is_jump_target: bool = False + positions: Optional["dis.Positions"] = None + # extra fields to make modification easier: + target: Optional["Instruction"] = None + exn_tab_entry: Optional[InstructionExnTabEntry] = None + argrepr: Optional[str] = None + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other) -> bool: + return id(self) == id(other) + + def short_inst_repr(self) -> str: + return f"Instruction(opname={self.opname}, offset={self.offset})" + + def copy_positions(self, other: "Instruction") -> None: + self.starts_line = other.starts_line + self.positions = other.positions + + +if sys.version_info >= (3, 13): + + def convert_instruction(i: dis.Instruction) -> Instruction: + return Instruction( + i.opcode, + i.opname, + i.arg, + i.argval, + i.offset, + i.line_number, + i.is_jump_target, + i.positions, + ) + +elif sys.version_info >= (3, 11): + + def convert_instruction(i: dis.Instruction) -> Instruction: + return Instruction( + i.opcode, + i.opname, + i.arg, + i.argval, + i.offset, + i.starts_line, + i.is_jump_target, + i.positions, + ) + +else: + + def convert_instruction(i: dis.Instruction) -> Instruction: + return Instruction( + i.opcode, + i.opname, + i.arg, + i.argval, + i.offset, + i.starts_line, + i.is_jump_target, + None, + ) + + +class _NotProvided: + def __repr__(self) -> str: + return "_NotProvided" + + +if sys.version_info >= (3, 12): + + def inst_has_op_bits(name): + return name in ("LOAD_ATTR", "LOAD_GLOBAL", "LOAD_SUPER_ATTR") + +elif sys.version_info >= (3, 11): + + def inst_has_op_bits(name): + return name == "LOAD_GLOBAL" + +else: + + def inst_has_op_bits(name): + return False + + +def create_instruction( + name, *, arg=None, argval=_NotProvided, target=None +) -> Instruction: + """ + At most one of `arg`, `argval`, and `target` can be not None/_NotProvided. + This is to prevent ambiguity, e.g. does + create_instruction("LOAD_CONST", 5) + mean load the constant at co_consts[5], or load the constant 5? + + If `arg` is not provided, it will be computed during assembly from + `argval` or `target`. + + Bits in the args of instructions LOAD_GLOBAL, LOAD_ATTR (3.12+), and LOAD_SUPER_ATTR + modify the behavior of the instruction. In this case, we allow both `arg` + and `argval` to be set. The value of `arg` here is expected to be the value of + the op bits and the true value of `arg` will be computed during assembly. + If `arg` is not set, the bits are assumed to be 0. + """ + + # allow for instructions with op bits to have both arg and argval specified + if inst_has_op_bits(name): + if target is not None: + raise RuntimeError("target cannot be specified for instruction") + if arg is None: + arg = 0 + else: + cnt = (arg is not None) + (argval is not _NotProvided) + (target is not None) + if cnt > 1: + raise RuntimeError( + "only one of arg, argval, and target can be not None/_NotProvided" + ) + if arg is not None and not isinstance(arg, int): + raise RuntimeError("instruction arg must be int or None") + return Instruction( + opcode=dis.opmap[name], opname=name, arg=arg, argval=argval, target=target + ) + + +# Python 3.11 remaps +def create_jump_absolute(target) -> Instruction: + inst = "JUMP_FORWARD" if sys.version_info >= (3, 11) else "JUMP_ABSOLUTE" + return create_instruction(inst, target=target) + + +def create_load_const(val, checked=True) -> Instruction: + """ + In general we should only create `LOAD_CONST` for immutable objects, but + sometimes it's convenient _and safe_ for Dynamo create `LOAD_CONST` for + mutable objects. In such cases, use `checked=False`. + """ + if checked: + assert is_safe_constant(val), f"unsafe constant {val}" + return create_instruction("LOAD_CONST", argval=val) + + +def create_dup_top() -> Instruction: + if sys.version_info >= (3, 11): + return create_instruction("COPY", arg=1) + return create_instruction("DUP_TOP") + + +def create_rot_n(n) -> list[Instruction]: + """ + Returns a "simple" sequence of instructions that rotates TOS to the n-th + position in the stack. For Python < 3.11, returns a single ROT_* + instruction. If no such instruction exists, an error is raised and the + caller is expected to generate an equivalent sequence of instructions. + For Python >= 3.11, any rotation can be expressed as a simple sequence of + swaps. + """ + if n <= 1: + # don't rotate + return [] + + if sys.version_info >= (3, 11): + # rotate can be expressed as a sequence of swap operations + # e.g. rotate 3 is equivalent to swap 3, swap 2 + return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)] + + # ensure desired rotate function exists + if sys.version_info < (3, 10) and n >= 5: + raise AttributeError(f"rotate {n} not supported for Python < 3.10") + + if n <= 4: + return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])] + return [create_instruction("ROT_N", arg=n)] + + +def add_push_null( + inst_or_insts: Union[Instruction, list[Instruction]], +) -> list[Instruction]: + """ + Appends or prepends a PUSH_NULL instruction to `inst_or_insts`, + depending on Python version. Used when you know that + `inst_or_insts` generates a callable that will be called. + + NOTE: Assumes `inst_or_insts` is a single instruction or sequence of + instructions that pushes exactly 1 object to the stack that is to + be called. It is important that you include ALL instructions that + construct the callable - not just the first instruction/a prefix. + + Will attempt to use the NULL push bit for instructions + with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR). + In this case, instructions WILL be modified. + """ + if isinstance(inst_or_insts, Instruction): + insts = [inst_or_insts] + else: + insts = inst_or_insts + + def inst_has_bit_set(idx): + assert insts[idx].arg is not None + return insts[idx].arg & 1 == 1 + + def set_inst_bit(idx): + assert insts[idx].arg is not None + insts[idx].arg |= 1 + + if sys.version_info >= (3, 13): + # In 3.13, NULL follows the callable + if inst_has_op_bits(insts[-1].opname) and not inst_has_bit_set(-1): + # All insts with op bits have the push_null bit as the last one. + # Only set the bit if it hasn't been set - otherwise, we need + # to add another PUSH_NULL. + set_inst_bit(-1) + else: + insts = insts + [create_instruction("PUSH_NULL")] + elif sys.version_info >= (3, 12): + # LOAD_ATTR/LOAD_SUPER_ATTR at the end + # We assume that `insts` will only load 1 object, so + # LOAD_GLOBAL at the end doesn't need to be checked + if inst_has_op_bits(insts[-1].opname) and not inst_has_bit_set(-1): + set_inst_bit(-1) + elif insts[0].opname == "LOAD_GLOBAL" and not inst_has_bit_set(0): + set_inst_bit(0) + else: + insts = [create_instruction("PUSH_NULL")] + insts + elif sys.version_info >= (3, 11): + # 3.11 introduced NULL preceding callable + if inst_has_op_bits(insts[0].opname) and not inst_has_bit_set(0): + set_inst_bit(0) + else: + insts = [create_instruction("PUSH_NULL")] + insts + return insts + + +def add_push_null_call_function_ex( + inst_or_insts: Union[Instruction, list[Instruction]], +) -> list[Instruction]: + """Like add_push_null, but the low bit of LOAD_ATTR/LOAD_SUPER_ATTR + is not set, due to an expected CALL_FUNCTION_EX instruction. + """ + if isinstance(inst_or_insts, Instruction): + insts = [inst_or_insts] + else: + insts = inst_or_insts + + if sys.version_info < (3, 11): + return insts + + idx = -1 if sys.version_info >= (3, 13) else 0 + if insts[idx].opname == "LOAD_GLOBAL": + assert insts[idx].arg is not None + if insts[idx].arg & 1 == 0: # type: ignore[operator] + insts[idx].arg |= 1 # type: ignore[operator] + return insts + + if sys.version_info >= (3, 13): + insts = insts + [create_instruction("PUSH_NULL")] + else: + insts = [create_instruction("PUSH_NULL")] + insts + + return insts + + +def create_call_function(nargs, push_null) -> list[Instruction]: + """ + Creates a sequence of instructions that makes a function call. + + `push_null` is used in Python 3.11+ only. It is used in codegen when + a function call is intended to be made with the NULL + fn convention, + and we know that the NULL has not been pushed yet. We will push a + NULL and rotate it to the correct position immediately before making + the function call. + + `push_null` should be True if no NULL is pushed for the callable. + Conversely, `push_null` should be False if a NULL was pushed for the callable. + Prefer using `push_null=False` when possible since we will not need to rotate + NULL to the right place, which is less efficient. + + Generally, you should codegen a function by using `add_push_null` then + `create_call_function` with `push_null=False`. + + Example of when to set push_null False: + + insts = [ + create_instruction("LOAD_GLOBAL", argval="torch"), + create_instruction("LOAD_ATTR", argval="nn"), + create_instruction("LOAD_ATTR", argval="functional"), + create_instruction("LOAD_ATTR", argval="relu"), + ] + insts = add_push_null(insts) + insts.append(create_instruction("LOAD_FAST", argval="x")) + insts.extend(create_call_function(1, False)) + + Example of when to set push_null True: + + insts = [create_instruction("LOAD_FAST", x)] + for should_wrap, wrapper_name in wrappers: + if should_wrap: + insts.extend([ + create_instruction("LOAD_GLOBAL", argval="wrapper1"), + create_instruction("SWAP", arg=2), + *create_call_function(1, True), + ) + """ + if sys.version_info >= (3, 11): + output = [] + if push_null: + output.append(create_instruction("PUSH_NULL")) + # 3.13 swapped NULL and callable + rots = nargs + 1 if sys.version_info >= (3, 13) else nargs + 2 + output.extend(create_rot_n(rots)) + if sys.version_info < (3, 12): + output.append(create_instruction("PRECALL", arg=nargs)) + output.append(create_instruction("CALL", arg=nargs)) + return output + return [create_instruction("CALL_FUNCTION", arg=nargs)] + + +def create_call_method(nargs) -> list[Instruction]: + if sys.version_info >= (3, 12): + return [create_instruction("CALL", arg=nargs)] + if sys.version_info >= (3, 11): + return [ + create_instruction("PRECALL", arg=nargs), + create_instruction("CALL", arg=nargs), + ] + return [create_instruction("CALL_METHOD", arg=nargs)] + + +def create_load_method(name) -> Instruction: + if sys.version_info >= (3, 12): + # in 3.12, create a LOAD_ATTR instruction with the low bit set + return create_instruction("LOAD_ATTR", arg=1, argval=name) + return create_instruction("LOAD_METHOD", argval=name) + + +def create_setup_with(target) -> Instruction: + opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH" + return create_instruction(opname, target=target) + + +def create_swap(n) -> list[Instruction]: + if sys.version_info >= (3, 11): + return [create_instruction("SWAP", arg=n)] + # in Python < 3.11, SWAP is a macro that expands to multiple instructions + if n == 1: + return [] + """ + e.g. swap "a" and "b" in this stack: + 0 a 1 2 3 b + 0 a [1 2 3 b] + 0 a [1 2 3 b] [1 2 3 b] + 0 a [1 2 3 b] [1 2 3 b] -1 + 0 a [1 2 3 b] b + 0 b a [1 2 3 b] + 0 b a [1 2 3 b] [1 2 3 b] + 0 b [1 2 3 b] a [1 2 3 b] + 0 b [1 2 3 b] a [1 2 3 b] -1 + 0 b [1 2 3 a] + 0 b [1 2 3 a] [1 2 3 a] + 0 b [1 2 3 a] [1 2 3 a] reverse + 0 b [a 3 2 1] None + 0 b [a 3 2 1] + 0 b 1 2 3 a + """ + return [ + create_instruction("BUILD_LIST", arg=n - 1), + create_instruction("DUP_TOP"), + create_instruction("LOAD_CONST", argval=-1), + create_instruction("BINARY_SUBSCR"), + create_instruction("ROT_THREE"), + create_instruction("DUP_TOP"), + create_instruction("ROT_THREE"), + create_instruction("LOAD_CONST", argval=-1), + create_instruction("STORE_SUBSCR"), + create_instruction("DUP_TOP"), + create_load_method("reverse"), + *create_call_method(0), + create_instruction("POP_TOP"), + create_instruction("UNPACK_SEQUENCE", arg=n - 1), + ] + + +def lnotab_writer( + lineno: int, byteno: int = 0 +) -> tuple[list[int], Callable[[int, int], None]]: + """ + Used to create typing.CodeType.co_lnotab + See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt + This is the internal format of the line number table if Python < 3.10 + """ + assert sys.version_info < (3, 10) + lnotab: list[int] = [] + + def update(lineno_new, byteno_new): + nonlocal byteno, lineno + while byteno_new != byteno or lineno_new != lineno: + byte_offset = max(0, min(byteno_new - byteno, 255)) + line_offset = max(-128, min(lineno_new - lineno, 127)) + assert byte_offset != 0 or line_offset != 0 + byteno += byte_offset + lineno += line_offset + lnotab.extend((byte_offset, line_offset & 0xFF)) + + return lnotab, update + + +def linetable_310_writer(first_lineno): + """ + Used to create typing.CodeType.co_linetable + See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt + This is the internal format of the line number table for Python 3.10 + """ + assert sys.version_info >= (3, 10) and sys.version_info < (3, 11) + linetable: list[int] = [] + lineno = first_lineno + lineno_delta = 0 + byteno = 0 + + def _update(byteno_delta, lineno_delta): + while byteno_delta != 0 or lineno_delta != 0: + byte_offset = max(0, min(byteno_delta, 254)) + line_offset = max(-127, min(lineno_delta, 127)) + assert byte_offset != 0 or line_offset != 0 + byteno_delta -= byte_offset + lineno_delta -= line_offset + linetable.extend((byte_offset, line_offset & 0xFF)) + + def update(lineno_new, byteno_new): + nonlocal lineno, lineno_delta, byteno + byteno_delta = byteno_new - byteno + byteno = byteno_new + _update(byteno_delta, lineno_delta) + lineno_delta = lineno_new - lineno + lineno = lineno_new + + def end(total_bytes): + _update(total_bytes - byteno, lineno_delta) + + return linetable, update, end + + +def encode_varint(n: int) -> list[int]: + """ + 6-bit chunk encoding of an unsigned integer + See https://github.com/python/cpython/blob/3.11/Objects/locations.md + """ + assert n >= 0 + b = [n & 63] + n >>= 6 + while n > 0: + b[-1] |= 64 + b.append(n & 63) + n >>= 6 + return b + + +def linetable_311_writer(first_lineno: int): + """ + Used to create typing.CodeType.co_linetable + See https://github.com/python/cpython/blob/3.11/Objects/locations.md + This is the internal format of the line number table for Python 3.11 + """ + assert sys.version_info >= (3, 11) + linetable = [] + lineno = first_lineno + + def update(positions: "dis.Positions", inst_size): + nonlocal lineno + lineno_new = positions.lineno if positions else None + + def _update(delta, size): + assert 0 < size <= 8 + # first byte - use 13 (no column info) is positions is + # malformed, otherwise use 14 (long form) + other_varints: tuple[int, ...] = () + if ( + positions + and positions.lineno is not None + and positions.end_lineno is not None + and positions.col_offset is not None + and positions.end_col_offset is not None + ): + linetable.append(0b1_1110_000 + size - 1) + # for whatever reason, column offset needs `+ 1` + # https://github.com/python/cpython/blob/1931c2a438c50e6250725c84dff94fc760b9b951/Python/compile.c#L7603 + other_varints = ( + positions.end_lineno - positions.lineno, + positions.col_offset + 1, + positions.end_col_offset + 1, + ) + else: + linetable.append(0b1_1101_000 + size - 1) + # encode signed int + if delta < 0: + delta = ((-delta) << 1) | 1 + else: + delta <<= 1 + # encode unsigned int + linetable.extend(encode_varint(delta)) + for n in other_varints: + linetable.extend(encode_varint(n)) + + if lineno_new is None: + lineno_delta = 0 + else: + lineno_delta = lineno_new - lineno + lineno = lineno_new + while inst_size > 8: + _update(lineno_delta, 8) + inst_size -= 8 + _update(lineno_delta, inst_size) + + return linetable, update + + +@dataclass_slots +@dataclasses.dataclass +class ExceptionTableEntry: + start: int + end: int + target: int + depth: int + lasti: bool + + +def encode_exception_table_varint(n: int) -> list[int]: + """ + Similar to `encode_varint`, but the 6-bit chunks are ordered in reverse. + """ + assert n >= 0 + b = [n & 63] + n >>= 6 + while n > 0: + b.append(n & 63) + n >>= 6 + b.reverse() + for i in range(len(b) - 1): + b[i] |= 64 + return b + + +def decode_exception_table_varint(bytes_iter: Iterator[int]) -> int: + """ + Inverse of `encode_exception_table_varint`. + """ + b = next(bytes_iter) + val = b & 63 + while b & 64: + val <<= 6 + b = next(bytes_iter) + val |= b & 63 + return val + + +def check_exception_table(tab: list[ExceptionTableEntry]) -> None: + """ + Verifies that a list of ExceptionTableEntries will make a well-formed + jump table: entries are non-empty, sorted, and do not overlap. + """ + for i in range(len(tab) - 1): + assert ( + tab[i].start <= tab[i].end + and tab[i].end < tab[i + 1].start + and tab[i + 1].start <= tab[i + 1].end + ) + + +def parse_exception_table(exntab: bytes) -> list[ExceptionTableEntry]: + """ + Parse the exception table according to + https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + """ + exntab_iter = iter(exntab) + tab = [] + try: + while True: + start = decode_exception_table_varint(exntab_iter) * 2 + length = decode_exception_table_varint(exntab_iter) * 2 + end = start + length - 2 + target = decode_exception_table_varint(exntab_iter) * 2 + dl = decode_exception_table_varint(exntab_iter) + depth = dl >> 1 + lasti = bool(dl & 1) + tab.append(ExceptionTableEntry(start, end, target, depth, lasti)) + except StopIteration: + check_exception_table(tab) + return tab + + +def assemble_exception_table(tab: list[ExceptionTableEntry]) -> bytes: + """ + Inverse of parse_exception_table - encodes list of exception + table entries into bytes. + """ + b = [] + for entry in tab: + first_entry = encode_exception_table_varint(entry.start // 2) + first_entry[0] |= 1 << 7 + b.extend(first_entry) + length = entry.end - entry.start + 2 + b.extend(encode_exception_table_varint(length // 2)) + b.extend(encode_exception_table_varint(entry.target // 2)) + dl = (entry.depth << 1) + entry.lasti + b.extend(encode_exception_table_varint(dl)) + return bytes(b) + + +def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes, bytes]: + """Do the opposite of dis.get_instructions()""" + code: list[int] = [] + if sys.version_info >= (3, 11): + lnotab, update_lineno = linetable_311_writer(firstlineno) + num_ext = 0 + for i, inst in enumerate(instructions): + if inst.opname == "EXTENDED_ARG": + inst_size = 1 + num_ext += 1 + # copy positions from the actual instruction + for j in (1, 2, 3): + if instructions[i + j].opname != "EXTENDED_ARG": + inst.positions = instructions[i + j].positions + break + else: + inst_size = instruction_size(inst) // 2 + num_ext + num_ext = 0 + update_lineno(inst.positions, inst_size) + num_ext = 0 + arg = inst.arg or 0 + code.extend((inst.opcode, arg & 0xFF)) + for _ in range(instruction_size(inst) // 2 - 1): + code.extend((0, 0)) + else: + if sys.version_info < (3, 10): + lnotab, update_lineno = lnotab_writer(firstlineno) + else: + lnotab, update_lineno, end = linetable_310_writer(firstlineno) + + for inst in instructions: + if inst.starts_line is not None: + update_lineno(inst.starts_line, len(code)) + arg = inst.arg or 0 + code.extend((inst.opcode, arg & 0xFF)) + + if sys.version_info >= (3, 10): + end(len(code)) + + return bytes(code), bytes(lnotab) + + +def _get_instruction_by_offset(offset_to_inst: dict[int, Instruction], offset: int): + """ + Get the instruction located at a given offset, accounting for EXTENDED_ARGs + """ + for n in (0, 2, 4, 6): + if offset_to_inst[offset + n].opcode != dis.EXTENDED_ARG: + return offset_to_inst[offset + n] + return None + + +def virtualize_jumps(instructions) -> None: + """Replace jump targets with pointers to make editing easier""" + jump_targets = {inst.offset: inst for inst in instructions} + + for inst in instructions: + if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel: + inst.target = _get_instruction_by_offset(jump_targets, inst.argval) + + +_REL_JUMPS = set(dis.hasjrel) + + +def flip_jump_direction(instruction: Instruction) -> None: + if sys.version_info < (3, 11): + raise RuntimeError("Cannot flip jump direction in Python < 3.11") + if "FORWARD" in instruction.opname: + instruction.opname = instruction.opname.replace("FORWARD", "BACKWARD") + elif "BACKWARD" in instruction.opname: + instruction.opname = instruction.opname.replace("BACKWARD", "FORWARD") + else: + raise AttributeError("Instruction is not a forward or backward jump") + instruction.opcode = dis.opmap[instruction.opname] + assert instruction.opcode in _REL_JUMPS + + +def _get_instruction_front(instructions: list[Instruction], idx: int): + """ + i.e. get the first EXTENDED_ARG instruction (if any) when targeting + instructions[idx] with a jump. + """ + target = instructions[idx] + for offset in (1, 2, 3): + if idx >= offset and instructions[idx - offset].opcode == dis.EXTENDED_ARG: + target = instructions[idx - offset] + else: + break + return target + + +def devirtualize_jumps(instructions): + """Fill in args for virtualized jump target after instructions may have moved""" + jumps = set(dis.hasjabs).union(set(dis.hasjrel)) + + # check for negative jump args and fix them + for inst in instructions: + if inst.opcode in jumps: + if inst.opcode not in dis.hasjabs: + if inst.target.offset < inst.offset: + if sys.version_info < (3, 11): + raise RuntimeError("Got negative jump offset for Python < 3.11") + # forward jumps become backward + if "FORWARD" in inst.opname: + flip_jump_direction(inst) + else: + # backward jumps become forward + if sys.version_info >= (3, 11) and "BACKWARD" in inst.opname: + flip_jump_direction(inst) + + # jump instruction size may have changed due to flips + update_offsets(instructions) + indexof = get_indexof(instructions) + + # compute jump instruction arg + for inst in instructions: + if inst.opcode in jumps: + target = _get_instruction_front(instructions, indexof[inst.target]) + if inst.opcode in dis.hasjabs: + if sys.version_info < (3, 10): + inst.arg = target.offset + elif sys.version_info < (3, 11): + # `arg` is expected to be bytecode offset, whereas `offset` is byte offset. + # Divide since bytecode is 2 bytes large. + inst.arg = int(target.offset / 2) + else: + raise RuntimeError("Python 3.11+ should not have absolute jumps") + else: # relative jump + # byte offset between target and next instruction + inst.arg = abs( + int(target.offset - inst.offset - instruction_size(inst)) + ) + if sys.version_info >= (3, 10): + # see bytecode size comment in the absolute jump case above + inst.arg //= 2 + inst.argval = target.offset + inst.argrepr = f"to {target.offset}" + + +def virtualize_exception_table(exn_tab_bytes: bytes, instructions: list[Instruction]): + """Replace exception table entries with pointers to make editing easier""" + exn_tab = parse_exception_table(exn_tab_bytes) + offset_to_inst = {cast(int, inst.offset): inst for inst in instructions} + offsets = sorted(offset_to_inst.keys()) + end_offset_idx = 0 + exn_tab_iter = iter(exn_tab) + try: + + def step(): + nonlocal end_offset_idx + entry = next(exn_tab_iter) + # find rightmost offset <= entry.end, since entry.end may not be + # an actual instruction, e.g. if the end instruction is LOAD_GLOBAL, + # which takes more than 2 bytes, then entry.end points to the end + # of the LOAD_GLOBAL instruction, not the beginning. + while ( + end_offset_idx < len(offsets) and offsets[end_offset_idx] <= entry.end + ): + end_offset_idx += 1 + assert end_offset_idx > 0 + end_offset = offsets[end_offset_idx - 1] + inst_entry = InstructionExnTabEntry( + _get_instruction_by_offset(offset_to_inst, entry.start), + _get_instruction_by_offset(offset_to_inst, end_offset), + _get_instruction_by_offset(offset_to_inst, entry.target), + entry.depth, + entry.lasti, + ) + return entry, inst_entry + + entry, inst_entry = step() + for inst in instructions: + while inst.offset > entry.end: + entry, inst_entry = step() + if inst.offset >= entry.start: + inst.exn_tab_entry = copy.copy(inst_entry) + except StopIteration: + pass + + +def compute_exception_table( + instructions: list[Instruction], +) -> list[ExceptionTableEntry]: + """Compute exception table in list format from instructions with exn_tab_entries""" + exn_dict: dict[tuple[int, int], tuple[int, int, bool]] = {} + indexof = get_indexof(instructions) + + for inst in instructions: + if inst.exn_tab_entry: + # account for prefixed EXTENDED_ARGS + start = _get_instruction_front( + instructions, indexof[inst.exn_tab_entry.start] + ).offset + # point to the last 2 bytes of the end instruction + end = ( + cast(int, inst.exn_tab_entry.end.offset) + + instruction_size(inst.exn_tab_entry.end) + - 2 + ) + target = _get_instruction_front( + instructions, indexof[inst.exn_tab_entry.target] + ).offset + key = (start, end) + val = (target, inst.exn_tab_entry.depth, inst.exn_tab_entry.lasti) + if key in exn_dict: + assert exn_dict[key] == val + exn_dict[key] = val + + # Dynamo may construct nested exception table entries for convenience, + # but Python expects exception table entries to not overlap. + # NOTE: below, "keys" refer to old instruction entries' starts and ends, + # and "entries" refer to the generated exception table entries. + + # Sort keys by increasing start, then decreasing end + keys_sorted = sorted(exn_dict.keys(), key=lambda t: (t[0], -t[1])) + # smallest byte that the next exception table entry can start at + nexti = 0 + # stack of current nested keys + key_stack: list[tuple[int, int]] = [] + exn_tab: list[ExceptionTableEntry] = [] + + def pop(): + """ + Pop the key_stack and append an exception table entry if possible. + """ + nonlocal nexti + if key_stack: + key = key_stack.pop() + if nexti <= key[1]: + exn_tab.append( + ExceptionTableEntry(max(key[0], nexti), key[1], *exn_dict[key]) + ) + nexti = key[1] + 2 + + for key in keys_sorted: + # pop keys that are no longer nested over the current key + while key_stack and key_stack[-1][1] < key[0]: + pop() + if key_stack: + # create an entry covering to the current key, if possible + assert key_stack[-1][0] <= key[0] <= key[1] <= key_stack[-1][1] + left = max(nexti, key_stack[-1][0]) + if left < key[0]: + exn_tab.append( + ExceptionTableEntry(left, key[0] - 2, *exn_dict[key_stack[-1]]) + ) + nexti = key[0] + key_stack.append(key) + while key_stack: + pop() + check_exception_table(exn_tab) + return exn_tab + + +def check_inst_exn_tab_entries_nested( + tab: list[InstructionExnTabEntry], indexof +) -> None: + """ + Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's, + i.e. no entries partially overlap. + "Properly sorted" means entries are sorted by increasing starts, then + decreasing ends. + """ + entry_stack: list[tuple[int, int]] = [] + for entry in tab: + key = (indexof[entry.start], indexof[entry.end]) + while entry_stack and entry_stack[-1][1] < key[0]: + entry_stack.pop() + if entry_stack: + assert entry_stack[-1][0] <= key[0] <= key[1] <= entry_stack[-1][1] + entry_stack.append(key) + + +def propagate_inst_exn_table_entries(instructions: list[Instruction]) -> None: + """ + Copies exception table entries to all instructions in an entry's range. + Supports nested exception table entries. + """ + indexof = get_indexof(instructions) + entries: dict[tuple[int, int], InstructionExnTabEntry] = {} + for inst in instructions: + if inst.exn_tab_entry: + key = ( + indexof[inst.exn_tab_entry.start], + indexof[inst.exn_tab_entry.end], + ) + if key in entries: + assert inst.exn_tab_entry == entries[key] + entries[key] = inst.exn_tab_entry + sorted_entries = [ + entries[key] for key in sorted(entries.keys(), key=lambda t: (t[0], -t[1])) + ] + check_inst_exn_tab_entries_nested(sorted_entries, indexof) + # Propagation of nested entries works since nested entries come later + # in sorted order. + for entry in sorted_entries: + for i in range(indexof[entry.start], indexof[entry.end] + 1): + instructions[i].exn_tab_entry = copy.copy(entry) + + +def check_inst_exn_tab_entries_valid(instructions: list[Instruction]): + """ + Checks that exn_tab_entries of instructions are valid. + An entry's start, end, and target must be in instructions. + Instructions with an exn_tab_entry are located within + the entry's start and end instructions. + Instructions do not share exn_tab_entries. + + Implicitly checks for no duplicate instructions. + """ + indexof = get_indexof(instructions) + exn_tab_entry_set = set() + for i, inst in enumerate(instructions): + if inst.exn_tab_entry: + assert sys.version_info >= (3, 11) + assert id(inst.exn_tab_entry) not in exn_tab_entry_set + exn_tab_entry_set.add(id(inst.exn_tab_entry)) + entry = inst.exn_tab_entry + assert entry.start in indexof + assert entry.end in indexof + assert entry.target in indexof + assert indexof[entry.start] <= i <= indexof[entry.end] + + +def strip_extended_args(instructions: list[Instruction]) -> None: + instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG] + + +# Overwrites old_inst with a sequence of new instructions. +# This is necessary in order to preserve jump targets to the old +# instruction, exception table entries, and positions. +# Returns the modified sequence of instructions (including the modified +# old instruction!) that can be manipulated elsewhere. +def overwrite_instruction(old_inst, new_insts): + # update old_inst.exnt_tab_entry.end if necessary + if ( + old_inst.exn_tab_entry + and old_inst.exn_tab_entry.end is old_inst + and len(new_insts) > 1 + ): + old_inst.exn_tab_entry.end = new_insts[-1] + # preserve exception table entries and positions + for inst in new_insts[1:]: + inst.exn_tab_entry = copy.copy(old_inst.exn_tab_entry) + inst.positions = old_inst.positions + # modify old_inst in-place to preserve jump target + old_inst.opcode = new_insts[0].opcode + old_inst.opname = new_insts[0].opname + old_inst.arg = new_insts[0].arg + old_inst.argval = new_insts[0].argval + old_inst.target = new_insts[0].target + return [old_inst] + new_insts[1:] + + +def remove_load_call_method(instructions: list[Instruction]) -> list[Instruction]: + """LOAD_METHOD puts a NULL on the stack which causes issues, so remove it""" + assert sys.version_info < (3, 11) + rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"} + for inst in instructions: + if inst.opname in rewrites: + inst.opname = rewrites[inst.opname] + inst.opcode = dis.opmap[inst.opname] + return instructions + + +def remove_jump_if_none(instructions: list[Instruction]) -> None: + new_insts = [] + for inst in instructions: + if "_NONE" in inst.opname: + is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname)) + # need both argval and arg set correctly now (not later) + is_op.argval = is_op.arg + + if sys.version_info < (3, 12): + jump_op = create_instruction( + ( + "POP_JUMP_FORWARD_IF_TRUE" + if "FORWARD" in inst.opname + else "POP_JUMP_BACKWARD_IF_TRUE" + ), + target=inst.target, + ) + else: + jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target) + + replace_insts = [ + create_instruction("LOAD_CONST", argval=None), + is_op, + jump_op, + ] + new_insts.extend(overwrite_instruction(inst, replace_insts)) + else: + new_insts.append(inst) + instructions[:] = new_insts + + +def remove_binary_store_slice(instructions: list[Instruction]) -> None: + new_insts = [] + for inst in instructions: + new_insts.append(inst) + if inst.opname in ("BINARY_SLICE", "STORE_SLICE"): + # new instruction + subscr_inst = create_instruction(inst.opname.replace("SLICE", "SUBSCR")) + if inst.exn_tab_entry and inst.exn_tab_entry.end is inst: + inst.exn_tab_entry.end = subscr_inst + subscr_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry) + subscr_inst.positions = inst.positions + # modify inst in-place to preserve jump target + inst.opcode = dis.opmap["BUILD_SLICE"] + inst.opname = "BUILD_SLICE" + inst.arg = 2 + inst.argval = 2 + new_insts.append(subscr_inst) + instructions[:] = new_insts + + +FUSED_INSTS = { + "LOAD_FAST_LOAD_FAST": ("LOAD_FAST", "LOAD_FAST"), + "STORE_FAST_STORE_FAST": ("STORE_FAST", "STORE_FAST"), + "STORE_FAST_LOAD_FAST": ("STORE_FAST", "LOAD_FAST"), +} + + +def remove_fused_load_store(instructions: list[Instruction]) -> None: + new_insts = [] + for inst in instructions: + if inst.opname in FUSED_INSTS: + inst0, inst1 = FUSED_INSTS[inst.opname] + argval0, argval1 = inst.argval + + replace_insts = [ + create_instruction(inst0, argval=argval0), + create_instruction(inst1, argval=argval1), + ] + new_insts.extend(overwrite_instruction(inst, replace_insts)) + else: + new_insts.append(inst) + instructions[:] = new_insts + + +def explicit_super(code: types.CodeType, instructions: list[Instruction]) -> None: + """convert super() with no args into explicit arg form""" + cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ()) + if not len(code.co_varnames): + # A function with no argument cannot contain a valid "super()" call + return + output = [] + for idx, inst in enumerate(instructions): + output.append(inst) + if inst.opname == "LOAD_GLOBAL" and inst.argval == "super": + nexti = instructions[idx + 1] + if nexti.arg == 0 and ( + (sys.version_info >= (3, 12) and nexti.opname == "CALL") + or ( + sys.version_info >= (3, 11) + and sys.version_info < (3, 12) + and nexti.opname == "PRECALL" + ) + or (sys.version_info < (3, 11) and nexti.opname == "CALL_FUNCTION") + ): + assert "__class__" in cell_and_free + output.append(create_instruction("LOAD_DEREF", argval="__class__")) + first_var = code.co_varnames[0] + if first_var in cell_and_free: + output.append(create_instruction("LOAD_DEREF", argval=first_var)) + else: + output.append(create_instruction("LOAD_FAST", argval=first_var)) + nexti.arg = 2 + nexti.argval = 2 + if nexti.opname == "PRECALL": + # also update the following CALL instruction + call_inst = instructions[idx + 2] + call_inst.arg = 2 + call_inst.argval = 2 + + instructions[:] = output + + +def fix_extended_args(instructions: list[Instruction]) -> int: + """Fill in correct argvals for EXTENDED_ARG ops""" + output: list[Instruction] = [] + + def maybe_pop_n(n): + for _ in range(n): + if output and output[-1].opcode == dis.EXTENDED_ARG: + output.pop() + + for inst in instructions: + if inst.opcode == dis.EXTENDED_ARG: + # Leave this instruction alone for now so we never shrink code + inst.arg = 0 + elif inst.arg and inst.arg > 0xFFFFFF: + maybe_pop_n(3) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 24)) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16)) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) + elif inst.arg and inst.arg > 0xFFFF: + maybe_pop_n(2) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16)) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) + elif inst.arg and inst.arg > 0xFF: + maybe_pop_n(1) + output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) + output.append(inst) + + added = len(output) - len(instructions) + assert added >= 0 + instructions[:] = output + return added + + +def instruction_size(inst) -> int: + import torch + + if sys.version_info >= (3, 11): + return 2 * (torch._C._dynamo.eval_frame.py_opcode_caches[inst.opcode] + 1) + return 2 + + +def check_offsets(instructions) -> None: + offset = 0 + for inst in instructions: + assert inst.offset == offset + offset += instruction_size(inst) + + +def update_offsets(instructions) -> None: + offset = 0 + for inst in instructions: + inst.offset = offset + offset += instruction_size(inst) + + +def debug_bytes(*args) -> str: + index = range(max(map(len, args))) + result = [ + " ".join(f"{x:03}" for x in arg) + for arg in [index] + + list(args) + + [[int(a != b) for a, b in zip(args[-1], args[-2])]] + ] + + return "bytes mismatch\n" + "\n".join(result) + + +def debug_checks(code): + """Make sure our assembler produces same bytes as we start with""" + dode = transform_code_object(code, lambda x, y: None, safe=True) + assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code) + assert code.co_lnotab == dode.co_lnotab, debug_bytes(code.co_lnotab, dode.co_lnotab) + + +HAS_LOCAL = set(dis.haslocal) +HAS_NAME = set(dis.hasname) +HAS_FREE = set(dis.hasfree) +HAS_CONST = set(dis.hasconst) + + +def get_const_index(code_options, val) -> int: + for i, v in enumerate(code_options["co_consts"]): + # NOTE: stronger comparison is required, since we have + # examples where two values compare equal but have + # different semantic meaning in some cases, e.g. + # 0.0 == -0.0 but have different effects in torch.copysign. + if val is v: + return i + code_options["co_consts"] += (val,) + return len(code_options["co_consts"]) - 1 + + +def fix_vars(instructions: list[Instruction], code_options, varname_from_oparg=None): + # compute instruction arg from argval if arg is not provided + names = {name: idx for idx, name in enumerate(code_options["co_names"])} + + def get_name_index(name) -> int: + try: + idx = names[name] + except KeyError: + # Add a missing item to co_names + idx = names[name] = len(names) + code_options["co_names"] = (*code_options["co_names"], name) + assert len(code_options["co_names"]) == len(names) + return idx + + if sys.version_info < (3, 11): + assert varname_from_oparg is None + varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])} + freenames = { + name: idx + for idx, name in enumerate( + code_options["co_cellvars"] + code_options["co_freevars"] + ) + } + else: + assert callable(varname_from_oparg) + allnames = {} + for idx in itertools.count(): + try: + name = varname_from_oparg(idx) + allnames[name] = idx + except IndexError: + break + varnames = {name: allnames[name] for name in code_options["co_varnames"]} + freenames = { + name: allnames[name] + for name in code_options["co_cellvars"] + code_options["co_freevars"] + } + for i in range(len(instructions)): + + def should_compute_arg(): + # argval is prioritized over arg + return instructions[i].argval is not _NotProvided + + if instructions[i].opname == "LOAD_GLOBAL": + # 3.11 LOAD_GLOBAL requires both arg and argval - see create_instruction + assert instructions[i].argval is not _NotProvided + if sys.version_info >= (3, 11): + assert instructions[i].arg is not None + instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + ( + cast(int, instructions[i].arg) % 2 + ) + else: + instructions[i].arg = get_name_index(instructions[i].argval) + elif instructions[i].opname == "LOAD_ATTR": + # 3.12 LOAD_ATTR requires both arg and argval, like LOAD_GLOBAL + assert instructions[i].argval is not _NotProvided + if sys.version_info >= (3, 12): + assert instructions[i].arg is not None + instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + ( + cast(int, instructions[i].arg) % 2 + ) + else: + instructions[i].arg = get_name_index(instructions[i].argval) + elif instructions[i].opname == "LOAD_SUPER_ATTR": + assert instructions[i].arg is not None + assert instructions[i].argval is not _NotProvided + # Copy low bit, force second bit on for explicit super (the "+ 2") + instructions[i].arg = ( + (get_name_index(instructions[i].argval) << 2) + + (cast(int, instructions[i].arg) % 2) + + 2 + ) + elif instructions[i].opname in FUSED_INSTS: + assert sys.version_info >= (3, 13) + assert isinstance(instructions[i].argval, tuple) + assert len(instructions[i].argval) == 2 + arg_tuple = tuple( + varnames[name] if name in varnames else freenames[name] + for name in instructions[i].argval + ) + instructions[i].arg = (arg_tuple[0] << 4) + (arg_tuple[1] & 15) + elif instructions[i].opcode in HAS_LOCAL: + if should_compute_arg(): + if ( + sys.version_info >= (3, 13) + and instructions[i].argval not in varnames + ): + # instructions like LOAD_FAST used for both local and free vars + instructions[i].arg = freenames[instructions[i].argval] + else: + instructions[i].arg = varnames[instructions[i].argval] + elif instructions[i].opcode in HAS_NAME: + if should_compute_arg(): + instructions[i].arg = get_name_index(instructions[i].argval) + elif instructions[i].opcode in HAS_FREE: + if should_compute_arg(): + instructions[i].arg = freenames[instructions[i].argval] + elif instructions[i].opcode in HAS_CONST: + # NOTE: only update argval if arg is not provided. This assumes + # that any additions to co_consts are appended. + if instructions[i].arg is None: + # cannot use a dictionary since consts may not be hashable + idx = get_const_index(code_options, instructions[i].argval) + assert idx >= 0 + instructions[i].arg = idx + + +def clear_instruction_args(instructions): + # Clear the instruction arg for instructions that have argvals. + # Useful for using dis'd bytecode within generated bytecode. + for inst in instructions: + if ( + inst.argval is not _NotProvided + and ( + inst.opcode in HAS_LOCAL + or inst.opcode in HAS_NAME + or inst.opcode in HAS_FREE + or inst.opcode in HAS_CONST + ) + and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR") + ): + inst.arg = None + + +@functools.lru_cache +def get_code_keys() -> list[str]: + # Python 3.11 changes to code keys are not fully documented. + # See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24 + # for new format. + keys = ["co_argcount"] + keys.append("co_posonlyargcount") + keys.extend( + [ + "co_kwonlyargcount", + "co_nlocals", + "co_stacksize", + "co_flags", + "co_code", + "co_consts", + "co_names", + "co_varnames", + "co_filename", + "co_name", + ] + ) + if sys.version_info >= (3, 11): + keys.append("co_qualname") + keys.append("co_firstlineno") + if sys.version_info >= (3, 10): + keys.append("co_linetable") + else: + keys.append("co_lnotab") + if sys.version_info >= (3, 11): + # not documented, but introduced in https://github.com/python/cpython/issues/84403 + keys.append("co_exceptiontable") + keys.extend( + [ + "co_freevars", + "co_cellvars", + ] + ) + return keys + + +def transform_code_object(code, transformations, safe=False) -> types.CodeType: + keys = get_code_keys() + code_options = {k: getattr(code, k) for k in keys} + assert len(code_options["co_varnames"]) == code_options["co_nlocals"] + + instructions = cleaned_instructions(code, safe) + propagate_line_nums(instructions) + + transformations(instructions, code_options) + return clean_and_assemble_instructions(instructions, keys, code_options)[1] + + +def clean_and_assemble_instructions( + instructions: list[Instruction], keys: list[str], code_options: dict[str, Any] +) -> tuple[list[Instruction], types.CodeType]: + # also implicitly checks for no duplicate instructions + check_inst_exn_tab_entries_valid(instructions) + + code_options["co_nlocals"] = len(code_options["co_varnames"]) + varname_from_oparg = None + if sys.version_info >= (3, 11): + # temporary code object with updated names + tmp_code = types.CodeType(*[code_options[k] for k in keys]) + varname_from_oparg = tmp_code._varname_from_oparg # type: ignore[attr-defined] + fix_vars(instructions, code_options, varname_from_oparg=varname_from_oparg) + + dirty = True + while dirty: + update_offsets(instructions) + devirtualize_jumps(instructions) + # this pass might change offsets, if so we need to try again + dirty = bool(fix_extended_args(instructions)) + + remove_extra_line_nums(instructions) + bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"]) + if sys.version_info < (3, 10): + code_options["co_lnotab"] = lnotab + else: + code_options["co_linetable"] = lnotab + + code_options["co_code"] = bytecode + code_options["co_stacksize"] = stacksize_analysis(instructions) + assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - { + "co_posonlyargcount" + } + if sys.version_info >= (3, 11): + code_options["co_exceptiontable"] = assemble_exception_table( + compute_exception_table(instructions) + ) + + return instructions, types.CodeType(*[code_options[k] for k in keys]) + + +def populate_kw_names_argval(instructions, consts): + for inst in instructions: + if inst.opname == "KW_NAMES": + inst.argval = consts[inst.arg] + + +# If safe=True, we do not make any bytecode modifications. +# Mainly used for debugging bytecode_transformation (see debug_checks) +def cleaned_instructions(code, safe=False) -> list[Instruction]: + instructions = _cached_cleaned_instructions(code, safe) + # We have a lot of code that implicitly mutates the instruction array. We + # could do better here by making the copies explicit when necessary. + return _clone_instructions(instructions) + + +# Copy an instructions array, making sure to remap the individual instruction targets. +def _clone_instructions(instructions): + # This is super hot and this is the fastest way to do this (tried copy.copy + # and dataclasses.replace). + copied = [ + Instruction( + i.opcode, + i.opname, + i.arg, + i.argval, + i.offset, + i.starts_line, + i.is_jump_target, + i.positions, + i.target, + i.exn_tab_entry, + i.argrepr, + ) + for i in instructions + ] + + remap = dict(zip(instructions, copied)) + # Handle `None` in the remapper so we don't need an extra `if`. + remap[None] = None + + for i in copied: + i.target = remap[i.target] + if entry := i.exn_tab_entry: + i.exn_tab_entry = InstructionExnTabEntry( + remap[entry.start], + remap[entry.end], + remap[entry.target], + entry.depth, + entry.lasti, + ) + return copied + + +@functools.lru_cache +def _cached_cleaned_instructions(code, safe=False) -> Sequence[Instruction]: + instructions = list(map(convert_instruction, dis.get_instructions(code))) + check_offsets(instructions) + if sys.version_info >= (3, 11): + populate_kw_names_argval(instructions, code.co_consts) + virtualize_exception_table(code.co_exceptiontable, instructions) + virtualize_jumps(instructions) + strip_extended_args(instructions) + if not safe: + if sys.version_info < (3, 11): + remove_load_call_method(instructions) + if sys.version_info < (3, 12): + explicit_super(code, instructions) + if sys.version_info >= (3, 11): + remove_jump_if_none(instructions) + if sys.version_info >= (3, 12): + remove_binary_store_slice(instructions) + if sys.version_info >= (3, 13): + remove_fused_load_store(instructions) + if sys.version_info >= (3, 11): + update_offsets(instructions) + devirtualize_jumps(instructions) + return instructions + + +_unique_id_counter = itertools.count() + + +def unique_id(name, with_uuid=False) -> str: + ret = f"{name}_{next(_unique_id_counter)}" + if with_uuid: + ret += f"_{uuid.uuid4()}".replace("-", "_") + return ret + + +def is_generator(code: types.CodeType) -> bool: + co_generator = 0x20 + return (code.co_flags & co_generator) > 0 + + +def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True): + """Generates bytecode from a template function `fn` for use in + dynamo bytecode generation. + + For example, we can generate Python-version-independent bytecode + for looping through a dictionary and copying the values to a new dictionary. + + def template(d1, d2): + for k, v in d1.items(): + d2[k] = v + + + or a try block: + + def template(): + try: + dummy1 + except: + dummy2 + raise + dummy3 + + Args: + fn: a function template to generate bytecode from + varname_map: a mapping of `fn`'s varnames to new names. This + map will be applied to the generated bytecode's varnames. + For example, local variables in `fn` can be replaced with + new names that are generated by `OutputGraph.new_var`. + noreturn: remove all RETURN_* bytecodes and replace them with a jump + to the end of the bytecode. NOTE: any items pushed to the stack + for return WILL remain on the stack! Append a POP_TOP if you don't want + that item to be present. + noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive). + """ + insts = cleaned_instructions(fn.__code__) + clear_instruction_args(insts) + + if noprefix: + for i, inst in enumerate(insts): + if inst.opname == "RESUME": + insts = insts[i + 1 :] + break + + for inst in insts: + # If we don't reset starts_line, then the generated + # bytecode's line number will be based on fn's. + inst.starts_line = None + inst.positions = None + if varname_map and inst.argval in varname_map: + inst.argval = varname_map[inst.argval] + + if noreturn: + if sys.version_info >= (3, 12): + # replace RETURN_CONST with LOAD_CONST RETURN_VALUE + new_insts = [] + for inst in insts: + if inst.opname == "RETURN_CONST": + inst.opcode = dis.opmap["LOAD_CONST"] + inst.opname = "LOAD_CONST" + new_insts.append(inst) + # no need to propagate target/exn table + new_insts.append(create_instruction("RETURN_VALUE")) + else: + new_insts.append(inst) + insts = new_insts + + returns = [] + for inst in insts: + if inst.opname == "RETURN_VALUE": + returns.append(inst) + + if len(returns) == 1 and returns[0] is insts[-1]: + # only 1 return at the end - just pop it + insts.pop(-1) + elif len(returns) > 0: + # create jump target - if the last inst is a return, + # we can replace it with a NOP and make that the jump target. + if insts[-1] is returns[-1]: + insts[-1].opname = "NOP" + insts[-1].opcode = dis.opmap["NOP"] + insts[-1].arg = None + insts[-1].argval = _NotProvided + returns.pop(-1) + else: + insts.append(create_instruction("NOP")) + + # replace returns with jumps + for inst in returns: + # don't replace inst with new instruction + # due to targeting/exn table/etc. + jump_inst = create_jump_absolute(insts[-1]) + inst.opname = jump_inst.opname + inst.opcode = jump_inst.opcode + inst.arg = jump_inst.arg + inst.argval = jump_inst.argval + inst.target = jump_inst.target + + return insts diff --git a/phivenv/Lib/site-packages/torch/_dynamo/cache_size.py b/phivenv/Lib/site-packages/torch/_dynamo/cache_size.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed831a97b8c73e6e5001d19d4b6aede15a7a411 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/cache_size.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +import logging +import weakref +from dataclasses import dataclass + +from torch._guards import CompileId + +from . import config +from .types import DynamoFrameType + + +log = logging.getLogger(__name__) +""" +[Note on cache size limit] + +Background - TorchDynamo cache is a linked list. Each cache entry is a +(guard_manager, out_code, next pointer). These are stored on the f_code's co_extra +scratch space. When a frame is invoked, we walk this linked list and run +guard_manager in each cache_entry to decide if the frame needs recompilation. If none +of the guard_manager's returns True, we recompile and add a new entry. To ensure we +don't end up recompiling infinitely, we put limits on the cache size. + +There are two limits +1) recompile_limit +2) accumulated_recompile_limit + + +Earlier we used to have only limit - maximum number of entries in 1 cache line +(which is now represented by (2) above). So, why do we need two limits? Lets try +to understand that. + +In general, we want our cache limit value to be a small number (e.g. 8 or even +lower). This ensures that for frames that cause too many recompilation fall to +eager quickly. However, there is another problem that prevents us from lowering +the value of recompile_limit. This is due to ID_MATCH'd guards. Today, we put +ID_MATCH guards on nn module if there is a graph break. This means we will have +many recompilations for the same code object because the ID_MATCH guard fails +for different instances of the nn module. This is a common pattern in how models +are authored. Therefore, this requires us to keep the recompile_limit high. + +We resolve this by introducing these two limits. The first limit (1) limits the +number of cache entries that have an ID_MATCH'd guard for an nn module instance. +And, (2)nd limit becomes a safeguard mechanism to have a maximum compilations +for a code object. One important question is - what is the limit for the code +object that does not have any ID_MATCH guard? For such code objects, we choose +(1) as the cache size limit. + +Lets take an example to understand how these limits help. Suppose, we have 16 +instances of a nn module and we ID_MATCH on the self object. Further, suppose +the inputs to these functions have varying batch size, leading to one +recompilation. In total, there will be 32 recompilations, and therefore 32 cache +entries on the forward code object. In the older case when we had only 1 limit, +our cache size limit must be >= 32 to capture all these recompilations. Now, +suppose there is a separate function in the same program which is very dynamic +and unsuitable for compilation. Such a function will need to undergo 32 +compilations to burst the cache and fallback to eager. These 32 recompilations +are too many and we want to fallback for these compilation-unfriendly functions +sooner. + +In the new scenario, we can have (1) recompile_limit = 2, (2) +accumulated_recompile_limit = 32. This means that each ID_MATCH'd object can +have maximum of two cache entries, and the maximum number of cache entries +(irrespective of ID_MATCH obj) is 32. This covers the case of forward code +object which has 32 recompilations. For the other function, the one unsuitable +for recompilation, our limit is 2. So, we will burst the cache in just 2 +recompilations. In this manner, these 2 limits help us resolve the tension +mentioned earlier. +""" + + +@dataclass +class CacheSizeRelevantForFrame: + """ + We track the number of cache entries that have same id_match objects as the + given frame. + + TODO(janimesh) - Consider adding a map from tuple_of_match_ids to count - + https://github.com/pytorch/pytorch/pull/107496#discussion_r1304564682 - this + could be useful for debugging as well. + """ + + # Total number of CacheEntry objects in the Dynamo linked list + num_cache_entries: int = 0 + + # Number of CacheEntry objects having same ID_MATCH'd objects as given frame. + num_cache_entries_with_same_id_matched_objs: int = 0 + + def will_compilation_exceed(self, limit: int) -> bool: + # Checks if a compilation will exceed the given limit (that's why >=). + return ( + self.will_compilation_exceed_accumulated_limit() + or self.will_compilation_exceed_specific_limit(limit) + ) + + def will_compilation_exceed_accumulated_limit(self) -> bool: + return self.num_cache_entries >= config.accumulated_recompile_limit + + def will_compilation_exceed_specific_limit(self, limit: int) -> bool: + return self.num_cache_entries_with_same_id_matched_objs >= limit + + +def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str): + obj = frame.f_locals.get(local_name, None) + weak_id = None + try: + weak_id = weakref.ref(obj) + except TypeError: + pass # cannot weakref bool object + return weak_id + + +def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool: + """ + Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones + in frame.f_locals. + """ + if not cache_entry: + return False + + for ( + local_name, + weakref_from_cache_entry, + ) in cache_entry.guard_manager.id_matched_objs.items(): + if weakref_from_cache_entry() is not None: + weakref_from_frame = _get_weakref_from_f_locals(frame, local_name) + if weakref_from_frame is not weakref_from_cache_entry: + return False + + # Also covers the case where no ID_MATCH objects are saved in frame.f_locals + return True + + +def compute_cache_size( + frame: DynamoFrameType, cache_entry +) -> CacheSizeRelevantForFrame: + # Walk the linked list to calculate the cache size + num_cache_entries = 0 + num_cache_entries_with_same_id_matched_objs = 0 + + while cache_entry: + num_cache_entries += 1 + # Track the number of cache entries having same ID_MATCH'd objects as + # that of frame.f_locals. This will be used later to compare against the + # recompile_limit. + if _has_same_id_matched_objs(frame, cache_entry): + num_cache_entries_with_same_id_matched_objs += 1 + cache_entry = cache_entry.next + + return CacheSizeRelevantForFrame( + num_cache_entries, num_cache_entries_with_same_id_matched_objs + ) + + +def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool: + """ + If the frame (earlier parsed by compute_cache_size) has more than 1 cache + entry with same ID_MATCH'd objects, then its a recompilation. + """ + # Note that you can have multiple entries in the cache but still not a + # recompile, e.g., you can have 64 nn module instances, each one having an + # ID_MATCH guard, and each one having just 1 cache entry in the cache. In + # this case, we can have 64 entries in the cache, but no recompilation + # because there is only one entry for each id_matched_obj. + return cache_size.will_compilation_exceed(1) + + +def exceeds_recompile_limit( + cache_size: CacheSizeRelevantForFrame, compile_id: CompileId +) -> tuple[bool, str]: + """ + Checks if we are exceeding the cache size limit. + """ + if cache_size.will_compilation_exceed_accumulated_limit(): + return True, "accumulated_recompile_limit" + if cache_size.will_compilation_exceed_specific_limit(config.recompile_limit): + return True, "recompile_limit" + # NOTE this check is needed in the case that the frame's cache doesn't grow + # and we keep recompiling. This can happen if the guard guard_manager becomes invalidated, + # e.g. due to guarded objects being freed. This technically makes the + # will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the + # check in case we have a better fix in the future. + assert compile_id.frame_compile_id is not None + if compile_id.frame_compile_id >= config.accumulated_recompile_limit: + return True, "accumulated_recompile_limit" + return False, "" diff --git a/phivenv/Lib/site-packages/torch/_dynamo/callback.py b/phivenv/Lib/site-packages/torch/_dynamo/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..6922f05b93950b189420a9c316ccd9f4f059480f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/callback.py @@ -0,0 +1,171 @@ +""" +This module provides callback management functionality for TorchDynamo's compilation process. + +It implements a thread-safe system for registering, managing and executing callbacks that run +at the start and end of TorchDynamo compilations. Key features include: + +- Registration and deregistration of compilation callbacks +- Thread-safe callback handling with proper locking mechanisms +- Prevention of duplicate callback execution when configured +- Decorator utilities for easy callback registration +- Context manager for controlled callback lifecycle + +The module centers around the CompilationCallbackHandler class which maintains separate +lists for start and end callbacks, manages their execution order, and ensures thread-safety. +Utility decorators @on_compile_start and @on_compile_end provide a convenient way to +register compilation hooks. + +Example usage: + @on_compile_start + def my_start_callback(): + print("Starting compilation") + + @on_compile_end + def my_end_callback(): + print("Compilation complete") +""" + +import enum +import threading +from collections.abc import Generator +from contextlib import contextmanager +from dataclasses import dataclass, field # noqa: F811 +from typing import Any, Callable + + +class CallbackTrigger(enum.Enum): + # most common case, dynamo attempts to trace a new frame + DYNAMO = 1 + # backward compilation can be deferred to runtime + LAZY_BACKWARD = 2 + # some backends autotune at runtime + TRITON_AUTOTUNING = 3 + # cudagraphs record at runtime + CUDAGRAPH_RECORDING = 4 + + +@dataclass +class CallbackArgs: + callback_trigger: CallbackTrigger + compile_id: str + + +@dataclass +class CompilationCallbackHandler: + start_callbacks: list[Callable[[CallbackArgs], None]] = field(default_factory=list) + end_callbacks: list[Callable[[CallbackArgs], None]] = field(default_factory=list) + + __pending_callbacks_counter: int = field(default=0, init=False, repr=False) + __pending_callbacks_counter_lock: threading.Lock = field( + default_factory=threading.Lock, init=False, repr=False + ) + + def register_start_callback( + self, callback: Callable[[CallbackArgs], None] + ) -> Callable[[CallbackArgs], None]: + """ + Register a callback function to be called when the compilation starts. + + Args: + - callback (Callable): The callback function to register. + """ + self.start_callbacks.append(callback) + return callback + + def register_end_callback( + self, callback: Callable[[CallbackArgs], None] + ) -> Callable[[CallbackArgs], None]: + """ + Register a callback function to be called when the compilation ends. + + Args: + - callback (Callable): The callback function to register. + """ + self.end_callbacks.append(callback) + return callback + + def remove_start_callback(self, callback: Callable[[CallbackArgs], None]) -> None: + """ + Remove a registered start callback function. + + Args: + - callback (Callable): The callback function to remove. + """ + self.start_callbacks.remove(callback) + + def remove_end_callback(self, callback: Callable[[CallbackArgs], None]) -> None: + """ + Remove a registered end callback function. + + Args: + - callback (Callable): The callback function to remove. + """ + self.end_callbacks.remove(callback) + + def run_start_callbacks(self, args: CallbackArgs) -> None: + """ + Execute all registered start callbacks. + """ + for callback in self.start_callbacks: + callback(args) + + def run_end_callbacks(self, args: CallbackArgs) -> None: + """ + Execute all registered end callbacks. + """ + for callback in self.end_callbacks: + callback(args) + + @contextmanager + def install_callbacks( + self, trigger: CallbackTrigger, compile_id: str + ) -> Generator[None, Any, Any]: + """ + Context manager to install the callbacks and run them when the context is exited. + """ + args = CallbackArgs(trigger, compile_id) + try: + with self.__pending_callbacks_counter_lock: + if self.__pending_callbacks_counter == 0: + self.run_start_callbacks(args) + self.__pending_callbacks_counter += 1 + yield + finally: + with self.__pending_callbacks_counter_lock: + assert self.__pending_callbacks_counter > 0, ( + "Pending callbacks counter cannot become negative." + ) + if self.__pending_callbacks_counter == 1: + self.run_end_callbacks(args) + self.__pending_callbacks_counter -= 1 + + def clear(self) -> None: + """ + Clear all registered callbacks. + """ + self.start_callbacks.clear() + self.end_callbacks.clear() + assert self.__pending_callbacks_counter == 0 + + +callback_handler = CompilationCallbackHandler() + + +def on_compile_start( + callback: Callable[[CallbackArgs], None], +) -> Callable[[CallbackArgs], None]: + """ + Decorator to register a callback function for the start of the compilation. + """ + callback_handler.register_start_callback(callback) + return callback + + +def on_compile_end( + callback: Callable[[CallbackArgs], None], +) -> Callable[[CallbackArgs], None]: + """ + Decorator to register a callback function for the end of the compilation. + """ + callback_handler.register_end_callback(callback) + return callback diff --git a/phivenv/Lib/site-packages/torch/_dynamo/code_context.py b/phivenv/Lib/site-packages/torch/_dynamo/code_context.py new file mode 100644 index 0000000000000000000000000000000000000000..eb1d95aa8e6bac0d32f0f609cfd6351f311fca7c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/code_context.py @@ -0,0 +1,60 @@ +""" +This module provides thread-safe code context management for TorchDynamo using weak references. + +The CodeContextDict class maintains a mapping between Python code objects and their associated +context data, using weak references to automatically clean up entries when code objects are +garbage collected. This prevents memory leaks while allowing context data to be associated +with code objects throughout their lifecycle. + +Key features: +- Thread-safe context storage and retrieval +- Automatic cleanup using weak references +- Safe context management for Python code objects +- Memory-leak prevention + +Example usage: + code_obj = compile('x = 1', '', 'exec') + + # Store context + context = code_context.get_context(code_obj) + context['metadata'] = {'optimized': True} + + # Retrieve context + if code_context.has_context(code_obj): + ctx = code_context.get_context(code_obj) + # Use context data... + + # Remove context + ctx = code_context.pop_context(code_obj) +""" + +import types +from typing import Any + +from .utils import ExactWeakKeyDictionary + + +class CodeContextDict: + def __init__(self) -> None: + self.code_context: ExactWeakKeyDictionary = ExactWeakKeyDictionary() + + def has_context(self, code: types.CodeType) -> bool: + return code in self.code_context + + def get_context(self, code: types.CodeType) -> dict[str, Any]: + ctx = self.code_context.get(code) + if ctx is None: + ctx = {} + self.code_context[code] = ctx + return ctx + + def pop_context(self, code: types.CodeType) -> dict[str, Any]: + ctx = self.get_context(code) + self.code_context._remove_id(id(code)) + return ctx + + def clear(self) -> None: + self.code_context.clear() + + +code_context: CodeContextDict = CodeContextDict() diff --git a/phivenv/Lib/site-packages/torch/_dynamo/codegen.py b/phivenv/Lib/site-packages/torch/_dynamo/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..50609848994b5d139a65465b440fa67a369e32b1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/codegen.py @@ -0,0 +1,692 @@ +# mypy: allow-untyped-defs + +""" +This module provides utilities for generating Python bytecode in PyTorch's Dynamo system. +It includes functionality for: +- Constructing bytecode sequences for Python operations +- Managing stack operations and variable tracking +- Handling graph outputs and their conversions +- Supporting different Python versions (3.11+, 3.12+, 3.13+) +- Converting high-level operations to low-level bytecode instructions +- Managing constant loading and attribute access +- Supporting function creation and closure handling +""" + +import collections +import dataclasses +import re +import sys +import types +from collections import Counter +from typing import Optional, TYPE_CHECKING, Union + +import torch.nn +from torch.utils._ordered_set import OrderedSet + +from . import config, graph_break_hints, utils +from .bytecode_transformation import ( + add_push_null, + add_push_null_call_function_ex, + create_call_function, + create_call_method, + create_dup_top, + create_instruction, + create_load_const, + create_load_method, + create_rot_n, + Instruction, +) +from .exc import IncorrectUsage, unimplemented_v2 +from .source import AttrSource, ChainedSource, DictGetItemSource, Source +from .utils import is_safe_constant, rot_n_helper +from .variables.base import ValueMutationExisting, VariableTracker +from .variables.functions import ( + ContextlibContextManagerLocalGeneratorObjectVariable, + LocalGeneratorObjectVariable, +) +from .variables.nn_module import NNModuleVariable +from .variables.tensor import ( + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .variables.torch_function import TensorWithTFOverrideVariable + + +if TYPE_CHECKING: + from .symbolic_convert import InstructionTranslatorBase + + +@dataclasses.dataclass +class GraphOutputEntry: + index: int + variable: VariableTracker + + +class PyCodegen: + """ + Helper class uses for constructing Python bytecode + """ + + def __init__( + self, + tx: "InstructionTranslatorBase", + root: Optional[torch.nn.Module] = None, + graph_output_var: Optional[str] = None, + tempvars=None, + overridden_sources=None, + ) -> None: + self.root = root + self.top_of_stack: Optional[Union[VariableTracker, Source]] = None + self.uses: Counter[Union[VariableTracker, Source]] = collections.Counter() + self.graph_outputs: dict[int, GraphOutputEntry] = {} + self._output: list[Instruction] = [] + # This determines which VariableTracker/Source should be stored as + # locals, and maps the VariableTracker/Source to the local variable + # name. Note that it could map to None initially, in which case we'll + # overwrite it to map to real temporary names via `add_cache`. + self.tempvars = tempvars or {} + self.tx = tx + self.graph_output_var = graph_output_var + self.code_options = self.tx.output.code_options + self.cell_and_freevars = self.tx.cell_and_freevars + self.new_var = self.tx.output.new_var + self.value_from_source: bool = True + # This serves as a way for codegen to use a different source; we need + # this because sometimes we can't easily modify the original source + # without affecting other components, e.g., guards. + self.overridden_sources: dict[Source, Source] = overridden_sources or {} + + def restore_stack(self, stack_values, *, value_from_source=True): + prev = self.value_from_source + self.value_from_source &= value_from_source + try: + self.foreach(stack_values) + finally: + self.value_from_source = prev + + def graph_output_vars(self): + return [x.variable for x in self.graph_outputs.values()] + + def call_reconstruct(self, value): + res = value.reconstruct(self) + assert res is None, f"reconstruct!=None {value}" + + def add_push_null(self, gen_fn, call_function_ex=False): + """ + `gen_fn` generates instructions via PyCodegen methods + that push a single callable to the stack. + + `add_push_null` pushes a NULL to the stack before or after the + instructions generated by `gen_fn`, depending on Python version. + + Will attempt to use the NULL push bit for instructions + with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR). + """ + old_len = len(self._output) + if sys.version_info < (3, 13): + # gen_fn may DUP_TOP instead if TOS is not cleared. + # Will cause problems since NULL will be pushed right + # before the generated instructions in <= 3.12 + self.clear_tos() + gen_fn() + # inplace modify self._output + added_insts = self._output[old_len:] + del self._output[old_len:] + if call_function_ex: + self._output.extend(add_push_null_call_function_ex(added_insts)) + else: + self._output.extend(add_push_null(added_insts)) + if sys.version_info >= (3, 13): + # NULL will be at top of stack + self.clear_tos() + + def __call__(self, value, allow_cache=True): + """ + Generate code such that top-of-stack (TOS) is set to value. + + `allow_cache` controls the behavior in the following manner. `value` can + either be a VariableTracker or a Source. + + If `value` is a `Source`, `allow_cache` must be True (invariant asserted + below). If the source was reconstructed earlier, we will reuse the + generated code by loading from top of stack or tempvars. + + If `value` is a `VariableTracker`, we have the following cases: + + 1) `allow_cache=True` + a) If the value.source is not None, we will emit the code based on + `value.source` to handle aliasing. + b) If value.source is None (example reconstructing a local list + returned by the compiled function), we will reconstruct the variable + tracker (w/o any source) to emit bytecode that generates a new + python object. + + In both cases of value.source being None or not, if the value was + reconstructed earlier, we will reuse the generated code by loading from + top of stack or tempvars. + + 2) `allow_cache=False` - This is a special case (allow_cache defaults to + True). + a) If the value.source is not None, we reconstruct the variable + tracker and emit a new python object. You might wonder what about + aliasing? The place where we use this config also has the followup + code where the original python object is assigned to this new python + value to handle aliasing (check side_effects.py and search for + allow_cache=False). + + b) If value.source is None, this is not allowed. TODO - assert this. + + Notable effects: + 1. `self.top_of_stack` will be set to `value`, if we don't codegen + `value` based on source. + 2. `self.uses[value]` will increment, unless (a). we codegen via + `top_of_stack` or cached `tempvars`, or (b). `value` has special VT + types like `NNModuleVariable`, etc. + """ + if isinstance(value, Source): + # If the source needs to be overridden, use the new one. + source = self.overridden_sources.get(value, value) + assert allow_cache is True, "allow_cache must be True for Source" + if self.top_of_stack is value: + self._output.append(create_dup_top()) + return + + if self.tempvars.get(source) is not None: + self._output.append(self.create_load(self.tempvars[source])) + self.top_of_stack = source + return + + self.uses[source] += 1 + try: + self.call_reconstruct(source) + except NotImplementedError: + unimplemented_v2( + gb_type="Reconstruction failure: source.reconstruct not implemented", + context=str(source), + explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + if source in self.tempvars: + self._output.append(create_dup_top()) + self.add_cache(source) + self.top_of_stack = source + + return + + assert isinstance(value, VariableTracker) + output = self._output + graph_outputs = self.graph_outputs + + if allow_cache: + if self.top_of_stack is value: + output.append(create_dup_top()) + return + + if self.tempvars.get(value) is not None: + output.append(self.create_load(self.tempvars[value])) + self.top_of_stack = value + return + + if value.is_realized() and isinstance( + value, ContextlibContextManagerLocalGeneratorObjectVariable + ): + raise IncorrectUsage( + "NYI: Returning a @contextmanager object from a torch.compile function" + ) + + # Dynamo normally prefers codegen from source to account for aliasing. + if ( + value.source is not None + and allow_cache + and not ( + value.is_realized() and isinstance(value, LocalGeneratorObjectVariable) + ) + ): + # There's a corner case for export: for instance, if the computation + # graph is just identity on an input tensor, Dynamo would just emit + # a `LOAD_FAST` from the input source, rather than generating an + # identity FX graph. + # + # However, export wants to maximize graph capture; in the case + # above, export _wants to_ obtain an identity FX graph (despite it + # appears unnecessarily expensive for `torch.compile`), so we have + # the following option to override Dynamo's preference for codegen + # from source. Moreover, this option applies recursively, for cases + # like input tensor being returned in a new dictionary. + # + # And why the `ValueMutationExisting` check? Not sure, so leaving it + # to keep the old behavior, as when `value_from_source` was + # introduced. TODO sort out the invariants among side effect, + # codegen and export. + if ( + isinstance(value.mutation_type, ValueMutationExisting) + or self.value_from_source + ): + return self(value.source) + + if value.is_python_constant() and is_safe_constant(value.as_python_constant()): + output.append(self.create_load_const(value.as_python_constant())) + elif isinstance(value, TensorWithTFOverrideVariable): + graph_outputs_key = self.add_graph_output(value) + + self.add_push_null( + lambda: self.load_import_from(utils.__name__, "to_subclass") + ) + self.load_graph_output(graph_outputs[graph_outputs_key].index) + output.append( + self.create_load_global( + value.global_mangled_class_name(self.tx), add=True + ) + ) + output.extend(create_call_function(2, False)) + elif ( + isinstance(value, SymNodeVariable) + and value.python_type() == float + and not self.tx.export + ): + # This is a little unusual; force the output convention to be a + # Tensor here. Don't do this for export because this is + # apparently load bearing for export tests (but I am a bit + # doubtful it actually works in the real world) + # NB: It works to add_graph_output on a computed expression + # as_tensor here, because we memoize as_tensor calls on + # SymNodeVariable! + graph_outputs_key = self.add_graph_output( + value.as_tensor(self.tx, torch.float64) + ) + + def gen_fn(): + self.load_graph_output(graph_outputs[graph_outputs_key].index) + output.append(self.create_load_attr("item")) + + self.add_push_null(gen_fn) + output.extend(create_call_function(0, False)) + elif isinstance( + value, + ( + TensorVariable, + SymNodeVariable, + UnspecializedPythonVariable, + NumpyNdarrayVariable, + ), + ): + graph_outputs_key = self.add_graph_output(value) + + if isinstance(value, NumpyNdarrayVariable): + self.add_push_null( + lambda: self.load_import_from(utils.__name__, "to_numpy_helper") + ) + self.load_graph_output(graph_outputs[graph_outputs_key].index) + output.extend(create_call_function(1, False)) + elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap: + + def gen_fn(): + self.load_graph_output(graph_outputs[graph_outputs_key].index) + output.append(self.create_load_attr("item")) + + self.add_push_null(gen_fn) + output.extend(create_call_function(0, False)) + else: + self.load_graph_output(graph_outputs[graph_outputs_key].index) + elif isinstance(value, NNModuleVariable): + parts = value.module_key.split(".") + if parts[0] in self.code_options["co_varnames"]: + output.append(self.create_load(parts[0])) + parts = parts[1:] + else: + assert self.root is not None + output.append(self.create_load_const_unchecked(self.root)) + for part in parts: + output.append(self.create_load_attr(part)) + else: + self.uses[value] += 1 + try: + self.call_reconstruct(value) + except NotImplementedError: + unimplemented_v2( + gb_type="Reconstruction failure", + context=str(value), + explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.", + hints=[ + "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable " + "that Dynamo cannot reconstruct, then remove it from the return statement.", + *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK, + "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have " + "reconstruction rules may be fundamentally unreconstructable.", + ], + ) + if allow_cache and value in self.tempvars: + self._output.append(create_dup_top()) + self.add_cache(value) + + self.top_of_stack = value + + def add_graph_output(self, value): + graph_outputs_key = id(value.as_proxy()) + if graph_outputs_key not in self.graph_outputs: + self.graph_outputs[graph_outputs_key] = GraphOutputEntry( + len(self.graph_outputs), value + ) + return graph_outputs_key + + def load_graph_output(self, index): + output = self._output + output.append(self.create_load(self.graph_output_var)) + output.append(self.create_load_const(index)) + output.append(self.create_binary_subscr()) + + def add_cache(self, value): + var = self.new_var() + self.tempvars[value] = var + self._output.append(self.create_store(var)) + + def foreach(self, items): + for i in items: + self(i) + + def create_binary_subscr(self) -> Instruction: + return create_instruction("BINARY_SUBSCR") + + def setup_globally_cached(self, name, value): + """Store value in a new global""" + name = re.sub(r"[^a-zA-Z0-9_]+", "_", name) + f_globals = self.tx.f_globals + if name in f_globals: + assert id(f_globals[name]) == id(value) + else: + f_globals[name] = value + return [self.create_load_global(name, add=True)] + + def clear_tos(self): + self.top_of_stack = None + + def append_output(self, inst): + assert isinstance(inst, Instruction) + self._output.append(inst) + self.clear_tos() + + def extend_output(self, insts): + assert all(isinstance(x, Instruction) for x in insts) + self._output.extend(insts) + self.clear_tos() + + def get_instructions(self) -> list[Instruction]: + return self._output + + def create_load(self, name) -> Instruction: + assert name in self.code_options["co_varnames"], f"{name} missing" + return create_instruction("LOAD_FAST", argval=name) + + def create_load_closure(self, name) -> Instruction: + assert name in self.cell_and_freevars() + inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE" + return create_instruction(inst_name, argval=name) + + def create_load_deref(self, name) -> Instruction: + assert name in self.cell_and_freevars() + return create_instruction("LOAD_DEREF", argval=name) + + def create_store(self, name) -> Instruction: + assert name in self.code_options["co_varnames"], f"{name} missing" + return create_instruction("STORE_FAST", argval=name) + + def create_store_deref(self, name) -> Instruction: + assert name in self.cell_and_freevars() + return create_instruction("STORE_DEREF", argval=name) + + def create_load_global(self, name, add=False) -> Instruction: + if add: + self.tx.output.update_co_names(name) + assert name in self.code_options["co_names"], f"{name} not in co_names" + return create_instruction("LOAD_GLOBAL", argval=name) + + def create_load_const(self, value) -> Instruction: + return create_load_const(value) + + def create_load_const_unchecked(self, value) -> Instruction: + return create_load_const(value, checked=False) + + def load_method(self, name): + self.tx.output.update_co_names(name) + self.append_output(create_load_method(name)) + + def call_method(self, nargs): + self.extend_output(create_call_method(nargs)) + + def create_load_attr(self, name) -> Instruction: + if name not in self.code_options["co_names"]: + self.code_options["co_names"] += (name,) + return create_instruction("LOAD_ATTR", argval=name) + + def load_attr(self, name): + self.append_output(self.create_load_attr(name)) + + def create_load_attrs(self, names): + return [self.create_load_attr(name) for name in names.split(".")] + + def create_store_attr(self, name) -> Instruction: + if name not in self.code_options["co_names"]: + self.code_options["co_names"] += (name,) + return create_instruction("STORE_ATTR", argval=name) + + def store_attr(self, name): + self.append_output(self.create_store_attr(name)) + + def load_function_name(self, fn_name, push_null, num_on_stack=0): + """Load the global fn_name on the stack num_on_stack down""" + output = [] + if push_null and sys.version_info >= (3, 11): + output.extend(add_push_null(self.create_load_global(fn_name, add=True))) + if num_on_stack > 0: + output.extend( + [ + *self.rot_n(num_on_stack + 2), + *self.rot_n(num_on_stack + 2), + ] + ) + else: + output.extend( + [ + self.create_load_global(fn_name, add=True), + *self.rot_n(num_on_stack + 1), + ] + ) + return output + + def rot_n(self, n): + try: + return create_rot_n(n) + except AttributeError: + # desired rotate bytecode doesn't exist, generate equivalent bytecode + return [ + create_instruction("BUILD_TUPLE", arg=n), + self.create_load_const_unchecked(rot_n_helper(n)), + *create_rot_n(2), + create_instruction("CALL_FUNCTION_EX", arg=0), + create_instruction("UNPACK_SEQUENCE", arg=n), + ] + + def pop_top(self): + self.append_output(create_instruction("POP_TOP")) + + def call_function(self, nargs: int, push_null: bool): + self.extend_output(create_call_function(nargs, push_null=push_null)) + + def dup_top(self): + self.append_output(create_dup_top()) + + def store(self, varname): + self.append_output(self.create_store(varname)) + + def load_deref(self, varname): + self.append_output(self.create_load_deref(varname)) + + def make_function_with_closure( + self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0 + ): + freevars = code.co_freevars + assert freevars + output = self._output + + def gen_fn(): + # Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars` + # requires that in the generated bytecode, these cells would keep + # their original local names, which we ensure via + # `CellVariable.local_name`. + for var in freevars: + assert var in self.cell_and_freevars() + output.append(self.create_load_closure(var)) + output.append(create_instruction("BUILD_TUPLE", arg=len(freevars))) + output.append(self.create_load_const(code)) + if sys.version_info < (3, 11): + output.append(self.create_load_const(fn_name)) + if sys.version_info >= (3, 13): + output.extend( + [ + create_instruction("MAKE_FUNCTION"), + create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08), + ] + ) + else: + output.append(create_instruction("MAKE_FUNCTION", arg=0x08)) + + if push_null and sys.version_info >= (3, 11): + self.add_push_null(gen_fn) + output.extend(self.rot_n(num_on_stack + 2)) + output.extend(self.rot_n(num_on_stack + 2)) + else: + gen_fn() + output.extend(self.rot_n(num_on_stack + 1)) + self.clear_tos() + + def create_load_python_module(self, mod) -> Instruction: + """ + Generate a LOAD_GLOBAL instruction to fetch a given python module. + """ + output = self.tx.output + global_scope = output.global_scope + name = re.sub(r"^.*[.]", "", mod.__name__) + if global_scope.get(name, None) is mod: + return self.create_load_global(name, add=True) + prefix = f"___module_{name}" + global_name = self.tx.output.install_global_by_id(prefix, mod) + return self.create_load_global(global_name, add=True) + + def mark_source_temp(self, source: Source) -> None: + """ + Mark a source as a temp variable, so that it can be reused. + """ + if source not in self.tempvars: + self.tempvars[source] = None + + def make_call_generated_code(self, fn_name: str) -> None: + """Call the generated code function stored in fn_name""" + self.extend_output(self.load_function_name(fn_name, True)) + + graphargs = self.tx.output.graphargs + + seen_sources: OrderedSet[Source] = OrderedSet() + + def collect_temp_source(source): + if source in seen_sources: + # This source is used at least twice, so it can be reused + self.mark_source_temp(source) + # Dont trace source further. This prevents us from marking too + # many nodes as temp sources. + return + + seen_sources.add(source) + + if isinstance(source, ChainedSource): + collect_temp_source(source.base) + + if isinstance(source, DictGetItemSource) and isinstance( + source.index, Source + ): + collect_temp_source(source.index) + + # Collect all the sources that are used more than once, so that we can + # generate tmp variables in the generated pre-graph bytecode. This + # essentially implements CSE. + for arg in graphargs: + if arg.source is not None: + collect_temp_source(arg.source) + + cm_var = None + if config.record_runtime_overhead: + # Record the pregraph bytecode start + self.add_push_null( + lambda: self.load_import_from( + utils.__name__, "record_pregraph_bytecode_enter" + ) + ) + self.extend_output(create_call_function(0, False)) + cm_var = self.new_var() + self.store(cm_var) + + for arg in graphargs: + if arg.pass_arg_as_tensor: + self.add_push_null( + lambda: self.extend_output( + [ + self.create_load_python_module(torch), + self.create_load_attr("_as_tensor_fullprec"), + ] + ) + ) + self.call_reconstruct(arg) + self.extend_output(create_call_function(1, False)) + else: + self.call_reconstruct(arg) + + if config.record_runtime_overhead: + # Record the pregraph bytecode end + self.add_push_null( + lambda: self.load_import_from( + utils.__name__, "record_pregraph_bytecode_exit" + ) + ) + assert cm_var is not None + self.extend_output([self.create_load(cm_var)]) + self.extend_output(create_call_function(1, False)) + self.pop_top() + + self.extend_output(create_call_function(len(graphargs), False)) + + def load_import_from(self, module_name, object_name) -> None: + source = AttrSource(self.tx.import_source(module_name), object_name) + # Note: This approach is somewhat aggressive because typically, a source is marked + # as a tempvar only when it is used more than once. In this case, we're marking it + # as a tempvar without performing that analysis. However, this is a simple solution, + # and in many cases, load imports are reused multiple times. + self.mark_source_temp(source) + self(source) + + def create_call_function_kw(self, nargs, kw_names, push_null) -> list[Instruction]: + if sys.version_info >= (3, 13): + output = create_call_function(nargs, push_null) + assert output[-1].opname == "CALL" + output.insert(-1, self.create_load_const(kw_names)) + output[-1] = create_instruction("CALL_KW", arg=nargs) + return output + elif sys.version_info >= (3, 11): + output = create_call_function(nargs, push_null) + if sys.version_info >= (3, 12): + idx = -1 + expected_inst = "CALL" + else: + idx = -2 + expected_inst = "PRECALL" + assert output[idx].opname == expected_inst + kw_names_inst = create_instruction("KW_NAMES", argval=kw_names) + output.insert(idx, kw_names_inst) + return output + return [ + self.create_load_const(kw_names), + create_instruction("CALL_FUNCTION_KW", arg=nargs), + ] + + def create_delete(self, value) -> Instruction: + return create_instruction("DELETE_FAST", argval=value) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/compiled_autograd.py b/phivenv/Lib/site-packages/torch/_dynamo/compiled_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..ffaf715041f46235100eade4cefb1109cfbf3527 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/compiled_autograd.py @@ -0,0 +1,1544 @@ +# mypy: allow-untyped-defs + +""" +Provides functionality for compiling PyTorch's autograd (automatic differentiation) system. + +This module implements compiled autograd, which traces and optimizes backward pass +computations at runtime. The key components are: + +- AutogradCompilerInstance: Traces and compiles autograd graphs using FX +- Context managers (_enable/_disable): Control when compiled autograd is active +- Utility functions: Support graph manipulation, tensor operations, and hooks + +Compiled autograd can significantly improve backward pass performance by removing +Python overhead and enabling additional optimizations. It works by capturing +backward computations into an FX graph that can be compiled and optimized, +while maintaining the same semantics as eager mode autograd. +""" + +import contextlib +import functools +import itertools +import operator +import time +from collections import Counter, defaultdict +from typing import Optional, TYPE_CHECKING, Union + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.external_utils import ( + call_accumulate_grad, + call_backward, + call_hook, + FakeCompiledAutogradEngine, + unwrap_maybe_dynamic_int, +) +from torch._dynamo.source import GetItemSource, LocalSource +from torch._dynamo.utils import ( + counters, + get_chromium_event_logger, + lazy_format_graph_code, + set_locals_to_steal, +) +from torch._functorch._aot_autograd.runtime_wrappers import ( + AutogradLazyBackwardCompileInfo, + CachedAutogradLazyBackwardCompileInfo, +) +from torch._guards import compile_context, CompileContext, CompileId +from torch._logging import getArtifactLogger, trace_structured +from torch._prims_common import clone_preserve_strides +from torch._subclasses import FakeTensorMode +from torch.fx import GraphModule +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import ( + decompose, + disable_autocast_cache, + disable_proxy_modes_tracing, + fetch_object_proxy, + ProxyTorchDispatchMode, + PythonKeyTracer, + track_tensor_tree, +) +from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv +from torch.fx.traceback import preserve_node_meta, set_stack_trace +from torch.utils._ordered_set import OrderedSet +from torch.utils._traceback import CapturedTraceback + + +if TYPE_CHECKING: + from torch.fx.proxy import Proxy + + +TURN_OFF_MSG = """You can turn off compiled autograd by either: +1. Moving the unsupported autograd call outside of the torch.compile'd region. +2. Wrapping the unsupported autograd call in the torch._dynamo.compiled_autograd._disable() context manager. +3. Setting torch._dynamo.config.compiled_autograd=False for the torch.compile call containing the unsupported autograd call. +4. Setting torch._dynamo.config.compiled_autograd=False at the start of the program.""" + +compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") +verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") + + +def snapshot_verbose_logging_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled( + "compiled_autograd_verbose" + ) + + +def snapshot_cudagraph_enabled(): + return torch._inductor.config.triton.cudagraphs + + +def maybe_clone(x): + if x is not None: + return clone_preserve_strides(x) + return x + + +def extract_bw_module(CompiledFunction): + if isinstance( + CompiledFunction._lazy_backward_info, AutogradLazyBackwardCompileInfo + ): + return CompiledFunction._lazy_backward_info.bw_module + elif isinstance( + CompiledFunction._lazy_backward_info, CachedAutogradLazyBackwardCompileInfo + ): + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + return CompiledFunction._lazy_backward_info.bw_module_fn() + else: + raise AssertionError( + "Unexpected Lazy Backward Compilation Info Type. Please file an issue." + ) + + +# Note: [Anomaly Mode Semantics in Compiled Autograd] +# In the eager autograd engine, anomaly mode is able to detect NaNs +# after each node. This is useful, because the executed code with +# and without anomaly mode are the same. So assuming determinism, +# a NaN in regular mode should also happen in anomaly mode. +# +# With torch.compile, following eager semantics would require inserting +# runtime asserts to check for NaNs, which could prevent some fusions. +# This results in different code being run with and without anomaly mode. +# So different semantics are needed, this implementation below will check +# for NaNs at the end of the autograd call, instead of after each node +class NaNChecker: + def __init__(self, accumulate_grad: bool): + self.accumulate_grad = accumulate_grad + self.params_indices: list[int] = [] + self.params_to_check: dict[str, torch.Tensor] = {} + self.output_names: list[str] = [] + + def prep_with_graph(self, graph: torch.fx.Graph): + inputs_node = next(iter(graph.nodes)) + acc_grad_nodes = graph.find_nodes( + op="call_function", target=call_accumulate_grad + ) + output_nodes = graph.find_nodes(op="output")[0].args[0] + assert self.accumulate_grad == bool( + acc_grad_nodes + ) and self.accumulate_grad == (not output_nodes) + + for node in acc_grad_nodes: + param_node = node.args[0] + # AccumulateGrad always saves a reference to the param + # so Compiled Autograd will always lift the param and + # this should always be true + assert ( + param_node.target == operator.getitem + and param_node.args[0] is inputs_node # type: ignore[possibly-undefined] + and isinstance(param_node.args[1], int) + ) + self.params_indices.append(param_node.args[1]) + + self.output_names = [node.name for node in output_nodes] + + def prep_with_inputs(self, inputs: tuple[torch.Tensor]): + if not self.accumulate_grad: + # Using .grad, nothing to prep + return + + # Using .backward, we must check existing grads on params if any + for idx in self.params_indices: + grad = inputs[idx].grad + if grad is not None: + assert not torch.isnan(grad).any(), ( + f"Compiled autograd running under anomaly mode with inputs[{idx}] already " + "having NaN gradient. This is not supported. {TURN_OFF_MSG}" + ) + + self.params_to_check[f"inputs[{idx}]"] = inputs[idx] + + def check(self, out: tuple[torch.Tensor]): + if self.accumulate_grad: + # Using .backward, graph outputs are empty + assert not out + nan_params: list[str] = [] + for inputs_str, param in self.params_to_check.items(): + assert param.grad is not None # not true for autograd.grad + if torch.isnan(param.grad).any(): + nan_params.append(inputs_str) + + if nan_params: + raise RuntimeError( + f"Compiled Autograd returned NaN gradients for parameters: {','.join(nan_params)}." + ) + else: + # Using .grad, graph outputs are grads + nan_grads: list[str] = [] + for i, grad in enumerate(out): + if torch.isnan(grad).any(): + nan_grads.append(self.output_names[i]) + + if nan_grads: + raise RuntimeError( + f"Compiled Autograd returned NaN gradients for output nodes: {','.join(nan_grads)}." + ) + + +# We lazily bind "functional backward" variants for PyTorch built-in autograd +# nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0 +# Each "functional backward" is bound the first time the node's apply_with_saved +# function is called. It's possible to avoid lazy binding and instead bind +# all of this upfront (perhaps at import time) via codegen changes. +class OpNamespace: + def __init__(self): + self.custom_function_name_counter: Counter[str] = Counter() + + def add(self, name, fn, is_custom_function, is_traceable): + if is_custom_function: + name = "CppNode" + name + count = self.custom_function_name_counter[name] + self.custom_function_name_counter[name] += 1 + name = f"{name}{count}" + + assert not hasattr(self, name) + result = Op(name, fn, is_custom_function) + if is_traceable: + setattr(self, name, torch._dynamo.allow_in_graph(result)) + else: + # C++ autograd function was not marked as traceable + # Dynamo can't dry run it at compile time, so must fallback to eager + @torch._dynamo.disable + def run_non_traceable_cpp_in_eager(*args, **kwargs): + return result(*args, **kwargs) + + setattr(self, name, run_non_traceable_cpp_in_eager) + return name + + def get(self, name): + return getattr(self, name) + + +class Op: + def __init__(self, name, fn, is_custom_function): + self.fn = fn + self.is_custom_function = is_custom_function + self.__name__ = name + self.__module__ = "torch._dynamo.compiled_autograd.ops" + + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def __repr__(self): + return self.__module__ + "." + self.__name__ + + +ops = OpNamespace() + + +_graph_placeholders = ["inputs", "sizes", "scalars", "hooks", "packed_data"] +_impure_targets = OrderedSet( + [ + call_hook, + call_backward, + FakeCompiledAutogradEngine._exec_final_callbacks_stub, + call_accumulate_grad, + ] +) + +COMPILE_COUNTER = itertools.count() + + +def make_compile_context(compiled_autograd_id): + return compile_context( + CompileContext( + CompileId( + compiled_autograd_id=compiled_autograd_id, + frame_id=None, + frame_compile_id=None, + ) + ) + ) + + +class AutogradCompilerInstance: + def __init__(self, compiler_fn) -> None: + self.compiler_fn = compiler_fn + self.stack = contextlib.ExitStack() + self.close = self.stack.close + self.shape_env = ShapeEnv() + self.fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=True, + allow_non_fake_inputs=True, + shape_env=self.shape_env, + ) + self.fx_tracer = PythonKeyTracer() + self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") + self.hooks_proxy: Optional[Proxy] = None + + def wrap_fake(self, x, source): + assert isinstance(x, torch.Tensor) + return self.fake_tensor_mode.from_tensor(x, source=source) + + @staticmethod + def source(name, idx) -> GetItemSource: + return GetItemSource(LocalSource(name), idx) + + def begin_capture( + self, + inputs: list[torch.Tensor], + sizes: list[int], + scalars: list[Union[int, float]], + origins: list[list[tuple[int, str]]], + accumulate_grad: bool, + check_nans: bool, + ): + counters["compiled_autograd"]["captures"] += 1 + self.id = next(COMPILE_COUNTER) + self.aot_id_counter: dict[int, int] = defaultdict(int) + self.compile_context = make_compile_context(self.id) + self.compile_context.__enter__() + self.nan_checker = NaNChecker(accumulate_grad) if check_nans else None + self.start_time_ns = time.time_ns() + get_chromium_event_logger().log_event_start( + "compiled_autograd", + self.start_time_ns, + {"graph_id": self.id}, + log_pt2_compile_event=True, + ) + self.fx_tracer.root = torch.nn.Module() + self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) + self.fx_tracer.tensor_attrs = {} + self.symnode_proxy_lookup = {} + ( + args_proxy, + self.sizes_proxy, + self.scalars_proxy, + self.hooks_proxy, + self.packed_data_proxy, + ) = ( + self.fx_tracer.create_proxy("placeholder", name, (), {}) + for name in _graph_placeholders + ) + + self.stack.enter_context(preserve_node_meta()) + inputs_origins, sizes_origins, scalars_origins = origins + + # tensor inputs to fake tensors + x = inputs[0] # mypy will complain about unbound x + try: + for idx, x in enumerate(inputs): + inputs[idx] = self.wrap_fake(x, self.source("inputs", idx)) + except Exception as e: + raise NotImplementedError( + f"Found tensor of type {type(x)}, which is not supported by FakeTensorMode. {TURN_OFF_MSG}" + ) from e + self.bind_objects_to_proxies(inputs, args_proxy, inputs_origins) + + # size inputs to symints + sizes = [ + self.shape_env.create_unspecified_symint_and_symbol( + val, + self.source("sizes", idx), + DimDynamic.DYNAMIC, + ) + for idx, val in enumerate(sizes) + ] + + # We want to mark every size as dynamic, but since there's no way to + # mark a primitive `int` as dynamic, we need to wrap it in a tensor. + # In the graph, we unwrap it with `unwrap_maybe_dynamic_int` back into a primitive. + proxies = [self.sizes_proxy[i] for i in range(len(sizes))] # type: ignore[index] + for i, symint in enumerate(sizes): + proxies[i] = self.fx_tracer.create_proxy( + "call_function", + unwrap_maybe_dynamic_int, + (proxies[i],), + {}, + ) + self.symnode_proxy_lookup[symint.node] = proxies[i] + proxies = self.bind_objects_to_proxies(sizes, proxies, sizes_origins) + + for idx, val in enumerate(scalars): + source = self.source("scalars", idx) + if isinstance(val, int): + scalars[idx] = self.shape_env.create_unspecified_symint_and_symbol( + val, + source, + DimDynamic.DYNAMIC, + ) + elif isinstance(val, float): + scalars[idx] = self.shape_env.create_symfloatnode( + self.shape_env.create_unspecified_symbol( + val, + source=source, + dynamic_dim=DimDynamic.DYNAMIC, + ), + hint=val, + source=source, + ) + else: + raise AssertionError("Unexpected scalar type: ", type(val)) + self.bind_objects_to_proxies(scalars, self.scalars_proxy, scalars_origins) + for i, symval in enumerate(scalars): + self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr] + + # TODO(jansel): are all these modes needed? + self.stack.enter_context(decompose({})) + self.stack.enter_context(self.fake_tensor_mode) + self.stack.enter_context(self.proxy_mode) + self.stack.enter_context(disable_autocast_cache()) + # Needed to make sure we don't accidentally specialize any symbols + assert self.fake_tensor_mode.shape_env is not None + env = self.fake_tensor_mode.shape_env + self.stack.enter_context( + torch.fx.experimental.symbolic_shapes._suppress_guards(env) + ) + return ( + str(CompileContext.current_compile_id()), + inputs, + sizes, + scalars, + ) + + def log_compile_reasons( + self, + compile_reasons: list[str], + ): + assert compile_reasons + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "compiled_autograd_compile_reasons", + "encoding": "json", + }, + payload_fn=lambda: compile_reasons, + ) + + def proxy_call_aot_backward( + self, + pinputs, + psaved_tensors, + saved_tensors, + pctx, + ctx, + maybe_backward_state_idx, + ): + # The AOTBackward call consists of three things: the prologue, the + # backward graph, and the epilogue. + # Our strategy is: + # - allow_in_graph the prologue (in the CA graph and Dynamo graph), + # - copy-paste the backward graph into the CA graph so that CA passes and Dynamo can see it + # - trace directly through the epilogue. Anything that gets baked in is + # constant metadata (for example, metadata about the number of outputs, or removing + # RNG arguments or effect tokens). + # If Dynamo graph capture were better, then we could add a node for the prologue + # into the CA graph and have Dynamo trace into it. + + psymints = [self.to_proxy(e) for e in ctx._get_compiled_autograd_symints()] + + # NOTE: we should only close over constants + CompiledFunction = ctx._forward_cls + bw_module = extract_bw_module(CompiledFunction) + metadata = CompiledFunction.metadata + maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata + aot_id = CompiledFunction._aot_id + del CompiledFunction + + if torch.is_grad_enabled(): + for output_alias_info in metadata.output_info: + if output_alias_info.requires_grad: + raise RuntimeError( + "torch.compile does not currently support higher order gradients." + ) + + @torch._dynamo.allow_in_graph # type: ignore[misc] + def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args): + out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional( + ctx_saved_tensors, + ctx_symints, + metadata, + maybe_subclass_metadata, + *flat_args, + ) + return out + + pgrads = self.fx_tracer.create_proxy( + kind="call_function", + target=call_aot_bwd_prologue, + args=( + psaved_tensors, + psymints, + *pinputs, + ), + kwargs={}, + ) + + pbackward_state = None + if maybe_backward_state_idx is not None: + pbackward_state = self.hooks_proxy[maybe_backward_state_idx] # type: ignore[index] + + # Copy-paste the AOT backward graph into the compiled autograd graph + def copy_paste_aot_backward_graph(): + def num_inputs(graph): + num_args = 0 + for node in graph.nodes: + if node.op == "placeholder": + num_args += 1 + continue + else: + break + return num_args + + # set up the proxy inputs to bw_module + # the calling convention is: [*symints, *args (primals and tangents), backward_state] + num_args = num_inputs(bw_module.graph) + pall_args = [ + pgrads[i] for i in range(num_args - int(pbackward_state is not None)) + ] + # replace the symints with our symints + symints = ctx._get_compiled_autograd_symints() + assert len(symints) == len(ctx.symints) + psymints = [self.to_proxy(e) for e in symints] + pall_args[: len(symints)] = psymints + # Add backward_state + if pbackward_state is not None: + pall_args.append(pbackward_state) + + # run over all nodes of the aot_backward graph. + # copy and paste them all into the compiled autograd graph. + args_idx = 0 + value_remap = {} + poutputs: Optional[list[torch.fx.Proxy]] = None + + # names of nodes must appear only once in the fx.Graph + # dedup AOT backwards that appear multiple times + deduped_aot_id = str(aot_id) + if self.aot_id_counter[aot_id]: + deduped_aot_id += f"_{self.aot_id_counter[aot_id]}" + self.aot_id_counter[aot_id] += 1 + + def make_unique(node_name): + # make it both informative and unique + return f"aot{deduped_aot_id}_{node_name}" + + for node in bw_module.graph.nodes: + if node.op == "placeholder": + ph = pall_args[args_idx].node + ph.name = make_unique(node.name) + value_remap[node] = ph + args_idx += 1 + elif node.op == "output": + assert len(node.args) == 1 + poutputs = [ + torch.fx.Proxy(value_remap[n], self.fx_tracer) + if isinstance(n, torch.fx.Node) + else n + for n in node.args[0] + ] + elif node.op == "get_attr": + name = node.target + qualname = self.fx_tracer.get_fresh_qualname(name) + setattr(self.fx_tracer.root, qualname, getattr(bw_module, name)) + result = self.fx_tracer.create_node("get_attr", qualname, (), {}) + result.name = make_unique(node.name) + value_remap[node] = result + elif node.op == "call_function": + if node.target == torch.ops.aten.view.default: + # this aot bwd graph is being lazily compiled + # we must manually apply the view_to_reshape post grad pass + # since it was already applied to the aot fwd, and baked into the gradients + node.target = torch.ops.aten.reshape.default + result = self.fx_tracer.graph.node_copy( + node, lambda n: value_remap[n] + ) + result.name = make_unique(node.name) + value_remap[node] = result + elif node.op == "call_module": + name = node.target + qualname = self.fx_tracer.get_fresh_qualname(name) + setattr(self.fx_tracer.root, qualname, getattr(bw_module, name)) + result = self.fx_tracer.graph.node_copy( + node, lambda n: value_remap[n] + ) + result.target = qualname + value_remap[node] = result + else: + raise AssertionError("shouldn't get here") + + assert poutputs is not None + + # In general we don't know what the shapes of the outputs are, so allocate + # some dummy sizes for them. + def dummy(): + with disable_proxy_modes_tracing(): + return torch.zeros(0, 0, 0, 0, 123) + + outputs = [ + dummy() if isinstance(o, torch.fx.Proxy) else o for o in poutputs + ] + self.bind_objects_to_proxies(outputs, poutputs) + return outputs + + outputs = copy_paste_aot_backward_graph() + + def proxy_subclass_constructor(subclass_meta, is_runtime, unwrapped_args): + @torch._dynamo.allow_in_graph + def make_subclass(*unwrapped_args): + return subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) + + punwrapped_args = pytree.tree_map(self.to_proxy, unwrapped_args) + + poutput = self.fx_tracer.create_proxy( + kind="call_function", + target=make_subclass, + args=tuple(punwrapped_args), + kwargs={}, + ) + + output = self.allocate_dummy() + self.bind_objects_to_proxies([output], [poutput]) + return output + + results = torch._functorch._aot_autograd.runtime_wrappers._backward_epilogue_functional( + metadata, + maybe_subclass_metadata, + outputs, + make_subclass_override=proxy_subclass_constructor, + ) + presults = pytree.tree_map(self.to_proxy, results) + return presults + + def proxy_call_backward( + self, + inputs, + output_metadatas, + saved_tensors, + backward_idx: int, + ctx: torch.autograd.function.BackwardCFunction, + maybe_backward_state_idx: Optional[int], + ): + assert self.hooks_proxy is not None + pctx = self.hooks_proxy[backward_idx] # type: ignore[index] + pinputs = self.to_proxy(inputs) + psaved_tensors = self.to_proxy(saved_tensors) + if hasattr(ctx._forward_cls, "_aot_id"): # type: ignore[attr-defined] + # AOT backward + proxies = self.proxy_call_aot_backward( + pinputs, + psaved_tensors, + saved_tensors, + pctx, + ctx, + maybe_backward_state_idx, + ) + else: + proxies = self.fx_tracer.create_proxy( + kind="call_function", + target=call_backward, + args=( + pctx, + psaved_tensors, + *pinputs, + ), + kwargs={}, + ) + assert proxies is not None + + with disable_proxy_modes_tracing(): + # create fake Tensors + grad_ins: list[Optional[torch.Tensor]] = [] + for idx, output_metadata in enumerate(output_metadatas): + if output_metadata is None or proxies[idx] is None: + grad_ins.append(None) + continue + + layout, device, dtype, size = output_metadata + grad_ins.append( + torch.empty(size=size, dtype=dtype, layout=layout, device=device) + ) + self.bind_objects_to_proxies(grad_ins, proxies) + return tuple(grad_ins) + + def call_copy_slices_prologue( + self, + inputs, + base_sizes, + base_strides, + base_storage_offset, + view_sizes, + view_strides, + view_storage_offset, + ): + args = ( + inputs, + self.to_proxy(base_sizes), + self.to_proxy(base_strides), + self.to_proxy(base_storage_offset), + self.to_proxy(view_sizes), + self.to_proxy(view_strides), + self.to_proxy(view_storage_offset), + ) + return self.proxy_call(copy_slices_prologue, args, [None] * 3) + + def call_copy_slices_epilogue(self, needs_input_grad, result, res, grad_slice): + return self.proxy_call( + copy_slices_epilogue, + (needs_input_grad, result, res, grad_slice), + [None] * len(needs_input_grad), + ) + + def allocate_dummy(self): + with disable_proxy_modes_tracing(): + # Weird quantity so it's easy to grep + return torch.zeros([0, 123456789]) + + def bind_function(self, fn_name, fn, is_custom_function, is_traceable): + """Binds ops.fn_name = fn""" + return ops.add(fn_name, fn, is_custom_function, is_traceable) + + def apply_functional(self, fn_name, grads, args, output_metadata): + """Proxies a call to ops.fn_name(grads, *args) into the graph""" + op = ops.get(fn_name) + return self.proxy_call(op, (grads, *args), output_metadata) + + def proxy_call(self, fn, args, output_metadata): + """Proxies a call to fn(*args) into the graph""" + flat_args, _ = pytree.tree_flatten(args) + proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args) + proxy_out = self.fx_tracer.create_proxy( + "call_function", fn, args=proxy_args, kwargs={} + ) + result = [self.allocate_dummy() for _ in output_metadata] + self.bind_objects_to_proxies(result, [proxy_out[i] for i in range(len(result))]) + return result + + def validate_outputs(self, _, outputs, args, output_metadata): + """Proxies a call to ops.validate_outputs(outputs, *args) into the graph""" + op = ops.get("validate_outputs") + proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args)) + new_proxy_outputs = self.fx_tracer.create_proxy( + "call_function", op, args=proxy_args, kwargs={} + ) + assert len(output_metadata) == len(outputs) + self.bind_objects_to_proxies(outputs, new_proxy_outputs) + return outputs + + def accumulate(self, old_var, new_var): + old_var_proxy = self.to_proxy(old_var) + new_var_proxy = self.to_proxy(new_var) + proxy_out = self.fx_tracer.create_proxy( + "call_function", torch.add, args=(old_var_proxy, new_var_proxy), kwargs={} + ) + result = self.allocate_dummy() + self.bind_objects_to_proxies([result], [proxy_out]) + return result + + def accumulate_grad(self, variable, grad, has_post_hooks): + self.fx_tracer.create_proxy( + "call_function", + call_accumulate_grad, + args=( + self.to_proxy(variable), + self.to_proxy(grad), + has_post_hooks, + ), + kwargs={}, + ) + + def proxy_call_hook(self, hook, *args, **kwargs): + return self.fx_tracer.create_proxy( + "call_function", + call_hook, + ( + hook, + *[self.to_proxy(x) for x in args], + ), + kwargs, + ) + + def unpack_hook(self, hook_id, data_id): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + data = self.packed_data_proxy[data_id] # type: ignore[index] + proxy = self.proxy_call_hook( + hook, + data, + hook_type="unpack_hook", + ) + out = self.allocate_dummy() + self.bind_objects_to_proxies([out], [proxy]) + return out + + def tensor_pre_hook(self, inputs, hook_id, i: int): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxy = self.proxy_call_hook( + hook, + inputs[i], + hook_type="tensor_pre_hook", + ) + with disable_proxy_modes_tracing(): + inputs[i] = maybe_clone(inputs[i]) + self.bind_objects_to_proxies([inputs[i]], [proxy]) + return inputs + + def cpp_tensor_pre_hook(self, inputs: list[torch.Tensor], hook_id: int, i: int): + proxy = self.fx_tracer.create_proxy( + "call_function", + torch._C._dynamo.compiled_autograd.call_cpp_tensor_pre_hooks, + (hook_id, self.to_proxy(inputs[i])), + {}, + ) + with disable_proxy_modes_tracing(): + inputs[i] = maybe_clone(inputs[i]) + self.bind_objects_to_proxies([inputs[i]], [proxy]) + return inputs + + def pre_hook(self, inputs, hook_id): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxies = self.proxy_call_hook( + hook, + inputs, + hook_type="pre_hook", + ) + with disable_proxy_modes_tracing(): + inputs = [maybe_clone(x) for x in inputs] + self.bind_objects_to_proxies(inputs, proxies) + return inputs + + def post_hook(self, outputs, inputs, hook_id): + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxies = self.proxy_call_hook( + hook, + outputs, + inputs, + hook_type="post_hook", + ) + with disable_proxy_modes_tracing(): + outputs = [maybe_clone(x) for x in outputs] + self.bind_objects_to_proxies(outputs, proxies) + return outputs + + def post_acc_grad_hook(self, input, hook_id): + assert isinstance(input, torch.Tensor) + assert self.hooks_proxy is not None + hook = self.hooks_proxy[hook_id] # type: ignore[index] + proxy = self.proxy_call_hook( + hook, + input, + hook_type="post_acc_grad_hook", + ) + with disable_proxy_modes_tracing(): + input = [maybe_clone(input)] + self.bind_objects_to_proxies(input, [proxy]) + return input + + # Note: [Compiled autograd and cudagraphs] + # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_. + # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph + # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the + # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too. + def move_graph_nodes_to_cuda(self, graph) -> list[int]: + to_move: dict[int, torch.fx.Node] = {} + has_cuda_inputs = False + nodes = list(graph.nodes) + assert nodes[0].target == "inputs" + inputs = nodes[0] + inputs_users = list(inputs.users.keys()) + # input access nodes should immediately follow placeholder nodes + first_getitem_idx = len(_graph_placeholders) + assert nodes[first_getitem_idx] == inputs_users[0] + last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 + assert nodes[last_getitem_idx] == inputs_users[-1] + # getitem nodes on inputs + for i, node in enumerate(inputs_users): + if not has_cuda_inputs and node.meta["val"].device.type == "cuda": + has_cuda_inputs = True + continue + + is_cpu = node.meta["val"].device.type == "cpu" + is_scalar = len(node.meta["val"].size()) == 0 + if is_cpu and is_scalar: + node_users = list(node.users.keys()) + # We can only move the cpu scalar if it is not exposed to user code. + if all( + ( + isinstance(user.target, torch._ops.OpOverload) + and user.target.namespace in ("prims", "aten") + ) + or ( + isinstance(user.target, Op) + and not user.target.is_custom_function + ) + for user in node_users + ): + # all users are prims/aten, can move safely + to_move[i] = node + + # only move cpu scalars to cuda if there were cuda activations in this graph, + # this is to handle the case where cudagraphs is enabled on a cpu-only graph + if has_cuda_inputs: + for node in to_move.values(): + verbose_log.debug("Moving node %s from cpu to cuda", node) + node.meta["val"] = node.meta["val"].cuda() + + # return runtime indices we need to move to cuda + return list(to_move.keys()) + + return [] + + def is_sym_node(self, node): + return ( + isinstance(node, torch.fx.Node) + and node.op == "call_function" + and node.target + in [torch.ops.aten.sym_size.int, torch.ops.aten.sym_numel.default] + ) + + def dce(self): + # Most of these removed nodes would have been removed during Dynamo and AOTDispatch + # Remove some of these nodes earlier to improve compilation speed + + # Dynamo guards will error instead of creating aliasing guards unless we unpack them in the graph + unpack_nodes: OrderedSet[torch.fx.Node] = OrderedSet() + for i, node in enumerate(self.fx_tracer.graph.find_nodes(op="placeholder")): + unpack_nodes.update(node.users.keys()) + assert i == len(_graph_placeholders) - 1 + + def is_impure(node): + if node in unpack_nodes or ( + node.op == "call_function" and node.target in _impure_targets + ): + return True + return node.is_impure() + + before = len(self.fx_tracer.graph.nodes) + self.fx_tracer.graph.eliminate_dead_code(is_impure) + after = len(self.fx_tracer.graph.nodes) + verbose_log.debug("DCE removed %d nodes", before - after) + + def remove_unused_sizes(self): + used_sizes = [] + unused_sizes = [] + + # seek placeholder, should be at nodes[1] + it = iter(self.fx_tracer.graph.nodes) + next(it) + sizes_node = next(it) + assert sizes_node.name == "sizes" + + for getitem_node in sizes_node.users.keys(): + assert getitem_node.target == operator.getitem + if getitem_node.users: + used_sizes.append(getitem_node) + else: + # remove from the graph + unused_sizes.append(getitem_node) + + used_sizes_idx: set[int] = set() + for used in used_sizes: + assert isinstance(used.args, tuple) + assert used.args[0] == sizes_node + assert isinstance(used.args[1], int) + next_size_idx = len(used_sizes_idx) + # used later reindex the runtime sizes arg + used_sizes_idx.add(used.args[1]) + # reindex the graph + used.args = (used.args[0], next_size_idx) + + for unused in unused_sizes: + self.fx_tracer.graph.erase_node(unused) + + return used_sizes_idx + + def create_graph_module(self, id): + return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id) + + def end_capture(self, outputs): + self.fx_tracer.create_proxy( + "call_function", + FakeCompiledAutogradEngine._exec_final_callbacks_stub, + (), + {}, + ) + self.stack.close() + self.fx_tracer.create_node( + "output", + "output", + (self.fx_tracer.create_arg(self.to_proxy(outputs)),), + {}, + ) + runtime_inputs_to_move: list[int] = [] + if snapshot_cudagraph_enabled(): + runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) + + # We traced using dummy tensors. Delete all the metadata of the dummy tensors. + # It's probably better to refactor this class to use a different tracer + # than the make_fx tracer, but that is a larger change. + for node in self.fx_tracer.graph.nodes: + for field in ["tensor_meta", "example_value", "val"]: + if field in node.meta: + del node.meta[field] + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "compiled_autograd_graph_pre_reordering", + "encoding": "string", + }, + payload_fn=lambda: GraphModule( + self.fx_tracer.root, + self.fx_tracer.graph, + f"CompiledAutograd{self.id}PreReordering", + ).print_readable(print_output=False), + ) + self.delay_unpack_hook_nodes() + self.reorder_tensor_pre_hook_nodes() + self.reorder_pre_hook_nodes_to_schedule_asap() + self.reorder_accumulate_grad_nodes() + self.reorder_pre_hook_nodes_to_mimic_eager() + self.reorder_post_acc_grad_hook_nodes() + self.reorder_post_hook_nodes() + # TODO(yf225): work around: remove dead codes like `sym_size` and `sym_numel` which are not used downstream. e.g. + # ``` + # sym_numel_default = torch.ops.aten.sym_numel.default(sum_109); sum_109 = None + # eq_115 = 16 == sym_numel_default; sym_numel_default = eq_115 = None + # sym_size_int_39 = torch.ops.aten.sym_size.int(getitem_112, 1); getitem_112 = None + # eq_116 = 16 == sym_size_int_39; eq_116 = None + # eq_117 = 16 == sym_size_int_39; sym_size_int_39 = eq_117 = None + # ``` + # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and + # should prevent these ops from going into the CA graph. + self.dce() + if self.nan_checker: + self.nan_checker.prep_with_graph(self.fx_tracer.graph) + + # keep only sizes that are actually used in the graph + used_sizes_idx = self.remove_unused_sizes() + + graph = self.create_graph_module(f"CompiledAutograd{self.id}") + set_locals_to_steal(graph, ["inputs"]) + lazy_graph_code = lazy_format_graph_code( + "Compiled autograd graph", + graph, + include_device=True, + include_stride=True, + colored=True, + ) + compiled_autograd_log.info("%s", lazy_graph_code) + verbose_log.debug("%s", lazy_graph_code) + trace_structured( + "compiled_autograd_graph", + payload_fn=lambda: graph.print_readable(print_output=False), + ) + + def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs): + global in_compiled_autograd_region + try: + in_compiled_autograd_region = True + + if self.nan_checker: + self.nan_checker.prep_with_inputs(inputs) + + filtered_sizes = [] + for idx, integer in enumerate(sizes): + if idx in used_sizes_idx: + # can't create negative size + if integer > 0: + filtered_sizes.append(torch.empty(0, integer)) + torch._dynamo.maybe_mark_dynamic(filtered_sizes[-1], 1) + else: + filtered_sizes.append(integer) + + for i in runtime_inputs_to_move: + inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True) + + with _disable(), make_compile_context(self.id): + out = compiled_fn( + inputs, filtered_sizes, scalars, hooks, packed_inputs + ) + if self.nan_checker: + self.nan_checker.check(out) + return out + finally: + in_compiled_autograd_region = False + + get_chromium_event_logger().log_event_end( + "compiled_autograd", + time.time_ns(), + {"graph_id": self.id}, + self.start_time_ns, + log_pt2_compile_event=True, + ) + self.compile_context.__exit__(None, None, None) + return runtime_wrapper, self.compiler_fn(graph) + + @staticmethod + def get_all_nodes(args): + # filter out non-Node args, like None + nodes = [n for n in args if type(n) is torch.fx.Node] + return nodes + + @staticmethod + def is_placeholder(node): + if node.op == "placeholder" or ( + node.op == "call_function" + and node.target == operator.getitem + and node.args[0].op == "placeholder" + ): + return True + return False + + def reorder_accumulate_grad_nodes(self): + """ + Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of + the graph. This differs from eager mode, which schedules them as soon as possible. This + pass attempts to reorder the graph to mimic eager behavior. + """ + for node in self.fx_tracer.graph.find_nodes( + op="call_function", target=call_accumulate_grad + ): + param_node, grad_node = node.args[0], node.args[1] + getitem_node = None + if grad_node.target == operator.getitem: + getitem_node = grad_node + grad_node = getitem_node.args[0] + + arg = max([param_node, grad_node]) # last arg + if arg is not node.prev and not self.is_placeholder(arg): + arg.append(node) + if getitem_node is not None: + arg.append(getitem_node) + + def delay_unpack_hook_nodes(self): + """ + We can delay unpack hooks until they are needed, even later than in the eager autograd engine. + """ + for node in self.fx_tracer.graph.find_nodes( + op="call_function", target=call_hook + ): + if node.kwargs.get("hook_type", None) != "unpack_hook": + continue + + first_user = min(node.users) + first_user.prepend(node) + + def reorder_tensor_pre_hook_nodes(self): + """ + Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed + to the end of the graph. This differs from eager mode, which schedules + them as soon as possible. This pass attempts to reorder the graph to + mimic eager behavior. + """ + for node in self.fx_tracer.graph.find_nodes( + op="call_function", target=call_hook + ): + if node.kwargs.get("hook_type", None) != "tensor_pre_hook": + continue + + getitem_node = node.args[0] + input_node = node.args[1] # tensor_pre_hook handle only one grad tensor + + if input_node is not node.prev and not self.is_placeholder(input_node): + input_node.append(getitem_node) + getitem_node.append(node) + + def reorder_pre_hook_nodes_to_schedule_asap(self): + """ + In this function, we schedule the pre hooks as soon as possible. This + does not match eager behavior (schedule pre hook right before its + registered node), but it can make acc grad be scheduled properly when + the pre hooks are registered to them. After reordering acc grad node, we + will reorder the pre hooks again to mimic eager behavior. + """ + for node in self.fx_tracer.graph.find_nodes( + op="call_function", target=call_hook + ): + if node.kwargs.get("hook_type", None) != "pre_hook": + continue + + getitem_node = node.args[0] + # pre_hook handle a tuple of grad tensors + input_nodes = self.get_all_nodes(node.args[1]) + + to_remove = [] + to_append = [] + hook_block = [node] # contain the hook and hook args getitem + for n in input_nodes: + if n.op == "call_function" and n.target == operator.getitem: + to_append.append(n.args[0]) + to_remove.append(n) + hook_block.append(n) + for a, b in zip(to_remove, to_append): + input_nodes.remove(a) + input_nodes.append(b) + + arg = max(input_nodes) # last input + if arg is not node.prev and not self.is_placeholder(arg): + arg.append(getitem_node) + for n in hook_block: + getitem_node.append(n) + + def reorder_pre_hook_nodes_to_mimic_eager(self): + """ + Usage of AOTAutograd causes all the pre_hook nodes to get pushed to the + end of the graph. This differs from eager mode, which schedules them + right before their registered node execution. This pass attempts to + reorder the graph to mimic eager behavior. + """ + pre_hooks = [] + for node in self.fx_tracer.graph.find_nodes( + op="call_function", target=call_hook + ): + if node.kwargs.get("hook_type", None) != "pre_hook": + continue + pre_hooks.append(node) + + for node in reversed(pre_hooks): + hook_getitem_node = node.args[0] + + users = list(node.users.keys()) + if len(users) == 0: + continue + + # users are all getitem ops and they are used by same registered node + assert all( + user.op == "call_function" and user.target == operator.getitem + for user in users + ) + registered_node = next(iter(users[0].users.keys())) + + if registered_node is not node.next: + registered_node.prepend(hook_getitem_node) + registered_node.prepend(node) + for getitem in users: + registered_node.prepend(getitem) + + def reorder_post_acc_grad_hook_nodes(self): + """ + Usage of AOTAutograd causes all the post_acc_grad_hook nodes to get + pushed to the end of the graph. This differs from eager mode, which + schedules them as soon as possible. This pass attempts to reorder the + graph to mimic eager behavior. + """ + post_acc_grad_hooks = [] + for node in self.fx_tracer.graph.find_nodes( + op="call_function", target=call_hook + ): + if node.kwargs.get("hook_type", None) != "post_acc_grad_hook": + continue + post_acc_grad_hooks.append(node) + + # nodes in post_acc_grad_hooks are in topo order. For hooks registered + # to same node, we should keep their relative order + for node in reversed(post_acc_grad_hooks): + getitem_node = node.args[0] + param_node = node.args[1] # post_acc_grad_hook handle one param + + # find the corresponding acc_grad node + acc_grad_node = None + for n in list(param_node.users.keys()): + if n.op == "call_function" and n.target == call_accumulate_grad: + acc_grad_node = n + break + + assert acc_grad_node is not None, ( + "post_acc_grad_hook must have corresponding acc grad node" + ) + + # append post_acc_grad_hook after acc_grad node + acc_grad_node.append(getitem_node) + getitem_node.append(node) + + def reorder_post_hook_nodes(self): + """ + Usage of AOTAutograd causes all the post_hook nodes to get pushed to the + end of the graph. This differs from eager mode, which schedules them as + soon as possible. This pass attempts to reorder the graph to mimic eager + behavior. + """ + post_hooks = [] + for node in self.fx_tracer.graph.find_nodes( + op="call_function", target=call_hook + ): + if node.kwargs.get("hook_type", None) != "post_hook": + continue + post_hooks.append(node) + + for node in reversed(post_hooks): + getitem_node = node.args[0] + output_nodes = node.args[1] + input_nodes = node.args[2] + + if len(output_nodes) > 0: + continue + + input_nodes_and_users = [] + input_nodes_and_users.extend(list(input_nodes)) + for input_node in input_nodes: + input_nodes_and_users.extend( + user + for user in list(input_node.users.keys()) + if not ( + user.op == "call_function" + and user.target == call_hook + and node.kwargs.get("hook_type", None) == "post_hook" + ) + ) + + arg = max(input_nodes_and_users) # last input users + if arg.op == "call_function" and arg.target == call_accumulate_grad: + param_node = arg.args[0] + post_acc_grad_hook_node = None + for n in list(param_node.users.keys()): + if ( + n.op == "call_function" + and n.target == call_hook + and n.kwargs.get("hook_type", None) == "post_acc_grad_hook" + ): + post_acc_grad_hook_node = n + + if post_acc_grad_hook_node is not None: + post_acc_grad_hook_node.append(getitem_node) + getitem_node.append(node) + continue + + if arg is not node.prev and not self.is_placeholder(arg): + arg.append(getitem_node) + getitem_node.append(node) + + def to_proxy(self, t): + if t is None: + return None + if isinstance(t, list): + return [self.to_proxy(x) for x in t] + if isinstance(t, tuple): + return tuple(self.to_proxy(x) for x in t) + if isinstance(t, (torch.SymInt, torch.SymFloat)): + return self.symnode_proxy_lookup[t.node] + if not isinstance(t, torch.Tensor): + # constant types like device, dtype, str + return t + proxy_tensor = fetch_object_proxy(self.fx_tracer, t) + assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor) + return proxy_tensor.proxy + + def bind_objects_to_proxies( + self, objects, proxies, origins: Optional[list[tuple[int, str]]] = None + ): + if isinstance(proxies, torch.fx.Proxy): + if origins: + assert len(origins) == len(objects) + bound_proxies = [] + for i in range(len(objects)): + nodecall_index, node_name = origins[i] + self.set_node_origin(node_name, nodecall_index, None) + bound_proxies.append(proxies[i]) # type: ignore[index] + proxies = bound_proxies + else: + proxies = [proxies[i] for i in range(len(objects))] # type: ignore[index] + + assert len(objects) == len(proxies) + track_tensor_tree(objects, proxies, constant=None, tracer=self.fx_tracer) + return proxies + + def bind_backward_state(self, index: int): + assert self.hooks_proxy is not None + proxy = self.hooks_proxy[index] # type: ignore[index] + bw_state = BackwardState() + track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer) + return bw_state + + def set_node_origin( + self, + node_name: str, + nodecall_index: int, + pyobj: Optional[torch.autograd.Function], + ): + maybe_aot_id = "" + if pyobj is not None: + forward_cls = pyobj._forward_cls # type: ignore[attr-defined] + if hasattr(forward_cls, "_aot_id"): + # backward was created by AOT Dispatcher + if forward_cls._lazy_backward_info is None: + raise RuntimeError( + """This compiled backward function was saved by AOTAutogradCache, which does not support + compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`.""" + ) + maybe_aot_id = forward_cls._aot_id + new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})" + raw_stack_trace = CapturedTraceback.extract().format()[-1] + new_stack_trace = raw_stack_trace.replace( + "raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code + ) + set_stack_trace(new_stack_trace) + + +# state of the autograd engine dispatch, kept in sync by enable/disable context managers +compiled_autograd_enabled = False + +# global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager" +compiled_autograd_enabled_force_eager = False + +# global flag to check if we are processing graphs produced from a compiled autograd graph +in_compiled_autograd_region = False + +active_disable_ctx = False + +depth = 0 + + +@contextlib.contextmanager +def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): + # The entrypoint to enable CA. + # It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather + # than using this context manager directly. If you are torch.compiling the corresponding + # forward pass, make sure they are wrapped under this context as well. + # + # Example: + # def train(model, inputs, target): + # compiled_model = torch.compile(model) + # pred = compiled_model(data) + # loss = compute_loss(pred, target) + # loss.backward() + # + # with _enable(compiler_fn): + # train(model, inputs, target) + # + # Inputs: + # - compiler_fn: The wrapper that will consume the compiled autograd graph, e.g. `torch.compile` + # - dynamic: Whether compiled autograd will treat tensors in the autograd graph (params, activations) as dynamic. + # This doesn't affect the dynamic configuration of the compilation wrapper. + + if not ignore_active_disable_ctx and active_disable_ctx: + yield + else: + if dynamic: + assert type(dynamic) is bool + + from torch._dynamo import eval_frame + + if eval_frame._stance.stance == "force_eager": + # If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd + # to fall back to eager as well. + global compiled_autograd_enabled_force_eager + compiled_autograd_enabled_force_eager = True + try: + yield + finally: + compiled_autograd_enabled_force_eager = False + else: + # we need to import this, because user might not have imported it if they directly use this context manager + # we need to lazily import it, because of circular dependencies + import torch._inductor.cudagraph_trees + + ( + prior_compiler, + prior_dynamic, + ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler( + functools.partial(AutogradCompilerInstance, compiler_fn), dynamic + ) + if snapshot_verbose_logging_enabled(): + torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type] + global compiled_autograd_enabled + compiled_autograd_enabled = True + global depth + prior_depth = depth + depth += 1 + try: + with torch.autograd.set_multithreading_enabled(False): + yield + finally: + if not prior_compiler: + compiled_autograd_enabled = False + torch._C._dynamo.compiled_autograd.set_autograd_compiler( + prior_compiler, prior_dynamic + ) + depth -= 1 + assert depth == prior_depth, ( + "Nested Compiled Autograd Contexts must return before their parent context" + ) + + +@contextlib.contextmanager +def _disable(): + ( + prior_compiler, + prior_dynamic, + ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False) + global compiled_autograd_enabled + compiled_autograd_enabled = False + global active_disable_ctx + if not active_disable_ctx: + active_disable_ctx = True + try: + yield + finally: + if prior_compiler: + compiled_autograd_enabled = True + active_disable_ctx = False + torch._C._dynamo.compiled_autograd.set_autograd_compiler( + prior_compiler, prior_dynamic + ) + + +# return to starting state of a new process +def reset() -> None: + global compiled_autograd_enabled + compiled_autograd_enabled = False + assert not in_compiled_autograd_region + torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False) + torch._C._dynamo.compiled_autograd.set_verbose_logger(None) + torch._C._dynamo.compiled_autograd.clear_cache() + global COMPILE_COUNTER + COMPILE_COUNTER = itertools.count() + + +# Reimplementation of part of CopySlices::apply in Python. +# The shared code is really similar so we're not going to try to deduplicate. +def copy_slices_prologue( + inputs, + base_sizes, + base_strides, + base_storage_offset, + view_sizes, + view_strides, + view_storage_offset, +): + grad = inputs[0] + result = grad.new_empty_strided(base_sizes, base_strides) + assert grad is not None + result.copy_(grad) + offset = view_storage_offset - base_storage_offset + grad_slice = result.as_strided(view_sizes, view_strides, offset) + return [result, grad_slice, grad_slice.clone(memory_format=torch.contiguous_format)] + + +# Reimplementation of part of CopySlices::apply in Python. +# The shared code is really similar so we're not going to try to deduplicate. +def copy_slices_epilogue(needs_input_grad, result, res, grad_slice): + grad_inputs = [None] * len(needs_input_grad) + for i in range(len(needs_input_grad)): + if needs_input_grad[i]: + if res[i] is None: + continue + if i == 0: + grad_slice.copy_(res[i]) + grad_inputs[i] = result + else: + grad_inputs[i] = res[i] + return grad_inputs diff --git a/phivenv/Lib/site-packages/torch/_dynamo/comptime.py b/phivenv/Lib/site-packages/torch/_dynamo/comptime.py new file mode 100644 index 0000000000000000000000000000000000000000..6ee4d81890a2815d9a38bb0e795f5e31deb2399d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/comptime.py @@ -0,0 +1,429 @@ +# mypy: allow-untyped-defs + +""" +This module provides the public comptime interface to TorchDynamo, enabling users to execute +arbitrary Python code during symbolic evaluation of their programs. + +The comptime interface allows inspection and modification of TorchDynamo's compilation +process while it is running. This can be useful for: + +- Debugging compilation issues +- Inspecting intermediate state +- Adding custom guards or graph breaks +- Analyzing symbolic shapes and values + +Example usage: + + import torch + from torch._dynamo.comptime import comptime + + def my_model(x): + # Print the compile-time known information about x + comptime.print(x) + + # Print the current FX graph being constructed + comptime.print_graph() + + # Force a value to be treated as static + if comptime(lambda ctx: ctx.get_local("x").is_dynamic()): + comptime.force_static(x) + + # Add a manual graph break + comptime.graph_break() + +Note: While this API provides significant flexibility, it intentionally avoids +exposing internal implementation details of TorchDynamo to maintain compatibility +across versions. +""" + +import builtins +import dis +import time +import traceback +from typing import Optional, Union + +import torch +from torch.fx.experimental.symbolic_shapes import free_symbols + +from .exc import unimplemented_v2 +from .variables import CellVariable +from .variables.constant import ConstantVariable +from .variables.tensor import SymNodeVariable + + +class ComptimeVar: + """ + A ComptimeVar represents a Python value, at some particular point + in time, in the Python code we are symbolically evaluating with + torchdynamo. This must be distinguished from a runtime value, as + at compile-time there are some properties of the variable we + do not know (for example, if the ComptimeVar represents a Tensor, + we only know metadata about the tensor; we do NOT know what the + actual data in the Tensor is.) + """ + + def __init__(self, v) -> None: + self.__variable = v + + def as_proxy(self): + """ + Returns an fx.Proxy (or tuple/list of fx.Proxy) representing + this variable in the FX graph we are assembling to pass + to the user compiler. + + This method only works for variables we actually track in + the FX graph, aka Tensors (and ints, if you are compiling + with dynamic shapes). In particular, if you have a list + or tuple of tensors, you will get a list/tuple of proxies + (not a single proxy representing the entire list/tuple). + """ + return self.__variable.as_proxy() + + def is_proxy(self): + """ + Returns True if as_proxy() would succeed. + """ + return self.__variable.is_proxy() + + def as_fake(self): + """ + Returns a "fake" value (either a FakeTensor or a SymInt) + representing the variable in question. This only works + for variables that denote Tensor or int. You can use + this to query metadata; e.g., v.as_fake().size(0) will + tell you the compile-time known size of the tensor. + + WARNING: Do NOT mutate the returned tensor. + """ + return self.__variable.as_proxy().node.meta["example_value"] + + def size(self, dim: Optional[int] = None) -> Union[int, torch.SymInt]: + """ + Returns the size of the tensor (if dim is None) or the size + at the dimension dim. The returned size may be a SymInt. + """ + return self.as_fake().size(dim) + + def python_type(self): + """ + Returns what type(v) would have returned for the variable + at compile time. + """ + return self.__variable.python_type() + + def as_python_constant(self): + """ + Returns the Python value this variable would have, but only if it is + completely known at compile-time (e.g., it is constant). + + WARNING: Do NOT mutate the returned constant. The returned constant + may or may not correspond to the actual value this variable may take + on at runtime; for example, if the variable in question is a constant + list, we may return a copy of that list. + """ + return self.__variable.as_python_constant() + + def is_python_constant(self): + """ + Returns True if as_python_constant would succeed. + """ + return self.__variable.is_python_constant() + + def is_dynamic(self): + if isinstance(self.__variable, SymNodeVariable): + fs = free_symbols(self.__variable.sym_num) + return bool(fs) + return False + + def force_static(self): + """ + Forces that a value is static, inducing a guard on its specific value + """ + if isinstance(self.__variable, SymNodeVariable): + self.__variable.evaluate_expr() + elif isinstance(self.__variable, ConstantVariable): + # TODO: Maybe complain if this isn't a int/bool/float variable + pass + else: + raise AssertionError( + f"cannot force {self.__variable} ({type(self.__variable)}) static" + ) + + def _i_will_not_complain_if_bc_breaks_VariableTracker(self): + """ + Returns the internal data structure VariableTracker that Dynamo uses + to represent variables at compile time. There are no BC guarantees on + this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if you rely on + it. + """ + return self.__variable + + def __repr__(self) -> str: + return self.__variable.debug_repr() + + # TODO: API for adding a custom guard + + +class ComptimeContext: + """ + This context class provides access to a public API for Dynamo's internals. + If there is something here you would find useful that is missing, please + file a feature request at https://github.com/pytorch/pytorch/ + """ + + def __init__(self, tx) -> None: + self.__tx = tx + + def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: + """ + Retrieve the compile-time known information about a local. + """ + tx = self.__get_tx(stacklevel) + var = tx.symbolic_locals[name] + + # Auto-dereference when accessing cell locals in python. + if isinstance(var, CellVariable): + return ComptimeVar(tx.output.side_effects.load_cell(var)) + + return ComptimeVar(var) + + def graph_break(self, msg="ComptimeContext.graph_break"): + """ + Manually trigger a graph break + """ + unimplemented_v2( + gb_type="ComptimeContext graph break", + context=msg, + explanation=f"Manually triggered ComptimeContext graph break with message {msg}.", + hints=[], + ) + + def graph(self): + """ + Retrieve the partially constructed FX graph that would be + passed to the user compiler after compilation. + """ + return self.__tx.output.graph + + def assert_static(self, val): + """ + Asserts that the int is static (and not dynamic, per dynamic shapes) + """ + assert not val.is_dynamic(), ( + "expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)" + ) + + def print_graph(self, *, verbose=True, file=None): + """ + Print the partially constructed FX graph that would be passed + to the user compiler after compilation. + """ + print( + self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file + ) + + def parent(self): + return ComptimeContext(self.__tx.parent) + + def __get_tx(self, stacklevel): + tx = self.__tx + for _ in range(stacklevel): + tx = tx.parent + return tx + + def print(self, val, *, file=None): + print(repr(val), file=file) + + def print_disas(self, *, file=None, stacklevel=0): + """ + Print the current series of opcodes being executed (not including + parent frames), including where you are in the particular opcode + stream. + """ + tx = self.__get_tx(stacklevel) + print( + dis.Bytecode( + tx.f_code, + current_offset=tx.instructions[tx.instruction_pointer].offset, + ).dis(), + file=file, + ) + + def print_value_stack(self, *, file=None, stacklevel=0): + """ + Print the current Python value stack. Note that this is NOT the same + as the traceback; use print_bt() to print that. Note that at + stacklevel=0, this will typically be empty, as comptime cannot + currently be used in an expression context where there would be + intermediates on the stack. If you would find this useful, please + file a bug at https://github.com/pytorch/pytorch/ + + NB: Stack grows downwards in our print + """ + tx = self.__get_tx(stacklevel) + for s in tx.stack: + print(f"- {s.debug_repr()}", file=file) + + def print_locals(self, *, file=None, stacklevel=0): + """ + Print all of the locals available in the current context. + By default this view is very limited; you can get more information + about any individual local using get_local(). + """ + tx = self.__get_tx(stacklevel) + for k, v in tx.symbolic_locals.items(): + print(f"{k} = {v.debug_repr()}", file=file) + + def print_bt(self, *, file=None, stacklevel=0): + """ + Print the user code backtrace, starting at the beginning of the + frame Dynamo started evaluating. Note that this MAY NOT go all + the way to the torch.compile invocation, as we may have done + a graph break and are compiling an intermediate frame as the + starting point. If you think the other behavior would be better, + file a bug at https://github.com/pytorch/pytorch/ + """ + stack = [] + tx = self.__get_tx(stacklevel) + while tx is not None: + stack.append(tx.frame_summary()) + tx = getattr(tx, "parent", None) + print( + "".join(traceback.StackSummary.from_list(reversed(stack)).format()), + file=file, + ) + + def print_guards(self, *, file=None): + """ + Print the currently installed guards for the Dynamo context. + This does NOT include guards associated with variables that + may or may not be installed in the future if those variables + are used. + """ + # TODO: improve print format, current guard format is extremely + # verbose + print( + "\n".join(f"{repr(guard)}" for guard in sorted(self.__tx.output.guards)), + file=file, + ) + + def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self): + """ + Returns the internal data structure InstructionTranslator that Dynamo + uses to track state of symbolic evaluation. There are no BC + guarantees on this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if + you rely on it. + """ + return self.__tx + + def sleep(self, sec): + time.sleep(sec) + + +class _Comptime: + @staticmethod + def __call__(fn, fallback_fn=lambda: None): + """fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise""" + fallback_fn() + + # Convenience wrappers that are more compact to use + + @staticmethod + def graph_break(): + comptime(lambda ctx: ctx.graph_break()) + + @staticmethod + def print(e): + comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e)) + + @staticmethod + def print_graph(): + comptime(lambda ctx: ctx.print_graph()) + + @staticmethod + def print_disas(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_disas( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + @staticmethod + def print_value_stack(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_value_stack( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + # This is a more useful variant of print_value_stack that can be used + # in an expression context; e.g., x + print_value_stack_and_return(y + z), + # you will see x on the stack prior to the addition operation + @staticmethod + def print_value_stack_and_return(e, *, stacklevel=0): + comptime( + lambda ctx: ctx.print_value_stack( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + return e + + @staticmethod + def print_locals(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_locals( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + @staticmethod + def print_bt(*, stacklevel=0): + comptime( + lambda ctx: ctx.print_bt( + stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 + ) + ) + + @staticmethod + def print_guards(): + comptime(lambda ctx: ctx.print_guards()) + + @staticmethod + def assert_static(val): + comptime(lambda ctx: ctx.assert_static(ctx.get_local("val"))) + + @staticmethod + def force_static(val): + comptime(lambda ctx: ctx.get_local("val").force_static()) + + @staticmethod + def breakpoint(): + """ + Like pdb breakpoint(), but drop into pdb whenever this line + of code is compiled by dynamo. Use it by putting + this in your model code:: + + from torch._dynamo.comptime import comptime + + comptime.breakpoint() + + And then, inside pdb, you can access 'ctx' to query things + about the compilation context:: + + (Pdb) !ctx.print_bt() + (Pdb) !ctx.print_locals() + (Pdb) p ctx.get_local("attention").as_fake() + """ + + def inner(inner_ctx): + ctx = inner_ctx.parent() # noqa: F841 + builtins.breakpoint() + + comptime(inner) + + @staticmethod + def sleep(sec): + comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant())) + + +comptime = _Comptime() diff --git a/phivenv/Lib/site-packages/torch/_dynamo/config.py b/phivenv/Lib/site-packages/torch/_dynamo/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a48994c7c22021f905e11a58ae3fee9c52df03 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/config.py @@ -0,0 +1,631 @@ +# mypy: allow-untyped-defs + +""" +Configuration module for TorchDynamo compiler and optimization settings. + +This module contains various configuration flags and settings that control TorchDynamo's +behavior, including: +- Runtime behavior flags (e.g., guard settings, specialization options) +- Debugging and development options +- Performance tuning parameters +- Feature toggles for experimental features +""" + +import getpass +import os +import sys +import tempfile +from os.path import abspath, dirname +from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union + +from torch._environment import is_fbcode +from torch.utils._config_module import Config, get_tristate_env, install_config_module + + +# to configure logging for dynamo, aot, and inductor +# use the following API in the torch._logging module +# torch._logging.set_logs(dynamo=, aot=, inductor) +# or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity) +# see this design doc for more detailed info +# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit# +# the name of a file to write the logs to +# [@compile_ignored: debug] +log_file_name: Optional[str] = None + +# [@compile_ignored: debug] Verbose will print full stack traces on warnings and errors +verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1" + +# [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend +verify_correctness = False + +# need this many ops to create an FX graph (deprecated: not used) +minimum_call_count = 1 + +# turn on/off DCE pass (deprecated: always true) +dead_code_elimination = True + +# disable (for a function) when cache reaches this size + +# controls the maximum number of cache entries with a guard on same ID_MATCH'd +# object. It also controls the maximum size of cache entries if they don't have +# any ID_MATCH'd guards. +# [@compile_ignored: runtime_behaviour] +recompile_limit = 8 + +# [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps +accumulated_recompile_limit = 256 + +# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit (deprecated: does not do anything) +skip_code_recursive_on_recompile_limit_hit = True + +# raise a hard error if cache limit is hit. If you are on a model where you +# know you've sized the cache correctly, this can help detect problems when +# you regress guards/specialization. This works best when recompile_limit = 1. +# This flag is incompatible with: suppress_errors. +# [@compile_ignored: runtime_behaviour] +fail_on_recompile_limit_hit = False + +cache_size_limit: int = Config(alias="torch._dynamo.config.recompile_limit") +accumulated_cache_size_limit: int = Config( + alias="torch._dynamo.config.accumulated_recompile_limit" +) + +# (deprecated: does not do anything) +skip_code_recursive_on_cache_limit_hit: bool = Config( + alias="torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit" +) +fail_on_cache_limit_hit: bool = Config( + alias="torch._dynamo.config.fail_on_recompile_limit_hit" +) + +# whether or not to specialize on int inputs. This only has an effect with +# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int +# inputs. Note that assume_static_by_default will also cause ints to get +# specialized, so this is mostly useful for export, where we want inputs +# to be dynamic, but accesses to ints should NOT get promoted into inputs. +specialize_int = False + +# Whether or not to specialize on float inputs. Dynamo will always promote +# float inputs into Tensor inputs, but at the moment, backends inconsistently +# support codegen on float (this is to be fixed). +specialize_float = False + +# legacy config, does nothing now! +dynamic_shapes = True + +use_lazy_graph_module = ( + os.environ.get("TORCH_COMPILE_USE_LAZY_GRAPH_MODULE", "1") == "1" +) + +# This is a temporarily flag, which changes the behavior of dynamic_shapes=True. +# When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic. +# NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API +# see [Note - on the state of mark_dynamic] +assume_static_by_default = True + +# This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction +# with assume_static_by_default=True. +# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail +# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic. +automatic_dynamic_shapes = True + +# Valid options: "dynamic", "unbacked" +automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic" + +# This flag changes how the shapes of parameters are treated. +# If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic +# If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static, +# while the shapes of torch.Tensor are assumed to be dynamic. +force_parameter_static_shapes = True + +# This flag ensures that the shapes of a nn module are always assumed to be static +# If the flag is set to True, then the shapes of a nn.module are assumed to be static +# If the flag is set to False, then the shapes of a nn.module can be dynamic +force_nn_module_property_static_shapes = True + +# Typically, if you mark_dynamic a dimension, we will error if the dimension +# actually ended up getting specialized. This knob changes the behavior so +# that we don't error at all. This is helpful for our CI where I'm using a +# heuristic to mark batch dimensions as dynamic and the heuristic may get it +# wrong. +allow_ignore_mark_dynamic = False + +# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing) +guard_nn_modules = True + +# Uses CPython internal dictionary tags to detect mutation. There is some +# overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag. +# guard_nn_modules unspecializes the nn module instance and adds guard for each +# relevant member of the nn modules. On the other hand, +# guard_nn_modules_using_dict_tags specializes on each nn module instance but +# uses low overhead dict version matching to detect mutations, obviating the +# need to guard on members of the nn modules. With +# guard_nn_modules_using_dict_tags, the guard_nn_modules is not really required +# but kept around for debugging and discussing unspecializing nn module +# variables. +# TODO(janimesh, voz): Remove both of these flags (or at least guard_nn_modules) +# once we have reached stability for the guard_nn_modules_using_dict_tags. +guard_nn_modules_using_dict_tags = True + +# Flag to enable preparation for graph freezing, so that the named parameters and +# buffers are passed as params_flat in tracing context by AOT autograd. +# Non-Inductor backends can use this list for graph freezing. +prepare_freezing = os.environ.get("TORCHDYNAMO_PREPARE_FREEZING", "0") == "1" + +# NOTE this has been deprecated, it does nothing now. +traceable_tensor_subclasses: set[type[Any]] = set() + +# If a tensor subclass is put into this set, Dynamo will model its instasnces in +# a very conservative and limited way (most likely causing lots of graph breaks +# if one apply tensor ops on these instances). This is useful if you encounter +# internal compiler errors from Dynamo which are caused by tensor subclasses, +# and you are willing to tolerate potential graph breaks rather than hard error. +nontraceable_tensor_subclasses: set[type[Any]] = set() + +# Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager. +# This is a good way to get your model to work one way or another, but you may +# lose optimization opportunities this way. Devs, if your benchmark model is failing +# this way, you should figure out why instead of suppressing it. +# This flag is incompatible with: fail_on_recompile_limit_hit. +suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False)) + +# Record and write an execution record of the current frame to a file +# if an exception is encountered +# @compile_ignored[debug] +replay_record_enabled = os.environ.get("TORCH_COMPILE_REPLAY_RECORD", "0") == "1" + +# Rewrite assert statement in python with torch._assert +rewrite_assert_with_torch_assert = True + +# Disable dynamo +disable = os.environ.get("TORCH_COMPILE_DISABLE", "0") == "1" + +# [@compile_ignored: runtime_behaviour] Get a cprofile trace of Dynamo +cprofile = os.environ.get("TORCH_COMPILE_CPROFILE", False) + +# legacy config, does nothing now! +skipfiles_inline_module_allowlist: dict[Any, Any] = {} + +# If a string representing a PyTorch module is in this ignorelist, +# the `allowed_functions.is_allowed` function will not consider it +# when creating a list of PyTorch functions that will appear in +# FX IR. +allowed_functions_module_string_ignorelist = { + "torch.distributions", + "torch.testing", + "torch._refs", + "torch._prims", + "torch._decomp", +} + +# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"} +# None - Minifier is switched off +# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails +# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails +# [@compile_ignored: debug] +repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None) + +# Compiler compilation debug info +# 1: Dumps the original graph out to repro.py if compilation fails +# 2: Dumps a minifier_launcher.py if compilation fails. +# 3: Always dumps a minifier_launcher.py. Good for segfaults. +# 4: Dumps a minifier_launcher.py if the accuracy fails. +# [@compile_ignored: debug] +repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2)) + +# By default, we try to detect accuracy failure by running both forward +# and backward of a torchdynamo produced graph (if you are using repro_after +# 'dynamo'). This setting forces us to only test the forward graph and +# not the backward graph. This can be helpful if you're trying to debug +# an inference only problem, but the minifier seems to be choking on the +# backwards step +# TODO: Detect this situation automatically so the user doesn't need +# to manually configure this +# [@compile_ignored: debug] +repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1" + +# The tolerance we should use when testing if a compiled graph +# has diverged so that we should treat it as an accuracy failure +# [@compile_ignored: debug] +repro_tolerance = 1e-3 + + +# Whether to ignore non-floating point values when checking accuracy. +# Checking accuracy of non-floating point values such as boolean tensors +# can lead to false positives. +# [@compile_ignored: debug] +repro_ignore_non_fp = os.environ.get("TORCHDYNAMO_REPRO_IGNORE_NON_FP") == "1" + +# If True, when testing if two models are the same, we will test them against +# a third fp64 reference and only report a problem if the RMSE relative to the +# fp64 is greater. However, this will use more memory; you may disable this +# if memory usage is too high. +# [@compile_ignored: runtime_behaviour] +same_two_models_use_fp64 = True + +# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type. +# When this flag is set to False, we introduce a graph break instead of capturing. +# This requires dynamic_shapes to be True. +capture_scalar_outputs = os.environ.get("TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS") == "1" + +# Not all backends support operators that have dynamic output shape (e.g., +# nonzero, unique). When this flag is set to False, we introduce a graph +# break instead of capturing. This requires dynamic_shapes to be True. +# If you set this to True, you probably also want capture_scalar_outputs +# (these are separated for historical reasons). +capture_dynamic_output_shape_ops = ( + os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1" +) + +# hybrid backed unbacked symints +prefer_deferred_runtime_asserts_over_guards = False + +# For complex dynamic shapes guards that we're unable to specify with dynamo/export's +# range constraints + dims + derived dims language, we raise constraint violation +# errors or specialize by default. If set to True, this flag avoids crashing/specialization, +# and allows complex guards as runtime assertions in the graph. +allow_complex_guards_as_runtime_asserts = False + +# By default, dynamo will treat all ints as backed SymInts, which means (1) it +# will wait to see the int change over multiple runs before generalizing and +# (2) it will still always 0/1 specialize an int. When true, this knob +# forces dynamo to treat _length_per_key and _offset_per_key on +# KeyedJaggedTensor from torchrec as size-like unbacked SymInts, so that +# they (1) generalize immediately and (2) unsoundly never compare equal to +# 0/1. This is not on by default as AOTAutograd/Inductor cannot currently +# compile this code; however, this can be useful for export. +force_unspec_int_unbacked_size_like_on_torchrec_kjt = False + +# Currently, Dynamo will always specialize on int members of NN module. +# However, there could be cases where this is undesirable, e.g., when tracking +# step count leading to constant recompilation and eventually eager fallback. +# Setting this flag to True will allow int members to be potentially unspecialized +# through dynamic shape mechanism. +# Defaults to False for BC. +allow_unspec_int_on_nn_module = False + +# Specify how to optimize a compiled DDP module. The flag accepts a boolean +# value or a string. There are 3 modes. +# 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically +# split model graph into pieces to match DDP bucket sizes to allow DDP +# comm/compute overlap. +# 2. "python_reducer" (experimental): this optimization requires the usage +# of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer +# and use the Python reducer to allow compiled_autograd to trace the +# communication and allow comm/compute overlap without graph-breaks. +# 3. "no_optimization" (or False): Dynamo won't split the model graph, nor +# will Python reducer be used. With this mode, there will be no graph-breaks +# and the original DDP C++ reducer will be used. There will no comm/compute +# overlap. This mode CANNOT be used with compiled_autograd. +# Note that to avoid breaking the existing usage, mode 1 and mode 4 can be +# specified with a boolean value. True is using ddp_optimizer and False is +# no optimization. +optimize_ddp: Union[ + bool, + Literal[ + "ddp_optimizer", + "python_reducer", + "python_reducer_without_compiled_forward", + "no_optimization", + ], +] = True + +# By default, Dynamo emits runtime asserts (e.g. torch._check, torch._check_is_size) in the graph. +# In some cases those asserts could be performance costly +# E.g. torch._check(tensor[0].item() > 2) for tensor on cuda will require cuda sync. +# Setting this to True keeps them hinting to symbolic shapes engine, +# but not be emitted in the graph. +do_not_emit_runtime_asserts: bool = ( + os.environ.get("TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS", "0") == "1" +) + +# Skip tracing the torchrec files added to trace_rules.FBCODE_SKIP_DIRS +skip_torchrec = True + +# Don't apply most trace_rules.py rules +dont_skip_tracing = False + +# No longer used +optimize_ddp_lazy_compile = False + +# Whether to skip guarding on FSDP-managed modules +skip_fsdp_guards = True +# Whether to apply torch._dynamo.disable() to FSDP2 hooks. +# Defaults to True. If Traceable FSDP2 is used, set this to False. +skip_fsdp_hooks = True + +# Make dynamo skip guarding on hooks on nn modules +# Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them, +# dynamo will not notice and will execute whichever version you first compiled. +skip_nnmodule_hook_guards = True + +# Make dynamo skip no tensor aliasing guard on parameters +# Note: unsafe: if you compile a function with different parameters as inputs, +# and then later pass on the same parameter as two inputs, dynamo will not +# notice and lead to incorrect result. +skip_no_tensor_aliasing_guards_on_parameters = True + +# Considers a tensor immutable if it is one of the values of a dictionary, and +# the dictionary tag is same across invocation calls. +skip_tensor_guards_with_matching_dict_tags = True + +# If True, raises exception if TorchDynamo is called with a context manager +raise_on_ctx_manager_usage = True + +# If True, raise when aot autograd is unsafe to use +raise_on_unsafe_aot_autograd = False + +# This flag is ignored and maintained for backwards compatibility. +error_on_nested_jit_trace = True + +# If true, error with a better message if we symbolically trace over a +# dynamo-optimized function. If false, silently suppress dynamo. +error_on_nested_fx_trace = True + +# Disables graph breaking on rnn. YMMV with backends. +allow_rnn = False + +# If true, enables feature that captures PyTorch sparsity in the +# exported FX graph. This flag should become the default eventually +# and be removed, but currently provides a way to fall back to old +# graph breaking behavior. +capture_sparse_compute = False if is_fbcode() else True + +# If true, error if we try to compile a function that has +# been seen before. +# [@compile_ignored: runtime_behaviour] +error_on_recompile = False + +# [@compile_ignored: debug] Whether to report any guard failures (deprecated: does not do anything) +report_guard_failures = True + +# [@compile_ignored: debug] root folder of the project +base_dir = dirname(dirname(dirname(abspath(__file__)))) + +# Trace through NumPy or graphbreak +trace_numpy = True + +# Default NumPy dtypes when tracing with torch.compile +# We default to 64bits. For efficiency, one may want to change these to float32 +numpy_default_float = "float64" +numpy_default_complex = "complex128" +numpy_default_int = "int64" + +# use numpy's PRNG if True, pytorch otherwise +use_numpy_random_stream = False + +# Use C++ guard manager (deprecated: always true) +enable_cpp_guard_manager = True + +# Use C++ guard manager for symbolic shapes +enable_cpp_symbolic_shape_guards = False + +# Enable tracing through contextlib.contextmanager +enable_trace_contextlib = True + +# Enable tracing through unittest +enable_trace_unittest = False + +# Enable tracing generator functions lazily. If False, Dynamo will exhaust +# generators upon first execution. And if True, the generator will be accessed lazily +enable_faithful_generator_behavior = True + +# Inline inbuilt nn modules +inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated] + default=True, + justknob="pytorch/compiler:inline_inbuilt_nn_modules", +) + +# Install "free" tensor variables (globals, non-locals, nn module attributes) +# as graph attributes. This is useful for export, as it +# produces a consistent number of inputs to the graph. +install_free_tensors = False + +# Use C++ FrameLocalsMapping (raw array view of Python frame fastlocals) (deprecated: always True) +enable_cpp_framelocals_guard_eval = True + +# Whether to automatically find and replace identical graph +# regions with a call to invoke_subgraph +use_graph_deduplication = False + +# Whether to track nodes for deduplication (testing only) +# This flag is ignored if use_graph_deduplication is True +track_nodes_for_deduplication = False + +# Whether to lint the graph after each region is replaced +# (Debug) +graph_deduplication_lint = False + +# Issues a warning in Python 3.13.0 for possibly slower guard evaluation and +# instructs user to attempt using 3.13.1+, where the CPython bug is fixed. +# Should be disabled in dynamo-wrapped tests since some tests check that no warnings are issued. +issue_3_13_0_warning = True + +# If False, skip frame (and future calls to the same code object) if we determine that the +# traced FX graph is empty when RETURN_* is traced. +allow_empty_graphs = False + +# When set, total compile time instruction count is recorded using +# torch._dynamo.utilsCompileTimeInstructionCounter. +record_compile_time_instruction_count = False + + +def default_debug_dir_root(): + # [@compile_ignored: debug] + DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR" + if DEBUG_DIR_VAR_NAME in os.environ: + return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug") + elif is_fbcode(): + return os.path.join( + tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug" + ) + else: + return os.path.join(os.getcwd(), "torch_compile_debug") + + +# [@compile_ignored: debug] +debug_dir_root = default_debug_dir_root() + +# [@compile_ignored: debug] +_save_config_ignore = { + "repro_after", + "repro_level", + # workaround: "cannot pickle PyCapsule" + "constant_functions", + # workaround: "cannot pickle module" + "skipfiles_inline_module_allowlist", +} + +# for backend="cudagraphs", mutations on input be sent to the cudagraph backend +# or replayed in aot_autograd epilogue. default is False because mutation on inputs +# can prevent cudagraphing. +cudagraph_backend_keep_input_mutation = False + +# enable cudagraph support for mutated inputs from prior cudagraph pool +cudagraph_backend_support_input_mutation = False + +# When True, only ops that have the torch.Tag.pt2_compliant tag +# will be allowed into the graph; all other ops will be disallowed +# and will fall back to eager-mode PyTorch. Useful to ensure +# correctness of custom ops. +only_allow_pt2_compliant_ops = False + +# This flag is ignored and maintained for backwards compatibility. +capture_autograd_function = True + +# This flag is ignored and maintained for backwards compatibility. +capture_func_transforms = True + +# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode). +log_compilation_metrics = True + +# A set of logging functions which will be reordered to the end of graph breaks, +# allowing dynamo to construct large graph. Note that there are some +# limitations to this, such as how it does not correctly print objects that were +# mutated after the print statement. +reorderable_logging_functions: set[Callable[[Any], None]] = set() + +# A set of methods that will be ignored while tracing, +# to prevent graph breaks. +# Add logging.Logger. to ignore all calls for method, +# or logger. to ignore calls for method from this logger instance only. +ignore_logger_methods: set[Callable[..., Any]] = set() + +# simulates what would happen if we didn't have support for BUILD_SET opcode, +# used for testing +inject_BUILD_SET_unimplemented_TESTING_ONLY = False + +_autograd_backward_strict_mode_banned_ops = [ + "layout", + "is_neg", + "is_conj", + "is_pinned", +] + +_autograd_backward_strict_mode_conditional_banned_ops = [ + "stride", + "storage_offset", + "is_contiguous", +] + +# Enables caching of dispatches to fake tensors. +fake_tensor_cache_enabled = ( + os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1" +) + +# Enables cross checking between the fake tensor cache and dispatch. +fake_tensor_cache_crosscheck_enabled = ( + os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1" +) + +# Disables inference mode for fake tensor prop during compilation. At runtime, +# the inference_mode is still respected. +fake_tensor_disable_inference_mode = True + +# Enables the Compiled Autograd engine to trace autograd calls made under torch.compile(). +# Note: AOTAutograd will still trace and partition an AOT backward graph local to that +# compiled region. But AOTAutograd traces without knowledge of backward hooks which are +# coordinated by the Autograd engine, and under the hood, it uses the torch.autograd.grad +# API, so it cannot capture gradient accumulation operations (AccumulateGrad). +# +# Compiled Autograd will trace all autograd operations as seen by the Autograd engine. +# This flag will also lift certain restrictions during the forward trace such as +# registering backward hooks on tensors contained within the compiled region. +compiled_autograd = False + +# Overrides torch.compile() kwargs for Compiled Autograd: +compiled_autograd_kwargs_override: dict[str, Any] = {} + +# Enables use of collectives *during* compilation to synchronize behavior +# across ranks. Today, this is used solely to modify automatic_dynamic_shapes +# behavior, making it so that we infer that if an input is dynamic by +# inspecting whether or not its input size varies across ranks. Because +# this synchronization uses collectives, all ranks must run compilation at +# the same time; ranks must not diverge with graph breaks. This can be most +# reliably achieved by ensuring PT2 only is run on SPMD programs. If this +# invariant is inviolated, you will likely deadlock NCCL and encounter a +# NCCL timeout. +enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1" + +# Enables a local, filesystem "profile" which can be used for automatic +# dynamic decisions, analogous to profile-guided optimization. This config +# ONLY has an effect if torch.compiler.config.workflow_id is specified, +# which specifies the name of the profile we will save/load. +# +# The idea is that if we observe that a particular input is dynamic over +# multiple iterations on one run, we can save a profile with this information +# so the next time we run we can just make it dynamic the first time around, +# skipping an unnecessary static compilation. The profile can be soundly +# stale, if it is wrong, it just means we may make more things dynamic than +# was actually necessary (NB: this /can/ cause a failure if making something +# dynamic causes the compiler to stop working because you tickled a latent +# bug.) +# +# The profile is ONLY guaranteed to work if the user source code is 100% +# unchanged. Applying the profile if there are user code changes is only +# best effort otherwise. In particular, we identify particular code objects +# by filename, line number and name of their function, so adding/removing newlines +# will typically cause cache misses. We continuously update the profile, +# so if we only discover something is dynamic on the second run, we will update +# the profile for subsequent runs. +automatic_dynamic_local_pgo: bool = Config( + justknob="pytorch/remote_cache:enable_local_automatic_dynamic_pgo", + env_name_force="TORCH_DYNAMO_AUTOMATIC_DYNAMIC_LOCAL_PGO", + default=True, +) + +# Like above, but using remote cache +automatic_dynamic_remote_pgo: Optional[bool] = get_tristate_env( + "TORCH_DYNAMO_AUTOMATIC_DYNAMIC_REMOTE_PGO" +) + +# temporary config to kill later +_unsafe_skip_fsdp_module_guards = ( + os.environ.get("UNSAFE_SKIP_FSDP_MODULE_GUARDS", "0") == "1" +) + +# Run GC at the end of compilation +run_gc_after_compile = Config( # type: ignore[var-annotated] + default=True, + justknob="pytorch/compiler:enable_run_gc_after_compile", + env_name_default="TORCH_DYNAMO_RUN_GC_AFTER_COMPILE", +) + +# Takes the function/module decorated with torch.compile and passes it through a +# wrapper. This ensures that nn.module hooks are also compiled in the same frame. +wrap_top_frame = False + +# Flag to record runtime overhead in profile traces. Used for pre-graph bytecode +# and AOTAutograd runtime wrapper. +record_runtime_overhead = True + +# HACK: this is for testing custom ops profiling only +_custom_ops_profile: Optional[Any] = None + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + + def _make_closure_patcher(**changes): ... + + +install_config_module(sys.modules[__name__]) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/convert_frame.py b/phivenv/Lib/site-packages/torch/_dynamo/convert_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..d114ab8dd449433eadbe7fc82aac31d1c401e5fa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/convert_frame.py @@ -0,0 +1,1503 @@ +# mypy: allow-untyped-decorators + +""" +This module implements TorchDynamo's core frame conversion functionality, transforming Python +frames into FX graphs. It handles: + +- Frame analysis and bytecode transformation +- Guard creation and management for dynamic behaviors +- Cache management for recompilation +- Error handling and fallback mechanisms + +Key classes: +- ConvertFrame: Main entry point for frame conversion with error handling +- ConvertFrameAssert: Implements core frame to graph conversion logic +- Tracker: Tracks input/output code objects during conversion +- CatchErrorsWrapper: Provides error handling and suppression logic + +The conversion process preserves program semantics while enabling optimizations +through torch.compile() and related systems. +""" + +from __future__ import annotations + +import collections +import contextlib +import cProfile +import dis +import functools +import gc +import itertools +import logging +import os +import pstats +import random +import subprocess +import sys +import threading +import time +import traceback +import typing +import weakref +from pathlib import Path +from types import CellType, CodeType, FunctionType, ModuleType +from typing import Any, Callable, Optional, TypeVar, Union +from typing_extensions import ParamSpec +from weakref import ReferenceType + +import torch +import torch._logging +from torch._C._dynamo.guards import GlobalStateGuard +from torch._dynamo.callback import CallbackTrigger +from torch._dynamo.distributed import get_compile_pg +from torch._dynamo.symbolic_convert import TensorifyState +from torch._guards import compile_context, CompileContext, CompileId, tracing +from torch._logging import structured +from torch._utils_internal import ( + compile_time_strobelight_meta, + justknobs_check, + maybe_upload_prof_stats_to_manifold, + signpost_event, +) +from torch.fx._lazy_graph_module import _use_lazy_graph_module +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + GuardOnDataDependentSymNode, +) +from torch.fx.graph_module import _forward_from_src as original_forward_from_src +from torch.monitor import _WaitCounter +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils._python_dispatch import ( + _disable_current_modes, + is_in_torch_dispatch_mode, +) +from torch.utils._traceback import CapturedTraceback, format_traceback_short + +from . import config, decorators, exc, graph_break_hints, trace_rules +from .bytecode_analysis import remove_dead_code, remove_pointless_jumps +from .bytecode_transformation import ( + check_inst_exn_tab_entries_valid, + Instruction, + is_generator, + propagate_inst_exn_table_entries, + transform_code_object, +) +from .cache_size import ( + CacheSizeRelevantForFrame, + compute_cache_size, + exceeds_recompile_limit, + is_recompilation, +) +from .eval_frame import ( + always_optimize_code_objects, + dynamo_tls, + skip_code, + TorchPatcher, +) +from .exc import ( + augment_exc_message, + BackendCompilerFailed, + FailOnRecompileLimitHit, + format_error_msg, + InternalTorchDynamoError, + PackageError, + RecompileLimitExceeded, + ShortenTraceback, + SkipCodeRecursiveException, + TorchRuntimeError, + UncapturedHigherOrderOpError, + unimplemented_v2, + Unsupported, +) +from .guards import ( + CheckFunctionManager, + get_and_maybe_log_recompilation_reasons, + GuardedCode, +) +from .hooks import Hooks +from .pgo import log_frame_dynamic_whitelist, put_code_state +from .replay_record import ExecutionRecord +from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX +from .symbolic_convert import ( + DistributedState, + ExceptionStack, + InstructionTranslator, + LocalState, + SpeculationLog, +) +from .trace_rules import is_numpy +from .types import ConvertFrameReturn, FrameAction, FrameExecStrategy, wrap_guarded_code +from .utils import ( + chromium_event_timed, + CleanupManager, + CompileTimeInstructionCounter, + counters, + dynamo_timed, + format_bytecode, + gen_record_file_name, + get_metrics_context, + increment_frame, + is_namedtuple, + istype, + LazyString, + maybe_disable_inference_mode, + maybe_disable_inference_mode_for_fake_prop, + orig_code_map, + reset_graph_break_dup_checker, + setup_compile_debug, + to_int_us, + troubleshooting_url, + write_record_to_file, +) +from .variables.torch_function import torch_function_mode_stack_state_mgr + + +np: Optional[ModuleType] +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +if typing.TYPE_CHECKING: + from .backends.registry import CompilerFn + from .package import CompilePackage + from .repro.after_dynamo import WrapBackendDebug + from .types import BytecodeHook, CacheEntry, DynamoFrameType + from .variables.builder import FrameStateSizeEntry + + +log = logging.getLogger(__name__) +bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode") +graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") + + +compile_lock = threading.RLock() + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +class TODO_UNKNOWN: + pass + + +class Tracker: + def __init__(self) -> None: + self.seen: list[ReferenceType[CodeType]] = [] + self.seen_ids: set[int] = set() + + def add(self, strong_obj: CodeType) -> None: + idx = id(strong_obj) + if idx not in self.seen_ids: + obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx)) + self.seen.append(obj) + self.seen_ids.add(idx) + + def __contains__(self, item: CodeType) -> bool: + return id(item) in self.seen_ids + + def clear(self) -> None: + self.seen.clear() + self.seen_ids.clear() + + +input_codes = Tracker() +output_codes = Tracker() + +initial_global_state: Optional[GlobalStateGuard] = None + + +@functools.wraps(original_forward_from_src) +def fx_forward_from_src_skip_result( + src: str, globals: dict[str, Any], co_fields: Optional[dict[str, str]] = None +) -> FunctionType: + # we monkey patch FX to prevent infinite loop of trying to convert + # our generated code + result = original_forward_from_src(src, globals, co_fields) + skip_code(result.__code__) + return result + + +def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: + """ + Context manager to: + 1) Save/restore torch.is_grad_enabled() state + 2) Save/restore python random state + 3) Save/restore torch random state + 4) Monkey patch torch.fx.graph_module._forward_from_src + """ + + @functools.wraps(fn) + def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: + guards = GlobalStateGuard() + prior_grad_mode = torch.is_grad_enabled() + + # Just in case we get left in a bad dispatch state we want to restore + # it. This can happen because the dispatch bits aren't a true + # stack/counter - so we can't just increment/decrement them as we enter + # and leave. + with ( + torch._C._PreserveDispatchKeyGuard(), + maybe_disable_inference_mode(), + maybe_disable_inference_mode_for_fake_prop(), + ): + prior_inference_mode = torch.is_inference_mode_enabled() + prior_deterministic = torch.are_deterministic_algorithms_enabled() + prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled() + prior_mobile_allocator_state = ( + torch._C._is_default_mobile_cpu_allocator_set() + ) + py_rng_state = random.getstate() + prior_dtype = torch.get_default_dtype() + torch_rng_state = torch.random.get_rng_state() + cuda_rng_state = None + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + allow_tf32 = torch._C._get_cublas_allow_tf32() + prior_fwd_from_src = torch.fx.graph_module._forward_from_src + torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result + cleanup = setup_compile_debug() + exit_stack = contextlib.ExitStack() + exit_stack.enter_context( + torch.fx._symbolic_trace._maybe_revert_all_patches() + ) + exit_stack.enter_context(torch_function_mode_stack_state_mgr) + try: + return fn(*args, **kwargs) + finally: + cleanup.close() + assert torch._C._len_torch_function_stack() == 0, ( + "Torch function mode stack state changed while dynamo tracing, please report a bug" + ) + exit_stack.close() + torch._C._set_grad_enabled(prior_grad_mode) + torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) + torch.use_deterministic_algorithms( + prior_deterministic, warn_only=prior_warn_only + ) + random.setstate(py_rng_state) + torch.random.set_rng_state(torch_rng_state) + torch.set_default_dtype(prior_dtype) + curr_mobile_allocator_state = ( + torch._C._is_default_mobile_cpu_allocator_set() + ) + if prior_mobile_allocator_state != curr_mobile_allocator_state: + torch._C._unset_default_mobile_cpu_allocator() + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + torch._C._set_cublas_allow_tf32(allow_tf32) + torch.fx.graph_module._forward_from_src = prior_fwd_from_src + assert guards.check(), ( + f"Global {guards.reason()}state changed while dynamo tracing, please report a bug" + ) + + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + return _fn + + +@TorchPatcher.suppress_torch_distributed_warnings +def has_tensor_in_frame(frame: DynamoFrameType) -> bool: + """Check if the frame has torch.* related bits""" + # Check if the function was decorated using torch._dynamo.optimize + if frame.f_code in always_optimize_code_objects: + return True + + # Check if there is global import of torch.* + for co_name in frame.f_code.co_names: + if co_name in frame.f_globals: + obj = frame.f_globals[co_name] + if isinstance(obj, ModuleType) and ( + obj.__name__.startswith("torch.") or obj is torch + ): + return True + # ... or a global import of numpy.* + if np and config.trace_numpy and (obj is np or is_numpy(obj)): + return True + + seen_ids: dict[int, bool] = {} + + def has_tensor(obj: object) -> bool: + """Recursively check if the obj has a tensor""" + obj_id = id(obj) + if obj_id in seen_ids: + return seen_ids[obj_id] + seen_ids[obj_id] = False + + if isinstance(obj, (torch.Tensor, torch.nn.Module)) or ( + istype(obj, type) and issubclass(obj, torch.nn.Module) + ): + seen_ids[obj_id] = True + return seen_ids[obj_id] + elif ( + config.trace_numpy + and np + and (istype(obj, np.ndarray) or isinstance(obj, np.generic)) + ): + seen_ids[obj_id] = True + return seen_ids[obj_id] + elif istype(obj, (list, tuple)): + seen_ids[obj_id] = any(has_tensor(v) for v in obj) + return seen_ids[obj_id] + elif istype(obj, dict): + # Some packages like pytest can be updated during runtime. So, make a + # copy of values to avoid issues like "RuntimeError: dictionary + # changed size during iteration" + values = list(obj.values()) + seen_ids[obj_id] = any(has_tensor(v) for v in values) + return seen_ids[obj_id] + elif istype(obj, (str, int, float, type(None), bool)): + seen_ids[obj_id] = False + return seen_ids[obj_id] + elif is_namedtuple(obj) and hasattr(obj, "_fields"): + seen_ids[obj_id] = any(has_tensor(getattr(obj, v)) for v in obj._fields) + return seen_ids[obj_id] + else: + # if config.debug: + # print( + # f"Assuming that object of type {type(obj)} does not have a tensor" + # ) + return False + + # Check if the passed arguments are of type Tensor + for value in frame.f_locals.values(): + if has_tensor(value): + return True + + log.debug( + "skipping because no torch.* %s \ + %s %s", + frame.f_code.co_name, + frame.f_code.co_filename, + frame.f_code.co_firstlineno, + ) + + return False + + +def exception_handler( + e: Exception, + code: CodeType, + frame: Optional[DynamoFrameType] = None, + export: bool = False, +) -> None: + record_filename = None + if hasattr(e, "exec_record"): + record_filename = gen_record_file_name(e, code) + write_record_to_file(record_filename, e.exec_record) + e.record_filename = record_filename # type: ignore[attr-defined] + + augment_exc_message(e, export=export) + + +FRAME_COUNTER = 0 +FRAME_COMPILE_COUNTER: typing.Counter[Union[int, FrameStateSizeEntry]] = ( + collections.Counter() +) + + +def maybe_cprofile(func: Callable[_P, _T]) -> Callable[_P, _T]: + if config.cprofile: + return cprofile_wrapper(func) + return func + + +def cprofile_wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + trace_id = CompileContext.current_trace_id() + assert trace_id, "Trace id is None" + profile_path = Path( + f"/tmp/{func.__name__}_{str(trace_id).replace('/', '_')}.profile" + ) + prof = cProfile.Profile() + try: + prof.enable() + start_ts = time.time() + retval = prof.runcall(func, *args, **kwargs) + profile_latency = time.time() - start_ts + prof.disable() + except ValueError: + log.exception("failed to enable cProfile") + profile_latency = 0 + retval = func(*args, **kwargs) + log.warning( + "### Cprofile for %s trace id [%s] took %.3f seconds ###", + func.__name__, + trace_id, + profile_latency, + ) + ps = pstats.Stats(prof) + try: + prof.dump_stats(profile_path) + except OSError: + log.exception("Cannot write to %s", profile_path) + log.warning("Raw profile at %s", profile_path) + svg_path = profile_path.with_suffix(".svg") + try: + gprof2dot_process = subprocess.Popen( + [ + "gprof2dot", + "-f", + "pstats", + "--node-label=total-time-percentage", + "--node-label=self-time-percentage", + "--node-label=total-time", + str(profile_path), + ], + stdout=subprocess.PIPE, + ) + subprocess.check_call( + ["dot", "-Tsvg", "-o", str(svg_path)], + stdin=gprof2dot_process.stdout, + ) + log.warning("Generated SVG from profile at %s", svg_path) + except FileNotFoundError: + log.warning( + "Failed to generate SVG from profile -- dumping stats instead." + "Try installing gprof2dot and dot for a better visualization" + ) + ps.sort_stats(pstats.SortKey.TIME).print_stats(20) + ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20) + + if manifold_link := maybe_upload_prof_stats_to_manifold( + str(profile_path) + ): # fb-only + torch._logging.trace_structured( + "link", + lambda: {"name": "cprofile_manifold_url", "url": manifold_link}, + ) + return retval + + return profile_wrapper + + +class ConvertFrameAssert: + def __init__( + self, + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints: Optional[typing.Never] = None, + package: Optional[CompilePackage] = None, + ) -> None: + # assert export_constraints is None + reset_graph_break_dup_checker() + self._torchdynamo_orig_callable = compiler_fn + self._one_graph = one_graph + self._export = export + self._export_constraints = export_constraints + self._package = package + + @property + def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]: + return lambda backend: convert_frame_assert( + backend, + self._one_graph, + self._export, + self._export_constraints, + ) + + def __call__( + self, + frame: DynamoFrameType, + cache_entry: Optional[CacheEntry], + hooks: Hooks, + frame_state: dict[str, Union[int, FrameStateSizeEntry]], + *, + skip: int = 0, + ) -> ConvertFrameReturn: + increment_frame() + + code = frame.f_code + + cache_size = compute_cache_size(frame, cache_entry) + input_codes.add(code) + if code in output_codes: + return ConvertFrameReturn() + if ( + os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") + and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name + ): + return ConvertFrameReturn() + if code.co_name == "" and code.co_filename.endswith( + ( + "transformers/file_utils.py", + "transformers/utils/generic.py", + "diffusers/utils/outputs.py", + ) + ): + # not needed, but cleans up torchbench error stats + return ConvertFrameReturn() + if code.co_name == "__setattr__": + # setattr could be tricky to handle generally, + # but also not likely useful to compile- skip the whole frame + return ConvertFrameReturn() + if code.co_name == "__init__" and code.co_filename.startswith( + os.path.dirname(torch.optim.__file__) + ): + # optimizer support is still incomplete see + # test_state_dict in test/dynamo/test_optimizers.py + return ConvertFrameReturn() + + # Check if the frame is generated by an exec builtin call + # TODO - Running exec generated frame seems propagates f_globals to the + # next frames. + if code.co_name == "" and code.co_filename == "": + return ConvertFrameReturn() + + if ( + code.co_name == "" + and code.co_filename == "" + and not bool(frame.f_builtins) + ): + # namedtuple subclass constructor. Empty builtins cause issue with + # len keyword in LIST_LEN guard. + return ConvertFrameReturn() + + if is_generator(code): + unimplemented_v2( + gb_type="Attempt to trace generator", + context="", + explanation="Generators cannot be compiled directly with `torch.compile`.", + hints=[ + "Call a generator from inside of a non-generator Python function and " + "compile that function instead.", + *graph_break_hints.FUNDAMENTAL, + ], + ) + + if not has_tensor_in_frame(frame): + return ConvertFrameReturn() + + # skip tracing non-recursive disabled functions + # detect if the previous frame (non-convert_frame) is a non-recursive disable wrapper + prev_frame = sys._getframe() + while ( + prev_frame + and "torch/_dynamo/convert_frame.py" in prev_frame.f_code.co_filename + ): + prev_frame = prev_frame.f_back # type: ignore[assignment] + if ( + prev_frame + and prev_frame.f_code is decorators._nonrecursive_disable_wrapper_code + ): + return ConvertFrameReturn(apply_to_code=False) + + global initial_global_state + initial_global_state = GlobalStateGuard() + + global FRAME_COUNTER + if "_id" not in frame_state: + frame_state["_id"] = FRAME_COUNTER + FRAME_COUNTER += 1 + frame_id = frame_state["_id"] + assert isinstance(frame_id, int) + + frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] + FRAME_COMPILE_COUNTER[frame_id] += 1 + + compiled_autograd_id = None + if prior := CompileContext.current_compile_id(): + compiled_autograd_id = prior.compiled_autograd_id + compile_id = CompileId( + compiled_autograd_id=compiled_autograd_id, + frame_id=frame_id, + frame_compile_id=frame_compile_id, + ) + + signpost_event( + "dynamo", + "_convert_frame_assert._compile", + { + "co_name": code.co_name, + "frame_id": frame_id, + "compile_id": str(compile_id), + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs, + "accumulated_cache_size": cache_size.num_cache_entries, + }, + ) + + # Record traced frames, skipping Dynamo generated ones. + if not code.co_name.startswith(TORCH_DYNAMO_RESUME_IN_PREFIX): + info = f"{code.co_name} {code.co_filename}:{code.co_firstlineno}" + dynamo_tls.traced_frame_infos.append(info) + + with compile_context(CompileContext(compile_id)): + return _compile( + frame.f_code, + frame.f_globals, + frame.f_locals, + frame.f_builtins, + frame.closure, + self._torchdynamo_orig_callable, + self._one_graph, + self._export, + self._export_constraints, + hooks, + cache_entry, + cache_size, + frame, + frame_state=frame_state, + compile_id=compile_id, + skip=skip + 1, + package=self._package, + ) + + +def convert_frame_assert( + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints: Optional[typing.Never] = None, + package: Optional[CompilePackage] = None, +) -> ConvertFrameAssert: + """Fully convert a frame into an FX graph""" + return ConvertFrameAssert( + compiler_fn, one_graph, export, export_constraints, package + ) + + +from collections import OrderedDict + +from torch.utils.hooks import RemovableHandle + + +if typing.TYPE_CHECKING: + from .output_graph import OutputGraph + +# we have to use `OrderedDict` to make `RemovableHandle` work. +_bytecode_hooks: dict[int, BytecodeHook] = OrderedDict() + + +def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle: + """Register hooks for bytecode generated by Dynamo. The hook can do some + logging, as well as return a new code object to be used. Please refer + to `BytecodeHook` for the hook signature. + """ + handle = RemovableHandle(_bytecode_hooks) + _bytecode_hooks[handle.id] = hook + return handle + + +def _compile( + code: CodeType, + globals: dict[str, object], + locals: dict[str, object], + builtins: dict[str, object], + closure: tuple[CellType], + compiler_fn: CompilerFn, + one_graph: bool, + export: bool, + export_constraints: Optional[typing.Never], + hooks: Hooks, + cache_entry: Optional[CacheEntry], + cache_size: CacheSizeRelevantForFrame, + frame: Optional[DynamoFrameType] = None, + frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None, + *, + compile_id: CompileId, + skip: int = 0, + package: Optional[CompilePackage] = None, +) -> ConvertFrameReturn: + from torch.fx.experimental.validator import ( + bisect, + BisectValidationException, + translation_validation_enabled, + ValidationException, + ) + + # Only nonlocal defs here please! + # Time spent compiling this frame before restarting or failing analysis + dynamo_time_before_restart: float = 0.0 + output: Optional[OutputGraph] = None + tracer: Optional[InstructionTranslator] = None + + tf_mode_stack: list[torch.overrides.TorchFunctionMode] = ( + torch.overrides._get_current_function_mode_stack() + ) + + @preserve_global_state + def transform( + instructions: list[Instruction], code_options: dict[str, object] + ) -> None: + nonlocal output + nonlocal tracer + speculation_log.restart() # type: ignore[has-type] + exn_vt_stack = ExceptionStack() + tracer = InstructionTranslator( + instructions, + code, + locals, + globals, + builtins, + closure, + tf_mode_stack, + code_options, + compiler_fn, + one_graph, + export, + export_constraints, + frame_state=frame_state, + speculation_log=speculation_log, # type: ignore[has-type] + exn_vt_stack=exn_vt_stack, + distributed_state=distributed_state, # type: ignore[has-type] + package=package, + ) + + try: + tracer.output.mark_bytecode_tracing_start() + with tracing(tracer.output.tracing_context), tracer.set_current_tx(): + tracer.run() + except exc.UnspecializeRestartAnalysis: + speculation_log.clear() # type: ignore[has-type] + raise + except ( + exc.SpeculationRestartAnalysis, + exc.TensorifyScalarRestartAnalysis, + exc.SkipFrame, + ): + raise + except Exception: + if translation_validation_enabled(): + bisect(tracer.output.shape_env) + raise + finally: + tracer.output.call_cleanup_hooks() + + output = tracer.output + assert output is not None + assert output.output_instructions + instructions[:] = output.output_instructions + code_options.update(output.code_options) + propagate_inst_exn_table_entries(instructions) + check_inst_exn_tab_entries_valid(instructions) + instructions[:] = remove_pointless_jumps(remove_dead_code(instructions)) + + @compile_time_strobelight_meta(phase_name="compile_inner") + def compile_inner( + code: CodeType, + one_graph: bool, + hooks: Hooks, + transform: Callable[[list[Instruction], dict[str, Any]], Any], + ) -> ConvertFrameReturn: + with contextlib.ExitStack() as stack: + stack.enter_context( + torch._dynamo.callback_handler.install_callbacks( + CallbackTrigger.DYNAMO, str(CompileContext.current_compile_id()) + ) + ) + stack.enter_context(CompileTimeInstructionCounter.record()) + return _compile_inner(code, one_graph, hooks, transform) + + return ( + ConvertFrameReturn() + ) # dead, but see https://github.com/python/mypy/issues/7577 + + @maybe_cprofile + def _compile_inner( + code: CodeType, + one_graph: bool, + hooks: Hooks, + transform: Callable[[list[Instruction], dict[str, Any]], Any], + ) -> ConvertFrameReturn: + nonlocal dynamo_time_before_restart + last_attempt_start_time = start_time = time.time() + + def log_bytecode( + prefix: str, name: str, filename: str, line_no: int, code: CodeType + ) -> None: + if bytecode_log.isEnabledFor(logging.DEBUG): + bytecode_log.debug( + format_bytecode(prefix, name, filename, line_no, code) + ) + + log_bytecode( + "ORIGINAL BYTECODE", + code.co_name, + code.co_filename, + code.co_firstlineno, + code, + ) + + out_code = None + for attempt in itertools.count(): + CompileContext.get().attempt = attempt + try: + with dynamo_timed( + f"compile_attempt_{attempt}", log_pt2_compile_event=True + ): + out_code = transform_code_object(code, transform) + break + except exc.RestartAnalysis as e: + if not isinstance(e, exc.TensorifyScalarRestartAnalysis): + TensorifyState.clear() + log.info( + "Restarting analysis due to %s", + LazyString(format_traceback_short, e.__traceback__), + ) + # If restart reason is None just log the type of the exception + restart_reasons.add(e.restart_reason or str(type(e))) + # We now have a new "last attempt", reset the clock + last_attempt_start_time = time.time() + if attempt > 100: + unimplemented_v2( + gb_type="Excessive RestartAnalysis() calls", + context="", + explanation="Dynamo attempted to trace the same frame 100+ times. " + "Giving up on compiling as the compile time tradeoff is likely not " + "worth the performance gain.", + hints=[], + ) + except exc.SkipFrame as e: + if not isinstance(e, exc.TensorifyScalarRestartAnalysis): + TensorifyState.clear() + log.debug( + "Skipping frame %s %s \ + %s %s", + e, + code.co_name, + code.co_filename, + code.co_firstlineno, + ) + if one_graph: + log.debug("No graph captured with one_graph=True") + return ConvertFrameReturn() + + assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type] + "compiler collective wasn't run before compilation completed" + ) + + assert out_code is not None + log_bytecode( + "MODIFIED BYTECODE", + code.co_name, + code.co_filename, + code.co_firstlineno, + out_code, + ) + + for hook in _bytecode_hooks.values(): + hook_output = hook(code, out_code) + if hook_output is not None: + out_code = hook_output + + orig_code_map[out_code] = code + output_codes.add(out_code) + dynamo_time_before_restart = last_attempt_start_time - start_time + assert output is not None + + # Tests for new code objects. + # The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c + # Only test once the code object is created. + # They are not tested during runtime. + + def count_args(code: CodeType) -> int: + import inspect + + return ( + code.co_argcount + + code.co_kwonlyargcount + + bool(code.co_flags & inspect.CO_VARARGS) + + bool(code.co_flags & inspect.CO_VARKEYWORDS) + ) + + assert out_code is not None + + total_argcount_old = count_args(code) + total_argcount_new = count_args(out_code) + msg = "arg mismatch: " + msg += f"old code object has args {code.co_varnames[:total_argcount_old]}, " + msg += f"new code object has args {out_code.co_varnames[:total_argcount_new]}" + assert ( + code.co_varnames[:total_argcount_old] + == out_code.co_varnames[:total_argcount_new] + ), msg + + msg = "free var mismatch: " + msg += f"old code object has free var {code.co_freevars}, " + msg += f"new code object has free var {out_code.co_freevars}" + assert code.co_freevars == out_code.co_freevars, msg + + msg = "cell var mismatch: " + msg += f"old code object has cell var {code.co_cellvars}, " + msg += f"new code object has cell var {out_code.co_cellvars}" + assert code.co_cellvars == out_code.co_cellvars, msg + + # Skipping Dynamo on a frame without any extracted graph. + # This does not affect eager functionality. But this is necessary + # for export for cases where Dynamo-reconstructed bytecode can create + # new function frames, confusing export in thinking that there + # are extra graphs now. + + if output.export and output.is_empty_graph(): + return ConvertFrameReturn() + + assert output.guards is not None + CleanupManager.instance[out_code] = output.cleanups + nonlocal cache_entry + with dynamo_timed("build_guards", log_pt2_compile_event=True): + check_fn = CheckFunctionManager( + code, + output, + cache_entry, + hooks.guard_fail_fn if hooks else None, + hooks.guard_filter_fn if hooks else None, + guards_serialization_mode="save" if package else None, + ) + + if package is not None: + assert check_fn.guards_state is not None + package.add_guarded_code(check_fn.guards_state, out_code) + + compile_id_str = str(compile_id) if compile_id is not None else "Unknown" + annotation_str = "Torch-Compiled Region: " + compile_id_str + guarded_code = GuardedCode( + out_code, + check_fn.guard_manager, # type: ignore[arg-type] + compile_id, + annotation_str, + ) + + if not output.is_empty_graph() and hooks.guard_export_fn is not None: + # We should not run the guard_export_fn when Dynamo does not + # generate any graph. This can happen in export when TorchDynamo + # generated bytecode has some reconstruction logic for mutated + # variables which can trigger TorchDynamo on the children frames but + # they are benign and do not generate any new graphs. + hooks.guard_export_fn(output.guards) + + return wrap_guarded_code(guarded_code) + + metrics_context = get_metrics_context() + code_context = ( + package.code_context(code) if package is not None else contextlib.nullcontext() + ) + with ( + _use_lazy_graph_module(config.use_lazy_graph_module), + compile_context(CompileContext(compile_id)), + chromium_event_timed( + "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True + ), + _WaitCounter("pytorch.wait_counter.entire_forward_compile").guard(), + metrics_context, + dynamo_timed( + "_compile.compile_inner", + phase_name="entire_frame_compile", + dynamo_compile_column_us="dynamo_cumulative_compile_time_us", + ), + code_context, + ): + restart_reasons: set[str] = set() + # This is shared across restarts + speculation_log = SpeculationLog() + if compile_pg := get_compile_pg(): + distributed_state = DistributedState(compile_pg, LocalState()) + else: + distributed_state = None + + # Check recompilations + recompile_reason: Optional[str] = None + if is_recompilation(cache_size) and frame: + reasons = get_and_maybe_log_recompilation_reasons(cache_entry, frame) + recompile_reason = ( + "Unable to find recompilation reasons" if not reasons else reasons[0] + ) + metrics_context.update_outer({"recompile_reason": recompile_reason}) + + exceeded, limit_type = exceeds_recompile_limit(cache_size, compile_id) + if exceeded: + + def format_func_info(code: CodeType) -> str: + return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})" + + log.warning( + "torch._dynamo hit config.%s (%s)\n" + " function: %s\n" + " last reason: %s\n" + 'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n' + "To diagnose recompilation issues, see %s.", + limit_type, + getattr(config, limit_type), + format_func_info(code), + recompile_reason, + troubleshooting_url, + ) + if config.fail_on_recompile_limit_hit: + raise FailOnRecompileLimitHit( + f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure" + ) + elif one_graph: + raise FailOnRecompileLimitHit( + f"{limit_type} reached with one_graph=True. Excessive recompilations can degrade " + "performance due to the compilation overhead of each recompilation. To monitor " + "recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider " + "increasing torch._dynamo.config.cache_size_limit to an appropriate value." + ) + elif justknobs_check( + "pytorch/compiler:skip_code_recursive_on_recompile_limit_hit" + ): + raise RecompileLimitExceeded(f"{limit_type} reached") + else: + # do not recursively skip frames + unimplemented_v2( + gb_type="Dynamo cache limit exceeded", + context=f"Limit type: {limit_type}", + explanation="Dynamo attempted to recompile the code object too many times, " + f"exceeding the {limit_type} cache size limit." + "Giving up on compiling as the compile time tradeoff is likely not " + "worth the performance gain.", + hints=[], + ) + + log.debug( + "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s", + code.co_name, + code.co_filename, + code.co_firstlineno, + skip + 2, + # -2: omit current frame, omit contextlib decorator + "".join(CapturedTraceback.extract(skip=2 + skip).format()), + ) + # -4: -2 as above, plus trace_structured frames + # + # NB: the frame looks like this: + # + # # handled by skip argument + # torch/_dynamo/convert_frame.py:1069 in catch_errors + # torch/_dynamo/convert_frame.py:910 in _convert_frame + # torch/_dynamo/convert_frame.py:464 in _convert_frame_assert + # torch/_utils_internal.py:70 in wrapper_function + # + # # 2 current frame and context lib + # env/lib/python3.10/contextlib.py:79 in inner + # torch/_dynamo/convert_frame.py:776 in _compile + # + # # 2 extra here + # torch/_logging/_internal.py:1064 in trace_structured + # torch/_dynamo/convert_frame.py:780 in + convert_frame_intern = structured.intern_string(__file__) + # Initialize the ChromiumEventLogger on start + torch._logging.trace_structured( + "dynamo_start", + lambda: { + "stack": list( + itertools.takewhile( + lambda f: f["filename"] != convert_frame_intern, + structured.from_traceback( + CapturedTraceback.extract(skip=4 + skip).summary() + ), + ) + ) + + [ + { + "line": code.co_firstlineno, + "name": code.co_name, + "filename": structured.intern_string(code.co_filename), + } + ] + }, + ) + start_time_ns = time.time_ns() + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + fail_user_frame_filename: Optional[str] = None + fail_user_frame_lineno: Optional[int] = None + torch._dynamo.utils.ReinplaceCounters.clear() + guarded_code = None + try: + guarded_code = compile_inner(code, one_graph, hooks, transform) + + # NB: We only put_code_state in success case. Success case here + # does include graph breaks; specifically, if a graph break still + # resulted in a partially compiled graph, we WILL return here. An + # Unsupported exception will only bubble to the top level if we + # are unable to compile the frame at all. In this case, there's + # no point in uploading the code state, because we will always + # fail exactly the same way even without the update. (It's useful + # to upload for graph break though, because this can prevent + # extra graph break compilations.) + put_code_state() + log_frame_dynamic_whitelist(code) + + return guarded_code + except Exception as e: + # NB: e's msg is mutated here to add user stack, but we DON'T want + # that stack in the Scuba logged fail_reason. So we grab the fail + # info here and add it to the metrics context below. + fail_type = type(e).__qualname__ + fail_reason = str(e) + exception_handler(e, code, frame, export=export) + # NB: this is the post-mutation exception + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_error", + "encoding": "string", + }, + payload_fn=lambda: traceback.format_exc(), + ) + fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( + e, compile_id + ) + if isinstance( + e, + ( + Unsupported, + TorchRuntimeError, + BackendCompilerFailed, + AssertionError, + ConstraintViolationError, + GuardOnDataDependentSymNode, + ValidationException, + UncapturedHigherOrderOpError, + BisectValidationException, + ShortenTraceback, + PackageError, + ), + ): + raise + else: + # Rewrap for clarity + raise InternalTorchDynamoError( + f"{type(e).__qualname__}: {str(e)}" + ).with_traceback(e.__traceback__) from None + finally: + # === WARNING WARNING WARNING === + # If you commit a bug here, it will suppress writing to + # dynamo_compile table, and we will not have telemetry. + # Be extra careful when making changes here! + + if torch._dynamo.config.run_gc_after_compile: + with dynamo_timed("gc", dynamo_compile_column_us="gc_time_us"): + log.info("run_gc_after_compile: running gc") + gc.collect(1) + + if tracer: + tracer.output.local_scope = {} + tracer.f_locals = {} + + from .utils import curr_frame + + frame_key = str(curr_frame) + if fail_reason is None and output is not None: + guard_count = len(output.guards) + shape_env_guard_count = len(output.shape_env.guards) + graph_op_count = output.count_calls() + graph_node_count = len(output.graph.nodes) + graph_input_count = len(output.placeholders) + non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops} + compliant_custom_ops = { + op.__qualname__ for op in output.compliant_custom_ops + } + torch._dynamo.utils.ReinplaceCounters.log() + else: + guard_count = None + shape_env_guard_count = None + graph_op_count = None + graph_node_count = None + graph_input_count = None + non_compliant_ops = set({}) + compliant_custom_ops = set({}) + restart_reasons = set() + # If compilation failed, the entire time is wasted + dynamo_time_before_restart = (time.time_ns() - start_time_ns) / 1e9 + + metrics = { + "frame_key": frame_key, + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs, + "accumulated_cache_size": cache_size.num_cache_entries, + "guard_count": guard_count, + "shape_env_guard_count": shape_env_guard_count, + "graph_op_count": graph_op_count, + "graph_node_count": graph_node_count, + "graph_input_count": graph_input_count, + "fail_type": fail_type, + "fail_reason": fail_reason, + "fail_user_frame_filename": fail_user_frame_filename, + "fail_user_frame_lineno": fail_user_frame_lineno, + "non_compliant_ops": non_compliant_ops, + "compliant_custom_ops": compliant_custom_ops, + "restart_reasons": restart_reasons, + "dynamo_time_before_restart_s": dynamo_time_before_restart, + "has_guarded_code": guarded_code is not None, + "config_suppress_errors": config.suppress_errors, + "config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules, + "specialize_float": config.specialize_float, + "is_forward": True, + "dynamo_compile_time_before_restart_us": to_int_us( + dynamo_time_before_restart + ), + } + # TODO: replace with CompileEventLogger.compilation_metrics + # There are some columns here not in PT2 Compile Events + # so we need to slightly change it + metrics_context.update_outer(metrics) + # === END WARNING WARNING WARNING === + + +class ConvertFrame: + def __init__( + self, + compiler_fn: CompilerFn, + hooks: Hooks, + package: Optional[CompilePackage] = None, + ) -> None: + self._torchdynamo_orig_callable = compiler_fn + self._inner_convert = convert_frame_assert( + compiler_fn, one_graph=False, package=package + ) + self._hooks = hooks + + @property + def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]: + return lambda backend: convert_frame(backend, self._hooks) + + def __call__( + self, + frame: DynamoFrameType, + cache_entry: Optional[CacheEntry], + hooks: Hooks, + frame_state: dict[str, Union[int, FrameStateSizeEntry]], + skip: int = 0, + ) -> ConvertFrameReturn: + input_codes.add(frame.f_code) + counters["frames"]["total"] += 1 + try: + result = self._inner_convert( + frame, cache_entry, hooks, frame_state, skip=skip + 1 + ) + counters["frames"]["ok"] += 1 + return result + except Exception as e: + # These two exception types are "soft" failure, in the sense that + # we know this is due to something we didn't implement all the + # way, scare the user less about it. That being said, if you + # are trying to understand why a graph break happened, it's still + # important to have this information, so offer it. + # + # NB: NotImplementedError used to be on this list, but actually + # it is impossible for it to reach here, as it is converted into + # InternalTorchDynamoError. This behavior seemed reasonable + # to me (ezyang, Aug 2023) so I kept it, but maybe at some point + # someone wanted these to also get suppressed. If so, you'll + # need to make these exceptions not get wrapped + + # We intentionally don't want to suppress error here. + if isinstance(e, UncapturedHigherOrderOpError): + raise + + soft_fail = isinstance(e, Unsupported) + + # This is a soft failure. In the sense, the code path reaches here + # when we do not support graph breaks on bytecodes like LOAD_ATTR, + # BUILD_SET etc. In such case, we can fallback to eager without + # scaring users. + if soft_fail and graph_break_log.isEnabledFor(logging.DEBUG): + # Log this message in the graph break. Also use the string + # "skip: " to tell that the whole frame is falling back to + # eager. + if hasattr(e, "compile_id") and hasattr(e, "real_stack"): + with compile_context(CompileContext(e.compile_id)): # type: ignore[attr-defined] + user_stack = e.real_stack + user_stack_formatted = "".join( + traceback.format_list(user_stack) + ) + user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}" + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_graph_break_reason", + "encoding": "string", + }, + payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc()}", + ) + graph_break_log.debug( + user_stack_trace, + exc_info=True, + ) + + if not config.suppress_errors and not soft_fail: + raise + + # Suppress the error. NB: It's very important to do the + # suppression logging HERE, where the actual suppression + # happens. Previously it was somewhere else and so it was + # possible to accidentally not log at all. + record_filename = getattr(e, "record_filename", None) + code = frame.f_code + error_msg = format_error_msg(e, code, record_filename, frame) + + if soft_fail: + log.info(error_msg, exc_info=True) + else: + log.warning(error_msg, exc_info=True) + + if isinstance(e, SkipCodeRecursiveException): + return ConvertFrameReturn( + frame_exec_strategy=FrameExecStrategy( + FrameAction.SKIP, FrameAction.SKIP + ) + ) + elif isinstance(e, RecompileLimitExceeded): + return ConvertFrameReturn( + frame_exec_strategy=FrameExecStrategy( + FrameAction.RUN_ONLY, FrameAction.RUN_ONLY + ) + ) + + return ConvertFrameReturn() + + +def convert_frame( + compiler_fn: CompilerFn, hooks: Hooks, package: Optional[CompilePackage] = None +) -> ConvertFrame: + """Try to convert a frame into an FX graph, if error leave frame unmodified""" + return ConvertFrame(compiler_fn, hooks, package=package) + + +# TODO mlazos: add support for same args, or record them +def replay(filename: str) -> None: + from .backends.debugging import eager + + original_replay_val = config.replay_record_enabled + config.replay_record_enabled = False + with open(filename, "rb") as in_file: + record = ExecutionRecord.load(in_file) + record.globals = dict(itertools.chain(record.globals.items(), globals().items())) + + try: + _compile( + record.code, + record.globals, + record.locals, + record.builtins, + record.closure, + compiler_fn=eager, + one_graph=False, + export=False, + export_constraints=None, + hooks=Hooks(), + cache_size=CacheSizeRelevantForFrame(0, 0), + cache_entry=None, + frame=None, + frame_state={}, + compile_id=CompileId(frame_id=42, frame_compile_id=999), + ) + finally: + config.replay_record_enabled = original_replay_val + + +def first_real_inst_idx(code: CodeType) -> int: + if sys.version_info < (3, 11): + return 0 + for inst in dis.get_instructions(code): + if inst.opname == "RESUME": + return inst.offset // 2 + raise RuntimeError("RESUME instruction not found in code") + + +class ConvertFrameProtocol(typing.Protocol): + def __call__( + self, + frame: DynamoFrameType, + cache_entry: Optional[CacheEntry], + hooks: Hooks, + frame_state: dict[str, Union[int, FrameStateSizeEntry]], + *, + skip: int = 0, + ) -> ConvertFrameReturn: ... + + +class CatchErrorsWrapper: + def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None: + functools.wraps(callback)(self) + self._torchdynamo_orig_callable = callback + self.hooks = hooks + + def __call__( + self, + frame: DynamoFrameType, + cache_entry: Optional[CacheEntry], + frame_state: dict[str, Union[int, FrameStateSizeEntry]], + ) -> ConvertFrameReturn: + assert frame_state is not None + + input_codes.add(frame.f_code) + + is_skipfile = trace_rules.check(frame.f_code) + if sys.version_info >= (3, 13): + has_started_execution = frame.f_lasti > first_real_inst_idx(frame.f_code) + else: + has_started_execution = frame.f_lasti >= first_real_inst_idx(frame.f_code) + if ( + # TODO: the first condition is not covered by any test + has_started_execution + or is_skipfile + or config.disable + or ( + is_in_torch_dispatch_mode(include_infra_modes=False) + and not getattr(self._torchdynamo_orig_callable, "_export", False) + ) + ): + if log.isEnabledFor(logging.DEBUG): + if has_started_execution: + skip_reason = "traced frame already" + elif trace_rules.check(frame.f_code): + skip_reason = "in skipfiles" + elif is_in_torch_dispatch_mode(include_infra_modes=False): + skip_reason = "non-infra torch dispatch mode present, this is not supported today in torch.compile" + else: + skip_reason = "dynamo tracing is disabled" + + log.debug( + "skipping: %s (reason: %s, file: %s)", + frame.f_code.co_name, + skip_reason, + frame.f_code.co_filename, + ) + return ConvertFrameReturn() + + if frame.f_code.co_filename == "" and frame.f_code.co_name == "__new__": + # nametuple constructor + return ConvertFrameReturn() + if torch._dynamo.utils.get_optimize_ddp_mode() == "ddp_optimizer": + ddp_module = DistributedDataParallel._get_active_ddp_module() + if ddp_module: + with compile_lock: + from torch._dynamo.backends.distributed import DDPOptimizer + + ddp_optimizer = DDPOptimizer( + bucket_bytes_cap=ddp_module.bucket_bytes_cap, + backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable, # type: ignore[attr-defined] + ) + assert hasattr( + self._torchdynamo_orig_callable, "_clone_with_backend" + ), ( + "DDPOptimizer only supports callback fns that know how to clone themselves." + ) + hijacked_callback = ( + self._torchdynamo_orig_callable._clone_with_backend( + ddp_optimizer.compile_fn, + ) + ) + return hijacked_callback( + frame, cache_entry, self.hooks, frame_state + ) + + with compile_lock, _disable_current_modes(): + # skip=1: skip this frame + return self._torchdynamo_orig_callable( + frame, cache_entry, self.hooks, frame_state, skip=1 + ) + + +def catch_errors_wrapper( + callback: ConvertFrameProtocol, hooks: Hooks +) -> CatchErrorsWrapper: + return CatchErrorsWrapper(callback, hooks) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/create_parameter_op.py b/phivenv/Lib/site-packages/torch/_dynamo/create_parameter_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ab23dbc71876d2b99f5d63c7aab82be675a3c96b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/create_parameter_op.py @@ -0,0 +1,68 @@ +import threading +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +import torch + + +# See [Note: Metadata mutation in proxy tracing] for why sacrificial parameter mutates +# metadata during proxy tracing and we should remove the sacrificial parameter logic. +doc = """ +This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly +with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which +becomes a graph arg and has no storage backing it. At the point in the graph where the parameter +actually should be created we mutate this sacrificial placeholder into it. This allows gradients +to flow into the parameter as if it were an input to the graph (which is the only thing we are +allowed to compute gradients on). +""".strip() + + +class TracableCreateParameter(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter: + assert not tensor.requires_grad + return placeholder.set_(tensor) + + @staticmethod + def backward(ctx: Any, *grad_outputs: torch.Tensor) -> tuple[None, torch.Tensor]: + grad = grad_outputs[0] + return None, grad # grad flows to placeholder + + +def tracable_create_parameter( + tensor: torch.Tensor, placeholder: torch.nn.Parameter +) -> torch.nn.Parameter: + with torch.set_grad_enabled(placeholder.requires_grad): + out = TracableCreateParameter.apply(tensor, placeholder) + return out + + +def new_parameter_placeholder( + size: tuple[int, ...], dtype: torch.dtype, device: torch.device, requires_grad: bool +) -> torch.nn.Parameter: + """Create a placeholder to be passed to the above functions""" + result = torch.nn.Parameter( + torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad + ) + # TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor. + # Allocating a zero tensor would causes assert failures in autograd. + result.untyped_storage().resize_(0) + return result + + +_TLS = threading.local() + + +@contextmanager +def do_not_convert_to_tracable_parameter() -> Generator[bool, None, None]: + old_flag = getattr(_TLS, "convert_tracable_parameter", True) + _TLS.convert_tracable_parameter = False + try: + yield False + finally: + _TLS.convert_tracable_parameter = old_flag + + +def can_convert_to_tracable_parameter() -> bool: + return getattr(_TLS, "convert_tracable_parameter", True) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/current_scope_id.py b/phivenv/Lib/site-packages/torch/_dynamo/current_scope_id.py new file mode 100644 index 0000000000000000000000000000000000000000..7398e74d0fffc76ef40462ee6b5edebadf2dd4ef --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/current_scope_id.py @@ -0,0 +1,42 @@ +""" +Provides thread-local scope identification for SubgraphTracer instances. + +This module implements a thread-safe mechanism for tracking nested tracing contexts, +which is essential when multiple SubgraphTracer instances are active. The scope ID +helps identify which tracer context is currently active when direct access to the +InstructionTranslator is difficult. + +Key components: +- Thread-local scope ID storage (_current_scope_id) +- Getter function (current_scope_id) to safely access the current scope +- Context manager (enter_new_scope) for managing nested scope transitions + +The scope ID increments when entering a new context and decrements when exiting, +allowing proper tracking of nested tracing operations across different threads. +""" + +import contextlib +import threading +from collections.abc import Generator + + +# Global variable to identify which SubgraphTracer we are in. +# It is sometimes difficult to find an InstructionTranslator to use. +_current_scope_id = threading.local() + + +def current_scope_id() -> int: + global _current_scope_id + if not hasattr(_current_scope_id, "value"): + _current_scope_id.value = 1 + return _current_scope_id.value + + +@contextlib.contextmanager +def enter_new_scope() -> Generator[None, None, None]: + global _current_scope_id + try: + _current_scope_id.value = current_scope_id() + 1 + yield + finally: + _current_scope_id.value = current_scope_id() - 1 diff --git a/phivenv/Lib/site-packages/torch/_dynamo/debug_utils.py b/phivenv/Lib/site-packages/torch/_dynamo/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de4946c37dbf042169c40398a095757a38fa7794 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/debug_utils.py @@ -0,0 +1,896 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="method-assign" + +""" +Debug utilities for TorchDynamo compilation and execution. + +This module provides various debugging tools and utilities for TorchDynamo, including: + +- Minification support for reducing test cases while preserving bugs +- Input/output handling via InputReader and InputWriter for reproducible testing +- Accuracy checking between original and compiled models +- Neural network module string conversion via NNModuleToString +- Profiling tools and system information collection +- Buck build system integration for Meta-internal testing + +Key classes: +- InputReader/InputWriter: Handle serialization of model inputs/outputs +- NNModuleToString: Converts nn.Modules to string representations +- BuckTargetWriter: Manages Buck build system integration +""" + +import atexit +import copy +import cProfile +import functools +import getpass +import inspect +import itertools +import logging +import os +import re +import subprocess +import sys +import tempfile +import textwrap +from collections import Counter +from importlib import import_module +from typing import Any, Callable, Optional, TypeVar + +import torch +import torch._prims_common as utils +import torch._subclasses.meta_utils +from torch import Tensor +from torch._dynamo.testing import rand_strided +from torch._prims_common import is_float_dtype +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._content_store import ContentStoreReader, ContentStoreWriter + +from . import config +from .utils import clone_inputs, get_debug_dir + + +log = logging.getLogger(__name__) + +T = TypeVar("T") + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + +if use_buck: + import libfb.py.build_info + + +extra_deps = [] +extra_imports = "" +if use_buck: + extra_deps = [ + "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", + "//caffe2/torch/fb/sparsenn:sparsenn_operators", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", + ] + cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined] + extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) + + +BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] + + +class BuckTargetWriter: + def __init__(self, filename): + self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) + self.target = self.py_file.replace(".py", "") + + # Get main_module path from fbcode + self.path = f"{self.subdir.replace('/', '.')}.{self.target}" + self.path = self.path[self.path.find("fbcode.") :] + self.path = self.path[7:] + + # Get cmd line path + tmp = self.subdir + tmp = tmp[tmp.find("fbcode/") :][7:] + self.cmd_line_path = f"//{tmp}:{self.target}" + + def build(self): + extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) + return textwrap.dedent( + f""" +load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") + +python_binary( + name="{self.target}", + srcs = ["{self.py_file}"], + compile = False, + deps = [ + "//caffe2:torch", + "//caffe2:libtorch", + "//caffe2/functorch:functorch", + "//triton:triton", + "{cur_target}", + ], + cpp_deps = [ +{extra_cpp_deps} + ], + main_module = "{self.path}", + par_style = "xar", +) +""" + ) + + def write(self, print_msg=True): + target_file = os.path.join(self.subdir, "TARGETS") + with open(target_file, "w") as fd: + fd.write(self.build()) + # log.warning("Wrote isolation TARGETS file at %s", target_file) + cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] + if print_msg: + log.warning( + "Found an example that reproduces the error. Run this cmd to repro - %s", + " ".join(cmd_split), + ) + return cmd_split + + +def minifier_dir(): + path = os.path.join(get_debug_dir(), "minifier") + if path is None: + path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + return path + + +MAX_CONSTANT_NUMEL_INLINE = 4 + + +class NNModuleToString: + safe_reprs = [ + torch.nn.Linear, + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.LayerNorm, + torch.nn.Dropout, + torch.nn.Softmax, + torch.nn.ReLU, + torch.nn.GELU, + torch.nn.Identity, + torch.nn.MaxPool2d, + torch.nn.Embedding, + torch.nn.Tanh, + torch.nn.ConvTranspose1d, + torch.nn.GLU, + torch.nn.LSTM, + torch.nn.Flatten, + torch.nn.AdaptiveAvgPool2d, + ] + + @staticmethod + def can_convert_to_string(gm): + cant_convert = set() + for _, module in gm.named_children(): + if type(module) not in NNModuleToString.safe_reprs: + cant_convert.add(module) + + if len(cant_convert) > 0: + log.warning("We have not tested reprs of some modules - %s", cant_convert) + # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct. + return True + + @staticmethod + def convert(gm): + from torch.nn.modules.module import _addindent + + tab = " " * 4 + + model_str = textwrap.dedent( + """ + from torch.nn import * + class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + """ + ) + + for module_name, module in gm.named_children(): + module_str = f"{module.__repr__()}" + # module should be a core torch.nn.Module, so all parameters + # should be on the same device. + example_param = next(module.parameters(), None) + if example_param is not None and example_param.is_cuda: + module_str = f"{module_str}.cuda()" + model_str += f"{tab * 2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in gm._buffers.items(): + if buffer is None: + continue + # Serialize full data for small buffers + if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE: + from torch._tensor_str import PRINT_OPTS + + assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE + tensor_str = repr(buffer) + elif torch.is_floating_point(buffer): + tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" + else: + tensor_str = ( + f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" + ) + if buffer.is_cuda: + tensor_str = f"{tensor_str}.cuda()" + model_str += ( + f"{tab * 2}self.register_buffer('{buffer_name}', {tensor_str})\n" + ) + + for param_name, param in gm._parameters.items(): + if param is None: + continue + maybe_device = "" + if param.is_cuda: + maybe_device = ', device="cuda"' + tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))" + model_str += f"{tab * 2}self.{param_name} = {tensor_str}\n" + + # TODO - Keep this code for now. But, I don't think we will need this. + # attrs = dir(gm) + # for attr in attrs: + # if "_tensor_constant" in attr: + # val = getattr(gm, attr) + # model_str += f" {attr} = {val!r}\n" + + model_str += f"{_addindent(gm.code, 4)}\n" + return model_str + + +@functools.cache # subprocess is expensive +def _cuda_system_info_comment(): + if not torch.cuda.is_available(): + return "# torch.cuda.is_available()==False, no GPU info collected\n" + + model_str = "# CUDA Info: \n" + try: + cuda_version_out = subprocess.check_output(["nvcc", "--version"]) + cuda_version_lines = cuda_version_out.decode().split("\n") + comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]]) + model_str += f"{comment}\n" + except (FileNotFoundError, subprocess.CalledProcessError): + model_str += "# nvcc not found\n" + + gpu_names = Counter( + torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) + ) + + model_str += "# GPU Hardware Info: \n" + for name, count in gpu_names.items(): + model_str += f"# {name} : {count} \n" + model_str += "\n" + return model_str + + +def generate_env_vars_string(*, stable_output=False): + """ + Generate a string configuration for environment variables related to Dynamo, Inductor, and Triton. + """ + if stable_output: + return "# env var omitted due to stable_output=True" + + allow_list = ["TORCH", "DYNAMO", "INDUCTOR", "TRITON"] + skip_list = ["TRITON_LIBDEVICE_PATH", "TRITON_PTXAS_PATH", "TRITON_LIBCUDA_PATH"] + + def filter(key): + return any(string in key for string in allow_list) and key not in skip_list + + config_lines = [ + f"os.environ['{key}'] = '{value}'" + for key, value in os.environ.items() + if filter(key) + ] + config_string = "\n".join(config_lines) + return f"""\ +import os +{config_string} + """ + + +def generate_config_string(*, stable_output=False): + import torch._functorch.config + import torch._inductor.config + + if stable_output: + return "# config omitted due to stable_output=True" + + experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined] + return f"""\ +import torch._dynamo.config +import torch._inductor.config +import torch._functorch.config +import torch.fx.experimental._config +{torch._dynamo.config.codegen_config()} +{torch._inductor.config.codegen_config()} +{torch._functorch.config.codegen_config()} +{experimental_config} +""" + + +def get_minifier_repro_path(): + return os.path.join(minifier_dir(), "minifier_launcher.py") + + +def helper_for_dump_minify(contents): + minified_repro_path = get_minifier_repro_path() + log.warning("Writing minified repro to:\n%s", minified_repro_path) + + if use_buck: + BuckTargetWriter(minified_repro_path).write() + try: + with open(minified_repro_path, "w") as fd: + fd.write(contents) + + except OSError as e: + log.exception("") + raise NotImplementedError("Could not write to {minified_repro_path}") from e + + +class AccuracyError(Exception): + pass + + +def clone_inputs_retaining_gradness(example_inputs): + """ + This clone inputs is different from utils clone_input. In case of minifier, + all the tensors are leaf tensors while creating a new graph. So, we set the + requires_grad field w/o checking the leafness of the tensor. + """ + cloned_inputs = clone_inputs(example_inputs) + for idx in range(len(example_inputs)): + if isinstance(cloned_inputs[idx], torch.Tensor): + cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) + return cloned_inputs + + +def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): + """ + Runs a forward and possibly backward iteration for a given mod and args. + + When disable_clone is True, we will use args as-is without cloning. + This is higher fidelity but we may destroy the args in the process. + """ + from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass + + gm = copy.deepcopy(gm) + if not disable_clone: + args = clone_inputs_retaining_gradness(args) + + if hasattr(gm, "zero_grad"): + gm.zero_grad(True) + + # TorchInductor returned callable expects lists. So, may need a boxed calling convention. + out = gm(args) if getattr(gm, "_boxed_call", False) else gm(*args) + + if only_fwd: + return out + if requires_bwd_pass(out): + loss = reduce_to_scalar_loss(out) + loss.backward() + return collect_results(gm, out, None, args) + + +def same_two_models( + gm, + opt_gm, + example_inputs, + only_fwd=False, + *, + require_fp64=False, + ignore_non_fp=False, +): + """ + Check two models have same accuracy. + + require_fp64: if True, raise an error if we unable to calculate the fp64 reference + ignore_non_fp: if True, do not compare outputs which are not floating point. This + is mostly useful for the minifier (which wants to avoid quantizing floating point + error into integer/boolean error) + """ + from .utils import same + + ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) + + fp64_ref = None + if config.same_two_models_use_fp64: + try: + fp64_model, fp64_examples = cast_to_fp64( + copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) + ) + fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) + except Exception: + if require_fp64: + raise RuntimeError( # noqa: B904 + "Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False" + ) + log.warning("Could not generate fp64 outputs") + + try: + res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) + except Exception: + # This means that the minified graph is bad/exposes a different problem. + # As we are checking accuracy here, lets log the exception and return True. + log.exception( + "While minifying the program in accuracy minification mode, " + "ran into a runtime exception which is likely an unrelated issue." + " Skipping this graph." + ) + return True + + passing = same( + ref, + res, + fp64_ref, + tol=config.repro_tolerance, + equal_nan=True, + ignore_non_fp=ignore_non_fp, + ) + return passing + + +def cast_dtype_args_to_fp64(model): + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.prims.convert_element_type.default + ): + assert len(node.args) == 2 + if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: + node.args = (node.args[0], torch.float64) + if node.op == "call_function": + dtype = node.kwargs.get("dtype") + if dtype is not None and is_float_dtype(dtype): + new_kwargs = dict(node.kwargs) + new_kwargs["dtype"] = torch.float64 + node.kwargs = new_kwargs + + model.graph.lint() + model.recompile() + return model + + +def cast_to(dtype, model, inputs): + from torch.utils._pytree import tree_map + + model = model.to(dtype) + if dtype == torch.float64: + # If casting to fp64 for accuracy comparison, we need to + # replace dtype arguments embedded in the graph with fp64 + model = cast_dtype_args_to_fp64(model) + + inputs = tree_map( + lambda x: x.to(dtype) + if isinstance(x, torch.Tensor) and x.is_floating_point() + else x, + inputs, + ) + return model, inputs + + +def cast_to_fp64(model, inputs): + return cast_to(torch.float64, model, inputs) + + +def backend_accuracy_fails( + gm, + example_inputs, + compiler_fn, + only_fwd=False, + *, + require_fp64=False, + ignore_non_fp=False, +): + try: + compiled_gm = compiler_fn( + copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) + ) + return not same_two_models( + gm, + compiled_gm, + example_inputs, + only_fwd, + require_fp64=require_fp64, + ignore_non_fp=ignore_non_fp, + ) + except Exception: + # This means that the minified graph is bad/exposes a different problem. + # As we are checking accuracy here, lets log the exception and return False. + log.exception( + "While minifying the program in accuracy minification mode, " + "ran into a runtime exception which is likely an unrelated issue." + " Skipping this graph" + ) + return False + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO SUPPORT CODE +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# Helper functions for computing what the default values of tensor +# values should be. These all coincide with factory functions, e.g., torch.empty + + +def _stride_or_default( + stride: Optional["torch._prims_common.StrideType"], + *, + shape: "torch._prims_common.ShapeType", +) -> "torch._prims_common.StrideType": + return stride if stride is not None else utils.make_contiguous_strides_for(shape) + + +def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]: + return lambda x: x if x is not None else d + + +_dtype_or_default = _mk_defaulter(torch.float32) +_device_or_default = _mk_defaulter(torch.device("cpu")) +_storage_offset_or_default = _mk_defaulter(0) +_requires_grad_or_default = _mk_defaulter(False) +_is_leaf_or_default = _mk_defaulter(False) + + +class NopInputReader: + def __init__(self) -> None: + self.total = 0 + + def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + self.total += 1 + + def tensor(self, *args, **kwargs): + pass + + def symint(self, *args, **kwargs): + pass + + +# TODO: Support bundling the entire repro into a zip file for ease of +# transferring around +class InputReader: + def __init__(self, save_dir=None, *, pbar=None): + # If None, we will generate random data instead. It's important + # to natively support this use case as it will allow people to + # share repros without including the real data, if the problem + # reproduces even on random data. + if save_dir is None: + log.warning("no save_dir specified, will generate random data") + self.store = ContentStoreReader(save_dir) if save_dir is not None else None + self.args = [] + self.pbar = pbar + + def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + if self.pbar is not None: + self.pbar.update(1) + device = _device_or_default(device) + dtype_hint = _dtype_or_default(dtype_hint) + if self.store is not None and storage_hash is not None: + try: + storage = self.store.read_storage(storage_hash) + except FileNotFoundError: + pass + else: + if device != storage.device: + log.warning("device mismatch: %s != %s", device, storage.device) + # TODO: transfer it to the right device? But failing this + # way would be very mysterious! Would have been better + # not to store device in the serialized format... + return storage + log.warning("could not load %s, generating random data instead", storage_hash) + shape = (nbytes // dtype_hint.itemsize,) + stride = _stride_or_default(None, shape=shape) + return rand_strided(shape, stride, dtype_hint, device).untyped_storage() + + def tensor( + self, + storage, + shape, + stride=None, + *, + storage_offset=None, + dtype=None, + requires_grad=None, + is_leaf=None, + **metadata, + ): + stride = _stride_or_default(stride, shape=shape) + storage_offset = _storage_offset_or_default(storage_offset) + dtype = _dtype_or_default(dtype) + is_leaf = _is_leaf_or_default(is_leaf) + requires_grad = _requires_grad_or_default(requires_grad) + t = torch.tensor( + [], dtype=dtype, device=storage.device, requires_grad=requires_grad + ) + with torch.no_grad(): + t.set_(storage, storage_offset, shape, stride) + if not is_leaf: + # Fake up some autograd history in a very naughty way + with torch.enable_grad(): + t = t.clone(memory_format=torch.preserve_format) + with torch.no_grad(): + t.set_(storage, storage_offset, shape, stride) + assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf + torch._utils.set_tensor_metadata(t, metadata) + self.args.append(t) + return t # for BC + + def symint(self, val): + self.args.append(val) + return val # for BC + + +# Here is our writer strategy: +# 1. We will stream all of the inputs to disk +# 2. You can now deterministically randomize the inputs, or reload +# the inputs from disk +# 3. You can YOLO run the script without the inputs, in which case +# we'll fill the inputs with random data and pray. This is the +# legacy behavior, but it's also useful if you want to find out +# if we're so broken even random inputs trigger it +# 4. We could offer an in process "check if the randomized thing +# works too" but this is delicate so we don't do it + + +class InputWriter: + def __init__(self, save_dir, *, stable_hash=False): + self._lines = [] + # TODO: consider ensuring tensor and storage counters line up? + self.storage_counter = itertools.count() + self.save_dir = save_dir + self.store = ( + ContentStoreWriter(save_dir, stable_hash=stable_hash) + if save_dir is not None + else None + ) + self.seen_storages = {} + + def lines(self): + r = [ + "def load_args(reader):", + ] + r.extend(f" {l}" for l in self._lines) + # In case we need to change the internal format of load_args + # in an FC-breaking way + r.append("load_args._version = 0") + return r + + # Storages are untyped, but we need to initialize them with data if + # we don't have the real data, so we give a hint saying what kind + # of initialization may be appropriate + # + # If we had a FakeTensor, device_hint tells us what device should be + def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: + ws = StorageWeakRef(untyped_storage) + v = self.seen_storages.get(ws) + if v is not None: + return v + v = f"buf{next(self.storage_counter)}" + maybe_dtype_hint = "" + if _dtype_or_default(None) != _dtype_or_default(dtype_hint): + maybe_dtype_hint = f", dtype_hint={dtype_hint!r}" + # TODO: being optional on device is kind of pointless as the default + # is CPU but most repros we care about are CUDA + maybe_device = "" + device = untyped_storage.device + if device.type == "meta": + assert device_hint is not None + device = device_hint + if _device_or_default(None) != device: + maybe_device = f", device={device!r}" + nbytes = untyped_storage.nbytes() + storage_hash = None + if self.store is not None and untyped_storage.device.type != "meta": + storage_hash = self.store.write_storage(untyped_storage) + self._lines.append( + f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})" + ) + self.seen_storages[ws] = v + return v + + def tensor(self, name, t) -> None: + from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq + + storage = self.storage( + t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device + ) + args = [] + # NB: this is positional, must come first + if not statically_known_true( + sym_eq(_stride_or_default(None, shape=t.shape), t.stride()) + ): + args.append(str(tuple(t.stride()))) + if _dtype_or_default(None) != t.dtype: + args.append(f"dtype={t.dtype!r}") + if not statically_known_true( + _storage_offset_or_default(None) == t.storage_offset() + ): + args.append(f"storage_offset={t.storage_offset()!r}") + tensor_metadata = torch._utils.get_tensor_metadata(t) + if tensor_metadata: + args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items()) + if _requires_grad_or_default(None) != t.requires_grad: + args.append(f"requires_grad={t.requires_grad!r}") + is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t) + if _is_leaf_or_default(None) != is_leaf: + args.append(f"is_leaf={is_leaf!r}") + self._lines.append( + "reader.tensor(" + + ", ".join([storage, str(tuple(t.shape)), *args]) + + f") # {name}" + ) + + def unsupported(self, name, arg): + # NB: Try hard not to /print/ a tensor, that will be very slow + self._lines.append(f"# {name} was unsupported type for dumping: {type(arg)}") + # Best effort dump as much useful stuff we can lol, in case you want + # to repair the repro + if isinstance(arg, (list, tuple)): + self._lines.append('"""') + for i, a in enumerate(arg): + name_i = f"{name}[{i}]" + if isinstance(a, torch.Tensor): + self.tensor(name_i, a) + elif isinstance(a, (int, torch.SymInt)): + self.symint(name_i, a) + else: + self.unsupported(name_i, a) + self._lines.append('"""') + + # write out that the arg was filtered out as it is constant + def const(self, name) -> None: + self._lines.append( + f"reader.const({name!r}) # {name}, filtered out during compilation" + ) + + # TODO: this doesn't actually symint atm + def symint(self, name, val) -> None: + if isinstance(val, torch.SymInt): + val = val.node.hint + self._lines.append(f"reader.symint({val!r}) # {name}") + + +def aot_graph_input_parser( + func: Callable[[list[Tensor]], list[Tensor]], + device: str = "cuda", + sym_shapes: Optional[dict[str, int]] = None, + default_sym_shape: Optional[int] = None, +) -> dict[str, Any]: + """ + Takes in a function which has been printed with print_readable() and constructs kwargs to run it. + + Handles Tensor inputs, Symints, and a graph module which might have tensor constants. + + Consider a function `forward` defined as follows: + + def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",): + _tensor_constant0: "i64[4190]" = self._tensor_constant0 + # Further implementation + + kwargs = aot_graph_input_parser(forward) + forward(**kwargs) + """ + + from torch.utils._dtype_abbrs import dtype_abbrs + + dtype_map = {value: key for key, value in dtype_abbrs.items()} + dtype_pattern = "|".join(dtype_abbrs.values()) + + # Extracting the source code from the function + source = inspect.getsource(func) + + # Regular expressions + tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)" + tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]" + sym_shape_regex = r"Sym\((s\d+)\)" + + class TensorContainer: + "Container for tensors as attributes" + + # Dictionary for tensors from annotations + kwargs: dict[str, Any] = {} + + sym_shapes = sym_shapes or {} + + def get_sym_int(symint): + torch._check( + symint in sym_shapes or default_sym_shape is not None, + lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in", + ) + return sym_shapes.get(symint, default_sym_shape) + + def gen_tensor(shape, dtype) -> Tensor: + # Resolve symbolic shapes to concrete values + resolved_shape = [] + dynamic_dims = [] + for i, dim in enumerate(shape): + dim = dim.strip() + if "s" in dim: + s = get_sym_int(dim) + resolved_shape.append(s) + dynamic_dims.append(i) + else: + if dim: + resolved_shape.append(int(dim)) + + constructor = torch.randn if dtype.is_floating_point else torch.zeros + out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg] + for d in dynamic_dims: + torch._dynamo.mark_dynamic(out, d) + return out + + # Parse function annotations for tensor generation + annotations = func.__annotations__ + for param, annotation in annotations.items(): + # Skip 'return' annotation + if param == "return": + continue + + match = re.search(tensor_regex, annotation) + if match: + data_type, shape_str = match.groups() + shape = tuple(shape_str.split(",")) + dtype = dtype_map[data_type] + kwargs[param] = gen_tensor(shape, dtype) + + match = re.search(sym_shape_regex, annotation) + if match: + kwargs[param] = get_sym_int(match.group(1)) + + if "self" in inspect.signature(func).parameters: + container = TensorContainer() + kwargs["self"] = container + for match in re.finditer(tensor_assignment_regex, source): + attr_name, data_type, shape_str, _ = match.groups() + shape = tuple(shape_str.split(",")) + dtype = dtype_map[data_type] + setattr(container, attr_name, gen_tensor(shape, dtype)) + + return kwargs + + +def profile_to_file(filename: str) -> Callable[[T], T]: + """ + Decorator to cProfile a given function and save the result to disk on process exit. + + Args: + filename: filename to save profile to + """ + prof = cProfile.Profile() + filename = os.path.abspath(os.path.expanduser(filename)) + + def decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + prof.enable() + try: + return fn(*args, **kwargs) + finally: + prof.disable() + + return wrapper + + def save_it(): + prof.dump_stats(filename) + sys.stderr.write( + textwrap.dedent( + f"""\ + Wrote profile to {filename}, view with: + + snakeviz {filename} + + """ + ) + ) + + atexit.register(save_it) + return decorator diff --git a/phivenv/Lib/site-packages/torch/_dynamo/decorators.py b/phivenv/Lib/site-packages/torch/_dynamo/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..1223dc4e7f289a68334d3790c689a27278e0d092 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/decorators.py @@ -0,0 +1,876 @@ +# mypy: allow-untyped-defs +# ruff: noqa: TCH004 + +""" +This module provides decorators and utilities for controlling TorchDynamo's behavior during compilation. +""" + +import functools +import inspect +import weakref +from dataclasses import dataclass +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +from torch.utils._contextlib import _DecoratorContextManager +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from . import trace_rules, variables +from .comptime import comptime +from .eval_frame import ( + _set_stance, + DisableContext, + DynamoStance, + innermost_fn, + RunOnlyContext, + skip_code, +) +from .exc import IncorrectUsage +from .external_utils import ( + _dynamo_config_patch_proxy_dunder_call, + get_nonrecursive_disable_wrapper, + is_compiling, +) +from .utils import is_function + + +if TYPE_CHECKING: + from types import FunctionType + + from torch._C._dynamo.eval_frame import ( # noqa: F401 + reset_code, + set_eval_frame, + set_guard_complete_hook, + set_guard_error_hook, + unsupported, + ) + + from .variables import VariableTracker +else: + for name in dir(torch._C._dynamo.eval_frame): + if name.startswith("__"): + continue + globals()[name] = getattr(torch._C._dynamo.eval_frame, name) + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def run(fn=None): + """Don't do any dynamic compiles, just use prior optimizations""" + if fn is not None: + fn = innermost_fn(fn) + assert callable(fn) + return RunOnlyContext()(fn) + return RunOnlyContext() + + +def disable(fn=None, recursive=True, *, reason=None, wrapping=True): + """ + Decorator to disable TorchDynamo + + If recursive=True, Dynamo is completely skipped on the decorated function + frame as well as the recursively invoked functions. + + If recursive=False, Dynamo skips frames associated with the function code, + but still process recursively invoked frames. + + If reason is provided, it will be printed when Dynamo attempts to trace the disabled function. + """ + if recursive: + if fn is not None: + fn = innermost_fn(fn) + assert callable(fn) + return DisableContext(msg=reason, wrapping=wrapping)(fn) + return DisableContext(msg=reason, wrapping=wrapping) + else: + + def wrap(fn): + fn = innermost_fn(fn) + assert callable(fn) + + nonrecursive_disable_wrapper = get_nonrecursive_disable_wrapper(fn) + nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined] + nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined] + nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + return nonrecursive_disable_wrapper + + if fn is None: + return wrap + return wrap(fn) + + +_nonrecursive_disable_wrapper_code = disable(lambda: None, recursive=False).__code__ # type: ignore[attr-defined] +skip_code(_nonrecursive_disable_wrapper_code) + + +def skip(fn=None): + """ + Skip frames associated with the function code, but still process recursively + invoked frames + """ + if fn is None: + return skip + fn = innermost_fn(fn) + assert callable(fn) + skip_code(fn.__code__) + fn._torchdynamo_disable = True + return fn + + +class set_stance(_DecoratorContextManager): + """ + Decorator, context manager, function to set the current stance of the compiler. + + Stances documented in corresponding function in torch/compiler/__init__.py + """ + + _dynamo_forbidden = True + + def __init__( + self, + stance: str = "default", + *, + skip_guard_eval_unsafe: bool = False, + force_backend=None, + ) -> None: + if force_backend is not None and stance != "default": + raise RuntimeError("non-default stance cannot have force_backend set") + + self.stance = DynamoStance(stance, skip_guard_eval_unsafe, force_backend) + self.prev = _set_stance(self.stance) + + def __call__(self, fn): + _set_stance(self.prev) + wrapper = super().__call__(fn) + # forbid wrapper in graph + wrapper._dynamo_forbidden = True # type: ignore[attr-defined] + return wrapper + + def __enter__(self): + _set_stance(self.stance) + + def __exit__(self, exc_type, exc_val, exc_tb): + _set_stance(self.prev) + + def clone(self): + return self.__class__(self.stance.stance, force_backend=self.stance.backend) + + +def assume_constant_result(fn): + fn._dynamo_marked_constant = True + return fn + + +def allow_in_graph(fn): + """ + Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function + and instead directly write it to the graph when encountered. + + See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation + + WARNING: this API can be a footgun, please read the documentation carefully. + """ + if isinstance(fn, (list, tuple)): + return [allow_in_graph(x) for x in fn] + assert callable(fn), "allow_in_graph expects a callable" + if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable: + fn_id = id(fn) + trace_rules._disallowed_callable_ids.remove(fn_id) + trace_rules._allowed_callable_ids.add(fn_id) + + # Avoid id reuse which creates subtle bugs. + def deregister(): + trace_rules._allowed_callable_ids.remove(fn_id) + + weakref.finalize(fn, deregister) + return fn + + +def nonstrict_trace(traceable_fn): + # Like `allow_in_graph`, but with the following enhancements/differences: + # + # 1. Supports user-defined class as inputs, as long as the class has been + # registered with pytree. + # 2. Reads to global/captured tensors forces the underlying graph to treat + # those tensors as constant, and we _assume_ they will not be updated. This + # is similar to FX tracing. + # 3. In the resulting Dynamo graph, the call to a `nonstrict_trace`-ed function + # will be represented as a call to `torch._higher_order_ops.flat_apply`, + # which takes in the `nonstrict_trace`-ed function and pytree-flattened + # inputs. + # 4. Only the returned function is traceable, and the original function will + # not be. Moreover, `nonstrict_trace` can be used inside a `torch.compile` + # region. + # + # NOTE: like `allow_in_graph`, aliasing information is neither preserved + # between inputs themselves, nor between inputs and outputs. + assert callable(traceable_fn), "nonstrict_trace expects a callable" + + @functools.wraps(traceable_fn) + def wrapped(*args, **kwargs): + return traceable_fn(*args, **kwargs) + + wrapped_id = id(wrapped) + + # This line allows us to reuse much of the `allow_in_graph` impl. + trace_rules._allowed_callable_ids.add(wrapped_id) + + # This line allows us to diverge the impl from `allow_in_graph`. + trace_rules._nonstrict_trace_callable_ids.add(wrapped_id) + + # Avoid id reuse which creates subtle bugs. + def deregister(): + trace_rules._allowed_callable_ids.remove(wrapped_id) + trace_rules._nonstrict_trace_callable_ids.remove(wrapped_id) + + weakref.finalize(wrapped, deregister) + + return wrapped + + +def _disallow_in_graph_helper(throw_if_not_allowed): + def inner(fn): + if isinstance(fn, (list, tuple)): + return [disallow_in_graph(x) for x in fn] + assert callable(fn), "disallow_in_graph expects a callable" + if ( + throw_if_not_allowed + and trace_rules.lookup_callable(fn) + != variables.TorchInGraphFunctionVariable + and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable + ): + raise IncorrectUsage( + "disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). " + "Allowed callables means callables that TorchDynamo puts as-is in the extracted graph." + ) + trace_rules._allowed_callable_ids.remove(id(fn)) + trace_rules._nonstrict_trace_callable_ids.remove(id(fn)) + trace_rules._disallowed_callable_ids.add(id(fn)) + return fn + + return inner + + +def disallow_in_graph(fn): + """ + Customize which functions TorchDynamo will exclude in the generated + graph and force a graph break on. + :: + + torch._dynamo.disallow_in_graph(torch.sub) + + + @torch._dynamo.optimize(...) + def fn(a): + x = torch.add(x, 1) + x = torch.sub(x, 1) + x = torch.add(x, 1) + return x + + + fn(...) + + Will break the graph on `torch.sub`, and give two graphs each with a + single `torch.add()` op. + """ + return _disallow_in_graph_helper(throw_if_not_allowed=True)(fn) + + +@_disallow_in_graph_helper(throw_if_not_allowed=False) +def graph_break(msg=""): + """Force a graph break""" + + +# NOTE: primarily used for internal debugging purposes! +@_disallow_in_graph_helper(throw_if_not_allowed=False) +def skip_frame(msg=""): + """Force a skipped frame""" + + +def forbid_in_graph(fn): + """ + Customize which functions TorchDynamo will assert are not present while tracing. + + If you want a graph break on this function instead, use disallow_in_graph. + TODO(voz): We now have allow_in_graph, disallow_in_graph, forbid_in_graph - some more robust + documentation would not be amiss. + """ + if isinstance(fn, (list, tuple)): + return [forbid_in_graph(x) for x in fn] + assert callable(fn), "forbid_in_graph applies only to callables" + fn._dynamo_forbidden = True + return fn + + +def substitute_in_graph( + original_fn: Callable[_P, _R], + *, + can_constant_fold_through: bool = False, + skip_signature_check: bool = False, + # type that is embedded in the Python interpreter + is_embedded_type: bool = False, # internal use only +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: + """ + Register a polyfill handler for a function, usually a C function from the C extension, to be + used in place of the original function when inlining the original function in the graph. + + .. note:: + + The polyfill handler is only used when inlining the original function. It is not used when + the original function is called directly. In the eager mode, the decorated function calls + the performant C function rather than the polyfill handler. + + The polyfill handler is a function that will be called in place of the original function when + inlining the original function. The polyfill handler should have the same signature and the same + behavior as the original function. + + Args: + original_fn (callable): The original function, usually a C function, to register a polyfill + handler for. + can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant + folded through. That is, if the polyfill handler is a pure function and its arguments + are constant, the result of the polyfill handler can be constant folded during the + compilation. Defaults to ``False``. + skip_signature_check (bool, optional): Whether to skip the signature check between the + original function and the polyfill handler. Defaults to ``False``. + + Returns: + A decorator that registers the polyfill handler for the original function. + + Example:: + + >>> # xdoctest: +SKIP("conflict with the tests: duplicate polyfill handlers") + >>> import operator + >>> operator.indexOf([1, 2, 3, 4, 5], 3) + 2 + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + Traceback (most recent call last): + ... + torch._dynamo.exc.Unsupported: ... + + >>> @torch.compiler.substitute_in_graph(operator.indexOf) + ... def indexOf(a, b, /): + ... for i, item in enumerate(a): + ... if item is b or item == b: + ... return i + ... raise ValueError("sequence.index(x): x not in sequence") + >>> + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + 2 + """ + if not is_function(original_fn) and not ( + is_embedded_type and inspect.isclass(original_fn) + ): + raise TypeError( + f"substitute_in_graph expects a function but got {type(original_fn)!r}" + ) + if is_embedded_type: + if not inspect.isclass(original_fn): + raise TypeError( + f"substitute_in_graph expects a class but got {type(original_fn)!r}" + ) + + from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS + + if id(original_fn) in ITERTOOLS_TYPE_IDS: + ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn)) + + def wrapper(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]: + if not is_function(traceable_fn): + raise TypeError( + f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}" + ) + + if not skip_signature_check: + try: + original_sig = inspect.signature(original_fn) + except ValueError: + pass + else: + traceable_sig = inspect.signature(traceable_fn) + + def sig_ident(sig): + # Ignore annotations for parameters and return type + return ( + tuple( + p.name + for p in sig.parameters.values() + if ( + p.kind + not in { + p.KEYWORD_ONLY, + # the name of *args and **kwargs is not important + p.VAR_POSITIONAL, + p.VAR_KEYWORD, + } + ) + ), + { + p.name + for p in sig.parameters.values() + if p.kind == p.KEYWORD_ONLY + }, + { + p.name: p.default + for p in sig.parameters.values() + # the name of *args and **kwargs is not important + if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD} + }, + ) + + wildcard_sig = inspect.signature(lambda *args, **kwargs: None) + + if ( + sig_ident(original_sig) != sig_ident(traceable_sig) + and sig_ident(original_sig) != sig_ident(wildcard_sig) + and sig_ident(traceable_sig) != sig_ident(wildcard_sig) + ): + raise TypeError( + f"Signature mismatch between {original_fn} and {traceable_fn}: " + f"{original_sig} != {traceable_sig}" + ) + + from torch._dynamo.guards import GuardBuilder + from torch._dynamo.trace_rules import ( + _polyfilled_function_ids, + get_torch_obj_rule_map, + ) + from torch._dynamo.variables import PolyfilledFunctionVariable + from torch._dynamo.variables.builder import VariableBuilder + + id_dispatch_map = VariableBuilder._id_dispatch() + if id(original_fn) in id_dispatch_map: + raise ValueError( + f"Duplicate dispatch rule for {original_fn}: " + "already registered in VariableBuilder's id dispatch map" + ) + + if id(original_fn) in _polyfilled_function_ids: + raise ValueError(f"Duplicate polyfilled object {original_fn}") + + rule_map: dict[Any, type[VariableTracker]] = get_torch_obj_rule_map() + if original_fn in rule_map: + raise ValueError( + f"Duplicate object {original_fn} with different rules: " + f"{PolyfilledFunctionVariable}, {rule_map[original_fn]}" + ) + + polyfill_handlers: dict[Callable[..., Any], FunctionType] + polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers() + if original_fn in polyfill_handlers: + raise ValueError( + f"Duplicate polyfill handlers for {original_fn}: " + f"already handled by {polyfill_handlers[original_fn]}" + ) + + # Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a + # C++ function. + @functools.wraps(traceable_fn) + def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R: + return original_fn(*args, **kwargs) + + def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable: + return PolyfilledFunctionVariable( + value, + source=self.source, + **self.install_guards(GuardBuilder.FUNCTION_MATCH), + ) + + id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn + _polyfilled_function_ids.add(id(original_fn)) + _polyfilled_function_ids.add(id(wrapped)) + rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable + polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = wrapped # type: ignore[assignment] + + wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined] + wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined] + wrapped.__torch_dynamo_can_constant_fold_through__ = can_constant_fold_through # type: ignore[attr-defined] + + return wrapped # type: ignore[return-value] + + return wrapper + + +# Helper function to flatten a tensor subclass and apply a function to +# all inner tensors that match the outer dim. Used to reduce duplication +# across the various marking APIs. +def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs): + assert is_traceable_wrapper_subclass(t) + + attrs, _ctx = t.__tensor_flatten__() + assert isinstance(t, torch.Tensor) + for attr in attrs: + inner = getattr(t, attr) + if inner.dim() == t.dim(): + func(inner, *args, **kwargs) + + +@dataclass(frozen=True) +class _DimRange: + """ + This represents an dimension of a tensor and the corresponding + min and max values it can take. Don't create this + class directly; instead, use :func:`mark_dynamic`. + """ + + dim: int + min: int + max: int + + +@forbid_in_graph +def mark_unbacked(t, index, strict=False, specialize_on=None): + """ + Mark a tensor as having an unbacked dim. This changes the semantics of operations, + we will always report the size does not equal zero/one, we will turn asserts + on this index into runtime asserts, and if you try to get the real value we will + raise an exception. In other words, we will treat this dimension as if it was + data dependent (we do not know anything about its value.) + + For historical reasons, by default if an unbacked dim is specialized, we will + happily specialize it and continue. If you want to error in these cases, pass + strict=True. + """ + # You could have copied the mark_dynamic behavior but I'm not convinced + # it's what you want + assert not is_traceable_wrapper_subclass(t), "not implemented yet" + + if isinstance(index, int): + if strict: + if not hasattr(t, "_dynamo_strict_unbacked_indices"): + t._dynamo_strict_unbacked_indices = set() + t._dynamo_strict_unbacked_indices.add(index) + return + + if not hasattr(t, "_specialized_on"): + t._specialize_on = {} + + if not hasattr(t, "_dynamo_unbacked_indices"): + t._dynamo_unbacked_indices = set() + + # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: + # TypeError: 'Attribute' object does not support item assignment + if isinstance(t._specialize_on, dict): + t._specialize_on[index] = specialize_on if specialize_on is not None else [] + + t._dynamo_unbacked_indices.add(index) + return + + assert isinstance(index, (list, tuple)) + for i in index: + mark_unbacked(t, i) + + +@forbid_in_graph +def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None): + """ + Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim. + + [Note - on the state of mark_dynamic] + + The behavior of having a dynamic dimension on a tensor is governed by a few factors: + + 1) torch._dynamo.config dynamic_shapes True or False. + a) dynamic_shapes=True - dynamic_shapes must be True for mark_dynamic to work. + a) dynamic_shapes=False - This config will raise an exception when used in conjunction with + mark_dynamic. We will eventually support this. + + 2) If the dimension is fully constrained - as in, it does not allow more than a single value + in both eager (torch.compile, torch._dynamo.optimize) mode and export mode (torch._dynamo.export), + we will raise an error + + 3) If the dimension is partially constrained - allowing at least 2 values but not the full unbounded + range of shapes, in eager we will pass it through, but export will raise an error. + + 4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made + before torch.compile. + + 5) If specialize_on is passed in, we will perform a single generic Dynamo trace followed by + multiple specialized compilations in addition to a single generic compilation. NB: For now we only support + per dimension specialization, or in other words we do not generate a cross product of specializations. + At runtime, we will dispatch to a specialized compiled region if the input matches the specialization criteria. + + For example: + mark_dynamic(..., specialize_on=[ + lambda x: x == 8, + lambda x: x == 16 + ]) + + This approach results in one Dynamo trace and two backend compilations. When the input dimension equals 8 or 16 + at runtime, execution will be directed to the specialized compiled region. Performance measurements indicate + 2-8x speedups depending on the specific specialization and model architecture. + """ + if is_traceable_wrapper_subclass(t): + # default behavior: mirror mark_dynamic() on all inner tensors with same dim as t + # TODO: Make this configurable via a supported public API + _apply_func_to_inner_tensors_of_same_dim( + mark_dynamic, t, index, min=min, max=max + ) + + if isinstance(index, int): + if not hasattr(t, "_dynamo_dynamic_indices"): + t._dynamo_dynamic_indices = set() + t._dynamo_dynamic_range = set() + + if not hasattr(t, "_specialize_on"): + t._specialize_on = {} + + # TODO(voz): Should we bounds check? + t._dynamo_dynamic_indices.add(index) + t._dynamo_dynamic_range.add(_DimRange(index, min, max)) + + # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: + # TypeError: 'Attribute' object does not support item assignment + if isinstance(t._specialize_on, dict): + t._specialize_on[index] = specialize_on if specialize_on is not None else [] + + return + + assert isinstance(index, (list, tuple)) + for i in index: + mark_dynamic(t, i, min=min, max=max) + mark_dynamic(t, i, min=min, max=max, specialize_on=specialize_on) + + +@forbid_in_graph +def maybe_mark_dynamic(t, index): + """ + Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this + dimension ends up getting specialized, don't error). + """ + if is_traceable_wrapper_subclass(t): + # default behavior: mirror maybe_mark_dynamic() on all inner tensors with same dim as t + # TODO: Make this configurable via a supported public API + _apply_func_to_inner_tensors_of_same_dim(maybe_mark_dynamic, t, index) + + if isinstance(index, int): + if not hasattr(t, "_dynamo_weak_dynamic_indices"): + t._dynamo_weak_dynamic_indices = set() + # TODO(voz): Should we bounds check? + t._dynamo_weak_dynamic_indices.add(index) + return + + assert isinstance(index, (list, tuple)) + for i in index: + maybe_mark_dynamic(t, i) + + +def mark_static(t, index=None): + """ + Mark a tensor as having a static dim or mark a nn module class as static. + + For tensors + =========== + This will prevent us from attempting to compile it dynamically + when dynamic=True; this can improve trace-time performance. + + This has lower precedence than mark_dynamic. + + Unlike mark_dynamic, this can be done inside a graph, in which case it + induces specialization on the tensor. + + For nn.Module classes + ===================== + For static nn.Module classes, TorchDynamo assumes that the module instance + attributes will not be modified after compilation. This will ensure that + TorchDynamo keeps integer attributes CONSTANT and not symints. + + From TorchDynamo implementation side, the instances of static-marked + nn.Module class will be converted to UnspecializedBuiltinNNModuleVariable, + which have the same properties. + + Note that we still have to guard on the attributes, because different + instances of the nn.Module can have different values of the attributes. The + key point here is that the attributes are static. + """ + if is_compiling(): + if index is None: + for s in t.size(): + comptime.force_static(s) + else: + comptime.force_static(t.size(index)) + return + + if is_traceable_wrapper_subclass(t): + # default behavior: mirror mark_static() on all inner tensors with same dim as t + # TODO: Make this configurable via a supported public API + _apply_func_to_inner_tensors_of_same_dim(mark_static, t, index) + + if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module): + t._dynamo_marked_static = True + return t + + if not isinstance(t, torch.Tensor): + raise TypeError( + f"mark_static expects a tensor/nn.Module class but received {type(t)}" + ) + + if isinstance(index, int): + if not hasattr(t, "_dynamo_static_indices"): + t._dynamo_static_indices = set() # type: ignore[attr-defined] + # TODO(voz): Should we bounds check? + t._dynamo_static_indices.add(index) # type: ignore[attr-defined] + elif index is None: + for i in range(t.dim()): + mark_static(t, i) + else: + assert isinstance(index, (list, tuple)) + for i in index: + mark_static(t, i) + + +@forbid_in_graph +def mark_static_address(t, guard=True): + """ + Marks an input tensor whose data_ptr will not change across multiple calls + to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation + is not needed for this input. The data_ptr will be guarded if guard=True. Note: + Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called. + """ + if not isinstance(t, torch.Tensor): + raise TypeError(f"mark_static_address expects a tensor but received {type(t)}") + + if guard: + t._dynamo_static_input_type = "guarded" # type: ignore[attr-defined] + else: + t._dynamo_static_input_type = "unguarded" # type: ignore[attr-defined] + + +# One day, Dynamo will support tracing into einops directly (no allow_in_graph needed) +# Note that PyTorch supports multiple versions of einops, so when that day comes, +# we still need to be really careful about version matches. +def _allow_in_graph_einops(): + import einops + + try: + # requires einops > 0.6.1, torch >= 2.0 + from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401 + _ops_were_registered_in_torchdynamo, + ) + + # einops > 0.6.1 will call the op registration logic as it is imported. + except ImportError: + # einops <= 0.6.1 + allow_in_graph(einops.rearrange) + allow_in_graph(einops.reduce) + if hasattr(einops, "repeat"): + allow_in_graph(einops.repeat) # available since einops 0.2.0 + if hasattr(einops, "einsum"): + allow_in_graph(einops.einsum) # available since einops 0.5.0 + if hasattr(einops, "pack"): + allow_in_graph(einops.pack) # available since einops 0.6.0 + if hasattr(einops, "unpack"): + allow_in_graph(einops.unpack) # available since einops 0.6.0 + + +# Note: this carefully avoids eagerly import einops. +trace_rules.add_module_init_func("einops", _allow_in_graph_einops) + + +# Proxy class for torch._dynamo.config patching - so dynamo can identify context managers/decorators +# created by patch_dynamo_config, compared to ones created by a raw torch._dynamo.config.patch. +class DynamoConfigPatchProxy: + def __init__(self, config_patch): + self.config_patch = config_patch + + @property + def changes(self): + return self.config_patch.changes + + # Decorator implementation that simply sets up `self` as a context manager. + # Placed in external_utils so that we can trace through it. + __call__ = _dynamo_config_patch_proxy_dunder_call + + def __enter__(self): + return self.config_patch.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self.config_patch.__exit__(exc_type, exc_val, exc_tb) + + +# Criteria for patchable config: +# - Config values must be constants (i.e. int, float, str, bool, None). +# - in particular, NO list, set, dict. +# - Traceable config patches are only useful for configs that change dynamo behavior +# from symbolic_convert and below. +# - e.g. patching recompile_limit won't really do anything. +# - For patching configs that affect Dynamo behavior above symbolic_convert, +# ensure that Dynamo behaves soundly even if tracing is done with different config. +# - e.g. be careful if patching guard-related configs as configs may have changed +# between guard creation and evaluation. +_allowed_config_patches = ( + "verbose", + "verify_correctness", + "rewrite_assert_with_torch_assert", + "capture_scalar_outputs", + "allow_unspec_int_on_nn_module", + "skip_torchrec", + "dont_skip_tracing", +) + +from . import config + + +for name in _allowed_config_patches: + assert hasattr(config, name), "nonexistent config" +del config + + +def _patch_dynamo_config_check(changes: dict[str, Any]): + for k, v in changes.items(): + if k not in _allowed_config_patches: + raise ValueError( + f"patch_dynamo_config does not support patching config {k}" + ) + if not torch._dynamo.utils.is_safe_constant(v): + raise ValueError( + f"patch_dynamo_config does not support patching config {k} " + f"with non-safe-constant value {v}" + ) + + +# TODO: also implement nonrecursive patch_dynamo_config/dont_skip_tracing. +# Unlike config.patch, we also need to accept tuple as input in order to +# deal with context manager reconstruction. +def patch_dynamo_config( + arg1: Optional[Union[str, dict[str, Any], tuple[tuple[str, Any], ...]]] = None, + arg2: Any = None, + **kwargs: Any, +) -> DynamoConfigPatchProxy: + """ + A wrapper around torch._dynamo.config.patch that can be traced by Dynamo to + temporarily change config values DURING tracing. + + See _allowed_config_patches for the list of allowed config patches. + + Arguments are the same as with torch._dynamo.config.patch. + + Can be used as a decorator or a context manager. + + User code SHOULD NOT MODIFY the return value of this function. + + WARNING: changing Dynamo config during tracing can lead to unpredictable tracing behavior! + Proceed only as advised! + """ + if isinstance(arg1, tuple): + arg1 = dict(arg1) + config_patch = torch._dynamo.config.patch(arg1, arg2, **kwargs) + _patch_dynamo_config_check(config_patch.changes) + # check for valid patching using config_patch.changes + return DynamoConfigPatchProxy(config_patch) + + +def dont_skip_tracing(fn=None): + """ + Context manager/decorator to trace into functions intentionally marked by developers to be skipped + when tracing. + + This decorator will also apply to recursively invoked functions. + """ + ctx = patch_dynamo_config(dont_skip_tracing=True) + if fn: + return ctx(fn) + return ctx diff --git a/phivenv/Lib/site-packages/torch/_dynamo/device_interface.py b/phivenv/Lib/site-packages/torch/_dynamo/device_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e020f6a2f61f27a005e3c5a662c198be0ba8f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/device_interface.py @@ -0,0 +1,515 @@ +# mypy: allow-untyped-defs + +""" +Device abstraction layer for TorchDynamo and Inductor backends. + +This module provides a unified interface for different hardware backends (CUDA, XPU, +CPU, MPS) through a common device interface. Key components include: + +- DeviceInterface: Base class defining the common API for all device types +- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface +- Device registration system for managing available backends +- Worker APIs for multi-processing scenarios +- Stream and event management across different devices +- Device property caching for worker processes + +The abstraction layer enables device-agnostic code in TorchDynamo while allowing +specialized implementations for each hardware backend's unique features. +""" + +import inspect +import time +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch + + +get_cuda_stream: Optional[Callable[[int], int]] +if torch.cuda._is_compiled(): + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream +else: + get_cuda_stream = None + +# Recording the device properties in the main process but used in worker process. +caching_worker_device_properties: dict[str, Any] = {} +caching_worker_current_devices: dict[str, int] = {} + + +class DeviceInterface: + """ + This is a simple device runtime interface for Inductor. It enables custom + backends to be integrated with Inductor in a device-agnostic semantic. + """ + + class device: + def __new__(cls, device: torch.types.Device): + raise NotImplementedError + + class Event: + def __new__(cls, *args, **kwargs): + raise NotImplementedError( + "Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo." + ) + + class Stream: + def __new__(cls, *args, **kwargs): + raise NotImplementedError( + "Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo." + ) + + class Worker: + """ + Worker API to query device properties that will work in multi processing + workers that cannot use the GPU APIs (due to processing fork() and + initialization time issues). Properties are recorded in the main process + before we fork the workers. + """ + + @staticmethod + def set_device(device: int): + raise NotImplementedError + + @staticmethod + def current_device() -> int: + raise NotImplementedError + + @staticmethod + def get_device_properties(device: torch.types.Device = None): + raise NotImplementedError + + @staticmethod + def current_device(): + raise NotImplementedError + + @staticmethod + def set_device(device: torch.types.Device): + raise NotImplementedError + + @staticmethod + def maybe_exchange_device(device: int) -> int: + raise NotImplementedError + + @staticmethod + def exchange_device(device: int) -> int: + raise NotImplementedError + + @staticmethod + def device_count(): + raise NotImplementedError + + @staticmethod + def is_available() -> bool: + raise NotImplementedError + + @staticmethod + def stream(stream: torch.Stream): + raise NotImplementedError + + @staticmethod + def current_stream(): + raise NotImplementedError + + @staticmethod + def set_stream(stream: torch.Stream): + raise NotImplementedError + + @staticmethod + def _set_stream_by_id(stream_id: int, device_index: int, device_type: int): + raise NotImplementedError + + @staticmethod + def get_raw_stream(device_idx: int) -> int: + raise NotImplementedError + + @staticmethod + def synchronize(device: torch.types.Device = None): + raise NotImplementedError + + @classmethod + def get_device_properties(cls, device: torch.types.Device = None): + return cls.Worker.get_device_properties(device) + + @staticmethod + def get_compute_capability(device: torch.types.Device = None): + raise NotImplementedError + + @staticmethod + def is_bf16_supported(including_emulation: bool = False): + raise NotImplementedError + + @classmethod + def is_dtype_supported( + cls, dtype: torch.dtype, including_emulation: bool = False + ) -> bool: + return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation) + + @staticmethod + def memory_allocated(device: torch.types.Device = None) -> int: + raise NotImplementedError + + @staticmethod + def is_triton_capable(device: torch.types.Device = None) -> bool: + """ + Returns True if the device has Triton support, False otherwise, even if + the appropriate Triton backend is not available. + """ + return False + + @classmethod + def raise_if_triton_unavailable(cls, device: torch.types.Device = None) -> None: + """ + Raises a `RuntimeError` with the appropriate human-readable instructions + to resolve the issue if Triton is not available for the given device, or + the default device if `device` is `None`. + + The caller should ensure the presence of the 'triton' package before + calling this method. + """ + if not cls.is_triton_capable(): + raise RuntimeError("This device is not capable of supporting Triton") + + +class DeviceGuard: + """ + This class provides a context manager for device switching. This is a stripped + down version of torch.{device_name}.device. + + The context manager changes the current device to the given device index + on entering the context and restores the original device on exiting. + The device is switched using the provided device interface. + """ + + def __init__( + self, device_interface: type[DeviceInterface], index: Optional[int] + ) -> None: + self.device_interface = device_interface + self.idx = index + self.prev_idx = -1 + + def __enter__(self): + if self.idx is not None: + self.prev_idx = self.device_interface.exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + if self.idx is not None: + self.idx = self.device_interface.maybe_exchange_device(self.prev_idx) + return False + + +class CudaInterface(DeviceInterface): + device = torch.cuda.device # type: ignore[assignment] + + # register Event and Stream class into the backend interface + # make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream + Event = torch.cuda.Event # type: ignore[assignment] + Stream = torch.cuda.Stream # type: ignore[assignment] + + class Worker: + @staticmethod + def set_device(device: int): + caching_worker_current_devices["cuda"] = device + + @staticmethod + def current_device() -> int: + if "cuda" in caching_worker_current_devices: + return caching_worker_current_devices["cuda"] + return torch.cuda.current_device() + + @staticmethod + def get_device_properties(device: torch.types.Device = None): + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert device.type == "cuda" + if isinstance(device, torch.device): + device = device.index + if device is None: + device = CudaInterface.Worker.current_device() + + if "cuda" not in caching_worker_device_properties: + device_prop = [ + torch.cuda.get_device_properties(i) + for i in range(torch.cuda.device_count()) + ] + caching_worker_device_properties["cuda"] = device_prop + + return caching_worker_device_properties["cuda"][device] + + current_device = staticmethod(torch.cuda.current_device) + set_device = staticmethod(torch.cuda.set_device) + device_count = staticmethod(torch.cuda.device_count) + stream = staticmethod(torch.cuda.stream) # type: ignore[assignment] + current_stream = staticmethod(torch.cuda.current_stream) + set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment] + _set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment] + synchronize = staticmethod(torch.cuda.synchronize) + get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment] + get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type] + exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type] + maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type] + memory_allocated = staticmethod(torch.cuda.memory_allocated) + is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type] + + # Can be mock patched by @patch decorator. + @staticmethod + def is_available() -> bool: + return torch.cuda.is_available() + + @staticmethod + def get_compute_capability(device: torch.types.Device = None): + if torch.version.hip is None: + major, min = torch.cuda.get_device_capability(device) + return major * 10 + min + else: + return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0] + + @staticmethod + def is_triton_capable(device: torch.types.Device = None) -> bool: + return ( + torch.version.hip is not None + or torch.cuda.get_device_properties(device).major >= 7 + ) + + @staticmethod + def raise_if_triton_unavailable(device: torch.types.Device = None) -> None: + from torch._inductor.exc import GPUTooOldForTriton + + if not CudaInterface.is_triton_capable(device): + device_props = torch.cuda.get_device_properties(device) + raise GPUTooOldForTriton(device_props, inspect.currentframe()) + + import triton.backends + + if torch.version.hip is not None: + if "amd" not in triton.backends.backends: + raise RuntimeError("triton not built with the 'amd' backend") + elif "nvidia" not in triton.backends.backends: + raise RuntimeError("triton not built with the 'nvidia' backend") + + +get_xpu_stream: Optional[Callable[[int], int]] +if torch.xpu._is_compiled(): + from torch._C import _xpu_getCurrentRawStream as get_xpu_stream +else: + get_xpu_stream = None + + +class XpuInterface(DeviceInterface): + device = torch.xpu.device # type: ignore[assignment] + Event = torch.xpu.Event # type: ignore[assignment] + Stream = torch.xpu.Stream # type: ignore[assignment] + + class Worker: + @staticmethod + def set_device(device: int): + caching_worker_current_devices["xpu"] = device + + @staticmethod + def current_device() -> int: + if "xpu" in caching_worker_current_devices: + return caching_worker_current_devices["xpu"] + return torch.xpu.current_device() + + @staticmethod + def get_device_properties(device: torch.types.Device = None): + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert device.type == "xpu" + if isinstance(device, torch.device): + device = device.index + if device is None: + device = XpuInterface.Worker.current_device() + + if "xpu" not in caching_worker_device_properties: + device_prop = [ + torch.xpu.get_device_properties(i) + for i in range(torch.xpu.device_count()) + ] + caching_worker_device_properties["xpu"] = device_prop + + return caching_worker_device_properties["xpu"][device] + + current_device = staticmethod(torch.xpu.current_device) + set_device = staticmethod(torch.xpu.set_device) + device_count = staticmethod(torch.xpu.device_count) + stream = staticmethod(torch.xpu.stream) # type: ignore[assignment] + current_stream = staticmethod(torch.xpu.current_stream) + set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment] + _set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment] + synchronize = staticmethod(torch.xpu.synchronize) + get_device_properties = staticmethod(torch.xpu.get_device_properties) # type: ignore[assignment] + get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type] + exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type] + maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type] + memory_allocated = staticmethod(torch.xpu.memory_allocated) + + # Can be mock patched by @patch decorator. + @staticmethod + def is_available() -> bool: + return torch.xpu.is_available() + + @staticmethod + def get_compute_capability(device: torch.types.Device = None): + cc = torch.xpu.get_device_capability(device) + return cc + + @staticmethod + def is_bf16_supported(including_emulation: bool = False) -> bool: + return torch.xpu.is_bf16_supported() + + @staticmethod + def is_triton_capable(device: torch.types.Device = None) -> bool: + return True + + @staticmethod + def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None: + import triton.backends + + if "intel" not in triton.backends.backends: + raise RuntimeError("triton not built with the 'intel' backend") + + +@dataclass +class CpuDeviceProperties: + multi_processor_count: int + + +class CpuInterface(DeviceInterface): + class Event(torch.Event): + def __init__(self, enable_timing=True): + self.time = 0.0 + + def elapsed_time(self, end_event) -> float: + return (end_event.time - self.time) * 1000 + + def record(self, stream=None): + self.time = time.perf_counter() + + class Worker: + @staticmethod + def get_device_properties(device: torch.types.Device = None): + import multiprocessing + + cpu_count = multiprocessing.cpu_count() + return CpuDeviceProperties(cpu_count) + + @staticmethod + def is_available() -> bool: + return True + + @staticmethod + def is_bf16_supported(including_emulation: bool = False): + return True + + @staticmethod + def get_compute_capability(device: torch.types.Device = None) -> str: + return "" + + @staticmethod + def get_raw_stream(device_idx) -> int: + return 0 + + @staticmethod + def current_device(): + return 0 + + @staticmethod + def synchronize(device: torch.types.Device = None): + pass + + @staticmethod + def is_triton_capable(device: torch.types.Device = None) -> bool: + return True + + @staticmethod + def raise_if_triton_unavailable(device: torch.types.Device = None) -> None: + import triton.backends + + if "cpu" not in triton.backends.backends: + raise RuntimeError("triton not built with the 'cpu' backend") + + +class MpsInterface(DeviceInterface): + @staticmethod + def is_bf16_supported(including_emulation: bool = False) -> bool: + return torch.backends.mps.is_macos_or_newer(14, 0) + + @classmethod + def is_dtype_supported( + cls, dtype: torch.dtype, including_emulation: bool = False + ) -> bool: + if dtype in [torch.float64, torch.complex128]: + return False + return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation) + + @staticmethod + def is_available() -> bool: + return torch.backends.mps.is_available() + + @staticmethod + def current_device(): + return 0 + + @staticmethod + def get_compute_capability(device: torch.types.Device = None) -> str: + return "" + + @staticmethod + def synchronize(device: torch.types.Device = None): + torch.mps.synchronize() + + class Worker: + @staticmethod + def get_device_properties(device: torch.types.Device = None): + return {} + + @staticmethod + def current_device(): + return 0 + + +device_interfaces: dict[str, type[DeviceInterface]] = {} +_device_initialized = False + + +def register_interface_for_device( + device: Union[str, torch.device], device_interface: type[DeviceInterface] +): + if isinstance(device, torch.device): + device = device.type + device_interfaces[device] = device_interface + + +def get_interface_for_device(device: Union[str, torch.device]) -> type[DeviceInterface]: + if isinstance(device, torch.device): + device = device.type + if not _device_initialized: + init_device_reg() + if device in device_interfaces: + return device_interfaces[device] + raise NotImplementedError(f"No interface for device {device}") + + +def get_registered_device_interfaces() -> Iterable[tuple[str, type[DeviceInterface]]]: + if not _device_initialized: + init_device_reg() + return device_interfaces.items() + + +def init_device_reg(): + global _device_initialized + register_interface_for_device("cuda", CudaInterface) + for i in range(torch.cuda.device_count()): + register_interface_for_device(f"cuda:{i}", CudaInterface) + + register_interface_for_device("xpu", XpuInterface) + for i in range(torch.xpu.device_count()): + register_interface_for_device(f"xpu:{i}", XpuInterface) + + register_interface_for_device("cpu", CpuInterface) + register_interface_for_device("mps", MpsInterface) + + _device_initialized = True diff --git a/phivenv/Lib/site-packages/torch/_dynamo/distributed.py b/phivenv/Lib/site-packages/torch/_dynamo/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..a86005f3f4c4cdbb24d044da74b5fbffd2f292b2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/distributed.py @@ -0,0 +1,54 @@ +""" +Manages process groups for distributed compilation in TorchDynamo. + +This module handles the initialization and management of process groups used for +distributed compilation. Key features: + +- Lazy initialization of compilation process groups +- Only creates groups when distributed mode is enabled and available +- Integrates with compiler_collectives configuration setting +- Provides a single global process group for compilation coordination + +The process group is created only when needed and if the distributed environment +is properly initialized, making it safe to import and use this module even in +non-distributed scenarios. +""" + +from typing import Optional + +import torch.distributed as dist + +from . import config + + +_COMPILE_PG: Optional[dist.ProcessGroup] = None +_GUARD_PG: Optional[dist.ProcessGroup] = None + + +def get_compile_pg() -> Optional[dist.ProcessGroup]: + if ( + config.enable_compiler_collectives + and dist.is_available() + and dist.is_initialized() + ): + global _COMPILE_PG + if _COMPILE_PG is None: + # , timeout=datetime.timedelta(seconds=2) + _COMPILE_PG = dist.distributed_c10d._new_group_with_tag( + pg_tag="pt2_compile_pg" + ) + return _COMPILE_PG + + return None + + +# NB: Unlike get_compile_pg, this is only called when guard collectives were +# explicitly requested +def get_guard_pg() -> Optional[dist.ProcessGroup]: + if dist.is_available() and dist.is_initialized(): + global _GUARD_PG + if _GUARD_PG is None: + _GUARD_PG = dist.distributed_c10d._new_group_with_tag(pg_tag="pt2_guard_pg") + return _GUARD_PG + + return None diff --git a/phivenv/Lib/site-packages/torch/_dynamo/eval_frame.py b/phivenv/Lib/site-packages/torch/_dynamo/eval_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb3cb8574e5c95896c7d73850bbbaf804470792 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/eval_frame.py @@ -0,0 +1,2166 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code="method-assign" + +""" +This module implements the core frame evaluation handler for TorchDynamo's compilation system. +The eval frame handler intercepts Python bytecode execution at runtime to enable dynamic +compilation and optimization of PyTorch code. + +Key components defined here: +- Frame evaluation handlers that intercept and analyze Python execution frames +- Guards management for tracking dependencies and invalidating compiled code +- Optimization contexts and decorators (optimize, run_once, disable, etc.) +- Export functionality for saving optimized graphs +- Backend compiler integrations and callback management + +Functions in this file are responsible for modifying the eval frame handler at RUNTIME. +Therefore, all functions in this file are hot and performance-critical. Functions that +only execute at compile time should be placed in torch._dynamo.convert_frame. + +The eval frame handler is the core mechanism that enables TorchDynamo to dynamically +intercept, analyze and optimize PyTorch code during execution. It works by registering +a custom frame evaluation function that gets called for every Python frame, allowing +us to detect PyTorch operations and trigger compilation as needed. +""" + +from __future__ import annotations + +import atexit +import contextlib +import functools +import inspect +import logging +import os +import sys +import sysconfig +import textwrap +import threading +import traceback +import types +import warnings +import weakref +from dataclasses import dataclass +from enum import Enum +from os.path import dirname, join +from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union +from unittest.mock import patch + +import sympy + +import torch +import torch.fx +import torch.utils._pytree as pytree +import torch.utils.checkpoint +from torch import _guards + +# see discussion at https://github.com/pytorch/pytorch/issues/120699 +from torch._C._dynamo.eval_frame import ( # noqa: F401 + reset_code, + set_code_exec_strategy, + set_eval_frame, + set_guard_complete_hook, + set_guard_error_hook, + set_skip_guard_eval_unsafe, + unsupported, +) +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.types import ConvertFrameReturn, FrameAction, FrameExecStrategy +from torch._export.utils import _compiling_state_context +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch._utils_internal import justknobs_check, log_export_usage +from torch.export.dynamic_shapes import ( + _combine_args, + _DimHint, + _DimHintType, + _IntWrapper, + _process_dynamic_shapes, + _RelaxedConstraint, + Constraint, +) +from torch.fx import GraphModule +from torch.fx.experimental._dynamism import ( + clone_and_convert_to_meta, + track_dynamism_across_examples, +) +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + DimDynamic, + ShapeEnv, + StatelessSymbolicContext, +) +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo + +from . import config, convert_frame, distributed, external_utils, trace_rules, utils +from .backends.registry import CompilerFn, lookup_backend +from .code_context import code_context +from .exc import ( + CondOpArgsMismatchError, + ShortenTraceback, + Unsupported, + UserError, + UserErrorType, +) +from .hooks import Hooks +from .mutation_guard import install_generation_tagging_init +from .utils import common_constant_types, compile_times + + +if TYPE_CHECKING: + from torch._subclasses import fake_tensor + + from .types import CacheEntry, DynamoCallback + + +log = logging.getLogger(__name__) + + +always_optimize_code_objects = utils.ExactWeakKeyDictionary() +null_context = contextlib.nullcontext + + +# See https://github.com/python/typing/pull/240 +class Unset(Enum): + token = 0 + + +cached_backends: dict[int, CompilerFn] = {} + +unset = Unset.token + + +def _maybe_set_eval_frame(callback: DynamoCallback): + # A wrapper on set_eval_frame that is guarded by a Justknob. + # Users can disable torchDynamo by setting the JK to False. + if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"): + torch._dynamo.utils.warn_once( + "Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame" + ) + return callback + else: + return set_eval_frame(callback) + + +@dataclass +class DynamoStance: + stance: str = "default" + skip_guard_eval_unsafe: bool = False + backend: Union[str, Callable[..., Any], None] = None + + +_stance = DynamoStance() + + +def _set_stance(stance: DynamoStance) -> DynamoStance: + global _stance + + from torch._C._dynamo.eval_frame import get_eval_frame_callback + + callback = get_eval_frame_callback() + + if callback is not False and callback is not None: + raise RuntimeError("attempted to set_stance in a torch.compile region") + + prior = _stance + _stance = stance + return prior + + +_set_stance._dynamo_forbidden = True # type: ignore[attr-defined] + +_EXAMPLE_INPUTS: Optional[dict[str, list[Any]]] = None + + +def get_example_inputs(key) -> list[Any]: + global _EXAMPLE_INPUTS + if _EXAMPLE_INPUTS is None: + _EXAMPLE_INPUTS = {} + + if key not in _EXAMPLE_INPUTS: + _EXAMPLE_INPUTS[key] = [] + + return _EXAMPLE_INPUTS[key] + + +def _callback_from_stance(callback): + if _stance.stance == "default": + # force_backend + if _stance.backend is not None and callback not in (False, None): + callback = _create_wrapped_callback(get_compiler_fn(_stance.backend)) + + return callback + elif _stance.stance == "eager_then_compile": + if callback not in (False, None): + return _create_delayed_compile_callback(callback, _stance.stance) + return callback + elif _stance.stance == "aot_eager_then_compile": + if callback not in (False, None): + return _create_delayed_compile_callback(callback, _stance.stance) + return callback + elif _stance.stance == "force_eager": + # disable + return None + elif _stance.stance == "eager_on_recompile": + # run mode + return False + elif _stance.stance == "fail_on_recompile": + if callback in (False, None): + return callback + + def fail_callback(frame, *args, **kwargs): + if trace_rules.check(frame.f_code): + return ConvertFrameReturn() + raise RuntimeError( + "Detected recompile when torch.compile stance is 'fail_on_recompile'" + ) + + # to prevent cache miss due to different callback + fail_callback._torchdynamo_orig_callable = callback # type: ignore[attr-defined] + + return fail_callback + else: + raise RuntimeError(f"invalid torch.compile stance '{_stance}'") + + +def _create_wrapped_callback(compiler_fn): + hooks = Hooks() + return convert_frame.catch_errors_wrapper( + convert_frame.convert_frame( # type: ignore[arg-type] + compiler_fn, + hooks, + ), + hooks, + ) + + +def _get_or_add_example_inputs(frame): + key = frame.f_code.co_filename + str(frame.f_code.co_firstlineno) + example_inputs = get_example_inputs(key) + + if len(example_inputs) < 2: + example_inputs.append(clone_and_convert_to_meta(frame.f_locals)) + + return example_inputs + + +def _create_delayed_compile_callback(callback, stance): + def callback_fn(*args, **kwargs): + frame = args[0] + example_inputs = _get_or_add_example_inputs(frame) + + if len(example_inputs) == 1: + if stance == "eager_then_compile": + return ConvertFrameReturn( + frame_exec_strategy=FrameExecStrategy( + FrameAction.DEFAULT, FrameAction.DEFAULT + ) + ) + elif stance == "aot_eager_then_compile": + aot_eager_fn = get_compiler_fn("aot_eager") + return _create_wrapped_callback(aot_eager_fn)(*args, **kwargs) + + dynamism = track_dynamism_across_examples(example_inputs) + code_context.get_context(frame.f_code)["dynamism"] = dynamism + compiler_fn = callback._torchdynamo_orig_callable._torchdynamo_orig_callable + return _create_wrapped_callback(compiler_fn)(*args, **kwargs) + + return callback_fn + + +def _is_skip_guard_eval_unsafe_stance(): + return _stance.skip_guard_eval_unsafe + + +def _reset_guarded_backend_cache(): + global cached_backends + for backend in cached_backends.values(): + if hasattr(backend, "reset"): + backend.reset() + cached_backends.clear() + + +DONT_WRAP_FILES = { + # For tracing into fx modules + inspect.getsourcefile(GraphModule), + join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"), +} + + +def _debug_get_cache_entry_list( + code: Union[types.CodeType, Callable[..., Any]], +) -> list[CacheEntry]: + """ + Given a code object or a callable object, retrieve the cache entries + stored in this code. + """ + if callable(code): + code = code.__code__ + return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code) + + +class OptimizedModule(torch.nn.Module): + """ + Wraps the original nn.Module object and later patches its + forward method to optimized self.forward method. + """ + + _torchdynamo_orig_callable: Callable[..., Any] + get_compiler_config: Callable[[], Any] + + _opt_mod_attributes = { + "_orig_mod", + "dynamo_ctx", + "_torchdynamo_orig_callable", + "get_compiler_config", + "forward", + "_forward", + "__dict__", + "named_children_walk", + "_super_module_initialized", + } + + def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None: + # NOTE: this must go first, because attribute reads/writes of `self` + # uses `_orig_mod`, and sometimes users override `Module.__init__` to + # do attribute reads/writes on `self`. + # + # We also can't use regular setattr because `super().__setattr__` will + # complain for module value before `super().__init__()` + object.__setattr__(self, "_orig_mod", mod) + self._super_module_initialized = False + super().__init__() + self._super_module_initialized = True + + # Installs the params/buffer + self._orig_mod = mod # `super().__setattr__` will register this module + self.dynamo_ctx = dynamo_ctx + self._initialize() + self.training = self._orig_mod.training + + def _initialize(self): + # Do this stuff in constructor to lower overhead slightly + if isinstance(self.dynamo_ctx, DisableContext): + # No need to check trace rules + self.forward = self.dynamo_ctx(self._orig_mod.__call__) + elif config.wrap_top_frame or ( + isinstance(self._orig_mod.forward, types.MethodType) + and ( + trace_rules.check(self._orig_mod.forward) + or getattr(self._orig_mod, "_is_fsdp_managed_module", False) + ) + ): + # This may be a torch.nn.* instance in trace_rules.py which + # won't trigger a frame evaluation workaround to add an extra + # frame we can capture + self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod)) + else: + # Invoke hooks outside of dynamo then pickup the inner frame + self.forward = self.dynamo_ctx(self._orig_mod.__call__) + + if hasattr(self._orig_mod, "_initialize_hook"): + self._forward = self.forward + self.forward = self._call_lazy_check + + def __call__(self, *args, **kwargs): + if torch.nn.modules.module._has_any_global_hook(): + warnings.warn( + "Using `torch.compile(module)` when there are global hooks on " + "modules (e.g., from `register_module_forward_hook`); this will" + " cause the hooks to fire an extra time for the " + "`OptimizedModule` created by `torch.compile(module)`. If this " + "causes undesired behavior, please try using `module.compile()`" + ", or use the per-module hooks instead", + stacklevel=2, + ) + return super().__call__(*args, **kwargs) + + def __reduce__(self): + return (self.__class__, (self._orig_mod, self.dynamo_ctx)) + + def __getstate__(self): + state = dict(self.__dict__) + state.pop("forward", None) + state.pop("__call__", None) + return state + + def __setstate__(self, state): + self.__dict__ = state + self._initialize() + + @property + def training(self): + return self._orig_mod.training + + @training.setter + def training(self, value): + # Ignore the `training` mutation in `super().__init__()`, since that's + # setting the default on `nn.Module`, but we are mirroring the + # `training` attr in `self._orig_mod`. + if self._super_module_initialized: + self._orig_mod.training = value + + def __getattr__(self, name): + if name == "_orig_mod": + return self._modules["_orig_mod"] + return getattr(self._orig_mod, name) + + def __setattr__(self, name, val) -> None: + # Allow patching over class attributes + if hasattr(type(self), name): + return super().__setattr__(name, val) + + if name in OptimizedModule._opt_mod_attributes: + return super().__setattr__(name, val) + return setattr(self._orig_mod, name, val) + + def __delattr__(self, name): + # This mirrors `__setattr__` + if hasattr(type(self), name): + return super().__delattr__(name) + + if name in OptimizedModule._opt_mod_attributes: + return super().__delattr__(name) + return delattr(self._orig_mod, name) + + def _call_lazy_check(self, *args, **kwargs): + if ( + hasattr(self._orig_mod, "_initialize_hook") + and hasattr(self._orig_mod, "_infer_parameters") + and callable(self._orig_mod._infer_parameters) + ): + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it. + # Afterwards, lazy module deletes its pre-hooks + # to avoid treating it as lazy on subsequent recompile. + self._orig_mod._infer_parameters(self._orig_mod, args, kwargs) + return self._forward(*args, **kwargs) + + def __dir__(self): + orig_mod_attrs = self._orig_mod.__dir__() + return orig_mod_attrs + [ + attr for attr in super().__dir__() if attr not in orig_mod_attrs + ] + + +def remove_from_cache(f): + """ + Make sure f.__code__ is not cached to force a recompile + """ + if isinstance(f, types.CodeType): + reset_code(f) + elif hasattr(f, "__code__"): + reset_code(f.__code__) + elif hasattr(getattr(f, "forward", None), "__code__"): + reset_code(f.forward.__code__) + else: + from . import reset # type: ignore[attr-defined] + + reset() + log.warning("could not determine __code__ for %s", f) + + +def nothing(): + pass + + +def always_false(): + return False + + +def innermost_fn(fn): + """ + In case of nesting of _TorchDynamoContext calls, find the innermost + function. TorchDynamo caches on fn.__code__ object, so its necessary to find + the innermost function to pass on the optimize, run, disable etc. + """ + unaltered_fn = fn + while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): + unaltered_fn = unaltered_fn._torchdynamo_orig_callable + assert callable(unaltered_fn), ( + f"A callable function is expected, but {type(unaltered_fn)} is provided." + ) + return unaltered_fn + + +def make_set_enable_dynamic(enable: bool): + assert isinstance(enable, bool) + if enable: + # Assume everything is dynamic by default + return config._make_closure_patcher(assume_static_by_default=False) + else: + return config._make_closure_patcher( + automatic_dynamic_shapes=False, assume_static_by_default=True + ) + + +# A thread local storage that serves to store information as Dynamo traces +# through a user provided function. +class DynamoTLS(threading.local): + # Each string is a summary of a frame Dynamo attempted to trace, stored in + # temporal order. + traced_frame_infos: list[str] = [] + + +dynamo_tls = DynamoTLS() + + +def clear_dynamo_tls(): + dynamo_tls.traced_frame_infos.clear() + + +@atexit.register +def _log_traced_frames(): + """ + At program exit, log all of the frames Dynamo has attempted to trace from, + excluding the continuation frames generated by Dynamo. + """ + msg = "\n".join(dynamo_tls.traced_frame_infos) + msg = textwrap.indent(msg, " * ") + msg = f"TorchDynamo attempted to trace the following frames: [\n{msg}\n]" + log.info(msg) + + +def guard_collectives_hook(guard_eval_result): + import torch.distributed as dist + from torch._dynamo.utils import dynamo_timed + + # guard_eval_result == True ==> cache hit + if pg := distributed.get_guard_pg(): + with dynamo_timed( + "guard_collective", log_pt2_compile_event=True, log_waitcounter=True + ): + log.info("guard_collective %s", guard_eval_result) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "guard_collective", + "encoding": "string", + }, + payload_fn=lambda: str(guard_eval_result), + ) + # TODO: a bit awkward to time, this isn't inside of the dynamo compile region + all_results = [None] * pg.size() + dist.all_gather_object(all_results, guard_eval_result, group=pg) + # True = everyone hit, OK to run + # False = someone missed, force recompile everywhere + res = all(all_results) + log.info("guard_collective %s -> %s", guard_eval_result, res) + return res + return guard_eval_result + + +_not_set = object() + + +class _TorchDynamoContext: + def __init__( + self, + callback: DynamoCallback, + on_enter=nothing, + backend_ctx_ctor=null_context, + patch_fn=nothing, + first_ctx=False, + *, + export=False, + dynamic=None, + compiler_config=None, + package=None, + ) -> None: + super().__init__() + assert callable(callback) or callback is False or callback is None + self.callback: DynamoCallback = callback + self._backend_ctx_ctor = backend_ctx_ctor + self.prior: Union[Unset, DynamoCallback] = unset + self.first_ctx = first_ctx + self.export = export + self._dynamic = dynamic + self.compiler_config = compiler_config + self.cleanup_fns: list[Callable[[], Any]] = [] + self.enter_exit_hooks = [] + self._package = package + patch_fn() + + # Save the backends so that we can reset them during torch._dynamo.reset + backend = innermost_fn(callback) + cached_backends.setdefault(id(backend), backend) + + if dynamic is not None: + self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic)) + + if on_enter is not nothing: + # this case is not common + def call_on_enter(): + on_enter() + return nothing + + self.enter_exit_hooks.append(call_on_enter) + + if backend_ctx_ctor is not contextlib.nullcontext: + # this case is not common + def call_backend_ctx(): + ctx = backend_ctx_ctor() + ctx.__enter__() + return functools.partial(ctx.__exit__, None, None, None) + + self.enter_exit_hooks.append(call_backend_ctx) + + def __enter__(self): + if config.raise_on_ctx_manager_usage: + raise RuntimeError( + "torch._dynamo.optimize(...) is used with a context manager. " + "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html " + "to use torch._dynamo.optimize(...) as an annotation/decorator. " + ) + self.prior = set_eval_frame(None) + self.cleanup_fns = [enter() for enter in self.enter_exit_hooks] + self.prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe( + _is_skip_guard_eval_unsafe_stance() + ) + _maybe_set_eval_frame(_callback_from_stance(self.callback)) + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.prior is not unset + set_eval_frame(None) + set_skip_guard_eval_unsafe(self.prior_skip_guard_eval_unsafe) + for cleanup in self.cleanup_fns: + cleanup() + self.cleanup_fns.clear() + _maybe_set_eval_frame(_callback_from_stance(self.prior)) + self.prior = unset + + def __call__(self, fn): + # public api for compiler config/options + def get_compiler_config(): + return self.compiler_config + + fn = innermost_fn(fn) + + # add context containing GraphModule to any GraphModule forward functions + if isinstance(fn, GraphModule): + # add context containing GraphModule to any GraphModule forward functions + code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = ( + weakref.ref(fn) + ) + + # Optimize the forward method of torch.nn.Module object + if isinstance(fn, torch.nn.Module): + mod = fn + new_mod = OptimizedModule(mod, self) + # Save the function pointer to find the original callable while nesting + # of decorators. + new_mod._torchdynamo_orig_callable = mod.forward + + # when compiling torch.nn.Module, + # provide public api OptimizedModule.get_compiler_config() + assert not hasattr(new_mod, "get_compiler_config") + new_mod.get_compiler_config = get_compiler_config + + return new_mod + + if inspect.isclass(fn): + # User has wrapped the class with compile/disable decorator. Apply + # disable to init/call method. + cls_obj = fn + cls_obj.__call__ = self(cls_obj.__call__) + if issubclass(cls_obj, torch.nn.Module): + # NN module variable tracker directly inlines the _call_impl. + cls_obj._call_impl = self(cls_obj._call_impl) + return cls_obj + + assert callable(fn), ( + f"A callable function is expected, but {type(fn)} is provided." + ) + + try: + filename = inspect.getsourcefile(fn) + except TypeError: + filename = None + if config.wrap_top_frame or ( + (filename is None or trace_rules.check(fn)) + and ( + getattr(fn, "__name__", "") + not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"] + ) + and filename not in DONT_WRAP_FILES + ): + # call to a builtin without a frame for us to capture + fn = external_utils.wrap_inline(fn) + + def do_nothing(*arg, **kwargs): + pass + + if hasattr(self, "callback"): + callback = self.callback + else: + callback = do_nothing + + is_jit_tracing = torch._C._is_tracing + is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing + + @functools.wraps(fn) + def compile_wrapper(*args, **kwargs): + prior = set_eval_frame(None) + try: + if is_fx_tracing(): + if config.error_on_nested_fx_trace: + raise RuntimeError( + "Detected that you are using FX to symbolically trace " + "a dynamo-optimized function. This is not supported at the moment." + ) + else: + return fn(*args, **kwargs) + + if is_jit_tracing(): + raise RuntimeError( + "Detected that you are using FX to torch.jit.trace " + "a dynamo-optimized function. This is not supported at the moment." + ) + + cleanups = [enter() for enter in self.enter_exit_hooks] + prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe( + _is_skip_guard_eval_unsafe_stance() + ) + + # Ensure that if an assertion occurs after graph pushes + # something onto the DynamicLayerStack then we pop it off (the + # constructed graph code isn't guarded with try/finally). + # + # This used to be a context but putting a `with` here is a noticeable + # perf regression (#126293) + saved_dynamic_layer_stack_depth = ( + torch._C._functorch.get_dynamic_layer_stack_depth() + ) + _maybe_set_eval_frame(_callback_from_stance(callback)) + + try: + return fn(*args, **kwargs) + except Unsupported as e: + if config.verbose: + raise + # strip internal tracebacks from causes + cur_exn: BaseException = e + while cur_exn.__cause__ is not None: + cur_exn.__cause__.with_traceback(None) + cur_exn = cur_exn.__cause__ + raise e.with_traceback(None) from e.__cause__ # User compiler error + except ShortenTraceback as e: + # Failures in the backend likely don't have useful + # data in the TorchDynamo frames, so we strip them out. + raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 + finally: + # Restore the dynamic layer stack depth if necessary. + set_eval_frame(None) + torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( + saved_dynamic_layer_stack_depth + ) + + set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe) + for cleanup in cleanups: + cleanup() + finally: + _maybe_set_eval_frame(prior) + + # hooks to properly handle inlining + compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined] + + # Save the function pointer to find the original callable while nesting + # of decorators. + compile_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + + # when compiling user function instead of nn.Module + # provide public api _fn.get_compiler_config() + assert not hasattr(compile_wrapper, "get_compiler_config") + compile_wrapper.get_compiler_config = get_compiler_config # type: ignore[attr-defined] + + # If the function is called using torch._dynamo.optimize decorator, we + # should prevent any type of skipping. + if callback not in (None, False): + if not hasattr(fn, "__code__"): + raise RuntimeError( + textwrap.dedent( + """ + + torch._dynamo.optimize is called on a non function object. + If this is a callable class, please wrap the relevant code into a function and optimize the + wrapper function. + + >> class CallableClass: + >> def __init__(self) -> None: + >> super().__init__() + >> self.relu = torch.nn.ReLU() + >> + >> def __call__(self, x): + >> return self.relu(torch.sin(x)) + >> + >> def print_hello(self): + >> print("Hello world") + >> + >> mod = CallableClass() + + If you want to optimize the __call__ function and other code, wrap that up in a function + + >> def wrapper_fn(x): + >> y = mod(x) + >> return y.sum() + + and then optimize the wrapper_fn + + >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn) + """ + ) + ) + always_optimize_code_objects[fn.__code__] = True + + return compile_wrapper + + +class OptimizeContext(_TorchDynamoContext): + def __init__( + self, + callback, + backend_ctx_ctor, + first_ctx=False, + *, + export=False, + dynamic=None, + compiler_config=None, + rebuild_ctx: Optional[ + Callable[[], Union[OptimizeContext, _NullDecorator]] + ] = None, + package=None, + ) -> None: + def on_enter(): + install_generation_tagging_init() + + super().__init__( + callback=callback, + on_enter=on_enter, + backend_ctx_ctor=backend_ctx_ctor, + patch_fn=TorchPatcher.patch, + first_ctx=first_ctx, + export=export, + dynamic=dynamic, + compiler_config=compiler_config, + package=package, + ) + + if config.compiled_autograd: + _dynamic = self._dynamic + if _dynamic is None: + _dynamic = not torch._dynamo.config.assume_static_by_default + + def call_compiled_autograd(): + assert rebuild_ctx is not None + compiler_fn = rebuild_ctx() + ctx = torch._dynamo.compiled_autograd._enable( + compiler_fn, dynamic=_dynamic, ignore_active_disable_ctx=False + ) + ctx.__enter__() + return functools.partial(ctx.__exit__, None, None, None) + + self.enter_exit_hooks.append(call_compiled_autograd) + + def __reduce__(self): + return ( + self.__class__, + (self.callback, self._backend_ctx_ctor, self.first_ctx), + { + "export": self.export, + "dynamic": self._dynamic, + "compiler_config": self.compiler_config, + }, + ) + + +class RunOnlyContext(_TorchDynamoContext): + def __init__(self) -> None: + # cudagraph trees relies on generation increment + def on_enter(): + torch._dynamo.mutation_guard.GenerationTracker.generation += 1 + + super().__init__(callback=False, on_enter=on_enter) + + def __reduce__(self): + return (self.__class__, ()) + + +class DisableContext(_TorchDynamoContext): + def __init__(self, msg: Optional[str] = None, wrapping: bool = True) -> None: + super().__init__(callback=None) + self.msg = msg + self.wrapping = wrapping + + def __call__(self, fn): + # Earlier this code was in the base class _TorchDynamoContext. But we + # moved it here to have better code organization. For disable, we just + # want the callback to be None. We don't have to check trace_rules or + # create any wrapper. + fn = innermost_fn(fn) + + if isinstance(fn, torch.nn.Module): + mod = fn + new_mod = OptimizedModule(mod, self) + new_mod._torchdynamo_orig_callable = mod.forward + return new_mod + + if isinstance(fn, type): + # User has wrapped the class with compile/disable decorator. Apply + # disable to init/call method. + cls_obj = fn + # Disable on init is useful for reconstruction of bytecodes where we + # want to prevent Dynamo from tracing into the init function. Check + # test_reconstruction in test_model_output.py. + cls_obj.__init__ = self(cls_obj.__init__) # type: ignore[misc] + cls_obj.__call__ = self(cls_obj.__call__) + if issubclass(cls_obj, torch.nn.Module): + # NN module variable tracker directly inlines the _call_impl. Disable it. + cls_obj._call_impl = self(cls_obj._call_impl) + return cls_obj + + assert callable(fn), ( + f"A callable function is expected, but {type(fn)} is provided." + ) + + def _fn(*args, **kwargs): + prior = set_eval_frame(None) + try: + _maybe_set_eval_frame(_callback_from_stance(self.callback)) + try: + return fn(*args, **kwargs) + finally: + set_eval_frame(None) + finally: + _maybe_set_eval_frame(prior) + + # Under some circumstances (e.g. precompile) we can end up calling @disable + # decorator in generated bytecode and trigger recompile. This is due to the + # fact that the old callback from torch.compile() is still active and under + # this circumstance we will trigger a failure with set_stance("fail_on_recompile"). + # Therefore we want to skip calling into any frame in this case. + if self.wrapping: + _fn = functools.wraps(fn)(_fn) + + _fn._torchdynamo_disable = True # type: ignore[attr-defined] + _fn._torchdynamo_disable_msg = self.msg # type: ignore[attr-defined] + + # Save the function pointer to find the original callable while nesting + # of decorators. + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + + return _fn + + def __reduce__(self): + return (self.__class__, ()) + + +def _optimize_catch_errors( + compile_fn, + hooks: Hooks, + backend_ctx_ctor=null_context, + export=False, + dynamic=None, + compiler_config=None, + rebuild_ctx=None, + package=None, +): + return OptimizeContext( + convert_frame.catch_errors_wrapper(compile_fn, hooks), + backend_ctx_ctor=backend_ctx_ctor, + first_ctx=True, + export=export, + dynamic=dynamic, + compiler_config=compiler_config, + rebuild_ctx=rebuild_ctx, + package=package, + ) + + +def get_compiler_fn(compiler_fn): + from .repro.after_dynamo import wrap_backend_debug + + if hasattr(compiler_fn, "compiler_name"): + compiler_str = compiler_fn.compiler_name + elif isinstance(compiler_fn, str): + compiler_str = compiler_fn + else: + compiler_str = None + compiler_fn = lookup_backend(compiler_fn) + return wrap_backend_debug(compiler_fn, compiler_str) + + +class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] + def __call__(self, fn): + assert callable(fn), ( + f"A callable function is expected, but {type(fn)} is provided." + ) + return fn + + +def check_if_dynamo_supported(): + if sys.version_info >= (3, 14): + raise RuntimeError("Python 3.14+ not yet supported for torch.compile") + elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < ( + 3, + 13, + 3, + ): + raise RuntimeError( + "torch.compile is not supported on Python < 3.13.3 built with GIL disabled. " + "Please use Python 3.13.3+." + ) + + +def is_dynamo_supported(): + try: + check_if_dynamo_supported() + return True + except Exception: + return False + + +def check_if_inductor_supported(): + check_if_dynamo_supported() + + +def is_inductor_supported(): + try: + check_if_inductor_supported() + return True + except Exception: + return False + + +def check_for_incompatible_configs(): + # Some of the configs should be mutually exclusive + assert not (config.suppress_errors and config.fail_on_recompile_limit_hit), ( + "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time." + ) + + +def optimize(*args, **kwargs): + def rebuild_ctx(): + ca_kwargs_override = config.compiled_autograd_kwargs_override + if ca_kwargs_override: + # NOTE: The process of translating other `torch.compile` kwargs to `torch._dynamo.optimize` kwargs + # is more complicated, we will add it in the future when needed. + assert set(ca_kwargs_override.keys()) == {"fullgraph"}, ( + f"Only `fullgraph` kwarg override is supported for now, but got {ca_kwargs_override.keys()}" + ) + kwargs["nopython"] = ca_kwargs_override["fullgraph"] + return optimize(*args, **kwargs) + + return _optimize(rebuild_ctx, *args, **kwargs) + + +def _optimize( + rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], + backend="inductor", + *, + nopython=False, + guard_export_fn=None, + guard_fail_fn=None, + guard_filter_fn=None, + disable=False, + dynamic=None, + package=None, +) -> Union[OptimizeContext, _NullDecorator]: + """ + The main entrypoint of TorchDynamo. Do graph capture and call + backend() to optimize extracted graphs. + + Args: + backend: One of the two things: + - Either, a function/callable taking a torch.fx.GraphModule and + example_inputs and returning a python callable that runs the + graph faster. + One can also provide additional context for the backend, like + torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. + See AOTAutogradMemoryEfficientFusionWithContext for the usage. + - Or, a string backend name in `torch._dynamo.list_backends()` + nopython: If True, graph breaks will be errors and there will + be a single whole-program graph. + disable: If True, turn this decorator into a no-op + dynamic: If True, upfront compile as dynamic a kernel as possible. If False, + disable all dynamic shapes support (always specialize). If None, automatically + detect when sizes vary and generate dynamic kernels upon recompile. + + Example Usage:: + + @torch._dynamo.optimize() + def toy_example(a, b): ... + """ + check_if_dynamo_supported() + check_for_incompatible_configs() + # Note: The hooks object could be global instead of passed around, *however* that would make + # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. + # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same + # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an + # easier to understand UX at the cost of a little more plumbing on our end. + hooks = Hooks( + guard_export_fn=guard_export_fn, + guard_fail_fn=guard_fail_fn, + guard_filter_fn=guard_filter_fn, + ) + torch._C._log_api_usage_once("torch._dynamo.optimize") + if ( + disable + or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1" + or (not justknobs_check("pytorch/compiler:enable_dynamo")) + ): + return _NullDecorator() + + if nopython: + return optimize_assert( + backend, + dynamic=dynamic, + hooks=hooks, + rebuild_ctx=rebuild_ctx, + package=package, + ) + + backend = get_compiler_fn(backend) + + # Find if backend has any extra context manager + backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) + + # The backend function is stashed in the callable returned by + # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can + # be used by eval_frame.c to insert a guard on the backend. + return _optimize_catch_errors( + convert_frame.convert_frame(backend, hooks=hooks, package=package), + hooks, + backend_ctx_ctor, + dynamic=dynamic, + compiler_config=( + backend.get_compiler_config() + if hasattr(backend, "get_compiler_config") + else None + ), + rebuild_ctx=rebuild_ctx, + package=package, + ) + + +# TODO(voz): Consider making "explain" output alongside a run / part of a run +@patch("torch._dynamo.symbolic_convert.explain", True) +def explain(f, *extra_args, **extra_kwargs): + def inner(*args, **kwargs): + # TODO(voz): Do we want a decorator for this? + from . import reset # type: ignore[attr-defined] + + reset() + + graphs: list[torch.fx.GraphModule] = [] + break_reasons: list[Any] = [] + op_count: int = 0 + ops_per_graph: list[torch.fx.Node] = [] + out_guards: list[_guards.Guard] = [] + + def dynamo_graph_accumulating_compiler( + gm: torch.fx.GraphModule, example_inputs + ): + from .backends.debugging import _explain_graph_detail + + nonlocal graphs + nonlocal op_count + nonlocal ops_per_graph + nonlocal break_reasons + + gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail( + gm, graphs, op_count, ops_per_graph, break_reasons + ) + + return gm.forward + + def guard_export_print(guards): + nonlocal out_guards + out_guards.extend(guards) + + opt_f = optimize( + dynamo_graph_accumulating_compiler, + nopython=False, + guard_export_fn=guard_export_print, + )(f) + # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject. + opt_f(*args, **kwargs) + + graph_count = len(graphs) + graph_break_count = graph_count - 1 + compile_time = compile_times(repr="str") + + # TODO(voz): Do we want a decorator for this? + reset() + from .backends.debugging import ExplainOutput + + return ExplainOutput( + graphs, + graph_count, + graph_break_count, + break_reasons, + op_count, + ops_per_graph, + out_guards, + compile_time, + ) + + if extra_args or extra_kwargs: + warnings.warn( + "explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. " + "If you don't migrate, we may break your explain call in the future if your user defined kwargs " + "conflict with future kwargs added to explain(f).", + FutureWarning, + stacklevel=2, + ) + return inner(*extra_args, **extra_kwargs) + else: + return inner + + +class FlattenInputOutputSignature(torch.fx.Transformer): + def __init__( + self, + m: torch.fx.GraphModule, + flat_args: tuple[Any], + matched_input_elements_positions: list[int], + flat_results: list[Any], + matched_output_elements_positions: list[int], + example_fake_inputs: list[torch.Tensor], + flat_args_dynamic_dims: list[set[int]], + fake_mode: Optional[fake_tensor.FakeTensorMode] = None, + ) -> None: + super().__init__(m) + + assert len(flat_args_dynamic_dims) == len(flat_args) + matched_input_elements_to_fake = { + val: example_fake_inputs[ix] + for ix, val in enumerate(matched_input_elements_positions) + } + + self.new_args = [] + for i in range(0, len(flat_args)): + arg = super().placeholder(f"arg{i}", (), {}) + if i in matched_input_elements_to_fake: + arg.node.meta["val"] = matched_input_elements_to_fake[i] + else: + # Fill node.meta["val"] with faketensor from the input, + # if it's not found in matched_input_elements_positions + if fake_mode is not None and isinstance(flat_args[i], torch.Tensor): + # TODO(zhxchen17) Also preserve all the user constraints here. + arg.node.meta["val"] = fake_mode.from_tensor( + flat_args[i], + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[ + ( + DimDynamic.DYNAMIC + if d in flat_args_dynamic_dims[i] + else DimDynamic.STATIC + ) + for d in range(len(flat_args[i].shape)) + ], + constraint_sizes=[None] * len(flat_args[i].shape), + ), + ) + elif isinstance(flat_args[i], _IntWrapper): + arg.node.meta["val"] = flat_args[i].val + else: + arg.node.meta["val"] = flat_args[i] + + self.new_args.append(arg) + self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions) + self.matched_output_elements_positions = matched_output_elements_positions + self.flat_results = flat_results + + def placeholder(self, target, args, kwargs): + arg = next(self.old_args_gen) + if "val" in self.current_node.meta: + arg.node.meta["val"] = self.current_node.meta["val"] + if "tensor_dict" in self.current_node.meta: + arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"] + if "example_value" in self.current_node.meta: + # NB: intentionally do not use set_example_value + arg.node.meta["example_value"] = self.current_node.meta["example_value"] + if "unbacked_bindings" in self.current_node.meta: + arg.node.meta["unbacked_bindings"] = self.current_node.meta[ + "unbacked_bindings" + ] + return arg + + def output(self, target, args, kwargs): + dynamo_result_flat = args[0] + lookup = [*dynamo_result_flat, *self.new_args] + new_results_flat = [] + for i in range(len(self.flat_results)): + if self.matched_output_elements_positions[i] is not None: + new_results_flat.append( + lookup[self.matched_output_elements_positions[i]] + ) + else: + const_val = self.flat_results[i] + assert isinstance(const_val, tuple(common_constant_types)) + new_results_flat.append(const_val) + return super().output(target, (new_results_flat,), {}) + + def run_node(self, n): + self.current_node = n + result_proxy = super().run_node(n) + if "val" in self.current_node.meta: + result_proxy.node.meta["val"] = self.current_node.meta["val"] + if "example_value" in self.current_node.meta: + # NB: intentionally do not use set_example_value + result_proxy.node.meta["example_value"] = self.current_node.meta[ + "example_value" + ] + if "unbacked_bindings" in self.current_node.meta: + result_proxy.node.meta["unbacked_bindings"] = self.current_node.meta[ + "unbacked_bindings" + ] + if self.current_node.op != "output": + result_proxy.node._rename( + getattr(self.current_node, "name", result_proxy.node.name) + ) + return result_proxy + + def transform(self): + result_gm = super().transform() + if "dynamo_flat_name_to_original_fqn" in self.module.meta: # type: ignore[operator] + result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ # type: ignore[index] + "dynamo_flat_name_to_original_fqn" # type: ignore[index] + ] + if "dynamo_compile_id" in self.module.meta: # type: ignore[operator] + result_gm.meta["dynamo_compile_id"] = self.module.meta["dynamo_compile_id"] # type: ignore[index] + return result_gm + + +class ExportResult(NamedTuple): + graph_module: torch.fx.GraphModule + guards: _guards.GuardsSet + # NB: Do not add new fields without overriding __iter__; people are + # destructuring so it is BC-breaking + + +# NOTE: this function only supports graphs created by Dynamo's OutputGraph module +def check_signature_rewritable(graph): + input_errors = [] + for node in graph.graph.find_nodes(op="placeholder"): + # set in OutputGraph._call_user_compiler + assert hasattr(node, "_dynamo_source") + assert hasattr(graph, "_source_to_user_stacks") + + source = node._dynamo_source + user_stacks = graph._source_to_user_stacks.get(source) + if user_stacks is None: + continue + assert len(user_stacks) > 0 + # In some cases we may not have a useful stack. Look for a + # useful stack + stack = None + for s in user_stacks: + if len(s) == 0: + continue + stack = s + break + if stack is None: + msg = f"{source.name()}, a closed over free variable" + else: + tb = "".join(traceback.format_list(stack)) + extra = "" + if len(user_stacks) > 1: + extra = f"(elided {len(user_stacks) - 1} more accesses)" + msg = f"{source.name()}, accessed at:\n{tb}{extra}" + # TODO: option to print ALL of the stack traces at once + input_errors.append(msg) + + if input_errors: + raise UserError( + UserErrorType.INVALID_INPUT, + "Cannot export model which references tensors that are neither " + "buffers/parameters/constants nor are direct inputs. For each tensor, if you'd " + "like this tensor to be an explicit input, add it as a dummy argument " + "to the top-level model definition you are exporting; if you would " + "like its value to be embedded as an exported constant, wrap its access " + "in a function marked with @assume_constant_result.\n\n" + + "\n\n".join(input_errors), + ) + + +def rewrite_signature( + f_sig, + graph, + fake_mode, + flat_args, + in_spec, + example_fake_inputs, + graph_captured_input, + graph_captured_output, + dynamo_traced_result, + flat_args_dynamic_dims, +): + orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec) + + def check_user_input_output(flat_values, error_type): + supported_types = [ + torch.Tensor, + torch.SymInt, + torch.SymFloat, + torch.SymBool, + torch._C.ScriptObject, + _IntWrapper, + ] + list(common_constant_types) + + def is_supported_type(val): + return isinstance(val, tuple(supported_types)) + + value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output" + # We only check that the outputs are not None. Inputs can be None. + for v in flat_values: + if not is_supported_type(v): + if error_type == UserErrorType.INVALID_INPUT and v is None: + continue + + raise UserError( + error_type, + f"It looks like one of the {value_type}s with type `{type(v)}` " + "is not supported or pytree-flattenable. \n" + f"Exported graphs {value_type}s can only contain the " + f"following supported types: {supported_types}. \n" + "If you are using a custom class object, " + "please register a pytree_flatten/unflatten function " + "using `torch.utils._pytree.register_pytree_node` or " + "`torch.export.register_dataclass`.", + ) + + check_user_input_output(flat_args, UserErrorType.INVALID_INPUT) + flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) + check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT) + + def check_optional_input_and_error(f_sig: inspect.Signature): + # Check if function has optional input. + for name, param in f_sig.parameters.items(): + if param.default is not inspect.Parameter.empty: + from torch._dynamo.exc import Unsupported + + log.error( + "Parameter %s is optional with a default value of %s", + name, + param.default, + ) + raise Unsupported( + "Tracing through optional input is not supported yet", + case_name="optional_input", + ) + + def produce_matching(debug_type, sources, candidates): + matched_elements_positions: list[Optional[int]] = [] + dict_of_source_vals = {} + for i, val in enumerate(sources): + dict_of_source_vals[id(val)] = i + + for i, val in enumerate(candidates): + if isinstance(val, tuple(common_constant_types)): + matched_elements_positions.append(None) + elif id(val) not in dict_of_source_vals: + if debug_type == "inputs": + check_optional_input_and_error(f_sig) + raise AssertionError( + f"Unexpectedly found a {type(val)} in the {debug_type}.\n" + 'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"', + ) + else: + matched_elements_positions.append(dict_of_source_vals[id(val)]) + + return matched_elements_positions + + matched_input_elements_positions = produce_matching( + "inputs", flat_args, graph_captured_input + ) + + assert graph_captured_output is not None + matched_output_elements_positions = produce_matching( + "outputs", list(graph_captured_output) + flat_args, flat_results_traced + ) + + new_graph = FlattenInputOutputSignature( + graph, + flat_args, + matched_input_elements_positions, + flat_results_traced, + matched_output_elements_positions, + example_fake_inputs, + flat_args_dynamic_dims, + fake_mode, + ).transform() + + # Make dynamo graph to have same input/output spec as user code + def argument_names(f_sig, args, kwargs) -> list[str]: + def signature_to_fullargspec(sig: inspect.Signature): + # Get a list of Parameter objects from the Signature object + params = list(sig.parameters.values()) + # Separate positional arguments, keyword-only arguments and varargs/varkw + args = [ + p.name + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + kwonlyargs = [ + p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY + ] + varargs = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), + None, + ) + varkw = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), + None, + ) + # Get default values for positional arguments and keyword-only arguments + defaults = tuple( + p.default + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is not inspect.Parameter.empty + ) + kwonlydefaults = { + p.name: p.default + for p in params + if p.kind == inspect.Parameter.KEYWORD_ONLY + and p.default is not inspect.Parameter.empty + } + # Get annotations for parameters and return value + annotations = {} + if sig.return_annotation: + annotations = {"return": sig.return_annotation} + for parameter in params: + annotations[parameter.name] = parameter.annotation + # Return a FullArgSpec object with the extracted attributes + return inspect.FullArgSpec( + args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations + ) + + fullargspec = signature_to_fullargspec(f_sig) + + # 1. Map `args` 1-to-1 to positional arguments in original signature. + input_strs = fullargspec.args[: len(args)] + + if len(args) > len(fullargspec.args): + # 2. If there are more arguments left in `args`, they map to varargs in original + # signature. Assign names as {varargs}_0, {varargs}_1, ... + assert fullargspec.varargs is not None, "More arguments than expected" + input_strs += [ + f"{fullargspec.varargs}_{i}" + for i in range(0, len(args) - len(input_strs)) + ] + elif len(args) < len(fullargspec.args): + # 3. If there are fewer arguments in `args` than `fullargspec.args`, + # it implies these are arguments either with default values, or provided in + # `kwargs`. The former can be safely ignored. Because Dynamo.export does not + # export them as part of the function signature. The latter will be handled + # in the next step. + for unprovided_arg in fullargspec.args[ + len(args) : -len(fullargspec.defaults or []) + ]: + assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}" + + # 4. Keyword arguments provided in `kwargs`. + input_strs += list(kwargs.keys()) + + # 5. Keyword-only arguments with default values if not provided are not exported + # as part of the function signature. + for kwonly_arg in fullargspec.kwonlyargs: + kwonlydefaults = fullargspec.kwonlydefaults or {} + assert kwonly_arg in kwargs or kwonly_arg in kwonlydefaults, ( + f"Missing keyword only argument {kwonly_arg}" + ) + + return input_strs + + new_graph.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + argument_names(f_sig, orig_args, orig_kwargs), + in_spec, + out_spec_traced, + ) + ) + new_graph.recompile() + return new_graph + + +def export( + f: Callable[..., Any], + *extra_args, + aten_graph: bool = False, + pre_dispatch: bool = False, + decomposition_table: Optional[ + dict[torch._ops.OpOverload, Callable[..., Any]] + ] = None, + tracing_mode: str = "symbolic", + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + specialize_float: bool = True, + assume_static_by_default: bool = False, + same_signature: bool = True, + disable_constraint_solver: bool = False, + prefer_deferred_runtime_asserts_over_guards: bool = False, + allow_complex_guards_as_runtime_asserts: bool = False, + _log_export_usage: bool = True, + constraints: Optional[list[Constraint]] = None, + **extra_kwargs, +) -> Callable[..., ExportResult]: + """ + Export an input function f to a format that can be executed outside of PyTorch using the FX graph. + + Args: + f (callable): A PyTorch function to be exported. + + aten_graph (bool): If True, exports a graph with ATen operators. + If False, exports a graph with Python operators. Default is False. + + pre_dispatch (bool): If True, exports a graph with ATen operators, + but before any logic in the PyTorch dispatcher has run. + This can be useful if you want to apply further transformations on a graph before running it + through autograd, autocast, or any other functionalities that are integrated into the dispatcher. + This flag is only valid if aten_graph=True is set. + Default is False. + + decomposition_table (dict): A dictionary that maps operators to their decomposition functions. + Required if aten_graph or tracing_mode is specified. Default is None. + + tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic". + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + same_signature (bool): If True, rewrite the returned graph's signature to be the same as f. + + disable_constraint_solver (bool): Whether the dim constraint solver must be disabled. + + Returns: + A function that given args and kwargs, returns a tuple of (graph, guards) + Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options. + Guards: The guards we accumulated during tracing f above + + Raises: + AssertionError: If decomposition_table is specified without setting aten_graph=True, + or if graph breaks during tracing in export. + + AssertionError: If Dynamo input and output is not consistent with traced input/output. + + Note - this headerdoc was authored by ChatGPT, with slight modifications by the author. + """ + if _log_export_usage: + log_export_usage(event="export.private_api", flags={"_dynamo"}) + + # Deal with "local variable referenced before assignment" + _f = f + _specialize_float = specialize_float + _assume_static_by_default = assume_static_by_default + _constraints = constraints + + def inner(*args, **kwargs): + if not _constraints: + combined_args = _combine_args(_f, args, kwargs) + constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) + else: + constraints = _constraints + + f = _f + specialize_float = _specialize_float + assume_static_by_default = _assume_static_by_default + check_if_dynamo_supported() + torch._C._log_api_usage_once("torch._dynamo.export") + if decomposition_table is not None: + assert aten_graph, ( + "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" + ) + if pre_dispatch: + assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" + f = innermost_fn(f) + call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f + original_signature = inspect.signature(call_to_inspect) + graph = None + out_guards = None + graph_captured_input = None + graph_captured_result: Optional[tuple[torch.Tensor, ...]] = None + fake_mode = None + result_traced = None + + def guard_export_print(guards: _guards.GuardsSet): + nonlocal out_guards + assert out_guards is None, ( + "whole graph export entails exactly one guard export" + ) + out_guards = guards + + example_inputs = [] + + def dynamo_normalization_capturing_compiler( + gm: torch.fx.GraphModule, inner_example_inputs + ): + nonlocal graph + assert graph is None, ( + "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." + ) + graph = gm + + nonlocal fake_mode, example_inputs + # NB: do NOT pass inner_example_inputs here, we are detecting the + # Dynamo allocated fake mode, which should be DISTINCT from a + # potential outer ambient fake mode which the user provided. + # example_inputs is always the user specified inputs, so they + # would have the wrong fake mode attached to them + fake_mode = _guards.detect_fake_mode() + example_inputs = inner_example_inputs + + def result_capturing_wrapper(*graph_inputs): + nonlocal graph_captured_result + nonlocal graph_captured_input + + graph_captured_input = graph_inputs + assert graph is not None + + named_parameters = dict(graph.named_parameters(remove_duplicate=False)) + named_buffers = dict(graph.named_buffers(remove_duplicate=False)) + + ambient_fake_mode = ( + _guards.detect_fake_mode(graph_inputs) + if _guards.detect_fake_mode(graph_inputs) is not None + else fake_mode + ) + + # We reran fake tensor propagation, but we didn't do + # anything with the resulting unbacked SymInts. Drop them + # from the pending list. + # NB: this is wrong if graph_captured_result has + # data-dependent output size! + ignore_fresh_unbacked = null_context() + assert ambient_fake_mode is not None + if shape_env := ambient_fake_mode.shape_env: + ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() + + with ( + ambient_fake_mode, + enable_python_dispatcher(), + ignore_fresh_unbacked, + ): + params_and_buffers = { + **named_parameters, + **named_buffers, + } + fake_params_buffers = {} + + for name, value in params_and_buffers.items(): + fake_params_buffers[name] = ambient_fake_mode.from_tensor( + value, static_shapes=True + ) + + def fakify_with_ambient(path, t): + if isinstance(t, torch.Tensor): + return ambient_fake_mode.from_tensor(t, static_shapes=True) + elif isinstance(t, _IntWrapper): + if ( + t.dynamism is not None + and isinstance(t.dynamism, _DimHint) + and t.dynamism.type + in ( + _DimHintType.DYNAMIC, + _DimHintType.AUTO, + ) + ): # type: ignore[union-attr] + from torch._export.non_strict_utils import ( + key_path_to_source, + ) + + source = key_path_to_source(path) + symint = ambient_fake_mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr] + t.val, source, DimDynamic.DYNAMIC + ) + return symint + else: + return t.val + else: + return t + + fake_graph_inputs = pytree.tree_map_with_path( + fakify_with_ambient, graph_inputs + ) + graph_captured_result = torch.func.functional_call( + graph, fake_params_buffers, fake_graph_inputs + ) + + return graph_captured_result + + return result_capturing_wrapper + + # Note: This is needed by rewrite_signature. We need to put it before + # optimize_assert since user program may mutate the inputs. + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + + remove_from_cache(f) + constraint_violation_error = None + if tracing_mode != "symbolic": + assume_static_by_default = True + with ( + config.patch( + specialize_int=True, + specialize_float=specialize_float, + assume_static_by_default=assume_static_by_default, + automatic_dynamic_shapes=False, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + ), + _compiling_state_context(), + ): + opt_f = optimize_assert( + dynamo_normalization_capturing_compiler, + hooks=Hooks( + guard_export_fn=guard_export_print, + guard_fail_fn=None, + ), + export=True, + export_constraints=constraints, + )(f) + # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject. + try: + result_traced = opt_f(*args, **kwargs) + except ConstraintViolationError as e: + constraint_violation_error = e + remove_from_cache(f) + + if ( + not disable_constraint_solver + and (shape_env := getattr(fake_mode, "shape_env", None)) is not None + and (dim_constraints := shape_env.dim_constraints) is not None + and not isinstance( + call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) + and not trace_rules.check(call_to_inspect) + ): + dim_constraints.solve() + forced_specializations = dim_constraints.forced_specializations() + msg = dim_constraints.prettify_results( + original_signature, + dynamic_shapes, + constraint_violation_error, + forced_specializations, + ) + if constraint_violation_error: + constraint_violation_error.args = ( + constraint_violation_error.args[0] + msg, + ) + else: + if forced_specializations: + constraint_violation_error = ConstraintViolationError(msg) + else: + log.info( + "Summary of dimension constraints:%s", + msg, + ) + + # Error if we have any constraints on static values + for k in shape_env.var_to_range.keys(): + if isinstance(k, sympy.Integer): + constraint_violation_error = ConstraintViolationError( + f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" + "It appears that you're trying to set a constraint on a " + f"value which we evaluated to have a static value of {k}. " + 'Set TORCH_LOGS="+export" for more information.' + ) + if constraint_violation_error: + raise constraint_violation_error + + if graph is None: + assert same_signature, ( + "Failed to produce a graph during tracing as no tensor operations were found and same_signature is False." + ) + # If the module does not contain any tensor computation, we would create a graph with inputs and outputs. + # To be consistent with the graph traced by dynano, `graph` will have only tensor inputs as placeholders + # and tensor outputs as output nodes. non-tensor inputs and outputs will be added when rewriting signature. + # We will also construct the `example_inputs`, `graph_captured_input`, and `graph_captured_result` corresponding + # to `graph`. + example_inputs = [] + graph_captured_input = () + graph_captured_result = () + fake_mode = torch._subclasses.FakeTensorMode( + shape_env=ShapeEnv(), export=True + ) + if out_guards is None: + out_guards = _guards.GuardsSet() + assert out_guards is not None # suppress mypy error + parameter_names = list(original_signature.parameters.keys()) + fx_graph = torch.fx.Graph() + for i, name in enumerate(parameter_names): + if torch.is_tensor(flat_args[i]): + node = fx_graph.placeholder(name) + node.meta["val"] = fake_mode.from_tensor( + flat_args[i], static_shapes=True + ) + graph_captured_input = graph_captured_input + (flat_args[i],) + example_inputs.append(flat_args[i]) + fx_graph.output(graph_captured_result) + module = torch.nn.Module() + graph = torch.fx.GraphModule(module, fx_graph) + log.info( + "Failed to capture a graph during tracing as no tensor operations were found.:\n\n%s", + graph.print_readable(print_output=False, colored=True), + ) + else: + assert out_guards is not None, "Failed to produce guards during tracing" + assert fake_mode is not None + + log.info( + "Dynamo captured graph:\n\n%s", + graph.print_readable(print_output=False, colored=True), + ) + + # This check need to happened before aten_graph + # because placeholder's _source_node attribute is not preserved by make_fx + if same_signature: + check_signature_rewritable(graph) + + # NB: This is mostly hitting the cache; Dynamo already converted these + example_fake_inputs = [ + fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in example_inputs + ] + + if aten_graph: + # Running graph with interpreter is needed for propagating the stack_trace + def graph_with_interpreter(*args): + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type] + + with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode: + try: + graph = make_fx( + graph_with_interpreter, + decomposition_table=decomposition_table, + tracing_mode="real", + _allow_non_fake_inputs=True, + pre_dispatch=pre_dispatch, + _allow_fake_constant=False, + )(*example_fake_inputs) + except CondOpArgsMismatchError as e: + # Wrap the internal error to the user-facing error + raise UserError( # noqa: B904 + UserErrorType.DYNAMIC_CONTROL_FLOW, + str(e), + case_name="cond_operands", + ) + + assert graph is not None + for node in graph.graph.find_nodes(op="get_attr"): + if isinstance(getattr(graph, node.target), torch.Tensor): # type: ignore[arg-type] + node.meta["val"] = fake_mode.from_tensor( + getattr(graph, node.target), # type: ignore[arg-type] + static_shapes=True, + ) + + if same_signature: + flat_args_dynamic_dims = [ + { + c.dim + for c in (constraints or ()) + if ( + c.t_id == id(x) + and not isinstance(c, _RelaxedConstraint) + and c.constraint_range.vr.lower != c.constraint_range.vr.upper + ) + } + for x in flat_args + ] + graph = rewrite_signature( + original_signature, + graph, + fake_mode, + flat_args, + in_spec, + example_fake_inputs, + graph_captured_input, + graph_captured_result, + result_traced, # type: ignore[possibly-undefined] + flat_args_dynamic_dims, + ) + return ExportResult(graph, out_guards) # type: ignore[arg-type] + + if extra_args or extra_kwargs: + warnings.warn( + "export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. " + "If you don't migrate, we may break your export call in the future if your user defined kwargs " + "conflict with future kwargs added to export(f).", + FutureWarning, + stacklevel=2, + ) + return inner(*extra_args, **extra_kwargs) + else: + return inner + + +def optimize_assert(*args, **kwargs): + if "rebuild_ctx" in kwargs and kwargs["rebuild_ctx"] is not None: + # called from optimize + rebuild_ctx = kwargs["rebuild_ctx"] + del kwargs["rebuild_ctx"] + else: + + def rebuild_ctx(): + return optimize_assert(*args, **kwargs) + + return _optimize_assert(rebuild_ctx, *args, **kwargs) + + +def _optimize_assert( + rebuild_ctx: Callable[[], OptimizeContext], + backend, + *, + hooks=Hooks(None, None, None), + export=False, + export_constraints=None, + dynamic=None, + package=None, +): + """ + The same as `torch._dynamo.optimize(backend, nopython=True)` + """ + backend = get_compiler_fn(backend) + + # Find if backend has any extra context manager + backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) + + return _optimize_catch_errors( + convert_frame.convert_frame_assert( + backend, + export=export, + export_constraints=export_constraints, + package=package, + ), + hooks, + backend_ctx_ctor, + export=export, + dynamic=dynamic, + rebuild_ctx=rebuild_ctx, + package=package, + ) + + +class TorchPatcher: + @staticmethod + @functools.cache + def patch(): + # A better way to disable the following would be decorate the source + # functions with @torch._disable_dynamo. However, this causes issues + # with torch.deploy internally. + from .decorators import disable + + torch.jit.trace = disable( + torch.jit.trace, reason="tracing into TorchScript not fully supported" + ) + torch.jit.trace_module = disable( + torch.jit.trace_module, + reason="tracing into TorchScript not fully supported", + ) + torch.jit._get_trace_graph = disable( + torch.jit._get_trace_graph, + reason="tracing into TorchScript not fully supported", + ) + torch.fx._symbolic_trace.Tracer.trace = disable( + torch.fx._symbolic_trace.Tracer.trace, + reason="tracing into FX not fully supported", + ) + torch.distributions.Distribution.set_default_validate_args(False) + + from torch.optim import ( + adadelta, + adagrad, + adam, + adamax, + adamw, + asgd, + lbfgs, + nadam, + radam, + rmsprop, + rprop, + sgd, + sparse_adam, + ) + + optimizer_modules = { + adadelta, + adagrad, + adam, + adamax, + adamw, + asgd, + lbfgs, + nadam, + radam, + rmsprop, + rprop, + sgd, + sparse_adam, + } + + for opt_mod in optimizer_modules: + opt_name = opt_mod.__name__.split(".")[-1] + fused_fn_name = f"_fused_{opt_name}" + + if hasattr(opt_mod, fused_fn_name): + setattr( + opt_mod, + fused_fn_name, + disable( + getattr(opt_mod, fused_fn_name), + reason="don't trace into fused optimizer", + ), + ) + + optimizer_classes = [ + opt + for opt in torch.optim.__dict__.values() + if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer) + ] + + # Note: we don't support sparsity or tracing through backwards + excluded_optimizer_classes = { + torch.optim.SparseAdam, + torch.optim.LBFGS, + } + + for opt in optimizer_classes: + if opt in excluded_optimizer_classes: + opt.step = disable( + opt.step, reason=f"optimizer {opt} step not supported" + ) + + if hasattr(opt, "_init_group"): + opt._init_group = disable( + opt._init_group, reason=f"optimizer {opt} _init_group not supported" + ) + + @staticmethod + def suppress_torch_distributed_warnings(fn): + def inner_fn(*args, **kwargs): + warnings.filterwarnings( + "ignore", category=UserWarning, module="torch.distributed" + ) + return fn(*args, **kwargs) + + return inner_fn + + +def skip_code(code: types.CodeType): + set_code_exec_strategy( + code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT) + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/exc.py b/phivenv/Lib/site-packages/torch/_dynamo/exc.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ff3c9f1d3d0af3cb49bc1d96355a3bd87f3797 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/exc.py @@ -0,0 +1,740 @@ +from __future__ import annotations + + +"""Exception handling and error reporting for TorchDynamo. + +This module provides a comprehensive set of exception classes and utilities for error +handling in TorchDynamo. It includes: + +Base Exceptions: + - TorchDynamoException: Base class for all TorchDynamo-specific exceptions + - Various specialized subclasses for different error scenarios + +User Error Handling: + - UserError: Exceptions for user-facing errors in TorchDynamo usage + - UserErrorType: Enumeration of different categories of user errors + - Formatted error messages with debugging information + +Observed Exceptions: + - Classes for handling exceptions observed during tracing + - Special handling for StopIteration, LookupError, etc. + - Exception state management during compilation + +Error Formatting: + - Stack trace filtering and formatting + - Error message augmentation + - Debugging utilities for error reporting +""" + +import logging +import os +import re +import textwrap +import typing +from enum import auto, Enum +from traceback import extract_stack, format_exc, format_list, StackSummary +from typing import Any, NoReturn, Optional, TYPE_CHECKING + +import torch._guards + +from . import config +from .utils import counters + + +if TYPE_CHECKING: + import types + + from torch._guards import CompileId + + from .symbolic_convert import InstructionTranslatorBase + from .types import DynamoFrameType + + +def exportdb_error_message(case_name: str) -> str: + return ( + "For more information about this error, see: " + + "https://pytorch.org/docs/main/generated/exportdb/index.html#" + + case_name.replace("_", "-") + ) + + +log = logging.getLogger(__name__) +graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") + + +class TorchDynamoException(RuntimeError): + pass + + +class InternalTorchDynamoError(TorchDynamoException): + pass + + +class RestartAnalysis(TorchDynamoException): + restart_reason: Optional[str] + + def __init__(self, *args: Any, restart_reason: Optional[str] = None) -> None: + self.restart_reason = restart_reason + super().__init__(*args) + + +class SpeculationRestartAnalysis(RestartAnalysis): + pass + + +class UnspecializeRestartAnalysis(RestartAnalysis): + pass + + +class CompileCollectiveRestartAnalysis(RestartAnalysis): + pass + + +class TensorifyScalarRestartAnalysis(RestartAnalysis): + pass + + +class SkipFrame(TorchDynamoException): + pass + + +class TorchRuntimeError(TorchDynamoException): + pass + + +class InvalidBackend(TorchDynamoException): + def __init__(self, name: str) -> None: + super().__init__( + f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends." + ) + + +class ResetRequired(TorchDynamoException): + def __init__(self) -> None: + super().__init__( + textwrap.dedent( + """ + Must call `torch._dynamo.reset()` before changing backends. Detected two calls to + `torch.compile()` with a different backend compiler arguments. + """ + ) + ) + + +class ShortenTraceback(TorchDynamoException): + def __init__( + self, *args: Any, first_useful_frame: Optional[types.FrameType], **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + self.first_useful_frame = first_useful_frame + + def remove_dynamo_frames(self) -> typing.Self: + tb = self.__traceback__ + if self.first_useful_frame is None or tb is None or config.verbose: + return self + while tb.tb_frame is not self.first_useful_frame: + tb = tb.tb_next + assert tb is not None, "internal error, please report a bug" + return self.with_traceback(tb) + + +class BackendCompilerFailed(ShortenTraceback): + def __init__( + self, + backend_fn: Any, + inner_exception: Exception, + first_useful_frame: Optional[types.FrameType], + ) -> None: + self.backend_name = getattr(backend_fn, "__name__", "?") + self.inner_exception = inner_exception + msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}" + super().__init__(msg, first_useful_frame=first_useful_frame) + + +class Unsupported(TorchDynamoException): + def __init__(self, msg: str, *, case_name: Optional[str] = None) -> None: + super().__init__(msg) + self.real_stack = torch._guards.TracingContext.extract_stack() + self.msg = msg + self.category: Optional[str] = None + self.add_to_stats() + self.case_name: Optional[str] = case_name + + def remove_from_stats(self) -> None: + assert self.category is not None + counters[self.category][self.msg] -= 1 + if counters[self.category][self.msg] <= 0: + del counters[self.category][self.msg] + + def add_to_stats(self, category: str = "unimplemented") -> None: + self.category = category + counters[category][self.msg] += 1 + + +class UnknownPropertiesDuringBackwardTrace(Unsupported): + pass + + +class RecompileError(TorchDynamoException): + pass + + +class ArgsMismatchError(Unsupported): + def __init__(self, msg: str) -> None: + super().__init__(msg) + + +class AttributeMutationError(Unsupported): + def __init__(self, msg: str) -> None: + super().__init__(msg) + + +class InfiniteGeneratorError(Unsupported): + # Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT + def __init__(self, msg: str) -> None: + super().__init__(msg) + + +class SideEffectsError(Unsupported): + def __init__(self, msg: str) -> None: + super().__init__(msg) + + +class CondOpArgsMismatchError(ArgsMismatchError): + """ + Internal error from cond() due to arguments mismatch. + """ + + def __init__(self, msg: str) -> None: + super().__init__(msg) + + +class UserErrorType(Enum): + DYNAMIC_CONTROL_FLOW = auto() + ANTI_PATTERN = auto() + STANDARD_LIBRARY = auto() + CONSTRAINT_VIOLATION = auto() + DYNAMIC_DIM = auto() + INVALID_INPUT = auto() + INVALID_OUTPUT = auto() + UNSUPPORTED_ALIASED_MUTATED_DYNAMIC_INPUTS = auto() + + +class UserError(Unsupported): + def __init__( + self, error_type: UserErrorType, msg: str, case_name: Optional[str] = None + ) -> None: + """ + Type of errors that would be valid in Eager, but not supported in TorchDynamo. + The error message should tell user about next actions. + + error_type: Type of user error + msg: Actionable error message + case_name: (Optional) Unique name (snake case) for the usage example in exportdb. + """ + if case_name is not None: + assert isinstance(case_name, str) + if msg.endswith("."): + msg += " " + else: + msg += "\n" + msg += exportdb_error_message(case_name) + super().__init__(msg) + self.error_type = error_type + self.message = msg + + +class SkipCodeRecursiveException(TorchDynamoException): + pass + + +class RecompileLimitExceeded(Unsupported): + pass + + +class UnsafeScriptObjectError(TorchDynamoException): + pass + + +class UncapturedHigherOrderOpError(TorchDynamoException): + pass + + +class IncorrectUsage(Exception): + pass + + +# TODO: I'm a little uncertain about what error classification we should have +# for this. This is potentially a user error, but regressions in +# specialization in PyTorch proper could also trigger this problem +class FailOnRecompileLimitHit(Exception): + pass + + +class PackageError(TorchDynamoException): + pass + + +class ObservedException(TorchDynamoException): + # An exception observed during the tracing. This exception is used by Dynamo to handle exceptions. + pass + + +class ObservedUserStopIteration(ObservedException): + # An UserStopIteration exception observed during the Dynamo tracing (e.g Dynamo tracing __next__) + value: Optional[Any] + + # Reference `StopIteration_init` in CPython + # https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L568-L584 + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__("unhandled `raise StopIteration`") + if len(args) > 0: + self.value = args[0] + else: + self.value = None + + +class ObservedLookupError(ObservedException): + # A LookupError exception to be raised from inside Dynamo tracing. This can happen on __getitem__ + pass + + +class ObservedIndexError(ObservedLookupError): + # An IndexError exception to be raised from inside Dynamo tracing. This can happen on list __getitem__ + pass + + +class ObservedKeyError(ObservedLookupError): + # A KeyError exception to be raised from inside Dynamo tracing. This can happen on dict __getitem__ + pass + + +class ObservedGeneratorExit(ObservedException): + pass + + +class ObservedAttributeError(ObservedException): + # An AttributeError exception to be raised from inside Dynamo tracing. This can happen on user defined object __getattr__ + pass + + +class ObservedRuntimeError(ObservedException): + # A RuntimeError exception to be raised from inside Dynamo tracing. This can happen on generator.throw(..) method + pass + + +class ObservedNotImplementedError(ObservedException): + pass + + +class ObservedTypeError(ObservedException): + # A TypeError exception to be raised from inside Dynamo tracing. This can happen on generator.send(..) method + pass + + +observed_exception_map = { + StopIteration: ObservedUserStopIteration, + LookupError: ObservedLookupError, + IndexError: ObservedIndexError, + GeneratorExit: ObservedGeneratorExit, + KeyError: ObservedKeyError, + AttributeError: ObservedAttributeError, + RuntimeError: ObservedRuntimeError, + NotImplementedError: ObservedNotImplementedError, + TypeError: ObservedTypeError, +} + + +def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]: + if exc_type not in observed_exception_map: + name = getattr(exc_type, "__name__", str(exc_type)) + observed_exception_map[exc_type] = type( + f"Observed{name}Error", (ObservedException,), {} + ) + return observed_exception_map[exc_type] + + +def raise_observed_exception( + exc_type: type[Exception], + tx: InstructionTranslatorBase, + *, + args: Optional[list[Any]] = None, + kwargs: Optional[dict[str, Any]] = None, +) -> NoReturn: + from .variables import BuiltinVariable + + # CPython here raises an exception. Since there is no python code, we have to manually setup the exception + # stack and raise the exception. + exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type] + tx.exn_vt_stack.set_current_exception(exception_vt) + raise observed_exception_map[exc_type] + + +def handle_observed_exception(tx: Any) -> None: + # This is essentially exception handling code, equivalent of this pseudo code + # + # try: + # ... somebody raising StopIteration + # except StopIteration + # pass + # + # If this was going through the python code, we would have called exception_handler method, but FOR_ITER + # handles the exception completely in CPython. For example for 3.11, the resulting bytecode is + # + # + # 6 46 LOAD_GLOBAL 2 (StopIteration) + # 58 RAISE_VARARGS 1 + # >> 60 PUSH_EXC_INFO + + # 7 62 LOAD_GLOBAL 2 (StopIteration) + # 74 CHECK_EXC_MATCH + # 76 POP_JUMP_FORWARD_IF_FALSE 3 (to 84) + # 78 POP_TOP + + # 8 80 POP_EXCEPT + # + + # Fortunately this translates to a simple pop from the exn_vt_stack + tx.exn_vt_stack.clear_current_exception() + + +# These exceptions are ok to fallback to eager/graph_break. +exceptions_allowed_to_be_fallback = ( + torch._subclasses.fake_tensor.DataDependentOutputException, + torch._subclasses.fake_tensor.DynamicOutputShapeException, + torch._subclasses.fake_tensor.UnsupportedOperatorException, + torch._subclasses.fake_tensor.UnsupportedFakeTensorException, + torch._subclasses.fake_tensor.UnsupportedMutationAliasingException, +) + + +def unimplemented_with_warning( + e: Exception, code: types.CodeType, msg: str +) -> NoReturn: + # This function calls unimplemented internally and eventually graph breaks + # or falls to eager. unimplemented itself does not print any user warnings, + # i.e., its very silent. This helper function is intended when an error is + # encountered in the torch.compile stack which is worth showing as warning + # to the user. For example, if AOT Autograd backend fails with a fake tensor + # exception, its ok to fallback to eager but not silently. Here, we can use + # this function to log the message and the stack trace. + graph_break_msg = format_error_msg_verbose(e, code) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_graph_break_reason", + "encoding": "string", + }, + payload_fn=lambda: graph_break_msg, + ) + graph_breaks_log.debug("%s", graph_break_msg) + log.warning(msg) + unimplemented(msg, from_exc=e) + + +_NOTHING = object() + + +def unimplemented( + msg: str, *, from_exc: Any = _NOTHING, case_name: Optional[str] = None +) -> NoReturn: + assert msg != os.environ.get("BREAK", False) + if from_exc is not _NOTHING: + raise Unsupported(msg, case_name=case_name) from from_exc + raise Unsupported(msg, case_name=case_name) + + +def unimplemented_v2_with_warning( + e: Exception, + code: types.CodeType, + gb_type: str, + context: str, + explanation: str, + hints: list[str], +) -> NoReturn: + # This function calls unimplemented internally and eventually graph breaks + # or falls to eager. unimplemented itself does not print any user warnings, + # i.e., its very silent. This helper function is intended when an error is + # encountered in the torch.compile stack which is worth showing as warning + # to the user. For example, if AOT Autograd backend fails with a fake tensor + # exception, its ok to fallback to eager but not silently. Here, we can use + # this function to log the message and the stack trace. + graph_break_msg = format_error_msg_verbose(e, code) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_graph_break_reason", + "encoding": "string", + }, + payload_fn=lambda: graph_break_msg, + ) + graph_breaks_log.debug("%s", graph_break_msg) + unimplemented_v2(gb_type, context, explanation, hints, from_exc=e, log_warning=True) + + +def format_graph_break_message( + gb_type: str, + context: str, + explanation: str, + hints: list[str], +) -> str: + explanation = textwrap.indent(explanation, " ").lstrip() + hints_str = "\n".join( + " Hint: " + textwrap.indent(hint, " ").lstrip() for hint in hints + ) + context = textwrap.indent(context, " ").lstrip() + + msg = f"""\ +{gb_type} + Explanation: {explanation} +{hints_str} + + Developer debug context: {context} +""" + return msg + + +# TODO replace old unimplemented later +def unimplemented_v2( + gb_type: str, + context: str, + explanation: str, + hints: list[str], + *, + from_exc: Any = _NOTHING, + log_warning: bool = False, +) -> NoReturn: + """ + Called within dynamo to cause a graph break. + Args: + gb_type: Context-free graph break type. It should be a short string without any + information specific to the tracing context (i.e. no dynamically-generated strings) + context: Developer context for the graph break. It can contain tracing context/dynamic strings. + explanation: User-facing context-dependent explanation for the graph break. Can be dynamic. + hints: List of user-facing hints for the graph break. + """ + + msg = format_graph_break_message(gb_type, context, explanation, hints) + + # Temporarily disabling the generation of the weblinks in error message + + # documentation_link = get_gbid_documentation_link(gb_type) + # msg += f"\n For more details about this graph break, please visit: {documentation_link}" + + if log_warning: + log.warning(msg) + if from_exc is not _NOTHING: + raise Unsupported(msg) from from_exc + raise Unsupported(msg) + + +def warning(msg: str) -> None: + counters["warnings"][msg] += 1 + assert msg != os.environ.get("BREAK", False) + + +# KeyError has special handling for its args +# see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details +class KeyErrorMsg: + def __init__(self, value: Any) -> None: + self.value = value + + def __str__(self) -> str: + return str(self.value) + + def __repr__(self) -> str: + return self.__str__() + + +def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None: + import traceback + + exc.innermost_user_frame_summary = None # type: ignore[attr-defined] + + real_stack = get_real_stack(exc) + if real_stack is not None and len(real_stack) > 0: + exc.innermost_user_frame_summary = real_stack[-1] # type: ignore[attr-defined] + msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}" + + if config.replay_record_enabled and hasattr(exc, "record_filename"): + msg += ( + f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ + torch._dynamo.replay('{exc.record_filename}').\n" + ) + + if not config.verbose and hasattr(exc, "real_stack"): + msg += ( + "\nSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace " + "(please do this especially if you're reporting a bug to PyTorch). " + 'For even more developer context, set TORCH_LOGS="+dynamo"\n' + ) + + if hasattr(exc, "inner_exception") and hasattr( + exc.inner_exception, "minifier_path" + ): + if hasattr(exc.inner_exception, "buck_command"): + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + f"this buck command to find the smallest traced graph " + f"which reproduces this error: {exc.inner_exception.buck_command}\n" + ) + else: + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + "this script to find the smallest traced graph which reproduces this error.\n" + ) + + old_msg = "" if len(exc.args) == 0 else str(exc.args[0]) + + if isinstance(exc, KeyError): + exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:] + else: + new_msg = old_msg + msg + exc.args = (new_msg,) + exc.args[1:] + + +def get_exc_message( + e: Exception, compile_id: CompileId +) -> tuple[Optional[str], Optional[int]]: + filename = None + lineno = None + if e.innermost_user_frame_summary is not None: # type: ignore[attr-defined] + filename = e.innermost_user_frame_summary.filename # type: ignore[attr-defined] + lineno = e.innermost_user_frame_summary.lineno # type: ignore[attr-defined] + e.compile_id = compile_id # type: ignore[attr-defined] + return filename, lineno + + +def get_stack_above_dynamo() -> StackSummary: + return filter_stack(extract_stack()) + + +def get_real_stack( + exc: Exception, frame: Optional[DynamoFrameType] = None +) -> Optional[StackSummary]: + real_stack = getattr(exc, "real_stack", None) + if real_stack is None: + return None + + # NB: it's possible for real_stack to be []; we still attempt to + # report a stack anyway because the stack_above_dynamo may still + # be useful for debugging + + if frame is not None: + # NB: frame is PyInterpreterFrame on Python 3.11 and later, + # not a TRUE frame object. You can't actually feed it + # to traceback because it doesn't have enough information. + # To solve this problem, we technically should just materialize + # the frame, the same way _PyFrame_GetFrameObject would do + # (but we cannot actually do this, because this populates + # frame_obj field, which default eval frame doesn't like). + # + # Fortunately, in this case, we can hack it: there's no need + # to actually use the truly top frame, we can just extract + # from where we are right now and rely on filter_stack to + # get rid of all the dynamo frames. For ease of testing + # we apply this behavior to ALL Python versions + stack_above_dynamo = get_stack_above_dynamo() + else: + stack_above_dynamo = StackSummary() + + return StackSummary.from_list(stack_above_dynamo + real_stack) + + +# filter out all frames after entering dynamo +def filter_stack(stack: StackSummary) -> StackSummary: + user_stack = StackSummary() + for frame in stack: + if frame.filename is None: + continue + if "convert_frame" in frame.filename: + break + if "eval_frame" in frame.filename or ( + frame.line and "torch._dynamo.optimize(" in frame.line + ): + continue + user_stack.append(frame) + + return user_stack + + +def remove_resume_prefix(name: str) -> Optional[str]: + from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX + + match = re.match(f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_(\\w+)_at_\\d+", name) + if match: + return match.group(1) + return None + + +def collapse_resume_frames(stack: StackSummary) -> StackSummary: + """ + When we graph break, we create a resume function and make a regular Python call + to it, which gets intercepted by Dynamo. This behavior is normally shown in the + traceback, which can be confusing to a user. So we can filter out resume frames + for better traceback clarity. + + Example: + File "..." line 3, in f + + File "..." line 5, in torch_dynamo_resume_in_f_at_80 + + File "..." line 10, in torch_dynamo_resume_in_f_at_120 + + + becomes + File "..." line 10, in f + + """ + + new_stack = StackSummary() + for frame in stack: + if frame.filename is None: + continue + name = remove_resume_prefix(frame.name) + if new_stack and name and new_stack[-1].name == name: + new_stack[-1] = frame + frame.name = name + else: + new_stack.append(frame) + + return new_stack + + +def format_error_msg_verbose( + exc: Exception, + code: types.CodeType, + record_filename: Optional[str] = None, + frame: Optional[DynamoFrameType] = None, +) -> str: + msg = ( + f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n" + ) + msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" + msg += format_exc() + real_stack = get_real_stack(exc, frame) + if real_stack is not None: + msg += ( + "\n" + + "=" * 10 + + " The above exception occurred while processing the following code " + + "=" * 10 + + "\n\n" + ) + msg += "".join(format_list(real_stack)) + msg += "\n" + msg += "=" * 10 + + return msg + + +def format_error_msg( + exc: Exception, + code: types.CodeType, + record_filename: Optional[str] = None, + frame: Optional[DynamoFrameType] = None, +) -> str: + if config.verbose: + return format_error_msg_verbose(exc, code, record_filename, frame) + return f"WON'T CONVERT {code.co_name} {code.co_filename}\ + line {code.co_firstlineno} \ndue to: \n{format_exc()}" diff --git a/phivenv/Lib/site-packages/torch/_dynamo/external_utils.py b/phivenv/Lib/site-packages/torch/_dynamo/external_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1dac6182398ff6851cf615b01a9af82702107016 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/external_utils.py @@ -0,0 +1,229 @@ +# This module contains functions that *will be allowed* by dynamo + +""" +This module contains utility functions that are explicitly allowed to be called during +TorchDynamo compilation. These functions are carefully vetted to ensure they work +correctly within the TorchDynamo tracing and compilation process. + +Key functionality groups: + +- Compilation State: + Functions for checking compilation state (is_compiling) + +- Function Wrapping: + Utilities for wrapping functions (wrap_inline, wrap_numpy) to work with + TorchDynamo compilation + +- Autograd Hooks: + Functions and classes for handling autograd hooks and backward passes + (call_hook, FakeBackwardCFunction, etc.) + +- Tensor Operations: + Utility functions for tensor operations and transformations +""" + +import functools +import warnings +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import deprecated, ParamSpec + +import torch +import torch.utils._pytree as pytree + + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +if TYPE_CHECKING: + # TorchScript does not support `@deprecated` + # This is a workaround to avoid breaking TorchScript + @deprecated( + "`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", + category=FutureWarning, + ) + def is_compiling() -> bool: + return torch.compiler.is_compiling() + +else: + + def is_compiling() -> bool: + """ + Indicates whether we are tracing/compiling with torch.compile() or torch.export(). + """ + # NOTE: With `@torch.compile(backend="eager")`, torch._dynamo.is_compiling() will get traced + # and return true. torch.compiler.is_compiling() is skipped and will return false. + return torch.compiler.is_compiling() + + +def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]: + """ + Create an extra frame around fn that is not in skipfiles. + """ + + @functools.wraps(fn) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: + return fn(*args, **kwargs) + + return inner + + +def call_hook( + hook: Callable[..., Optional[torch.Tensor]], *args: Any, **kwargs: Any +) -> torch.Tensor: + """ + Used by compiled autograd to handle hook returning None. + """ + result = hook(*args) + if result is None: + return args[0] + elif kwargs.get("hook_type") == "post_acc_grad_hook": + raise RuntimeError("Tensor post accumulate grad hooks should return None.") + return result + + +def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]: + r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function + from ``torch.Tensor``s to ``torch.Tensor``s. + """ + if not np: + return f + + @functools.wraps(f) + def wrap(*args: _P.args, **kwargs: _P.kwargs) -> pytree.PyTree: + args, kwargs = pytree.tree_map_only( + torch.Tensor, lambda x: x.numpy(), (args, kwargs) + ) + out = f(*args, **kwargs) + return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out) + + return wrap + + +class FakeBackwardCFunction: + def __init__( + self, + real: torch.autograd.function.BackwardCFunction, + saved_tensors: list[torch.Tensor], + ) -> None: + self.real = real + self.saved_tensors = saved_tensors + + def __getattr__(self, name: str) -> Any: + if name == "saved_variables": + warnings.warn( + "'saved_variables' is deprecated; use 'saved_tensors'", + DeprecationWarning, + ) + return self.saved_tensors + + return getattr(self.real, name) + + +def call_backward( + backward_c_function: torch.autograd.function.BackwardCFunction, + saved_tensors: list[torch.Tensor], + *args: Any, +) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + fake = FakeBackwardCFunction(backward_c_function, saved_tensors) + grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined] + + if not isinstance(grads, tuple): + grads = (grads,) + + return grads + + +def normalize_as_list(x: Any) -> list[Any]: + if isinstance(x, tuple): + return list(x) + elif isinstance(x, list): + return x + return [x] + + +def untyped_storage_size(x: torch.Tensor) -> int: + return x.untyped_storage().size() + + +class FakeCompiledAutogradEngine: + @staticmethod + def queue_callback( + final_callbacks: list[Callable[[], None]], cb: Callable[[], None] + ) -> None: + final_callbacks.append(cb) + + @staticmethod + def exec_final_callbacks(final_callbacks: list[Callable[[], None]]) -> None: + i = 0 + while i < len(final_callbacks): + cb = final_callbacks[i] + cb() + i += 1 + final_callbacks.clear() + + @staticmethod + def _exec_final_callbacks_stub() -> None: + pass + + +def call_hook_from_backward_state( + *args: Any, bw_state: Any, hook_name: str, **kwargs: Any +) -> Any: + return getattr(bw_state, hook_name)(*args, **kwargs) + + +def call_module_hooks_from_backward_state( + _: Any, result: Any, *args: Any, bw_state: Any, hooks_name: str, module_name: str +) -> Any: + module = getattr(bw_state, module_name) + hooks = getattr(bw_state, hooks_name) + for hook in hooks: + new_result = hook(module, result, *args) + if new_result is not None: + result = new_result + return result + + +# used for torch._dynamo.disable(recursive=False) +def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]: + # wrap function to get the right error message + # this function is in external_utils so that convert_frame doesn't skip it. + @functools.wraps(fn) + def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + return fn(*args, **kwargs) + + return nonrecursive_disable_wrapper + + +def _dynamo_config_patch_proxy_dunder_call( + self: Any, func: Callable[_P, _R] +) -> Callable[_P, _R]: + @functools.wraps(func) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: + with self: + return func(*args, **kwargs) + + return inner + + +# Use only on ints marked dynamic via torch.empty(0, integer) +# Currently only way to mark ints as dynamic: https://github.com/pytorch/pytorch/issues/129623 +def unwrap_maybe_dynamic_int(x: Union[torch.Tensor, int]) -> int: + if isinstance(x, torch.Tensor): + # x.size() is expected to be [0, dynamic_int] + return x.size(1) + return x + + +def call_accumulate_grad( + variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool +) -> None: + updated_grad = torch._dynamo.compiled_autograd.ops.AccumulateGrad( # type: ignore[attr-defined] + [grad], variable, variable.grad, has_post_hooks + ) + variable.grad = updated_grad[0] diff --git a/phivenv/Lib/site-packages/torch/_dynamo/funcname_cache.py b/phivenv/Lib/site-packages/torch/_dynamo/funcname_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..a09b9bc07e046c187cd76dfa4f8fd8a7a05cfed8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/funcname_cache.py @@ -0,0 +1,75 @@ +""" +This module provides functionality for caching and looking up fully qualified function +and class names from Python source files by line number. + +It uses Python's tokenize module to parse source files and tracks function/class +definitions along with their nesting to build fully qualified names (e.g. 'class.method' +or 'module.function'). The results are cached in a two-level dictionary mapping: + + filename -> (line_number -> fully_qualified_name) + +Example usage: + name = get_funcname("myfile.py", 42) # Returns name of function/class at line 42 + clearcache() # Clear the cache if file contents have changed + +The parsing is done lazily when a file is first accessed. Invalid Python files or +IO errors are handled gracefully by returning empty cache entries. +""" + +import tokenize +from typing import Optional + + +cache: dict[str, dict[int, str]] = {} + + +def clearcache() -> None: + cache.clear() + + +def _add_file(filename: str) -> None: + try: + with tokenize.open(filename) as f: + tokens = list(tokenize.generate_tokens(f.readline)) + except (OSError, tokenize.TokenError): + cache[filename] = {} + return + + # NOTE: undefined behavior if file is not valid Python source, + # since tokenize will have undefined behavior. + result: dict[int, str] = {} + # current full funcname, e.g. xxx.yyy.zzz + cur_name = "" + cur_indent = 0 + significant_indents: list[int] = [] + + for i, token in enumerate(tokens): + if token.type == tokenize.INDENT: + cur_indent += 1 + elif token.type == tokenize.DEDENT: + cur_indent -= 1 + # possible end of function or class + if significant_indents and cur_indent == significant_indents[-1]: + significant_indents.pop() + # pop the last name + cur_name = cur_name.rpartition(".")[0] + elif ( + token.type == tokenize.NAME + and i + 1 < len(tokens) + and tokens[i + 1].type == tokenize.NAME + and (token.string == "class" or token.string == "def") + ): + # name of class/function always follows class/def token + significant_indents.append(cur_indent) + if cur_name: + cur_name += "." + cur_name += tokens[i + 1].string + result[token.start[0]] = cur_name + + cache[filename] = result + + +def get_funcname(filename: str, lineno: int) -> Optional[str]: + if filename not in cache: + _add_file(filename) + return cache[filename].get(lineno, None) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/graph_break_hints.py b/phivenv/Lib/site-packages/torch/_dynamo/graph_break_hints.py new file mode 100644 index 0000000000000000000000000000000000000000..37f70dc436d420ab91600a4ea2495f96a23bef31 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/graph_break_hints.py @@ -0,0 +1,26 @@ +USER_ERROR = [ + "Dynamo has detected that tracing the code will result in an error when running in eager. " + "Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.", +] +DYNAMO_BUG = [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch.", +] +DIFFICULT = [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance.", +] +FUNDAMENTAL = [ + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through " + "your code. Consider finding a workaround.", +] +SUPPORTABLE = [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you " + "encounter this graph break often and it is causing performance issues.", +] +CAUSED_BY_EARLIER_GRAPH_BREAK = [ + "This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.", +] +INFERENCE_MODE = [ + "Avoid using `tensor.is_inference()` and `torch.is_inference_mode_enabled()` in your compile code. " + "This is primarily used in conjunction with `torch.inference_mode`. Consider using `torch.no_grad` instead " + "because `torch.no_grad` leads to same improvements as `inference_mode` when `torch.compile` is used.", +] diff --git a/phivenv/Lib/site-packages/torch/_dynamo/graph_break_registry.json b/phivenv/Lib/site-packages/torch/_dynamo/graph_break_registry.json new file mode 100644 index 0000000000000000000000000000000000000000..6991267f4d510c139c88325a7d44483eeff95f8c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/graph_break_registry.json @@ -0,0 +1,2141 @@ +{ + "GB0000": [ + { + "Gb_type": "All __torch_function__ overrides returned NotImplemented due to TypeError from user code", + "Context": "fn={fn}, args={args}, kwargs={kwargs}", + "Explanation": "All __torch_function__ overrides for for function {fn} returned NotImplemented", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0001": [ + { + "Gb_type": "Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", + "Context": "{self}.as_subclass({cls})", + "Explanation": "Currently not supported", + "Hints": [ + "Avoid this call or move it outside `torch.compile` regione", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0002": [ + { + "Gb_type": "Assertion failed on symbolic shapes", + "Context": "str(sym_expr)", + "Explanation": "", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0003": [ + { + "Gb_type": "Attempt to trace generator", + "Context": "", + "Explanation": "Generators cannot be compiled directly with `torch.compile`.", + "Hints": [ + "Call a generator from inside of a non-generator Python function and ", + "compile that function instead.", + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0004": [ + { + "Gb_type": "Attempted super().__delattr__() on an object without mutation tracking", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo needs to track mutations on an object before `super().__delattr__` can be used on it. But the object ({self.objvar}) doesn't have attribute mutation tracking enabled.", + "Hints": [ + "Ensure the object is tracked by Dynamo's side effect system.", + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0005": [ + { + "Gb_type": "Attempted to a str() method implemented in C/C++", + "Context": "", + "Explanation": "{type(arg.value)} has a C/C++ based str method. This is not supported.", + "Hints": [ + "Write the str method in Python" + ] + } + ], + "GB0006": [ + { + "Gb_type": "Attempted to call a super() attribute that is not a function or method", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo does not know how to trace the call `super().{name}()` because `super().{name}` is not a function or method attribute.", + "Hints": [ + "Ensure the attribute accessed via `super()` is a standard method or function." + ] + } + ], + "GB0007": [ + { + "Gb_type": "Attempted to call function marked as skipped", + "Context": "module: {module_name}, qualname: {qualname}, skip reason: {reason}", + "Explanation": "explanation", + "Hints": [] + } + ], + "GB0008": [ + { + "Gb_type": "Attempted to inline function marked as skipped", + "Context": "qualname: {fn_qualname}, name: {func.get_name()}, filename: `{func.get_filename()}`, skip reason: {result.reason}", + "Explanation": "Dynamo developers have intentionally marked that the function `{fn_qualname}` should not be traced.", + "Hints": [] + } + ], + "GB0009": [ + { + "Gb_type": "Attempted to inline function marked as skipped (SkipFunctionVariable)", + "Context": "Attempted to inline a SkipFunctionVariable {func}", + "Explanation": "Attempted to inline a function that was previously determined to be marked as intentionally skipped.", + "Hints": [] + } + ], + "GB0010": [ + { + "Gb_type": "Attempted to read a deleted variable", + "Context": "item: {item}, name: {name}", + "Explanation": "", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0011": [ + { + "Gb_type": "Attempted to read undefined local variable", + "Context": "LOAD_FAST {name}", + "Explanation": "Could not find a local variable with name `{name}`", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0012": [ + { + "Gb_type": "Attempted to read undefined local variable (implicit)", + "Context": "LOAD_FAST {name}", + "Explanation": "Could not find an implicit local variable with name `{name}`", + "Hints": [ + "This happens in dict/list comprehensions", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0013": [ + { + "Gb_type": "Attempted to represent unregistered RemovableHandle", + "Context": "", + "Explanation": "Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, which is not supported. This happens because the RemovableHandle was created in another frame.", + "Hints": [] + } + ], + "GB0014": [ + { + "Gb_type": "Attempted to wrap RNN, GRU, or LSTM", + "Context": "str(value)", + "Explanation": "Dynamo does not support RNN, GRU, or LSTM.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0015": [ + { + "Gb_type": "Attempted to wrap sparse Tensor", + "Context": "", + "Explanation": "torch.compile does not support sparse Tensors", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0016": [ + { + "Gb_type": "Attempted to wrap strided NestedTensor", + "Context": "", + "Explanation": "torch.compile does not support strided NestedTensor", + "Hints": [] + } + ], + "GB0017": [ + { + "Gb_type": "Attempted to wrap torch._higher_order_ops.invoke_subgraph", + "Context": "", + "Explanation": "Directly using invoke_subgraph is not supported. Use mark_compile_region", + "Hints": [] + } + ], + "GB0018": [ + { + "Gb_type": "Attempted to wrap unbacked SymInt", + "Context": "", + "Explanation": "Unbacked SymInt input is not supported yet.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0019": [ + { + "Gb_type": "AutogradFunctionContextVariable escaped Dynamo-traced region", + "Context": "", + "Explanation": "We cannot reconstruct a torch.autograd.Function's context object.", + "Hints": [] + } + ], + "GB0020": [ + { + "Gb_type": "BUILD_STRING key conflict", + "Context": "format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}", + "Explanation": "Failed to build format string due to key conflict", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0021": [ + { + "Gb_type": "BUILD_STRING type error", + "Context": "str(part)", + "Explanation": "Format string part type is not correct - expected constant or format string.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0022": [ + { + "Gb_type": "Bad import result", + "Context": "typestr(value)", + "Explanation": "Import result is not a Python module.", + "Hints": [] + } + ], + "GB0023": [ + { + "Gb_type": "Builtin `operator.*` comparison with constant `self` failed", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "\"Failed to compare {self} with {other}, because {other} is not a Python constant or its mutation check fails.\"", + "Hints": [] + } + ], + "GB0024": [ + { + "Gb_type": "CLEANUP_THROW with StopIteration", + "Context": "", + "Explanation": "Received StopIteration when handling generator.throw/close. This is not supported.", + "Hints": [] + } + ], + "GB0025": [ + { + "Gb_type": "Call to `torch._dynamo.graph_break()`", + "Context": "Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`", + "Explanation": "User-inserted graph break. Message: {graph_break_msg}", + "Hints": [ + "Remove the `torch._dynamo.graph_break()` call." + ] + } + ], + "GB0026": [ + { + "Gb_type": "Calling subclass default constructor with more than tensor argument", + "Context": "{self.value}(args={args}, kwargs={kwargs})", + "Explanation": "Currently not supported", + "Hints": [ + "Avoid this constructor call or move it outside ", + "`torch.compile` regione", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0027": [ + { + "Gb_type": "Cannot check Tensor object identity without its fake value", + "Context": "str(fake_tensor)", + "Explanation": "TensorVariable is missing a fake example_value.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0028": [ + { + "Gb_type": "Caught non-Exception value", + "Context": "str(exc_instance)", + "Explanation": "Except expects to receive an object of Exception type but received {exc_instance}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0029": [ + { + "Gb_type": "Compilation of intermediate hooks requires compiled autograd", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo must be in compiled_autograd to register hooks.", + "Hints": [] + } + ], + "GB0030": [ + { + "Gb_type": "ComptimeContext graph break", + "Context": "msg", + "Explanation": "Manually triggered ComptimeContext graph break with message {msg}.", + "Hints": [] + } + ], + "GB0031": [ + { + "Gb_type": "Custom __getattribute__ in nn.Module attribute access", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo does not support checking key existence on `nn.Module` instances that have a custom `__getattribute__` method defined.", + "Hints": [ + "Avoid defining `__getattribute__` in your module.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0032": [ + { + "Gb_type": "Custom __getattribute__ in nn.Module dict key check", + "Context": "has_key_in_generic_dict {self} {key}", + "Explanation": "Dynamo does not support checking key existence on `nn.Module` instances that have a custom `__getattribute__` method defined.", + "Hints": [ + "Avoid defining `__getattribute__` in your module.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0033": [ + { + "Gb_type": "Data dependent operator", + "Context": "str(cause.func)", + "Explanation": "Operator `{cause.func}` has a non-Tensor output whose value is dependent on the data of Tensor inputs.", + "Hints": [] + } + ], + "GB0034": [ + { + "Gb_type": "Data-dependent assertion failed (cannot compile partial graph)", + "Context": "value: {value}", + "Explanation": "Dynamo has determined when encountering a data-dependent assert failure that it should not compile the partial graph.", + "Hints": [ + "Use `torch._assert()` to raise a hard AssertionError when the check fails. ", + "This error will propagate back the user code ", + "that called the compiled function (i.e. Dynamo will not trace any exception handling).", + "Remove the assert statement.", + "Move the assert statement outside of any context managers in order to graph break with ", + "partial graph compilation (if fullgraph=False).", + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0035": [ + { + "Gb_type": "Data-dependent branching with non-constant __bool__", + "Context": "method: {x}, result: {result}", + "Explanation": "Attempted to perform data-dependent branching on a user-defined object with a __bool__ method that did not return a constant.", + "Hints": [] + } + ], + "GB0036": [ + { + "Gb_type": "Dynamic shape operator", + "Context": "str(cause.func)", + "Explanation": "Operator `{cause.func}`'s output shape depends on input Tensor data.", + "Hints": [ + "Enable tracing of dynamic shape operators with ", + "`torch._dynamo.config.capture_dynamic_output_shape_ops = True`" + ] + } + ], + "GB0037": [ + { + "Gb_type": "Dynamic shape operator (no meta kernel)", + "Context": "str(cause.func)", + "Explanation": "Operator `{cause.func}` does not have a meta kernel that supports dynamic output shapes", + "Hints": [ + "Please report an issue to PyTorch" + ] + } + ], + "GB0038": [ + { + "Gb_type": "Dynamic slicing with Tensor arguments", + "Context": "SliceVariable start: {start}, stop: {stop}, step: {step}", + "Explanation": "Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0039": [ + { + "Gb_type": "Dynamo cache limit exceeded", + "Context": "Limit type: {limit_type}", + "Explanation": "Dynamo attempted to recompile the code object too many times, exceeding the {limit_type} cache size limit.Giving up on compiling as the compile time tradeoff is likely not worth the performance gain.", + "Hints": [] + } + ], + "GB0040": [ + { + "Gb_type": "Encountered aliasing during higher order op tracing", + "Context": "context", + "Explanation": "Higher order ops do not support aliasing. Found in {source_target.name()}", + "Hints": [ + "Consider using the debug context to change user code to avoid aliasing.", + "Please open an issue." + ] + } + ], + "GB0041": [ + { + "Gb_type": "Encountered input mutation during higher order op tracing", + "Context": "context", + "Explanation": "Higher order ops do not support input mutation. Found in {source_target.name()}", + "Hints": [ + "Consider using the debug context to change user code to avoid mutation.", + "Please open an issue." + ] + } + ], + "GB0042": [ + { + "Gb_type": "Encountered non user function variable during invoke_subgraph HOP tracing", + "Context": "str(fn_vt)", + "Explanation": "invoke_subgraph does not support non user function variable", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0043": [ + { + "Gb_type": "Encountered non-PT2-compliant op", + "Context": "", + "Explanation": "msg + + err_epilogue", + "Hints": [] + } + ], + "GB0044": [ + { + "Gb_type": "Encountered strided NestedTensor in automatic dynamic dim determination", + "Context": "", + "Explanation": "torch.compile does not support strided NestedTensor", + "Hints": [] + } + ], + "GB0045": [ + { + "Gb_type": "Encountered tensor.is_inference() during tracing", + "Context": "", + "Explanation": "tensor.is_inference() is not supported", + "Hints": [ + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0046": [ + { + "Gb_type": "Encountered torch.is_inference_mode_enabled during tracing", + "Context": "", + "Explanation": "torch.is_inference_mode_enabled() is not supported", + "Hints": [ + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0047": [ + { + "Gb_type": "Encountered unconverted argument when attempting to inline", + "Context": "func: {func}, arg: {v}", + "Explanation": "An argument to an inlined function was not successfully converted to a VariableTracker.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0048": [ + { + "Gb_type": "Error getting associated real value", + "Context": "call_id {self}", + "Explanation": "Dynamo encountered an error while trying to get the associated real value.", + "Hints": [] + } + ], + "GB0049": [ + { + "Gb_type": "Error when attempting to resolve op packet", + "Context": "", + "Explanation": "str(e)", + "Hints": [] + } + ], + "GB0050": [ + { + "Gb_type": "Exception with bad expected type", + "Context": "str(expected_exc_types)", + "Explanation": "`except ...` has unsupported type {expected_exc_types}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0051": [ + { + "Gb_type": "Exception with non-type expectation", + "Context": "str(expected_type)", + "Explanation": "`except ...` expects a non-type: {expected_type}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0052": [ + { + "Gb_type": "Excessive RestartAnalysis() calls", + "Context": "", + "Explanation": "Dynamo attempted to trace the same frame 100+ times. Giving up on compiling as the compile time tradeoff is likely not worth the performance gain.", + "Hints": [] + } + ], + "GB0053": [ + { + "Gb_type": "FSDP with use_orig_params=False", + "Context": "", + "Explanation": "Dynamo only supports FSDP with use_orig_params=True", + "Hints": [] + } + ], + "GB0054": [ + { + "Gb_type": "Failed to construct Enum variable", + "Context": "value: {value_vt}, allowed enum values: {list(cls_type)}", + "Explanation": "Attempted to construct an Enum value that is non-constant (e.g. int, string) or is not an acceptable value for the Enum. Acceptable values for Enum `{cls_type}`: {list(cls_type)}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0055": [ + { + "Gb_type": "Failed to convert args/kwargs to proxy", + "Context": "call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}", + "Explanation": "Missing `as_proxy()` implementation for some arg/kwarg.", + "Hints": [] + } + ], + "GB0056": [ + { + "Gb_type": "Failed to mutate tensor data attribute", + "Context": "setattr({obj}, {name}, {val})", + "Explanation": "Dyanmo only supports mutating `.data` of tensor created outside `torch.compile` region", + "Hints": [ + "Don't mutate `.data` on this tensor, or move ", + "the mutation out of `torch.compile` region" + ] + } + ], + "GB0057": [ + { + "Gb_type": "Failed to raise exception", + "Context": "str(exc)", + "Explanation": "Attempted to raise a non-Exception type/value.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0058": [ + { + "Gb_type": "Failed to set tensor attribute", + "Context": "setattr({obj}, {name}, {val})", + "Explanation": "Dyanmo doesn't support setting these tensor attributes", + "Hints": [ + "Don't mutate attribute '{name}' on tensors, or ", + "move the mutation out of `torch.compile` region" + ] + } + ], + "GB0059": [ + { + "Gb_type": "Failed to trace builtin operator", + "Context": "builtin {fn.__name__} {arg_types} {has_kwargs}", + "Explanation": "Dynamo does not know how to trace builtin operator `{fn.__name__}` with argument types {real_arg_types} (has_kwargs {has_kwargs})", + "Hints": [ + "Avoid calling builtin `{fn.__name__}` with argument types {real_arg_types}. ", + "Consider using an equivalent alternative function/method to `{fn.__name__}`.", + "If you are attempting to call a logging function (e.g. `print`), ", + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please report an issue to PyTorch." + ] + } + ], + "GB0060": [ + { + "Gb_type": "Failed to trace unittest method", + "Context": "function: unittest.TestCase.{name}", + "Explanation": "Dynamo does not know how to trace unittest method `{name}` ", + "Hints": [ + "Avoid calling `TestCase.{name}`. ", + "Please report an issue to PyTorch." + ] + } + ], + "GB0061": [ + { + "Gb_type": "Failed to unpack object for BUILD_LIST_UNPACK", + "Context": "str(seq)", + "Explanation": "{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK bytecode (`[*x, *y, ...]`).", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0062": [ + { + "Gb_type": "Failed to unpack object for UNPACK_EX", + "Context": "str(seq)", + "Explanation": "{seq} cannot be unpacked into a list for the UNPACK_EX bytecode.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0063": [ + { + "Gb_type": "Failed to unpack object for UNPACK_SEQUENCE", + "Context": "str(seq)", + "Explanation": "{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode (i.e. `a, b, c = d`).", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0064": [ + { + "Gb_type": "Fake tensor propagation exception", + "Context": "str(e.reason)", + "Explanation": "msg", + "Hints": [] + } + ], + "GB0065": [ + { + "Gb_type": "Graph break in inlined function", + "Context": "", + "Explanation": "Graph breaks in an inlined call are not supported.", + "Hints": [] + } + ], + "GB0066": [ + { + "Gb_type": "Graph break under GenericContextWrappingVariable", + "Context": "Active generic context managers: {self.active_generic_context_managers}", + "Explanation": "Attempted to graph break in an active context manager(s) that doesn't support graph breaking.", + "Hints": [ + "Move the offending context manager(s) to outside the compiled region.", + "This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one." + ] + } + ], + "GB0067": [ + { + "Gb_type": "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)", + "Context": "", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0068": [ + { + "Gb_type": "Illegal method invocation in strict mode", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Dynamo currently does not support this method ({name}) invocation in strict mode.", + "Hints": [] + } + ], + "GB0069": [ + { + "Gb_type": "Import failure", + "Context": "module_name: {module_name}, fromlist: {fromlist}, level={level}", + "Explanation": "Failure when attempting to import.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0070": [ + { + "Gb_type": "Indexing list with non-scalar tensor", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Attempted to index list-like object with tensor with > 1 element.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0071": [ + { + "Gb_type": "Inline attempt with __self__", + "Context": "str(func)", + "Explanation": "Attempted to inline a function with the `__self__` attribute. Dynamo is expected to decompose method calls into function calls with a `self` argument.", + "Hints": [] + } + ], + "GB0072": [ + { + "Gb_type": "Inplace op on input tensor", + "Context": "", + "Explanation": "Attempted to trace an inplace view op on input tensor {typestr(self.value)}.", + "Hints": [ + "Ensure you do not modify input tensor in place.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0073": [ + { + "Gb_type": "Invoking an nn.Module inside a HigherOrderOperator", + "Context": "", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0074": [ + { + "Gb_type": "Invoking an nn.Module inside a higher order operator", + "Context": "Higher order op name: {self.source_target}", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0075": [ + { + "Gb_type": "LOAD_BUILD_CLASS bytecode not supported", + "Context": "", + "Explanation": "Dynamo does not support tracing classes that are defined in the compiled region.", + "Hints": [ + "Move the class definition out of the compiled region.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0076": [ + { + "Gb_type": "LOAD_FAST_CHECK on uninitialized variable", + "Context": "inst.argval", + "Explanation": "Attempted to load uninitialized local variable {inst.argval}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0077": [ + { + "Gb_type": "Length mismatch when unpacking object for UNPACK_SEQUENCE", + "Context": "expected length: {inst.argval}, actual: {len(val)}", + "Explanation": "{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode (i.e. `a, b, c = d`) with unexpected length.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0078": [ + { + "Gb_type": "Limitation of `nonstrict_trace", + "Context": "{self}", + "Explanation": "msg", + "Hints": [ + "make sure definition of {fn_name} is outside ", + "`torch.compile` region" + ] + } + ], + "GB0079": [ + { + "Gb_type": "Missing CALL_INTRINSIC_1 handler", + "Context": "CALL_INTRINSIC_1 operand: {inst.argval}", + "Explanation": "No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0080": [ + { + "Gb_type": "Missing FakeTensor example value", + "Context": "str(node)", + "Explanation": "`FakeTensor` example value was required for {node} but not available.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0081": [ + { + "Gb_type": "Missing attribute when running call_method node", + "Context": "", + "Explanation": "make_error_message(\"attribute not defined\")", + "Hints": [] + } + ], + "GB0082": [ + { + "Gb_type": "Missing bytecode handler", + "Context": "{opname} with args {args}", + "Explanation": "Dynamo does not know how to handle the bytecode instruction `{opname}`.", + "Hints": [ + "Do not trace code that produces the `{opname}` bytecode instruction ", + "(see https://docs.python.org/3/library/dis.html for bytecode semantics).", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0083": [ + { + "Gb_type": "Module-level backwards hooks require compiled autograd.", + "Context": "", + "Explanation": "", + "Hints": [ + "Enable compiled autograd by setting torch._dynamo.config.compiled_autograd = True." + ] + } + ], + "GB0084": [ + { + "Gb_type": "Non-constant attribute given to `super().__delattr__()`", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo requires the attribute name passed to `super().__delattr__(...)` to be a constant (string).", + "Hints": [ + "Ensure the attribute name is a string literal or a constant variable." + ] + } + ], + "GB0085": [ + { + "Gb_type": "Non-function or method in subclass of torch.autograd.Function", + "Context": "call_apply {self} {args} {kwargs}", + "Explanation": "Dynamo requires the `forward` attribute of a `torch.autograd.Function` subclass to be a standard Python function or method. Found type `{type(fn).__name__}` instead.", + "Hints": [ + "Ensure the `forward` method is defined as a regular ", + "function or instance method." + ] + } + ], + "GB0086": [ + { + "Gb_type": "Not a Python constant", + "Context": "guard_as_python_constant {self}", + "Explanation": "Failed to convert {self} into a Python constant.", + "Hints": [] + } + ], + "GB0087": [ + { + "Gb_type": "NotImplementedError/UnsupportedFakeTensorException when running FX node", + "Context": "", + "Explanation": "make_error_message(e)", + "Hints": [] + } + ], + "GB0088": [ + { + "Gb_type": "Observed exception", + "Context": "str(raised_exception)", + "Explanation": "observed_exn_gb_explanation", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0089": [ + { + "Gb_type": "Observed exception (EXCEPT_HANDLER)", + "Context": "str(raised_exception)", + "Explanation": "observed_exn_gb_explanation + \" This graph break is unexpected.\"", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0090": [ + { + "Gb_type": "Operator does not support running with fake tensors", + "Context": "unsupported operator: {cause.func}", + "Explanation": "", + "Hints": [ + "{import_suggestion}see ", + "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0", + " for how to fix" + ] + } + ], + "GB0091": [ + { + "Gb_type": "Read uninitialized cell", + "Context": "str(cellvar)", + "Explanation": "Attempted to read a cell variable that has not been populated yet.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0092": [ + { + "Gb_type": "Reconstruction failure", + "Context": "str(value)", + "Explanation": "Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.", + "Hints": [ + "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable ", + "that Dynamo cannot reconstruct, then remove it from the return statement.", + "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have ", + "reconstruction rules may be fundamentally unreconstructable.", + "This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one." + ] + } + ], + "GB0093": [ + { + "Gb_type": "Reconstruction failure: source.reconstruct not implemented", + "Context": "str(source)", + "Explanation": "Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0094": [ + { + "Gb_type": "SEND with bad type", + "Context": "TOS type: {typestr(tos)}", + "Explanation": "Attempted to SEND with unsupported type {typestr(tos)}.", + "Hints": [] + } + ], + "GB0095": [ + { + "Gb_type": "Set Exception object `__traceback__` attribute to not-`None`", + "Context": "call_setattr {self} {name}", + "Explanation": "Dynamo does not support setting the attribute '__traceback__' on tracked exception objects to anything other than None.", + "Hints": [ + "Avoid setting '__traceback__' on exception objects ", + "within traced code, or set it to None." + ] + } + ], + "GB0096": [ + { + "Gb_type": "Should not compile partial graph (STORE_ATTR)", + "Context": "", + "Explanation": "Dynamo has determined when encountering an unsupported STORE_ATTR instruction (i.e. `obj.attr = val`) that it should not compile the partial graph.", + "Hints": [] + } + ], + "GB0097": [ + { + "Gb_type": "Side effect on existing deque with limited maxlen", + "Context": "", + "Explanation": "This is not supported.", + "Hints": [ + "Don't use a deque with `maxlen` specified." + ] + } + ], + "GB0098": [ + { + "Gb_type": "Skip calling `torch.compiler.disable()`d function", + "Context": "str(self.value)", + "Explanation": "Skip calling function `{self.value}` since it was wrapped with `torch.compiler.disable` (reason: {msg})", + "Hints": [ + "Remove the `torch.compiler.disable` call" + ] + } + ], + "GB0099": [ + { + "Gb_type": "Skip inlining `torch.compiler.disable()`d function", + "Context": "str(func.get_function())", + "Explanation": "Skip inlining function {func.get_function()} since it was wrapped with `torch.compiler.disable` (reason: {msg})", + "Hints": [ + "Remove the `torch.compiler.disable` call" + ] + } + ], + "GB0100": [ + { + "Gb_type": "Storing Tensor hook handle in globals", + "Context": "name", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0101": [ + { + "Gb_type": "Storing Tensor hook handle in globals (inline call)", + "Context": "inst.argval", + "Explanation": "This is not supported.", + "Hints": [] + } + ], + "GB0102": [ + { + "Gb_type": "Strict mode banned op", + "Context": "var_getattr {self} {name}", + "Explanation": "Getattr invocation '{name}' in strict mode is not supported.", + "Hints": [ + "Remove `{name}` from the list of banned ops by ", + "setting `torch._dynamo.config._autograd_backward_strict_mode_banned_ops`." + ] + } + ], + "GB0103": [ + { + "Gb_type": "Tensor subclass overridden method call", + "Context": "{name}", + "Explanation": "`torch.compile` currently can't trace this", + "Hints": [ + "Avoid calling {name} of tensor subclass in torch.compile region", + "Renaming method `{name}` of type {self.class_type}", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0104": [ + { + "Gb_type": "Tensor with grad_fn()", + "Context": "var_getattr {self} grad_fn", + "Explanation": "Dynamo does not support tracing tensors with a grad_fn directly.", + "Hints": [] + } + ], + "GB0105": [ + { + "Gb_type": "Tensor.numpy() with trace_numpy=False", + "Context": "call_method {self} numpy", + "Explanation": "`Tensor.numpy()` was called, but the `trace_numpy` configuration was manually disabled.", + "Hints": [ + "Set `torch._dynamo.config.trace_numpy = True` to allow ", + "Dynamo to trace through NumPy." + ] + } + ], + "GB0106": [ + { + "Gb_type": "Tensor.numpy() without NumPy installed", + "Context": "call_method {self} numpy", + "Explanation": "`Tensor.numpy()` was called, but the NumPy library is not available in the current environment.", + "Hints": [ + "Ensure NumPy is installed in your Python environment." + ] + } + ], + "GB0107": [ + { + "Gb_type": "Tensor.random_ op", + "Context": "Tensor.{name}(args={args}, kwargs={kwargs})", + "Explanation": "This is currently not supported.", + "Hints": [ + "Use the out-of-place version of this op", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0108": [ + { + "Gb_type": "Tensor.retain_grad() with AOTDispatcher", + "Context": "var_getattr {self} retain_grad", + "Explanation": "`Tensor.retain_grad()` does not work with AOTDispatcher.", + "Hints": [] + } + ], + "GB0109": [ + { + "Gb_type": "Tensor.tolist() with non-integer tensor", + "Context": "call_method {self} to_list", + "Explanation": "Dynamo currently does not support tracing `tolist()` on non-integer tensors.", + "Hints": [ + "Ensure the input tensor to `tolist()` is an integer ", + "type (e.g., int8, int16, int32, int64)." + ] + } + ], + "GB0110": [ + { + "Gb_type": "Tensor.uniform_ op called with `from` keyword", + "Context": "Tensor.{name}(args={args}, kwargs={kwargs})", + "Explanation": "This is currently not supported.", + "Hints": [ + "Avoid using the `from` keyword.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0111": [ + { + "Gb_type": "TypeError from user code", + "Context": "call_function({self.value}, {args}, {kwargs})", + "Explanation": "msg", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0112": [ + { + "Gb_type": "TypeError when making fake tensor call", + "Context": "TypeError {node.target}: {cause}", + "Explanation": "", + "Hints": [] + } + ], + "GB0113": [ + { + "Gb_type": "Unable to resolve super getattr", + "Context": "", + "Explanation": "Dynamo failed to trace attribute `{name}` accessed via `super()` (for type `{self.typevar}` and object `{self.objvar}`) because the resolved attribute type is not supported.", + "Hints": [ + "Ensure the attribute exists in the parent class.", + "Check the arguments passed to `super()`." + ] + } + ], + "GB0114": [ + { + "Gb_type": "Unexpected failure during itertools.accumulate() iteration", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0115": [ + { + "Gb_type": "Unexpected failure during itertools.groupby() iteration", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Unexpected failure in invoking function during groupby", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0116": [ + { + "Gb_type": "Unexpected type in sourceless builder", + "Context": "{value_type.__module__}.{value_type.__qualname__}", + "Explanation": "SourcelessBuilder.create does not know how to wrap {value_type}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0117": [ + { + "Gb_type": "Unhandled args for method", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Dynamo encountered an error while calling the method `{name}`.", + "Hints": [] + } + ], + "GB0118": [ + { + "Gb_type": "Unimplemented next() call", + "Context": "next({self})", + "Explanation": "This abstract method must be implemented", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0119": [ + { + "Gb_type": "Uninitialized nn.Module", + "Context": "typestr(value)", + "Explanation": "Attempted to trace an uninitialized nn.Module of type {typestr(value)}.", + "Hints": [ + "Ensure your nn.Module instance has called `super().__init__()`.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0120": [ + { + "Gb_type": "Unreachable sub-generator code", + "Context": "", + "Explanation": "Should only be encountered while implementing generator support.", + "Hints": [] + } + ], + "GB0121": [ + { + "Gb_type": "UnspecializedNNModuleVariable missing method", + "Context": "call_method: {self} {name} {args} {kwargs}", + "Explanation": "Dynamo does not support tracing method {name} of nn.Module {self.value}", + "Hints": [ + "Dynamo does not really define unspecialized nn.Module very well.", + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0122": [ + { + "Gb_type": "Unsupported SourceType", + "Context": "MutationType.__init__ {self} {typ}", + "Explanation": "Dynamo does not support the type `{typ}`", + "Hints": [ + "This branch is not supposed to be reachable.", + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0123": [ + { + "Gb_type": "Unsupported Tensor.backward() call", + "Context": "call_method {self} backward {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.backward()`.", + "Hints": [ + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], + "GB0124": [ + { + "Gb_type": "Unsupported Tensor.item() call with capture_scalar_outputs=False", + "Context": "call_method {self} item {args} {kwargs}", + "Explanation": "Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.", + "Hints": [ + "Set `torch._dynamo.config.capture_scalar_outputs = True` ", + "or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` ", + "to include these operations in the captured graph." + ] + } + ], + "GB0125": [ + { + "Gb_type": "Unsupported Tensor.requires_grad_() call", + "Context": "call_method {self} requires_grad_", + "Explanation": "Dynamo does not support changes to a Tensor's `requires_grad` through calling `requires_grad_()`.", + "Hints": [] + } + ], + "GB0126": [ + { + "Gb_type": "Unsupported Tensor.resize_() call", + "Context": "call_method {self} resize_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.resize_()`.", + "Hints": [] + } + ], + "GB0127": [ + { + "Gb_type": "Unsupported Tensor.resize_as_() call", + "Context": "call_method {self} resize_as_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.resize_as_()`.", + "Hints": [] + } + ], + "GB0128": [ + { + "Gb_type": "Unsupported Tensor.set_() call", + "Context": "call_method {self} set_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.set_()` overloads that include more than one argument.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0129": [ + { + "Gb_type": "Unsupported Tensor.sparse_resize_() call", + "Context": "call_method {self} sparse_resize_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.sparse_resize_()`.", + "Hints": [] + } + ], + "GB0130": [ + { + "Gb_type": "Unsupported Tensor.sparse_resize_and_clear_() call", + "Context": "call_method {self} sparse_resize_and_clear_ {args} {kwargs}", + "Explanation": "Dynamo currently does not support tracing `Tensor.sparse_resize_and_clear_()`.", + "Hints": [] + } + ], + "GB0131": [ + { + "Gb_type": "Unsupported __setitem__/__setattr__ inline attempt", + "Context": "code name: {code.co_name}, args: {args}", + "Explanation": "Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.", + "Hints": [] + } + ], + "GB0132": [ + { + "Gb_type": "Unsupported `func` in itertools.accumulate", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to get the function to use for itertools.accumulate. itertools.accumulate expects the `func` as the second argument or as a keyword argument.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0133": [ + { + "Gb_type": "Unsupported arguments for itertools.accumulate", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace itertools.accumulate with args: {args} and kwargs: {kwargs}. itertools.accumulate expects an iterable, an optional binary function for accumulation, and an optional initial value to set the starting state.", + "Hints": [ + "Make sure the arguments to itertools.accumulate are correct.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0134": [ + { + "Gb_type": "Unsupported arguments for itertools.groupby", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace itertools.groupby with args: {args} and kwargs: {kwargs}. itertools.groupby expects an iterable to group and an optional key function to determine groupings.", + "Hints": [ + "Make sure the arguments to itertools.groupby are correct.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0135": [ + { + "Gb_type": "Unsupported attribute assignment on Exception object", + "Context": "call_setattr {self} {name}", + "Explanation": "Dynamo does not support setting the attribute '{name}' on tracked exception objects. Only `__context__`, `__cause__`, `__suppress_context__`, and `__traceback__` are supported.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0136": [ + { + "Gb_type": "Unsupported attribute for range() object", + "Context": "var_getattr {self} {name}", + "Explanation": "Expected attribute to be one of {','.join(fields)} but got {name}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0137": [ + { + "Gb_type": "Unsupported attribute for slice() object", + "Context": "var_getattr {self} {name}", + "Explanation": "Expected attribute to be one of {','.join(fields)} but got {name}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0138": [ + { + "Gb_type": "Unsupported autograd.Function context `save_for_backward`", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo requires the `saved_tensors` attribute to be initialized on the `autograd.Function` context object.", + "Hints": [ + "Ensure that the `saved_tensors` attribute is properly ", + "initialized before calling `save_for_backward`. ", + "`save_for_backward` only supported on a newly constructed `torch.autograd.function.FunctionCtx`." + ] + } + ], + "GB0139": [ + { + "Gb_type": "Unsupported autograd.Function context method", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo does not support calling the method `{name}` on `autograd.Function` context objects. Supported methods are `__setattr__`, `save_for_backward` and `mark_non_differentiable`.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0140": [ + { + "Gb_type": "Unsupported autograd.Function method", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo does not support calling the method `{name}` directly on the `torch.autograd.Function` instance. Supported methods include `apply`, `backward`, static methods, and class methods.", + "Hints": [ + "Ensure the method is decorated with `@staticmethod` ", + "or `@classmethod` if it's meant to be called on the class." + ] + } + ], + "GB0141": [ + { + "Gb_type": "Unsupported call_id() without source", + "Context": "call_id {self}", + "Explanation": "call_id() not supported for sourceless TensorVariable.", + "Hints": [] + } + ], + "GB0142": [ + { + "Gb_type": "Unsupported context manager", + "Context": "Attempted SETUP_WITH/BEFORE_WITH on {ctx}", + "Explanation": "Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.", + "Hints": [ + "Avoid using the unsupported context manager.", + "If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then ", + "it may be the case that it was created outside the compiled region, which Dynamo does not support. ", + "Supported context managers can cross graph break boundaries only if they are local non-closure ", + "variables, or are intermediate values.", + "File an issue to PyTorch. Simple context managers can potentially be supported, ", + "but note that context managers can't be supported in general" + ] + } + ], + "GB0143": [ + { + "Gb_type": "Unsupported conversion for slice assignment", + "Context": "call_method {self} {name} {args}", + "Explanation": "Missing dynamo support for converting {value} into a list for slice assignment.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0144": [ + { + "Gb_type": "Unsupported custom jvp", + "Context": "call_apply {self} {args} {kwargs}", + "Explanation": "Dynamo does not support tracing `torch.autograd.Function` subclasses that define a custom `jvp` method.", + "Hints": [ + "Remove the custom `jvp` method if possible.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0145": [ + { + "Gb_type": "Unsupported custom vjp", + "Context": "call_apply {self} {args} {kwargs}", + "Explanation": "Dynamo does not support tracing `torch.autograd.Function` subclasses that define a custom `vjp` method.", + "Hints": [ + "Remove the custom `vjp` method if possible.", + "Use standard `backward` instead if applicable.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0146": [ + { + "Gb_type": "Unsupported event method", + "Context": "str(name)", + "Explanation": "Dynamo doesn't support tracing the {method_name} method. We currently support wait, record, synchronize, and query.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0147": [ + { + "Gb_type": "Unsupported function call", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace the function `{self.debug_repr()}`", + "Hints": [ + "Avoid calling `{self.debug_repr()}` in your code.", + "Please report an issue to PyTorch." + ] + } + ], + "GB0148": [ + { + "Gb_type": "Unsupported function call (delayed)", + "Context": "source: {self.source}", + "Explanation": "Dynamo determined that a graph break should occur when calling `{self.source.name()}`. Reason: {self.msg}", + "Hints": [] + } + ], + "GB0149": [ + { + "Gb_type": "Unsupported functorch tracing attempt", + "Context": "", + "Explanation": "msg", + "Hints": [] + } + ], + "GB0150": [ + { + "Gb_type": "Unsupported hasattr call", + "Context": "call_obj_hasattr {self} {name}", + "Explanation": "Dynamo does not know how to trace the function `{self.debug_repr()}`", + "Hints": [ + "Avoid calling `hasattr({self.__class__.__name__}, {name})` in your code.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0151": [ + { + "Gb_type": "Unsupported inspect call", + "Context": "inspect_parameter_names {self}", + "Explanation": "Dynamo does not know how to trace the function `{self.debug_repr()}`", + "Hints": [] + } + ], + "GB0152": [ + { + "Gb_type": "Unsupported key type for itertools.groupby", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace itertools.groupby with key type: {str(type(key))}. We only support grouping keys that are constants (int, float, str, etc.)", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0153": [ + { + "Gb_type": "Unsupported key type for nn.Module.__getitem__", + "Context": "call_method: {self} {name} {args} {kwargs}", + "Explanation": "Dynamo does not support getitem on `nn.Module` with non-constant key.", + "Hints": [] + } + ], + "GB0154": [ + { + "Gb_type": "Unsupported kwargs for itertools.accumulate", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Expected kwargs: 'initial', 'func', but got {','.join(set(kwargs.keys()) - {'initial', 'func'})}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0155": [ + { + "Gb_type": "Unsupported kwargs for itertools.groupby", + "Context": "call_function {self} {args} {kwargs}", + "Explanation": "Expected kwargs: 'key', but got {','.join(set(kwargs.keys()) - {'key'})}", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0156": [ + { + "Gb_type": "Unsupported method call", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`", + "Hints": [] + } + ], + "GB0157": [ + { + "Gb_type": "Unsupported ndarray attribute access", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo currently does not support tracing `ndarray.{name}`.", + "Hints": [] + } + ], + "GB0158": [ + { + "Gb_type": "Unsupported ndarray method call", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "`ndarray.{name}()` is not modelled in `torch._numpy`.", + "Hints": [] + } + ], + "GB0159": [ + { + "Gb_type": "Unsupported ndarray.__version__ access", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo currently does not support tracing `ndarray.{name}`.", + "Hints": [] + } + ], + "GB0160": [ + { + "Gb_type": "Unsupported next() call", + "Context": "next({self})", + "Explanation": "Dynamo does not know how to trace calling `next()` on variable `{self}`.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0161": [ + { + "Gb_type": "Unsupported nn.Module attribute type", + "Context": "nn.Module subclass: {typestr(base)}, name: {name}, attribute type: {typestr(subobj)}", + "Explanation": "Dynamo does not support tracing nn.Module attributes of type `{typestr(subobj)}`", + "Hints": [ + "Refactor your code so that `{name}` (type `{typestr(subobj)}`) is not an attribute of `{typestr(base)}`", + "Currently supported attribute types are methods, classmethods, staticmethods, ", + "properties, constants, and tensors.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0162": [ + { + "Gb_type": "Unsupported super().__init__() call", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Dynamo encountered a super().__init__() call on {objvar} that resolved to a `torch.nn.Module.__init__()` call that we cannot trace.", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0163": [ + { + "Gb_type": "Unsupported tensor subclass attribute access", + "Context": "{name}", + "Explanation": "`torch.compile` currently can't trace this", + "Hints": [ + "Avoid accessing {name} of tensor subclass in torch.compile region", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0164": [ + { + "Gb_type": "Unsupported tensor subclass overridden attribute access", + "Context": "{name}", + "Explanation": "`torch.compile` only support tracing certain types of overridden tensor subclass attributes", + "Hints": [ + "Avoid accessing {name} of tensor subclass in torch.compile region", + "Renaming attribute `{name}` of type {self.class_type}", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0165": [ + { + "Gb_type": "Unsupported torch._C._ImperativeEngine method", + "Context": "call_method {self} {name}", + "Explanation": "Dynamo only supports the `queue_callback` method on a torch._C._ImperativeEngine instance, but found: `{name}`.", + "Hints": [] + } + ], + "GB0166": [ + { + "Gb_type": "Unsupported torch._C._ImperativeEngine.queue_callback()", + "Context": "call_method {self} {name}", + "Explanation": "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True.", + "Hints": [] + } + ], + "GB0167": [ + { + "Gb_type": "Variadic function call with bad args/kwargs type", + "Context": "args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}", + "Explanation": "Expected args to be a list and kwargs to be a dict", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0168": [ + { + "Gb_type": "Variadic function call with bad flags", + "Context": "flags: {inst.argval}", + "Explanation": "Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0169": [ + { + "Gb_type": "Write to immutable cell", + "Context": "cellvar: {cellvar}, value: {value}", + "Explanation": "Dynamo doesn't support writing to immutable/sourceless cell variables.", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0170": [ + { + "Gb_type": "_gb_type", + "Context": "attempted to jump with {value}", + "Explanation": "_explanation", + "Hints": [] + } + ], + "GB0171": [ + { + "Gb_type": "assert with non-string message", + "Context": "str(args)", + "Explanation": "Dynamo only supports asserts with string messages", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0172": [ + { + "Gb_type": "async_op=True for distributed collectives", + "Context": "{self.fn}, args={args}, kwargs={kwargs}", + "Explanation": "`torch.compile` doesn't support `async_op=True for {self.fn}", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0173": [ + { + "Gb_type": "backward_state does not support export", + "Context": "", + "Explanation": "Compiled autograd doesn't work with `torch.export`.", + "Hints": [] + } + ], + "GB0174": [ + { + "Gb_type": "bad args to builtin cast()", + "Context": "got args {args} {kwargs}", + "Explanation": "Dynamo expects exactly 2 args to builtin cast().", + "Hints": [ + "Ensure your call to cast() has exactly 2 arguments." + ] + } + ], + "GB0175": [ + { + "Gb_type": "builtin isinstance() cannot determine type of argument", + "Context": "isinstance({arg}, {isinstance_type})", + "Explanation": "Dynamo doesn't have a rule to determine the type of argument {arg}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0176": [ + { + "Gb_type": "call_id() without associated real value", + "Context": "call_id {self}", + "Explanation": "Dynamo could not find an associated real value for the tensor.", + "Hints": [] + } + ], + "GB0177": [ + { + "Gb_type": "can't handle functions not implemented in python ", + "Context": "{fn}", + "Explanation": "Dynamo can only handle functions defined in python", + "Hints": [ + "Move usage of this function out of `torch.compile` region", + "Avoid using `tensor.is_inference()` and `torch.is_inference_mode_enabled()` in your compile code. This is primarily used in conjunction with `torch.inference_mode`. Consider using `torch.no_grad` instead because `torch.no_grad` leads to same improvements as `inference_mode` when `torch.compile` is used." + ] + } + ], + "GB0178": [ + { + "Gb_type": "constant fold exception", + "Context": "attempted to run function {fn} with arguments {args}", + "Explanation": "Encountered exception when attempting to constant fold.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0179": [ + { + "Gb_type": "copy.deepcopy()", + "Context": "copy.deepcopy({x})", + "Explanation": "Dynamo does not support copy.deepcopy()", + "Hints": [ + "Avoid calling copy.deepcopy()", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0180": [ + { + "Gb_type": "dataclass fields failure", + "Context": "obj: {obj}; variable type: {type(obj)}", + "Explanation": "Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.", + "Hints": [] + } + ], + "GB0181": [ + { + "Gb_type": "dtype mismatch between tensor and its gradient", + "Context": "tensor dtype: {value.dtype}; grad dtype: {safe_grad(value).dtype}", + "Explanation": "Inconsistent dtype between tensor and its gradient. This can happen in FSDP and crashes meta tensor creation.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0182": [ + { + "Gb_type": "failed to broadcast when attempting Tensor comparison op", + "Context": "{op.__name__}({left}, {right})", + "Explanation": "Dynamo was unable to broad cast the arguments {left}, {right} when attempting to trace the comparison op {op.__name__}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0183": [ + { + "Gb_type": "failed to call dict.fromkeys()", + "Context": "{user_cls.__name__}.fromkeys(): {args} {kwargs}", + "Explanation": "Failed to call {user_cls.__name__}.fromkeys() because arguments could not be automatically converted to a list, or some dict key is not hashable.", + "Hints": [ + "Manually convert the argument to a list.", + "Ensure all keys are hashable." + ] + } + ], + "GB0184": [ + { + "Gb_type": "failed to call str() on user defined object", + "Context": "str(arg)", + "Explanation": "User defined object has no __str__ or __repr__ method", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0185": [ + { + "Gb_type": "failed to convert numpy.ndarray to Tensor", + "Context": "str(value)", + "Explanation": "Exception encountered when attempting to convert numpy.ndarray to Tensor", + "Hints": [] + } + ], + "GB0186": [ + { + "Gb_type": "functools.partial() with non-literal keyword", + "Context": "non-literal keyword: {k}", + "Explanation": "functools.partial() expects literal/string keywords", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0187": [ + { + "Gb_type": "functools.wraps", + "Context": "{fn}", + "Explanation": "`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0188": [ + { + "Gb_type": "getattr with no source", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo does not know how to access an attribute on an `nn.Module` instance that lacks a source. This is usually an internal error in Dynamo.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0189": [ + { + "Gb_type": "getattr() on nn.Module with pending mutation", + "Context": "getattr({obj}, {name}, {default})", + "Explanation": "Intentionally graph breaking on getattr() on a nn.Module with a pending mutation", + "Hints": [] + } + ], + "GB0190": [ + { + "Gb_type": "getattr() with non-constant name argument", + "Context": "getattr({obj}, {name_var}, {default})", + "Explanation": "getattr() with non-constant name argument is not supported", + "Hints": [ + "Ensure the name argument of getattr() is a string" + ] + } + ], + "GB0191": [ + { + "Gb_type": "id() with unsupported args", + "Context": "str(args)", + "Explanation": "Dynamo doesn't know how to trace id() call with args {args}", + "Hints": [ + "Supported args are Tensors, and functions/nn.Modules/user-defined objects ", + "from outside the compiled region.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0192": [ + { + "Gb_type": "input iterator to itertools.cycle has too many items", + "Context": "next({self})", + "Explanation": "Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}", + "Hints": [] + } + ], + "GB0193": [ + { + "Gb_type": "invalid call to builtin op handler", + "Context": "invalid args to {self_handler}: {args} {kwargs}", + "Explanation": "Encountered TypeError when trying to handle op {fn.__name__}", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0194": [ + { + "Gb_type": "isinstance() called on user defined object with C extensions", + "Context": "isinstance({arg}, {isinstance_type})", + "Explanation": "User-defined object with C extensions can have torch.Tensor attributes; intentionally graph breaking.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0195": [ + { + "Gb_type": "issubclass() with non-constant arguments", + "Context": "issubclass({left_ty}, {right_ty})", + "Explanation": "issubclass() with non-constant arguments not supported.", + "Hints": [ + "Make sure your arguments are types.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0196": [ + { + "Gb_type": "key not found in dict", + "Context": "Key {arg.value}", + "Explanation": "msg", + "Hints": [ + "Check if the key exists in the dictionary before accessing it.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0197": [ + { + "Gb_type": "list elements are pointing to the list itself", + "Context": "", + "Explanation": "Dynamo does not support lists whose items reference to itself", + "Hints": [ + "Avoid using self referential list" + ] + } + ], + "GB0198": [ + { + "Gb_type": "mapping proxy affected by dictionary mutation", + "Context": "Source: {self.source}, Dict mutation detected", + "Explanation": "msg", + "Hints": [ + "Avoid modifying dictionaries that might be referenced by mapping proxy objects", + "Or avoid using the mapping proxy objects after modifying its underlying dictionary" + ] + } + ], + "GB0199": [ + { + "Gb_type": "mapping proxy cannot be reconstructed", + "Context": "Source: {self.source}", + "Explanation": "msg", + "Hints": [ + "Use a mapping proxy constructed in the same `torch.compile` region.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0200": [ + { + "Gb_type": "missing BUILD_SET handler", + "Context": "", + "Explanation": "Missing BUILD_SET bytecode handler (for testing purposes).", + "Hints": [] + } + ], + "GB0201": [ + { + "Gb_type": "namedtuple construction", + "Context": "args={args}, kwargs={kwargs}", + "Explanation": "`torch.compile` only support certain input types for namedtuple", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0202": [ + { + "Gb_type": "non-const argument in nn.Module method", + "Context": "call_method: {self} {name} {args} {kwargs}", + "Explanation": "Dynamo does not support calling method `{name}` of ``nn.Module`` {module} with non-constant arguments.", + "Hints": [] + } + ], + "GB0203": [ + { + "Gb_type": "non-const keys in dict_keys", + "Context": "non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}", + "Explanation": "Dynamo expects dict_keys keys to be constants.", + "Hints": [ + "Ensure your dict_keys keys are constants (e.g. int, float, strings)" + ] + } + ], + "GB0204": [ + { + "Gb_type": "non-const keys in mappingproxy", + "Context": "non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", + "Explanation": "Dynamo expects mappingproxy keys to be constants.", + "Hints": [ + "Ensure your mappingproxy keys are constants (e.g. int, float, strings)" + ] + } + ], + "GB0205": [ + { + "Gb_type": "proxy not set", + "Context": "as_proxy {self}", + "Explanation": "Dynamo requires the autograd.Function context to be initialized with a proxy.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0206": [ + { + "Gb_type": "setattr() on Tensor.requires_grad", + "Context": "setattr({obj}, {name}, {val})", + "Explanation": "setattr() on Tensor.requires_grad not supported. Mutating requires_grad can introduce a new leaf from non-leaf or vice versa in the middle of the graph, which AOTAutograd does not currently know how to handle.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0207": [ + { + "Gb_type": "sort with non-constant keys", + "Context": "str(first_non_constant_key)", + "Explanation": "Cannot perform sort with non-constant key. First non-constant key type: {python_type}. Most notably, we cannot sort with Tensor or SymInt keys, but we can sort ints.", + "Hints": [ + "Use something else as the key." + ] + } + ], + "GB0208": [ + { + "Gb_type": "torch.* op returned non-Tensor", + "Context": "example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}", + "Explanation": "torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output", + "Hints": [] + } + ], + "GB0209": [ + { + "Gb_type": "torch.autograd._unsafe_preserve_version_counter escaped from compiled region", + "Context": "str(self)", + "Explanation": "Dynamo doesn't support compiling a region that returns a torch.autograd._unsafe_preserve_version_counter context manager.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0210": [ + { + "Gb_type": "torch.distributed package is not available!", + "Context": "", + "Explanation": "The PyTorch package doesn't include torch.distributed when building from source.", + "Hints": [ + "Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source." + ] + } + ], + "GB0211": [ + { + "Gb_type": "torch.nn.Module with a non-function custom __getattr__", + "Context": "var_getattr {self} {name}", + "Explanation": "Dynamo detected a nn.Module object with a custom `__getattr__` method, but this method is not a standard Python function (e.g., it might be implemented in C/C++). Dynamo cannot currently trace into such non-standard `__getattr__` methods.", + "Hints": [ + "Avoid using objects with non-standard __getattr__ methods ", + "within the compiled region. If possible, implement ", + "__getattr__ as a standard Python function.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0212": [ + { + "Gb_type": "torch.profiler object escaped from compiled region", + "Context": "str(self)", + "Explanation": "Dynamo doesn't support compiling a region that returns a torch.profiler context manager.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0213": [ + { + "Gb_type": "unimplemented builtin op on tensor arguments", + "Context": "partial tensor op: {self} {args} {kwargs}", + "Explanation": "Dynamo does not know how to trace builtin operator {self.fn} with tensor arguments", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0214": [ + { + "Gb_type": "unsupported SymNode comparison op", + "Context": "{op.__name__}({left}, {right})", + "Explanation": "Dynamo does not support the comparison op {op.__name__} with SymNode arguments {left}, {right}", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0215": [ + { + "Gb_type": "unsupported Tensor comparison op", + "Context": "{op.__name__}({left}, {right})", + "Explanation": "Dynamo does not support the comparison op {op.__name__} with Tensor arguments {left}, {right}", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0216": [ + { + "Gb_type": "unsupported grid type for triton hop check_grid", + "Context": "grid type = {type(grid)}", + "Explanation": "`torch.compile` only supports list-like grid for check_grid", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0217": [ + { + "Gb_type": "unsupported hasattr operation", + "Context": "Class {self.user_cls}", + "Explanation": "msg", + "Hints": [ + "Consider using a regular dictionary instead", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0218": [ + { + "Gb_type": "unsupported index(Tensor)", + "Context": "", + "Explanation": "Dynamo does not support tracing builtin index() on a Tensor", + "Hints": [] + } + ] +} diff --git a/phivenv/Lib/site-packages/torch/_dynamo/graph_deduplication.py b/phivenv/Lib/site-packages/torch/_dynamo/graph_deduplication.py new file mode 100644 index 0000000000000000000000000000000000000000..b786d910ec1feb7f0857b35c5d1c2f0b92c12041 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/graph_deduplication.py @@ -0,0 +1,439 @@ +""" +This module implements graph deduplication functionality for TorchDynamo's optimization pipeline. +Graph deduplication identifies identical subgraphs in the computational graph and merges them +to reduce redundancy and improve performance. The process involves analyzing regions of the graph, +identifying structurally equivalent regions, and replacing them with a single shared implementation. +This optimization is particularly effective for models with repeated patterns or similar computational +structures across different parts of the network. +""" + +import logging +import operator +from collections import defaultdict +from collections.abc import Generator, Iterable +from typing import Optional + +import torch +import torch.fx +from torch._dynamo import config +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._ordered_set import OrderedSet + +from .graph_region_tracker import Node, Region +from .graph_utils import _detect_cycles, _get_flat_args, _get_flat_args_unique + + +# Represents an index into the region +# to select a node and then +# an index into that node's +# flattened arguments +UsageIndex = tuple[int, int] + +log = logging.getLogger(__name__) + +last_node_to_additional_deps: Optional[dict[Node, OrderedSet[Node]]] = None + + +def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def] + """ + This is the main entry point for applying the graph deduplication pass. \ +Deduplication occurs in two phases: + 1. Subgraph creation: + Subgraph creation works by taking one representative region from each region \ +group and creating a subgraph from it, which will then be used to replace all regions \ +in the group. This is implemented by first copying all nodes of the region to the new \ +subgraph and then finding all inputs which are not within the region and creating placeholders \ +for them. For the outputs, all regions in a region group need to be scanned to ensure the \ +largest set of outputs is found, and then an output node is created which returns \ +a tuple of all outputs. + + 2. Graph replacement: + To replace each region with the extracted subgraph, the node index in the region \ +and argument index within the node's flattened args and kwargs are recorded once during \ +subgraph creation. This allows us to determine which (external to the region) nodes and \ +in which order these nodes are passed as inputs. For the outputs, getitem nodes are created \ +for each output, and all nodes in the region with external outputs are replaced by the proper \ +getitem node. Finally, all original nodes are erased (there should be no uses of these \ +left in the graph). + +The deduplication mutates the output_graph argument in place. + +Returns a mapping of nodes to their subgraph output replacement node to remap outputs +when they are created in output_graph. + """ + + duplicated_region_groups = output_graph.region_tracker.get_identical_regions( + output_graph.graph + ) + node_to_mutated_arg_positions = ( + output_graph.region_tracker.node_to_mutated_arg_positions + ) + node_to_additional_deps = _populate_additional_deps( + output_graph.graph, output_graph.region_tracker.node_to_mutated_arg_positions + ) + + sub_gms: dict[str, torch.fx.GraphModule] = {} + + for region_group in duplicated_region_groups: + inds_with_external_users = _get_all_output_indices(region_group) + region = region_group[0] + ( + subgraph, + external_node_usages, + ) = _create_subgraph(region, inds_with_external_users) + + # Ignore regions with no args for now, could they possibly be evaluated at compile time? + if not list(external_node_usages): + continue + + sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph) + subgraph_name = output_graph.install_subgraph("subgraph", sub_gm) + sub_gms[subgraph_name] = sub_gm + with output_graph.graph.inserting_before(): + get_subgraph_node = output_graph.graph.create_node( + "get_attr", subgraph_name, (), {} + ) + + for region in region_group: + _replace_region_with_subgraph( + output_graph.graph, + region, + get_subgraph_node, + external_node_usages, + inds_with_external_users, + subgraph_name, + node_to_additional_deps, + node_to_mutated_arg_positions, + ) + + # This is to expose the updated node_to_additional_deps to tests + global last_node_to_additional_deps + last_node_to_additional_deps = node_to_additional_deps + + _stable_topological_sort( + output_graph.graph, + node_to_additional_deps, + ) + return sub_gms + + +def _replace_region_with_subgraph( + graph: torch.fx.Graph, + region: Region, + get_subgraph_node: Node, + external_node_usages: Iterable[OrderedSet[UsageIndex]], + inds_with_external_users: list[int], + subgraph_name: str, + node_to_additional_deps: dict[Node, OrderedSet[Node]], + node_to_mutated_arg_positions: dict[Node, OrderedSet[int]], +) -> None: + sub_args = [] + for usages in external_node_usages: + node_ind, usage_ind = next(iter(usages)) + node = region[node_ind] + flattened_args_kwargs = _get_flat_args(node, {}) + for user_ind, node_usage_ind in usages: + user = region[user_ind] + if user in node_to_mutated_arg_positions: + if node_usage_ind in node_to_mutated_arg_positions[user]: + log.debug( + "NYI: Failed to substitute region %s due to mutation", region + ) + return + sub_args.append(flattened_args_kwargs[usage_ind]) + + # Input/Output aliasing not supported in HOPs today + # Note: we should use the nodes in the original graph (the region here) + # because we use the original traced example values for this check + if _has_aliasing(region, sub_args, inds_with_external_users): + return + + invoke_args = (get_subgraph_node, subgraph_name, *sub_args) + + invoke_subgraph_node = graph.create_node( + "call_function", + torch.ops.higher_order.invoke_subgraph, + invoke_args, # type: ignore[arg-type] + {}, + ) + for ind, external_user_ind in enumerate(inds_with_external_users): + node = region[external_user_ind] + subgraph_output = graph.create_node( + "call_function", operator.getitem, (invoke_subgraph_node, ind), {} + ) + node.replace_all_uses_with(subgraph_output, propagate_meta=True) + + # Erase in reverse topological order + for node in reversed(region): + graph.erase_node(node) + # Remove any nodes with additional deps + # This is safe; we've guaranteed that there is + # no input mutation, so all additional deps + # will be internal to the subgraph + node_to_additional_deps.pop(node, None) + for deps in node_to_additional_deps.values(): + try: + deps.remove(node) + deps.add(invoke_subgraph_node) + except KeyError: + pass + + if config.graph_deduplication_lint: + print(_detect_cycles(graph, node_to_additional_deps)) + _stable_topological_sort(graph, node_to_additional_deps) + graph.lint() + + +def _get_external_inputs( + region: Region, +) -> dict[Node, OrderedSet[UsageIndex]]: + external_node_to_usages = defaultdict[Node, OrderedSet[UsageIndex]](OrderedSet) + region_unique = set(region) + for node_ind, node in enumerate(region): + flattened_args_kwargs = _get_flat_args(node, {}) + for arg_ind, in_node in enumerate(flattened_args_kwargs): + if isinstance(in_node, Node) and in_node not in region_unique: + # in_node may occur in multiple nodes' flat_args + # track this so we can check if the arg is mutated + # Previously, we only needed to track one occurrence + # to be able to map that node to a placeholder + external_node_to_usages[in_node].add((node_ind, arg_ind)) + + return external_node_to_usages + + +def _get_all_output_indices(regions: list[Region]) -> list[int]: + # Scan all regions to get the set of all possible output nodes indices in the region + # perhaps we can record this information during region creation for more efficiency? + inds_with_external_users: set[int] = set() + for region in regions: + _get_inds_with_external_users(region, inds_with_external_users) + + return sorted(inds_with_external_users) + + +def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None: + for ind, node in enumerate(region): + for user in node.users: + if user not in region: + if ind not in inds_unique: + inds_unique.add(ind) + + +def _copy_nodes_and_remap_inputs( + subgraph: torch.fx.Graph, region: Region +) -> list[OrderedSet[UsageIndex]]: + external_input_to_usages = _get_external_inputs(region) + external_node_usages = list[OrderedSet[UsageIndex]]() + region_to_subgraph_node = {} + for node, usage_indices in external_input_to_usages.items(): + placeholder = subgraph.placeholder(f"subgraph_input_{node.name}") + region_to_subgraph_node[node] = placeholder + external_node_usages.append(usage_indices) + + def map_arg(node: Node) -> Node: + if node in region_to_subgraph_node: + return region_to_subgraph_node[node] + else: + return node + + for node in region: + subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old)) + region_to_subgraph_node[node] = subgraph_node + + return external_node_usages + + +def _create_subgraph_outputs( + subgraph: torch.fx.Graph, inds_to_output: list[int] +) -> None: + node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")] + out_tup = tuple(node_list[ind] for ind in inds_to_output) + subgraph.output(out_tup) + + +def _create_subgraph( + region: Region, + inds_with_external_users: list[int], +) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]: + subgraph: torch.fx.Graph = torch.fx.Graph() + external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region) + _create_subgraph_outputs(subgraph, inds_with_external_users) + return subgraph, external_node_usages + + +def _stable_topological_sort( + graph: torch.fx.Graph, + node_to_additional_deps: dict[Node, OrderedSet[Node]], +) -> None: + # Nodes are in exactly one of these four collections: + + # - Nodes in `pending` are waiting to be processed (in reverse order): + pending = list(reversed(graph.nodes)) + + # - Nodes in `ready` have been processed and are already in the correct + # order. + ready = OrderedSet[Node]() + + # - `waiting` is a mapping from a dependency to nodes which depend on that + # dependency. + waiting = defaultdict(list) + + # - `outputs` are always at the end of the graph + outputs = OrderedSet[Node]() + + # The cursor indicates the last processed node so we can add new nodes + # after it. + cursor = None + while pending: + node = pending.pop() + + if node.target == "output": + outputs.add(node) + assert not node.users, "output nodes should have no users" + continue + + waiting_for = [ + x + for x in _get_flat_args_unique(node, node_to_additional_deps) + if x not in ready + ] + if waiting_for: + # We have unprocessed input nodes. Might as well wait for the last + # arg so an already sorted list will only recheck this node once. + waiting[waiting_for[-1]].append(node) + else: + ready.add(node) + if cursor and cursor.next is not node: + cursor.append(node) + cursor = node + # Mark the nodes that have been waiting for this node to finish as + # ready to check again. + pending.extend(reversed(waiting.pop(node, ()))) + + ready.update(outputs) + assert not waiting and len(ready) == len(graph.nodes) + + +def _populate_additional_deps( + graph: torch.fx.Graph, node_to_mutated_arg_positions: dict[Node, OrderedSet[int]] +) -> dict[Node, OrderedSet[Node]]: + node_to_additional_deps: dict[Node, OrderedSet[Node]] = defaultdict(OrderedSet) + _add_mutation_dependencies(node_to_mutated_arg_positions, node_to_additional_deps) + _add_global_state_dependencies(graph, node_to_additional_deps) + return node_to_additional_deps + + +def _add_global_state_dependencies( + graph: torch.fx.Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]] +) -> None: + import torch.amp + + all_nodes = list(graph.nodes) + + # These are targets of the nodes which need to stay in the same relative place in the graph + global_state_targets = {torch.amp._enter_autocast, torch.amp._exit_autocast} + all_nodes_dep_on: list[Node] = [] + + def prev_cur_nodes( + all_nodes: list[Node], + ) -> Generator[tuple[list[Node], Node], None, None]: + prev_nodes: list[Node] = [] + next_nodes = list(reversed(all_nodes)) + + while next_nodes: + cur_node = next_nodes.pop() + yield prev_nodes, cur_node + prev_nodes.append(cur_node) + + for prev_nodes, cur_node in prev_cur_nodes(all_nodes): + args_unique = _get_flat_args_unique(cur_node, {}) + new_deps = [n for n in all_nodes_dep_on if n not in args_unique] + + if new_deps: + additional_deps = node_to_additional_deps[cur_node] + additional_deps.update(new_deps) + + if cur_node.target in global_state_targets: + additional_deps = node_to_additional_deps[cur_node] + additional_deps.update(n for n in prev_nodes if n not in args_unique) + all_nodes_dep_on.append(cur_node) + + +def _add_mutation_dependencies( + node_to_mutated_arg_positions: dict[Node, OrderedSet[int]], + node_to_additional_deps: dict[Node, OrderedSet[Node]], +) -> None: + for node, indices in node_to_mutated_arg_positions.items(): + flat_args_kwargs = _get_flat_args(node, {}) + + # for all mutated args, + # add dependency on usages which occur after node to ensure + # node will always be ordered before them + # also add node as a dependency on usages which + # occur before node to ensure node is ordered after them + for index in indices: + mutated_arg = flat_args_kwargs[index] + for user in mutated_arg.users: + if user is node: + continue + elif user < node: + node_to_additional_deps[node].add(user) + elif user > node: + node_to_additional_deps[user].add(node) + + +def _has_aliasing( + region: Region, inputs: list[Node], inds_with_external_users: list[int] +) -> bool: + input_storages: dict[StorageWeakRef, Node] = dict() + + for node in inputs: + example_value = node.meta["example_value"] + if isinstance(example_value, torch.Tensor): + storage = StorageWeakRef(example_value._typed_storage()) + if storage in input_storages: + # input-input aliasing + log.debug( + "NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s, %s", + region, + input_storages[storage], + node, + ) + return True + input_storages[storage] = node + + output_storages: dict[StorageWeakRef, Node] = dict() + for i in inds_with_external_users: + out_node = region[i] + if out_node: + example_value = out_node.meta["example_value"] + assert not isinstance(example_value, list) + if isinstance(example_value, torch.Tensor): + storage = StorageWeakRef(example_value._typed_storage()) + if storage in output_storages: + # output-output aliasing + log.debug( + "NYI: Failed to substitute region %s due to output-output aliasing detected at nodes %s, %s", + region, + output_storages[storage], + out_node, + ) + return True + output_storages[storage] = out_node + + intersected_storages = input_storages.keys() & output_storages.keys() + if len(intersected_storages) > 0: + # input-output aliasing + aliased = [ + (input_storages[s], output_storages[s]) for s in intersected_storages + ] + aliased = ", ".join([f"{i} and {o}" for i, o in aliased]) + log.debug( + "NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s", + region, + aliased, + ) + return True + + return False diff --git a/phivenv/Lib/site-packages/torch/_dynamo/graph_region_tracker.py b/phivenv/Lib/site-packages/torch/_dynamo/graph_region_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..1b53ceb70bbccec1632caa8ad24688ed30780807 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/graph_region_tracker.py @@ -0,0 +1,465 @@ +""" +This module provides functionality for tracking and managing regions in computational graphs. +It supports graph optimization by identifying and grouping similar regions based on their +structure and behavior. The module implements algorithms for: + +1. Tracking nodes and their relationships in the computational graph +2. Identifying identical or similar regions across the graph +3. Managing graph regions for optimization purposes +4. Supporting deduplication and other graph transformation passes + +The core functionality revolves around the GraphRegionTracker class which maintains +mappings between nodes and their duplicates, enabling efficient graph analysis and +optimization operations. +""" + +import copyreg +import io +import logging +import math +import operator +import pickle +from collections import defaultdict, deque +from dataclasses import fields +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar + +import torch._logging +import torch.fx +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_flatten + +from .graph_utils import _get_flat_args_unique + + +T = TypeVar("T") + + +if TYPE_CHECKING: + from .symbolic_convert import InstructionTranslatorBase + + +Node = torch.fx.Node +Region = list[Node] +IdenticalNodes = list[Node] +GlobalStateKey = tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool] + +log = logging.getLogger(__name__) +graph_expansion_log = torch._logging.getArtifactLogger( + __name__, "graph_region_expansion" +) + + +def debug_log(msg: str, *args) -> None: # type: ignore[no-untyped-def] + graph_expansion_log.debug(msg, *args) + + +def _extract_tensor_metadata_for_node_hash( + x: torch.Tensor, +) -> tuple[Callable[[T], T], tuple[Any, ...]]: + from torch._inductor.codecache import _ident, extract_tensor_metadata_for_cache_key + + out = [] + metadata = extract_tensor_metadata_for_cache_key(x) + for field in fields(metadata): + out.append(getattr(metadata, field.name)) + + return (_ident, tuple(out)) + + +class NodeHashException(Exception): + pass + + +class InputPickler(pickle.Pickler): + def __init__(self) -> None: + from torch._inductor.codecache import _ident + + stream = io.BytesIO() + self._stream = stream + super().__init__(stream) + self.dispatch_table = copyreg.dispatch_table.copy() + self.dispatch_table.update( + { + FakeTensor: _extract_tensor_metadata_for_node_hash, + torch.SymInt: lambda x: (_ident, (str(x),)), + torch.SymBool: lambda x: (_ident, (str(x),)), + torch.SymFloat: lambda x: (_ident, (str(x),)), + } + ) + self.fast = True + + def dumps(self, obj: Any) -> bytes: + """ + Pickle an object and return a byte string. + """ + try: + self.dump(obj) + return self._stream.getvalue() + except (TypeError, AttributeError) as e: + raise NodeHashException from e + finally: + self._stream.seek(0) + self._stream.truncate(0) + + +def _extract_args(arg: Any) -> Any: + if isinstance(arg, Node): + return arg.meta.get("example_value") + elif isinstance(arg, (torch.Tensor, int)): + return arg + else: + return None + + +def _normalize_args( + node: Node, +) -> tuple[tuple[str, ...], tuple[Optional[Any], ...]]: + flat_args, _ = tree_flatten(node.args) + sorted_kwargs = sorted(node.kwargs.items(), key=operator.itemgetter(0)) + sorted_keys = tuple(sorted(node.kwargs.keys())) + flat_kwargs, _ = tree_flatten(sorted_kwargs) + all_args = flat_args + flat_kwargs + return (sorted_keys, tuple(_extract_args(arg) for arg in all_args)) + + +def get_global_state_key() -> GlobalStateKey: + return ( + torch.is_grad_enabled(), + torch.is_inference_mode_enabled(), + torch.get_num_threads(), + torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), + torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), + torch.get_default_dtype(), + torch.are_deterministic_algorithms_enabled(), + torch._C._get_cublas_allow_tf32(), + torch.is_deterministic_algorithms_warn_only_enabled(), + torch._C._autograd._saved_tensors_hooks_is_enabled(), # type: ignore[attr-defined] + ) + + +# This is typical BFS with the caveat +# that a node's children need to be explicitly +# added with the add_children() method +# The flow is yield a node and check if it's valid for all regions +# if not valid, discard and continue onto the next node +# Note: this iterates backward through the graph by looking at args/kwargs +# of a node +class BackwardBfsArgIter: + def __init__(self, origin: Node) -> None: + self._cur: Optional[Node] = origin + self._queue: deque[Optional[Node]] = deque() + + @staticmethod + def create(origin: Node) -> "BackwardBfsArgIter": + it = BackwardBfsArgIter(origin) + it.add_children(origin) + # pop the origin node, since it is the origin of + # the region and does not need to be considered for addition + assert it.next() + return it + + def next(self) -> Optional[Node]: + ret = self._cur + if not self._queue: + self._cur = None + else: + self._cur = self._queue.popleft() + return ret + + def peek(self) -> Optional[Node]: + return self._cur + + def add_children(self, node: Node) -> None: + flat_args = _get_flat_args_unique(node, {}) + for arg in flat_args: + if isinstance(arg, Node): + self._append(arg) + + def _append(self, arg: Node) -> None: + if self._cur is None: + self._cur = arg + else: + self._queue.append(arg) + + def __str__(self) -> str: + return f"BackwardBfsArgIter(cur={self._cur}, queue={self._queue})" + + +class GraphRegionTracker: + """ + GraphRegionTracker tracks each node added to the output graph and generates a key based on the source location, + instruction pointer, input shapes, and global state at the time the node is inserted into the graph. Nodes with + the same key are grouped together in a list of identical nodes (the value of node_to_duplicates). + + hash_to_duplicates: Dict[str, IdenticalNodes] - A dictionary mapping the key to a list of identical nodes + node_to_duplicates: Dict[Node, IdenticalNodes] - A dictionary mapping a node to the list of identical nodes it belongs to + input_pickler: InputPickler - An instance of InputPickler used to generate a node hash + """ + + def __init__(self) -> None: + self.hash_to_duplicates: dict[str, IdenticalNodes] = defaultdict(list) + self.node_to_duplicates: dict[Node, IdenticalNodes] = {} + # Note: position is in flattened args/kwargs list + self.node_to_mutated_arg_positions: dict[Node, OrderedSet[int]] = {} + self.input_pickler = InputPickler() + + def _hash_node( + self, filename: str, lineno: int, instruction_pointer: Optional[int], node: Node + ) -> str: + from torch._inductor.codecache import sha256_hash + + key = ( + get_global_state_key(), + filename, + lineno, + instruction_pointer, + _normalize_args(node), + ) + return sha256_hash(self.input_pickler.dumps(key)) + + def _is_identical(self, n0: Node, n1: Node) -> bool: + return ( + n0 in self.node_to_duplicates + and n1 in self.node_to_duplicates + and self.node_to_duplicates[n0] is self.node_to_duplicates[n1] + and n0 is not n1 + ) + + def track_node(self, tx: "InstructionTranslatorBase", node: Node) -> None: + """ + The main entry point for tracking a node. This function will hash the node argument and group + nodes with the same hash together. It updates the hash_to_duplicates and node_to_duplicates dictionaries + to track the new node. + """ + try: + duplicates = self.hash_to_duplicates[ + self._hash_node( + tx.f_code.co_filename, tx.lineno, tx.instruction_pointer, node + ) + ] + duplicates.append(node) + self.node_to_duplicates[node] = duplicates + except NodeHashException as e: + log.debug("Unable to hash node %s with exception %s", node, e) + + def track_node_mutations( + self, + node: Node, + flat_args_kwargs: list[Any], + id_to_initial_version: dict[int, int], + ) -> None: + """ + This function tracks which argument positions are mutated by the given node. Subgraph HOP does not support + input mutations today so we will skip regions which have inputs that are mutated. + """ + mutated_arg_positions = OrderedSet[int]() + for i, arg in enumerate(flat_args_kwargs): + val_id = id(arg) + if ( + val_id in id_to_initial_version + and id_to_initial_version[val_id] != arg._version + ): + mutated_arg_positions.add(i) + + if mutated_arg_positions: + self.node_to_mutated_arg_positions[node] = mutated_arg_positions + + def add_node_mutation( + self, + node: Node, + arg_pos: int, + ) -> None: + if node in self.node_to_mutated_arg_positions: + self.node_to_mutated_arg_positions[node].add(arg_pos) + else: + self.node_to_mutated_arg_positions[node] = OrderedSet([arg_pos]) + + def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]: + """ + This function is responsible for extracting the largest regions of identical nodes from the given graph. + **Note**: This function assumes the nodes that have been tracked with track_node are in the provided graph argument. + + The algorithm proceeds as follows: + The nodes tracked via track_node above are organized into region groups. The initial region groups look like this: + [[IdenticalNode1], [IdenticalNode2], [IdenticalNode3]] and each sublist is called a region. For each region group + (starting at the topologically latest region group), the inner regions are gradually expanded one node at time from + the flattened args and kwargs of the node in each region provided that for all regions in the group, the nodes being + added are also identical (ie have the same key computed by track_node). This is checked by verifying that the two + nodes have the same identical node list in node_to_duplicates. + """ + topological_ranking = {node: i for i, node in enumerate(graph.nodes)} + region_groups_with_rank = [] + # needed to detect if replacing a region will create cycles + node_to_recursive_ancestors = _populate_recursive_ancestor_map(graph) + + # Create region groups; a region group is a group + # of regions that are all identical. In this initial state + # each region in the group is a single node, and we discard + # groups that are only a single region. + # We track the topological ranking to start with groups later in the graph + # the reason for this is that we will necessarily create the largest groups first. + for group in self.hash_to_duplicates.values(): + if len(group) > 1: + region_group = [] + min_rank = math.inf + for node in group: + # some nodes aren't in the topo ranking? + if node in topological_ranking: + min_rank = min(min_rank, topological_ranking[node]) + region_group.append([node]) + + if len(region_group) > 1: + region_groups_with_rank.append((region_group, min_rank)) + + region_groups_with_rank.sort(key=lambda rg: -rg[1]) + region_groups = [rg for rg, _ in region_groups_with_rank] + + # We start from regions later in the graph and expand them earlier + # as a result, we will create the largest regions first and they won't + # overlap. + seen_nodes: set[Node] = set() + for region_group in region_groups: + fully_expand_region_group( + region_group, + seen_nodes, + node_to_recursive_ancestors, + self._is_identical, + ) + # sort topologically + for region in region_group: + region.sort(key=lambda n: topological_ranking[n]) + + return [ + region_group for region_group in region_groups if len(region_group[0]) > 1 + ] + + def __str__(self) -> str: + return f"GraphRegionTracker(hash_to_duplicates={self.hash_to_duplicates}, node_to_duplicates={self.node_to_duplicates})" + + +class RegionWrapper: + """Holds state for regions e.g. ancestors and new candidate nodes for consideration""" + + def __init__( + self, region: Region, node_to_recursive_ancestors: dict[Node, set[Node]] + ) -> None: + assert len(region) == 1, "all regions should start with one node" + node = region[0] + self.node_to_recursive_ancestors = node_to_recursive_ancestors + self.iter = BackwardBfsArgIter.create(node) + self.nodes_unique = OrderedSet([node]) + self.ancestors = set(node_to_recursive_ancestors[node]) + self.region = region + + def next_candidate(self) -> Optional[Node]: + return self.iter.next() + + def will_inclusion_create_cycle(self, node: Node) -> bool: + external_users = [user for user in node.users if user not in self.nodes_unique] + for user in external_users: + if user in self.ancestors: + return True + + return False + + def add(self, node: Node) -> None: + self.nodes_unique.add(node) + self.region.append(node) + self.iter.add_children(node) + self.ancestors.update(self.node_to_recursive_ancestors[node]) + + +def fully_expand_region_group( + regions: list[Region], + seen_nodes: set[Node], + node_to_recursive_ancestors: dict[Node, set[Node]], + is_identical_fn: Callable[[Node, Node], bool], +) -> None: + debug_log("--------------------------------------------------") + debug_log("expanding new region group: %s", regions) + + # All regions should start with 1 node + assert all(len(region) == 1 for region in regions) + region_wrappers = [ + RegionWrapper(region, node_to_recursive_ancestors) for region in regions + ] + + nodes_to_add = OrderedSet[Node]() + current_node = region_wrappers[0].next_candidate() + + # No children + if current_node is None: + return + + # Loop incrementally adding new nodes to each region + # regions are only expanded if the node to add is valid + # for ALL regions + while current_node: + add_to_all_regions = not region_wrappers[0].will_inclusion_create_cycle( + current_node + ) + nodes_to_add.clear() + nodes_to_add.add(current_node) + for region_wrapper in region_wrappers[1:]: + candidate = region_wrapper.next_candidate() + + debug_log("--------------------") + debug_log( + "considering candidate: %s, cur_node: %s", candidate, current_node + ) + + if not candidate or not add_to_all_regions: + add_to_all_regions = False + continue + + debug_log( + "candidate in previously claimed nodes?: %s", candidate in seen_nodes + ) + debug_log("is_identical: %s", is_identical_fn(candidate, current_node)) + + add_to_all_regions &= ( + candidate not in seen_nodes + and candidate not in nodes_to_add + and candidate.op != "placeholder" + and is_identical_fn(candidate, current_node) + and not region_wrapper.will_inclusion_create_cycle(candidate) + ) + nodes_to_add.add(candidate) + + debug_log(f"add_to_all_regions: {add_to_all_regions}") + debug_log("--------------------") + + if add_to_all_regions: + assert len(region_wrappers) == len(nodes_to_add), ( + "Number of nodes to add must equal the number of regions" + ) + for region_wrapper, node in zip(region_wrappers, nodes_to_add): + region_wrapper.add(node) + debug_log("adding %s's children", node) + debug_log("%s %s", node.args, list(node.kwargs.items())) + seen_nodes.add(node) + + current_node = region_wrappers[0].next_candidate() + + # Ensure regions are sorted in topological order + for region in regions: + region.reverse() + + debug_log("end expand new region group: %s", regions) + debug_log("--------------------------------------------------") + + +def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[Node]]: + node_to_recursive_ancestors: dict[Node, set[Node]] = {} + for node in graph.nodes: + node_to_recursive_ancestors[node] = set() + for node in graph.nodes: + all_args = _get_flat_args_unique(node, {}) + for arg in all_args: + if isinstance(arg, Node): + node_to_recursive_ancestors[node].update( + node_to_recursive_ancestors[arg] + ) + node_to_recursive_ancestors[node].add(arg) + return node_to_recursive_ancestors diff --git a/phivenv/Lib/site-packages/torch/_dynamo/graph_utils.py b/phivenv/Lib/site-packages/torch/_dynamo/graph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..225a23d87d8c7b12c4519255756090a0b2caf7c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/graph_utils.py @@ -0,0 +1,77 @@ +from collections import deque +from typing import Any + +from torch.fx import Graph, map_arg, Node +from torch.utils._ordered_set import OrderedSet + + +# flattens with support for slices +# Note: a better way to do this would +# be register/unregister slices as pytree nodes +# but there is no unregister API in the pytorch +# pytree impl +def _get_flat_args( + node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]] +) -> list[Node]: + args = list[Any]() + map_arg((node.args, node.kwargs), args.append) + if node in node_to_additional_deps: + args.extend(node_to_additional_deps[node]) + return args + + +def _get_flat_args_unique( + node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]] +) -> OrderedSet[Node]: + args = OrderedSet[Node]() + map_arg((node.args, node.kwargs), args.add) + if node in node_to_additional_deps: + args.update(node_to_additional_deps[node]) + return args + + +def _detect_cycles( + graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]] +) -> str: + current_path: deque[Node] = deque() + current_path_set: set[Node] = set() + pending: deque[tuple[Node, Node]] = deque() + + def add_to_current_path(node: Node) -> None: + current_path.append(node) + current_path_set.add(node) + + def pop_current_path() -> None: + node = current_path.pop() + current_path_set.remove(node) + + def current_path_head() -> Node: + return current_path[-1] + + for origin in graph.find_nodes(op="output"): + current_path.clear() + current_path_set.clear() + add_to_current_path(origin) + for child in _get_flat_args_unique(origin, node_to_additional_deps): + pending.append((child, origin)) + + while pending: + cur_node, parent = pending.pop() + + # handle backtracking + while current_path and current_path_head() != parent: + pop_current_path() + + if not isinstance(cur_node, Node): + continue + + if cur_node in current_path_set: + current_path.append(cur_node) + return f"cycle detected in path: {current_path}" + + add_to_current_path(cur_node) + + for child in _get_flat_args_unique(cur_node, node_to_additional_deps): + pending.append((child, cur_node)) + + return "no cycle detected" diff --git a/phivenv/Lib/site-packages/torch/_dynamo/guards.py b/phivenv/Lib/site-packages/torch/_dynamo/guards.py new file mode 100644 index 0000000000000000000000000000000000000000..c396b9233aa5378b15385de106a100e2ead23831 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/guards.py @@ -0,0 +1,3633 @@ +# mypy: allow-untyped-defs + +""" +Core guard system for Dynamo that detects when compiled code needs to be recompiled due to +changes in program state. Guards are conditions that must remain true for previously-compiled +code to be valid for reuse. + +This module provides the infrastructure for creating, managing and checking guards, including: +- Guard creation and composition +- Guard state management and invalidation +- Guard checking and failure handling +- Utilities for guard optimization and debugging +- Integration with Dynamo's compilation caching + +The guard system is critical for Dynamo's ability to efficiently reuse compiled code while +maintaining correctness by detecting when recompilation is necessary due to changes in +program state, tensor properties, or control flow. +""" + +from __future__ import annotations + +import ast +import builtins +import collections +import dataclasses +import enum +import functools +import importlib +import inspect +import io +import logging +import math +import pickle +import sys +import textwrap +import types +import warnings +import weakref +from contextlib import contextmanager +from copy import deepcopy +from inspect import currentframe +from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union +from weakref import ReferenceType + +import torch +import torch.overrides +import torch.utils._device +from torch._C._dynamo.eval_frame import code_framelocals_names +from torch._C._dynamo.guards import ( + check_obj_id, + check_type_id, + dict_version, + DictGuardManager, + install_no_tensor_aliasing_guard, + install_object_aliasing_guard, + install_storage_overlapping_guard, + install_symbolic_shape_guard, + profile_guard_manager, + RootGuardManager, +) +from torch._dynamo.source import ( + get_global_source_name, + get_local_source_name, + IndexedSource, + is_from_flatten_script_object_source, + is_from_local_source, + is_from_optimizer_source, + TensorProperty, + TensorPropertySource, +) +from torch._dynamo.utils import CompileEventLogger, get_metrics_context +from torch._guards import ( + CompileContext, + CompileId, + DuplicateInputs, + Guard, + GuardBuilderBase, + GuardEnvExpr, + GuardSource, + Source, + StorageOverlap, +) +from torch._logging import structured +from torch._utils_internal import justknobs_check +from torch.fx.experimental.symbolic_shapes import ( + _CppShapeGuardsHelper, + _ShapeGuardsHelper, + EqualityConstraint, + is_symbolic, + SYMPY_INTERP, +) +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils._traceback import format_frame, report_compile_source_on_error +from torch.utils.weak import TensorWeakRef + +from . import config, convert_frame, exc +from .eval_frame import set_guard_error_hook +from .source import ( + AttrProxySource, + AttrSource, + CallFunctionNoArgsSource, + CallMethodItemSource, + ChainedSource, + ConstantSource, + ConstDictKeySource, + DataclassFieldsSource, + DefaultsSource, + DictGetItemSource, + DictSubclassGetItemSource, + FlattenScriptObjectSource, + FloatTensorSource, + FSDPNNModuleSource, + GenericAttrSource, + GetItemSource, + GlobalSource, + GlobalStateSource, + GlobalWeakRefSource, + GradSource, + ListGetItemSource, + LocalSource, + NNModuleSource, + NumpyTensorSource, + OptimizerSource, + ScriptObjectQualifiedNameSource, + ShapeEnvSource, + SubclassAttrListSource, + TorchFunctionModeStackSource, + TupleIteratorGetItemSource, + TypeSource, + UnspecializedBuiltinNNModuleSource, + UnspecializedNNModuleSource, + UnspecializedParamBufferSource, + WeakRefCallSource, +) +from .types import ( # noqa: F401 + CacheEntry, + DynamoFrameType, + ExtraState, + GuardedCode, + GuardFail, + GuardFilterEntry, + GuardFn, +) +from .utils import ( + builtin_dict_keys, + common_constant_types, + dataclass_fields, + dict_keys, + get_custom_getattr, + get_torch_function_mode_stack, + get_torch_function_mode_stack_at, + guard_failures, + istype, + key_is_id, + key_to_id, + normalize_range_iter, + orig_code_map, + tensor_always_has_static_shape, + tuple_iterator_getitem, + tuple_iterator_len, + unpatched_nn_module_getattr, + verify_guard_fn_signature, +) + + +guard_manager_testing_hook_fn: Optional[Callable[[Any, Any], Any]] = None + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + + +if TYPE_CHECKING: + from sympy import Symbol + + from torch._dynamo.output_graph import OutputGraphGuardsState + + +log = logging.getLogger(__name__) +guards_log = torch._logging.getArtifactLogger(__name__, "guards") +recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") +recompiles_verbose_log = torch._logging.getArtifactLogger( + __name__, "recompiles_verbose" +) +verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") + + +class GuardManagerWrapper: + """ + A helper class that contains the root guard manager. An instance of this + class is stored in the Dynamo cache entry, so that the cache entry can + access the RootGuardManager stored in the "root" attribute and directly call + the check_nopybind from C++. + """ + + def __init__(self, root=None): + if root is None: + self.root = RootGuardManager() + else: + self.root = root + + self.diff_guard_root = None + self.closure_vars = None + self.args = None + self.code_parts = [] + self.verbose_code_parts = None + self.global_scope = None + self.guard_fail_fn = None + self.cache_entry = None + self.extra_state = None + self.id_matched_objs = {} + self.no_tensor_aliasing_sources = [] + + self.printed_relational_guards = set() + + self.diff_guard_sources: OrderedSet[str] = OrderedSet() + + @contextmanager + def _preserve_printed_relational_guards(self): + self.printed_relational_guards = set() + try: + yield + finally: + self.printed_relational_guards = set() + + def collect_diff_guard_sources(self): + # At the time of finalize, we have only marked guard managers with + # TENSOR_MATCH guards as diff guard managers. So, we do a tree traversal + # and collect all the nodes in the tree (branches) that lead to tensor + # guards. + + # After a recompilation, some of guard managers will have a fail_count > + # 0, so we collect them as well. Later on, we accumulate the diff guard + # sources for all the guard managers. + + def visit_dict_manager(node): + is_diff_guard_node = ( + node.get_source() in self.diff_guard_sources or node.fail_count() > 0 + ) + for idx, (key_mgr, val_mgr) in sorted( + node.get_key_value_managers().items() + ): + is_diff_guard_node |= visit(key_mgr) | visit(val_mgr) + + if is_diff_guard_node: + self.diff_guard_sources.add(node.get_source()) + + return is_diff_guard_node + + def visit_manager(node): + assert not isinstance(node, DictGuardManager) + + is_diff_guard_node = ( + node.get_source() in self.diff_guard_sources or node.fail_count() > 0 + ) + for child_mgr in node.get_child_managers(): + is_diff_guard_node |= visit(child_mgr) + + if is_diff_guard_node: + self.diff_guard_sources.add(node.get_source()) + + return is_diff_guard_node + + def visit(node): + if node is None: + return False + if isinstance(node, DictGuardManager): + return visit_dict_manager(node) + return visit_manager(node) + + visit(self.root) + + return self.diff_guard_sources + + def finalize(self): + self.collect_diff_guard_sources() + self.populate_diff_guard_manager() + + def populate_diff_guard_manager(self): + self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources) + + # Ensure that that C++ side points to the updated diff guard manager. + # When a new GuardManagerWrapper is created, it does not have a + # cache_entry attribute, so it relies on the CacheEntry constructor to + # set the diff_guard_root in C++. But once it is saved in the Dynamo + # cache, C++ side adds a cache_entry attribute. On recompiles, this + # cache_entry is visible, so we update the C++ side to point to the + # update guard manager. + if self.cache_entry: + self.cache_entry.update_diff_guard_root_manager() + + def clone_with_chosen_sources(self, chosen_sources): + def filter_fn(node_mgr): + return node_mgr.get_source() in chosen_sources + + return self.root.clone_manager(filter_fn) + + def get_guard_lines(self, guard): + guard_name = guard.__class__.__name__ + parts = guard.verbose_code_parts() + parts = [guard_name + ": " + part for part in parts] + return parts + + def get_manager_line(self, guard_manager, accessor_str=None): + source = guard_manager.get_source() + t = guard_manager.__class__.__name__ + s = t + ": source=" + source + if accessor_str: + s += ", " + accessor_str + return s + + def construct_dict_manager_string(self, mgr, body): + for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()): + body.writeline(f"KeyValueManager pair at index={idx}") + with body.indent(): + if key_mgr: + body.writeline(f"KeyManager: {self.get_manager_line(key_mgr)}") + self.construct_manager_string(key_mgr, body) + + if val_mgr: + body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}") + self.construct_manager_string(val_mgr, body) + + def construct_manager_string(self, mgr, body): + with body.indent(): + for guard in mgr.get_leaf_guards(): + if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined] + if guard not in self.printed_relational_guards: + self.printed_relational_guards.add(guard) + body.writelines(self.get_guard_lines(guard)) + else: + body.writelines( + [ + guard.__class__.__name__, + ] + ) + else: + body.writelines(self.get_guard_lines(guard)) + + # This works for both DictGuardManager and SubclassedDictGuardManager + if isinstance(mgr, DictGuardManager): + self.construct_dict_manager_string(mgr, body) + + # General case of GuardManager/RootGuardManager + for accessor, child_mgr in zip( + mgr.get_accessors(), mgr.get_child_managers() + ): + body.writeline( + self.get_manager_line(child_mgr, f"accessed_by={accessor.repr()}") + ) + self.construct_manager_string(child_mgr, body) + + def __str__(self): + from torch._inductor.utils import IndentedBuffer + + class IndentedBufferWithPrefix(IndentedBuffer): + def prefix(self): + return "| " * (self._indent * self.tabwidth) + + def writeline(self, line, skip_prefix=False): + if skip_prefix: + super().writeline(line) + else: + super().writeline("+- " + line) + + with self._preserve_printed_relational_guards(): + body = IndentedBufferWithPrefix() + body.tabwidth = 1 + body.writeline("", skip_prefix=True) + body.writeline("TREE_GUARD_MANAGER:", skip_prefix=True) + body.writeline("RootGuardManager") + self.construct_manager_string(self.root, body) + if hasattr(self.root, "get_epilogue_lambda_guards"): + for guard in self.root.get_epilogue_lambda_guards(): + body.writelines(self.get_guard_lines(guard)) + return body.getvalue() + + def check(self, x): + # Only needed for debugging purposes. + return self.root.check(x) + + def check_verbose(self, x): + # Only needed for debugging purposes. + return self.root.check_verbose(x) + + def populate_code_parts_for_debugging(self): + # This should be called when the guard manager is fully populated + relational_guards_seen = set() + + def get_code_parts(leaf_guard): + code_parts = [] + for verbose_code_part in leaf_guard.verbose_code_parts(): + code_part = verbose_code_part.split("#")[0].rstrip() + code_parts.append(code_part) + return code_parts + + def visit(mgr): + nonlocal relational_guards_seen + for guard in mgr.get_leaf_guards(): + if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined] + if guard not in relational_guards_seen: + self.code_parts.extend(get_code_parts(guard)) + relational_guards_seen.add(guard) + else: + self.code_parts.extend(get_code_parts(guard)) + + for child_mgr in mgr.get_child_managers(): + visit(child_mgr) + + visit(self.root) + + +def from_numpy(a): + # If not numpy array, piggy back on e.g. tensor guards to check type + # Re-enable torch function since we disable it on leaf guards + # we need it to properly construct the tensor if a default device is set + with torch.overrides._enable_torch_function(): + return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a + + +# For user stack printing +@functools.cache +def uninteresting_files(): + import torch._dynamo.external_utils + import torch._dynamo.polyfills + + mods = [torch._dynamo.external_utils, torch._dynamo.polyfills] + + from torch._dynamo.polyfills.loader import POLYFILLED_MODULES + + mods.extend(POLYFILLED_MODULES) + + return {inspect.getfile(m) for m in mods} + + +_CLOSURE_VARS: Optional[dict[str, object]] = None + + +def _get_closure_vars(): + global _CLOSURE_VARS + if _CLOSURE_VARS is None: + _CLOSURE_VARS = { + "___check_type_id": check_type_id, + "___check_obj_id": check_obj_id, + "___odict_getitem": collections.OrderedDict.__getitem__, + "___key_to_id": key_to_id, + "___dict_version": dict_version, + "___dict_contains": lambda a, b: dict.__contains__(b, a), + "___tuple_iterator_len": tuple_iterator_len, + "___normalize_range_iter": normalize_range_iter, + "___tuple_iterator_getitem": tuple_iterator_getitem, + "___dataclass_fields": dataclass_fields, + "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, + "__math_isnan": math.isnan, + "__numpy_isnan": None if np is None else np.isnan, + "inf": float("inf"), + "__load_module": importlib.import_module, + "utils_device": torch.utils._device, + "device": torch.device, + "___from_numpy": from_numpy, + "___as_tensor": torch._as_tensor_fullprec, + "torch": torch, + "inspect": inspect, + } + return _CLOSURE_VARS + + +def _ast_unparse(node: ast.AST) -> str: + return ast.unparse(node).replace("\n", "") + + +strip_function_call = torch._C._dynamo.strip_function_call + + +def get_verbose_code_part(code_part: str, guard: Guard) -> str: + extra = "" + if guard is not None: + if guard.user_stack: + for fs in reversed(guard.user_stack): + if fs.filename not in uninteresting_files(): + extra = f" # {format_frame(fs, line=True)}" + break + elif guard.stack: + summary = guard.stack.summary() + if len(summary) > 0: + extra = f" # {format_frame(summary[-1])}" + else: + extra = " # " + return f"{code_part:<60}{extra}" + + +def get_verbose_code_parts( + code_parts: Union[str | list[str]], guard: Guard +) -> list[str]: + if not isinstance(code_parts, list): + code_parts = [code_parts] + return [get_verbose_code_part(code_part, guard) for code_part in code_parts] + + +def convert_int_to_concrete_values(dim) -> Optional[int]: + if dim is None: + return None + if not is_symbolic(dim): + return dim + else: + assert isinstance(dim, torch.SymInt) + return dim.node.maybe_as_int() + + +def convert_to_concrete_values(size_or_stride): + return [convert_int_to_concrete_values(dim) for dim in size_or_stride] + + +def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_keys): + dispatch_key = ( + dispatch_keys | torch._C._dispatch_tls_local_include_set() + ) - torch._C._dispatch_tls_local_exclude_set() + dtype = value.dtype + device_index = value.device.index + requires_grad = value.requires_grad + guard_str = ( + f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, " + f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})" + ) + return guard_str + + +def get_key_index(dct, key): + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on PyDict_Next + # to traverse the dictionary, which uses the internal data structure and + # does not call the overridden keys method. + return list(builtin_dict_keys(dct)).index(key) + + +def get_key_index_source(source, index): + return f"list(dict.keys({source}))[{index}]" + + +def raise_local_type_error(obj: Any) -> NoReturn: + raise TypeError( + f"Type {type(obj)} for object {obj} cannot be saved " + + "into torch.compile() package since it's defined in local scope. " + + "Please define the class at global scope (top level of a module)." + ) + + +@dataclasses.dataclass(frozen=True) +class NNModuleAttrAccessorInfo: + # Represents where is the attr name is present in the nn module attribute + # access + + # Tells that the attribute can be accessed via __dict__ + present_in_generic_dict: bool = False + + # Either the actual name or _parameters/_buffers/_modules + l1_key: Optional[str] = None + + # Actual parameter/buffer/submodule name + l2_key: Optional[str] = None + + +def getitem_on_dict_manager( + source, base_guard_manager, base_example_value, example_value, guard_manager_enum +): + base_source_name = source.base.name() + if isinstance(source.index, ConstDictKeySource): + index = source.index.index + else: + assert isinstance(base_example_value, dict) + index = get_key_index(base_example_value, source.index) + + key_source = get_key_index_source(base_source_name, index) + + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on PyDict_Next + # to traverse the dictionary, which uses the internal data structure and + # does not call the overridden keys method. + key_example_value = list(builtin_dict_keys(base_example_value))[index] + if isinstance(key_example_value, (int, str)): + value_source = f"{base_source_name}[{key_example_value!r}]" + else: + value_source = f"{base_source_name}[{key_source}]" + if not isinstance(source.index, ConstDictKeySource): + # We have to insert a key manager guard here + # TODO - source debug string is probably wrong here. + base_guard_manager.get_key_manager( + index=index, + source=key_source, + example_value=source.index, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ).add_equals_match_guard( + source.index, [f"{key_source} == {key_example_value!r}"] + ) + + return base_guard_manager.get_value_manager( + index=index, + source=value_source, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + + +def match_on_id_for_tensor(guard): + source = guard.originating_source + # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads + # to a new tensor every time and therefore id differs. + if isinstance(source, NumpyTensorSource): + return False + + if guard.is_specialized_nn_module(): + return True + + return source.is_dict_key() and not isinstance(source, GradSource) + + +# The ready to eval generated code (possibly multiple parts) for a guard, plus +# the original guard object that created it for provenance +@dataclasses.dataclass +class GuardCodeList: + code_list: list[str] + guard: Guard + + +class GuardManagerType(enum.Enum): + GUARD_MANAGER = 1 + DICT_GUARD_MANAGER = 2 + + +@functools.cache +def code_framelocals_names_reversed_cached(code: types.CodeType): + return list(reversed(code_framelocals_names(code))) + + +class GuardBuilder(GuardBuilderBase): + def __init__( + self, + f_code: types.CodeType, + id_ref: Callable[[Any, str], str], + source_ref: Callable[[Source], str], + lookup_weakrefs: Callable[[object], ReferenceType[object]], + local_scope: dict[str, object], + global_scope: dict[str, object], + guard_manager: GuardManagerWrapper, + check_fn_manager: CheckFunctionManager, + serialization_mode: Optional[str] = None, + ): + self.f_code = f_code + self.id_ref = id_ref + self.source_ref = source_ref + self.lookup_weakrefs = lookup_weakrefs + self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope} + self.scope["__builtins__"] = builtins.__dict__.copy() + for ( + name, + package_module, + ) in torch.package.package_importer._package_imported_modules.items(): + name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_") + # Write the package module into the scope so that we can import it + self.scope["__builtins__"][name] = package_module + # Write the demangled name to the scope so that we can use it + self.scope[name] = package_module + self.guard_manager = guard_manager + + self.argnames: list[str] = [] + # Code is python expression strings generated for each guard + self.code: list[GuardCodeList] = [] + # shape_env_code is only used by builder and is used for + # shape env code. This exists only because we need to make sure + # shape env guards get run after tensor match guards (since the + # tensor match guards make sure we actually have tensors) + self.shape_env_code: list[GuardCodeList] = [] + + # Collect the guard managers and debug info to insert no tensor aliasing + # guards. + self.no_tensor_aliasing_names: list[str] = [] + self.no_tensor_aliasing_guard_managers: list[GuardManagerWrapper] = [] + + self.check_fn_manager: CheckFunctionManager = check_fn_manager + + # Collect the ids of dicts which need key order guarding. source_name is + # not sufficient because for nn modules, we can have different sources + # to access the same object - self._module["param"] is same as + # self.param. + self.key_order_guarded_dict_ids = set() + for source in self.check_fn_manager.output_graph.guard_on_key_order: + self.key_order_guarded_dict_ids.add(id(self.get(source.name()))) + + # Keep track of weak references of objects with ID_MATCH guard. This + # info is stored alongside optimized_code and guard_manager and is used to + # limit the number of cache entries with same ID_MATCH'd object. + self.id_matched_objs: dict[str, ReferenceType[object]] = {} + + # Save the guard managers to avoid repeatedly traversing sources. + self._cached_guard_managers: dict[ + str, torch._C._dynamo.guards.GuardManager + ] = {} + self._cached_duplicate_input_guards: set[tuple[str, str]] = set() + self.serialization_mode = serialization_mode + + def guard_on_dict_keys_and_ignore_order(self, example_value, guard): + dict_mgr = self.get_guard_manager(guard) + if isinstance(dict_mgr, DictGuardManager): + raise NotImplementedError( + "Not expecting a DictGuardManager. Seems like Dynamo incorrectly " + f"added the dict to tx.output.guard_on_key_order for {guard.name}" + ) + + # Iterate over the dicts and install a dict_getitem_manager. + dict_source = guard.originating_source.name() + + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on PyDict_Next + # to traverse the dictionary, which uses the internal data structure and + # does not call the overridden keys method. + for key in builtin_dict_keys(example_value): + value = example_value[key] + value_source = DictGetItemSource(guard.originating_source, index=key) + guard_manager_enum = self.get_guard_manager_type( + value_source, example_value + ) + dict_mgr.dict_getitem_manager( + key=key, + source=f"{dict_source}[{key!r}]", + example_value=value, + guard_manager_enum=guard_manager_enum, + ) + + def guard_on_dict_keys_and_order(self, value, guard): + # Add key managers for the DictGuardManager. Then add either an + # ID_MATCH or EQUALS_MATCH guard on the key. + dict_mgr = self.get_guard_manager(guard) + if not isinstance(dict_mgr, DictGuardManager): + raise NotImplementedError( + "Expecting a DictGuardManager. Seems like Dynamo forgot " + f"to set the right guard manager enum for {guard.name}" + ) + assert isinstance(dict_mgr, DictGuardManager) + + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on PyDict_Next + # to traverse the dictionary, which uses the internal data structure and + # does not call the overridden keys method. + for idx, key in enumerate(builtin_dict_keys(value)): + key_source = get_key_index_source(guard.name, idx) + key_manager = dict_mgr.get_key_manager( + index=idx, + source=key_source, + example_value=key, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + if key_is_id(key): + # Install ID_MATCH guard + id_val = self.id_ref(key, key_source) + key_manager.add_id_match_guard( + id_val, + get_verbose_code_parts( + f"__check_obj_id({key_source}, {id_val})", guard + ), + ) + else: + # Install EQUALS_MATCH guard + key_manager.add_equals_match_guard( + key, get_verbose_code_parts(f"{key_source} == {key!r}", guard) + ) + + @staticmethod + def _get_generic_dict_manager_example_value(example_value): + # due to a bug in 3.13.0 (introduced by https://github.com/python/cpython/pull/116115, + # reported in https://github.com/python/cpython/issues/125608, + # fixed by https://github.com/python/cpython/pull/125611), we cannot take + # advantage of __dict__ versions to speed up guard checks. + if ( + config.issue_3_13_0_warning + and sys.version_info >= (3, 13) + and sys.version_info < (3, 13, 1) + ): + warnings.warn( + "Guards may run slower on Python 3.13.0. Consider upgrading to Python 3.13.1+.", + RuntimeWarning, + ) + return None + return example_value + + def getattr_on_nn_module( + self, + source, + base_guard_manager, + base_example_value, + example_value, + base_source_name, + source_name, + guard_manager_enum, + ): + """ + This tries to avoid calling the expensive nn module custom getattr method by + checking if the attribute is accessible via __dict__. For attributes that + are not accessible via __dict__ (like descriptors), we fallback to + PyObject_GetAttr. + + There are two cases that we optimize for + 1) attributes present directly in __dict__, e.g training. + 2) parameters/buffers/modules - they can be accessed via _parameters, + _buffers, _modules keys in __dict__. For example, mod.linear can be + accessed as mod.__dict__["_parameters"]["linear"] + + The most common and expensive case for nn module guards is of type + mod.submod1.submod2.submod3.training. We avoid the python getattr of nn + modules by going through the __dict__. + """ + + def getitem_on_dict_mgr( + mgr, key, source_name, base_example_value, example_value, guard_manager_enum + ): + if isinstance(mgr, DictGuardManager): + # Case where the user code relies on key order, e.g., + # named_parameters + index = get_key_index(base_example_value, key) + + # Install the key manager and add equals match guard + key_source = f"list(dict.keys({source_name}))[{index!r}]" + mgr.get_key_manager( + index=index, + source=key_source, + example_value=key, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ).add_equals_match_guard(key, [f"{key_source} == {key!r}"]) + + # Install the value manager + return mgr.get_value_manager( + index=index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + return mgr.dict_getitem_manager( + key=key, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + + attr_name = source.member + mod_dict = base_example_value.__dict__ + + all_class_attribute_names: set[str] = set() + for x in inspect.getmro(base_example_value.__class__): + all_class_attribute_names.update(x.__dict__.keys()) + + accessor_info = NNModuleAttrAccessorInfo(False, None, None) + + if attr_name in mod_dict: + accessor_info = NNModuleAttrAccessorInfo(True, attr_name, None) + elif "_parameters" in mod_dict and attr_name in mod_dict["_parameters"]: + accessor_info = NNModuleAttrAccessorInfo(True, "_parameters", attr_name) + elif "_buffers" in mod_dict and attr_name in mod_dict["_buffers"]: + accessor_info = NNModuleAttrAccessorInfo(True, "_buffers", attr_name) + elif ( + attr_name not in all_class_attribute_names + and "_modules" in mod_dict + and attr_name in mod_dict["_modules"] + ): + # Check test_attr_precedence test - instance attributes always take precedence unless its an nn.Module. + accessor_info = NNModuleAttrAccessorInfo(True, "_modules", attr_name) + + if not accessor_info.present_in_generic_dict: + # The attribute can be accessed by __getattribute__ call, so rely on + # PyObject_GetAttr + return base_guard_manager.getattr_manager( + attr=source.member, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + assert accessor_info.l1_key + l1_key = accessor_info.l1_key + l2_key = accessor_info.l2_key + + # Set source strings for debug info + mod_dict_source = f"{base_source_name}.__dict__" + l1_source_name = l2_source_name = None + l1_value = l2_value = None + l1_guard_manager_enum = l2_guard_manager_enum = None + if l2_key: + l1_source = AttrSource(source.base, l1_key) + l1_source_name = l1_source.name() + l1_value = mod_dict[l1_key] + # do not guard on key order for _parameters etc unless the user code + # actually needs the key order (e.g. calling named_parameters) + l1_guard_manager_enum = self.get_guard_manager_type(l1_source, l1_value) + + l2_source_name = source_name + l2_value = example_value + l2_guard_manager_enum = self.get_guard_manager_type( + source, example_value + ) + else: + l1_source_name = source_name + l1_value = example_value + l1_guard_manager_enum = self.get_guard_manager_type( + source, example_value + ) + + # Get __dict__ accessor. No need to guard on dict key order, so use base + # Guard Manager + mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager( + source=mod_dict_source, + example_value=self._get_generic_dict_manager_example_value(mod_dict), + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + + l1_mgr = getitem_on_dict_mgr( + mgr=mod_generic_dict_manager, + key=l1_key, + source_name=l1_source_name, + base_example_value=mod_dict, + example_value=l1_value, + guard_manager_enum=l1_guard_manager_enum, + ) + + if l2_key: + return getitem_on_dict_mgr( + mgr=l1_mgr, + key=l2_key, + source_name=l2_source_name, + base_example_value=l1_value, + example_value=l2_value, + guard_manager_enum=l2_guard_manager_enum, + ) + return l1_mgr + + def requires_key_order_guarding(self, source): + source_name = source.name() + if source_name == "": + return False + obj_id = id(self.get(source_name)) + return obj_id in self.key_order_guarded_dict_ids + + def get_guard_manager_type(self, source, example_value): + guard_manager_enum = GuardManagerType.GUARD_MANAGER + if self.requires_key_order_guarding(source): + # Fix this if condition + if isinstance(example_value, dict_keys): + guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER + else: + assert isinstance(example_value, dict) + guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER + return guard_manager_enum + + def manager_guards_on_keys(self, mgr_enum): + return mgr_enum == GuardManagerType.DICT_GUARD_MANAGER + + def get_global_guard_manager(self): + return self.guard_manager.root.globals_dict_manager( + f_globals=self.scope["G"], + source="G", + example_value=self.scope["G"], + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + + def get_guard_manager_from_source(self, source): + root_guard_manager = self.guard_manager.root + + example_value = None + source_name = source.name() + + if source_name != "" and source_name in self._cached_guard_managers: + return self._cached_guard_managers[source_name] + + if source_name != "": + example_value = self.get(source_name) + + guard_manager_enum = self.get_guard_manager_type(source, example_value) + + # Get base manager related information + base_source_name = None + base_example_value = None + base_guard_manager = None + base_guard_manager_enum = GuardManagerType.GUARD_MANAGER + if isinstance(source, ChainedSource): + base_source_name = source.base.name() + base_example_value = self.get(base_source_name) + base_guard_manager = self.get_guard_manager_from_source(source.base) + base_guard_manager_enum = self.get_guard_manager_type( + source.base, base_example_value + ) + + # Use istype instead of isinstance to check for exact type of source. + if istype(source, LocalSource): + # Refer to index in the frame's localsplus directly. + # NOTE: name order for a code object doesn't change. + # NOTE: we need to find the LAST matching index because <= 3.10 contains + # duplicate names in the case of cells: a name can be both local and cell + # and will take up 2 slots of the frame's localsplus. The correct behavior + # is to refer to the cell, which has a higher index. + framelocals_names_reversed = code_framelocals_names_reversed_cached( + self.f_code + ) + framelocals_idx = ( + len(framelocals_names_reversed) + - framelocals_names_reversed.index(source.local_name) + - 1 + ) + out = root_guard_manager.framelocals_manager( + key=(source.local_name, framelocals_idx), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GlobalSource): + # Global manager accepts a dict but it is not a DictGuardManager + # because globals dict is big and we typically guard on a very + # selected items on globals. + out = self.get_global_guard_manager().dict_getitem_manager( + key=source.global_name, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GlobalWeakRefSource): + out = self.get_global_guard_manager().global_weakref_manager( + global_name=source.global_name, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GlobalStateSource): + # Don't do anything here. We guard on global state completely in + # C++. So just return the root mgr. + return root_guard_manager + elif istype(source, ShapeEnvSource): + return root_guard_manager + elif istype(source, TypeSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.type_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype( + source, + ( + OptimizerSource, + NNModuleSource, + UnspecializedNNModuleSource, + UnspecializedBuiltinNNModuleSource, + FSDPNNModuleSource, + ), + ): + assert base_guard_manager # to make mypy happy + out = base_guard_manager + elif istype(source, TorchFunctionModeStackSource): + out = root_guard_manager.lambda_manager( + python_lambda=lambda _: get_torch_function_mode_stack_at( + source._get_index() + ), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GradSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.grad_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GenericAttrSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.generic_getattr_manager( + attr=source.member, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, (AttrSource, UnspecializedParamBufferSource)): + assert base_guard_manager # to make mypy happy + + if ( + isinstance(base_example_value, torch.nn.Module) + and get_custom_getattr(base_example_value) + is unpatched_nn_module_getattr + ): + out = self.getattr_on_nn_module( + source, + base_guard_manager, + base_example_value, + example_value, + base_source_name, + source_name, + guard_manager_enum, + ) + else: + out = base_guard_manager.getattr_manager( + attr=source.member, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)): + assert base_guard_manager # to make mypy happy + assert isinstance(base_example_value, (dict, collections.OrderedDict)) + if isinstance(base_guard_manager, DictGuardManager): + assert self.manager_guards_on_keys(base_guard_manager_enum) + out = getitem_on_dict_manager( + source, + base_guard_manager, + base_example_value, + example_value, + guard_manager_enum, + ) + else: + if isinstance(source.index, ConstDictKeySource): + raise RuntimeError( + "Expecting clean index here. Likely Dynamo forgot to mark" + " a dict as guard_on_key_order" + ) + out = base_guard_manager.dict_getitem_manager( + key=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, TensorPropertySource): + out = getattr( + base_guard_manager, + f"tensor_property_{source.prop.name.lower()}_manager", + )( + idx=source.idx, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, IndexedSource): + assert base_guard_manager # to make mypy happy + + out = base_guard_manager.indexed_manager( + idx=source.idx, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, ListGetItemSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.list_getitem_manager( + key=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, GetItemSource): + assert base_guard_manager # to make mypy happy + assert not isinstance( + base_example_value, (dict, collections.OrderedDict) + ), "Use DictGetItemSource" + if isinstance(base_example_value, list) and not source.index_is_slice: + out = base_guard_manager.list_getitem_manager( + key=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif isinstance(base_example_value, tuple) and not source.index_is_slice: + out = base_guard_manager.tuple_getitem_manager( + key=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + index = source.index + if source.index_is_slice: + index = source.unpack_slice() + out = base_guard_manager.getitem_manager( + key=index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, DefaultsSource): + assert base_guard_manager # to make mypy happy + assert callable(base_example_value) + if not source.is_kw: + out = base_guard_manager.func_defaults_manager( + source=base_source_name, + example_value=base_example_value.__defaults__, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ).getitem_manager( + key=source.idx_key, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + # kwdefauts is a dict, so use a DictGuardManager + kwdefaults = base_example_value.__kwdefaults__ + assert base_source_name is not None + kw_source = base_source_name + ".__kwdefaults__" + + # kwdefaults is a dict. No need to guard on dict order. + dict_mgr = base_guard_manager.func_kwdefaults_manager( + source=kw_source, + example_value=kwdefaults, + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + assert not isinstance(dict_mgr, DictGuardManager) + + out = dict_mgr.dict_getitem_manager( + key=source.idx_key, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, NumpyTensorSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=from_numpy, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, SubclassAttrListSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x.__tensor_flatten__()[0], + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, FlattenScriptObjectSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x.__obj_flatten__(), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, ScriptObjectQualifiedNameSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x._type().qualified_name(), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, AttrProxySource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x.get_base(), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, CallMethodItemSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x.item(), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, FloatTensorSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: torch._as_tensor_fullprec(x), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, TupleIteratorGetItemSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.tuple_iterator_getitem_manager( + index=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif isinstance(source, ConstDictKeySource): + if not isinstance(base_guard_manager, DictGuardManager): + raise AssertionError( + "ConstDictKeySource can only work on DictGuardManager" + ) + out = base_guard_manager.get_key_manager( + index=source.index, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, WeakRefCallSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.weakref_call_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, CallFunctionNoArgsSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.call_function_no_args_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + elif istype(source, DataclassFieldsSource): + assert base_guard_manager + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: dataclass_fields(x), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + raise AssertionError( + f"missing guard manager builder {source} - {source.name()}" + ) + + self._cached_guard_managers[source.name()] = out + return out + + def get_guard_manager(self, guard: Guard): + return self.get_guard_manager_from_source(guard.originating_source) + + def add_python_lambda_leaf_guard_to_root( + self, + code_parts, + verbose_code_parts, + closure_vars=None, + is_epilogue=True, + ): + if closure_vars is None: + closure_vars = _get_closure_vars() + # Adds a lambda leaf guard to the root guard manager. It wraps the + # code_parts in a function object which is then passed on to the leaf + # guard. + make_guard_fn_args = ", ".join(closure_vars.keys()) + _guard_body, pycode = build_guard_function(code_parts, make_guard_fn_args) + out: dict[str, Any] = {} + globals_for_guard_fn = {"G": self.scope["G"]} + guards_log.debug("Python shape guard function:\n%s", pycode) + exec(pycode, globals_for_guard_fn, out) + guard_fn = out["___make_guard_fn"](*closure_vars.values()) + if is_epilogue: + # Epilogue guards are run after all the other guards have finished. + # If epilogue guards contain a getattr or getitem access, one of the + # other guards would fail preventing the epilogue guards to run. + self.guard_manager.root.add_epilogue_lambda_guard( + guard_fn, verbose_code_parts + ) + else: + self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts) + + # Warning: use this with care! This lets you access what the current + # value of the value you are guarding on is. You probably don't want + # to actually durably save this value though (because it's specific + # to this frame!) Instead, you should be reading out some property + # (like its type) which is what you permanently install into the + # guard code. + def get(self, name: str, closure_vars: Optional[dict[str, Any]] = None) -> Any: + if closure_vars is None: + closure_vars = _get_closure_vars() + return eval(name, self.scope, closure_vars) + + # Registers the usage of the source name referenced by the + # string (or stored in the Guard) as being guarded upon. It's important + # to call this before generating some code that makes use of 'guard', + # because without this call, we won't actually bind the variable + # you reference in the actual guard closure (oops!) + def arg_ref(self, guard: Union[str, Guard]) -> str: + name: str + if isinstance(guard, str): + name = guard + else: + name = guard.name + base = strip_function_call(name) + if base not in self.argnames: + is_valid = torch._C._dynamo.is_valid_var_name(base) + if is_valid: + if is_valid == 2: + log.warning("invalid var name: %s", guard) + self.argnames.append(base) + + return name + + def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): + attr_source = AttrSource(guard.originating_source, attr_name) + # Copy the stack info + new_guard = Guard( + attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack + ) + new_guard.create(self) + + # Note: the order of the guards in this file matters since we sort guards on the same object by lineno + def HASATTR(self, guard: Guard): + source = guard.originating_source + if isinstance(source, NNModuleSource): + source = source.base + assert isinstance(source, AttrSource), f"invalid source {guard.name}" + base_source = source.base + base = base_source.name() + attr = source.member + + ref = self.arg_ref(base) + val = hasattr(self.get(base), attr) + code = None + if val: + code = f"hasattr({ref}, {attr!r})" + else: + code = f"not hasattr({ref}, {attr!r})" + self._set_guard_export_info( + guard, [code], provided_guarded_object=self.get(base) + ) + + base_manager = self.get_guard_manager_from_source(base_source) + if val: + # Just install a getattr manager. GetAttrGuardAccessor itself + # acts as hasattr guard. + example_value = self.get(source.name()) + base_example_value = self.get(base) + guard_manager_enum = self.get_guard_manager_type(source, example_value) + + # if the base value is nn.Module, check if we can speedup the + # guard by going through __dict__ attrs. + if ( + isinstance(base_example_value, torch.nn.Module) + and get_custom_getattr(base_example_value) + is unpatched_nn_module_getattr + ): + return self.getattr_on_nn_module( + source, + base_manager, + base_example_value, + example_value, + base, + source.name(), + guard_manager_enum, + ) + else: + base_manager.getattr_manager( + attr=attr, + source=guard.name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) + else: + base_manager.add_no_hasattr_guard(attr, get_verbose_code_parts(code, guard)) + + def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: + assert attr is not None + ref = self.arg_ref(guard) + val = self.get(guard.name) + assert isinstance(val, torch.nn.Module) + + base_manager = self.get_guard_manager(guard) + + mod_dict_source = f"{guard.name}.__dict__" + mod_generic_dict_manager = base_manager.get_generic_dict_manager( + source=mod_dict_source, + example_value=self._get_generic_dict_manager_example_value(val.__dict__), + guard_manager_enum=GuardManagerType.GUARD_MANAGER, + ) + + code = f"not ___dict_contains({attr!r}, {ref}.__dict__)" + mod_generic_dict_manager.add_dict_contains_guard( + False, attr, get_verbose_code_parts(code, guard) + ) + + def TYPE_MATCH(self, guard: Guard) -> None: + # ___check_type_id is same as `id(type(x)) == y` + value = self.get(guard.name) + if isinstance(value, torch._subclasses.FakeTensor) and value.pytype: + t = value.pytype + else: + t = type(value) + + if self.serialization_mode == "save": + if t.__qualname__ != t.__name__: + raise_local_type_error(value) + + obj_id = self.id_ref(t, f"type({guard.name})") + code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" + self._set_guard_export_info(guard, [code]) + + self.get_guard_manager(guard).add_type_match_guard( + obj_id, get_verbose_code_parts(code, guard) + ) + + def DICT_VERSION(self, guard: Guard): + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "DICT_VERSION guard cannot be serialized." + ) + # ___check_dict_version is same as `dict_version(x) == y` + ref = self.arg_ref(guard) + val = self.get(guard.name) + version = dict_version(self.get(guard.name)) + code = f"___dict_version({ref}) == {version}" + self._set_guard_export_info(guard, [code]) + + # TODO(anijain2305) - Delete this when DictGuardManager uses tags + # for dicts. + self.get_guard_manager(guard).add_dict_version_guard( + val, get_verbose_code_parts(code, guard) + ) + + def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): + dict_ref = self.arg_ref(guard) + + maybe_not = "not " if invert else "" + code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})" + self._set_guard_export_info(guard, [code]) + + self.get_guard_manager(guard).add_dict_contains_guard( + not invert, key, get_verbose_code_parts(code, guard) + ) + + def BOOL_MATCH(self, guard: Guard): + # checks val == True or val == False + ref = self.arg_ref(guard) + val = self.get(guard.name) + assert istype(val, bool) + code = [f"{ref} == {val!r}"] + self._set_guard_export_info(guard, code) + + if val: + self.get_guard_manager(guard).add_true_match_guard( + get_verbose_code_parts(code, guard) + ) + else: + self.get_guard_manager(guard).add_false_match_guard( + get_verbose_code_parts(code, guard) + ) + + def NONE_MATCH(self, guard: Guard): + # checks `val is None` + ref = self.arg_ref(guard) + val = self.get(guard.name) + assert val is None + code = [f"{ref} is None"] + self._set_guard_export_info(guard, code) + + self.get_guard_manager(guard).add_none_match_guard( + get_verbose_code_parts(code, guard) + ) + + def ID_MATCH(self, guard: Guard): + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError("ID_MATCH guard cannot be serialized.") + # ___check_obj_id is same as `id(x) == y` + if isinstance(guard.originating_source, TypeSource): + # optional optimization to produce cleaner/faster guard code + return self.TYPE_MATCH( + Guard(guard.originating_source.base, GuardBuilder.TYPE_MATCH) # type: ignore[arg-type] + ) + + ref = self.arg_ref(guard) + val = self.get(guard.name) + id_val = self.id_ref(val, guard.name) + code = f"___check_obj_id({ref}, {id_val})" + self._set_guard_export_info(guard, [code]) + + self.get_guard_manager(guard).add_id_match_guard( + id_val, get_verbose_code_parts(code, guard) + ) + + # Keep track of ID_MATCH'd objects. This will be used to modify the + # cache size logic + if isinstance(guard.originating_source, LocalSource): + # TODO(anijain2305) - This is currently restricted to nn.Module objects + # because many other ID_MATCH'd objects fail - like DeviceMesh. + # Increase the scope of ID_MATCH'd objects. + if isinstance(val, torch.nn.Module): + local_name = guard.originating_source.local_name + weak_id = self.lookup_weakrefs(val) + if weak_id is not None: + self.id_matched_objs[local_name] = weak_id + + def NOT_NONE_MATCH(self, guard: Guard, value=None): + ref = self.arg_ref(guard) + val = self.get(guard.name) + assert isinstance(val, torch.Tensor) + code = f"{ref} is not None" + self._set_guard_export_info(guard, [code]) + + self.get_guard_manager(guard).add_not_none_guard( + get_verbose_code_parts(code, guard) + ) + + def DISPATCH_KEY_SET_MATCH(self, guard: Guard): + ref = self.arg_ref(guard) + val = self.get(guard.name) + assert isinstance(val, torch._C.DispatchKeySet) + code_parts = f"{ref}.raw_repr() == {val!r}.raw_repr()" + + self.get_guard_manager(guard).add_dispatch_key_set_guard( + val, get_verbose_code_parts(code_parts, guard) + ) + + def NAME_MATCH(self, guard: Guard): + self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) + + def DUAL_LEVEL(self, guard: Guard): + # Invalidate dual level if current dual level is different than the one + # in the fx graph + dual_level = self.check_fn_manager.output_graph.dual_level + code = [f"torch.autograd.forward_ad._current_level == {dual_level}"] + self._set_guard_export_info(guard, [code]) + # TODO(anijain2305) - Consider this moving this guard to C++ + forward_ad = torch.autograd.forward_ad + + def fn(x): + return forward_ad._current_level == dual_level + + self.guard_manager.root.add_lambda_guard( + fn, get_verbose_code_parts(code, guard) + ) + + def FUNCTORCH_STACK_MATCH(self, guard: Guard): + # Invalidate functorch code if current level is different than + # the one when FX graph was generated + cis = self.check_fn_manager.output_graph.functorch_layers + states = [ci.get_state() for ci in cis] + code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] + self._set_guard_export_info(guard, code) + + # TODO(anijain2305) - Consider this moving this guard to C++ + compare_fn = torch._functorch.pyfunctorch.compare_functorch_state + + def fn(x): + return compare_fn(states) + + self.guard_manager.root.add_lambda_guard( + fn, get_verbose_code_parts(code, guard) + ) + + def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard): + get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks + are_inline_hooks = ( + torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable + ) + + def hooks_ids_fn(hooks): + if not are_inline_hooks(hooks): + return None + + pack_hook, unpack_hook = hooks + return tuple(map(id, hooks)) + + guard_hooks_ids = hooks_ids_fn(get_hooks()) + + code = [ + f"torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == {guard_hooks_ids}" + ] + self._set_guard_export_info(guard, code) + + def fn(x): + return guard_hooks_ids == hooks_ids_fn(get_hooks()) + + self.guard_manager.root.add_lambda_guard( + fn, get_verbose_code_parts(code, guard) + ) + + def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard): + value = self.get(guard.name) + original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) + if hasattr(value, "__metadata_guard__"): + verify_guard_fn_signature(value) + + def metadata_checker(x): + return value.__metadata_guard__( + original_metadata, x.__tensor_flatten__()[1] + ) + + else: + + def metadata_checker(x): + return x.__tensor_flatten__()[1] == original_metadata + + global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}" + self.get_guard_manager(guard).add_lambda_guard( + metadata_checker, get_verbose_code_parts(global_name, guard) + ) + + def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None): + ref = self.arg_ref(guard) + val = self.get(guard.name) + if np: + np_types: tuple[type[Any], ...] = ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + ) + else: + np_types = () + + ok_mutable_types = (list, set) + + ok_types = tuple( + common_constant_types + | { + type, + tuple, + frozenset, + slice, + range, + dict_keys, + torch.Size, + *np_types, + *ok_mutable_types, + } + ) + + if torch.distributed.is_available(): + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Replicate, + Shard, + ) + + ok_types = ok_types + ( + Shard, + Replicate, + Partial, + DeviceMesh, + _StridedShard, + ) + + from torch.export.dynamic_shapes import _IntWrapper + + ok_types = ok_types + (_IntWrapper,) + + import torch.utils._pytree as pytree + + assert istype(val, ok_types) or pytree.is_constant_class(type(val)), ( + f"Unexpected type {type(val)}" + ) + + # Special case for nan because float("nan") == float("nan") evaluates to False + if istype(val, float) and math.isnan(val): + self.TYPE_MATCH(guard) + code = [] + code.append(f"__math_isnan({ref})") + self._set_guard_export_info(guard, code) + + self.get_guard_manager(guard).add_lambda_guard( + _get_closure_vars()["__math_isnan"], + get_verbose_code_parts(code, guard), + ) + return + + # Python math library doesn't support complex nan, so we need to use numpy + if istype(val, complex) and np.isnan(val): + self.TYPE_MATCH(guard) + code = [] + code.append(f"__numpy_isnan({ref})") + self._set_guard_export_info(guard, code) + + self.get_guard_manager(guard).add_lambda_guard( + _get_closure_vars()["__numpy_isnan"], + get_verbose_code_parts(code, guard), + ) + return + + # Construct a debug string to put into the c++ equals match guard. + code = [f"{ref} == {val!r}"] + if istype(val, ok_mutable_types): + # C++ guards perform a pointer equality check to speedup guards, but the assumption is that the object + # is immutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the + # pointer equality check. + val = deepcopy(val) + + verbose_code_parts = get_verbose_code_parts(code, guard) + if recompile_hint: + verbose_code_parts = [ + f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts + ] + + self.get_guard_manager(guard).add_equals_match_guard(val, verbose_code_parts) + self._set_guard_export_info(guard, code) + return + + def CONSTANT_MATCH(self, guard: Guard): + val = self.get(guard.name) + if istype(val, bool): + self.BOOL_MATCH(guard) + elif val is None: + self.NONE_MATCH(guard) + elif istype(val, types.CodeType): + self.ID_MATCH(guard) + else: + self.EQUALS_MATCH(guard) + + def NN_MODULE(self, guard: Guard): + # don't support this in serialization because it uses unsupported ID_MATCH + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "NN_MODULE guard cannot be serialized." + ) + self.ID_MATCH(guard) + val = self.get(guard.name) + if hasattr(val, "training"): + assert istype(val.training, bool) + self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) + else: + exc.unimplemented_v2( + gb_type="Attempted to guard on uninitialized nn.Module", + context="", + explanation="Attempted to setup an NN_MODULE guard on uninitialized " + f"nn.Module subclass `{type(val)}`.", + hints=[ + "Ensure the `nn.Module` subclass instance has called `super().__init__()`.", + ], + ) + + def FUNCTION_MATCH(self, guard: Guard): + """things like torch.add and user defined functions""" + # don't support this in serialization because it uses unsupported ID_MATCH + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "FUNCTION_MATCH guard cannot be serialized." + ) + return self.ID_MATCH(guard) + + def CLOSURE_MATCH(self, guard: Guard): + """matches a closure by __code__ id.""" + # don't support this in serialization because it uses unsupported FUNCTION_MATCH + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "CLOSURE_MATCH guard cannot be serialized." + ) + val = self.get(guard.name) + # Strictly only want user-defined functions + if type(val) == types.FunctionType and hasattr(val, "__code__"): + self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) + self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) + else: + self.FUNCTION_MATCH(guard) + + def BUILTIN_MATCH(self, guard: Guard): + return self.FUNCTION_MATCH(guard) + + def SEQUENCE_LENGTH(self, guard): + # This guard is used to check length of PySequence objects like list, + # tuple, collections.deque etc + ref = self.arg_ref(guard) + value = self.get(guard.name) + + if not isinstance(value, dict): + # C++ DICT_LENGTH checks for type + self.TYPE_MATCH(guard) + + code = [] + if len(value) == 0: + code.append(f"not {ref}") + else: + code.append(f"len({ref}) == {len(value)}") + + self._set_guard_export_info(guard, code) + if isinstance(value, dict): + self.get_guard_manager(guard).add_dict_length_check_guard( + len(value), get_verbose_code_parts(code, guard) + ) + else: + self.get_guard_manager(guard).add_length_check_guard( + len(value), get_verbose_code_parts(code, guard) + ) + + def TUPLE_ITERATOR_LEN(self, guard): + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + code = [] + code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}") + self._set_guard_export_info(guard, code) + + t = type(value) + obj_id = self.id_ref(t, f"type({guard.name})") + + self.get_guard_manager(guard).add_tuple_iterator_length_guard( + tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) + ) + + def RANGE_ITERATOR_MATCH(self, guard): + ref = self.arg_ref(guard) + value = self.get(guard.name) + t = type(value) + + code = [] + normalized_range_iter = normalize_range_iter(value) + code.append(f"___normalize_range_iter({ref}) == {normalized_range_iter}") + self._set_guard_export_info(guard, code) + + t = type(value) + obj_id = self.id_ref(t, f"type({guard.name})") + + start, stop, step = normalized_range_iter + self.get_guard_manager(guard).add_range_iterator_match_guard( + start, stop, step, obj_id, get_verbose_code_parts(code, guard) + ) + + # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards + def DUPLICATE_INPUT(self, guard, source_b): + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "DUPLICATE_INPUT guard cannot be serialized yet." + ) + ref_a = self.arg_ref(guard) + ref_b = self.arg_ref(source_b.name()) + + if is_from_optimizer_source( + guard.originating_source + ) or is_from_optimizer_source(source_b): + return + + # Check that the guard has not been inserted already + key = (ref_a, ref_b) + if key in self._cached_duplicate_input_guards: + return + + self._cached_duplicate_input_guards.add((ref_a, ref_b)) + self._cached_duplicate_input_guards.add((ref_b, ref_a)) + + code = [f"{ref_b} is {ref_a}"] + self._set_guard_export_info(guard, code) + + install_object_aliasing_guard( + self.get_guard_manager(guard), + self.get_guard_manager_from_source(source_b), + get_verbose_code_parts(code, guard), + ) + + def WEAKREF_ALIVE(self, guard): + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "WEAKREF_ALIVE guard cannot be serialized." + ) + code = [f"{self.arg_ref(guard)} is not None"] + + self._set_guard_export_info(guard, code) + self.get_guard_manager(guard).add_not_none_guard( + get_verbose_code_parts(code, guard) + ) + + def MAPPING_KEYS_CHECK(self, guard): + """Guard on the key order of types.MappingProxyType object""" + ref = self.arg_ref(guard) + value = self.get(guard.name) + + code = [] + code.append(f"list({ref}.keys()) == {list(value.keys())}") + self._set_guard_export_info(guard, code) + self.get_guard_manager(guard).add_mapping_keys_guard(value, code) + + def DICT_KEYS_MATCH(self, guard): + """Insert guard to check that the keys of a dict are same""" + ref = self.arg_ref(guard) + value = self.get(guard.name) + + if value is torch.utils._pytree.SUPPORTED_NODES: + # For SUPPORTED_NODES, we can guard on the dictionary version (PEP509). + self.DICT_VERSION(guard) + return + + self.SEQUENCE_LENGTH(guard) + + code = [] + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on PyDict_Next + # to traverse the dictionary, which uses the internal data structure and + # does not call the overridden keys method. + code.append(f"list(dict.keys({ref})) == {list(builtin_dict_keys(value))!r}") + self._set_guard_export_info(guard, code) + + if self.requires_key_order_guarding(guard.originating_source): + self.guard_on_dict_keys_and_order(value, guard) + else: + self.guard_on_dict_keys_and_ignore_order(value, guard) + + def EMPTY_NN_MODULE_HOOKS_DICT(self, guard): + """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" + if config.skip_nnmodule_hook_guards: + # This is unsafe if you add/remove a hook on nn module variable + return + self.SEQUENCE_LENGTH(guard) + + def GRAD_MODE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def DETERMINISTIC_ALGORITHMS(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def TORCH_FUNCTION_STATE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def FSDP_TRAINING_STATE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def DEFAULT_DEVICE(self, guard: Guard): + """Guard on CURRENT_DEVICE per torch.utils._device""" + assert guard.source is GuardSource.GLOBAL + + code = [ + f"utils_device.CURRENT_DEVICE == {self.check_fn_manager.output_graph.current_device!r}" + ] + self._set_guard_export_info(guard, code) + + self.get_guard_manager(guard).add_default_device_guard( + get_verbose_code_parts(code, guard) + ) + + def SHAPE_ENV(self, guard: Guard): + assert guard.name == "" + output_graph = self.check_fn_manager.output_graph + if self.serialization_mode == "load": + assert self.check_fn_manager.shape_code_parts is not None + shape_code_parts = self.check_fn_manager.shape_code_parts + python_code_parts = shape_code_parts.python_code_parts + verbose_code_parts = shape_code_parts.verbose_code_parts + if shape_code_parts.cpp_code_parts is not None: + cpp_code_parts = shape_code_parts.cpp_code_parts + python_fallback = shape_code_parts.python_fallback + else: + # Let's handle ShapeEnv guards. To do this, we will resolve + # shape variables to sources from tracked_fakes. This must happen after + # tensor checks. + # NB: self.output_graph can be None in the debug_nops tests + fs = output_graph.tracked_fakes + input_contexts = [a.symbolic_context for a in fs] + + def get_sources(t_id, dim): + # Looks up base sources mapped to a tensor id and uses them to create + # sources for the corresponding tensor dimension. + return [ + TensorPropertySource(source, TensorProperty.SIZE, dim) + for source in output_graph.tracked_fakes_id_to_source[t_id] + ] + + if output_graph.export_constraints: + names: dict[str, tuple[int, int]] = {} + source_pairs: list[tuple[Source, Source]] = [] + derived_equalities: list[ # type: ignore[type-arg] + tuple[Source, Union[Source, Symbol], Callable] + ] = [] + phantom_symbols: dict[str, Symbol] = {} + relaxed_sources: set[Source] = set() + for constraint in output_graph.export_constraints: + if constraint.t_id in output_graph.tracked_fakes_id_to_source: + torch.export.dynamic_shapes._process_equalities( + constraint, + get_sources, + output_graph.shape_env, + names, + source_pairs, + derived_equalities, + phantom_symbols, + relaxed_sources, + ) + else: + log.warning("Untracked tensor used in export constraints") + equalities_inputs = EqualityConstraint( + source_pairs=source_pairs, + derived_equalities=derived_equalities, + phantom_symbols=list(phantom_symbols.values()), + relaxed_sources=relaxed_sources, + warn_only=False, + ) + else: + equalities_inputs = None + + def _get_code_parts(langs): + return output_graph.shape_env.produce_guards_verbose( + [a.fake for a in fs], + [a.source for a in fs], + input_contexts=input_contexts, + equalities_inputs=equalities_inputs, + source_ref=self.source_ref, + # Export keeps static. + ignore_static=(not self.check_fn_manager.output_graph.export), + langs=langs, + ) + + if config.enable_cpp_symbolic_shape_guards: + try: + # For exporting we need the python code parts + python_code_parts, verbose_code_parts, cpp_code_parts = ( + _get_code_parts(("python", "verbose_python", "cpp")) + ) + python_fallback = False + except OverflowError: + # Cannot use int64_t + python_fallback = True + python_code_parts, verbose_code_parts = _get_code_parts( + ("python", "verbose_python") + ) + else: + python_fallback = True + python_code_parts, verbose_code_parts = _get_code_parts( + ("python", "verbose_python") + ) + + # When exporting, we may work with the shape constraints some more in + # postprocessing, so don't freeze yet + if not self.check_fn_manager.output_graph.export: + output_graph.shape_env.freeze() + + if self.serialization_mode == "save": + # For SHAPE_ENV we want to skip serializing the entire ShapeEnv so instead + # we directly serialize the generated code here. + maybe_cpp_code_parts = locals().get("cpp_code_parts") + assert maybe_cpp_code_parts is None or isinstance( + maybe_cpp_code_parts, _CppShapeGuardsHelper + ) + maybe_shape_env_sources = ( + [] + if maybe_cpp_code_parts is None + else list(maybe_cpp_code_parts.source_to_symbol.keys()) + ) + self.check_fn_manager.shape_code_parts = ShapeCodeParts( + python_code_parts=python_code_parts, + verbose_code_parts=verbose_code_parts, + cpp_code_parts=maybe_cpp_code_parts, + python_fallback=python_fallback, + shape_env_sources=maybe_shape_env_sources, + ) + + for code in python_code_parts.exprs: + self._set_guard_export_info(guard, [code]) + + # Make ShapeEnv guards available for testing. + if compile_context := CompileContext.try_get(): + compile_context.shape_env_guards.extend(verbose_code_parts.exprs) + + int_source_to_symbol = [] + float_source_to_symbol = [] + + if not python_fallback: + assert cpp_code_parts # type: ignore[possibly-undefined] + code_parts, source_to_symbol = ( + cpp_code_parts.exprs, + cpp_code_parts.source_to_symbol, + ) + + if not code_parts: + return + + for source, symbol in source_to_symbol.items(): + if isinstance(source, ConstantSource): + python_fallback = True + else: + example_value = self.get( + source.name(), + closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, + ) + if isinstance(example_value, int): + int_source_to_symbol.append((source, symbol)) + elif isinstance(example_value, float): + float_source_to_symbol.append((source, symbol)) + else: + # SymInts/SymFloats go through python guard as we only support + # int64_t/double in C++ guards for now. + python_fallback = True + + if not python_fallback: + import ctypes + + from torch._inductor.codecache import CppCodeCache + + assert cpp_code_parts # type: ignore[possibly-undefined] + code_parts, source_to_symbol = ( + cpp_code_parts.exprs, + cpp_code_parts.source_to_symbol, + ) + + source_to_symbol = dict(int_source_to_symbol + float_source_to_symbol) + try: + guard_managers = [ + self.get_guard_manager_from_source(IndexedSource(source, i)) + for i, source in enumerate(source_to_symbol) + ] + + int_symbols_str = ", ".join( + f"{symbol} = int_values[{i}]" + for i, (_, symbol) in enumerate(int_source_to_symbol) + ) + float_symbols_str = ", ".join( + f"{symbol} = float_values[{i}]" + for i, (_, symbol) in enumerate(float_source_to_symbol) + ) + + if int_symbols_str: + int_symbols_str = f"int64_t {int_symbols_str};" + if float_symbols_str: + float_symbols_str = f"double {float_symbols_str};" + + func_str = textwrap.dedent( + f""" + #include + #include + #include + #include + + #if defined(_MSC_VER) + # define EXTERN_DLL_EXPORT extern "C" __declspec(dllexport) + #else + # define EXTERN_DLL_EXPORT extern "C" + #endif + + EXTERN_DLL_EXPORT int8_t guard(int64_t *int_values, double *float_values) {{ + {int_symbols_str} + {float_symbols_str} + return ({") && (".join(code_parts)}); + }} + """ + ) + guards_log.debug( + "C++ shape guard function: %s %s", + func_str, + verbose_code_parts.exprs, + ) + clib = CppCodeCache.load(func_str) + cguard = ctypes.cast(clib.guard, ctypes.c_void_p).value + assert cguard + except torch._inductor.exc.InvalidCxxCompiler: + # No valid C++ compiler to compile the shape guard + pass + else: + install_symbolic_shape_guard( + guard_managers, + len(int_source_to_symbol), + len(float_source_to_symbol), + cguard, + clib, + verbose_code_parts.exprs, + ) + return + + # Install all the symbolic guards in one python lambda guard. These are run + # at the very end of the RootGuardManager via epilogue guards. + # TODO(anijain2305,williamwen42) - Consider moving this to C++. + if python_code_parts.exprs: + self.add_python_lambda_leaf_guard_to_root( + python_code_parts.exprs, + verbose_code_parts.exprs, + closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, + ) + + def TENSOR_MATCH(self, guard: Guard, value=None): + if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module(): + return + # For tensors that are part of the Dynamo extracted Fx graph module, an + # ID_MATCH suffices. Once we turn on inline_inbuilt_nn_modules, these + # will be lifted as inputs and have a TENSOR_MATCH guard. + if match_on_id_for_tensor(guard): + self.ID_MATCH(guard) + else: + if isinstance(value, TensorWeakRef): + value = value() + + value = value if value is not None else self.get(guard.name) + + pytype = type(value) + dispatch_keys = torch._C._dispatch_keys(value) + if isinstance(value, torch._subclasses.FakeTensor): + if value.pytype is not None: + pytype = value.pytype + if value.dispatch_keys is not None: + dispatch_keys = value.dispatch_keys + + assert isinstance(value, torch.Tensor) + + if config.log_compilation_metrics and isinstance(value, torch.nn.Parameter): + metrics_context = get_metrics_context() + metrics_context.increment("param_numel", value.numel()) + metrics_context.increment("param_bytes", value.nbytes) + metrics_context.increment("param_count", 1) + + tensor_name = self.arg_ref(guard) + # [Note - On Export Tensor Guards] + # + # In eager mode, tensor guards are evaluated through C++, in guards.cpp + # see [Note - On Eager Tensor Guards] for more info. + # + # In export mode, we instead maintain parallel logic between C++ and python + # here, with an exception of checking the dispatch key - with the idea that a dispatch key + # is an entirely runtime notion that would make no sense to keep in an exported graph. + # + # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although + # not entirely true. + # For example, suppose one of the input tensors had the negative dispatch key. + # You should end up with a graph that is specialized for tensors that have a negative dispatch key. + # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated. + # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't + # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key. + # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported + # subset of keys during export. + # + # The list of tensor fields and calls we care about can be found in `terms` below. + # TODO(voz): We are missing storage offset in all our tensor guards? + code: list[str] = [] + if self.check_fn_manager.output_graph.export: + self.TYPE_MATCH(guard) + terms = [ + "dtype", + "device", + "requires_grad", + "ndimension()", + ] + + for term in terms: + real_value = self.get(tensor_name + "." + term) + if istype(real_value, (torch.device, torch.dtype)): + # copy pasted from EQUALS_MATCH + code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}") + else: + code.append(f"{tensor_name}.{term} == {real_value}") + else: + guard_manager = self.get_guard_manager(guard) + + # skip_no_tensor_aliasing_guards_on_parameters bring + # unsoundness. If you compile a function with two different + # parameters, but later on you pass on same tensor as two + # different outputs (aliasing), Dynamo will not detect this. + # But we deliberately take this soundness hit because this + # usecase is quite rare and there is substantial reduction in + # guard overhead. + # For numpy tensors, since those are ephemeral, we don't have to + # insert aliasing guards on them + if not ( + config.skip_no_tensor_aliasing_guards_on_parameters + and istype(value, torch.nn.Parameter) + ) and not isinstance(guard.originating_source, NumpyTensorSource): + # Keep track of all the tensor guard managers to insert + # NoAliasing check at the end. + self.no_tensor_aliasing_names.append(tensor_name) + self.no_tensor_aliasing_guard_managers.append(guard_manager) + + output_graph = self.check_fn_manager.output_graph + metadata = output_graph.input_source_to_sizes_strides[ + guard.originating_source + ] + size = convert_to_concrete_values(metadata["size"]) + stride = convert_to_concrete_values(metadata["stride"]) + + verbose_code_parts = get_verbose_code_parts( + get_tensor_guard_code_part( + value, tensor_name, size, stride, pytype, dispatch_keys + ), + guard, + ) + guard_manager.add_tensor_match_guard( + value, + size, + stride, + tensor_name, + verbose_code_parts, + pytype, + dispatch_keys, + ) + + # We consider TENSOR_MATCH guard to be important enough to be + # included in diff guard manager by default. + if not isinstance(value, torch.nn.Parameter): + self.guard_manager.diff_guard_sources.add(guard.name) + + # A frame is valid for reuse with dynamic dimensions if the new + # (user-requested) dynamic dimensions are a subset of the old + # (already compiled) dynamic dimensions. + # + # It's a little non-obvious why you'd want this: in particular, + # if an already compiled frame matches all of the guards, why + # not just use it, why force a recompile? + # + # We force it for two reasons: + # + # - The user *required* us to compile with a new dynamic dimension, + # we should not ignore that and serve up the old, specialized + # frame. Listen to the user! + # + # - In fact, we are obligated to *raise an error* if we fail to + # make the requested dimension dynamic. If we don't + # recompile, we can't tell if that dimension can actually be + # made dynamic. + # + # If the new dynamic dims are a subset of the old, we already know + # we can make them dynamic (since we made them dynamic in old). + # This is slightly unsound, because maybe your input size is + # [s0, s0, s1] and so you can do it dynamic if you say dynamic + # dims {0, 1, 2} but you can't if you only do {0, 2} (because now + # the second s0 is specialized). But we're not entirely sure if + # this is a good idea anyway lol... (if you want to try removing + # this logic, be my guest! -- ezyang 2024) + # + assert guard.source is not None + static, _reason = tensor_always_has_static_shape( + value, is_tensor=True, tensor_source=guard.originating_source + ) + + if not static: + if hasattr(value, "_dynamo_dynamic_indices"): + dynamic_indices = value._dynamo_dynamic_indices + code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # noqa: B950 + code.append(code_part) + self.get_guard_manager(guard).add_dynamic_indices_guard( + dynamic_indices, get_verbose_code_parts(code_part, guard) + ) + # In the case of us not having any dynamic dimension indices, we compiled the frame with no chance of + # raising for this specific tensor - and any inputs with more dynamic user directives specified must be recompiled. + else: + code_part = ( + f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False" + ) + code.append(code_part) + self.get_guard_manager(guard).add_no_hasattr_guard( + "_dynamo_dynamic_indices", + get_verbose_code_parts(code_part, guard), + ) + if len(code) > 0: + self._set_guard_export_info(guard, code) + + # A util that in the case of export, adds data onto guards + def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None): + # WARNING: It is important that cur_frame/caller do NOT stay in + # the current frame, because they will keep things live longer + # than they should. See TestMisc.test_release_module_memory + cur_frame = currentframe() + assert cur_frame is not None + caller = cur_frame.f_back + del cur_frame + assert caller is not None + func_name = caller.f_code.co_name + del caller + # We use func_name for export, so might as well get a nice defensive check out of it + assert func_name in self.__class__.__dict__, ( + f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}" + ) + + # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD) + if provided_guarded_object is None: + name = guard.name + guarded_object = None if not name else self.get(name) + else: + guarded_object = provided_guarded_object + + guarded_object_type = ( + weakref.ref(type(guarded_object)) if guarded_object is not None else None + ) + obj_ref = None + # Not necessary to have weakref for Enum type, but there is a bug that + # makes hasattr(guarded_object.__class__, "__weakref__") return True. + supports_weakref = ( + getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0 + ) + # See D64140537 for why we are checking for tuple. + if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)): + obj_ref = weakref.ref(guarded_object) + + guard.set_export_info( + func_name, + guarded_object_type, + code_list, + obj_ref, + ) + + +# Common Sub-Expression Elimination for Python expressions. +# +# There are 2 steps to this pass: +# 1. Count the frequency of each sub-expression (i.e. inner +# node in the AST tree) +# +# 2. Replace those that occur more than once by a fresh variable 'v'. +# 'v' will be defined in the 'preface' list (output argument to +# 'NodeTransformer') +# +# NB: the use of 'ast.unparse' while visiting the nodes makes this pass +# quadratic on the depth of the tree. +# +# NB: this pass creates a new variable for each AST node that is repeated +# more than 'USE_THRESHOLD'. e.g. if 'a.b.c.d' is used 10 times, 'a.b.c' +# and 'a.b' are also used 10 times. So, there will be a new variable for +# each of them. +class PyExprCSEPass: + # Maximum number of times a given expression can be used without being + # replaced by a fresh variable. + USE_THRESHOLD = 1 + + # Ad-Hoc: AST nodes this pass focuses on. + ALLOWED_NODE_TYPES = (ast.Attribute, ast.Call, ast.Subscript) + + @dataclasses.dataclass + class Config: + expr_count: dict[str, int] + expr_to_name: dict[str, str] + + class ExprCounter(ast.NodeVisitor): + def __init__(self, config: PyExprCSEPass.Config) -> None: + self._config = config + + def visit(self, node: ast.AST) -> Any: + if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): + self._config.expr_count[_ast_unparse(node)] += 1 + super().visit(node) + + class Replacer(ast.NodeTransformer): + def __init__( + self, + config: PyExprCSEPass.Config, + gen_name: Callable[[], str], + ) -> None: + super().__init__() + self._config = config + self._gen_name = gen_name + self.preface: list[str] = [] + + def visit(self, node: ast.AST) -> Any: + if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): + expr = _ast_unparse(node) + + # Replacement only occurs if a given expression is used more + # than once. + if self._config.expr_count[expr] > PyExprCSEPass.USE_THRESHOLD: + if expr not in self._config.expr_to_name: + # Parent 'visit' is called so that we CSE the inner expressions first. + # + # The resulting expression is used as right-hand-side of the variable + # assignment. i.e. we are CSE-ing the children before the parents. + # + # Indexing still uses the old 'node', since that's what was counted + # by the 'NodeVisitor'. + node_ = super().visit(node) + expr_ = _ast_unparse(node_) + var_name = self._gen_name() + self.preface.append(f"{var_name} = {expr_}") + self._config.expr_to_name[expr] = var_name + else: + var_name = self._config.expr_to_name[expr] + return ast.Name(var_name, ast.Load()) + + return super().visit(node) + + def __init__(self) -> None: + self._counter = 0 + self._config = self.Config( + expr_count=collections.defaultdict(lambda: 0), expr_to_name={} + ) + + def _new_var(self, prefix: str = "_var") -> str: + name = f"{prefix}{self._counter}" + self._counter += 1 + return name + + def count(self, exprs: list[str]) -> None: + counter = self.ExprCounter(self._config) + for e in exprs: + try: + counter.visit(ast.parse(e)) + except SyntaxError as ex: + log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e) + raise + + def replace(self, expr: str) -> tuple[list[str], str]: + replacer = self.Replacer(self._config, self._new_var) + new_node = replacer.visit(ast.parse(expr)) + return replacer.preface, _ast_unparse(new_node) + + +def must_add_nn_module_guards(guard): + # For config.guard_nn_modules=False, we can skip all the guards that + # originate from inside of nn module except for a few categories. + return ( + # Guard for defaults + isinstance(guard.originating_source, DefaultsSource) + # Guard using dict tags if the config flag is set + or ( + config.guard_nn_modules_using_dict_tags + and guard.create_fn is GuardBuilder.NN_MODULE + ) + ) + + +class DeletedGuardManagerWrapper(GuardManagerWrapper): + def __init__(self, reason): + super().__init__() + self.invalidation_reason = reason + + def populate_diff_guard_manager(self): + self.diff_guard_root = None + + +@dataclasses.dataclass +class ShapeCodeParts: + python_code_parts: _ShapeGuardsHelper + verbose_code_parts: _ShapeGuardsHelper + cpp_code_parts: Optional[_CppShapeGuardsHelper] + python_fallback: bool + shape_env_sources: list[Source] + + +@dataclasses.dataclass +class GuardsState: + output_graph: OutputGraphGuardsState + shape_code_parts: Optional[ShapeCodeParts] + + +class GuardsStatePickler(pickle.Pickler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fake_mode = torch._subclasses.FakeTensorMode() + self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() + + @classmethod + def _unpickle_module(cls, state): + mod = torch.nn.Module() + mod.__setstate__(state) + return mod + + @classmethod + def _unpickle_tensor(cls, meta_tensor, device, pytype, dispatch_keys_raw): + fake_mode = torch._subclasses.FakeTensorMode() + tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() + return tensor_converter.from_meta_and_device( + fake_mode, + meta_tensor, + device, + pytype, + torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw), + ) + + @classmethod + def _unpickle_traceable_wrapper_subclass( + cls, meta_tensor, device, pytype, dispatch_keys_raw, ctx, inner_data + ): + # Unpickle the inner tensor components. These could also be subclass instances. + inner_tensors = {} + for attr, unpickle_func, unpickle_func_args in inner_data: + inner_tensors[attr] = unpickle_func(*unpickle_func_args) + + outer_size, outer_stride = meta_tensor.shape, meta_tensor.stride() + out = type(meta_tensor).__tensor_unflatten__( + inner_tensors, ctx, outer_size, outer_stride + ) + out.pytype = pytype + out.dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw) + return out + + @classmethod + def _unpickle_python_module(cls, alias: str): + return importlib.import_module(alias) + + @classmethod + def _unpickle_dispatch_key_set(cls, raw_repr: int): + return torch._C.DispatchKeySet.from_raw_repr(raw_repr) + + @classmethod + def _unpickle_functorch_interpreter(cls, json: bytes): + return torch._C._functorch.CInterpreter.deserialize(json) + + @classmethod + def _unpickle_mapping_proxy(cls, d): + return types.MappingProxyType(d) + + def reducer_override(self, obj): + import sympy + + if isinstance(obj, torch.Tensor) and obj.device.type != "meta": + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(obj): + # inner_data is a list of tuples of: + # (inner attr name, unpickle func, tuple of func inputs) + # This supports traceable wrapper subclass inner tensors. + inner_data = [] + attrs, ctx = obj.__tensor_flatten__() + # recursively call for inner tensor components + for attr in attrs: + inner = getattr(obj, attr) + func, args_tuple = self.reducer_override(inner) + inner_data.append((attr, func, args_tuple)) + + return type(self)._unpickle_traceable_wrapper_subclass, ( + torch.empty_like(obj, device="meta"), + obj.device, + type(obj), + torch._C._dispatch_keys(obj).raw_repr(), + ctx, + inner_data, + ) + + return type(self)._unpickle_tensor, ( + torch.empty_like(obj, device="meta"), + obj.device, + type(obj), + torch._C._dispatch_keys(obj).raw_repr(), + ) + + elif isinstance(obj, torch.nn.Module): + if type(obj).__qualname__ == type(obj).__name__: + return NotImplemented + if obj.__class__.__getstate__ == torch.nn.Module.__getstate__: + return type(self)._unpickle_module, (obj.__getstate__(),) + + elif inspect.ismodule(obj): + return type(self)._unpickle_python_module, (obj.__name__,) + + elif isinstance(obj, torch._C.DispatchKeySet): + return type(self)._unpickle_dispatch_key_set, (obj.raw_repr(),) + + elif isinstance(obj, torch._C._functorch.CInterpreter): + return type(self)._unpickle_functorch_interpreter, (obj.serialize(),) + + elif ( + inspect.isclass(obj) + and issubclass(obj, sympy.Function) + and hasattr(obj, "_torch_handler_name") + ): + assert hasattr(obj, "_torch_unpickler") + return obj._torch_unpickler, (obj._torch_handler_name,) + + elif isinstance(obj, torch.SymInt): + raise RuntimeError(f"Cannot serialize SymInt {obj} (node: {obj.node})") + + elif isinstance(obj, types.MappingProxyType): + return type(self)._unpickle_mapping_proxy, (obj.copy(),) + + if type(obj).__qualname__ != type(obj).__name__: + raise torch._dynamo.exc.PackageError( + f"Type {type(obj)} for object {obj} cannot be saved " + + "into torch.compile() package since it's defined in local scope. " + + "Please define the class at global scope (top level of a module)." + ) + + return NotImplemented + + +def pickle_guards_state(state: GuardsState) -> bytes: + buf = io.BytesIO() + pickler = GuardsStatePickler(buf) + try: + pickler.dump(state) + except AttributeError as e: + raise torch._dynamo.exc.PackageError(str(e)) from e + return buf.getvalue() + + +# NB: Naively, you'd expect this to only be a function that produces +# the callable that constitutes the guard. However, there is some +# delicate handling for invalidating this check function when the +# locals/globals get invalidated, so there's some extra state +# we have to hold in this manager class. +class CheckFunctionManager: + def __init__( + self, + f_code, + output_graph=None, + cache_entry=None, + guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, + guard_filter_fn: Optional[ + Callable[[list[GuardFilterEntry]], list[bool]] + ] = None, + guards_serialization_mode: Optional[str] = None, + shape_code_parts: Optional[ShapeCodeParts] = None, + ): + guards = output_graph.guards if output_graph else None + self._weakrefs: dict[int, ReferenceType[object]] = {} + + existing_diff_guard_sources = ( + update_diff_guard_managers_for_existing_cache_entries(cache_entry) + ) + self.output_graph = output_graph + + # Only used for serialization. + self.shape_code_parts = shape_code_parts + + # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing + # in case a set default device call was made in the graph. + self.torch_function_mode_stack = ( + output_graph.torch_function_mode_stack if output_graph else None + ) + self.guards_serialization_mode = guards_serialization_mode + + if not justknobs_check("pytorch/compiler:guard_nn_modules"): + log.warning("guard_nn_modules is turned off using justknobs killswitch") + + sorted_guards = sorted(guards or (), key=Guard.sort_key) + builder, guard_manager = self.build_guards( + sorted_guards, + existing_diff_guard_sources, + f_code, + output_graph, + None if guard_filter_fn else self.guards_serialization_mode, + ) + + if guard_filter_fn: + + def make_guard_filter_entry(guard): + MISSING = object() + name = strip_local_scope(guard.name) + if name == "": + has_value = False + value = MISSING + else: + has_value = True + value = builder.get(guard.name) + is_global = get_global_source_name(guard.originating_source) is not None + guard_fn = guard.create_fn + if isinstance(guard_fn, functools.partial): + guard_fn = guard.create_fn.func + return GuardFilterEntry( + name=name, + has_value=has_value, + value=value, + guard_type=guard_fn.__name__, + derived_guard_types=tuple(guard.guard_types) + if guard.guard_types + else (), + is_global=is_global, + orig_guard=guard, + ) + + filter_results = guard_filter_fn( + [make_guard_filter_entry(guard) for guard in sorted_guards] + ) + assert len(filter_results) == len(sorted_guards) + assert all(type(x) == bool for x in filter_results) + sorted_guards = [ + guard for i, guard in enumerate(sorted_guards) if filter_results[i] + ] + # Redo the guards because filtering relies on the results from the last guard builder. + builder, guard_manager = self.build_guards( + sorted_guards, + existing_diff_guard_sources, + f_code, + output_graph, + self.guards_serialization_mode, + ) + + self.guard_manager = guard_manager + self.compile_check_fn(builder, sorted_guards, guard_fail_fn) + + # Keep track of weak references of objects with ID_MATCH guard. This + # info is stored alongside optimized_code and guard_manager and is used to + # limit the number of cache entries with same ID_MATCH'd object. + # TODO(anijain2305) - Currently this information is stored as an attr on + # the guard_manager itself to avoid changing CacheEntry data structure in + # eval_frame.c. In future, we should probably replace guard_manager with a + # queryable data structure such that this information is already present + # in some form. + self.guard_manager.id_matched_objs = builder.id_matched_objs + + guards_log.debug("%s", self.guard_manager) + self.guard_manager.id_matched_objs = builder.id_matched_objs + + # Check that the guard returns True. False means that we will always + # recompile. + # TODO(anijain2305, ydwu4) - Skipping export because of following test + # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs + latency = 0.0 + if not output_graph.export and self.guards_serialization_mode != "load": + if not self.guard_manager.check(output_graph.local_scope): + reasons = get_guard_fail_reason_helper( + self.guard_manager, # type: ignore[arg-type] + output_graph.local_scope, + CompileContext.current_compile_id(), + ) + raise AssertionError(f"Guard check failed: {reasons}") + + if guard_manager_testing_hook_fn is not None: + guard_manager_testing_hook_fn( + self.guard_manager, output_graph.local_scope + ) + + # NB for developers: n_iters is chosen to be 1 to prevent excessive + # increase in compile time. We first do a cache flush to measure the + # guard latency more accurately. This cache flush is expensive. + # Note - If you are working on a guard optimization, it might be a + # good idea to increase this number for more stabiilty during + # development. + latency = profile_guard_manager( + self.guard_manager.root, output_graph.local_scope, 1 + ) + guards_log.debug("Guard eval latency = %s us", f"{latency:.2f}") + # Note: We use `increment_toplevel` instead of `compilation_metric` + # here. This is because, in scenarios where `torch._dynamo.reset` + # is invoked, the same frame ID and compile ID may be reused during + # a new compilation cycle. This behavior causes issues with + # `compilation_metric`, as it expects the metric field to be empty. + # Ideally, we would overwrite the existing entry in such cases, but + # we currently lack an API to support overwriting metrics. However, + # since these situations are rare and typically impractical to + # account for, we simply increment at the toplevel instead. + CompileEventLogger.increment_toplevel("guard_latency_us", int(latency)) + + self.guards_state: Optional[bytes] = None + if self.guards_serialization_mode == "save": + used_global_vars = set() + used_local_vars = set() + + def prune_variable(source): + if name := get_global_source_name(source): + assert isinstance(name, str) + used_global_vars.add(name) + elif name := get_local_source_name(source): + assert isinstance(name, str) + used_local_vars.add(name) + + output_graph_guards_state = self.output_graph.dump_guards_state() + # Only serialize the global variables that are actually used in guards. + for guard in sorted_guards: + if isinstance(guard.originating_source, ShapeEnvSource): + assert self.shape_code_parts + for source in self.shape_code_parts.shape_env_sources: + prune_variable(source) + else: + prune_variable(guard.originating_source) + + for source in self.output_graph.guard_on_key_order: + prune_variable(source) + + def normalize_create_fn(x): + if isinstance(x, functools.partial): + + def _ref(x): + if isinstance(x, (TensorWeakRef, weakref.ref)): + return x() + return x + + new_args = tuple(_ref(a) for a in x.args) + new_keywords = {k: _ref(v) for k, v in x.keywords.items()} + return functools.partial(x.func, *new_args, **new_keywords) + + return x + + output_graph_guards_state = dataclasses.replace( + output_graph_guards_state, + local_scope={ + k: v + for k, v in output_graph_guards_state.local_scope.items() + if k in used_local_vars + }, + global_scope={ + k: v + for k, v in output_graph_guards_state.global_scope.items() + if k in used_global_vars + }, + _guards=torch._guards.GuardsSet( + { + dataclasses.replace( + guard, + obj_weakref=None, + guarded_class_weakref=None, + create_fn=normalize_create_fn(guard.create_fn), + ) + for guard in sorted_guards + } + ), + input_source_to_sizes_strides=pytree.tree_map( + convert_int_to_concrete_values, + output_graph_guards_state.input_source_to_sizes_strides, + ), + ) + guards_state = GuardsState( + output_graph=output_graph_guards_state, + shape_code_parts=self.shape_code_parts, + ) + self.guards_state = pickle_guards_state(guards_state) + + # TODO: don't do the string rep, do something more structured here + torch._logging.trace_structured( + "dynamo_cpp_guards_str", + payload_fn=lambda: f"{self.guard_manager}\nGuard latency = {latency:.2f} us", + ) + # NB - We have to very careful of cleaning up here. Because of the + # invalidate function, we can create a weakref finalizer that keeps + # `self` alive for very long. Sometimes by mistake, we can run + # invalidate for a type/object (check id_ref method) that Python can + # leak by design, preventing us from calling the finalizer. In that + # case, the `self` will be alive even though the cache entry will be + # deleted (check invalidate method), which can cause a memory leak, + # e.g., not setting output_graph = None can keep hold of nn_modules. + self._weakrefs.clear() + self.output_graph = None + + def build_guards( + self, + sorted_guards, + existing_diff_guard_sources, + f_code, + output_graph, + serialization_mode=None, + ): + guard_manager = GuardManagerWrapper() + guard_manager.diff_guard_sources = existing_diff_guard_sources + + w_builder = None + + def source_ref(source): + guard_source = source.guard_source() + if guard_source is GuardSource.CONSTANT: + # No need to track constants + return source.name() + assert w_builder + r_builder = w_builder() + assert r_builder is not None + return r_builder.arg_ref(source.name()) + + builder = GuardBuilder( + f_code, + self.id_ref, + source_ref, + self.lookup_weakrefs, + output_graph.local_scope, + output_graph.global_scope, + guard_manager, + self, + serialization_mode, + ) + + # Break retain cycle. See test_release_scope_memory + def cleanup_builder(weak_b): + b = weak_b() + if b: + b.scope = None + + # Break retain cycle. See test_release_input_memory + w_builder = weakref.ref(builder, cleanup_builder) + + guard_on_nn_modules = config.guard_nn_modules and justknobs_check( + "pytorch/compiler:guard_nn_modules" + ) + + for guard in sorted_guards: + if ( + not guard_on_nn_modules + and guard.is_specialized_nn_module() + # Default func args must be guarded on. + # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API + and "__defaults__" not in guard.name + and "__kwdefaults__" not in guard.name + and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name) + ): + continue + + guard.create(builder) + return builder, guard_manager + + def compile_check_fn(self, builder, guards_out, guard_fail_fn): + # see parallel handling of ".0" / "___implicit0" in _eval_frame.c + largs = builder.argnames + largs += ["**___kwargs_ignored"] + + guards_log.debug("GUARDS:") + + code_parts = [] + verbose_code_parts = [] + structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] + + torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard( + self.torch_function_mode_stack + ) + + # Insert the global_state guard + self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) + + self.guard_manager.root.add_torch_function_mode_stack_guard( + self.torch_function_mode_stack, + ["___check_torch_function_mode_stack()"], + ) + # Clear references to torch_function modes held in the list + self.torch_function_mode_stack = None + + def add_code_part(code_part, guard, log_only=False): + verbose_code_part = get_verbose_code_part(code_part, guard) + guards_log.debug("%s", verbose_code_part) + + structured_guard_fns.append( + lambda: { + "code": code_part, + "stack": ( + structured.from_traceback(guard.stack.summary()) + if guard and guard.stack + else None + ), + "user_stack": ( + structured.from_traceback(guard.user_stack) + if guard and guard.user_stack + else None + ), + } + ) + + if verbose_guards_log.isEnabledFor(logging.DEBUG): + maybe_stack = "" + maybe_user_stack = "" + if guard is not None: + if guard.stack: + maybe_stack = f"\nStack:\n{''.join(guard.stack.format())}" + if guard.user_stack: + maybe_user_stack = ( + f"\nUser stack:\n{''.join(guard.user_stack.format())}" + ) + verbose_guards_log.debug( + "Guard: %s%s%s", + code_part, + maybe_stack, + maybe_user_stack, + ) + + if not log_only: + code_parts.append(code_part) + verbose_code_parts.append(verbose_code_part) + + seen = set() + for gcl in builder.code: + for code in gcl.code_list: + if code not in seen: + # If Cpp guard manager is enabled, we don't need to add to + # code_parts. + add_code_part(code, gcl.guard, True) + seen.add(code) + + no_tensor_aliasing_names = builder.no_tensor_aliasing_names + check_tensors_fn = None + check_tensors_verbose_fn = None + + if len(no_tensor_aliasing_names) > 1: + # Install tensor aliasing guard. TENSOR_MATCH guards are already + # installed for cpp guard manager. + install_no_tensor_aliasing_guard( + builder.no_tensor_aliasing_guard_managers, + no_tensor_aliasing_names, + ["check_no_aliasing(" + ", ".join(no_tensor_aliasing_names) + ")"], + ) + + aotautograd_guards: list[GuardEnvExpr] = ( + self.output_graph.aotautograd_guards if self.output_graph else [] + ) + + # TODO(anijain2305) - There is a duplicate logic in Dynamo to find + # aliased input tensors. So most probably we don't need this here. + # Revisit. + for guard in aotautograd_guards: + if isinstance(guard, DuplicateInputs): + source_a = guard.input_source_a + source_b = guard.input_source_b + code_part = f"{source_a.name()} is {source_b.name()}" + install_object_aliasing_guard( + builder.get_guard_manager_from_source(source_a), + builder.get_guard_manager_from_source(source_b), + [code_part], + ) + add_code_part(code_part, None, True) + elif isinstance(guard, StorageOverlap): + overlapping_guard_managers = [ + builder.get_guard_manager_from_source(s) + for s in guard.overlapping_sources + ] + non_overlapping_guard_managers = [ + builder.get_guard_manager_from_source(s) + for s in guard.non_overlapping_sources + ] + code_part = ( + """check_overlapping(""" + f"""overlapping=[{", ".join(s.name() for s in guard.overlapping_sources)}], """ + f"""non_overlapping=[{", ".join(s.name() for s in guard.non_overlapping_sources)}])""" + ) + install_storage_overlapping_guard( + overlapping_guard_managers, + non_overlapping_guard_managers, + [code_part], + ) + add_code_part(code_part, None, True) + else: + raise RuntimeError(f"Unknown GuardEnvExpr: {guard}") + + # TODO: the "guard" here is actually just the top level SHAPE_ENV + # which is useless. Get ShapeEnv to pass in more provenance. + for gcl in builder.shape_env_code: + for code in gcl.code_list: + # Shape env guards are already added for CPP guard manager in + # SHAPE_ENV implementation. + add_code_part(code, gcl.guard, True) + + # OK, all done generating guards + if structured_guard_fns: + torch._logging.trace_structured( + "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns] + ) + + global_state = convert_frame.initial_global_state + if global_state is None: + # we should only hit this case in NopTests() + global_state = convert_frame.GlobalStateGuard() + closure_vars = { + "___check_tensors": check_tensors_fn, + "___check_tensors_verbose": check_tensors_verbose_fn, + "___check_global_state": global_state.check, + "___check_torch_function_mode_stack": torch_function_mode_stack_check_fn, + **SYMPY_INTERP, + **_get_closure_vars(), + } + + self.guard_manager.finalize() + + globals_for_guard_fn = {"G": builder.scope["G"]} + # Guard manager construction is complete. Ensure we did not miss to + # insert a guard in cpp guard manager. + assert len(code_parts) == 0 + + self.guard_manager.closure_vars = closure_vars + self.guard_manager.args = largs + self.guard_manager.populate_code_parts_for_debugging() + self.guard_manager.verbose_code_parts = verbose_code_parts + # Grab only G, but preserve "G" because guards access it as "G" + self.guard_manager.global_scope = globals_for_guard_fn + self.guard_manager.guard_fail_fn = guard_fail_fn + # will be populated by a non-owning reference to CacheEntry/ExtraState + # when the CacheEntry is constructed + self.guard_manager.cache_entry = None + self.guard_manager.extra_state = None + self.guard_manager.no_tensor_aliasing_sources = no_tensor_aliasing_names + + def invalidate(self, obj_str): + # Some tests reveal that CheckFunctionManager has no attribute + # guard_manager, but this case should not be of any concern. + # This case doesn't seem easy to repro. + if ( + hasattr(self, "guard_manager") + and not isinstance(self.guard_manager, DeletedGuardManagerWrapper) + and (cache_entry := self.guard_manager.cache_entry) is not None + and (extra_state := self.guard_manager.extra_state) is not None + ): + assert isinstance(cache_entry, CacheEntry) + assert isinstance(extra_state, ExtraState) + reason = f"Cache line invalidated because {obj_str} got deallocated" + deleted_guard_manager = DeletedGuardManagerWrapper(reason) + extra_state.invalidate(cache_entry, deleted_guard_manager) + self.guard_manager = deleted_guard_manager + + def id_ref(self, obj, obj_str): + """add a weakref, return the id""" + try: + if id(obj) not in self._weakrefs: + # We will clear the _weakrefs dict at the end of __init__ + # function, which will delete the callbacks as well. Therefore, + # we are using a finalizer which is kept alive. + self._weakrefs[id(obj)] = weakref.ref(obj) + weakref.finalize( + obj, functools.partial(self.invalidate, obj_str=obj_str) + ) + except TypeError: + pass # cannot weakref bool object + return id(obj) + + def lookup_weakrefs(self, obj): + """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects""" + if id(obj) in self._weakrefs: + return self._weakrefs[id(obj)] + return None + + +def build_guard_function(code_parts, closure_args) -> tuple[str, str]: + from torch._inductor.utils import IndentedBuffer + + csepass = PyExprCSEPass() + try: + csepass.count(code_parts) + + def replace(expr: str) -> tuple[list[str], str]: + return csepass.replace(expr) + except RecursionError: + # If we hit recursion limits during CSE analysis, fall back to a no-op replace function + # This can happen with extremely complex guard expressions + def replace(expr: str) -> tuple[list[str], str]: + return [], expr + + # Generate the inner body of the guard function. + # i.e. if-chain of the guard expressions. + guard_body = IndentedBuffer() + for expr in code_parts: + preface, expr = replace(expr) + guard_body.writelines(preface) + guard_body.writeline(f"if not ({expr}):") + with guard_body.indent(): + guard_body.writeline("return False") + + # Wrap the inner body into the actual guard function. + guard = IndentedBuffer() + guard.writeline("def guard(L):") + with guard.indent(): + guard.splice(guard_body) + guard.writeline("return True") + + # Wrap the whole guard function into another function + # with the closure variables. + make_guard_fn = IndentedBuffer() + make_guard_fn.writeline(f"def ___make_guard_fn({closure_args}):") + with make_guard_fn.indent(): + make_guard_fn.splice(guard) + make_guard_fn.writeline("return guard") + + return guard_body.getvalue(), make_guard_fn.getvalue() + + +def is_recompiles_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled("recompiles") + + +def is_recompiles_verbose_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose") + + +# this will only be used if cpp guards are disabled +def make_torch_function_mode_stack_guard(initial_stack): + types = [type(x) for x in initial_stack] + + def check_torch_function_mode_stack(): + cur_stack = get_torch_function_mode_stack() + + if len(cur_stack) != len(types): + return False + + for ty, mode in zip(types, cur_stack): + if ty != type(mode): + return False + + return True + + return check_torch_function_mode_stack + + +def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): + global_scope = dict(guard_manager.global_scope) + ids_to_source = collections.defaultdict(list) + for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined] + global_scope["__compile_source__"] = tensor_source + tensor_id = id(eval(tensor_source, global_scope, scope)) + ids_to_source[tensor_id].append(tensor_source) + + duplicate_tensors = [ + f"{ids_to_source[key]}" for key in ids_to_source if len(ids_to_source[key]) > 1 + ] + + reason = ", ".join(duplicate_tensors) + return [f"Duplicate tensors found: {reason}"] + + +def strip_local_scope(s: str) -> str: + """ + Replace occurrences of L[...] with just the inner content. + Handles both single and double quotes. + + This is to generate user friendly recompilation messages. + """ + import re + + pattern = r"L\[\s*['\"](.*?)['\"]\s*\]" + return re.sub(pattern, r"\1", s) + + +def get_guard_fail_reason_helper( + guard_manager: GuardFn, + f_locals: dict[str, object], + compile_id: CompileId, +) -> str: + """ + Return the reason why `guard_manager` failed. + Updates `guard_failures` with the generated reason. + Only the first failed check of guard_manager is reported. + """ + scope = {"L": f_locals, "G": guard_manager.global_scope["G"]} + scope.update(guard_manager.closure_vars) + reasons: list[str] = [] + + no_tensor_aliasing_check_failed = False + + verbose_code_parts: list[str] = [] + guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined] + # For test_export_with_map_cond, the check_verbose fail even without the + # C++ guard manager. We need to fix the issue to remove the comment. + # assert not guard_debug_info.result + if not guard_debug_info.result: + verbose_code_parts = guard_debug_info.verbose_code_parts + # verbose_code_parts is either the actual reason (e.g. in case of + # TENSOR_MATCH) or it could be a list of verbose_code_part that we + # passed to the leaf guard at construction time. If its a list, we + # walk through this list and find the guard that failed. This is + # very important for symbolic shape guards which are currently + # installed as a lambda guard and can encompass a long list of code_parts. + + if len(verbose_code_parts) == 1: + if "Duplicate tensor found" in verbose_code_parts[0]: + no_tensor_aliasing_check_failed = True + else: + reasons = verbose_code_parts + verbose_code_parts = [] + + if no_tensor_aliasing_check_failed: + reasons = recompilation_reason_for_no_tensor_aliasing_guard( + guard_manager, scope + ) + else: + for part in verbose_code_parts: + global_scope = dict(guard_manager.global_scope) + global_scope["__compile_source__"] = part + with report_compile_source_on_error(): + try: + fail_reason = eval(part, global_scope, scope) + except Exception: + if is_recompiles_verbose_enabled(): + continue + else: + raise + # Only ___check_tensors knows how to return a fancy fail reason; + # for everything else we just report the code that failed + + if isinstance(fail_reason, bool) and not fail_reason: + fail_reason = part + if isinstance(fail_reason, str): + reasons.append(fail_reason) + if not is_recompiles_verbose_enabled(): + break + + reason_str = f"{compile_id}: " + "; ".join(reasons) + return strip_local_scope(reason_str) + + +def get_guard_fail_reason( + guard_manager: GuardFn, + code: types.CodeType, + f_locals: dict[str, object], + compile_id: CompileId, +) -> str: + if isinstance(guard_manager, DeletedGuardManagerWrapper): + return f"{compile_id}: {guard_manager.invalidation_reason}" + reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id) + guard_failures[orig_code_map[code]].append(reason_str) + + try: + if guard_manager.guard_fail_fn is not None: + guard_manager.guard_fail_fn( + GuardFail(reason_str or "unknown reason", orig_code_map[code]) + ) + except Exception: + log.exception( + "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", + ) + + return reason_str + + +def get_and_maybe_log_recompilation_reasons( + cache_entry, frame: DynamoFrameType +) -> list[str]: + """ + Return the list of guard failure reasons using cache_entry. + Logs the recompilation reason if `recompiles` logging is enabled. + Raises a RecompileError if `config.error_on_recompile` is enabled. + """ + reasons = [] + while cache_entry is not None: + reason = get_guard_fail_reason( + cache_entry.guard_manager, + cache_entry.code, + frame.f_locals, + cache_entry.compile_id, + ) + if reason: + reasons.append(reason) + cache_entry = cache_entry.next + + code = frame.f_code + + # at least one of "recompiles" or "recompiles_verbose" is enabled + do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled() + + if do_recompiles_log or config.error_on_recompile: + if is_recompiles_verbose_enabled(): + failures = "\n\n".join( + f"guard {i} failures:\n" + textwrap.indent(reason, "- ") + for i, reason in enumerate(reasons) + ) + else: + failures = textwrap.indent("\n".join(reasons), "- ") + guard_failure_details = ( + f"triggered by the following guard failure(s):\n{failures}" + ) + message = ( + f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n" + f"{textwrap.indent(guard_failure_details, ' ')}" + ) + if do_recompiles_log: + if is_recompiles_verbose_enabled(): + recompiles_verbose_log.debug(message) + else: + recompiles_log.debug(message) + if config.error_on_recompile: + raise exc.RecompileError(message) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "recompile_reasons", + "encoding": "json", + }, + payload_fn=lambda: reasons, + ) + + return reasons + + +def update_diff_guard_managers_for_existing_cache_entries(cache_entry): + first_cache_entry = cache_entry + + # On the first pass, go through the cache entries and accumulate the diff + # guard sources. Different guard managers can fail with different sources. + # So, we collect all of them first. + acc_diff_guard_sources = set() + while cache_entry is not None: + acc_diff_guard_sources.update( + cache_entry.guard_manager.collect_diff_guard_sources() + ) + cache_entry = cache_entry.next + + # On the second pass, set the diff_guard_sources for each cache line to the + # accumulated value. And the re-populate the diff guard manager. + cache_entry = first_cache_entry + while cache_entry is not None: + cache_entry.guard_manager.diff_guard_sources = acc_diff_guard_sources + cache_entry.guard_manager.populate_diff_guard_manager() + cache_entry = cache_entry.next + + # return the accumulated sources to set up the new cache line. + return acc_diff_guard_sources + + +def guard_error_hook( + guard_manager: GuardFn, + code: types.CodeType, + f_locals: dict[str, object], + index: int, + last: bool, +): + print( + f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" + ) + print("lambda " + ", ".join(guard_manager.args) + ":") + print(" ", " and\n ".join(guard_manager.code_parts)) + + print(guard_manager) + + local_scope = {"L": f_locals, **guard_manager.closure_vars} + for guard in guard_manager.code_parts: + try: + eval(guard, guard_manager.global_scope, local_scope) + except: # noqa: B001,E722 + print(f"Malformed guard:\n{guard}") + + +set_guard_error_hook(guard_error_hook) + + +def unique(seq): + seen = set() + for x in seq: + if x not in seen: + yield x + seen.add(x) + + +def make_dupe_guard(obj_source, dupe_source): + # Note - we may end up in a situation where we invoke something like + # def fn(x, y) + # with fn(x, x) + # Prior to the addition of tracking to all relevant objects, we would handle this just fine by + # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However, + # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here - + # In the fn(x, x) example call above look like a graph with a single input. + # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard. + + # Note - we may not have a source, that is fine, it just means we had an object that is safe to have + # leave unsourced - like a local list created and discharged entirely within a local scope. + if dupe_source and dupe_source != obj_source: + ser_source_is_local = is_from_local_source(dupe_source) + source_is_local = is_from_local_source(obj_source) + if is_from_flatten_script_object_source( + dupe_source + ) or is_from_flatten_script_object_source(obj_source): + raise exc.UnsafeScriptObjectError( + f"{obj_source.name()} is aliasing {dupe_source.name()}. This is not supported." + f" Please do a clone for corresponding input." + ) + + # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently + # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here, + # so maybe we should do this refactor before we land this... + # TODO(voz): Combine local and global guard builders. + if ser_source_is_local == source_is_local: + # Note - this is a little aggressive - these being duplicate input does not always matter. + # However, this should always be a sound guard to add here. + return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source) + return None + + +def install_guard(*guards, skip=0): + """ + Add dynamo guards to the current tracing context. + + Args: + guards: guard(s) to add + skip: number of stack frames to ignore for debug stack trace + """ + from torch._guards import TracingContext + + collect_debug_stack = guards_log.isEnabledFor( + logging.DEBUG + ) or verbose_guards_log.isEnabledFor(logging.DEBUG) + add = TracingContext.get().guards_context.dynamo_guards.add + for guard in guards: + assert isinstance(guard, Guard) + add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/hooks.py b/phivenv/Lib/site-packages/torch/_dynamo/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..e41e71f2d1fc2afd4bc0d2692816184980cd6ca8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/hooks.py @@ -0,0 +1,24 @@ +"""Hook system for Dynamo's guard functionality. + +This module provides a way to register callback functions that are triggered during +guard-related operations. + +The Hooks class manages two types of hook functions: +- guard_export_fn: Called when guards need to be exported, taking a GuardsSet as input +- guard_fail_fn: Called when a guard check fails, taking a GuardFail object as input +These hooks enable customization of guard export and failure handling behaviors. +""" + +import dataclasses +from typing import Callable, Optional + +from torch._guards import GuardsSet + +from .types import GuardFail, GuardFilterEntry + + +@dataclasses.dataclass +class Hooks: + guard_export_fn: Optional[Callable[[GuardsSet], None]] = None + guard_fail_fn: Optional[Callable[[GuardFail], None]] = None + guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None diff --git a/phivenv/Lib/site-packages/torch/_dynamo/logging.py b/phivenv/Lib/site-packages/torch/_dynamo/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..2447772f70fa7b5b350939f154b30033cef63c69 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/logging.py @@ -0,0 +1,72 @@ +"""Logging utilities for Dynamo and Inductor. + +This module provides specialized logging functionality including: +- Step-based logging that prepends step numbers to log messages +- Progress bar management for compilation phases +- Centralized logger management for Dynamo and Inductor components + +The logging system helps track the progress of compilation phases and provides structured +logging output for debugging and monitoring. +""" + +import itertools +import logging +from typing import Any, Callable + +from torch.hub import _Faketqdm, tqdm + + +# Disable progress bar by default, not in dynamo config because otherwise get a circular import +disable_progress = True + + +# Return all loggers that torchdynamo/torchinductor is responsible for +def get_loggers() -> list[logging.Logger]: + return [ + logging.getLogger("torch.fx.experimental.symbolic_shapes"), + logging.getLogger("torch._dynamo"), + logging.getLogger("torch._inductor"), + ] + + +# Creates a logging function that logs a message with a step # prepended. +# get_step_logger should be lazily called (i.e. at runtime, not at module-load time) +# so that step numbers are initialized properly. e.g.: + +# @functools.cache +# def _step_logger(): +# return get_step_logger(logging.getLogger(...)) + +# def fn(): +# _step_logger()(logging.INFO, "msg") + +_step_counter = itertools.count(1) + +# Update num_steps if more phases are added: Dynamo, AOT, Backend +# This is very inductor centric +# _inductor.utils.has_triton() gives a circular import error here + +if not disable_progress: + try: + import triton # noqa: F401 + + num_steps = 3 + except ImportError: + num_steps = 2 + pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0) + + +def get_step_logger(logger: logging.Logger) -> Callable[..., None]: + if not disable_progress: + pbar.update(1) + if not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(f"{logger.name}") + + step = next(_step_counter) + + def log(level: int, msg: str, **kwargs: Any) -> None: + if "stacklevel" not in kwargs: + kwargs["stacklevel"] = 2 + logger.log(level, "Step %s: %s", step, msg, **kwargs) + + return log diff --git a/phivenv/Lib/site-packages/torch/_dynamo/metrics_context.py b/phivenv/Lib/site-packages/torch/_dynamo/metrics_context.py new file mode 100644 index 0000000000000000000000000000000000000000..7be00a988f7019b5fb25beb33017bfd64cc54916 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/metrics_context.py @@ -0,0 +1,229 @@ +"""Metrics collection and management system for Dynamo. + +This module provides context managers for gathering and reporting metrics during +compilation and runtime. + +It includes two main components: +- MetricsContext: A context manager for collecting metrics during compilation, supporting + nested contexts and various metric types (counters, sets, key-value pairs) +- RuntimeMetricsContext: A specialized context for runtime metrics collection that doesn't + require explicit context management + +The metrics system enables comprehensive monitoring and analysis of both compilation and +execution performance. +""" + +import heapq +import logging +import time +from collections.abc import Iterator +from typing import Any, Callable, Optional +from typing_extensions import TypeAlias + + +log = logging.getLogger(__name__) + + +class TopN: + """ + Helper to record a list of metrics, keeping only the top N "most expensive" elements. + """ + + def __init__(self, at_most: int = 25): + self.at_most = at_most + self.heap: list[tuple[int, Any]] = [] + + def add(self, key: Any, val: int) -> None: + # Push if we haven't reached the max size, else push and pop the smallest + fn = heapq.heappush if len(self.heap) < self.at_most else heapq.heappushpop + fn(self.heap, (val, key)) + + def __len__(self) -> int: + return len(self.heap) + + def __iter__(self) -> Iterator[tuple[Any, int]]: + return ((key, val) for val, key in sorted(self.heap, reverse=True)) + + +OnExitType: TypeAlias = Callable[ + [int, int, dict[str, Any], Optional[type[BaseException]], Optional[BaseException]], + None, +] + + +class MetricsContext: + def __init__(self, on_exit: OnExitType): + """ + Use this class as a contextmanager to create a context under which to accumulate + a set of metrics, e.g., metrics gathered during a compilation. On exit of the + contextmanager, call the provided 'on_exit' function and pass a dictionary of + all metrics set during the lifetime of the contextmanager. + """ + self._on_exit = on_exit + self._metrics: dict[str, Any] = {} + self._start_time_ns: int = 0 + self._level: int = 0 + + def __enter__(self) -> "MetricsContext": + """ + Initialize metrics recording. + """ + if self._level == 0: + # In case of recursion, track at the outermost context. + self._metrics = {} + self._start_time_ns = time.time_ns() + + self._level += 1 + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + _traceback: Any, + ) -> None: + """ + At exit, call the provided on_exit function. + """ + self._level -= 1 + assert self._level >= 0 + if self._level == 0: + try: + end_time_ns = time.time_ns() + self._on_exit( + self._start_time_ns, end_time_ns, self._metrics, exc_type, exc_value + ) + except Exception: + log.exception("Unexpected exception logging compilation metrics") + + def in_progress(self) -> bool: + """ + True if we've entered the context. + """ + return self._level > 0 + + def increment(self, metric: str, value: int) -> None: + """ + Increment a metric by a given amount. + """ + if self._level == 0: + raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext") + if metric not in self._metrics: + self._metrics[metric] = 0 + self._metrics[metric] += value + + def set(self, metric: str, value: Any, overwrite: bool = False) -> None: + """ + Set a metric to a given value. Raises if the metric has been assigned previously + in the current context. + """ + if self._level == 0: + raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext") + if metric in self._metrics and not overwrite: + raise RuntimeError( + f"Metric '{metric}' has already been set in the current context" + ) + self._metrics[metric] = value + + def set_key_value(self, metric: str, key: str, value: Any) -> None: + """ + Treats a give metric as a dictionary and set the k and value within it. + Note that the metric must be a dictionary or not present. + + We allow this to be called multiple times (i.e. for features, it's not uncommon + for them to be used multiple times within a single compilation). + """ + if self._level == 0: + raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext") + if metric not in self._metrics: + self._metrics[metric] = {} + self._metrics[metric][key] = value + + def update(self, values: dict[str, Any], overwrite: bool = False) -> None: + """ + Set multiple metrics directly. This method does NOT increment. Raises if any + metric has been assigned previously in the current context and overwrite is + not set to True. + """ + if self._level == 0: + raise RuntimeError("Cannot update metrics outside of a MetricsContext") + existing = self._metrics.keys() & values.keys() + if existing and not overwrite: + raise RuntimeError( + f"Metric(s) {existing} have already been set in the current context" + ) + self._metrics.update(values) + + def update_outer(self, values: dict[str, Any]) -> None: + """ + Update, but only when at the outermost context. + """ + if self._level == 0: + raise RuntimeError("Cannot update metrics outside of a MetricsContext") + if self._level == 1: + self.update(values) + + def add_to_set(self, metric: str, value: Any) -> None: + """ + Records a metric as a set() of values. + """ + if self._level == 0: + raise RuntimeError(f"Cannot add {metric} outside of a MetricsContext") + if metric not in self._metrics: + self._metrics[metric] = set() + self._metrics[metric].add(value) + + def add_top_n(self, metric: str, key: Any, val: int) -> None: + """ + Records a metric as a TopN set of values. + """ + if self._level == 0: + return + if metric not in self._metrics: + self._metrics[metric] = TopN() + self._metrics[metric].add(key, val) + + +class RuntimeMetricsContext: + def __init__(self, on_exit: OnExitType): + """ + Similar to MetricsContext, but used to gather the runtime metrics that are + decoupled from compilation, where there's not a natural place to insert a + context manager. + """ + self._on_exit = on_exit + self._metrics: dict[str, Any] = {} + self._start_time_ns: int = 0 + + def increment( + self, metric: str, value: int, extra: Optional[dict[str, Any]] = None + ) -> None: + """ + Increment a metric by a given amount. + """ + if not self._metrics: + # Start timing on the first entry + self._start_time_ns = time.time_ns() + if metric not in self._metrics: + self._metrics[metric] = 0 + self._metrics[metric] += value + + if extra: + for k, v in extra.items(): + if k not in self._metrics and v is not None: + self._metrics[k] = v + + def finish(self) -> None: + """ + Call the on_exit function with the metrics gathered so far and reset. + """ + if self._metrics: + try: + end_time_ns = time.time_ns() + self._on_exit( + self._start_time_ns, end_time_ns, self._metrics, None, None + ) + except Exception: + log.exception("Unexpected exception logging runtime metrics") + finally: + self._metrics = {} diff --git a/phivenv/Lib/site-packages/torch/_dynamo/mutation_guard.py b/phivenv/Lib/site-packages/torch/_dynamo/mutation_guard.py new file mode 100644 index 0000000000000000000000000000000000000000..1853a2ed99f130deb069a182b699f4afa9e6b17e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/mutation_guard.py @@ -0,0 +1,160 @@ +"""Mutation tracking and dynamic module detection system for Dynamo. + +This module provides mechanisms to track and respond to mutations in PyTorch modules +and detect dynamically created or modified modules. + +Key components: +- MutationTracker: Tracks mutations to objects and invalidates associated cached code +- GenerationTracker: Tracks module creation timing to identify dynamic instances +- Patching system for nn.Module to detect mutations and dynamic creation + +The system ensures that Dynamo's optimizations remain valid by detecting and responding +to runtime changes in module state and structure. +""" + +import functools +import weakref +from collections.abc import MutableMapping +from typing import Any + +import torch.nn +from torch.nn import Module + +from . import config +from .utils import ExactWeakKeyDictionary, nn_module_has_global_hooks + + +unpatched_nn_module_init = torch.nn.Module.__init__ + + +class MutationTracker: + db: ExactWeakKeyDictionary = ExactWeakKeyDictionary() + + def __init__(self) -> None: + self.mutation_count: int = 0 + self.watchers: list[weakref.ReferenceType[Any]] = [] + + def on_mutation(self, name: str) -> None: + self.mutation_count += 1 + tmp = self.watchers + self.watchers = [] + for ref in tmp: + guarded = ref() + if guarded is not None: + guarded.invalidate(ref) + + def track(self, guarded_code: Any) -> None: + self.watchers.append(weakref.ref(guarded_code)) + + +def watch(obj: Any, guarded_code: Any) -> None: + """invalidate guarded_code when obj is mutated""" + ensure_patched(type(obj)) + + if obj not in MutationTracker.db: + MutationTracker.db[obj] = MutationTracker() + tracker = MutationTracker.db[obj] + tracker.track(guarded_code) + + +def ensure_patched(cls: Any) -> None: + if getattr(cls, "___needs_mutation_patch", True): + cls.___needs_mutation_patch = False + original_setattr = cls.__setattr__ + + @functools.wraps(original_setattr) + def custom_setattr(self: Any, key: str, value: Any) -> None: + try: + MutationTracker.db[self].on_mutation(key) + except KeyError: + pass + return original_setattr(self, key, value) + + cls.__setattr__ = custom_setattr + + +class GenerationTracker: + generation: int = 0 + dynamic_classes: ExactWeakKeyDictionary = ExactWeakKeyDictionary() + generation_values: ExactWeakKeyDictionary = ExactWeakKeyDictionary() + + @classmethod + def tag(cls, obj: Any) -> None: + cls.generation_values[obj] = cls.generation + + @staticmethod + def mark_class_dynamic(cls: type[torch.nn.Module]) -> None: + assert issubclass(cls, torch.nn.Module) + GenerationTracker.dynamic_classes[cls] = True + + @classmethod + def get_generation_value(cls, obj: Any) -> int: + if obj not in cls.generation_values: + return -1 + return cls.generation_values[obj] + + @classmethod + def check(cls, obj: Any) -> bool: + return ( + obj in cls.generation_values + and cls.generation_values[obj] == cls.generation + ) + + @classmethod + def clear(cls) -> None: + cls.generation = 0 + cls.dynamic_classes = ExactWeakKeyDictionary() + cls.generation_values = ExactWeakKeyDictionary() + + +def is_dynamic_nn_module(obj: Any, is_export: bool) -> bool: + """Check for nn.Modules() created dynamically or mutated""" + if isinstance(obj, torch.nn.Module) and ( + "forward" in obj.__dict__ or isinstance(obj, (dict, MutableMapping)) + ): + # A monkey patched `.forward` indicates something wacky is going on + # Similarly a nn module also subclassed as a dict is unusual. + return True + if hasattr(obj, "torchdynamo_force_dynamic"): + return obj.torchdynamo_force_dynamic + if ( + isinstance(obj, torch.nn.Module) + and config.inline_inbuilt_nn_modules + and (not is_export or config.install_free_tensors) + ): + return True + + if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks(): + return True + dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check( + obj + ) + return dyn + + +def install_generation_tagging_init() -> None: + """ + Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__ + so we can detect nn.Module instances created dynamically inside forward methods. + """ + + if getattr(Module, "___needs_generation_tag_patch", True): + init = Module.__init__ + + def patched_init(self: Module, *args: Any, **kwargs: Any) -> None: + init(self, *args, **kwargs) + GenerationTracker.tag(self) + + Module.__init__ = patched_init # type: ignore[method-assign] + + setstate = Module.__setstate__ + + def patched_setstate(self: Module, state: Any) -> None: + setstate(self, state) + GenerationTracker.tag(self) + + Module.__setstate__ = patched_setstate # type: ignore[method-assign] + + Module.___needs_generation_tag_patch = False # type: ignore[attr-defined] + + GenerationTracker.generation += 1 diff --git a/phivenv/Lib/site-packages/torch/_dynamo/output_graph.py b/phivenv/Lib/site-packages/torch/_dynamo/output_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..6320ae1a617a02648ae84957f14c9eb95ded8be7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/output_graph.py @@ -0,0 +1,3145 @@ +# mypy: allow-untyped-defs + +""" +Core graph building functionality for PyTorch's Dynamo system. This module contains +the essential components for constructing and managing FX graphs during compilation: + +- OutputGraph: Manages the overall graph construction and compilation process. It owns + a SubgraphTracer and handles graph compilation, execution, and state management. + OutputGraph also manages features like graph deduplication, symbolic shape handling, + and tracking of side effects. + +- SubgraphTracer: Handles the actual FX graph construction by tracing Python code. + It supports advanced features like higher-order operators through nested tracers, + lifting of free variables, and handling of symbolic shapes. + +The module supports key Dynamo features including: +- Higher-order operators through nested SubgraphTracers +- Graph deduplication for optimization +- Symbolic shape handling and propagation +- Side effect tracking and management +- Guard insertion and management +""" + +import collections +import contextlib +import copy +import functools +import inspect +import itertools +import logging +import operator +import re +import sys +import traceback +import weakref +from dataclasses import dataclass, field as dc_field +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union + +import sympy + +import torch._guards +import torch._logging +import torch.distributed as dist +import torch.nn +import torch.utils._pytree as pytree +from torch import fx, Tensor +from torch._C._dynamo import guards +from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis +from torch._guards import ( + CompileContext, + CompileId, + GlobalContextCheckpointState, + Source, + tracing, + TracingContext, +) +from torch._subclasses.fake_tensor import FakeTensor +from torch._utils_internal import signpost_event +from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined] +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.symbolic_shapes import ( + free_symbols, + guard_scalar, + is_symbolic, + ShapeEnv, + Specialization, +) +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._ordered_set import OrderedSet +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from . import config, exc, logging as torchdynamo_logging, variables +from .backends.registry import CompiledFn, CompilerFn +from .bytecode_transformation import ( + create_call_function, + create_instruction, + create_load_const, + Instruction, + unique_id, +) +from .code_context import code_context +from .codegen import PyCodegen +from .current_scope_id import enter_new_scope +from .device_interface import get_interface_for_device +from .exc import ( + BackendCompilerFailed, + exceptions_allowed_to_be_fallback, + SkipFrame, + unimplemented_v2, + unimplemented_v2_with_warning, +) +from .graph_deduplication import apply_graph_deduplication +from .graph_region_tracker import GraphRegionTracker +from .guards import GuardBuilder, install_guard +from .mutation_guard import is_dynamic_nn_module +from .side_effects import AttributeMutationExisting, SideEffects +from .source import ( + AttrSource, + BackwardStateSource, + ConstantSource, + GetItemSource, + GlobalStateSource, + is_constant_source, + is_from_local_source, + LocalSource, + NumpyTensorSource, + ParamBufferSource, + ShapeEnvSource, + SyntheticLocalSource, + TensorProperty, + TensorPropertySource, +) +from .utils import ( + _extract_tensor_dict, + checkpoint_params, + CleanupHook, + clone_inputs, + count_calls, + counters, + dynamo_timed, + get_instruction_source_311, + get_locals_to_steal, + get_static_address_type, + get_unique_name_wrt, + graph_break_reasons, + increment_op_count, + istype, + lazy_format_graph_code, + LazyString, + nn_module_proxy, + same, + set_example_value, +) +from .variables.base import VariableTracker +from .variables.builder import ( + BackwardStateGraphArg, + GraphArg, + TrackedFake, + wrap_fx_proxy, +) +from .variables.ctx_manager import ContextWrappingVariable +from .variables.lists import BaseListVariable +from .variables.misc import CellVariable, NullVariable +from .variables.nn_module import NNModuleVariable +from .variables.tensor import ( + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .variables.torch_function import TensorWithTFOverrideVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslatorBase + + +log = logging.getLogger(__name__) +graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph") +graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") +graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") +trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") + +RootGuardManager = guards.RootGuardManager + + +@dataclass(frozen=True) +class VariableTrackerCacheKey: + vt_id: int + # Two different source can point to the same object. However, Dynamo handles + # globals and local source differently when it comes to guards and possibly + # some other parts as well. So, cache also relies on the source. + source: Source + + +@dataclass(frozen=True) +class AliasingInfo: + has_aliasing: bool + msg: str + + +@dataclass(frozen=True) +class MutationInfo: + has_mutation: bool + msg: str + + +class VariableTrackerCache: + def __init__(self): + self.cache = {} + + def lookup(self, value, source): + key = VariableTrackerCacheKey(id(value), source) + if key not in self.cache: + return None + return self.cache[key] + + def add(self, value, source, vt): + key = VariableTrackerCacheKey(id(value), source) + self.cache[key] = vt + + def clone(self): + # Needed for copy and restore graph state + new_cache = VariableTrackerCache() + new_cache.cache.update(self.cache) + return new_cache + + def clear(self): + self.cache.clear() + + +@functools.cache +def _step_logger(): + return torchdynamo_logging.get_step_logger(log) + + +@dataclass +class GraphCompileReason: + """Stores why a given output graph was compiled; i.e. what caused the graph break.""" + + reason: str + user_stack: list[traceback.FrameSummary] + + # Indicates if this was a graph compile reason due to graph break. + graph_break: bool = True + + def __post_init__(self): + if self.graph_break: + graph_break_reasons.append(self) + + +def _get_gen_rand_values_fn(random_calls): + def _gen_rand_values(): + return [fn(*args, **kwargs) for fn, args, kwargs in random_calls] + + return _gen_rand_values + + +class FakeRootModule(torch.nn.Module): + """Trick the constructor of fx.GraphModule""" + + def __init__(self, nn_modules: dict[str, torch.nn.Module]): + super().__init__() + for k, v in nn_modules.items(): + setattr(self, k, v) + + def __repr__(self) -> str: + return "FakeRootModule(...)" + + def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]): + for k, v in nn_modules.items(): + setattr(self, k, v) + + +class WrapperBackend: + def __init__(self, backend: CompilerFn): + self.backend: CompilerFn = backend + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]): + self.restore = checkpoint_params(gm) + self.gm = gm + copy_gm = copy.deepcopy(self.gm) + self.candidate = self.backend(copy_gm, example_inputs) + + if self.candidate is None or self.candidate is self.gm.forward: + return self.gm.forward + + if not config.verify_correctness: + return self.candidate + + # if verify_correctness=True + try: + correct = self.gm.forward(*clone_inputs(example_inputs)) + result = self.candidate(*clone_inputs(example_inputs)) + + # TODO: replace `same` function with the one in testing + if same(correct, result): + return self.candidate + + raise RuntimeError(f"incorrect results of backend {self}") + + except Exception: + log.exception("error in verify_correctness") + raise + finally: + self.restore() + + +Scope = dict[str, object] + + +@dataclass +class OutputGraphGuardsState: + """ + A base class containing fields that are considered "persistent" when we + want to save all the important state for reconstrucing guards in a different + process. Normally we don't need to add states here, but we may have to when + the information is needed to serialize the guards, so the fields here are + supposed to be serializable as a requirement. + """ + + local_scope: Scope + global_scope: Scope + # This records the initial torch function mode stack for guarding + torch_function_mode_stack: list[torch.overrides.TorchFunctionMode] + guard_on_key_order: set[Source] + # Map from graph input's `Source` to sizes / strides metadata + input_source_to_sizes_strides: dict[Source, dict[str, Any]] + dual_level: int + functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter] + current_device: Optional[torch.device] + + export: bool = False + export_constraints: bool = False + + _guards: Optional[torch._guards.GuardsSet] = None + _aotautograd_guards: Optional[list[torch._guards.GuardEnvExpr]] = None + + @property + def shape_env(self): + raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}") + + @property + def guards(self): + return self._guards + + @property + def aotautograd_guards(self): + return self._aotautograd_guards + + +@dataclass +class StackLocalsMetadata: + """ + Stores metadata for a frame's stack and locals for the purposes of building resume functions + """ + + stack_null_idxes: list[int] = dc_field(default_factory=list) + locals_null_keys: list[str] = dc_field(default_factory=list) + stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list) + stack_ctx_idxes_orig: list[int] = dc_field(default_factory=list) + locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list) + + +class OutputGraph(OutputGraphGuardsState): + """ + Wrapper class to hold outputs of InstructionTranslator. Mainly the + generated fx.Graph. + + OutputGraph is 1:1 with a frame being processed. Each frame is associated + with some root InstructionTranslator. When user code calls a function, + we construct a InliningInstructionTranslator that continues to write into + the root InstructionTranslator's OutputGraph. + """ + + side_effects: SideEffects + + def __init__( + self, + code_options: dict[str, Any], + compiler_fn: Optional[CompilerFn], + root_tx, + export: bool, + export_constraints, + frame_state, + local_scope: Scope, + global_scope: Scope, + f_code, + torch_function_mode_stack, + package, + ): + super().__init__( + local_scope, + global_scope, + torch_function_mode_stack, + guard_on_key_order=set(), + input_source_to_sizes_strides={}, + dual_level=torch.autograd.forward_ad._current_level, + functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(), + current_device=torch.utils._device.CURRENT_DEVICE, + ) + self.tracers = [SubgraphTracer(self, is_export=export)] + # Map from graph input's `Source` to its `VariableTracker` to + # de-duplicate graph inputs by source and reuse the tracker + self.input_source_to_var: dict[Source, VariableTracker] = {} + self.export = export + self.export_constraints = export_constraints + self.frame_state = frame_state + self.cleanup_hooks: list[Callable[[], Any]] = [] + # compile_id is an id number for the current torch.compile + self.compile_id: int = next(_compile_id_counter) + # Set of globals installed via install_global* APIs + self.installed_globals: set[str] = set() + + # TODO: maybe should just pass the entire f_code in here? Not + # sure... + self.co_fields = { + "co_name": f_code.co_name, + "co_filename": f_code.co_filename, + "co_firstlineno": f_code.co_firstlineno, + } + + self.region_tracker = GraphRegionTracker() + + # tracked_fakes says where any tensor that was wrapped to fake came + # from. It is similar to GraphArg, in that all GraphArgs will get + # will get added to TrackedFakes, but TrackedFakes also contains + # GraphArgs that got pruned, and things like Tensor attributes which + # aren't explicit graph inputs. Used by shape guard + self.tracked_fakes: list[TrackedFake] = [] + + shape_env = ShapeEnv( + # Reference Cycle! + # Share a reference to the list of TrackedFake. + # + # ShapeEnv needs this in order to be able to reproduce the call + # to produce_guards at an arbitrary time point. That is because + # TrackedFake instances may have its metadata changed throughout + # the program execution. + tracked_fakes=self.tracked_fakes, + allow_scalar_outputs=config.capture_scalar_outputs, + allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, + prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards, + allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts, + co_fields=self.co_fields, + ) + + # In export mode, we force the shape_env to strictly disallow any constraining + # of the user marked dynamic dims + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_mode = torch._subclasses.FakeTensorMode( + shape_env=shape_env, + # TODO (tmanlaibaatar) Remove this once we always lift params and buffers + allow_non_fake_inputs=True if self.export else False, + export=self.export, + ) + self.tracing_context: TracingContext = TracingContext(fake_mode) + self.tracing_context.traced_code.append(f_code) + self.dynamo_compile_id: Optional[CompileId] = ( + CompileContext.current_compile_id() + ) + self.init_ambient_guards() + + # Map each tensor id to a list of sources. This is necessary because + # tensor ids cannot be recovered from tracked fakes (in general). + # We use this map to interpret (i.e., check for violations of) constraints, + # specifically equality constraints, which have shared tensor ids in them. + # This map should also be generally useful, e.g., for (de)serialization. + self.tracked_fakes_id_to_source: dict[int, list[Source]] = ( + collections.defaultdict(list) + ) + # Stores the full fqn of a param or buffer to the relevant source. + self.param_name_to_source: Optional[dict[str, Source]] = {} + self.side_effects = SideEffects(self) + # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL + # and LOAD_ATTR for same python objects free. + self.variable_tracker_cache = VariableTrackerCache() + self.unique_var_id = itertools.count() + self.code_options: dict[str, Any] = dict(code_options) + self.output_instructions: list[Instruction] = [] + # used to track nodes that are added between calls of copy_graphstate + # and restore_graphstate + self.timestamp = 0 + + # A list of register_finalizer_fns to apply to the output graph module + self.register_finalizer_fns: list[Callable[[fx.GraphModule], None]] = [] + + # Not checkpointed + self.compiler_fn: Optional[CompilerFn] = compiler_fn + self.root_tx = root_tx + + self.package = package + # Given a source, what are the user stacks of all locations that + # accessed it? + # + # For efficiency, we only populate this: + # - During export, and + # - If the source could potentially lead to a spurious export input + # + # Feel free to populate this more frequently if other use-cases arise, + # but be aware that we have to generate full stacks for each + # recording! + self.source_to_user_stacks: dict[Source, list[traceback.StackSummary]] = {} + + self._current_tx: list[InstructionTranslatorBase] = [] + self.cleanups: list[CleanupHook] = [] + self.should_exit = False + self.unspec_variable_map: dict[str, UnspecializedPythonVariable] = {} + + # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty + self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() + + # Tracks if the output graph has a user defined allowed function in the + # graph. This is used later to determine if we should fallback to eager + # for certain exceptions. THe idea is that if the user has applied + # allow_in_graph, they would like to see the error instead of falling + # back for backend errors. + self.has_user_defined_allowed_in_graph = False + + # Tracks a list of called ops that were not tagged with "pt2_compliant_tag". + # This information is useful for logging. + self.non_compliant_ops: set[torch._ops.OpOverload] = set({}) + + # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag". + # This information is useful for logging. + self.compliant_custom_ops: set[torch._ops.OpOverload] = set({}) + + # We save the global torch state here to be restored in case of graph + # breaks. The relevant issue is seen here + # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086 + # where inlining of a function changes the global state (because of the + # presence of torch.no_grad) and there is a graph break. + self.save_global_state() + + # Tracks the original FQNs of the constant tensors from the original graph, + # i.e. buffers and parameters. + self.dynamo_flat_name_to_original_fqn: dict[str, str] = {} + + # All calls to random() are replaced with a single call to __gen_rand_values + # functions that returns a tuple of random values for each original call. + # random_calls tracks calls to random() and random_values_var stores the name of + # the variable that stores __gen_rand_values results. + self.random_calls: list[ + tuple[Callable[..., object], tuple[object, ...], dict[str, object]] + ] = [] + self.random_values_var: Any = None + + # Bytecode to insert right before we call the graph + self.pregraph_bytecode: list[Instruction] = [] + + # Use to pass values to backward hooks when using compiled autograd + self.backward_state: dict[str, VariableTracker] = {} + self.backward_state_proxy: Optional[torch.fx.Proxy] = None + self.backward_state_var: Optional[str] = None + + self.name_of_builtins_dict_key_in_fglobals: str = ( + self.install_builtins_dict_in_fglobals() + ) + + self.compiler_trace_stack = contextlib.ExitStack() + + # These are the ambient, currently-global saved_tensor_hooks stashed in autograd, + # that are set for the entire duration of the compiled region. + # This is an invariant today because we graph break on the saved_tensor_hook + # context manager inside a compiled region + self.saved_tensors_hooks_subgraph_names: Optional[list[str]] = ( + self.maybe_install_saved_tensors_hooks_subgraphs() + ) + + def mark_bytecode_tracing_start(self): + self.compiler_trace_stack.enter_context( + dynamo_timed( + "bytecode_tracing", + log_pt2_compile_event=True, + ) + ) + + def mark_bytecode_tracing_stop(self): + self.compiler_trace_stack.close() + + def install_builtins_dict_in_fglobals(self): + # f_globals["__builtins__"] can be a dict or a module. This is an + # implementation detail - + # https://docs.python.org/3/library/builtins.html. + + # This makes guarding on any builtin messy because the guard check_fn + # has to check if the __builtins__ is a module or dict, and then access + # by either using getattr or getitem respectively. + + # To solve this problem, we insert a new entry in f_globals which points + # to the builtins __dict__ and then we guard any builtin on this dict. + # To avoid any collision with the pre-existing keys, we use the + # install_global to give us a unique dict key. + + f_builtins = self.global_scope["__builtins__"] + if not isinstance(f_builtins, dict): + f_builtins = f_builtins.__dict__ + return self.install_global("__builtins_dict__", f_builtins) + + def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"): + name = f"{prefix}{len(self.backward_state)}" + assert name not in self.backward_state + self.backward_state[name] = hook + return name, self.get_backward_state_proxy() + + def get_backward_state_proxy(self): + if self.backward_state_proxy is None: + if self.export: + unimplemented_v2( + gb_type="backward_state does not support export", + context="", + explanation="Compiled autograd doesn't work with `torch.export`.", + hints=[], + ) + example_value = BackwardState() + self.backward_state_proxy = self.root_tracer.create_graph_input( + "dynamo_backward_state", + type(example_value), + example_value, + source=BackwardStateSource(), + ) + self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg() + self.backward_state_var = self.new_var() + return self.backward_state_proxy + + # This gets its own helper function so guards DEBUG logs are more informative + def init_ambient_guards(self): + # Register a SHAPE_ENV guard to make sure we setup shape guards + # that show up in ShapeEnv + self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) + + self.guards.add( + GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS) + ) + + self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE)) + + self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE)) + + self.guards.add( + GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE) + ) + + ci = torch._C._functorch.peek_interpreter_stack() + if ci is not None: + self.guards.add( + GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH) + ) + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: + self.guards.add( + GlobalStateSource().make_guard( + GuardBuilder.AUTOGRAD_SAVED_TENSORS_HOOKS + ) + ) + + def maybe_install_saved_tensors_hooks_subgraphs(self) -> Optional[list[str]]: + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + return None + + get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks + are_inline_hooks = ( + torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable + ) + hooks = get_hooks() + if not are_inline_hooks(hooks): + return None + + # If GraphModule provided by user contains fx.wrap, + # We can only rely on user provided cache hash in this case. + # If user did not provide cache hash - then we always bypass cache. + + pack_gm, unpack_gm = hooks + pack_subgraph_name = self.install_subgraph( + "saved_tensors_hooks_pack", + torch.fx.GraphModule(self.nn_modules, pack_gm.graph), + ) + unpack_subgraph_name = self.install_subgraph( + "saved_tensors_hooks_unpack", + torch.fx.GraphModule(self.nn_modules, unpack_gm.graph), + ) + assert pack_subgraph_name == "saved_tensors_hooks_pack_0" + assert unpack_subgraph_name == "saved_tensors_hooks_unpack_0" + return [pack_subgraph_name, unpack_subgraph_name] + + def dump_guards_state(self): + return OutputGraphGuardsState( + local_scope=self.local_scope, + global_scope=self.global_scope, + torch_function_mode_stack=self.torch_function_mode_stack, + guard_on_key_order=self.guard_on_key_order, + input_source_to_sizes_strides=self.input_source_to_sizes_strides, + dual_level=self.dual_level, + functorch_layers=self.functorch_layers, + current_device=self.current_device, + export=self.export, + export_constraints=self.export_constraints, + _guards=self.guards, + _aotautograd_guards=self.aotautograd_guards, + ) + + def synthetic_graph_input(self, fn, args): + """ + call fn(*args) before the graph runs and turn the result into a fake input. + """ + example_value = fn(*args) + varname = self.new_var() + cg = PyCodegen(self.root_tx) + cg.add_push_null( + lambda: cg.load_import_from( + fn.__module__, + fn.__name__, + ) + ) + cg.foreach(map(variables.ConstantVariable.create, args)) + cg.call_function(len(args), False) + cg.store(varname) + self.pregraph_bytecode.extend(cg.get_instructions()) + source = SyntheticLocalSource(varname) + result = VariableTracker.build(self.root_tx, example_value, source) + # Realize the VT because we will delete the guards on it in the next line. + result = result.realize() + TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( + source + ) + return result + + def add_cleanup_hook(self, fn: Callable[[], Any]): + self.cleanup_hooks.append(fn) + + def call_cleanup_hooks(self): + for hook in reversed(self.cleanup_hooks): + hook() + self.cleanup_hooks.clear() + + @property + def root_tracer(self): + return self.tracers[0] + + @property + def current_tracer(self): + return self.tracers[-1] + + def is_root_tracer(self): + # Helper to tell if we are inside the higher order operator tracing. + return len(self.tracers) == 1 + + @property + def graph(self): + return self.current_tracer.graph + + # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer. + @graph.setter + def graph(self, value): + self.current_tracer.graph = value + + @property + def input_name_to_proxy(self): + return self.current_tracer.input_name_to_proxy + + @property + def real_value_cache(self): + return self.current_tracer.real_value_cache + + @property + def bound_symbols(self): + return self.current_tracer.bound_symbols + + # If you are here, and you're looking for create_graph_input, + # to avoid ambiguity, please call one of the following: + # - self.current_tracer.create_graph_input + # - self.root_tracer.create_graph_input + # See NOTE [HigherOrderOperator tracing design] for more context. + + def create_proxy(self, *args, **kwargs): + return self.current_tracer.create_proxy(*args, **kwargs) + + def create_node(self, *args, **kwargs): + return self.current_tracer.create_node(*args, **kwargs) + + def remove_node(self, *args, **kwargs): + return self.current_tracer.remove_node(*args, **kwargs) + + @contextlib.contextmanager + def subtracer(self, source_target, prior_tracer): + new_scope_ctx = enter_new_scope() + try: + if prior_tracer: + # Lineage MUST stay preserved + assert prior_tracer.parent is self.current_tracer + new_scope_ctx.__enter__() + tracer = ( + prior_tracer + if prior_tracer + else SubgraphTracer( + self, + parent=self.current_tracer, + source_target=source_target, + is_export=self.current_tracer.is_export, + ) + ) + self.tracers.append(tracer) + yield tracer + finally: + new_scope_ctx.__exit__(None, None, None) + self.tracers.pop() + + @property + def output(self): + return self + + @property + def fake_mode(self): + return self.tracing_context.fake_mode + + @property + def shape_env(self): + return self.tracing_context.fake_mode.shape_env + + @property + def guards(self) -> torch._guards.GuardsSet: + return self.tracing_context.guards_context.dynamo_guards + + @property + def nn_modules(self) -> dict[str, Any]: + return self.tracing_context.module_context.nn_modules + + @property + def aotautograd_guards(self): + return self.tracing_context.guards_context.aotautograd_guards + + def save_global_state(self, out=None): + """ + Saves to out if it is provided. Else saves to the tracing context's global_state. + """ + global_state = cast( + dict[str, tuple[Callable[..., Any], bool]], + ( + out + if out is not None + else self.tracing_context.global_context.global_state + ), + ) + + global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) + + global_state["autocast_enabled"] = ( + functools.partial(torch.set_autocast_enabled, "cuda"), + torch.is_autocast_enabled("cuda"), + ) + global_state["autocast_cpu_enabled"] = ( + functools.partial(torch.set_autocast_enabled, "cpu"), + torch.is_autocast_enabled("cpu"), + ) + global_state["autocast_gpu_dtype"] = ( # type:ignore[assignment] + functools.partial(torch.set_autocast_dtype, "cuda"), + torch.get_autocast_dtype("cuda"), + ) + global_state["autocast_cpu_dtype"] = ( # type:ignore[assignment] + functools.partial(torch.set_autocast_dtype, "cpu"), + torch.get_autocast_dtype("cpu"), + ) + global_state["autocast_cache_enabled"] = ( + torch.set_autocast_cache_enabled, + torch.is_autocast_cache_enabled(), + ) + + def push_tx(self, tx): + self._current_tx.append(tx) + + def pop_tx(self): + return self._current_tx.pop() + + @property + def current_tx(self): + return self.root_tx if not self._current_tx else self._current_tx[-1] + + def count_calls(self): + return count_calls(self.graph) + + def is_empty_graph(self): + return len(list(self.graph.nodes)) == 0 + + def get_submodule(self, keys): + assert keys + obj: Union[torch.nn.Module, dict[str, torch.nn.Module]] = self.nn_modules + for k in keys.split("."): + if isinstance(obj, dict): + obj = obj[k] + else: + obj = getattr(obj, k) + return obj + + def new_var(self, name="tmp"): + existing = set(self.code_options["co_varnames"]) + # In common case, this will be O(1) + while True: + var = f"{name}_{next(self.unique_var_id)}" + if var not in existing: + self.code_options["co_varnames"] += (var,) + return var + + def update_co_names(self, name): + """Ensure self.code_options.co_names contains name""" + if name not in self.code_options["co_names"]: + self.code_options["co_names"] += (name,) + + @staticmethod + def module_key_name(*names): + # create a new unique name + name = "_".join(map(str, names)) + # Strip the guard lookup L/G access + name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name) + # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv + name = re.sub(r"\[(\d+)\]", r"_\g<1>", name) + # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv + name = re.sub(r"[^a-zA-Z0-9]", "_", name) + + if not name or not name[0].isalpha(): + name = "sub" + name + + return name + + def register_static_attr_and_return_proxy( + self, attr_prefix: str, attr_value: Any + ) -> fx.Proxy: + attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules) + # TODO `nn_modules` has been historically overloaded to store a lot more + # than just nn module objects, fix that. + self.nn_modules[attr_name] = attr_value + proxy = self.create_proxy("get_attr", attr_name, (), {}) + set_example_value(proxy.node, attr_value) + return proxy + + def register_attr_or_module( + self, + target: Union[torch.nn.Module, torch.Tensor, Any], + *names, + **options, + ): + if is_dynamic_nn_module(target, self.export): + # Instead of returning UnspecializedNNModuleVariable, call + # VariableTracker.build so that it is tracked for mutation. + return VariableTracker.build(self.current_tx, target, **options) + + options = dict(options) + assert "source" in options + source = options["source"] + assert not isinstance(source, ParamBufferSource) + + if isinstance(target, torch.Tensor): + tracer = self.current_tracer + if not self.is_root_tracer(): + # For higher order ops, we don't want to insert the get_attr in + # innermost graph. Instead, we want to raise the params/buffers + # as inputs to the higher-order graph, and register them as + # get_attrs in the root tracer. + + # Note that Dynamo will still call lift_tracked_freevar_to_input + # when these inputs are encountered for the inner graph. The + # only difference is what happens at the root tracer for + # nn.Parameters vs free inputs. The free inputs are registered + # as placeholders in the root graph, whereas the nn.Parameters + # are registered as get_attr nodes in the root graph. + tracer = self.root_tracer + + def wrap_name(module_key): + assert self.param_name_to_source is not None + self.param_name_to_source[module_key] = source + + # Check if the attr has already been registered. This can happen + # when two different sources point to the same tensor. + if target in self.root_tx.output.side_effects: + return self.root_tx.output.side_effects[target] + + if get_static_address_type(target) == "guarded" and not isinstance( + source, NumpyTensorSource + ): + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + elif not is_constant_source(source): + install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH)) + + vt = wrap_fx_proxy( + self.root_tx, + tracer.create_proxy("get_attr", module_key, (), {}), + example_value=target, + **options, + ) + + # Track the object so to avoid duplicate registration in case of + # different sources pointing to the same tensor object. + vt = self.root_tx.output.side_effects.track_object_existing(target, vt) + + assert "tensor_dict" not in vt.proxy.node.meta + vt.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(target) + + return vt + + elif isinstance(target, torch.nn.Module): + assert isinstance(target, torch.nn.Module) + + if source: + install_guard(source.make_guard(GuardBuilder.NN_MODULE)) + + def wrap_name(module_key): + return NNModuleVariable(type(target), module_key, target, **options) + + else: + # This is Dynamo created graph module, e.g., graph module coming + # from higher order ops. NNModuleVariable tracker can't be + # sourceless, so let's return a unspecializedNNModule variable + # tracker. + def wrap_name(module_key): + return variables.UnspecializedNNModuleVariable(target, **options) + + elif isinstance(target, (torch.SymInt, torch.SymFloat)): + # HACKY CODE REGION BEGIN + # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS + # This ultimately gets written to self.nn_modules, which is unfortunate + # Attrs that are tenors and symints and such need to be migrated to have their + # own storage + # alas, this is like this for now + + def wrap_name(module_key): + return SymNodeVariable.create( + self, + self.create_proxy("get_attr", module_key, (), {}), + sym_num=target, + **options, + ) + + # HACKY CODE REGION END + else: + + def wrap_name(module_key): + self.output.update_co_names(module_key) + self.global_scope[module_key] = target + return VariableTracker.build( + self, # type: ignore[arg-type] + target, + ConstantSource(source_name=module_key), + ) + + for k, v in self.nn_modules.items(): + if v is target: + # it already exists + return wrap_name(k) + + name = OutputGraph.module_key_name(*names) + name = get_unique_name_wrt(name, self.nn_modules, self.global_scope) + self.nn_modules[name] = target + if isinstance(target, torch.nn.Module): + + def register_leaf_name(leaf_name): + assert self.param_name_to_source is not None + new_source = ParamBufferSource(source, leaf_name) + new_name = f"{name}.{leaf_name}" + self.param_name_to_source[new_name] = new_source + if isinstance(source, LocalSource): + self.dynamo_flat_name_to_original_fqn[ + OutputGraph.module_key_name(new_source.name()) + ] = leaf_name + + # annoying, but there are cases when we do not have parameters + # see test_nn_moduledict_contains + if hasattr(target, "_parameters"): + for leaf_name, _ in target.named_parameters(): + register_leaf_name(leaf_name) + if hasattr(target, "_buffers"): + for leaf_name, _ in target.named_buffers(): + register_leaf_name(leaf_name) + + return wrap_name(name) + + def handle_aliases_for_stolen_lists(self, tx): + # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive + maybe_gm = self.local_scope.get("self") + stolen_list_names = get_locals_to_steal(maybe_gm) + if not stolen_list_names: + return [], {} + + alias_insts = [] + needs_alias: dict[str, list[VariableTracker]] = {} + + queue = [ + *tx.stack, + *tx.symbolic_locals.values(), + *self.side_effects.store_attr_mutations.keys(), + ] + + while queue: + x = queue.pop() + if isinstance(x, BaseListVariable): + assert isinstance(x.items, list) + queue += x.items + continue + + if not ( + ( + x not in self.side_effects.store_attr_mutations + or isinstance(x.mutation_type, AttributeMutationExisting) + ) + and isinstance(x.source, GetItemSource) + and isinstance(x.source.base, LocalSource) + and x.source.base.local_name in stolen_list_names + ): + continue + + stolen_name = x.source.base.local_name + if stolen_name not in needs_alias: + needs_alias[stolen_name] = [] + needs_alias[stolen_name].append(x) + + visited = {} + overridden_sources: dict[Source, Source] = {} + for arg in self.graphargs: + if not ( + isinstance(arg._example, list) + and isinstance(arg.source, LocalSource) + and arg.source.local_name in needs_alias + ): + continue + + # arg is a list that will be cleared by the compiled function + list_name = arg.source.local_name + assert list_name in self.code_options["co_varnames"] + for x in needs_alias[list_name]: + # Skip if already handled. + if x.source in overridden_sources: + continue + + # A small codegen optimization because we might have different + # VariableTrackers that share the same source. + list_idx = x.source.index # type: ignore[attr-defined] + if list_idx not in visited: + alias_name = self.new_var( + f"{list_name}_ref" + ) # self.new_var already adds unique id suffix + + visited[list_idx] = alias_name + # bytecode of `alias_name = list_name[list_idx]` + alias_insts.extend( + [ + create_instruction("LOAD_FAST", argval=list_name), + create_load_const(list_idx), + create_instruction("BINARY_SUBSCR"), + create_instruction("STORE_FAST", argval=alias_name), + ] + ) + + # operate on alias, handled by suffix codegen + old_source = x.source + overridden_sources[old_source] = LocalSource(visited[list_idx]) + + # NOTE: we need `overridden_sources` because (1) we want to codegen for + # these list items to use the new local source, but (2) we want to avoid + # updating `source` in place because that might break invariants in + # other parts of Dynamo like guards. + return alias_insts, overridden_sources + + def _get_stack_values_to_restore(self, tx, stack_pops): + """ + Gets the stack + locals values belonging to tx that need to be restored. + + Also prunes dead tx locals and realizes all VTs in the tx's stack. + + NullVariables in stack/locals will NOT be restored, unless they are the top `stack_pops` + elements of the stack - it is expected that the next instruction to run will pop the top + `stack_pops` elements of the stack, so we should codegen NULLs. + + Returns: + - stack_values: stack and locals values that need to be restored + - restore_vars: names of locals corresponding to the locals part of `stack_values` + - meta: locations of NULLs and ContextWrappingVariables in the stack/locals + (ignores the top `stack_pops` values on the stack) + """ + tx.prune_dead_locals() + + stack_values = [] + meta = StackLocalsMetadata() + + # realize any unrealized tensor VTs in case they + # need to be added to self.nn_modules as attributes + for i, value in enumerate(tx.stack): + variables.LazyVariableTracker.realize_all(value) + # ignore top `stack_pops` values on the stack + if len(tx.stack) - i <= stack_pops: + stack_values.append(value) + continue + if isinstance(value, NullVariable): + meta.stack_null_idxes.append(i) + else: + stack_values.append(value) + if isinstance(value, ContextWrappingVariable): + target_values = ( + () if value.target_values is None else tuple(value.target_values) + ) + # NOTE: track index in stack after NULLs have been removed + meta.stack_ctx_args.append((len(stack_values) - 1, target_values)) + meta.stack_ctx_idxes_orig.append(i) + + # Add all the local vars to the "stack" so restore at the end + restore_vars: list[str] = [] + val_to_names: dict[VariableTracker, list[str]] = {} + # NB: Typically (i.e., for graph compile from RETURN_VALUE), + # symbolic_locals will be empty at this point, as prune_dead_locals + # will clear out all of symbolic_locals because RETURN_VALUE is the + # last instruction and no more locals are used. The fanciness here + # is only needed for partial graphs. + # NOTE: All cell and free variables are represented as CellVariable, + # so checks for NULLs and context managers in the case of codegen'ing resume + # functions will not be performed on them. This is expected behavior. + for k, v in tx.symbolic_locals.items(): + # Note! this explicitly uses .local_name for matching + # Failure to do so will cause spurious registrations in val_to_names. + # This will in turn result in spurious variables showing up in the graph. + # This was very tricky to debug. For an example, dump the graph at call_user_compiler + # while running test_subgraphs.py + if isinstance(v.source, LocalSource) and v.source.local_name == k: + continue # no need to restore initial state + if isinstance(v, CellVariable) and v.local_name == k: + continue # no need to restore initial state + # Do not load variable if it is NULL. + if sys.version_info >= (3, 12): + # Continuation function will load the NULL for v. + if type.__instancecheck__(NullVariable, v): + meta.locals_null_keys.append(k) + continue + else: + # A variable should never be NULL in < 3.12 + assert not type.__instancecheck__(NullVariable, v) + if isinstance(v, ContextWrappingVariable): + target_values = ( + () if v.target_values is None else tuple(v.target_values) + ) + meta.locals_ctx_args.append((k, target_values)) + if v not in val_to_names: + val_to_names[v] = [] + val_to_names[v].append(k) + for v in val_to_names.keys(): + restore_vars.extend(val_to_names[v]) + stack_values.extend([v] * len(val_to_names[v])) + + return stack_values, restore_vars, meta + + def compile_subgraph( + self, + tx: "InstructionTranslatorBase", + reason: GraphCompileReason, + partial_convert=False, + stack_pops=0, + ): + """ + Compiles the current subgraph, with inputs w.r.t. self.root_tx, and codegens: + - Call the compiled subgraph + - Apply side effects + - Codegen stack and locals + - Store the locals + + Python does not allow NULL to be an arg to a function, so we do not codegen NULLs on the stack, + unless the value is one of the top `stack_pops` values on the stack (these values are expected to be + popped immediately after this generated code. The prologue of the resume function is expected to restore + any dropped NULLs. + + Returns stack indices and locals keys where we dropped NULLs, and where we found inactive context manager objects. + """ + + assert self.root_tx is not None + + # FIXME temporary assert to make sure we're not accidentally compiling nested graph breaks + # before we're done the full implementation + assert self.root_tx is tx + + # bytecode tracing has finished. Pop the context manager for dynamo_timed + self.mark_bytecode_tracing_stop() + + self.partial_convert = partial_convert + self.compile_subgraph_reason = reason + self.should_exit = True + + log.debug("COMPILING GRAPH due to %s", reason) + + # prefix instructions (Python 3.11+) + prefix_insts: list[Instruction] = [] + if sys.version_info >= (3, 11): + for inst in tx.prefix_insts: + if inst.opname == "MAKE_CELL": + prefix_insts.append( + create_instruction("MAKE_CELL", argval=inst.argval) + ) + elif inst.opname == "COPY_FREE_VARS": + prefix_insts.append( + create_instruction( + "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"]) + ) + ) + else: + prefix_insts.append(copy.copy(inst)) + self.add_output_instructions(prefix_insts) + + assert not (self.pregraph_bytecode and self.export), ( + "export does not support pregraph_bytecode" + ) + self.add_output_instructions(self.pregraph_bytecode) + + alias_insts, overridden_sources = self.handle_aliases_for_stolen_lists( + self.root_tx + ) + self.add_output_instructions(alias_insts) + + # Exit from all context manager variables to make sure global state is restored + for block in reversed(self.root_tx.block_stack): + block.exit(self.root_tx, is_graph_break=reason.graph_break) + + self.cleanup_graph() + + # stack values and restore vars for each frame are pushed in reverse order + # i.e. last element corresponds to root frame, first element corresponds to current frame + all_stack_values = [] + all_restore_vars = [] + all_stack_locals_metas = [] + cur_tx: Optional[InstructionTranslatorBase] = tx + while True: + assert cur_tx is not None + # this should have been checked by the caller + assert all(block.can_restore() for block in cur_tx.block_stack) + stack_values, restore_vars, meta = self._get_stack_values_to_restore( + cur_tx, stack_pops + ) + all_stack_values.append(stack_values) + all_restore_vars.append(restore_vars) + all_stack_locals_metas.append(meta) + if cur_tx is self.root_tx: + break + cur_tx = tx.parent + + # Use nn.Module "proxies" in the constructed GraphModule so that + # the resulting GM does not hold additional strong references to the original modules. + # This prevents a strong ref cycle where Dynamo created code holds on to references + # to modules that also have Dynamo code cache invalidation checks. + # When cache invalidation runs, the generated GM will be invalidated, which also deletes + # the proxies. + nn_modules_proxies = { + name: nn_module_proxy(mod) for name, mod in self.nn_modules.items() + } + root = FakeRootModule(nn_modules_proxies) + + from .decorators import disable + + # to handle random calls + if len(self.random_calls) > 0: + random_calls_instructions = [] + self.random_values_var = self.new_var("random_values") + rand_fn = disable( + _get_gen_rand_values_fn(self.random_calls), + reason="do not trace into Dynamo rng recovery function", + ) + rand_fn_name = self.install_global("__gen_rand_values", rand_fn) + codegen = PyCodegen( + self.root_tx, root, overridden_sources=overridden_sources + ) + random_calls_instructions.extend( + codegen.load_function_name(rand_fn_name, True) + ) + random_calls_instructions.extend(create_call_function(0, False)) + random_calls_instructions.append( + codegen.create_store(self.random_values_var), + ) + self.add_output_instructions(random_calls_instructions) + + # call compiled fx graph + graph_output_var = None + stored_graph_output_var = False + root_stack_values = all_stack_values[-1] + if ( + self.root_tx is tx + and root_stack_values + and all( + not isinstance( + v, + ( + UnspecializedPythonVariable, + NumpyNdarrayVariable, + TensorWithTFOverrideVariable, + ), + ) + and not (isinstance(v, SymNodeVariable) and v.python_type() is float) + for v in root_stack_values + ) + and all(isinstance(x, TensorVariable) for x in root_stack_values) + and len(set(root_stack_values)) == len(root_stack_values) + and self.side_effects.is_empty() + and not tx.debug_locals + and not self.backward_state + and not all_stack_locals_metas[-1].stack_null_idxes + and not all_stack_locals_metas[-1].locals_null_keys + ): + # optimization to generate better code in a common case + self.add_output_instructions( + self.compile_and_call_fx_graph( + tx, list(reversed(root_stack_values)), root + ) + + [create_instruction("UNPACK_SEQUENCE", arg=len(root_stack_values))] + ) + else: + graph_output_var = self.new_var("graph_out") + # load stack values in a flat manner for now - will likely change later. + stack_values_flat = [ + val for vals in reversed(all_stack_values) for val in vals + ] + pass1 = PyCodegen( + self.root_tx, + root, + graph_output_var, + overridden_sources=overridden_sources, + ) + self.codegen_suffix(tx, stack_values_flat, pass1) + + # Use `pass1.uses` to selectively cache multi-user variables into a + # temporary local source. This (a). speeds up loading VTs with long + # chained source, and (b). avoids redundantly saving single-user VT + # into a temporary local. + tempvars = {} # type: ignore[var-annotated] + for val, count in pass1.uses.items(): + # If it's already a local source, no need to cache it + if count > 1 and not istype(val, (SyntheticLocalSource, LocalSource)): + tempvars[val] = None + pass2 = PyCodegen( + self.root_tx, + root, + graph_output_var, + tempvars=tempvars, + overridden_sources=overridden_sources, + ) + self.codegen_suffix(tx, stack_values_flat, pass2) + + output = [] + if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: + output.extend( + self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) + ) + + if len(pass2.graph_outputs) != 0: + output.append(pass2.create_store(graph_output_var)) + stored_graph_output_var = True + else: + output.append(create_instruction("POP_TOP")) + else: + # NB: Important to run compiler collective even when there is + # a graph break + self.run_compiler_collective() + self.add_output_instructions(output + pass2.get_instructions()) + + # restore all the live local vars of the root + local_restore_cg = PyCodegen( + self.root_tx, overridden_sources=overridden_sources + ) + # TODO this local restoration should be removed when fully implementing nested graph breaks + self.add_output_instructions( + [ + local_restore_cg.create_store(var) + for var in reversed(all_restore_vars[-1]) + ] + ) + + if graph_output_var and stored_graph_output_var: + self.add_output_instructions( + [local_restore_cg.create_delete(graph_output_var)] + ) + + return all_stack_locals_metas + + def codegen_suffix(self, tx, stack_values, cg): + # NOTE: `codegen_save_tempvars` must run first to update `source` fields + # for variables with `AttributeMutationNew`, as they don't implement + # `reconstruct` themselves. + self.side_effects.codegen_save_tempvars(cg) + if self.backward_state: + assert not self.export + for name, val in self.backward_state.items(): + cg(val) + cg.append_output(cg.create_load(self.backward_state_var)) + cg.store_attr(name) + self.side_effects.codegen_hooks(cg) + + # Return variables used for logging at the end + for debug_var, args in tx.debug_locals: + cg.add_push_null(lambda: cg(debug_var)) + for arg in args: + cg(arg) + cg.extend_output(create_call_function(len(args), False)) + cg.extend_output([create_instruction("POP_TOP")]) + + cg.restore_stack(stack_values, value_from_source=not tx.export) + self.side_effects.codegen_update_mutated(cg) + + def cleanup_graph(self): + """ + Remove "creation_timestamp" from node meta + + Remove this pattern from the graph: + torch._C._set_grad_enabled(False) + torch._C._set_grad_enabled(True) + """ + assert self.should_exit + nodes = list(self.graph.nodes) + for node in nodes: + node.meta.pop("creation_timestamp", None) + + grad_enabled = torch.is_grad_enabled() + for node1, node2 in zip(nodes, nodes[1:]): + if ( + node1.target is torch._C._set_grad_enabled + and tuple(node1.args) == (not grad_enabled,) + and not node1._erased + ): + grad_enabled = node1.args[0] + if ( + node2.target is torch._C._set_grad_enabled + and tuple(node2.args) == (not grad_enabled,) + and not node2._erased + ): + grad_enabled = node2.args[0] + self.graph.erase_node(node1) + self.graph.erase_node(node2) + + def get_graph_sizes_structured(self): + ret = {} + for node in self.graph.nodes: + example_value = node.meta.get("example_value", None) + if isinstance(example_value, torch._subclasses.FakeTensor): + size = example_value.size() + ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size] + return ret + + def get_graph_sizes(self, name: str): + graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n" + graph_sizes_str += f"===== {name} =====\n" + for node in self.graph.nodes: + example_value = node.meta.get("example_value", None) + if isinstance(example_value, torch._subclasses.FakeTensor): + size = example_value.size() + graph_sizes_str += f"{node.name}: {tuple(size)}\n" + concrete_size = [] + has_symint = False + for sz in size: + if isinstance(sz, int): + concrete_size.append(sz) + elif isinstance(sz, torch.SymInt): + has_symint = True + concrete_size.append(sz.node.hint) + else: + break + else: + if has_symint: + graph_sizes_str += ( + f"{node.name} (concrete): {tuple(concrete_size)}\n" + ) + return graph_sizes_str + + @contextlib.contextmanager + def restore_global_state(self): + """ + Momentarily restores the global state to what it was prior to tracing the current output + """ + prior_global_state = self.tracing_context.global_context.copy_graphstate() + current_global_state: dict[str, tuple[Any, bool]] = {} + self.save_global_state(out=current_global_state) + try: + # Set to state prior to tracing the graph + self.tracing_context.global_context.restore_graphstate(prior_global_state) + yield + finally: + # Reset to state at the current time (e.g. before calling the user compiler) + self.tracing_context.global_context.restore_graphstate( + GlobalContextCheckpointState(current_global_state) + ) + + def run_compiler_collective(self): + tx = self.root_tx + assert tx is not None + if (ds := tx.distributed_state) is not None and ds.all_states is None: + compile_pg = ds.compile_pg + log.info("compiler_collective %s", ds.local_state) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "compiler_collective", + "encoding": "string", + }, + payload_fn=lambda: ds.local_state.render(), + ) + device_types = compile_pg._device_types + assert len(device_types) == 1, ( + "Expect only one device type but got {}".format("+".join(device_types)) + ) + with ( + get_interface_for_device(device_types.pop()).device( # type: ignore[attr-defined] + compile_pg.rank() % torch.accelerator.device_count() + ), + dynamo_timed("compiler_collective", log_pt2_compile_event=True), + ): + all_states = [None] * compile_pg.size() + dist.all_gather_object(all_states, ds.local_state, group=compile_pg) + ds.all_states = all_states + # Clear speculation log, because are tracing may diverge due to + # this information from the compiler collective + tx.speculation_log.clear() + raise exc.CompileCollectiveRestartAnalysis + + def compile_and_call_fx_graph(self, tx, rv, root): + """ + Generate code from self.graph and return the Instruction()s to + call that generated code. + + Code is generated w.r.t. self.root_tx. + tx is only used for preserving GraphModule metadata + """ + with torch._guards.TracingContext.clear_frame(): + from .decorators import disable + + assert self.should_exit + + self.run_compiler_collective() + + name = unique_id("__compiled_fn", with_uuid=True) + + assert isinstance(rv, list) + assert isinstance(root, FakeRootModule) + + output_node = self.create_node( + "output", + "output", + (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),), + {}, + ) + sub_gms = self.dedup_pass() + root.add_nn_modules(sub_gms) + + self.current_tracer._maybe_preserve_original_meta(tx, output_node) + if not config.do_not_emit_runtime_asserts: + # There is a rare scenario where codegen_suffix adds a new entry + # to self.nn_modules while `root` knows only about the + # nn_modules at the time of its creation. This causes failures + # while creating the graph module because self.graph and root + # are out of sync. This only happens for `get_attr` nodes, so + # here we clean up the get_attr nodes that are unused. + self.remove_unused_get_attr_nodes() + insert_deferred_runtime_asserts( + fx.GraphModule(root, self.graph), + self.shape_env, + name, + export=self.export, + ) + # NB: deferred runtime asserts can keep graphargs live, so make sure + # those are inserted before pruning + self.remove_unused_graphargs() + ncalls = count_calls(self.graph) + counters["stats"]["calls_captured"] += ncalls + + self.remove_tensorify_specialized_graphargs() + + # free a bit of memory + self.real_value_cache.clear() + + gm = _make_graph_module(root, self.graph) + + # Saved tensors hooks are not used by the graph. + # GraphModule by default only copies used in the graph submodules. + # Copying them into the result graph manually. + if self.saved_tensors_hooks_subgraph_names: + for subgraph_name in self.saved_tensors_hooks_subgraph_names: + setattr(gm, subgraph_name, getattr(root, subgraph_name)) + + for register_finalizer in self.register_finalizer_fns: + register_finalizer(gm) + + gm._backend_id = name + gm.compile_subgraph_reason = self.compile_subgraph_reason + gm.meta["dynamo_flat_name_to_original_fqn"] = ( + self.dynamo_flat_name_to_original_fqn.copy() + ) + gm.meta["dynamo_compile_id"] = self.dynamo_compile_id + + graph_code_log.debug( + "%s", + lazy_format_graph_code( + name, gm, include_stride=True, include_device=True, colored=True + ), + ) + torch._logging.trace_structured( + "dynamo_output_graph", + lambda: {"sizes": self.get_graph_sizes_structured()}, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + self.call_cleanup_hooks() + old_fake_mode = self.tracing_context.fake_mode + if not self.export: + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting + backend_fake_mode = torch._subclasses.FakeTensorMode( + shape_env=old_fake_mode.shape_env, + ) + # TODO(voz): Ostensibily, this should be scoped and + # restore back to old_fake_mode, but doing so currently violates + # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode + self.tracing_context.fake_mode = backend_fake_mode + + with self.restore_global_state(): + compiled_fn = self.call_user_compiler(gm, self.example_inputs()) + + from torch.fx._lazy_graph_module import _LazyGraphModule + + if isinstance(compiled_fn, _LazyGraphModule) or ( + isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule) + and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined] + ): + # Since dynamo will run the forward method for the GraphModule shortly + # anyways, it does not hurt to do the real recompilation here if + # this is a _LazyGraphModule. This makes it easier for dynamo to + # optimize a _LazyGraphModule. + + lazy_gm = ( + compiled_fn + if isinstance(compiled_fn, _LazyGraphModule) + else compiled_fn.__self__ # type: ignore[attr-defined] + ) + + _LazyGraphModule.force_recompile(lazy_gm) + + if not isinstance(compiled_fn, _LazyGraphModule): + # replace compiled_fn with the real forward method + compiled_fn = lazy_gm.forward + + if self.package is not None: + self.package.add_backend_id(name, compiled_fn) + + compiled_fn = disable( + compiled_fn, reason="do not trace Dynamo-compiled graph" + ) + + counters["stats"]["unique_graphs"] += 1 + if specializations := old_fake_mode.shape_env.specializations: + specialization_guards = [] + specialization_cache: dict[Specialization, Callable[[Any], Any]] = {} + sources = [a.source for a in self.graphargs] + for specialization in specializations: + source_index = sources.index(specialization.source) + check_fn_source = inspect.getsource(specialization.check_fn).strip() + check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined] + specialization.check_fn, + [check_fn_source], + ) + + log.debug( + "Compiling backend specialized graph with specialization=%s", + check_fn_source, + ) + + specialization_guards.append( + ( + functools.partial( + lambda idx, args, check_fn=check_fn: check_fn( + args[idx] + ), + source_index, + ), + specialization, + ) + ) + + @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") + def specialized_dispatch(*args, **kwargs): + for check_fn, specialization in specialization_guards: + if check_fn(args): + if specialization in specialization_cache: + return specialization_cache[specialization]( + *args, **kwargs + ) + + with self.shape_env.patch_source_specialization( + specialization.source, specialization.check_fn + ): + # Modify gm so AOTAutogradCache key changes per specialization + gm.meta["specialization"] = specialization + example_inputs: list[Tensor] = list(args) + with tracing(self.tracing_context): + specialization_cache[specialization] = ( + self.call_user_compiler(gm, example_inputs) + ) + + return specialization_cache[specialization](*args, **kwargs) + return compiled_fn(*args, **kwargs) + + # This is safe because we pre-process name to be unique + self.install_global_unsafe(name, specialized_dispatch) + else: + # This is safe because we pre-process name to be unique + self.install_global_unsafe(name, compiled_fn) + + assert self.root_tx is not None + cg = PyCodegen(self.root_tx) + cg.make_call_generated_code(name) + return cg.get_instructions() + + @property + def placeholders(self) -> list[fx.Node]: + return self.graph.find_nodes(op="placeholder") + + @property + def graphargs(self) -> list[GraphArg]: + return [node.meta["grapharg"] for node in self.placeholders] + + def call_user_compiler( + self, gm: fx.GraphModule, example_inputs: list[Tensor] + ) -> CompiledFn: + with dynamo_timed( + "OutputGraph.call_user_compiler", + phase_name="backend_compile", + log_pt2_compile_event=True, + log_waitcounter=True, + waitcounter_name_override="compile_aot_autograd", + dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us", + ): + return self._call_user_compiler(gm, example_inputs) + + def _call_user_compiler( + self, gm: fx.GraphModule, example_inputs: list[Tensor] + ) -> CompiledFn: + assert self.compiler_fn is not None + tot = 0 + placeholders = [] + for node in gm.graph.nodes: + if node.op in ("call_function", "call_method", "call_module"): + tot += 1 + if node.op == "placeholder": + placeholders.append(node) + increment_op_count(tot) + for pl in placeholders: + if not hasattr(pl, "_dynamo_source"): + arg = pl.meta["grapharg"] + # TODO: Why isn't this stored in meta :think: + # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 + pl._dynamo_source = arg.source + + # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 + gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment] + gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment] + + name = ( + self.compiler_fn.__name__ + if hasattr(self.compiler_fn, "__name__") + else "" + ) + try: + _step_logger()(logging.INFO, f"calling compiler function {name}") + compiler_fn = self.compiler_fn + if config.verify_correctness: + compiler_fn = WrapperBackend(compiler_fn) + compiled_fn = compiler_fn(gm, example_inputs) + _step_logger()(logging.INFO, f"done compiler function {name}") + assert callable(compiled_fn), "compiler_fn did not return callable" + except (TensorifyScalarRestartAnalysis, ShortenTraceback): + raise + except exceptions_allowed_to_be_fallback as e: + if self.has_user_defined_allowed_in_graph: + raise BackendCompilerFailed( + self.compiler_fn, e, inspect.currentframe() + ).with_traceback(e.__traceback__) from None + unimplemented_v2_with_warning( + e, + self.root_tx.f_code, + gb_type="Backend compiler exception", + context=f"Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}", + explanation=f"Backend compiler `{name}` failed with {str(e)}. Adding a graph break.", + hints=[ + "Report an issue to the backend compiler repo.", + ], + ) + except SkipFrame as e: + # The backend compiler has requested that we skip the frame, instead of + # aborting execution. + raise e + except Exception as e: + raise BackendCompilerFailed( + self.compiler_fn, e, inspect.currentframe() + ).with_traceback(e.__traceback__) from None + + signpost_event( + "dynamo", + "OutputGraph.call_user_compiler", + { + **self.co_fields, + "op_count": tot, + "node_count": len(gm.graph.nodes), + "input_count": len(placeholders), + }, + ) + + return compiled_fn + + def dedup_pass(self): + if torch._dynamo.config.use_graph_deduplication: + return apply_graph_deduplication(self) + else: + return {} + + def install_subgraph(self, name, sub_gm): + next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True) + sub_gm.__name__ = next_name + sub_gm.torchdynamo_force_dynamic = False + # This graph module is not present in the user space, so it can't be + # accessed by a source. Set source=None. + self.register_attr_or_module(sub_gm, next_name, source=None) + return next_name + + def example_inputs(self) -> list[torch.Tensor]: + result = [arg.example for arg in self.graphargs] + return result + + def remove_unused_get_attr_nodes(self) -> None: + for node in sorted(self.graph.find_nodes(op="get_attr"), reverse=True): + if len(list(node.users)) == 0: + self.remove_node(node) + + def remove_unused_graphargs(self) -> None: + # NB: It's OK to drop GraphArg for symbols that ended up being + # specialized iff they are not used in runtime assertions. You don't + # even have to make a guard for it, because ShapeEnv produce_guards + # operates on tracked_fakes, which never gets pruned. + # That being said, you'll get marginally better generated + # guard code if you promote the guard into a Dynamo guard (since that + # allows for the guard to be done using C++ guards.) If we get + # ShapeEnv guards to go into C++ guards, this will stop being a thing + # though! + + assert self.should_exit + + # Miniature DCE pass, but only for obviously trivial operations + def is_static_true(b_node: fx.node.Argument): + if b_node is True: + return True + if not isinstance(b_node, fx.Node): + return False + b = b_node.meta.get("example_value") + if b is None: + return False + if b is True: + return True + if ( + isinstance(b, torch.SymBool) + and (r := b.node.maybe_as_bool()) is not None + ): + return r + # TODO: We can also technically remove all cases when the input + # doesn't have unbacked inputs, since it's all in the ShapeEnv + return False + + def is_symnode_arg(a: fx.node.Argument): + from torch.fx.experimental.sym_node import SymTypes + + if isinstance(a, (int, float, bool)): + return True + if isinstance(a, fx.Node): + return isinstance(a.meta.get("example_value"), SymTypes) + return False + + # NB: We assume that you cannot do mutations on int/float/bool, + # because they are immutable types, and therefore is always safe to + # DCE. + def is_symnode_compute_node(node): + from torch.fx.experimental.sym_node import SymTypes + + if node.op != "call_function": + return False + # TODO: I don't think it's possible to have a bare int/float here? + if not isinstance(node.meta.get("example_value"), SymTypes): + return False + # TODO: This will bail here if you ever end up with a more complicated + # computation function, like sum(list_of_ints), even though it + # should be DCE'able + if not all(is_symnode_arg(a) for a in node.args): + return False + if not all(is_symnode_arg(a) for a in node.kwargs.values()): + return False + return True + + from torch.fx.experimental.symbolic_shapes import is_accessor_node + + for node in reversed(list(self.graph.nodes)): + if len(list(node.users)) == 0: + if ( + node.op == "get_attr" + or (node.op == "call_function" and node.target is operator.getitem) + or ( + node.op == "call_function" + and node.target is torch._check + and is_static_true(node.args[0]) + ) + or is_symnode_compute_node(node) + or is_accessor_node(node) + ): + self.remove_node(node) + + def placeholder_binds_symbol(node): + arg = node.meta["grapharg"] + example = arg.example + if isinstance(example, torch.SymInt) and isinstance( + example.node.expr, sympy.Symbol + ): + return example.node.expr + return None + + def remove_unused(node): + log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name()) + # I'm not really sure why you need to delete these from the + # node since the node is going to get removed + del node.meta["grapharg"] + self.remove_node(node) + self.real_value_cache.pop(node, None) + + used_symbols: set[sympy.Symbol] = set() + + def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]): + used_symbols |= free_symbols(fake) + + recheck_placeholders = [] + for node in self.placeholders: + binds_symbol = placeholder_binds_symbol(node) is not None + # Don't delete symbol bindings yet + if binds_symbol: + if not node.users: + recheck_placeholders.append(node) + else: + if not node.users and not isinstance( + node.meta["grapharg"], BackwardStateGraphArg + ): + remove_unused(node) + else: + # Register the free symbols as uses + arg = node.meta["grapharg"] + if isinstance(arg, BackwardStateGraphArg): + continue + if isinstance(node.meta["grapharg"].example, torch.ScriptObject): + real_script_obj = node.meta["grapharg"].example + fake_script_obj = node.meta["grapharg"].example_strong_ref + if not torch._library.fake_class_registry.tracing_with_real( + real_script_obj + ): + flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined] + for attr in flat_dict.keys(): + fake_attr_val = getattr( + fake_script_obj.wrapped_obj, attr + ) + pytree.tree_map_only( + (torch.SymInt, torch.Tensor), + lambda t: update_used_symbols(used_symbols, t), + fake_attr_val, + ) + continue + fake = ( + arg.fake_tensor if arg.fake_tensor is not None else arg.example + ) + update_used_symbols(used_symbols, fake) + + # After removing unused graphargs, prune unused binds_symbol + for node in recheck_placeholders: + symbol = placeholder_binds_symbol(node) + if symbol is not None: + if symbol not in used_symbols: + remove_unused(node) + else: + # Make sure we delete later occurrences of the same symbol + used_symbols.remove(symbol) + + def remove_tensorify_specialized_graphargs(self) -> None: + # This is a pretty interesting function. Basically we have this problem + # where our compiler tends to choke when we have unused inputs. The way + # we support dynamic float arguments is by doing a joint fx pass and + # tensorifying away as many symfloats as we can. For the remaining symfloats + # we have no choice but to specialize... HOWEVER at that point in time + # we can no longer remove graph inputs. So our sledgehammer solution is to + # save the state of what inputs we should have specialized in dynamo and + # restart analysis. This function incorporates this "view from the future" + # state and specializes inputs that we know we won't be able to tensorify + # away in the joint pass. In principle we shouldn't choke on unused inputs + # and so this shouldn't be necessary. In practice CUDA graphs choke on + # unused inputs so we need this for now. + + # Import here to prevent circular import + from torch._dynamo.symbolic_convert import TensorifyState + + for node in self.graph.nodes: + example_value = node.meta.get("example_value") + if ( + isinstance(example_value, FakeTensor) + and example_value.item_memo is not None + and hasattr(example_value.item_memo.node._expr, "name") + and all(u.target == "item" for u in node.users) + and TensorifyState.should_specialize( + # We use _expr instead of expr b/c we want the symbol not the replacement + example_value.item_memo.node._expr.name + ) + ): + for u in list(node.users): + u.replace_all_uses_with(guard_scalar(example_value.item_memo)) + self.remove_node(u) + self.remove_node(node) + + def add_output_instructions(self, prefix: list[Instruction]) -> None: + """ + We call this on the creation of a new compiled subgraph that is inserted + before user code. + """ + self.output_instructions.extend(prefix) + self.should_exit = True + + def install_global_unsafe(self, name, value) -> None: + """ + WARNING: prefer the safer `install_global_by_id/install_global`. + torch.compile instances should be independent of each other; + one footgun is to have one instance depend on the existence of + a global installed by another instance. This can happen if we mangle + a global the same way across both instances. + """ + assert name not in self.installed_globals + self.installed_globals.add(name) + self.cleanups.append(CleanupHook.create(self.global_scope, name, value)) + + def install_global_by_id(self, prefix, value) -> str: + """ + Installs a global if it hasn't been installed already. + This is determined by (prefix, id(value)) pair. + + Returns the name of the newly installed global. + """ + # NB: need self.compile_id to distinguish this global + # from another global created in a different torch.compile instance + name = f"{prefix}_{id(value)}_c{self.compile_id}" + if name in self.installed_globals: + return name + self.install_global_unsafe(name, value) + return name + + def install_global(self, prefix, value) -> str: + """ + Installs a global, generating a unique name for it. + + Returns the name of the newly installed global. + """ + # NB: unique_id is unique, even across torch.compile instances + name = unique_id(prefix) + self.install_global_unsafe(name, value) + return name + + def cleanup(self) -> None: + # There is a reference cycle between tracer and OutputGraph, causing + # some of the tensor objects to be held alive for longer than necessary. + self.root_tx = None + self.nn_modules.clear() + self.param_name_to_source = None + + for node in self.graph.nodes: + if "grapharg" in node.meta: + del node.meta["grapharg"] + self.real_value_cache.clear() + self.input_name_to_proxy.clear() + self.side_effects.clear() + self.variable_tracker_cache.clear() + self.register_finalizer_fns.clear() + self.dynamo_flat_name_to_original_fqn.clear() + self.tracing_context.clear() + self.input_source_to_var.clear() + self.unspec_variable_map.clear() + self.backward_state.clear() + + def add_graph_finalizer( + self, register_finalizer: Callable[[fx.GraphModule], None] + ) -> None: + self.register_finalizer_fns.append(register_finalizer) + + def example_value_from_input_node(self, node: torch.fx.Node): + """Extract the non-fake example tensor""" + if node.op == "placeholder": + return node.meta["grapharg"].example + assert node.op == "get_attr" + return self.nn_modules[node.target] # type: ignore[index] + + +err_epilogue = ( + "With the current config, we will graph break " + "(and fall back to eager-mode PyTorch) on all ops " + "that have do not have the 'pt2_compliant_tag'. " + "Please see the following doc for how to mark this op as PT2 compliant " + "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html" +) + + +def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): + if kind != "call_function": + return + + def encountered_compliant_op(target): + if target.namespace in {"prim", "prims", "aten"}: + return + output_graph.compliant_custom_ops.add(target) + + def encountered_non_compliant_op(target, msg): + output_graph.non_compliant_ops.add(target) + if config.only_allow_pt2_compliant_ops: + unimplemented_v2( + gb_type="Encountered non-PT2-compliant op", + context="", + explanation=msg + " " + err_epilogue, + hints=[], + ) + + if isinstance(target, torch._ops.OpOverload): + if torch.Tag.pt2_compliant_tag in target.tags: + encountered_compliant_op(target) + return + encountered_non_compliant_op( + target, + f"Encountered the torch.ops.OpOverload {target} that is not PT2 compliant.", + ) + return + + if isinstance(target, torch._ops.OpOverloadPacket): + overloads = tuple(target.overloads()) + # Optimization: Overload resolution is expensive. + # If there's only one overload, we know what it will resolve to. + if len(overloads) == 1: + op = getattr(target, overloads[0]) + if torch.Tag.pt2_compliant_tag in op.tags: + encountered_compliant_op(op) + return + encountered_non_compliant_op( + op, + f"Encountered the non-overloaded " + f"torch.ops.OpOverloadPacket {target} " + f"that is not PT2 compliant. ", + ) + return + + args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes( + output_graph.current_tx, (args, kwargs), False + ) + try: + overload = torch._C._jit_resolve_packet( + target._qualified_op_name, *args, **kwargs + ) + except RuntimeError as e: + unimplemented_v2( + gb_type="Error when attempting to resolve op packet", + context="", + explanation=str(e), + hints=[], + ) + + op = getattr(target, overload) + if torch.Tag.pt2_compliant_tag in op.tags: + encountered_compliant_op(op) + else: + encountered_non_compliant_op( + op, + f"Encountered the torch.ops.OpOverloadPacket {target} " + f"which resolves to the overload ({overload}) that is " + f"not PT2 compliant.", + ) + + +_compile_id_counter = itertools.count() + + +class LazyProxy: + def __init__(self, tracer, fn, *args, **kwargs): + self.tracer = tracer + self.fn = fn + self.args = args + self.kwargs = kwargs + + def __call__(self): + return self.fn(*self.args, **self.kwargs) + + +class SubgraphTracer(fx.Tracer): + """ + Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer + and the separation of responsibilities is that SubgraphTracer is + responsible for building the graph while OutputGraph is responsible for + compiling and executing the graph. + """ + + def __init__(self, output_graph, parent=None, is_export=False, source_target=None): + super().__init__() + self.output_graph = weakref.proxy(output_graph) + self.graph = torch.fx.Graph() + + # See note [Export inputs must be explicitly passed in] + self.is_export = is_export + # Map from graph input name to its placeholder proxy object, where the + # map's keys give all current placeholder node names and can be used to + # create unique node names + self.input_name_to_proxy: dict[str, fx.Proxy] = {} + # Node => computed real value (see utils.get_real_value) + self.real_value_cache: dict[fx.Node, torch.Tensor] = {} + + # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design] + self.parent = parent + self.source_target = source_target + # A dict mapping previously free variables (Proxy objects) + # to new Proxy objects that wrap inputs to this subgraph. + # + # This dict maps proxies in outer graphs to placeholders in current graph. + # It serves two purposes: + # - Proxies are associated with VariableTrackers. If we see + # the same VariableTracker twice (and it is a free variable), + # then we want to use the same Proxy in the current subgraph to + # record the tracing. + # - If we are tracing a HigherOrderOperator's body_fn, then we + # need to keep track of what free variables were lifted so we can + # rewrite the HigherOrderOperator call using the traced body_fn. + # Dicts maintain the order of args for the HigherOrderOperator call. + self.lifted_freevars = {} + + # map basic symbols (unbacked and unbacked) to their bound proxies. + # There are only two cases where bound_symbols will be recorded: + # 1. when we create_graph_input for a backed SymInt that's basic symbol + # 2. when we track_unbacked_symbols for intermediate results that contain unbacked symints. + self.bound_symbols: dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {} + + self.prev_inst = None + # True if this tracer is currently tracing into torch.utils.checkpoint + # as part of speculate_subgraph. + self.under_activation_checkpoint = False + # True if we want to allow externally visible side-effects (doesn't throw error on their existence) + # during this tracer's tracing of torch.utils.checkpoint (via speculate_subgraph). + # Only safe if we know for sure that *NOT* replaying these side-effects during + # backward recomputation of the checkpoint region doesn't affect its correctness. + self.allow_side_effects_under_checkpoint = False + # True if we want to allow externally visible side-effects (doesn't throw error on their existence) + # during this tracer's tracing. This is currently only used by experimental AC out-of-tree + # via torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer. + # Note: Externally visible side-effects are allowed if this flag OR the above flag is True. + self.unsafe_allow_externally_visible_side_effects = False + + # True if this tracer is currently tracing (reconstructing) into a Python generator + self.is_reconstructing_generator = False + + self.debug_level: int = parent.debug_level + 1 if parent is not None else 0 + + self._cur_code = None + self._orig_gm_meta = None + self._orig_gm_lineno_map = None + self._orig_gm_firstlineno = None + # Each SubgraphTracer is associated with a source target, which indicates + # which operator this subgraph is attached to. We compute a source_fn_stack + # based on the source target. For the root tracer, it's set to []. + # This is useful for debugging and transforming the exported graph. + if self.parent is None: + self.source_fn_stack = [] + else: + self.source_fn_stack = self.parent.source_fn_stack + [ + (self.graph._target_to_str(source_target), source_target) + ] + + # This is used to create a unique name for the placeholder + self._used_names: OrderedSet[str] = OrderedSet() + # Stores the versions of the input tensors at the time they are inserted + # as placeholders in the graph. This is used to track input mutation. + self._input_versions_at_beginning: list[int] = [] + if torch.is_inference_mode_enabled(): + raise RuntimeError( + "Inference mode is supposed to be disabled during compilation. Please open an issue." + ) + + # preserve original meta if it is available + def _maybe_preserve_original_meta(self, tx, node): + if ( + self._orig_gm_meta + and self._orig_gm_lineno_map + and self._orig_gm_firstlineno + ): + lineno = tx.current_instruction.starts_line + node_idx = None + if lineno is not None: + node_idx = self._orig_gm_lineno_map.get( + lineno - self._orig_gm_firstlineno, None + ) + if node_idx is not None: + meta = self._orig_gm_meta[node_idx] + for field in fx.proxy._COPY_META_FIELDS: + if field in meta: + node.meta[field] = meta[field] + if "stack_trace" in meta: + node.meta["stack_trace"] = meta["stack_trace"] + + def create_proxy( + self, + kind, + target, + args, + kwargs, + name=None, + type_expr=None, + proxy_factory_fn=None, + ): + # NOTE: [Nested SubgraphTracer and free_variable handling] + # -------------------------------------------------------- + # Read NOTE [HigherOrderOperator tracing design] first. + # + # Let's say we're in the middle of introspecting the body of a possibly + # nested HigherOrderOperator, and we see a free variable. + # + # There are two cases: + # 1. We see a free variable that is already tracked by Dynamo. + # 2. We see a free variable that has not been tracked by Dynamo + # + # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below) + # which will lift the freevar to be an input of this subgraph + # and also recursively lift it to be an input on the parent(s). + # + # In case 2, before the call to `create_proxy`, the InstructionTranslator + # will see the freevar when it gets loaded by Python bytecode. + # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or + # LOAD_GLOBAL. + # There, the InstructionTranslator asks Dynamo to begin tracking the + # freevar by building a new Variable. + # Building a new Variable automatically lifts the freevar to be an + # input of the root SubgraphTracer. + # + # The implications for the code below are: + # - We will always be in Case 1 when we get to this code. + # - Any "free variable" we encounter here is guaranteed to already be + # bound, that is, it is either a graph input of the root graph, or + # some local variable of the root graph or a subgraph. + # - The additional work we need to do here is *only* that we need to + # lift this free variable into inputs (recursively) of each nested + # higher-order-op subgraph until we hit the subgraph where the free + # variable is bound + if self.parent is not None: + flat_args, tree_spec = pytree.tree_flatten((args, kwargs)) + new_flat_args = [] + for arg in flat_args: + maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg) + new_flat_args.append(maybe_new_arg) + + args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec) + + rv = super().create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) + + # append stack trace to fx node + tx = self.output_graph.current_tx + + # log detailed location of line of code in 3.11 + if sys.version_info >= (3, 11) and kind in ( + "call_function", + "call_method", + "call_module", + ): + cur_inst = tx.current_instruction + if ( + cur_inst is not self.prev_inst + and cur_inst.positions is not None + and cur_inst.positions.lineno is not None + ): + tx_code = tx.f_code + header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno) + + def get_trace_call_log_str(): + line = get_instruction_source_311(tx_code, cur_inst).rstrip() + return f"TRACE FX call {rv.node.name} from {header}\n{line}" + + trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) + self.prev_inst = cur_inst + + # update reference to original meta if we're tracing a new code object + is_retracing = False + if tx.f_code is not self._cur_code: + orig_graphmodule_maybe = code_context.get_context(tx.f_code).get( + "orig_graphmodule", lambda: None + )() + if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule): + is_retracing = True + self._orig_gm_meta = [ + nd.meta for nd in orig_graphmodule_maybe.graph.nodes + ] + self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map + self._orig_gm_firstlineno = ( + orig_graphmodule_maybe.forward.__code__.co_firstlineno + ) + else: + self._orig_gm_meta = None + self._orig_gm_lineno_map = None + self._orig_gm_firstlineno = None + nn_module_stack = tx.nn_module_stack + if nn_module_stack: + rv.node.meta["nn_module_stack"] = nn_module_stack.copy() + + if kind in {"call_function", "call_method"}: + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + (rv.node.name, target) + ] + elif kind == "call_module": + if self.parent is not None: + # TODO can remove once inline_inbuilt_nn_modules is always True + unimplemented_v2( + gb_type="Invoking an nn.Module inside a higher order operator", + context=f"Higher order op name: {self.source_target}", + explanation="This is not supported.", + hints=[], + ) + # For modules we store the class + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + ( + rv.node.name, + next( + ty + for k, (_, ty) in rv.node.meta["nn_module_stack"].items() + if k.split("@")[0] == target + ), + ) + ] + + self._maybe_preserve_original_meta(tx, rv.node) + + if not is_retracing: + if "nn_module_stack" not in rv.node.meta: + nn_module_stack = tx.nn_module_stack + if nn_module_stack: + rv.node.meta["nn_module_stack"] = nn_module_stack.copy() + + if "source_fn_stack" not in rv.node.meta: + if kind in {"call_function", "call_method"}: + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + (rv.node.name, target) + ] + elif kind == "call_module": + if self.parent is not None: + # TODO can remove once inline_inbuilt_nn_modules is always True + unimplemented_v2( + gb_type="Invoking an nn.Module inside a HigherOrderOperator", + context="", + explanation="This is not supported.", + hints=[], + ) + # For modules we store the class + rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ + ( + rv.node.name, + rv.node.meta["nn_module_stack"][target][1], + ) + ] + + if "stack_trace" not in rv.node.meta: + frame_summaries: list[traceback.FrameSummary] = [] + while tx: + # Avoid frame summaries from inside the torch/nn/modules. This ensures that we keep the stack trace of + # the user code. + if not tx.is_co_filename_from_nn_modules(): + frame_summaries.append(tx.frame_summary()) + tx = getattr(tx, "parent", None) + # Reverse the frame_summaries, such that the innermost frame is at the last + frame_summaries.reverse() + + # official from_list stub doesn't have new-style type + msgs = traceback.StackSummary.from_list(frame_summaries).format() + rv.node.stack_trace = "".join(msgs) + + if ( + torch._dynamo.config.use_graph_deduplication + or torch._dynamo.config.track_nodes_for_deduplication + ): + self.output_graph.region_tracker.track_node( + self.output_graph.current_tx, rv.node + ) + return rv + + def create_node( + self, op, target, args=None, kwargs=None, name=None, type_expr=None + ): + check_pt2_compliant_op(self.output_graph, op, target, args, kwargs) + if self.parent is not None: + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + for arg in flat_args: + if not isinstance(arg, torch.fx.Node): + continue + assert arg.graph == self.graph, ( + "create_node using arg not from this SubgraphTracer" + ) + + node = super().create_node(op, target, args, kwargs, name, type_expr) + node.meta["creation_timestamp"] = self.output_graph.timestamp + self._used_names.add(node.name) + return node + + # Note: we did not override erase_node since + # we call self.graph.erase_node elsewhere + def remove_node(self, node): + if len(node.users) > 0: + user_graph_nodes: list[torch.fx.Node] = [] + for user in node.users.keys(): + # For the case where user.graph == self.graph, that is a real bug and will raise + # properly. + if user.graph != self.graph: + # This is a nested graph, which needs to be deleted. + # If we do not do this, we will raise on attempting to remove this. + # As we only get here during restoration cleanup, this is sound. + user_graph_nodes.extend(reversed(list(user.graph.nodes))) + for other_graph_node in user_graph_nodes: + other_graph_node.graph.erase_node(other_graph_node) + self.graph.erase_node(node) + self.input_name_to_proxy.pop(node.name, None) + + # when before=True, we will insert this input before the most recent + # inserted proxy. This is a hack to get around an ordering problem, + # where we first insert a tensor argument, and then insert bindings + # for SymInts that may occur in the tensor argument. + # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets + # fixed. + def create_graph_input( + self, name, type_expr, example_value, before=False, source=None + ): + if isinstance(example_value, torch.Tensor): + self._input_versions_at_beginning.append(example_value._version) + log.debug( + "create_graph_input %s %s %s at debug_level %s before=%s", + name, + source.name() if source is not None else "(none)", + example_value, + self.debug_level, + before, + ) + if source is None: + assert self.parent is not None, ( + f"you are required to provide a source for inputs {name} example_val {example_value} on the root tracer" + ) + + # Note [Export inputs must be explicitly passed in] + # In eager, we are generally OK with adding graph inputs whenever we + # want, because we take care of writing the bytecode that knows how + # to source all the inputs. + # + # In export, this is bad, because you want a self-contained export + # object which only depends on the inputs you explicitly passed to it. + # So we are a bit more strict about what sources can become inputs + # in export + if self.is_export and self.parent is None: + if not is_from_local_source(source, only_allow_input=True): + self.output_graph.source_to_user_stacks.setdefault(source, []).append( + TracingContext.extract_stack() + ) + + # _used_names contains the names of all the nodes in the graph, + # including intermediates. This ensures that we do not have a name + # collision. + name = get_unique_name_wrt(name, self._used_names) + if self.input_name_to_proxy: + prev_name = next(reversed(self.input_name_to_proxy)) + node = self.input_name_to_proxy[prev_name].node + if before: + ctx = self.graph.inserting_before(node) + else: + ctx = self.graph.inserting_after(node) + else: + ctx = self.graph.inserting_before(None) + with ctx: + proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) + set_example_value(proxy.node, example_value) + if self.input_name_to_proxy and before: + k, v = self.input_name_to_proxy.popitem() + self.input_name_to_proxy[name] = proxy + self.input_name_to_proxy[k] = v + else: + self.input_name_to_proxy[name] = proxy + + # For placeholder nodes, `name` is passed as a str to the target, + # and then torch.fx decides the node.name. So, record the `target` + # name as well in the _used_names to prevent any collision. + self._used_names.add(name) + + # NOTE: [Auto lift basic free symbols when create_graph_input] + # Whenever we call create_graph_input, we try to also lift the basic symbols in example values + # as graph input. + # This applies to both top-level graph and subgraphs in higher order ops. + # It has several cases: + # 1. When create_graph_input for a tensor that has symbolic shapes, + # we look for basic symbols in its size and stride, we check if the symbol is bound + # in current graph (i.e. bound_symbols), it it's not bound, we'll create a placeholder + # for it then recursively check its parent, creates ph if not bound. + # Every tracer maintains a mapping (i.e. lifted_freevars) + # that maps from parent proxy to proxy in current tracer for the symbol. + # 2. When create_graph_input for a tensor with unbacked symbolic shapes, + # Backed symbols all come from inputs's symbolic shape. But unbacked symbols + # can be created while tracing. So we use track_unbacked_symbols will intercept + # at wrap_fx_proxy, and try to bind the unbacked symbols immediately after they're + # created. + # 3. subgraph will also lifted basic symbols in compound exprs of tensor shape. + # For example, if an input to subgraph takes size [s1+s2//8], we'll look for the + # the free symbols in the sizes and lift as inputs similar to 1 in _lift_symbols_in_symint) + # 4. When create_graph_input for a SymInt, if the symint is a basic symbol, we'll track it + # in bound_symbols so that we don't lift the same basic symbol twice. When the symint is a + # compound expr, we'll just create the proxy for the compouned expr but not lift its basic symbols. + # Also see NOTE: [Export inputs must be explicitly passed in] + is_strict_export = self.is_export + is_non_strict_export = torch.compiler.is_compiling() + if not is_strict_export and not is_non_strict_export: + if isinstance(example_value, torch.Tensor): + self._lift_basic_symbols(example_value, source) + elif isinstance(example_value, (list, tuple)): + for i, e in enumerate(example_value): + if not isinstance(e, torch.Tensor): + continue + + e_source = None + if source: + e_source = GetItemSource( + base=source, index=i, index_is_slice=False + ) + + self._lift_basic_symbols(e, e_source) + + # Bound the symbol to ph if example_value is a SymInt with basic symbol. + if isinstance(example_value, torch.SymInt) and isinstance( + example_value.node.expr, sympy.Symbol + ): + self.bound_symbols[example_value.node.expr] = proxy + return proxy + + # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details + def lift_tracked_freevar_to_input(self, proxy): + # You're doing something wrong if we are the root SubgraphTracer because + # Dynamo adds tensors to graph inputs before creating a proxy for them. + assert self.parent is not None, ( + "lift_tracked_freevar_to_input should not be called on root SubgraphTracer" + ) + + example_value = proxy.node.meta["example_value"] + + # To avoid lifting the same symbol twice, we check whether basic symbols has been tracked. + # For example, the basic symbols may have already been lifted for current subgraph when + # we automatically lift basic symbols in the sizes/strides of a tensor t. + # Suppose parent graph calls sz = t.size()[0], it creates + # a proxy in parent and the subgraph accesses sz via closure. sz's proxy is not tracked + # in current sub-tracer so we may lift the same symbol twice. + if ( + isinstance(example_value, torch.SymInt) + and example_value.node.expr in self.bound_symbols + ): + return self.bound_symbols[example_value.node.expr] + + # Proxies are associated with VariableTracker. + # It is possible that we've already lifted the Proxy to be an input. + # If that is the case, just return the already lifted Proxy. + if proxy in self.lifted_freevars: + return self.lifted_freevars[proxy] + + # We first lift proxy to parent's graph then lift to current grpah's input + # so that when we bind symints of the sizes in current graph, those symints + # would already be lifted as inputs to parent graph. + if proxy.tracer != self.parent: + self.parent.lift_tracked_freevar_to_input(proxy) + + example_value = proxy.node.meta["example_value"] + new_proxy = self.create_graph_input( + proxy.node.name, type(example_value), example_value + ) + self.lifted_freevars[proxy] = new_proxy + return new_proxy + + def maybe_lift_tracked_freevar_to_input(self, arg): + """ + If arg is a free variable, then lift it to be an input. + Returns the new lifted arg (if arg was a freevar), else the + original arg. + """ + if not isinstance(arg, torch.fx.Proxy): + # Note: arg can be a python built-in slice type e.g. + # x[:max_seq] is represented as get_item(t, (slice(None, max_seq, None))) + # we need to also look into the slice variable itself to lift the + # proxies there. + if isinstance(arg, slice): + return slice( + *( + self.maybe_lift_tracked_freevar_to_input(sub_arg) + for sub_arg in (arg.start, arg.stop, arg.step) + ) + ) + else: + return arg + elif arg.tracer == self: + return arg + return self.lift_tracked_freevar_to_input(arg) + + # See NOTE: [Auto lift basic free symbols when create_graph_input] for overall design + # You MUST call this API every time when creating a proxy in wrap_fx_proxy for a call + # that produced unbacked symints or tensors with unbacked symint shapes. + # This function is used to track the unbacked symints with its proxies created during + # dynamo tracing so that subgraph knows how to bind a symbol input with parent's proxy. + # LazyProxy are created for tensor shapes that're unbacked so that we don't create proxies + # for symbols that're not going to be used. + def track_unbacked_symbols( + self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy] + ): + # When binding the symbols in an exmaple_value, we bind the symbols + # to the proxy's associated Tracer instead of current tracer. + # This is because: + # 1. We may be calling wrap_tensors during speculate_subgraph because + # the variables are lazily realized. The proxy are top-level phs but + # current tracer is a subtracer. + # 2. For autograd.Function, we trace the backward graph with a new tracer + # whose parent is the forward tracer, but we're using all the proxies created + # in forward tracer to trace the backward. + # For example, forward calls save_for_backward for a input tensor t. + # Backward calls t.tolist(). In this case, all the proxies that backward tracer + # sees are from parent tracer (i.e. the forward tracer). (e.g. t[0].item()) + # See test_validate_outputs_unbacked for repro on 2. + tracer = e_proxy.tracer + assert isinstance(tracer, SubgraphTracer) + + def need_bind(s) -> bool: + from torch.fx.experimental.symbolic_shapes import is_symbolic + + return ( + is_symbolic(s) + and isinstance(s.node.expr, sympy.Symbol) + and s.node.shape_env.is_unbacked_symint(s.node.expr) + and s.node.expr not in self.bound_symbols + ) + + def _proxy_with_example_value(example_value, *args, **kwargs): + proxy = tracer.create_proxy(*args, **kwargs) + set_example_value(proxy.node, example_value) + return proxy + + if isinstance(example_value, torch.Tensor): + for i, s in enumerate(example_value.size()): + if need_bind(s): + log.debug( + "_track_unbacked_symbols %s for %s.size()[%s] at debug_level %s", + s, + e_proxy, + i, + tracer.debug_level, + ) + lazy_proxy = LazyProxy( + tracer, + _proxy_with_example_value, + s, + "call_function", + torch.ops.aten.sym_size.int, + (e_proxy, i), + {}, + type_expr=type(s), + ) + self.track_unbacked_symbols(s, lazy_proxy) + + if example_value.layout is torch.strided: + for i, s in enumerate(example_value.stride()): + if need_bind(s): + log.debug( + "_track_unbacked_symbols %s for %s.stride()[%s] at debug_level %s", + s, + e_proxy, + i, + tracer.debug_level, + ) + lazy_proxy = LazyProxy( + tracer, + _proxy_with_example_value, + s, + "call_function", + torch.ops.aten.sym_stride.int, + (e_proxy, i), + {}, + type_expr=type(s), + ) + self.track_unbacked_symbols(s, lazy_proxy) + + elif example_value.layout is torch.sparse_coo: + self.track_unbacked_symbols(example_value._indices(), e_proxy) + self.track_unbacked_symbols(example_value._values(), e_proxy) + elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}: + self.track_unbacked_symbols(example_value.crow_indices(), e_proxy) + self.track_unbacked_symbols(example_value.col_indices(), e_proxy) + elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}: + self.track_unbacked_symbols(example_value.ccol_indices(), e_proxy) + self.track_unbacked_symbols(example_value.row_indices(), e_proxy) + if is_traceable_wrapper_subclass(example_value): + attrs, ctx = example_value.__tensor_flatten__() + for attr in attrs: + inner_t = getattr(example_value, attr) + self.track_unbacked_symbols(inner_t, getattr(e_proxy, attr)) + elif isinstance(example_value, torch.SymInt): + # Only bind unbacked symbols. backed symbols are lifted as inputs. + if need_bind(example_value): + expr = example_value.node.expr + tracer.bound_symbols[expr] = e_proxy + + # See Note [Auto lift basic free symbols when create_graph_input] + def _lift_basic_symbols( + self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source] + ): + # The before arg is for inserting symints in the sizes/strides of a tensor + # before the tensor. This ordering ensures that when we look at the tensor's + # symbols, they're already lifted/tracked. E.g. this assumption is used + # in insert_deferred_runtime_asserts. + def _lift_symbols_in_symint( + s: Union[int, torch.SymInt], + source: Optional[Source], + before: bool = False, + ) -> None: + if not is_symbolic(s): + return + + assert isinstance(s, torch.SymInt) + self_to_be_bound = self.lookup_unbound_symbols(s) + if len(self_to_be_bound) == 0: + return + + # For subgraph + if self.parent is not None: + # Recursively lift symbols in symint until top-level. + self.parent._lift_basic_symbols(s, source) + for s0 in self_to_be_bound: + parent_proxy = self.parent.bound_symbols[s0] + example_val = parent_proxy.node.meta["example_value"] + assert isinstance(example_val, torch.SymInt) + ph = self.create_graph_input( + str(s0), + type(example_val), + example_val, + before=before, + source=source, + ) + log.debug( + "_lift_symbols_in_symint %s from %s at debug_level %s", + s0, + source.name() if source is not None else "subgraph inputs", + self.debug_level, + ) + self.lifted_freevars[parent_proxy] = ph + # For root_tracer: + else: + assert len(self_to_be_bound) == 1, ( + f"For root tracer, we only expect to bind basic symbols (compound symbols " + f"should be cached before) but got unbound symbols {self_to_be_bound} in {s}" + ) + assert source is not None, ( + f"Source of '{s}' is None when lifting it to input of top-level. If it's an unbacked symbol, " + "this could be because it's not tracked with lazy_bind_unbacked_symbols. " + f"Otherwise, should provide a source when create_graph_input for `{s}` at root tracer." + ) + s0 = next(iter(self_to_be_bound)) + ph = self.create_graph_input( + str(s0), + type(s), + s, + before=before, + source=source, + ) + log.debug( + "_lift_symbols_in_symint %s from %s at debug_level %s", + s, + source.name() if source is not None else "subgraph inputs", + self.debug_level, + ) + ph.node.meta["grapharg"] = GraphArg( + source, + s, + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + ) + + if isinstance(example_value, torch.Tensor): + for i, s in enumerate(example_value.size()): + _lift_symbols_in_symint( + s, + ( + TensorPropertySource(src, TensorProperty.SIZE, i) + if src is not None + else None + ), + before=True, + ) + if example_value.layout is torch.strided: + for i, s in enumerate(example_value.stride()): + _lift_symbols_in_symint( + s, + ( + TensorPropertySource(src, TensorProperty.STRIDE, i) + if src is not None + else None + ), + before=True, + ) + _lift_symbols_in_symint( + example_value.storage_offset(), + ( + TensorPropertySource(src, TensorProperty.STORAGE_OFFSET) + if src is not None + else None + ), + before=True, + ) + elif example_value.layout is torch.sparse_coo: + self._lift_basic_symbols(example_value._indices(), src) + self._lift_basic_symbols(example_value._values(), src) + elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}: + self._lift_basic_symbols(example_value.crow_indices(), src) + self._lift_basic_symbols(example_value.col_indices(), src) + elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}: + self._lift_basic_symbols(example_value.ccol_indices(), src) + self._lift_basic_symbols(example_value.row_indices(), src) + if is_traceable_wrapper_subclass(example_value): + attrs, ctx = example_value.__tensor_flatten__() + for attr in attrs: + inner_t = getattr(example_value, attr) + self._lift_basic_symbols( + inner_t, AttrSource(src, attr) if src is not None else None + ) + elif isinstance(example_value, torch.SymInt): + _lift_symbols_in_symint( + example_value, + src, + ) + + # Lookup the proxy in current tracer for each symbol in expressions of s, + # See Note [Auto lift basic free symbols when create_graph_input] + def lookup_unbound_symbols(self, s: torch.SymInt) -> list[sympy.Symbol]: + free_symbols = s.node.expr.free_symbols + if len(free_symbols) == 0: + return [] + + to_be_bound = [] + for s0 in free_symbols: + if s0 not in self.bound_symbols: + to_be_bound.append(s0) + continue + + proxy = self.bound_symbols[s0] + if isinstance(proxy, LazyProxy): + proxy = proxy() + self.bound_symbols[s0] = proxy + assert isinstance(proxy, torch.fx.Proxy) and proxy.tracer is self, ( + f"The proxy of symbol {s0} doesn't belong to current tracer." + ) + # Sort the symbols so that we can have a deterministic lifting order + return sorted(to_be_bound, key=lambda s: s.name) + + def has_input_mutation(self): + input_versions_at_beginning = self._input_versions_at_beginning + input_nodes = [] + + input_versions_at_end = [] + for node in self.graph.nodes: + if node.op == "placeholder": + example_value = node.meta["example_value"] + if isinstance(example_value, torch.Tensor): + input_versions_at_end.append(example_value._version) + input_nodes.append(node) + else: + break + + mutated_inputs = [ + i + for i, (v1, v2) in enumerate( + zip(input_versions_at_beginning, input_versions_at_end) + ) + if v1 != v2 + ] + + if len(mutated_inputs): + mutated_nodes = [input_nodes[i] for i in mutated_inputs] + msg = f"Input mutation detected at {mutated_nodes}" + return MutationInfo(True, msg) + + return MutationInfo(False, "") + + def has_aliasing(self): + from torch._higher_order_ops.utils import _collect_fake_inputs + + input_storages: dict[StorageWeakRef, torch.fx.Node] = dict() + + for node in self.graph.nodes: + if node.op == "placeholder": + example_value = _collect_fake_inputs([node])[0] + if isinstance(example_value, torch.Tensor): + storage = StorageWeakRef(example_value._typed_storage()) + if storage in input_storages: + # input-input aliasing + msg = f"Input-to-input aliasing detected at nodes {input_storages[storage]} and {node}" + return AliasingInfo(True, msg) + input_storages[storage] = node + else: + break + + output_storages: dict[StorageWeakRef, torch.fx.Node] = dict() + out_nodes = self.graph.find_nodes(op="output")[0] + for out_node in pytree.tree_leaves(out_nodes.args[0]): + if out_node: + example_value = _collect_fake_inputs([out_node])[0] + assert not isinstance(example_value, list) + if isinstance(example_value, torch.Tensor): + storage = StorageWeakRef(example_value._typed_storage()) + if storage in output_storages: + # output-output aliasing + msg = f"Output-to-output aliasing detected at nodes {output_storages[storage]} and {out_node}" + return AliasingInfo(True, msg) + output_storages[storage] = out_node + + intersected_storages = input_storages.keys() & output_storages.keys() + if len(intersected_storages) > 0: + # input-output aliasing + aliased = [ + (input_storages[s], output_storages[s]) for s in intersected_storages + ] + aliased = ", ".join([f"{i} and {o}" for i, o in aliased]) + msg = f"Input-to-output aliasing detected at nodes {aliased}" + return AliasingInfo(True, msg) + + return AliasingInfo(False, "") + + +# NOTE: [HigherOrderOperator tracing design] +# Ignoring HigherOrderOperators for a moment, +# OutputGraph represents the graph being built by Dynamo that may be compiled +# and executed. It holds a root SubgraphTracer where the FX graph is built. +# +# HigherOrderOperators are operators that take functions as their arguments. +# When Dynamo encounters a HigherOrderOperator, then it attempts to introspect +# the function passed to it (call this the "body function"), capture it into a +# GraphModule, and rewrite the call to the HigherOrderOperator to use the +# GraphModule. +# +# The way we handle the capture of body functions is through having +# (possibly nested) SubgraphTracers, one per body function. +# +# Mechanically, we do the introspection by: +# - Creating a new SubgraphTracer via OutputGraph.subtracer +# - Executing the body function. +# This constructs the graph of the body function in the new SubgraphTracer +# while modifying the state of the OutputGraph. For example: +# - the OutputGraph can receive new GraphArgs (if we discover any new +# untracked Tensors) +# - side effects from the body function get accumulated into +# OutputGraph.side_effects +# - guards produced by the body function get accumulated into OutputGraph.guards +# +# The traced function has some special properties that make it easier for us +# to transform later down the line: +# - we lift all free variables to being inputs. +# +# If the introspection fails (due to the existence of graph breaks), then +# we roll back the current OutputGraph state and graph break on the +# HigherOrderOperator. diff --git a/phivenv/Lib/site-packages/torch/_dynamo/package.py b/phivenv/Lib/site-packages/torch/_dynamo/package.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a22839993a4e623740dc32a3aee8e054994190 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/package.py @@ -0,0 +1,413 @@ +""" +This module provides the infrastructure for creating and managing compile package +for torch.compile. We mainly have two abstractions here: + - CompilePackage: Overarching data structure for store and lookup a list of compiled codes. + - CodeCacheEntry: Data structure for a single code being compiled by torch.compile. +The caching behavior is always under user control explicitly so that a stronger guarantee can +be provided about cache hit for a specific compiled model. Users can load the compile package +from a different process or host. +""" + +import contextlib +import dataclasses +import functools +import hashlib +import importlib +import logging +import os +import pickle +import platform +import sys +import types +from collections.abc import Generator +from typing import Any, NewType, Optional + +import torch +import torch._inductor.package +from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext +from torch.compiler._cache import CacheArtifactFactory + +from .bytecode_transformation import get_code_keys + + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class SerializedCode: + co_argcount: int + co_posonlyargcount: int + co_kwonlyargcount: int + co_nlocals: int + co_stacksize: int + co_flags: int + co_code: bytes + co_consts: tuple[Any, ...] + co_names: tuple[str, ...] + co_varnames: tuple[str, ...] + co_filename: str + co_name: str + co_firstlineno: int + co_cellvars: tuple[str, ...] + co_freevars: tuple[str, ...] + co_linetable: Optional[bytes] = None + co_qualname: Optional[str] = None + co_exceptiontable: Optional[bytes] = None + co_lnotab: Optional[str] = None + + @classmethod + @functools.cache + def from_code_object(cls, code: types.CodeType) -> "SerializedCode": + kwargs = {key: getattr(code, key) for key in get_code_keys()} + kwargs["co_consts"] = tuple( + cls.from_code_object(c) if isinstance(c, types.CodeType) else c + for c in kwargs["co_consts"] + ) + return cls(**kwargs) + + @classmethod + @functools.cache + def to_code_object(cls, serialized_code: "SerializedCode") -> types.CodeType: + kwargs = {key: getattr(serialized_code, key) for key in get_code_keys()} + kwargs["co_consts"] = tuple( + cls.to_code_object(c) if isinstance(c, SerializedCode) else c + for c in kwargs["co_consts"] + ) + return types.CodeType( + *kwargs.values(), + ) + + +@dataclasses.dataclass +class _GuardedCodeCacheEntry: + """ + Contains the serializable information associated with a single compilation in dynamo. + To restore an execution of compiled code, we will need to serialize the following data: + - Dynamo bytecode for mapping Python inputs/outputs. + - Dynamo guards. + """ + + guards_state: bytes + dynamo_code: SerializedCode + + +_BackendId = NewType("_BackendId", str) # __compiled_fn +_FunctionId = NewType("_FunctionId", str) # __resume_at + + +@dataclasses.dataclass +class _DynamoCodeCacheEntry: + """ + Contains the serializable information associated with a single code object + in dynamo. To restore an execution of compiled code, we will need the following + ingredients: + 1. The "original" code object, which serves as the entry point for eager + execution, i.e. the code only executed when there's no cache entry hit. + 2. The python module name this code object belongs to, for identifying the + enclosing global scope to inject compiled and resume functions. + 3. A list of function names that pointing to this code object. There could be + multiple function objects pointing to the same code such as recursive functions. + 4. A list of guarded code that eval frame dispatches to. + 5. A list of imported module objects unioned from all compiled branches. + 6. A list of "backends" (compiled fx graph) unioned from all compield branches. + """ + + python_code: SerializedCode + python_module: str + function_names: list[_FunctionId] + guarded_codes: list[_GuardedCodeCacheEntry] + import_sources: dict[str, str] + backend_ids: list[_BackendId] + + +@dataclasses.dataclass +class _DynamoCacheEntry: + codes: list[_DynamoCodeCacheEntry] + python_version: str = platform.python_version() + torch_version: str = torch.__version__ + + @property + def backend_ids(self) -> set[_BackendId]: + return {backend_id for code in self.codes for backend_id in code.backend_ids} + + +@CacheArtifactFactory.register +class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]): + @staticmethod + def type() -> str: + return "precompile_dynamo" + + def after_deserialization(self) -> _DynamoCacheEntry: + return pickle.loads(self.content) + + +class CompilePackage: + """ + CompilePackage is considered a low level component and should not be directly exposed to + end users. It has the following interface: + + 1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states. + a. when `dynamo` argument is None, it will construct a brand new CompilePackage object. + b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state. + 2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object. + 3. `package.install(backends) which will handle all the side-effectful global scope + updates with compiled functions and resume functions. + """ + + def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None: + self._innermost_fn = None + self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {} + + self._current_entry: Optional[_DynamoCodeCacheEntry] = None + self._installed_globals: dict[types.ModuleType, list[str]] = {} + + # For debugging/testing purpose only. + self._cached_backends: dict[_BackendId, Any] = {} + + self._initialize(fn, dynamo) + self.uninstall() + self.validate() + + def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None: + from .eval_frame import innermost_fn + + self._innermost_fn = innermost_fn(fn) + assert self._innermost_fn is not None + if dynamo is not None: + assert isinstance(dynamo, _DynamoCacheEntry) + if dynamo.python_version != platform.python_version(): + raise RuntimeError( + f"Compile package was created with a different Python version: {dynamo.python_version}" + ) + if dynamo.torch_version != torch.__version__: + raise RuntimeError( + f"Compile package was created with a different PyTorch version: {dynamo.torch_version}" + ) + + main, *codes = dynamo.codes + self._codes = {self._innermost_fn.__code__: main} + for code in codes: + self._codes[SerializedCode.to_code_object(code.python_code)] = code + else: + self._add_function( + self._innermost_fn.__code__, self._innermost_fn.__module__ + ) + + def _add_function( + self, + python_code: types.CodeType, + python_module: str, + name: Optional[_FunctionId] = None, + ) -> None: + if python_code not in self._codes: + code = _DynamoCodeCacheEntry( + python_code=SerializedCode.from_code_object(python_code), + python_module=python_module, + function_names=[], + guarded_codes=[], + import_sources={}, + backend_ids=[], + ) + self._codes[python_code] = code + else: + code = self._codes[python_code] + assert code.python_module == python_module + + if name is not None: + code.function_names.append(name) + + @property + def cached_backends(self) -> dict[_BackendId, Any]: + return self._cached_backends + + @functools.cached_property + def source_id(self) -> str: + assert self._innermost_fn is not None + sha256_hash = hashlib.sha256() + sha256_hash.update(self._innermost_fn.__qualname__.encode()) + sha256_hash.update(str(self._innermost_fn.__code__.co_firstlineno).encode()) + return sha256_hash.hexdigest() + + @contextlib.contextmanager + def code_context(self, code: types.CodeType) -> Generator[None, None, None]: + assert self._current_entry is None + + entry = self._codes[code] + self._current_entry = entry + try: + yield + finally: + self._current_entry = None + + def add_guarded_code( + self, + guards_state: bytes, + dynamo_code: types.CodeType, + ) -> None: + assert self._current_entry is not None + guarded_code_entry = _GuardedCodeCacheEntry( + guards_state=guards_state, + dynamo_code=SerializedCode.from_code_object(dynamo_code), + ) + self._current_entry.guarded_codes.append(guarded_code_entry) + + def add_resume_function( + self, + python_code: types.CodeType, + python_module: str, + name: Optional[str], + ) -> None: + self._add_function( + python_code, python_module, _FunctionId(name) if name else None + ) + + def add_import_source(self, alias: str, module_name: str) -> None: + assert self._current_entry is not None + self._current_entry.import_sources[alias] = module_name + + def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None: + assert self._current_entry is not None + assert backend_id.startswith("__compiled_fn_") # sanity check + backend_id = _BackendId(backend_id) + self._current_entry.backend_ids.append(backend_id) + if backend is not None: + self._cached_backends[backend_id] = backend + + def validate(self) -> None: + assert self._current_entry is None + assert self._innermost_fn is not None + assert next(iter(self._codes)) is self._innermost_fn.__code__ + + def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None: + module.__dict__[name] = value + self._installed_globals.setdefault(module, []).append(name) + + def uninstall(self) -> None: + from torch._C._dynamo.eval_frame import _reset_precompile_entries + + assert self._innermost_fn is not None + for module, names in self._installed_globals.items(): + for name in names: + module.__dict__.pop(name) + + self._installed_globals = {} + + _reset_precompile_entries(self._innermost_fn.__code__) + + def install(self, backends: dict[_BackendId, Any]) -> None: + """ + Sync the package states to the compiled function. This includes the following actions: + 1. Clean up the previously installed states. + 2. Install the compiled functions to global scopes. + 3. Install the precompiled cache entries to ExtraStates on the code object. + """ + from torch._C._dynamo.eval_frame import _load_precompile_entry + + self.uninstall() + + for code, entry in self._codes.items(): + module = sys.modules[entry.python_module] + for alias, module_name in entry.import_sources.items(): + self._install_global( + module, alias, importlib.import_module(module_name) + ) + for function_name in entry.function_names: + fn = types.FunctionType(code, module.__dict__, function_name) + self._install_global(module, function_name, fn) + for backend_id in entry.backend_ids: + if backend_id not in backends: + raise RuntimeError( + f"Backend {backend_id} is not found in the given backends" + ) + backend = backends[backend_id] + self._install_global( + module, + backend_id, + torch._dynamo.disable(backend), + ) + + for code, entry in self._codes.items(): + for guarded_code in entry.guarded_codes: + guards_state = pickle.loads(guarded_code.guards_state) + assert isinstance(guards_state, torch._dynamo.guards.GuardsState) + check_fn_manager = torch._dynamo.guards.CheckFunctionManager( + code, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + ) + _load_precompile_entry( + code, + check_fn_manager.guard_manager, + SerializedCode.to_code_object(guarded_code.dynamo_code), + ) + + def cache_entry(self) -> _DynamoCacheEntry: + self.validate() + return _DynamoCacheEntry(codes=list(self._codes.values())) + + +@CacheArtifactFactory.register +class EagerCacheArtifact(PrecompileCacheArtifact[Any]): + @staticmethod + def type() -> str: + return "precompile_eager" + + def after_deserialization(self) -> Any: + return pickle.loads(self.content) + + +class DynamoStore: + """ + A DynamoStore tracks active CompilePackages, and provides methods to store and retrieve them. + """ + + def record_package(self, package: CompilePackage) -> None: + """Records a package to PrecompileContext, so that it can be serialized later.""" + cache_entry = package.cache_entry() + pickled_result = pickle.dumps(cache_entry) + PrecompileContext.record_artifact( + _DynamoCacheArtifact.type(), key=package.source_id, content=pickled_result + ) + + def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None: + """Records eager fx graphs to PrecompileContext for testing purposes.""" + pickled_result = pickle.dumps(backend) + PrecompileContext.record_artifact( + EagerCacheArtifact.type(), key=backend_id, content=pickled_result + ) + + def save_package(self, package: CompilePackage, path: str) -> None: + """Saves a package to a given path. Grabs backends from PrecompileContext.""" + backend_content = {} + cache_entry = package.cache_entry() + for backend_id in cache_entry.backend_ids: + serialized_backend = PrecompileContext.serialize_artifact_by_key(backend_id) + if serialized_backend is None: + raise RuntimeError( + f"Backend {backend_id} is not found in the given backends" + ) + backend_content[backend_id] = serialized_backend + try: + with open(os.path.join(path, "dynamo"), "wb") as dynamo_path: + pickle.dump(cache_entry, dynamo_path) + with open(os.path.join(path, "backends"), "wb") as backend_path: + pickle.dump(backend_content, backend_path) + except Exception as e: + raise RuntimeError(f"Failed to save package to {path}: {e}") from e + + def load_package( + self, fn: Any, path: str + ) -> tuple[CompilePackage, dict[_BackendId, Any]]: + """Loads a package from a given path and returns it plus a list of deserialized backends""" + try: + with open(os.path.join(path, "dynamo"), "rb") as dynamo_path: + cache_entry = pickle.load(dynamo_path) + with open(os.path.join(path, "backends"), "rb") as backend_path: + backend_content = pickle.load(backend_path) + except Exception as e: + raise RuntimeError(f"Failed to load package from path {path}: {e}") from e + for backend_id, backend in backend_content.items(): + backend_content[backend_id] = backend.after_deserialization() + package = CompilePackage(fn, cache_entry) + return package, backend_content diff --git a/phivenv/Lib/site-packages/torch/_dynamo/pgo.py b/phivenv/Lib/site-packages/torch/_dynamo/pgo.py new file mode 100644 index 0000000000000000000000000000000000000000..f275dc1c218cf081548f69dc6e78d58b8d44d8e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/pgo.py @@ -0,0 +1,879 @@ +""" +Profile Guided Optimization (PGO) implementation for Dynamo. + +This module provides functionality for caching and managing code state profiles +that guide optimization decisions in Dynamo. It implements both local and remote +caching mechanisms for storing profile information across runs, handles profile +merging across distributed ranks, and manages the lifecycle of profile data +during compilation. The profiles track dynamic vs static properties of tensors +and help Dynamo make better specialization decisions. +""" + +from __future__ import annotations + +import base64 +import copy +import dataclasses +import enum +import functools +import logging +import os +import pickle +import re +import zlib +from collections import defaultdict +from typing import Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import override, Self + +import torch._dynamo.config +import torch._utils_internal +import torch.compiler.config +import torch.distributed as dist +from torch._dynamo.utils import ( + CompileEventLogger, + dynamo_timed, + set_feature_use, + warn_once, +) +from torch._environment import is_fbcode +from torch._logging._internal import trace_structured_artifact +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) +from torch.utils._ordered_set import OrderedSet + + +if TYPE_CHECKING: + import types + + from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._inductor.remote_cache import JsonDataTy, RemoteCache + + +class ReservedWorkflowIdUserError(ValueError): + pass + + +log = logging.getLogger(__name__) + +LOCK_TIMEOUT = 10 + +# How does in memory representation work? Concretely, this module is +# responsible for holding GLOBAL state representing the state it holds, no +# other copies permitted. So we retire frame_state entirely and store it +# here. This should be reset when Dynamo is reset. We never GC information +# (similar to how the filesystem doesn't get cleaned up except by tmp +# cleaner), so the expectation is the information is relatively cheap and we +# don't mind leaking it. + + +# How exactly did we design the cache key? Here are some of the questions: +# +# - JOB_ID: Do we have a unique identifier for the "training run" (such that +# it stays the same if we're running the same code, and changes if we're +# running something different). +# +# - RANK: Are we sharing the cache across ranks, or does each rank get +# an individual cache? +# +# We choose to require job_id for PGO cache. This is to prevent +# situations where unrelated invocations of PyTorch unpredictably cause +# changes to each other's behavior. With a job_id, at least you know there +# is some "state" associated with it. (State dict might be another way to +# tell if a run is related or not.) You can opt-in to YOLO everything +# aliases everything by passing a shared job_id for all your invocations. +# +# We choose to NOT share PGO cache across ranks. With no RANK_SHARING, there +# is never contention between runs, so we can leisurely update a bundle with +# information we need. Because we are grouped by job_id, we can have a single +# consolidated bundle for everything (or not; maybe worry about O(n^2) IO if +# we updated every compile--let's just instrument this.) Can even take a +# filelock for extra safety (expect no contention); expect 50ns overhead from +# uncontended filelock. +# +# If we did share ranks, everyone is storming to modify the same cache files. +# We can do this by having folks atomic write to a CAS-store and then having +# readers do on-the-fly merging (this can be implemented in remote using +# prefix iteration). As an optional optimization, one rank can be elected to +# handling bundling post facto (ideally, this is done async, after quiescence, +# without compiler collective need to wait for everyone to finish writing +# their bits.) Not sure how you can avoid a listdir because if some rank shows +# up with some new entries we need to pull them in ASAP (unless you want to +# delay bundling). +# +# But compiler collectives fill a similar niche: compilers chat with each +# other so rank 0 has collected everything. So elect rank 0 only to write the +# bundle. Don't even need CAS-store atomic write; just one rank writing an +# updating bundles. The point is that use compiler collectives to share +# profiles across ranks, but use the PGO cache to persist profiles per rank +# across attempts. No need to have one mechanism to do everything. + + +@functools.cache +def _hash_containing_file(filepath: str) -> str: + # if the file does not exists we consider filepath to be the hash. + if not os.path.exists(filepath): + return filepath + + with open(filepath, "rb") as file: + content = file.read() + crc32_value = zlib.crc32(content) + hash = format(crc32_value & 0xFFFFFFFF, "08x") + return hash + + +@dataclasses.dataclass(frozen=True) +class CodeId: + filename: str + firstlineno: int + name: str + # When a job restart, the code can be copied to a different path than the previous attempt. In that case + # self.filename will have a different value, we do not want to consider those differences. Instead we + # hash the content of the file and use it as an identifier of the file. + # + # self.filename is kept in the object to give readable information/pointer to the actual file, in a local + # code state it will refer to the first seen file path. + file_hash: str + + # Exclude file name. + def __eq__(self, other: object) -> bool: + if not isinstance(other, CodeId): + return False + return ( + self.file_hash == other.file_hash + and self.firstlineno == other.firstlineno + and self.name == other.name + ) + + # Ensure if two CodeIds are the same, then they have the same hash by excluding filename. + def __hash__(self) -> int: + return hash((self.file_hash, self.name, self.firstlineno)) + + def __str__(self) -> str: + return f"hash({self.file_hash}){self.filename}:{self.firstlineno}:{self.name}" + + @staticmethod + def make(code: types.CodeType) -> CodeId: + return CodeId( + code.co_filename, + code.co_firstlineno, + code.co_name, + _hash_containing_file(code.co_filename), + ) + + +@dataclasses.dataclass +class CodeState: + automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field( + default_factory=lambda: defaultdict(FrameStateSizeEntry) + ) + + +_INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None +_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None + + +@dataclasses.dataclass(frozen=True) +class InferStride: + """ + Denotes the quantity stride[dim] * size[dim], which is what the stride would + be for the next physical dimension that results in a contiguous layout. + + For example, given size = [2, 3], stride = [3, 1], we can replace this with + stride = [InferStride(1), 1], because InferStride(1) = stride[1] * size[1] = 1 * 3 = 3 + + Indirecting the representation in this way is important for the join operation + on strides as if we join [2, 3][3, 1] and [2, 4][4, 1], + we don't want [2, None][None, 1] which would get eventually symbolized into + [2, s0][s1, 1] (notice that the relationship between s0 and s1 is broken). + If we instead rewrite the expressions as InferStride so we have [2, 3][InferStride(1), 1] + and [2, 4][InferStride(1), 1] we now join to [2, None][InferStride(1), 1] will + result in [2, s0][s0, 1], as desired. + """ + + dim: int + + +_T = TypeVar("_T") + + +class AutoUnset(enum.Enum): + """ + The identity element of our semilattice, a generic "don't know" element that + is always subsumed when we get more information. + """ + + token = 0 + + +auto_unset = AutoUnset.token + + +class AutoDynamic(enum.Enum): + """ + The top element of our (bounded) semilattice, whenever you merge this with + any other element you always get it again + """ + + token = 0 + + +auto_dynamic = AutoDynamic.token + + +@dataclasses.dataclass +class FrameStateSizeEntry: + scalar: Union[int, AutoDynamic, AutoUnset] = dataclasses.field(default=auto_unset) + # NB: We don't have cases where we have a known dimensionality but + # we know NOTHING about the individual sizes + size: Union[AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic], ...]] = ( + dataclasses.field(default=auto_unset) + ) + stride: Union[ + AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic, InferStride], ...] + ] = dataclasses.field(default=auto_unset) + + def render(self) -> str: + # Special cases + def render_single(s: Union[int, AutoDynamic, AutoUnset, InferStride]) -> str: + if s is auto_dynamic: + return "?" + elif s is auto_unset: + # This basically shouldn't happen, this is for debugging + return "auto unset" + elif isinstance(s, InferStride): + return f"S({s.dim})" + else: + return str(s) + + def render_tuple(ss: tuple[Union[int, AutoDynamic, InferStride], ...]) -> str: + return "[" + ", ".join(render_single(s) for s in ss) + "]" + + # Common cases + if self.size is auto_dynamic and self.stride is auto_dynamic: + if self.scalar is auto_dynamic: + return "fully dynamic scalar or tensor" + else: + return f"scalar {self.scalar}" + elif self.scalar is auto_dynamic: + if isinstance(self.size, tuple) and isinstance(self.stride, tuple): + return f"tensor size={render_tuple(self.size)} stride={render_tuple(self.stride)}" + + # Fallback + return "unusual {repr(self)}" + + def __post_init__(self) -> None: + assert not isinstance(self.scalar, torch.SymInt), self.scalar + if isinstance(self.size, tuple): + for s in self.size: + assert not isinstance(s, torch.SymInt), s + if isinstance(self.stride, tuple): + for s1 in self.stride: + assert not isinstance(s1, torch.SymInt), s1 + + def is_size_dynamic(self, dim: int) -> bool: + if self.size is auto_dynamic: + return True + if self.size is auto_unset: + return False + return self.size[dim] is auto_dynamic + + def is_stride_dynamic(self, dim: int) -> bool: + # At the moment, dynamic strides is a bit buggy. Good test case + # here is `PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py + # TestAutograd.test_gradcheck_jacobian_mismatch` + # + # This if statement preserves historical behavior, which is that we + # ONLY make strides dynamic if the size is exactly static everywhere. + # We could potentially relax this but in general we should be very + # careful about when to infer dynamic strides. + # + # Actually, the existing algorithm is already somewhat problematic. + # Suppose a tensor that is sometimes: + # f32[2, 3, 5][15, 5, 1] and other times + # f32[2, 3, 5][5, 10, 1] (specifically, dim 0 and 1 are physically transposed). + # If we infer strides should be (DYNAMIC, DYNAMIC, 1). But this is + # silly: we really should have just guarded on dim order. + if not ( + isinstance(self.size, tuple) and all(type(s) is int for s in self.size) + ): + return False + if self.stride is auto_dynamic: + return True + if self.stride is auto_unset: + return False + return self.stride[dim] is auto_dynamic + + @staticmethod + def _munge_symint(xs: tuple[int, ...]) -> tuple[Union[AutoDynamic, int], ...]: + return tuple(auto_dynamic if isinstance(x, torch.SymInt) else x for x in xs) + + @classmethod + def make_scalar(cls, x: int) -> FrameStateSizeEntry: + return FrameStateSizeEntry(scalar=x, size=auto_dynamic, stride=auto_dynamic) + + @classmethod + def make_tensor( + cls, size: tuple[int, ...], stride: tuple[int, ...] + ) -> FrameStateSizeEntry: + return FrameStateSizeEntry( + scalar=auto_dynamic, + size=cls._munge_symint(size), + stride=cls._munge_symint(stride), + ) + + @classmethod + def make_size(cls, size: tuple[int, ...]) -> FrameStateSizeEntry: + return FrameStateSizeEntry( + scalar=auto_unset, + size=cls._munge_symint(size), + stride=auto_unset, + ) + + @staticmethod + def _merge_atom(x: _T, y: _T) -> Union[AutoDynamic, _T]: + if x is auto_unset: + return y + if y is auto_unset: + return x + if x is auto_dynamic or y is auto_dynamic or x != y: + return auto_dynamic + return x + + @classmethod + def _merge_atom_tup( + cls, + xs: Union[AutoDynamic, AutoUnset, tuple[_T, ...]], + ys: Union[AutoDynamic, AutoUnset, tuple[_T, ...]], + ) -> Union[AutoDynamic, AutoUnset, tuple[Union[AutoDynamic, _T], ...]]: + if xs is auto_unset: + return ys + if ys is auto_unset: + return xs + if xs is auto_dynamic or ys is auto_dynamic: + return auto_dynamic + if len(xs) != len(ys): + return auto_dynamic + return tuple(cls._merge_atom(x, y) for x, y in zip(xs, ys)) + + def __ior__(self, other: Self) -> Self: + self.scalar = self._merge_atom(self.scalar, other.scalar) + self.size = self._merge_atom_tup(self.size, other.size) + self.stride = self._merge_atom_tup(self.stride, other.stride) + return self + + +def update_automatic_dynamic( + tx: InstructionTranslator, + name: str, + entry: FrameStateSizeEntry, + *, + is_unspecialized_nn_module: bool = False, +) -> FrameStateSizeEntry: + code_id = CodeId.make(tx.f_code) + frame_state = get_code_state()[code_id] + if torch._dynamo.config.automatic_dynamic_shapes: + is_update = name in frame_state.automatic_dynamic + mut_entry = frame_state.automatic_dynamic[name] + old_entry = copy.copy(mut_entry) + mut_entry |= entry + + # Do some logs (damn, I spend more code logging than I do actually doing + # the updates lol) + if is_update and old_entry.scalar != mut_entry.scalar: + log.debug( + "automatic dynamic int %s val %s != %s", + name, + entry.scalar, + old_entry.scalar, + ) + CompileEventLogger.instant( + "automatic_dynamic", + { + "name": name, + "dim_changed": "scalar", + "reason": "scalar change", + "cached": str(old_entry.scalar), + "new": str(entry.scalar), + }, + ) + if is_unspecialized_nn_module: + log.info( + "%s is converted to a symbolic integer. It is an attribute of a " + "user defined nn module class. If you wish to keep it static, you can " + "mark the nn module class as `torch._dynamo.mark_static`.", + name, + ) + + def log_tup( + tup_name: str, short_reason: str, long_reason: str, i: Optional[int] = None + ) -> None: + entry_tup = ( + getattr(entry, tup_name) if i is None else getattr(entry, tup_name)[i] + ) + old_entry_tup = ( + getattr(old_entry, tup_name) + if i is None + else getattr(old_entry, tup_name)[i] + ) + log.debug( + "automatic dynamic %s %s %s %s != %s", + tup_name, + name, + short_reason, + # NB: We used to only report len(...) here for dim mismatch + entry_tup, + old_entry_tup, + ) + CompileEventLogger.instant( + "automatic_dynamic", + { + "name": name, + "dim_changed": "all" if i is None else i, + "reason": long_reason, + "cached": str(old_entry_tup), + "new": str(entry_tup), + }, + ) + + if is_update and old_entry.size != mut_entry.size: + if isinstance(old_entry.size, tuple) and isinstance(entry.size, tuple): + if len(old_entry.size) != len(entry.size): + log_tup("size", "dim", "dimensionality change") + else: + for i in range(len(entry.size)): + if old_entry.size[i] != entry.size[i]: + log_tup("size", f"size({i})", "size change", i) + else: + log_tup("size", "other", "other") + + if is_update and old_entry.stride != mut_entry.stride: + if isinstance(old_entry.stride, tuple) and isinstance(entry.stride, tuple): + if len(old_entry.stride) != len(entry.stride): + log_tup("stride", "dim", "dimensionality change") + else: + for i in range(len(entry.stride)): + if old_entry.stride[i] != entry.stride[i]: + log_tup("stride", f"stride({i})", "stride change", i) + else: + log_tup("stride", "other", "other") + else: + old_entry = frame_state.automatic_dynamic[name] + log.debug( + "automatic dynamic is off, overwriting int %s val %s -> %s", + name, + old_entry.scalar, + entry.scalar, + ) + frame_state.automatic_dynamic[name] = entry + mut_entry = entry + + return mut_entry + + +def process_automatic_dynamic( + tx: InstructionTranslator, + name: str, + entry: FrameStateSizeEntry, + *, + is_unspecialized_nn_module: bool = False, +) -> FrameStateSizeEntry: + if (st := tx.distributed_state) is None: + return update_automatic_dynamic( + tx, + name, + entry, + is_unspecialized_nn_module=is_unspecialized_nn_module, + ) + elif st.all_states is None: + # Preflight, always pretend as if it's static. The point here + # is we want to get through the preflight quickly, and static + # will run faster. The preexisting frame state will get + # applied anyway after we do compiler collectives. + # TODO: I'm not sure if we should just bong the entire pgo + # state here, it kind of depends if we're going to have other + # things that talk in compiler collective. Also, the PGO + # state, if we've already inferred something is automatic + # dynamic, will have lost the actual input sizes, which might + # be useful for debugging purposes (e.g., observing 0/1 + # specialization). Bonging the entire PGO state here would + # let us delete this logic here; the compiler collective + # would just directly update_automatic_dynamic + st.local_state.automatic_dynamic[name] = entry + return entry + else: + # Apply the updates. NB: all_states includes the local state + # too. + res = None + for sub_state in st.all_states: + if name in sub_state.automatic_dynamic: + res = update_automatic_dynamic( + tx, + name, + sub_state.automatic_dynamic[name], + is_unspecialized_nn_module=is_unspecialized_nn_module, + ) + assert res is not None + return res + + +def get_cache_key() -> Optional[str]: + # TODO: info versions of these logs that log only once + if torch._inductor.config.force_disable_caches: + warn_once( + "dynamo_pgo force disabled by torch._inductor.config.force_disable_caches" + ) + return None + + # NB: We always use global rank for keys, even though they are overkill + # for local only cache + rank = None + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + + tag = torch.compiler.config.cache_key_tag + + # NB: We namespace the cache keys so that only user-specified job id + # can alias with each other. + if (r := torch.compiler.config.job_id) is not None: + if r.startswith("mast:"): + raise ReservedWorkflowIdUserError( + "torch.compiler.config.job_id with prefix 'mast:' is reserved for " + "automatically generated job id associated with a specific MAST job " + "name and version." + ) + return f"{r}:{rank}:{tag}" + + if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None: + mast_job_name, mast_job_version = name_version + return f"mast:{mast_job_name}:{mast_job_version}:{rank}:{tag}" + + return None + + +# This solely controls local PGO +def code_state_path(cache_key: str) -> Optional[str]: + if not torch._dynamo.config.automatic_dynamic_local_pgo: + log.debug("automatic_dynamic_local_pgo not enabled") + return None + + from torch._inductor.runtime.runtime_utils import cache_dir + + code_state_key = re.sub(r'[<>:"/\\|?*]', "_", f"code_state_{cache_key}.pkl") + return os.path.join(cache_dir(), "dynamo", code_state_key) + + +def should_use_remote_dynamo_pgo_cache() -> bool: + if torch._inductor.config.force_disable_caches: + return False + + if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None: + return r + + if not is_fbcode(): + return False + + if torch._utils_internal.is_fb_unit_test(): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:dynamo_pgo_version" + ) + + +def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: + from torch._inductor.remote_cache import create_cache + + if not should_use_remote_dynamo_pgo_cache(): + return None + + return create_cache( + "dynamo-pgo", + is_fbcode(), + "FbRemoteDynamoPGOCache", + "RemoteDynamoPGOCache", + ) + + +def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]: + dynamic_sources: OrderedSet[str] = OrderedSet() + for src, fs in code_state.automatic_dynamic.items(): + dynamic = False + if isinstance(fs.size, tuple): + dynamic = auto_dynamic in fs.size # type: ignore[operator] + elif fs.scalar == auto_dynamic: + dynamic = True + if dynamic: + dynamic_sources.add(src) + return dynamic_sources + + +def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: + code_id = CodeId.make(f_code) + frame_state = get_code_state()[code_id] + frame_whitelist = ",".join(_collect_dynamic_sources(frame_state)) + if frame_whitelist: + with dynamo_timed(name := "pgo.dynamic_whitelist", log_pt2_compile_event=True): + CompileEventLogger.pt2_compile( + name, recompile_dynamic_whitelist=frame_whitelist + ) + + +def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: + code_state_str = "\n".join( + f"{k}:\n" + + "\n".join( + f" {src}: {fs.render()}" for src, fs in v.automatic_dynamic.items() + ) + for k, v in cs.items() + ) + dynamic_sources: OrderedSet[str] = OrderedSet() + for state in cs.values(): + dynamic_sources.update(_collect_dynamic_sources(state)) + if dynamic_sources: + code_state_str += ( + "\n\nPGO detected a recompilation due to dynamic shapes. " + "To reduce shape recompilations by compiling dynamically to start, " + f'set environment variable TORCH_COMPILE_DYNAMIC_SOURCES="{",".join(dynamic_sources)}"' + ) + return code_state_str + + +@CacheArtifactFactory.register +class PGOCacheArtifact(CacheArtifact): + @override + def populate_cache(self) -> None: + meta = write_local_impl( + self._rewrite_cache_key_for_mega_cache(self.key), self.content + ) + assert meta is not None + + @override + @staticmethod + def type() -> str: + return "pgo" + + @staticmethod + def _rewrite_cache_key_for_mega_cache(original_key: str) -> str: + """ + The PGO cache artifact key for a MAST job contains the job name and the version. + When we want to use the cache artifact on a different MAST job, we need to + update the key to use the new MAST job's name and version. + """ + if not original_key.startswith("mast:"): + # if original_key is overridden, then dont change it + return original_key + if (new_key := get_cache_key()) is not None: + return new_key + return original_key + + +def get_code_state() -> defaultdict[CodeId, CodeState]: + global _CODE_STATE, _INIT_CODE_STATE + if _CODE_STATE is not None: + return _CODE_STATE + + # Initialize it (even if we don't look up profile) + _CODE_STATE = defaultdict(CodeState) + + cache_key = get_cache_key() + if cache_key is None: + return _CODE_STATE + + def hit(ty: str) -> defaultdict[CodeId, CodeState]: + global _INIT_CODE_STATE + assert isinstance(_CODE_STATE, defaultdict) + log.info("get_code_state %s hit %s, %d entries", path, ty, len(_CODE_STATE)) + trace_structured_artifact( + f"get_{ty}_code_state", + "string", + lambda: render_code_state(_CODE_STATE), # type: ignore[arg-type] + ) + set_feature_use("pgo", True) + _INIT_CODE_STATE = copy.deepcopy(_CODE_STATE) + return _CODE_STATE + + # Attempt local + path = code_state_path(cache_key) + if path is not None and os.path.exists(path): + with dynamo_timed( + name := "pgo.get_local_code_state", log_pt2_compile_event=True + ): + CompileEventLogger.pt2_compile(name, cache_key=cache_key) + # Read lock not necessary as we always write atomically write to + # the actual location + with open(path, "rb") as f: + try: + content = f.read() + _CODE_STATE = pickle.loads(content) + CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell()) + except Exception: + log.warning( + "get_code_state failed while reading %s", path, exc_info=True + ) + else: + CacheArtifactManager.record_artifact( + PGOCacheArtifact.type(), cache_key, content + ) + return hit("local") + + # Attempt remote + remote_cache = get_remote_cache() + if remote_cache is not None: + with dynamo_timed( + name := "pgo.get_remote_code_state", + log_pt2_compile_event=True, + dynamo_compile_column_us="pgo_get_remote_code_state_time_us", + ): + CompileEventLogger.pt2_compile(name, cache_key=cache_key) + # TODO: I don't really understand why there's a JSON container format + try: + cache_data = remote_cache.get(cache_key) + except Exception: + log.warning( + "get_code_state failed remote read on %s", cache_key, exc_info=True + ) + else: + if cache_data is not None: + try: + assert isinstance(cache_data, dict) + data = cache_data["data"] + assert isinstance(data, str) + payload = base64.b64decode(data) + CompileEventLogger.pt2_compile( + name, cache_size_bytes=len(payload) + ) + _CODE_STATE = pickle.loads(payload) + except Exception: + log.warning( + "get_code_state failed parsing remote result on %s", + cache_key, + exc_info=True, + ) + else: + CacheArtifactManager.record_artifact( + PGOCacheArtifact.type(), cache_key, payload + ) + return hit("remote") + else: + log.info("get_code_state remote miss on %s", cache_key) + + log.info("get_code_state using default") + + assert _CODE_STATE is not None + return _CODE_STATE + + +def put_code_state() -> None: + if _CODE_STATE is None: + log.info("put_code_state: never initialized, will not write") + return + + if _CODE_STATE == _INIT_CODE_STATE: + log.info("put_code_state: no change, skipping") + return + + cache_key = get_cache_key() + if cache_key is None: + log.info("put_code_state: no cache key, skipping") + return + + put_local_code_state(cache_key) + put_remote_code_state(cache_key) + + +def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[tuple[str, int]]: + path = code_state_path(cache_key) + + if path is None: + return None + + # If the user isn't misusing our API, we should have exclusive access to + # this directory. But it's not too hard + + tmp_path = path + ".tmp" + lock_path = path + ".lock" + # We /mostly/ don't need the lock but the tmp file could be clobbered + # TODO: use a safe tempfile create to eliminate lock + from torch.utils._filelock import FileLock + + os.makedirs(os.path.dirname(path), exist_ok=True) + + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + with open(tmp_path, "wb") as f: + f.write(pickled_code) + size = f.tell() + os.replace(tmp_path, path) + return path, size + + +def put_local_code_state(cache_key: str) -> None: + with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True): + CompileEventLogger.pt2_compile(name, cache_key=cache_key) + assert _CODE_STATE is not None + + pickled_code = pickle.dumps(_CODE_STATE) + + CacheArtifactManager.record_artifact( + PGOCacheArtifact.type(), cache_key, pickled_code + ) + + meta = write_local_impl(cache_key, pickled_code) + if meta is None: + log.info("put_code_state: local cache disabled") + return + path, size = meta + + CompileEventLogger.pt2_compile(name, cache_size_bytes=size) + log.info("put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE)) + trace_structured_artifact( + "put_local_code_state", + "string", + lambda: render_code_state(_CODE_STATE), + ) + + +def put_remote_code_state(cache_key: str) -> None: + with dynamo_timed( + name := "pgo.put_remote_code_state", + log_pt2_compile_event=True, + dynamo_compile_column_us="pgo_put_remote_code_state_time_us", + ): + CompileEventLogger.pt2_compile(name, cache_key=cache_key) + assert _CODE_STATE is not None + + remote_cache = get_remote_cache() + + if remote_cache is None: + log.info("put_code_state: remote cache disabled") + return + + content = pickle.dumps(_CODE_STATE) + CompileEventLogger.pt2_compile(name, cache_size_bytes=len(content)) + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + } + remote_cache.put(cache_key, cache_data) + log.info( + "put_code_state: wrote remote %s, %d entries", cache_key, len(_CODE_STATE) + ) + # TODO: don't log this multiple times + trace_structured_artifact( + "put_remote_code_state", + "string", + lambda: render_code_state(_CODE_STATE), + ) + + +# NB: this does NOT reset the cached code state on disk +def reset_code_state() -> None: + global _CODE_STATE, _INIT_CODE_STATE + _CODE_STATE = None + _INIT_CODE_STATE = None diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__init__.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf73f82a1440c615fe82a477b4738f70755783a8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__init__.py @@ -0,0 +1,331 @@ +""" +Python polyfills for common builtins. +""" + +# NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports. +# 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py. +# Add it in the TYPE_CHECKING block below as well. + +# mypy: allow-untyped-defs + +import types +from collections.abc import Iterable, MutableMapping, Sequence +from itertools import repeat as _repeat +from typing import Any, Callable, TYPE_CHECKING + +import torch + +from ..utils import dict_keys + + +if TYPE_CHECKING: + # Load by torch._dynamo.polyfills.loader + # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py + # Put the submodules here to avoid circular imports + from . import ( + builtins as builtins, + functools as functools, + itertools as itertools, + operator as operator, + os as os, + pytree as pytree, + sys as sys, + ) + +from torch.overrides import BaseTorchFunctionMode + + +# These classes handle support for TorchFunctionModes across +# graph breaks +# Today the TorchFunctionMode enter (for the classes we support) +# simply pushes the mode onto the stack. Since after this occurs +# the stack is mutated, and we replay these mutations, we don't need +# any cleanup logic to be run once the graph break occurs, we simply replay +# these mutations to ensure at the graph break the torch function mode stack is correct +# and reconstruct the torch function mode stack normally +# when we compile the resume function on the other side of the break. +# However, to ensure we exit properly +# in the resume function, we need to re-enter the contexts as we do other contexts. +# These contexts do nothing on enter, but provide the correct exit logic to ensure +# the stack state is correct. +class NoEnterTorchFunctionMode(BaseTorchFunctionMode): + def __enter__(self): + pass + + +def index(iterator, item, start=0, end=None): + from itertools import islice + + for i, elem in islice(enumerate(iterator), start, end): + if item == elem: + return i + # This will not run in dynamo + raise ValueError(f"{item} is not in {type(iterator)}") + + +def repeat(item, count): + for _ in range(count): + yield item + + +def radians(x): + import math + + return math.pi / 180.0 * x + + +def accumulate_grad(x, new_grad): + # polyfills according to the Gradient Layout Contract + if new_grad is None: + return + new_grad_strided = torch.empty_like(x) + new_grad_strided.copy_(new_grad) + if x.grad is None: + x.grad = new_grad_strided + elif torch.is_grad_enabled(): + x.grad = x.grad + new_grad_strided + else: + x.grad.add_(new_grad_strided) + + +# This mirrors +# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/listobject.c#L3352-L3413 +def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequence[Any]): + """emulate `(1,2,3) > (1,2)` etc""" + # Apply `op` to the first pair that differ + for a, b in zip(left, right): + if a != b: + return op(a, b) + + # No more pairs to compare, so compare sizes. + return op(len(left), len(right)) + + +def set_symmetric_difference(set1, set2): + symmetric_difference_set = set() + for x in set1: + if x not in set2: + symmetric_difference_set.add(x) + for x in set2: + if x not in set1: + symmetric_difference_set.add(x) + return symmetric_difference_set + + +def set_symmetric_difference_update(set1, set2): + result = set1.symmetric_difference(set2) + set1.clear() + set1.update(result) + + +def set_isdisjoint(set1, set2): + for x in set1: + if x in set2: + return False + return True + + +def set_intersection(set1, *others): + if len(others) == 0: + return set1.copy() + + intersection_set = set() + for x in set1: + for set2 in others: + if x not in set2: + break + else: + intersection_set.add(x) + return intersection_set + + +def set_intersection_update(set1, *others): + result = set1.intersection(*others) + set1.clear() + set1.update(result) + + +def set_union(set1, *others): + # frozenset also uses this function + union_set = set(set1.copy()) + for set2 in others: + set_update(union_set, set2) + return type(set1)(union_set) + + +def set_update(set1, *others): + if len(others) == 0: + return set1 + + for set2 in others: + for x in set2: + if x not in set1: + set1.add(x) + + +def set_difference(set1, *others): + if len(others) == 0: + return set1.copy() + + if not all(isinstance(s, Iterable) for s in others): + raise TypeError(f"set.difference expected an iterable, got {type(others)}") + + difference_set = set() + for x in set1: + for set2 in others: + if x in set2: + break + else: + difference_set.add(x) + return difference_set + + +def set_difference_update(set1, *others): + result = set1.difference(*others) + set1.clear() + set1.update(result) + + +def getattr_and_trace(*args, **kwargs): + wrapper_obj = args[0] + attr_name = args[1] + fn = getattr(wrapper_obj, attr_name) + return fn(*args[2:], **kwargs) + + +def mapping_get(obj, key, value=None): + try: + return obj.__getitem__(key) + except KeyError: + return value + + +def instantiate_user_defined_class_object(cls, /, *args, **kwargs): + obj = cls.__new__(cls, *args, **kwargs) + + # Only call __init__ if the object is an instance of the class + # Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673 + if isinstance(obj, cls): + obj.__init__(*args, **kwargs) + return obj + + +# Used with something like dict(obj) +def construct_dict(cls, /, *args, **kwargs): + dst = cls.__new__(cls) + + if args: + src = args[0] + + # Ensure that the overridden __iter__ method is invoked + if isinstance(src, (dict, MutableMapping, types.MappingProxyType)): + for key in src: + # This will inline the __getitem__ of the src object + dst[key] = src[key] + else: + # likely a sequence like tuple of pairs + for key, value in src: + dst[key] = value + + if kwargs: + for key in kwargs: + dst[key] = kwargs[key] + + return dst + + +def foreach_map_fn(*args): + op = args[0] + new_args: list[Any] = [] + at_least_one_list = False + for arg in args[1:]: + if not isinstance(arg, (list, tuple)): + new_args.append(_repeat(arg)) + else: + at_least_one_list = True + new_args.append(arg) + + # Just apply op once to args if there are no lists + if not at_least_one_list: + return op(*args[1:]) + + out = [] + for unpacked in zip(*new_args): + out.append(op(*unpacked)) + + return out + + +def foreach_lerp_inplace(self, end, weight): + # decompose foreach lerp into constituent ops, prevents a graph break due to + # converting a value to a scalar when arg[2] is a single tensor + result = torch._foreach_sub(end, self) + result = torch._foreach_mul(result, weight) + return torch._foreach_add_(self, result) + + +def foreach_pow_scalar(scalar, exps): + return torch._foreach_pow([scalar for _ in exps], exps) + + +def addcmul_inplace(self, tensor1, tensor2, value): + return self.add_(tensor1 * tensor2 * value) + + +def predicate(obj: Any) -> bool: + # This will cause the rest of dynamo to handle the if statement correctly, so we don't have to rewrite it here. + # We can't just use bool() here since we can't trace into that in general. + if obj: + return True + return False + + +def cmp_eq(a, b): + # Note that the commented `is` check should ideally be removed. This is a + # CPython optimization that skips the __eq__ checks it the obj id's are + # same. But, these lines adds many `is` nodes in the Fx graph for + # SymNodeVariable. For now, we can just skip this check. This is STILL + # correct because one of the __eq__ checks will pass later, just could be + # slow in some corner cases. + # if a is b: + # return True + result = a.__eq__(b) + if result is NotImplemented: + result = b.__eq__(a) + return result is not NotImplemented and result + + +def cmp_ne(a, b): + # Check if __ne__ is overridden + if isinstance(type(a).__ne__, types.FunctionType): + return a.__ne__(b) + return not cmp_eq(a, b) + + +def cmp_lt(a, b): + result = a.__lt__(b) + if result is NotImplemented: + raise TypeError(f"{type(a)} does not support the < operator") + return result + + +def cmp_le(a, b): + # Check if __le__ is overridden + if isinstance(type(a).__le__, types.FunctionType): + return a.__le__(b) + return cmp_eq(a, b) or cmp_lt(a, b) + + +def cmp_gt(a, b): + # Check if __gt__ is overridden + if isinstance(type(a).__gt__, types.FunctionType): + return a.__gt__(b) + # a > b is equivalent to b < a + return cmp_lt(b, a) + + +def cmp_ge(a, b): + # Check if __ge__ is overridden + if isinstance(type(a).__ge__, types.FunctionType): + return a.__ge__(b) + return cmp_eq(a, b) or cmp_gt(a, b) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41fe68599a8dc169c26dbc6a1407025cf4cb3135 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/builtins.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/builtins.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ea1a690c47e530d2a777b6f97eaefa08135dbf7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/builtins.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/functools.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/functools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81d1c0df7beb34ab406b2a0b407ef6689cc3ea38 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/functools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/fx.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/fx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a66ad05aafb7c6aa3f07f2f05bb829a651814b3c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/fx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/itertools.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/itertools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2927e92196b54d866fc37318eb8ff72f4ab0f9a4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/itertools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/loader.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/loader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bb8130dd39dd5cfffaa0dc3d263119aa8bad003 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/loader.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/operator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/operator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c95d1b516081a1f3e12fe3b33655c5894bf903b0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/operator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/os.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/os.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecfd778488dde96193c5d19d40b9e6b55a679db5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/os.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/pytree.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/pytree.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7daeac2435fa675a8c658a2fb59feb2f000f0387 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/pytree.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/sys.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/sys.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4953416bd07cb08f5d5412d7b0a5342c2a418dc2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/sys.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4102c791c02c3d8bf948db449f4b90b1dba2c674 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/__pycache__/tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/builtins.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/builtins.py new file mode 100644 index 0000000000000000000000000000000000000000..acea9e82dca966b137188ddf26a098a73ff971e7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/builtins.py @@ -0,0 +1,60 @@ +""" +Python polyfills for builtins +""" + +from __future__ import annotations + +import builtins +import functools +import operator +from typing import TYPE_CHECKING, TypeVar + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +__all__ = [ + "all", + "any", + "enumerate", + "sum", +] + + +_T = TypeVar("_T") + + +@substitute_in_graph(builtins.all, can_constant_fold_through=True) +def all(iterable: Iterable[object], /) -> bool: + for elem in iterable: + if not elem: + return False + return True + + +@substitute_in_graph(builtins.any, can_constant_fold_through=True) +def any(iterable: Iterable[object], /) -> bool: + for elem in iterable: + if elem: + return True + return False + + +@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type] +def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]: + if not isinstance(start, int): + raise TypeError( + f"{type(start).__name__!r} object cannot be interpreted as an integer" + ) + + for x in iterable: + yield start, x + start += 1 + + +@substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] +def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] + return functools.reduce(operator.add, iterable, start) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/functools.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/functools.py new file mode 100644 index 0000000000000000000000000000000000000000..12453cac7ccb9d668af6bde6416778763687a14d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/functools.py @@ -0,0 +1,47 @@ +""" +Python polyfills for functools +""" + +import functools +from collections.abc import Iterable +from typing import Callable, TypeVar + +from ..decorators import substitute_in_graph + + +__all__ = ["reduce"] + + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class _INITIAL_MISSING: + pass + + +# Reference: https://docs.python.org/3/library/functools.html#functools.reduce +@substitute_in_graph(functools.reduce) +def reduce( + function: Callable[[_U, _T], _U], + iterable: Iterable[_T], + initial: _U = _INITIAL_MISSING, # type: ignore[assignment] + /, +) -> _U: + it = iter(iterable) + + value: _U + if initial is _INITIAL_MISSING: + try: + value = next(it) # type: ignore[assignment] + except StopIteration: + raise TypeError( + "reduce() of empty iterable with no initial value", + ) from None + else: + value = initial + + for element in it: + value = function(value, element) + + return value diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/fx.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/fx.py new file mode 100644 index 0000000000000000000000000000000000000000..f92a5c3b7f0c437970e15e437f3e19a0d186baba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/fx.py @@ -0,0 +1,40 @@ +from typing import Any, Callable + +from torch._C import _fx_map_aggregate, _fx_map_arg +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.node import Node + +from ..decorators import substitute_in_graph + + +@substitute_in_graph(_fx_map_arg, can_constant_fold_through=True) +def map_arg(a: Any, fn: Callable[[Node], Any]) -> Any: + return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + + +@substitute_in_graph(_fx_map_aggregate, can_constant_fold_through=True) +def map_aggregate(a: Any, fn: Callable[[Any], Any]) -> Any: + result: Any + if isinstance(a, tuple): + it = (map_aggregate(elem, fn) for elem in a) + # Support NamedTuple (if it has `_fields`) by repacking into original type. + result = type(a)(*it) if hasattr(a, "_fields") else tuple(it) + elif isinstance(a, list): + result = immutable_list([map_aggregate(elem, fn) for elem in a]) + elif isinstance(a, dict): + result = immutable_dict([(k, map_aggregate(v, fn)) for k, v in a.items()]) + elif isinstance(a, slice): + result = slice( + map_aggregate(a.start, fn), + map_aggregate(a.stop, fn), + map_aggregate(a.step, fn), + ) + else: + result = fn(a) + return result + + +__all__ = [ + "map_arg", + "map_aggregate", +] diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/itertools.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/itertools.py new file mode 100644 index 0000000000000000000000000000000000000000..6941337420b1cdceaa38bef92d269b0a0e7187ad --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/itertools.py @@ -0,0 +1,211 @@ +""" +Python polyfills for itertools +""" + +from __future__ import annotations + +import itertools +import sys +from typing import Callable, overload, TYPE_CHECKING, TypeVar +from typing_extensions import TypeAlias + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + +__all__ = [ + "chain", + "chain_from_iterable", + "compress", + "dropwhile", + "islice", + "tee", + "zip_longest", +] + + +_T = TypeVar("_T") +_U = TypeVar("_U") +_Predicate: TypeAlias = Callable[[_T], object] +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain +@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type] +def chain(*iterables: Iterable[_T]) -> Iterator[_T]: + for iterable in iterables: + yield from iterable + + +@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] +def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: + return itertools.chain(*iterable) + + +chain.from_iterable = chain_from_iterable # type: ignore[attr-defined] + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.compress +@substitute_in_graph(itertools.compress, is_embedded_type=True) # type: ignore[arg-type] +def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]: + return (datum for datum, selector in zip(data, selectors) if selector) + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile +@substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type] +def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]: + # dropwhile(lambda x: x < 5, [1, 4, 6, 3, 8]) -> 6 3 8 + + iterator = iter(iterable) + for x in iterator: + if not predicate(x): + yield x + break + + yield from iterator + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice +@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] +def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: + s = slice(*args) + start = 0 if s.start is None else s.start + stop = s.stop + step = 1 if s.step is None else s.step + if start < 0 or (stop is not None and stop < 0) or step <= 0: + raise ValueError( + "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.", + ) + + if stop is None: + # TODO: use indices = itertools.count() and merge implementation with the else branch + # when we support infinite iterators + next_i = start + for i, element in enumerate(iterable): + if i == next_i: + yield element + next_i += step + else: + indices = range(max(start, stop)) + next_i = start + for i, element in zip(indices, iterable): + if i == next_i: + yield element + next_i += step + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise +if sys.version_info >= (3, 10): + + @substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type] + def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]: + a = None + first = True + for b in iterable: + if first: + first = False + else: + yield a, b # type: ignore[misc] + a = b + + __all__ += ["pairwise"] + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.tee +@substitute_in_graph(itertools.tee) +def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: + iterator = iter(iterable) + shared_link = [None, None] + + def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def] + try: + while True: + if link[1] is None: + link[0] = next(iterator) + link[1] = [None, None] + value, link = link + yield value + except StopIteration: + return + + return tuple(_tee(shared_link) for _ in range(n)) + + +@overload +def zip_longest( + iter1: Iterable[_T1], + /, + *, + fillvalue: _U = ..., +) -> Iterator[tuple[_T1]]: ... + + +@overload +def zip_longest( + iter1: Iterable[_T1], + iter2: Iterable[_T2], + /, +) -> Iterator[tuple[_T1 | None, _T2 | None]]: ... + + +@overload +def zip_longest( + iter1: Iterable[_T1], + iter2: Iterable[_T2], + /, + *, + fillvalue: _U = ..., +) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ... + + +@overload +def zip_longest( + iter1: Iterable[_T], + iter2: Iterable[_T], + iter3: Iterable[_T], + /, + *iterables: Iterable[_T], +) -> Iterator[tuple[_T | None, ...]]: ... + + +@overload +def zip_longest( + iter1: Iterable[_T], + iter2: Iterable[_T], + iter3: Iterable[_T], + /, + *iterables: Iterable[_T], + fillvalue: _U = ..., +) -> Iterator[tuple[_T | _U, ...]]: ... + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.zip_longest +@substitute_in_graph(itertools.zip_longest, is_embedded_type=True) # type: ignore[arg-type,misc] +def zip_longest( + *iterables: Iterable[_T], + fillvalue: _U = None, # type: ignore[assignment] +) -> Iterator[tuple[_T | _U, ...]]: + # zip_longest('ABCD', 'xy', fillvalue='-') -> Ax By C- D- + + iterators = list(map(iter, iterables)) + num_active = len(iterators) + if not num_active: + return + + while True: + values = [] + for i, iterator in enumerate(iterators): + try: + value = next(iterator) + except StopIteration: + num_active -= 1 + if not num_active: + return + iterators[i] = itertools.repeat(fillvalue) # type: ignore[arg-type] + value = fillvalue # type: ignore[assignment] + values.append(value) + yield tuple(values) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/loader.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..2fad1e5bffda535c0ec38a18c8a4306357a9f78a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/loader.py @@ -0,0 +1,39 @@ +# Used to load and initialize polyfill handlers when importing torch._dynamo +# Please add a new import when adding a new polyfill module. + +import importlib +from typing import TYPE_CHECKING + +from .. import polyfills, trace_rules + + +if TYPE_CHECKING: + from types import ModuleType + + +# See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py +POLYFILLED_MODULE_NAMES: tuple[str, ...] = ( + "builtins", + "functools", + "itertools", + "operator", + "os", + "pytree", + "sys", + "fx", + "tensor", +) +POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple( + importlib.import_module(f".{submodule}", package=polyfills.__name__) + for submodule in POLYFILLED_MODULE_NAMES +) + + +# Unregister the builtin functions from _builtin_function_ids to let them to be +# dispatched with the appropriate VariableTracker type. Otherwise, they will be +# dispatched with BuiltinVariable if present in _builtin_function_ids. +for polyfill_module in POLYFILLED_MODULES: + for polyfill_name in polyfill_module.__all__: + polyfill_handler = getattr(polyfill_module, polyfill_name) + original_fn = polyfill_handler.__torch_dynamo_original__ + trace_rules._builtin_function_ids.remove(id(original_fn)) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/operator.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..5cba337d64614f55dba2a3c2c15ab2bf6d29b5e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/operator.py @@ -0,0 +1,105 @@ +""" +Python polyfills for operator +""" + +from __future__ import annotations + +import operator +from typing import Any, Callable, overload, TypeVar +from typing_extensions import TypeVarTuple, Unpack + +from ..decorators import substitute_in_graph + + +# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) +__all__ = ["attrgetter", "itemgetter", "methodcaller"] + + +_T = TypeVar("_T") +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") +_Ts = TypeVarTuple("_Ts") +_U = TypeVar("_U") +_U1 = TypeVar("_U1") +_U2 = TypeVar("_U2") +_Us = TypeVarTuple("_Us") + + +@overload +def attrgetter(attr: str, /) -> Callable[[Any], _U]: ... + + +@overload +def attrgetter( + attr1: str, attr2: str, /, *attrs: str +) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... + + +# Reference: https://docs.python.org/3/library/operator.html#operator.attrgetter +@substitute_in_graph(operator.attrgetter, is_embedded_type=True) # type: ignore[arg-type,misc] +def attrgetter(*attrs: str) -> Callable[[Any], Any | tuple[Any, ...]]: + if len(attrs) == 0: + raise TypeError("attrgetter expected 1 argument, got 0") + + if any(not isinstance(attr, str) for attr in attrs): + raise TypeError("attribute name must be a string") + + def resolve_attr(obj: Any, attr: str) -> Any: + for name in attr.split("."): + obj = getattr(obj, name) + return obj + + if len(attrs) == 1: + attr = attrs[0] + + def getter(obj: Any) -> Any: + return resolve_attr(obj, attr) + + else: + + def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] + return tuple(resolve_attr(obj, attr) for attr in attrs) + + return getter + + +@overload +def itemgetter(item: _T, /) -> Callable[[Any], _U]: ... + + +@overload +def itemgetter( + item1: _T1, item2: _T2, /, *items: Unpack[_Ts] +) -> Callable[[Any], tuple[_U1, _U2, Unpack[_Us]]]: ... + + +# Reference: https://docs.python.org/3/library/operator.html#operator.itemgetter +@substitute_in_graph(operator.itemgetter, is_embedded_type=True) # type: ignore[arg-type,misc] +def itemgetter(*items: Any) -> Callable[[Any], Any | tuple[Any, ...]]: + if len(items) == 0: + raise TypeError("itemgetter expected 1 argument, got 0") + + if len(items) == 1: + item = items[0] + + def getter(obj: Any) -> Any: + return obj[item] + + else: + + def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] + return tuple(obj[item] for item in items) + + return getter + + +# Reference: https://docs.python.org/3/library/operator.html#operator.methodcaller +@substitute_in_graph(operator.methodcaller, is_embedded_type=True) # type: ignore[arg-type] +def methodcaller(name: str, /, *args: Any, **kwargs: Any) -> Callable[[Any], Any]: + if not isinstance(name, str): + raise TypeError("method name must be a string") + + def caller(obj: Any) -> Any: + return getattr(obj, name)(*args, **kwargs) + + return caller diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/os.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/os.py new file mode 100644 index 0000000000000000000000000000000000000000..235299e4ece0bd3d04328f1870bf535bb8f0f211 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/os.py @@ -0,0 +1,36 @@ +""" +Python polyfills for os +""" + +from __future__ import annotations + +import os +from typing import AnyStr + +from ..decorators import substitute_in_graph + + +__all__ = ["fspath"] + + +# Copied from os.py in the standard library +@substitute_in_graph(os.fspath, can_constant_fold_through=True) +def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: + if isinstance(path, (str, bytes)): + return path + + path_type = type(path) + try: + path_repr = path_type.__fspath__(path) # type: ignore[arg-type] + except AttributeError: + if hasattr(path_type, "__fspath__"): + raise + raise TypeError( + f"expected str, bytes or os.PathLike object, not {path_type.__name__}", + ) from None + if isinstance(path_repr, (str, bytes)): + return path_repr # type: ignore[return-value] + raise TypeError( + f"expected {path_type.__name__}.__fspath__() to return str or bytes, " + f"not {type(path_repr).__name__}", + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/pytree.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c9fefed8544b89d827ee84020b87f3052f5384 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/pytree.py @@ -0,0 +1,419 @@ +""" +Python polyfills for torch.utils.pytree +""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass, field +from typing import Any, Callable, Literal, TYPE_CHECKING +from typing_extensions import TypeIs + +import torch.utils._pytree as python_pytree +from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES + +from ..decorators import substitute_in_graph + + +if TYPE_CHECKING: + import builtins + from collections.abc import Iterable + from typing_extensions import Self + + +__all__: list[str] = [] + + +if python_pytree._cxx_pytree_dynamo_traceable: + import optree + import optree._C + + import torch.utils._cxx_pytree as cxx_pytree + + if TYPE_CHECKING: + from torch.utils._cxx_pytree import PyTree + + @substitute_in_graph( + optree._C.is_dict_insertion_ordered, + can_constant_fold_through=True, + ) + def _(*args: Any, **kwargs: Any) -> bool: + # In namespace 'torch', the dictionary is always traversed in insertion order. + # This function returns True. + raise ValueError( + "Should not be called directly " + "because the original function will be called in the constant fold path." + ) + + __name = "" + for __name in ( + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", + "namedtuple_fields", + "structseq_fields", + ): + __func = getattr(optree, __name) + globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)( + __func.__python_implementation__ + ) + __all__ += [__name] # noqa: PLE0604 + del __func + del __name + + @substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True) + def tree_is_leaf( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> bool: + if tree is None or (is_leaf is not None and is_leaf(tree)): + return True + if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined] + return True + return False + + @substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False) + def tree_iter( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> Iterable[Any]: + stack = [tree] + while stack: + node = stack.pop() + if tree_is_leaf(node, is_leaf=is_leaf): + yield node + continue + + children, *_ = optree.tree_flatten_one_level( + node, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + stack.extend(reversed(children)) + + __all__ += ["tree_iter"] + + @substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True) + def tree_leaves( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> list[Any]: + return list(tree_iter(tree, is_leaf=is_leaf)) + + __all__ += ["tree_leaves"] + + class _Asterisk(str): + __slots__ = () + + def __new__(cls) -> Self: + return super().__new__(cls, "*") + + def __repr__(self) -> str: + return "*" # no quotes + + _asterisk = _Asterisk() + del _Asterisk + + @dataclass(frozen=True) + class PyTreeSpec: + """Analog for :class:`optree.PyTreeSpec` in Python.""" + + _children: tuple[PyTreeSpec, ...] + _type: builtins.type | None + _metadata: Any + _entries: tuple[Any, ...] + _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None + + num_nodes: int = field(init=False) + num_leaves: int = field(init=False) + num_children: int = field(init=False) + none_is_leaf: Literal[True] = field(init=False) + namespace: Literal["torch"] = field(init=False) + + def __post_init__(self) -> None: + if self._type is None: + assert len(self._children) == 0 + assert self._metadata is None + assert self._entries == () + assert self._unflatten_func is None + num_nodes = 1 + num_leaves = 1 + num_children = 0 + else: + assert callable(self._unflatten_func) + num_nodes = sum((spec.num_nodes for spec in self._children), start=1) + num_leaves = sum(spec.num_leaves for spec in self._children) + num_children = len(self._children) + + object.__setattr__(self, "num_nodes", num_nodes) + object.__setattr__(self, "num_leaves", num_leaves) + object.__setattr__(self, "num_children", num_children) + object.__setattr__(self, "none_is_leaf", True) + object.__setattr__(self, "namespace", "torch") + + def __repr__(self) -> str: + def helper(treespec: PyTreeSpec) -> str: + if treespec.is_leaf(): + assert treespec.type is None + return _asterisk + + assert treespec.type is not None + assert callable(treespec._unflatten_func) + children_representations = [ + helper(subspec) for subspec in treespec._children + ] + if ( + treespec.type in BUILTIN_TYPES + or optree.is_namedtuple_class(treespec.type) + or optree.is_structseq_class(treespec.type) + ): + return treespec._unflatten_func( + treespec._metadata, + children_representations, + ) + return ( + f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], " + f"[{', '.join(children_representations)}])" + ) + + return ( + f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})" + ) + + def __len__(self) -> int: + return self.num_leaves + + @property + def type(self) -> builtins.type | None: + return self._type + + def is_leaf(self) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def children(self) -> list[PyTreeSpec]: + return list(self._children) + + def child(self, index: int) -> PyTreeSpec: + return self._children[index] + + def entries(self) -> list[Any]: + return list(self._entries) + + def entry(self, index: int) -> Any: + return self._entries[index] + + def flatten_up_to(self, tree: PyTree) -> list[PyTree]: + def helper( + treespec: PyTreeSpec, + node: PyTree, + subtrees: list[PyTree], + ) -> None: + if treespec.is_leaf(): + subtrees.append(node) + return + + node_type = type(node) + if treespec.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != treespec.type: + raise ValueError( + f"Type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + + children, metadata, *_ = optree.tree_flatten_one_level( + node, + none_is_leaf=True, + namespace="torch", + ) + if len(children) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(children)}.", + ) + if metadata != treespec._metadata: + raise ValueError( + f"Node context mismatch for custom node type {treespec.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + treespec.type in STANDARD_DICT_TYPES + and node_type in STANDARD_DICT_TYPES + ) + if not both_standard_dict and node_type != treespec.type: + raise ValueError( + f"Node type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + if len(node) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(node)}.", + ) + + if both_standard_dict: + # dictionary types are compatible with each other + expected_keys = treespec.entries() + got_key_set = set(node) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + children = [node[key] for key in expected_keys] + else: + # node_type is treespec.type + children, metadata, *_ = optree.tree_flatten_one_level( + node, + none_is_leaf=True, + namespace="torch", + ) + if ( + node_type + is not deque # ignore mismatch of `maxlen` for deque + ) and metadata != treespec._metadata: + raise ValueError( + f"Node metadata mismatch for node type {treespec.type!r}; " + f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch + ) + + for subtree, subspec in zip(children, treespec._children): + helper(subspec, subtree, subtrees) + + subtrees: list[PyTree] = [] + helper(self, tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any]) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + # Recursively unflatten the children + start = 0 + end = 0 + subtrees = [] + for subspec in self._children: + end += subspec.num_leaves + subtrees.append(subspec.unflatten(leaves[start:end])) + start = end + + assert callable(self._unflatten_func) + return self._unflatten_func(self._metadata, subtrees) + + _LEAF_SPEC = PyTreeSpec((), None, None, (), None) + + def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: + return isinstance(obj, PyTreeSpec) + + @substitute_in_graph( # type: ignore[arg-type] + cxx_pytree.tree_flatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, + ) + def tree_flatten( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> tuple[list[Any], PyTreeSpec]: + def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: + if tree_is_leaf(node, is_leaf=is_leaf): + leaves.append(node) + return _LEAF_SPEC + + ( + children, + metadata, + entries, + unflatten_func, + ) = optree.tree_flatten_one_level( + node, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + + # Recursively flatten the children + subspecs = tuple(helper(child, leaves) for child in children) + return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type] + + leaves: list[Any] = [] + treespec = helper(tree, leaves) + return leaves, treespec + + __all__ += ["tree_flatten"] + + @substitute_in_graph( # type: ignore[arg-type] + cxx_pytree.tree_structure, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, + ) + def tree_structure( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> PyTreeSpec: + return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value] + + __all__ += ["tree_structure"] + + @substitute_in_graph( # type: ignore[arg-type] + cxx_pytree.tree_unflatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, + ) + def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree: + if not _is_pytreespec_instance(treespec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " + f"PyTreeSpec but got item of type {type(treespec)}." + ) + return treespec.unflatten(leaves) + + __all__ += ["tree_unflatten"] + + @substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True) + def tree_map( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> PyTree: + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, *flat_args)) + + __all__ += ["tree_map"] + + @substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True) + def tree_map_( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> PyTree: + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable + return tree + + __all__ += ["tree_map_"] diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/sys.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/sys.py new file mode 100644 index 0000000000000000000000000000000000000000..078c4c2939bfcb5a9a51cc5dc963e545814a5050 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/sys.py @@ -0,0 +1,25 @@ +""" +Python polyfills for sys +""" + +from __future__ import annotations + +import sys + +from ..decorators import substitute_in_graph + + +__all__ = [ + "intern", + "getrecursionlimit", +] + + +@substitute_in_graph(sys.intern, can_constant_fold_through=True) +def intern(string: str, /) -> str: + return string + + +@substitute_in_graph(sys.getrecursionlimit, can_constant_fold_through=True) +def getrecursionlimit() -> int: + return sys.getrecursionlimit() diff --git a/phivenv/Lib/site-packages/torch/_dynamo/polyfills/tensor.py b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..d47533fdda547d0013ac61720e8a159da778be1c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/polyfills/tensor.py @@ -0,0 +1,40 @@ +from typing import Any + +import torch + +from ..decorators import substitute_in_graph + + +@substitute_in_graph( # type: ignore[arg-type] + torch.Tensor._make_subclass +) +def make_subclass( + cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any +) -> Any: + with torch._C.DisableTorchFunctionSubclass(): + # This is a rough approximation of `THPVariable_make_subclass`. It should + # suffice for most of Dynamo tracing purposes. + # https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650 + assert len(kwargs) == 0, ( + "_make_subclass only supports requires_grad as keyword arg" + ) + data = data.detach() + + # Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo. + if data.requires_grad != requires_grad: + data.requires_grad = requires_grad + + # Dynamo can't yet handle upcasting to base tensor type via `as_subclass`. + if cls is torch.Tensor: + return torch.Tensor(data) + + # Calling `as_subclass` because + # 1. Dynamo knows how to handle it + # 2. the C impls match at this point -- both `THPVariable_make_subclass` and + # `THPVariable_as_subclass` calls `THPVariable_NewWithVar`. + return data.as_subclass(cls) + + +__all__ = [ + "make_subclass", +] diff --git a/phivenv/Lib/site-packages/torch/_dynamo/precompile_context.py b/phivenv/Lib/site-packages/torch/_dynamo/precompile_context.py new file mode 100644 index 0000000000000000000000000000000000000000..53b230cf134396a798318c22eb2967c3d2f77ada --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/precompile_context.py @@ -0,0 +1,152 @@ +from abc import abstractmethod +from collections import defaultdict +from typing import Any, Generic, Optional, TypeVar +from typing_extensions import override + +from torch.compiler._cache import ( + _serialize_single_cache, + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, + CacheArtifactsResult, + CacheInfo, +) +from torch.utils._appending_byte_serializer import AppendingByteSerializer +from torch.utils._ordered_set import OrderedSet + + +""" +Classes and implementations related to precompile +""" + +T = TypeVar("T") + + +class PrecompileCacheArtifact(CacheArtifact, Generic[T]): + """ + Data for each cache artifact that will be serialized and deserialized by + PrecompileContext, rather than CacheArtifactManager. + T represents the deserialized type of the artifact, i.e. the return type of after_deserialization + + PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts + as needed, and use them in after_deserialization. + + Example implementation: + + class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]): + my_field: int + + def after_deserialization(self) -> MySerializableType: + result = pickle.loads(self.content) + # Do some extra work post deserialization + result.my_post_deserialization_function(self.my_field) + return result + """ + + @override + def populate_cache(self) -> None: + raise RuntimeError("Precompile cache artifacts do not populate caches") + + @override + def precompile_compatible(self) -> bool: + return True + + @abstractmethod + def after_deserialization(self) -> T: + """ + Code to be run after reading raw byte contents from disk. + Generally converts self.content from raw bytes back into its original form. + """ + ... + + +class PrecompileContext(CacheArtifactManager): + """ + PrecompileContext is a special CacheArtifactManager for handling precompilation + It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead + of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key + together and place it into a global Precompile Cache. + + The following artifact types are supported by PrecompileContext: + - BundledAOTAutogradCacheArtifact + - CodeStateArtifact (from torch._dynamo.package once available) + """ + + # Protected by the compile_lock + # _new_cache_artifacts_by_key organizes results by the key of each artifact. + # This allows us to implement serialize_by_key easily. + # On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key + # are transferred to _new_cache_artifacts before serialization. + _new_cache_artifacts_by_key: dict[str, CacheArtifact] = {} + _new_cache_artifacts: CacheArtifactsResult = defaultdict(list) + # Keep a separate seen artifacts list to make avoid unnecessary duplicates + # This list will not be cleared between serialize() calls + _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet() + # When serialize() is called, artifacts are transferred from _cache_artifacts to + # internal data structure of the _serializer + # This allows us to only pay the cost of serialization if serialize() is called + _serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = ( + AppendingByteSerializer(serialize_fn=_serialize_single_cache) + ) + _cache_info: CacheInfo = CacheInfo() + + @classmethod + def clear(cls) -> None: + cls._new_cache_artifacts_by_key.clear() + super().clear() + + @override + @classmethod + def record_artifact( + cls, + artifact_type: str, + key: str, + content: Any, + ) -> None: + """ + Called from each caching operation to record the artifact in this + "mega" list + """ + artifact = CacheArtifactFactory.encode_create(artifact_type, key, content) + # TODO: although this covers completely same artifacts, it's possible + # with AOTAutogradCacheEntries to have multiple artifacts whose keys + # (i.e. backend_ids) are different, but whose contents are equal. + # In those cases, it would be much better if we only serialize once instead + # of N times. + if artifact in cls._seen_artifacts: + return + + cls._new_cache_artifacts_by_key[key] = artifact + cls._seen_artifacts.add(artifact) + + @classmethod + def _save_artifacts_by_type(cls) -> None: + """ + We normally record artifacts by key, but serialization expects them to be organized + by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts + """ + for artifact in cls._new_cache_artifacts_by_key.values(): + cls._new_cache_artifacts[artifact.__class__.type()].append(artifact) + cls._new_cache_artifacts_by_key.clear() + + @classmethod + def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]: + """ + Serialize all artifacts with the given key returned in a list. + """ + return cls._new_cache_artifacts_by_key.get(key, None) + + @classmethod + def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: + cls._save_artifacts_by_type() + return super().serialize() + + @staticmethod + def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: + raise NotImplementedError("TODO") + + @classmethod + def _ensure_cache_artifacts_registered(cls) -> None: + from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401 + BundledAOTAutogradCacheArtifact, + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/profiler.py b/phivenv/Lib/site-packages/torch/_dynamo/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..645e8d73903c05981a6e407702a80eab4189395c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/profiler.py @@ -0,0 +1,172 @@ +""" +Dynamo profiling implementation. + +This module provides profiling functionality for Dynamo, including: +- ProfileMetrics: Class for collecting and aggregating performance metrics like + execution time, operator counts, and fusion statistics +- ProfileResult: Class for analyzing and reporting profiling results +- Utilities for tracking missed/uncaptured operations +- Functions for instrumenting FX graphs with profiling capabilities + +The profiler helps measure and optimize the performance of Dynamo-compiled code +by tracking both captured and total operations, timing, and graph statistics. +""" + +import dataclasses +import os +from typing import Any +from typing_extensions import Self + +import torch + +from .utils import print_once + + +@dataclasses.dataclass +class ProfileMetrics: + microseconds: float = 0.0 + operators: int = 0 + fusions: int = 0 + graphs: int = 0 + + def __iadd__(self, other: Self) -> Self: + self.microseconds += other.microseconds + self.operators += other.operators + self.fusions += other.fusions + return self + + def __add__(self, other: "ProfileMetrics") -> "ProfileMetrics": + assert isinstance(other, ProfileMetrics) + return ProfileMetrics( + self.microseconds + other.microseconds, + self.operators + other.operators, + self.fusions + other.fusions, + ) + + def __truediv__(self, other: Any) -> "ProfileMetrics": + if isinstance(other, int): + other = ProfileMetrics(other, other, other) + return ProfileMetrics( + self.microseconds / max(1, other.microseconds), + self.operators / max(1, other.operators), + self.fusions / max(1, other.fusions), + ) + + def __str__(self) -> str: + return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time" + + def tocsv(self) -> list[float]: + return [self.operators, self.microseconds] + + +class ProfileResult: + def __init__( + self, captured: ProfileMetrics, total: ProfileMetrics, unique_graphs: int + ) -> None: + self.captured: ProfileMetrics = captured or ProfileMetrics() + self.total: ProfileMetrics = total or ProfileMetrics() + self.unique_graphs: int = unique_graphs + + def __iadd__(self, other: Self) -> Self: + self.captured += other.captured + self.total += other.total + self.unique_graphs += other.unique_graphs + return self + + def percent(self) -> ProfileMetrics: + return self.captured / self.total + + def __str__(self) -> str: + return ( + f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls " + f"{self.captured.operators:4}/{self.total.operators:4} = " + + str(self.percent()) + ) + + def tocsv(self) -> list[Any]: + return [ + self.unique_graphs, + self.captured.graphs, + self.captured.operators, + self.total.operators, + ] + self.percent().tocsv() + + +def should_print_missing() -> bool: + return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1" + + +def print_missing(stack: list[str]) -> None: + if any("/torch/autograd/profiler.py" in x for x in stack): + return + stack = [ + x for x in stack if ("> ".join(stack[-3:])) + + +class Profiler: + unique_graphs: int = 0 + + def __init__(self) -> None: + self.prof = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + with_stack=should_print_missing(), + ) + + def results(self) -> ProfileResult: + captured_regions = 0 + captured_ops = 0 + captured_microseconds = 0 + total_ops = 0 + total_microseconds = 0 + + last_op_end_time = -1 + captured_region_end_time = -1 + events = sorted(self.prof.events(), key=lambda x: x.time_range.start) + for e in events: + if e.name == "TORCHDYNAMO": + captured_region_end_time = e.time_range.end + captured_regions += 1 + # ignore `handle = torch.zeros(1)` in record_function.__init__() + total_ops -= 1 + elif e.time_range.start >= last_op_end_time: + last_op_end_time = e.time_range.end + if e.time_range.end <= captured_region_end_time: + captured_ops += 1 + captured_microseconds += e.time_range.elapsed_us() + elif should_print_missing(): + print_missing(e.stack) + total_ops += 1 + total_microseconds += e.time_range.elapsed_us() + else: + pass # ops recursively called from other ops (ignored) + + unique_graphs = Profiler.unique_graphs + Profiler.unique_graphs = 0 + # we counted one extra op that is part of the profiler setup code + total_ops -= 1 + + return ProfileResult( + captured=ProfileMetrics( + microseconds=captured_microseconds, + operators=captured_ops, + fusions=captured_ops - captured_regions, + graphs=captured_regions, + ), + total=ProfileMetrics( + microseconds=total_microseconds, + operators=total_ops, + fusions=total_ops - 1, + ), + unique_graphs=unique_graphs, + ) + + +def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: list[Any]) -> Any: + def _wrapped(*args: Any) -> Any: + with torch.profiler.record_function("TORCHDYNAMO"): + return gm.forward(*args) + + Profiler.unique_graphs += 1 + return _wrapped diff --git a/phivenv/Lib/site-packages/torch/_dynamo/replay_record.py b/phivenv/Lib/site-packages/torch/_dynamo/replay_record.py new file mode 100644 index 0000000000000000000000000000000000000000..4f99a47ae909845dc1df664fcaf5928e32eb44c3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/replay_record.py @@ -0,0 +1,129 @@ +""" +Python execution state recording and replay functionality. + +This module provides mechanisms for capturing and replaying Python execution state: + +- ModuleRecord: Tracks module access patterns and attribute usage +- DummyModule: Lightweight module substitute for replay +- ExecutionRecord: Manages execution context including globals, locals and builtins +- ExecutionRecorder: Records variable states and module access during execution + +The module enables serialization and reproduction of Python execution environments, +particularly useful for debugging and testing frameworks that need to capture +and recreate specific program states. +""" + +import dataclasses +from dataclasses import field +from types import CellType, CodeType, ModuleType +from typing import Any, IO +from typing_extensions import Self + +from torch.utils._import_utils import import_dill + + +dill = import_dill() + + +@dataclasses.dataclass +class ModuleRecord: + module: ModuleType + accessed_attrs: dict[str, Any] = field(default_factory=dict) + + +@dataclasses.dataclass +class DummyModule: + name: str + is_torch: bool = False + value: object = None + + @property + def __name__(self) -> str: + return self.name + + +@dataclasses.dataclass +class ExecutionRecord: + code: CodeType + closure: tuple[CellType] + globals: dict[str, Any] = field(default_factory=dict) + locals: dict[str, Any] = field(default_factory=dict) + builtins: dict[str, Any] = field(default_factory=dict) + code_options: dict[str, Any] = field(default_factory=dict) + + def dump(self, f: IO[str]) -> None: + assert dill is not None, "replay_record requires `pip install dill`" + dill.dump(self, f) + + @classmethod + def load(cls, f: IO[bytes]) -> Self: + assert dill is not None, "replay_record requires `pip install dill`" + return dill.load(f) + + +@dataclasses.dataclass +class ExecutionRecorder: + LOCAL_MOD_PREFIX = "___local_mod_" + + code: CodeType + closure: tuple[CellType] + globals: dict[str, Any] = field(default_factory=dict) + locals: dict[str, Any] = field(default_factory=dict) + builtins: dict[str, Any] = field(default_factory=dict) + code_options: dict[str, Any] = field(default_factory=dict) + name_to_modrec: dict[str, ModuleRecord] = field(default_factory=dict) + + def add_local_var(self, name: str, var: Any) -> None: + if isinstance(var, ModuleType): + self.locals[name] = self._add_mod(var) + else: + self.locals[name] = var + + def add_global_var(self, name: str, var: Any) -> None: + if isinstance(var, ModuleType): + self.globals[name] = self._add_mod(var) + else: + self.globals[name] = var + + def add_local_mod(self, name: str, mod: ModuleType) -> None: + assert isinstance(mod, ModuleType) + self.add_global_var(name, mod) + + def record_module_access(self, mod: ModuleType, name: str, val: Any) -> None: + if isinstance(val, ModuleType): + self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val) + return + + if mod.__name__ in self.name_to_modrec: + self.name_to_modrec[mod.__name__].accessed_attrs[name] = val + + def get_record(self) -> ExecutionRecord: + return ExecutionRecord( + self.code, + self.closure, + ExecutionRecorder._resolve_modules(self.globals), + ExecutionRecorder._resolve_modules(self.locals), + self.builtins.copy(), + self.code_options.copy(), + ) + + def _add_mod(self, mod: ModuleType) -> ModuleRecord: + if mod.__name__ not in self.name_to_modrec: + self.name_to_modrec[mod.__name__] = ModuleRecord(mod) + + return self.name_to_modrec[mod.__name__] + + @classmethod + def _resolve_modules(cls, vars: dict[str, Any]) -> dict[str, Any]: + def resolve_module(var: Any) -> Any: + if not isinstance(var, ModuleRecord): + return var + + dummy_mod = DummyModule(var.module.__name__) + for attr_name, attr_value in var.accessed_attrs.items(): + attr_value = resolve_module(attr_value) + dummy_mod.__setattr__(attr_name, attr_value) + + return dummy_mod + + return {k: resolve_module(v) for k, v in vars.items()} diff --git a/phivenv/Lib/site-packages/torch/_dynamo/repro/__init__.py b/phivenv/Lib/site-packages/torch/_dynamo/repro/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73a31ae321a483fc85b9df883a936c40170d3b25 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_aot.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_aot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deb8dd80e076653833003015f0802afdcd29d28f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_aot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_dynamo.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_dynamo.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20214322656dfe870f0b58729084424693b2660e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/after_dynamo.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/aoti.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/aoti.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67343c114dcc40b90c589937a13027531bc37512 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/repro/__pycache__/aoti.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/repro/after_aot.py b/phivenv/Lib/site-packages/torch/_dynamo/repro/after_aot.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d43fa64bc6a105261bda61e645d3a3af5ba439 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/repro/after_aot.py @@ -0,0 +1,1062 @@ +# mypy: allow-untyped-defs + +""" +Utilities for reproducing and debugging issues in PyTorch's Dynamo AOT compilation. + +This module provides tools and infrastructure for: +1. Generating minimal reproducible test cases ("repros") from failing compilations +2. Analyzing accuracy issues between eager and compiled execution +3. Minifying large models/inputs to isolate problematic patterns +4. Debugging compiler errors and accuracy divergences + +The main components include: +- Repro generation: Creates standalone Python files that reproduce compiler issues +- Minification: Reduces large graphs to minimal failing examples +- Accuracy analysis: Compares compiled vs eager execution, with fp64 reference +- Debug tools: Dumps graph state, tracks intermediates, analyzes divergences + +This is primarily used by PyTorch developers and researchers to debug issues in +the Dynamo AOT compilation pipeline, particularly for the Inductor backend. +""" + +import argparse +import copy +import functools +import io +import logging +import os +import shutil +import subprocess +import sys +import textwrap +import uuid +from collections.abc import Sequence +from importlib import import_module +from tempfile import TemporaryFile +from typing import Any, Callable, TYPE_CHECKING, Union +from typing_extensions import Unpack + +import torch +import torch.fx as fx +import torch.nn as nn +from torch._dynamo.debug_utils import ( + _cuda_system_info_comment, + AccuracyError, + backend_accuracy_fails, + BuckTargetWriter, + cast_to_fp64, + extra_deps, + extra_imports, + generate_config_string, + generate_env_vars_string, + helper_for_dump_minify, + InputReader, + InputWriter, + MAX_CONSTANT_NUMEL_INLINE, + minifier_dir, + NNModuleToString, + NopInputReader, + same_two_models, +) +from torch._dynamo.utils import clone_inputs, counters, same +from torch._environment import is_fbcode +from torch._inductor.output_code import OutputCode +from torch._library.fake_class_registry import FakeScriptObject +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + fx_placeholder_targets, + has_free_symbols, +) +from torch.hub import tqdm + +from .. import config + + +if TYPE_CHECKING: + from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs + from torch._inductor.utils import InputType + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = is_fbcode() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MAIN ENTRY POINT +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def wrap_compiler_debug( + unconfigured_compiler_fn: "_CompileFxCallable", + compiler_name: str, +) -> "_CompileFxCallable": + """ + Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both + forward and backward call separately with the backend compiler_fn - like + inductor or nvfuser. Intercepting after Aot Autograd presents neat + abstraction, where all the params are lifted as graph inputs, making it easy + to save the graph as a string. + """ + + @functools.wraps(unconfigured_compiler_fn) + def debug_wrapper( + gm: torch.fx.GraphModule, + example_inputs: Sequence["InputType"], + **kwargs: Unpack["_CompileFxKwargs"], + ) -> OutputCode: + from torch._subclasses import FakeTensorMode + + compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) + + from torch._functorch.aot_autograd import get_aot_graph_name + + graph_name = get_aot_graph_name() + + # TODO: why do we need to deepcopy the original graph? + orig_graph = copy.deepcopy(gm.graph) + assert config.repro_after in ("dynamo", "aot", None) + + try: + # Call the compiler_fn - which is either aot_autograd or inductor + # with fake inputs + inner_compiled_fn = compiler_fn(gm, example_inputs) + except Exception: + # TODO: Failures here are troublesome because no real inputs, + # need a different serialization strategy + if config.repro_after == "aot": + if config.repro_level == 1: + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + example_inputs, + compiler_name, + ) + elif config.repro_level == 2: + dump_to_minify( + fx.GraphModule(gm, orig_graph), + example_inputs, + compiler_name, + ) + log.error("CompilerError") + raise + + # We may run regular PyTorch compute that may trigger Dynamo, do NOT + # recursively attempt to accuracy minify in that case! + def deferred_for_real_inputs( + real_inputs: Sequence["InputType"], **_kwargs: object + ) -> Any: + # This is a bit obscure: if we recursively try to accuracy minify + # the SAME function, this would trigger. But most of the time + # we should never hit this branch + assert not _kwargs + if config.repro_after != "aot": + assert not isinstance(inner_compiled_fn, str) + return inner_compiled_fn(real_inputs) + with config.patch(repro_after=None): + return inner_debug_fn(real_inputs) + + def inner_debug_fn(real_inputs): + """ + Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, + example_inputs can be fake tensors. We can call compiler_fn (which is + inductor or nvfuser) with fake tensors but the actually compiled_fn + should be called with real tensors. Therefore, the actual invocation + is deferred. + """ + # Copy the tensor attrs like shape, stride etc by converting to Fake Tensor + # because inductor clears the tensor list in its codegen. And example_inputs + # are available only for the first invocation. + fake_mode = FakeTensorMode() + copy_tensor_attrs = [ + fake_mode.from_tensor(x) if isinstance(x, torch.Tensor) else x + for x in real_inputs + ] + if config.repro_level == 3: + # Always dump the original module in case we have segfaults + dump_to_minify( + fx.GraphModule(gm, orig_graph), real_inputs, compiler_name + ) + + if config.repro_level == 4: + if compiler_name != "inductor": + raise NotImplementedError( + "Accuracy minification is supported for inductor only" + ) + failed = not same_two_models( + gm, + inner_compiled_fn, + real_inputs, + only_fwd=True, + ignore_non_fp=config.repro_ignore_non_fp, + ) + + if failed: + log.warning( + "Accuracy failed for the AOT Autograd graph %s", graph_name + ) + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + real_inputs, + f"{compiler_name}_accuracy", + ) + dump_to_minify( + fx.GraphModule(gm, orig_graph), + real_inputs, + f"{compiler_name}_accuracy", + ) + raise AccuracyError("Bad accuracy detected") + else: + # Call the compiled function with real inputs + return inner_compiled_fn(real_inputs) # type: ignore[operator] + else: + try: + # Call the compiled function with real inputs + out = inner_compiled_fn(real_inputs) # type: ignore[operator] + # sync cuda kernels to ensure IMA detection + for arg in example_inputs: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + torch.cuda.synchronize() + break + return out + except Exception: + if config.repro_level == 1: + dump_compiler_graph_state( + fx.GraphModule(gm, orig_graph), + copy_tensor_attrs, + compiler_name, + ) + elif config.repro_level == 2: + dump_to_minify( + fx.GraphModule(gm, orig_graph), + copy_tensor_attrs, + compiler_name, + ) + raise + + if config.repro_after == "aot": + compiled_fn = deferred_for_real_inputs + compiled_fn._boxed_call = True # type: ignore[attr-defined] + return compiled_fn # type: ignore[return-value] + else: + return inner_compiled_fn + + return debug_wrapper + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP REPROS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def maybe_fbcode_instructions(): + if is_fbcode(): + extra_deps_formatted = "\n".join([f' "{dep}",' for dep in extra_deps]) + if len(extra_deps_formatted) > 0: + extra_deps_formatted = "\n" + extra_deps_formatted + return f"""\ +\"\"\" +To run this script in fbcode: +- Create a directory (//scripts/{{your_unixname}}/repro) +- Put this file in scripts/{{your_unixname}}/repro/fx_graph_runnable.py +- Add a TARGETS file that looks like the following +- `buck2 run //scripts/{{your_unixname}}/repro:repro` + +NOTE: you may need additional deps to actually be able to run the script. +``` +# Contents of TARGETS file +load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") + +python_binary( + name = "repro", + main_src = "fx_graph_runnable.py", + deps = [ + "//caffe2:torch",{extra_deps_formatted} + ], +) +``` +\"\"\" +""" + else: + return "" + + +def generate_compiler_repro_string( + gm, args, *, stable_output=False, save_dir=None, stable_hash=False +): + model_str = textwrap.dedent( + f""" +{generate_env_vars_string(stable_output=stable_output)} +import torch +from torch import tensor, device +import torch.fx as fx +from torch._dynamo.testing import rand_strided +from math import inf +import torch._inductor.inductor_prims + +{generate_config_string(stable_output=stable_output)} + +isolate_fails_code_str = None + +{extra_imports} + +{maybe_fbcode_instructions()} + """ + ) + if not stable_output: + model_str += f"# torch version: {torch.version.__version__}\n" + if hasattr(torch.version, "cuda"): + model_str += f"# torch cuda version: {torch.version.cuda}\n" + if hasattr(torch.version, "git_version"): + model_str += f"# torch git version: {torch.version.git_version}\n\n\n" + model_str += _cuda_system_info_comment() + + model_str += NNModuleToString.convert(gm) + + # get hint shape/stride when dynamic shape enabled + def hint_if_symint(x): + return tuple(i.node.hint if isinstance(i, torch.SymInt) else i for i in x) + + writer = InputWriter(save_dir, stable_hash=stable_hash) + for placeholder, arg in zip(fx_placeholder_targets(gm), args): + if isinstance(arg, (int, torch.SymInt)): + writer.symint(placeholder, arg) + elif isinstance(arg, torch.Tensor): + # TODO: improve these names with FQN + writer.tensor(placeholder, arg) + elif arg is None: + writer.const(placeholder) + else: + # It's better to produce a slightly wrong repro string than none + # at all + writer.unsupported(placeholder, arg) + + model_str += "\n".join(writer.lines()) + "\n" + + model_str += "mod = Repro()\n" + return model_str + + +def save_graph_repro( + fd, + gm, + args, + compiler_name, + *, + stable_output=False, + save_dir=None, + command="run", + accuracy=None, + tracing_mode=None, + check_str=None, + stable_hash=False, +): + if any( + isinstance(arg, torch.fx.experimental._backward_state.BackwardState) + for arg in args + ): + fd.write( + "Repro is not generated due to existence of BackwardState in graph input" + ) + return + + fd.write( + generate_compiler_repro_string( + gm, + args, + stable_output=stable_output, + save_dir=save_dir, + stable_hash=stable_hash, + ) + ) + if accuracy is None: + accuracy = "_accuracy" in compiler_name + if tracing_mode is None: + tracing_mode = "real" + if any( + has_free_symbols(a) for a in args if not isinstance(a, FakeScriptObject) + ): + tracing_mode = "symbolic" + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._dynamo.repro.after_aot import run_repro\n") + fd.write( + f" with torch.no_grad():\n" + f" run_repro(mod, load_args, accuracy={accuracy!r}, command={command!r}, " + f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n" + f" # To run it separately, do \n" + f" # mod, args = run_repro(mod, load_args, accuracy={accuracy!r}, command='get_args', " + f"save_dir={save_dir!r}, tracing_mode={tracing_mode!r}, check_str={check_str!r})\n" + f" # mod(*args)" + ) + + +def dump_compiler_graph_state(gm, args, compiler_name, *, accuracy=None): + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + with open(file_name, "w") as fd: + save_graph_repro( + fd, gm, args, compiler_name, save_dir=subdir, accuracy=accuracy + ) + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") + try: + shutil.copyfile(file_name, repro_path) + log.warning("Copying repro file for convenience to %s", repro_path) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", repro_path) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP MINIFIER +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def dump_to_minify(gm, args, compiler_name: str): + out = io.StringIO() + # TODO: factor this out + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + save_graph_repro(out, gm, args, compiler_name, save_dir=subdir, command="minify") + return helper_for_dump_minify(out.getvalue()) + + +def isolate_fails( + fx_g, + args, + compiler_name: str, + env=None, + save_dir=None, + accuracy=None, + tracing_mode=None, + check_str=None, +): + if env is None: + env = {} + subdir = os.path.join(os.getcwd(), "isolate") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") + with open(file_name, "w") as fd: + save_graph_repro( + fd, + fx_g, + args, + compiler_name, + save_dir=save_dir, + command="minifier-query", + accuracy=accuracy, + tracing_mode=tracing_mode, + check_str=check_str, + ) + # with open(file_name, "r") as fd: + # print(fd.read()) + new_env = os.environ.copy() + new_env = {**new_env, **env} + stdout, stderr = TemporaryFile(), TemporaryFile() + + if use_buck: + cmd = BuckTargetWriter(file_name).write(print_msg=False) + else: + cmd = [sys.executable, file_name] + + p = subprocess.Popen( + cmd, + cwd=subdir, + stdout=stdout, + stderr=stderr, + env=new_env, + ) + p.wait() + + stdout.seek(0) + stderr.seek(0) + print( + textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "), file=sys.stdout + ) + print( + textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "), file=sys.stderr + ) + # print(f"Isolated test failed - {file_name}") + return p.returncode != 0 + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER TOOLS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def inductor_fails(fx_g, args, check_str=None): + has_cuda = False + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + has_cuda = True + break + + def sync(): + if has_cuda: + # Ensures that segfaults are surfaced + torch.cuda.synchronize() + + from torch._inductor.compile_fx import compile_fx_inner + + try: + result = fx_g(*args) + assert isinstance(result, (tuple, list)) + assert not any(isinstance(x, (tuple, list)) for x in result) + except Exception: + return False + + sync() + + try: + compile_mod = compile_fx_inner(fx_g, args) + assert not isinstance(compile_mod, str) + compile_mod(args) + sync() + except Exception as e: + if check_str is not None and check_str not in repr(e): + return False + print(repr(e)) + return True + return False + + +def inductor_accuracy_fails( + fx_g, args, check_str=None, *, require_fp64=False, ignore_non_fp=False +): + from torch._inductor.compile_fx import compile_fx_inner + + return backend_aot_accuracy_fails( + fx_g, + args, + compile_fx_inner, + require_fp64=require_fp64, + ignore_non_fp=ignore_non_fp, + ) + + +backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO MAIN +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def repro_common(options, mod, load_args): + # Invariant for graphs we generate with the repro script + assert not any(mod.named_parameters()) + for n, b in mod.named_buffers(): + if b.numel() > MAX_CONSTANT_NUMEL_INLINE: + log.warning( + "Constant %s was not serialized, generated random data instead. " + "If you think this is affecting you, please comment on " + "https://github.com/pytorch/pytorch/issues/100468", + n, + ) + + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + # Turn mod into a GraphModule the slow way + # TODO: speed this up + mod = make_fx(mod, tracing_mode=options.tracing_mode)(*args) + + torch._inductor.config.generate_intermediate_hooks = True + + return mod, args + + +ACCURACY_FAILS: dict[str, Callable[[nn.Module, Any], bool]] = { + "": inductor_fails, + # This might look inverted but it's not. strict_accuracy means "we will + # minify any time we see anything that diverges", whereas accuracy is more + # conservative, and will only minify if there is a meaningful fp64 + # divergence + "accuracy": functools.partial( + inductor_accuracy_fails, require_fp64=True, ignore_non_fp=True + ), + "strict_accuracy": inductor_accuracy_fails, +} + + +def repro_minifier_query(options, mod, load_args): + mod, args = repro_common(options, mod, load_args) + fail_fn = functools.partial( + ACCURACY_FAILS[options.accuracy], + check_str=options.check_str, # type: ignore[call-arg] + ) + if fail_fn(mod, args): + sys.exit(1) + else: + sys.exit(0) + + +def repro_minify(options, mod, load_args): + from functorch.compile import minifier + + mod, args = repro_common(options, mod, load_args) + compiler_name = "inductor_accuracy" if options.accuracy != "" else "inductor" + + favored_device = 1 if torch.cuda.device_count() >= 2 else 0 + env_variables = {"CUDA_VISIBLE_DEVICES": str(favored_device)} + + module_fails: Any + if options.isolate: + module_fails = functools.partial( + isolate_fails, + env=env_variables, + compiler_name=compiler_name, + save_dir=options.save_dir, + accuracy=options.accuracy, + tracing_mode=options.tracing_mode, + ) + else: + module_fails = ACCURACY_FAILS[options.accuracy] + + minifier( + mod, + args, + module_fails=functools.partial(module_fails, check_str=options.check_str), + dump_state=functools.partial( + dump_compiler_graph_state, compiler_name=compiler_name + ), + save_dir=options.save_dir, + offload_to_disk=options.offload_to_disk, + skip_offload=options.skip_saving_eager_intermediates, + skip_sanity=options.skip_sanity, + max_granularity=options.max_granularity, + ) + + +def repro_analyze(options, mod, load_args): + from torch._inductor.compile_fx import compile_fx_inner + from torch._inductor.hooks import intermediate_hook + + mod, args = repro_common(options, mod, load_args) + + # TODO: The logic for cloning inputs/models here is intentionally + # modeled off of run_fwd_maybe_bwd, but arguably it is better not to + # clone inputs (as you are doubling your effective GPU memory usage). + # It is certainly faster though! It probably makes sense to let the + # user specify the offload strategy. + + with tqdm(desc="Compiling"): + compiled = compile_fx_inner(mod, args) + total = counters["inductor"]["intermediate_hooks"] + + known_names = set() + + def save_hook(name, val): + known_names.add(name) + if not options.skip_saving_inductor_intermediates: + writer.write_tensor(os.path.join("inductor", name), val) + pbar.update(1) # type: ignore[has-type] + + writer = torch.utils._content_store.ContentStoreWriter( + options.save_dir, stable_hash=options.stable_hash + ) + reader = torch.utils._content_store.ContentStoreReader(options.save_dir) + + new_args = clone_inputs(args) + with ( + intermediate_hook(save_hook), + tqdm(desc="Saving inductor intermediates", total=total) as pbar, + ): + assert not isinstance(compiled, str) + compiled(new_args) + assert not new_args + + def compare_tuples(tuple1, tuple2): + diff_indices = [i for i in range(len(tuple1)) if tuple1[i] != tuple2[i]] + diff_values = [(tuple1[i], tuple2[i]) for i in diff_indices] + + if not diff_values: + return None + else: + return " and ".join(f"{a} != {b}" for a, b in diff_values) + + def check_hook(name, val): + meta = writer.compute_tensor_metadata(val) + meta2 = reader.read_tensor_metadata(os.path.join("inductor", name)) + reason = compare_tuples(meta, meta2) + if reason is not None: + pbar.write(f"NONDETERMINISTIC INDUCTOR at {name} ({reason})") + pbar.update(1) + + if not options.skip_check_deterministic: + new_args = clone_inputs(args) + with ( + intermediate_hook(check_hook), + tqdm(desc="Checking inductor determinism", total=total) as pbar, + ): + compiled(new_args) + assert not new_args + + class WriterInterp(fx.Interpreter): + def __init__(self, mod, subdir) -> None: + super().__init__(mod) + self.subdir = subdir + + def run_node(self, n): + r = super().run_node(n) + name = n.name + if name in known_names: + pbar.update(1) + writer.write_tensor(os.path.join(self.subdir, name), r) + return r + + # NB: the module cast doesn't actually do anything, since there are no + # parameters/buffers on the module + if not options.skip_saving_float64_intermediates: + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) + with tqdm(desc="Saving float64 intermediates", total=total) as pbar: + WriterInterp(new_mod, "float64").boxed_run(new_args) + assert not new_args + + class ExactReaderInterp(fx.Interpreter): + def run_node(self, n): + r = super().run_node(n) + name = n.name + if name in known_names: + meta = writer.compute_tensor_metadata(r) + meta2 = reader.read_tensor_metadata(os.path.join("float64", name)) + reason = compare_tuples(meta, meta2) + if reason is not None: + pbar.write(f"NONDETERMINISTIC FLOAT64 at {name} ({reason})") + pbar.update(1) + return r + + # TODO: check eager determinism + + if not options.skip_check_deterministic: + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) + with tqdm(desc="Checking float64 determinism", total=total) as pbar: + ExactReaderInterp(new_mod).boxed_run(new_args) + assert not new_args + + # Now that we've saved everything, interp through the eager graph + # and do comparisons + class ReaderInterp(fx.Interpreter): + def run_node(self, n): + r = super().run_node(n) + name = n.name + if name in known_names: + inductor = reader.read_tensor(os.path.join("inductor", name)) + float64 = reader.read_tensor(os.path.join("float64", name)) + logged = False + + def log_error(msg, *args): + nonlocal logged + logged = True + pbar.write(f"DIVERGED at {name}: {msg % args}") + + if not same( + r, + inductor, + float64, + tol=torch._dynamo.config.repro_tolerance, + equal_nan=True, + log_error=log_error, + ): + assert logged + pbar.update(1) + return r + + with tqdm(desc="Checking divergence", total=total) as pbar: + ReaderInterp(mod).boxed_run(args) + assert not args + + +def repro_get_args(options, mod, load_args): + mod, args = repro_common(options, mod, load_args) + return mod, args + + +def repro_run(options, mod, load_args): + from torch._inductor.compile_fx import compile_fx_inner + + mod, args = repro_common(options, mod, load_args) + + from torch.cuda import synchronize + + compiled = compile_fx_inner(mod, args) + assert not isinstance(compiled, str) + + if options.accuracy != "": + # We don't really respect --accuracy vs --strict-accuracy here, it + # seems counterintuitive + if not same_two_models( + mod, + compiled, + args, + only_fwd=True, + ignore_non_fp=config.repro_ignore_non_fp, + ): + raise AccuracyError("Bad accuracy detected") + else: + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + compiled(list(args)) + + if need_sync: + synchronize() # ensure segfaults are surfaced + + +# TODO: lazily load the inputs or something, rather than cloning them +def run_repro( + mod, + load_args, + *, + command="run", + accuracy: Union[bool, str] = "", + save_dir=None, + tracing_mode=None, + patch_code=None, + check_str=None, + **kwargs, +): + for k in kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + if patch_code is not None: + log.warning( + "patch_code no longer works on this version of PyTorch, silently ignoring" + ) + + parser = argparse.ArgumentParser( + description=f"""\ +An after_aot repro script, typically triggering a bug in PyTorch Inductor. +When run with no arguments, this script defaults to running '{command}'. +Extra flags may be available; to find out more, try '{command} --help'. +There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {tracing_mode=} + {save_dir=} + {check_str=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser): + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="""\ +test if the RMSE between the compiled module and the fp64 reference is greater +than eager and the fp64 reference. This is usually more reliable than the +standard allclose test, as we expect numeric differences from compiling, often +improving accuracy over eager. RMSE test allows for compiled module to +diverge greatly from eager, as long as this divergence moves it closer to the +'true' mathematical value of the network. Caveats: (1) double precision can +still suffer from rounding error, so it is not a perfect reference (see for +example 'Herbie: Automatically Improving Floating Point Accuracy') for +approaches that detect the necessary working precision and compute it in +arbitrary precision floating point; unfortunately, this is not practical for +tensor computation; (2) if there are not enough samples in the output being +compared, we may get unlucky and have an unlucky greater RMSE than eager; this +could be overcome by applying a more rigorous statistical test at some +p-value, which we leave for future work. +""", + ) + accuracy_group.add_argument( + "--strict-accuracy", + dest="accuracy", + action="store_const", + const="strict_accuracy", + default=accuracy, + help="""\ +by default, when doing accuracy minification we will reject reductions which +change the divergence from a floating point divergence to a integral/boolean +divergence. This is because some operations like ReLU involve temporarily +sharp boundaries that smooth out again afterwards; without requiring +divergence on floating point, the minifier will often fixate on divergent +boolean tensor even though this is not the true source of the divergence. +However, rejecting these reductions makes it more difficult for the minifier +to make process. Using this option will let the minifier progress for ALL +divergences--you just might not end up with a useful repro in the end.""", + ) + + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + parser.add_argument( + "--tracing-mode", + type=str, + metavar="{real,fake,symbolic}", + default=tracing_mode, + help="how to trace the repro module into a GraphModule with metadata", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify,analyze}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + parser_get_args = subparsers.add_parser("get_args", help="get the args") + common_flags(parser_get_args) + parser_minify_isolate = parser_minify.add_mutually_exclusive_group() + parser_minify_isolate.add_argument( + "--isolate", + action="store_true", + default=True, + help="run in separate processes to avoid interference (default)", + ) + parser_minify_isolate.add_argument( + "--no-isolate", + dest="isolate", + action="store_false", + help="speed up by running all compilation in same process", + ) + parser_minify.add_argument( + "--skip-saving-eager-intermediates", + action="store_true", + help="skip saving eager intermediates on --minify", + ) + # TODO: make this an option for --analyze too + parser_minify.add_argument( + "--offload-to-disk", + action="store_true", + help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing", + ) + parser_minify.add_argument( + "--skip-sanity", + action="store_true", + help="skip sanity check at beginning of minification on original graph", + ) + parser_minify.add_argument( + "--max-granularity", + type=int, + default=None, + help="start at this granularity and work down; must be power of 2", + ) + parser_minify.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + parser_analyze = subparsers.add_parser( + "analyze", help="run the accuracy analyzer on the repro" + ) + common_flags(parser_analyze) + parser_analyze.add_argument( + "--skip-saving-inductor-intermediates", + action="store_true", + help="skip saving inductor intermediates on --analyze", + ) + parser_analyze.add_argument( + "--skip-saving-float64-intermediates", + action="store_true", + help="skip saving float64 intermediates", + ) + parser_analyze.add_argument( + "--skip-check-deterministic", + action="store_true", + help="skip checking that the network is deterministic", + ) + parser_analyze.add_argument( + "--stable-hash", + action="store_true", + help="use SHA-1 checksum instead of fast (but possibly unsound) hash", + ) + + # Run the repro in the context of minification, inverting exit code meaning + parser_minifier_query = subparsers.add_parser( + "minifier-query", + ) + common_flags(parser_minifier_query) + parser_minifier_query.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "analyze": repro_analyze, + "minifier-query": repro_minifier_query, + "run": repro_run, + "get_args": repro_get_args, + } + return COMMAND_FNS[options.command](options, mod, load_args) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/repro/after_dynamo.py b/phivenv/Lib/site-packages/torch/_dynamo/repro/after_dynamo.py new file mode 100644 index 0000000000000000000000000000000000000000..9084a1503aba0995feb956dbdb65d434155d1cd5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/repro/after_dynamo.py @@ -0,0 +1,607 @@ +# mypy: allow-untyped-defs + +""" +Utilities for reproducing and debugging issues in Dynamo after graph capture. + +This file provides tools and infrastructure for debugging problems that occur +after Dynamo has captured the graph but before/during backend compilation. +Key components include: + +- Minification tools to reduce large graphs to minimal failing examples +- Accuracy testing to validate compiled graph outputs match eager mode +- Repro generation to create standalone reproduction scripts +- Debug backends for capturing and analyzing failures +- Utilities for saving/loading graph states and inputs + +The tools here focus specifically on the post-graph-capture stage, making them +useful for debugging backend compilation issues, AOTAutograd problems, and +accuracy discrepancies between compiled and eager execution. +""" + +import argparse +import copy +import functools +import logging +import os +import shutil +import sys +import textwrap +from importlib import import_module +from typing import Union + +import torch +import torch.fx as fx +from torch._dynamo.backends.registry import CompiledFn +from torch._dynamo.debug_utils import ( + AccuracyError, + backend_accuracy_fails, + BUCK_CMD_PREFIX, + BuckTargetWriter, + extra_imports, + generate_config_string, + generate_env_vars_string, + helper_for_dump_minify, + InputReader, + InputWriter, + minifier_dir, + NNModuleToString, + NopInputReader, + run_fwd_maybe_bwd, + same_two_models, +) +from torch.fx.experimental.symbolic_shapes import fx_placeholder_targets +from torch.hub import tqdm + +from .. import config +from ..backends.registry import lookup_backend, register_debug_backend +from ..debug_utils import clone_inputs_retaining_gradness + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MAIN ENTRY POINT +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def _accuracy_fails(gm, example_inputs, compiler_fn): + return backend_accuracy_fails( + gm, + example_inputs, + compiler_fn, + only_fwd=config.repro_forward_only, + ignore_non_fp=config.repro_ignore_non_fp, + ) + + +class WrapBackendDebug: + def __init__(self, unconfigured_compiler_fn, compiler_name: str) -> None: + functools.wraps(unconfigured_compiler_fn)(self) + self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] + self._compiler_name = compiler_name + if hasattr(unconfigured_compiler_fn, "__name__"): + self.__name__ = unconfigured_compiler_fn.__name__ + if hasattr(unconfigured_compiler_fn, "compiler_name"): + self.__name__ = unconfigured_compiler_fn.compiler_name + if hasattr(unconfigured_compiler_fn, "get_compiler_config"): + self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] + + def __call__(self, gm, example_inputs, **kwargs): + compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs) + assert config.repro_after in ("dynamo", "aot", None) + + if config.repro_after == "dynamo": + + def add_paths(exc): + exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") + if use_buck: + exc.buck_command = " ".join( + BUCK_CMD_PREFIX + + [BuckTargetWriter(exc.minifier_path).cmd_line_path] + ) + + if config.repro_level == 3: + dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name) + + # Check for either accuracy (level 4) or other type of failures. + if config.repro_level == 4: + # Check Accuracy + compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) + if _accuracy_fails(gm, example_inputs, compiler_fn): + log.warning( + "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error." + ) + dump_to_minify_after_dynamo( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), + example_inputs, + self._compiler_name, + ) + exc = AccuracyError("Bad accuracy detected.") + add_paths(exc) + raise exc + else: + try: + compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) + except Exception as exc: + log.warning( + "Compiled Fx GraphModule failed. Creating script to minify the error." + ) + if config.repro_level == 1: + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=self._compiler_name + ) + dump_state_fn( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs + ) + elif config.repro_level == 2: + dump_to_minify_after_dynamo( + fx.GraphModule(gm, copy.deepcopy(gm.graph)), + example_inputs, + self._compiler_name, + ) + add_paths(exc) + raise + else: + compiled_gm = compiler_fn(gm, example_inputs) + + return compiled_gm + + +def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): + """ + A minifier decorator that wraps the TorchDynamo produced Fx graph modules. + As opposed to wrap_compiler_debug, this wrapper intercepts at the + TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some + level, e.g., it is useful for minifying issues related to Aot Autograd + tracing. If an error is found, we minify and save the minified repro in + repro.tar.gz. + """ + return WrapBackendDebug(unconfigured_compiler_fn, compiler_name) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO DUMPERS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def generate_dynamo_fx_repro_string( + gm, + args, + compiler_name, + check_accuracy=False, + *, + stable_output=False, + save_dir=None, + command="run", +): + """ + Generate a repro string for backend-agnostic minified version. + """ + + model_str = NNModuleToString.convert(gm) + + # TODO: Figure out why torch.compile'd hash isn't work on this codepath + writer = InputWriter(save_dir, stable_hash=True) + for placeholder, arg in zip(fx_placeholder_targets(gm), args): + if isinstance(arg, (int, torch.SymInt)): + writer.symint(placeholder, arg) + elif isinstance(arg, torch.Tensor): + # TODO: improve these names with FQN + writer.tensor(placeholder, arg) + else: + raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}") + load_args = "\n".join(writer.lines()) + + return textwrap.dedent( + f""" +{generate_env_vars_string(stable_output=stable_output)} +from math import inf +import torch +from torch import tensor, device +import torch.fx as fx +import torch._dynamo +from torch._dynamo.testing import rand_strided +from torch._dynamo.debug_utils import run_fwd_maybe_bwd + +{generate_config_string(stable_output=stable_output)} + +{extra_imports} + +{model_str} +mod = Repro() + +{load_args} + +if __name__ == '__main__': + from torch._dynamo.repro.after_dynamo import run_repro + run_repro(mod, load_args, accuracy={check_accuracy!r}, command={command!r}, + save_dir={save_dir!r}, autocast={torch.is_autocast_enabled()!r}, backend={compiler_name!r}) +""" + ) + + +def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): + """ + Saves the repro to a repro.py file + """ + curdir = os.getcwd() + subdir = os.path.join(os.getcwd(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + + with open(file_name, "w") as fd: + fd.write( + generate_dynamo_fx_repro_string( + gm, args, compiler_name, check_accuracy, save_dir=subdir + ) + ) + latest_repro = os.path.join(curdir, "repro.py") + log.warning("Copying %s to %s for convenience", file_name, latest_repro) + + if use_buck: + BuckTargetWriter(latest_repro).write() + + shutil.copyfile(file_name, latest_repro) + + +def dump_backend_state(gm, args, compiler_name, check_accuracy=False): + """ + Dumps the dynamo graph to repro the issue. + 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a + repro.py file. + 2) If we can't convert Fx GraphModule to a string, we use to_folder to save + the module and save a tar file. + """ + assert NNModuleToString.can_convert_to_string(gm) + return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy) + # return dump_backend_repro_as_tarfile(gm, args, compiler_name) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER DUMPER +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def dump_to_minify_after_dynamo(gm, args, compiler_name): + # TODO: factor this out + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + helper_for_dump_minify( + generate_dynamo_fx_repro_string( + gm, + args, + compiler_name, + check_accuracy=config.repro_level == 4, + save_dir=subdir, + command="minify", + ) + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# MINIFIER BACKENDS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@register_debug_backend # type: ignore[arg-type] +def dynamo_minifier_backend( + gm: fx.GraphModule, example_inputs, compiler_name: CompiledFn +): + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) + + # TODO: It's inconsistent to pass SymInt inputs but REAL tensors. + # We should pass ints and look at the GraphModule placeholders + # to resolve them to SymInt (if necessary) + example_inputs = [ + i.node.hint if isinstance(i, torch.SymInt) else i for i in example_inputs + ] + + try: + compiled_gm = compiler_fn(gm, example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) + raise ValueError("No issue was detected") + except Exception as exc: + orig_failure = str(exc) + log.warning( + "Compiled Fx GraphModule failed. Creating script to minify the error." + ) + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=compiler_name + ) + dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) + fails_fn = functools.partial( + backend_fails, + compiler_fn=compiler_fn, + orig_failure=orig_failure, + ) + minifier( + gm, + example_inputs, + module_fails=fails_fn, + dump_state=dump_state_fn, + ) + return gm + + +@register_debug_backend # type: ignore[arg-type] +def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) + + # Set the eval mode to remove randomness. + gm.eval() + + # Check Accuracy + if _accuracy_fails(gm, example_inputs, compiler_fn): + log.warning("Accuracy failed for the TorchDynamo produced graph") + dump_state_fn = functools.partial( + dump_backend_state, compiler_name=compiler_name, check_accuracy=True + ) + fails_fn = functools.partial( + _accuracy_fails, + compiler_fn=compiler_fn, + ) + dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) + minifier( + gm, + example_inputs, + module_fails=fails_fn, + dump_state=dump_state_fn, + ) + else: + log.error("Input graph does not fail accuracy testing") + return gm + + +def backend_fails(gm, example_inputs, compiler_fn, orig_failure): + """ + Minifier uses this function to identify if the minified graph module fails + with the same error. + + One caveat is that minifier can potentially go into a wrong direction when + the resulting graph module fails for a different reason. To avoid this, we + save the string for the original exception and check similarity between new + and old exception. They can be somewhat different in some cases, when the + exception string depends on the failing node information. So, we have a + loose similarity metric to guide the minifier path. + """ + from difflib import SequenceMatcher + + try: + # Run the original gm to check eager validity + run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs)) + compiled_gm = compiler_fn(gm, example_inputs) + run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) + except Exception as e: + new_failure = str(e) + if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: + return True + return False + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# REPRO MAIN +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def run_load_args(options, mod, load_args): + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=options.save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + return args + + +def repro_minify(options, mod, load_args): + args = run_load_args(options, mod, load_args) + + # Setup debug minifier compiler + if not options.accuracy: + compiler_fn = lookup_backend("dynamo_minifier_backend") + else: + compiler_fn = lookup_backend("dynamo_accuracy_minifier_backend") + + if options.backend is None: + raise RuntimeError( + "Compiler name is None - this likely means that a custom compiler " + "was called by torchdynamo. Please remove this error, import your " + "custom compiler function, and replace the backend=None " + "line in run_repro to backend=" + ) + + dynamo_minifier_backend = functools.partial( + compiler_fn, + compiler_name=options.backend, + ) + opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod) + + with torch.amp.autocast("cuda", enabled=options.autocast): + opt_mod(*args) + + +def repro_run(options, mod, load_args): + opt_mod = torch._dynamo.optimize(options.backend)(mod) + + if options.accuracy != "": + mod.eval() + opt_mod.eval() + + with torch.amp.autocast("cuda", enabled=options.autocast): + # TODO: disable clone + args = run_load_args(options, mod, load_args) + assert same_two_models(mod, mod, args), "Eager itself failed" + if not same_two_models( + mod, + opt_mod, + args, + only_fwd=config.repro_forward_only, + ignore_non_fp=config.repro_ignore_non_fp, + ): + raise AccuracyError("Dynamo failed") + else: + with torch.amp.autocast("cuda", enabled=options.autocast): + args = run_load_args(options, mod, load_args) + run_fwd_maybe_bwd(mod, args, only_fwd=options.only_fwd, disable_clone=True) + del args + + args = run_load_args(options, mod, load_args) + run_fwd_maybe_bwd( + opt_mod, args, only_fwd=options.only_fwd, disable_clone=True + ) + + +def run_repro( + mod, + load_args, + *, + command="run", + accuracy: Union[bool, str] = "", + save_dir=None, + autocast=False, + backend="inductor", + **kwargs, +): + for k in kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + parser = argparse.ArgumentParser( + description=f"""\ +An after_dynamo repro script, typically triggering a bug in Dynamo or +AOTAutograd. When run with no arguments, this script defaults to running +'{command}'. Extra flags may be available; to find out more, try '{command} +--help'. There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {save_dir=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser): + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="test accuracy", + ) + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + parser.add_argument( + "--no-isolate", + dest="isolate", + action="store_false", + default=False, + help="no isolate (doesn't do anything for after_dynamo)", + ) + parser.add_argument( + "--autocast", + default=autocast, + action="store_true", + help="use torch.cuda.amp.autocast", + ) + parser.add_argument( + "--no-autocast", + dest="autocast", + action="store_false", + help="don't use torch.cuda.amp.autocast", + ) + parser.add_argument( + "--backend", + type=str, + default=backend, + metavar="BACKEND", + help="torch.compile backend to use", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + parser_run.add_argument( + "--only-fwd", + action="store_true", + help="don't run backwards compilation for testing", + ) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "run": repro_run, + } + COMMAND_FNS[options.command](options, mod, load_args) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/repro/aoti.py b/phivenv/Lib/site-packages/torch/_dynamo/repro/aoti.py new file mode 100644 index 0000000000000000000000000000000000000000..1a23218ee15d770964d2584cd6b44be2d692ce8d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/repro/aoti.py @@ -0,0 +1,637 @@ +# mypy: allow-untyped-defs + +""" +Utilities for debugging and reproducing issues in Ahead of Time with Inductor (AOTI) compilation. + +This file provides tools and utilities for: +- Generating minimal reproducible test cases (minification) +- Handling exported programs and graph modules +- Creating debug repros for AOTI compilation issues +- Supporting both accuracy testing and error reproduction +- Managing configuration and environment for repro cases + +The main components include: +- Minification tools to reduce test cases while preserving errors +- Repro generation utilities for exported programs +- Error handling specific to AOTI compilation +- Command-line interface for running and managing repros +""" + +import argparse +import functools +import io +import logging +import os +import re +import shutil +import sys +import textwrap +from importlib import import_module +from typing import Any, Optional, Union + +import torch +from torch._dynamo.debug_utils import ( + _cuda_system_info_comment, + BuckTargetWriter, + extra_imports, + generate_config_string, + generate_env_vars_string, + helper_for_dump_minify, + InputReader, + minifier_dir, + NNModuleToString, + NopInputReader, +) +from torch.export import ExportedProgram +from torch.hub import tqdm + + +log = logging.getLogger(__name__) + + +inductor_config = import_module("torch._inductor.config") +use_buck = inductor_config.is_fbcode() + + +class AOTIMinifierError(Exception): + def __init__(self, original_exception): + additional_message = "This error is caused by a bug in the AOTI minifier, please report a bug to PyTorch" + full_message = f"{additional_message}: {str(original_exception)}" + super().__init__(full_message) + self.original_exception = original_exception + + +def dump_to_minify( + exported_program: ExportedProgram, + compiler_name: str, + command: str = "minify", + options: Optional[dict[str, Any]] = None, +): + """ + If command is "minify": + Dump exported_program to `debug_dir/minifier/minifier_launcher.py`, with minify command. + If command is "run": + Dump exported_program to `cwd/repro.py`, with run command. + """ + assert command in ["minify", "run"] + + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + if command == "minify": + out = io.StringIO() + save_graph_repro_ep( + out, + compiler_name, + exported_program=exported_program, + save_dir=subdir, + command="minify", + config_patches=options, + ) + return helper_for_dump_minify(out.getvalue()) + else: + curdir = os.getcwd() + file_name = os.path.join(curdir, "repro.py") + try: + with open(file_name, "w") as fd: + save_graph_repro_ep( + fd, + compiler_name, + exported_program=exported_program, + config_patches=options, + save_dir=subdir, + command="run", + module_in_comment=True, + ) + log.warning("Writing repro file to %s", file_name) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", file_name) + + +def get_module_string(gm): + def _convert_to_comment(s_): + s = s_.split("\n") + if len(s) == 1: + return "# " + s_ + first = s.pop(0) + for i in range(len(s)): + line = s[i] + if line.strip() != "": + s[i] = "# " + line + else: + s[i] = "" + s = "\n".join(s) + s = first + "\n" + s + return s + + module_string = NNModuleToString.convert(gm) + return _convert_to_comment(module_string) + + +def save_graph_repro_ep( + fd, + compiler_name, + *, + exported_program: Optional[ExportedProgram] = None, + gm: Optional[torch.nn.Module] = None, + args: Optional[tuple[Any]] = None, + config_patches: Optional[dict[str, str]] = None, + stable_output=False, + save_dir=None, + command="run", + accuracy=None, + check_str=None, + module_in_comment=False, + strict=False, +): + # Save graph for reproducing the error. + # Either exported_program or gm will be saved, depending on which one is defined. + # Only one of exported_program and gm should be defined. + + if exported_program is None and gm is None: + raise AOTIMinifierError("One of exported_program and gm must be defined") + if exported_program is not None and gm is not None: + raise AOTIMinifierError("Only one of exported_program and gm can be defined") + if gm is not None and args is None: + raise AOTIMinifierError("If gm is defined, args should also be defined") + + if exported_program is None: + assert gm is not None + assert args is not None + exported_program = torch.export.export(gm, args, strict=strict) + elif gm is None: + gm = exported_program.module() + + # save a graph preview using gm + module_string = get_module_string(gm) + fd.write(module_string) + + # save a graph repro using exported_program + fd.write( + generate_compiler_repro_exported_program( + exported_program, + options=config_patches, + stable_output=stable_output, + save_dir=save_dir, + ) + ) + if accuracy is None: + accuracy = "_accuracy" in compiler_name + fd.write("if __name__ == '__main__':\n") + fd.write(" from torch._dynamo.repro.aoti import run_repro\n") + fd.write( + f" with torch.no_grad():\n" + f" run_repro(exported_program, config_patches=config_patches, accuracy={accuracy!r}, command={command!r}, " + f"save_dir={save_dir!r}, check_str={check_str!r})\n" + ) + + +def dump_compiler_graph_state( + gm, + args, + compiler_name, + *, + config_patches=None, + accuracy=None, + strict=False, +): + subdir = os.path.join(minifier_dir(), "checkpoints") + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + log.warning( + "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name + ) + with open(file_name, "w") as fd: + save_graph_repro_ep( + fd, + compiler_name, + gm=gm, + args=tuple(args), + config_patches=config_patches, + save_dir=subdir, + accuracy=accuracy, + module_in_comment=True, + strict=strict, + ) + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") + try: + shutil.copyfile(file_name, repro_path) + log.warning("Copying repro file for convenience to %s", repro_path) + if use_buck: + BuckTargetWriter(file_name).write() + except OSError: + log.warning("No write permissions for %s", repro_path) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# DUMP REPROS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def generate_compiler_repro_exported_program( + exported_program, + *, + options: Optional[dict[str, str]] = None, + stable_output=False, + save_dir=None, +): + model_str = textwrap.dedent( + f""" +{generate_env_vars_string(stable_output=stable_output)} +import torch +import torch._inductor.inductor_prims + +{generate_config_string(stable_output=stable_output)} + +isolate_fails_code_str = None + +{extra_imports} + + """ + ) + if not stable_output: + model_str += f"# torch version: {torch.version.__version__}\n" + if hasattr(torch.version, "cuda"): + model_str += f"# torch cuda version: {torch.version.cuda}\n" + if hasattr(torch.version, "git_version"): + model_str += f"# torch git version: {torch.version.git_version}\n\n\n" + model_str += _cuda_system_info_comment() + + ep_path = os.path.join(save_dir, "exported_program.pt2") + torch.export.save(exported_program, ep_path) + + model_str += f"exported_program = torch.export.load('{ep_path}')\n" + model_str += "# print(exported_program.graph)\n" + model_str += f"config_patches={options}\n" + return model_str + + +def repro_load_args(load_args, save_dir): + if not hasattr(load_args, "_version"): + log.warning( + "load_args does not have a _version attribute, please file a bug to PyTorch " + "and describe how you generate this repro script" + ) + else: + if load_args._version > 0: + log.warning( + "load_args is version %s, but this version of PyTorch only supports " + "version 0. We will try to run it anyway but there may be an incompatibility; " + "if so, try upgrading your version of PyTorch.", + load_args._version, + ) + + nop_reader = NopInputReader() + load_args(nop_reader) + + with tqdm(desc="Loading inputs", total=nop_reader.total) as pbar: + input_reader = InputReader(save_dir=save_dir, pbar=pbar) + load_args(input_reader) + args = input_reader.args + + return tuple(args) + + +def repro_common(options, exported_program): + torch._inductor.config.generate_intermediate_hooks = True + mod = exported_program.module() + args, kwargs = exported_program.example_inputs + return mod, args, kwargs + + +def repro_get_args(options, exported_program, config_patches): + mod, args, kwargs = repro_common(options, exported_program) + return mod, args, kwargs + + +def repro_run(options, exported_program, config_patches): + from torch._inductor import _aoti_compile_and_package_inner + + gm, args, kwargs = repro_common(options, exported_program) + + from torch.cuda import synchronize + + _aoti_compile_and_package_inner( + gm, + args, + kwargs, + load_and_run=True, + check_accuracy=options.accuracy, + inductor_configs=config_patches, + ) + + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + if need_sync: + synchronize() # ensure segfaults are surfaced + + +def export_for_aoti_minifier( + gm, tuple_inputs, strict=False, skip_export_error=True +) -> Optional[torch.nn.Module]: + # Some graphs cannot be used for AOTI/export (illegal graphs), these should be + # considered as graphs that don't fail in the minifier, so the minifier keeps searching. + # In these case, we return None. Otherwise, we return the exported graph module. + # This won't affect the minifier result because the minifier is only responsible for catching + # errors in AOTI, not export. + # + # Please add to this list of illegal graphs if you change the implementation here. + # - graph output is not allowed by export + # + # If skip_export_error=True, then the errors in export will not be raised, and the minifier + # will keep exploring and ignore this graph. + from torch._dynamo.exc import UserError, UserErrorType + + try: + ep = torch.export.export(gm, tuple_inputs, strict=strict) + gm = ep.module() + return gm + except Exception as e: + if skip_export_error: + return None + if isinstance(e, UserError) and e.error_type == UserErrorType.INVALID_OUTPUT: + # graph output is not allowed by export when strict=True + return None + if isinstance(e, RuntimeError): + # graph output is not allowed by export when strict=False + pattern = r"Found .* in output, which is not a known type\." + if re.search(pattern, str(e)) is not None: + return None + raise AOTIMinifierError(e) from e + # we should never reach here + return None + + +def repro_minify(options, exported_program, config_patches): + from functorch.compile import minifier + from torch._inductor import _aoti_compile_and_package_inner + from torch._inductor.compile_fx import _aoti_flatten_inputs + + mod, args, kwargs = repro_common(options, exported_program) + + # update serialized_in_spec and serialized_out_spec + flat_example_inputs, inductor_configs = _aoti_flatten_inputs( + mod, args, kwargs, options=config_patches + ) + compiler_name = "aot_inductor" + assert options.minifier_export_mode in ["dynamo", "python"] + strict = options.minifier_export_mode == "dynamo" + skip_export_error = options.skip_export_error + + from torch.cuda import synchronize + + need_sync = False + + for arg in args: + if isinstance(arg, torch.Tensor) and arg.is_cuda: + need_sync = True + break + + def module_fails(gm, flat_example_inputs, check_str=None): + # Need to export first so the in_spec and out_spec are populated + tuple_inputs = tuple(flat_example_inputs) + gm = export_for_aoti_minifier( + gm, tuple_inputs, strict=strict, skip_export_error=skip_export_error + ) + + # Some graphs cannot be used for AOTI/export (illegal graphs), these should be + # considered as graphs that don't fail in the minifier, so the minifier keeps searching. + if gm is None: + return False + + assert isinstance(gm, torch.fx.GraphModule) + + try: + _aoti_compile_and_package_inner( + gm, + tuple_inputs, + load_and_run=True, + check_accuracy=options.accuracy, + inductor_configs=inductor_configs, + ) + if need_sync: + synchronize() # ensure segfaults are surfaced + return False + except Exception as e: + if check_str is not None and check_str not in repr(e): + return False + return True + + minifier( + mod, + flat_example_inputs, + module_fails=functools.partial(module_fails, check_str=options.check_str), + dump_state=functools.partial( + dump_compiler_graph_state, + compiler_name=compiler_name, + config_patches=config_patches, + accuracy=options.accuracy, + strict=strict, + ), + save_dir=options.save_dir, + offload_to_disk=options.offload_to_disk, + skip_offload=options.skip_saving_eager_intermediates, + skip_sanity=options.skip_sanity, + max_granularity=options.max_granularity, + ) + + +def run_repro( + exported_program, + *, + config_patches: Optional[dict[str, str]] = None, + command="run", + accuracy: Union[bool, str] = "", + save_dir=None, + tracing_mode=None, + check_str=None, + minifier_export_mode="python", + skip_export_error=True, + **more_kwargs, +): + for k in more_kwargs: + log.warning( + "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", + k, + ) + + if accuracy is True: + accuracy = "accuracy" + elif accuracy is False: + accuracy = "" + + parser = argparse.ArgumentParser( + description=f"""\ +An AOTI repro script, typically triggering a bug in PyTorch AOTInductor. +When run with no arguments, this script defaults to running '{command}'. +Extra flags may be available; to find out more, try '{command} --help'. +There are also alternate subcommands available, see below. + +default settings on this script: + {accuracy=} + {tracing_mode=} + {save_dir=} + {check_str=} +""", + formatter_class=argparse.RawTextHelpFormatter, + ) + + def common_flags(parser): + accuracy_group = parser.add_mutually_exclusive_group() + accuracy_group.add_argument( + "--no-accuracy", + dest="accuracy", + action="store_const", + const="", + default=accuracy, + help="do not test accuracy, just run the module and see if it errors", + ) + accuracy_group.add_argument( + "--accuracy", + action="store_const", + const="accuracy", + default=accuracy, + help="""\ +test if the RMSE between the compiled module and the fp64 reference is greater +than eager and the fp64 reference. This is usually more reliable than the +standard allclose test, as we expect numeric differences from compiling, often +improving accuracy over eager. RMSE test allows for compiled module to +diverge greatly from eager, as long as this divergence moves it closer to the +'true' mathematical value of the network. Caveats: (1) double precision can +still suffer from rounding error, so it is not a perfect reference (see for +example 'Herbie: Automatically Improving Floating Point Accuracy') for +approaches that detect the necessary working precision and compute it in +arbitrary precision floating point; unfortunately, this is not practical for +tensor computation; (2) if there are not enough samples in the output being +compared, we may get unlucky and have an unlucky greater RMSE than eager; this +could be overcome by applying a more rigorous statistical test at some +p-value, which we leave for future work. +""", + ) + accuracy_group.add_argument( + "--strict-accuracy", + dest="accuracy", + action="store_const", + const="strict_accuracy", + default=accuracy, + help="""\ +by default, when doing accuracy minification we will reject reductions which +change the divergence from a floating point divergence to a integral/boolean +divergence. This is because some operations like ReLU involve temporarily +sharp boundaries that smooth out again afterwards; without requiring +divergence on floating point, the minifier will often fixate on divergent +boolean tensor even though this is not the true source of the divergence. +However, rejecting these reductions makes it more difficult for the minifier +to make process. Using this option will let the minifier progress for ALL +divergences--you just might not end up with a useful repro in the end.""", + ) + + parser.add_argument( + "--save-dir", + type=str, + default=save_dir, + metavar="DIR", + help="directory where saved inputs live", + ) + parser.add_argument( + "--no-save-dir", + dest="save_dir", + action="store_const", + const=None, + help="don't use any directory for saved inputs", + ) + + subparsers = parser.add_subparsers( + dest="command", metavar="{run,minify}", required=True + ) + + parser_run = subparsers.add_parser( + "run", + help="just run the repro", + ) + common_flags(parser_run) + + parser_minify = subparsers.add_parser( + "minify", help="run the minifier on the repro" + ) + common_flags(parser_minify) + parser_get_args = subparsers.add_parser("get_args", help="get the args") + common_flags(parser_get_args) + parser_minify.add_argument( + "--skip-saving-eager-intermediates", + action="store_true", + help="skip saving eager intermediates on --minify", + ) + parser_minify.add_argument( + "--offload-to-disk", + action="store_true", + help="during minification, offload delta debugging intermediates to disk. Use if you're OOMing", + ) + parser_minify.add_argument( + "--skip-sanity", + action="store_true", + help="skip sanity check at beginning of minification on original graph", + ) + parser_minify.add_argument( + "--max-granularity", + type=int, + default=None, + help="start at this granularity and work down; must be power of 2", + ) + parser_minify.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + parser_minify.add_argument( + "--minifier-export-mode", + type=str, + default=minifier_export_mode, + help=( + "The export mode used in minifier, either dynamo or python." + "`dynamo` corresponds to strict=True, and `python` corresponds to strict=False." + ), + ) + parser_minify.add_argument( + "--skip-export-error", + type=bool, + default=skip_export_error, + help="Skip intermediate graphs that cannot be exported.", + ) + + # Run the repro in the context of minification, inverting exit code meaning + parser_minifier_query = subparsers.add_parser( + "minifier-query", + ) + common_flags(parser_minifier_query) + parser_minifier_query.add_argument( + "--check-str", + type=str, + default=check_str, + help="require minified program to fail with error containing this string", + ) + + args = None + if len(sys.argv) <= 1: + args = [command, *sys.argv[1:]] + + options = parser.parse_args(args) + COMMAND_FNS = { + "minify": repro_minify, + "run": repro_run, + "get_args": repro_get_args, + } + return COMMAND_FNS[options.command]( + options, exported_program, config_patches=config_patches + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/resume_execution.py b/phivenv/Lib/site-packages/torch/_dynamo/resume_execution.py new file mode 100644 index 0000000000000000000000000000000000000000..fe179dff6fe21cc00a453e04d6ea7e72cac8af29 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/resume_execution.py @@ -0,0 +1,568 @@ +# mypy: allow-untyped-defs + +""" +This module provides functionality for resuming Python execution at specific points in code, +primarily used by PyTorch Dynamo for control flow handling and optimization. It implements +bytecode transformation and execution state management to enable: + +- Resuming execution at arbitrary points in Python bytecode +- Managing context managers and their state across execution boundaries +- Transforming and generating new code objects with preserved execution state +- Supporting Python 3.11+ exception handling and block management +- Restoring torch function mode stacks and other execution context + +The module is critical for PyTorch Dynamo's ability to optimize code while preserving +Python semantics and execution state. +""" + +import copy +import dataclasses +import sys +import types +from typing import Any, cast, Optional + +from .bytecode_transformation import ( + bytecode_from_template, + create_call_function, + create_instruction, + create_jump_absolute, + create_load_const, + Instruction, + overwrite_instruction, + transform_code_object, + unique_id, +) +from .utils import ExactWeakKeyDictionary + + +# taken from code.h in cpython +CO_OPTIMIZED = 0x0001 +CO_NEWLOCALS = 0x0002 +CO_VARARGS = 0x0004 +CO_VARKEYWORDS = 0x0008 +CO_NESTED = 0x0010 +CO_GENERATOR = 0x0020 +CO_NOFREE = 0x0040 +CO_COROUTINE = 0x0080 +CO_ITERABLE_COROUTINE = 0x0100 +CO_ASYNC_GENERATOR = 0x0200 + +# trace_rules.py import this constant for consistency +TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in" + + +def _initial_push_null(insts): + if sys.version_info >= (3, 11): + insts.append(create_instruction("PUSH_NULL")) + if sys.version_info < (3, 13): + insts.append(create_instruction("SWAP", arg=2)) + + +# Generates bytecode from template and splits the code where LOAD_FAST dummy is present. +def _bytecode_from_template_with_split(template, stack_index, varname_map=None): + template_code = bytecode_from_template(template, varname_map=varname_map) + template_code.append(create_instruction("POP_TOP")) + + # adjust exception table entry depth + for inst in template_code: + if inst.exn_tab_entry: + inst.exn_tab_entry.depth += stack_index + + # search for LOAD_FAST dummy and replace it with 2 NOPs (we can break up the bytecode between them) + dummy_idx, dummy_inst = next( + ( + (i, inst) + for i, inst in enumerate(template_code) + if inst.opname == "LOAD_FAST" and inst.argval == "dummy" + ), + (None, None), + ) + assert dummy_idx is not None + + # replace LOAD_FAST dummy with first NOP marking exception area + overwrite_instruction(dummy_inst, [create_instruction("NOP")]) + + # POP_TOP follows LOAD_FAST dummy - replace with NOP marking end of exception area + assert template_code[dummy_idx + 1].opname == "POP_TOP" + overwrite_instruction(template_code[dummy_idx + 1], [create_instruction("NOP")]) + + return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :] + + +def _try_except_tf_mode_template(dummy, stack_var_name): + # NOTE: Make sure this name matches what is generated by symbolic_convert:import_source + # on torch._dynamo.utils. + global __import_torch_dot__dynamo_dot_utils + try: + dummy + except: # noqa: E722, B001 + __import_torch_dot__dynamo_dot_utils.set_torch_function_mode_stack( # type: ignore[name-defined] + stack_var_name + ) + raise + + +@dataclasses.dataclass(frozen=True) +class ReenterWith: + stack_index: int + target_values: Optional[tuple[Any, ...]] = None + + def try_except_torch_function_mode(self, code_options, cleanup: list[Instruction]): + """ + Codegen based off of: + try: + (rest) + except: + (restore previous tf mode stack) + raise + """ + from .variables.torch_function import get_prev_stack_var_name + + setup_try_except, epilogue = _bytecode_from_template_with_split( + _try_except_tf_mode_template, + self.stack_index, + varname_map={"stack_var_name": get_prev_stack_var_name()}, + ) + cleanup[:] = epilogue + cleanup + + return setup_try_except + + # If we do not want to destroy the stack, we can do the same thing as a + # `SETUP_WITH` block, only that we store the context manager in a local_symbol + def try_finally(self, code_options, cleanup: list[Instruction]): + """ + Codegen based off of: + load args + enter context + try: + (rest) + finally: + exit context + """ + # NOTE: we assume that TOS is a context manager CLASS! + load_args = [] + if self.target_values: + load_args = [create_load_const(val) for val in self.target_values] + ctx_name = unique_id(f"___context_manager_{self.stack_index}") + if ctx_name not in code_options["co_varnames"]: + code_options["co_varnames"] += (ctx_name,) + for name in ["__enter__", "__exit__"]: + if name not in code_options["co_names"]: + code_options["co_names"] += (name,) + + create_ctx: list[Instruction] = [] + _initial_push_null(create_ctx) + create_ctx.extend( + [ + *load_args, + *create_call_function(len(load_args), False), + create_instruction("STORE_FAST", argval=ctx_name), + ] + ) + + def _template(ctx, dummy): + ctx.__enter__() + try: + dummy + finally: + ctx.__exit__(None, None, None) + + setup_try_finally, epilogue = _bytecode_from_template_with_split( + _template, self.stack_index, varname_map={"ctx": ctx_name} + ) + cleanup[:] = epilogue + cleanup + return create_ctx + setup_try_finally + + def __call__(self, code_options, cleanup): + """ + Codegen based off of: + with ctx(args): + (rest) + """ + # NOTE: we assume that TOS is a context manager CLASS! + load_args = [] + if self.target_values: + load_args = [create_load_const(val) for val in self.target_values] + + create_ctx: list[Instruction] = [] + _initial_push_null(create_ctx) + create_ctx.extend( + [ + *load_args, + *create_call_function(len(load_args), False), + ] + ) + + def _template(ctx, dummy): + with ctx: + dummy + + setup_with, epilogue = _bytecode_from_template_with_split( + _template, self.stack_index + ) + cleanup[:] = epilogue + cleanup + + load_fast_ctx_inst = next( + ( + inst + for inst in setup_with + if inst.opname == "LOAD_FAST" and inst.argval == "ctx" + ), + None, + ) + assert load_fast_ctx_inst is not None + # ctx already loaded on stack before the template - no need to LOAD_FAST + overwrite_instruction(load_fast_ctx_inst, [create_instruction("NOP")]) + + # 3.11+ only + push_exc_info_gen = ( + inst for inst in epilogue if inst.opname == "PUSH_EXC_INFO" + ) + push_exc_info_inst = next(push_exc_info_gen, None) + # expect only 1 PUSH_EXC_INFO in epilogue + assert next(push_exc_info_gen, None) is None + + return create_ctx + setup_with, push_exc_info_inst + + +@dataclasses.dataclass +class ResumeFunctionMetadata: + code: types.CodeType + instructions: list[Instruction] = dataclasses.field(default_factory=list) + # Python 3.11+ fields + # NOTE: Python 3.11 removed blocks, but for our purposes, a "block" consists + # of instructions of all exception table entries that have the same target. + + # map from PUSH_EXC_INFO's in the prefix to original block target offset + prefix_block_target_offset_remap: list[int] = dataclasses.field( + default_factory=list + ) + # map from new block target offsets to original block target offsets + block_target_offset_remap: Optional[dict[int, int]] = None + + +def _filter_iter(l1, l2, cond): + """ + Two-pointer conditional filter. + e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o) + returns the instructions with offsets in sorted_offsets + """ + it = iter(l2) + res: list[Instruction] = [] + try: + cur = next(it) + for val in l1: + if cond(val, cur): + res.append(val) + cur = next(it) + except StopIteration: + pass + return res + + +def _load_tuple_and_call(tup): + insts: list[Instruction] = [] + _initial_push_null(insts) + insts.extend(create_load_const(val) for val in tup) + insts.extend(create_call_function(len(tup), False)) + return insts + + +class ContinueExecutionCache: + cache = ExactWeakKeyDictionary() + generated_code_metadata = ExactWeakKeyDictionary() + + @classmethod + def lookup(cls, code, lineno, *key): + if code not in cls.cache: + cls.cache[code] = {} + key = tuple(key) + if key not in cls.cache[code]: + cls.cache[code][key] = cls.generate(code, lineno, *key) + return cls.cache[code][key] + + @classmethod + def generate( + cls, + code, + lineno, + offset: int, + setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+ + nstack: int, + argnames: tuple[str, ...], + argnames_null: tuple[str, ...], + setup_fns: tuple[ReenterWith, ...], + stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...], + argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...], + null_idxes: tuple[int, ...], + ) -> types.CodeType: + assert offset is not None + assert not ( + code.co_flags + & (CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR) + ) + assert code.co_flags & CO_OPTIMIZED + if code in ContinueExecutionCache.generated_code_metadata: + return cls.generate_based_on_original_code_object( + code, + lineno, + offset, + setup_fn_target_offsets, + nstack, + argnames, + argnames_null, + setup_fns, + stack_ctx_vars, + argnames_ctx_vars, + null_idxes, + ) + + is_py311_plus = sys.version_info >= (3, 11) + meta = ResumeFunctionMetadata(code) + + def update(instructions: list[Instruction], code_options: dict[str, Any]): + meta.instructions = copy.deepcopy(instructions) + + args = [f"___stack{i}" for i in range(nstack)] + args.extend(v for v in argnames if v not in args) + freevars = tuple(code_options["co_cellvars"] or []) + tuple( + code_options["co_freevars"] or [] + ) + freevars = tuple(sorted(freevars)) + code_options["co_name"] = ( + f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_{code_options['co_name']}_at_{lineno}" + ) + if is_py311_plus: + qualified_path = code_options["co_qualname"].rsplit(".", maxsplit=1) + if len(qualified_path) == 1: + code_options["co_qualname"] = code_options["co_name"] + else: + assert len(qualified_path) == 2 + module_name, co_name = qualified_path + code_options["co_qualname"] = ( + f"{module_name}.{TORCH_DYNAMO_RESUME_IN_PREFIX}_{co_name}_at_{lineno}" + ) + code_options["co_firstlineno"] = lineno + code_options["co_cellvars"] = () + code_options["co_freevars"] = freevars + code_options["co_argcount"] = len(args) + code_options["co_posonlyargcount"] = 0 + code_options["co_kwonlyargcount"] = 0 + code_options["co_varnames"] = tuple( + args + + [v for v in argnames_null if v not in args] + + [ + v + for v in code_options["co_varnames"] + if v not in args and v not in freevars + ] + ) + code_options["co_flags"] = code_options["co_flags"] & ~( + CO_VARARGS | CO_VARKEYWORDS + ) + target = next(i for i in instructions if i.offset == offset) + + prefix = [] + if is_py311_plus: + if freevars: + prefix.append( + create_instruction("COPY_FREE_VARS", arg=len(freevars)) + ) + prefix.append(create_instruction("RESUME", arg=0)) + + cleanup: list[Instruction] = [] + hooks = {fn.stack_index: fn for fn in setup_fns} + hook_target_offsets = { + fn.stack_index: setup_fn_target_offsets[i] + for i, fn in enumerate(setup_fns) + } + offset_to_inst = {inst.offset: inst for inst in instructions} + # map old hook targets to new targets generated by the hook + old_hook_target_remap = {} + null_idxes_i = 0 + stack_ctx_vars_d = dict(stack_ctx_vars) # type: ignore[var-annotated,arg-type] + for i in range(nstack): + while ( + null_idxes_i < len(null_idxes) + and null_idxes[null_idxes_i] == i + null_idxes_i + ): + prefix.append(create_instruction("PUSH_NULL")) + null_idxes_i += 1 + prefix.append(create_instruction("LOAD_FAST", argval=f"___stack{i}")) + if i in hooks: + hook = hooks.pop(i) + hook_insts, exn_target = hook(code_options, cleanup) + prefix.extend(hook_insts) + if is_py311_plus: + hook_target_offset = hook_target_offsets.pop(i) + old_hook_target = offset_to_inst[hook_target_offset] + meta.prefix_block_target_offset_remap.append(hook_target_offset) + old_hook_target_remap[old_hook_target] = exn_target + if i in stack_ctx_vars_d: + # NOTE: we assume that current stack var is a context manager CLASS! + # Load args for context variable and construct it + prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[i])) + + if is_py311_plus: + # reverse the mapping since targets of later/nested contexts are inserted + # into the mapping later, but show up earlier in the prefix. + meta.prefix_block_target_offset_remap = list( + reversed(meta.prefix_block_target_offset_remap) + ) + + assert not hooks + + # NOTE: we assume that local var is a context manager CLASS! + # initialize inactive context vars in argnames + for name, vals in argnames_ctx_vars: + prefix.append(create_instruction("LOAD_FAST", argval=name)) + prefix.extend(_load_tuple_and_call(vals)) + prefix.append(create_instruction("STORE_FAST", argval=name)) + + # 3.12+: store NULL into variables that were NULL + if argnames_null: + assert sys.version_info >= (3, 12) + for v in argnames_null: + assert v not in args + prefix.extend( + [ + create_instruction("PUSH_NULL"), + create_instruction("STORE_FAST", argval=v), + ] + ) + + prefix.append(create_jump_absolute(target)) + + # because the line number table monotonically increases from co_firstlineno + # remove starts_line for any instructions before the graph break instruction + # this will ensure the instructions after the break have the correct line numbers + for inst in instructions: + if inst.offset == target.offset: + break + inst.starts_line = None + if sys.version_info >= (3, 11): + inst.positions = None + + if cleanup: + prefix.extend(cleanup) + prefix.extend(cls.unreachable_codes(code_options)) + + # remap original instructions' exception table entries + if old_hook_target_remap: + assert is_py311_plus + for inst in instructions: + if ( + inst.exn_tab_entry + and inst.exn_tab_entry.target in old_hook_target_remap + ): + inst.exn_tab_entry.target = old_hook_target_remap[ + inst.exn_tab_entry.target + ] + + # TODO(jansel): add dead code elimination here + instructions[:] = prefix + instructions + + new_code = transform_code_object(code, update) + ContinueExecutionCache.generated_code_metadata[new_code] = meta + return new_code + + @staticmethod + def unreachable_codes(code_options) -> list[Instruction]: + """Codegen a `raise None` to make analysis work for unreachable code""" + return [ + create_load_const(None), + create_instruction("RAISE_VARARGS", arg=1), + ] + + @classmethod + def generate_based_on_original_code_object( + cls, code, lineno, offset: int, setup_fn_target_offsets: tuple[int, ...], *args + ): + """ + This handles the case of generating a resume into code generated + to resume something else. We want to always generate starting + from the original code object so that if control flow paths + converge we only generated 1 resume function (rather than 2^n + resume functions). + """ + + meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[ + code + ] + new_offset = None + + def find_new_offset( + instructions: list[Instruction], code_options: dict[str, Any] + ): + nonlocal new_offset + (target,) = (i for i in instructions if i.offset == offset) + # match the functions starting at the last instruction as we have added a prefix + (new_target,) = ( + i2 + for i1, i2 in zip(reversed(instructions), reversed(meta.instructions)) + if i1 is target + ) + assert target.opcode == new_target.opcode + new_offset = new_target.offset + + transform_code_object(code, find_new_offset) + + if sys.version_info >= (3, 11): + # setup_fn_target_offsets currently contains the target offset of + # each setup_fn, based on `code`. When we codegen the resume function + # based on the original code object, `meta.code`, the offsets in + # setup_fn_target_offsets must be based on `meta.code` instead. + if not meta.block_target_offset_remap: + block_target_offset_remap = meta.block_target_offset_remap = {} + + def remap_block_offsets( + instructions: list[Instruction], code_options: dict[str, Any] + ): + # NOTE: each prefix block generates exactly one PUSH_EXC_INFO, + # so we can tell which block a prefix PUSH_EXC_INFO belongs to, + # by counting. Then we can use meta.prefix_block-target_offset_remap + # to determine where in the original code the PUSH_EXC_INFO offset + # replaced. + prefix_blocks: list[Instruction] = [] + for inst in instructions: + if len(prefix_blocks) == len( + meta.prefix_block_target_offset_remap + ): + break + if inst.opname == "PUSH_EXC_INFO": + prefix_blocks.append(inst) + + # offsets into prefix + for inst, o in zip( + prefix_blocks, meta.prefix_block_target_offset_remap + ): + block_target_offset_remap[cast(int, inst.offset)] = o + + # old bytecode targets are after the prefix PUSH_EXC_INFO's + old_start_offset = ( + cast(int, prefix_blocks[-1].offset) if prefix_blocks else -1 + ) + # offsets into old bytecode + old_inst_offsets = sorted( + n for n in setup_fn_target_offsets if n > old_start_offset + ) + targets = _filter_iter( + instructions, old_inst_offsets, lambda inst, o: inst.offset == o + ) + new_targets = _filter_iter( + zip(reversed(instructions), reversed(meta.instructions)), + targets, + lambda v1, v2: v1[0] is v2, + ) + for new, old in zip(new_targets, targets): + block_target_offset_remap[old.offset] = new[1].offset + + transform_code_object(code, remap_block_offsets) + + # if offset is not in setup_fn_target_offsets, it is an error + setup_fn_target_offsets = tuple( + meta.block_target_offset_remap[n] for n in setup_fn_target_offsets + ) + return ContinueExecutionCache.lookup( + meta.code, lineno, new_offset, setup_fn_target_offsets, *args + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/side_effects.py b/phivenv/Lib/site-packages/torch/_dynamo/side_effects.py new file mode 100644 index 0000000000000000000000000000000000000000..2b406c2538f93baca24386fbe0d6d5280a53006b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/side_effects.py @@ -0,0 +1,1139 @@ +# mypy: allow-untyped-defs + +""" +Side effect tracking and management for TorchDynamo's compilation system. + +This module provides infrastructure for tracking and managing side effects that occur +during symbolic execution, including: + +- Tracking mutations to objects, attributes, and variables +- Managing context changes (cell variables, global namespace modifications) +- Handling aliasing and object identity preservation +- Managing stack frame state and local variable changes +- Tracking function calls with side effects + +Key classes: +- SideEffects: Main container for tracking all side effects during execution +- MutableSideEffects: Specialization for mutable object tracking +- AttributeMutation/ValueMutation: Track specific types of mutations +- Various specialized side effect classes for different scenarios + +The side effect system ensures that mutations performed during symbolic execution +are properly replayed during runtime, maintaining the correctness of compiled code +while enabling optimizations where safe. +""" + +import collections +import contextlib +import inspect +import warnings +import weakref +from collections.abc import MutableMapping +from types import CellType +from typing import Any, Optional, TYPE_CHECKING + +import torch.nn + +from . import graph_break_hints, utils, variables +from .bytecode_transformation import ( + bytecode_from_template, + create_call_function, + create_call_method, + create_instruction, +) +from .codegen import PyCodegen +from .exc import SideEffectsError, unimplemented_v2 +from .source import GlobalSource, LocalCellSource, LocalSource, Source +from .utils import is_frozen_dataclass, nn_module_new, object_new +from .variables.base import ( + AttributeMutation, + AttributeMutationExisting, + AttributeMutationNew, + is_side_effect_safe, + ValueMutationExisting, + ValueMutationNew, + VariableTracker, +) +from .variables.user_defined import FrozenDataClassVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +def _manual_dict_setitem(dict_from, dict_to, mro_index): + # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have + # to be careful because we don't want to trigger the user defined object + # setitem or clear. The mro_index is used to find the dict/OrderedDict from + # the class mro. + dict_class = type(dict_to).__mro__[mro_index] + dict_class.clear(dict_to) + for k, v in dict_from.items(): + dict_class.__setitem__(dict_to, k, v) + + +def _manual_list_update(list_from, list_to): + list.clear(list_to) + list.extend(list_to, list_from) + + +class SideEffects: + """ + Maintain records of mutations and provide methods to apply them during code generation. + + Handles tracking and applying side effects during PyTorch Dynamo compilation, + maintaining Python semantics by managing mutations, attribute modifications, + and other side effects that occur during program execution. + + Key responsibilities: + - Tracks mutations to Python objects, lists, and dictionaries that need to be + applied after an FX graph is run. + - Manages attribute modifications and deletions + - Handles tensor hooks and backward pass state + - Tracks cell variable mutations and global variable changes + - Ensures correct ordering and application of side effects after graph execution + + This ensures that optimized code behaves identically to the original Python code with + respect to object mutations and other side effects. + """ + + id_to_variable: dict[int, VariableTracker] + store_attr_mutations: dict[VariableTracker, dict[str, VariableTracker]] + keepalive: list[Any] + + def __init__( + self, + output_graph, + id_to_variable=None, + store_attr_mutations=None, + keepalive=None, + save_for_backward=None, + tensor_hooks=None, + ): + super().__init__() + self.output_graph_weakref = weakref.ref(output_graph) + self.id_to_variable = id_to_variable or {} + self.store_attr_mutations = store_attr_mutations or {} + self.keepalive = keepalive or [] + self.save_for_backward = save_for_backward or [] + self.tensor_hooks = tensor_hooks or {} + # Used by MappingProxyVariable to graph break in case of any mutated + # dict + self._has_existing_dict_mutation = False + # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph. + # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd. + self.ca_final_callbacks_var = None + + def __eq__(self, other: object) -> bool: + assert isinstance(other, SideEffects) + # NB: do NOT test keepalive + return ( + self.id_to_variable == other.id_to_variable + and self.store_attr_mutations == other.store_attr_mutations + and self.save_for_backward == other.save_for_backward + and self.tensor_hooks == other.tensor_hooks + ) + + def diff(self, other: "SideEffects") -> Optional[str]: + if self.id_to_variable != other.id_to_variable: + sk_itv = self.id_to_variable.keys() + ok_itv = other.id_to_variable.keys() + if sk_itv != ok_itv: + return f"id_to_variable keys: {sk_itv} != {ok_itv}" + # Feel free to augment this with more fancy diffing logic + # if needed for debugging + return "id_to_variable: unknown diff" + elif self.store_attr_mutations != other.store_attr_mutations: + sk_sam = self.store_attr_mutations.keys() + ok_sam = other.store_attr_mutations.keys() + if sk_sam != ok_sam: + return f"store_attr_mutations keys: {sk_sam} != {ok_sam}" + return "store_attr_mutations: unknown diff" + elif self.save_for_backward != other.save_for_backward: + return "save_for_backward" + elif self.tensor_hooks != other.tensor_hooks: + return "tensor_hooks" + else: + return None + + def clone(self): + """Create a shallow copy""" + return self.__class__( + output_graph=self.output_graph_weakref(), + id_to_variable=dict(self.id_to_variable), + store_attr_mutations={ + k: dict(v) for k, v in self.store_attr_mutations.items() + }, + keepalive=list(self.keepalive), + save_for_backward=self.save_for_backward, + tensor_hooks=self.tensor_hooks, + ) + + def __contains__(self, item): + return id(item) in self.id_to_variable + + def __getitem__(self, item): + return self.id_to_variable[id(item)] + + def should_allow_side_effects_under_checkpoint(self): + output_graph = self.output_graph_weakref() + return ( + output_graph + and output_graph.current_tx.output.current_tracer.under_activation_checkpoint + and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint + ) + + def should_allow_externally_visible_side_effects_in_subtracer(self): + output_graph = self.output_graph_weakref() + return ( + output_graph + and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects + ) + + def is_reconstructing_generator(self): + output_graph = self.output_graph_weakref() + + return ( + output_graph + and output_graph.current_tx.output.current_tracer.is_reconstructing_generator + ) + + def check_allowed_side_effect(self, item): + from torch._dynamo.variables.misc import AutogradFunctionContextVariable + + # People do things like self.dim = dim inside autograd.Function. + # These are benign. + if isinstance(item, AutogradFunctionContextVariable): + return True + if self.should_allow_externally_visible_side_effects_in_subtracer(): + return True + if self.should_allow_side_effects_under_checkpoint(): + return True + if self.is_reconstructing_generator(): + # This is missing the case where one mutates a tensor. See + # test_generator.py::test_reconstruct_generator_tensor_mutation + raise SideEffectsError( + "Cannot reconstruct a generator with variable mutations. " + "Dynamo needs to fully exhaust the generator, which may cause " + "unintended variable modifications." + ) + if not is_side_effect_safe(item.mutation_type): + # TODO plumb HOP information here + unimplemented_v2( + gb_type="HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)", + context="", + explanation="This is not supported.", + hints=[], + ) + + def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): + assert self.is_attribute_mutation(item) + self.check_allowed_side_effect(item) + if item not in self.store_attr_mutations: + self.store_attr_mutations[item] = {} + self.store_attr_mutations[item][name] = value + + def load_attr(self, item, name, deleted_ok=False, check=False): + if check: + assert self.is_attribute_mutation(item) + result = self.store_attr_mutations[item][name] + if not deleted_ok and isinstance(result, variables.DeletedVariable): + unimplemented_v2( + gb_type="Attempted to read a deleted variable", + context=f"item: {item}, name: {name}", + explanation="", + hints=[*graph_break_hints.USER_ERROR], + ) + return result + + def store_cell(self, cellvar, value): + if cellvar.is_immutable(): + unimplemented_v2( + gb_type="Write to immutable cell", + context=f"cellvar: {cellvar}, value: {value}", + explanation="Dynamo doesn't support writing to immutable/sourceless cell variables.", + hints=[*graph_break_hints.DIFFICULT], + ) + assert isinstance(cellvar, variables.CellVariable) + assert isinstance(value, variables.VariableTracker) + self.store_attr(cellvar, "cell_contents", value) + + def load_cell(self, cellvar): + assert isinstance(cellvar, variables.CellVariable) + if self.has_pending_mutation_of_attr(cellvar, "cell_contents"): + return self.load_attr(cellvar, "cell_contents", check=False) + if cellvar.pre_existing_contents: + return cellvar.pre_existing_contents + unimplemented_v2( + gb_type="Read uninitialized cell", + context=str(cellvar), + explanation="Attempted to read a cell variable that has not been populated yet.", + hints=[*graph_break_hints.USER_ERROR], + ) + + def load_global(self, gvar: VariableTracker, name: str): + assert isinstance(gvar, variables.VariableTracker) + return self.load_attr(gvar, name) + + def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): + assert isinstance(gvar, variables.VariableTracker) + assert isinstance(value, variables.VariableTracker) + self.store_attr(gvar, name, value) + + @staticmethod + def cls_supports_mutation_side_effects(cls): + return inspect.getattr_static(cls, "__getattribute__", None) in ( + object.__getattribute__, + dict.__getattribute__, + int.__getattribute__, + str.__getattribute__, + list.__getattribute__, + tuple.__getattribute__, + BaseException.__getattribute__, + ) + + def is_attribute_mutation(self, item): + return isinstance(item.mutation_type, AttributeMutation) + + def has_pending_mutation(self, item): + return self.is_attribute_mutation(item) and bool( + self.store_attr_mutations.get(item) + ) + + def has_pending_mutation_of_attr(self, item, name): + return self.is_attribute_mutation( + item + ) and name in self.store_attr_mutations.get(item, ()) + + def is_modified(self, item): + if item.is_immutable(): + return False + if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)): + return True + + if isinstance(item, variables.UserDefinedObjectVariable): + # Checks if the underlying dict or tuple vt has been modified + return item in self.store_attr_mutations or item.is_underlying_vt_modified( + self + ) + + if self.is_attribute_mutation(item): + return item in self.store_attr_mutations + + return item.mutation_type.is_modified + + def _track_obj( + self, + item: Any, + variable: VariableTracker, + mutation_type_cls=ValueMutationExisting, + ): + """Start tracking an existing or new variable for mutation""" + if id(item) in self.id_to_variable: + raise AssertionError( + f"{variable} is already tracked for mutation. This could be " + "because you are not using VariableBuilder to construct " + "the variable tracker. " + f"Source of new object: {variable.source}. " + f"Source of previously tracked object: {self.id_to_variable[id(item)].source}." + ) + + variable.mutation_type = mutation_type_cls() + self.id_to_variable[id(item)] = variable + self.keepalive.append(item) + return variable + + track_mutable = _track_obj + + def track_object_existing( + self, + item: Any, + variable: VariableTracker, + ): + return self._track_obj( + item, + variable, + mutation_type_cls=AttributeMutationExisting, + ) + + def track_object_new( + self, + cls_source: Source, + user_cls: Any, + variable_cls: Any, + options, + ): + if user_cls is torch.autograd.function.FunctionCtx: + with warnings.catch_warnings(record=True): + obj = torch.autograd.Function() + else: + obj = object_new(user_cls) + variable = variable_cls( + obj, + mutation_type=AttributeMutationNew(cls_source), + **options, + ) + self.id_to_variable[id(obj)] = variable + self.keepalive.append(obj) + return variable + + def get_variable_cls(self, user_cls): + from torch.overrides import TorchFunctionMode + + from .variables.ctx_manager import GenericContextWrappingVariable + from .variables.torch_function import TorchFunctionModeVariable + from .variables.user_defined import is_forbidden_context_manager + + variable_cls: type[variables.UserDefinedObjectVariable] = ( + variables.UserDefinedObjectVariable + ) + if issubclass( + user_cls, TorchFunctionMode + ) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls): + variable_cls = TorchFunctionModeVariable + elif ( + hasattr(user_cls, "__enter__") + and hasattr(user_cls, "__exit__") + and not is_forbidden_context_manager(user_cls) + ): + variable_cls = GenericContextWrappingVariable + elif issubclass(user_cls, torch.nn.Module): + variable_cls = variables.UnspecializedNNModuleVariable + elif issubclass(user_cls, (dict, collections.OrderedDict)): + variable_cls = variables.UserDefinedDictVariable + elif issubclass(user_cls, tuple): + variable_cls = variables.UserDefinedTupleVariable + elif issubclass(user_cls, list): + variable_cls = variables.UserDefinedListVariable + elif issubclass(user_cls, MutableMapping): + variable_cls = variables.MutableMappingVariable + elif is_frozen_dataclass(user_cls): + variable_cls = FrozenDataClassVariable + elif issubclass(user_cls, BaseException): + variable_cls = variables.UserDefinedExceptionObjectVariable + assert issubclass(variable_cls, variables.UserDefinedObjectVariable) + return variable_cls + + def get_example_value( + self, + base_cls_vt, + cls_vt, + init_args, + ): + user_cls = cls_vt.value + if issubclass(user_cls, torch.nn.Module): + # TODO(anijain2305) - Is it possible to remove this specialization? + obj = nn_module_new(user_cls) + else: + if isinstance(base_cls_vt, variables.BuiltinVariable): + base_cls = base_cls_vt.fn + elif isinstance(base_cls_vt, variables.UserDefinedClassVariable): + base_cls = base_cls_vt.value + else: + raise RuntimeError(f"Unexpected base_cls_vt {base_cls_vt}") + + assert variables.UserDefinedClassVariable.is_supported_new_method( + base_cls.__new__ + ) + # TODO(anijain2305) - Consider adding get_example_value method to + # each VT to get an example value for all args. As we expand the + # scope to other __new__ methods, we might need to call __new__ with + # init_args (like functools.partial) + # init_args = [arg.get_example_value() for arg in init_args] + # obj = base_cls.__new__(user_cls, *init_args) + + obj = base_cls.__new__(user_cls) + return obj + + def track_new_user_defined_object( + self, + base_cls_vt, + cls_vt, + init_args, + ): + """ + Creates a UserDefinedObjectVariable (or its subclass) variable tracker + and mark it for attribute mutation tracking. + + Also records the variable trackers to call __new__ method on + reconstruction. Roughly, the reconstruction looks like this + base_cls_vt.__new__(user_cls, *init_args) + """ + cls_source = cls_vt.source + user_cls = cls_vt.value + variable_cls = self.get_variable_cls(user_cls) + obj = self.get_example_value(base_cls_vt, cls_vt, init_args) + + variable = variable_cls( + obj, + cls_source=cls_vt.source, + base_cls_vt=base_cls_vt, + init_args=init_args, + mutation_type=AttributeMutationNew(cls_source), + ) + self.id_to_variable[id(obj)] = variable + self.keepalive.append(obj) + return variable + + def track_cell_new( + self, + ): + obj = object() + variable = variables.CellVariable( + mutation_type=AttributeMutationNew(), + ) + self.id_to_variable[id(obj)] = variable + self.keepalive.append(obj) + return variable + + def track_cell_existing( + self, source: Optional[Source], cell: CellType, contents: VariableTracker + ): + variable = variables.CellVariable( + # We don't support mutation to cell without source because we need + # source to properly codegen the mutations. + mutation_type=None if source is None else AttributeMutationExisting(), + pre_existing_contents=contents, + source=source, + ) + self.id_to_variable[id(cell)] = variable + self.keepalive.append(cell) + return variable + + def track_global_existing(self, source: Source, item: Any): + variable = variables.NewGlobalVariable( + mutation_type=AttributeMutationExisting(), + source=source, + ) + self.id_to_variable[id(item)] = variable + self.keepalive.append(item) + return variable + + def track_save_for_backward(self, ctx, args): + assert isinstance(ctx, variables.AutogradFunctionContextVariable) + self.save_for_backward.append((ctx, args)) + + def track_tensor_variables_from_runahead_side_effects(self, other): + # In higher order ops we want to keep track of tensors seen in the + # speculate_subgraph so that we don't lift them again as a new input in + # other speculate_subgraph or in the root tracer. + for other_item in other.keepalive: + other_id = id(other_item) + other_variable = other.id_to_variable[other_id] + if other_id not in self.id_to_variable and isinstance( + other_variable, variables.TensorVariable + ): + self.track_object_existing(other_item, other_variable) + + def prune_dead_object_new(self, tx): + # Avoid VT cycles from e.g., recursive function. + visited: set[VariableTracker] = set() + live_new_objects: set[VariableTracker] = set() + + def visit(var: VariableTracker): + if var in visited: + return + visited.add(var) + # Object may have been mutated, store this mutation. + if isinstance(var.mutation_type, AttributeMutationNew): + live_new_objects.add(var) + # It's possible that we have mutated the value of this variable + # to be another one. The new value is in store_attr_mutations. + # Also recurse through the new value to detect alive AttributeMutationNew. + if var in self.store_attr_mutations: + VariableTracker.visit( + visit, # noqa: F821 + self.store_attr_mutations[var], + ) + + def is_live(var: VariableTracker): + if isinstance(var.mutation_type, AttributeMutationNew): + return var in live_new_objects + return True + + pre_existing_vars = [ + var + for var in self.id_to_variable.values() + if not isinstance(var.mutation_type, AttributeMutationNew) + ] + + # The only live side effects come from returns (tx.stack), any intermediates + # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables. + # Recursively visit Variables and see if any of them have been mutated. + VariableTracker.visit( + visit, + # TODO track from all possible sources. + ( + tx.stack, + tx.symbolic_locals, + pre_existing_vars, + tx.output.backward_state, + self.tensor_hooks, + ), + ) + # Manually release the self-referential function, which indirectly + # captures certain `VariableTracker` and affects parts of PT test/logic + # that are sensitive to when certain objects get released. + del visit + + # NB: cell variable handling.is tricky. + # cell variables must stay alive if any NestedUserFunctionVariable + # are live. "visit"-ing the NestedUserFunctionVariable visits + # the .closures field, from which we will see if we need to keep + # any mutations to cell variables alive. + + self.id_to_variable = { + k: v for k, v in self.id_to_variable.items() if is_live(v) + } + self.store_attr_mutations = { + k: v for k, v in self.store_attr_mutations.items() if is_live(k) + } + + def mutation(self, var): + self.check_allowed_side_effect(var) + if isinstance(var.mutation_type, ValueMutationExisting): + var.mutation_type.is_modified = True + if ( + var.source + and isinstance(var, variables.ConstDictVariable) + and not isinstance(var, variables.SetVariable) + ): + self._has_existing_dict_mutation = True + + def has_existing_dict_mutation(self): + return self._has_existing_dict_mutation + + def _get_modified_vars(self): + return [var for var in self.id_to_variable.values() if self.is_modified(var)] + + def codegen_save_tempvars(self, cg: PyCodegen): + # We must codegen modified VT to their source by default, so that + # mutation and aliasing are properly accounted for. + # + # Since newly constructed objects don't have a source, we manually + # codegen their construction and store them to a newly assigned local + # source. Note that `ValueMutationNew` isn't tracked by SideEffects. + for var in self._get_modified_vars(): + if not isinstance(var.mutation_type, AttributeMutationNew): + assert var.source is not None + continue + + if isinstance(var, variables.CellVariable): + # Cells created in the root frame are created either by + # `MAKE_CELL` or by them being in `co_cellvars`, so we only emit + # `make_cell` for the non-root-frame cells here. + # TODO generalize this so we never need to call `make_cell`. + if var.local_name is None: + cg.add_push_null( + lambda: cg.load_import_from(utils.__name__, "make_cell") + ) + cg.extend_output(create_call_function(0, False)) + cg.add_cache(var) + var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] + elif var.source is None: + var.source = LocalCellSource(var.local_name) + elif isinstance(var, variables.TensorVariable): + # NOTE: for historical reasons we never assigned local sources + # to newly constructed tensor object, so we keep it that way. + # They are always loaded from output of the fx graph, so one can + # think of it as having a "OutputGraphSource" for codegen + # purposes. + # + # However, tensor subclass objects are different, because the + # reconstruction logic in `PyCodegen` loads the data tensor from + # graph output and then calls `as_subclass`, meaning we must + # assign a source to it to ensure we only reconstruct one + # subclass instance. + if isinstance( + var, variables.torch_function.TensorWithTFOverrideVariable + ): + # Don't codegen from temp source assigned from the 1st pass. + cg(var, allow_cache=False) + cg.add_cache(var) + # `add_cache` generates STORE and consumes TOS, but we never + # cleared it. TODO move this call into `add_cache` + cg.clear_tos() + var.source = LocalSource(cg.tempvars[var]) + elif isinstance(var, variables.AutogradFunctionContextVariable): + unimplemented_v2( + gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", + context="", + explanation="We cannot reconstruct a torch.autograd.Function's context object.", + hints=[], + ) + else: + # Reconstruct the bytecode for + # base_cls.__new__(user_cls, *args) + if isinstance(var, variables.UserDefinedObjectVariable): + + def load_new_method(): + assert var.base_cls_vt is not None + cg(var.base_cls_vt) # type: ignore[attr-defined] + cg.extend_output([cg.create_load_attr("__new__")]) + + cg.add_push_null(load_new_method) + else: + cg.add_push_null( + lambda: cg.load_import_from(utils.__name__, "object_new") + ) + cg(var.mutation_type.cls_source) + + # Generate the args to the __new__ method + for arg in var.init_args: + cg(arg) + + # Call the __new__ method + cg.extend_output(create_call_function(1 + len(var.init_args), False)) + + cg.add_cache(var) + var.source = LocalSource(cg.tempvars[var]) + + for ctx, args in self.save_for_backward: + cg(ctx.source) + cg.load_method("save_for_backward") + for arg in args: + cg(arg) + cg.extend_output( + [ + *create_call_method(len(args)), + create_instruction("POP_TOP"), + ] + ) + + def register_hook(self, tensor, hook, handle, name): + assert isinstance(tensor, variables.TensorVariable) + assert isinstance(hook, variables.VariableTracker) + assert ( + isinstance(handle, variables.RemovableHandleVariable) + and handle.is_mutable() + ) + assert hasattr(torch.Tensor, name) + idx = len(self.tensor_hooks.keys()) + # duplicate index possible because of self.remove_hook() + while idx in self.tensor_hooks: + idx += 1 + self.tensor_hooks[idx] = (tensor, hook, handle, name) + assert not handle.idx + handle.idx = idx + + def remove_hook(self, idx): + del self.tensor_hooks[idx] + + def codegen_hooks(self, cg): + for ( + tensor, + hook, + handle, + name, + ) in self.tensor_hooks.values(): + # Note: [On tensor.register_hook] + # + # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented + # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries). + # + # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph. + # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in + # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able + # tensors. Because a source indicates knowledge of this object outside the torch compile region, and + # because we are running residuals firmly before .backward() can be run, it is sound to invoke + # `register_hook` on a known tensor. + # + # For tensors without a source, we support a limited subset of hooks. Global functions only, and + # compiled_autograd must be enabled or we will graph break. + # + # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the + # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed + # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the + # stack intact. + # + # Dynamo Tensor Hooks Workflow: + # - Functions passed to register_hook are lifted globally. + # - For tensors with sources: + # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to: + # - Generate the tensor. + # - Issue a register_hook call on the tensor, linking to the globally stored function. + # - Incorporate a handle if one was established in the eager phase. + # - For tensors without sources: + # - We don't generate any instructions for registering a hook. + # - Handles from intermediary hooks are NYI. + # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it. + # - We then manually insert the call function above into the graph. + # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. + assert tensor.source, "Hooks on non input tensors NYI - should not get here" + + def gen_fn(): + cg(tensor) + cg.extend_output([cg.create_load_attr(name)]) + + cg.add_push_null(gen_fn) + cg(hook) + cg.extend_output(create_call_function(1, False)) + + # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will + # be associated with the return value of register_hook(). This consumes the top of stack. + cg.add_cache(handle) + + def get_ca_final_callbacks_var(self): + from .variables.base import ValueMutationNew + + if self.ca_final_callbacks_var is None: + self.ca_final_callbacks_var = variables.ListVariable( + [], mutation_type=ValueMutationNew() + ) + return self.ca_final_callbacks_var + + def codegen_update_mutated(self, cg: PyCodegen): + suffixes = [] + for var in self._get_modified_vars(): + if isinstance(var, variables.ListVariable): + # old[:] = new + cg(var, allow_cache=False) # Don't codegen via source + cg(var.source) # type: ignore[attr-defined] + cg.extend_output( + [ + cg.create_load_const(None), + cg.create_load_const(None), + create_instruction("BUILD_SLICE", arg=2), + ] + ) + suffixes.append([create_instruction("STORE_SUBSCR")]) + elif isinstance(var, variables.lists.DequeVariable): + # For limited maxlen, the order of operations matter for side + # effect, but we currently don't track the order, so no support. + if not ( + isinstance(var.maxlen, variables.ConstantVariable) + and var.maxlen.value is None + ): + unimplemented_v2( + gb_type="Side effect on existing deque with limited maxlen", + context="", + explanation="This is not supported.", + hints=[ + "Don't use a deque with `maxlen` specified.", + ], + ) + + # old.extend(new), this runs last + cg(var.source) + cg.load_method("extend") + cg(var, allow_cache=False) # Don't codegen via source + suffixes.append( + [ + *create_call_method(1), + create_instruction("POP_TOP"), + ] + ) + + # old.clear(), this runs first + cg(var.source) + cg.load_method("clear") + suffixes.append( + [ + *create_call_method(0), + create_instruction("POP_TOP"), + ] + ) + + elif isinstance(var, variables.ConstDictVariable): + # Reconstruct works as follow: + # (1) Skip codegen if there are no new items + # (2) codegen(...) each pair of key/value + # (3) create a new dictionary with the pairs of key/values above + # (4) clear the original dictionary + # + only if a key was removed from the input dict + # (5) update the original dictionary with the dict created in (2) + + if var.has_new_items(): + cg(var.source) # type: ignore[attr-defined] + cg.load_method("update") + cg(var, allow_cache=False) # Don't codegen via source + + if var.should_reconstruct_all: + cg(var.source) # type: ignore[attr-defined] + cg.load_method("clear") + + suffixes.append( + [ + *create_call_method(1), # update + create_instruction("POP_TOP"), + ] + ) + + if var.should_reconstruct_all: + # clear will appear before "update" as the suffixes are + # applied in reverse order. + suffixes.append( + [ + *create_call_method(0), # clear + create_instruction("POP_TOP"), + ] + ) + + elif isinstance( + var, variables.torch_function.TorchFunctionModeStackVariable + ): + # Needed in the finally block for stack restoration + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "get_torch_function_mode_stack" + ) + ) + cg.call_function(0, False) + name = variables.torch_function.get_prev_stack_var_name() + cg.code_options["co_varnames"] += (name,) + cg.append_output(create_instruction("STORE_FAST", argval=name)) + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "set_torch_function_mode_stack" + ) + ) + + cg.foreach(var.symbolic_stack) + cg.append_output( + create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) + ) + cg.call_function(1, False) + cg.append_output(create_instruction("POP_TOP")) + + elif isinstance(var, variables.CellVariable) and var.local_name is not None: + # Emit more readable and performant bytecode. + # TODO generalize this for cells created during inlining. + if var in self.store_attr_mutations: + contents_var = self.load_cell(var) + cg(contents_var) + suffixes.append([cg.create_store_deref(var.local_name)]) + + elif self.is_attribute_mutation(var): + if isinstance( + var, variables.UserDefinedDictVariable + ) and self.is_modified(var._dict_vt): + # Do dict related update manually here. The store_attr + # mutations will be applied later. + varname_map = {} + for name in _manual_dict_setitem.__code__.co_varnames: + varname_map[name] = cg.tx.output.new_var() + + try: + mro_index = type(var.value).__mro__.index( + collections.OrderedDict + ) + except ValueError: + mro_index = type(var.value).__mro__.index(dict) + + cg.extend_output( + [ + create_instruction("LOAD_CONST", argval=mro_index), + create_instruction( + "STORE_FAST", argval=varname_map["mro_index"] + ), + ] + ) + + cg(var.source) # type: ignore[attr-defined] + cg.extend_output( + [ + create_instruction( + "STORE_FAST", argval=varname_map["dict_to"] + ) + ] + ) + + cg(var._dict_vt, allow_cache=False) # Don't codegen via source + cg.extend_output( + [ + create_instruction( + "STORE_FAST", argval=varname_map["dict_from"] + ) + ] + ) + + dict_update_insts = bytecode_from_template( + _manual_dict_setitem, varname_map=varname_map + ) + + suffixes.append( + [ + *dict_update_insts, + create_instruction("POP_TOP"), + ] + ) + elif isinstance( + var, variables.UserDefinedListVariable + ) and self.is_modified(var._list_vt): + # Update the list to the updated items. Be careful in + # calling the list methods and not the overridden methods. + varname_map = {} + for name in _manual_list_update.__code__.co_varnames: + varname_map[name] = cg.tx.output.new_var() + + cg(var.source) # type: ignore[attr-defined] + cg.extend_output( + [ + create_instruction( + "STORE_FAST", argval=varname_map["list_to"] + ) + ] + ) + + cg(var._list_vt, allow_cache=False) # Don't codegen via source + cg.extend_output( + [ + create_instruction( + "STORE_FAST", argval=varname_map["list_from"] + ) + ] + ) + + list_update_insts = bytecode_from_template( + _manual_list_update, varname_map=varname_map + ) + + suffixes.append( + [ + *list_update_insts, + create_instruction("POP_TOP"), + ] + ) + + # Applying mutations involves two steps: 1) Push all + # reconstructed objects onto the stack. 2) Call STORE_ATTR to + # apply the mutations. + # + # Dynamo must ensure that mutations are applied in the same + # order as in the original program. Therefore, two reverse + # operations occur below. + # + # The first reverse operation concerns `suffixes`. We apply + # suffixes in reverse order due to the way Python handles the + # stack. In Step 1, we push all reconstructed objects onto the + # stack, but the item at the top of the stack refers to the last + # attribute in the mutation order. If not fixed, this will apply + # the mutations of attributes in the reverse order. To account + # for this reversal, we iterate through the mutable attributes + # in reverse order. + for name, value in reversed( + self.store_attr_mutations.get(var, {}).items() + ): + if isinstance(var, variables.NewGlobalVariable): + cg.tx.output.update_co_names(name) + cg(value) + assert isinstance(var.source, GlobalSource) # type: ignore[attr-defined] + suffixes.append( + [create_instruction("STORE_GLOBAL", argval=name)] + ) + elif isinstance(value, variables.DeletedVariable): + if isinstance( + var.mutation_type, AttributeMutationExisting + ) and hasattr(getattr(var, "value", None), name): + cg.tx.output.update_co_names(name) + cg(var.source) + suffixes.append( + [create_instruction("DELETE_ATTR", argval=name)] + ) + elif isinstance( + var, variables.UserDefinedObjectVariable + ) and var.should_skip_descriptor_setter(name): + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "object_setattr_ignore_descriptor" + ) + ) + cg(var.source) # type: ignore[attr-defined] + cg(variables.ConstantVariable(name)) + cg(value) + suffixes.append( + [ + *create_call_function(3, False), + create_instruction("POP_TOP"), + ] + ) + elif ( + isinstance(var, variables.UserDefinedObjectVariable) + and var.needs_slow_setattr() + ): + # __setattr__ is defined on this object, so call object.__setattr__ directly + cg.load_import_from("builtins", "object") + cg.load_method("__setattr__") + cg(var.source) # type: ignore[attr-defined] + cg(variables.ConstantVariable(name)) + cg(value) + suffixes.append( + [*create_call_method(3), create_instruction("POP_TOP")] + ) + else: + cg.tx.output.update_co_names(name) + cg(value) + cg(var) + suffixes.append([create_instruction("STORE_ATTR", argval=name)]) + elif isinstance(var, variables.ListIteratorVariable): + for _ in range(var.index): + cg.add_push_null( + lambda: cg.load_import_from(utils.__name__, "iter_next") + ) + cg(var.source) # type: ignore[attr-defined] + cg.call_function(1, False) + cg.pop_top() + elif isinstance(var, variables.RandomVariable): + # set correct random seed state + def gen_fn(): + cg(var.source) # type: ignore[attr-defined] + cg.load_attr("setstate") + + cg.add_push_null(gen_fn) + cg(var.wrap_state(var.random.getstate())) + + suffixes.append( + [ + *create_call_function(1, False), # setstate + create_instruction("POP_TOP"), + ] + ) + else: + raise AssertionError(type(var)) + + # do all the actual mutations at the very end to handle dependencies + for suffix in reversed(suffixes): + cg.extend_output(suffix) + + def is_empty(self): + return not ( + any(map(self.is_modified, self.id_to_variable.values())) + or self.tensor_hooks + or self.save_for_backward + or self.tensor_hooks + ) + + def clear(self): + self.keepalive.clear() + self.id_to_variable.clear() + + +@contextlib.contextmanager +def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): + assert tx.output.current_tracer.under_activation_checkpoint + orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint + try: + tx.output.current_tracer.allow_side_effects_under_checkpoint = True + yield + finally: + tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val + + +@contextlib.contextmanager +def allow_externally_visible_side_effects_in_subtracer(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.unsafe_allow_externally_visible_side_effects + try: + tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = True + yield + finally: + tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = orig_val + + +@contextlib.contextmanager +def disallow_side_effects_in_generator(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.is_reconstructing_generator + try: + tx.output.current_tracer.is_reconstructing_generator = True + yield + finally: + tx.output.current_tracer.is_reconstructing_generator = orig_val diff --git a/phivenv/Lib/site-packages/torch/_dynamo/source.py b/phivenv/Lib/site-packages/torch/_dynamo/source.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf19192f86cec504401814fb486bee6db8ed30c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/source.py @@ -0,0 +1,1028 @@ +# mypy: allow-untyped-defs + +""" +This module provides Source classes that track the origins of values in PyTorch Dynamo. +Sources represent where values come from (e.g. local variables, globals, attributes) and +are used for guard generation and code reconstruction during compilation. + +The module includes specialized sources for: +- Local variables and synthetic locals +- Global variables and constants +- Object attributes and method calls +- NN module specialization (specialized vs unspecialized) +- Random values and tensor properties +- Default argument handling +- FSDP (Fully Sharded Data Parallel) modules + +Sources play a key role in Dynamo's guard system by tracking value origins for +guard generation, and in code reconstruction by providing methods to rebuild +the code needed to recreate values. +""" + +import dataclasses +import enum +import functools +from typing import Any, Optional, TYPE_CHECKING, Union + +from torch._guards import ChainedSource, GuardSource, Source + +from . import utils +from .bytecode_transformation import create_call_function, create_instruction + + +if TYPE_CHECKING: + from .codegen import PyCodegen + +# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module, +# so those cases are omitted intentionally + +# represents nn.Modules tracked with NNModuleVariable (specialized is implicit in the variable name) +_GUARD_SOURCE_SPECIALIZED_NN_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_SPECIALIZED_NN_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_SPECIALIZED_NN_MODULE, + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, + # Just to ensure that guard_source() works + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, +} + +# represents nn.Modules tracked with UnspecializedNNModuleVariable +_GUARD_SOURCE_UNSPECIALIZED_NN_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, + # this happens for an UnspecializedNNModule submodule on a NNModuleVariable + GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE, + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, + # Just to ensure that guard_source() works + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, +} + +# represents nn.Modules tracked with UnspecializedBuiltinNNModuleVariable +_GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + # Just to ensure that guard_source() works + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, +} + +_GUARD_SOURCE_FSDP_MODULE = { + GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, + GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, +} + + +def is_constant_source(source): + if isinstance(source, ConstantSource): + return True + try: + if source.guard_source() == GuardSource.CONSTANT: + return True + except NotImplementedError: + pass + + return False + + +@dataclasses.dataclass(frozen=True) +class LocalSource(Source): + local_name: str + + # Whether this local is an input to the root frame. + is_input: bool = False + + # Whether we know this input is dynamic (based on example_inputs) + # For non tensors, we simply look at the first index of the tuple + dynamism: Optional[frozenset[str]] = None + + # Whether the item at this source is the _content_ of a cell that is + # dereferenced from the root frame, i.e., it's a part of the `co_cellvars` + # or `co_freevars`. + is_derefed_cell_contents: bool = False + + def reconstruct(self, codegen: "PyCodegen"): + if self.is_derefed_cell_contents: + codegen.load_deref(self.local_name) + else: + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self): + return GuardSource.LOCAL + + def name(self): + return f"L[{repr(self.local_name)}]" + + +@dataclasses.dataclass(frozen=True) +class SyntheticLocalSource(Source): + local_name: str + + def reconstruct(self, codegen: "PyCodegen"): + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self): + return GuardSource.SYNTHETIC_LOCAL + + def name(self): + return f"SYNTHETIC_LOCAL[{self.local_name!r}]" + + +@dataclasses.dataclass(frozen=True) +class RandomValueSource(Source): + random_call_index: int + + def guard_source(self): + return GuardSource.RANDOM_VALUE + + def reconstruct(self, codegen: "PyCodegen"): + codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) + codegen.append_output(codegen.create_load_const(self.random_call_index)) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def name(self): + return f"random_value_{self.random_call_index}" + + +@dataclasses.dataclass(frozen=True) +class GlobalSource(Source): + global_name: str + + def reconstruct(self, codegen: "PyCodegen"): + codegen.append_output(codegen.create_load_global(self.global_name, add=True)) + + def guard_source(self): + return GuardSource.GLOBAL + + def name(self): + return f"G[{repr(self.global_name)}]" + + +@dataclasses.dataclass(frozen=True) +class GlobalWeakRefSource(Source): + global_name: str + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_global(self.global_name, add=True) + ) + ) + codegen.extend_output(create_call_function(0, False)) + + def guard_source(self): + return GuardSource.GLOBAL + + def name(self): + return f"G[{repr(self.global_name)}]()" + + +@dataclasses.dataclass(frozen=True) +class WeakRefCallSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen(self.base)) + codegen.extend_output(create_call_function(0, False)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}()" + + +@dataclasses.dataclass(frozen=True) +class CallFunctionNoArgsSource(WeakRefCallSource): + pass + + +@dataclasses.dataclass(frozen=True) +class AttrSource(ChainedSource): + member: str + + def __post_init__(self): + assert self.base, "Can't construct an AttrSource without a valid base source" + if "." in self.member: + member_parts = self.member.split(".") + object.__setattr__( + self, "base", AttrSource(self.base, ".".join(member_parts[:-1])) + ) + object.__setattr__(self, "member", member_parts[-1]) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs(self.member)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + if not self.member.isidentifier(): + return f"getattr({self.base.name()}, {self.member!r})" + return f"{self.base.name()}.{self.member}" + + +@dataclasses.dataclass(frozen=True) +class GenericAttrSource(ChainedSource): + member: str + + def __post_init__(self): + assert self.base, "Can't construct an AttrSource without a valid base source" + if "." in self.member: + member_parts = self.member.split(".") + object.__setattr__( + self, "base", AttrSource(self.base, ".".join(member_parts[:-1])) + ) + object.__setattr__(self, "member", member_parts[-1]) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs(self.member)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"object.__getattribute__({self.base.name()}, {self.member!r})" + + +@dataclasses.dataclass(frozen=True) +class LocalCellSource(Source): + """ + Conceptually, this class is `LocalSource` for cell objects implicitly + generated by Python (e.g., captured variables). + """ + + local_name: str + + def reconstruct(self, codegen: "PyCodegen"): + # Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics, + # Dynamo's bytecode transformation differentiates them slightly, so we + # always emit `LOAD_CLOSURE` here. + codegen.append_output(codegen.create_load_closure(self.local_name)) + + # All the other methods are intentionally unimplemented because e.g., a + # local cell object should never be used for guards. + + +# Represents tensor.grad source. It could be represented by AttrSource as well. +# But, we could access grad field on tensor directly in C++ without going +# through the Python bytecodes. Therefore, we use a separate source for grad +# field. +@dataclasses.dataclass(frozen=True) +class GradSource(ChainedSource): + member: str = "grad" + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs(self.member)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}.{self.member}" + + +@dataclasses.dataclass(frozen=True) +class ParamBufferSource(AttrSource): + def guard_source(self): + return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] + + +# Special AttrSource to differentiate module._buffers or module._parameters +@dataclasses.dataclass(frozen=True) +class UnspecializedParamBufferSource(AttrSource): + pass + + +# This source is intended to be used in places where a source is needed but it is expected +# that the symbol will be simplified out later on. Symbols with ephemeral sources are +# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral +# source. Guarding on this source is an error. +# +# Example: During subclass view fake-ification, any close-over ViewFunc state should be +# symbolicized / fake-ified to avoid invalid specialization during view replay. This source +# is useful for symbols utilized in the middle of the view chain that are not expected to be +# present within the final view shape metadata. +@dataclasses.dataclass(frozen=True) +class EphemeralSource(Source): + desc: Optional[str] = None + + def guard_source(self): + return GuardSource.EPHEMERAL + + def name(self): + return f"" + + def make_guard(self, fn): + raise NotImplementedError + + def is_ephemeral(self): + return True + + +class TensorProperty(enum.Enum): + SIZE = 0 + STRIDE = 1 + STORAGE_OFFSET = 2 + + def method_name(self): + if self is TensorProperty.SIZE: + return "size" + elif self is TensorProperty.STRIDE: + return "stride" + elif self is TensorProperty.STORAGE_OFFSET: + return "storage_offset" + + +@dataclasses.dataclass(frozen=True) +class TensorPropertySource(ChainedSource): + prop: TensorProperty + idx: Optional[int] = None # None for STORAGE_OFFSET + + def __post_init__(self): + assert self.base is not None + if self.prop is TensorProperty.STORAGE_OFFSET: + assert self.idx is None + else: + assert self.idx is not None + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from( + utils.__name__, f"call_{self.prop.method_name()}" + ) + ) + codegen(self.base) + + if self.idx is not None: + codegen.append_output(codegen.create_load_const(self.idx)) + codegen.extend_output( + create_call_function(2 if self.idx is not None else 1, False) + ) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + if self.prop is TensorProperty.SIZE: + return f"{self.base.name()}.size()[{self.idx}]" + elif self.prop is TensorProperty.STRIDE: + return f"{self.base.name()}.stride()[{self.idx}]" + elif self.prop is TensorProperty.STORAGE_OFFSET: + assert self.idx is None + return f"{self.base.name()}.storage_offset()" + else: + raise AssertionError(f"unhandled {self.prop}") + + +@dataclasses.dataclass(frozen=True) +class IndexedSource(ChainedSource): + idx: int + + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + raise NotImplementedError + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"({self.idx}, {self.base.name()})" + + +@dataclasses.dataclass(frozen=True) +class NegateSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + raise NotImplementedError + + def guard_source(self): + return self.base.guard_source() + + def name(self): + # NB: use method call so that function stripping regexes work + return f"{self.base.name()}.__neg__()" + + +@dataclasses.dataclass(frozen=True) +class ConvertIntSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"cast_symbool_to_symint_guardless({self.base.name()})" + + +@dataclasses.dataclass(frozen=True) +class FlattenScriptObjectSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}.__obj_flatten__()" + + +@dataclasses.dataclass(frozen=True) +class ScriptObjectQualifiedNameSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}._type().qualified_name()" + + +class AttrProxySource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}.get_base()" + + +@dataclasses.dataclass(frozen=True) +class DefaultsSource(ChainedSource): + idx_key: Union[int, str] + is_kw: bool = False + field: str = dataclasses.field(init=False, repr=False, compare=False) + _name: str = dataclasses.field(init=False, repr=False, compare=False) + + def __post_init__(self): + assert self.base, ( + "Base must be a valid source in order to properly track and guard this Defaults to its origin." + ) + if self.is_kw: + assert isinstance(self.idx_key, str) + object.__setattr__(self, "field", "__kwdefaults__") + object.__setattr__( + self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']" + ) + else: + assert isinstance(self.idx_key, int) + object.__setattr__(self, "field", "__defaults__") + object.__setattr__( + self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" + ) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs(self.field)) + codegen.append_output(codegen.create_load_const(self.idx_key)) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return self._name + + +@dataclasses.dataclass(frozen=True) +class GetItemSource(ChainedSource): + index: Any + index_is_slice: bool = False + + def __post_init__(self): + assert self.base is not None + if isinstance(self.index, slice): + # store the hashable version of the slice so the whole GetItemSource is hashable + super().__setattr__("index", self.index.__reduce__()) + super().__setattr__("index_is_slice", True) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + if self.index_is_slice: + codegen.append_output(codegen.create_load_const(self.unpack_slice())) + else: + codegen.append_output(codegen.create_load_const(self.index)) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def guard_source(self): + return self.base.guard_source() + + def unpack_slice(self): + assert self.index_is_slice + slice_class, slice_args = self.index + return slice_class(*slice_args) + + def name(self): + # Index can be of following types + # 1) index is a slice - example 1:4 + # 2) index is a constant - example string, integer + assert not isinstance(self.index, Source) + if self.index_is_slice: + return f"{self.base.name()}[{self.unpack_slice()!r}]" + else: + return f"{self.base.name()}[{self.index!r}]" + + +@dataclasses.dataclass(frozen=True) +class ConstDictKeySource(ChainedSource): + index: Any + + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") + ) + codegen(self.base) + codegen.append_output(codegen.create_load_const(self.index)) + codegen.extend_output(create_call_function(2, False)) + + def name(self): + # The list creation will be CSE'd by PyExprCSEPass + return f"list(dict.keys({self.base.name()}))[{self.index!r}]" + + def is_dict_key(self): + return True + + +# Used to access an item from the dictionary +@dataclasses.dataclass(frozen=True) +class DictGetItemSource(ChainedSource): + # Key to access in the dictionary. It can be one of the the following types + # 1) ConstDictKeySource + # 2) constant - like string, integer + index: Any + + def __post_init__(self): + from .variables import ConstantVariable + + assert isinstance( + self.index, ConstDictKeySource + ) or ConstantVariable.is_literal(self.index) + + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): + # Load dict + codegen(self.base) + + # Load key + if isinstance(self.index, Source): + codegen(self.index) + else: + codegen.append_output(codegen.create_load_const(self.index)) + codegen.append_output(create_instruction("BINARY_SUBSCR")) + + def name(self): + if isinstance(self.index, ConstDictKeySource): + return f"{self.base.name()}[{self.index.name()}]" + else: + return f"{self.base.name()}[{self.index!r}]" + + +# Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that +# torch.compile does not run the overridden __getitem__ method +@dataclasses.dataclass(frozen=True) +class DictSubclassGetItemSource(ChainedSource): + # Key to access in the dictionary. It can be one of the the following types + # 1) ConstDictKeySource + # 2) constant - like string, integer + index: Any + + def __post_init__(self): + from .variables import ConstantVariable + + assert isinstance( + self.index, ConstDictKeySource + ) or ConstantVariable.is_literal(self.index) + + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): + # reconstruct dict.__getitem__(dct, key) + + # Load dict.__getitem__ + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "dict_getitem") + ) + + # Load dict + codegen(self.base) + + # Load key + if isinstance(self.index, Source): + codegen(self.index) + else: + codegen.append_output(codegen.create_load_const(self.index)) + + codegen.extend_output(create_call_function(2, False)) + + def name(self): + if isinstance(self.index, ConstDictKeySource): + return f"dict.__getitem__({self.base.name()}, {self.index.name()})" + else: + return f"{self.base.name()}[{self.index!r}]" + + +@dataclasses.dataclass(frozen=True) +class ListGetItemSource(GetItemSource): + """ + Same as GetItemSource with reconstruct and name overridden to be list specific. + """ + + def reconstruct(self, codegen: "PyCodegen"): + # Reconstruct list.__getitem__(lst, index) to avoid any side effects + # from possibly overridden __getitem__. + + # Load list.__getitem__ + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "list_getitem") + ) + + # Load the list + codegen(self.base) + + # Load the index + if self.index_is_slice: + raise RuntimeError( + "List[slice] is a temporary object and should not have a source" + ) + else: + codegen.append_output(codegen.create_load_const(self.index)) + + codegen.extend_output(create_call_function(2, False)) + + def name(self): + # Index can be of following types + # 1) index is a slice - example 1:4 + # 2) index is a constant - example string, integer + assert not isinstance(self.index, Source) + if self.index_is_slice: + raise RuntimeError( + "List[slice] is a temporary object and should not have a source" + ) + else: + return f"list.__getitem__({self.base.name()}, {self.index!r})" + + +@dataclasses.dataclass(frozen=True) +class TupleIteratorGetItemSource(GetItemSource): + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") + ) + codegen(self.base) + codegen.append_output(codegen.create_load_const(self.index)) + codegen.extend_output(create_call_function(2, False)) + + def name(self): + return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" + + +@dataclasses.dataclass(frozen=True) +class DataclassFieldsSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "dataclass_fields") + ) + codegen(self.base) + codegen.extend_output(create_call_function(1, False)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"___dataclass_fields({self.base.name()})" + + +@dataclasses.dataclass(frozen=True) +class TypeSource(ChainedSource): + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type")) + codegen(self.base) + codegen.extend_output(create_call_function(1, False)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"type({self.base.name()})" + + +@dataclasses.dataclass(frozen=True) +class OptimizerSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return self.base.name() + + +@dataclasses.dataclass(frozen=True) +class NNModuleSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] + + def name(self): + return self.base.name() + + +@dataclasses.dataclass(frozen=True) +class UnspecializedNNModuleSource(NNModuleSource): + def guard_source(self): + return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()] + + +@dataclasses.dataclass(frozen=True) +class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource): + def guard_source(self): + return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()] + + +@dataclasses.dataclass(frozen=True) +class FSDPNNModuleSource(NNModuleSource): + def guard_source(self): + return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()] + + +@dataclasses.dataclass(frozen=True) +class GlobalStateSource(Source): + def name(self): + return "" + + def guard_source(self): + return GuardSource.GLOBAL + + +@dataclasses.dataclass(frozen=True) +class TorchFunctionModeStackSource(Source): + ind: int + + def name(self): + return f"___get_torch_function_mode_stack_at({self._get_index()})" + + def _get_index(self): + from .variables.torch_function import TorchFunctionModeStackVariable + + return TorchFunctionModeStackVariable.get_mode_index(self.ind) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from( + utils.__name__, "get_torch_function_mode_stack_at" + ) + ) + codegen.extend_output([codegen.create_load_const(self._get_index())]) + codegen.extend_output(create_call_function(1, False)) + + def guard_source(self): + return GuardSource.GLOBAL + + +@dataclasses.dataclass(frozen=True) +class ConstantSource(Source): + source_name: str + + def reconstruct(self, codegen: "PyCodegen"): + codegen.append_output(codegen.create_load_global(self.source_name, add=False)) + + def guard_source(self): + return GuardSource.CONSTANT + + def name(self): + return self.source_name + + def make_guard(self, fn): + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class NumpyTensorSource(ChainedSource): + def name(self) -> str: + return f"___from_numpy({self.base.name()})" + + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) + codegen(self.base) + codegen.extend_output(create_call_function(1, False)) + + +@dataclasses.dataclass(frozen=True) +class SubclassAttrListSource(ChainedSource): + def name(self) -> str: + return f"{self.base.name()}.__tensor_flatten__()[0]" + + def guard_source(self): + return self.base.guard_source() + + +# NB: We don't expect you to actually ever generate guards against this +# source, it is ephemeral +@dataclasses.dataclass(frozen=True) +class FloatTensorSource(ChainedSource): + def name(self) -> str: + return f"___as_tensor({self.base.name()})" + + def guard_source(self): + return self.base.guard_source() + + +@dataclasses.dataclass(frozen=True) +class CallMethodItemSource(ChainedSource): + def name(self) -> str: + return f"{self.base.name()}.item()" + + def guard_source(self): + return self.base.guard_source() + + +# This is a synthetic source that is associated with the singleton +# shape env guard we always register for all frames. We get the actual +# guard contents from the ambient ShapeEnv +@dataclasses.dataclass(frozen=True) +class ShapeEnvSource(Source): + def name(self): + return "" + + def guard_source(self): + return GuardSource.SHAPE_ENV + + +@dataclasses.dataclass(frozen=True) +class BackwardStateSource(Source): + def name(self): + return "" + + def guard_source(self): + return GuardSource.BACKWARD_STATE + + +def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional[str]: + if isinstance(source, ChainedSource): + return get_local_source_name(source.base, only_allow_input=only_allow_input) + if not isinstance(source, LocalSource): + return None + if only_allow_input and not source.is_input: + return None + return source.local_name + + +def is_from_local_source(source: Source, *, only_allow_input=False): + return get_local_source_name(source, only_allow_input=only_allow_input) is not None + + +def is_from_global_source(source: Source) -> bool: + return get_global_source_name(source) is not None + + +def get_global_source_name(source: Source) -> Optional[str]: + if isinstance(source, ChainedSource): + return get_global_source_name(source.base) + if not isinstance(source, GlobalSource): + return None + return source.global_name + + +def is_from_nonlocal_source(source: Source): + if isinstance(source, ChainedSource): + return is_from_nonlocal_source(source.base) + return ( + isinstance(source, LocalSource) + and source.is_derefed_cell_contents + and not source.is_input + ) + + +def is_from_source(source: Source, target: Source): + if isinstance(source, ChainedSource): + return is_from_source(source.base, target) + return source == target + + +@functools.lru_cache +def is_from_unspecialized_nn_module_source(source: Source): + if isinstance(source, UnspecializedNNModuleSource): + return True + if isinstance(source, ChainedSource): + return is_from_unspecialized_nn_module_source(source.base) + return False + + +@functools.lru_cache +def is_from_unspecialized_builtin_nn_module_source(source: Source): + if isinstance(source, UnspecializedBuiltinNNModuleSource): + return True + if isinstance(source, ChainedSource): + return is_from_unspecialized_builtin_nn_module_source(source.base) + return False + + +@functools.lru_cache +def is_from_unspecialized_param_buffer_source(source: Source): + if isinstance(source, UnspecializedParamBufferSource): + return True + if isinstance(source, ChainedSource): + return is_from_unspecialized_param_buffer_source(source.base) + return False + + +@functools.lru_cache +def is_from_flatten_script_object_source(source: Source): + if isinstance(source, FlattenScriptObjectSource): + return True + elif isinstance(source, ChainedSource): + return is_from_flatten_script_object_source(source.base) + return False + + +@functools.lru_cache +def is_from_optimizer_source(source: Source): + if isinstance(source, OptimizerSource): + return True + if isinstance(source, ChainedSource): + return is_from_optimizer_source(source.base) + return False + + +# TODO: can probably write a generic "test this on everything in the chain" +# helper +@functools.lru_cache +def is_from_defaults(source: Source): + if isinstance(source, DefaultsSource): + return True + + # Accessed with func.__kwdefaults__["foo"] + if ( + isinstance(source, DictGetItemSource) + and isinstance(source.base, AttrSource) + and source.base.member == "__kwdefaults__" + ): + return True + + # Accessed with func.__defaults__[0] + if ( + isinstance(source, GetItemSource) + and isinstance(source.base, AttrSource) + and source.base.member == "__defaults__" + ): + return True + + if isinstance(source, ChainedSource): + return is_from_defaults(source.base) + return False diff --git a/phivenv/Lib/site-packages/torch/_dynamo/symbolic_convert.py b/phivenv/Lib/site-packages/torch/_dynamo/symbolic_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..f99e6693057f43da3356c0f74ef478f84f68952d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/symbolic_convert.py @@ -0,0 +1,4224 @@ +# mypy: allow-untyped-defs + +""" +Core module responsible for converting Python bytecode into TorchDynamo's symbolic execution format. + +This module implements the bytecode-level tracing system that allows TorchDynamo to analyze +and transform Python code. It converts Python bytecode instructions into a symbolic format +that tracks the flow of tensors and other values through the program. + +Key components: +- InstructionTranslatorBase: Base class for converting bytecode to symbolic execution +- InstructionTranslator: Main translator for function bytecode +- InliningInstructionTranslator: Handles inlining of called functions +- SpeculationLog: Manages state for speculative execution and rollback + +The symbolic conversion process handles: +- Control flow (loops, conditionals, etc.) +- Function inlining and call stack management +- Tracking of program values and side effects +- Graph breaks and resumption points +- Exception handling and stack frame management + +This is a core part of TorchDynamo's tracing system that enables ahead-of-time +optimization of PyTorch programs. +""" + +import collections +import collections.abc +import contextlib +import copy +import dataclasses +import dis +import functools +import importlib +import inspect +import itertools +import linecache +import logging +import operator +import re +import sys +import threading +import traceback +import types +import typing +import weakref +from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union +from unittest.mock import patch + +import torch +import torch._logging +from torch._dynamo.exc import TensorifyScalarRestartAnalysis +from torch._guards import tracing, TracingContext +from torch._logging.structured import dump_file +from torch.fx.experimental.symbolic_shapes import guard_bool +from torch.utils._functools import cache_method + +from . import ( + config, + exc, + graph_break_hints, + logging as torchdynamo_logging, + trace_rules, + variables, +) +from .bytecode_analysis import ( + get_indexof, + JUMP_OPNAMES, + livevars_analysis, + propagate_line_nums, +) +from .bytecode_transformation import ( + cleaned_instructions, + create_call_function, + create_instruction, + create_jump_absolute, + create_swap, + get_code_keys, + Instruction, + is_generator, + unique_id, +) +from .code_context import code_context +from .codegen import PyCodegen +from .exc import ( + ArgsMismatchError, + BackendCompilerFailed, + collapse_resume_frames, + format_graph_break_message, + get_stack_above_dynamo, + unimplemented_v2, + Unsupported, +) +from .funcname_cache import get_funcname +from .guards import GuardBuilder, install_guard +from .output_graph import GraphCompileReason, OutputGraph +from .replay_record import DummyModule, ExecutionRecorder +from .resume_execution import ContinueExecutionCache, ReenterWith +from .source import ( + AttrSource, + DictGetItemSource, + GlobalSource, + GlobalWeakRefSource, + LocalCellSource, + LocalSource, + Source, +) +from .trace_rules import is_builtin_constant, is_forbidden +from .utils import ( + counters, + get_fake_value, + get_instruction_source_311, + get_metrics_context, + graph_break_dup_warning_checker, + istype, + LazyString, + proxy_args_kwargs, +) +from .variables.base import typestr, ValueMutationNew, VariableTracker +from .variables.builder import FrameStateSizeEntry, VariableBuilder, wrap_fx_proxy +from .variables.builtin import BuiltinVariable +from .variables.constant import ConstantVariable +from .variables.ctx_manager import ( + ContextWrappingVariable, + GenericContextWrappingVariable, + WithExitFunctionVariable, +) +from .variables.dicts import ConstDictVariable, SetVariable +from .variables.functions import ( + BaseUserFunctionVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, + NestedUserFunctionVariable, + SkipFunctionVariable, + UserFunctionVariable, + UserMethodVariable, +) +from .variables.iter import MAX_ITERATOR_LIMIT +from .variables.lazy import LazyVariableTracker +from .variables.lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + SliceVariable, + TupleVariable, +) +from .variables.misc import ( + CellVariable, + ExceptionVariable, + GetAttrVariable, + NullVariable, + PythonModuleVariable, + UnknownVariable, +) +from .variables.nn_module import NNModuleVariable +from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable +from .variables.torch_function import ( + SymbolicTorchFunctionState, + TorchFunctionModeVariable, +) +from .variables.user_defined import ( + RemovableHandleVariable, + UserDefinedClassVariable, + UserDefinedExceptionClassVariable, + UserDefinedExceptionObjectVariable, + UserDefinedObjectVariable, +) + + +if TYPE_CHECKING: + from .package import CompilePackage + +log = logging.getLogger(__name__) +graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") +trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") +trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source") +trace_bytecode_log = torch._logging.getArtifactLogger(__name__, "trace_bytecode") +tls = threading.local() +compare_op_handlers: dict[str, Any] = { + k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items() +} +handle_contains = BuiltinVariable(operator.contains).call_function +handle_not = BuiltinVariable(operator.not_).call_function +compare_op_handlers["in"] = lambda tx, args, _: handle_contains( + tx, [*reversed(args)], {} +) +compare_op_handlers["not in"] = lambda tx, args, _: handle_not( + tx, [handle_contains(tx, [*reversed(args)], {})], {} +) + + +PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml" + + +@functools.cache +def _import_module(name: str) -> types.ModuleType: + """ + Import the named module and cache the result. importlib.import_module() + seems to do some filesystem checking to validate the name so not caching + this can be slow. + """ + return importlib.import_module(name) + + +@dataclasses.dataclass +class SpeculationEntry: + filename: str + lineno: int + instruction_pointer: int + inst: Instruction # for debugging only + failed: bool = False + reason: Optional[GraphCompileReason] = None + + def fail_and_restart_analysis(self): + """ + Start tracing of the current frame over again, and don't take this branch. + """ + self.failed = True + if self.reason is not None: + restart_reason = self.reason.reason + else: + restart_reason = "Unknown fail_and_restart_analysis" + raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason) + + +@dataclasses.dataclass +class SpeculationLog: + """ + SpeculationLog replaces the prior copy_graphstate/restore_graphstate + checkpointing. Rather than saving/restoring state, we restart the + dynamo conversion process over from the beginning -- but when we + hit the start of the speculation that failed, we instead generate + a graph break. + """ + + entries: list[SpeculationEntry] = dataclasses.field(default_factory=list) + index: int = 0 + + def restart(self): + self.index = 0 + + def clear(self): + self.entries.clear() + self.index = 0 + + def next( + self, filename: str, lineno: int, instruction_pointer, inst + ) -> SpeculationEntry: + """ + Lookup or create a SpeculationEntry() that is shared across + RestartAnalysis calls. Args are used only for debug checks. + """ + if len(self.entries) == self.index: + self.entries.append( + SpeculationEntry(filename, lineno, instruction_pointer, inst) + ) + entry = self.entries[self.index] + prev_entry_msg = "" + if self.index != 0: + prev_entry = self.entries[self.index - 1] + prev_entry_msg = ( + f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}" + f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n" + ) + if not ( + entry.instruction_pointer == instruction_pointer + and entry.filename == filename + and entry.lineno == lineno + ): + raise SpeculationLogDivergence( + f""" +SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries): +- Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer}) +- Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer}) +{prev_entry_msg} +There are two usual reasons why this may have occurred: +- When Dynamo analysis restarted, the second run took a different path than + the first. If this occurred, the previous instruction is the critical instruction that + behaved differently. +- Speculation entries are only added under certain conditions (as seen in + step()), e.g., there must exist operators in the graph; those conditions may + have changed on restart. + +If this divergence was intentional, clear the speculation log before restarting (do NOT +do this for graph breaks, you will infinite loop). + +Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo +""" + ) + self.index += 1 + return entry + + +@dataclasses.dataclass +class LocalState: + automatic_dynamic: dict[str, FrameStateSizeEntry] = dataclasses.field( + default_factory=dict + ) + + def render(self) -> str: + return "\n".join( + f"{k}: {v.render()}" for k, v in self.automatic_dynamic.items() + ) + + +# Mutable box that is shared across restarts +@dataclasses.dataclass +class DistributedState: + compile_pg: Any + local_state: LocalState + all_states: Optional[list[LocalState]] = None + + +class TensorifyState: + # These are the set of string symfloats names (eg. "zf0") that we collect + # from the tensorify_python_scalars.py joint fx pass to inform us about + # which float inputs we should specialize when we restart analysis. + force_specializations: set[str] = set() + + @classmethod + def specialize(cls, index: str) -> None: + cls.force_specializations.add(index) + + @classmethod + def should_specialize(cls, index: str) -> bool: + return index in cls.force_specializations + + @classmethod + def clear(cls) -> None: + cls.force_specializations.clear() + + @classmethod + def empty(cls) -> bool: + return len(cls.force_specializations) == 0 + + +@functools.cache +def _step_logger(): + return torchdynamo_logging.get_step_logger(log) + + +@contextlib.contextmanager +def save_and_restart_speculation_log(tx: "InstructionTranslatorBase"): + # When reconstructing a generator after a graph break, we advance it until + # it is fully exhausted. This process adds new entries to the speculation + # log that were not previously observed. Without temporarily clearing the + # speculation log, this could lead to a divergence error. + + entries = tx.speculation_log.entries + index = tx.speculation_log.index + try: + tx.speculation_log.entries = [] + tx.speculation_log.index = 0 + yield + finally: + tx.speculation_log.entries = entries + tx.speculation_log.index = index + + +@contextlib.contextmanager +def temporarely_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"): + try: + tmp = tx.output.should_exit + tx.output.should_exit = False + yield + finally: + tx.output.should_exit = tmp + + +@dataclasses.dataclass +class BlockStackEntry: + # Current instruction that pushes something to block_stack + inst: Instruction + target: Instruction + stack_index: int + with_context: Optional[ + Union[ContextWrappingVariable, GenericContextWrappingVariable] + ] = None + + def can_restore(self): + return self.with_context is not None + + def resume_fn(self): + assert self.stack_index is not None + if ( + self.with_context + and hasattr(self.with_context, "target_values") + and self.with_context.target_values + ): + return ReenterWith( + self.stack_index - 1, tuple(self.with_context.target_values) + ) + else: + return ReenterWith(self.stack_index - 1) + + def exit(self, tx, is_graph_break): + assert self.with_context is not None + if ( + is_graph_break and self.with_context.exit_on_graph_break() + ) or not is_graph_break: + return self.with_context.exit(tx) + + +class SpeculationLogDivergence(AssertionError): + pass + + +class ReturnValueOp(Exception): + pass + + +class YieldValueOp(Exception): + """ + Signal to the symbolic tracer to stop and return control flow to the + caller + """ + + +def stack_op(fn: typing.Callable[..., object]): + nargs = len(inspect.signature(fn).parameters) + fn_var = BuiltinVariable(fn) + + @functools.wraps(fn) + def impl(self: "InstructionTranslator", inst: Instruction): + self.push(fn_var.call_function(self, self.popn(nargs), {})) + + return impl + + +def _detect_and_normalize_assert_statement( + self: "InstructionTranslatorBase", + truth_fn: typing.Callable[[object], bool], + push: bool, +): + # Detect if this jump instruction is assert and normalize the assert + # by pushing dummy error message when nothing is given. + # + # Python 3.9 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_ASSERTION_ERROR + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS + # + # Python 3.8 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_GLOBAL 0 (Assertion type) + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS 1 + + if (truth_fn is not operator.truth) or push: + return False + + assert isinstance(self.instruction_pointer, int) + current_instruction_pointer = self.instruction_pointer + inst = self.instructions[current_instruction_pointer] + # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 + if inst.opname != "LOAD_ASSERTION_ERROR": + return False + + current_instruction_pointer += 1 + + # Use dummy error message if its hard to extract + error_msg = "assertion error" + + inst = self.instructions[current_instruction_pointer] + # DETECT RAISE_VARARGS or LOAD CONST + if inst.opname == "LOAD_CONST": + if not isinstance(inst.argval, str): + return False + error_msg = inst.argval + + # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION + # (PRECALL for Python 3.11, CALL for Python 3.12+) + current_instruction_pointer += 1 + inst = self.instructions[current_instruction_pointer] + if inst.opname not in ("CALL_FUNCTION", "PRECALL", "CALL"): + return False + + # for Python 3.11, PRECALL should be followed by CALL, then RAISE_VARARGS + # for Python != 3.11, CALL_FUNCTION/CALL should be followed by RAISE_VARARGS + current_instruction_pointer += 1 + if inst.opname == "PRECALL": + current_instruction_pointer += 1 + inst = self.instructions[current_instruction_pointer] + + if inst.opname != "RAISE_VARARGS": + return False + + self.push(ConstantVariable.create(error_msg)) + + return True + + +explain = False + + +def log_graph_break(code_options, reason="", exc_info=False, user_stack=None): + if user_stack is None: + user_stack = torch._guards.TracingContext.extract_stack() + + try: + frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + except IndexError: + # first instruction + frame_loc = ( + code_options["co_filename"], + code_options["co_firstlineno"], + ) + + stack_above_dynamo_formatted = "" + if config.verbose: + stack_above_dynamo = get_stack_above_dynamo() + stack_above_dynamo_formatted = "".join( + traceback.format_list(stack_above_dynamo) + ) + else: + user_stack = get_stack_above_dynamo() + user_stack + user_stack = collapse_resume_frames(user_stack) + user_stack_formatted = "".join(traceback.format_list(user_stack)) + user_stack_trace = ( + f"Graph break in user code at {frame_loc[0]}:{frame_loc[1]}\n" + f"Graph Break Reason: {reason}\n" + "User code traceback:\n" + ) + + if config.verbose: + user_stack_trace += ( + f"{stack_above_dynamo_formatted}\n" + "========== most recent `torch.compile` tracing attempt started here ==========\n\n" + f"{user_stack_formatted}\n" + "NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! " + "This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another " + "Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python " + "function, which Dynamo intercepts as a top-level frame.\n" + ) + else: + user_stack_trace += str(user_stack_formatted) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_graph_break_reason", + "encoding": "string", + }, + payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc() if exc_info else ''}", + ) + + # torch._dynamo.explain() formats this a little nicer, and presents a slightly + # more actionable user code pointer + if ( + graph_break_log.isEnabledFor(logging.DEBUG) + and not explain + and graph_break_dup_warning_checker.add(frame_loc) + ): + # This log line MUST contain the string "Graph break in user code", + # This log line is exercised from + # python test/dynamo/test_exc.py -k test_graph_break_log + graph_break_log.debug( + user_stack_trace, + ) + else: + # This log line MUST not contain the string "Graph break in user code", + # exercised by + # python test/dynamo/test_misc.py -k test_duplicate_graph_break_log + graph_break_log.debug( + "Graph break (user stack suppressed due to duplicate graph break) in user code at %s:%s\nGraph Break Reason: %s", + frame_loc[0], + frame_loc[1], + reason, + ) + + +def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): + # graph break message fields for data dependent branching + _gb_type = "Data-dependent branching" + _explanation = ( + "Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). " + "Dynamo does not support tracing dynamic control flow." + ) + _hints = [ + *graph_break_hints.FUNDAMENTAL, + "Use `torch.cond` to express dynamic control flow.", + ] + + def jump_graph_break(self, inst, value, extra_msg=""): + log_graph_break( + self.code_options, + reason=format_graph_break_message( + gb_type=_gb_type, + context=f"attempted to jump with {value}", + explanation=_explanation, + hints=_hints, + ), + ) + assert self.should_compile_partial_graph() + # compile a partial subgraph prefix then jump into user code + if self.maybe_has_backedge(): + msg = ( + "Skipping frame because there is a graph break in a for/while loop\n" + f"{self.frame_summary()}" + ) + log.info(msg) + raise exc.SkipFrame(msg) + + self.push(value) + log.debug("generic_jump triggered compile") + all_stack_locals_metadata = self.output.compile_subgraph( + self, + reason=GraphCompileReason( + f"generic_jump {typestr(value)}{extra_msg}", [self.frame_summary()] + ), + stack_pops=1, + ) + self.pop() + + if_next = self.create_call_resume_at( + self.next_instruction, all_stack_locals_metadata + ) + if push: + self.push(value) + if_jump = self.create_call_resume_at(inst.target, all_stack_locals_metadata) + + if sys.version_info >= (3, 13): + # 3.13 requires stack[-1] to be bool type + self.output.add_output_instructions([create_instruction("TO_BOOL")]) + + jump_inst = create_instruction(inst.opname, target=if_jump[0]) + jump_inst.copy_positions(inst) + self.output.add_output_instructions([jump_inst] + if_next + if_jump) + + def inner(self: "InstructionTranslatorBase", inst: Instruction): + value: VariableTracker = self.pop() + if ( + config.rewrite_assert_with_torch_assert + and _detect_and_normalize_assert_statement(self, truth_fn, push) + ): + error_msg: VariableTracker = self.pop() + # Skip over things like `assert True` + if value.is_python_constant(): + if bool(value.as_python_constant()): + return self.jump(inst) + elif self.should_compile_partial_graph(): + jump_graph_break(self, inst, value) + else: + unimplemented_v2( + gb_type="Data-dependent assertion failed (cannot compile partial graph)", + context=f"value: {value}", + explanation="Dynamo has determined when encountering a data-dependent assert failure " + "that it should not compile the partial graph.", + hints=[ + *graph_break_hints.FUNDAMENTAL, + "Use `torch._assert()` to raise a hard AssertionError when the check fails. " + "This error will propagate back the user code " + "that called the compiled function (i.e. Dynamo will not trace any exception handling).", + "Remove the assert statement.", + "Move the assert statement outside of any context managers in order to graph break with " + "partial graph compilation (if fullgraph=False).", + ], + ) + + # TODO maybe should respect DtoH sync intention of users later?? + # Manually insert torch._assert_async instead of python assert and jump over + # assert related instructions as we don't need them anymore. + + # if we see Tensor as assert statement, no need to call scalar_tensor + if isinstance(value, TensorVariable): + self.output.create_proxy( + "call_function", + torch._assert_async, + *proxy_args_kwargs((value, error_msg), {}), + ) + self.jump(inst) + return + + if isinstance(value, SymNodeVariable): + # if the assertion is normal shape expression. + # just install guard and bail out. + sym_expr = value.sym_num + if not isinstance(sym_expr, torch.SymBool): + sym_expr = sym_expr != 0 + + result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr) + if not result: + unimplemented_v2( + gb_type="Assertion failed on symbolic shapes", + context=str(sym_expr), + explanation="", + hints=[*graph_break_hints.USER_ERROR], + ) + self.jump(inst) + return + + scalar_to_tensor_proxy = self.output.create_proxy( + "call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {}) + ) + + scalar_to_tensor = wrap_fx_proxy( + self, + scalar_to_tensor_proxy, + example_value=get_fake_value(scalar_to_tensor_proxy.node, self), + ) + + self.output.create_proxy( + "call_function", + torch._assert_async, + *proxy_args_kwargs((scalar_to_tensor, error_msg), {}), + ) + self.jump(inst) + return + + if value.is_python_constant(): + # ConstDictVariable is optimized to be very lazy about insertion of + # guards, so we have to manually insert a SEQUENCE_LENGTH guard + # here. + if isinstance(value, ConstDictVariable) and value.source: + install_guard(value.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + if truth_fn(value.as_python_constant()): + if push: + self.push(value) + self.jump(inst) + elif ( + isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() + ): + jump_graph_break(self, inst, value) + elif isinstance(value, NNModuleVariable): + # Equivalent of "self.nn_module is not None" + mod = self.output.get_submodule(value.module_key) + if truth_fn(mod): + if push: + self.push(value) + self.jump(inst) + elif isinstance(value, UserDefinedObjectVariable): + try: + x = value.var_getattr(self, "__bool__") # type: ignore[arg-type] + except exc.ObservedAttributeError: + exc.handle_observed_exception(self) + # if __bool__ is missing, trying __len__ to infer a truth value. + try: + x = value.var_getattr(self, "__len__") # type: ignore[arg-type] + except exc.ObservedAttributeError: + exc.handle_observed_exception(self) + x = None + + # __bool__ or __len__ is function + if isinstance(x, UserMethodVariable): + result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment] + if isinstance(result, ConstantVariable) and isinstance( + result.value, (bool, int) + ): + if truth_fn(result.value): + if push: + self.push(value) + self.jump(inst) + elif isinstance(result, SymNodeVariable): + if result.evaluate_expr(): + if push: + self.push(value) + self.jump(inst) + else: + unimplemented_v2( + gb_type="Data-dependent branching with non-constant __bool__", + context=f"method: {x}, result: {result}", + explanation="Attempted to perform data-dependent branching on a user-defined " + "object with a __bool__ method that did not return a constant.", + hints=[], + ) + # __bool__ or __len__ is non-function or not existed in the user defined object + else: + if truth_fn(True): + if push: + self.push(value) + self.jump(inst) + elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( + self + ): + if truth_fn(len(value.unpack_var_sequence(self))): + if push: + self.push(value) + self.jump(inst) + elif isinstance(value, SymNodeVariable): + try: + # if the user is branching on a SymBool, guard on it + # if the user has code like: + # if size: + # ... + # then they are just testing truthiness: guard that the expr != 0 + if isinstance(value.sym_num, torch.SymBool): + eval_result = value.evaluate_expr(self.output) + else: + eval_result = guard_bool(value.sym_num != 0) + except exc.UserError as e: + if self.should_compile_partial_graph(): + return jump_graph_break(self, inst, value, extra_msg=f"\n{e}") + raise + if truth_fn(eval_result): + if push: + self.push(value) + self.jump(inst) + elif isinstance(value, variables.BackwardHookVariable): + if truth_fn(True): + if push: + self.push(value) + self.jump(inst) + else: + from .source import is_constant_source + + if value.source is not None and is_constant_source(value.source): + if truth_fn(value.get_real_value()): # type: ignore[attr-defined] + if push: + self.push(value) + self.jump(inst) + else: + unimplemented_v2( + gb_type=_gb_type, + context=f"attempted to jump with {value}", + explanation=_explanation, + hints=_hints, + ) + + return inner + + +def break_graph_if_unsupported(*, push): + def decorator(inner_fn): + @functools.wraps(inner_fn) + def wrapper(self: "InstructionTranslatorBase", inst: Instruction): + speculation = self.speculate() + if speculation.failed: + assert speculation.reason is not None + return handle_graph_break(self, inst, speculation.reason) + try: + return inner_fn(self, inst) + except Unsupported as excp: + if self.active_generic_context_managers: + # We don't support graph break under GenericContextWrappingVariable, + # If there is, we roll back to the checkpoint and fall back. + excp.remove_from_stats() + unimplemented_v2( + gb_type="Graph break under GenericContextWrappingVariable", + context=f"Active generic context managers: {self.active_generic_context_managers}", + explanation="Attempted to graph break in an active context manager(s) that doesn't support graph breaking.", + hints=[ + "Move the offending context manager(s) to outside the compiled region.", + *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK, + ], + from_exc=excp, + ) + + if isinstance(excp, exc.UncapturedHigherOrderOpError): + raise + + if not self.should_compile_partial_graph(): + raise + + log_graph_break( + self.code_options, + exc_info=True, + reason=str(excp), + user_stack=excp.real_stack, + ) + + if self.maybe_has_backedge(): + msg = ( + "Skipping frame because there is a graph break in a for/while loop\n" + f"{self.frame_summary()}" + ) + log.info(msg) + raise exc.SkipFrame(msg) from excp + + excp.remove_from_stats() + excp.add_to_stats("graph_break") + speculation.reason = GraphCompileReason(excp.msg, excp.real_stack) + speculation.fail_and_restart_analysis() + + def handle_graph_break( + self: "InstructionTranslatorBase", + inst: Instruction, + reason: GraphCompileReason, + ): + if ( + sys.version_info >= (3, 11) + and sys.version_info < (3, 12) + and inst.opname == "CALL" + ): + # stack effect for PRECALL + CALL is split between the two instructions + stack_effect = dis.stack_effect( + dis.opmap["PRECALL"], inst.arg + ) + dis.stack_effect(dis.opmap["CALL"], inst.arg) + else: + stack_effect = dis.stack_effect(inst.opcode, inst.arg) + + all_stack_locals_metadata = self.output.compile_subgraph( + self, reason=reason, stack_pops=push - stack_effect + ) + cg = PyCodegen(self) + cleanup: list[Instruction] = [] + # Reconstruct the context variable CLASS in the block stack + for b in self.block_stack: + # Don't exit any modes we have entered, + # output bytecode will mutate the tf mode stack accordingly + if isinstance(b.with_context, TorchFunctionModeVariable): + cg.extend_output( + b.resume_fn().try_except_torch_function_mode( + cg.code_options, cleanup + ) + ) + continue + assert b.with_context is not None + assert isinstance(b.with_context, (ContextWrappingVariable)) + b.with_context.reconstruct_type(cg) + cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup)) + self.output.add_output_instructions(cg.get_instructions()) + del cg + + if sys.version_info >= (3, 11) and inst.opname == "CALL": + kw_names = ( + self.kw_names.as_python_constant() + if self.kw_names is not None + else () + ) + if len(kw_names) > 0: + # KW_NAMES no longer used in 3.13 + assert sys.version_info < (3, 13) + self.output.add_output_instructions( + [create_instruction("KW_NAMES", argval=kw_names)] + ) + call_insts = create_call_function(inst.arg, False) + call_insts[-1].copy_positions(inst) + self.output.add_output_instructions(call_insts) + else: + # copy instruction, but without exception table data + assert inst.target is None + inst_copy = copy.copy(inst) + inst_copy.exn_tab_entry = None + self.output.add_output_instructions([inst_copy]) + + self.output.add_output_instructions(cleanup) + + self.popn(push - stack_effect) + for _ in range(push): + self.push(UnknownVariable()) + self.output.add_output_instructions( + self.create_call_resume_at( + self.next_instruction, all_stack_locals_metadata + ) + ) + + return wrapper + + return decorator + + +class BytecodeDistpatchTableMeta(type): + """Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()""" + + def __init__(cls, name, bases, dct) -> None: + super().__init__(name, bases, dct) + + def _missing(opname, *args): + unimplemented_v2( + gb_type="Missing bytecode handler", + context=f"{opname} with args {args}", + explanation=f"Dynamo does not know how to handle the bytecode instruction `{opname}`.", + hints=[ + f"Do not trace code that produces the `{opname}` bytecode instruction " + "(see https://docs.python.org/3/library/dis.html for bytecode semantics).", + *graph_break_hints.SUPPORTABLE, + ], + ) + + dispatch_table = { + op: getattr(cls, opname, functools.partial(_missing, opname)) + for opname, op in dis.opmap.items() + } + cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)] + + +@dataclasses.dataclass +class ExceptionStack: + """ + Exception stack that it is shared among all InstructionTranslator instances + """ + + # Exception handling in CPython is a bit confusing and some of the bytecode + # have a slightly different behavior than what is is documented. While reading + # the documentation, is important to notice that the terms "current exception" + # and "stack" sometimes refers to a C variable with the same name and the + # exception stack, respectively. + # + # The lifetime of an exception is (Python 3.11+): + # + tx._raise_exception_variable(...) := sets the current_exception variable + # + PUSH_EXC_INFO := pushes the current_exception to the *exception stack* + # + POP_EXCEPT := pops TOS from the *exception stack* + + _exc_stack: list[VariableTracker] = dataclasses.field(default_factory=list) + _current_exception: Optional[VariableTracker] = dataclasses.field(default=None) + + def clear_current_exception(self): + self._current_exception = None + + def set_current_exception(self, val): + self._set_context_and_break_context_reference_cycle(val) + self._current_exception = val + + def move_current_exception_to_stack(self): + assert self._current_exception is not None + self.append(self._current_exception) + self.clear_current_exception() + + def get_current_exception(self): + assert self._current_exception is not None + return self._current_exception + + def _set_context_recursive(self, val, prev_idx): + if (ctx := val.__context__) and type(ctx) is not ConstantVariable: + return val + if len(self._exc_stack) + prev_idx > 0: + prev = self._exc_stack[prev_idx] + self._set_context_recursive(prev, prev_idx - 1) + val.set_context(prev) + return val + + def _break_context_reference_cycle(self, val): + # See test_exceptions::test_raise_does_not_create_context_chain_cycle + # Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228 + # As noted on CPython, this is O(chain length) but the context chains + # are usually very small + o = slow_o = val + slow_update_toggle = False # floyd's algorithm for detecting cycle + while True: + context = o.__context__ + if type(context) is ConstantVariable: # context not set + break + + if context is val: + o.set_context(ConstantVariable(None)) + break + + o = context + if o is slow_o: + # pre-existing cycle - all exceptions on the path were + # visited and checked + break + + if slow_update_toggle: + slow_o = slow_o.__context__ # visited all exceptions + slow_update_toggle = not slow_update_toggle + + def _set_context_and_break_context_reference_cycle(self, val): + # set Exception.__context__ + self._set_context_recursive(val, len(self._exc_stack) - 1) + self._break_context_reference_cycle(val) + + def pop(self): + return self._exc_stack.pop() + + def append(self, val): + self._exc_stack.append(val) + + def __len__(self): + return len(self._exc_stack) + + def __getitem__(self, index): + return self._exc_stack[index] + + def __str__(self): + return f"{self._exc_stack=} - {self._current_exception=}" + + __repr__ = __str__ + + +class InstructionTranslatorBase( + metaclass=BytecodeDistpatchTableMeta, +): + output: OutputGraph + symbolic_locals: dict[str, VariableTracker] + symbolic_globals: dict[str, VariableTracker] + symbolic_torch_function_state: SymbolicTorchFunctionState + stack: list[VariableTracker] + instruction_pointer: Optional[int] + current_instruction: Instruction + block_stack: list[BlockStackEntry] + lineno: int + kw_names: Optional[ConstantVariable] + accept_prefix_inst: bool + prefix_insts: list[Instruction] + inline_depth: int + inconsistent_side_effects: bool + current_speculation: Optional[SpeculationEntry] + dispatch_table: list[Any] + exn_vt_stack: ExceptionStack + exec_recorder: Optional[ExecutionRecorder] + strict_checks_fn: Optional[Callable[[VariableTracker], bool]] + start_point: Optional[int] + is_leaf_tracer: bool + parent: Optional["InstructionTranslatorBase"] + debug_locals: list[tuple[VariableTracker, list[VariableTracker]]] + package: Optional["CompilePackage"] + + def mark_inconsistent_side_effects(self): + """ + InstructionTranslator has encountered instructions which may cause + dynamo to see a different version of history from eager + See: https://github.com/pytorch/pytorch/issues/110765 + """ + self.inconsistent_side_effects = True + + def maybe_has_backedge(self): + # This function employs a heuristic. It does not reliably detect a backedge. + # The heuristic is straightforward: starting from the current instruction and + # continuing to the end, if any jump instruction targets an instruction before + # the current one, there might be a backedge. + + # Python 3.12 introduced changes to bytecode that group common paths in + # blockstacks (with or try...else) and allow for early returns. Consequently, + # there can be multiple RETURN_VALUE instructions. Another heuristic is to + # halt detection upon encountering the first RETURN_VALUE or RETURN_CONST. + + # These heuristics can result in both false positives and negatives, but + # in either case, the Dynamo code remains valid. For false positives + # (where an edge is incorrectly marked as a backedge), Dynamo will + # perform a SkipFrame instead of potentially applying optimizations. For + # false negatives (where an edge that should be marked as a backedge + # isn't), multiple graphs may be generated if there's a break in the + # graph during a for loop. In general, its better to have fewer false + # negatives so that Dynamo does not skip the whole frame. + + cur_offset = self.current_instruction.offset + assert self.instruction_pointer is not None + for inst in self.instructions[self.instruction_pointer :]: + if inst.opname in ("RETURN_VALUE", "RETURN_CONST"): + return False + if inst.opname in JUMP_OPNAMES: + jump_offset = inst.argval + if jump_offset < cur_offset: + return True + return False + + def cellvars(self): + if not hasattr(self, "_cellvars"): + self._cellvars = tuple(self.code_options["co_cellvars"] or []) + # An inlined function might depend on the cellvar of the parent + # function. So, recursively obtain parent cellvars. + if isinstance(self, InliningInstructionTranslator): + self._cellvars += self.parent.cellvars() + return self._cellvars + + def freevars(self): + if not hasattr(self, "_freevars"): + self._freevars = tuple(self.code_options["co_freevars"] or []) + # An inlined function might depend on the freevar of the parent + # function. So, recursively obtain parent freevars. + if isinstance(self, InliningInstructionTranslator): + self._freevars += self.parent.freevars() + return self._freevars + + def cell_and_freevars(self): + if not hasattr(self, "_cell_and_freevars"): + self._cell_and_freevars = self.cellvars() + self.freevars() + return self._cell_and_freevars + + def prune_dead_locals(self): + # Only keep the locals that must remain on the stack. + reads = livevars_analysis(self.instructions, self.current_instruction) + self.symbolic_locals = { + k: v for k, v in self.symbolic_locals.items() if k in reads + } + # "Garbage collect the heap". + self.output.side_effects.prune_dead_object_new(self) + + def call_function( + self, + fn: VariableTracker, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ): + assert isinstance(fn, VariableTracker) + assert isinstance(args, list) + assert isinstance(kwargs, dict) + assert all( + isinstance(x, VariableTracker) + for x in itertools.chain(args, kwargs.values()) + ) + inner_fn = None + if hasattr(fn, "value"): + inner_fn = fn.value + if hasattr(fn, "fn"): + inner_fn = fn.fn + if inner_fn and callable(inner_fn) and is_forbidden(inner_fn): + raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}") + self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] + + def inline_generator_function(self, fn, args, kwargs): + """ + Redirect the call to the generator "call_function" + """ + if not isinstance(fn, LocalGeneratorFunctionVariable): + fn = LocalGeneratorFunctionVariable(fn) + return fn.call_function(self, args, kwargs) + + def inline_user_function_return(self, fn, args, kwargs): + """ + A call to some user defined function by inlining it. + """ + if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): + return self.inline_generator_function(fn, args, kwargs) + else: + return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) + + def get_line_of_code_header(self, lineno=None): + if lineno is None: + lineno = self.lineno + inline_depth_str = ( + f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else "" + ) + funcname = get_funcname(self.f_code.co_filename, lineno) + funcname_str = "" if funcname is None else f" ({funcname})" + return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}" + + def get_log_starts_line_log_str(self): + log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n" + line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip() + log_str += f" {line}" + return log_str + + def starts_line(self, lineno): + if self.lineno == lineno: + return + self.lineno = lineno + TracingContext.set_current_loc( + self.f_code.co_filename, lineno, self.f_code.co_name + ) + + if self.is_trace_source_log_enabled: + trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str)) + + def step(self): + """Process exactly one instruction, return False we should exit""" + ip = self.instruction_pointer + if ip is None: + return False + self.current_instruction = inst = self.instructions[ip] + self.instruction_pointer = ip + 1 + + if inst.starts_line: + self.starts_line(inst.starts_line) + + if ( + not self.stack + and self.should_compile_partial_graph() + and self.is_non_empty_graph() + ): + self.current_speculation = self.speculate() + if self.current_speculation.failed: + return self.step_graph_break(inst) + + if self.is_trace_bytecode_log_enabled: + trace_bytecode_log.debug( + "TRACE %s %s %s", inst.opname, inst.argval, self.stack + ) + + self.update_block_stack(inst) + + try: + self.dispatch_table[inst.opcode](self, inst) + return not self.output.should_exit + except TensorifyScalarRestartAnalysis: + raise + except exc.ObservedException as e: + self.exception_handler(e) + return True + except (ReturnValueOp, YieldValueOp): + return False + except Unsupported: + if self.current_speculation is None: + log.debug("empty checkpoint") + raise + log.debug("step triggered compile", exc_info=True) + + self.current_speculation.fail_and_restart_analysis() + + if sys.version_info >= (3, 11): + + def update_block_stack(self, inst): + # 3.11+ no longer uses a block stack, but we still keep track of one + # so that we know which contexts are currently active. + # For our purposes, all exception table entries with the same target + # are considered to be part of the same "block". + # NOTE: we only keep track of with blocks that are not contained in try blocks. + # This is because we will not create continuation functions on graph breaks in try blocks, + # but we may for with blocks. We do not push blocks here since + # with blocks are pushed when handling BEFORE_WITH. + entry = inst.exn_tab_entry + if entry: + # Detect when we have exited the top with block. + # The with blocks on the block stack are not enclosed in try + # blocks, so a with block's cleanup code should be in the + # previous with block (if any). + if ( + len(self.block_stack) >= 2 + and entry.target is not self.block_stack[-1].target + and entry.target is self.block_stack[-2].target + ): + # exit the current block + self.block_stack.pop() + else: + # no longer in any block + # It is possible for NOPs to be between two instructions + # in the same block, but the NOPs are not covered by an + # exception table entry. In this case, assume that we + # are still in the same block. + # In 3.12+, JUMP_BACKWARD might also not be covered by + # an exception table entry, so we also assume that we + # are still in the same block. It is probably safe to do + # this in 3.11, even though we haven't encountered this case before. + if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"): + # If we really escape from a block and the current + # instruction is not in another block, then there + # should be no other nested blocks that we are in. + assert len(self.block_stack) == 1 + self.block_stack.pop() + + else: + + def update_block_stack(self, inst): + pass + + @property + def next_instruction(self): + return self.instructions[self.instruction_pointer] # type: ignore[index] + + def step_graph_break(self, continue_inst): + # generate code from checkpoint + assert not self.output.output_instructions + assert self.current_speculation is not None + # NOTE: adding an assert here since it seems like the only place + # where we call step_graph_break right now is when the stack is empty, + # so let's enforce that for now. + assert not self.stack + self.output.compile_subgraph( + self, + partial_convert=True, + reason=GraphCompileReason("step_unsupported", [self.frame_summary()]), + ) + self.output.add_output_instructions( + [create_jump_absolute(continue_inst)] + self.instructions + ) + + def run_ctx_mgr(self): + # NB: Don't push the top level frame summary; set_current_loc will + # take care of it. However, DO make sure we attach real_stack to + # exceptions + return TracingContext.current_frame(None) + + def run(self): + with self.run_ctx_mgr(): + dump_file(self.f_code.co_filename) + try: + self.output.push_tx(self) + self.start_point = self.instruction_pointer + while self.step(): + pass + except TensorifyScalarRestartAnalysis: + raise + except BackendCompilerFailed: + raise + except RuntimeError as e: + if hasattr(e, "msg") and "Data-dependent" in e.msg: + readable_graph = torch.fx.GraphModule( + self.output.nn_modules, self.output.graph + ).print_readable( + print_output=False, include_stride=True, include_device=True + ) + e.partial_fx_graph = readable_graph # type: ignore[attr-defined] + raise + + raise + except Exception as e: + if self.exec_recorder: + e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] + + raise + finally: + self.output.pop_tx() + # Cleanup the outputGraph to delete the held tensors. We perform the + # cleanup only for InstructionTranslator and not + # InliningInstructionTranslator. The InliningInstructionTranslator + # mutates the output object and is restored to original state if + # there was an exception. + if isinstance(self, InstructionTranslator): + self.output.cleanup() + + # Note that this call maybe redundant if compile_subgraph is + # called. This is ok, because calling exit stack close() + # twice is not an issue (second stop is a no op). + self.output.mark_bytecode_tracing_stop() + + def push(self, val: Optional[VariableTracker]): + assert val is None or isinstance(val, VariableTracker), ( + f"push expects VariableTracker, got {typestr(val)}" + ) + self.stack.append(val) # type: ignore[arg-type] + + def push_many(self, vals: list[VariableTracker]): + for val in vals: + self.push(val) + + def pop(self) -> VariableTracker: + return self.stack.pop() + + def popn(self, n: int) -> list[VariableTracker]: + return [*reversed([self.pop() for _ in range(n)])] + + def LOAD_FAST(self, inst): + name = inst.argval + if self.exec_recorder and name in self.f_locals: + self.exec_recorder.add_local_var(name, self.f_locals[name]) + + try: + self.push(self.symbolic_locals[name].unwrap()) + except KeyError: + if name.startswith("."): + try: + # This happens in dict/list comprehensions + new_name = name.replace(".", "implicit") + self.push(self.symbolic_locals[new_name]) + except KeyError: + unimplemented_v2( + gb_type="Attempted to read undefined local variable (implicit)", + context=f"LOAD_FAST {name}", + explanation=f"Could not find an implicit local variable with name `{name}`", + hints=[ + "This happens in dict/list comprehensions", + *graph_break_hints.USER_ERROR, + ], + ) + else: + unimplemented_v2( + gb_type="Attempted to read undefined local variable", + context=f"LOAD_FAST {name}", + explanation=f"Could not find a local variable with name `{name}`", + hints=[*graph_break_hints.USER_ERROR], + ) + + # for continuation functions + if name.startswith("__stack"): + self.symbolic_locals.pop(name) + + def LOAD_DEREF(self, inst): + assert inst.argval in self.cell_and_freevars() + cell = self.symbolic_locals[inst.argval] + contents_var = self.output.side_effects.load_cell(cell) + self.push(contents_var) + + if self.exec_recorder and inst.argval in self.f_locals: + self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval]) + + def STORE_FAST(self, inst): + name = inst.argval + loaded_vt = self.pop() + loaded_vt.set_name_hint(name) + self.symbolic_locals[name] = loaded_vt + + def DELETE_FAST(self, inst): + del self.symbolic_locals[inst.argval] + + def STORE_DEREF(self, inst): # type: ignore[override] + assert inst.argval in self.cell_and_freevars() + cell = self.symbolic_locals[inst.argval] + val = self.pop() + self.output.side_effects.store_cell(cell, val) + + assert isinstance(cell, CellVariable) # tame mypy + if cell.local_name is not None: + val.set_name_hint(cell.local_name) # type: ignore[attr-defined] + + LOAD_CLOSURE = LOAD_FAST + + def _load_const(self, inst): + i = inst.arg + if i is None: + return ConstantVariable.create(value=inst.argval) + val = self._constants_cache[i] + if not val: + self._constants_cache[i] = val = ConstantVariable.create(value=inst.argval) + return val + + def LOAD_CONST(self, inst): + self.push(self._load_const(inst)) + + def _load_global(self, inst): + name = inst.argval + + if self.exec_recorder: + if name in self.f_globals: + self.exec_recorder.add_global_var(name, self.f_globals[name]) + else: + assert name in self.f_builtins + self.exec_recorder.builtins[name] = self.f_builtins[name] + + if name not in self.f_globals: + return self.load_builtin(inst) + + if name in self.symbolic_globals: + variable = self.output.side_effects[self.symbolic_globals[name]] + self.push(self.output.side_effects.load_global(variable, name)) + return + + value = self.f_globals[name] + self.push(VariableTracker.build(self, value, GlobalSource(name))) + + @functools.cached_property + def nn_modules_globals_vt(self): + module_name = "torch.nn.modules.module" + module_source = self.import_source(module_name) + fglobals_value = _import_module(module_name) + return VariableTracker.build(self, fglobals_value, module_source) + + def LOAD_GLOBAL(self, inst): + if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2: + self.PUSH_NULL(inst) + self._load_global(inst) + if sys.version_info >= (3, 13) and inst.arg % 2: + self.PUSH_NULL(inst) + + def STORE_GLOBAL(self, inst): + value = self.pop() + name = inst.argval + source = GlobalSource(name) + if name not in self.symbolic_globals: + self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object + variable = self.output.side_effects.track_global_existing( + source, self.symbolic_globals[name] + ) + if isinstance(value, RemovableHandleVariable): + unimplemented_v2( + gb_type="Storing Tensor hook handle in globals", + context=name, + explanation="This is not supported.", + hints=[], + ) + self.output.side_effects.store_global(variable, name, value) + + # Cache note: This cache only exists for the duration of this + # InstructionTranslator - so it should be safe to do. + @cache_method + def import_source(self, module_name): + """Create an alias to a module for use in guards""" + if "torch_package" in module_name: + value = torch.package.package_importer._package_imported_modules[ + module_name + ] + alias = ( + module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_") + ) + else: + value = _import_module(module_name) + alias = f"__import_{module_name.replace('.', '_dot_')}" + + if self.package is not None: + self.package.add_import_source(alias, module_name) + f_globals = self.output.global_scope + assert alias not in f_globals or f_globals[alias] is value + f_globals[alias] = value + self.output.update_co_names(alias) + return GlobalSource(alias) + + def resolve_name(self, name, package, level): + """ + Copied from the Cpython implementation of __import__ + Resolve a relative module name to an absolute one. + https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902 + """ + bits = package.rsplit(".", level - 1) + if len(bits) < level: + raise ImportError("attempted relative import beyond top-level package") + base = bits[0] + return f"{base}.{name}" if name else base + + def calc_package(self): + """ + Copied from the Cpython implementation of __import__ + https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090 + """ + package = self.f_globals.get("__package__") + spec = self.f_globals.get("__spec__") + if package is not None: + if spec is not None and package != spec.parent: + log.warning( + "__package__ != __spec__.parent (%r != %r)", + package, + spec.parent, + stacklevel=3, + ) + return package + elif spec is not None: + return spec.parent + else: + log.warning( + "can't resolve package from __spec__ or __package__, " + "falling back on __name__ and __path__", + stacklevel=3, + ) + package = self.f_globals["__name__"] + if "__path__" not in self.f_globals: + package = package.rpartition(".")[0] + return package + + def IMPORT_NAME(self, inst): + level, fromlist = self.popn(2) + level = level.as_python_constant() + fromlist = fromlist.as_python_constant() + module_name = inst.argval + + # Are we replaying? if so, load recorded module + recorded_name = ( + f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}" + ) + if recorded_name in self.f_globals: + value = self.f_globals[recorded_name] + source = GlobalSource(recorded_name) + else: + try: + value = __import__( + module_name, + fromlist=fromlist, + level=level, + globals=self.f_globals, + ) + except ImportError: + unimplemented_v2( + gb_type="Import failure", + context=f"module_name: {module_name}, fromlist: {fromlist}, level={level}", + explanation="Failure when attempting to import.", + hints=[*graph_break_hints.USER_ERROR], + ) + + if level != 0: + pkg = self.calc_package() + module_name = self.resolve_name(module_name, pkg, level) + + # For __import__, when the name variable is of the form package.module, + # normally, the top-level package (the name up till the first dot) is + # returned, not the module named by module_name. However, when a + # non-empty fromlist argument is given, the module named by name is + # returned. Therefore, we set the source correctly here. + if not fromlist: + top_level_module_name = module_name.partition(".")[0] + source = self.import_source(top_level_module_name) + else: + source = self.import_source(module_name) + + if self.exec_recorder: + self.exec_recorder.add_local_mod(recorded_name, value) + + if istype(value, (types.ModuleType, DummyModule)): + self.push(PythonModuleVariable(value, source=source)) + else: + unimplemented_v2( + gb_type="Bad import result", + context=typestr(value), + explanation="Import result is not a Python module.", + hints=[], + ) + + # fb internal 3.12 opcode + EAGER_IMPORT_NAME = IMPORT_NAME + + def IMPORT_FROM(self, inst): + self.DUP_TOP(inst) + self._load_attr(inst) + + # Cache note: This cache only exists for the duration of this + # InstructionTranslator - so it should be safe to do. + @cache_method + def load_builtin_from_argval(self, argval): + if argval not in self.f_builtins: + raise Unsupported(f"name '{argval}' is not defined") + val = self.f_builtins[argval] + + if callable(val): + builtins_source = GlobalSource( + self.output.name_of_builtins_dict_key_in_fglobals + ) + var_source = DictGetItemSource(builtins_source, argval) + return VariableTracker.build(self, val, var_source) + else: + assert is_builtin_constant(val) + return ConstantVariable.create(value=val) + + def load_builtin(self, inst): + self.push(self.load_builtin_from_argval(inst.argval)) + + def jump(self, inst): + assert self.instruction_pointer is not None + assert self.start_point is not None + get_metrics_context().increment( + "ir_count", self.instruction_pointer - self.start_point + ) + self.instruction_pointer = self.indexof[inst.target] + self.start_point = self.instruction_pointer + + JUMP_FORWARD = jump + JUMP_ABSOLUTE = jump + + POP_JUMP_IF_FALSE = generic_jump(operator.not_, False) + POP_JUMP_IF_TRUE = generic_jump(operator.truth, False) + JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True) + JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True) + + def SETUP_LOOP(self, inst): + # only exists in python<=3.7 + self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack))) + + def SETUP_EXCEPT(self, inst): + # only exists in python<=3.7 + self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack))) + + def POP_BLOCK(self, inst): + self.block_stack.pop() + + def SETUP_WITH(self, inst): + self.setup_or_before_with(inst) + + def SETUP_FINALLY(self, inst): + self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack))) + + def BEGIN_FINALLY(self, inst): + self.push(None) + + def WITH_CLEANUP_START(self, inst): + exit, exc = self.popn(2) + assert exc is None + self.push(exc) + self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {})) + + def WITH_CLEANUP_FINISH(self, inst): + self.popn(2) + self.push(None) + + def FOR_ITER(self, inst): + it = self.pop().realize() + try: + val = it.next_variable(self) + self.push(it) + self.push(val) + except (StopIteration, exc.ObservedUserStopIteration) as e: + if isinstance(e, exc.ObservedUserStopIteration): + exc.handle_observed_exception(self) + + # leave iterator upon exhaustion in 3.12 + if sys.version_info >= (3, 12): + # CPython 3.12 actually jumps to the instruction after the END_FOR + # and performs the action of END_FOR as part of FOR_ITER. We jump + # to the END_FOR and run it, so we need to make sure 2 values are + # on the stack for it to pop. + self.push(it) + self.push(ConstantVariable.create(None)) + self.jump(inst) + + def _create_exception_type(self, val): + if isinstance( + val, (variables.BuiltinVariable, UserDefinedExceptionClassVariable) + ): + # Create the instance of the exception type + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 + val = val.call_function(self, [], {}) # type: ignore[arg-type] + return val + + def _raise_exception_variable(self, val) -> NoReturn: + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise exception instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + val = self._create_exception_type(val) + + # Handle https://peps.python.org/pep-0479/ + # CPython 3.12+ has a specific bytecode instruction (CALL_INTRINSIC_1 3) for this + if ( + is_generator(self.f_code) + and isinstance(val, variables.ExceptionVariable) + and val.exc_type is StopIteration + ): + val = variables.BuiltinVariable(RuntimeError).call_function(self, [], {}) # type: ignore[arg-type] + + # Save the exception in a global data structure + self.exn_vt_stack.set_current_exception(val) + + # 2) when user raises exception instance + if self._isinstance_exception(val): + observed_exception_type = exc.get_dynamo_observed_exception(val.exc_type) # type: ignore[attr-defined] + raise observed_exception_type(f"raised exception {val}") + unimplemented_v2( + gb_type="Failed to raise exception", + context=str(exc), + explanation="Attempted to raise a non-Exception type/value.", + hints=[*graph_break_hints.USER_ERROR], + ) + + def RAISE_VARARGS(self, inst): + if inst.arg == 0: + if not len(self.exn_vt_stack): + msg = ConstantVariable("No active exception to reraise") + exc.raise_observed_exception(RuntimeError, self, args=[msg]) + + # re-raise the previous exception. Here CPython refers to the exception + # on top of the exception stack + assert len(self.exn_vt_stack) + val = self.exn_vt_stack[-1] + assert self._isinstance_exception(val), val + self._raise_exception_variable(val) + elif inst.arg == 1: + # raise TOS + val = self.stack[-1] + self._raise_exception_variable(val) + else: + # raise .. from ... + from_vt = self.pop() + val = self.pop() + try: + self._raise_exception_variable(val) + finally: + # Update __cause__/__supppress_context__ in the raised exception + curr_exc = self.exn_vt_stack.get_current_exception() + cause = self._create_exception_type(from_vt) + curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) + + def CLEANUP_THROW(self, inst): + # https://github.com/python/cpython/pull/96010 + tos = self.stack[-1] + assert isinstance(tos, ExceptionVariable) + if tos.exc_type is StopIteration: + unimplemented_v2( + gb_type="CLEANUP_THROW with StopIteration", + context="", + explanation="Received StopIteration when handling generator.throw/close. This is not supported.", + hints=[], + ) + else: + self.RERAISE(inst) + + def RERAISE(self, inst): + # https://docs.python.org/3/library/dis.html#opcode-RERAISE + # Re-raises the exception currently on top of the stack. If oparg is + # non-zero, pops an additional value from the stack which is used to + # set f_lasti of the current frame. + + if sys.version_info >= (3, 11): + # RERAISE is currently supported in a narrow case of `raise ... from None` + val = self.pop() + if inst.argval: + # RERAISE 1 + _ = self.pop() + self._raise_exception_variable(val) + else: + # RERAISE 0 + self.push(val) + self._raise_exception_variable(val) + else: + _exc = self.pop() + val = self.pop() + _tb = self.pop() + self._raise_exception_variable(val) + + def _isinstance_exception(self, val): + return isinstance( + val, + ( + variables.ExceptionVariable, + UserDefinedExceptionClassVariable, + UserDefinedExceptionObjectVariable, + ), + ) + + def WITH_EXCEPT_START(self, inst): + if sys.version_info >= (3, 11): + # At the top of the stack are 4 values: + # - TOP = exc_info() + # - SECOND = previous exception + # - THIRD: lasti of exception in exc_info() + # - FOURTH: the context.__exit__ bound method + # We call FOURTH(type(TOP), TOP, GetTraceback(TOP)). + # Then we push the __exit__ return value. + assert len(self.stack) >= 4 + fn = self.stack[-4] + val = self.stack[-1] + assert self._isinstance_exception(val) + typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined] + tb = ConstantVariable(None) + else: + assert len(self.stack) >= 7 + fn = self.stack[-7] + val = self.stack[-2] + assert self._isinstance_exception(val) + typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined] + tb = ConstantVariable(None) + + self.call_function(fn, [typ, val, tb], {}) + + def exception_handler(self, raised_exception): + observed_exn_gb_explanation = ( + "Dynamo found no exception handler at the top-level compiled function " + "when encountering an exception. Exception will propagate outside the compiled region." + ) + + if sys.version_info >= (3, 11): + exn_tab_entry = self.current_instruction.exn_tab_entry + if exn_tab_entry: + # Implementation is based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + + # 1) pop values from the stack until it matches the stack depth + # for the handler + while len(self.stack) > exn_tab_entry.depth: + self.pop() + + # 2) if 'lasti' is true, then push the offset that the exception was raised at + if exn_tab_entry.lasti: + self.push( + variables.ConstantVariable(self.current_instruction.offset) + ) + + # 3) push the exception to the stack + self.push(self.exn_vt_stack.get_current_exception()) + + # 4) jump to the handler + self.jump(exn_tab_entry) + else: + # No handler found. Bubble the exception to the parent + # instruction translator. We use special exception for this. + self.stack.clear() + if type(self) is InstructionTranslator: + unimplemented_v2( + gb_type="Observed exception", + context=str(raised_exception), + explanation=observed_exn_gb_explanation, + hints=[ + *graph_break_hints.USER_ERROR, + *graph_break_hints.SUPPORTABLE, + ], + ) + raise raised_exception + else: + if len(self.block_stack): + # base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455 + + block_stack_entry = self.block_stack.pop() + + while block_stack_entry.inst.opname == "EXCEPT_HANDLER": + # TODO(anijain2305) - This is not tested .. unable to create a testcase + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 + self.popn(3) + self.exn_vt_stack.pop() + if len(self.block_stack) == 0: + # No handler found in this frame. Bubble the exception to the parent + # instruction translator. + self.stack.clear() + if type(self) is InstructionTranslator: + unimplemented_v2( + gb_type="Observed exception (EXCEPT_HANDLER)", + context=str(raised_exception), + explanation=observed_exn_gb_explanation + + " This graph break is unexpected.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + raise raised_exception + block_stack_entry = self.block_stack.pop() + + exception_var = self.exn_vt_stack.get_current_exception() + self.exn_vt_stack.move_current_exception_to_stack() + + # 1) pop values from the stack until it matches the stack depth + # for the handler + while len(self.stack) > block_stack_entry.stack_index: + self.pop() + + # Push a dummy block stack entry of EXCEPT_HANDLER + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456 + except_handler_inst = Instruction(1e6, "EXCEPT_HANDLER", None, 0) + self.block_stack.append( + BlockStackEntry(except_handler_inst, None, len(self.stack)) + ) + + # Push old exception + if len(self.exn_vt_stack) >= 2: + old_exception = self.exn_vt_stack[-2] + + # Push the old exception on to stack - tb, value, type + # Traceback is currently mapped to UnknownVariable + self.push(variables.UnknownVariable()) + self.push(old_exception) + self.push(variables.BuiltinVariable(old_exception.exc_type)) + else: + # Push empty exception tb, value, type + self.push(variables.ConstantVariable(None)) + self.push(variables.ConstantVariable(None)) + self.push(variables.ConstantVariable(None)) + + # Push new exception - tb, val, type + # Traceback is currently mapped to UnknownVariable + self.push(variables.UnknownVariable()) + self.push(exception_var) + self.push(variables.BuiltinVariable(exception_var.exc_type)) + + # Jump to target + self.jump(block_stack_entry) + else: + # No handler found. Bubble the exception to the parent + # instruction translator. We use special exception for this. + self.stack.clear() + if type(self) is InstructionTranslator: + unimplemented_v2( + gb_type="Observed exception", + context=str(raised_exception), + explanation=observed_exn_gb_explanation, + hints=[ + *graph_break_hints.USER_ERROR, + *graph_break_hints.SUPPORTABLE, + ], + ) + raise raised_exception + + def PUSH_EXC_INFO(self, inst): + # https://docs.python.org/3/library/dis.html#opcode-PUSH_EXC_INFO + # Pops a value from the stack. Pushes the current exception to the top + # of the stack. Pushes the value originally popped back to the stack. + # + # The behavior of this opcode in CPython is a bit different than what it + # is described. It pops a value from the stack, pushes the top of the + # exception stack to the interpreter stack and moves the + # "current exception" to the exception stack. + # + # As an example, suppose the stack is in the following state: + # + stack = [..., ConstantVariable(1), ConstantVariable(2)] + # + current_exception = TypeError + # + exception_stack = [ValueError] + # + # After PUSH_EXC_INFO is executed + # + stack = [..., ConstantVariable(1), ValueError, ConstantVariable(2)] + # + current_exception = None + # + exception_stack = [ValueError, TypeError] + + val = self.pop() + if len(self.exn_vt_stack) == 0: + prev_exc = ConstantVariable(None) + else: + prev_exc = self.exn_vt_stack[-1] + self.push(prev_exc) + self.push(val) + self.exn_vt_stack.move_current_exception_to_stack() + + def POP_EXCEPT(self, inst): + if sys.version_info >= (3, 11): + _ = self.pop() + # This exception is handled and therefore we can clear the error indicator + assert len(self.exn_vt_stack) + self.exn_vt_stack.pop() + else: + assert len(self.block_stack) > 0 + if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER": + raise AssertionError( + "Bug in Dynamo tracing of exception handling." + "Top of the block stack is not EXCEPT_HANDLER." + ) + self.block_stack.pop() + + self.popn(3) + + # This exception is handled and therefore we can clear the error indicator + assert len(self.exn_vt_stack) + self.exn_vt_stack.pop() + + def check_if_exc_matches(self): + assert len(self.stack) >= 2 + expected_exc_types = self.pop() + if sys.version_info >= (3, 11): + # CHECK_EXC_MATCH (which is used from 3.11 onwards) does not pop. + # This is the description from the disassembly doc + # + # Performs exception matching for ``except``. Tests whether the ``STACK[-2]`` + # is an exception matching ``STACK[-1]``. Pops ``STACK[-1]`` and pushes the boolean + # result of the test. + exc_instance = self.stack[-1] + else: + # This is used prior to 3.11 via opcode JUMP_IF_NOT_EXC_MATCH + # There is no documentation but here is the code pointer that does 2 pops + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L3650-L3665 + exc_instance = self.stack.pop() + + # Users can check exception in 3 ways + # 1) except NotImplementedError --> BuiltinVariable + # 2) except CustomException --> UserDefinedExceptionClasVariable + # 3) except (NotImplemetedError, AttributeError) -> TupleVariable + + if not isinstance( + expected_exc_types, + ( + BuiltinVariable, + TupleVariable, + UserDefinedExceptionClassVariable, + UserDefinedExceptionObjectVariable, + ), + ): + unimplemented_v2( + gb_type="Exception with bad expected type", + context=str(expected_exc_types), + explanation=f"`except ...` has unsupported type {expected_exc_types}.", + hints=[*graph_break_hints.USER_ERROR], + ) + + if sys.version_info >= (3, 11): + if not self._isinstance_exception(exc_instance): + unimplemented_v2( + gb_type="Caught non-Exception value", + context=str(exc_instance), + explanation=f"Except expects to receive an object of Exception type but received {exc_instance}.", + hints=[*graph_break_hints.USER_ERROR], + ) + + if isinstance(expected_exc_types, TupleVariable): + expected_types = expected_exc_types.items + else: + expected_types = [ + expected_exc_types, + ] + + for expected_type in expected_types: + if not isinstance( + expected_type, + ( + BuiltinVariable, + UserDefinedExceptionObjectVariable, + UserDefinedExceptionClassVariable, + ), + ): + unimplemented_v2( + gb_type="Exception with non-type expectation", + context=str(expected_type), + explanation=f"`except ...` expects a non-type: {expected_type}.", + hints=[*graph_break_hints.USER_ERROR], + ) + if self._isinstance_exception(exc_instance) and issubclass( + exc_instance.exc_type, # type: ignore[attr-defined] + expected_type.fn, # type: ignore[attr-defined] + ): + return True + elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass( + exc_instance.fn, expected_type.fn + ): + return True + + return False + + def CHECK_EXC_MATCH(self, inst): + self.push(variables.ConstantVariable(self.check_if_exc_matches())) + + def JUMP_IF_NOT_EXC_MATCH(self, inst): + if not self.check_if_exc_matches(): + self.jump(inst) + + def COMPARE_OP(self, inst): + if inst.argval == "exception match": + self.CHECK_EXC_MATCH(inst) + else: + self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) + + def GET_ITER(self, inst): + self.call_function(BuiltinVariable(iter), [self.pop()], {}) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION(self, inst): + args = self.popn(inst.argval) + fn = self.pop() + self.call_function(fn, args, {}) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION_EX(self, inst): + kwargsvars: VariableTracker + if inst.argval == 0: + kwargsvars = ConstDictVariable({}) + argsvars = self.pop() + elif inst.argval == 1: + kwargsvars = self.pop() + argsvars = self.pop() + else: + unimplemented_v2( + gb_type="Variadic function call with bad flags", + context=f"flags: {inst.argval}", + explanation=f"Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + if sys.version_info >= (3, 13): + # 3.13 swapped null and callable + null = self.pop() + assert isinstance(null, NullVariable) + + fn = self.pop() + + if sys.version_info >= (3, 11) and sys.version_info < (3, 13): + null = self.pop() + assert isinstance(null, NullVariable) + + if not isinstance( + argsvars, BaseListVariable + ) and argsvars.has_force_unpack_var_sequence(self): + argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self)) + + # Unpack for cases like fn(**obj) where obj is a map + if isinstance(kwargsvars, UserDefinedObjectVariable): + kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type] + + if not isinstance(argsvars, BaseListVariable) or not isinstance( + kwargsvars, ConstDictVariable + ): + unimplemented_v2( + gb_type="Variadic function call with bad args/kwargs type", + context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}", + explanation="Expected args to be a list and kwargs to be a dict", + hints=[*graph_break_hints.USER_ERROR], + ) + + # Map to a dictionary of str -> VariableTracker + kwargsvars = kwargsvars.keys_as_python_constant() + self.call_function(fn, argsvars.items, kwargsvars) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION_KW(self, inst): + argnames = self.pop() + args = self.popn(inst.argval) + fn = self.pop() + assert isinstance(argnames, TupleVariable) and argnames.is_python_constant() + argnames = argnames.as_python_constant() + args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :] + kwargs = dict(zip(argnames, kwargs_list)) + assert len(kwargs) == len(argnames) + self.call_function(fn, args, kwargs) + + def LOAD_METHOD_SUPER(self, inst): + self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) + arg = inst.argval[0] + argval = self.code_options["co_names"][arg] + if sys.version_info < (3, 11): + self._load_attr(dataclasses.replace(inst, argval=argval)) + else: + self.LOAD_METHOD(dataclasses.replace(inst, argval=argval)) + + def LOAD_ATTR_SUPER(self, inst): + self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) + arg = inst.argval[0] + argval = self.code_options["co_names"][arg] + self._load_attr(dataclasses.replace(inst, argval=argval)) + + def LOAD_METHOD(self, inst): + self._load_attr(inst) + obj = self.pop() + if sys.version_info >= (3, 13): + self.push(obj) + self.PUSH_NULL(inst) + elif sys.version_info >= (3, 11): + # always follow the NULL + fn convention, since if obj + # is actually a method, self is already bound to it, so it + # doesn't need to be passed in as an arg. + self.PUSH_NULL(inst) + self.push(obj) + else: + self.push(obj) + self.push(None) + + def CALL_METHOD(self, inst): + args = self.popn(inst.argval) + dummy = self.pop() + assert dummy is None + fn = self.pop() + self.call_function(fn, args, {}) + + def _load_attr(self, inst): + obj = self.pop() + result = BuiltinVariable(getattr).call_function( + self, # type: ignore[arg-type] + [obj, ConstantVariable.create(inst.argval)], + {}, + ) + self.push(result) + + def LOAD_ATTR(self, inst): + if sys.version_info >= (3, 12): + if inst.arg % 2: + self.LOAD_METHOD(inst) + return + self._load_attr(inst) + + def STORE_ATTR(self, inst): + speculation = self.speculate() + if speculation.failed: + return self.store_attr_graph_break(inst) + val, obj = self.popn(2) + + if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable): + # We don't allow side effects during export on non-constant values + # https://github.com/pytorch/torchdynamo/issues/1475 + assert not self.export, ( + f"Mutating module attribute {inst.argval} during export." + ) + + try: + BuiltinVariable(setattr).call_function( + self, # type: ignore[arg-type] + [obj, ConstantVariable.create(inst.argval), val], + {}, + ) + return + except Unsupported as e: + if not self.should_compile_partial_graph(): + raise + log.debug("STORE_ATTR triggered compile", exc_info=True) + e.remove_from_stats() + e.add_to_stats("graph_break") + speculation.fail_and_restart_analysis() + + def store_attr_graph_break(self, inst): + log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break") + if not self.should_compile_partial_graph(): + unimplemented_v2( + gb_type="Should not compile partial graph (STORE_ATTR)", + context="", + explanation="Dynamo has determined when encountering an unsupported " + "STORE_ATTR instruction (i.e. `obj.attr = val`) that it should not compile the partial graph.", + hints=[], + ) + all_stack_locals_metadata = self.output.compile_subgraph( + self, + reason=GraphCompileReason("store_attr", [self.frame_summary()]), + stack_pops=2, + ) + self.output.add_output_instructions([copy.copy(inst)]) + self.popn(2) + self.output.add_output_instructions( + self.create_call_resume_at(self.next_instruction, all_stack_locals_metadata) + ) + + def DELETE_ATTR(self, inst): + obj = self.pop() + BuiltinVariable(delattr).call_function( + self, # type: ignore[arg-type] + [obj, ConstantVariable.create(inst.argval)], + {}, + ) + + def create_call_resume_at(self, offset, all_stack_locals_metadata): + raise AssertionError( + f"create_call_resume_at not overridden by subclass {type(self)}" + ) + + def should_compile_partial_graph(self) -> bool: + raise AssertionError( + f"should_compile_partial_graph not overridden by subclass {type(self)}" + ) + + @break_graph_if_unsupported(push=0) + def STORE_SUBSCR(self, inst): + val, obj, key = self.popn(3) + obj.call_method(self, "__setitem__", [key, val], {}) + + def DELETE_SUBSCR(self, inst): + obj, key = self.popn(2) + obj.call_method(self, "__delitem__", [key], {}) + + def BUILD_TUPLE(self, inst): + items = self.popn(inst.argval) + self.push(TupleVariable(items)) + + def BUILD_SLICE(self, inst): + items = self.popn(inst.argval) + self.push(SliceVariable(items)) + + def BUILD_LIST(self, inst): + items = self.popn(inst.argval) + self.push(ListVariable(items, mutation_type=ValueMutationNew())) + + def BUILD_SET(self, inst): + if config.inject_BUILD_SET_unimplemented_TESTING_ONLY: + unimplemented_v2( + gb_type="missing BUILD_SET handler", + context="", + explanation="Missing BUILD_SET bytecode handler (for testing purposes).", + hints=[], + ) + items = self.popn(inst.argval) + new_set = SetVariable(items, mutation_type=ValueMutationNew()) + self.push(new_set) + + def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): + seqs = self.popn(inst.argval) + items = [] + for seq in seqs: + try: + items.extend(seq.force_unpack_var_sequence(self)) + except NotImplementedError: + unimplemented_v2( + gb_type="Failed to unpack object for BUILD_LIST_UNPACK", + context=str(seq), + explanation=f"{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK " + "bytecode (`[*x, *y, ...]`).", + hints=[*graph_break_hints.USER_ERROR], + ) + self.push(cls(items, mutation_type=ValueMutationNew())) + + def BUILD_TUPLE_UNPACK(self, inst): + self.BUILD_LIST_UNPACK(inst, cls=TupleVariable) + + BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK + + def BUILD_MAP(self, inst): + items = self.popn(inst.argval * 2) + d = dict(zip(items[::2], items[1::2])) + self.push(ConstDictVariable(d, mutation_type=ValueMutationNew())) + + def BUILD_MAP_UNPACK(self, inst): + items = self.popn(inst.argval) + # ensure everything is a dict + items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type] + result = {} + for x in items: + assert isinstance(x, ConstDictVariable) + result.update(x.items) + self.push( + ConstDictVariable( + result, + mutation_type=ValueMutationNew(), + ) + ) + + BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK + + def BUILD_CONST_KEY_MAP(self, inst): + keys = self.pop() + values = self.popn(inst.argval) + assert isinstance(keys, TupleVariable) + assert keys.is_python_constant() + + keys = keys.force_unpack_var_sequence(self) + assert len(keys) == len(values) + + self.push( + ConstDictVariable( + dict(zip(keys, values)), + mutation_type=ValueMutationNew(), + ) + ) + + def MAP_ADD(self, inst): + k, v = self.popn(2) + assert inst.argval > 0 + obj = self.stack[-inst.arg].realize() + assert isinstance(obj, ConstDictVariable) + obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type] + + def SET_ADD(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, SetVariable) + assert obj.is_mutable() + return obj.call_method(self, "add", [v], {}) + + def SET_UPDATE(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, SetVariable) + assert obj.is_mutable() + obj.call_method(self, "update", [v], {}) + + def LIST_APPEND(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg].realize() + assert isinstance(obj, ListVariable) + assert obj.is_mutable() + self.output.side_effects.mutation(obj) + obj.items.append(v) + + def MAKE_FUNCTION(self, inst): + flags = inst.arg + if sys.version_info < (3, 11): + fn_name = self.pop() + code = self.pop() + if sys.version_info >= (3, 11): + # MAKE_FUNCTION behavior actually changed in 3.11, see + # https://github.com/python/cpython/pull/93189/ + assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined] + fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined] + defaults = None + closure = None + annotations = None + kwdefaults = None + + if sys.version_info < (3, 13): + # in 3.13, this is handled in SET_FUNCTION_ATTRIBUTE + if flags & 0x08: + closure = self.pop() + if flags & 0x04: + annotations = self.pop() + if flags & 0x02: + kwdefaults = self.pop() + if flags & 0x01: + defaults = self.pop() + + self.push( + NestedUserFunctionVariable( + fn_name, + code, + self.f_globals, + defaults, + kwdefaults, + annotations, + closure, + ) + ) + + def UNPACK_SEQUENCE(self, inst): + seq = self.pop() + if isinstance(seq, TensorVariable): + val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) # type: ignore[arg-type] + elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): + # x, y = a.shape + proxy = getattr(seq.obj.as_proxy(), seq.name) + val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] + elif seq.has_force_unpack_var_sequence(self): + val = seq.force_unpack_var_sequence(self) + else: + unimplemented_v2( + gb_type="Failed to unpack object for UNPACK_SEQUENCE", + context=str(seq), + explanation=f"{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode " + "(i.e. `a, b, c = d`).", + hints=[*graph_break_hints.USER_ERROR], + ) + if len(val) != inst.argval: + unimplemented_v2( + gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE", + context=f"expected length: {inst.argval}, actual: {len(val)}", + explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode " + "(i.e. `a, b, c = d`) with unexpected length.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + for i in reversed(val): + self.push(i) + + def UNPACK_EX(self, inst): + assert 0 <= inst.argval <= 0xFFFF + prefix = inst.argval & 0xFF # low byte + suffix = inst.argval >> 8 # high byte + seq = self.pop() + if seq.has_force_unpack_var_sequence(self): + vals = list(seq.force_unpack_var_sequence(self)) + assert len(vals) >= prefix + suffix + vals_prefix = vals[:prefix] + vals_list = vals[prefix : len(vals) - suffix] + vals_suffix = vals[len(vals) - suffix :] + for item in reversed(vals_suffix): + self.push(item) + self.push(TupleVariable(vals_list)) + for item in reversed(vals_prefix): + self.push(item) + else: + unimplemented_v2( + gb_type="Failed to unpack object for UNPACK_EX", + context=str(seq), + explanation=f"{seq} cannot be unpacked into a list for the UNPACK_EX bytecode.", + hints=[*graph_break_hints.USER_ERROR], + ) + + def NOP(self, inst): + pass + + def POP_TOP(self, inst): + self.pop() + + def ROT_TWO(self, inst): + a = self.pop() + b = self.pop() + self.push(a) + self.push(b) + + def ROT_THREE(self, inst): + a = self.pop() + b = self.pop() + c = self.pop() + self.push(a) + self.push(c) + self.push(b) + + def ROT_FOUR(self, inst): + a = self.pop() + b = self.pop() + c = self.pop() + d = self.pop() + self.push(a) + self.push(d) + self.push(c) + self.push(b) + + def DUP_TOP(self, inst): + a = self.pop() + self.push(a) + self.push(a) + + def DUP_TOP_TWO(self, inst): + a = self.pop() + b = self.pop() + self.push(b) + self.push(a) + self.push(b) + self.push(a) + + def _convert_value(self, value, flag): + if flag == 1: + return BuiltinVariable(str).call_function(self, [value], {}) # type: ignore[arg-type] + elif flag == 2: + return BuiltinVariable(repr).call_function(self, [value], {}) # type: ignore[arg-type] + elif flag == 3: + return BuiltinVariable(ascii).call_function(self, [value], {}) # type: ignore[arg-type] + return value + + def _format_value(self, fmt_spec, flags): + value = self.pop() + if isinstance(value, SymNodeVariable): + from torch._dynamo.variables.lazy import ( + LazySymNodeFormatString, + LazyVariableTracker, + ) + + value = LazyVariableTracker.create( + LazySymNodeFormatString(value, fmt_spec), source=value.source + ) + self.push(value) + return + + value = self._convert_value(value, flags & 0x03) + + fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}") + + self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) + + def FORMAT_VALUE(self, inst): + flags = inst.arg + if (flags & 0x04) == 0x04: + fmt_spec = self.pop() + else: + fmt_spec = ConstantVariable.create("") + + return self._format_value(fmt_spec, flags) + + def BUILD_STRING(self, inst): + format_string_parts: list[str] = [] + args: list[VariableTracker] = [] + kwargs: dict[str, VariableTracker] = {} + for part in self.popn(inst.arg): + if isinstance(part, ConstantVariable): + format_string_parts.append("{}") + args.append(part) + elif isinstance(part, variables.StringFormatVariable): + format_string_parts.append(part.format_string) + args.extend(part.sym_args) + if set(kwargs.keys()) & set(part.sym_kwargs.keys()): + unimplemented_v2( + gb_type="BUILD_STRING key conflict", + context=f"format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}", + explanation="Failed to build format string due to key conflict", + hints=[*graph_break_hints.USER_ERROR], + ) + kwargs.update(part.sym_kwargs) + else: + unimplemented_v2( + gb_type="BUILD_STRING type error", + context=str(part), + explanation="Format string part type is not correct - expected constant or format string.", + hints=[*graph_break_hints.USER_ERROR], + ) + self.push( + variables.StringFormatVariable.create( + "".join(format_string_parts), args, kwargs + ) + ) + + def IS_OP(self, inst): + assert inst.argval == 0 or inst.argval == 1 + if inst.argval == 0: + new_argval = "is" + else: + new_argval = "is not" + new_inst = create_instruction("COMPARE_OP", argval=new_argval) + self.COMPARE_OP(new_inst) + + def CONTAINS_OP(self, inst): + assert inst.argval == 0 or inst.argval == 1 + left, right = self.popn(2) + op = inst.argval + self.push(right.call_method(self, "__contains__", [left], {})) + if op == 1: + self.UNARY_NOT(inst) + + def LIST_EXTEND(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, ListVariable) + assert obj.is_mutable() + obj.call_method(self, "extend", [v], {}) + + def LIST_TO_TUPLE(self, inst): + self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type] + + def STOPITERATION_ERROR(self, inst): + # wrap the generator body in a try: ... except StopIteration: ... which + # converts the StopIteration into a RuntimeError + # https://peps.python.org/pep-0479/ + # https://github.com/python/cpython/pull/99006 + # https://github.com/python/cpython/commit/28187141cc34063ef857976ddbca87ba09a882c2 + val = self.stack[-1] + assert self._isinstance_exception(val) + if val.exc_type is StopIteration: # type: ignore[attr-defined] + new_val = variables.BuiltinVariable(RuntimeError).call_function( + self, # type: ignore[arg-type] + [ConstantVariable("generator raised StopIteration")], + {}, + ) + new_val.call_setattr(self, ConstantVariable("__context__"), val) # type: ignore[attr-defined] + new_val.call_setattr(self, ConstantVariable("__cause__"), val) # type: ignore[attr-defined] + self.stack[-1] = new_val + + def DICT_MERGE(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg].realize() + assert isinstance(obj, ConstDictVariable) + assert obj.is_mutable() + obj.call_method(self, "update", [v], {}) + + DICT_UPDATE = DICT_MERGE + + def GEN_START(self, inst): + self.pop() + + def GET_LEN(self, inst): + tos = self.stack[-1] + if tos.is_python_constant(): + self.push(ConstantVariable.create(len(tos.as_python_constant()))) + else: + self.push(tos.call_method(self, "__len__", [], {})) + + def MATCH_MAPPING(self, inst): + tos = self.stack[-1] + assert isinstance(tos, ConstDictVariable) + if isinstance(tos.items, collections.abc.Mapping): + self.push(ConstantVariable.create(True)) + else: + self.push(ConstantVariable.create(False)) + + def MATCH_SEQUENCE(self, inst): + tos = self.stack[-1] + assert tos.is_python_constant() + tos_value = tos.as_python_constant() + if isinstance(tos_value, collections.abc.Sequence) and not isinstance( + tos_value, (str, bytes, bytearray) + ): + self.push(ConstantVariable.create(True)) + else: + self.push(ConstantVariable.create(False)) + + def MATCH_KEYS(self, inst): + tos = self.stack[-1] + tos1 = self.stack[-2] + assert isinstance(tos1, ConstDictVariable) + + if all(k in tos1 for k in tos): # type: ignore[attr-defined] + self.push(TupleVariable([tos1.getitem_const(self, k) for k in tos])) # type: ignore[attr-defined,arg-type] + if sys.version_info < (3, 11): + self.push(ConstantVariable.create(True)) + else: + self.push(ConstantVariable.create(None)) + if sys.version_info < (3, 11): + self.push(ConstantVariable.create(False)) + + def LOAD_ASSERTION_ERROR(self, inst): + self.push(self.load_builtin_from_argval("AssertionError")) + + def LOAD_BUILD_CLASS(self, inst): + unimplemented_v2( + gb_type="LOAD_BUILD_CLASS bytecode not supported", + context="", + explanation="Dynamo does not support tracing classes that are defined in the compiled region.", + hints=[ + "Move the class definition out of the compiled region.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + UNARY_POSITIVE = stack_op(operator.pos) + UNARY_NEGATIVE = stack_op(operator.neg) + UNARY_NOT = stack_op(operator.not_) + UNARY_INVERT = stack_op(operator.invert) + + BINARY_POWER = stack_op(operator.pow) + BINARY_MULTIPLY = stack_op(operator.mul) + BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul) + BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv) + BINARY_TRUE_DIVIDE = stack_op(operator.truediv) + BINARY_MODULO = stack_op(operator.mod) + BINARY_REMAINDER = stack_op(operator.mod) + BINARY_ADD = stack_op(operator.add) + BINARY_SUBTRACT = stack_op(operator.sub) + BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem)) + BINARY_LSHIFT = stack_op(operator.lshift) + BINARY_RSHIFT = stack_op(operator.rshift) + BINARY_AND = stack_op(operator.and_) + BINARY_OR = stack_op(operator.or_) + BINARY_XOR = stack_op(operator.xor) + + INPLACE_POWER = stack_op(operator.ipow) + INPLACE_MULTIPLY = stack_op(operator.imul) + INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul) + INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv) + INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv) + INPLACE_MODULO = stack_op(operator.imod) + INPLACE_REMAINDER = stack_op(operator.imod) + INPLACE_ADD = stack_op(operator.iadd) + INPLACE_SUBTRACT = stack_op(operator.isub) + INPLACE_LSHIFT = stack_op(operator.ilshift) + INPLACE_RSHIFT = stack_op(operator.irshift) + INPLACE_AND = stack_op(operator.iand) + INPLACE_XOR = stack_op(operator.ixor) + INPLACE_OR = stack_op(operator.ior) + + # 3.11 opcodes + def RESUME(self, inst): + if inst.arg == 0: + self.append_prefix_inst(inst) + self.accept_prefix_inst = False + else: + assert not self.accept_prefix_inst + + if sys.version_info >= (3, 11): + + def BINARY_OP(self, inst): + return _binary_op_lookup[inst.arg](self, inst) + + def PRECALL(self, inst): + pass + + def KW_NAMES(self, inst): + kw_names = self.code_options["co_consts"][inst.arg] + assert isinstance(kw_names, tuple) + for name in kw_names: + assert isinstance(name, str) + assert self.kw_names is None + self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment] + + def PUSH_NULL(self, inst): + self.push(NullVariable()) + + def _call(self, inst, call_kw=False): + # see https://docs.python.org/3.11/library/dis.html#opcode-CALL + # for convention + if call_kw: + # TOS is kw_names for CALL_KW instruction + assert sys.version_info >= (3, 13) + kw_names = self.pop() + assert isinstance(kw_names, TupleVariable) and kw_names.is_python_constant() + kw_names = kw_names.as_python_constant() + else: + kw_names = self.kw_names.value if self.kw_names else () + + contents = self.popn(inst.arg + 2) + if sys.version_info >= (3, 13): + # NULL and callable swapped + fn = contents[0] + args = [] if isinstance(contents[1], NullVariable) else [contents[1]] + else: + if isinstance(contents[0], NullVariable): + fn = contents[1] + args = [] + else: + fn = contents[0] + args = [contents[1]] + + if kw_names: + args = args + contents[2 : -len(kw_names)] + kwargs_list = contents[-len(kw_names) :] + kwargs = dict(zip(kw_names, kwargs_list)) + assert len(kwargs) == len(kw_names) + else: + args = args + contents[2:] + kwargs = {} + + try: + # if call_function fails, need to set kw_names to None, otherwise + # a subsequent call may have self.kw_names set to an old value + self.call_function(fn, args, kwargs) + finally: + self.kw_names = None + + @break_graph_if_unsupported(push=1) + def CALL(self, inst): + self._call(inst) + + def COPY(self, inst): + self.push(self.stack[-inst.arg]) + + def SWAP(self, inst): + self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1] + + JUMP_BACKWARD = jump + JUMP_BACKWARD_NO_INTERRUPT = jump + + POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False) + POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False) + POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False) + POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False) + + def CACHE(self, inst): + pass + + def BEFORE_WITH(self, inst): + self.setup_or_before_with(inst) + + def setup_or_before_with(self, inst): + ctx = self.pop() + if not isinstance( + ctx, (ContextWrappingVariable, GenericContextWrappingVariable) + ): + unimplemented_v2( + gb_type="Unsupported context manager", + context=f"Attempted SETUP_WITH/BEFORE_WITH on {ctx}", + explanation=f"Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.", + hints=[ + "Avoid using the unsupported context manager.", + "If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then " + "it may be the case that it was created outside the compiled region, which Dynamo does not support. " + "Supported context managers can cross graph break boundaries only if they are local non-closure " + "variables, or are intermediate values.", + "File an issue to PyTorch. Simple context managers can potentially be supported, " + "but note that context managers can't be supported in general", + ], + ) + + if ( + isinstance(ctx, GenericContextWrappingVariable) + and not ctx.supports_graph_breaks() + ): + self.active_generic_context_managers.append(ctx) + + # Need this redundant check for mypy + assert isinstance( + ctx, (ContextWrappingVariable, GenericContextWrappingVariable) + ) + + exit = WithExitFunctionVariable( + ctx, + inst.target, + ) + + if sys.version_info >= (3, 11): + # See create_call_resume_at for block stack details. + # Only push a block if the current instruction's block is a + # with block that is not nested in a try block - that is, the current + # instruction's block target is the same as the top block's target. + if inst.exn_tab_entry and ( + not self.block_stack + or inst.exn_tab_entry.target is not self.block_stack[-1].target + ): + target = None + else: + target = self.next_instruction.exn_tab_entry.target + else: + target = inst.target + + self.push(exit) + + if target: + if isinstance(self, InstructionTranslator): + self.block_stack.append( + BlockStackEntry(inst, target, len(self.stack), ctx) + ) + else: + self.block_stack.append(BlockStackEntry(inst, target, len(self.stack))) + + self.push(ctx.enter(self)) + + def append_prefix_inst(self, inst): + assert self.accept_prefix_inst + self.prefix_insts.append(inst) + + def MAKE_CELL(self, inst): + if sys.version_info >= (3, 12) and not self.accept_prefix_inst: + # In 3.12+, MAKE_CELL is not longer necessarily a prefix instruction. + # It can be generated by inlined comprehensions. + assert isinstance(self.symbolic_locals[inst.argval], NullVariable) + self.symbolic_locals[inst.argval] = ( + self.output.side_effects.track_cell_new() + ) + else: + self.append_prefix_inst(inst) + + def COPY_FREE_VARS(self, inst): + self.append_prefix_inst(inst) + + def RETURN_GENERATOR(self, inst): + self.append_prefix_inst(inst) + + # 3.12 opcodes + # BINARY/STORE_SLICE opcodes are broken down into + # BUILD_SLICE 2 and BINARY/STORE_SUBSCR + + def END_FOR(self, inst): + if sys.version_info >= (3, 13): + self.pop() + else: + self.popn(2) + + def LOAD_FAST_CHECK(self, inst): + if isinstance(self.symbolic_locals.get(inst.argval, None), NullVariable): + unimplemented_v2( + gb_type="LOAD_FAST_CHECK on uninitialized variable", + context=inst.argval, + explanation=f"Attempted to load uninitialized local variable {inst.argval}", + hints=[*graph_break_hints.USER_ERROR], + ) + self.LOAD_FAST(inst) + + def LOAD_FAST_AND_CLEAR(self, inst): + if inst.argval not in self.symbolic_locals: + self.push(NullVariable()) + else: + self.LOAD_FAST(inst) + self.symbolic_locals[inst.argval] = NullVariable() + + def LOAD_SUPER_ATTR(self, inst): + self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) + if inst.arg & 1: + self.LOAD_METHOD(inst) + else: + self._load_attr(inst) + + def CALL_INTRINSIC_1(self, inst): + if inst.argval == 3: + # INTRINSIC_STOPITERATION_ERROR + self.STOPITERATION_ERROR(inst) + elif inst.argval == 5: + # INTRINSIC_UNARY_POSITIVE + self.UNARY_POSITIVE(inst) + elif inst.argval == 6: + # INTRINSIC_LIST_TO_TUPLE + self.push(TupleVariable(self.pop().force_unpack_var_sequence(self))) + else: + unimplemented_v2( + gb_type="Missing CALL_INTRINSIC_1 handler", + context=f"CALL_INTRINSIC_1 operand: {inst.argval}", + explanation=f"No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + def END_SEND(self, inst): + tos = self.pop() + self.pop() + self.push(tos) + + # 3.13 opcodes + # fused instructions LOAD_FAST_LOAD_FAST, STORE_FAST_STORE_FAST, STORE_FAST_LOAD_FAST + # are broken down. + @break_graph_if_unsupported(push=1) + def CALL_KW(self, inst): + self._call(inst, call_kw=True) + + def TO_BOOL(self, inst): + # TO_BOOL only precedes a conditional jump or UNARY_NOT (see compile.c in CPython) + # So we can skip this instruction as long as we remember to codegen a TO_BOOL + # before conditional jumps/UNARY_NOT. + assert self.next_instruction.opname in ( + "POP_JUMP_IF_TRUE", + "POP_JUMP_IF_FALSE", + "UNARY_NOT", + ) + + def SET_FUNCTION_ATTRIBUTE(self, inst): + flags = inst.arg + fn = self.pop() + assert isinstance(fn, NestedUserFunctionVariable) + attr = self.pop() + + if flags & 0x08: + fn.closure = attr + elif flags & 0x04: + fn.annotations = attr + elif flags & 0x02: + fn.kwdefaults = attr + elif flags & 0x01: + fn.defaults = attr + + self.push(fn) + + def CONVERT_VALUE(self, inst): + self.push(self._convert_value(self.pop(), inst.argval)) + + def FORMAT_SIMPLE(self, inst): + self._format_value(ConstantVariable.create(""), 0) + + def FORMAT_WITH_SPEC(self, inst): + self._format_value(self.pop(), 0) + + def is_non_empty_graph(self): + if self.output.count_calls() > 1: + # perf optimization only + self.is_non_empty_graph = lambda: True # type: ignore[method-assign] + return True + return False + + def format_frame_summary(self, additional_stack_frames=None): + if additional_stack_frames is None: + additional_stack_frames = [] + return "".join( + traceback.format_list( + [self.frame_summary()] + list(reversed(additional_stack_frames)) + ) + ) + + def frame_summary(self): + return traceback.FrameSummary( + getattr(self.f_code, "co_filename", ""), + self.lineno, + getattr(self.f_code, "co_name", ""), + lookup_line=False, + ) + + def is_co_filename_from_nn_modules(self): + filename = getattr(self.f_code, "co_filename", "") + nn_modules_pattern = re.compile(r".*torch/nn/modules.*") + return nn_modules_pattern.match(filename) is not None + + def store_global_weakref_by_id(self, prefix, value): + global_name = self.output.install_global_by_id(prefix, weakref.ref(value)) + install_guard( + GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE) + ) + return global_name + + @property + def fake_mode(self): + return self.output.tracing_context.fake_mode + + @contextlib.contextmanager + def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]): + """ + Strict mode is enabled on a per-VariableTracker level depending on the return value of check_fn(node). + """ + prior = self.strict_checks_fn + self.strict_checks_fn = check_fn + try: + yield + finally: + self.strict_checks_fn = prior + + def speculate(self) -> SpeculationEntry: + assert self.instruction_pointer is not None + assert self.instruction_pointer > 0 + return self.speculation_log.next( + self.f_code.co_filename, + self.lineno, + self.instruction_pointer - 1, + self.instructions[self.instruction_pointer - 1], + ) + + def __init__( + self, + output: OutputGraph, + instructions: list[Instruction], + f_locals: dict[str, Any], + f_globals: dict[str, Any], + f_builtins: dict[str, Any], + code_options: dict[str, Any], + symbolic_locals: dict[str, VariableTracker], + symbolic_globals: dict[str, VariableTracker], + symbolic_torch_function_state: SymbolicTorchFunctionState, + f_code: types.CodeType, + export: bool, + inline_depth: int, + speculation_log: SpeculationLog, + exn_vt_stack: ExceptionStack, + distributed_state: Optional[DistributedState], + # This determines whether to use the execution recorder. + closure: Optional[tuple[types.CellType]] = None, + package: Optional["CompilePackage"] = None, + ) -> None: + super().__init__() + self.speculation_log = speculation_log + self.distributed_state = distributed_state + + # Mutable state checkpointed by copy_graphstate() + self.output = output + self.symbolic_locals = symbolic_locals + self.symbolic_globals = symbolic_globals + self.symbolic_torch_function_state = symbolic_torch_function_state + self.stack = [] + self.instruction_pointer = 0 + self.start_point = None + self.current_instruction = create_instruction("NOP") + self.block_stack = [] + # states before SETUP_WITH for checkpointing and fallback + self.active_generic_context_managers: list[GenericContextWrappingVariable] = [] + self.lineno = -1 + self.kw_names = None + self.accept_prefix_inst = True + self.prefix_insts = [] + self.exn_vt_stack = exn_vt_stack + + # Properties of the input/output code + self.instructions: list[Instruction] = instructions + self.indexof: dict[Instruction, int] = get_indexof(self.instructions) + self.f_locals: dict[str, Any] = ( + f_locals # needed for recording accessed locals for replay + ) + self.f_globals: dict[str, Any] = f_globals + self.f_builtins: dict[str, Any] = f_builtins + self.code_options: dict[str, Any] = code_options + self.f_code: types.CodeType = f_code + + # Execution record for replaying errors + if closure is not None and config.replay_record_enabled: + self.exec_recorder = ExecutionRecorder( + code=f_code, closure=closure, code_options=code_options + ) + else: + self.exec_recorder = None + # Stack of module being parsed, current nn.module is at the end of ordered dict. + # The first field of tuple is the fully qualified name of current module + # in original hierarchy. The second field is the type of current nn.module + self.nn_module_stack: dict[str, tuple[str, type[Any]]] = {} + self.num_calls: dict[str, int] = {} + # Flag to indicate whether tracing is used for export. + self.export = export + self.one_graph = False + + self.current_speculation = None + + self.strict_checks_fn = None + + self.is_leaf_tracer = True + self.parent = None + self.debug_locals = [] + + self.package = package + + if sys.version_info >= (3, 10): + from .resume_execution import ( + CO_ASYNC_GENERATOR, + CO_COROUTINE, + CO_GENERATOR, + CO_ITERABLE_COROUTINE, + ) + + if f_code.co_flags & ( + CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR + ): + self.push(BuiltinVariable(None)) + + self.inline_depth = inline_depth + self.inconsistent_side_effects = False + self._constants_cache: list[Optional[VariableTracker]] = [None] * len( + f_code.co_consts + ) + + self.is_trace_bytecode_log_enabled: Optional[bool] = ( + trace_bytecode_log.isEnabledFor(logging.DEBUG) + ) + self.is_trace_source_log_enabled: Optional[bool] = ( + trace_source_log.isEnabledFor(logging.DEBUG) + ) + linecache.lazycache(f_code.co_filename, f_globals) + + +class InstructionTranslator(InstructionTranslatorBase): + @staticmethod + def current_tx() -> "InstructionTranslator": + return tls.current_tx + + @contextlib.contextmanager + def set_current_tx(self): + prior = getattr(tls, "current_tx", None) + tls.current_tx = self + try: + yield + finally: + tls.current_tx = prior + + def __init__( + self, + instructions: list[Instruction], + f_code, + f_locals, + f_globals, + f_builtins, + closure, + torch_function_mode_stack, + code_options, + compiler_fn, + one_graph, + export, + export_constraints, + frame_state, + speculation_log: SpeculationLog, + exn_vt_stack: ExceptionStack, + distributed_state: Optional[DistributedState], + package: Optional["CompilePackage"], + ) -> None: + _step_logger()( + logging.INFO, + f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}", + ) + super().__init__( + output=OutputGraph( + code_options, + compiler_fn, + self, + export, + export_constraints, + frame_state, + local_scope=f_locals, + global_scope=f_globals, + f_code=f_code, + torch_function_mode_stack=torch_function_mode_stack, + package=package, + ), + instructions=instructions, + f_locals=f_locals, + f_globals=f_globals, + f_builtins=f_builtins, + closure=closure, + code_options=code_options, + symbolic_locals={}, # set below + # A global var is inserted only after a STORE_GLOBAL happens to it + symbolic_globals={}, + symbolic_torch_function_state=None, # type: ignore[arg-type] # set below + f_code=f_code, + export=export, + inline_depth=0, + speculation_log=speculation_log, + exn_vt_stack=exn_vt_stack, + distributed_state=distributed_state, + package=package, + ) + + self._throw_if_in_functorch() + + # as soon as we create the tracing context we should keep it active, so any calls + # into dynamo apis can rely on finding it + with tracing(self.output.tracing_context), self.set_current_tx(): + self.one_graph: bool = one_graph + self.export = export + if self.export: + assert self.one_graph, ( + "Export without one graph - something has gone wrong." + ) + + self.symbolic_locals = {} + # Populate `symbolic_locals` with non-cell variables. + cell_and_freevars: set[str] = set(self.cell_and_freevars()) + + dynamism = code_context.get_context(f_code).get("dynamism", None) + for name, value in f_locals.items(): + if name not in cell_and_freevars: + local_dynamism = None + if dynamism: + local_dynamism = frozenset(dynamism.get(name, {}).items()) + var = LazyVariableTracker.create( + value, + LocalSource( + name, + is_input=True, + dynamism=local_dynamism, + ), + ) + self.symbolic_locals[name] = var + + # Populate `symbolic_locals` with cells created by this frame, + # effectively implementing the `MAKE_CELL` instructions. + side_effects = self.output.side_effects + for name in self.cellvars(): + if name in f_locals: + # This models cells that are also function inputs. + value = f_locals[name] + # NOTE: root frame inputs that are captured by a nested + # function become special cell objects -- they exist in + # `f_locals` as contents of the cells, rather than the cells + # objects themselves. + # + # In Dynamo, we choose to represent such input cell objects + # as newly created (rather than pre-existing) cell objects, + # because + # + # 1. The reason for representing a pre-existing cell object + # is to emit guard or codegen mutations. However, local + # cells should never be used for guards. Moreover, at this + # point these input cell objects should've never been + # accessed by anyone else, since Dynamo intercepts the frame + # right after its evaluation starts, i.e., right after these + # cell objects are created. So they should have no external + # reference, meaning no mutation needs to be propagated. + # + # 2. This conveniently allows codegen to prune away + # mutations to these cells, unless they escape the frame. + contents_source = LocalSource( + name, is_input=True, is_derefed_cell_contents=True + ) + contents_var: VariableTracker = LazyVariableTracker.create( + value, contents_source + ) + cell_var = side_effects.track_cell_new() + side_effects.store_cell(cell_var, contents_var) + else: + cell_var = side_effects.track_cell_new() + cell_var.local_name = name + self.symbolic_locals[name] = cell_var + + # Populate `symbolic_locals` with cells captured by this frame, + # effectively implementing the `COPY_FREE_VARS` instruction. + for name, cell in zip(self.freevars(), closure): + cell_source = LocalCellSource(name) + contents_source = LocalSource(name, is_derefed_cell_contents=True) + try: + contents_var = LazyVariableTracker.create( + cell.cell_contents, contents_source + ) + except ValueError: + # Cell has not yet been assigned + contents_var = variables.DeletedVariable() + cell_var = side_effects.track_cell_existing( + cell_source, cell, contents_var + ) + cell_var.local_name = name + self.symbolic_locals[name] = cell_var + + self.symbolic_torch_function_state = SymbolicTorchFunctionState( + torch_function_mode_stack + ) + + if export: + # export gets confused if we never realize unused inputs + # in export mode just eagerly realize everything + self.symbolic_locals = variables.LazyVariableTracker.realize_all( + self.symbolic_locals + ) + + def _throw_if_in_functorch(self): + # Fallback to eager in case of a graph break inside vmap + eager = torch._dynamo.lookup_backend("eager") + compiler_fn = inspect.getattr_static( + self.output.compiler_fn, "compiler_fn", self.output.compiler_fn + ) + ci = torch._C._functorch.peek_interpreter_stack() + forbidden_keys = ( + torch._C._functorch.TransformType.Vmap, + torch._C._functorch.TransformType.Grad, + torch._C._functorch.TransformType.Jvp, + ) + + if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager: + name = ci.key().name.lower() + msg = ( + "If you are reaching here, it means dynamo failed for one of the following reasons:\n" + # Calling a torch.compiled function + f"- Calling torch.func.{name}(compiled_fn) function from eager mode is not supported. " + f"Ensure that torch.func.{name} is also wrapped within a torch.compile function. " + "For more information, see PyTorch issue #128711.\n" + # if it reaches here, it means Dynamo failed to inline a functorch function + f"- torch.func.{name}(fn) requires the function to be inlined by dynamo" + ) + unimplemented_v2( + gb_type="Unsupported functorch tracing attempt", + context="", + explanation=msg, + hints=[], + ) + + def get_example_value(self, source: Source): + if isinstance(source, LocalSource): + return self.f_locals[source.local_name] + if isinstance(source, GlobalSource): + return self.f_globals[source.global_name] + raise KeyError + + def run(self): + super().run() + + def should_compile_partial_graph(self): + if sys.version_info >= (3, 11): + # Do not compile if current instruction's block is not the top with block + entry = self.current_instruction.exn_tab_entry + if entry and ( + not self.block_stack or entry.target is not self.block_stack[-1].target + ): + return False + return ( + all(b.can_restore() for b in self.block_stack) + and not self.one_graph + and not self.active_generic_context_managers + ) + + def create_call_resume_at(self, inst, all_stack_locals_metadata): + self.instruction_pointer = None + + if inst.opname == "RETURN_VALUE": + return [create_instruction("RETURN_VALUE")] + elif inst.opname == "RETURN_CONST": + return [create_instruction("RETURN_CONST", argval=inst.argval)] + + reads = livevars_analysis(self.instructions, inst) + all_argnames = tuple( + k + for k in self.symbolic_locals.keys() + if k in reads and k not in self.cell_and_freevars() + ) + # NOTE: do not use isinstance, since it realizes lazy VT's + argnames_null_set = set(all_stack_locals_metadata[0].locals_null_keys) + argnames = tuple(k for k in all_argnames if k not in argnames_null_set) + argnames_null = tuple(k for k in all_argnames if k in argnames_null_set) + if sys.version_info < (3, 12): + assert len(argnames_null) == 0, "variables should not be NULL in < 3.12" + # compile_subgraph did not codegen any NULLs, + # so we should not count NullVariables + stack_len = len(self.stack) - len(all_stack_locals_metadata[0].stack_null_idxes) + nargs = stack_len + len(argnames) + + cg = PyCodegen(self) + + # Handle inactive context variables. + # The resume function assumes that context variables are the class, NOT the object. + # e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled + # NOTE: if the unsupported instruction modifies the inactive context variable, it may + # result in silent incorrectness! + for (i, _), i_orig in zip( + all_stack_locals_metadata[0].stack_ctx_args, + all_stack_locals_metadata[0].stack_ctx_idxes_orig, + ): + # Replace the current stack var with the context class + ctx = cast(ContextWrappingVariable, self.stack[i_orig]) + ctx.reconstruct_type(cg) + cg.extend_output(create_swap(stack_len - i + 1)) + cg.append_output(create_instruction("POP_TOP")) + + for name, _ in all_stack_locals_metadata[0].locals_ctx_args: + # Replace the local with the context class + ctx = cast(ContextWrappingVariable, self.symbolic_locals[name]) + ctx.reconstruct_type(cg) + cg.append_output(create_instruction("STORE_FAST", argval=name)) + + name = unique_id(f"__resume_at_{inst.offset}", with_uuid=True) + + new_code: types.CodeType = ContinueExecutionCache.lookup( + self.f_code, + self.lineno, + inst.offset, + tuple(b.target.offset for b in self.block_stack), + stack_len, + argnames, + argnames_null, + tuple(b.resume_fn() for b in self.block_stack), + tuple(all_stack_locals_metadata[0].stack_ctx_args), + tuple(all_stack_locals_metadata[0].locals_ctx_args), + tuple(all_stack_locals_metadata[0].stack_null_idxes), + ) + + # Add original GraphModule context to the resume function to handle + # the case of a graph break while tracing a GraphModule + orig_graphmodule_maybe = code_context.get_context(self.f_code).get( + "orig_graphmodule", lambda: None + )() + if orig_graphmodule_maybe is not None: + code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref( + orig_graphmodule_maybe + ) + + if new_code.co_freevars: + # expose code object for debugging purposes + self.output.install_global_unsafe(name, new_code) + cg.make_function_with_closure(name, new_code, True, stack_len) + package_name = None + else: + # This is safe: we pre-generate a unique name + self.output.install_global_unsafe( + name, types.FunctionType(new_code, self.f_globals, name) + ) + cg.extend_output(cg.load_function_name(name, True, stack_len)) + package_name = name + + if self.package is not None: + self.package.add_resume_function( + new_code, self.f_globals["__name__"], package_name + ) + + cg.extend_output([cg.create_load(k) for k in argnames]) + cg.extend_output(create_call_function(nargs, False)) + cg.append_output(create_instruction("RETURN_VALUE")) + return cg.get_instructions() + + def symbolic_locals_contain_module_class(self): + for v in self.symbolic_locals.values(): + if isinstance(v, UserDefinedClassVariable) and issubclass( + v.as_python_constant(), torch.nn.Module + ): + return True + return False + + def replace_tos_if_return_is_generator(self): + if ( + len(self.stack) + and (tos := self.stack[-1]) + and isinstance(tos, LocalGeneratorObjectVariable) + ): + self.stack[-1] = ListIteratorVariable( + tos.force_unpack_var_sequence(self), + mutation_type=ValueMutationNew(), + ) + + def _return(self, inst): + self.replace_tos_if_return_is_generator() + assert self.instruction_pointer is not None + assert self.start_point is not None + get_metrics_context().increment( + "ir_count", self.instruction_pointer - self.start_point + ) + + if ( + not config.allow_empty_graphs + and self.output.count_calls() == 0 + and not self.inconsistent_side_effects + and not self.symbolic_locals_contain_module_class() + and not self.export + and not self.one_graph + ): + raise exc.SkipFrame("because no content in function call") + + self.instruction_pointer = None + _step_logger()( + logging.INFO, + f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})", + ) + log.debug("%s triggered compile", inst.opname) + all_stack_locals_metadata = self.output.compile_subgraph( + self, + reason=GraphCompileReason( + "return_value", [self.frame_summary()], graph_break=False + ), + ) + # check that our stack/locals meta are correct: + # we should only be tracing 1 frame, and there should not be any NULLs on the stack + assert len(all_stack_locals_metadata) == 1 + assert not all_stack_locals_metadata[0].stack_null_idxes + return_inst = ( + create_instruction("RETURN_VALUE") + if inst.opname == "RETURN_VALUE" + else create_instruction("RETURN_CONST", argval=inst.argval) + ) + self.output.add_output_instructions([return_inst]) + raise ReturnValueOp + + def RETURN_VALUE(self, inst): + self._return(inst) + + def RETURN_CONST(self, inst): + self._return(inst) + + +if sys.version_info >= (3, 11): + _binary_op_lookup = [ + getattr( + InstructionTranslator, + opname[3:] if "INPLACE" in opname else f"BINARY_{opname[3:]}", + ) + for opname, _ in dis._nb_ops # type: ignore[attr-defined] + ] + + +class InliningInstructionTranslator(InstructionTranslatorBase): + """Trace and inline a called method""" + + symbolic_result: Optional[VariableTracker] + parent: InstructionTranslatorBase + + @classmethod + def inline_call(cls, parent, func, args, kwargs): + with patch.dict(counters, {"unimplemented": counters["inline_call"]}): + tracer = cls.build_inline_tracer(parent, func, args, kwargs) + return tracer.inline_call_() + + @staticmethod + def check_inlineable(func): + if func.has_self(): + unimplemented_v2( + gb_type="Inline attempt with __self__", + context=str(func), + explanation="Attempted to inline a function with the `__self__` attribute. " + "Dynamo is expected to decompose method calls into function calls with a `self` argument.", + hints=[], + ) + + if isinstance(func, UserFunctionVariable) and inspect.getattr_static( + func.get_function(), "_torchdynamo_disable", False + ): + msg = inspect.getattr_static( + func.get_function(), "_torchdynamo_disable_msg", None + ) + unimplemented_v2( + gb_type="Skip inlining `torch.compiler.disable()`d function", + context=str(func.get_function()), + explanation=f"Skip inlining function {func.get_function()} since it was wrapped " + f"with `torch.compiler.disable` (reason: {msg})", + hints=[ + "Remove the `torch.compiler.disable` call", + ], + ) + + result = trace_rules.check_verbose(func, is_inlined_call=True) + if result.skipped: + from torch._dynamo.variables.misc import produce_trampoline_autograd_apply + + # _origin marks this as coming from an internal dynamo known function that is safe to + # trace through. + if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [ + produce_trampoline_autograd_apply, + ]: + # Known sound + return trace_rules.SkipResult( + False, "allowlist in dynamo known function" + ) + fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else "" + hints = [ + f"Avoid calling the function `{fn_qualname}`.", + ] + if "_dynamo" not in func.get_filename(): + hints += [ + f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{fn_qualname}` " + "to force tracing into the function. " + "More graph breaks may occur as a result of attempting to trace into the function.", + "Please file an issue to PyTorch.", + ] + unimplemented_v2( + gb_type="Attempted to inline function marked as skipped", + context=f"qualname: {fn_qualname}, name: {func.get_name()}, " + f"filename: `{func.get_filename()}`, skip reason: {result.reason}", + explanation=f"Dynamo developers have intentionally marked that the function `{fn_qualname}` " + "should not be traced.", + hints=hints, + ) + + return result + + @staticmethod + def build_inline_tracer( + parent, + func: VariableTracker, + args: list[VariableTracker], + kwargs, + ): + assert isinstance( + func, + ( + UserFunctionVariable, + NestedUserFunctionVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, + ), + ) + code: types.CodeType = func.get_code() + result = None + tracing_ctx = parent.output.tracing_context + + # Check if we have already identified this function to be inline-able. + # The exception is dont_skip_tracing flag which affects the inline + # behavior. If the flag is True, don't rely on previous results. + if not config.dont_skip_tracing and tracing_ctx: + if previous_result := tracing_ctx.previously_inlined_functions.get( + code, None + ): + result = previous_result + + if result is None: + if isinstance(func, SkipFunctionVariable): + unimplemented_v2( + gb_type="Attempted to inline function marked as skipped (SkipFunctionVariable)", + context=f"Attempted to inline a SkipFunctionVariable {func}", + explanation=( + "Attempted to inline a function that was previously determined to be marked as intentionally skipped." + ), + hints=[], + ) + result = InliningInstructionTranslator.check_inlineable(func) + assert result.skipped is False + + if not config.dont_skip_tracing and tracing_ctx: + tracing_ctx.previously_inlined_functions[code] = result + + try: + sub_locals = func.bind_args(parent, args, kwargs) + except TypeError as e: + # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info + raise ArgsMismatchError( # noqa: B904 + "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( + reason=str(e), + func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", + args=[arg.python_type() for arg in args], + kwargs=kwargs, + ), + ) + + for v in itertools.chain(sub_locals.values()): + if not isinstance(v, VariableTracker): + unimplemented_v2( + gb_type="Encountered unconverted argument when attempting to inline", + context=f"func: {func}, arg: {v}", + explanation="An argument to an inlined function was not successfully converted to a VariableTracker.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + if code.co_name in ("__setitem__", "__setattr__") and not ( + args and isinstance(args[0], variables.UserDefinedObjectVariable) + ): + unimplemented_v2( + gb_type="Unsupported __setitem__/__setattr__ inline attempt", + context=f"code name: {code.co_name}, args: {args}", + explanation=f"Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.", + hints=[], + ) + + suffix = "" + # TODO: mlazos, add support for enabling multiple artifact logs + # with a single alias + if torch._logging._internal.log_state.is_artifact_enabled("bytecode"): + suffix = f"\n{dis.Bytecode(code).dis()}" + if sys.version_info >= (3, 11): + cur_inst = parent.current_instruction + parent_code = parent.f_code + + def get_trace_call_log_str(): + header = parent.get_line_of_code_header( + lineno=cur_inst.positions.lineno + ) + line = get_instruction_source_311(parent_code, cur_inst).rstrip() + return f"TRACE inlined call {code.co_name} from {header}\n{line}" + + trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) + log.debug("INLINING %s%s, %s", code, suffix, result.reason) + + # Detect inline GraphModule calls in order to propagate node metadata, + # by checking if the first argument (self) is a variable tracking a GraphModule. + if args and isinstance(args[0], NNModuleVariable): + module = parent.output.get_submodule(args[0].module_key) + if isinstance(module, torch.fx.GraphModule): + # The inline call might not actually be a call to `forward`, + # but it is enough to add a context for `forward` in case it is called. + code_context.get_context(module.forward.__code__)[ + "orig_graphmodule" + ] = weakref.ref(module) + + tracer: InliningInstructionTranslator + if is_generator(code): + tracer = InliningGeneratorInstructionTranslator( + parent, + code, + sub_locals, + parent.symbolic_globals, + parent.symbolic_torch_function_state, + func, + ) + else: + # need the line below to make MyPy happy + assert not isinstance(func, LocalGeneratorObjectVariable) + tracer = InliningInstructionTranslator( + parent, + code, + sub_locals, + parent.symbolic_globals, + parent.symbolic_torch_function_state, + func, + ) + return tracer + + def inline_call_(self): + parent = self.parent + code = self.f_code + + strict_ctx: Any = contextlib.nullcontext() + if parent.strict_checks_fn: + strict_ctx = self.strict_translation_mode(parent.strict_checks_fn) + try: + with strict_ctx: + self.run() + except exc.ObservedException as e: + msg = f"Observed exception DURING INLING {code} : {e}" + log.debug(msg) + # bubble up the exception to the parent frame. + raise + except exc.SkipFrame as e: + msg = f"SKIPPED INLINING {code}: {e}" + log.debug(msg) + raise Unsupported(msg) from e + except Exception: + log.debug("FAILED INLINING %s", code) + raise + assert self.symbolic_result is not None + + if self.f_globals is parent.f_globals: + # Merge symbolic_globals back if parent and child are in the same namespace + parent.symbolic_globals.update(self.symbolic_globals) + + parent.inconsistent_side_effects |= self.inconsistent_side_effects + + log.debug("DONE INLINING %s", code) + self.output.tracing_context.traced_code.append(code) + + if config.enable_faithful_generator_behavior or ( + isinstance(self, InliningGeneratorInstructionTranslator) + and self.is_generator_from_ctx_manager + ): + if ( + is_generator(code) + and isinstance(self, InliningGeneratorInstructionTranslator) + and self.generator_exhausted + ): + assert isinstance(self, InliningGeneratorInstructionTranslator) + # When the generator returns None, we raise StopIteration + exc.raise_observed_exception(StopIteration, self) + else: + return self.symbolic_result + else: + if is_generator(code): + assert isinstance(self, InliningGeneratorInstructionTranslator) + assert self.symbolic_result.as_python_constant() is None + return ListIteratorVariable( + self.generated_items, + mutation_type=ValueMutationNew(), + ) + else: + return self.symbolic_result + + def __init__( + self, + parent: InstructionTranslatorBase, + code: types.CodeType, + symbolic_locals: dict[str, VariableTracker], + symbolic_globals: dict[str, VariableTracker], + symbolic_torch_function_state: SymbolicTorchFunctionState, + funcvar: BaseUserFunctionVariable, + ) -> None: + f_globals = funcvar.get_globals() # type: ignore[attr-defined] + f_builtins = f_globals["__builtins__"] + if not isinstance(f_builtins, dict): + f_builtins = f_builtins.__dict__ + + # Get the cached instructions. These instructions are safe to cache + # because we dont mutate them in transform_code_object (those + # instructions are for the top most Instruction translator). Also, we + # have to be careful about not using _cached_cleaned_instructions here + # because that function is global, while we want the the cache to be + # alive only during a compmilation. + tracing_ctx = parent.output.tracing_context + instructions = None + if tracing_ctx: + if tracing_ctx.previously_cleaned_instructions.get(code): + instructions = tracing_ctx.previously_cleaned_instructions[code] + + if instructions is None: + instructions = cleaned_instructions(code) + propagate_line_nums(instructions) + if tracing_ctx: + tracing_ctx.previously_cleaned_instructions[code] = instructions + + super().__init__( + output=parent.output, + f_locals={}, + f_globals=f_globals, + f_builtins=f_builtins, + symbolic_locals=symbolic_locals, + symbolic_globals=symbolic_globals, + symbolic_torch_function_state=symbolic_torch_function_state, + instructions=instructions, + code_options={k: getattr(code, k) for k in get_code_keys()}, + f_code=code, + export=parent.export, + inline_depth=parent.inline_depth + 1, + speculation_log=parent.speculation_log, + exn_vt_stack=parent.exn_vt_stack, + distributed_state=parent.distributed_state, + package=parent.package, + ) + self.funcvar = funcvar + self.parent = parent + self.num_calls = parent.num_calls + self.symbolic_result = None + self.nn_module_stack = parent.nn_module_stack.copy() + self.one_graph = parent.one_graph + + @property + def fake_mode(self): + return self.parent.fake_mode + + def run_ctx_mgr(self): + return TracingContext.current_frame(self.parent.frame_summary()) + + def should_compile_partial_graph(self): + return False # inlining functions is all-or-nothing + + def create_call_resume_at(self, inst, all_stack_locals_metadata): + unimplemented_v2( + gb_type="Graph break in inlined function", + context="", + explanation="Graph breaks in an inlined call are not supported.", + hints=[], + ) + + def RETURN_VALUE(self, inst): + self.symbolic_result = self.pop() # type: ignore[assignment] + self.instruction_pointer = None + raise ReturnValueOp + + def RETURN_CONST(self, inst): + self.symbolic_result = self._load_const(inst) + self.instruction_pointer = None + raise ReturnValueOp + + def get_globals_source_and_value(self, name): + # NamedTuple's `__new__` has a fake global scope that's not an actual + # module. TODO generalize the check for other non-importable cases. + # https://github.com/python/cpython/blob/8421b03b16a4852a527256cb7cdce2ab2d318548/Lib/collections/__init__.py#L441-L447 + if "__name__" in self.f_globals and not self.f_globals["__name__"].startswith( + "namedtuple_" + ): + module_name = self.f_globals["__name__"] + module_source = self.import_source(module_name) + if "torch_package" in module_name: + fglobals_value = ( + torch.package.package_importer._package_imported_modules[ + module_name + ] + ) # type: ignore[assignment] + else: + fglobals_value = _import_module(module_name) + # Dont use lazy vt because we will do a setattr afterwards + fglobals_vt = VariableBuilder(self, module_source)(fglobals_value) + global_source = AttrSource(module_source, name) + else: + globals_name = self.output.install_global_by_id( + "___unnamed_scope", self.f_globals + ) + globals_source = GlobalSource(globals_name) + fglobals_value = self.f_globals # type: ignore[assignment] + # Dont use lazy vt because we will do a setattr afterwards + fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value) + global_source = DictGetItemSource(globals_source, name) # type: ignore[assignment] + return fglobals_value, fglobals_vt, global_source + + def _load_global(self, inst): + name = inst.argval + if name not in self.f_globals: + return self.load_builtin(inst) + + if self.output.global_scope is self.f_globals: + # If the global scope matches that of the root frame, use handler in + # root frame instruction translator, to enforce consistency. + super()._load_global(inst) + else: + _, fglobals_vt, global_source = self.get_globals_source_and_value(name) + if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name): + self.push(self.output.side_effects.load_attr(fglobals_vt, name)) + else: + value = self.f_globals[name] + self.push(VariableTracker.build(self, value, global_source)) + + def STORE_GLOBAL(self, inst): + if self.output.global_scope is self.f_globals: + # If the global scope matches that of the root frame, use handler in + # root frame instruction translator, to enforce consistency. + super().STORE_GLOBAL(inst) + else: + value = self.pop() + if isinstance(value, RemovableHandleVariable): + unimplemented_v2( + gb_type="Storing Tensor hook handle in globals (inline call)", + context=inst.argval, + explanation="This is not supported.", + hints=[], + ) + name = inst.argval + _fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name) + self.output.side_effects.store_attr(fglobals_vt, name, value) + + +class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): + generated_items: list[VariableTracker] + # Flag whether or not the InlineGenerator should consume the entire iterator + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.generated_items = [] + self.generator_exhausted = False + self.is_generator_from_ctx_manager = False + + def YIELD_VALUE(self, inst: Instruction): + top = self.pop() + self.generated_items.append(top) + if len(self.generated_items) > MAX_ITERATOR_LIMIT: + raise exc.InfiniteGeneratorError( + "Too many yield values in generator. Maybe you are inlining an infinite generator. " + f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}", + ) + self.push(ConstantVariable.create(None)) + if ( + config.enable_faithful_generator_behavior + or self.is_generator_from_ctx_manager + ): + self.symbolic_result = top + # Stop tracing + raise YieldValueOp + + def GET_YIELD_FROM_ITER(self, inst): + tos = self.stack[-1] + if not isinstance(tos, ListIteratorVariable): + self.pop() + res = BuiltinVariable(iter).call_function(self, [tos], {}) # type: ignore[arg-type] + self.push(res) + + def RETURN_VALUE(self, inst): + self.generator_exhausted = True + return super().RETURN_VALUE(inst) + + def RETURN_CONST(self, inst): + self.generator_exhausted = True + return super().RETURN_CONST(inst) + + def YIELD_FROM(self, inst): + assert len(self.stack) >= 2 + val = self.pop() + tos = self.stack[-1] + if not (isinstance(val, ConstantVariable) and val.value is None): + # invoke send + # Unreachable code - if you hit this, you are implementing generator support and have + # lifted the `unimplemented("generator")` in frame conversion. This codepath handles + # subgenerator and lines up with this line in Python 3.10 + # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599 + unimplemented_v2( + gb_type="Unreachable sub-generator code", + context="", + explanation="Should only be encountered while implementing generator support.", + hints=[], + ) + + try: + val = tos.next_variable(self) + except (StopIteration, exc.ObservedUserStopIteration) as ex: + if isinstance(ex, exc.ObservedUserStopIteration): + exc.handle_observed_exception(self) + + # The iterator is exhausted. Stop the loop and return. + self.pop() + self.push(ConstantVariable.create(ex.value)) + else: + # Repeat the YIELD_FROM instruction in the next eval loop + assert ( + isinstance(self.instruction_pointer, int) + and self.instruction_pointer > 0 + ) + self.instruction_pointer -= 1 + + self.push(val) + # Add the value to yield into generated_items and replace the top of the stack with None + self.YIELD_VALUE(inst) + + def SEND(self, inst): + assert len(self.stack) >= 2 + val = self.pop() + tos = self.stack[-1] + if isinstance(tos, (ListIteratorVariable, LocalGeneratorObjectVariable)) or ( + isinstance(tos, UserDefinedObjectVariable) + and isinstance(tos.value, collections.abc.Iterator) + ): + if isinstance(val, ConstantVariable) and val.value is None: + try: + val = tos.next_variable(self) + except (StopIteration, exc.ObservedUserStopIteration) as ex: + # To implement SEND, we have to look at the implementation + # when the iterator returns StopIteration. This translates to this code + # 3.11: https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2619 + # 3.12: https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L866 + # The implementation is different in 3.11 and 3.12. In 3.12, we rely + # on END_SEND to clean up. In 3.11, SEND does the cleanup as well. + if sys.version_info < (3, 12): + self.pop() # Python 3.12 uses new opcode END_SEND + self.push(ConstantVariable.create(ex.value)) + self.jump(inst) + else: + self.push(val) + else: + # invoke send + # Unreachable code - if you hit this, you are implementing generator support and have + # lifted the `unimplemented("generator")` in frame conversion. This codepath handles + # subgenerator and lines up with this line in Python 3.11 + # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597 + unimplemented_v2( + gb_type="Unreachable sub-generator code", + context="", + explanation="Should only be encountered while implementing generator support.", + hints=[], + ) + else: + unimplemented_v2( + gb_type="SEND with bad type", + context=f"TOS type: {typestr(tos)}", + explanation=f"Attempted to SEND with unsupported type {typestr(tos)}.", + hints=[], + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/tensor_version_op.py b/phivenv/Lib/site-packages/torch/_dynamo/tensor_version_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7ca282df3565c420be7b50dc8393ba607e51f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/tensor_version_op.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs + +"""This module implements tensor version operations for Dynamo tracing. + +It provides primitives for handling tensor versioning during tracing, particularly in the +context of functionalization where version operations are handled eagerly on fake tensors. + +When we functionalize _tensor_version + _unsafe_set_version_counter, the ops disappear from +the traced graph. We run them eagerly on the fake tensors used for tracing, in order to get +past asserts that would fail in autograd. + +Why is this ok? +1) Versions on functional tensors do not make any sense since you cannot mutate a functional + tensor. +2) The whole point of version munging is to trick autograd into doing what we want, and after + AotAutograd there is no longer any need for these ops. + +Note this is similar to how no_grad is handled. +""" + +import torch +from torch._prims import _make_prim, RETURN_TYPE +from torch._subclasses import FakeTensorMode +from torch._subclasses.functional_tensor import FunctionalTensorMode + + +_tensor_version = _make_prim( + schema="_tensor_version(Tensor self) -> SymInt", + return_type=RETURN_TYPE.NEW, + meta=torch.ops.aten._version.default, + impl_aten=torch.ops.aten._version.default, + doc="Tracable unbacked SymInt version of torch.Tensor._version", +) + + +@_tensor_version.py_impl(FakeTensorMode) +def _tensor_version_fake(fake_mode, self_tensor): + """ + The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the + `._version` into an unbacked SymInt so that we don't need to specialize on the `._version` + of input tensors to the graph. + """ + return fake_mode.shape_env.create_unbacked_symint() + + +_unsafe_set_version_counter = _make_prim( + schema="_unsafe_set_version_counter(Tensor[] tensors, SymInt[] versions) -> ()", + return_type=RETURN_TYPE.NEW, + meta=lambda self, version: None, + impl_aten=torch._C._autograd._unsafe_set_version_counter, + doc="Tracable+SymInt version of torch._C._autograd._unsafe_set_version_counter", +) +torch.fx.node.has_side_effect(_unsafe_set_version_counter) + + +@_tensor_version.py_impl(FunctionalTensorMode) +def _tensor_version_functional(mode, self): + return self._version + + +@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) +def _unsafe_set_version_counter_functional(ctx, tensors, versions): + torch._C._autograd._unsafe_set_version_counter(tensors, versions) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/test_case.py b/phivenv/Lib/site-packages/torch/_dynamo/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc409b6a17abc96446f22c5ab7f63827649b37b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/test_case.py @@ -0,0 +1,204 @@ +# mypy: allow-untyped-defs + +"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality. + +This module extends PyTorch's testing framework with Dynamo-specific testing capabilities. +It includes: +- A custom TestCase class that handles Dynamo-specific setup/teardown +- Test running utilities with dependency checking +- Automatic reset of Dynamo state between tests +- Proper handling of gradient mode state +""" + +import contextlib +import importlib +import inspect +import logging +import os +import re +import sys +import unittest +from typing import Union + +import torch +import torch.testing +from torch._logging._internal import trace_log +from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + IS_WINDOWS, + TEST_WITH_CROSSREF, + TEST_WITH_TORCHDYNAMO, + TestCase as TorchTestCase, +) + +from . import config, reset, utils + + +log = logging.getLogger(__name__) + + +def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None: + from torch.testing._internal.common_utils import run_tests + + if TEST_WITH_TORCHDYNAMO or TEST_WITH_CROSSREF: + return # skip testing + + if ( + not torch.xpu.is_available() + and IS_WINDOWS + and os.environ.get("TORCHINDUCTOR_WINDOWS_TESTS", "0") == "0" + ): + return + + if isinstance(needs, str): + needs = (needs,) + for need in needs: + if need == "cuda": + if not torch.cuda.is_available(): + return + else: + try: + importlib.import_module(need) + except ImportError: + return + run_tests() + + +class TestCase(TorchTestCase): + _exit_stack: contextlib.ExitStack + + @classmethod + def tearDownClass(cls) -> None: + cls._exit_stack.close() + super().tearDownClass() + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined] + cls._exit_stack.enter_context( # type: ignore[attr-defined] + config.patch( + raise_on_ctx_manager_usage=True, + suppress_errors=False, + log_compilation_metrics=False, + ), + ) + + def setUp(self) -> None: + self._prior_is_grad_enabled = torch.is_grad_enabled() + super().setUp() + reset() + utils.counters.clear() + self.handler = logging.NullHandler() + trace_log.addHandler(self.handler) + + def tearDown(self) -> None: + trace_log.removeHandler(self.handler) + for k, v in utils.counters.items(): + print(k, v.most_common()) + reset() + utils.counters.clear() + super().tearDown() + if self._prior_is_grad_enabled is not torch.is_grad_enabled(): + log.warning("Running test changed grad mode") + torch.set_grad_enabled(self._prior_is_grad_enabled) + + +class CPythonTestCase(TestCase): + """ + Test class for CPython tests located in "test/dynamo/CPython/Py_version/*". + + This class enables specific features that are disabled by default, such as + tracing through unittest methods. + """ + + _stack: contextlib.ExitStack + dynamo_strict_nopython = True + + # Restore original unittest methods to simplify tracing CPython test cases. + assertEqual = unittest.TestCase.assertEqual # type: ignore[assignment] + assertNotEqual = unittest.TestCase.assertNotEqual # type: ignore[assignment] + assertTrue = unittest.TestCase.assertTrue + assertFalse = unittest.TestCase.assertFalse + assertIs = unittest.TestCase.assertIs + assertIsNot = unittest.TestCase.assertIsNot + assertIsNone = unittest.TestCase.assertIsNone + assertIsNotNone = unittest.TestCase.assertIsNotNone + assertIn = unittest.TestCase.assertIn + assertNotIn = unittest.TestCase.assertNotIn + assertIsInstance = unittest.TestCase.assertIsInstance + assertNotIsInstance = unittest.TestCase.assertNotIsInstance + assertAlmostEqual = unittest.TestCase.assertAlmostEqual + assertNotAlmostEqual = unittest.TestCase.assertNotAlmostEqual + assertGreater = unittest.TestCase.assertGreater + assertGreaterEqual = unittest.TestCase.assertGreaterEqual + assertLess = unittest.TestCase.assertLess + assertLessEqual = unittest.TestCase.assertLessEqual + assertRegex = unittest.TestCase.assertRegex + assertNotRegex = unittest.TestCase.assertNotRegex + assertCountEqual = unittest.TestCase.assertCountEqual + assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual + assertSequenceEqual = unittest.TestCase.assertSequenceEqual + assertListEqual = unittest.TestCase.assertListEqual + assertTupleEqual = unittest.TestCase.assertTupleEqual + assertSetEqual = unittest.TestCase.assertSetEqual + assertDictEqual = unittest.TestCase.assertDictEqual + assertRaises = unittest.TestCase.assertRaises + assertRaisesRegex = unittest.TestCase.assertRaisesRegex + assertWarns = unittest.TestCase.assertWarns + assertWarnsRegex = unittest.TestCase.assertWarnsRegex + assertLogs = unittest.TestCase.assertLogs + fail = unittest.TestCase.fail + failureException = unittest.TestCase.failureException + + def compile_fn(self, fn, backend, nopython): + # We want to compile only the test function, excluding any setup code + # from unittest + method = getattr(self, self._testMethodName) + method = torch._dynamo.optimize(backend, nopython=nopython)(method) + setattr(self, self._testMethodName, method) + return fn + + def _dynamo_test_key(self): + suffix = super()._dynamo_test_key() + test_cls = self.__class__ + test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0] + py_ver = re.search(r"/([\d_]+)/", inspect.getfile(test_cls)) + if py_ver: + py_ver = py_ver.group().strip(os.sep).replace("_", "") # type: ignore[assignment] + else: + return suffix + return f"CPython{py_ver}-{test_file}-{suffix}" + + @classmethod + def tearDownClass(cls) -> None: + cls._stack.close() + super().tearDownClass() + + @classmethod + def setUpClass(cls) -> None: + # Skip test if python versions doesn't match + prefix = os.path.join("dynamo", "cpython") + os.path.sep + regex = re.escape(prefix) + r"\d_\d{2}" + search_path = inspect.getfile(cls) + m = re.search(regex, search_path) + if m: + test_py_ver = tuple(map(int, m.group().removeprefix(prefix).split("_"))) + py_ver = sys.version_info[:2] + if py_ver < test_py_ver: + expected = ".".join(map(str, test_py_ver)) + got = ".".join(map(str, py_ver)) + raise unittest.SkipTest( + f"Test requires Python {expected} but got Python {got}" + ) + else: + raise unittest.SkipTest( + f"Test requires a specific Python version but not found in path {inspect.getfile(cls)}" + ) + + super().setUpClass() + cls._stack = contextlib.ExitStack() # type: ignore[attr-defined] + cls._stack.enter_context( # type: ignore[attr-defined] + config.patch( + enable_trace_unittest=True, + ), + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/test_dont_skip_tracing_functions.py b/phivenv/Lib/site-packages/torch/_dynamo/test_dont_skip_tracing_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..c032cf708b413dab1b5066bc92c782ae3df7f925 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/test_dont_skip_tracing_functions.py @@ -0,0 +1,40 @@ +""" +Functions used to test torch._dynamo.dont_skip_tracing. +This file is located in torch/_dynamo so that it is skipped by trace rules. +There is a special rule in trace_rules that doesn't skip this file when +dont_skip_tracing is active. +""" + +import torch + + +def f1(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + +def f2(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + +def f3(x: torch.Tensor) -> torch.Tensor: + return f2(x) + + +def f4(x: torch.Tensor) -> torch.Tensor: + x = f5(x, 1) + x = torch._dynamo.dont_skip_tracing(f6)(x) + x = f5(x, 8) + return x + + +def f5(x: torch.Tensor, n: int) -> torch.Tensor: + if torch.compiler.is_compiling(): + return x + n + return x + + +def f6(x: torch.Tensor) -> torch.Tensor: + x = f5(x, 2) + torch._dynamo.graph_break() + x = f5(x, 4) + return x diff --git a/phivenv/Lib/site-packages/torch/_dynamo/test_minifier_common.py b/phivenv/Lib/site-packages/torch/_dynamo/test_minifier_common.py new file mode 100644 index 0000000000000000000000000000000000000000..e82d252c1c9a2a6db7be352119e64df36c6b7a4a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/test_minifier_common.py @@ -0,0 +1,297 @@ +# mypy: allow-untyped-defs + +"""Common utilities for testing Dynamo's minifier functionality. + +This module provides the base infrastructure for running minification tests in Dynamo. +It includes: +- MinifierTestResult: A dataclass for storing and processing minifier test results +- MinifierTestBase: A base test class with utilities for: + - Running tests in isolated environments + - Managing temporary directories and configurations + - Executing minifier launcher scripts + - Running and validating reproduction scripts + - Supporting both compile-time and runtime error testing + +The minifier helps reduce failing Dynamo compilations to minimal reproductions. +""" + +import dataclasses +import io +import logging +import os +import re +import shutil +import subprocess +import sys +import tempfile +import traceback +from typing import Optional +from unittest.mock import patch + +import torch +import torch._dynamo +import torch._dynamo.test_case +from torch._dynamo.trace_rules import _as_posix_path +from torch.utils._traceback import report_compile_source_on_error + + +@dataclasses.dataclass +class MinifierTestResult: + minifier_code: str + repro_code: str + + def _get_module(self, t): + match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t) + assert match is not None, "failed to find module" + r = match.group(0) + r = re.sub(r"\s+$", "\n", r, flags=re.MULTILINE) + r = re.sub(r"\n{3,}", "\n\n", r) + return r.strip() + + def get_exported_program_path(self): + # Extract the exported program file path from AOTI minifier's repro.py + # Regular expression pattern to match the file path + pattern = r'torch\.export\.load\(\s*["\'](.*?)["\']\s*\)' + # Search for the pattern in the text + match = re.search(pattern, self.repro_code) + # Extract and print the file path if a match is found + if match: + file_path = match.group(1) + return file_path + return None + + def minifier_module(self): + return self._get_module(self.minifier_code) + + def repro_module(self): + return self._get_module(self.repro_code) + + +class MinifierTestBase(torch._dynamo.test_case.TestCase): + DEBUG_DIR = tempfile.mkdtemp() + + @classmethod + def setUpClass(cls): + super().setUpClass() + if not os.path.exists(cls.DEBUG_DIR): + cls.DEBUG_DIR = tempfile.mkdtemp() + cls._exit_stack.enter_context( # type: ignore[attr-defined] + torch._dynamo.config.patch(debug_dir_root=cls.DEBUG_DIR) + ) + # These configurations make new process startup slower. Disable them + # for the minification tests to speed them up. + cls._exit_stack.enter_context( # type: ignore[attr-defined] + torch._inductor.config.patch( + { + # https://github.com/pytorch/pytorch/issues/100376 + "pattern_matcher": False, + # multiprocess compilation takes a long time to warmup + "compile_threads": 1, + # https://github.com/pytorch/pytorch/issues/100378 + "cpp.vec_isa_ok": False, + } + ) + ) + + @classmethod + def tearDownClass(cls): + if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1": + shutil.rmtree(cls.DEBUG_DIR) + else: + print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}") + cls._exit_stack.close() # type: ignore[attr-defined] + + def _gen_codegen_fn_patch_code(self, device, bug_type): + assert bug_type in ("compile_error", "runtime_error", "accuracy") + return f"""\ +{torch._dynamo.config.codegen_config()} +{torch._inductor.config.codegen_config()} +torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r} +""" + + def _maybe_subprocess_run(self, args, *, isolate, cwd=None): + if not isolate: + assert len(args) >= 2, args + assert args[0] == "python3", args + if args[1] == "-c": + assert len(args) == 3, args + code = args[2] + args = ["-c"] + else: + assert len(args) >= 2, args + with open(args[1]) as f: + code = f.read() + args = args[1:] + + # WARNING: This is not a perfect simulation of running + # the program out of tree. We only interpose on things we KNOW we + # need to handle for tests. If you need more stuff, you will + # need to augment this appropriately. + + # NB: Can't use save_config because that will omit some fields, + # but we must save and reset ALL fields + dynamo_config = torch._dynamo.config.get_config_copy() + inductor_config = torch._inductor.config.get_config_copy() + try: + stderr = io.StringIO() + log_handler = logging.StreamHandler(stderr) + log = logging.getLogger("torch._dynamo") + log.addHandler(log_handler) + try: + prev_cwd = _as_posix_path(os.getcwd()) + if cwd is not None: + cwd = _as_posix_path(cwd) + os.chdir(cwd) + with patch("sys.argv", args), report_compile_source_on_error(): + exec(code, {"__name__": "__main__", "__compile_source__": code}) + rc = 0 + except Exception: + rc = 1 + traceback.print_exc(file=stderr) + finally: + log.removeHandler(log_handler) + if cwd is not None: + os.chdir(prev_cwd) # type: ignore[possibly-undefined] + # Make sure we don't leave buggy compiled frames lying + # around + torch._dynamo.reset() + finally: + torch._dynamo.config.load_config(dynamo_config) + torch._inductor.config.load_config(inductor_config) + + # TODO: return a more appropriate data structure here + return subprocess.CompletedProcess( + args, + rc, + b"", + stderr.getvalue().encode("utf-8"), + ) + else: + if cwd is not None: + cwd = _as_posix_path(cwd) + return subprocess.run(args, capture_output=True, cwd=cwd, check=False) + + # Run `code` in a separate python process. + # Returns the completed process state and the directory containing the + # minifier launcher script, if `code` outputted it. + def _run_test_code(self, code, *, isolate): + proc = self._maybe_subprocess_run( + ["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR + ) + + print("test stdout:", proc.stdout.decode("utf-8")) + print("test stderr:", proc.stderr.decode("utf-8")) + repro_dir_match = re.search( + r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") + ) + if repro_dir_match is not None: + return proc, repro_dir_match.group(1) + return proc, None + + # Runs the minifier launcher script in `repro_dir` + def _run_minifier_launcher( + self, repro_dir, isolate, *, minifier_args=(), repro_after=None + ): + self.assertIsNotNone(repro_dir) + launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) + with open(launch_file) as f: + launch_code = f.read() + self.assertTrue(os.path.exists(launch_file)) + + args = ["python3", launch_file, "minify", *minifier_args] + if not isolate and repro_after != "aot_inductor": + # AOTI minifier doesn't have --no-isolate flag. + # Everything in AOTI minifier is in no-isolate mode. + args.append("--no-isolate") + launch_proc = self._maybe_subprocess_run(args, isolate=isolate, cwd=repro_dir) + print("minifier stdout:", launch_proc.stdout.decode("utf-8")) + stderr = launch_proc.stderr.decode("utf-8") + print("minifier stderr:", stderr) + self.assertNotIn("Input graph did not fail the tester", stderr) + + return launch_proc, launch_code + + # Runs the repro script in `repro_dir` + def _run_repro(self, repro_dir, *, isolate=True): + self.assertIsNotNone(repro_dir) + repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) + with open(repro_file) as f: + repro_code = f.read() + self.assertTrue(os.path.exists(repro_file)) + + repro_proc = self._maybe_subprocess_run( + ["python3", repro_file], isolate=isolate, cwd=repro_dir + ) + print("repro stdout:", repro_proc.stdout.decode("utf-8")) + print("repro stderr:", repro_proc.stderr.decode("utf-8")) + return repro_proc, repro_code + + # Template for testing code. + # `run_code` is the code to run for the test case. + # `patch_code` is the code to be patched in every generated file; usually + # just use this to turn on bugs via the config + def _gen_test_code(self, run_code, repro_after, repro_level): + repro_after_line = "" + if repro_after == "aot_inductor": + repro_after_line = ( + "torch._inductor.config.aot_inductor.dump_aoti_minifier = True" + ) + elif repro_after: + repro_after_line = f"""\ +torch._dynamo.config.repro_after = "{repro_after}" + """ + return f"""\ +import torch +import torch._dynamo +import torch._inductor +{_as_posix_path(torch._dynamo.config.codegen_config())} +{_as_posix_path(torch._inductor.config.codegen_config())} +{repro_after_line} +torch._dynamo.config.repro_level = {repro_level} +torch._inductor.config.aot_inductor.repro_level = {repro_level} +torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}" +{run_code} +""" + + # Runs a full minifier test. + # Minifier tests generally consist of 3 stages: + # 1. Run the problematic code + # 2. Run the generated minifier launcher script + # 3. Run the generated repro script + # + # If possible, you should run the test with isolate=False; use + # isolate=True only if the bug you're testing would otherwise + # crash the process + def _run_full_test( + self, run_code, repro_after, expected_error, *, isolate, minifier_args=() + ) -> Optional[MinifierTestResult]: + if isolate: + repro_level = 3 + elif expected_error is None or expected_error == "AccuracyError": + repro_level = 4 + else: + repro_level = 2 + test_code = self._gen_test_code(run_code, repro_after, repro_level) + print("running test", file=sys.stderr) + test_proc, repro_dir = self._run_test_code(test_code, isolate=isolate) + if expected_error is None: + # Just check that there was no error + self.assertEqual(test_proc.returncode, 0) + self.assertIsNone(repro_dir) + return None + # NB: Intentionally do not test return code; we only care about + # actually generating the repro, we don't have to crash + self.assertIn(expected_error, test_proc.stderr.decode("utf-8")) + self.assertIsNotNone(repro_dir) + print("running minifier", file=sys.stderr) + _minifier_proc, minifier_code = self._run_minifier_launcher( + repro_dir, + isolate=isolate, + minifier_args=minifier_args, + repro_after=repro_after, + ) + print("running repro", file=sys.stderr) + repro_proc, repro_code = self._run_repro(repro_dir, isolate=isolate) + self.assertIn(expected_error, repro_proc.stderr.decode("utf-8")) + self.assertNotEqual(repro_proc.returncode, 0) + return MinifierTestResult(minifier_code=minifier_code, repro_code=repro_code) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/testing.py b/phivenv/Lib/site-packages/torch/_dynamo/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac094c7c1459255033c92a97c4af5de74bc90cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/testing.py @@ -0,0 +1,553 @@ +"""Testing utilities and infrastructure for Dynamo. + +This module provides a comprehensive set of testing utilities including: +- Test result collection and validation +- Graph manipulation and comparison tools +- Test case management and execution helpers +- Specialized test decorators for different Python versions and features +- RNG state management +- Compilation counting and monitoring +- Debug utilities for bytecode transformation + +The utilities in this module are used across Dynamo's test suite to ensure +consistent testing patterns and proper test isolation. +""" + +import contextlib +import dis +import functools +import logging +import os.path +import random +import re +import sys +import types +import unittest +from collections.abc import Sequence +from typing import Any, Callable, Optional, overload, TypeVar, Union +from typing_extensions import ParamSpec +from unittest.mock import patch + +import torch +from torch import fx +from torch._dynamo.backends.debugging import aot_eager +from torch._dynamo.output_graph import OutputGraph + +from . import config, eval_frame, optimize_assert, reset +from .bytecode_transformation import ( + create_instruction, + debug_checks, + is_generator, + transform_code_object, +) +from .guards import CheckFunctionManager, CompileId, GuardedCode +from .types import ConvertFrameReturn, DynamoFrameType, wrap_guarded_code +from .utils import same + + +np: Optional[types.ModuleType] = None +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +unsupported = eval_frame.unsupported +three = 3 + +log = logging.getLogger(__name__) + +_P = ParamSpec("_P") + + +def clone_me(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if x is None: + return None + return x.detach().clone().requires_grad_(x.requires_grad) + + +def remove_optimized_module_prefix(name: str) -> str: + return re.sub(r"^_orig_mod[.]", "", name) + + +def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-def] + from torch._dynamo.symbolic_convert import InstructionTranslator + + gm = None + region_tracker = None + + def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def] + nonlocal gm + nonlocal region_tracker + gm = _gm + region_tracker = InstructionTranslator.current_tx().output.region_tracker + return _gm + + torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs) + return gm.graph, region_tracker # type: ignore[union-attr] + + +def collect_results( + model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any +) -> list[Any]: + results = [] + results.append(prediction) + results.append(loss) + # if isinstance(loss, torch.Tensor) and loss.item() > 1: + # log.warning( + # f"High loss value alert - {loss:.2f}. Can result in unstable gradients." + # ) + + grads = {} + params = {} + for name, param in model.named_parameters(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) + param_copy = param + grad = param.grad + # Treat None and zero grad as same + if param.grad is None: + grad = torch.zeros_like(param) + grads[name + ".grad"] = grad + params[name] = param_copy + results.append(grads) + results.append(params) + buffers = {} + for name, buffer in model.named_buffers(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) + buffers[name] = buffer + results.append(buffers) + for example in example_inputs: + if isinstance(example, (tuple, list)): + results.extend(inp.grad for inp in example if isinstance(inp, torch.Tensor)) + else: + if isinstance(example, torch.Tensor): + results.append(example.grad) + return results + + +def requires_bwd_pass(out: Any) -> bool: + if isinstance(out, torch.Tensor): + return out.requires_grad + elif isinstance(out, (list, tuple)): + return any(requires_bwd_pass(x) for x in out) + elif out is None: + return False + elif isinstance(out, int): + return False + raise NotImplementedError("Don't know how to reduce", type(out)) + + +@overload +def reduce_to_scalar_loss(out: torch.Tensor) -> torch.Tensor: ... + + +@overload +def reduce_to_scalar_loss( + out: Union[list[Any], tuple[Any, ...], dict[Any, Any]], +) -> float: ... + + +def reduce_to_scalar_loss(out: Any) -> Union[torch.Tensor, float]: + """Reduce the output of a model to get scalar loss""" + if isinstance(out, torch.Tensor): + # Mean does not work on integer tensors + return out.sum() / out.numel() + elif isinstance(out, (list, tuple)): + return sum(reduce_to_scalar_loss(x) for x in out) / len(out) + elif type(out).__name__ in ( + "MaskedLMOutput", + "Seq2SeqLMOutput", + "CausalLMOutputWithCrossAttentions", + ): + return reduce_to_scalar_loss(out.logits) + elif type(out).__name__ == "SquashedNormal": + return out.mean.sum() + elif isinstance(out, dict): + return sum(reduce_to_scalar_loss(value) for value in out.values()) / len( + out.keys() + ) + raise NotImplementedError("Don't know how to reduce", type(out)) + + +def debug_dir() -> str: + path = os.path.join(os.path.dirname(__file__), "../debug") + if not os.path.exists(path): + os.mkdir(path) + return path + + +def debug_dump(name: str, code: types.CodeType, extra: str = "") -> None: + with open(os.path.join(debug_dir(), name), "w") as fd: + fd.write( + f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n" + ) + + +def debug_insert_nops( + frame: DynamoFrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0 +) -> ConvertFrameReturn: + """used to debug jump updates""" + + def insert_nops(instructions: list[Any], code_options: Any) -> None: + instructions.insert(0, create_instruction("NOP")) + instructions.insert(0, create_instruction("NOP")) + + metrics_context = torch._dynamo.utils.get_metrics_context() + with torch._dynamo.utils.dynamo_timed("debug_insert_nops"), metrics_context: + if is_generator(frame.f_code): + return ConvertFrameReturn() + + debug_checks(frame.f_code) + code = transform_code_object(frame.f_code, insert_nops) + graph = OutputGraph( + code_options={}, + compiler_fn=None, + root_tx=None, + export=False, + export_constraints=None, + frame_state={"_id": 0}, + # TODO: shouldn't this be f_locals/f_globals from frame? + local_scope=locals(), + global_scope=globals(), + f_code=frame.f_code, + torch_function_mode_stack=[], + package=None, + ) + + return wrap_guarded_code( + GuardedCode( + code, + CheckFunctionManager(frame.f_code, graph).guard_manager, # type: ignore[arg-type] + CompileId(frame_id=0, frame_compile_id=0), + ) + ) + + +class CompileCounter: + def __init__(self) -> None: + self.frame_count = 0 + self.op_count = 0 + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> Callable[..., Any]: + self.frame_count += 1 + for node in gm.graph.nodes: + if "call" in node.op: + self.op_count += 1 + return gm.forward + + def clear(self) -> None: + self.frame_count = 0 + self.op_count = 0 + + +class CompileCounterWithBackend: + def __init__(self, backend: str) -> None: + self.frame_count = 0 + self.op_count = 0 + self.backend = backend + self.graphs: list[torch.fx.GraphModule] = [] + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> Callable[..., Any]: + from .backends.registry import lookup_backend + + self.frame_count += 1 + for node in gm.graph.nodes: + if "call" in node.op: + self.op_count += 1 + self.graphs.append(gm) + return lookup_backend(self.backend)(gm, example_inputs) + + def clear(self) -> None: + self.frame_count = 0 + self.op_count = 0 + self.graphs = [] + + +# Equivalent to backend="eager", but also records graphs that +# we can assert on +class EagerAndRecordGraphs: + def __init__(self) -> None: + self.graphs: list[torch.fx.GraphModule] = [] + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> Callable[..., Any]: + self.graphs.append(gm) + return gm.forward + + +class AotEagerAndRecordGraphs: + def __init__(self) -> None: + self.graphs: list[torch.fx.GraphModule] = [] + self.fw_graphs: list[torch.fx.GraphModule] = [] + self.bw_graphs: list[torch.fx.GraphModule] = [] + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> Callable[..., Any]: + self.graphs.append(gm) + + def fw_compiler( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> Callable[..., Any]: + self.fw_graphs.append(gm) + return gm.forward + + def bw_compiler( + gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] + ) -> Callable[..., Any]: + self.bw_graphs.append(gm) + return gm.forward + + return aot_eager( + gm, + example_inputs, + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + ) + + +class InductorAndRecordGraphs: + def __init__(self) -> None: + self.graphs: list[torch.fx.GraphModule] = [] + self.inductor_graphs: list[torch.fx.GraphModule] = [] + + def __call__(self, gm, example_inputs): # type: ignore[no-untyped-def] + import torch._inductor.compile_fx as compile_fx_mod + + self.graphs.append(gm) + + old_compile_fx_inner = compile_fx_mod._compile_fx_inner + + def patched(*args, **kwargs): # type: ignore[no-untyped-def] + self.inductor_graphs.append(args[0]) + return old_compile_fx_inner(*args, **kwargs) + + with patch.object(compile_fx_mod, "_compile_fx_inner", new=patched): + return compile_fx_mod.compile_fx(gm, example_inputs) + + +def strip_comment(code: str) -> str: + return re.sub(r"(?m)^ *#.*\n?", "", code) + + +def remove_trailing_space(code: str) -> str: + return "\n".join([line.rstrip() for line in code.split("\n")]) + + +def normalize_gm(gm_str: str) -> str: + # strip comments as comments have path to files which may differ from + # system to system. + return remove_trailing_space(strip_comment(gm_str)) + + +def empty_line_normalizer(code: str) -> str: + """ + Normalize code: remove empty lines. + """ + normal_code = re.sub(r"[\r\n]+", "\n", code) + return normal_code + + +def standard_test( + self: Any, + fn: Callable[..., Any], + nargs: int, + expected_ops: Optional[int] = None, + expected_ops_dynamic: Optional[int] = None, + expected_frame_count: int = 1, +) -> None: + if not config.assume_static_by_default and expected_ops_dynamic is not None: + expected_ops = expected_ops_dynamic + + actual = CompileCounter() + + args1 = [torch.randn(10, 10) for _ in range(nargs)] + args2 = [torch.randn(10, 10) for _ in range(nargs)] + correct1 = fn(*args1) + correct2 = fn(*args2) + reset() + opt_fn = optimize_assert(actual)(fn) + val1a = opt_fn(*args1) + val2a = opt_fn(*args2) + val1b = opt_fn(*args1) + val2b = opt_fn(*args2) + reset() + self.assertTrue(same(val1a, correct1)) + self.assertTrue(same(val1b, correct1)) + self.assertTrue(same(val2a, correct2)) + self.assertTrue(same(val2b, correct2)) + self.assertEqual(actual.frame_count, expected_frame_count) + if expected_ops is not None: + self.assertEqual(actual.op_count, expected_ops) + + +def dummy_fx_compile( + gm: fx.GraphModule, example_inputs: list[torch.Tensor] +) -> Callable[..., Any]: + return gm.forward + + +def format_speedup( + speedup: float, + pvalue: float, + is_correct: bool = True, + pvalue_threshold: float = 0.1, +) -> str: + if not is_correct: + return "ERROR" + if pvalue > pvalue_threshold: + return f"{speedup:.3f}x SAME" + return f"{speedup:.3f}x p={pvalue:.2f}" + + +def rand_strided( + size: Sequence[int], + stride: Sequence[int], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + extra_size: int = 0, +) -> torch.Tensor: + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(size, stride)) + + 1 + + extra_size + ) + if dtype.is_floating_point: + if dtype.itemsize == 1: + """ + normal distribution kernel is not implemented for fp8.. + Workaround that by creating a fp16 tensor and then cast. + """ + buffer = torch.randn(needed_size, dtype=torch.float16, device=device).to( + dtype=dtype + ) + else: + buffer = torch.randn(needed_size, dtype=dtype, device=device) + else: + buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device) + return torch.as_strided(buffer, size, stride) + + +_T = TypeVar("_T") + + +def check_dynamic_shape_capture() -> bool: + # This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls` + return not config.assume_static_by_default + + +def _make_fn_with_patches(fn: Callable[_P, _T], *patches: Any) -> Callable[_P, _T]: + @functools.wraps(fn) + def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: + with contextlib.ExitStack() as stack: + for module, attr, val in patches: + stack.enter_context(patch.object(module, attr, val)) + + return fn(*args, **kwargs) + + return _fn + + +def make_test_cls_with_patches( + cls: type, + cls_prefix: str, + fn_suffix: str, + *patches: Any, + xfail_prop: Optional[str] = None, + decorator: Callable[[Callable[..., Any]], Callable[..., Any]] = lambda x: x, +) -> type: + DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {}) + DummyTestClass.__qualname__ = DummyTestClass.__name__ + + for name in dir(cls): + if name.startswith("test_"): + fn = getattr(cls, name) + if not callable(fn): + setattr(DummyTestClass, name, getattr(cls, name)) + continue + new_name = f"{name}{fn_suffix}" + new_fn = _make_fn_with_patches(fn, *patches) + new_fn.__name__ = new_name + if xfail_prop is not None and hasattr(fn, xfail_prop): + new_fn = unittest.expectedFailure(new_fn) + setattr(DummyTestClass, new_name, decorator(new_fn)) + # NB: Doesn't handle slots correctly, but whatever + elif not hasattr(DummyTestClass, name): + setattr(DummyTestClass, name, getattr(cls, name)) + + return DummyTestClass + + +# test Python 3.11+ specific features +def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if sys.version_info >= (3, 11): + return fn + return unittest.skip(fn) + + +def skipIfNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if sys.version_info >= (3, 12): + return fn + return unittest.skip("Requires Python 3.12+")(fn) + + +def xfailIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if sys.version_info >= (3, 12): + return unittest.expectedFailure(fn) + return fn + + +def skipIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if sys.version_info >= (3, 12): + return unittest.skip("Not supported in Python 3.12+")(fn) + return fn + + +def requiresPy310(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if sys.version_info >= (3, 10): + return fn + else: + return unittest.skip("Requires Python 3.10+")(fn) + + +# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py +# and test/dynamo/test_dynamic_shapes.py +def expectedFailureDynamic(fn: Callable[_P, _T]) -> Callable[_P, _T]: + fn._expected_failure_dynamic = True # type: ignore[attr-defined] + return fn + + +# Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py +def expectedFailureCodegenDynamic(fn: Callable[_P, _T]) -> Callable[_P, _T]: + fn._expected_failure_codegen_dynamic = True # type: ignore[attr-defined] + return fn + + +# Controls test generated in test/inductor/test_cpp_wrapper.py +def expectedFailureDynamicWrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]: + fn._expected_failure_dynamic_wrapper = True # type: ignore[attr-defined] + return fn + + +def reset_rng_state(use_xla: bool = False) -> None: + torch.manual_seed(1337) + random.seed(1337) + if np: + np.random.seed(1337) + if use_xla: + import torch_xla.core.xla_model as xm + + xm.set_rng_state(1337, str(xm.xla_device())) + + +def _skipped_function_for_test_reconstruct( + f: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs +) -> _T: + return f(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/trace_rules.py b/phivenv/Lib/site-packages/torch/_dynamo/trace_rules.py new file mode 100644 index 0000000000000000000000000000000000000000..423640b755749c61ecb74ab1e052fba003e4bc0a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/trace_rules.py @@ -0,0 +1,3983 @@ +# mypy: allow-untyped-defs + +""" +Tracing rules and policies for TorchDynamo compilation decisions. + +This module defines the rules that govern what code TorchDynamo should trace and compile +versus what should be executed eagerly. It contains functions and classes that determine: + +- Which modules, functions, and objects should be skipped during tracing +- Which parts of the code should cause graph breaks +- How to handle different Python libraries and third-party packages +- Rules for determining when to inline functions vs calling them eagerly + +Key components: +- Skip rules: Functions that return True if an object should be skipped during tracing +- Inlining rules: Policies for when to inline function calls during compilation +- Library-specific handling: Special cases for popular Python packages +- Performance heuristics: Rules that balance compilation overhead vs runtime benefits + +These rules are critical for TorchDynamo's ability to automatically determine +compilation boundaries and optimize PyTorch programs effectively. +""" + +import abc +import builtins +import collections +import copy +import dataclasses +import functools +import importlib +import inspect +import linecache +import operator +import os +import random +import re +import sys +import traceback +import types +import typing +import unittest +from collections import defaultdict +from pathlib import Path +from typing import Any, Callable, cast, Optional, Union + +import torch +import torch._inductor.test_operators +import torch.distributed +import torch.utils._content_store +from torch._environment import is_fbcode +from torch.utils import _config_module + +from . import config +from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX +from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper +from .variables import ( + BuiltinVariable, + FunctionalCallVariable, + FunctorchHigherOrderVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, + NestedUserFunctionVariable, + PolyfilledFunctionVariable, + SkipFunctionVariable, + TorchInGraphFunctionVariable, + UserFunctionVariable, + UserMethodVariable, +) + + +np: Optional[types.ModuleType] = None +try: + import numpy as np +except ModuleNotFoundError: + pass + + +if typing.TYPE_CHECKING: + from .variables.base import VariableTracker + + +""" +A note on skip/inline rules: + +Dynamo consults this file to determine whether function should be inlined or skipped. + +A skip applies at the frame boundary, meaning dynamo either triggers a graph break +at the beginning of the frame or attempts to trace/inline the whole frame. When skipping +a frame, recursively called frames are still traced by dynamo unless also skipped. + +Skipfiles (skipped at the file level instead of function level) still apply on a +frame-by-frame boundary as dynamo traces, but apply to all functions in that file. + +@skip is a helper decorator that can be applied to your function to cause it to be +included here. + +Dynamo skip/inline rules & priorities are defined as follows: +* Inline is the default behavior and will be used unless explicitly skipped. +* Dynamo has two SKIPLIST: BUILTIN_SKIPLIST and THIRDPARTY_SKIPLIST. + * BUILTIN_SKIPLIST contains builtin python modules, such as abc, collections, etc. + * THIRDPARTY_SKIPLIST contains common third party libraries, such as numpy, pandas, etc. +* Functions in these two SKIPLISTs are always skipped, except: + * They have explicitly defined rule in `manual_torch_name_rule_map`; + * The corresponding python module has been put into MOD_INLINELIST. +* PyTorch(torch) is in the BUILTIN_SKIPLIST by default, but there are many cases + where we want inline the functions under torch namespace. + We should specify inline for the functions in `manual_torch_name_rule_map` or + put the corresponding python module into MOD_INLINELIST to make dynamo inline them. +* If you call functions under skipped modules/files, Dynamo will wrap these functions + as SkipFunctionVariable. There are a few functions(e.g, collections.OrderedDict) that + we have special handling at SkipFunctionVariable.call_function. + +Overall: *_INLINELIST has precedence over *_SKIPLIST has precedence over DEFAULT (inline) + +To figure out what the behavior is, check the following list in order: +* `manual_torch_name_rule_map` (Inline if YES) +* MOD_INLINELIST (Inline if YES) +* BUILTIN_SKIPLIST & THIRDPARTY_SKIPLIST (Skip if YES) +* MOD_SKIPLIST (Skip if YES) +* Inline by default + +In general, if you want to force inline a function or module, please consider adding +the function's python module to MOD_INLINELIST first. +Use the `manual_torch_name_rule_map` only when there are other functions under the same module that +you don't want to inline them. +""" + +""" +Map of function objects to their tracing rules (Dynamo variables). +* TorchInGraphFunctionVariable: The functions should be put into the FX graph or can be constant folded. E.g., + - torch.add: should be put into the FX graph. + - torch.is_floating_point: constant folded. +* SkipFunctionVariable: The objects should be skipped from tracing. +* UserFunctionVariable: The functions should be inlined. + +For developers: If you add/remove a torch level API, it may trigger failures from +test/dynamo/test_trace_rules.py:test_torch_name_rule_map_updated. To fix the failures: +If you are adding a new torch level API or Dynamo implementation: +* Add the name with the corresponding tracing rule to this map + if you are adding a new in graph function or Dynamo implementation for an existing function. +* Remove the object name from test/dynamo/test_trace_rules.ignored_c_binding_in_graph_function_names if it's there. + +If you are removing an existing torch level API: +* Remove the entry represented the API from this map or test/dynamo/test_trace_rules.ignored_c_binding_in_graph_function_names + depends on where it is. + + +""" +manual_torch_name_rule_map: dict[str, Any] = { + "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, + "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable, + "torch.overrides.is_tensor_like": TorchInGraphFunctionVariable, + "torch.jit.is_scripting": TorchInGraphFunctionVariable, + "torch.jit.is_tracing": TorchInGraphFunctionVariable, + "torch.jit.annotate": TorchInGraphFunctionVariable, + "torch.distributed.is_available": TorchInGraphFunctionVariable, + "torch.distributed.is_initialized": TorchInGraphFunctionVariable, + "torch.distributed.get_rank": TorchInGraphFunctionVariable, + "torch.distributed.get_world_size": TorchInGraphFunctionVariable, + "torch.distributed.tensor._api.DTensor#from_local": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d._resolve_group_name_by_ranks_and_tag": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable, + "torch.distributed.distributed_c10d.get_process_group_ranks": TorchInGraphFunctionVariable, + "torch._utils.is_compiling": TorchInGraphFunctionVariable, + "torch.fx._symbolic_trace.is_fx_tracing": TorchInGraphFunctionVariable, + "torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable, + "torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer": UserFunctionVariable, + "torch.compiler.is_compiling": TorchInGraphFunctionVariable, + "torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable, + "torch.compiler.is_exporting": TorchInGraphFunctionVariable, + "torch.autograd._profiler_enabled": SkipFunctionVariable, + "torch._C._to_dlpack": SkipFunctionVariable, + "torch.to_dlpack": SkipFunctionVariable, + # We graph break on RNG state setters or getters like + # `torch.get_rng_state` or `torch.set_rng_state`. These functions + # are not aten operations and therefore they are completely ignored + # by the AOT dispatcher. As a result, the AOT graph does not have + # these setter or getter functions, producing an incorrect graph + # when it comes to rng states. + "torch.default_generator#get_state": SkipFunctionVariable, + "torch._C.Generator#get_state": SkipFunctionVariable, + "torch.get_rng_state": SkipFunctionVariable, + "torch.cuda.get_rng_state": SkipFunctionVariable, + "torch.default_generator#set_state": SkipFunctionVariable, + "torch._C.Generator#set_state": SkipFunctionVariable, + "torch.set_rng_state": SkipFunctionVariable, + "torch.cuda.set_rng_state": SkipFunctionVariable, + # https://github.com/pytorch/pytorch/issues/107187 + "torch.manual_seed": SkipFunctionVariable, + # https://github.com/pytorch/pytorch/issues/93501 + "torch.nn.utils.rnn.pack_padded_sequence": SkipFunctionVariable, + "torch.nn.Parameter": TorchInGraphFunctionVariable, + "torch.nn.Buffer": TorchInGraphFunctionVariable, + "torch._nested_tensor_from_mask": SkipFunctionVariable, + "torch.nested._internal.nested_tensor.nested_from_padded": TorchInGraphFunctionVariable, + "torch.nested.nested_tensor_from_jagged": UserFunctionVariable, + "torch.nested.nested_tensor_from_padded": UserFunctionVariable, + # torch.fx map utils + "torch.fx.node.map_aggregate": UserFunctionVariable, + "torch.fx.node.map_arg": UserFunctionVariable, + "torch.fx.immutable_collections._no_mutation": UserFunctionVariable, + # symbol operators implemented in Python + "torch.sym_not": TorchInGraphFunctionVariable, + "torch.sym_float": TorchInGraphFunctionVariable, + "torch.sym_int": TorchInGraphFunctionVariable, + "torch.sym_max": TorchInGraphFunctionVariable, + "torch.sym_min": TorchInGraphFunctionVariable, + "torch.sym_sqrt": TorchInGraphFunctionVariable, + "torch.sym_ite": TorchInGraphFunctionVariable, + "torch.sym_sum": TorchInGraphFunctionVariable, + "torch.sym_fresh_size": UserFunctionVariable, + "torch.Tensor#_make_wrapper_subclass": SkipFunctionVariable, + "torch.Tensor#__init__": SkipFunctionVariable, + "torch.Tensor#split": TorchInGraphFunctionVariable, + "torch.cuda.set_device": SkipFunctionVariable, + "torch.cuda.current_device": TorchInGraphFunctionVariable, + "torch._C.autocast_decrement_nesting": SkipFunctionVariable, + "torch._C.autocast_increment_nesting": SkipFunctionVariable, + "torch.autograd.grad": SkipFunctionVariable, + "torch.autograd.backward": SkipFunctionVariable, + "torch._C.clear_autocast_cache": SkipFunctionVariable, + "torch.distributions.constraints.is_dependent": SkipFunctionVariable, + "torch.jit.isinstance": SkipFunctionVariable, + "torch._C.set_anomaly_enabled": SkipFunctionVariable, + "torch._C.set_autocast_cache_enabled": SkipFunctionVariable, + "torch._C.set_autocast_cpu_dtype": SkipFunctionVariable, + "torch._C.set_autocast_cpu_enabled": SkipFunctionVariable, + "torch._C.set_autocast_enabled": SkipFunctionVariable, + "torch._C.set_autocast_gpu_dtype": SkipFunctionVariable, + "torch._C.set_autocast_ipu_dtype": SkipFunctionVariable, + "torch._C.set_autocast_ipu_enabled": SkipFunctionVariable, + "torch._C.set_autocast_xla_dtype": SkipFunctionVariable, + "torch._C.set_autocast_xla_enabled": SkipFunctionVariable, + "torch.resize_as_": SkipFunctionVariable, + "torch.resize_as_sparse_": SkipFunctionVariable, + "torch.get_default_device": TorchInGraphFunctionVariable, + # functorch/vmap + "torch._functorch.vmap._check_int_or_none": UserFunctionVariable, + "torch._functorch.vmap._check_out_dims_is_int_or_int_pytree": UserFunctionVariable, + "torch._functorch.vmap._check_randomness_arg": UserFunctionVariable, + "torch._functorch.vmap._chunked_vmap": UserFunctionVariable, + "torch._functorch.vmap._concat_chunked_outputs": UserFunctionVariable, + "torch._functorch.vmap._create_batched_inputs": UserFunctionVariable, + "torch._functorch.vmap._flat_vmap": UserFunctionVariable, + "torch._functorch.vmap._flatten_chunks_output": UserFunctionVariable, + "torch._functorch.vmap._get_chunked_inputs": UserFunctionVariable, + "torch._functorch.vmap._get_name": UserFunctionVariable, + "torch._functorch.vmap._maybe_remove_batch_dim": UserFunctionVariable, + "torch._functorch.vmap._num_outputs": UserFunctionVariable, + "torch._functorch.vmap._process_batched_inputs": UserFunctionVariable, + "torch._functorch.vmap._unwrap_batched": UserFunctionVariable, + "torch._functorch.vmap._validate_and_get_batch_size": UserFunctionVariable, + "torch._functorch.vmap.doesnt_support_saved_tensors_hooks": UserFunctionVariable, + "torch._functorch.vmap.get_chunk_sizes": UserFunctionVariable, + # lazy_load_decompositions uses a lock that is not supported yet in dynamo + # "torch._functorch.vmap.lazy_load_decompositions": UserFunctionVariable, + "torch._functorch.vmap.restore_vmap": UserFunctionVariable, + "torch._functorch.apis.vmap": UserFunctionVariable, + "torch._functorch.vmap.unwrap_batched": UserFunctionVariable, + "torch._functorch.vmap.vmap_impl": FunctorchHigherOrderVariable, + "torch._functorch.vmap.wrap_batched": UserFunctionVariable, + # functorch/grad + "torch._functorch.eager_transforms.grad_impl": FunctorchHigherOrderVariable, + "torch._functorch.apis.grad_and_value": UserFunctionVariable, + "torch._functorch.eager_transforms._as_tuple": UserFunctionVariable, + "torch._functorch.eager_transforms._check_unique_non_empty": UserFunctionVariable, + "torch._functorch.eager_transforms._create_differentiable": UserFunctionVariable, + "torch._functorch.eager_transforms._slice_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms._undo_create_differentiable": UserFunctionVariable, + "torch._functorch.eager_transforms._validate_and_wrap_argnum": UserFunctionVariable, + "torch._functorch.eager_transforms._validate_and_wrap_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms._wrap_all_tensors": UserFunctionVariable, + "torch._functorch.eager_transforms._wrap_tensor_for_grad": UserFunctionVariable, + # functorch/jacrev + "torch._functorch.eager_transforms.jacrev": FunctorchHigherOrderVariable, + "torch._functorch.eager_transforms.error_if_complex": UserFunctionVariable, + "torch._functorch.eager_transforms._chunked_standard_basis_for_": UserFunctionVariable, + "torch._functorch.eager_transforms._safe_zero_index": UserFunctionVariable, + # functorch/vjp + "torch._functorch.eager_transforms.vjp": FunctorchHigherOrderVariable, + "torch._functorch.eager_transforms._vjp_with_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms.assert_non_empty_tensor_output": UserFunctionVariable, + # functorch/jvp + "torch._functorch.eager_transforms._jvp_with_argnums": UserFunctionVariable, + "torch._functorch.eager_transforms.jvp": FunctorchHigherOrderVariable, + "torch._functorch.eager_transforms._replace_args": UserFunctionVariable, + "torch._functorch.eager_transforms.safe_unpack_dual": UserFunctionVariable, + "torch._functorch.eager_transforms.assert_non_empty_list_of_tensors": UserFunctionVariable, + "torch._functorch.eager_transforms.assert_output_is_tensor_or_tensors": UserFunctionVariable, + "torch.autograd.forward_ad.enter_dual_level": UserFunctionVariable, + "torch.autograd.forward_ad.exit_dual_level": UserFunctionVariable, + "torch.autograd.forward_ad.make_dual": UserFunctionVariable, + "torch.autograd.forward_ad.unpack_dual": UserFunctionVariable, + # functorch/linearize + "torch._functorch.eager_transforms.linearize": FunctorchHigherOrderVariable, + # functorch/jacfwd + "torch._functorch.eager_transforms.jacfwd": FunctorchHigherOrderVariable, + "torch._functorch.eager_transforms._construct_standard_basis_for": UserFunctionVariable, + "torch._functorch.eager_transforms.safe_unflatten": UserFunctionVariable, + # functorch/hessian + "torch._functorch.eager_transforms.hessian": FunctorchHigherOrderVariable, + # functional_call + "torch._functorch.functional_call.functional_call": FunctionalCallVariable, + "torch.nn.utils.stateless._groupby_tensor": TorchInGraphFunctionVariable, + # functorch/deprecated + "torch._functorch.deprecated.jvp": UserFunctionVariable, + "torch._functorch.deprecated.hessian": UserFunctionVariable, + "torch._functorch.deprecated.jacfwd": UserFunctionVariable, + "torch._functorch.deprecated.jacrev": UserFunctionVariable, + "torch._functorch.deprecated.grad": UserFunctionVariable, + "torch._functorch.deprecated.grad_and_value": UserFunctionVariable, + "torch._functorch.deprecated.vjp": UserFunctionVariable, + # functorch/C++ bindings + "torch._C._functorch._add_batch_dim": TorchInGraphFunctionVariable, + "torch._C._functorch._remove_batch_dim": TorchInGraphFunctionVariable, + "torch._C._functorch._wrap_for_grad": TorchInGraphFunctionVariable, + "torch._C._functorch._unwrap_for_grad": TorchInGraphFunctionVariable, + "torch._C._functorch._unwrap_batched": TorchInGraphFunctionVariable, + "torch._C._functorch.current_level": TorchInGraphFunctionVariable, + "torch._C._functorch.maybe_current_level": TorchInGraphFunctionVariable, + "torch._C._functorch.is_batchedtensor": TorchInGraphFunctionVariable, + "torch._C._functorch.peek_interpreter_stack": TorchInGraphFunctionVariable, + "torch._C._functorch.unwrap_if_dead": TorchInGraphFunctionVariable, + # everything else + "torch._functorch.pyfunctorch.coerce_cinterpreter": TorchInGraphFunctionVariable, + "torch._higher_order_ops.triton_kernel_wrap.do_prune_configs": UserFunctionVariable, + "torch._higher_order_ops.foreach_map.foreach_map": UserFunctionVariable, + "torch._constrain_as_size": UserFunctionVariable, + "torch._tensor._convert": UserFunctionVariable, + "torch.jit._unwrap_optional": UserFunctionVariable, + "torch.backends.mha.get_fastpath_enabled": UserFunctionVariable, + "torch._dynamo.dont_skip_tracing": UserFunctionVariable, + "torch._dynamo.mark_static": UserFunctionVariable, + "torch._dynamo.nonstrict_trace": UserFunctionVariable, + "torch._dynamo.patch_dynamo_config": UserFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.statically_known_true": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.statically_known_false": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.sym_and": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.sym_or": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_scalar": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.has_static_value": TorchInGraphFunctionVariable, + "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, + "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, + "torch.set_default_device": UserFunctionVariable, + "torch.sparse_bsc_tensor": SkipFunctionVariable, + "torch.sparse_bsr_tensor": SkipFunctionVariable, + "torch.sparse_csc_tensor": SkipFunctionVariable, + "torch.sparse_csr_tensor": SkipFunctionVariable, + "torch.sparse_compressed_tensor": SkipFunctionVariable, + "torch._C._autograd._unsafe_set_version_counter": TorchInGraphFunctionVariable, + "torch.xpu.get_rng_state": SkipFunctionVariable, + "torch.xpu.set_rng_state": SkipFunctionVariable, + # avoid skipping user defined modules in distributed unit tests + "torch/testing/_internal/common_fsdp.py#forward": UserFunctionVariable, + f"torch/testing/_internal/common_fsdp.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, + "torch/testing/_internal/distributed/_tensor/common_dtensor.py#forward": UserFunctionVariable, + f"torch/testing/_internal/distributed/_tensor/common_dtensor.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, + "torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable, + f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, +} + + +# In graph functions (including constant folding) that are C bindings +torch_c_binding_in_graph_functions = dict.fromkeys( + [ + "math.acos", + "math.acosh", + "math.asin", + "math.asinh", + "math.atan", + "math.atan2", + "math.atanh", + "math.ceil", + "math.comb", + "math.copysign", + "math.cos", + "math.cosh", + "math.degrees", + "math.dist", + "math.erf", + "math.erfc", + "math.exp", + "math.expm1", + "math.fabs", + "math.factorial", + "math.floor", + "math.fmod", + "math.frexp", + "math.fsum", + "math.gamma", + "math.gcd", + "math.hypot", + "math.isclose", + "math.isfinite", + "math.isinf", + "math.isnan", + "math.isqrt", + "math.lcm", + "math.ldexp", + "math.lgamma", + "math.log", + "math.log10", + "math.log1p", + "math.log2", + "math.modf", + "math.nextafter", + "math.perm", + "math.pow", + "math.prod", + "math.radians", + "math.remainder", + "math.sin", + "math.sinh", + "math.tan", + "math.tanh", + "math.trunc", + "math.ulp", + "torch._adaptive_avg_pool2d", + "torch._adaptive_avg_pool3d", + "torch._add_batch_dim", + "torch._add_relu_", + "torch._add_relu", + "torch._addmm_activation", + "torch._aminmax", + "torch._amp_foreach_non_finite_check_and_unscale_", + "torch._amp_update_scale_", + "torch._assert_async", + "torch._assert_tensor_metadata", + "torch._batch_norm_impl_index", + "torch._C._accelerator_getAccelerator", + "torch._C._accelerator_getDeviceIndex", + "torch._C._accelerator_getStream", + "torch._C._accelerator_setStream", + "torch._C._accelerator_synchronizeDevice", + "torch._C._activate_gpu_trace", + "torch._C._add_cached_tensor", + "torch._C._add_docstr", + "torch._C._are_functorch_transforms_active", + "torch._C._autograd_init", + "torch._C._awaitable_nowait", + "torch._C._awaitable_wait", + "torch._C._awaitable", + "torch._C._backport_for_mobile_from_buffer_to_buffer", + "torch._C._backport_for_mobile_from_buffer", + "torch._C._backport_for_mobile_to_buffer", + "torch._C._backport_for_mobile", + "torch._C._broadcast_coalesced", + "torch._C._broadcast_out", + "torch._C._broadcast", + "torch._C._c10d_init", + "torch._C._calculate_package_version_based_on_upgraders", + "torch._C._can_use_flash_attention", + "torch._C._can_use_mem_efficient_attention", + "torch._C._can_use_cudnn_attention", + "torch._C._check_onnx_proto", + "torch._C._check_sparse_tensor_invariants", + "torch._C._collect_all", + "torch._C._commit_update", + "torch._C._compile_graph_to_code_table", + "torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata", + "torch._C._construct_storage_from_data_pointer", + "torch._C._conv_determine_backend_memory_format", + "torch._C._cpu._is_avx2_supported", + "torch._C._cpu._is_avx512_supported", + "torch._C._cpu._is_avx512_vnni_supported", + "torch._C._cpu._is_avx512_bf16_supported", + "torch._C._cpu._is_amx_tile_supported", + "torch._C._cpu._is_amx_fp16_supported", + "torch._C._cpu._init_amx", + "torch._C._crash_if_aten_asan", + "torch._C._crash_if_csrc_asan", + "torch._C._crash_if_csrc_ubsan", + "torch._C._crash_if_debug_asserts_fail", + "torch._C._crash_if_vptr_ubsan", + "torch._C._create_function_from_graph", + "torch._C._create_function_from_trace_with_dict", + "torch._C._create_function_from_trace", + "torch._C._create_graph_by_tracing", + "torch._C._create_module_with_type", + "torch._C._create_object_with_type", + "torch._C._cuda_attach_out_of_memory_observer", + "torch._C._cuda_beginAllocateCurrentStreamToPool", + "torch._C._cuda_canDeviceAccessPeer", + "torch._C._cuda_changeCurrentAllocator", + "torch._C._cuda_checkPoolLiveAllocations", + "torch._C._cuda_clearCublasWorkspaces", + "torch._C._cuda_cudaCachingAllocator_raw_alloc", + "torch._C._cuda_cudaCachingAllocator_raw_delete", + "torch._C._cuda_cudaCachingAllocator_set_allocator_settings", + "torch._C._cuda_cudaHostAllocator", + "torch._C._cuda_customAllocator", + "torch._C._cuda_emptyCache", + "torch._C._cuda_endAllocateToPool", + "torch._C._cuda_exchangeDevice", + "torch._C._cuda_get_conv_benchmark_empty_cache", + "torch._C._cuda_get_cudnn_benchmark_limit", + "torch._C._cuda_get_sync_debug_mode", + "torch._C._cuda_getAllocator", + "torch._C._cuda_getAllocatorBackend", + "torch._C._cuda_getArchFlags", + "torch._C._cuda_getCheckpointState", + "torch._C._cuda_getCompiledVersion", + "torch._C._cuda_getCurrentBlasHandle", + "torch._C._cuda_getCurrentRawStream", + "torch._C._cuda_getCurrentStream", + "torch._C._cuda_getDefaultStream", + "torch._C._cuda_getDevice", + "torch._C._cuda_getDeviceCount", + "torch._C._cuda_hasPrimaryContext", + "torch._C._cuda_hostMemoryStats", + "torch._C._cuda_init", + "torch._C._cuda_ipc_collect", + "torch._C._cuda_isCurrentStreamCapturing", + "torch._C._cuda_isHistoryEnabled", + "torch._C._cuda_isInBadFork", + "torch._C._cuda_jiterator_compile_and_launch_kernel", + "torch._C._cuda_lock_mutex", + "torch._C._cuda_maybeExchangeDevice", + "torch._C._cuda_memorySnapshot", + "torch._C._cuda_memoryStats", + "torch._C._cuda_record_memory_history_legacy", + "torch._C._cuda_record_memory_history", + "torch._C._cuda_releasePool", + "torch._C._cuda_resetAccumulatedHostMemoryStats", + "torch._C._cuda_resetAccumulatedMemoryStats", + "torch._C._cuda_resetPeakHostMemoryStats", + "torch._C._cuda_resetPeakMemoryStats", + "torch._C._cuda_set_cudnn_benchmark_limit", + "torch._C._cuda_set_sync_debug_mode", + "torch._C._cuda_setCheckpointPoolState", + "torch._C._cuda_setDevice", + "torch._C._cuda_setMemoryFraction", + "torch._C._cuda_setStream", + "torch._C._cuda_sleep", + "torch._C._cuda_synchronize", + "torch._C._cuda_unlock_mutex", + "torch._C._cudnn_set_conv_benchmark_empty_cache", + "torch._C._cudnn.getCompileVersion", + "torch._C._cudnn.getRuntimeVersion", + "torch._C._cudnn.getVersionInt", + "torch._C._current_autograd_node", + "torch._C._current_graph_task_execution_order", + "torch._C._current_graph_task_id", + "torch._C._cxx_flags", + "torch._C._debug_get_fusion_group_inlining", + "torch._C._debug_only_are_vmap_fallback_warnings_enabled", + "torch._C._debug_only_display_vmap_fallback_warnings", + "torch._C._debug_set_autodiff_subgraph_inlining", + "torch._C._debug_set_fusion_group_inlining", + "torch._C._demangle", + "torch._C._disabled_torch_dispatch_impl", + "torch._C._dispatch_call_boxed", + "torch._C._dispatch_check_all_invariants", + "torch._C._dispatch_check_invariants", + "torch._C._dispatch_dump_table", + "torch._C._dispatch_dump", + "torch._C._dispatch_find_dangling_impls", + "torch._C._dispatch_find_schema_or_throw", + "torch._C._dispatch_get_all_op_names", + "torch._C._dispatch_get_backend_keyset_from_autograd", + "torch._C._dispatch_get_registrations_for_dispatch_key", + "torch._C._dispatch_has_backend_fallback", + "torch._C._dispatch_has_computed_kernel_for_dispatch_key", + "torch._C._dispatch_has_kernel_for_any_dispatch_key", + "torch._C._dispatch_has_kernel_for_dispatch_key", + "torch._C._dispatch_has_kernel", + "torch._C._dispatch_is_alias_key", + "torch._C._dispatch_is_included_in_alias", + "torch._C._dispatch_is_main_interpreter", + "torch._C._dispatch_isTensorSubclassLike", + "torch._C._dispatch_key_for_device", + "torch._C._dispatch_key_name", + "torch._C._dispatch_key_parse", + "torch._C._dispatch_key_set", + "torch._C._dispatch_keys", + "torch._C._dispatch_keyset_full_after", + "torch._C._dispatch_keyset_full", + "torch._C._dispatch_keyset_to_string", + "torch._C._dispatch_library", + "torch._C._dispatch_num_backends", + "torch._C._dispatch_print_registrations_for_dispatch_key", + "torch._C._dispatch_pystub", + "torch._C._dispatch_set_report_error_callback", + "torch._C._dispatch_tls_is_dispatch_key_excluded", + "torch._C._dispatch_tls_is_dispatch_key_included", + "torch._C._dispatch_tls_local_exclude_set", + "torch._C._dispatch_tls_local_include_set", + "torch._C._dispatch_tls_set_dispatch_key_excluded", + "torch._C._dispatch_tls_set_dispatch_key_included", + "torch._C._dist_autograd_init", + "torch._C._dump_local_tls_set", + "torch._C._dump_upgraders_map", + "torch._C._enable_mobile_interface_call_export", + "torch._C._enter_dual_level", + "torch._C._error_if_any_worker_fails", + "torch._C._exit_dual_level", + "torch._C._export_operator_list", + "torch._C._export_opnames", + "torch._C._faulty_agent_init", + "torch._C._fft.fft_fft", + "torch._C._fft.fft_fft2", + "torch._C._fft.fft_fftfreq", + "torch._C._fft.fft_fftn", + "torch._C._fft.fft_fftshift", + "torch._C._fft.fft_hfft", + "torch._C._fft.fft_hfft2", + "torch._C._fft.fft_hfftn", + "torch._C._fft.fft_ifft", + "torch._C._fft.fft_ifft2", + "torch._C._fft.fft_ifftn", + "torch._C._fft.fft_ifftshift", + "torch._C._fft.fft_ihfft", + "torch._C._fft.fft_ihfft2", + "torch._C._fft.fft_ihfftn", + "torch._C._fft.fft_irfft", + "torch._C._fft.fft_irfft2", + "torch._C._fft.fft_irfftn", + "torch._C._fft.fft_rfft", + "torch._C._fft.fft_rfft2", + "torch._C._fft.fft_rfftfreq", + "torch._C._fft.fft_rfftn", + "torch._C._free_And_Remove_DeleterFn", + "torch._C._freeze_module", + "torch._C._from_dlpack", + "torch._C._functionality_to_backend_keys", + "torch._C._functionalization_reapply_views_tls", + "torch._C._fuse_to_static_module", + "torch._C._gather_out", + "torch._C._gather", + "torch._C._generate_upgraders_graph", + "torch._C._get_autograd_fallback_mode", + "torch._C._get_backcompat_broadcast_warn", + "torch._C._get_backcompat_keepdim_warn", + "torch._C._get_blas_preferred_backend", + "torch._C._get_caught_jit_exception_class_name", + "torch._C._get_caught_jit_exception_original_msg", + "torch._C._get_constant_bool_symnode", + "torch._C._get_cpp_backtrace", + "torch._C._get_cpu_capability", + "torch._C._get_cublas_allow_bf16_reduced_precision_reduction", + "torch._C._get_cublas_allow_fp16_reduced_precision_reduction", + "torch._C._get_cublas_allow_tf32", + "torch._C._get_cudnn_allow_tf32", + "torch._C._get_cudnn_benchmark", + "torch._C._get_cudnn_deterministic", + "torch._C._get_cudnn_enabled", + "torch._C._get_custom_class_python_wrapper", + "torch._C._get_default_device", + "torch._C._get_deterministic_algorithms_warn_only", + "torch._C._get_deterministic_algorithms", + "torch._C._get_deterministic_fill_uninitialized_memory", + "torch._C._get_dispatch_mode", + "torch._C._get_dispatch_stack_at", + "torch._C._get_file_format", + "torch._C._get_flash_sdp_enabled", + "torch._C._get_float32_matmul_precision", + "torch._C._get_function_stack_at", + "torch._C._get_graph_executor_optimize", + "torch._C._get_linalg_preferred_backend", + "torch._C._get_rocm_fa_preferred_backend", + "torch._C._get_math_sdp_enabled", + "torch._C._get_math_sdp_allow_fp16_bf16_reduction", + "torch._C._get_max_operator_version", + "torch._C._get_mem_efficient_sdp_enabled", + "torch._C._get_mkldnn_enabled", + "torch._C._get_cudnn_sdp_enabled", + "torch._C._set_sdp_use_cudnn", + "torch._C._get_mobile_model_contained_types_from_buffer", + "torch._C._get_mobile_model_contained_types", + "torch._C._get_model_bytecode_version_from_buffer", + "torch._C._get_model_bytecode_version", + "torch._C._get_model_extra_files_from_buffer", + "torch._C._get_model_extra_files", + "torch._C._get_model_ops_and_info_from_buffer", + "torch._C._get_model_ops_and_info", + "torch._C._get_module_info_from_flatbuffer", + "torch._C._get_nnpack_enabled", + "torch._C._get_obj_in_tls", + "torch._C._get_operation_overload", + "torch._C._get_operator_version_map", + "torch._C._get_privateuse1_backend_name", + "torch._C._get_qengine", + "torch._C._get_schema", + "torch._C._get_sm_carveout_experimental", + "torch._C._get_nested_int", + "torch._C._get_tensor_metadata", + "torch._C._get_tracing_state", + "torch._C._get_upgrader_ranges", + "torch._C._get_upgraders_entry_map", + "torch._C._get_upgraders_map_size", + "torch._C._get_value_trace", + "torch._C._get_version_calculator_flag", + "torch._C._get_warnAlways", + "torch._C._graph_pool_handle", + "torch._C._group_tensors_by_device_and_dtype", + "torch._C._hack_do_not_use_clone_module_with_class", + "torch._C._has_distributed", + "torch._C._has_Standard_Deleter", + "torch._C._has_storage", + "torch._C._has_tensorexpr_cpp_tests", + "torch._C._run_tensorexpr_cpp_tests", + "torch._C._has_torch_function_unary", + "torch._C._has_torch_function_variadic", + "torch._C._has_torch_function", + "torch._C._import_ir_module_from_package", + "torch._C._increment_version", + "torch._C._infer_size", + "torch._C._init_names", + "torch._C._initExtension", + "torch._C._is_alias_of", + "torch._C._is_any_autocast_enabled", + "torch._C._is_cached_tensor", + "torch._C._is_flash_attention_available", + "torch._C._is_fwd_grad_enabled", + "torch._C._is_key_in_tls", + "torch._C._is_multithreading_enabled", + "torch._C._is_torch_function_enabled", + "torch._C._is_torch_function_mode_enabled", + "torch._C._is_torch_function_all_disabled", + "torch._C._is_tracing", + "torch._C._is_view_replay_enabled", + "torch._C._is_xnnpack_enabled", + "torch._C._itt.is_available", + "torch._C._itt.mark", + "torch._C._itt.rangePop", + "torch._C._itt.rangePush", + "torch._C._ivalue_debug_python_object", + "torch._C._ivalue_tags_match", + "torch._C._jit_assert_is_instance", + "torch._C._jit_can_fuse_on_cpu_legacy", + "torch._C._jit_can_fuse_on_cpu", + "torch._C._jit_can_fuse_on_gpu", + "torch._C._jit_cat_wo_conditionals", + "torch._C._jit_check_alias_annotation", + "torch._C._jit_clear_class_registry", + "torch._C._jit_debug_fuser_num_cached_kernel_specs", + "torch._C._jit_debug_module_iterators", + "torch._C._jit_decay_packed_param_input_types", + "torch._C._jit_decomposition_graph_for_node", + "torch._C._jit_differentiate", + "torch._C._jit_erase_non_input_shape_information", + "torch._C._jit_flatten", + "torch._C._jit_fuser_get_fused_kernel_code", + "torch._C._jit_get_all_schemas", + "torch._C._jit_get_custom_class_schemas", + "torch._C._jit_get_emit_hooks", + "torch._C._jit_get_inline_everything_mode", + "torch._C._jit_get_logging_option", + "torch._C._jit_get_num_profiled_runs", + "torch._C._jit_get_operation", + "torch._C._jit_get_schemas_for_operator", + "torch._C._jit_get_te_cuda_pointwise_block_count", + "torch._C._jit_get_te_cuda_pointwise_block_size", + "torch._C._jit_get_te_cuda_pointwise_loop_levels", + "torch._C._jit_get_te_generate_block_code", + "torch._C._jit_get_te_must_use_llvm_cpu", + "torch._C._jit_get_tracer_state_warn", + "torch._C._jit_has_cpp_tests", + "torch._C._jit_init", + "torch._C._jit_interpret_graph", + "torch._C._jit_is_onnx_log_enabled", + "torch._C._jit_is_script_object", + "torch._C._jit_llga_enabled", + "torch._C._jit_nvfuser_can_be_enabled", + "torch._C._jit_nvfuser_clear_comparison_callback", + "torch._C._jit_nvfuser_enabled", + "torch._C._jit_nvfuser_horizontal_mode", + "torch._C._jit_nvfuser_set_comparison_callback", + "torch._C._jit_nvfuser_single_node_mode", + "torch._C._jit_object_is_non_holding", + "torch._C._jit_onnx_convert_pattern_from_subblock", + "torch._C._jit_onnx_create_full_scope_name", + "torch._C._jit_onnx_list_model_parameters", + "torch._C._jit_onnx_log", + "torch._C._jit_opt_conditionals", + "torch._C._jit_override_can_fuse_on_cpu_legacy", + "torch._C._jit_override_can_fuse_on_cpu", + "torch._C._jit_override_can_fuse_on_gpu", + "torch._C._jit_pass_autocast", + "torch._C._jit_pass_batch_mm", + "torch._C._jit_pass_canonicalize_graph_fuser_ops", + "torch._C._jit_pass_canonicalize", + "torch._C._jit_pass_complete_shape_analysis", + "torch._C._jit_pass_concat_frozen_linear", + "torch._C._jit_pass_constant_loop_unrolling", + "torch._C._jit_pass_constant_pooling", + "torch._C._jit_pass_constant_propagation_immutable_types", + "torch._C._jit_pass_constant_propagation", + "torch._C._jit_pass_convert_frozen_ops_to_mkldnn", + "torch._C._jit_pass_create_autodiff_subgraphs", + "torch._C._jit_pass_create_functional_graphs", + "torch._C._jit_pass_cse", + "torch._C._jit_pass_custom_pattern_based_rewrite_graph", + "torch._C._jit_pass_custom_pattern_based_rewrite", + "torch._C._jit_pass_dbr_quant_remove_redundant_aliases", + "torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects", + "torch._C._jit_pass_dce", + "torch._C._jit_pass_decompose_ops", + "torch._C._jit_pass_dedup_module_uses", + "torch._C._jit_pass_erase_number_types", + "torch._C._jit_pass_erase_shape_information", + "torch._C._jit_pass_filter_non_tensor_arguments", + "torch._C._jit_pass_fixup_onnx_controlflow_node", + "torch._C._jit_pass_fold_convbn", + "torch._C._jit_pass_fold_frozen_conv_add_or_sub", + "torch._C._jit_pass_fold_frozen_conv_bn", + "torch._C._jit_pass_fold_frozen_conv_mul_or_div", + "torch._C._jit_pass_fold_frozen_linear_bn", + "torch._C._jit_pass_fold_prepacking_ops", + "torch._C._jit_pass_functional_to_inplace_activation", + "torch._C._jit_pass_fuse_add_relu", + "torch._C._jit_pass_fuse_addmm", + "torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv", + "torch._C._jit_pass_fuse_frozen_conv_add_relu", + "torch._C._jit_pass_fuse_linear", + "torch._C._jit_pass_fuse_quantized_add_relu", + "torch._C._jit_pass_fuse_tensorexprs", + "torch._C._jit_pass_fuse", + "torch._C._jit_pass_inline_fork_wait", + "torch._C._jit_pass_inline_functional_graphs", + "torch._C._jit_pass_inline", + "torch._C._jit_pass_inplace_to_functional_activation", + "torch._C._jit_pass_insert_observer_method_for_ondevice_ptq", + "torch._C._jit_pass_insert_observers", + "torch._C._jit_pass_insert_prepack_unpack", + "torch._C._jit_pass_insert_prepacked_ops", + "torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq", + "torch._C._jit_pass_insert_quant_dequant", + "torch._C._jit_pass_integer_value_refinement", + "torch._C._jit_pass_lint", + "torch._C._jit_pass_loop_unrolling", + "torch._C._jit_pass_lower_all_tuples", + "torch._C._jit_pass_lower_graph", + "torch._C._jit_pass_metal_fold_prepacking_ops", + "torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv", + "torch._C._jit_pass_metal_insert_prepacked_ops", + "torch._C._jit_pass_metal_optimize_for_mobile", + "torch._C._jit_pass_onnx_assign_output_shape", + "torch._C._jit_pass_onnx_assign_scoped_names_for_node_and_value", + "torch._C._jit_pass_onnx_autograd_function_process", + "torch._C._jit_pass_onnx_block", + "torch._C._jit_pass_onnx_cast_all_constant_to_floating", + "torch._C._jit_pass_onnx_clear_scope_records", + "torch._C._jit_pass_onnx_constant_fold", + "torch._C._jit_pass_onnx_deduplicate_initializers", + "torch._C._jit_pass_onnx_eliminate_unused_items", + "torch._C._jit_pass_onnx_eval_peephole", + "torch._C._jit_pass_onnx_function_extraction", + "torch._C._jit_pass_onnx_function_substitution", + "torch._C._jit_pass_onnx_graph_shape_type_inference", + "torch._C._jit_pass_onnx_lint", + "torch._C._jit_pass_onnx_node_shape_type_inference", + "torch._C._jit_pass_onnx_peephole", + "torch._C._jit_pass_onnx_preprocess_caffe2", + "torch._C._jit_pass_onnx_preprocess", + "torch._C._jit_pass_onnx_quantization_insert_permutes", + "torch._C._jit_pass_onnx_remove_inplace_ops_for_onnx", + "torch._C._jit_pass_onnx_remove_print", + "torch._C._jit_pass_onnx_scalar_type_analysis", + "torch._C._jit_pass_onnx_set_dynamic_input_shape", + "torch._C._jit_pass_onnx_track_scope_attributes", + "torch._C._jit_pass_onnx_unpack_quantized_weights", + "torch._C._jit_pass_onnx", + "torch._C._jit_pass_optimize_for_inference", + "torch._C._jit_pass_optimize_for_mobile", + "torch._C._jit_pass_optimize_frozen_graph", + "torch._C._jit_pass_pattern_based_rewrite", + "torch._C._jit_pass_peephole_list_idioms", + "torch._C._jit_pass_peephole", + "torch._C._jit_pass_prepare_division_for_onnx", + "torch._C._jit_pass_propagate_device", + "torch._C._jit_pass_propagate_dtype", + "torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute", + "torch._C._jit_pass_propagate_shapes_on_graph", + "torch._C._jit_pass_quant_finalize_for_ondevice_ptq", + "torch._C._jit_pass_quant_finalize", + "torch._C._jit_pass_quant_fusion", + "torch._C._jit_pass_refine_integer_values", + "torch._C._jit_pass_refine_tuple_types", + "torch._C._jit_pass_remove_dropout", + "torch._C._jit_pass_remove_expands", + "torch._C._jit_pass_remove_inplace_ops", + "torch._C._jit_pass_remove_mutation", + "torch._C._jit_pass_replace_old_ops_with_upgraders", + "torch._C._jit_pass_replicate_dequantize", + "torch._C._jit_pass_run_decompositions", + "torch._C._jit_pass_specialize_autogradzero", + "torch._C._jit_pass_swap_functional_linear", + "torch._C._jit_pass_transform_conv1d_to_conv2d", + "torch._C._jit_pass_transpose_frozen_linear", + "torch._C._jit_pass_vulkan_fold_prepacking_ops", + "torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv", + "torch._C._jit_pass_vulkan_insert_prepacked_ops", + "torch._C._jit_pass_vulkan_optimize_for_mobile", + "torch._C._jit_register_decomposition_for_schema", + "torch._C._jit_register_shape_compute_graph_for_node", + "torch._C._jit_resolve_packet", + "torch._C._jit_run_cpp_tests", + "torch._C._jit_script_class_compile", + "torch._C._jit_script_compile_overload", + "torch._C._jit_script_compile", + "torch._C._jit_script_interface_compile", + "torch._C._jit_set_autocast_mode", + "torch._C._jit_set_bailout_depth", + "torch._C._jit_set_emit_hooks", + "torch._C._jit_set_fusion_strategy", + "torch._C._jit_set_inline_everything_mode", + "torch._C._jit_set_llga_enabled", + "torch._C._jit_set_logging_option", + "torch._C._jit_set_logging_stream", + "torch._C._jit_set_num_profiled_runs", + "torch._C._jit_set_nvfuser_enabled", + "torch._C._jit_set_nvfuser_guard_mode", + "torch._C._jit_set_nvfuser_horizontal_mode", + "torch._C._jit_set_nvfuser_single_node_mode", + "torch._C._jit_set_nvfuser_skip_node_kind", + "torch._C._jit_set_onnx_log_enabled", + "torch._C._jit_set_onnx_log_output_stream", + "torch._C._jit_set_profiling_executor", + "torch._C._jit_set_profiling_mode", + "torch._C._jit_set_symbolic_shapes_test_mode", + "torch._C._jit_set_te_cuda_pointwise_block_count", + "torch._C._jit_set_te_cuda_pointwise_block_size", + "torch._C._jit_set_te_cuda_pointwise_loop_levels", + "torch._C._jit_set_te_generate_block_code", + "torch._C._jit_set_te_must_use_llvm_cpu", + "torch._C._jit_set_texpr_dynamic_shape_enabled", + "torch._C._jit_set_texpr_fuser_enabled", + "torch._C._jit_set_texpr_reductions_enabled", + "torch._C._jit_set_tracer_state_warn", + "torch._C._jit_set_utf8_decoding_ignore", + "torch._C._jit_shape_compute_graph_for_node", + "torch._C._jit_symbolic_shapes_test_mode_enabled", + "torch._C._jit_texpr_dynamic_shape_enabled", + "torch._C._jit_texpr_fallback_allowed", + "torch._C._jit_texpr_fuser_enabled", + "torch._C._jit_texpr_reductions_enabled", + "torch._C._jit_texpr_set_fallback_allowed", + "torch._C._jit_to_backend_selective", + "torch._C._jit_to_backend", + "torch._C._jit_to_static_module", + "torch._C._jit_trace_graph", + "torch._C._jit_trace_module", + "torch._C._jit_tree_views.FalseLiteral", + "torch._C._jit_tree_views.NoneLiteral", + "torch._C._jit_tree_views.TrueLiteral", + "torch._C._jit_try_infer_type", + "torch._C._jit_unflatten", + "torch._C._last_executed_optimized_graph", + "torch._C._len_torch_dispatch_stack", + "torch._C._len_torch_function_stack", + "torch._C._linalg._linalg_eigvals", + "torch._C._linalg.linalg_cholesky_ex", + "torch._C._linalg.linalg_cholesky", + "torch._C._linalg.linalg_cond", + "torch._C._linalg.linalg_cross", + "torch._C._linalg.linalg_det", + "torch._C._linalg.linalg_diagonal", + "torch._C._linalg.linalg_eig", + "torch._C._linalg.linalg_eigh", + "torch._C._linalg.linalg_eigvals", + "torch._C._linalg.linalg_eigvalsh", + "torch._C._linalg.linalg_householder_product", + "torch._C._linalg.linalg_inv_ex", + "torch._C._linalg.linalg_inv", + "torch._C._linalg.linalg_ldl_factor_ex", + "torch._C._linalg.linalg_ldl_factor", + "torch._C._linalg.linalg_ldl_solve", + "torch._C._linalg.linalg_lstsq", + "torch._C._linalg.linalg_lu_factor_ex", + "torch._C._linalg.linalg_lu_factor", + "torch._C._linalg.linalg_lu_solve", + "torch._C._linalg.linalg_lu", + "torch._C._linalg.linalg_matmul", + "torch._C._linalg.linalg_matrix_exp", + "torch._C._linalg.linalg_matrix_norm", + "torch._C._linalg.linalg_matrix_power", + "torch._C._linalg.linalg_matrix_rank", + "torch._C._linalg.linalg_multi_dot", + "torch._C._linalg.linalg_norm", + "torch._C._linalg.linalg_pinv", + "torch._C._linalg.linalg_qr", + "torch._C._linalg.linalg_slogdet", + "torch._C._linalg.linalg_solve_ex", + "torch._C._linalg.linalg_solve_triangular", + "torch._C._linalg.linalg_solve", + "torch._C._linalg.linalg_svd", + "torch._C._linalg.linalg_svdvals", + "torch._C._linalg.linalg_tensorinv", + "torch._C._linalg.linalg_tensorsolve", + "torch._C._linalg.linalg_vander", + "torch._C._linalg.linalg_vecdot", + "torch._C._linalg.linalg_vector_norm", + "torch._C._llvm_enabled", + "torch._C._load_for_lite_interpreter_from_buffer", + "torch._C._load_for_lite_interpreter", + "torch._C._load_jit_module_from_bytes", + "torch._C._load_jit_module_from_file", + "torch._C._load_mobile_module_from_bytes", + "torch._C._load_mobile_module_from_file", + "torch._C._log_api_usage_metadata", + "torch._C._log_api_usage_once", + "torch._C._logging_set_logger", + "torch._C._meta_in_tls_dispatch_include", + "torch._C._mps_acquireEvent", + "torch._C._mps_currentAllocatedMemory", + "torch._C._mps_deviceSynchronize", + "torch._C._mps_driverAllocatedMemory", + "torch._C._mps_recommendedMaxMemory", + "torch._C._mps_elapsedTimeOfEvents", + "torch._C._mps_emptyCache", + "torch._C._mps_get_default_generator", + "torch._C._mps_is_available", + "torch._C._mps_is_in_bad_fork", + "torch._C._mps_is_on_macos_13_or_newer", + "torch._C._mps_profilerStartTrace", + "torch._C._mps_profilerStopTrace", + "torch._C._mps_queryEvent", + "torch._C._mps_recordEvent", + "torch._C._mps_releaseEvent", + "torch._C._mps_setMemoryFraction", + "torch._C._mps_synchronizeEvent", + "torch._C._mps_waitForEvent", + "torch._C._multiprocessing_init", + "torch._C._nccl_all_gather", + "torch._C._nccl_all_reduce", + "torch._C._nccl_broadcast", + "torch._C._nccl_init_rank", + "torch._C._nccl_reduce_scatter", + "torch._C._nccl_reduce", + "torch._C._nccl_unique_id", + "torch._C._nccl_version_suffix", + "torch._C._nccl_version", + "torch._C._nested.nested_tensor", + "torch._C._nested.nested_to_padded_tensor", + "torch._C._new_symbolic_shape_symbol", + "torch._C._nn_module_to_mobile", + "torch._C._nn._conv_depthwise2d", + "torch._C._nn._pad_circular", + "torch._C._nn._pad_enum", + "torch._C._nn._parse_to", + "torch._C._nn._test_ambiguous_defaults", + "torch._C._nn._test_optional_filled_intlist", + "torch._C._nn._test_optional_floatlist", + "torch._C._nn._test_optional_intlist", + "torch._C._nn._test_string_default", + "torch._C._nn._test_warn_in_autograd", + "torch._C._nn._upsample_bicubic2d_aa", + "torch._C._nn._upsample_bilinear2d_aa", + "torch._C._nn._upsample_nearest_exact1d", + "torch._C._nn._upsample_nearest_exact2d", + "torch._C._nn._upsample_nearest_exact3d", + "torch._C._nn.adaptive_avg_pool2d", + "torch._C._nn.adaptive_avg_pool3d", + "torch._C._nn.adaptive_max_pool2d", + "torch._C._nn.adaptive_max_pool3d", + "torch._C._nn.avg_pool2d", + "torch._C._nn.avg_pool3d", + "torch._C._nn.binary_cross_entropy", + "torch._C._nn.col2im", + "torch._C._nn.conv_depthwise3d", + "torch._C._nn.cross_entropy_loss", + "torch._C._nn.elu_", + "torch._C._nn.elu", + "torch._C._nn.flatten_dense_tensors", + "torch._C._nn.fractional_max_pool2d", + "torch._C._nn.fractional_max_pool3d", + "torch._C._nn.gelu_", + "torch._C._nn.gelu", + "torch._C._nn.glu", + "torch._C._nn.hardsigmoid_", + "torch._C._nn.hardsigmoid", + "torch._C._nn.hardswish_", + "torch._C._nn.hardswish", + "torch._C._nn.hardtanh_", + "torch._C._nn.hardtanh", + "torch._C._nn.huber_loss", + "torch._C._nn.im2col", + "torch._C._nn.l1_loss", + "torch._C._nn.leaky_relu_", + "torch._C._nn.leaky_relu", + "torch._C._nn.linear", + "torch._C._nn.log_sigmoid", + "torch._C._nn.max_pool2d_with_indices", + "torch._C._nn.max_pool3d_with_indices", + "torch._C._nn.max_unpool2d", + "torch._C._nn.max_unpool3d", + "torch._C._nn.mish_", + "torch._C._nn.mish", + "torch._C._nn.mkldnn_linear", + "torch._C._nn.mkldnn_reorder_conv2d_weight", + "torch._C._nn.mkldnn_reorder_conv3d_weight", + "torch._C._nn.mse_loss", + "torch._C._nn.multi_margin_loss", + "torch._C._nn.multilabel_margin_loss", + "torch._C._nn.nll_loss_nd", + "torch._C._nn.nll_loss", + "torch._C._nn.nll_loss2d", + "torch._C._nn.one_hot", + "torch._C._nn.pad_sequence", + "torch._C._nn.pad", + "torch._C._nn.reflection_pad1d", + "torch._C._nn.reflection_pad2d", + "torch._C._nn.reflection_pad3d", + "torch._C._nn.relu6_", + "torch._C._nn.relu6", + "torch._C._nn.replication_pad1d", + "torch._C._nn.replication_pad2d", + "torch._C._nn.replication_pad3d", + "torch._C._nn.rrelu_with_noise_", + "torch._C._nn.rrelu_with_noise", + "torch._C._nn.scaled_dot_product_attention", + "torch._C._nn.silu_", + "torch._C._nn.silu", + "torch._C._nn.slow_conv_dilated2d", + "torch._C._nn.slow_conv_dilated3d", + "torch._C._nn.slow_conv_transpose2d", + "torch._C._nn.slow_conv_transpose3d", + "torch._C._nn.slow_conv3d", + "torch._C._nn.smooth_l1_loss", + "torch._C._nn.soft_margin_loss", + "torch._C._nn.softplus", + "torch._C._nn.softshrink", + "torch._C._nn.thnn_conv2d", + "torch._C._nn.unflatten_dense_tensors", + "torch._C._nn.upsample_bicubic2d", + "torch._C._nn.upsample_bilinear2d", + "torch._C._nn.upsample_linear1d", + "torch._C._nn.upsample_nearest1d", + "torch._C._nn.upsample_nearest2d", + "torch._C._nn.upsample_nearest3d", + "torch._C._nn.upsample_trilinear3d", + "torch._C._non_sym_sizes", + "torch._C._overlaps", + "torch._C._parallel_info", + "torch._C._parse_dispatch_key", + "torch._C._parse_source_def", + "torch._C._pop_torch_dispatch_stack", + "torch._C._pop_torch_function_stack", + "torch._C._propagate_and_assign_input_shapes", + "torch._C._propagate_shapes", + "torch._C._propagate_xla_data", + "torch._C._push_on_torch_dispatch_stack", + "torch._C._push_on_torch_function_stack", + "torch._C._quantize_ondevice_ptq_dynamic", + "torch._C._register_py_class_for_device", + "torch._C._remove_cached_tensor", + "torch._C._remove_worker_pids", + "torch._C._rename_privateuse1_backend", + "torch._C._replace_", + "torch._C._replace_overloaded_method_decl", + "torch._C._resolve_type_from_object", + "torch._C._resolve_type", + "torch._C._rocm_is_backward_pass", + "torch._C._rpc_init", + "torch._C._run_emit_module_hook", + "torch._C._save_jit_module_to_bytes", + "torch._C._save_jit_module", + "torch._C._save_mobile_module_to_bytes", + "torch._C._save_mobile_module", + "torch._C._save_parameters", + "torch._C._scatter_out", + "torch._C._scatter", + "torch._C._select_conv_backend", + "torch._C._select_batch_norm_backend", + "torch._C._set_autograd_fallback_mode", + "torch._C._set_backcompat_broadcast_warn", + "torch._C._set_backcompat_keepdim_warn", + "torch._C._set_blas_preferred_backend", + "torch._C._set_cached_tensors_enabled", + "torch._C._set_check_sparse_tensor_invariants", + "torch._C._set_conj", + "torch._C._set_cublas_allow_bf16_reduced_precision_reduction", + "torch._C._set_cublas_allow_fp16_reduced_precision_reduction", + "torch._C._set_cublas_allow_tf32", + "torch._C._set_cudnn_allow_tf32", + "torch._C._set_cudnn_benchmark", + "torch._C._set_cudnn_deterministic", + "torch._C._set_cudnn_enabled", + "torch._C._set_default_dtype", + "torch._C._set_default_mobile_cpu_allocator", + "torch._C._set_default_tensor_type", + "torch._C._set_deterministic_algorithms", + "torch._C._set_deterministic_fill_uninitialized_memory", + "torch._C._set_dispatch_mode", + "torch._C._set_float32_matmul_precision", + "torch._C._set_fwd_grad_enabled", + "torch._C._set_grad_enabled", + "torch._C._set_graph_executor_optimize", + "torch._C._set_linalg_preferred_backend", + "torch._C._set_rocm_fa_preferred_backend", + "torch._C._set_meta_in_tls_dispatch_include", + "torch._C._set_mkldnn_enabled", + "torch._C._set_multithreading_enabled", + "torch._C._set_neg", + "torch._C._set_nnpack_enabled", + "torch._C._set_print_stack_traces_on_fatal_signal", + "torch._C._set_qengine", + "torch._C._set_sdp_use_flash", + "torch._C._set_sdp_use_math", + "torch._C._set_math_sdp_allow_fp16_bf16_reduction", + "torch._C._set_sdp_use_mem_efficient", + "torch._C._set_should_use_format_with_string_table", + "torch._C._set_sm_carveout_experimental", + "torch._C._set_storage_access_error_msg", + "torch._C._set_tensor_metadata", + "torch._C._set_tracing_state", + "torch._C._set_value_trace", + "torch._C._set_view_replay_enabled", + "torch._C._set_warnAlways", + "torch._C._set_worker_pids", + "torch._C._set_worker_signal_handlers", + "torch._C._should_allow_numbers_as_tensors", + "torch._C._show_config", + "torch._C._sparse._sparse_addmm", + "torch._C._sparse._sparse_log_softmax", + "torch._C._sparse._sparse_mm_reduce_impl", + "torch._C._sparse._sparse_mm", + "torch._C._sparse._sparse_softmax", + "torch._C._sparse._spdiags", + "torch._C._sparse.sparse_sampled_addmm", + "torch._C._special.special_airy_ai", + "torch._C._special.special_bessel_j0", + "torch._C._special.special_bessel_j1", + "torch._C._special.special_bessel_y0", + "torch._C._special.special_bessel_y1", + "torch._C._special.special_chebyshev_polynomial_t", + "torch._C._special.special_chebyshev_polynomial_u", + "torch._C._special.special_chebyshev_polynomial_v", + "torch._C._special.special_chebyshev_polynomial_w", + "torch._C._special.special_digamma", + "torch._C._special.special_entr", + "torch._C._special.special_erf", + "torch._C._special.special_erfc", + "torch._C._special.special_erfcx", + "torch._C._special.special_erfinv", + "torch._C._special.special_exp2", + "torch._C._special.special_expit", + "torch._C._special.special_expm1", + "torch._C._special.special_gammainc", + "torch._C._special.special_gammaincc", + "torch._C._special.special_gammaln", + "torch._C._special.special_hermite_polynomial_h", + "torch._C._special.special_hermite_polynomial_he", + "torch._C._special.special_i0", + "torch._C._special.special_i0e", + "torch._C._special.special_i1", + "torch._C._special.special_i1e", + "torch._C._special.special_laguerre_polynomial_l", + "torch._C._special.special_legendre_polynomial_p", + "torch._C._special.special_log_ndtr", + "torch._C._special.special_log_softmax", + "torch._C._special.special_log1p", + "torch._C._special.special_logit", + "torch._C._special.special_logsumexp", + "torch._C._special.special_modified_bessel_i0", + "torch._C._special.special_modified_bessel_i1", + "torch._C._special.special_modified_bessel_k0", + "torch._C._special.special_modified_bessel_k1", + "torch._C._special.special_multigammaln", + "torch._C._special.special_ndtr", + "torch._C._special.special_ndtri", + "torch._C._special.special_polygamma", + "torch._C._special.special_psi", + "torch._C._special.special_round", + "torch._C._special.special_scaled_modified_bessel_k0", + "torch._C._special.special_scaled_modified_bessel_k1", + "torch._C._special.special_shifted_chebyshev_polynomial_t", + "torch._C._special.special_shifted_chebyshev_polynomial_u", + "torch._C._special.special_shifted_chebyshev_polynomial_v", + "torch._C._special.special_shifted_chebyshev_polynomial_w", + "torch._C._special.special_sinc", + "torch._C._special.special_softmax", + "torch._C._special.special_spherical_bessel_j0", + "torch._C._special.special_xlog1py", + "torch._C._special.special_xlogy", + "torch._C._special.special_zeta", + "torch._C._stash_obj_in_tls", + "torch._C._storage_id", + "torch._C._storage_Use_Count", + "torch._C._supported_qengines", + "torch._C._te.abs", + "torch._C._te.acos", + "torch._C._te.annotate_input_shapes", + "torch._C._te.asin", + "torch._C._te.atan", + "torch._C._te.atan2", + "torch._C._te.ceil", + "torch._C._te.Compute", + "torch._C._te.Compute2", + "torch._C._te.construct_codegen", + "torch._C._te.cos", + "torch._C._te.cosh", + "torch._C._te.erf", + "torch._C._te.erfc", + "torch._C._te.exp", + "torch._C._te.expm1", + "torch._C._te.fixup_missing_shape_info", + "torch._C._te.floor", + "torch._C._te.fmod", + "torch._C._te.frac", + "torch._C._te.ifThenElse", + "torch._C._te.is_graph_compilable", + "torch._C._te.isnan", + "torch._C._te.lgamma", + "torch._C._te.log", + "torch._C._te.log10", + "torch._C._te.log1p", + "torch._C._te.log2", + "torch._C._te.lower", + "torch._C._te.make_shapes_symbolic", + "torch._C._te.pow", + "torch._C._te.Reduce", + "torch._C._te.remainder", + "torch._C._te.remove_graph_output", + "torch._C._te.remove_unused_self_argument", + "torch._C._te.replace_list_output_with_tuple", + "torch._C._te.round", + "torch._C._te.rsqrt", + "torch._C._te.sigmoid", + "torch._C._te.simplify", + "torch._C._te.sin", + "torch._C._te.sinh", + "torch._C._te.sqrt", + "torch._C._te.tan", + "torch._C._te.tanh", + "torch._C._te.trim_graph", + "torch._C._te.trunc", + "torch._C._tensor_impl_raw_handle", + "torch._C._test_only_add_entry_to_op_version_map", + "torch._C._test_only_populate_upgraders", + "torch._C._test_only_remove_entry_to_op_version_map", + "torch._C._test_only_remove_upgraders", + "torch._C._to_functionality_key", + "torch._C._tracer_set_force_outplace", + "torch._C._tracer_set_get_unique_name_fn", + "torch._C._tracer_warn_use_python", + "torch._C._unset_default_mobile_cpu_allocator", + "torch._C._unset_dispatch_mode", + "torch._C._valgrind_supported_platform", + "torch._C._valgrind_toggle_and_dump_stats", + "torch._C._valgrind_toggle", + "torch._C._verbose.mkl_set_verbose", + "torch._C._verbose.mkldnn_set_verbose", + "torch._C._vmapmode_decrement_nesting", + "torch._C._vmapmode_increment_nesting", + "torch._C._warn_deprecation", + "torch._C._warn", + "torch._C._will_engine_execute_node", + "torch._C._wrap_tensor_impl", + "torch._C._xpu_emptyCache", + "torch._C._xpu_getArchFlags", + "torch._C._xpu_getCurrentStream", + "torch._C._xpu_getCurrentRawStream", + "torch._C._xpu_getDeviceCount", + "torch._C._xpu_getDevice", + "torch._C._xpu_getMemoryInfo", + "torch._C._xpu_getStreamFromExternal", + "torch._C._xpu_isInBadFork", + "torch._C._xpu_init", + "torch._C._xpu_memoryStats", + "torch._C._xpu_resetAccumulatedMemoryStats", + "torch._C._xpu_resetPeakMemoryStats", + "torch._C._xpu_setStream", + "torch._C._xpu_synchronize", + "torch._C.fork", + "torch._C.get_autocast_cpu_dtype", + "torch._C.get_autocast_dtype", + "torch._C.get_autocast_gpu_dtype", + "torch._C.get_autocast_ipu_dtype", + "torch._C.get_autocast_xla_dtype", + "torch._C.get_default_dtype", + "torch._C.get_num_interop_threads", + "torch._C.get_num_threads", + "torch._C.import_ir_module_from_buffer", + "torch._C.import_ir_module", + "torch._C.init_num_threads", + "torch._C.is_anomaly_check_nan_enabled", + "torch._C.is_anomaly_enabled", + "torch._C.is_autocast_cache_enabled", + "torch._C.is_autocast_cpu_enabled", + "torch._C.is_autocast_enabled", + "torch._C.is_autocast_ipu_enabled", + "torch._C.is_autocast_xla_enabled", + "torch._C.is_grad_enabled", + "torch._C.is_inference_mode_enabled", + "torch._C.merge_type_from_type_comment", + "torch._C.parse_ir", + "torch._C.parse_schema", + "torch._C.parse_type_comment", + "torch._C.read_vitals", + "torch._C.set_vital", + "torch._C.unify_type_list", + "torch._C.vitals_enabled", + "torch._C.wait", + "torch._cast_Byte", + "torch._cast_Char", + "torch._cast_Double", + "torch._cast_Float", + "torch._cast_Half", + "torch._cast_Int", + "torch._cast_Long", + "torch._cast_Short", + "torch._choose_qparams_per_tensor", + "torch._chunk_cat", + "torch._coalesce", + "torch._compute_linear_combination", + "torch._conj_copy", + "torch._conj_physical", + "torch._conj", + "torch._convert_indices_from_coo_to_csr", + "torch._convert_indices_from_csr_to_coo", + "torch._convert_weight_to_int4pack", + "torch._convert_weight_to_int4pack_for_cpu", + "torch._convolution_mode", + "torch._convolution", + "torch._copy_from_and_resize", + "torch._copy_from", + "torch._cslt_compress", + "torch._cslt_sparse_mm", + "torch._ctc_loss", + "torch._cudnn_ctc_loss", + "torch._cudnn_init_dropout_state", + "torch._cudnn_rnn_flatten_weight", + "torch._cudnn_rnn", + "torch._cufft_clear_plan_cache", + "torch._cufft_get_plan_cache_max_size", + "torch._cufft_get_plan_cache_size", + "torch._cufft_set_plan_cache_max_size", + "torch._cummax_helper", + "torch._cummin_helper", + "torch._debug_has_internal_overlap", + "torch._dim_arange", + "torch._dirichlet_grad", + "torch._disable_functionalization", + "torch._dyn_quant_matmul_4bit", + "torch._dyn_quant_pack_4bit_weight", + "torch._efficientzerotensor", + "torch._embedding_bag_forward_only", + "torch._embedding_bag", + "torch._empty_affine_quantized", + "torch._empty_per_channel_affine_quantized", + "torch._enable_functionalization", + "torch._euclidean_dist", + "torch._fake_quantize_learnable_per_channel_affine", + "torch._fake_quantize_learnable_per_tensor_affine", + "torch._fake_quantize_per_tensor_affine_cachemask_tensor_qparams", + "torch._fft_c2c", + "torch._fft_c2r", + "torch._fft_r2c", + "torch._fill_mem_eff_dropout_mask_", + "torch._foobar", + "torch._foreach_abs_", + "torch._foreach_abs", + "torch._foreach_acos_", + "torch._foreach_acos", + "torch._foreach_add_", + "torch._foreach_add", + "torch._foreach_addcdiv_", + "torch._foreach_addcdiv", + "torch._foreach_addcmul_", + "torch._foreach_addcmul", + "torch._foreach_asin_", + "torch._foreach_asin", + "torch._foreach_atan_", + "torch._foreach_atan", + "torch._foreach_ceil_", + "torch._foreach_ceil", + "torch._foreach_clamp_max_", + "torch._foreach_clamp_max", + "torch._foreach_clamp_min_", + "torch._foreach_clamp_min", + "torch._foreach_copy_", + "torch._foreach_cos_", + "torch._foreach_cos", + "torch._foreach_cosh_", + "torch._foreach_cosh", + "torch._foreach_div_", + "torch._foreach_div", + "torch._foreach_erf_", + "torch._foreach_erf", + "torch._foreach_erfc_", + "torch._foreach_erfc", + "torch._foreach_exp_", + "torch._foreach_exp", + "torch._foreach_expm1_", + "torch._foreach_expm1", + "torch._foreach_floor_", + "torch._foreach_floor", + "torch._foreach_frac_", + "torch._foreach_frac", + "torch._foreach_lerp_", + "torch._foreach_lerp", + "torch._foreach_lgamma_", + "torch._foreach_lgamma", + "torch._foreach_log_", + "torch._foreach_log", + "torch._foreach_log10_", + "torch._foreach_log10", + "torch._foreach_log1p_", + "torch._foreach_log1p", + "torch._foreach_log2_", + "torch._foreach_log2", + "torch._foreach_maximum_", + "torch._foreach_maximum", + "torch._foreach_minimum_", + "torch._foreach_minimum", + "torch._foreach_mul_", + "torch._foreach_mul", + "torch._foreach_neg_", + "torch._foreach_neg", + "torch._foreach_norm", + "torch._foreach_pow_", + "torch._foreach_pow", + "torch._foreach_reciprocal_", + "torch._foreach_reciprocal", + "torch._foreach_round_", + "torch._foreach_round", + "torch._foreach_sigmoid_", + "torch._foreach_sigmoid", + "torch._foreach_rsqrt_", + "torch._foreach_rsqrt", + "torch._foreach_sign_", + "torch._foreach_sign", + "torch._foreach_sin_", + "torch._foreach_sin", + "torch._foreach_sinh_", + "torch._foreach_sinh", + "torch._foreach_sqrt_", + "torch._foreach_sqrt", + "torch._foreach_sub_", + "torch._foreach_sub", + "torch._foreach_tan_", + "torch._foreach_tan", + "torch._foreach_tanh_", + "torch._foreach_tanh", + "torch._foreach_trunc_", + "torch._foreach_trunc", + "torch._foreach_zero_", + "torch._freeze_functional_tensor", + "torch._from_functional_tensor", + "torch._functional_assert_async", + "torch._functional_sym_constrain_range_for_size", + "torch._functional_sym_constrain_range", + "torch._functionalize_are_all_mutations_hidden_from_autograd", + "torch._functionalize_commit_update", + "torch._functionalize_enable_reapply_views", + "torch._functionalize_has_data_mutation", + "torch._functionalize_has_metadata_mutation", + "torch._functionalize_is_multi_output_view", + "torch._functionalize_mark_mutation_hidden_from_autograd", + "torch._functionalize_replace", + "torch._functionalize_sync", + "torch._functionalize_was_storage_changed", + "torch._fused_adam_", + "torch._fused_adamw_", + "torch._fused_dropout", + "torch._fused_moving_avg_obs_fq_helper", + "torch._fused_sdp_choice", + "torch._fw_primal_copy", + "torch._grid_sampler_2d_cpu_fallback", + "torch._grouped_mm", + "torch._has_compatible_shallow_copy_type", + "torch._histogramdd_bin_edges", + "torch._histogramdd_from_bin_cts", + "torch._histogramdd_from_bin_tensors", + "torch._index_put_impl_", + "torch._indices_copy", + "torch._int_mm", + "torch._is_all_true", + "torch._is_any_true", + "torch._is_functional_tensor", + "torch._is_zerotensor", + "torch._linalg_check_errors", + "torch._linalg_det", + "torch._linalg_eigh", + "torch._linalg_eigvals", + "torch._linalg_slogdet", + "torch._linalg_solve_ex", + "torch._linalg_svd", + "torch._log_softmax_backward_data", + "torch._log_softmax", + "torch._logcumsumexp", + "torch._lstm_mps", + "torch._lu_with_info", + "torch._make_dep_token", + "torch._make_dual_copy", + "torch._make_dual", + "torch._make_per_channel_quantized_tensor", + "torch._make_per_tensor_quantized_tensor", + "torch._masked_scale", + "torch._masked_softmax", + "torch._mirror_autograd_meta_to", + "torch._mixed_dtypes_linear", + "torch._mkldnn_reshape", + "torch._mkldnn_transpose_", + "torch._mkldnn_transpose", + "torch._mps_convolution_transpose", + "torch._mps_convolution", + "torch._native_batch_norm_legit_no_training", + "torch._native_batch_norm_legit", + "torch._native_multi_head_attention", + "torch._neg_view_copy", + "torch._neg_view", + "torch._nested_from_padded_and_nested_example", + "torch._nested_from_padded_tensor", + "torch._nested_tensor_from_mask_left_aligned", + "torch._nested_tensor_from_tensor_list", + "torch._nested_tensor_softmax_with_shape", + "torch._nested_view_from_buffer_copy", + "torch._nested_view_from_buffer", + "torch._nnpack_available", + "torch._nnpack_spatial_convolution", + "torch._pack_padded_sequence", + "torch._pad_packed_sequence", + "torch._pin_memory", + "torch._prelu_kernel", + "torch._propagate_xla_data", + "torch._remove_batch_dim", + "torch._reshape_alias_copy", + "torch._reshape_from_tensor", + "torch._resize_output_", + "torch._rowwise_prune", + "torch._sample_dirichlet", + "torch._saturate_weight_to_fp16", + "torch._scaled_dot_product_attention_math", + "torch._scaled_dot_product_efficient_attention", + "torch._scaled_dot_product_flash_attention", + "torch._scaled_dot_product_flash_attention_for_cpu", + "torch._scaled_dot_product_cudnn_attention", + "torch._scaled_mm", + "torch._scaled_grouped_mm", + "torch._shape_as_tensor", + "torch._sobol_engine_draw", + "torch._sobol_engine_ff_", + "torch._sobol_engine_initialize_state_", + "torch._sobol_engine_scramble_", + "torch._softmax_backward_data", + "torch._softmax", + "torch._sparse_broadcast_to_copy", + "torch._sparse_broadcast_to", + "torch._sparse_csr_prod", + "torch._sparse_csr_sum", + "torch._sparse_log_softmax_backward_data", + "torch._sparse_semi_structured_addmm", + "torch._sparse_semi_structured_linear", + "torch._sparse_semi_structured_mm", + "torch._sparse_softmax_backward_data", + "torch._sparse_sparse_matmul", + "torch._sparse_sum", + "torch._stack", + "torch._standard_gamma_grad", + "torch._standard_gamma", + "torch._test_autograd_multiple_dispatch_view_copy", + "torch._test_autograd_multiple_dispatch_view", + "torch._test_autograd_multiple_dispatch", + "torch._test_check_tensor", + "torch._test_functorch_fallback", + "torch._test_serialization_subcmul", + "torch._to_cpu", + "torch._to_functional_tensor", + "torch._to_sparse_semi_structured", + "torch._transform_bias_rescale_qkv", + "torch._transformer_encoder_layer_fwd", + "torch._trilinear", + "torch._triton_multi_head_attention", + "torch._triton_scaled_dot_attention", + "torch._unique", + "torch._unique2", + "torch._unpack_dual", + "torch._unsafe_index_put", + "torch._unsafe_index", + "torch._unsafe_masked_index_put_accumulate", + "torch._unsafe_masked_index", + "torch._use_cudnn_ctc_loss", + "torch._use_cudnn_rnn_flatten_weight", + "torch._values_copy", + "torch._weight_int4pack_mm", + "torch._weight_int4pack_mm_for_cpu", + "torch._weight_int4pack_mm_with_scales_and_zeros", + "torch._weight_int8pack_mm", + "torch._weight_norm_interface", + "torch._weight_norm", + "torch.abs_", + "torch.abs", + "torch.absolute", + "torch.acos_", + "torch.acos", + "torch.acosh_", + "torch.acosh", + "torch.adaptive_avg_pool1d", + "torch.adaptive_max_pool1d", + "torch.add", + "torch.addbmm", + "torch.addcdiv", + "torch.addcmul", + "torch.addmm", + "torch.addmv_", + "torch.addmv", + "torch.addr", + "torch.adjoint", + "torch.affine_grid_generator", + "torch.alias_copy", + "torch.all", + "torch.allclose", + "torch.alpha_dropout_", + "torch.alpha_dropout", + "torch.amax", + "torch.amin", + "torch.aminmax", + "torch.angle", + "torch.any", + "torch.arange", + "torch.arccos_", + "torch.arccos", + "torch.arccosh_", + "torch.arccosh", + "torch.arcsin_", + "torch.arcsin", + "torch.arcsinh_", + "torch.arcsinh", + "torch.arctan_", + "torch.arctan", + "torch.arctan2", + "torch.arctanh_", + "torch.arctanh", + "torch.argmax", + "torch.argmin", + "torch.argsort", + "torch.argwhere", + "torch.as_strided_", + "torch.as_strided_copy", + "torch.as_strided_scatter", + "torch.as_strided", + "torch.as_tensor", + "torch.asarray", + "torch.asin_", + "torch.asin", + "torch.asinh_", + "torch.asinh", + "torch.atan_", + "torch.atan", + "torch.atan2", + "torch.atanh_", + "torch.atanh", + "torch.avg_pool1d", + "torch.baddbmm", + "torch.bartlett_window", + "torch.batch_norm_backward_elemt", + "torch.batch_norm_backward_reduce", + "torch.batch_norm_elemt", + "torch.batch_norm_gather_stats_with_counts", + "torch.batch_norm_gather_stats", + "torch.batch_norm_stats", + "torch.batch_norm_update_stats", + "torch.batch_norm", + "torch.bernoulli", + "torch.bilinear", + "torch.binary_cross_entropy_with_logits", + "torch.bincount", + "torch.binomial", + "torch.bitwise_and", + "torch.bitwise_left_shift", + "torch.bitwise_not", + "torch.bitwise_or", + "torch.bitwise_right_shift", + "torch.bitwise_xor", + "torch.blackman_window", + "torch.bmm", + "torch.broadcast_to", + "torch.bucketize", + "torch.can_cast", + "torch.cat", + "torch.ccol_indices_copy", + "torch.ceil_", + "torch.ceil", + "torch.celu_", + "torch.celu", + "torch.channel_shuffle", + "torch.cholesky_inverse", + "torch.cholesky_solve", + "torch.cholesky", + "torch.choose_qparams_optimized", + "torch.chunk", + "torch.clamp_", + "torch.clamp_max_", + "torch.clamp_max", + "torch.clamp_min_", + "torch.clamp_min", + "torch.clamp", + "torch.clip_", + "torch.clip", + "torch.clone", + "torch.col_indices_copy", + "torch.column_stack", + "torch.combinations", + "torch.complex", + "torch.concat", + "torch.concatenate", + "torch.conj_physical_", + "torch.conj_physical", + "torch.conj", + "torch.constant_pad_nd", + "torch.conv_tbc", + "torch.conv_transpose1d", + "torch.conv_transpose2d", + "torch.conv_transpose3d", + "torch.conv1d", + "torch.conv2d", + "torch.conv3d", + "torch.convolution", + "torch.copysign", + "torch.corrcoef", + "torch.cos_", + "torch.cos", + "torch.cosh_", + "torch.cosh", + "torch.cosine_embedding_loss", + "torch.cosine_similarity", + "torch.count_nonzero", + "torch.cov", + "torch.cross", + "torch.crow_indices_copy", + "torch.ctc_loss", + "torch.cudnn_affine_grid_generator", + "torch.cudnn_batch_norm", + "torch.cudnn_convolution_add_relu", + "torch.cudnn_convolution_relu", + "torch.cudnn_convolution_transpose", + "torch.cudnn_convolution", + "torch.cudnn_grid_sampler", + "torch.cudnn_is_acceptable", + "torch.cummax", + "torch.cummin", + "torch.cumprod", + "torch.cumsum", + "torch.cumulative_trapezoid", + "torch.deg2rad_", + "torch.deg2rad", + "torch.dequantize", + "torch.det", + "torch.detach_", + "torch.detach_copy", + "torch.detach", + "torch.diag_embed", + "torch.diag", + "torch.diagflat", + "torch.diagonal_copy", + "torch.diagonal_scatter", + "torch.diagonal", + "torch.diff", + "torch.digamma", + "torch.dist", + "torch.div", + "torch.divide", + "torch.dot", + "torch.dropout_", + "torch.dropout", + "torch.dsmm", + "torch.dsplit", + "torch.dstack", + "torch.embedding_bag", + "torch.embedding_renorm_", + "torch.embedding", + "torch.empty_like", + "torch.empty_permuted", + "torch.empty_quantized", + "torch.empty_strided", + "torch.empty", + "torch.eq", + "torch.equal", + "torch.erf_", + "torch.erf", + "torch.erfc_", + "torch.erfc", + "torch.erfinv", + "torch.exp_", + "torch.exp", + "torch.exp2_", + "torch.exp2", + "torch.expand_copy", + "torch.expm1_", + "torch.expm1", + "torch.eye", + "torch.fake_quantize_per_channel_affine", + "torch.fake_quantize_per_tensor_affine", + "torch.fbgemm_linear_fp16_weight_fp32_activation", + "torch.fbgemm_linear_fp16_weight", + "torch.fbgemm_linear_int8_weight_fp32_activation", + "torch.fbgemm_linear_int8_weight", + "torch.fbgemm_linear_quantize_weight", + "torch.fbgemm_pack_gemm_matrix_fp16", + "torch.fbgemm_pack_quantized_matrix", + "torch.feature_alpha_dropout_", + "torch.feature_alpha_dropout", + "torch.feature_dropout_", + "torch.feature_dropout", + "torch.fill_", + "torch.fill", + "torch.fix_", + "torch.fix", + "torch.flatten", + "torch.flip", + "torch.fliplr", + "torch.flipud", + "torch.float_power", + "torch.floor_", + "torch.floor_divide", + "torch.floor", + "torch.fmax", + "torch.fmin", + "torch.fmod", + "torch.frac_", + "torch.frac", + "torch.frexp", + "torch.frobenius_norm", + "torch.from_file", + "torch.from_numpy", + "torch.frombuffer", + "torch.full_like", + "torch.full", + "torch.fused_moving_avg_obs_fake_quant", + "torch.gather", + "torch.gcd_", + "torch.gcd", + "torch.ge", + "torch.geqrf", + "torch.ger", + "torch.get_device", + "torch.gradient", + "torch.greater_equal", + "torch.greater", + "torch.grid_sampler_2d", + "torch.grid_sampler_3d", + "torch.grid_sampler", + "torch.group_norm", + "torch.gru_cell", + "torch.gru", + "torch.gt", + "torch.hamming_window", + "torch.hann_window", + "torch.hardshrink", + "torch.heaviside", + "torch.hinge_embedding_loss", + "torch.histc", + "torch.histogram", + "torch.histogramdd", + "torch.hsmm", + "torch.hsplit", + "torch.hspmm", + "torch.hstack", + "torch.hypot", + "torch.i0_", + "torch.i0", + "torch.igamma", + "torch.igammac", + "torch.imag", + "torch.index_add", + "torch.index_copy", + "torch.index_fill", + "torch.index_put_", + "torch.index_put", + "torch.index_reduce", + "torch.index_select", + "torch.indices_copy", + "torch.inner", + "torch.instance_norm", + "torch.int_repr", + "torch.inverse", + "torch.is_complex", + "torch.is_conj", + "torch.is_distributed", + "torch.is_floating_point", + "torch.is_inference", + "torch.is_neg", + "torch.is_nonzero", + "torch.is_same_size", + "torch.is_signed", + "torch.is_vulkan_available", + "torch.isclose", + "torch.isfinite", + "torch.isin", + "torch.isinf", + "torch.isnan", + "torch.isneginf", + "torch.isposinf", + "torch.isreal", + "torch.istft", + "torch.kaiser_window", + "torch.kl_div", + "torch.kron", + "torch.kthvalue", + "torch.layer_norm", + "torch.lcm_", + "torch.lcm", + "torch.ldexp_", + "torch.ldexp", + "torch.le", + "torch.lerp", + "torch.less_equal", + "torch.less", + "torch.lgamma", + "torch.linspace", + "torch.log_", + "torch.log_softmax", + "torch.log", + "torch.log10_", + "torch.log10", + "torch.log1p_", + "torch.log1p", + "torch.log2_", + "torch.log2", + "torch.logaddexp", + "torch.logaddexp2", + "torch.logcumsumexp", + "torch.logdet", + "torch.logical_and", + "torch.logical_not", + "torch.logical_or", + "torch.logical_xor", + "torch.logit_", + "torch.logit", + "torch.logspace", + "torch.logsumexp", + "torch.lstm_cell", + "torch.lstm", + "torch.lt", + "torch.lu_solve", + "torch.lu_unpack", + "torch.margin_ranking_loss", + "torch.masked_fill", + "torch.masked_scatter", + "torch.masked_select", + "torch.matmul", + "torch.matrix_exp", + "torch.matrix_power", + "torch.max_pool1d_with_indices", + "torch.max_pool1d", + "torch.max_pool2d", + "torch.max_pool3d", + "torch.max", + "torch.maximum", + "torch.mean", + "torch.median", + "torch.min", + "torch.minimum", + "torch.miopen_batch_norm", + "torch.miopen_convolution_add_relu", + "torch.miopen_convolution_relu", + "torch.miopen_convolution_transpose", + "torch.miopen_convolution", + "torch.miopen_depthwise_convolution", + "torch.miopen_rnn", + "torch.mkldnn_adaptive_avg_pool2d", + "torch.mkldnn_convolution", + "torch.mkldnn_linear_backward_weights", + "torch.mkldnn_max_pool2d", + "torch.mkldnn_max_pool3d", + "torch.mkldnn_rnn_layer", + "torch.mm", + "torch.mode", + "torch.moveaxis", + "torch.movedim", + "torch.msort", + "torch.mul", + "torch.multinomial", + "torch.multiply", + "torch.mv", + "torch.mvlgamma", + "torch.nan_to_num_", + "torch.nan_to_num", + "torch.nanmean", + "torch.nanmedian", + "torch.nanquantile", + "torch.nansum", + "torch.narrow_copy", + "torch.narrow", + "torch.native_batch_norm", + "torch.native_channel_shuffle", + "torch.native_dropout", + "torch.native_group_norm", + "torch.native_layer_norm", + "torch.native_norm", + "torch.ne", + "torch.neg_", + "torch.neg", + "torch.negative_", + "torch.negative", + "torch.nextafter", + "torch.nonzero_static", + "torch.nonzero", + "torch.norm_except_dim", + "torch.normal", + "torch.not_equal", + "torch.nuclear_norm", + "torch.numel", + "torch.ones_like", + "torch.ones", + "torch.orgqr", + "torch.ormqr", + "torch.outer", + "torch.pairwise_distance", + "torch.pdist", + "torch.permute_copy", + "torch.permute", + "torch.pinverse", + "torch.pixel_shuffle", + "torch.pixel_unshuffle", + "torch.poisson_nll_loss", + "torch.poisson", + "torch.polar", + "torch.polygamma", + "torch.positive", + "torch.pow", + "torch.prelu", + "torch._print", + "torch.prod", + "torch.promote_types", + "torch.put", + "torch.q_per_channel_axis", + "torch.q_per_channel_scales", + "torch.q_per_channel_zero_points", + "torch.q_scale", + "torch.q_zero_point", + "torch.qr", + "torch.quantile", + "torch.quantize_per_channel", + "torch.quantize_per_tensor_dynamic", + "torch.quantize_per_tensor", + "torch.quantized_batch_norm", + "torch.quantized_gru_cell", + "torch.quantized_lstm_cell", + "torch.quantized_max_pool1d", + "torch.quantized_max_pool2d", + "torch.quantized_max_pool3d", + "torch.quantized_rnn_relu_cell", + "torch.quantized_rnn_tanh_cell", + "torch.rad2deg_", + "torch.rad2deg", + "torch.rand_like", + "torch.rand", + "torch.randint_like", + "torch.randint", + "torch.randn_like", + "torch.randn", + "torch.randperm", + "torch.range", + "torch.ravel", + "torch.real", + "torch.reciprocal_", + "torch.reciprocal", + "torch.relu_", + "torch.relu", + "torch.remainder", + "torch.renorm", + "torch.repeat_interleave", + "torch.reshape", + "torch.resolve_conj", + "torch.resolve_neg", + "torch.result_type", + "torch.rms_norm", + "torch.rnn_relu_cell", + "torch.rnn_relu", + "torch.rnn_tanh_cell", + "torch.rnn_tanh", + "torch.roll", + "torch.rot90", + "torch.round_", + "torch.round", + "torch.row_indices_copy", + "torch.row_stack", + "torch.rrelu_", + "torch.rrelu", + "torch.rsqrt_", + "torch.rsqrt", + "torch.rsub", + "torch.saddmm", + "torch.scalar_tensor", + "torch.scatter_add", + "torch.scatter_reduce", + "torch.scatter", + "torch.searchsorted", + "torch.segment_reduce", + "torch.select_copy", + "torch.select_scatter", + "torch.select", + "torch.selu_", + "torch.selu", + "torch.sgn", + "torch.sigmoid_", + "torch.sigmoid", + "torch.sign", + "torch.signal.windows.windows.sqrt", + "torch.signbit", + "torch.sin_", + "torch.sin", + "torch.sinc_", + "torch.sinc", + "torch.sinh_", + "torch.sinh", + "torch.slice_copy", + "torch.slice_scatter", + "torch.slogdet", + "torch.smm", + "torch.softmax", + "torch.sort", + "torch.split_copy", + "torch.split_with_sizes_copy", + "torch.split_with_sizes", + "torch.spmm", + "torch.sqrt_", + "torch.sqrt", + "torch.square_", + "torch.square", + "torch.squeeze_copy", + "torch.squeeze", + "torch.sspaddmm", + "torch.stack", + "torch.std_mean", + "torch.std", + "torch.sub", + "torch.subtract", + "torch.sum", + "torch.svd", + "torch.swapaxes", + "torch.swapdims", + "torch.sym_constrain_range_for_size", + "torch.sym_constrain_range", + "torch.t_copy", + "torch.t", + "torch.take_along_dim", + "torch.take", + "torch.tan_", + "torch.tan", + "torch.tanh_", + "torch.tanh", + "torch.tensor_split", + "torch.tensor", + "torch.threshold_", + "torch.threshold", + "torch.tile", + "torch.topk", + "torch.trace", + "torch.transpose_copy", + "torch.transpose", + "torch.trapezoid", + "torch.trapz", + "torch.triangular_solve", + "torch.tril_indices", + "torch.tril", + "torch.triplet_margin_loss", + "torch.triu_indices", + "torch.triu", + "torch.true_divide", + "torch.trunc_", + "torch.trunc", + "torch.unbind_copy", + "torch.unbind", + "torch.unflatten", + "torch.unfold_copy", + "torch.unsafe_chunk", + "torch.unsafe_split_with_sizes", + "torch.unsafe_split", + "torch.unsqueeze_copy", + "torch.unsqueeze", + "torch.values_copy", + "torch.vander", + "torch.var_mean", + "torch.var", + "torch.vdot", + "torch.view_as_complex_copy", + "torch.view_as_complex", + "torch.view_as_real_copy", + "torch.view_as_real", + "torch.view_copy", + "torch.vsplit", + "torch.vstack", + "torch.where", + "torch.xlogy_", + "torch.xlogy", + "torch.zero_", + "torch.zeros", + "torch.zeros_like", + "torch._fused_sgd_", + "torch.slice_inverse", + "torch._assert_scalar", + "torch._functional_assert_scalar", + "torch.xpu._get_device_properties", + ], + TorchInGraphFunctionVariable, +) + + +if sys.version_info >= (3, 11): + torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable + torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable + + +# In graph functions (including constant folding) that are not C bindings +# NOTE: [Cacheability of in-graph torch functions] +# Functions in this list have the property that graphs containing them are safe to cache/serialize. +# serialize given only the information in the graph. I.e, either: +# - Your function does not access or close over global state, or +# - Your function closes over global state, but this state is guarded by dynamo, either +# through constant folding or other mechanisms +# If your function needs a custom special handler (via @register on TorchInGraphFunctionVariable), +# or captures global state, please add it to manual_torch_name_rule_map instead +torch_non_c_binding_in_graph_functions = dict.fromkeys( + [ + "torch.__future__.get_overwrite_module_params_on_conversion", + "torch.__future__.set_overwrite_module_params_on_conversion", + "torch.__getattr__", + "torch._assert", + "torch._check_index", + "torch._check_is_size", + "torch._check_not_implemented", + "torch._check_tensor_all_with", + "torch._check_tensor_all", + "torch._check_type", + "torch._check_value", + "torch._check_with", + "torch._check", + "torch._compile._disable_dynamo", + "torch._functorch.apis.chunk_vmap", + "torch._functorch.batch_norm_replacement.batch_norm_without_running_stats", + "torch._functorch.batch_norm_replacement.replace_all_batch_norm_modules_", + "torch._functorch.deprecated.combine_state_for_ensemble", + "torch._functorch.deprecated.functionalize", + "torch._functorch.deprecated.get_warning", + "torch._functorch.deprecated.make_functional_with_buffers", + "torch._functorch.deprecated.make_functional", + "torch._functorch.deprecated.setup_docs", + "torch._functorch.deprecated.warn_deprecated", + "torch._functorch.eager_transforms._any_differentiable", + "torch._functorch.eager_transforms._autograd_grad", + "torch._functorch.eager_transforms._set_tensor_requires_grad", + "torch._functorch.eager_transforms._is_differentiable", + "torch._functorch.eager_transforms._maybe_unwrap_functional_tensor", + "torch._functorch.eager_transforms._maybe_wrap_functional_tensor", + "torch._functorch.eager_transforms._unwrap_all_tensors_from_functional", + "torch._functorch.eager_transforms._wrap_all_tensors_to_functional", + "torch._functorch.eager_transforms.assert_flat_tuple_of_tensors", + "torch._functorch.eager_transforms.functionalize", + "torch._functorch.eager_transforms.lazy_dynamo_disable", + "torch._functorch.eager_transforms.noop", + "torch._functorch.utils.enable_single_level_autograd_function", + "torch._functorch.utils.exposed_in", + "torch._functorch.utils.unwrap_dead_wrappers", + "torch._functorch.vmap.lazy_load_decompositions", + "torch._guards.compile_context", + "torch._guards.detect_fake_mode", + "torch._guards.tracing", + "torch._higher_order_ops.map._has_potential_branch_input_alias", + "torch._higher_order_ops.map._has_potential_branch_input_mutation", + "torch._higher_order_ops.map._stack_pytree", + "torch._higher_order_ops.map._unstack_pytree", + "torch._higher_order_ops.map.create_fw_bw_graph", + "torch._higher_order_ops.map.map_autograd", + "torch._higher_order_ops.map.map_dense", + "torch._higher_order_ops.map.map_fake_tensor_mode", + "torch._higher_order_ops.map.map_functionalize", + "torch._higher_order_ops.map.map_proxy_torch_dispatch_mode", + "torch._higher_order_ops.map.map_wrapper", + "torch._higher_order_ops.map.trace_map", + "torch._higher_order_ops.out_dtype.elementwise_dtypes", + "torch._higher_order_ops.out_dtype.is_int_mm", + "torch._higher_order_ops.out_dtype.out_dtype_dense", + "torch._higher_order_ops.out_dtype.out_dtype_fake_tensor_mode", + "torch._higher_order_ops.out_dtype.out_dtype_fallback", + "torch._higher_order_ops.out_dtype.out_dtype_func", + "torch._higher_order_ops.out_dtype.out_dtype_proxy", + "torch._higher_order_ops.out_dtype.trace_out_dtype", + "torch._higher_order_ops.utils.autograd_not_implemented_inner", + "torch._higher_order_ops.utils.autograd_not_implemented", + "torch._linalg_utils._symeig", + "torch._linalg_utils.basis", + "torch._linalg_utils.bform", + "torch._linalg_utils.eig", + "torch._linalg_utils.get_floating_dtype", + "torch._linalg_utils.is_sparse", + "torch._linalg_utils.lstsq", + "torch._linalg_utils.matmul", + "torch._linalg_utils.matrix_rank", + "torch._linalg_utils.qform", + "torch._linalg_utils.solve", + "torch._linalg_utils.symeig", + "torch._load_global_deps", + "torch._lowrank._svd_lowrank", + "torch._lowrank.get_approximate_basis", + "torch._lowrank.pca_lowrank", + "torch._lowrank.svd_lowrank", + "torch._preload_cuda_deps", + "torch._register_device_module", + "torch._running_with_deploy", + "torch._utils._dummy_type", + "torch._utils._flatten_dense_tensors", + "torch._utils._unflatten_dense_tensors", + "torch._weights_only_unpickler._get_allowed_globals", + "torch._weights_only_unpickler.load", + "torch.accelerator.current_accelerator", + "torch.accelerator.current_device_index", + "torch.accelerator.current_stream", + "torch.accelerator.device_count", + "torch.accelerator.is_available", + "torch.accelerator.set_stream", + "torch.accelerator.synchronize", + "torch.align_tensors", + "torch.amp.autocast_mode._enter_autocast", + "torch.amp.autocast_mode._exit_autocast", + "torch.amp.autocast_mode.autocast_decorator", + "torch.amp.autocast_mode.custom_bwd", + "torch.amp.autocast_mode.custom_fwd", + "torch.are_deterministic_algorithms_enabled", + "torch.atleast_1d", + "torch.atleast_2d", + "torch.atleast_3d", + "torch.autograd._calculate_shape", + "torch.autograd._is_checkpoint_valid", + "torch.autograd._make_grads", + "torch.autograd._register_py_tensor_class_for_device", + "torch.autograd._tensor_or_tensors_to_tuple", + "torch.autograd.forward_ad._maybe_load_decompositions", + "torch.autograd.function._iter_filter", + "torch.autograd.function._iter_jit_values", + "torch.autograd.function._iter_None_tensors", + "torch.autograd.function._iter_tensors_permissive", + "torch.autograd.function._iter_tensors", + "torch.autograd.function._jit_unwrap_structured", + "torch.autograd.function._map_tensor_data", + "torch.autograd.function._nested_map", + "torch.autograd.function._unflatten", + "torch.autograd.function.once_differentiable", + "torch.autograd.function.traceable", + "torch.autograd.functional._as_tuple_nocheck", + "torch.autograd.functional._as_tuple", + "torch.autograd.functional._autograd_grad", + "torch.autograd.functional._check_requires_grad", + "torch.autograd.functional._construct_standard_basis_for", + "torch.autograd.functional._fill_in_zeros", + "torch.autograd.functional._grad_postprocess", + "torch.autograd.functional._grad_preprocess", + "torch.autograd.functional._jacfwd", + "torch.autograd.functional._tuple_postprocess", + "torch.autograd.functional._validate_v", + "torch.autograd.functional.hessian", + "torch.autograd.functional.hvp", + "torch.autograd.functional.jacobian", + "torch.autograd.functional.jvp", + "torch.autograd.functional.vhp", + "torch.autograd.functional.vjp", + "torch.autograd.grad_mode._enter_inference_mode", + "torch.autograd.grad_mode._exit_inference_mode", + "torch.autograd.graph._get_sid", + "torch.autograd.graph._get_tid", + "torch.autograd.graph.allow_mutation_on_saved_tensors", + "torch.autograd.graph.get_gradient_edge", + "torch.autograd.graph.increment_version", + "torch.autograd.graph.register_multi_grad_hook", + "torch.autograd.variable", + "torch.backends.__allow_nonbracketed_mutation", + "torch.backends.cpu.get_cpu_capability", + "torch.backends.cuda.can_use_efficient_attention", + "torch.backends.cuda.can_use_flash_attention", + "torch.backends.cuda.can_use_cudnn_attention", + "torch.backends.cuda.enable_flash_sdp", + "torch.backends.cuda.enable_math_sdp", + "torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp", + "torch.backends.cuda.enable_mem_efficient_sdp", + "torch.backends.cuda.flash_sdp_enabled", + "torch.backends.cuda.is_built", + "torch.backends.cuda.is_flash_attention_available", + "torch.backends.cuda.math_sdp_enabled", + "torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed", + "torch.backends.cuda.mem_efficient_sdp_enabled", + "torch.backends.cuda.cudnn_sdp_enabled", + "torch.backends.cuda.enable_cudnn_sdp", + "torch.backends.cuda.preferred_blas_library", + "torch.backends.cuda.preferred_linalg_library", + "torch.backends.cuda.preferred_rocm_fa_library", + "torch.backends.cuda.sdp_kernel", + "torch.backends.cudnn._init", + "torch.backends.cudnn.flags", + "torch.backends.cudnn.is_acceptable", + "torch.backends.cudnn.is_available", + "torch.backends.cudnn.set_flags", + "torch.backends.cudnn.version", + "torch.backends.disable_global_flags", + "torch.backends.flags_frozen", + "torch.backends.mkl.is_available", + "torch.backends.mkldnn.flags", + "torch.backends.mkldnn.is_available", + "torch.backends.mkldnn.set_flags", + "torch.backends.mps._init", + "torch.backends.mps.is_available", + "torch.backends.mps.is_built", + "torch.backends.mps.is_macos13_or_newer", + "torch.backends.openmp.is_available", + "torch.backends.quantized._get_qengine_id", + "torch.backends.quantized._get_qengine_str", + "torch.block_diag", + "torch.broadcast_tensors", + "torch.cartesian_prod", + "torch.cdist", + "torch.chain_matmul", + "torch.compile", + "torch.compiled_with_cxx11_abi", + "torch._C._cpu._is_avx2_supported", + "torch._C._cpu._is_avx512_supported", + "torch._C._cpu._is_avx512_vnni_supported", + "torch._C._cpu._is_avx512_bf16_supported", + "torch._C._cpu._is_amx_tile_supported", + "torch._C._cpu._is_amx_fp16_supported", + "torch.cpu._init_amx", + "torch.cpu.current_device", + "torch.cpu.current_stream", + "torch.cpu.device_count", + "torch.cpu.is_available", + "torch.cpu.set_device", + "torch.cpu.stream", + "torch.cpu.synchronize", + "torch.cuda._check_capability", + "torch.cuda._check_cubins", + "torch.cuda._device_count_amdsmi", + "torch.cuda._device_count_nvml", + "torch.cuda._get_amdsmi_handler", + "torch.cuda._get_amdsmi_device_index", + "torch.cuda._get_device", + "torch.cuda._get_generator", + "torch.cuda._get_nvml_device_index", + "torch.cuda._get_pynvml_handler", + "torch.cuda._get_rng_state_offset", + "torch.cuda._is_compiled", + "torch.cuda._lazy_call", + "torch.cuda._lazy_init", + "torch.cuda._memory_viz._block_extra_legacy", + "torch.cuda._memory_viz._block_extra", + "torch.cuda._memory_viz._format_size", + "torch.cuda._memory_viz._format_viz", + "torch.cuda._memory_viz._frame_filter", + "torch.cuda._memory_viz._frame_fmt", + "torch.cuda._memory_viz._frames_fmt", + "torch.cuda._memory_viz._profile_to_snapshot", + "torch.cuda._memory_viz._report_free", + "torch.cuda._memory_viz._write_blocks", + "torch.cuda._memory_viz.calc_active", + "torch.cuda._memory_viz.compare", + "torch.cuda._memory_viz.format_flamegraph", + "torch.cuda._memory_viz.memory", + "torch.cuda._memory_viz.profile_plot", + "torch.cuda._memory_viz.segment_plot", + "torch.cuda._memory_viz.segments", + "torch.cuda._memory_viz.segsum", + "torch.cuda._memory_viz.trace_plot", + "torch.cuda._memory_viz.trace", + "torch.cuda._nvml_based_avail", + "torch.cuda._parse_visible_devices", + "torch.cuda._raw_device_count_amdsmi", + "torch.cuda._raw_device_count_nvml", + "torch.cuda._raw_device_uuid_amdsmi", + "torch.cuda._raw_device_uuid_nvml", + "torch.cuda._register_triton_kernels", + "torch.cuda._set_rng_state_offset", + "torch.cuda._set_stream_by_id", + "torch.cuda._sleep", + "torch.cuda._transform_uuid_to_ordinals", + "torch.cuda._utils._get_device_index", + "torch.cuda.amp.autocast_mode._cast", + "torch.cuda.amp.autocast_mode.custom_bwd", + "torch.cuda.amp.autocast_mode.custom_fwd", + "torch.cuda.amp.common.amp_definitely_not_available", + "torch.amp.grad_scaler._refresh_per_optimizer_state", + "torch.cuda.can_device_access_peer", + "torch.cuda.check_error", + "torch.cuda.clock_rate", + "torch.cuda.cudart", + "torch.cuda.current_blas_handle", + "torch.cuda.current_stream", + "torch.cuda.default_stream", + "torch.cuda.device_count", + "torch.cuda.device_memory_used", + "torch.cuda.get_arch_list", + "torch.cuda.get_device_capability", + "torch.cuda.get_device_name", + "torch.cuda.get_device_properties", + "torch.cuda.get_gencode_flags", + "torch.cuda.get_sync_debug_mode", + "torch.cuda.graphs.graph_pool_handle", + "torch.cuda.graphs.is_current_stream_capturing", + "torch.cuda.graphs.make_graphed_callables", + "torch.cuda.init", + "torch.cuda.ipc_collect", + "torch.cuda.is_available", + "torch.cuda.is_bf16_supported", + "torch.cuda.is_initialized", + "torch.cuda.jiterator._create_jit_fn", + "torch.cuda.jiterator._create_multi_output_jit_fn", + "torch.cuda.memory_usage", + "torch.cuda.memory._dump_snapshot", + "torch.cuda.memory._free_mutex", + "torch.cuda.memory._get_current_allocator", + "torch.cuda.memory._host_allocator", + "torch.cuda.memory._record_memory_history_impl", + "torch.cuda.memory._record_memory_history_legacy", + "torch.cuda.memory._record_memory_history", + "torch.cuda.memory._save_memory_usage", + "torch.cuda.memory._save_segment_usage", + "torch.cuda.memory._set_allocator_settings", + "torch.cuda.memory._snapshot", + "torch.cuda.memory.caching_allocator_alloc", + "torch.cuda.memory.caching_allocator_delete", + "torch.cuda.memory.caching_allocator_enable", + "torch.cuda.memory.change_current_allocator", + "torch.cuda.memory.empty_cache", + "torch.cuda.memory.get_allocator_backend", + "torch.cuda.memory.get_per_process_memory_fraction", + "torch.cuda.memory.host_memory_stats_as_nested_dict", + "torch.cuda.memory.host_memory_stats", + "torch.cuda.memory.list_gpu_processes", + "torch.cuda.memory.max_memory_allocated", + "torch.cuda.memory.max_memory_cached", + "torch.cuda.memory.max_memory_reserved", + "torch.cuda.memory.mem_get_info", + "torch.cuda.memory.memory_allocated", + "torch.cuda.memory.memory_cached", + "torch.cuda.memory.memory_reserved", + "torch.cuda.memory.memory_snapshot", + "torch.cuda.memory.memory_stats_as_nested_dict", + "torch.cuda.memory.memory_stats", + "torch.cuda.memory.memory_summary", + "torch.cuda.memory.reset_accumulated_host_memory_stats", + "torch.cuda.memory.reset_accumulated_memory_stats", + "torch.cuda.memory.reset_max_memory_allocated", + "torch.cuda.memory.reset_max_memory_cached", + "torch.cuda.memory.reset_peak_host_memory_stats", + "torch.cuda.memory.reset_peak_memory_stats", + "torch.cuda.memory.set_per_process_memory_fraction", + "torch.cuda.nccl._check_sequence_type", + "torch.cuda.nccl.all_gather", + "torch.cuda.nccl.all_reduce", + "torch.cuda.nccl.broadcast", + "torch.cuda.nccl.init_rank", + "torch.cuda.nccl.is_available", + "torch.cuda.nccl.reduce_scatter", + "torch.cuda.nccl.reduce", + "torch.cuda.nccl.unique_id", + "torch.cuda.nccl.version", + "torch.cuda.nvtx.mark", + "torch.cuda.nvtx.range_end", + "torch.cuda.nvtx.range_pop", + "torch.cuda.nvtx.range_push", + "torch.cuda.nvtx.range_start", + "torch.cuda.nvtx.range", + "torch.cuda.power_draw", + "torch.cuda.profiler.init", + "torch.cuda.profiler.profile", + "torch.cuda.profiler.start", + "torch.cuda.profiler.stop", + "torch.cuda.random.get_rng_state_all", + "torch.cuda.random.initial_seed", + "torch.cuda.random.manual_seed_all", + "torch.cuda.random.manual_seed", + "torch.cuda.random.seed_all", + "torch.cuda.random.seed", + "torch.cuda.random.set_rng_state_all", + "torch.cuda.set_stream", + "torch.cuda.set_sync_debug_mode", + "torch.cuda.stream", + "torch.cuda.synchronize", + "torch.cuda.temperature", + "torch.cuda.utilization", + "torch.einsum", + "torch.functional._check_list_size", + "torch.functional._consecutive_return_counts", + "torch.functional._consecutive_return_inverse_false", + "torch.functional._consecutive_return_inverse_true", + "torch.functional._consecutive_return_inverse", + "torch.functional._consecutive_return_output", + "torch.functional._lu_impl", + "torch.functional._lu_no_infos", + "torch.functional._lu_with_infos", + "torch.functional._meshgrid", + "torch.functional._return_counts", + "torch.functional._return_inverse_false", + "torch.functional._return_inverse_true", + "torch.functional._return_inverse", + "torch.functional._return_output", + "torch.functional._unique_consecutive_impl", + "torch.functional._unique_impl", + "torch.functional._unravel_index", + "torch.functional.broadcast_shapes", + "torch.functional.lu", + "torch.functional.unique", + "torch.functional.unravel_index", + "torch.futures.collect_all", + "torch.futures.wait_all", + "torch.fx.experimental.const_fold.split_const_subgraphs", + "torch.fx.experimental.proxy_tensor.make_fx", + "torch.get_deterministic_debug_mode", + "torch.get_float32_matmul_precision", + "torch.is_deterministic_algorithms_warn_only_enabled", + "torch.is_storage", + "torch.is_tensor", + "torch.is_warn_always_enabled", + "torch.masked._ops._any", + "torch.masked._ops._apply_docstring_templates", + "torch.masked._ops._canonical_dim", + "torch.masked._ops._combine_input_and_mask", + "torch.masked._ops._generate_docstring", + "torch.masked._ops._input_mask", + "torch.masked._ops._output_mask", + "torch.masked._ops._reduction_identity", + "torch.masked._ops._sparse_coo_flatten_indices", + "torch.masked._ops._sparse_coo_scatter_reduction_helper", + "torch.masked._ops._sparse_coo_where", + "torch.masked._ops._sparse_csr_segment_reduction_helper", + "torch.masked._ops._sparse_csr_where", + "torch.masked._ops._std_var", + "torch.masked._ops._where", + "torch.masked._ops.amax", + "torch.masked._ops.amin", + "torch.masked._ops.argmax", + "torch.masked._ops.argmin", + "torch.masked._ops.corresponding_real_dtype", + "torch.masked._ops.cumprod", + "torch.masked._ops.cumsum", + "torch.masked._ops.log_softmax", + "torch.masked._ops.logaddexp", + "torch.masked._ops.logsumexp", + "torch.masked._ops.mean", + "torch.masked._ops.median", + "torch.masked._ops.norm", + "torch.masked._ops.normalize", + "torch.masked._ops.prod", + "torch.masked._ops.softmax", + "torch.masked._ops.softmin", + "torch.masked._ops.std", + "torch.masked._ops.sum", + "torch.masked._ops.var", + "torch.meshgrid", + "torch.mps._get_default_mps_generator", + "torch.mps.current_allocated_memory", + "torch.mps.driver_allocated_memory", + "torch.mps.empty_cache", + "torch.mps.get_rng_state", + "torch.mps.manual_seed", + "torch.mps.profiler.profile", + "torch.mps.profiler.start", + "torch.mps.profiler.stop", + "torch.mps.seed", + "torch.mps.set_per_process_memory_fraction", + "torch.mps.set_rng_state", + "torch.mps.synchronize", + "torch.nested._internal.nested_tensor.buffer_from_jagged", + "torch.nested._internal.nested_tensor.get_tensor_symint", + "torch.nested._internal.nested_tensor.is_expandable_to", + "torch.nested._internal.nested_tensor.jagged_from_list", + "torch.nested._internal.nested_tensor.jagged_from_tensor_and_lengths", + "torch.nested._internal.nested_tensor.nested_view_from_values_offsets", + "torch.nested._internal.nested_tensor.nested_view_from_values_offsets_lengths", + "torch.nested.as_nested_tensor", + "torch.nested.narrow", + "torch.nested.nested_tensor", + "torch.nn._reduction.get_enum", + "torch.nn._reduction.legacy_get_enum", + "torch.nn._reduction.legacy_get_string", + "torch.nn.factory_kwargs", + "torch.nn.functional.adaptive_avg_pool2d", + "torch.nn.functional.adaptive_avg_pool3d", + "torch.nn.functional.adaptive_max_pool1d_with_indices", + "torch.nn.functional.adaptive_max_pool1d", + "torch.nn.functional.adaptive_max_pool2d_with_indices", + "torch.nn.functional.adaptive_max_pool2d", + "torch.nn.functional.adaptive_max_pool3d_with_indices", + "torch.nn.functional.adaptive_max_pool3d", + "torch.nn.functional.affine_grid", + "torch.nn.functional.alpha_dropout", + "torch.nn.functional.assert_int_or_pair", + "torch.nn.functional.batch_norm", + "torch.nn.functional.binary_cross_entropy_with_logits", + "torch.nn.functional.binary_cross_entropy", + "torch.nn.functional.celu", + "torch.nn.functional.cosine_embedding_loss", + "torch.nn.functional.cross_entropy", + "torch.nn.functional.ctc_loss", + "torch.nn.functional.dropout", + "torch.nn.functional.dropout1d", + "torch.nn.functional.dropout2d", + "torch.nn.functional.dropout3d", + "torch.nn.functional.elu", + "torch.nn.functional.embedding_bag", + "torch.nn.functional.embedding", + "torch.nn.functional.feature_alpha_dropout", + "torch.nn.functional.fold", + "torch.nn.functional.fractional_max_pool2d_with_indices", + "torch.nn.functional.fractional_max_pool2d", + "torch.nn.functional.fractional_max_pool3d_with_indices", + "torch.nn.functional.fractional_max_pool3d", + "torch.nn.functional.gaussian_nll_loss", + "torch.nn.functional.glu", + "torch.nn.functional.grid_sample", + "torch.nn.functional.group_norm", + "torch.nn.functional.gumbel_softmax", + "torch.nn.functional.hardsigmoid", + "torch.nn.functional.hardswish", + "torch.nn.functional.hardtanh", + "torch.nn.functional.hinge_embedding_loss", + "torch.nn.functional.huber_loss", + "torch.nn.functional.instance_norm", + "torch.nn.functional.interpolate", + "torch.nn.functional.kl_div", + "torch.nn.functional.l1_loss", + "torch.nn.functional.layer_norm", + "torch.nn.functional.leaky_relu", + "torch.nn.functional.local_response_norm", + "torch.nn.functional.log_softmax", + "torch.nn.functional.lp_pool1d", + "torch.nn.functional.lp_pool2d", + "torch.nn.functional.margin_ranking_loss", + "torch.nn.functional.max_pool1d_with_indices", + "torch.nn.functional.max_pool1d", + "torch.nn.functional.max_pool2d_with_indices", + "torch.nn.functional.max_pool2d", + "torch.nn.functional.max_pool3d_with_indices", + "torch.nn.functional.max_pool3d", + "torch.nn.functional.max_unpool1d", + "torch.nn.functional.max_unpool2d", + "torch.nn.functional.max_unpool3d", + "torch.nn.functional.mish", + "torch.nn.functional.mse_loss", + "torch.nn.functional.multi_head_attention_forward", + "torch.nn.functional.multi_margin_loss", + "torch.nn.functional.multilabel_margin_loss", + "torch.nn.functional.multilabel_soft_margin_loss", + "torch.nn.functional.nll_loss", + "torch.nn.functional.normalize", + "torch.nn.functional.poisson_nll_loss", + "torch.nn.functional.relu", + "torch.nn.functional.relu6", + "torch.nn.functional.rrelu", + "torch.nn.functional.selu", + "torch.nn.functional.sigmoid", + "torch.nn.functional.silu", + "torch.nn.functional.smooth_l1_loss", + "torch.nn.functional.soft_margin_loss", + "torch.nn.functional.softmax", + "torch.nn.functional.softmin", + "torch.nn.functional.softsign", + "torch.nn.functional.tanh", + "torch.nn.functional.tanhshrink", + "torch.nn.functional.triplet_margin_loss", + "torch.nn.functional.unfold", + "torch.nn.functional.upsample_bilinear", + "torch.nn.functional.upsample_nearest", + "torch.nn.functional.upsample", + "torch.nn.grad._pair", + "torch.nn.grad._single", + "torch.nn.grad._triple", + "torch.nn.grad.conv1d_input", + "torch.nn.grad.conv1d_weight", + "torch.nn.grad.conv2d_input", + "torch.nn.grad.conv2d_weight", + "torch.nn.grad.conv3d_input", + "torch.nn.grad.conv3d_weight", + "torch.nn.modules.activation._is_make_fx_tracing", + "torch.nn.modules.utils._list_with_default", + "torch.nn.modules.utils._ntuple", + "torch.nn.modules.utils._quadruple", + "torch.nn.modules.utils._reverse_repeat_tuple", + "torch.nn.modules.utils.consume_prefix_in_state_dict_if_present", + "torch.nn.parameter.is_lazy", + "torch.norm", + "torch.quantization.default_eval_fn", + "torch.random._seed_custom_device", + "torch.random.fork_rng", + "torch.random.initial_seed", + "torch.random.seed", + "torch.return_types.pytree_register_structseq", + "torch.set_default_dtype", + "torch.set_default_tensor_type", + "torch.set_deterministic_debug_mode", + "torch.set_float32_matmul_precision", + "torch.set_warn_always", + "torch.signal.windows.windows._add_docstr", + "torch.signal.windows.windows._window_function_checks", + "torch.signal.windows.windows.bartlett", + "torch.signal.windows.windows.blackman", + "torch.signal.windows.windows.cosine", + "torch.signal.windows.windows.exponential", + "torch.signal.windows.windows.gaussian", + "torch.signal.windows.windows.general_cosine", + "torch.signal.windows.windows.general_hamming", + "torch.signal.windows.windows.hamming", + "torch.signal.windows.windows.hann", + "torch.signal.windows.windows.kaiser", + "torch.signal.windows.windows.merge_dicts", + "torch.signal.windows.windows.nuttall", + "torch.signal.windows.windows.parse_kwargs", + "torch.sparse.semi_structured.to_sparse_semi_structured", + "torch.sparse.sum", + "torch.split", + "torch.stft", + "torch.sym_float", + "torch.sym_int", + "torch.sym_ite", + "torch.sym_max", + "torch.sym_min", + "torch.sym_not", + "torch.tensordot", + "torch.unique_consecutive", + "torch.use_deterministic_algorithms", + "torch.xpu._get_device", + "torch.xpu._get_generator", + "torch.xpu._get_rng_state_offset", + "torch.xpu._is_compiled", + "torch.xpu._lazy_call", + "torch.xpu._lazy_init", + "torch.xpu._set_rng_state_offset", + "torch.xpu._set_stream_by_id", + "torch.xpu._utils._get_device_index", + "torch.xpu.current_device", + "torch.xpu.current_stream", + "torch.xpu.device_count", + "torch.xpu.get_arch_list", + "torch.xpu.get_device_capability", + "torch.xpu.get_device_name", + "torch.xpu.get_device_properties", + "torch.xpu.get_gencode_flags", + "torch.xpu.get_stream_from_external", + "torch.xpu.init", + "torch.xpu.is_available", + "torch.xpu.is_bf16_supported", + "torch.xpu.is_initialized", + "torch.xpu.memory.empty_cache", + "torch.xpu.memory.max_memory_allocated", + "torch.xpu.memory.max_memory_reserved", + "torch.xpu.memory.mem_get_info", + "torch.xpu.memory.memory_allocated", + "torch.xpu.memory.memory_reserved", + "torch.xpu.memory.memory_stats_as_nested_dict", + "torch.xpu.memory.memory_stats", + "torch.xpu.memory.reset_accumulated_memory_stats", + "torch.xpu.memory.reset_peak_memory_stats", + "torch.xpu.random.initial_seed", + "torch.xpu.random.seed_all", + "torch.xpu.random.seed", + "torch.xpu.set_stream", + "torch.xpu.synchronize", + ], + TorchInGraphFunctionVariable, +) + + +torch_name_rule_map = [ + manual_torch_name_rule_map, + torch_c_binding_in_graph_functions, + torch_non_c_binding_in_graph_functions, +] + + +""" +Generate the torch object - Dynamo tracing rule (the wrapping variable) map. +""" + + +@functools.cache +def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: + d: dict[Any, type[VariableTracker]] = {} + for m in torch_name_rule_map: + for k, v in m.items(): # type: ignore[attr-defined] + if ".py#" not in k: + obj = load_object(k) + else: + obj = _module_dir(torch) + k[len("torch/") :] + if obj is not None: + if obj in d and d[obj] != v: + raise AssertionError( + f"Duplicate torch object {obj} with different rules: {v}, {d[obj]}" + ) + else: + d[obj] = v + return d + + +def _load_obj_from_str(fully_qualified_name): + module, obj_name = fully_qualified_name.rsplit(".", maxsplit=1) + return getattr(importlib.import_module(module), obj_name) + + +""" +Load string represented torch objects. +""" + + +def load_object(name): + try: + x = name.split("#") + if len(x) == 2: + obj = _load_obj_from_str(x[0]) + val = getattr(obj, x[1]) + else: + assert len(x) == 1, f"Invalid obj name {name}" + val = _load_obj_from_str(x[0]) + val = unwrap_if_wrapper(val) + except (AttributeError, ImportError): + val = None + return val + + +""" +Get all torch.Tensor methods which are allowed to be in graph functions. +""" + + +@functools.cache +def get_tensor_method(): + disallowed_tensor_methods = {"__new__", "_make_wrapper_subclass", "_make_subclass"} + s = set() + for name in dir(torch.Tensor): + method = getattr(torch.Tensor, name) + if ( + isinstance( + method, + ( + types.MethodDescriptorType, + types.WrapperDescriptorType, + types.BuiltinFunctionType, + ), + ) + and name not in disallowed_tensor_methods + ): + s.add(method) + + # mlazos: these are functions which we handle specially in TensorVariable + s.add(torch.Tensor.__contains__) # type: ignore[arg-type] + s.add(torch.Tensor.register_hook) # type: ignore[arg-type] + return frozenset(s) + + +""" +Return if a torch object is ATen op or torch.Tensor method. +""" + + +def is_aten_op_or_tensor_method(obj): + return obj in get_tensor_method() or isinstance( + obj, + (torch._ops.OpOverloadPacket, torch._ops.OpOverload), + ) + + +class FunctionIdSet: + """ + Track a set of `id()`s of objects which are either allowed or not + allowed to go into the generated FX graph. Use to test for torch.*, + numpy.*, builtins.*, etc. + + Support user modification to permit customization of what can be + added to the graph and what will cause a graph break. + """ + + function_ids: Optional[set[int]] = None + function_names: Optional[dict[int, str]] = None + + def __init__( + self, lazy_initializer: Callable[[], Union[dict[int, str], set[int]]] + ) -> None: + self.lazy_initializer = lazy_initializer + + def __call__(self) -> set[int]: + if self.function_ids is None: + value = self.lazy_initializer() + if isinstance(value, dict): + self.function_ids = set(value.keys()) + self.function_names = value + else: + assert isinstance(value, set) + self.function_ids = value + return self.function_ids + + def get_name(self, idx: int, default: str): + self() # lazy init + assert self.function_names is not None + return self.function_names.get(idx, default) + + def add(self, idx: int): + function_ids = self() # lazy init + function_ids.add(idx) + + def remove(self, idx: int): + function_ids = self() + if idx in function_ids: + function_ids.remove(idx) + + def __contains__(self, idx: int) -> bool: + return idx in self() + + +@FunctionIdSet +def _allowed_callable_ids() -> dict[int, str]: + rv: dict[int, str] = {} + return rv + + +@FunctionIdSet +def _disallowed_callable_ids() -> dict[int, str]: + rv: dict[int, str] = {} + return rv + + +@FunctionIdSet +def _nonstrict_trace_callable_ids() -> dict[int, str]: + rv: dict[int, str] = {} + return rv + + +@FunctionIdSet +def _builtin_function_ids() -> dict[int, str]: + # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids + rv = { + id(v): f"builtins.{k}" + for k, v in builtins.__dict__.items() + if not k.startswith("_") and callable(v) + } + rv.update( + { + id(v): f"operator.{k}" + for k, v in operator.__dict__.items() + if not k.startswith("_") and callable(v) + } + ) + rv.update( + { + id(cast): "typing.cast", + id(copy.deepcopy): "copy.deepcopy", + } + ) + return rv + + +@FunctionIdSet +def _polyfilled_function_ids() -> set[int]: + # See also @torch._dynamo.decorators.substitute_in_graph(...), which adds items in _polyfilled_function_ids + return set() + + +@FunctionIdSet +def _numpy_function_ids() -> dict[int, str]: + unsupported_funcs = { + "seed", + "ranf", + "get_bit_generator", + "RandomState", + "set_bit_generator", + "sample", + } + + def is_supported(k, v, mod): + if not callable(v): + return False + if not getattr(v, "__module__", None): + return True + if v.__module__ == mod.__name__: + return True + if ( + v.__module__ == "numpy.random.mtrand" + and mod.__name__ == "numpy.random" + and k not in unsupported_funcs + ): + return True + return False + + rv = {} + for mod in NP_SUPPORTED_MODULES: + for k, v in mod.__dict__.items(): + if is_supported(k, v, mod): + rv[id(v)] = f"{mod.__name__}.{k}" + return rv + + +@FunctionIdSet +def _builtin_constant_ids() -> dict[int, str]: + """ + Collects constant builtins by eliminating callable items. + """ + rv = { + id(v): f"builtins.{k}" + for k, v in builtins.__dict__.items() + if not k.startswith("_") and not callable(v) + } + return rv + + +_lazy_module_init: dict[str, list[Callable[[], None]]] = defaultdict(list) + + +def add_module_init_func(name: str, init_func: Callable[[], None]) -> None: + """Register a module without eagerly importing it""" + # If the module is already imported, eagerly run init + assert "." not in name, f"Expected a root module name, but got {name}" + assert name not in _lazy_module_init + _lazy_module_init[name].append(init_func) + + +def _maybe_init_lazy_module(obj: object) -> None: + module = getattr(obj, "__module__", None) + if module is None: + return + + base_module = module.split(".")[0] + init_funcs = _lazy_module_init.pop(base_module, None) + if init_funcs is not None: + for fn in init_funcs: + fn() + + +def is_callable_allowed(obj) -> bool: + _maybe_init_lazy_module(obj) + return id(obj) in _allowed_callable_ids + + +def is_nonstrict_trace_callable(obj) -> bool: + _maybe_init_lazy_module(obj) + return id(obj) in _nonstrict_trace_callable_ids + + +def is_callable_disallowed(obj) -> bool: + _maybe_init_lazy_module(obj) + return id(obj) in _disallowed_callable_ids + + +def is_forbidden(obj) -> bool: + _maybe_init_lazy_module(obj) + return inspect.getattr_static(obj, "_dynamo_forbidden", False) + + +def is_builtin_callable(obj) -> bool: + # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids + return id(obj) in _builtin_function_ids + + +def is_builtin_constant(obj) -> bool: + return id(obj) in _builtin_constant_ids + + +def is_polyfilled_callable(obj) -> bool: + # See also @torch._dynamo.decorators.substitute_in_graph(...), which adds items in _polyfilled_function_ids + return id(obj) in _polyfilled_function_ids + + +def is_numpy(obj) -> bool: + if np is None: + return False + return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids + + +def is_numpy_dtype(obj) -> bool: + if np is None: + return False + return isinstance(obj, np.dtype) + + +def is_numpy_type_info(obj) -> bool: + if np is None: + return False + return isinstance(obj, (np.finfo, np.iinfo)) + + +BUILTIN_SKIPLIST = ( + abc, + collections, + copy, + random, + traceback, + linecache, +) + +# third party libraries skiplist is defined by str, because users may not use these libraries. +# we should use lazy import & skip in the future. +THIRDPARTY_SKIPLIST = ( + "fx2trt_oss", + "hypothesis", + "networkx", + "numpy", + "onnx", + "onnxruntime", + "onnx_tf", + "pandas", + "sklearn", + "tabulate", + "tensorflow", + "tensorrt", + "torch2trt", + "tqdm", + "tree", + "tvm", + "xarray", +) + + +def _as_posix_path(path): + posix_path = Path(os.path.normpath(path)).as_posix() + # os.path.normpath and pathlib.Path remove trailing slash, so we need to add it back + if path.endswith((os.path.sep, "/")): + posix_path += "/" + return posix_path + + +def _strip_init_py(s): + suffix = "__init__.py" + s = s.removesuffix(suffix) + return _as_posix_path(s) + + +def _module_dir(m: types.ModuleType): + # Protect against a module not exporting __file__ - this can happen for + # frozen modules, for example. + file = getattr(m, "__file__", None) + return file and _strip_init_py(file) + + +# These are legacy workarounds, don't add new modules to this list. +# Please use the MOD_INLINELIST instead to force inline functions under particular modules. +# +# NB: The only thing that is different about MOD_INLINELIST and LEGACY_MOD_INLINELIST +# is the behavior of a function f2 in the module when called by a function f1 +# in a module in MOD_SKIPLIST (see MOD_SKIPLIST for more details) +# +# LEGACY_MOD_INLINELIST is the same thing as Dynamo's behavior on a module that +# is not in any *_INLINELIST or *_SKIPLIST. +# That being said, we prefer people to add things to MOD_INLINELIST over +# LEGACY_MOD_INLINELIST because it is less likely to break existing tests. +LEGACY_MOD_INLINELIST = { + "torch._dynamo.external_utils", + "torch._export.db.examples", + "torch._export.wrappers", + "torch._functorch.apis", + "torch._functorch.deprecated", + "torch.nn.attention.flex_attention", + "torch.ao.quantization.pt2e.export_utils", + "torch.ao.quantization.pt2e.qat_utils", + "torch.ao.quantization.pt2e.representation.rewrite", + "torch.ao.quantization.pt2e.utils", + "torch.ao.quantization.quantizer.xnnpack_quantizer", + "torch.export.unflatten", +} + +if torch.distributed.is_available(): + LEGACY_MOD_INLINELIST |= { + "torch.distributed.tensor._api", + "torch.distributed.tensor.device_mesh", + "torch.distributed.device_mesh", + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper", + "torch.distributed.tensor.parallel._data_parallel_utils", + "torch.distributed.tensor.parallel._utils", + "torch.distributed.tensor.parallel.style", + # we have to add replicate to LEGACY_MOD_INLINELIST to ensure + # the forward_hook won't be ignored. + "torch.distributed._composable.replicate", + } + if not config.skip_fsdp_hooks: + LEGACY_MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard") + +# Force inline functions under these modules, even they are in *_SKIPLIST. +# We are using python module name instead of file or directory object to avoid circular dependency. +# Please keep this sorted alphabetically. +# +# Btw, it is not "ideal" for something to be in MOD_INLINELIST. If Dynamo +# fully supports a module, then the ideal case is that it is not in +# any *_INLINELIST or *_SKIPLIST: then, the behavior of Dynamo is that +# it will always inline into functions in the module. +MOD_INLINELIST = [ + "torch._decomp", + "torch._dynamo._trace_wrapped_higher_order_op", + "torch._dynamo.compiled_autograd", + "torch._dynamo.comptime", + "torch._dynamo.polyfills", + "torch._functorch._aot_autograd.subclass_parametrization", + "torch._functorch.autograd_function", + "torch._functorch.eager_transforms", + "torch._functorch.functional_call", + "torch._functorch.pyfunctorch", + "torch._functorch.vmap", + "torch._inductor.test_operators", + "torch._library.autograd", + "torch._library.custom_ops", + "torch._ops", + "torch._prims", + "torch._refs", + "torch._tensor", + "torch.amp.autocast_mode", + "torch.ao.nn", + "torch.autograd.function", + "torch.backends.cuda", + "torch.cuda.amp.autocast_mode", + "torch.distributions", + "torch.export._tree_utils", + "torch.export._wrapper_utils", + "torch.fx._pytree", + "torch.fx._symbolic_trace", + "torch.fx.experimental.proxy_tensor", + "torch.fx.passes.shape_prop", + "torch.nn", + "torch.overrides", + "torch.random", + "torch.return_types", + "torch.sparse", + "torch.testing", + "torch.utils._content_store", + "torch.utils._contextlib", + "torch.utils._cxx_pytree", + "torch.utils._device", + "torch.utils._foreach_utils", + "torch.utils._python_dispatch", + "torch.utils._pytree", + "torch.utils.hooks", +] +assert sorted(set(MOD_INLINELIST)) == MOD_INLINELIST +MOD_INLINELIST = set(MOD_INLINELIST) + + +if torch.distributed.is_available(): + MOD_INLINELIST.add("torch.distributed") + if not config.skip_fsdp_hooks: + MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard") + + +# By default, all functions under these modules are skipped. +# All the other knobs +# (torch_name_rule_map, MOD_INLINELIST, LEGACY_MOD_INLINELIST) +# take precedence over this list; e.g. if a function is in +# MOD_INLINELIST and MOD_SKIPLIST, then it will be inlined. +# See "A note on skip/inline rules" for more details. +# +# The skip is NOT recursive. If a function f1 in a module in MOD_SKIPLIST +# calls out to another function f2 in some other module, then Dynamo's +# behavior (skip/inline) depends on what we've marked f2 as: +# - if f2 is a function in a module in MOD_SKIPLIST, then we skip f2 +# - if f2 is a function in a module in MOD_INLINELIST, then we skip f2 +# - if f2 is a function in a module in LEGACY_MOD_INLINELIST, then we inline f2 +# - if f2 is a function in a module not in any *_LIST, then we inline f2 +MOD_SKIPLIST = [ + "torch._VF", + "torch.__future__", + "torch.__init__", + "torch._awaits", + "torch._classes", + "torch._compile", + "torch._custom_op", + "torch._custom_ops", + "torch._decomp", + "torch._deploy", + "torch._dispatch", + "torch._dynamo", + "torch._export", + "torch._functorch", + "torch._guards", + "torch._higher_order_ops.effects", + "torch._higher_order_ops.torchbind", + "torch._higher_order_ops.wrap", + "torch._inductor", + "torch._jit_internal", + "torch._lazy", + "torch._library", + "torch._linalg_utils", + "torch._lobpcg", + "torch._logging", + "torch._lowrank", + "torch._meta_registrations", + "torch._namedtensor_internals", + "torch._numpy", + "torch._ops", + "torch._prims", + "torch._prims_common", + "torch._python_dispatcher", + "torch._refs", + "torch._strobelight", + "torch._subclasses", + "torch._tensor", + "torch._tensor_str", + "torch._thread_safe_fork", + "torch._utils", + "torch._utils_internal", + "torch._vmap_internals", + "torch._weights_only_unpickler", + "torch.accelerator", + "torch.amp", + "torch.ao", + "torch.autograd", + "torch.backends", + "torch.compiler", + "torch.contrib", + "torch.cpu", + "torch.cuda", + "torch.distributed", + "torch.distributions", + "torch.export", + "torch.fb", + "torch.fft", + "torch.functional", + "torch.futures", + "torch.fx", + "torch.hub", + "torch.jit", + "torch.library", + "torch.linalg", + "torch.masked", + "torch.monitor", + "torch.mps", + "torch.mtia", + "torch.multiprocessing", + "torch.nested", + "torch.nn", + "torch.onnx", + "torch.overrides", + "torch.package", + "torch.profiler", + "torch.quantization", + "torch.quasirandom", + "torch.random", + "torch.serialization", + "torch.signal", + "torch.sparse", + "torch.special", + "torch.storage", + "torch.testing", + "torch.types", + "torch.utils", + "torch.xpu", +] + +assert sorted(set(MOD_SKIPLIST)) == MOD_SKIPLIST +MOD_SKIPLIST = set(MOD_SKIPLIST) + + +@functools.cache +def get_legacy_mod_inlinelist(): + inlinelist = { + _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + for m in LEGACY_MOD_INLINELIST + } + return inlinelist + + +@functools.cache +def get_mod_inlinelist(): + inlinelist = { + _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + for m in MOD_INLINELIST + } + return inlinelist + + +@functools.cache +def get_mod_skiplist(): + skiplist = { + _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + for m in MOD_SKIPLIST + } + return skiplist + + +# skip some standard python builtin libs +SKIP_DIRS = [ + "", + _as_posix_path(_config_module.__file__), + "triton/backends", +] +SKIP_DIRS.extend(map(_as_posix_path, filter(None, map(_module_dir, BUILTIN_SKIPLIST)))) + +SKIP_DIRS_RE = re.compile(r"match nothing^") + +# Skip fbcode paths(including torch.package paths) containing +# one of the following strings. +FBCODE_SKIP_DIRS: set[str] = set() + +FBCODE_SKIP_DIRS_RE = re.compile(f".*({'|'.join(map(re.escape, FBCODE_SKIP_DIRS))})") + +# Remove this after fbcode is fully migrated to tracing through torchrec. +FBCODE_SKIP_TORCHREC_DIRS = { + "torchrec/distributed", + "torchrec/fb/distributed", + "caffe2/torch/fb/sparsenn/pooled_embeddings_modules.py", +} + +FBCODE_SKIP_TORCHREC_DIRS_RE = re.compile( + f".*({'|'.join(re.escape(_as_posix_path(d)) for d in FBCODE_SKIP_TORCHREC_DIRS)})" +) + +# TODO(yanboliang, anijain2305) - There are a few concerns that we should +# resolve +# 1) Audit if torchrec/distributed is even required in FBCODE_SKIPS_DIR +# 2) To inline just one file but skip others in a directory, we could use +# manual_torch_name_rule_map but this one is hard because FBCODE can add unusual +# names like torch_package. +# So, this is a stop gap solution till then. +FBCODE_INLINE_FILES_IN_SKIPPED_DIRS = { + "torchrec/distributed/types.py", +} +FBCODE_INLINE_FILES_IN_SKIPPED_DIRS_RE = re.compile( + f".*({'|'.join(re.escape(_as_posix_path(d)) for d in FBCODE_INLINE_FILES_IN_SKIPPED_DIRS)})" +) + +# torch.optim is a special case, +# we usually want to inline it, but the directory +# structure does not match the module structure +# and we want to skip the functions in optim/lr_scheduler.py +# this has precedence over all other rules in check_file +FORCE_SKIP_FILES = {f"{_module_dir(torch)}optim/lr_scheduler.py"} + + +def _recompile_re(): + global SKIP_DIRS_RE + SKIP_DIRS_RE = re.compile( + rf"^[^\s<]*({'|'.join(re.escape(_as_posix_path(d)) for d in SKIP_DIRS)})" + ) + + +def add(import_name: str): + if isinstance(import_name, types.ModuleType): + return add(import_name.__name__) + assert isinstance(import_name, str) + from importlib.util import find_spec + + module_spec = find_spec(import_name) + if not module_spec: + return + origin = module_spec.origin + if origin is None: + return + SKIP_DIRS.append(_strip_init_py(origin)) + _recompile_re() + + +@dataclasses.dataclass +class SkipResult: + skipped: bool + reason: Optional[str] + + +def check_file(filename, is_inlined_call=False): + """Should skip this file?""" + if filename is None: + return SkipResult(True, "filename is None") + filename = _as_posix_path(filename) + if filename in FORCE_SKIP_FILES: + return SkipResult(True, "FORCE_SKIP_FILES") + + if any(filename.startswith(d) for d in get_legacy_mod_inlinelist()): + return SkipResult( + False, + "LEGACY_MOD_INLINELIST", + ) + if is_inlined_call and is_torch_inline_allowed(filename): + return SkipResult( + False, + "MOD_INLINELIST", + ) + if ( + is_fbcode() + and FBCODE_SKIP_DIRS + and bool(FBCODE_SKIP_DIRS_RE.match(filename)) + and not bool(FBCODE_INLINE_FILES_IN_SKIPPED_DIRS_RE.match(filename)) + ): + return SkipResult( + True, + "FBCODE_SKIP_DIRS", + ) + + if ( + is_fbcode() + and config.skip_torchrec + and FBCODE_SKIP_TORCHREC_DIRS + and bool(FBCODE_SKIP_TORCHREC_DIRS_RE.match(filename)) + and not bool(FBCODE_INLINE_FILES_IN_SKIPPED_DIRS_RE.match(filename)) + ): + return SkipResult(True, "FBCODE_SKIP_TORCHREC_DIRS") + + if ( + filename.startswith(_module_dir(unittest)) + and not torch._dynamo.config.enable_trace_unittest + ): + return SkipResult(True, "unittest") + + if bool(SKIP_DIRS_RE.match(filename)): + return SkipResult(True, "SKIP_DIRS") + + if any(filename.startswith(d) for d in get_mod_skiplist()): + return SkipResult(True, "MOD_SKIPLIST") + return SkipResult(False, "inlined by default") + + +@dataclasses.dataclass +class FunctionInfo: + py_obj: Optional[object] + name: Optional[str] + filename: str + code: Optional[types.CodeType] + + +""" +This is the main entry point to determine whether an object (function) should be inlined or skipped. +Let's illustrate the logic with an example: + @torch.compile + def f1(x, y): + ...... + f2(x, y) + ...... + + def f2(x, y): + ...... + f3(x, y) + ...... + + def f3(x, y): + ...... + +There are mainly three call sites of check/check_verbose: +* The compile region entrance (like function f1), the corresponding code is located at eval_frame.py. +* When tracing the recursively called functions (like function f2 and f3). + * Dynamo decides inline/skip every time it encounters a new recursively function call, and the call site + is in InliningInstructionTranslator.check_inlineable of symbolic_convert.py. + * If f2 is skipped by Dynamo, when evaluating the frame of f3, Dynamo need the inline/skip check again + and the call site is in catch_errors_wrapper.catch_errors of convert_frame.py. +* For global variables and function arguments, Dynamo needs to decide if they are wrapped as SkipFunctionVariable in builder.py. + +`is_inlined_call` is used to indicate if the current function call is inlined (f2 is inlined call if it passes check) +or not (f3 is not inlined call if f2 is skipped). Inside of the `check_verbose` function, there are more rules +to be checked if this `is_inlined_call`. +The reason to have this flag is that if the upper level function call (e.g, f2) is skipped, +we don't want to inline the lower level function call (e.g, f3) by default. +""" + + +def check_verbose(obj, is_inlined_call=False): + if isinstance( + obj, + ( + UserFunctionVariable, + UserMethodVariable, + NestedUserFunctionVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, + ), + ): + try: + py_obj = obj.get_function() + except NotImplementedError: + py_obj = None + fi = FunctionInfo(py_obj, obj.get_name(), obj.get_filename(), obj.get_code()) + elif isinstance(obj, types.CodeType): + fi = FunctionInfo(None, obj.co_name, obj.co_filename, obj) + elif isinstance(obj, (types.FunctionType, types.MethodType)): + fi = FunctionInfo( + obj, + obj.__name__, + getfile(obj), + obj.__code__, # type: ignore[union-attr] # FIXME Add MethodType.__code__ to typeshed + ) + else: + fi = FunctionInfo(obj, None, getfile(obj), None) + + # Consulte the central trace rules defined in torch._dynamo.trace_rules. + reasons: set[str] = set() + rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) + if issubclass( + rule, + ( + UserFunctionVariable, + LocalGeneratorFunctionVariable, + PolyfilledFunctionVariable, + ), + ): + return SkipResult( + False, + f"inlined according trace_rules.lookup {reasons.pop()}", + ) + elif issubclass(rule, TorchInGraphFunctionVariable): + return SkipResult( + False, + f"registered in torch_obj_rule {reasons.pop()}", + ) + else: + assert rule == SkipFunctionVariable, rule + return SkipResult( + True, + f"skipped according trace_rules.lookup {reasons.pop()}", + ) + + +def check(obj, is_inlined_call=False): + return check_verbose(obj, is_inlined_call).skipped + + +# skip common third party libs +for _name in THIRDPARTY_SKIPLIST: + add(_name) + +_recompile_re() + + +def is_torch_inline_allowed(filename): + return any(filename.startswith(d) for d in get_mod_inlinelist()) + + +@functools.cache +def dynamo_dir(): + import torch._dynamo + + return _module_dir(torch._dynamo) + + +def is_torch(filename): + if filename.startswith(dynamo_dir()): + return False + return filename.startswith(_module_dir(torch)) + + +""" +Main entry point for looking up the trace rule (the Dynamo variable) for a given callable object. +""" + + +def lookup_callable(obj): + if not hashable(obj): + return None + # Custom allow/disallow in graph takes precedence over the general lookup. + if is_callable_disallowed(obj): + return SkipFunctionVariable + if is_callable_allowed(obj): + return TorchInGraphFunctionVariable + if is_polyfilled_callable(obj): + return PolyfilledFunctionVariable + if is_builtin_callable(obj): + return BuiltinVariable + return None + + +""" +Main entry point for looking up the trace rule (the Dynamo variable) for a given function object. +E.g, the lookup result of `torch.sin` is `TorchInGraphFunctionVariable`. +""" + + +def lookup(obj): + return lookup_inner(obj) + + +# also takes config.dont_skip_tracing into account +def lookup_inner( + obj, + name=None, + filename=None, + is_direct_call=True, + reasons: Union[None, set[str]] = None, +): + result = _lookup_inner( + obj, + name=name, + filename=filename, + is_direct_call=is_direct_call, + reasons=reasons, + ) + # There are still some modules we should absolutely NOT trace into - e.g. most of torch._dynamo, + # as this can result in really weird tracing behaviors. + # Note that if a torch._dynamo function is already not skipped (e.g. functions in external_utils.py), + # then this branch does not apply. + if config.dont_skip_tracing and result is SkipFunctionVariable: + if filename is None: + filename = getfile(obj) + filename = _as_posix_path(filename) + dynamo_path = _as_posix_path(_module_dir(torch)) + "_dynamo" + if filename.startswith(dynamo_path) and not filename.endswith( + "test_dont_skip_tracing_functions.py" + ): + return SkipFunctionVariable + if reasons is not None: + reasons.add( + "Attempted skip but we are ignoring skips due to torch._dynamo.config.dont_skip_tracing" + ) + return UserFunctionVariable + return result + + +def _lookup_inner( + obj, + name=None, + filename=None, + is_direct_call=True, + reasons: Union[None, set[str]] = None, +): + # Step 1: lookup obj's tracing rule in `torch_name_rule_map`. + # The rules defined in `torch_name_rule_map` mainly includes two parts: + # - Manually defined rules for any functions. + # - The list of torch in graph functions. + try: + can_hash = hashable(obj) + except Exception: + can_hash = False + if not can_hash: + if reasons is not None: + reasons.add("obj is not hashable") + return None + if obj is not None: + if is_aten_op_or_tensor_method(obj): + return TorchInGraphFunctionVariable + rule = get_torch_obj_rule_map().get(obj, None) + if rule is not None: + if reasons is not None: + reasons.add("get_torch_obj_rule_map") + return rule + elif name is not None and filename is not None and not is_direct_call: + if name.startswith(TORCH_DYNAMO_RESUME_IN_PREFIX): + rule = get_torch_obj_rule_map().get( + filename + "#" + TORCH_DYNAMO_RESUME_IN_PREFIX, None + ) + else: + rule = get_torch_obj_rule_map().get(filename + "#" + name, None) + if rule is not None: + if reasons is not None: + reasons.add("get_torch_obj_rule_map") + return rule + elif name == "": + if reasons is not None: + reasons.add("inlining frame from list comprehension") + return UserFunctionVariable + + # Step 2: lookup obj's tracing rule by function name. + if is_direct_call: + if name == "patched_init": + if reasons is not None: + reasons.add("func name is patched_init") + return SkipFunctionVariable + elif name == "__torch_function__" or ( + obj and getattr(obj, "__name__", None) == "__torch_function__" + ): + if reasons is not None: + reasons.add("func name is __torch_function__") + return UserFunctionVariable + + if not is_direct_call: + if name == "__getattr__": + # is_direct_call = False indicates that this is the top-level frame + # being traced (i.e., it is not inlined and not called from + # InliningInstructionTranslator). Tracing __getattr__ at the top + # level is unlikely because we inline it for + # UserDefinedObjectVariable. This scenario occurs only for + # UnspecializedNNModuleVariable, where Dynamo directly calls + # __getattr__ during trace time, generating LOAD_ATTR bytecode + # without going through the underlying __getattr__ data structures. + # When this optimized bytecode is executed, Dynamo is triggered + # again on the __getattr__ call. Therefore, we skip Dynamo tracing + # in this case. + if reasons is not None: + reasons.add( + "Tracing __getattr__ as the top level frame, unsuitable for tracing." + ) + return SkipFunctionVariable + + # Step 3: lookup obj's tracing rule by filename. + if filename is None: + filename = getfile(obj) + + skip_result = check_file(filename, is_direct_call) + if reasons is not None: + reasons.add(skip_result.reason) + if skip_result.skipped: + return SkipFunctionVariable + else: + return UserFunctionVariable + + +def clear_lru_cache(): + torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear() + torch._dynamo.trace_rules.get_tensor_method.cache_clear() + torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear() + torch._dynamo.trace_rules.get_mod_inlinelist.cache_clear() + torch._dynamo.trace_rules.dynamo_dir.cache_clear() diff --git a/phivenv/Lib/site-packages/torch/_dynamo/types.py b/phivenv/Lib/site-packages/torch/_dynamo/types.py new file mode 100644 index 0000000000000000000000000000000000000000..16ed1c92c3c106dcb4fa045226f03ad1d2ac8b72 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/types.py @@ -0,0 +1,139 @@ +"""This module contains the core type definitions and protocols used throughout Dynamo. + +The types defined here fall into several categories: +- Guard related types (GuardFn, GuardFail, GuardedCode): Used for tracking and managing guards that protect compiled code +- Frame and cache types (FrameState, CacheEntry): Used for managing interpreter frame state and caching +- Callback protocols (DynamoCallbackFn): Define the interface for frame evaluation callbacks +- Hook protocols (DynamoGuardHook, ProfilerStartHook, ProfilerEndHook, BytecodeHook): Define various hook points for + instrumentation and customization + +These types provide the foundational interfaces that enable Dynamo's dynamic compilation and optimization system, +ensuring type safety and clear contracts between different components of the system. +""" + +import dataclasses +import types +from typing import Any, Callable, NamedTuple, Optional, Protocol, Union + +# CacheEntry has a `guard_manager` field for the guard, and a `code` field for the code object. +from torch._C._dynamo.eval_frame import ( + _CacheEntry as CacheEntry, + _ExtraState as ExtraState, + _FrameAction as FrameAction, + _FrameExecStrategy as FrameExecStrategy, + _PyInterpreterFrame as DynamoFrameType, +) +from torch._guards import CompileId, Guard + + +# We use a dict to store additional data per frame. +FrameState = dict[Any, Any] + + +class GuardFail(NamedTuple): + # A string repr of the piece of failed guard code we eval-ed + reason: str + # A code object where we failed a guard + orig_code: types.CodeType + + +@dataclasses.dataclass(frozen=True) +class GuardFilterEntry: + name: str + has_value: bool + value: object + guard_type: str + derived_guard_types: tuple[str, ...] + is_global: bool + orig_guard: Guard + + +class GuardFn(Protocol): + closure_vars: dict[str, object] + args: list[str] + code_parts: list[str] + verbose_code_parts: list[str] + global_scope: dict[str, object] + guard_fail_fn: Optional[Callable[[GuardFail], None]] + cache_entry: Optional[CacheEntry] + extra_state: Optional[ExtraState] + + # maps locals of user function to bool + def __call__(self, f_locals: dict[str, object]) -> bool: ... + + +@dataclasses.dataclass +class GuardedCode: + code: types.CodeType + guard_manager: GuardFn + compile_id: CompileId + trace_annotation: str = "Unknown" + + +@dataclasses.dataclass +class ConvertFrameReturn: + # default return is no compiled code (i.e. `return None`): + # strategy is to skip non-recursively, for all future intercepted frames too + + # eval frame execution strategy for this frame + frame_exec_strategy: FrameExecStrategy = dataclasses.field( + default_factory=lambda: FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT) + ) + # also apply frame_exec strategy to future frames with same code + apply_to_code: bool = True + guarded_code: Optional[GuardedCode] = None + + +def wrap_guarded_code(guarded_code: GuardedCode) -> ConvertFrameReturn: + return ConvertFrameReturn( + frame_exec_strategy=FrameExecStrategy(FrameAction.DEFAULT, FrameAction.DEFAULT), + guarded_code=guarded_code, + ) + + +class DynamoCallbackFn(Protocol): + def __call__( + self, + frame: DynamoFrameType, + cache_entry: Optional[CacheEntry], + frame_state: FrameState, + ) -> ConvertFrameReturn: ... + + +DynamoCallback = Union[DynamoCallbackFn, None, bool] + + +class DynamoGuardHook(Protocol): + def __call__( + self, + guard_manager: GuardFn, + code: types.CodeType, + f_locals: dict[str, object], + index: int, + last: bool, + ) -> None: ... + + +class DynamoGuardCompleteHook(Protocol): + def __call__( + self, + cache_hit: bool, + ) -> bool: ... + + +class ProfilerStartHook(Protocol): + def __call__( + self, + name: str, + # TODO(whc) how do I annotate a _RecordFunction here? + ) -> Any: ... + + +class ProfilerEndHook(Protocol): + def __call__(self, record: Any) -> None: ... + + +class BytecodeHook(Protocol): + def __call__( + self, code: types.CodeType, new_code: types.CodeType + ) -> Optional[types.CodeType]: ... diff --git a/phivenv/Lib/site-packages/torch/_dynamo/utils.py b/phivenv/Lib/site-packages/torch/_dynamo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc1efb1462f233617e0d536e5ab7b655c45137b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/utils.py @@ -0,0 +1,4692 @@ +# mypy: allow-untyped-defs + +""" +Utility functions and classes used throughout the TorchDynamo system. + +This module contains a collection of helper utilities used by various parts of Dynamo for: +- Performance metrics collection and reporting +- Compilation timing and debugging +- Graph manipulation and tensor operations +- Runtime guards and checks +- Common data structure operations +- Testing and development tools + +This is an internal module that provides shared functionality used across the Dynamo codebase. +""" + +from __future__ import annotations + +import atexit +import collections +import contextlib +import copy +import dataclasses +import datetime +import dis +import enum +import functools +import gc +import importlib +import inspect +import itertools +import json +import linecache +import logging +import math +import operator +import os +import re +import sys +import textwrap +import threading +import time +import traceback +import types +import typing +import uuid +import warnings +import weakref +from collections import Counter, OrderedDict +from contextlib import AbstractContextManager, contextmanager +from dataclasses import is_dataclass +from functools import lru_cache +from types import CodeType, MethodWrapperType +from typing import ( + Any, + Callable, + cast, + ClassVar, + Generic, + Optional, + overload, + TypeVar, + Union, +) +from typing_extensions import Literal, TypeAlias, TypeGuard, TypeIs + +import torch +import torch._functorch.config +import torch.fx.experimental.symbolic_shapes +import torch.utils._pytree as pytree +from torch import fx +from torch._C import ( + _instruction_counter, + _len_torch_function_stack, + _pop_torch_function_stack, + _push_on_torch_function_stack, +) +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.metrics_context import MetricsContext, RuntimeMetricsContext +from torch._guards import CompileId, Source, TracingContext +from torch._subclasses.meta_utils import is_sparse_compressed +from torch._utils_internal import ( + justknobs_check, + log_chromium_event_internal, + log_compilation_event, + record_chromium_event_internal, + signpost_event, +) +from torch.fx._utils import _format_graph_code, lazy_format_graph_code +from torch.monitor import _WaitCounter +from torch.nn.modules.lazy import LazyModuleMixin +from torch.utils._triton import has_triton, has_triton_package +from torch.utils.hooks import RemovableHandle + +from .graph_utils import _get_flat_args + + +if typing.TYPE_CHECKING: + from collections.abc import ( + Generator, + ItemsView, + Iterable, + Iterator, + KeysView, + ValuesView, + ) + + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + +try: + import torch._logging + import torch._numpy as tnp + from torch._guards import detect_fake_mode # noqa: F401 + from torch._logging import LazyString + + from . import config + + # NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync. + if np: + NP_SUPPORTED_MODULES: tuple[types.ModuleType, ...] = ( + np, + np.fft, + np.linalg, + np.random, + ) + + NP_TO_TNP_MODULE = { + np: tnp, + np.fft: tnp.fft, + np.linalg: tnp.linalg, + np.random: tnp.random, + } + else: + NP_SUPPORTED_MODULES = () + + NP_TO_TNP_MODULE = {} + from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode +except ImportError: + pass + + +T = TypeVar("T") + +unpatched_nn_module_getattr = torch.nn.Module.__getattr__ +unpatched_nn_module_call = torch.nn.Module.__call__ +unpatched_nn_module_call_impl = torch.nn.Module._call_impl + +counters: collections.defaultdict[str, Counter[str]] = collections.defaultdict( + collections.Counter +) +optimus_scuba_log: dict[str, Any] = {} +troubleshooting_url = ( + "https://pytorch.org/docs/main/torch.compiler_troubleshooting.html" +) +nnmodule_doc_url = "https://pytorch.org/docs/main/torch.compiler_nn_module.html" +nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations." +log = logging.getLogger(__name__) + +# profiling compilation time by function +compilation_time_metrics: dict[str, list[float]] = {} + +# This supports calculate_time_spent(), which reports cumulative times +# across the process for any "phase" populated by dynamo_timed. Reset if +# reset_frame_count() is called. +cumulative_time_spent_ns: dict[str, float] = collections.defaultdict(float) + +timer_counter = itertools.count() + + +# Abstraction on top of counters. +class ReInplaceTrigger(enum.Enum): + AUTO_FUNC_V1 = 1 + AUTO_FUNC_V2 = 2 + TRITON_OPS = 3 + + +class ReinplaceCounters: + _values: collections.defaultdict[str, int] = collections.defaultdict(int) + + # Track sizes of known not re-inplaced tensors (exclude dynamic shapes). + @classmethod + def add_missed_bytes(cls, trigger: ReInplaceTrigger, bytes: int): + if bytes != 0: + cls._values[f"missed_bytes_{trigger.name}"] += bytes + + # Track number of not re-inplaced tensors. + @classmethod + def add_missed_opportunities(cls, trigger: ReInplaceTrigger, count: int): + if count != 0: + cls._values[f"missed_tensors_{trigger}"] += count + + @classmethod + def clear(cls): + cls._values.clear() + + @classmethod + def get_total_missed(cls): + sum = 0 + for trigger in ReInplaceTrigger: + sum += cls._values.get(f"missed_tensors_{trigger}", 0) + return sum + + @classmethod + def get_total_missed_bytes(cls): + sum = 0 + for trigger in ReInplaceTrigger: + sum += cls._values.get(f"missed_bytes_{trigger.name}", 0) + return sum + + @classmethod + def log(cls): + # if not empty log. + if cls._values: + signpost_event("inductor", "reinplace_counters", cls._values) + + +def tabulate( + rows: Union[list[tuple[str, object]], list[list[object]]], + headers: Union[tuple[str, ...], list[str]], +) -> str: + try: + import tabulate + + return tabulate.tabulate(rows, headers=headers) + except ImportError: + return "\n".join( + ", ".join(map(str, row)) for row in itertools.chain([headers], rows) + ) + + +curr_frame = 0 + + +# Note: Called for you by dynamo - you almost never ever want to invoke this yourself. +def increment_frame() -> None: + global curr_frame + curr_frame = curr_frame + 1 + + +# Note: Called for you by dynamo - you almost never ever want to invoke this yourself. +def reset_frame_count() -> None: + global curr_frame + cumulative_time_spent_ns.clear() + compilation_time_metrics.clear() + curr_frame = 0 + + +op_count = 0 + + +def increment_op_count(cnt: int) -> None: + global op_count + op_count += cnt + + +# Get the total time in seconds for each "phase" +# For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806} +def calculate_time_spent() -> dict[str, float]: + total_by_key = {} + for phase, timing in cumulative_time_spent_ns.items(): + total_by_key[phase] = timing / 1e9 + + total_by_key["total_wall_time"] = total_by_key.get( + "entire_frame_compile", 0 + ) + total_by_key.get("entire_backward_compile", 0) + return total_by_key + + +# Print a report of time spent so far +# Ex: +# TIMING: +# entire_frame_compile:8.574629999999999 +# backend_compile:5.26806 +def print_time_report() -> None: + total_by_key = calculate_time_spent() + + out = "TIMING:" + for key, value in total_by_key.items(): + out = f"{out} {key}:{round(value, 5)}" + + print(out) + + +# Use the following singleton to capture and log CompilationMetrics. Entering the context +# manager allocates a new record to be logged when it exits. (You should not need to use +# this directly unless you introduce a new code path where compilation metrics would be +# gathered). While compiling, use the setters or timer in MetricsContext to update fields +# in the current context. For example: +# +# To set a single field once (use overwrite=True to overwrite): +# get_metrics_context().set("metric_name", value) +# +# To set multiple fields at once (use overwrite=True to overwrite): +# get_metrics_context().update({"name1": val1, "name2": val2}) +# +# To increment an integer field: +# get_metrics_context().increment("metric_name", value) +# +# To record execution time, MetricsContext works with dynamo_timed: +# def foo(...): +# # Updates the "metric_us" field. +# with dynamo_timed("metric", dynamo_compile_column_us="metric_us") +# ... +# +_METRICS_CONTEXT: MetricsContext +_RUNTIME_METRICS_CONTEXT: RuntimeMetricsContext + + +def get_metrics_context() -> MetricsContext: + return _METRICS_CONTEXT + + +def get_runtime_metrics_context() -> RuntimeMetricsContext: + return _RUNTIME_METRICS_CONTEXT + + +class CompileEventLogLevel(enum.Enum): + """ + Enum that loosely corresponds with a "log level" of a given event. + + CHROMIUM_EVENT: Logs only to tlparse. + COMPILE_EVENT: Logs to tlparse + PT2 Compile Events + COMPILATION_METRIC: Logs to tlparse, PT2 Compile Events, and dynamo_compile + """ + + CHROMIUM = 1 + PT2_COMPILE = 2 + COMPILATION_METRIC = 3 + + +class CompileEventLogger: + """ + Helper class for representing adding metadata(i.e. columns) to various compile events. + Use CompileEventLogger to add event data to: + - Chromium events + - PT2 Compile Events + - CompilationMetrics + + This should be used in conjunction with dynamo_timed() and metrics contexts, which create + timed spans and events. CompileEventLogger uses three log levels (described in CompileEventLogLevel), + where each log level logs to all sources below it in the hierarchy. + + Example usages: + - I want to log to an existing chromium event within dynamo timed: + with dynamo_timed("my_event"): + CompileEventLogger.chromium("my_event", foo=bar) + + - I want to log my event to both chromium + pt2_compile_events: + with dynamo_timed("my_event", log_pt2_compile_event=True): + CompileEventLogger.pt2_compile("my_event", foo=bar) + + - I want to add information to dynamo events and dynamo_compile + CompileEventLogger.compilation_metric(foo=bar) + """ + + @staticmethod + def log_instant_event( + event_name: str, + metadata: dict[str, Any], + time_ns: Optional[int] = None, + log_level: CompileEventLogLevel = CompileEventLogLevel.CHROMIUM, + ): + if time_ns is None: + time_ns = time.time_ns() + chromium_log = get_chromium_event_logger() + if log_level == CompileEventLogLevel.CHROMIUM: + log_pt2_compile_event = False + elif log_level == CompileEventLogLevel.PT2_COMPILE: + log_pt2_compile_event = True + else: + raise RuntimeError( + "Cannot log instant event at COMPILATION_METRIC level. Please choose one of CHROMIUM_EVENT or COMPILE_EVENT" + ) + chromium_log.log_instant_event( + event_name, time_ns, metadata, log_pt2_compile_event + ) + + @staticmethod + def add_data( + event_name: str, + log_level: CompileEventLogLevel, + overwrite: bool = False, + **metadata: object, + ): + """ + Centralized API for adding data to various events + Log an event to a toplevel "dynamo" event or metrics context + depending on log level. + """ + chromium_log = get_chromium_event_logger() + pt2_compile_substack = chromium_log.get_pt2_compile_substack() + + if log_level == CompileEventLogLevel.CHROMIUM: + chromium_log.add_event_data(event_name, **metadata) + elif log_level == CompileEventLogLevel.PT2_COMPILE: + pt2_compile_substack = chromium_log.get_pt2_compile_substack() + if event_name not in pt2_compile_substack: + raise RuntimeError( + "Error: specified log level PT2_COMPILE, but the event %s" + " is not logged to pt2_compile_events. Make sure the event is active and you passed " + "log_pt2_compile_event=True to dynamo_timed", + event_name, + ) + chromium_log.add_event_data(event_name, **metadata) + else: + assert log_level == CompileEventLogLevel.COMPILATION_METRIC + top_event = chromium_log.get_outermost_event() + + if event_name != top_event: + raise RuntimeError( + "Log level is COMPILATION_METRIC, but event_name isn't the toplevel event. " + "CompilationMetrics must be logged to the toplevel event. Consider using `log_toplevel_event_data` directly." + ) + metrics_context = get_metrics_context() + if not metrics_context.in_progress(): + raise RuntimeError( + "No metrics context is in progress. Please only call this function within a metrics context." + ) + + # TODO: should we assert that the keys of metadata are in CompilationMetrics? + metrics_context.update(metadata, overwrite) + chromium_log.add_event_data(event_name, **metadata) + + @staticmethod + def add_toplevel( + log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object + ): + """ + Syntactic sugar for logging to the toplevel event + """ + top_event = get_chromium_event_logger().get_outermost_event() + if top_event is None: + raise RuntimeError( + "No toplevel event active. Please only call this function within a dynamo_timed context." + ) + CompileEventLogger.add_data(top_event, log_level, overwrite, **metadata) + + @staticmethod + def increment( + event_name: str, log_level: CompileEventLogLevel, key: str, value: int + ): + """ + Increments an existing field, or adds it + """ + chromium_log = get_chromium_event_logger() + if ( + log_level == CompileEventLogLevel.CHROMIUM + or log_level == CompileEventLogLevel.PT2_COMPILE + ): + chromium_log.increment(event_name, key, value) + else: + assert log_level == CompileEventLogLevel.COMPILATION_METRIC + top_event = chromium_log.get_outermost_event() + if event_name != top_event: + raise RuntimeError( + "Log level is COMPILATION_METRIC, but event_name isn't the toplevel event. " + "CompilationMetrics must be logged to the toplevel event. Consider using `increment_toplevel` directly." + ) + + metrics_context = get_metrics_context() + if not metrics_context.in_progress(): + raise RuntimeError( + "No metrics context is in progress. Please only call this function within a metrics context/dynamo_timed." + ) + + metrics_context.increment(key, value) + chromium_log.increment(event_name, key, value) + + @staticmethod + def increment_toplevel( + key: str, + value: int = 1, + log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC, + ): + """ + Increments a value on the toplevel metric. By default, logs to metric. + """ + chromium_log = get_chromium_event_logger() + top_event = chromium_log.get_outermost_event() + if top_event is None: + raise RuntimeError( + "No toplevel event active. Please only call this function within a metrics context/dynamo_timed." + ) + CompileEventLogger.increment(top_event, log_level, key, value) + + @staticmethod + def add_to_set( + event_name: str, log_level: CompileEventLogLevel, key: str, value: Any + ): + """ + Add metadata to a set of values with key . Creates a set if it doesn't exist. + """ + chromium_log = get_chromium_event_logger() + if ( + log_level == CompileEventLogLevel.CHROMIUM + or log_level == CompileEventLogLevel.PT2_COMPILE + ): + chromium_log.add_to_set(event_name, key, value) + else: + assert log_level == CompileEventLogLevel.COMPILATION_METRIC + top_event = chromium_log.get_outermost_event() + if event_name != top_event: + raise RuntimeError( + "Log level is COMPILATION_METRIC, but event_name isn't the toplevel event. " + "CompilationMetrics must be logged to the toplevel event. Consider using `add_to_set_metric` directly." + ) + + metrics_context = get_metrics_context() + if not metrics_context.in_progress(): + raise RuntimeError( + "No metrics context is in progress. Please only call this function within a metrics context/dynamo_timed." + ) + + metrics_context.add_to_set(key, value) + chromium_log.add_to_set(event_name, key, value) + + @staticmethod + def add_to_set_toplevel( + key: str, + value: Any, + log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC, + ): + """ + Same as add to set, just does it automatically to the toplevel event instead of having to explicitly name it. + Defaults to COMPILATION_METRIC log level. + """ + chromium_log = get_chromium_event_logger() + top_event = chromium_log.get_outermost_event() + if top_event is None: + raise RuntimeError( + "No toplevel event active. Please only call this function within a metrics context/dynamo_timed." + ) + CompileEventLogger.add_to_set(top_event, log_level, key, value) + + # Helper functions that are syntactic sugar + + @staticmethod + def chromium(event_name: str, **metadata: object): + """ + Add to in chromium. Each key/value of metadata will appear in the chromium trace. + should be the name of a timed event span passed to `dynamo_timed`. + """ + CompileEventLogger.add_data( + event_name, CompileEventLogLevel.CHROMIUM, overwrite=False, **metadata + ) + + @staticmethod + def pt2_compile(event_name: str, **metadata: object): + """ + Add to in chromium and PT2 Compile Events. + Each key/value of metadata will appear in the chromium trace. Each kwarg name becomes + a column in PT2 Compile Events, with the corresponding kwarg value. + should be the name of a timed event span passed to `dynamo_timed`, + with log_to_pt2_compile_events=True. + """ + CompileEventLogger.add_data( + event_name, CompileEventLogLevel.PT2_COMPILE, overwrite=False, **metadata + ) + + @staticmethod + def compilation_metric(overwrite: bool = False, **metadata: object): + """ + Add to the CompilationMetrics context. Also logs to PT2 Compile Events + and chromium. + Each key/value of metadata will appear in the chromium trace. Each kwarg name becomes + a column in PT2 Compile Events and Dynamo Compile, with the corresponding kwarg value. + """ + CompileEventLogger.add_toplevel( + CompileEventLogLevel.COMPILATION_METRIC, overwrite, **metadata + ) + + @staticmethod + def instant( + event_name: str, metadata: dict[str, Any], time_ns: Optional[int] = None + ): + """ + Log an instant event to chromium logs with name at time . The `args` field in + Perfetto will point to metadata. should be a value obtained from time.time_ns(). + """ + CompileEventLogger.log_instant_event( + event_name, metadata, time_ns, CompileEventLogLevel.CHROMIUM + ) + + @staticmethod + def try_add_pt2_compile(event_name: str, **metadata: object): + """ + Adds to an existing pt2_compile event, but silently returns if the event doesn't exist + or ChromiumEventLogger is not initialized. + This function is syntactic sugar for chromium_event_logger().try_add_event_data. + """ + if not chromium_event_log_active(): + return + chromium_log = get_chromium_event_logger() + chromium_log.try_add_event_data(event_name, **metadata) + + @staticmethod + def try_(method_fn, *args, **kwargs): + """ + Special function that quietly runs a given method, returning if CHROMIUM_EVENT_LOG is None or metrics context is not set + """ + if not chromium_event_log_active(): + return + metrics_context = get_metrics_context() + if not metrics_context.in_progress(): + return + method_fn(*args, **kwargs) + + +_dynamo_timed_tls = threading.local() + + +@contextmanager +def dynamo_timed( + key: str, + # TODO(masneral): Deprecate this param. + phase_name: Optional[str] = None, + log_pt2_compile_event: bool = False, + metadata: Optional[dict[str, object]] = None, + dynamo_compile_column_us: Optional[str] = None, + compile_id: Optional[CompileId] = None, + is_backward: Optional[bool] = None, + log_waitcounter: bool = False, + waitcounter_name_override: Optional[str] = None, +) -> Generator[Any, None, None]: + """ + dynamo_timed is a context manager + By wrapping a function in dynamo_timed, we can get a few things: + + 1) Optionally log timings to pt2_compile_events. + 2) Optionally log timings to CompilationMetrics (dynamo_compile). + 3) Optionally log chromium events. + 4) Optionally increment a WaitCounter. + 5) Store a record in compilation_time_metrics + For example: + + def _foo(...): + with dynamo_timed("_foo"): + ... + + Would show up as an entry in our timing dict: + OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])]) + This is extremely useful for granular debugging. + + Although it is tempting to use dynamo_timed as a decorator, please do not. + In its decorator form it makes cProfile traces less useful as dynamo_timed + suddenly becomes a bottleneck for lots of function calls (as only one parent + pointer is recorded). + + Params: + - key: key into compile_time_metrics. If phase_name is not provided, this is + also the event name used for pt2_compile_events logs and chromium events. + - phase_name: Optional override for the event name. + - log_pt2_compile_event: Whether to log a pt2 compile event internally. + - metadata: Extra metadata to put in pt2_compile_events. + - dynamo_compile_column_us: If provided, updates the specified CompilationMetrics + field to be logged to dyname_compile column. We expect all columns to be _us; + therefore, the field name must end with "_us". + - compile_id: In the typical case, this parameter should not be needed. Use to + supply the compile_id for those cases where we want to log a compile_id where + it's not naturally available, e.g., for runtime autotuning. + - is_backward: Specify forward/backward directly when not available in a + CompileContext, e.g., during runtime autotuning. + that support it. + - log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}" + """ + if phase_name: + event_name = phase_name + fn_name = key + else: + event_name = key + fn_name = None + + if key not in compilation_time_metrics: + compilation_time_metrics[key] = [] + + event_metadata = {} + if metadata: + event_metadata.update(metadata) + if fn_name: + event_metadata.update({"fn_name": fn_name}) + if is_backward is not None: + event_metadata.update({"is_backward": is_backward}) + + chromium_log: ChromiumEventLogger = get_chromium_event_logger() + start_ns = time.time_ns() + chromium_log.log_event_start( + event_name, start_ns, event_metadata, log_pt2_compile_event, compile_id + ) + + cx_mgrs: list[typing.Any] = [ + torch.profiler.record_function(f"{key} (dynamo_timed)") + ] + if log_waitcounter: + wc_name = waitcounter_name_override if waitcounter_name_override else key + cx_mgrs.append(_WaitCounter(f"pytorch.wait_counter.{wc_name}").guard()) + + is_compile_time = torch._guards.CompileContext.current_compile_id() is not None + if dynamo_compile_column_us: + # We're standardizing on microseconds for dynamo_compile timings. + assert dynamo_compile_column_us.endswith("_us") + + # Track nested dynamo_timed calls that update CompilationMetrics so we can + # bump a total duration only for the outermost metric. + if not hasattr(_dynamo_timed_tls, "depth"): + _dynamo_timed_tls.depth = 0 + _dynamo_timed_tls.depth += 1 + + # The corresponding WaitCounters that we bump for all overheads + if _dynamo_timed_tls.depth == 1: + cx_mgrs.append(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard()) + if not is_compile_time: + runtime_wc = "pytorch.wait_counter.compile_runtime_overheads" + cx_mgrs.append(_WaitCounter(runtime_wc).guard()) + + try: + with contextlib.ExitStack() as stack: + for cx in cx_mgrs: + stack.enter_context(cx) + yield + finally: + end_ns = time.time_ns() + time_spent_ns = end_ns - start_ns + compilation_time_metrics[key].append(time_spent_ns / 1e9) + chromium_log.log_event_end( + event_name, end_ns, {}, start_ns, log_pt2_compile_event, compile_id + ) + if dynamo_compile_column_us: + # TODO: the events that we capture in calculate_time_spent() seem a little + # arbitrary. Currently, it's only those fields that are present in + # CompilationMetrics (but note that we accumulate by the associated event + # name, not the field name in CompilationMetrics). Do we want to keep it + # this way? + cumulative_time_spent_ns[event_name] += time_spent_ns + + # Bump the total duration for every outer event. + _dynamo_timed_tls.depth -= 1 + is_outer_event = _dynamo_timed_tls.depth == 0 + + duration_us = time_spent_ns // 1000 + if is_compile_time: + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.increment(dynamo_compile_column_us, duration_us) + if is_outer_event: + metrics_context.increment("duration_us", duration_us) + else: + runtime_context = get_runtime_metrics_context() + runtime_context.increment(dynamo_compile_column_us, duration_us) + if is_outer_event: + extra = { + "compile_id": compile_id, + "is_runtime": True, + "is_forward": not is_backward, + } + runtime_context.increment("duration_us", duration_us, extra) + + +@overload +def compile_times(repr: Literal["str"], aggregate: bool = False) -> str: ... + + +@overload +def compile_times( + repr: Literal["csv"], aggregate: bool = False +) -> tuple[list[str], list[object]]: ... + + +def compile_times(repr="str", aggregate: bool = False): + """ + Get metrics about torchdynamo frontend/backend compilation times. + + Accumulates information from functions tagged with `dynamo_timed`. + + repr='str' returns a printable string for user interaction, and 'csv' + returns headers, rows which can be logged for output + + aggregate causes values from multiple compilations (e.g. split graphs) + to be accumulated into one value. If false, expect more than one value + per metric. + """ + + def fmt_fn(values, item_fn=lambda x: x): + if aggregate: + return item_fn(sum(values)) + return ", ".join(map(item_fn, values)) + + if repr == "str": + rows = [ + (k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}")) + for k in compilation_time_metrics + ] + out = "TorchDynamo compilation metrics:\n" + out += tabulate(rows, headers=("Function", "Runtimes (s)")) + return out + elif repr == "csv": + values = [ + fmt_fn(v, item_fn=lambda x: f"{x:.6f}") + for v in compilation_time_metrics.values() + ] + headers = list(compilation_time_metrics.keys()) + return headers, values + return None + + +@atexit.register +def dump_compile_times() -> None: + log.info(compile_times(repr="str", aggregate=True)) + + +tensortype_to_dtype = { + torch.FloatTensor: (torch.float32, torch.float), + torch.DoubleTensor: (torch.float64, torch.double), + torch.HalfTensor: (torch.float16, torch.half), + torch.BFloat16Tensor: (torch.bfloat16,), + torch.ByteTensor: (torch.uint8,), + torch.CharTensor: (torch.int8,), + torch.LongTensor: (torch.int64, torch.long), + torch.IntTensor: (torch.int32, torch.int), + torch.ShortTensor: (torch.int16, torch.short), + torch.BoolTensor: (torch.bool,), +} + + +class DuplicateWarningChecker: + def __init__(self, maxsize: int = 4096) -> None: + self.maxsize = maxsize + self.reset() + + def reset(self): + self.set = OrderedDict() + + def add(self, key: Union[str, tuple[object, object]]) -> bool: + if key in self.set: + self.set.move_to_end(key, last=True) + if not config.verbose: + return False + else: + self.set[key] = None + while len(self.set) > self.maxsize: + self.set.popitem(last=False) + return True + + +graph_break_dup_warning_checker = DuplicateWarningChecker() + + +def setup_compile_debug(): + compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + if compile_debug: + return add_file_handler() + + return contextlib.ExitStack() + + +def reset_graph_break_dup_checker() -> None: + graph_break_dup_warning_checker.reset() + + +def add_file_handler(): + log_path = os.path.join(get_debug_dir(), "torchdynamo") + os.makedirs(log_path, exist_ok=True) + + log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log")) + logger = logging.getLogger("torch._dynamo") + logger.addHandler(log_file_handler) + + exitstack = contextlib.ExitStack() + exitstack.callback(lambda: logger.removeHandler(log_file_handler)) + return exitstack + + +def setup_log_file(): + exitstack = contextlib.ExitStack() + if config.log_file_name is not None: + log_file_handler = logging.FileHandler(config.log_file_name) + for logger in torch._logging._internal.get_loggers(): + logger.addHandler(log_file_handler) + exitstack.callback(lambda: logger.removeHandler(log_file_handler)) + return exitstack + + return exitstack + + +def gen_record_file_name(exc, code) -> str: + return f"{get_debug_dir()}/error_recordings/\ +{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec" + + +def write_record_to_file(filename: str, exec_record) -> None: + try: + if os.path.exists(filename): + log.warning( + "Unable to write execution record %s; file already exists.", filename + ) + else: + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "wb") as f: + exec_record.dump(f) + except Exception: + log.exception("Unable to write execution record %s", filename) + + +def count_calls(g: fx.Graph) -> int: + c = 0 + for n in g.nodes: + if "call" in n.op: + c += 1 + return c + + +def identity(x: T) -> T: + return x + + +def hashable(x): + try: + hash(x) + return True + except TypeError: + return False + # cannot hash writable memoryview object + except ValueError: + return False + + +def nothing(*args, **kwargs): + pass + + +class ExactWeakKeyDictionary: + """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality""" + + def __init__(self): + self.values = {} + self.refs = {} + + def __getitem__(self, key): + return self.values[id(key)] + + def get(self, key, default=None): + return self.values.get(id(key), default) + + def __contains__(self, key): + return id(key) in self.values + + def __setitem__(self, key, value): + idx = id(key) + if idx not in self.refs: + self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx)) + self.values[idx] = value + + def _remove_id(self, idx): + if idx in self.values: + del self.values[idx] + if idx in self.refs: + del self.refs[idx] + + def clear(self): + self.refs.clear() + self.values.clear() + + +@overload +def istype(obj: object, allowed_types: type[T]) -> TypeIs[T]: ... + + +@overload +def istype( + obj: object, allowed_types: tuple[type[list[T]], type[tuple[T, ...]]] +) -> TypeIs[T]: ... + + +@overload +def istype(obj: object, allowed_types: Iterable[type]) -> bool: ... + + +def istype(obj, allowed_types): + """isinstance() without subclasses""" + if isinstance(allowed_types, (tuple, list, set)): + return type(obj) in allowed_types + return type(obj) is allowed_types + + +if sys.version_info >= (3, 12): + # Some typing classes moved to C in 3.12, + # which no longer have the _Final mixin. + _builtin_final_typing_classes = ( + typing.ParamSpecArgs, + typing.ParamSpecKwargs, + typing.ParamSpec, + typing.TypeVar, + typing.TypeVarTuple, + typing.TypeAliasType, + ) + + +def is_typing(value): + # _Final catches most of typing classes: + # - Any + # - Callable + # - Union + # ... + # + # NB: we intentionally ignore classes that inherit from Generic, since they + # can be used as both TypingVariable as well as UserDefinedClassVariable. + if sys.version_info >= (3, 12) and isinstance(value, _builtin_final_typing_classes): + return True + return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined] + + +def is_numpy_int_type(value): + if not np: + return False + + return istype( + value, + ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + ) + + +def is_numpy_float_type(value): + if not np: + return False + + return istype( + value, + ( + np.float16, + np.float32, + np.float64, + ), + ) + + +@overload +def is_lru_cache_wrapped_function( + value: Callable[..., T], +) -> TypeGuard[functools._lru_cache_wrapper[T]]: ... + + +@overload +def is_lru_cache_wrapped_function( + value: Any, +) -> TypeGuard[functools._lru_cache_wrapper[Any]]: ... + + +def is_lru_cache_wrapped_function( + value: Any, +) -> bool: + return isinstance(value, functools._lru_cache_wrapper) and is_function( + inspect.getattr_static(value, "__wrapped__") + ) + + +_FuncTypes: TypeAlias = Union[ + types.FunctionType, + types.BuiltinFunctionType, + types.MethodDescriptorType, + types.WrapperDescriptorType, +] + + +def is_function_or_wrapper( + value: Any, +) -> TypeIs[Union[_FuncTypes, torch._ops.OpOverloadPacket, torch._ops.OpOverload]]: + return is_function(value) or isinstance( + value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) + + +def is_function( + value: Any, +) -> TypeIs[_FuncTypes]: + return isinstance( + value, + ( + types.FunctionType, + types.BuiltinFunctionType, + types.MethodDescriptorType, + types.WrapperDescriptorType, + ), + ) + + +cmp_name_to_op_mapping = { + "__eq__": operator.eq, + "__ne__": operator.ne, + "__lt__": operator.lt, + "__le__": operator.le, + "__gt__": operator.gt, + "__ge__": operator.ge, +} + + +cmp_name_to_op_str_mapping = { + "__eq__": "==", + "__ne__": "!=", + "__lt__": "<", + "__le__": "<=", + "__gt__": ">", + "__ge__": ">=", +} + + +def is_wrapper_or_member_descriptor( + value: Any, +) -> TypeIs[ + Union[ + types.GetSetDescriptorType, + types.MethodDescriptorType, + types.WrapperDescriptorType, + types.MemberDescriptorType, + types.MethodWrapperType, + ] +]: + return isinstance( + value, + ( + # set up by PyGetSetDef + types.GetSetDescriptorType, + # set by PyMethodDef, e.g. list.append + types.MethodDescriptorType, + # slots - list.__add__ + types.WrapperDescriptorType, + # set up by PyMemberDef + types.MemberDescriptorType, + # wrapper over C functions + types.MethodWrapperType, + ), + ) + + +def unwrap_if_wrapper(fn): + return unwrap_with_attr_name_if_wrapper(fn)[0] + + +def unwrap_with_attr_name_if_wrapper(fn): + # TODO(anijain2305) - Investigate if we can get rid of this function + # unpack @torch._dynamo.optimize()(fn) wrapped function + if is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False): + fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) + attr_name = "_torchdynamo_inline" + else: + attr_name = None + return fn, attr_name + + +def is_numpy_ndarray(value): + if not np: + return False + + return istype(value, np.ndarray) + + +def istensor(obj): + """Check of obj is a tensor""" + tensor_list: tuple[type, ...] = ( + torch.Tensor, + torch.nn.Parameter, + *config.traceable_tensor_subclasses, + ) + tensor_list = tensor_list + (torch._subclasses.FakeTensor,) + return istype(obj, tensor_list) + + +def is_lazy_module(mod): + return isinstance(mod, LazyModuleMixin) + + +@functools.lru_cache(4096) +def print_once(*args): + print(*args) + + +def make_cell(val=None): + """Some black magic to create a cell object that usually only exists in a closure""" + x = val + + def f(): + return x + + assert f.__closure__ is not None and len(f.__closure__) == 1 + return f.__closure__[0] + + +def proxy_args_kwargs(args, kwargs): + try: + proxy_args = tuple(arg.as_proxy() for arg in args) + proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return proxy_args, proxy_kwargs + except NotImplementedError as e: + from .exc import unimplemented_v2 + from .variables.base import typestr + + unimplemented_v2( + gb_type="Failed to convert args/kwargs to proxy", + context=f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}", + explanation="Missing `as_proxy()` implementation for some arg/kwarg.", + hints=[], + from_exc=e, + ) + + +def to_int_ms(v: Optional[float]) -> Optional[int]: + return None if v is None else int(v * 1000) + + +# float64 timestamp has a quarter microsecond precision in 2024, so while +# this is suboptimal we shouldn't meaningfully lose precision +def to_int_us(v: Optional[float]) -> Optional[int]: + return None if v is None else int(v * 1_000_000) + + +# Version field added to every log. Increment to make it easier to distinguish new +# vs. old entries when you make a substantive change to how the logs are populated. +LOG_FORMAT_VERSION = 3 + + +@dataclasses.dataclass +class CompilationMetrics: + compile_id: Optional[str] = None + frame_key: Optional[str] = None + co_name: Optional[str] = None + co_filename: Optional[str] = None + co_firstlineno: Optional[int] = None + cache_size: Optional[int] = None + accumulated_cache_size: Optional[int] = None + guard_count: Optional[int] = None + shape_env_guard_count: Optional[int] = None + graph_op_count: Optional[int] = None + graph_node_count: Optional[int] = None + graph_input_count: Optional[int] = None + start_time: Optional[float] = None + entire_frame_compile_time_s: Optional[float] = None + backend_compile_time_s: Optional[float] = None + inductor_compile_time_s: Optional[float] = None + code_gen_time_s: Optional[float] = None + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + fail_user_frame_filename: Optional[str] = None + fail_user_frame_lineno: Optional[int] = None + non_compliant_ops: Optional[set[str]] = None + compliant_custom_ops: Optional[set[str]] = None + restart_reasons: Optional[set[str]] = None + dynamo_time_before_restart_s: Optional[float] = None + # Sometimes, we will finish analyzing a frame but conclude we don't want + # to install any guarded code. True means we actually decided to install + # a compiled frame + has_guarded_code: Optional[bool] = None + remote_cache_time_saved_s: Optional[float] = None + structured_logging_overhead_s: Optional[float] = None + config_suppress_errors: Optional[bool] = None + config_inline_inbuilt_nn_modules: Optional[bool] = None + specialize_float: Optional[bool] = None + dynamo_config: Optional[str] = None + is_forward: Optional[bool] = None + num_triton_bundles: Optional[int] = None + remote_fx_graph_cache_get_time_ms: Optional[int] = None + remote_fx_graph_cache_put_time_ms: Optional[int] = None + start_time_us: Optional[int] = None + duration_us: Optional[int] = None + dynamo_cumulative_compile_time_us: Optional[int] = None + aot_autograd_cumulative_compile_time_us: Optional[int] = None + inductor_cumulative_compile_time_us: Optional[int] = None + inductor_code_gen_cumulative_compile_time_us: Optional[int] = None + triton_compile_time_us: Optional[int] = None + runtime_cudagraphify_time_us: Optional[int] = None + runtime_triton_autotune_time_us: Optional[int] = None + dynamo_compile_time_before_restart_us: Optional[int] = None + distributed_ephemeral_timeout_us: Optional[int] = None + structured_logging_overhead_us: Optional[int] = None + remote_fx_graph_cache_get_time_us: Optional[int] = None + remote_fx_graph_cache_put_time_us: Optional[int] = None + backward_cumulative_compile_time_us: Optional[int] = None + end_time_us: Optional[int] = None + pre_grad_pass_time_us: Optional[int] = None + post_grad_pass_time_us: Optional[int] = None + joint_graph_pass_time_us: Optional[int] = None + log_format_version: int = LOG_FORMAT_VERSION + inductor_config: Optional[str] = None + remote_cache_version: Optional[int] = None + inductor_fx_remote_cache_hit_count: Optional[int] = None + inductor_fx_remote_cache_miss_count: Optional[int] = None + inductor_fx_remote_cache_backend_type: Optional[str] = None + inductor_fx_remote_cache_hit_keys: Optional[str] = None + inductor_fx_remote_cache_miss_keys: Optional[str] = None + cuda_version: Optional[str] = None + triton_version: Optional[str] = None + feature_usage: Optional[dict[str, bool]] = None + compile_time_autotune_time_us: Optional[int] = None + is_runtime: Optional[bool] = False + gc_time_us: Optional[int] = None + tensorify_float_attempt: Optional[bool] = None + tensorify_float_success: Optional[bool] = None + tensorify_float_failure: Optional[set[str]] = None + guard_latency_us: Optional[float] = None + recompile_reason: Optional[str] = None + num_graph_breaks: Optional[int] = None + triton_kernel_compile_times_us: Optional[str] = None + ir_count: Optional[int] = None + cudagraph_skip_reason: Optional[str] = None + python_version: Optional[str] = None + pgo_put_remote_code_state_time_us: Optional[int] = None + pgo_get_remote_code_state_time_us: Optional[int] = None + # The number of elements within parameters. This is classically what people + # think of when they think of parameters in a ML model. + param_numel: Optional[int] = None + # The number of elements counted by bytes - i.e. a float32 is 4 bytes + # per element. + param_bytes: Optional[int] = None + # The number of parameters counted by fields. This is mostly a proxy for + # the number of distinct type of params. + param_count: Optional[int] = None + + @classmethod + def create(cls, metrics: dict[str, Any]): + """ + Factory method to create a CompilationMetrics from a dict of fields. + Includes the logic to add legacy fields and any pre-processing, e.g., + we transform some fields to comma-separated strings for scuba logging. + """ + + def us_to_s(metric: Optional[int]) -> Optional[float]: + return metric / 1e6 if metric is not None else None + + def us_to_ms(metric: Optional[int]) -> Optional[int]: + return metric // 1000 if metric is not None else None + + def collection_to_str(metric: Optional[Any]) -> Optional[str]: + def safe_str(item: Any) -> str: + try: + return str(item) + except Exception: + return "" + + if metric is None: + return None + + if not isinstance(metric, (set, list)): + return "" + + return ",".join(safe_str(item) for item in sorted(metric)) + + def collection_to_json_str(metric: Optional[Any]) -> Optional[str]: + if metric is None: + return None + try: + return json.dumps(list(metric)) + except Exception: + return "" + + # TODO: The following are legacy fields, populated from the fields that replace + # them. Remove these when we decide we can really deprecate them. + legacy_metrics = { + "start_time": us_to_s(metrics.get("start_time_us")), + "entire_frame_compile_time_s": us_to_s( + metrics.get("dynamo_cumulative_compile_time_us") + ), + "backend_compile_time_s": us_to_s( + metrics.get("aot_autograd_cumulative_compile_time_us") + ), + "inductor_compile_time_s": us_to_s( + metrics.get("inductor_cumulative_compile_time_us") + ), + "code_gen_time_s": us_to_s( + metrics.get("inductor_code_gen_cumulative_compile_time_us") + ), + "remote_cache_time_saved_s": us_to_s( + metrics.get("distributed_ephemeral_timeout_us") + ), + "remote_fx_graph_cache_get_time_ms": us_to_ms( + metrics.get("remote_fx_graph_cache_get_time_us") + ), + "remote_fx_graph_cache_put_time_ms": us_to_ms( + metrics.get("remote_fx_graph_cache_put_time_us") + ), + "structured_logging_overhead_s": us_to_s( + metrics.get("structured_logging_overhead_us") + ), + } + + all_metrics = {**legacy_metrics, **metrics} + + # Processing before logging: + all_metrics["inductor_fx_remote_cache_hit_keys"] = collection_to_str( + all_metrics.get("inductor_fx_remote_cache_hit_keys") + ) + all_metrics["inductor_fx_remote_cache_miss_keys"] = collection_to_str( + all_metrics.get("inductor_fx_remote_cache_miss_keys") + ) + all_metrics["triton_kernel_compile_times_us"] = collection_to_json_str( + all_metrics.get("triton_kernel_compile_times_us") + ) + compile_id = all_metrics.get("compile_id") + all_metrics["compile_id"] = str(compile_id) if compile_id else None + + return cls(**all_metrics) + + +DEFAULT_COMPILATION_METRICS_LIMIT = 64 + + +_compilation_metrics: collections.deque[CompilationMetrics] = collections.deque( + maxlen=DEFAULT_COMPILATION_METRICS_LIMIT +) + + +def add_compilation_metrics_to_chromium(c: CompilationMetrics) -> None: + """ + These are the common fields in CompilationMetrics that existed before + metrics_context, and aren't set by MetricsContext.set(). We add the subset + of them that make sense in `dynamo`/toplevel events in PT2 Compile Events + directly. + + If you're tempted to add to this list, consider using CompileEventLogger.compilation_metric() + instead, which will automatically also add it to tlparse and PT2 Compile Events. + TODO: Get rid of this function and replace it with CompileEventLogger directly instead. + """ + event_logger = get_chromium_event_logger() + event_name = event_logger.get_outermost_event() + if not event_name: + return + event_logger.add_event_data( + event_name=event_name, + frame_key=c.frame_key, + co_name=c.co_name, + co_filename=c.co_filename, + co_firstlineno=c.co_firstlineno, + cache_size=c.cache_size, + accumulated_cache_size=c.accumulated_cache_size, + guard_count=c.guard_count, + shape_env_guard_count=c.shape_env_guard_count, + graph_op_count=c.graph_op_count, + graph_node_count=c.graph_node_count, + graph_input_count=c.graph_input_count, + fail_type=c.fail_type, + fail_reason=c.fail_reason, + fail_user_frame_filename=c.fail_user_frame_filename, + fail_user_frame_lineno=c.fail_user_frame_lineno, + # Sets aren't JSON serializable + non_compliant_ops=list(c.non_compliant_ops) + if c.non_compliant_ops is not None + else None, + compliant_custom_ops=list(c.compliant_custom_ops) + if c.compliant_custom_ops is not None + else None, + restart_reasons=list(c.restart_reasons) + if c.restart_reasons is not None + else None, + dynamo_time_before_restart_s=c.dynamo_time_before_restart_s, + has_guarded_code=c.has_guarded_code, + dynamo_config=c.dynamo_config, + ) + + +def _get_dynamo_config_for_logging() -> Optional[str]: + def clean_for_json(d: dict[str, Any]) -> dict[str, Any]: + blocklist = { + "TYPE_CHECKING", + "log_file_name", + "verbose", + "repro_after", + "repro_level", + "repro_forward_only", + "repro_tolerance", + "repro_ignore_non_fp", + "same_two_models_use_fp64", + "base_dir", + "debug_dir_root", + "_save_config_ignore", + "log_compilation_metrics", + "inject_BUILD_SET_unimplemented_TESTING_ONLY", + "_autograd_backward_strict_mode_banned_ops", + "reorderable_logging_functions", + "ignore_logger_methods", + "traceable_tensor_subclasses", + "nontraceable_tensor_subclasses", + "_custom_ops_profile", + } + + return { + key: sorted(value) if isinstance(value, set) else value + for key, value in d.items() + if key not in blocklist + } + + config_dict = clean_for_json(config.get_config_copy()) + return json.dumps(config_dict, sort_keys=True) + + +def _scrubbed_inductor_config_for_logging() -> Optional[str]: + """ + Method to parse and scrub uninteresting configs from inductor config + """ + + # TypeSafeSerializer for json.dumps() + # Skips complex types as values in config dict + class TypeSafeSerializer(json.JSONEncoder): + def default(self, o): + try: + return super().default(o) + except Exception: + return "Value is not JSON serializable" + + keys_to_scrub: set[Any] = set() + inductor_conf_str = None + inductor_config_copy = ( + torch._inductor.config.get_config_copy() if torch._inductor.config else None + ) + if inductor_config_copy is not None: + try: + for key, val in inductor_config_copy.items(): + if not isinstance(key, str): + keys_to_scrub.add(key) + # Convert set() to list for json.dumps() + if isinstance(val, set): + inductor_config_copy[key] = list(val) + # Evict unwanted keys + for key in keys_to_scrub: + del inductor_config_copy[key] + # Stringify Inductor config + inductor_conf_str = json.dumps( + inductor_config_copy, + cls=TypeSafeSerializer, + skipkeys=True, + sort_keys=True, + ) + except Exception: + # Don't crash because of runtime logging errors + inductor_conf_str = "Inductor Config is not JSON serializable" + return inductor_conf_str + + +def record_compilation_metrics( + start_time_ns: int, + end_time_ns: int, + metrics: dict[str, Any], + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], +): + if torch._inductor.utils.should_use_remote_fx_graph_cache(): + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + + remote_cache_version = REMOTE_CACHE_VERSION + inductor_fx_remote_cache_backend_type = "_ManifoldCache" + except ModuleNotFoundError: + remote_cache_version = None + inductor_fx_remote_cache_backend_type = None + else: + inductor_fx_remote_cache_backend_type = None + remote_cache_version = None + + # Populate the compile_id from the metrics context if it's set. Otherwise, + # look for it in the current compile context. + compile_id = metrics.get("compile_id") + if not compile_id: + compile_id = torch._guards.CompileContext.current_compile_id() + + common_metrics = { + "compile_id": compile_id, + "start_time_us": start_time_ns // 1000, + "end_time_us": end_time_ns // 1000, + "fail_type": exc_type.__qualname__ if exc_type else None, + "fail_reason": str(exc_value) if exc_value else None, + "structured_logging_overhead_us": to_int_us( + torch._logging.get_structured_logging_overhead() + ), + "dynamo_config": _get_dynamo_config_for_logging(), + "inductor_config": _scrubbed_inductor_config_for_logging(), + "cuda_version": torch.version.cuda, + "triton_version": triton.__version__ if has_triton() else "", + "remote_cache_version": remote_cache_version, + "inductor_fx_remote_cache_backend_type": inductor_fx_remote_cache_backend_type, + "python_version": sys.version, + } + + compilation_metrics = CompilationMetrics.create({**common_metrics, **metrics}) + _compilation_metrics.append(compilation_metrics) + + name = "compilation_metrics" + if compilation_metrics.is_forward is False: + name = "bwd_" + name + if compilation_metrics.is_runtime is True: + name = name + "_runtime" + + torch._logging.trace_structured( + name, + lambda: { + k: list(v) if isinstance(v, set) else v + for k, v in dataclasses.asdict(compilation_metrics).items() + }, + # NB: Because compilation metrics *includes* the logging overhead time, + # we can't both *measure* the logging overhead of compilation metrics + # without making it inconsistent with compilation metrics itself, so + # we ignore the (hopefully small) time spent logging compilation metrics + record_logging_overhead=False, + # These may be runtime logs, e.g., runtime autotunning, so we provide + # the CompileId from the compilation metrics in case it's not available + # in the current trace. + compile_id=compile_id, + ) + + # If there's a chromium event in flight, add the CompilationMetrics to it. + add_compilation_metrics_to_chromium(compilation_metrics) + + # Finally log the compilation metrics. + if config.log_compilation_metrics: + log_compilation_event(compilation_metrics) + + +# record_compilation_metrics is called by the singleton MetricsContext exit handler. +_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics) +_RUNTIME_METRICS_CONTEXT = RuntimeMetricsContext(on_exit=record_compilation_metrics) + + +def set_compilation_metrics_limit(new_size: int) -> None: + global _compilation_metrics + while len(_compilation_metrics) > new_size: + _compilation_metrics.popleft() + new_deque = collections.deque(_compilation_metrics, maxlen=new_size) + _compilation_metrics = new_deque + + +def clear_compilation_metrics() -> None: + global _compilation_metrics + _compilation_metrics.clear() + + +def get_compilation_metrics() -> list[CompilationMetrics]: + return list(_compilation_metrics) + + +class ChromiumEventLogger: + """Logs chromium events to structured logs. tlparse will concatenate these into a perfetto UI link. + + See https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#heading=h.yr4qxyxotyw for + a specification of the Chromium Event JSON format. + """ + + def get_stack(self) -> list[str]: + """ + The main event stack, with every chromium event. + Logged to tlparse. + """ + if hasattr(self.tls, "stack"): + return self.tls.stack + else: + self.tls.stack = [] + return self.tls.stack + + def get_outermost_event(self) -> Optional[str]: + """ + Get the outermost event name (i.e. the longest running event) + or None if the stack is empty. + """ + stack = self.get_stack() + return stack[0] if stack else None + + def get_pt2_compile_substack(self): + """ + A smaller subset of the main stack that gets used to log + PT2 Compile Events internally. + """ + if hasattr(self.tls, "pt2_compile_substack"): + return self.tls.pt2_compile_substack + else: + self.tls.pt2_compile_substack = [] + return self.tls.pt2_compile_substack + + def get_event_data(self) -> dict[str, Any]: + if not hasattr(self.tls, "event_data"): + self.tls.event_data = {} + return self.tls.event_data + + def __init__(self): + self.tls = threading.local() + # Generate a unique id for this logger, which we can use in scuba to filter down + # to a single python run. + self.id_ = str(uuid.uuid4()) + + # TODO: log to init/id tlparse after I add support for it + log.info("ChromiumEventLogger initialized with id %s", self.id_) + + def try_add_event_data(self, event_name: str, **kwargs) -> None: + """ + Same as add_event_data, but will silently not log if the event isn't in the stack. + """ + if event_name not in self.get_stack(): + return + self.add_event_data(event_name, **kwargs) + + def add_event_data( + self, + event_name: str, + **kwargs, + ) -> None: + """ + Adds additional metadata info to an in-progress event + This metadata is recorded in the END event + """ + if event_name not in self.get_stack(): + raise RuntimeError( + f"Event {repr(event_name)} not in {self.get_stack()}. " + "Cannot add metadata to events that aren't in progress. " + "Please make sure the event has started and hasn't ended." + ) + event_data = self.get_event_data() + if event_name not in event_data: + event_data[event_name] = {} + event_data[event_name].update(kwargs) + + def increment(self, event_name: str, key: str, value: int): + """ + Increment an integer event data field by the given amount + """ + if event_name not in self.get_stack(): + raise RuntimeError( + f"Event {repr(event_name)} not in {self.get_stack()}. " + "Cannot add metadata to events that aren't in progress. " + "Please make sure the event has started and hasn't ended." + ) + + event_data = self.get_event_data() + if event_name not in event_data: + event_data[event_name] = {} + if key not in event_data[event_name]: + event_data[event_name][key] = 0 + event_data[event_name][key] += value + + def add_to_set( + self, + event_name: str, + key: str, + value: Any, + ): + """ + Add a value to a set within a event_name's metadata if it exists + """ + if event_name not in self.get_stack(): + raise RuntimeError( + f"Event {repr(event_name)} not in {self.get_stack()}. " + "Cannot add metadata to events that aren't in progress. " + "Please make sure the event has started and hasn't ended." + ) + event_data = self.get_event_data() + if event_name not in event_data: + event_data[event_name] = {} + if key not in event_data[event_name]: + event_data[event_name][key] = set() + event_data[event_name][key].add(value) + + def log_event_start( + self, + event_name: str, + time_ns: int, + metadata: dict[str, Any], + log_pt2_compile_event: bool = False, + compile_id: Optional[CompileId] = None, + ) -> None: + """ + Logs the start of a single event. + :param str event_name Name of event to appear in trace + :param time_ns Timestamp in nanoseconds + :param metadata: Any extra metadata associated with this event + :param log_pt2_compile_event: If True, log to pt2_compile_events + :param compile_id: Explicit compile_id (rather than using the current context) + """ + compile_id = compile_id or torch._guards.CompileContext.current_compile_id() + metadata["compile_id"] = str(compile_id) + self._log_timed_event( + event_name, + time_ns, + "B", + metadata, + ) + self.get_stack().append(event_name) + # Add metadata from start event + self.add_event_data(event_name, **metadata) + if log_pt2_compile_event: + self.get_pt2_compile_substack().append(event_name) + + def reset(self) -> None: + # We this on every compile in case a compile crashes or restarts and we haven't + # cleared the stack. + stack = self.get_stack() + substack = self.get_pt2_compile_substack() + stack.clear() + substack.clear() + event_data = self.get_event_data() + event_data.clear() + + def log_event_end( + self, + event_name: str, + time_ns: int, + metadata: dict[str, Any], + start_time_ns: int, + log_pt2_compile_event: bool, + compile_id: Optional[CompileId] = None, + ) -> None: + """ + Logs the end of a single event. This function should only be + called after log_event_start with the same event_name. + :param event_name: Name of event to appear in trace + :param time_ns: Timestamp in nanoseconds + :param metadata: Any extra metadata associated with this event + :param start_time_ns: The start time timestamp in nanoseconds + :param log_pt_compile_event: If True, log to pt2_compile_events + :param compile_id: Explicit compile_id (rather than using the current context) + """ + compile_id = compile_id or torch._guards.CompileContext.current_compile_id() + metadata["compile_id"] = str(compile_id) + + # Grab metadata collected during event span + all_event_data = self.get_event_data() + if event_name in all_event_data: + event_metadata = all_event_data[event_name] + del all_event_data[event_name] + else: + event_metadata = {} + # Add the passed in metadata + event_metadata.update(metadata) + + event = self._log_timed_event( + event_name, + time_ns, + "E", + event_metadata, + ) + + def pop_stack(stack): + while event_name != stack[-1]: + # If the event isn't the most recent one to end, pop + # off the stack until it is. + # Since event_name in self.stack, this pop is always safe + log.warning( + "ChromiumEventLogger: Detected overlapping events, fixing stack" + ) + stack.pop() + + event_stack = self.get_stack() + # These stack health checks currently never happen, + # but they're written this way to future proof any weird event + # overlaps in the future. + if event_name not in event_stack: + # Something went wrong, we never called start on this event, + # or it was skipped due to overlapping events below + log.warning("ChromiumEventLogger: Start event not in stack, ignoring") + return + + pop_stack(event_stack) + + if log_pt2_compile_event: + pt2_compile_substack = self.get_pt2_compile_substack() + pop_stack(pt2_compile_substack) + log_chromium_event_internal( + event, pt2_compile_substack, self.id_, start_time_ns + ) + # Pop actual event off of stack + pt2_compile_substack.pop() + + # Finally pop the actual event off the stack + event_stack.pop() + + def _log_timed_event( + self, + event_name: str, + time_ns: int, + phase: str, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """ + Logs a timed event in chromium format. See log_event_start, log_event_end, etc. + """ + event = { + "name": event_name, + "ts": time_ns / 1000, # Chromium events are in micro seconds + "args": metadata, + "ph": phase, + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, + "pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id + } + torch._logging.trace_structured( + "chromium_event", + payload_fn=lambda: event, + suppress_context=False, + expect_trace_id=False, # Not every chromium event will have a trace_id + ) + record_chromium_event_internal(event) + return event + + def log_instant_event( + self, + event_name: str, + time_ns: int, + metadata: Optional[dict[str, Any]] = None, + # By default, an instant event isn't logged internally, only to structured logging. + log_pt2_compile_event: bool = False, + ) -> None: + """ + Log an instant event with no associated duration. + :param str event_name: Name of event to appear in trace + :param int time_ns Timestamp in nanoseconds + :param Optional[Dict[str, Any]] metadata: Any extra metadata associated with this event + :param str cname optional color for the arrow in the trace + """ + if metadata is None: + metadata = {} + compile_id = str(torch._guards.CompileContext.current_compile_id()) + metadata["compile_id"] = compile_id + event = { + "name": event_name, + "ts": time_ns / 1000, + "args": metadata, + "ph": "i", + # These categories are needed in all chromium traces + "cat": "dynamo_timed", + "tid": 0, + "pid": 0, + "s": "p", # We use "process" level instant events so they all appear on the same row in the trace. + } + torch._logging.trace_structured( + "chromium_event", + payload_fn=lambda: event, + suppress_context=False, + expect_trace_id=True, + ) + if log_pt2_compile_event: + # Log an instant event with the same start and end time + log_chromium_event_internal( + event, self.get_pt2_compile_substack(), self.id_, time_ns + ) + + +CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None + + +def get_chromium_event_logger() -> ChromiumEventLogger: + global CHROMIUM_EVENT_LOG + if CHROMIUM_EVENT_LOG is None: + CHROMIUM_EVENT_LOG = ChromiumEventLogger() + return CHROMIUM_EVENT_LOG + + +def chromium_event_log_active() -> bool: + global CHROMIUM_EVENT_LOG + return CHROMIUM_EVENT_LOG is not None + + +@contextmanager +def chromium_event_timed( + event_name: str, + reset_event_log_on_exit: bool = False, + log_pt2_compile_event: bool = False, +) -> Generator[Any, None, None]: + """ + Context manager that creates a chromium start and end event. Chromium event + logging is integrated with dynamo_timed, so you probably want to use that + instead. Use this context manager only if you want to avoid dynamo_timed. + """ + chromium_event_log = get_chromium_event_logger() + chromium_start_time = time.time_ns() + chromium_event_log.log_event_start( + event_name, + chromium_start_time, + {}, + log_pt2_compile_event, + ) + try: + yield + finally: + chromium_event_log.log_event_end( + event_name, + time.time_ns(), + {}, + chromium_start_time, + log_pt2_compile_event, + ) + if reset_event_log_on_exit: + chromium_event_log.reset() + + +@dataclasses.dataclass +class CleanupHook: + """Remove a global variable when hook is called""" + + scope: dict[str, Any] + name: str + + def __call__(self, *args): + # Make sure we're not shutting down + if CleanupManager is not None: + CleanupManager.count -= 1 + del self.scope[self.name] + + @staticmethod + def create(scope, name, val): + assert name not in scope + CleanupManager.count += 1 + scope[name] = val + return CleanupHook(scope, name) + + +class CleanupManager(ExactWeakKeyDictionary): + count = 0 + instance: ClassVar[CleanupManager] + + def _remove_id(self, idx): + for hook in self.values[idx]: + hook() + super()._remove_id(idx) + + +CleanupManager.instance = CleanupManager() + + +def clone_tensor(x): + """Clone the tensor and its gradient""" + y = x.clone().requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + y.grad = x.grad.clone() + return y + + +def clone_input(x, *, dtype=None): + """copy while preserving strides""" + # TODO: this is questionable + if is_fake(x): + # this func fails on fake tensors in __torch_dispatch__ + return x + + def torch_clone(x): + y = torch.clone(x) + if x.is_leaf: + y.requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + y.grad = clone_input(x.grad, dtype=dtype) + if hasattr(x, "_dynamo_dynamic_indices"): + y._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined] + return y + + with torch.no_grad(): + if x.device.type == "xla": + # Access data_ptr() for a xla tensor will cause crash + return torch_clone(x) + + # Handle sparse storage (no stride). + if x.layout is torch.sparse_coo: + return torch.sparse_coo_tensor( + torch_clone(x._indices()), + torch_clone(x._values()), + x.shape, + is_coalesced=x.is_coalesced(), + ) + elif is_sparse_compressed(x): + if x.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices = x.crow_indices() + plain_indices = x.col_indices() + else: + compressed_indices = x.ccol_indices() + plain_indices = x.row_indices() + return torch.sparse_compressed_tensor( + torch_clone(compressed_indices), + torch_clone(plain_indices), + torch_clone(x.values()), + x.shape, + layout=x.layout, + ) + + needed_size = sum( + (shape - 1) * stride for shape, stride in zip(x.size(), x.stride()) + ) + if x.is_quantized: + result = torch.empty_quantized((needed_size + 32,), x) + else: + result = torch.empty( + needed_size + 32, dtype=dtype or x.dtype, device=x.device + ) + cache_line_offset = ( + (x.data_ptr() - result.data_ptr()) % 32 + ) // x.element_size() + result.as_strided_(x.size(), x.stride(), cache_line_offset) + try: + result.copy_(x.clone()) + if x.is_leaf: + result.requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + result.grad = clone_input(x.grad, dtype=dtype) + except RuntimeError: + # RuntimeError: unsupported operation: more than one element of the written-to + # tensor refers to a single memory location. Please clone() the tensor before + # performing the operation. + return torch_clone(x) + if hasattr(x, "_dynamo_dynamic_indices"): + result._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined] + return result + + +def clone_inputs(example_inputs): + res: Union[dict[Any, Any], list[Any]] + if type(example_inputs) is dict: + res = dict(example_inputs) + for key, value in res.items(): + if isinstance(value, tuple): + res[key] = clone_inputs(value) + else: + assert isinstance(value, torch.Tensor), type(value) + res[key] = clone_input(value) + return res + + res = list(example_inputs) + for i in range(len(res)): + if isinstance(res[i], torch.Tensor): + res[i] = clone_input(res[i]) + return res + + +def skip_frame_if_in_functorch_mode(val: torch.Tensor): + try: + val.data_ptr() # will throw for functorch tensors + except RuntimeError as e: + from .exc import SkipFrame + + # This will be GradTrackingTensor/BatchedTensor/etc + functorch_subclass_name = re.sub(r"\(.*", "", repr(val)) + raise SkipFrame( + f"torch.compile cannot be run in context: {functorch_subclass_name}" + ) from e + + +@contextmanager +def preserve_rng_state(): + disable_functorch = torch._C._DisableFuncTorch + disable_current_modes = torch.utils._python_dispatch._disable_current_modes + with disable_current_modes(), disable_functorch(): + rng_state = torch.clone(torch.random.get_rng_state()) + skip_frame_if_in_functorch_mode(rng_state) + if torch.cuda.is_available(): + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) + try: + yield + finally: + with torch.utils._python_dispatch._disable_current_modes(): + torch.random.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] + + +def is_jit_model( + model0, +): + return isinstance( + model0, + ( + torch.jit._trace.TopLevelTracedModule, + torch.jit._script.RecursiveScriptModule, + torch.jit.ScriptFunction, + torch.jit.ScriptModule, + ), + ) + + +def torchscript(model, example_inputs, verbose=False): + if is_jit_model(model): + # already done? + return model + + try: + return torch.jit.trace(model, example_inputs) + except Exception: + try: + return torch.jit.script(model) + except Exception: + if verbose: + log.exception("jit error") + else: + log.error("Both torch.jit.trace and torch.jit.script failed") + return None + + +def getfile(obj): + try: + return inspect.getfile(obj) + except (TypeError, OSError): + return None + + +def is_namedtuple(obj): + """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" + return is_namedtuple_cls(type(obj)) + + +def is_namedtuple_cls(cls): + """Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple""" + try: + if issubclass(cls, tuple): + module = getattr(cls, "__module__", None) + if module in ("torch.return_types", "torch.autograd.forward_ad"): + return True + if isinstance(getattr(cls, "_fields", None), tuple) and callable( + getattr(cls, "_make", None) + ): + # The subclassing style namedtuple can have an extra base `typing.Generic` + bases = tuple(t for t in cls.__bases__ if t is not Generic) + if bases == (tuple,): + # This is a namedtuple type directly created by `collections.namedtuple(...)` + return True + if bases and any( + ( + # Subclass of namedtuple + is_namedtuple_cls(t) + # For subclasses of namedtuple, the __new__ method should not be customized + and cls.__new__ is t.__new__ + ) + for t in bases + ): + return True + except TypeError: + pass + return False + + +@functools.lru_cache(1) +def namedtuple_fields(cls) -> tuple[str, ...]: + """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple""" + if cls is slice: + return ("start", "stop", "step") + + assert issubclass(cls, tuple) + if hasattr(cls, "_fields"): + # normal namedtuples + return cls._fields + + @dataclasses.dataclass + class Marker: + index: int + + # frustrating ones e.g. torch.return_types.max + assert cls.__module__ == "torch.return_types" + obj = cls(map(Marker, range(cls.n_fields))) + fields: dict[str, int] = {} + for name in dir(obj): + if name[0] != "_" and isinstance(getattr(obj, name), Marker): + fields[name] = getattr(obj, name).index + assert len(fields) == cls.n_fields + return tuple(sorted(fields, key=fields.get)) # type: ignore[arg-type] + + +def checkpoint_params(gm): + with torch.no_grad(): + rng_state = torch.clone(torch.random.get_rng_state()) + if torch.cuda.is_available(): + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) + saved_state = [ + (param, param._version, torch.clone(param)) + for param in itertools.chain(gm.parameters(), gm.buffers()) + ] + + def restore(): + with torch.no_grad(): + torch.random.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + for param, version, original_value in saved_state: + if param._version != version: + param.copy_(original_value) + + return restore + + +def timed(model, example_inputs, times=1): + if torch.cuda.is_available(): + synchronize = torch.cuda.synchronize + else: + synchronize = nothing + + synchronize() + gc.collect() + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(times): + result = model(*example_inputs) + synchronize() + t1 = time.perf_counter() + return result, t1 - t0 # type: ignore[possibly-undefined] + + +def check_is_cuda(gm, example_inputs): + return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True))) + + +@lru_cache(32) +def rot_n_helper(n): + assert n > 1 + vars = [f"v{i}" for i in range(n)] + rotated = reversed(vars[-1:] + vars[:-1]) + fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})") + fn.__name__ = f"rot_{n}_helper" + return fn + + +common_constant_types: set[type] = { + int, + float, + complex, + bool, + str, + bytes, + type(None), + Ellipsis.__class__, + NotImplemented.__class__, + types.CodeType, + # Commonly used immutable types from torch. + torch.device, + torch.dtype, + torch.memory_format, + torch.layout, + torch.finfo, + torch.iinfo, + torch.nn.attention.SDPBackend, + torch.cuda._CudaDeviceProperties, +} + +if has_triton_package(): + import triton + + common_constant_types.add(triton.language.dtype) + +""" + Difference between is_safe_constant and common_constant_types. + * common_constant_types: Constants would be wrapped by VariableBuilder.wrap_literal + as ConstantVariable. + * is_safe_constant: Constants can be loaded by LOAD_CONST bytecode. +""" + + +def is_safe_constant(v): + if istype(v, (tuple, frozenset)): + return all(map(is_safe_constant, v)) + return isinstance( + v, + ( + enum.Enum, + type, + torch.Size, + typing._GenericAlias, # type: ignore[attr-defined] + types.GenericAlias, + ), + ) or istype( + v, + common_constant_types | {slice}, + ) + + +@functools.cache +def common_constants(): + return { + # We zero-one specialize shapes, so specialize these constants + # too + 0, + 1, + } + + +def is_torch_sym(value: Any) -> TypeGuard[Union[torch.SymBool, torch.SymInt]]: + return isinstance(value, (torch.SymBool, torch.SymInt)) and not isinstance( + value.node, torch.nested._internal.nested_int.NestedIntNode + ) + + +def is_int_specialization_case(value, source): + from .source import is_from_defaults + + return not TracingContext.get().force_unspec_int_unbacked_size_like and ( + # Assume integers from global variables want to be specialized + not source.guard_source().is_local() + # Assume that integers that came from NN modules want to be + # specialized (as we don't expect users to be changing the + # NN modules on the fly), unless explicitly disabled + or ( + source.guard_source().is_specialized_nn_module() + and not config.allow_unspec_int_on_nn_module + ) + or ( + source.guard_source().is_unspecialized_builtin_nn_module() + and not config.allow_unspec_int_on_nn_module + ) + or ( + source.guard_source().is_unspecialized_nn_module() + and not config.allow_unspec_int_on_nn_module + ) + or is_from_defaults(source) + # TODO: Delete this condition when rollout is done. NB: this + # condition never evaluates True in open source + or ( + not justknobs_check("pytorch/dynamo:enable_unspecialize_zero_one_plain_int") + and value in common_constants() + ) + ) + + +def specialize_symnode(arg): + from .variables import ConstantVariable, LazyVariableTracker, SymNodeVariable + + # Guard and specialize + if isinstance(arg, LazyVariableTracker) and not arg.is_realized(): + # Find if the arg would be realized as SymNodeVariable later on. If yes, + # realize it and specialize. Else return the arg. + + source = arg.original_source() + value = arg.original_value() + + is_symnode_vt = is_torch_sym(value) or ( + not config.specialize_int + and type(value) is int + and not is_int_specialization_case(value, source) + ) + + if not is_symnode_vt: + return arg + + if isinstance(arg, SymNodeVariable): + return ConstantVariable.create(arg.evaluate_expr()) + return arg + + +def guard_if_dyn(arg): + from .variables import ConstantVariable + + arg = specialize_symnode(arg) + + if isinstance(arg, ConstantVariable): + return arg.as_python_constant() + + return arg + + +def check_constant_args(args, kwargs): + return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values())) + + +def check_unspec_python_args(args, kwargs): + from .variables.constant import ConstantVariable + from .variables.tensor import UnspecializedPythonVariable + + unspec_count = 0 + for x in itertools.chain(args, kwargs.values()): + if isinstance(x, UnspecializedPythonVariable): + unspec_count += 1 + elif not isinstance(x, ConstantVariable): + return False + return unspec_count > 0 + + +def check_unspec_or_constant_args(args, kwargs): + # A fused version of: + # return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs) + from .variables.tensor import UnspecializedPythonVariable + + for x in itertools.chain(args, kwargs.values()): + if not (x.is_python_constant() or isinstance(x, UnspecializedPythonVariable)): + return False + return True + + +def check_numpy_ndarray_args(args, kwargs): + from .variables.tensor import NumpyNdarrayVariable + + return any( + isinstance(x, NumpyNdarrayVariable) + for x in itertools.chain(args, kwargs.values()) + ) + + +dict_keys: type[KeysView[Any]] = type({}.keys()) +dict_values: type[ValuesView[Any]] = type({}.values()) +dict_items: type[ItemsView[Any, Any]] = type({}.items()) +odict_values: type[ValuesView[Any]] = type(OrderedDict().values()) +tuple_iterator: type[Iterator[Any]] = type(iter(())) +range_iterator: type[Iterator[Any]] = type(iter(range(0))) +tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined] +object_new = object.__new__ +dict_new = dict.__new__ +dict_methods = { + method + for method in itertools.chain(dict.__dict__.values(), OrderedDict.__dict__.values()) + if callable(method) +} + +tuple_new = tuple.__new__ +tuple_methods = {method for method in tuple.__dict__.values() if callable(method)} +list_methods = {method for method in list.__dict__.values() if callable(method)} +list_getitem = list.__getitem__ + +str_methods = {method for method in str.__dict__.values() if callable(method)} + + +def builtin_dict_keys(d): + # Avoids overridden keys method of the dictionary + assert isinstance(d, dict) + return dict.keys(d) + + +def get_items_from_dict(obj): + # Get items without calling the user defined __getitem__ or keys method. + assert isinstance(obj, dict) + if istype(obj, (dict, OrderedDict)): + return obj.items() + elif isinstance(obj, OrderedDict): + return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)] + else: + return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)] + + +def nn_module_new(cls): + obj = object_new(cls) + torch.nn.Module.__init__(obj) + return obj + + +def product(it): + return functools.reduce(operator.mul, it, 1) + + +def tuple_iterator_getitem(it, index): + _, (obj,), start = it.__reduce__() + return obj[start + index] + + +def dataclass_fields(cls): + return torch._dynamo.disable(dataclasses.fields)(cls) + + +iter_next = next + + +def normalize_range_iter(range_iter) -> tuple[int, int, int]: + _, (range_obj,), maybe_idx = range_iter.__reduce__() + # In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been + # already incremented by the current index. + start = range_obj.start + (maybe_idx or 0) + stop = range_obj.stop + step = range_obj.step + return (start, stop, step) + + +def to_subclass(t, cls): + return t.as_subclass(cls) + + +dict_getitem = dict.__getitem__ + + +def dict_keys_getitem(d, n): + # Call dict(d) to prevent calling overridden __iter__/keys + dict_class = dict + if isinstance(d, OrderedDict): + dict_class = OrderedDict + return next(itertools.islice(dict_class.keys(d), n, n + 1)) + + +def enum_repr(value, local): + # enum class can override __str__ method. Use __class__ and name attribute + # to extract the class name and key name. + name = value.__class__.__name__ + val = value.name + scope = "L" if local else "G" + local_name = f'{scope}["{name}"].{val}' + return local_name + + +def set_example_value(node, example_value): + # NB: example_value is a bit of a misnomer, because this is always a fake + # tensor of some sort. Furthermore, these example values serve as the + # runtime state of Dynamo tracing, which means if metadata mutation + # occurs, the example_value gets directly updated (so you can't rely on + # this to accurately reflect what the state of the value was at the time + # the program was traced). + node.meta["example_value"] = example_value + shape_env = TracingContext.get().fake_mode.shape_env + if ( + symbol_to_path + := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings( + shape_env, example_value + ) + ): + node.meta["unbacked_bindings"] = symbol_to_path + + +def _get_fake_tensor(vt): + fake_tensor = vt.as_proxy().node.meta.get("example_value") + if not is_fake(fake_tensor): + from . import graph_break_hints + from .exc import unimplemented_v2 + + unimplemented_v2( + gb_type="Cannot check Tensor object identity without its fake value", + context=str(fake_tensor), + explanation="TensorVariable is missing a fake example_value.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + return fake_tensor + + +def iter_contains(items, search, tx, check_tensor_identity=False): + from .variables import ( + BuiltinVariable, + ConstantVariable, + TensorVariable, + VariableTracker, + ) + + if search.is_python_constant(): + found_const = any( + x.is_python_constant() + and x.as_python_constant() == search.as_python_constant() + for x in items + ) + return ConstantVariable.create(found_const) + + must_check_tensor_id = False + if check_tensor_identity and isinstance(search, TensorVariable): + must_check_tensor_id = True + # Match of Tensor means match of FakeTensor + search = _get_fake_tensor(search) + + found: Optional[VariableTracker] = None + for x in items: + if must_check_tensor_id: + if isinstance(x, TensorVariable): + if search is _get_fake_tensor(x): # Object equivalence + return ConstantVariable.create(True) + else: + check = BuiltinVariable(operator.eq).call_function(tx, [x, search], {}) + if found is None: + found = check + else: + found = BuiltinVariable(operator.or_).call_function( + tx, [check, found], {} + ) + if found is None: + found = ConstantVariable.create(False) + return found + + +def key_is_id( + k: Any, +) -> TypeIs[Union[torch.Tensor, torch.nn.Module, MethodWrapperType]]: + """Returns whether it indexes dictionaries using its id""" + return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType)) + + +def key_to_id(value): + return [id(k) if key_is_id(k) else k for k in value.keys()] + + +def const_repr(x, *, local) -> str: + from .trace_rules import is_builtin_callable + + if isinstance(x, (list, tuple)): + elems_repr = ",".join(const_repr(s, local=local) for s in x) + if isinstance(x, list): + return f"[{elems_repr}]" + else: + assert isinstance(x, tuple) + if len(x) == 1: + return f"({elems_repr},)" + else: + return f"({elems_repr})" + elif isinstance(x, enum.Enum): + # To workaround repr(Enum) returning invalid global reference before python 3.11 + # by calling enum_repr and removing quotes to render enum in guard code. + return enum_repr(x, local=local).replace("'", "") + elif is_builtin_callable(x): + return x.__name__ + elif isinstance(x, type): + + def fullname(o): + klass = o.__class__ + module = klass.__module__ + if module == "builtins": + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + klass.__qualname__ + + return fullname(x) + else: + return f"{x!r}" + + +def dict_keys_repr(const_keys, *, local) -> str: + keys_str = ",".join(const_repr(s, local=local) for s in const_keys) + return "[" + keys_str + "]" + + +GLOBAL_KEY_PREFIX = "__dict_key" + + +from torch._subclasses import UnsupportedFakeTensorException # noqa: F401 + + +def get_safe_global_name(tx, root, obj): + # The global_mangled_class_name should be different for different + # invocations of torch.compile. Otherwise, we can run into a situation + # where multiple torch.compile invocations reuse the same global name, + # but the global's lifetime is tied to the first invocation (and + # may be deleted when the first torch.compile invocation is deleted) + # We mangle it based off of the output_graph's id. + return f"{root}_{id(obj)}_c{tx.output.compile_id}" + + +def is_in(item: Any, *containers) -> bool: + for container in containers: + if item in container: + return True + return False + + +def get_unique_name_wrt(prefix: str, *containers, requires_suffix=False) -> str: + """ + Return a name that starts with `prefix` and is not in any of the + `containers` (e.g., map, set). + """ + if not requires_suffix and not is_in(prefix, *containers): + return prefix + + for i in itertools.count(): + candidate = f"{prefix}_{i}" + if not is_in(candidate, *containers): + return candidate + + raise AssertionError("unreachable") + + +def wrap_fake_exception(fn): + try: + return fn() + except UnsupportedFakeTensorException as e: + from .exc import unimplemented_v2 + + msg = f"Encountered exception ({e.reason}) during fake tensor propagation." + log.warning(msg) + unimplemented_v2( + gb_type="Fake tensor propagation exception", + context=str(e.reason), + explanation=msg, + hints=[], + from_exc=e, + ) + + +def deepcopy_to_fake_tensor(obj, fake_mode): + with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): + return wrap_fake_exception(lambda: copy.deepcopy(obj)) + + +def rmse(ref, res): + """ + Calculate root mean squared error + """ + return torch.sqrt(torch.mean(torch.square(ref - res))) + + +def same( + ref, + res, + fp64_ref=None, + cos_similarity=False, + tol=1e-4, + equal_nan=False, + exact_dtype=True, + relax_numpy_equality=False, + ignore_non_fp=False, + log_error=log.error, + use_larger_multiplier_for_smaller_tensor=False, + force_max_multiplier: bool = False, +): + """Check correctness to see if ref and res match""" + if fp64_ref is None: + fp64_ref = ref + if isinstance( + ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size) + ): + assert isinstance(res, (list, tuple, collections.deque)), ( + f"type mismatch {type(ref)} {type(res)}" + ) + if len(ref) != len(res): + log_error("Length mismatch") + return False + return len(ref) == len(res) and all( + same( + ai, + bi, + fp64_refi, + cos_similarity, + tol, + equal_nan, + exact_dtype, + relax_numpy_equality, + ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + force_max_multiplier=force_max_multiplier, + ) + for ai, bi, fp64_refi in zip(ref, res, fp64_ref) + ) + elif type(ref).__name__ == "QuestionAnsweringModelOutput": + # This skips checking accuracy for start_logits/end_logits. + # Tentatively, start_logits/end_logits appear to be very prone to + # inaccuracies and is somewhat subsumed by checking the loss. + return same( + ref.loss, + res.loss, + fp64_ref.loss, + cos_similarity, + tol, + equal_nan, + exact_dtype, + relax_numpy_equality, + ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + force_max_multiplier=force_max_multiplier, + ) + elif isinstance(ref, dict): + assert isinstance(res, dict) + assert set(ref.keys()) == set(res.keys()), ( + f"keys mismatch {set(ref.keys())} == {set(res.keys())}" + ) + for k in sorted(ref.keys()): + if not ( + same( + ref[k], + res[k], + fp64_ref[k], + cos_similarity=cos_similarity, + tol=tol, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + relax_numpy_equality=relax_numpy_equality, + ignore_non_fp=ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + force_max_multiplier=force_max_multiplier, + ) + ): + log_error("Accuracy failed for key name %s", k) + return False + return True + elif isinstance(ref, set): + assert isinstance(res, set) + assert set(ref) == set(res), f"elements mismatch {set(ref)} == {set(res)}" + return True + elif isinstance(ref, (torch.Tensor, float)): + assert not isinstance(ref, torch._subclasses.FakeTensor) + assert not isinstance(res, torch._subclasses.FakeTensor) + + def to_tensor(t): + return t if isinstance(t, torch.Tensor) else torch.tensor(t) + + ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref)) + + if ref.is_sparse: + assert res.is_sparse + ref = ref.to_dense() + res = res.to_dense() + assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}" + if exact_dtype: + if ref.dtype != res.dtype: + log_error("dtype mismatch %s, %s", ref.dtype, res.dtype) + return False + if ref.dtype == torch.bool: + if ignore_non_fp: + return True + # triton stores bool as int8, so add this for more accurate checking + r = torch.allclose( + ref.to(dtype=torch.uint8), + res.to(dtype=torch.uint8), + atol=tol, + rtol=tol, + equal_nan=equal_nan, + ) + if not r: + log_error("Accuracy failed: uint8 tensor did not match") + return r + + if cos_similarity: + ref = ref.flatten().to(torch.float32) + res = res.flatten().to(torch.float32) + if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True): + # early exit that handles zero/nan better + # cosine_similarity(zeros(10), zeros(10), dim=0) is 0 + return True + score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) + if score < 0.99: + log.warning("Similarity score=%s", score.detach().cpu().item()) + return score >= 0.99 + else: + if not exact_dtype: + ref = ref.to(res.dtype) + + # First try usual allclose + if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan): + return True + + # Check error from fp64 version + if fp64_ref.dtype == torch.float64: + # Fix a corner case that res and fp64_ref does not contains NaN and match (with loose tolerance) + # while the ref contains NaN. In this case, RMSE should not match any ways. + # But res is 'BETTER' than ref so we count it pass. + # + # This happens for Super_SloMo when loop ordering after fusion is enabled: + # https://gist.github.com/shunting314/11f235c70f7db0d52718d26f4a701cab + loose_tol = 1e-2 * 4 + if ( + not fp64_ref.isnan().any() + and not res.isnan().any() + and ref.isnan().any() + and torch.allclose( + fp64_ref.to(dtype=res.dtype), + res, + atol=loose_tol, + rtol=loose_tol, + equal_nan=equal_nan, + ) + ): + return True + ref_error = rmse(fp64_ref, ref).item() + # ref unable to produce this with stable numerics in this precision, ignore + if math.isnan(ref_error): + log.warning( + "Found nan in reference. Consider running in higher precision." + ) + + res_error = rmse(fp64_ref, res).item() + + def get_multiplier(): + # In some particular cases, we expect high difference in results. + # At the moment one of this cases is inductor freezing bfloat16 convolution const folding. + # In case of it the res_error is at least one order of magnitude higher. + if force_max_multiplier: + return 10.0 + # In the case of using AMP (Automatic Mixed Precision), certain models have + # failed the benchmark's correctness check. However, the end-to-end model's + # accuracy when comparing AMP with FP32 is within a difference of less than 0.1%. + # Thus, it's possible that the correctness check failures for these models are + # false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms. + multiplier = ( + 3.0 if res.dtype in (torch.float16, torch.bfloat16) else 2.0 + ) + + if use_larger_multiplier_for_smaller_tensor and ( + fp64_ref.numel() <= 10 + ): + multiplier = 10.0 + elif use_larger_multiplier_for_smaller_tensor and ( + fp64_ref.numel() <= 500 + ): + multiplier = 8.0 + elif ( + fp64_ref.numel() < 1000 + or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1) + # large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE + or tol >= 2 * 1e-2 + ): + # In the presence of noise, noise might dominate our error + # metric for smaller tensors. + # Similarly, for 1x1 kernels, there seems to be high noise with amp. + multiplier = 3.0 + return multiplier + + multiplier = get_multiplier() + + passes_test = res_error <= (multiplier * ref_error + tol / 10.0) + if ( + not passes_test + and equal_nan + and math.isnan(ref_error) + and math.isnan(res_error) + # Some unit test for the accuracy minifier relies on + # returning false in this case. + and not torch._inductor.config.cpp.inject_relu_bug_TESTING_ONLY + ): + passes_test = True + if not passes_test: + log_error( + "RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s. res.dtype: %s, multiplier: %f, tol: %f" + ", use_larger_multiplier_for_smaller_tensor: %d", + res_error, + ref_error, + res.size(), + res.dtype, + multiplier, + tol, + use_larger_multiplier_for_smaller_tensor, + ) + return passes_test + + if ignore_non_fp: + return True + + log_error("Accuracy failed: allclose not within tol=%s", tol) + return False + elif isinstance(ref, (str, int, type(None), bool, torch.device)): + if ignore_non_fp: + return True + r = ref == res + if not r: + log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res) + return r + elif is_numpy_int_type(ref) or is_numpy_float_type(ref): + if relax_numpy_equality and not ( + is_numpy_int_type(res) or is_numpy_float_type(res) + ): + ref = ref.item() + r = (type(ref) is type(res)) and (ref == res) + if not r: + log_error("Accuracy failed (numpy): %s != %s", ref, res) + return r + elif is_numpy_ndarray(ref): + return (type(ref) is type(res)) and same( + torch.as_tensor(ref), + torch.as_tensor(res), + fp64_ref, + cos_similarity=cos_similarity, + tol=tol, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + relax_numpy_equality=relax_numpy_equality, + ignore_non_fp=ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + ) + elif type(ref).__name__ in ( + "MaskedLMOutput", + "Seq2SeqLMOutput", + "CausalLMOutputWithCrossAttentions", + "LongformerMaskedLMOutput", + "Instances", + "SquashedNormal", + "Boxes", + "Normal", + "TanhTransform", + "Foo", + "Variable", + ): + assert type(ref) is type(res) + return all( + same( + getattr(ref, key), + getattr(res, key), + getattr(fp64_ref, key), + cos_similarity=cos_similarity, + tol=tol, + equal_nan=equal_nan, + exact_dtype=exact_dtype, + relax_numpy_equality=relax_numpy_equality, + ignore_non_fp=ignore_non_fp, + log_error=log_error, + use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, + ) + for key in ref.__dict__.keys() + ) + else: + raise RuntimeError(f"unsupported type: {type(ref).__name__}") + + +def format_func_info(code): + short_filename = code.co_filename.split("/")[-1] + return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})" + + +@contextlib.contextmanager +def disable_cache_limit(): + prior = config.recompile_limit + config.recompile_limit = sys.maxsize + prior_acc_limit = config.accumulated_recompile_limit + config.accumulated_recompile_limit = sys.maxsize + + try: + yield + finally: + config.recompile_limit = prior + config.accumulated_recompile_limit = prior_acc_limit + + +# map from transformed code back to original user code +orig_code_map = ExactWeakKeyDictionary() + +# keep a record of code_obj -> list of guard failure reasons for logging +guard_failures: collections.defaultdict[Any, list[Any]] = collections.defaultdict(list) + +# Keep a record of graph break reasons for logging +graph_break_reasons: list[torch._dynamo.output_graph.GraphCompileReason] = [] + +# keep record of compiled code, if we are in "error if recompile" +# to track code that dynamo has compiled previously +seen_code_map = ExactWeakKeyDictionary() + + +# return same dir unless user changes config between calls +@functools.cache +def _get_debug_dir(root_dir): + dir_name = ( + "run_" + + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + # use pid to avoid conflicts among ranks + + "-pid_" + + str(os.getpid()) + ) + return os.path.join(root_dir, dir_name) + + +def get_debug_dir(): + debug_root = config.debug_dir_root + return _get_debug_dir(debug_root) + + +def extract_fake_example_value(node, required=True): + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + elif required: + from torch._dynamo.exc import unimplemented_v2 + + from . import graph_break_hints + + unimplemented_v2( + gb_type="Missing FakeTensor example value", + context=str(node), + explanation=f"`FakeTensor` example value was required for {node} but not available.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + else: + return None + + +def ensure_graph_fake(e, tx): + assert maybe_get_fake_mode(e) is tx.fake_mode + return e + + +def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake): + def visit(n: torch.fx.Node): + if n.op == "call_function" and "example_value" not in n.meta: + # fake tensor validity is checked inside get_fake_value using + # ensure_graph_fake + return get_fake_value(n, tx, allow_non_graph_fake) + + elif n.op == "get_attr" and "example_value" not in n.meta: + assert n.target in tx.output.nn_modules + gm = tx.output.nn_modules[n.target] + assert isinstance(gm, torch.fx.GraphModule) + return gm + + out = n.meta["example_value"] + if not allow_non_graph_fake and isinstance(out, torch.Tensor): + return ensure_graph_fake(out, tx) + return out + + return torch.fx.node.map_arg(nodes, visit) + + +def get_fake_value(node, tx, allow_non_graph_fake=False): + """ + Run the computation represented by `node` using fake tensors and return the result. + + allow_non_graph_fake: whether to allow the return result to be: + 1. non-fake or 2. fake that is not created by this instance of Dynamo. + If `True`, you must be prepared to deal with such return values, ideally + by further wrapping them as this graph's fakes. + """ + from torch.utils._sympy.value_ranges import ValueRangeError + + from .exc import ( + TorchRuntimeError, + unimplemented_v2, + Unsupported, + UserError, + UserErrorType, + ) + + op = node.op + + # FX Node should always return the same fake value + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + + args, kwargs = get_fake_values_from_nodes( + tx, (node.args, node.kwargs), allow_non_graph_fake + ) + + if ( + torch._dynamo.config.use_graph_deduplication + or torch._dynamo.config.track_nodes_for_deduplication + ): + flat_args_kwargs = get_fake_values_from_nodes( + tx, _get_flat_args(node, {}), allow_non_graph_fake + ) + id_to_initial_version = { + id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg) + } + else: + flat_args_kwargs = [] + id_to_initial_version = {} + + nnmodule = None + if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module): + # If the first argument is nn.Module, should copy to fake mode. + args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:]) + + if op == "call_module": + nnmodule = tx.output.nn_modules[node.target] + + if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"): + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it. + # Afterwards, lazy module deletes its pre-hooks + # to avoid treating it as lazy on subsequent recompile. + nnmodule._infer_parameters(nnmodule, args) + + # no matter it's lazy module or not, we should copy to fake mode. + nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) + + if node.name in ["interpolate", "is_integer", "wrapped_gradient"] or any( + isinstance(a, complex) for a in args + ): + # We need to specialize symfloats for now. Eventually we should do a tensorify pass in dynamo. + args = tuple( + float(arg) + if isinstance(arg, torch.SymFloat) and arg.node.hint is not None + else arg + for arg in args + ) + + try: + with tx.fake_mode, enable_python_dispatcher(): + ret_val = wrap_fake_exception( + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ) + except Unsupported: + raise + except RuntimeError as e: + cause: BaseException = e + if e.__cause__ is not None: + cause = e.__cause__ + + if isinstance( + cause, torch._subclasses.fake_tensor.DataDependentOutputException + ): + # capture_scalar_outputs only works for these ops right now + # see torch/_subclasses/fake_impls.py + if cause.func in ( + torch.ops.aten.item.default, + torch.ops.aten._local_scalar_dense.default, + ): + # does this actually get triggered? + hints = [ + "Enable tracing of data-dependent output operators with " + "`torch._dynamo.config.capture_scalar_outputs = True`", + ] + else: + hints = [ + "Consider wrapping the operator into a PyTorch-understood custom operator " + "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)", + ] + unimplemented_v2( + gb_type="Data dependent operator", + context=str(cause.func), + explanation=f"Operator `{cause.func}` has a non-Tensor output " + "whose value is dependent on the data of Tensor inputs.", + hints=hints, + ) + elif isinstance( + cause, torch._subclasses.fake_tensor.DynamicOutputShapeException + ): + if not torch._dynamo.config.capture_dynamic_output_shape_ops: + unimplemented_v2( + gb_type="Dynamic shape operator", + context=str(cause.func), + explanation=f"Operator `{cause.func}`'s output shape depends on input Tensor data.", + hints=[ + "Enable tracing of dynamic shape operators with " + "`torch._dynamo.config.capture_dynamic_output_shape_ops = True`", + ], + ) + else: + unimplemented_v2( + gb_type="Dynamic shape operator (no meta kernel)", + context=str(cause.func), + explanation=f"Operator `{cause.func}` does not have a meta kernel that supports dynamic output shapes", + hints=[ + "Please report an issue to PyTorch", + ], + ) + elif isinstance( + cause, torch._subclasses.fake_tensor.UnsupportedOperatorException + ): + op = cause.func + import_suggestion = "" + if isinstance(op, torch._ops.OpOverload): + maybe_pystub = torch._C._dispatch_pystub( + op._schema.name, op._schema.overload_name + ) + if maybe_pystub is not None: + module, ctx = maybe_pystub + import_suggestion = ( + f"It's possible that the support was implemented in " + f"module `{module}` and you may need to `import {module}`" + f"({ctx}), otherwise " + ) + unimplemented_v2( + gb_type="Operator does not support running with fake tensors", + context=f"unsupported operator: {cause.func}", + explanation="", + hints=[ + f"{import_suggestion}see " + "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0" + " for how to fix", + ], + ) + elif isinstance( + cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode + ): + raise UserError( # noqa: B904 + UserErrorType.CONSTRAINT_VIOLATION, + str(cause), + case_name="constrain_as_size_example", + ) + elif isinstance(cause, ValueRangeError): + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e + elif isinstance(cause, TypeError) and "argument" in str(cause): + unimplemented_v2( + gb_type="TypeError when making fake tensor call", + context=f"TypeError {node.target}: {cause}", + explanation="", + hints=[], + ) + + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + + if not allow_non_graph_fake: + _ = pytree.tree_map_only( + torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val + ) + + if ( + torch._dynamo.config.use_graph_deduplication + or torch._dynamo.config.track_nodes_for_deduplication + ): + tx.output.region_tracker.track_node_mutations( + node, + flat_args_kwargs, + id_to_initial_version, + ) + + return ret_val + + +_current_node = threading.local() + + +def get_current_node(): + return getattr(_current_node, "value", None) + + +@contextmanager +def set_current_node(node): + old = get_current_node() + _current_node.value = node + try: + yield + finally: + _current_node.value = old + + +def run_node(tracer, node, args, kwargs, nnmodule): + """ + Runs a given node, with the given args and kwargs. + + Behavior is dictated by a node's op. + + run_node is useful for extracting real values out of nodes. + See get_real_value for more info on common usage. + + Note: The tracer arg is only used for 'get_attr' ops + Note: The nnmodule arg is only used for 'call_module' ops + + Nodes that are not call_function, call_method, call_module, or get_attr will + raise an AssertionError. + """ + op = node.op + + with set_current_node(node): + + def make_error_message(e): + return ( + f"Dynamo failed to run FX node with fake tensors: {op} {node.target}(*{args}, **{kwargs}): got " + + repr(e) + ) + + from .exc import Unsupported + + try: + if op == "call_function": + return node.target(*args, **kwargs) + elif op == "call_method": + if not hasattr(args[0], node.target): + from .exc import unimplemented_v2 + + unimplemented_v2( + gb_type="Missing attribute when running call_method node", + context="", + explanation=make_error_message("attribute not defined"), + hints=[], + ) + return getattr(args[0], node.target)(*args[1:], **kwargs) + elif op == "call_module": + assert nnmodule is not None + return nnmodule(*args, **kwargs) + elif op == "get_attr": + return tracer.output_graph.get_submodule(node.target) + elif op == "placeholder": + assert "example_value" in node.meta + return node.meta["example_value"] + + except (NotImplementedError, UnsupportedFakeTensorException) as e: + # NB: mimic how wrap_fake_exception does it + from .exc import unimplemented_v2 + + hints = [] + if isinstance(e, NotImplementedError): + hints = [ + "If the op is a PyTorch op, please file an issue to PyTorch.", + ] + + unimplemented_v2( + gb_type="NotImplementedError/UnsupportedFakeTensorException when running FX node", + context="", + explanation=make_error_message(e), + hints=hints, + from_exc=e, + ) + except Unsupported: + raise + except Exception as e: + raise RuntimeError(make_error_message(e)).with_traceback( + e.__traceback__ + ) from e + + raise AssertionError(op) + + +def get_real_value(node, tracer): + """ + Run the actual computation represented by `node` and return the result. + This will execute any dependent nodes in the graph as well. + """ + from .exc import TorchRuntimeError + + cache = tracer.real_value_cache + if node in cache: + return cache[node] + + op = node.op + args, kwargs = torch.fx.node.map_arg( # type: ignore[misc] + (node.args, node.kwargs), + lambda n: get_real_value(n, tracer), + ) + + if op == "placeholder" and "grapharg" in node.meta: + return node.meta["grapharg"].example + + if op == "call_module": + nn_module = tracer.output_graph.nn_modules[node.target] + if not is_lazy_module(nn_module): + nn_module = copy.deepcopy(nn_module) + else: + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nn_module(*args, **kwargs) + else: + nn_module = None + + try: + real_value = run_node(tracer, node, args, kwargs, nn_module) + cache[node] = real_value + except RuntimeError as e: + raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None + return real_value + + +def assert_no_fake_params_or_buffers(gm): + from torch._subclasses.fake_tensor import FakeTensorConfig, is_fake + + def stack_or_hint(t): + if FakeTensorConfig.debug: + import traceback + + return f"FAKE TENSOR CREATION TRACEBACK: \n {traceback.format_list(t._debug_trace)}" + else: + return "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors." + + for name, buffer in gm.named_buffers(): + assert not is_fake(buffer), ( + f"Unexpected fake buffer {name} {stack_or_hint(buffer)}" + ) + for name, param in gm.named_parameters(): + assert not is_fake(param), ( + f"Unexpected fake param {name} {stack_or_hint(param)}" + ) + + +def fqn(obj: Any): + """ + Returns the fully qualified name of the object. + """ + return f"{obj.__module__}.{obj.__qualname__}" + + +def ifdynstaticdefault(count1, count2): + if torch._dynamo.config.assume_static_by_default: + return count1 + else: + return count2 + + +def import_submodule(mod: types.ModuleType): + """ + Ensure all the files in a given submodule are imported + """ + for filename in sorted(os.listdir(os.path.dirname(cast(str, mod.__file__)))): + if filename.endswith(".py") and filename[0] != "_": + importlib.import_module(f"{mod.__name__}.{filename[:-3]}") + + +def object_has_getattribute(value: Any): + return class_has_getattribute(type(value)) + + +def object_setattr_ignore_descriptor(obj, name, value): + # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1286-L1335 + d = object.__getattribute__(obj, "__dict__") + d[name] = value + + +def class_has_getattribute(cls: type): + try: + if isinstance( + inspect.getattr_static(cls, "__getattribute__"), + types.FunctionType, + ): + return True + except AttributeError: + pass + return False + + +def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): + try: + getattr_fn = inspect.getattr_static(type(value), "__getattr__") + except AttributeError: + getattr_fn = None + if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__: + # ignore this case of getattr + getattr_fn = None + return getattr_fn + + +class TensorStaticReason(enum.Enum): + PARAMETER = 2 + NOT_TENSOR = 4 + NN_MODULE_PROPERTY = 5 + + +def tensor_static_reason_to_message(reason: TensorStaticReason): + if reason == TensorStaticReason.PARAMETER: + return "mark_dynamic on parameter, parameters are always static today." + if reason == TensorStaticReason.NOT_TENSOR: + return "mark_dynamic on a non tensor, how did this happen?" + if reason == TensorStaticReason.NN_MODULE_PROPERTY: + return "tensor is static because it is nn module associated." + raise AssertionError(f"Illegal reason {reason}") + + +def tensor_always_has_static_shape( + tensor: Union[torch.Tensor, Any], + is_tensor: bool, + tensor_source: Source, +) -> tuple[bool, Optional[TensorStaticReason]]: + """ + Given a tensor, source, and is_tensor flag, determine if a shape should be static. + + Args: + tensor - the real tensor to evaluate, parameters force a static shape. + is_tensor - internal dynamo check, essentially "is_tensor": target_cls is TensorVariable, + tensors not in a TensorVariable for whatever reason are forced static. + + Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape. + The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed. + """ + from .source import is_from_unspecialized_param_buffer_source + + if ( + tensor_source.guard_source().is_specialized_nn_module() + or tensor_source.guard_source().is_unspecialized_builtin_nn_module() + ) and config.force_nn_module_property_static_shapes: + return True, TensorStaticReason.NN_MODULE_PROPERTY + + if ( + type(tensor) is torch.nn.Parameter + or is_from_unspecialized_param_buffer_source(tensor_source) + ) and config.force_parameter_static_shapes: + return True, TensorStaticReason.PARAMETER + if not is_tensor: + return True, TensorStaticReason.NOT_TENSOR + return False, None + + +def lazy_format_graph_tabular(fn_name, gm): + def inner(): + try: + from tabulate import tabulate # TODO: Check that this is installed + except ImportError: + return ( + "Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n" + + str(lazy_format_graph_code(fn_name, gm)) + ) + + node_specs = [ + [n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes + ] + graph_str = tabulate( + node_specs, headers=["opcode", "name", "target", "args", "kwargs"] + ) + return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str) + + return LazyString(inner) + + +def format_bytecode(prefix, name, filename, line_no, code): + return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n" + + +forward_hook_names = ["_forward_pre_hooks", "_forward_hooks"] +backward_hook_names = ["_backward_pre_hooks", "_backward_hooks"] +state_dict_hook_names = [ + "_state_dict_pre_hooks", + "_state_dict_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", +] +all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names + + +def nn_module_has_global_hooks(): + # This is limited to backward hooks for now because NNModuleVariable + # supports fwd hooks underneath. + return len(torch.nn.modules.module._global_backward_hooks) or len( + torch.nn.modules.module._global_backward_pre_hooks + ) + + +def nn_module_get_all_hooks( + mod, + check_forward_hooks=False, + check_backward_hooks=False, + check_state_dict_hooks=False, +): + """ + Sometimes its useful to differentiate between types of hooks such as forward/backward/pre + hooks executed during module.__call__, and state_dict hooks which are executed separately. + """ + hook_dicts_to_check = [] + check_all_hooks = ( + not check_forward_hooks + and not check_backward_hooks + and not check_state_dict_hooks + ) + if check_forward_hooks or check_all_hooks: + hook_dicts_to_check.extend(forward_hook_names) + if check_backward_hooks or check_all_hooks: + hook_dicts_to_check.extend(backward_hook_names) + if check_state_dict_hooks: + hook_dicts_to_check.extend(state_dict_hook_names) + + all_hooks = [] + for hook_dict_name in hook_dicts_to_check: + hooks = getattr(mod, hook_dict_name, []) + for hook_name in hooks: + hook = hooks[hook_name] + + all_hooks.append(hook) + return all_hooks + + +def nnmodule_has_hooks( + mod, + check_forward_hooks=False, + check_backward_hooks=False, + check_state_dict_hooks=False, +): + """ + Helper function to check if a module has any hooks attached to it. + """ + hooks = nn_module_get_all_hooks( + mod, + check_forward_hooks=check_forward_hooks, + check_backward_hooks=check_backward_hooks, + check_state_dict_hooks=check_state_dict_hooks, + ) + return bool(hooks) + + +def to_numpy_helper(value): + """Convert tensor and tnp.ndarray to numpy.ndarray.""" + if is_fake(value): + return value + if isinstance(value, tnp.ndarray): + return to_numpy_helper(value.tensor) + elif isinstance(value, torch.Tensor): + return value.numpy(force=True) + elif isinstance(value, (tuple, list)): + return type(value)(to_numpy_helper(obj) for obj in value) + else: + return value + + +def numpy_to_tensor(value): + """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert.""" + assert np is not None + if isinstance(value, np.ndarray): + return torch.as_tensor(value) + if isinstance(value, tnp.ndarray): + return value.tensor + elif isinstance(value, (tuple, list)): + return type(value)(numpy_to_tensor(obj) for obj in value) + else: + return value + + +class numpy_to_tensor_wrapper: + def __init__(self, f): + self.f = f + self.__name__ = "wrapped_" + self.f.__name__ + + def __repr__(self) -> str: + return f">" + + def __call__(self, *args, **kwargs): + out = self.f(*args, **kwargs) + return numpy_to_tensor(out) + + +def numpy_attr_wrapper(obj, name): + if isinstance(obj, tnp.ndarray): + out = getattr(obj, name) + return numpy_to_tensor(out) + elif isinstance(obj, torch.Tensor): + out = getattr(tnp.ndarray(obj), name) + return numpy_to_tensor(out) + + +class numpy_method_wrapper: + """Convert obj from torch.Tensor to tnp.ndarray and call method. Then convert result back to torch.Tensor.""" + + def __init__(self, method: str): + self.method = method + self.__name__ = "wrapped_" + self.method + + def __repr__(self) -> str: + return f">" + + def __call__(self, *args, **kwargs): + obj = args[0] + if isinstance(obj, torch.Tensor): + obj = tnp.ndarray(obj) + method_callable = getattr(obj, self.method) + out = method_callable(*args[1:], **kwargs) + return numpy_to_tensor(out) + + +class numpy_operator_wrapper: + """Implements dunder methods for tnp.ndarray via functions from the operator library""" + + def __init__(self, op: Callable[..., Any]): + self.op = op + self.__name__ = f"wrapped_{op.__name__}" + + def __repr__(self) -> str: + return f">" + + def __call__(self, *args, **kwargs): + assert not kwargs + + args = ( + tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args + ) + out = self.op(*args) + return numpy_to_tensor(out) + + +def defake(x): + if not isinstance(x, FakeTensor): + return x + size: torch._prims_common.ShapeType + stride: torch._prims_common.StrideType + if x._has_symbolic_sizes_strides: + size = [] + for s in x.size(): + if isinstance(s, torch.SymInt): + size.append(s.node.shape_env.size_hint(s.node.expr)) + else: + size.append(s) + stride = [] + for s in x.stride(): + if isinstance(s, torch.SymInt): + stride.append(s.node.shape_env.size_hint(s.node.expr)) + else: + stride.append(s) + else: + size = x.size() + stride = x.stride() + y = torch.empty_strided( + size, + stride, + dtype=x.dtype, + device=x.device, + requires_grad=x.requires_grad, + ) + y.zero_() + return y + + +def _disable_side_effect_safety_checks_for_current_subtracer(fn, *args, **kwargs): + return fn(*args, **kwargs) + + +def is_utils_checkpoint(obj): + # Lazy import to avoid circular dependencies + import torch.utils.checkpoint + + return obj is torch.utils.checkpoint.checkpoint + + +def is_invoke_subgraph(obj): + from torch._higher_order_ops.invoke_subgraph import invoke_subgraph_placeholder + + return obj is invoke_subgraph_placeholder + + +def build_invoke_subgraph_variable(**options): + from .variables.higher_order_ops import TorchHigherOrderOperatorVariable + + return TorchHigherOrderOperatorVariable.make( + torch._higher_order_ops.invoke_subgraph, + **options, + ) + + +def build_checkpoint_variable(**options): + import torch._higher_order_ops.wrap as higher_order_ops + + from .variables.higher_order_ops import TorchHigherOrderOperatorVariable + + # TODO - This is a temporary situation where we have two versions of + # checkpointing implementation. We will converge on one and remove the other. + activation_checkpoint_op: torch._ops.HigherOrderOperator = ( + higher_order_ops.tag_activation_checkpoint + ) + if torch._functorch.config.functionalize_rng_ops: + activation_checkpoint_op = higher_order_ops.wrap_activation_checkpoint + + return TorchHigherOrderOperatorVariable.make( + activation_checkpoint_op, + **options, + ) + + +def is_compile_supported(device_type): + from .eval_frame import is_dynamo_supported + + type = torch.device(device_type).type + compile_supported = is_dynamo_supported() + if type == "cpu": + pass + elif type in ["cuda", "xpu"] and compile_supported: + compile_supported = has_triton() + else: + compile_supported = False + return compile_supported + + +# The following 3.11 source code functions are adapted from +# https://github.com/python/cpython/blob/v3.11.4/Lib/traceback.py +# in order to output source code corresponding to bytecode in 3.11+. +# We need our own versions since we want to support multiline expressions. +def _fix_offset(str: str, offset: int) -> int: + """ + Convert byte offset `offset` of `str` into character offset. + Byte offset is used for 3.11+ instruction column data. + Takes things like unicode characters into consideration. + + Unchanged from CPython implementation. + """ + as_utf8 = str.encode("utf-8") + return len(as_utf8[:offset].decode("utf-8", errors="replace")) + + +@dataclasses.dataclass +class _Anchors: + # inclusive + left_end_lineno: int + left_end_offset: int + right_start_lineno: int + # exclusive + right_start_offset: int + + +def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: + """ + Given source code `segment` corresponding to a bytecode + instruction, determine: + - for binary ops, the location of the binary op + - for indexing, the location of the brackets. + `segment` is expected to be a valid Python expression + """ + assert sys.version_info >= (3, 11) + + import ast + + try: + # Without brackets, `segment` is parsed as a statement. + # We expect an expression, so wrap `segment` in + # brackets to handle multi-line expressions. + tree = ast.parse("(\n" + segment + "\n)") + except SyntaxError: + return None + + if len(tree.body) != 1: + return None + + lines = segment.split("\n") + + # get character index given byte offset + def normalize(lineno, offset): + return _fix_offset(lines[lineno], offset) + + # Gets the next valid character index in `lines`, if + # the current location is not valid. Handles empty lines. + def next_valid_char(lineno, col): + while lineno < len(lines) and col >= len(lines[lineno]): + col = 0 + lineno += 1 + assert lineno < len(lines) and col < len(lines[lineno]) + return lineno, col + + # Get the next valid character index in `lines`. + def increment(lineno, col): + col += 1 + lineno, col = next_valid_char(lineno, col) + assert lineno < len(lines) and col < len(lines[lineno]) + return lineno, col + + # Get the next valid character at least on the next line + def nextline(lineno, col): + col = 0 + lineno += 1 + lineno, col = next_valid_char(lineno, col) + assert lineno < len(lines) and col < len(lines[lineno]) + return lineno, col + + statement = tree.body[0] + if isinstance(statement, ast.Expr): + expr = statement.value + if isinstance(expr, ast.BinOp): + # ast gives locations for BinOp subexpressions, e.g. + # ( left_expr ) + ( right_expr ) + # left^^^^^ right^^^^^ + # -2 since end_lineno is 1-indexed and because we added an extra + # bracket to `segment` when calling ast.parse + cur_lineno = cast(int, expr.left.end_lineno) - 2 + cur_col = normalize(cur_lineno, expr.left.end_col_offset) + cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col) + + # Heuristic to find the operator character. + # The original CPython implementation did not look for ), \, or #, + # leading to incorrect anchor location, e.g. + # (x) + (y) + # ~~^~~~~~~ + while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#": + if ch in "\\#": + cur_lineno, cur_col = nextline(cur_lineno, cur_col) + else: + cur_lineno, cur_col = increment(cur_lineno, cur_col) + + # binary op is 1 or 2 characters long, on the same line + right_col = cur_col + 1 + if ( + right_col < len(lines[cur_lineno]) + and not (ch := lines[cur_lineno][right_col]).isspace() + and ch not in "\\#" + ): + right_col += 1 + # right_col can be invalid since it is exclusive + + return _Anchors(cur_lineno, cur_col, cur_lineno, right_col) + elif isinstance(expr, ast.Subscript): + # ast gives locations for value and slice subexpressions, e.g. + # ( value_expr ) [ slice_expr ] + # value^^^^^ slice^^^^^ + # subscript^^^^^^^^^^^^^^^^^^^^ + # find left bracket (first '[' after value) + left_lineno = cast(int, expr.value.end_lineno) - 2 + left_col = normalize(left_lineno, expr.value.end_col_offset) + left_lineno, left_col = next_valid_char(left_lineno, left_col) + while lines[left_lineno][left_col] != "[": + left_lineno, left_col = increment(left_lineno, left_col) + # find right bracket (final character of expression) + right_lineno = cast(int, expr.end_lineno) - 2 + right_col = normalize(right_lineno, expr.end_col_offset) + return _Anchors(left_lineno, left_col, right_lineno, right_col) + elif isinstance(expr, ast.Call): + # ( func_expr ) (args, kwargs) + # func^^^^^ + # call^^^^^^^^^^^^^^^^^^^^^^^^ + # find left bracket (first '(' after func) + left_lineno = cast(int, expr.func.end_lineno) - 2 + left_col = normalize(left_lineno, expr.func.end_col_offset) + left_lineno, left_col = next_valid_char(left_lineno, left_col) + while lines[left_lineno][left_col] != "(": + left_lineno, left_col = increment(left_lineno, left_col) + # find right bracket (final character of expression) + right_lineno = cast(int, expr.end_lineno) - 2 + right_col = normalize(right_lineno, expr.end_col_offset) + return _Anchors(left_lineno, left_col, right_lineno, right_col) + + return None + + +def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> str: + """ + Python 3.11+ only. Returns lines of source code (from code object `code`) + corresponding to `inst`'s location data, and underlines relevant code to `inst`. + + Example: CALL on `g`: + f(g( + ^^ + h(x))) + ^^^^^ + + We need our own implementation in < 3.13 since `format_frame_summary` in + Python's `traceback` module doesn't handle multi-line expressions + (and their anchor extraction code is not completely correct). + """ + if sys.version_info >= (3, 13): + # multiline traceback implemented in 3.13+ + frame_summary = traceback.FrameSummary( + code.co_filename, + inst.positions.lineno, + code.co_name, + end_lineno=inst.positions.end_lineno, + colno=inst.positions.col_offset, + end_colno=inst.positions.end_col_offset, + ) + result = traceback.format_list([frame_summary])[0] + # remove first line containing filename info + result = "\n".join(result.splitlines()[1:]) + # indent lines with original indentation + orig_lines = [ + linecache.getline(code.co_filename, lineno).rstrip() + for lineno in range(inst.positions.lineno, inst.positions.end_lineno + 1) + ] + orig_lines_dedent = textwrap.dedent("\n".join(orig_lines)).splitlines() + indent_len = len(orig_lines[0]) - len(orig_lines_dedent[0]) + indent = orig_lines[0][:indent_len] + result = textwrap.indent(textwrap.dedent(result), indent) + return result + + assert inst.positions is not None + if inst.positions.lineno is None: + return "" + # The rstrip + "\n" pattern is used throughout this function to handle + # linecache.getline errors. Error lines are treated as empty strings "", but we want + # to treat them as blank lines "\n". + first_line = linecache.getline(code.co_filename, inst.positions.lineno).rstrip() + if inst.positions.end_lineno is None: + return first_line + if inst.positions.col_offset is None or inst.positions.end_col_offset is None: + return first_line + + # character index of the start of the instruction + start_offset = _fix_offset(first_line, inst.positions.col_offset) + # character index of the end of the instruction + # compute later since end may be a different line + end_offset = None + # expression corresponding to the instruction so we can get anchors + segment = "" + # underline markers to be printed - start with `~` marker and replace with `^` later + markers = [] + + # Compute segment and initial markers + if inst.positions.end_lineno == inst.positions.lineno: + end_offset = _fix_offset(first_line, inst.positions.end_col_offset) + segment = first_line[start_offset:end_offset] + markers.append(" " * start_offset + "~" * (end_offset - start_offset)) + else: + segment = first_line[start_offset:] + "\n" + markers.append(" " * start_offset + "~" * (len(first_line) - start_offset)) + last_line = linecache.getline( + code.co_filename, inst.positions.end_lineno + ).rstrip() + end_offset = _fix_offset(last_line, inst.positions.end_col_offset) + for lineno in range(inst.positions.lineno + 1, inst.positions.end_lineno): + line = linecache.getline(code.co_filename, lineno).rstrip() + segment += line + "\n" + # don't underline leading spaces + num_spaces = len(line) - len(line.lstrip()) + markers.append(" " * num_spaces + "~" * (len(line) - num_spaces)) + segment += last_line[:end_offset] + num_spaces = len(last_line) - len(last_line.lstrip()) + markers.append(" " * num_spaces + "~" * (end_offset - num_spaces)) + + anchors: Optional[_Anchors] = None + try: + anchors = _extract_anchors_from_expr(segment) + except AssertionError: + pass + + # replace `~` markers with `^` where necessary + if anchors is None: + markers = [marker.replace("~", "^") for marker in markers] + else: + # make markers mutable + mutable_markers: list[list[str]] = [list(marker) for marker in markers] + + # anchor positions do not take start_offset into account + if anchors.left_end_lineno == 0: + anchors.left_end_offset += start_offset + if anchors.right_start_lineno == 0: + anchors.right_start_offset += start_offset + + # Turn `~`` markers between anchors to `^` + for lineno in range(len(markers)): + for col in range(len(mutable_markers[lineno])): + if lineno < anchors.left_end_lineno: + continue + if lineno == anchors.left_end_lineno and col < anchors.left_end_offset: + continue + if ( + lineno == anchors.right_start_lineno + and col >= anchors.right_start_offset + ): + continue + if lineno > anchors.right_start_lineno: + continue + if mutable_markers[lineno][col] == "~": + mutable_markers[lineno][col] = "^" + + # make markers into strings again + markers = ["".join(marker) for marker in mutable_markers] + + result = "" + for i in range(len(markers)): + result += ( + linecache.getline(code.co_filename, inst.positions.lineno + i).rstrip() + + "\n" + ) + result += markers[i] + "\n" + return result + + +def get_static_address_type(t): + if isinstance(t, torch.Tensor): + return getattr(t, "_dynamo_static_input_type", None) + + return None + + +def is_rng_state_getter_or_setter(value): + getters = ( + # The following two functions are not identical, so don't remove anyone! + torch._C.Generator.get_state, + torch.default_generator.get_state, + torch.get_rng_state, + torch.cuda.get_rng_state, + ) + setters = ( + torch._C.Generator.set_state, + torch.default_generator.set_state, + torch.set_rng_state, + torch.cuda.set_rng_state, + ) + return value in (*setters, *getters) + + +def is_tensor_base_attr_getter(value): + return ( + isinstance(value, types.MethodWrapperType) + and value.__name__ == "__get__" + and value.__self__.__objclass__ is torch._C._TensorBase # type: ignore[attr-defined] + ) + + +def is_tensor_getset_descriptor(name): + try: + attr = inspect.getattr_static(torch.Tensor, name) + return type(attr) is types.GetSetDescriptorType + except AttributeError: + return False + + +def is_torch_function_object(value): + return hasattr(value, "__torch_function__") + + +def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool: + # This emulates + # https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/disable_torch_function.cpp#L315-L323 + from torch._dynamo.variables import UserDefinedObjectVariable + from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable + + # Note on lazy vars: The value will either be realized or not throughout the course of execution + # if the value has a torch function, it will eventually be realized so we can realize it here + # if the value does not have a torch function, it may or may not be realized + # if it is realized it will be used and guards will be installed properly + # if it is not used, guards won't be installed, and it doesn't matter + # if the value has a torch function or not, so we should *not* realize it. + # NB: We technically know that if is_realized is False, LazyVariableTracker has the peek_value method + # but mypy does not unfortunately + if vt.is_realized() or ( + hasattr(vt, "peek_value") and hasattr(vt.peek_value(), "__torch_function__") + ): + func = None + if isinstance(vt, TensorWithTFOverrideVariable): + func = getattr(vt.class_type, "__torch_function__", None) + + elif isinstance(vt, UserDefinedObjectVariable): + func = getattr(vt.value, "__torch_function__", None) + + return func not in (None, torch._C._disabled_torch_function_impl) + + return False + + +# see note [Tensor Fakification and Symbol Caching] +def to_fake_tensor(t, fake_mode): + symbolic_context = None + source = None + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + source = symbolic_context.tensor_source + + return fake_mode.from_tensor( + t, static_shapes=False, symbolic_context=symbolic_context, source=source + ) + + +# NB: this works for both classes and instances +def is_frozen_dataclass(value): + return ( + not object_has_getattribute(value) + and not class_has_getattribute(value) + and is_dataclass(value) + and hasattr(value, "__dataclass_params__") + and hasattr(value.__dataclass_params__, "frozen") + and value.__dataclass_params__.frozen + ) + + +def get_first_attr(obj, *attrs): + """ + Return the first available attribute or throw an exception if none is present. + """ + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + + raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") + + +@contextlib.contextmanager +def maybe_enable_compiled_autograd(should_enable, fullgraph=True, dynamic=True): + if not should_enable: + yield + else: + + def compiler_fn(gm): + def inner_compiler(gm_, example_inputs_): + torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1 + return torch._inductor.compile(gm_, example_inputs_) + + return torch.compile( + gm, backend=inner_compiler, fullgraph=fullgraph, dynamic=dynamic + ) + + with torch._dynamo.compiled_autograd._enable(compiler_fn) as ctx: + yield ctx + + +def invalid_removeable_handle(): + # need a subclass so weakref works + class Invalid(dict): # type: ignore[type-arg] + pass + + return RemovableHandle(Invalid()) + + +# Returns a "proxy" (new object with the same class and dict) for (non-GraphModule) nn.Module's. +# Attribute changes to the original object/proxy will be reflected in the other. +# This is useful for cases where we want a keep-alive reference to a module without increasing +# its reference count. +def nn_module_proxy(mod): + if not isinstance(mod, torch.nn.Module): + return mod + if isinstance(mod, torch.fx.GraphModule): + # Dynamo-generated GM's shouldn't contain user-created GM's + return mod + proxy = mod.__class__.__new__(mod.__class__) + proxy.__dict__ = mod.__dict__ + return proxy + + +class GmWrapper(torch.nn.Module): + def __init__(self, gm, unflatten_fn): + super().__init__() + self.gm = gm + self.unflatten_fn = unflatten_fn + + def forward(self, *args): + args: list[Any] = list(args) + return self.gm(*self.unflatten_fn(args)) + + +def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): + """ + Mutate inputs so that they are flat and wrap gm such that it + accepts those inputs. This is needed for graphs that take + bumpy inputs. + """ + inputs_idx_to_clear = [ + i + for i, node in enumerate(gm.graph.nodes) + if node.op == "placeholder" and node.meta.get("steal_arg", False) + ] + + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + # fast path, avoid pytree overhead + # compiled autograd inputs are always a list of tensors, maybe followed by symints + assert inputs_idx_to_clear == [0] + assert isinstance(inputs[0], list) + boxed_inputs_count = len(inputs[0]) + + def flatten_fn(args): + return args[0] + list(args[1:]) + + def unflatten_fn(flat_args): + return (flat_args[:boxed_inputs_count], *flat_args[boxed_inputs_count:]) + + compiled_fn = compile_gm(GmWrapper(gm, unflatten_fn), flatten_fn(inputs)) + else: + # slow path, don't know inputs structure + flat_inputs, spec = pytree.tree_flatten(inputs) + unflatten_fn = functools.partial(pytree.tree_unflatten, treespec=spec) + compiled_fn = compile_gm(GmWrapper(gm, unflatten_fn), flat_inputs) + # note this doesn't check the spec, assuming it is the same + flatten_fn = pytree.arg_tree_leaves + + def wrapper(*args): + flat_args = flatten_fn(args) + + # flat_args is a new list, so we need to clear references from the old list + for i in inputs_idx_to_clear: + args[i].clear() + + # this call is boxed to avoid increasing refcount until we reach aot_module_simplified forward + return compiled_fn(flat_args) + + return wrapper + + +def get_locals_to_steal(maybe_gm): + if not isinstance(maybe_gm, torch.fx.GraphModule) or not hasattr(maybe_gm, "meta"): + return [] + return maybe_gm.meta.get("locals_to_steal", []) + + +def set_locals_to_steal(gm, locals_to_steal): + gm.meta["locals_to_steal"] = locals_to_steal + + +class Lit: + def __init__(self, s): + self.s = s + + def __repr__(self) -> str: + return self.s + + +warn_once_cache: set[str] = set() + + +def warn_once(msg, stacklevel=1): + # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time. + # https://github.com/pytorch/pytorch/issues/128427. + # warn_once is a workaround: if the msg has been warned on before, then we will not + # warn again. + # NB: it's totally ok to store a cache of all the strings: this is what warnings.warn does as well. + if msg in warn_once_cache: + return + warn_once_cache.add(msg) + warnings.warn(msg, stacklevel=stacklevel + 1) + + +def strip_color_from_string(text): + # This regular expression matches ANSI escape codes + ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") + return ansi_escape.sub("", text) + + +@contextlib.contextmanager +def _disable_saved_tensors_hooks_during_tracing(): + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + try: + prior = torch._C._autograd._saved_tensors_hooks_set_tracing(True) + yield + finally: + torch._C._autograd._saved_tensors_hooks_set_tracing(prior) + + +def is_parameter_freezing(): + return torch._inductor.config.freezing and not torch.is_grad_enabled() + + +def get_torch_function_mode_stack(): + return [ + get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) + ] + + +def get_torch_function_mode_stack_at(ind): + assert ind < _len_torch_function_stack() and ind >= 0 + return torch._C._get_function_stack_at(ind) + + +def set_torch_function_mode_stack(stack): + for _ in range(_len_torch_function_stack()): + _pop_torch_function_stack() + + for mode in stack: + _push_on_torch_function_stack(mode) + + +def clear_torch_function_mode_stack(): + for _ in range(_len_torch_function_stack()): + _pop_torch_function_stack() + + +# call from C dynamo in order to inspect values in pdb +def _breakpoint_for_c_dynamo(*args): + breakpoint() + + +def verify_guard_fn_signature(value): + fn = value.__metadata_guard__ + sig = inspect.signature(fn) + if len(sig.parameters) != 2: + from .exc import InternalTorchDynamoError + + raise InternalTorchDynamoError( + "Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments" + ) + if fn.__self__ != value.__class__: + from .exc import InternalTorchDynamoError + + raise InternalTorchDynamoError( + "Tensor subclass method __metadata_guard__ must be a classmethod" + ) + + +def does_not_override_dict_iter_methods(user_cls): + return ( + user_cls.items in (dict.items, OrderedDict.items) + and user_cls.values in (dict.values, OrderedDict.values) + and user_cls.keys in (dict.keys, OrderedDict.keys) + and user_cls.__iter__ in (dict.__iter__, OrderedDict.__iter__) + ) + + +# Helper functions below are to prevent TorchDynamo to prevent tracing of +# __torch_function__ calls triggered on tensor properties in the pre graph +# bytecode. +@torch._disable_dynamo +def call_size(x, i): + return x.size(i) + + +@torch._disable_dynamo +def call_stride(x, i): + return x.stride(i) + + +@torch._disable_dynamo +def call_storage_offset(x): + return x.storage_offset() + + +# Helper function to extract relevant parts of a tensor's __dict__ to store in node meta. +# To avoid ref cycles, it's important that no tensors are present here, so leave those out. +def _extract_tensor_dict(t): + KEYS_TO_COPY = [ + "_dynamo_static_input_type", + "tag", + ] + + tensor_dict = { + key: copy.copy(t.__dict__[key]) for key in KEYS_TO_COPY if key in t.__dict__ + } + + return tensor_dict + + +# This is useful for reconstructing within the Dynamo graph the non-graph-input objects +# whose lifetime is governed by the user. +# e.g. torch.cuda.Event is a prime example. +user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {} + + +def get_user_object_from_id(obj_id): + obj = user_obj_id_to_weakref[obj_id]() + assert obj is not None, "User object is no longer alive" + return obj + + +def store_user_object_weakref(obj): + obj_id = id(obj) + user_obj_id_to_weakref[obj_id] = weakref.ref(obj) + + +class CompileTimeInstructionCounter: + _counter: int = 0 + _id: int = -1 + _depth = 0 + + @classmethod + def start(cls) -> None: + cls._depth = cls._depth + 1 + if cls._depth == 1: + cls._id = _instruction_counter.start() + + @classmethod + def end(cls) -> None: + cls._depth = cls._depth - 1 + if cls._depth == 0: + cls._counter += _instruction_counter.end(cls._id) + cls._id = -1 + + @classmethod + def clear(cls) -> None: + cls._counter = 0 + + @classmethod + def value(cls) -> int: + return cls._counter + + @classmethod + @contextmanager + def record(cls): + try: + if config.record_compile_time_instruction_count: + cls.start() + yield + finally: + if config.record_compile_time_instruction_count: + cls.end() + + +def set_feature_use(feature: str, usage: bool): + """ + Records whether we are using a feature + Generally a feature is a JK. + """ + # Note that sometimes (tests etc...) we're not in a context which we can record into + if get_metrics_context().in_progress(): + get_metrics_context().set_key_value("feature_usage", feature, usage) + + +_ddp_optimization_mode: tuple[str, ...] = ( + "ddp_optimizer", + "python_reducer", # experimental mode + "python_reducer_without_compiled_forward", + "no_optimization", +) + + +def get_optimize_ddp_mode(): + optimize_ddp = config.optimize_ddp + if isinstance(optimize_ddp, bool): + mode = "ddp_optimizer" if optimize_ddp else "no_optimization" + elif isinstance(optimize_ddp, str): + mode = optimize_ddp + else: + raise ValueError( + f"Invalid dynamo config optimize_ddp type {type(optimize_ddp)=}" + ) + + assert mode in _ddp_optimization_mode, ( + f"Invalid dynamo config optimize_ddp value {mode=}" + ) + return mode + + +@contextmanager +def maybe_disable_inference_mode() -> Generator[None, None, None]: + """ + Disables torch.inference_mode for the compilation (still on at runtime). + This simplifies the compile stack where we can assume that inference_mode + will always be off. + + Since inference_mode is equivalent to no_grad + some optimizations (version + counts etc), we turn on no_grad here. The other optimizations are not + relevant to torch.compile. + """ + is_inference_mode_on = ( + config.fake_tensor_disable_inference_mode and torch.is_inference_mode_enabled() + ) + if is_inference_mode_on: + with ( + torch.inference_mode(False), + torch.no_grad(), + ): + yield + else: + yield + + +@contextmanager +def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]: + """ + Turns off tracking of inference_mode for fake tensor propagation. With this + context manager, when a real tensor is converted to fake tensor, the fake + tensor looses its inference-ness. + """ + if config.fake_tensor_disable_inference_mode: + with torch._subclasses.meta_utils.disable_inference_mode_for_fake_prop(): + yield + else: + yield + + +def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool: + return node is None or "example_value" in node.meta or "val" in node.meta + + +@torch._disable_dynamo +def record_pregraph_bytecode_enter() -> AbstractContextManager[None]: + cm: AbstractContextManager[None] = ( + torch._C._profiler._RecordFunctionFast("Pregraph bytecode") + if torch.autograd.profiler._is_profiler_enabled + else contextlib.nullcontext() + ) + cm.__enter__() + return cm + + +@torch._disable_dynamo +def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None: + cm.__exit__(None, None, None) + + +# Returns a set of code objects present traced in the current TracingContext, or None +# if there is no current TracingContext. +def get_traced_code() -> list[CodeType]: + from torch._guards import TracingContext + + return TracingContext.get_traced_code() diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__init__.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c90294e2887800a414f6d570392801708eb7d72a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/__init__.py @@ -0,0 +1,224 @@ +""" +This package implements variable tracking and symbolic execution capabilities for Dynamo, +which are essential for converting Python code into FX graphs. It provides a comprehensive +set of variable types that handle different Python constructs during tracing. + +Each variable type (like BuiltinVariable, TensorVariable, NNModuleVariable, etc.) is responsible +for tracking and symbolically executing operations on specific Python objects. This enables +Dynamo to: +- Track the flow of values through Python code +- Maintain correct semantics during graph conversion +- Handle complex Python features like context managers, iterators, and custom objects +- Support both eager and symbolic execution modes + +The VariableTracker base class provides the foundation for all variable types, with each +subclass implementing specific behavior for different Python constructs. This modular design +allows Dynamo to accurately trace and optimize Python code while preserving its semantics. +""" + +from .base import VariableTracker +from .builtin import BuiltinVariable +from .constant import ConstantVariable, EnumVariable +from .ctx_manager import ( + CatchWarningsCtxManagerVariable, + ContextWrappingVariable, + CUDADeviceVariable, + DeterministicAlgorithmsVariable, + DisabledSavedTensorsHooksVariable, + DualLevelContextManager, + DynamoConfigPatchVariable, + FSDPParamGroupUseTrainingStateVariable, + GradIncrementNestingCtxManagerVariable, + GradInplaceRequiresGradCtxManagerVariable, + GradModeVariable, + InferenceModeVariable, + JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, + SetFwdGradEnabledContextManager, + StreamContextVariable, + StreamVariable, + TemporarilyPopInterpreterStackCtxManagerVariable, + VmapIncrementNestingCtxManagerVariable, + WithExitFunctionVariable, +) +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictKeySetVariable, + FrozensetVariable, + MappingProxyVariable, + NNModuleHooksDictVariable, + SetVariable, +) +from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable +from .functions import ( + BuiltinMethodVariable, + CollectionsNamedTupleFunction, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, + FunctionDecoratedByContextlibContextManagerVariable, + FunctoolsPartialVariable, + FunctoolsWrapsVariable, + LocalGeneratorFunctionVariable, + LocalGeneratorObjectVariable, + NestedUserFunctionVariable, + PolyfilledFunctionVariable, + SkipFunctionVariable, + TMADescriptorExperimentalVariable, + TMADescriptorStableVariable, + UserFunctionVariable, + UserMethodVariable, + WrapperUserFunctionVariable, + WrapperUserMethodVariable, +) +from .higher_order_ops import ( + FunctionalCallVariable, + FunctorchHigherOrderVariable, + TorchHigherOrderOperatorVariable, +) +from .iter import ( + CountIteratorVariable, + CycleIteratorVariable, + FilterVariable, + IteratorVariable, + ItertoolsVariable, + MapVariable, + RepeatIteratorVariable, + ZipVariable, +) +from .lazy import LazyVariableTracker +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + NamedTupleVariable, + RangeVariable, + SliceVariable, + TupleIteratorVariable, + TupleVariable, +) +from .misc import ( + AutogradFunctionContextVariable, + AutogradFunctionVariable, + CellVariable, + DeletedVariable, + ExceptionVariable, + GetAttrVariable, + LambdaVariable, + MethodWrapperVariable, + NewGlobalVariable, + NumpyVariable, + PythonModuleVariable, + RandomClassVariable, + RandomVariable, + RegexPatternVariable, + StringFormatVariable, + SuperVariable, + TorchVersionVariable, + TypingVariable, + UnknownVariable, + WeakRefVariable, +) +from .nn_module import ( + FSDPManagedNNModuleVariable, + NNModuleVariable, + UnspecializedBuiltinNNModuleVariable, + UnspecializedNNModuleVariable, +) +from .optimizer import OptimizerVariable +from .sdpa import SDPAParamsVariable +from .tensor import ( + DataPtrVariable, + FakeItemVariable, + NumpyNdarrayVariable, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, + UntypedStorageVariable, +) +from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable +from .user_defined import ( + MutableMappingVariable, + RemovableHandleVariable, + UserDefinedClassVariable, + UserDefinedDictVariable, + UserDefinedExceptionClassVariable, + UserDefinedExceptionObjectVariable, + UserDefinedListVariable, + UserDefinedObjectVariable, + UserDefinedTupleVariable, +) + + +__all__ = [ + "AutogradFunctionContextVariable", + "AutogradFunctionVariable", + "BackwardHookVariable", + "BaseListVariable", + "BuiltinVariable", + "CatchWarningsCtxManagerVariable", + "ConstantVariable", + "ConstDictVariable", + "ContextWrappingVariable", + "CountIteratorVariable", + "CreateTMADescriptorExperimentalVariable", + "CreateTMADescriptorStableVariable", + "CUDADeviceVariable", + "CycleIteratorVariable", + "DataPtrVariable", + "DefaultDictVariable", + "DeletedVariable", + "DeterministicAlgorithmsVariable", + "DictKeySetVariable", + "DynamoConfigPatchVariable", + "EnumVariable", + "FakeItemVariable", + "GetAttrVariable", + "GradModeVariable", + "IteratorVariable", + "ItertoolsVariable", + "LambdaVariable", + "LazyVariableTracker", + "ListIteratorVariable", + "ListVariable", + "NamedTupleVariable", + "NestedUserFunctionVariable", + "CellVariable", + "NewGlobalVariable", + "NNModuleVariable", + "NumpyNdarrayVariable", + "NumpyVariable", + "OptimizerVariable", + "PlacementVariable", + "PolyfilledFunctionVariable", + "PythonModuleVariable", + "RangeVariable", + "RegexPatternVariable", + "RemovableHandleVariable", + "RepeatIteratorVariable", + "SDPAParamsVariable", + "SkipFunctionVariable", + "SliceVariable", + "StringFormatVariable", + "SuperVariable", + "TemporarilyPopInterpreterStackCtxManagerVariable", + "TensorVariable", + "TMADescriptorExperimentalVariable", + "TMADescriptorStableVariable", + "TorchCtxManagerClassVariable", + "TorchInGraphFunctionVariable", + "TorchVersionVariable", + "TupleVariable", + "UnknownVariable", + "UnspecializedNNModuleVariable", + "UnspecializedPythonVariable", + "UntypedStorageVariable", + "UserDefinedClassVariable", + "UserDefinedTupleVariable", + "UserDefinedObjectVariable", + "UserFunctionVariable", + "UserMethodVariable", + "VariableTracker", + "WithExitFunctionVariable", + "MappingProxyVariable", +] diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1db5609268b3a5c60c1e3c157fdeead4afb6e34d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c620a0169eb068c0956247da9d564b1d7a055ce7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/base.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/builder.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec55777eaefb554514564f44e6799360cf3ef709 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/builder.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/builtin.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/builtin.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56bf9777474c92910d883eb2c0e9fcd27f10dae4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/builtin.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e1d9c8815ea0a5bf49b747808709a5ab5c0a76b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/constant.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fde1be710bbb1e8a3c6088d17c76535bcb20ae9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/ctx_manager.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..654f2e741f0c6fa2c5bcf0ca49c0afc6766d1128 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/dicts.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a360436b267886a349d97e5feec6b9de8626685c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/distributed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/functions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..597484a7fb40c933d08bc25ee32ad306c285bf35 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/higher_order_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/higher_order_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5019e9ecf6180f8a134cc21589464bca399ccdcc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/higher_order_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..573cfa5d7af4faaa7a876d381e5c4edd4dd4b525 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/iter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2be5a14fa6a59dabd76a98185201dac3aa1e3ad Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/lazy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77a23ceaef68e60ce39200d5bf95e209080dd86b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/lists.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b229e878b1ad41d8ff90b56cd531d6a97b65a40 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/misc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b53f6f10f5f65e9eed91cb3b7546f6ceff875857 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/nn_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4db242671a52ca4d9540bd3927152227c116f421 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/optimizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/script_object.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/script_object.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f29f5020812eecb6bde2673c2b2034e1fb7e6c88 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/script_object.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a4e8c1918d7d4f3aaec6bbcd3ac12c420e7aebb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/sdpa.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaae2548baf349a3ecbbf5a166ab0c613de91aef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b631fafd9e25e3a60a480cc3fae6e1218339567d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c871cd148d2c81a70501042f2e3d89f82e7a385a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/torch_function.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/user_defined.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/user_defined.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae2932fa851d47d6fd49fb0f9ad0b454426d55f8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_dynamo/variables/__pycache__/user_defined.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/base.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c5266cde61297b0ca15a5afe58bf1f22a9a94e9e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/base.py @@ -0,0 +1,642 @@ +# mypy: ignore-errors + +""" +Core variable tracking functionality for Dynamo. This module defines the fundamental +classes and systems used to track and manage variables during Dynamo's operation. + +The module provides: +1. VariableTracker - The base class for tracking variables during compilation +2. MutationType system - Classes for tracking and managing mutations to variables +3. Source type management - Utilities for tracking variable origins and scope +4. Variable state management - Tools for managing variable state and transformations + +These components form the foundation of Dynamo's variable handling system, +enabling accurate tracking and transformation of Python code into optimized +computations. +""" + +import collections +from collections.abc import ItemsView, KeysView, Sequence, ValuesView +from enum import Enum +from typing import Any, Callable, Optional, TYPE_CHECKING + +from .. import graph_break_hints, variables +from ..current_scope_id import current_scope_id +from ..exc import raise_observed_exception, unimplemented_v2 +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, Source +from ..utils import cmp_name_to_op_mapping, istype + + +if TYPE_CHECKING: + from ..codegen import PyCodegen + from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase + + +class SourceType(Enum): + """ + This Enum divides VariableTracker into 2 cases, depending on the variable + it represents: + - already existed that Dynamo began tracking while introspection (Existing) + - is a new variable that is created during Dynamo introspection (New) + + In general, we have these invariants: + 1. for `VariableTracker` associated with `Existing`, its `source` field must not be None. + 2. for `VariableTracker` associated with `New`, most of the time its + `source` field is None, except for cases like side effect codegen for + `AttributeMutationNew`, during which we generate a + `LocalSource('tmp...')` for such variable, to facilitate codegen. + """ + + Existing = 0 + New = 1 + + +class MutationType: + """ + Base class for Variable.mutation_type. It encodes information about + 1. The type of mutation Dynamo allows on the variable. + 2. Whether the value represented by this variable already existed before + Dynamo tracing. + """ + + def __init__(self, typ: SourceType) -> None: + # In HigherOrderOperator tracing, we need to distinguish + # between MutationTypes inside the HigherOrderOperator and + # ones outside it. For example, it is not safe to mutate + # `a` in the following example because it was constructed + # in a different scope. + # + # def f(x): + # a = 1 + # def g(x): + # nonlocal a + # a = 2 + # return x + # return wrap(g, x) + a + # + # We use self.scope to distinguish this. + # scope == 0: The object was an existing variable + # scope == 1: The object was created while Dynamo + # was introspecting a function + # (and no HigherOrderOps were involved) + # scope >= 2: The object was created through + # Dynamo introspection of a HigherOrderOp. + # The exact number corresponds to the level + # of nested HigherOrderOps. + if typ is SourceType.Existing: + self.scope = 0 + elif typ is SourceType.New: + self.scope = current_scope_id() + else: + unimplemented_v2( + gb_type="Unsupported SourceType", + context=f"MutationType.__init__ {self} {typ}", + explanation=f"Dynamo does not support the type `{typ}`", + hints=[ + "This branch is not supposed to be reachable.", + *graph_break_hints.DYNAMO_BUG, + ], + ) + + +class ValueMutationNew(MutationType): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value itself (rather than its attributes). + 2. The value is created by the bytecode Dynamo is tracing through. + + For instance, Dynamo could model a newly created list with this marker, + indicating that while we need to model mutations to this list, we don't have + to emit bytecode for these mutations if the list doesn't escape into the + Python world. + """ + + def __init__(self) -> None: + super().__init__(SourceType.New) + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return self is other + + +class ValueMutationExisting(MutationType): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value itself (rather than its attributes). + 2. The value exists before Dynamo tracing started. + + For instance, Dynamo could model a pre-existing list with this marker, + indicating that if we encounter mutations to this list, we need to buffer + and re-apply those mutations after the graph runs, since the list might be + used afterwards in Python. + """ + + # A flag to indicate whether mutation happened on the associated + # `VariableTracker`. This enables SideEffects to accurately and quickly + # filter out which pre-existing values it needs to generate mutation for. + is_modified: bool + + def __init__(self, is_modified: bool = False): + super().__init__(SourceType.Existing) + self.is_modified = is_modified + + +class AttributeMutation(MutationType): + """ + This case of VariableTracker.mutation_type marker indicates that Dynamo + allows mutation on the value's attributes. + """ + + def __init__(self, typ: SourceType): + super().__init__(typ) + + +class AttributeMutationExisting(AttributeMutation): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value's attributes. + 2. The value exists before Dynamo tracing started. + + For instance, Dynamo could model a pre-existing object with this marker, + indicating that if we encounter mutations to this object, we need to buffer + then re-apply those mutations after the graph runs, since the object might + be used afterwards in Python. + """ + + def __init__(self): + super().__init__(SourceType.Existing) + + +class AttributeMutationNew(AttributeMutation): + """ + This case of VariableTracker.mutation_type marker indicates + 1. Dynamo allows mutation on the value's attributes. + 2. The value is created by the bytecode Dynamo is tracing through. + + For instance, Dynamo could model a newly created object with this marker, + indicating that while we need to model mutations to this object, we don't + have to emit bytecode for these mutations if the object doesn't escape into + the Python world. + """ + + def __init__(self, cls_source: Optional[Source] = None): + super().__init__(SourceType.New) + self.cls_source = cls_source + + +def _is_top_level_scope(scope_id): + return scope_id == 1 + + +def is_side_effect_safe(m: MutationType): + scope_id = current_scope_id() + + # In the top-level scope (if no HigherOrderOperators are involved), + # we are allowed to modify variables created in this scope as well + # as existing variables. + if _is_top_level_scope(scope_id): + return True + # Otherwise, only allow local mutation of variables created in the current scope + return m.scope == scope_id + + +# This helps users of `as_python_constant` to catch unimplemented error with +# more information; it inherits `NotImplementedError` for backward +# compatibility reasons. +class AsPythonConstantNotImplementedError(NotImplementedError): + vt: "VariableTracker" + + def __init__(self, vt: "VariableTracker"): + super().__init__(f"{vt} is not a constant") + self.vt = vt + + +class VariableTrackerMeta(type): + all_subclasses = [] + + def __instancecheck__(cls, instance) -> bool: + """Make isinstance work with LazyVariableTracker""" + # This is super expensive - just having it costs over 4% of tracing + # time! + if (type(instance) is variables.LazyVariableTracker) and ( + cls not in (VariableTracker, variables.LazyVariableTracker) + ): + instance = instance.realize() + return type.__instancecheck__(cls, instance) + + def __init__(cls, name, bases, attrs) -> None: + super().__init__(name, bases, attrs) + VariableTrackerMeta.all_subclasses.append(cls) + + +class VariableTracker(metaclass=VariableTrackerMeta): + """ + Base class for tracked locals and stack values + + VariableTracker instances are immutable and should be copied in + order to change them. + + Prefer the factory function VariableTracker.build() over VariableTracker.__init__(). + """ + + # fields to leave unmodified in apply() + _nonvar_fields = { + "value", + "guards", + "source", + "mutation_type", + "parents_tracker", + "user_code_variable_name", + } + + def clone(self, **kwargs): + """Shallow copy with some (optional) changes""" + args = dict(self.__dict__) + args.update(kwargs) + return self.__class__(**args) + + @classmethod + def visit( + cls, + fn: Callable[["VariableTracker"], None], + value: Any, + cache: Optional[dict[int, Any]] = None, + ) -> None: + """ + Walk value and call fn on all the VariableTracker instances + """ + if cache is None: + cache = {} + + idx = id(value) + if idx in cache: + return + # save `value` to keep it alive and ensure id() isn't reused + cache[idx] = value + + if isinstance(value, VariableTracker): + value = value.unwrap() + fn(value) + value = value.unwrap() # calling fn() might have realized it + nonvars = value._nonvar_fields + for key, subvalue in value.__dict__.items(): + if key not in nonvars: + cls.visit(fn, subvalue, cache) + elif istype(value, (list, tuple)): + for subvalue in value: + cls.visit(fn, subvalue, cache) + elif istype(value, (dict, collections.OrderedDict)): + for subvalue in value.values(): + cls.visit(fn, subvalue, cache) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def debug_repr(self): + # Intended to be overridden to provide more info + try: + return repr(self.as_python_constant()) + except NotImplementedError: + return repr(self) + + def python_type(self): + """ + Abstract method to be implemented by subclasses of VariableTracker. + + This method should return the type represented by the instance of the subclass. + The purpose is to provide a standardized way to retrieve the Python type information + of the variable being tracked. + + Returns: + type: The Python type (such as int, str, list, etc.) of the variable tracked by + the subclass. If the type cannot be determined or is not relevant, + leaving it undefined or invoking super() is always sound. + + Note: + This is an abstract method and may be overridden in subclasses. + + Example: + class SetVariable(VariableTracker): + def python_type(self): + return set + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ + try: + return type(self.as_python_constant()) + except NotImplementedError: + raise NotImplementedError(f"{self} has no type") from None + + def python_type_name(self): + try: + return self.python_type().__name__ + except NotImplementedError: + return "" + + def as_python_constant(self): + """For constants""" + raise AsPythonConstantNotImplementedError(self) + + def guard_as_python_constant(self): + """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" + try: + return self.as_python_constant() + except NotImplementedError: + unimplemented_v2( + gb_type="Not a Python constant", + context=f"guard_as_python_constant {self}", + explanation=f"Failed to convert {self} into a Python constant.", + hints=[], + ) + + def is_python_constant(self): + try: + self.as_python_constant() + return True + except NotImplementedError: + return False + + def make_guard(self, fn): + if self.source: + return self.source.make_guard(fn) + raise NotImplementedError + + def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: + """getattr(self, name) returning a python constant""" + raise NotImplementedError + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + """getattr(self, name) returning a new variable""" + value = self.const_getattr(tx, name) + if not variables.ConstantVariable.is_literal(value): + raise NotImplementedError + source = self.source and AttrSource(self.source, name) + if source: + install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) + return variables.ConstantVariable.create(value, source=source) + + def is_proxy(self): + try: + self.as_proxy() + return True + except NotImplementedError: + return False + + def as_proxy(self): + raise NotImplementedError(str(self)) + + def maybe_fx_node(self): + try: + proxy = self.as_proxy() + import torch.fx + + if isinstance(proxy, torch.fx.Proxy): + return proxy.node + return None + except NotImplementedError: + return None + + def reconstruct(self, codegen: "PyCodegen"): + raise NotImplementedError + + def unpack_var_sequence(self, tx) -> list["VariableTracker"]: + raise NotImplementedError + + def force_unpack_var_sequence(self, tx) -> list["VariableTracker"]: + # like unpack_var_sequence, but should only be used when it is + # safe to eagerly (vs. lazily) unpack this variable. + # e.g. map(f, x) is normally evaluated lazily but sometimes + # we want to force eager unpacking, e.g. when converting to a list. + # NOTE: this method is allowed to mutate the VariableTracker, so + # it should only be called once. + return self.unpack_var_sequence(tx) + + def has_unpack_var_sequence(self, tx) -> bool: + try: + self.unpack_var_sequence(tx) + return True + except NotImplementedError: + return False + + # NB: don't call force_unpack_var_sequence, especially if it mutates! + def has_force_unpack_var_sequence(self, tx) -> bool: + return self.has_unpack_var_sequence(tx) + + # Forces unpacking the var sequence while also applying a function to each element. + # Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence). + # INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True! + def force_apply_to_var_sequence(self, tx, fn) -> None: + assert self.has_force_unpack_var_sequence(tx) + for v in self.unpack_var_sequence(tx): + fn(v) + + def inspect_parameter_names(self) -> list[str]: + unimplemented_v2( + gb_type="Unsupported inspect call", + context=f"inspect_parameter_names {self}", + explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", + hints=[], + ) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + unimplemented_v2( + gb_type="Unsupported hasattr call", + context=f"call_obj_hasattr {self} {name}", + explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", + hints=[ + f"Avoid calling `hasattr({self.__class__.__name__}, {name})` in your code.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence["VariableTracker"], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented_v2( + gb_type="Unsupported function call", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", + hints=[ + f"Avoid calling `{self.debug_repr()}` in your code.", + "Please report an issue to PyTorch.", + ], + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__len__" and self.has_unpack_var_sequence(tx): + assert not (args or kwargs) + return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx))) + elif ( + name == "__getattr__" + and len(args) == 1 + and args[0].is_python_constant() + and not kwargs + ): + return self.var_getattr(tx, args[0].as_python_constant()) + elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: + other = args[0] + if not isinstance(self, type(other)) and not ( + isinstance(self, variables.GetAttrVariable) + or isinstance(other, variables.GetAttrVariable) + ): + # NB: GetAttrVariable is a special case because sometimes an + # object can map to GetAttrVariable but other time as + # SkipFunctionVariable if it is an input to the compiled + # function, e.g. tensor.data_ptr + return variables.ConstantVariable.create(NotImplemented) + # NB : Checking for mutation is necessary because we compare + # constant values + if ( + not self.is_python_constant() + or not other.is_python_constant() + or tx.output.side_effects.has_pending_mutation(self) + or tx.output.side_effects.has_pending_mutation(other) + ): + unimplemented_v2( + gb_type="Builtin `operator.*` comparison with constant `self` failed", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"Failed to compare {self} with {other}, " + + f"because {other} is not a Python constant or its mutation check fails.", + hints=[], + ) + + try: + return variables.ConstantVariable.create( + cmp_name_to_op_mapping[name]( + self.as_python_constant(), other.as_python_constant() + ) + ) + except Exception as e: + raise_observed_exception( + type(e), + tx, + args=[list(map(variables.ConstantVariable.create, e.args))], + ) + hints = [ + f"Avoid calling `{self.python_type_name()}.{name}` in your code.", + "Please report an issue to PyTorch.", + ] + # additional hint for method calls on improperly constructed iterators + if isinstance(self, variables.UserDefinedObjectVariable) and name in ( + "__iter__", + "__next__", + ): + if isinstance(self.value, (KeysView, ItemsView, ValuesView)): + hints.append( + "Consider moving the creation of dict view object (e.g. `dict.keys()`, `dict.items()`,) " + "to the compiled region, instead of passing it as an input to the compiled region." + ) + hints.append( + "Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) " + "passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). " + "This can happen unintentionally if a previous graph break happens with a builtin iterator " + "in the local scope." + ) + unimplemented_v2( + gb_type="Unsupported method call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`", + hints=hints, + ) + + def set_name_hint(self, name): + pass + + def realize(self) -> "VariableTracker": + """Used by LazyVariableTracker to build the real VariableTracker""" + return self + + def unwrap(self) -> "VariableTracker": + """Used by LazyVariableTracker to return the real VariableTracker if it already exists""" + return self + + def is_realized(self): + """Used by LazyVariableTracker to indicate an unrealized node""" + return True + + def next_variable(self, tx): + unimplemented_v2( + gb_type="Unsupported next() call", + context=f"next({self})", + explanation=f"Dynamo does not know how to trace calling `next()` on variable `{self}`.", + hints=[*graph_break_hints.USER_ERROR], + ) + + def is_strict_mode(self, tx): + return tx.strict_checks_fn and tx.strict_checks_fn(self) + + def is_mutable(self): + """Whether Dynamo allows mutation on this variable.""" + return not self.is_immutable() + + def is_immutable(self): + """Whether Dynamo bans mutation on this variable.""" + return self.mutation_type is None + + @staticmethod + def build( + tx: "InstructionTranslatorBase", + value: Any, + source: Optional[Source] = None, + ) -> Any: + """Create a new VariableTracker from a value and optional Source""" + if source is None: + return builder.SourcelessBuilder.create(tx, value) + else: + return variables.LazyVariableTracker.create(value, source) + + def __init__( + self, + *, + source: Source = None, + mutation_type: MutationType = None, + ) -> None: + super().__init__() + self.source = source + self.mutation_type = mutation_type + + # NOTE sometimes mutation_type is set afterwards for implementation + # convenience, we don't validate those cases at the moment. + if mutation_type is not None: + if isinstance(mutation_type, (ValueMutationNew, AttributeMutationNew)): + # If this fails, it's either + # 1. one mistakenly passed in a source + # 2. `mutation_type` is incorrect + assert source is None + else: + assert isinstance( + mutation_type, (ValueMutationExisting, AttributeMutationExisting) + ) + # If this fails, it's either + # 1. one forgot to pass in a source + # 2. `mutation_type` is incorrect + assert source is not None + + +def typestr(*objs): + if len(objs) == 1: + (obj,) = objs + if isinstance(obj, VariableTracker): + return str(obj) + else: + return type(obj).__name__ + else: + return " ".join(map(typestr, objs)) + + +from . import builder diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/builder.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..af4d04dcabdd15eaafe25c4a8eb8e4207b809107 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/builder.py @@ -0,0 +1,3633 @@ +# mypy: ignore-errors + +""" +This module contains classes and utilities for building variable trackers in Dynamo. +Variable trackers are used to convert Python values into symbolic representations +that can be traced and transformed during graph capture. + +The key classes are: + +- VariableBuilder: Handles source-tracked objects that need guards and proper + reconstruction in the output graph. Used for inputs, module attributes, etc. + +- SourcelessBuilder: Handles ephemeral objects created during tracing that don't + need source tracking or guards. Used for temporary lists, intermediate values, etc. + +Variable trackers enable Dynamo to track the flow of values through the program, +maintain guards for dynamic properties, and reconstruct values in the output graph. +The builders in this module handle converting Python values into appropriate +VariableTracker instances based on their type and usage context. +""" + +import abc +import collections +import contextlib +import copy +import dataclasses +import enum +import functools +import inspect +import itertools +import logging +import math +import operator +import random +import re +import sys +import traceback +import types +import warnings +import weakref +from collections.abc import MutableMapping +from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +from torch import SymInt +from torch._dynamo.utils import ( + get_metrics_context, + is_int_specialization_case, + is_torch_sym, + set_feature_use, +) +from torch._guards import TracingContext +from torch._higher_order_ops.torchbind import call_torchbind +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode +from torch._subclasses.meta_utils import is_sparse_any, safe_grad +from torch._utils_internal import justknobs_check +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental._dynamism import normalize_source_name +from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + _nested_int_aware_sort, + DimDynamic, + RelaxedUnspecConstraint, + StatefulSymbolicContext, + SubclassSymbolicContext, + SymbolicContext, + SymIntSymbolicContext, + TrackedFake, +) +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.nn.utils._expanded_weights import ExpandedWeight +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + is_traceable_wrapper_subclass_type, +) +from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils.weak import TensorWeakRef + +from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules +from ..device_interface import get_registered_device_interfaces +from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented_v2 +from ..guards import GuardBuilder, install_guard, make_dupe_guard +from ..pgo import ( + auto_dynamic, + auto_unset, + FrameStateSizeEntry, + InferStride, + process_automatic_dynamic, +) +from ..side_effects import SideEffects +from ..source import ( + AttrProxySource, + AttrSource, + CallMethodItemSource, + ChainedSource, + ConstDictKeySource, + ConvertIntSource, + DictGetItemSource, + DictSubclassGetItemSource, + FloatTensorSource, + GetItemSource, + GradSource, + is_constant_source, + is_from_global_source, + is_from_nonlocal_source, + is_from_optimizer_source, + is_from_unspecialized_nn_module_source, + ListGetItemSource, + LocalSource, + NumpyTensorSource, + OptimizerSource, + RandomValueSource, + Source, + SubclassAttrListSource, + TupleIteratorGetItemSource, + UnspecializedBuiltinNNModuleSource, + UnspecializedNNModuleSource, +) +from ..utils import ( + _extract_tensor_dict, + build_checkpoint_variable, + build_invoke_subgraph_variable, + clone_input, + common_constant_types, + dict_keys, + get_fake_value, + get_items_from_dict, + get_locals_to_steal, + get_static_address_type, + is_frozen_dataclass, + is_function_or_wrapper, + is_invoke_subgraph, + is_lru_cache_wrapped_function, + is_namedtuple, + is_parameter_freezing, + is_typing, + is_utils_checkpoint, + is_wrapper_or_member_descriptor, + istype, + namedtuple_fields, + odict_values, + proxy_args_kwargs, + range_iterator, + set_example_value, + tensor_always_has_static_shape, + tuple_iterator, + tuple_iterator_getitem, + tuple_iterator_len, + unwrap_with_attr_name_if_wrapper, + wrap_fake_exception, +) +from .base import ( + AttributeMutationNew, + typestr, + ValueMutationExisting, + ValueMutationNew, + VariableTracker, + VariableTrackerMeta, +) +from .constant import ConstantVariable, EnumVariable +from .ctx_manager import ( + AutocastModeVariable, + DynamoConfigPatchVariable, + EventVariable, + NullContextVariable, + PreserveVersionContextVariable, + StreamContextVariable, + StreamVariable, +) +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictKeySetVariable, + FrozensetVariable, + MappingProxyVariable, + SetVariable, +) +from .distributed import ( + DeviceMeshVariable, + PlacementClassVariable, + PlacementVariable, + ProcessGroupVariable, + WorldMetaClassVariable, +) +from .functions import ( + BuiltinMethodVariable, + CollectionsNamedTupleFunction, + CollectiveFunctionRewriteVariable, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, + FunctoolsPartialVariable, + FunctoolsWrapsVariable, + SysFunctionVariable, + TracebackVariable, + TritonKernelVariable, + UserFunctionVariable, + UserMethodVariable, + WrapperUserFunctionVariable, +) +from .higher_order_ops import TorchHigherOrderOperatorVariable +from .iter import ItertoolsVariable +from .lazy import LazyVariableTracker +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + NamedTupleVariable, + RangeVariable, + SizeVariable, + SliceVariable, + TupleIteratorVariable, + TupleVariable, +) +from .misc import ( + AutogradEngineVariable, + AutogradFunctionContextVariable, + AutogradFunctionVariable, + ComptimeVariable, + DebuggingVariable, + DelayGraphBreakVariable, + GetAttrVariable, + GetSetDescriptorVariable, + LambdaVariable, + LoggingLoggerVariable, + MethodWrapperVariable, + NumpyDTypeVariable, + NumpyTypeInfoVariable, + NumpyVariable, + PythonModuleVariable, + RandomClassVariable, + RandomVariable, + RegexPatternVariable, + SavedTensorBox, + TorchVersionVariable, + TypingVariable, + WeakRefVariable, +) +from .nn_module import ( + FSDPManagedNNModuleVariable, + UnspecializedBuiltinNNModuleVariable, + UnspecializedNNModuleVariable, +) +from .optimizer import OptimizerVariable +from .script_object import TorchScriptObjectVariable +from .sdpa import SDPAParamsVariable +from .tensor import ( + NumpyNdarrayVariable, + supported_const_comparison_op_values, + SymNodeVariable, + TensorSubclassVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .torch import ( + DispatchKeySetVariable, + FuncTorchInterpreterVariable, + TorchCtxManagerClassVariable, + TorchInGraphFunctionVariable, +) +from .torch_function import ( + TensorWithTFOverrideVariable, + torch_function_mode_stack_state_mgr, + TorchFunctionModeVariable, +) +from .user_defined import ( + FrozenDataClassVariable, + IntWrapperVariable, + KeyedJaggedTensorVariable, + MutableMappingVariable, + SourcelessGraphModuleVariable, + UserDefinedClassVariable, + UserDefinedDictVariable, + UserDefinedExceptionClassVariable, + UserDefinedListVariable, + UserDefinedObjectVariable, + UserDefinedTupleVariable, +) + + +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +DimList = list + + +def safe_has_grad(t): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + return hasattr(t, "grad") + + +class _missing: + pass + + +@dataclasses.dataclass +class GraphArg: + source: Source + # TODO: storing a SymInt here but not a FakeTensor is a pretty strange + # thing to do. Probably should have example (which stores an int) and + # fake_example + _example: Union[TensorWeakRef, torch.SymInt] + # When True, this indicates that this GraphArg is a Python quantity (e.g., + # a float or int) which we pass to the FX graph as a Tensor. This + # controls how we codegen calls into the Dynamo graph: we will call + # torch.as_tensor on the quantity before passing it in. + # + # Note that we typically do not pass dynamic integers as tensors, because + # they will most frequently just be used for size computation. But this + # is a policy decision that we can change our mind on; in particular, when + # an int comes from a random number generator (e.g., random.randint), we + # DO pass it as a tensor. + # + # It's also worth noting that our current tracing rules for + # pass_arg_as_tensor as subtly broken: we just pun the variable as a + # 0d scalar Tensor and pray that the semantics are the same. Which they + # often are, but not necessarily. ezyang(May 2024) plans to fix this + # soon. + pass_arg_as_tensor: bool + fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] + # UnspecializedPythonVariable often masquerades as a tensor. + # We MUST NOT generate shape guard code + # that actually tries to access tensor properties on these values. + # is_tensor lets us tell if this graph arg actually is a tensor + # or not. + is_tensor: bool = True + # Sometimes, the Tensor we pass to example is freshly allocated (smh). + # Then we cannot only keep a weak reference to it. This lets you + # stash a strong reference too. + example_strong_ref: Optional[torch.Tensor] = None + + @property + def example(self): + if isinstance(self._example, TensorWeakRef): + r = self._example() + assert r is not None + return r + else: + return self._example + + def __post_init__(self): + if isinstance(self._example, torch.Tensor): + self._example = TensorWeakRef(self._example) + assert is_fake(self.fake_tensor) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.source) + + def erase(self): + self._example = None + self.example_strong_ref = None + + def __eq__(self, other): + return self.source.name() == other.source.name() + + +class BackwardStateGraphArg(GraphArg): + def __init__(self) -> None: + super().__init__( + source=None, + _example=BackwardState(), + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + ) + + def reconstruct(self, codegen: "PyCodegen"): + assert codegen.tx.output.backward_state_var + codegen.add_push_null( + lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState") + ) + codegen.call_function(0, False) + codegen.dup_top() + codegen.store(codegen.tx.output.backward_state_var) + + +# All class-based iterators in itertools +# NOTE: use id() because some objects are not hashable, it will raise error during lookup +ITERTOOLS_TYPE_IDS: frozenset[int] = frozenset( + id(member) + for name, member in vars(itertools).items() + if not name.startswith("_") and inspect.isclass(member) +) +# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py +ITERTOOLS_POLYFILLED_TYPE_IDS: set[int] = set() + +# Capture fn pointer at import time +# This is to guard against trying to mark the iterated tensors +# as static in case user overrides fn ptr +og_module_named_buffers_fn_ptr = torch.nn.Module.named_buffers +og_module_named_parameters_fn_ptr = torch.nn.Module.named_parameters + + +class VariableBuilder: + """Wrap a python value in a VariableTracker() instance""" + + def __init__( + self, + tx, + source: Source, + ) -> None: + assert source is not None, ( + "Consider SourcelessBuilder for ephemeral objects, usually objects created locally." + ) + assert TracingContext.try_get() is not None, "Expected active TracingContext" + super().__init__() + self.tx = tx + self.source = source + self.name = source.name() + + def __call__(self, value): + if value in self.tx.output.side_effects: + side_effect_result = self.tx.output.side_effects[value] + dup_guard = make_dupe_guard(self.source, side_effect_result.source) + if dup_guard: + self.install_guards(dup_guard) + return side_effect_result + + cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source) + if cached_vt: + return cached_vt + + vt = self._wrap(value) + + if vt.source is None: + vt.source = self.source + + if ( + self._can_lift_attrs_to_inputs(vt) + and value not in self.tx.output.side_effects + and not is_wrapper_or_member_descriptor(value) + ): + vt = self.tx.output.side_effects.track_object_existing(value, vt) + + self.tx.output.variable_tracker_cache.add(value, self.source, vt) + return vt + + def _can_lift_attrs_to_inputs(self, vt): + return type(vt) in { + TensorVariable, + TensorWithTFOverrideVariable, + UserDefinedObjectVariable, + NumpyNdarrayVariable, + } + + def get_source(self): + return self.source + + def install_guards(self, *guards): + source = self.get_source() + try: + tmp = [source.make_guard(guard) for guard in guards] + except NotImplementedError: + return None + install_guard(*tmp, skip=1) + return {} + + @classmethod + def _type_dispatch(cls): + return cls._type_dispatch_impl(config.trace_numpy) + + @classmethod + @functools.cache + def _type_dispatch_impl(cls, trace_numpy): + # NB: Careful not to close over self to avoid ref cycle from lru_cache + entries = [ + ( + ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + cls.wrap_tensor, + ), + ( + (tuple, list, odict_values, collections.deque, torch.Size), + cls.wrap_listlike, + ), + (tuple_iterator, cls.wrap_tuple_iterator), + (range_iterator, cls.wrap_range_iterator), + ((slice, range), cls.wrap_slice_range), + (tuple(common_constant_types), cls.wrap_literal), + (re.Pattern, cls.wrap_regex_pattern), + (weakref.ReferenceType, cls.wrap_weakref), + (torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle), + (torch.jit.ScriptFunction, cls.wrap_jit_function), + (types.MappingProxyType, cls.wrap_mapping_proxy), + ] + + if trace_numpy and np: + entries.append((np.ndarray, cls.wrap_numpy_ndarray)) + + result = {} + for ts, fn in entries: + for t in ts if isinstance(ts, tuple) else (ts,): + assert t not in result + result[t] = fn + + return result + + def wrap_regex_pattern(self, value: re.Pattern): + # TODO(jansel): something like a REPR_MATCH might be more robust here + self.install_guards(GuardBuilder.ID_MATCH) + return RegexPatternVariable(value) + + def wrap_weakref(self, value: weakref.ReferenceType): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WeakRefVariable.build(self.tx, value, source=self.source) + + def wrap_removable_handle(self, value): + # This means that the removable handle was created in some other frame. + # Our current infra requires the hook to be registered and removed in + # the same frame. So graph break. + # Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks + unimplemented_v2( + gb_type="Attempted to represent unregistered RemovableHandle", + context="", + explanation="Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, " + "which is not supported. This happens because the RemovableHandle was created in another frame.", + hints=[], + ) + + def wrap_jit_function(self, value): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) + + def wrap_mapping_proxy(self, value): + self.install_guards(GuardBuilder.TYPE_MATCH) + # This might be suboptimal compared to dict guards. But mappingproxy is + # not very common, so its ok to guard on all keys. + self.install_guards(GuardBuilder.MAPPING_KEYS_CHECK) + all_const = all(ConstantVariable.is_literal(k) for k in value.keys()) + + if not all_const: + unimplemented_v2( + gb_type="non-const keys in mappingproxy", + context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", + explanation="Dynamo expects mappingproxy keys to be constants.", + hints=[ + "Ensure your mappingproxy keys are constants (e.g. int, float, strings)", + ], + ) + + def build_key_value(k, v): + key = ConstantVariable.create(k) + source_key = k + + source_value = GetItemSource(self.get_source(), source_key) + res_value = LazyVariableTracker.create(v, source_value) + + return key, res_value + + items = dict(build_key_value(k, v) for k, v in value.items()) + + # Create a dict_vt to be used in the mapping proxy variable + dict_vt = ConstDictVariable(items, source=None) + result = MappingProxyVariable(dict_vt, source=self.source) + return self.tx.output.side_effects.track_mutable(value, result) + + @classmethod + @functools.cache + def _id_dispatch( + cls, + ) -> dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: + from ..comptime import comptime + + entries = [ + (comptime, lambda self, value: ComptimeVariable()), + ( + dataclasses.fields, + lambda self, value: LambdaVariable( + _dataclasses_fields_lambda, + source=self.source, + **self.install_guards(GuardBuilder.FUNCTION_MATCH), + ), + ), + (torch.__version__, lambda self, value: TorchVersionVariable()), + ] + + result = {} + for ts, fn in entries: + for t in ts if isinstance(ts, (tuple, list)) else (ts,): + assert t not in result + result[id(t)] = fn + + return result + + def _wrap(self, value): + # import here to avoid circular dependencies + from torch.utils._triton import ( + has_triton, + has_triton_experimental_host_tma, + has_triton_tensor_descriptor_host_tma, + ) + + from ..decorators import DynamoConfigPatchProxy + + if has_triton(): + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + else: + + class JITFunction: + pass + + class Autotuner: + pass + + # default implementations, in case we don't have triton (or the wrong triton version) + def create_1d_tma_descriptor(): + pass + + def create_2d_tma_descriptor(): + pass + + class TensorDescriptor: + @staticmethod + def from_tensor(): + pass + + if has_triton_experimental_host_tma(): + from triton.tools.experimental_descriptor import ( # noqa: F811 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + if has_triton_tensor_descriptor_host_tma(): + from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811 + + # Handle exact type() match + type_dispatch = self._type_dispatch().get(type(value)) + if type_dispatch is not None: + return type_dispatch(self, value) + + # Handle exact id() match + id_dispatch = self._id_dispatch().get(id(value)) + if id_dispatch is not None: + return id_dispatch(self, value) + + # Everything else (NB: order matters!) + if ( + isinstance(value, torch.Tensor) + and type(value) + not in ( + # These torch-native subclasses have overly restrictive + # `__torch_function__` which prevents Dynamo from reading their + # tensor attributes like `is_nested` or calling methods like + # `_is_view`. + torch.nn.parameter.UninitializedBuffer, + torch.nn.parameter.UninitializedParameter, + ExpandedWeight, + ) + and type(value) not in config.nontraceable_tensor_subclasses + ): + if type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__: + # This case it's either tensor or subclass with default + # torch_dispatch (they might override torch_function or not), + # and we can always trace into them. + return self.wrap_tensor(value) + elif is_traceable_wrapper_subclass(value): + # For non-default torch_dispatch, we have more requirements. + return self.wrap_tensor(value) + + if is_namedtuple(value): + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + output = [ + LazyVariableTracker.create( + getattr(value, name), + source=AttrSource(self.source, name), + ) + for name in namedtuple_fields(type(value)) + ] + result = NamedTupleVariable( + output, tuple_cls=type(value), source=self.source + ) + return result + elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): + self.install_guards(GuardBuilder.TYPE_MATCH) + all_const = all(ConstantVariable.is_literal(k) for k in value.keys()) + + # For all_const, we don't have to guard on anything yet. We guard on + # keys lazily by adding a dict_getitem entry for each accessed key. + # For cases where we need to guard on all keys, we lazily put guards + # during the dict call_method (check dicts.py) + if not all_const: + # Guard on the key order + # This is not ideal, i.e., there is no need to guard on the key + # order. But we guard on the key order because of the complexity + # + # 1) For non-constant objects, we can't save the key in the + # guard context because it can be memory heavy. We can add + # weakrefs but this complicates the accesses. + # + # 2) For non-constant objects, we also have to guard on the keys + # (like TENSOR_MATCH on tensor). We might also have guards on + # the attributes of the keys (like tensor.grad). To make this + # work in tree structure is complicated. + # + # So, instead we guard on the key order. While guarding on key + # order, we just save the indices and use it to access keys and + # values. Indices are cheap to save. + self.tx.output.guard_on_key_order.add(self.source) + + # We need all the keys to be hashable. We do this within the + # _HashableTracker class in dicts.py + def build_key_value(i, k, v): + base = self.get_source() + if all_const: + key = ConstantVariable.create(k) + source_key = k + else: + source_key = ConstDictKeySource(base, i) + key = LazyVariableTracker.create(k, source_key) + source_value = DictGetItemSource(base, source_key) + res_value = LazyVariableTracker.create(v, source_value) + + return key, res_value + + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on + # PyDict_Next to traverse the dictionary, which uses the internal + # data structure and does not call the overridden keys method. + result = dict( + build_key_value(i, k, v) + for i, (k, v) in enumerate(get_items_from_dict(value)) + ) + + if istype(value, collections.defaultdict): + factory_source = AttrSource(self.source, "default_factory") + result = DefaultDictVariable( + result, + type(value), + default_factory=VariableBuilder(self.tx, factory_source)( + value.default_factory + ), + source=self.source, + ) + else: + result = ConstDictVariable( + result, user_cls=type(value), source=self.source + ) + + return self.tx.output.side_effects.track_mutable(value, result) + elif isinstance(value, torch.nn.Module): + return self.wrap_module(value) + elif ConstantVariable.is_literal(value): # non-atomic literals + return self.wrap_literal(value) + elif isinstance(value, torch.overrides.TorchFunctionMode): + var = TorchFunctionModeVariable(value, source=self.source) + self.tx.output.side_effects.track_object_existing(value, var) + return var + elif istype(value, frozenset) and all( + ( + # For DBR quantization, we could get a frozenset of torch funcs. + (type(x) is types.BuiltinMethodType and x.__module__ == "torch") + or + # Another commonly used frozenset of types. + x in torch.utils._pytree.BUILTIN_TYPES + ) + for x in value + ): + # For the limited cases of frozenset here, we know the items won't + # change across runs, so we can safely create sourceless VTs for + # them and only guard on the frozenset id. + # TODO support source for sets and remove the special logics here. + items = [SourcelessBuilder.create(self.tx, v) for v in value] + self.install_guards(GuardBuilder.ID_MATCH) + return FrozensetVariable(items, source=self.source) + elif isinstance( + value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) + ): + self.install_guards(GuardBuilder.ID_MATCH) + return EnumVariable(value=value, source=self.source) + elif DebuggingVariable.is_reorderable_logging_function(value): + # Put this above builtin_callable so that print() can be handled + # along with other builtin debugging functions + self.install_guards(GuardBuilder.BUILTIN_MATCH) + return DebuggingVariable(value, source=self.source) + elif isinstance(value, logging.Logger): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return LoggingLoggerVariable(value, source=self.source) + elif is_utils_checkpoint(value): + return build_checkpoint_variable(source=self.source) + elif is_invoke_subgraph(value): + return build_invoke_subgraph_variable(source=self.source) + elif isinstance(value, functools.partial): + func_src = AttrSource(self.get_source(), "func") + func_obj = VariableBuilder(self.tx, func_src)(value.func) + + args = [] + args_source = AttrSource(self.get_source(), "args") + for i, arg in enumerate(value.args): + args.append( + VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) + ) + + keywords = {} + keywords_source = AttrSource(self.get_source(), "keywords") + for k, v in value.keywords.items(): + if not ConstantVariable.is_literal(k): + unimplemented_v2( + gb_type="functools.partial() with non-literal keyword", + context=f"non-literal keyword: {k}", + explanation="functools.partial() expects literal/string keywords", + hints=[*graph_break_hints.USER_ERROR], + ) + keywords[k] = VariableBuilder( + self.tx, DictGetItemSource(keywords_source, k) + )(v) + + install_guard( + self.get_source().make_guard(GuardBuilder.TYPE_MATCH), + keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH), + args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), + ) + return FunctoolsPartialVariable(func_obj, args, keywords) + elif is_typing(value): + # typing.List, typing.Mapping, etc. + self.install_guards(GuardBuilder.ID_MATCH) + return TypingVariable( + value, + source=self.source, + ) + elif np is not None and isinstance(value, np.generic): + # numpy array scalars: convert to 0D arrays + return self.wrap_numpy_ndarray(np.asarray(value)) + elif trace_rules.is_numpy(value): + assert np + self.install_guards( + GuardBuilder.FUNCTION_MATCH + if callable(value) + else GuardBuilder.TYPE_MATCH + ) + return NumpyVariable(value, source=self.source) + elif trace_rules.is_numpy_dtype(value): + self.install_guards(GuardBuilder.ID_MATCH) + return NumpyDTypeVariable(value, source=self.source) + elif trace_rules.is_numpy_type_info(value): + if isinstance(value, np.iinfo): + self.install_guards(GuardBuilder.TYPE_MATCH) + dt_source = AttrSource(self.source, "dtype") + install_guard(dt_source.make_guard(GuardBuilder.ID_MATCH)) + else: + self.install_guards(GuardBuilder.ID_MATCH) + return NumpyTypeInfoVariable(value, source=self.source) + # NB: These can't be put in type_dispatch, they have to run later + elif CollectiveFunctionRewriteVariable.can_rewrite(value): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return CollectiveFunctionRewriteVariable.create( + self.tx, + value, + source=self.source, + ) + elif istype(value, torch.autograd.function.FunctionMeta): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return AutogradFunctionVariable( + value, + source=self.source, + ) + elif isinstance(value, torch.autograd.function.FunctionCtx): + actual_saved_tensors = None + try: + actual_saved_tensors = value.saved_tensors + except RuntimeError: + pass + + saved_tensors = [] + guards = [self.source.make_guard(GuardBuilder.TYPE_MATCH)] + if isinstance(actual_saved_tensors, tuple): + saved_tensors_source = AttrSource(self.source, "saved_tensors") + guards.append( + saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH) + ) + for i, v in enumerate(actual_saved_tensors): + saved_tensors.append( + VariableBuilder( + self.tx, GetItemSource(saved_tensors_source, i) + )(v) + ) + install_guard(*guards) + + return self.tx.output.side_effects.track_object_existing( + value, + AutogradFunctionContextVariable( + value, + source=self.source, + saved_tensors=SavedTensorBox(saved_tensors), + ), + ) + elif ( + isinstance(value, types.MethodType) + and istype( + getattr(value, "__self__", None), torch.autograd.function.FunctionMeta + ) + and getattr(value, "__name__", "") == "apply" + and value == getattr(value.__self__, "apply", None) + ): + # handle aliased autograd function `apply` calls + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return GetAttrVariable( + AutogradFunctionVariable( + value.__self__, source=AttrSource(self.source, member="__self__") + ), + "apply", + ) + elif isinstance(value, torch._C._ImperativeEngine): + self.install_guards(GuardBuilder.ID_MATCH) + return AutogradEngineVariable(value, source=self.source) + elif ( + value + is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub + ): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return LambdaVariable( + lambda: UserFunctionVariable( + torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks, + ).call_function( + self.tx, + (self.tx.output.side_effects.get_ca_final_callbacks_var(),), + {}, + ) + ) + elif isinstance(value, DynamoConfigPatchProxy): + return DynamoConfigPatchVariable(value.changes) + elif callable(value) and trace_rules.lookup_callable(value) is not None: + if trace_rules.is_callable_allowed(value): + self.tx.output.has_user_defined_allowed_in_graph = True + return trace_rules.lookup_callable(value).create_with_source( + value, source=self.source + ) + elif np and isinstance(value, np.number): + return self.wrap_unspecialized_primitive(value) + elif isinstance(value, HigherOrderOperator): + if value is torch._higher_order_ops.invoke_subgraph: + unimplemented_v2( + gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph", + context="", + explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region", + hints=[], + ) + self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH) + return TorchHigherOrderOperatorVariable.make(value, source=self.source) + elif isinstance(value, torch.cuda.StreamContext): + self.install_guards(GuardBuilder.ID_MATCH) + stream_source = AttrSource(self.source, "stream") + stream_var = VariableBuilder(self.tx, stream_source)(value.stream) + return StreamContextVariable.create(self.tx, stream_var) + elif isinstance(value, torch.Stream): + self.install_guards(GuardBuilder.ID_MATCH) + stream_proxy = self.tx.output.create_proxy( + "call_function", + type(value), + (), + { + "stream_id": value.stream_id, + "device_index": value.device_index, + "device_type": value.device_type, + }, + ) + set_example_value(stream_proxy.node, value) + return StreamVariable( + stream_proxy, + value, + value.device, + source=self.source, + ) + elif isinstance(value, (torch._C._SDPAParams)): + self.install_guards(GuardBuilder.TYPE_MATCH) + return SDPAParamsVariable.create(self.tx, value, self.source) + elif isinstance(value, torch._functorch.pyfunctorch.FuncTorchInterpreter): + self.install_guards(GuardBuilder.ID_MATCH) + return FuncTorchInterpreterVariable(value) + elif isinstance(value, torch.Event): + self.install_guards(GuardBuilder.ID_MATCH) + torch._dynamo.utils.store_user_object_weakref(value) + event_proxy = self.tx.output.create_proxy( + "call_function", + torch._dynamo.utils.get_user_object_from_id, + (id(value),), + {}, + ) + set_example_value(event_proxy.node, value) + return EventVariable( + event_proxy, + value, + source=self.source, + ) + elif ( + istype(value, contextlib.nullcontext) + and inspect.getattr_static(value, "enter_result", None) is None + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return NullContextVariable(source=self.source) + elif KeyedJaggedTensorVariable.is_matching_object(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = KeyedJaggedTensorVariable(value, source=self.source) + # TODO: this doing it manually is bad + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, torch.optim.Optimizer): + self.install_guards(GuardBuilder.ID_MATCH) + self.source = OptimizerSource(self.source) + return OptimizerVariable(value, source=self.source) + elif isinstance(value, torch.DispatchKeySet): + self.install_guards(GuardBuilder.DISPATCH_KEY_SET_MATCH) + return DispatchKeySetVariable(value) + elif WorldMetaClassVariable.is_group_member_type(value): + return WorldMetaClassVariable(value, source=self.source) + elif ProcessGroupVariable.is_process_group(value): + self.install_guards(GuardBuilder.ID_MATCH) + return ProcessGroupVariable(value, source=self.source) + elif DeviceMeshVariable.is_device_mesh(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.EQUALS_MATCH) + return DeviceMeshVariable(value, source=self.source) + elif PlacementClassVariable.is_placement_type(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.ID_MATCH) + return PlacementClassVariable(value, source=self.source) + elif PlacementVariable.is_placement(value): + # TODO: see if we need to add custom guard instead of a simple ID_MATCH + self.install_guards(GuardBuilder.EQUALS_MATCH) + return PlacementVariable( + value, + source=self.source, + ) + elif ( + id(value) in ITERTOOLS_TYPE_IDS + and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS + ): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return ItertoolsVariable(value, source=self.source) + elif is_torch_sym(value): + # Note: this doesn't handle nested symints. + # For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo. + + # Concretely, + # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source). + # so that guards on the SymInts can be effectively applied on the original SymBool in user program. + # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program + # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly. + source = ( + self.source + if isinstance(value, torch.SymInt) + else ConvertIntSource(self.source) + ) + if value.node.has_hint(): + new_symint = ( + self.tx.output.shape_env.create_unspecified_symint_and_symbol( + int(value.node.hint), + source, + dynamic_dim=DimDynamic.DYNAMIC, + ) + ) + else: + if isinstance(value, torch.SymBool): + # We need to create an unbacked symint to replace the unbacked symbool. + new_symint = self.tx.output.shape_env.create_unbacked_symint() + else: + # TODO (yidi): we need to figure out a way to propagate the guards + # we accumulated when tracing the subggraph to outer shape_env. For normal symints, + # this is automatically done by evaluating the guards once but this + # will cause data-dependent error when we evaluate the outer unbacked symints. + # The test case that triggers this graph break is test_cond_unbacked_symint_closure + unimplemented_v2( + gb_type="Attempted to wrap unbacked SymInt", + context="", + explanation="Unbacked SymInt input is not supported yet.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + sym_node_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(new_symint), + new_symint, + source=source, + ) + + sym_node_proxy.node.meta["grapharg"] = GraphArg( + source, + new_symint, + False, + None, + is_tensor=False, + example_strong_ref=new_symint, + ) + # We bind the new_symint to graph input. + sym_expr = new_symint.node.expr + assert isinstance(sym_expr, sympy.Symbol), ( + f"{sym_expr} is not a basic Symbol." + ) + self.tx.output.tracked_fakes.append(TrackedFake(new_symint, source, None)) + + tracing_symint = ( + new_symint if isinstance(value, torch.SymInt) else new_symint == 1 + ) # cast it back to symbool for tracing + return SymNodeVariable(sym_node_proxy, tracing_symint) + + elif isinstance(value, (JITFunction, Autotuner)): + self.install_guards(GuardBuilder.ID_MATCH) + return TritonKernelVariable( + value, + None, # No kernel idx provided + None, # No grid provided + source=self.source, + ) + elif value is create_1d_tma_descriptor: + return CreateTMADescriptorExperimentalVariable(rank=1) + elif value is create_2d_tma_descriptor: + return CreateTMADescriptorExperimentalVariable(rank=2) + elif value is TensorDescriptor.from_tensor: + return CreateTMADescriptorStableVariable() + elif isinstance(value, torch.amp.autocast_mode.autocast): + self.install_guards(GuardBuilder.ID_MATCH) + return AutocastModeVariable( + target_values=[ + value.device, + value.fast_dtype, + value._enabled, + value._cache_enabled, + ], + source=self.source, + ) + elif TorchCtxManagerClassVariable.is_matching_cls(value): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return TorchCtxManagerClassVariable(value, source=self.source) + elif inspect.getattr_static(value, "__script_if_tracing_wrapper", False): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "__original_fn", source=self.source + ) + elif is_lru_cache_wrapped_function(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) + elif value is traceback.clear_frames: + return TracebackVariable(source=self.source) + elif value is sys.exc_info or ( + sys.version_info >= (3, 11) and value is sys.exception + ): + return SysFunctionVariable(value, source=self.source) + elif is_function_or_wrapper(value) and inspect.getattr_static( + value, "_torchdynamo_inline", False + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) + elif value is functools.wraps: + self.install_guards(GuardBuilder.ID_MATCH) + return FunctoolsWrapsVariable(value, source=self.source) + elif value is collections.namedtuple: + self.install_guards(GuardBuilder.ID_MATCH) + return CollectionsNamedTupleFunction(value, source=self.source) + elif isinstance( + value, types.BuiltinMethodType + ) and BuiltinMethodVariable.is_supported_builtin_method(value): + self.install_guards(GuardBuilder.ID_MATCH) + return BuiltinMethodVariable(value, source=self.source) + elif is_function_or_wrapper(value): + value, attr_name = unwrap_with_attr_name_if_wrapper(value) + # For these wrappers, Dynamo points to the wrapped function, + # so source needs to be updated as well. + if attr_name is not None: + self.source = AttrSource(self.source, attr_name) + return trace_rules.lookup(value).create_with_source( + value, source=self.source + ) + elif value is random.Random: + self.install_guards(GuardBuilder.ID_MATCH) + return RandomClassVariable(source=self.source) + elif istype(value, random.Random) and RandomVariable.is_supported_random_obj( + value + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = RandomVariable(value, source=self.source) + self.tx.output.side_effects.track_mutable(value, result) + return result + # Don't use istype, since some python modules are not subclasses of types.ModuleType directly. + # E.g, type(torch.ops) -> , + # type(torch.backends.cudnn) -> + elif isinstance(value, (types.ModuleType, replay_record.DummyModule)): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + result = PythonModuleVariable( + value, + source=self.source, + ) + self.tx.output.side_effects.track_object_existing(value, result) + return result + elif isinstance(value, types.MethodType) and isinstance( + value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec) + ): + # don't let MethodTypes fall through to UserDefinedObject, + # which doesn't support 'CALL_FUNCTION' + + # TODO(whc): Why do we limit this to methods on NNModules? + # I don't have a good reason for this, but it preserves the existing behavior + # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise. + # I suspect we probably want to relax this check and dig deeper there. + + # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python, + # but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here + # and then `__func__` gets wrapped inside UserMethodVariable. + self_obj = VariableBuilder( + self.tx, source=AttrSource(self.source, "__self__") + )(value.__self__) + assert self_obj and isinstance(self_obj, VariableTracker), ( + "Failed to produce a valid self obj" + ) + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return UserMethodVariable( + value.__func__, + self_obj, + source=self.source, + ) + elif isinstance(value, types.GetSetDescriptorType): + # GetSet descriptors are C functions attached to an attribute lookup + # using PyGetSetDef. Python, on attribute lookup, can decide to + # create a new object on the fly, and therefore the `id` of the + # descriptors is not guaranteed to be same for different attribute + # accesses. Since these are unlikely to change during the program + # execution, we can skip guarding on them. + return GetSetDescriptorVariable(value) + elif isinstance(value, types.MethodWrapperType): + # Method-wrappers are written in C, and they are not guaranteed to + # return the same object on attribute lookup. Therefore, we cannot + # insert a FUNCTION_MATCH guard here. method-wrappers are very + # unlikely to change, so its ok to skip the guard here. + return MethodWrapperVariable(value) + elif issubclass(type(value), type) and issubclass(value, BaseException): + # match user defined exceptions + self.install_guards(GuardBuilder.ID_MATCH) + return UserDefinedExceptionClassVariable(value) + elif issubclass(type(value), type): + if value in ( + torch.utils.hooks.BackwardHook, + torch.nn.Parameter, + torch.nn.Buffer, + ): + # TODO(jansel): combine this case with the one above + return trace_rules.lookup(value).create_with_source( + value, source=self.source + ) + if value is torch.autograd._unsafe_preserve_version_counter: + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return PreserveVersionContextVariable.constructor(self.tx) + if ( + # `value` must be a strict subclass of `torch.Tensor` + issubclass(value, torch.Tensor) + and value is not torch.Tensor + # `TensorSubclassVariable` is not for subclass that overrides + # `torch_dispatch`. + and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__ + # `TensorSubclassVariable` would lead to construction of + # `TensorWithTFOverrideVariable`, but we don't want that for + # traceable wrapper subclasses (we wrap those subclass instances + # into `TensorVariable`). + and not is_traceable_wrapper_subclass_type(value) + ): + return TensorSubclassVariable(value, source=self.source) + # This is a userdefined class, so install an ID_MATCH even if its a + # global variable. + self.install_guards(GuardBuilder.ID_MATCH) + return UserDefinedClassVariable( + value, + source=self.source, + ) + elif TorchScriptObjectVariable.is_matching_cls(type(value)): + from ..source import ( + FlattenScriptObjectSource, + ScriptObjectQualifiedNameSource, + ) + + if torch._library.fake_class_registry.tracing_with_real(value): + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + value, + source=self.source, + ) + + # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default + # setting example to be real value because these example values will be used + # as example_inputs for user compiler. + proxy.node.meta["grapharg"] = GraphArg( + self.source, value, False, None, False, value + ) + return TorchScriptObjectVariable.create( + proxy, + value, + source=self.source, + ) + + # This exists to allow a smoother transition. + # The implications are: + # The script objects won't be tracked as proxies. + # Methods on these objects won't show up in the graph. + # The original script object might be mutated. + if not hasattr(value, "__obj_flatten__"): + return self.wrap_user_defined(value) + + # Install the guards on the fully qualified name of the script object + LazyVariableTracker.realize_all( + VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))( + value._type().qualified_name() # type: ignore[attr-defined] + ) + ) + # Install the guards on the content of the script object by setting the source + # to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents. + LazyVariableTracker.realize_all( + VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))( + value.__obj_flatten__() + ) + ) + + fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj( + self.tx.output.fake_mode, value + ) + + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + fake_script_obj, + source=self.source, + ) + + # setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default + # setting example to be real value because these example values will be used + # as example_inputs for user compiler. + proxy.node.meta["grapharg"] = GraphArg( + self.source, value, False, None, False, fake_script_obj + ) + return TorchScriptObjectVariable.create( + proxy, + fake_script_obj, + source=self.source, + ) + elif ( + isinstance(value, (dict, collections.OrderedDict)) + and type(value).__new__ is dict.__new__ + ): + # Construct a dict_vt that will reside inside the UserDefinedDictVariable + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # Guard on the key order + self.tx.output.guard_on_key_order.add(self.source) + + # We need all the keys to be hashable. We do this within the + # _HashableTracker class in dicts.py + def build_key_value(i, k, v): + base = self.get_source() + source_key = ConstDictKeySource(base, i) + key = LazyVariableTracker.create(k, source_key) + + source_value = DictSubclassGetItemSource(base, source_key) + res_value = LazyVariableTracker.create(v, source_value) + + return key, res_value + + # Ensure that we call dict.keys and not value.keys (which can call + # overridden keys method). In the C++ guards, we relied on + # PyDict_Next to traverse the dictionary, which uses the internal + # data structure and does not call the overridden keys method. + result = dict( + build_key_value(i, k, v) + for i, (k, v) in enumerate(get_items_from_dict(value)) + ) + + dict_vt = ConstDictVariable( + result, + user_cls=( + collections.OrderedDict + if isinstance(value, collections.OrderedDict) + else dict + ), + mutation_type=ValueMutationExisting(), + source=self.source, + ) + # Force this to reconstruct on mutation to keep the reconstruction + # bytecode simple + dict_vt.should_reconstruct_all = True + + result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, tuple): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # NB - Be careful in not triggering user code. Guards also work on + # the underlying tuple data structure. + output = [ + LazyVariableTracker.create( + tuple.__getitem__(value, i), + source=GetItemSource(self.get_source(), i), + ) + for i in range(tuple.__len__(value)) + ] + + tuple_vt = TupleVariable( + output, source=self.source, mutation_type=ValueMutationExisting() + ) + result = UserDefinedTupleVariable( + value, tuple_vt=tuple_vt, source=self.source + ) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, list): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # NB - Be careful in not triggering user code. Guards also work on + # the underlying list data structure. + output = [ + LazyVariableTracker.create( + list.__getitem__(value, i), + source=ListGetItemSource(self.get_source(), i), + ) + for i in range(list.__len__(value)) + ] + list_vt = ListVariable( + output, source=self.source, mutation_type=ValueMutationExisting() + ) + result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif issubclass(type(value), MutableMapping): + self.install_guards(GuardBuilder.TYPE_MATCH) + return MutableMappingVariable(value, source=self.source) + elif is_frozen_dataclass(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = FrozenDataClassVariable.create(self.tx, value, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, dict_keys): + if all(ConstantVariable.is_literal(k) for k in value): + # If the dict_keys object is passed from outside the compile region, it must either be passed along with + # the corresponding dict object or treated as a set (when only the keys are passed into the compiled region). + # - If it is passed along with the dict, the dict object itself is already guarded. + # - If only the dict_keys object is passed, we add EQUALS_MATCH and SEQUENCE_LENGTH guards + # to ensure it remains unchanged across multiple runs. + items = [SourcelessBuilder.create(self.tx, v) for v in value] + install_guard( + self.get_source().make_guard(GuardBuilder.SEQUENCE_LENGTH), + self.get_source().make_guard(GuardBuilder.EQUALS_MATCH), + ) + return DictKeySetVariable(items, source=self.source) + else: + unimplemented_v2( + gb_type="non-const keys in dict_keys", + context=f"non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}", + explanation="Dynamo expects dict_keys keys to be constants.", + hints=[ + "Ensure your dict_keys keys are constants (e.g. int, float, strings)", + ], + ) + elif IntWrapperVariable.is_matching_object(value): + from torch.export.dynamic_shapes import _DimHintType + + if value.dynamism is None or value.dynamism.type == _DimHintType.STATIC: + return self.wrap_symint(value.val) + elif value.dynamism.type == _DimHintType.DYNAMIC: + log.debug( + "%s marked %s via IntWrapper", + self.source.name(), + DimDynamic.DYNAMIC, + ) + return self.wrap_symint( + value.val, + dynamism=DimDynamic.DYNAMIC, + context=SymIntSymbolicContext( + constraint=RelaxedUnspecConstraint(warn_only=False) + ), + ) + elif value.dynamism.type == _DimHintType.AUTO: + log.debug( + "%s marked %s via IntWrapper", + self.source.name(), + DimDynamic.DYNAMIC, + ) + return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC) + else: + raise RuntimeError(f"Undefined dynamism {value.dynamism}") + else: + return self.wrap_user_defined(value) + + def wrap_user_defined(self, value: Any): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = UserDefinedObjectVariable(value, source=self.source) + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + + def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): + for item in value: + if item is value: + unimplemented_v2( + gb_type="list elements are pointing to the list itself", + context="", + explanation="Dynamo does not support lists whose items reference to itself", + hints=["Avoid using self referential list"], + ) + + if config.specialize_int and type(value) is torch.Size: + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value) + + # One can index a tensor with a list/tuple. Therefore, we need to + # have a stricter match. + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # Tuples are immutable objects, so we should mark its items static. This + # avoids wrapping of tuple items as symints. This helps for nn module + # attributes like conv2d strides, dilations. + if ( + istype(value, tuple) + and all(ConstantVariable.is_literal(item) for item in value) + and self.source.guard_source().is_unspecialized_nn_module() + ): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return TupleVariable([ConstantVariable.create(item) for item in value]) + + output = [ + LazyVariableTracker.create( + item, + source=GetItemSource(self.get_source(), i), + ) + for i, item in enumerate(value) + ] + + maybe_gm = self.tx.output.local_scope.get("self") + if isinstance( + self.source, LocalSource + ) and self.source.local_name in get_locals_to_steal(maybe_gm): + # The input tensor list to dynamo from compiled autograd may contain activations + # which are freed as they are used in inductor. Dynamo's default behavior is to + # lift all tensors to the graph inputs, but this will cause dynamo to hold an + # extra reference to the activation tensors and increase peak memory usage. + # To allow freeing ASAP, we keep the list as graph argument to the dynamo output + # graph, and unpack it locally. + # e.g. instead of `def forward(self, L_inputs_0_, L_inputs_1_, ...):`, we have + # `def forward(self, L_inputs_):` + source = self.source + assert isinstance(value, list) + tensor_list_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + value, + source=source, + ) + tensor_list_proxy.node.meta["steal_arg"] = True + + list_variable = wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=self.tx, + proxy=tensor_list_proxy, + example_value=value, + subclass_type=None, + source=source, + ) + + guards = [] + for i, tensor_variable in enumerate(list_variable.items): + source_i = GetItemSource(base=source, index=i, index_is_slice=False) + # access unpacked tensor from this list instead of from a lifted arg + self.tx.output.input_source_to_var[source_i] = tensor_variable + tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict( + value[i] + ) + + guard = functools.partial( + GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i]) + ) + guards.append(source_i.make_guard(guard)) + + install_guard(*guards, skip=1) + + grapharg = GraphArg( + source, + value, + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + ) + tensor_list_proxy.node.meta["grapharg"] = grapharg + + result = BaseListVariable.cls_for_instance(value)(output, source=self.source) + if istype(value, (list, collections.deque)): + return self.tx.output.side_effects.track_mutable(value, result) + return result + + def wrap_tuple_iterator(self, value: tuple_iterator): + self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN) + output = [ + VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))( + tuple_iterator_getitem(value, i) + ) + for i in range(tuple_iterator_len(value)) + ] + result = TupleIteratorVariable(output, source=self.source) + return self.tx.output.side_effects.track_mutable(value, result) + + def wrap_range_iterator(self, value: range_iterator): + self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH) + # Get all the values from the range iterator; no need to install guards + # on items since `RANGE_ITERATOR_MATCH` guarantees the same items. + items = [ConstantVariable.create(v) for v in copy.deepcopy(value)] + result = ListIteratorVariable(items, source=self.source) + return self.tx.output.side_effects.track_mutable(value, result) + + def wrap_slice_range(self, value: Union[slice, range]): + items = [ + VariableBuilder(self.tx, AttrSource(self.get_source(), k))( + getattr(value, k) + ) + for k in ("start", "stop", "step") + ] + self.install_guards(GuardBuilder.TYPE_MATCH) + if isinstance(value, slice): + return SliceVariable(items, source=self.source) + else: + return RangeVariable(items, source=self.source) + + def mark_static_input(self, value: torch.Tensor, guard: bool): + from ..decorators import mark_static_address + + static_inputs_log.debug( + "Marking static input %s, id: %s)", self.source.name(), id(value) + ) + mark_static_address(value, guard=guard) + + # Check if we've seen this tensor before and update graph metadata if needed + # As long as this runs before AOT this is sound + if value in self.tx.output.side_effects: + var = self.tx.output.side_effects[value] + var.proxy.node.meta["tensor_dict"]["_dynamo_static_input_type"] = ( + value._dynamo_static_input_type + ) + + def wrap_module(self, value: torch.nn.Module): + from ..eval_frame import OptimizedModule + + if len(value.__dict__) == 0: + unimplemented_v2( + gb_type="Uninitialized nn.Module", + context=typestr(value), + explanation=f"Attempted to trace an uninitialized nn.Module of type {typestr(value)}.", + hints=[ + *graph_break_hints.USER_ERROR, + "Ensure your nn.Module instance has called `super().__init__()`.", + ], + ) + if istype(value, OptimizedModule): + # Check if the optimized module was disabled + if inspect.getattr_static(value.forward, "_torchdynamo_disable", False): + # This bytecode is mostly of kind LOAD_ATTR or LOAD_METHOD. If + # we graph break here, Dynamo does not know how to create + # continuation functions for such bytecodes. So, we delay the + # graph break to CALL_FUNCTION. + msg = inspect.getattr_static( + value.forward, "_torchdynamo_disable_msg", None + ) + return DelayGraphBreakVariable( + source=self.source, + msg=f"Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: {msg})", + ) + + self.install_guards(GuardBuilder.TYPE_MATCH) + self.source = AttrSource(self.source, "_orig_mod") + return self.wrap_module(value._orig_mod) + + if ( + isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) + and not config.allow_rnn + ): + unimplemented_v2( + gb_type="Attempted to wrap RNN, GRU, or LSTM", + context=str(value), + explanation="Dynamo does not support RNN, GRU, or LSTM.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + if getattr(value, "_is_fsdp_managed_module", False): + # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] + # in fully_sharded_data_parallel.py for more information + + # we can't do this assert inside FSDP constructor, + # since we don't know yet whether dynamo will be used + if not getattr(value, "_fsdp_use_orig_params", False): + unimplemented_v2( + gb_type="FSDP with use_orig_params=False", + context="", + explanation="Dynamo only supports FSDP with use_orig_params=True", + hints=[], + ) + + # Note on FSDP guarding + # Eager FSDP already assumes (requires, but without enforcement) + # that users don't mutate their model parameters/structure after + # FSDP wrapping, because FSDP wouldn't notice or update its + # FlatParams. + # + # Therefore, torch.compile can skip guarding on params or submodule + # structure of fsdp_managed modules, by using FSDPNNModuleSource as + # the guard source. This behavior is gated on + # config.skip_fsdp_guards. + self.install_guards(GuardBuilder.TYPE_MATCH) + result = FSDPManagedNNModuleVariable(value, source=self.get_source()) + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + elif mutation_guard.is_dynamic_nn_module(value, self.tx.export): + # created dynamically, don't specialize on it + + # Note [Tracing a torch.compiled function] + # when make_fx tracing a compiled function, we need + if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy): + value = value.get_base() + self.source = AttrProxySource(self.source) + + if torch._dynamo.config.inline_inbuilt_nn_modules: + freezing = is_parameter_freezing() + + # Guard against the case where user may overwrite named parameters + # / named buffers + # NOTE: This is not likely to happen but worth guarding to avoid + # exception + if ( + callable(value.named_parameters) + and value.named_parameters.__func__ + is og_module_named_parameters_fn_ptr + ): + try: # catch TypeErrors in named_parameters() from unserializable nn modules + for _, p in value.named_parameters(): + self.mark_static_input(p, guard=freezing) + except TypeError as e: + raise_observed_exception(type(e), self.tx, args=list(e.args)) + + if ( + callable(value.named_buffers) + and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr + ): + try: # catch TypeErrors in named_parameters() from unserializable nn modules + for _, b in value.named_buffers(): + self.mark_static_input(b, guard=freezing) + except TypeError as e: + raise_observed_exception(type(e), self.tx, args=list(e.args)) + + if freezing: + # we need to add the module to tracing context + # in order to allow its params to get invalidated + # this will get cleaned up once compile ends + self.tx.output.nn_modules[self.name] = value + + if ( + value.__module__.startswith(("torch.nn.modules", "torch.ao.")) + and not value.__module__.startswith("torch.nn.modules.container") + ) or getattr(value.__class__, "_dynamo_marked_static", False): + new_source = self.source + if config.inline_inbuilt_nn_modules and ( + not self.tx.output.export or config.install_free_tensors + ): + # Export corner case - look at test_repros.py test_inlining_cornercase + new_source = UnspecializedBuiltinNNModuleSource(self.source) + result = UnspecializedBuiltinNNModuleVariable(value, source=new_source) + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) + else: + new_source = self.source + if config.inline_inbuilt_nn_modules and ( + not self.tx.output.export or config.install_free_tensors + ): + # Export corner case - look at test_repros.py test_inlining_cornercase + new_source = UnspecializedNNModuleSource(self.source) + result = UnspecializedNNModuleVariable(value, source=new_source) + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) + + if not SideEffects.cls_supports_mutation_side_effects(type(value)): + # don't allow STORE_ATTR mutation with custom __setattr__ + return result + return self.tx.output.side_effects.track_object_existing(value, result) + elif issubclass( + value.__class__, torch.nn.parallel.distributed.DistributedDataParallel + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return UnspecializedNNModuleVariable(value, source=self.get_source()) + else: + return self.tx.output.register_attr_or_module( + value, + self.name, + source=self.get_source(), + # Guards are added inside register_attr_or_module + ) + + def wrap_literal(self, value): + if type(value) is int: + # allowlist has higher precedence over specialization control. + if is_dynamic_source(self.source.name()): + log.debug("%s marked dynamic via source whitelist", self.source.name()) + return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC) + + if is_unbacked_source(self.source.name()): + log.debug("%s marked unbacked via source whitelist", self.source.name()) + return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED) + + if not config.specialize_int: + # unspecializing int by default, but still + # specialize for the following conditions + if is_int_specialization_case(value, self.source): + recompile_hint = None + if ( + self.source.guard_source().is_unspecialized_builtin_nn_module() + or self.source.guard_source().is_unspecialized_nn_module() + ): + # This means that it is an integer from a NN module. + # Dynamo considers nn module int attributes to be static + # (a good heuristic). But a user might want to mark the + # int attribute to be a symint, so track this integer + # for recompilation later. + recompile_hint = ( + "torch.compile considers integer attributes of the nn.Module to be static. " + "If you are observing recompilation, you might want to make this integer dynamic " + "using torch._dynamo.config.allow_unspec_int_on_nn_module = True, or convert this " + "integer into a tensor." + ) + + self.install_guards( + functools.partial( + GuardBuilder.EQUALS_MATCH, recompile_hint=recompile_hint + ) + ) + return ConstantVariable.create(value=value, source=self.source) + + return self.wrap_symint(value) + elif not config.specialize_float and type(value) is float: + return self.wrap_symfloat(value) + else: + self.install_guards(GuardBuilder.CONSTANT_MATCH) + result = ConstantVariable.create(value=value, source=self.source) + if isinstance(value, (list, set)): + return self.tx.output.side_effects.track_mutable(value, result) + return result + + def assert_not_wrapped_by_this_graph(self, value: torch.Tensor): + if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode: + raise InternalTorchDynamoError( + "Cannot wrap a Tensor that has already been", + "wrapped by this instance of Dynamo", + ) + + def wrap_tensor(self, value: torch.Tensor): + source = self.get_source() + + # We cannot already be tracking the tensor, which implies + # it would have already been wrapped + assert value not in self.tx.output.side_effects + + is_static_input = get_static_address_type(value) is not None + + if ( + config.inline_inbuilt_nn_modules + and not is_static_input + and ( + isinstance(value, torch.nn.Parameter) + # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior + # compatible with previous behavior. + or (source and source.guard_source().is_unspecialized_nn_module()) + ) + ): + self.mark_static_input(value, guard=is_parameter_freezing()) + is_static_input = True + + # Install any tensors which are "free" variables; that is: + # 1. Globals + # 2. NonLocals + # 3. tensors that are attributes of nn module + should_install_free_tensor = config.install_free_tensors and ( + is_from_global_source(source) + or is_from_nonlocal_source(source) + or is_from_unspecialized_nn_module_source(source) + ) + + make_graph_attribute = is_static_input and ( + not config.inline_inbuilt_nn_modules + or is_parameter_freezing() + or torch._dynamo.config.prepare_freezing + ) + + if should_install_free_tensor or ( + (source.guard_source().is_specialized_nn_module() or make_graph_attribute) + and not source.guard_source().is_fsdp_module() + ): + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, self.name, source=source + ) + + if get_static_address_type(value) == "guarded": + # If it's a guarded tensor, we can install the parameter directly + # into the Fx graph instead of lifting it as an input. Lifting + # offers no benefit, such as regional compilation, since we still + # guard on the tensor's ID. Moreover, installing it in the Fx graph + # eliminates the pre-graph bytecode required to extract the tensor + # from locals/globals, reducing overhead. This can lead to + # significant cost savings, especially for optimizers handling many + # tensors. + self.install_guards(GuardBuilder.ID_MATCH) + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, self.name, source=source + ) + + if is_constant_source(source): + self.assert_not_wrapped_by_this_graph(value) + return self.tx.output.register_attr_or_module( + value, + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + source=source, + # Guards are added inside register_attr_or_module + ) + + # NB: this just says we accessed a tensor from the same source again + # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). + # This is distinct from two distinct sources mapping to the same + # Tensor (per id())! No guard is necessary here. See below for the + # other case. + is_duplicate_tensor = source in self.tx.output.input_source_to_var + if is_duplicate_tensor: + return self.tx.output.input_source_to_var[source] + + options = {} + if type(value) in ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ) or is_traceable_wrapper_subclass(value): + # Ordinarily, we would fakeify a tensor so that it can get dynamic + # shapes and be computed on without triggering actual operations. + # However, how can we fakeify a tensor subclass? Ordinary + # inheritance (nor multiple inheritance) won't work work. + # + # Instead, our plan is to *manually simulate* the tensor subclass + # inheriting from a fake tensor with dynamo. This means our + # data representation for a tensor subclass will be a fake tensor + # + tensor subclass type + any extra data the subclass may have + # been storing on the tensor. Because all Python accesses are + # mediated through TensorWithTFOverrideVariable, we can ensure + # that we dispatch differently, e.g., according to + # __torch_function__ + # + # To simplify things for now, the __dict__ tracking bits haven't + # been implemented yet, but they can be added into this design at + # a later point in time. + subclass_type = None + else: + subclass_type = type(value) + self.install_guards(GuardBuilder.TYPE_MATCH) + + if get_static_address_type(value) == "guarded": + self.install_guards(GuardBuilder.ID_MATCH) + + # By this point, we should have deduplicated all tensors + self.assert_not_wrapped_by_this_graph(value) + + if ( + isinstance(value, torch.Tensor) + and value.is_nested + and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor) + ): + unimplemented_v2( + gb_type="Attempted to wrap strided NestedTensor", + context="", + explanation="torch.compile does not support strided NestedTensor", + hints=[], + ) + + # TODO(pearu,sparse-team) - Add the corresponding SPARSE_TENSOR_MATCH guards + if ( + isinstance(value, torch.Tensor) + and is_sparse_any(value) + and (not self.tx.export or not config.capture_sparse_compute) + ): + # A hot fix for sparse tensors + torch.compile. Support for + # export + sparsity is being added but we need to create + # SPARSE_TENSOR_GUARDS for guards to work properly. + unimplemented_v2( + gb_type="Attempted to wrap sparse Tensor", + context="", + explanation="torch.compile does not support sparse Tensors", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + if ( + safe_has_grad(value) + and safe_grad(value) is not None + and value.dtype != safe_grad(value).dtype + ): + unimplemented_v2( + gb_type="dtype mismatch between tensor and its gradient", + context=f"tensor dtype: {value.dtype}; grad dtype: {safe_grad(value).dtype}", + explanation="Inconsistent dtype between tensor and its gradient. " + "This can happen in FSDP and crashes meta tensor creation.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + # tx.output has multiple tracers if we're introspecting HigherOrderOperator. + # When we've discovered an untracked tensor, then we actually need + # to get Dynamo to track the tensor (which is what this function does) + # and put it as a graph input on the root tracer. Later on, + # if the input is actually used in the body of the HigherOrderOperator, + # then the relevant SubgraphTracer will lift it to being an input of + # the subgraph. + # See NOTE [HigherOrderOperator tracing design] for more details. + + example_value = wrap_to_fake_tensor_and_record( + value, tx=self.tx, is_tensor=True, source=source + ) + + tensor_proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(value), + example_value, + source=source, + ) + cache_real_value_when_export(self.tx, tensor_proxy, value) + + tensor_variable = wrap_fx_proxy( + tx=self.tx, + proxy=tensor_proxy, + example_value=example_value, + subclass_type=subclass_type, + source=source, + **options, + ) + + if value._is_view(): + # If value is a view, add its base tensor to the tracked fakes list. + # This is so we are able to access the correct source for its symbolic + # shape values, in case we need them. + wrap_to_fake_tensor_and_record( + value._base, + tx=self.tx, + source=AttrSource(source, "_base"), + is_tensor=True, + ) + + guard_type = GuardBuilder.TENSOR_MATCH + + if isinstance(source, GradSource) and is_from_optimizer_source(source): + guard_type = GuardBuilder.NOT_NONE_MATCH + + self.install_guards( + functools.partial( + guard_type, + value=( + value + if isinstance(source, NumpyTensorSource) + else TensorWeakRef(value) + ), + ) + ) + + # We install TYPE_MATCH guards for traceable wrapper subclass object, + # and recursively install corresponding guard for each inner attribute. + if is_traceable_wrapper_subclass(value): + self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH) + self.install_guards(GuardBuilder.TYPE_MATCH) + install_guard( + SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH) + ) + + attrs, _ = value.__tensor_flatten__() + for attr in attrs: + inner_value = getattr(value, attr) + inner_source = AttrSource(self.source, attr) + LazyVariableTracker.realize_all( + VariableBuilder(self.tx, inner_source)(inner_value) + ) + + self.tx.output.input_source_to_var[source] = tensor_variable + assert "tensor_dict" not in tensor_proxy.node.meta + tensor_proxy.node.meta["tensor_dict"] = _extract_tensor_dict(value) + + # Note: this information is conveyed via subclass_type now + fake_tensor_value = tensor_variable.proxy.node.meta["example_value"] + if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode: + raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake") + + grapharg = GraphArg(source, value, False, fake_tensor_value) + tensor_proxy.node.meta["grapharg"] = grapharg + return tensor_variable + + def wrap_numpy_ndarray(self, value): + assert np is not None + assert isinstance(value, np.ndarray) + + source = NumpyTensorSource(self.get_source()) + + from torch._numpy import _util + + readonly = not value.flags.writeable + if readonly: + try: + value.flags.writeable = True + except ValueError: + # One can not easily make nditer elements writable, + # but warning is not the end of the world + assert isinstance(value.base, np.nditer) + + with torch_function_mode_stack_state_mgr.temp_restore_stack(): + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides + + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented_v2( + gb_type="failed to convert numpy.ndarray to Tensor", + context=str(value), + explanation="Exception encountered when attempting to convert numpy.ndarray to Tensor", + hints=[], + from_exc=e, + ) + + # We do this because we want the full behavior of guarding the numpy ndarray as if it were + # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here + # that there's not another great way to do this atm. + # This creates the right graphargs, as well as registration for guards in tensor names and shape env. + LazyVariableTracker.realize_all(VariableBuilder(self.tx, source)(tensor_value)) + example_value = wrap_to_fake_tensor_and_record( + tensor_value, + tx=self.tx, + is_tensor=False, + source=source, + ) + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(tensor_value), + example_value, + source=source, + ) + cache_real_value_when_export(self.tx, proxy, tensor_value) + options = {"source": source} + numpy_ndarray_variable = wrap_fx_proxy_cls( + target_cls=NumpyNdarrayVariable, + tx=self.tx, + proxy=proxy, + example_value=example_value, + **options, + ) + + self.tx.output.input_source_to_var[source] = numpy_ndarray_variable + example_value = numpy_ndarray_variable.proxy.node.meta["example_value"] + + # pass_arg_as_tensor should be true because we are wrapping a np.ndarray as argument input, and it needs to be + # converted to a tensor. + grapharg = GraphArg( + source, + tensor_value, + pass_arg_as_tensor=True, + fake_tensor=example_value, + is_tensor=True, + example_strong_ref=tensor_value, + ) + proxy.node.meta["grapharg"] = grapharg + + # TODO - Why do we need to set the source of the np ndarray vt back to + # original source. Many tests fails. + numpy_ndarray_variable.source = self.source + + return numpy_ndarray_variable + + def wrap_symint( + self, + value, + dynamism: Optional[DimDynamic] = None, + context: Optional[SymIntSymbolicContext] = None, + ): + assert type(value) is int + + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + + shape_env = self.tx.output.shape_env + if TracingContext.get().force_unspec_int_unbacked_size_like: + wrapped_value = shape_env.create_unbacked_symint() + _constrain_range_for_size(wrapped_value) + self.tx.output.tracked_fakes.append( + TrackedFake(wrapped_value, self.source, None) + ) + + # NB: We do not do float. For motivation, see + # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit + # but the general idea is that we generate kernels that can + # take unspecialized floats and use them in sizevar computation + elif not is_constant_source(self.get_source()): + if dynamism is None and torch._dynamo.config.specialize_int: + # If specialize_int is False, also return + # a constant (but this should have been handled + # in the caller, TBH). But if `dynamism` is set, then actually + # turn it into a symint + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + name = self.source.name() + + frame_state_entry = process_automatic_dynamic( + self.tx, + name, + FrameStateSizeEntry.make_scalar(value), + is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), + ) + + # TODO: This should be dynamic, as we in general do not + # know if bare integers are actually going to be sizevars + # and it is inappropriate to eagerly duck size them with + # real sizevars + normalized_source_name = normalize_source_name(self.source.name()) + base_source = self.source + if isinstance(base_source, ChainedSource): + base_source = base_source.get_base() + + if dynamism is not None: + dynamic_dim = dynamism + elif ( + config.automatic_dynamic_shapes + and frame_state_entry.scalar is auto_dynamic + ): + set_feature_use("dynamo.automatic_dynamic_shapes", True) + dynamic_dim = get_automatic_dynamic_shapes_mark_as() + elif ( + isinstance(base_source, LocalSource) + and base_source.dynamism is not None + and dict(base_source.dynamism).get(normalized_source_name, {0: False})[ + 0 + ] + ) or not config.assume_static_by_default: + dynamic_dim = DimDynamic.DYNAMIC + else: # assume_static_by_default + # TODO: dynamic_dim = DimDynamic.STATIC should work but + # for some reason it doesn't + if frame_state_entry.scalar is auto_dynamic: + set_feature_use("dynamo.automatic_dynamic_shapes", False) + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value) + + wrapped_value = shape_env.create_unspecified_symint_and_symbol( + value, + source=self.source, + dynamic_dim=dynamic_dim, + ) + + self.tx.output.tracked_fakes.append( + TrackedFake(wrapped_value, self.source, context) + ) + else: + assert is_constant_source(self.get_source()) + # TODO: Do I actually need guard for constant source? + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + assert not isinstance(self.get_source(), RandomValueSource) + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + + options = {"source": self.get_source()} + + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + wrapped_value, + source=self.get_source(), + ) + + sym_expr = wrapped_value.node.expr + assert isinstance(sym_expr, sympy.Symbol), f"{sym_expr} is not a basic Symbol." + self.tx.output.root_tracer.bound_symbols[sym_expr] = proxy + unspec_var = SymNodeVariable(proxy, wrapped_value, **options) + self.tx.output.unspec_variable_map[self.name] = unspec_var + + if not is_constant_source(self.get_source()): + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + pass_arg_as_tensor=False, + fake_tensor=None, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + + return unspec_var + + def wrap_symfloat(self, value): + # SymFloat wrapping is special. We first wrap it in the same way we + # do an unspecialized primitive, and then we item() it into a + # SymFloat. Removal of the item() call is left to a later FX pass, + # mostly because that pass is more easily done after we have lowered + # to ATen ops. (Dynamo doesn't do decomposition right now). + + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + + frame_state_entry = process_automatic_dynamic( + self.tx, + self.source.name(), + FrameStateSizeEntry.make_scalar(value), + is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), + ) + + # NB: we specialize on nan input, because our guard modeling in + # ShapeEnv cannot deal with nan + if ( + torch._dynamo.config.specialize_float + or is_constant_source(self.get_source()) + or math.isnan(value) + or math.isinf(value) + # We don't support cudagraphs for now. Without this cudagraphs + # break because they expect all cuda inputs but our tensorified + # float will be a f64[] cpu tensor. Fixes the following test + # when specialize_float=False + # python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950 + or torch._inductor.config.triton.cudagraphs + or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False) + or ( + config.assume_static_by_default + and frame_state_entry.scalar is not auto_dynamic + ) + ): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + # NB: At the point we've gotten here, we don't assume static by + # default. Since we have a guard mechanism, there isn't really any + # downside to trying to be dynamic for float all the time. Unlike + # ints, this won't make codegen perf worse. Modest cost to compile + # time. + + wrapped_value = torch.tensor(value, dtype=torch.float64) + + # We don't support specializing floats for grad checking tensors + # See https://github.com/pytorch/pytorch/pull/140828 for more + # context. + if torch._C._functorch.is_gradtrackingtensor(wrapped_value): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return ConstantVariable.create(value=value, source=self.source) + + # TODO: Switch RandomValueSource over to use this, this is more + # accurate + assert not isinstance(self.get_source(), RandomValueSource) + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + + # The FloatTensorSource here is just for pedantic correctness: if you + # guard against an UnspecializedPythonVariable, you need to guard + # against the tensor-ified version of the local, otherwise it's not a + # Tensor. However, we never let the UnspecializedPythonVariable escape + # here, so there should never actually be any guards against this + # source. + source = FloatTensorSource(self.get_source()) + options = {"source": source, "raw_value": value} + + # TODO: Maybe the tensor-ification should be built into the source, + # rather than by special pattern match + example_value = wrap_to_fake_tensor_and_record( + wrapped_value, tx=self.tx, is_tensor=False, source=source + ) + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + example_value, + source=source, + ) + cache_real_value_when_export(self.tx, proxy, wrapped_value) + + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx=self.tx, + proxy=proxy, + example_value=example_value, + **options, + ) + assert isinstance(unspec_var, UnspecializedPythonVariable) + self.tx.output.unspec_variable_map[self.name] = unspec_var + + if self.tx.export and not isinstance(self.get_source(), LocalSource): + raise AssertionError( + f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" + ) + fake_tensor_value = None + example_value = unspec_var.proxy.node.meta["example_value"] + assert is_fake(example_value) + + fake_tensor_value = example_value + assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( + f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" + "({self.tx.fake_mode}) from InstructionTranslator" + ) + + # There's something a bit incoherent about pass_arg_as_tensor, + # specifically regarding sources. + # + # Specifically, suppose we have "x: float" local argument. We + # eventually end up with an UnspecializedPythonVariable denoting + # torch.as_tensor(x)... but it's source is still L['x'] (which if you + # accessed it directly is a float!) So you gotta be careful when + # setting up your guards, because it's still going to be a float at + # this point, the conversion happens only precisely at the point we're + # actually calling the FX graph. This happens to be what we want for + # shape guard generation, but it's kind of unintuitive. + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + pass_arg_as_tensor=True, + fake_tensor=fake_tensor_value, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + + # Directly do item to bypass capture_scalar_outputs + r = wrap_fx_proxy( + self.tx, + self.tx.output.create_proxy( + "call_method", + "item", + *proxy_args_kwargs([unspec_var], {}), + ), + ) + self.tx.output.tracked_fakes.append(TrackedFake(r.sym_num, self.source, None)) + + get_metrics_context().set("tensorify_float_attempt", True, overwrite=True) + + return r + + def wrap_unspecialized_primitive(self, value): + if self.name in self.tx.output.unspec_variable_map: + return self.tx.output.unspec_variable_map[self.name] + + wrapped_value = torch.tensor(value) + if not isinstance(self.get_source(), RandomValueSource): + install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) + + options = {"source": self.get_source()} + options.update({"raw_value": value}) + + example_value = wrap_to_fake_tensor_and_record( + wrapped_value, tx=self.tx, is_tensor=False, source=self.get_source() + ) + proxy = self.tx.output.root_tracer.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + type(wrapped_value), + example_value, + source=self.get_source(), + ) + cache_real_value_when_export(self.tx, proxy, wrapped_value) + + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx=self.tx, + proxy=proxy, + example_value=example_value, + **options, + ) + self.tx.output.unspec_variable_map[self.name] = unspec_var + if not is_constant_source(self.get_source()): + if self.tx.export and not isinstance(self.get_source(), LocalSource): + raise AssertionError( + f"Dynamo attempts to add additional input during export: value={wrapped_value}, source={self.get_source()}" + ) + fake_tensor_value = None + if isinstance(unspec_var, ConstantVariable): + # TODO: when can this happen? + example_value = unspec_var.value + else: + example_value = unspec_var.proxy.node.meta["example_value"] + assert is_fake(example_value) + + fake_tensor_value = example_value + assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( + f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" + "({self.tx.fake_mode}) from InstructionTranslator" + ) + + proxy.node.meta["grapharg"] = GraphArg( + self.get_source(), + wrapped_value, + pass_arg_as_tensor=True, + fake_tensor=fake_tensor_value, + is_tensor=False, + example_strong_ref=wrapped_value, + ) + return unspec_var + + +def _dataclasses_fields_lambda(obj): + if isinstance(obj, UserDefinedObjectVariable): + value = obj.value + else: + unimplemented_v2( + gb_type="dataclass fields failure", + context=f"obj: {obj}; variable type: {type(obj)}", + explanation=f"Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.", + hints=[], + ) + items = [] + for field in dataclasses.fields(value): + source = None + if obj.source: + base_src = AttrSource(obj.source, "__dataclass_fields__") + source = DictGetItemSource(base_src, field.name) + items.append(UserDefinedObjectVariable(field, source=source)) + return TupleVariable(items) + + +def _clone_input(value, fake_mode): + if isinstance(value, torch.Tensor): + # tensor subclasses will not be converted to FakeTensors and need to be cloned + if not ( + isinstance(value, FakeTensor) + or ( + # Is functional tensor fakeified by this instance of Dynamo + torch._is_functional_tensor(value) + and maybe_get_fake_mode(value) is fake_mode + ) + or value.is_nested + ): + # NB: ensure strides are preserved + value = clone_input(value) + + return value + + +def wrap_fx_proxy( + tx, proxy, example_value=None, subclass_type=None, **options +) -> VariableTracker: + kwargs = { + "tx": tx, + "proxy": proxy, + "example_value": example_value, + "subclass_type": subclass_type, + **options, + } + if subclass_type is None: + return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) + else: + result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs) + result.install_global(tx) + return result + + +def cache_real_value_when_export(tx, proxy, example_value): + if tx.export: + # The legacy behavior for real value cache with subclasses was + # to perform a clone WITHOUT preserving the subclass. It's + # not entirely clear this is what you actually want though. + with torch._C.DisableTorchFunctionSubclass(): + proxy.tracer.real_value_cache[proxy.node] = _clone_input( + example_value, tx.fake_mode + ) + + +# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable +# Should be compositional instead +# +# This is a horribly complicated function that does too many things, to +# explain what it does, let's first talk about the classic usage wrap_fx_proxy +# for a TensorVariable. There are two primary modes of use: +# +# 1. Wrapping a pre-existing Tensor. In this case, example_value is set +# to the pre-existing Tensor. (Note that this example_value will NOT +# be the final example_value we put into node.meta['example_value'], +# instead it is converted into a fake tensor using +# wrap_to_fake_tensor_and_record and registered as a graph input.) +# +# 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In +# this case, example_value is None (and we are going to figure it out +# ourselves using FakeTensors, via get_fake_value, which will run +# the operation represented by the (singular!) FX node referenced by +# the passed in proxy.) +# +# The expectation is you end up with a Tensor output, and everything is +# straightforwardly traced into the graph. +# +# In all cases, the returned `TensorVariable` subclass will have an `example_value` +# and that `example_value` must be a `FakeTensor` produced by the currently running +# instance of Dynamo. +# +# Upon closer inspection, you may notice that there are a slurry of non-Tensor +# output cases in handle_traced_output. What gives? Well, we sometimes trace operations into the +# graph that don't involve tensors. +# +# * Some operators return tuples; we need to recursively handle their +# contents +# +# * Some operators have side effects that will affect subsequent AOTAutograd +# tracing but don't otherwise return anything. +# +# * Some operators return symbolic ints/floats/bools which can go in the +# graph and be traced (but only if they're actually symbolic! If they're +# static you don't want to put them in the graph, which means you +# shouldn't call this function.) +# +# The common theme is that you only use this function WHEN YOU ARE TRACING +# SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call +# this function without a proxy. +def wrap_fx_proxy_cls( + target_cls, tx, proxy, example_value=None, subclass_type=None, **options +): + if example_value is None: + return _wrap_fx_proxy( + target_cls, tx, proxy, example_value, subclass_type, **options + ) + elif isinstance(example_value, torch.Tensor): + return _wrap_fx_preexisting_tensor( + target_cls, tx, proxy, example_value, subclass_type, **options + ) + else: + # This will skip tracing an op and recursively reinvoke wrap_fx_proxy_cls on supported + # data structures. In essence this just handles tracing some other value which may + # contain Fake Tensors or is otherwise proxyable. + return handle_traced_output( + example_value, tx, proxy, options, subclass_type, target_cls + ) + + +# This is 1 above (wrapping a preexisting tensor) +def _wrap_fx_preexisting_tensor( + target_cls, tx, proxy, tensor, subclass_type=None, **options +): + from ..symbolic_convert import InstructionTranslatorBase + + assert isinstance(tensor, torch.Tensor), ( + f"_wrap_fx_preexisting_tensor expected tensor, got {type(tensor)}" + ) + + assert isinstance(tx, InstructionTranslatorBase) + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + # Placeholders always carry example_value in node.meta. + # non-placeholders always have no example_value in node.meta + if proxy.node.op == "placeholder": + assert "example_value" in proxy.node.meta, ( + f"placeholder {proxy} doesn't have 'example_value' in node.meta" + ) + else: + assert "example_value" not in proxy.node.meta, ( + f"{proxy.node.meta['example_value']}" + ) + + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + # Handle recursive calls here + if maybe_get_fake_mode(tensor) is tx.fake_mode: + pass + else: + cache_real_value_when_export(tx, proxy, tensor) + if tx.export: + # The legacy behavior for real value cache with subclasses was + # to perform a clone WITHOUT preserving the subclass. It's + # not entirely clear this is what you actually want though. + with torch._C.DisableTorchFunctionSubclass(): + proxy.tracer.real_value_cache[proxy.node] = _clone_input( + tensor, tx.fake_mode + ) + # NB: If we're ignoring subclass, then the expectation is you will + # take the returned TensorVariable and wrap it into a more + # accurate TensorVariable that is able to track subclass-ness; + # otherwise this is wrong! + kwargs = { + "is_tensor": target_cls + in (TensorVariable, TensorWithTFOverrideVariable), + } + assert "source" in options and options["source"] is not None + kwargs["source"] = options["source"] + tensor = wrap_to_fake_tensor_and_record(tensor, tx=tx, **kwargs) + + if tensor.device.type != "meta" and ( + maybe_get_fake_mode(tensor) is not tx.fake_mode + ): + raise InternalTorchDynamoError( + "`tensor` needs to be a `FakeTensor`" + f"wrapped by this instance of Dynamo. Found: {tensor}" + ) + + return construct_tensor_variable( + target_cls, tx, proxy, tensor, subclass_type, options + ) + + +# This is 2 in the above comment (wrapping the output of a traced op) +def _wrap_fx_proxy( + target_cls, tx, proxy, example_value=None, subclass_type=None, **options +): + from ..symbolic_convert import InstructionTranslatorBase + + assert isinstance(tx, InstructionTranslatorBase) + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}" + + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + # with preserve_rng_state(): + # only allow_non_graph_fake in this instance because we handle the non-fake + # cases properly below. + example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) + + return handle_traced_output( + example_value, tx, proxy, options, subclass_type, target_cls + ) + + +# This handles wrapping of the output of an op traced into the graph +def handle_traced_output(example_value, tx, proxy, options, subclass_type, target_cls): + import torch._functorch.vmap + import torch._subclasses.fake_tensor + import torch._utils + + if isinstance(example_value, torch.Tensor): + var = construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options + ) + # NOTE: [Side effect tracking for newly constructed tensor] + # For newly constructed objects that have mutable attributes, we usually + # construct their VariableTracker via `track_object_new`, but since + # tensor variable construction is a bit different, we handle them + # specially here. This ensures that codegen will actually generate the + # attribute mutations on this tensor. + # + # NOTE we pass a dummy object as the `item` argument to avoid + # constructing a dummy _tensor_ object. The object isn't used for + # newly constructed VTs anyways. + tx.output.side_effects._track_obj( + proxy, var, mutation_type_cls=AttributeMutationNew + ) + return var + elif ( + hasattr(proxy.node.target, "__name__") + and proxy.node.target.__name__ == "set_state" + and isinstance(proxy.node.target.__self__, torch._C.Generator) + or proxy.node.target == torch.random.set_rng_state + ): + return TorchInGraphFunctionVariable(proxy.node.target) + elif ( + proxy.node.target == torch._C._DisableFuncTorch + or proxy.node.target == torch.cuda._is_in_bad_fork + ): + return UserDefinedObjectVariable(example_value) + elif istype(example_value, torch.Size) and all( + isinstance(x, int) for x in example_value + ): + sizes = [ConstantVariable.create(x) for x in example_value] + return SizeVariable(sizes, **options) + elif isinstance(example_value, (tuple, list)): + set_example_value(proxy.node, example_value) + unpacked = [] + for i, val in enumerate(example_value): + if val is None: + # nn.MultiheadAttention() can return None, see issue #175 + unpacked.append( + ConstantVariable.create(None, **options), + ) + else: + proxy_i = proxy.tracer.create_proxy( + kind="call_function", + target=operator.getitem, + args=(proxy, i), + kwargs={}, + ) + + if "source" in options: + # This path should only trigger for list stealing, so it's + # safe to use `GetItemSource`. + assert isinstance(example_value, list) + source = options["source"] + options_i = options.copy() + options_i["source"] = GetItemSource( + base=source, index=i, index_is_slice=False + ) + else: + # use the same options object as parent + options_i = options + + # WARNING: this assumes the same target_cls as this tuple/list call + unpacked.append( + wrap_fx_proxy_cls( + target_cls=target_cls, + tx=tx, + proxy=proxy_i, + example_value=val, + **options_i, + ) + ) + if isinstance(example_value, torch.Size): + # NB: Keep the old proxy around. See SizeVariable for an + # explanation why + return SizeVariable(unpacked, proxy, **options) + elif istype(example_value, tuple): + return TupleVariable(unpacked, **options) + elif istype(example_value, (list, immutable_list)): + return ListVariable(unpacked, **options) + else: + assert ( + example_value.__class__.__module__ == "torch.return_types" + or hasattr(example_value, "_fields") + ), ( + f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}" + ) + return NamedTupleVariable(unpacked, example_value.__class__, **options) + elif example_value is None or proxy.node.target is torch.manual_seed: + return ConstantVariable.create(None, **options) + elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): + tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) + set_example_value(proxy.node, example_value) + return SymNodeVariable(proxy, example_value, **options) + elif ( + inspect.isclass(proxy.node.target) + and issubclass(proxy.node.target, torch.Stream) + ) or proxy.node.target in [ + device_interface.current_stream + for _, device_interface in get_registered_device_interfaces() + ]: + set_example_value(proxy.node, example_value) + return StreamVariable(proxy, example_value, example_value.device, **options) + elif ( + inspect.isclass(proxy.node.target) + and issubclass(proxy.node.target, torch.Event) + ) or proxy.node.target in [ + device_interface.Event + for _, device_interface in get_registered_device_interfaces() + ]: + set_example_value(proxy.node, example_value) + return EventVariable(proxy, example_value, **options) + elif proxy.node.target == "query" and proxy.node.op == "call_method": + set_example_value(proxy.node, example_value) + return ConstantVariable(example_value, **options) + elif ( + example_value is not None + and isinstance(example_value, torch.Event) + and proxy.node.target == "record_event" + and proxy.node.op == "call_method" + ): + set_example_value(proxy.node, example_value) + return EventVariable(proxy, example_value, **options) + elif isinstance(example_value, int) and ( + proxy.node.target + in [ + torch.sym_int, + getattr, + operator.getitem, + torch._utils._element_size, + torch.seed, + operator.mod, + torch._functorch.vmap._validate_and_get_batch_size, + # some mac builds are missing torch.distributed.get_rank() + getattr(torch.distributed, "get_rank", _missing), + getattr(torch.distributed, "get_world_size", _missing), + # This always wants to be in the graph, even if the constraint + # results in a constant int + torch._constrain_as_size, + ] + or ( + # TODO: this is a little sus, because we didn't check what the self is + proxy.node.op == "call_method" and proxy.node.target in ["bit_length"] + ) + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, torch.backends.cuda.SDPAParams): + from .sdpa import SDPAParamsVariable + + set_example_value(proxy.node, example_value) + return SDPAParamsVariable(proxy, **options) + elif isinstance(example_value, bool) and ( + proxy.node.target + in [ + torch._C._are_functorch_transforms_active, + torch._C._functorch.is_batchedtensor, + torch.backends.cuda.is_flash_attention_available, + torch.backends.cuda.can_use_flash_attention, + torch.backends.cuda.can_use_efficient_attention, + "is_integer", + ] + + list(supported_const_comparison_op_values.keys()) + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + elif ( + isinstance(example_value, (int, float, bool)) + and proxy.node.target is call_torchbind + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, float) or proxy.node.target in ["hex", "__round__"]: + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) + else: + unimplemented_v2( + gb_type="torch.* op returned non-Tensor", + context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}", + explanation="torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output", + hints=[], + ) + + +def construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options +): + """ + Actually construct a tensor variable after all the pre-processing from + wrapping a pre-existing or newly created tensor value. + """ + # NB: In most (all?) cases, this does not actually do a clone. + # (WARNING: this means that if we mutate metadata on the fake + # tensor, the stored example value will update too!) + example_value = _clone_input(example_value, tx.fake_mode) + set_example_value(proxy.node, example_value) + # We bind the unbacked symints in sizes/trdies of tensor lazily. + # So that subgraphs can access the unbacked symbol's proxy in parent graph + # when lifting unbacked symbols of input tensors to subgraph inputs. + # We do it lazily because the tensor may not be used in subgraphs. + tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) + specialized_props = target_cls.specialize(example_value) + # TODO: not sure about this fake mode test + if ( + isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) + and example_value.fake_mode is tx.fake_mode + ): + if subclass_type: + tensor_type = subclass_type + elif isinstance(example_value, torch.nn.Parameter): + tensor_type = torch.nn.Parameter + elif isinstance(example_value, torch.nn.Buffer): + tensor_type = torch.nn.Buffer + else: + tensor_type = torch.Tensor + specialized_props["class_type"] = tensor_type + + options.update(specialized_props) + return target_cls(proxy, **options) + + +def get_automatic_dynamic_shapes_mark_as(): + if config.automatic_dynamic_shapes_mark_as == "dynamic": + return DimDynamic.DYNAMIC + elif config.automatic_dynamic_shapes_mark_as == "unbacked": + return DimDynamic.SIZE_LIKE_UNBACKED + elif config.automatic_dynamic_shapes_mark_as == "oblivious": + return DimDynamic.OBLIVIOUS_SIZE + else: + raise ValueError( + f"invalid automatic_dynamic_shapes_mark_as = {config.automatic_dynamic_shapes_mark_as}" + ) + + +_DYNAMIC_SOURCES: Optional[set[str]] = None +_DYNAMIC_SOURCES_CONFIG_HASH: Optional[int] = None + + +def get_dynamic_sources() -> set[str]: + global _DYNAMIC_SOURCES, _DYNAMIC_SOURCES_CONFIG_HASH + + current_hash = hash(torch.compiler.config.dynamic_sources) + + # If we have already calculated the sources and the config hasn't changed, return cached result + if _DYNAMIC_SOURCES is not None and _DYNAMIC_SOURCES_CONFIG_HASH == current_hash: + return _DYNAMIC_SOURCES + + # Config has changed or first time, (re)calculate the sources + _DYNAMIC_SOURCES = { + s + for s in torch.compiler.config.dynamic_sources.replace(" ", "").split(",") + if s + } + _DYNAMIC_SOURCES_CONFIG_HASH = current_hash + + return _DYNAMIC_SOURCES + + +def is_dynamic_source(source_name: str) -> bool: + dynamic_sources = get_dynamic_sources() + for pattern in dynamic_sources: + if pattern == source_name or re.match(pattern, source_name): + log.debug( + "%s was marked dynamic due to dynamic source allowlist pattern: %s", + source_name, + pattern, + ) + return True + return False + + +def record_automatic_dynamic( + tx: "InstructionTranslator", name: str, e: torch.Tensor +) -> FrameStateSizeEntry: + # This mimics stride inference algorithm in _create_symbolic_sizes_strides_storage_offset + ex_size = e.size() + if not is_sparse_any(e): + ex_stride = e.stride() + dim = e.dim() + + stride = [None] * dim + pending = [(ex_stride[i], -i) for i in range(dim)] + pending.sort(key=_nested_int_aware_sort) + candidates = {} + for i_stride, neg_i in pending: + i = -neg_i + stride[i] = candidates.get(i_stride, i_stride) + candidates.setdefault(i_stride * ex_size[i], InferStride(i)) + else: + stride = [] + + return process_automatic_dynamic( + tx, name, FrameStateSizeEntry.make_tensor(tuple(ex_size), tuple(stride)) + ) + + +_UNBACKED_SOURCES: Optional[set[str]] = None +_UNBACKED_SOURCES_CONFIG_HASH: Optional[int] = None + + +def get_unbacked_sources() -> set[str]: + global _UNBACKED_SOURCES, _UNBACKED_SOURCES_CONFIG_HASH + + current_hash = hash(torch.compiler.config.unbacked_sources) + + # If we have already calculated the sources and the config hasn't changed, return cached result + if _UNBACKED_SOURCES is not None and _UNBACKED_SOURCES_CONFIG_HASH == current_hash: + return _UNBACKED_SOURCES + + # Config has changed or first time, (re)calculate the sources + _UNBACKED_SOURCES = { + s + for s in torch.compiler.config.unbacked_sources.replace(" ", "").split(",") + if s + } + _UNBACKED_SOURCES_CONFIG_HASH = current_hash + + return _UNBACKED_SOURCES + + +def is_unbacked_source(source_name: str) -> bool: + unbacked_sources = get_unbacked_sources() + for pattern in unbacked_sources: + if pattern == source_name or re.match(pattern, source_name): + log.debug( + "%s was marked unbacked due to unbacked source allowlist pattern: %s", + source_name, + pattern, + ) + return True + return False + + +# Performs automatic dynamic dim determination. +# Returns a SymbolicContext +def _automatic_dynamic( + e, tx, source, static_shapes, outer_only=False +) -> SymbolicContext: + # strided NT not supported + if e.is_nested and not isinstance( + e, torch.nested._internal.nested_tensor.NestedTensor + ): + unimplemented_v2( + gb_type="Encountered strided NestedTensor in automatic dynamic dim determination", + context="", + explanation="torch.compile does not support strided NestedTensor", + hints=[], + ) + + name = source.name() + prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) + shape_env_to_source_to_symbol_cache = ( + prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None + ) + + # Get base context if the tensor is a view + view_base_context: Optional[SymbolicContext] = None + if e._is_view(): + base_source = AttrSource(source, "_base") + view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes) + + if is_traceable_wrapper_subclass(e) and not outer_only: + # Get symbolic context for outer tensor + outer_context = _automatic_dynamic( + e, tx, source, static_shapes, outer_only=True + ) + + # Get symbolic contexts for inner tensors + inner_contexts = {} # mapping from attr -> symbolic context + attrs, _ = type(e).__tensor_flatten__(e) + for attr in attrs: + inner_tensor = getattr(e, attr) + inner_source = AttrSource(source, attr) + inner_contexts[attr] = _automatic_dynamic( + inner_tensor, tx, inner_source, static_shapes + ) + + return SubclassSymbolicContext( + dynamic_sizes=outer_context.dynamic_sizes, + dynamic_strides=outer_context.dynamic_strides, + constraint_sizes=outer_context.constraint_sizes, + constraint_strides=outer_context.constraint_strides, + view_base_context=view_base_context, + tensor_source=outer_context.tensor_source, + shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache, + inner_contexts=inner_contexts, + ) + + if static_shapes and not is_dynamic_source(name): + record_automatic_dynamic(tx, name, e) + return StatefulSymbolicContext( + dynamic_sizes=[DimDynamic.STATIC] * e.dim(), + dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), + constraint_sizes=[None] * e.dim(), + constraint_strides=[None] * e.dim(), + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + # We preserve the dynamism of inputs. For example, when users call + # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. + from torch.fx.experimental.symbolic_shapes import is_nested_int + + if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()): + return StatefulSymbolicContext( + dynamic_sizes=[ + DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC + for s in e.size() + ], + dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), + constraint_sizes=[None] * e.dim(), + constraint_strides=[None] * e.dim(), + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + # Prep for automatic dynamic + frame_state_entry = record_automatic_dynamic(tx, name, e) + + # TODO: index export_constraints ahead of time so we don't have to + # do a linear scan every time here + t_id = id(e) + dim2constraint = {} + + def update_dim2constraint(dim, constraint_range, name): + if dim in dim2constraint: + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + + old_constraint_range, old_name = dim2constraint[dim] + new_constraint_range = StrictMinMaxConstraint( + vr=constraint_range.vr & old_constraint_range.vr, + warn_only=False, + ) + # It is possible for (non-None) old_name and name to be different + # but this will only happen the corresponding Dims can be derived equal. + new_name = old_name or name + dim2constraint[dim] = new_constraint_range, new_name + else: + dim2constraint[dim] = constraint_range, name + + from torch.export.dynamic_shapes import _RelaxedConstraint + + if tx.output.export_constraints: + for constraint in tx.output.export_constraints: + if isinstance(constraint, _RelaxedConstraint): + continue + if constraint.t_id == t_id: + update_dim2constraint( + constraint.dim, constraint.constraint_range, constraint.name + ) + + dynamic_sizes = [] + dynamic_strides = [] + constraint_sizes = [] + constraint_strides = [] + specialize_on = [] + for i in range(e.dim()): + # NB: mark dynamic has precedence over static + marked_strict_unbacked = i in getattr( + e, "_dynamo_strict_unbacked_indices", set() + ) + marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set()) + marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) + marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) + marked_static = i in getattr(e, "_dynamo_static_indices", set()) + + specialize_on.append(getattr(e, "_specialize_on", {}).get(i, [])) + + # Reflect the user directive in the frame_state + # For dynamic, apply None always + + normalized_source_name = normalize_source_name(source.name()) + base_source = source + if isinstance(base_source, ChainedSource): + base_source = base_source.get_base() + + if marked_dynamic or ( + isinstance(base_source, LocalSource) + and base_source.dynamism is not None + and dict(base_source.dynamism).get(normalized_source_name, {i: False})[i] + ): + # TODO: This can be batched + # TODO: Doing this here is kind of sus, maybe better to set this + # up when we initially created the FrameStateSizeEntry to bong + # into the mutable state + log.debug("automatic dynamic %s marked dynamic", name) + mark_size = [auto_unset] * e.dim() + mark_size[i] = auto_dynamic + frame_state_entry |= FrameStateSizeEntry.make_size(size=mark_size) + + # NB: both static and dynamic have precedence over + automatic_dynamic_size = ( + config.automatic_dynamic_shapes and frame_state_entry.is_size_dynamic(i) + ) + # NB: previously, if size was dynamic, we wouldn't make its stride + # dynamic. But now, because of InferStride concept, we will properly + # not make stride dynamic even if it's wobbling + automatic_dynamic_stride = ( + config.automatic_dynamic_shapes and frame_state_entry.is_stride_dynamic(i) + ) + + if is_dynamic_source(name): + log.debug("%s marked dynamic via source whitelist", name) + automatic_dynamic_size = True + automatic_dynamic_stride = True + + if is_unbacked_source(name): + log.debug("%s marked unbacked via source whitelist", name) + automatic_dynamic_size = True + automatic_dynamic_stride = True + + automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride + + # We will process constraints first, as they will imply that we + # have a dynamic dimension + # Precedence: export constraints > eager constraints + constraint = dim2constraint.get(i) + if constraint is None: + constraint_size = None + constraint_stride = None + if marked_dynamic and not config.allow_ignore_mark_dynamic: + # constraint_stride is deliberaly kept None because no easy way to provide value ranges for mark dynamic + constraint_stride = None + if hasattr(e, "_dynamo_dynamic_range"): + dim_range = [ + dr for dr in e._dynamo_dynamic_range if dr.dim == i + ].pop() + if dim_range.min is None and dim_range.max is None: + constraint_size = RelaxedUnspecConstraint(warn_only=False) + else: + from torch.fx.experimental.symbolic_shapes import ( + StrictMinMaxConstraint, + ) + + constraint_size = StrictMinMaxConstraint( + vr=ValueRanges(lower=dim_range.min, upper=dim_range.max), + warn_only=False, + ) + else: + constraint_size = RelaxedUnspecConstraint(warn_only=False) + elif marked_strict_unbacked: + constraint_size = RelaxedUnspecConstraint(warn_only=False) + elif not marked_static and automatic_dynamic: + set_feature_use("dynamo.automatic_dynamic_shapes", True) + if automatic_dynamic_size: + constraint_size = RelaxedUnspecConstraint(warn_only=True) + if automatic_dynamic_stride: + constraint_stride = RelaxedUnspecConstraint(warn_only=True) + else: + if not marked_static and not config.automatic_dynamic_shapes: + set_feature_use("dynamo.automatic_dynamic_shapes", False) + constraint_size = None + constraint_stride = None + else: + constraint_size, name_ = constraint + constraint_stride = None + dim_name = f"{name}.size()[{i}]" + tx.output.shape_env.source_name_to_debug_name[dim_name] = name_ + constraint_sizes.append(constraint_size) + constraint_strides.append(constraint_stride) + + if marked_unbacked or is_unbacked_source(name): + dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED + elif ( + constraint_size is not None + or marked_dynamic + or marked_weak_dynamic + or is_nested_int(e.size()[i]) + ): + # NB: We could assert static_shapes is False here, but it + # seems better to allow the user to override symbolic_context in this + # case + if automatic_dynamic: + dynamic_size = get_automatic_dynamic_shapes_mark_as() + else: + dynamic_size = DimDynamic.DYNAMIC + elif static_shapes or config.assume_static_by_default or marked_static: + dynamic_size = DimDynamic.STATIC + else: + # TODO: When does this show up? + dynamic_size = DimDynamic.DUCK + + if constraint_stride is not None: + dynamic_stride = DimDynamic.DYNAMIC + else: + dynamic_stride = DimDynamic.INFER_STRIDE + + dynamic_sizes.append(dynamic_size) + dynamic_strides.append(dynamic_stride) + + return StatefulSymbolicContext( + dynamic_sizes=dynamic_sizes, + dynamic_strides=dynamic_strides, + constraint_sizes=constraint_sizes, + constraint_strides=constraint_strides, + specialize_on=specialize_on, + view_base_context=view_base_context, + tensor_source=source, + shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, + ) + + +# See note [Tensor Fakification and Symbol Caching] +def wrap_to_fake_tensor_and_record( + e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None +): + if ( + type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) + or isinstance(e, torch.Tensor) + or is_traceable_wrapper_subclass(e) + ): + assert source is not None + static_shapes, _reason = tensor_always_has_static_shape( + e, + is_tensor, + tensor_source=source, + ) + + if not parent_context: + symbolic_context = _automatic_dynamic(e, tx, source, static_shapes) + else: + # Parent contexts are passed in when we are recursively creating + # fake tensors for subclasses. A better design would be not to create a + # parent/child relationship, but to recursively call _automatic_dynamic + # as we recursively call wrap_to_fake_tensor_and_record. This runs + # into bugs around how meta_utils knows and works to create fake tensors + # with tensor subclasses. Ideally, dynamo would drive both the recursive + # wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation. + assert isinstance(source, AttrSource) + inner_context_name = source.member + symbolic_context = parent_context.inner_contexts[inner_context_name] + + log.debug( + "wrap_to_fake %s %s %s %s", + source.name(), + tuple(e.shape), + symbolic_context, + type(e), + ) + + fake_e = wrap_fake_exception( + lambda: tx.fake_mode.from_tensor( + e, + source=source, + symbolic_context=symbolic_context, + ) + ) + if ( + source is not None + and isinstance(fake_e, FakeTensor) + and (sym_val := fake_e.item_memo) is not None + ): + tx.output.tracked_fakes.append( + TrackedFake(sym_val, CallMethodItemSource(source), symbolic_context) + ) + + if is_traceable_wrapper_subclass(fake_e): + attrs, _ = fake_e.__tensor_flatten__() + for attr in attrs: + fake_inner = getattr(fake_e, attr) + inner = getattr(e, attr) + inner_source = AttrSource(source, attr) + wrap_to_fake_tensor_and_record( + inner, + tx, + source=inner_source, + is_tensor=isinstance(fake_inner, torch.Tensor), + parent_context=symbolic_context, + ) + + tx.output.tracing_context.tensor_to_context[e] = symbolic_context + if is_sparse_any(fake_e): + # TODO: for TensorGuards, this eventually may need more + # fields for the size/stride of any other constituents + values = fake_e._values() if fake_e.is_sparse else fake_e.values() + tx.output.input_source_to_sizes_strides[source] = { + "size": fake_e.size(), + # TODO: revise this, but for now this stride instead of () + # avoids SegFault with PYTORCH_TEST_WITH_DYNAMO=1 + "stride": (1,) * fake_e.ndim, + "values_size": values.size(), + "values_stride": values.stride(), + } + else: + tx.output.input_source_to_sizes_strides[source] = { + "size": fake_e.size(), + "stride": fake_e.stride(), + } + + if ( + is_tensor + and not (static_shapes and source.is_specialized_nn_module()) + and not is_constant_source(source) + ): + tx.output.tracked_fakes.append( + TrackedFake(fake_e, source, symbolic_context) + ) + tx.output.tracked_fakes_id_to_source[id(e)].append(source) + + return fake_e + else: + return e + + +class SourcelessBuilder: + """ + Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects + that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over + .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However, + there may be reasons to represent it as a ListVariable internally. + + NOTE - Objects produced here are born UNGUARDED due to the nature of sources! + + NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant + if/else type->VariableTracker trees that were cropping up all over dynamo. + """ + + def __init__(self) -> None: + raise AssertionError("Use SourcelessBuilder.create()") + + @staticmethod + def create(tx: "InstructionTranslator", value) -> VariableTracker: + value_type = type(value) + fast_handler = SourcelessBuilder._type_handlers.get(value_type) + if fast_handler: + return fast_handler(tx, value) + + if isinstance(value, VariableTracker): + # This is always valid to call, and useful for recursive calls. + return value + elif isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS): + return UserDefinedObjectVariable(value) + elif ConstantVariable.is_literal(value): + return ConstantVariable.create(value) + elif callable(value) and trace_rules.lookup_callable(value) is not None: + if trace_rules.is_callable_allowed(value): + tx.output.has_user_defined_allowed_in_graph = True + return trace_rules.lookup_callable(value)(value) + elif is_function_or_wrapper(value): + return trace_rules.lookup(value)(value) + elif isinstance( + value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) + ): + return EnumVariable(value) + elif isinstance(value, (type, abc.ABCMeta)): + return UserDefinedClassVariable(value) + elif isinstance(value, types.MethodWrapperType): + return MethodWrapperVariable(value) + elif ( + isinstance(value, types.MethodType) + # We only want to support sourceless class objects here + # An instance variable is not allowed and it should have source + and isinstance(value.__self__, (type, abc.ABCMeta)) + ): + # value is a classmethod + assert getattr(value.__self__, value.__func__.__name__) == value + cls_obj_vt = SourcelessBuilder.create(tx, value.__self__) + try: + return cls_obj_vt.var_getattr(tx, value.__func__.__name__) + except NotImplementedError: + pass # failthrough to unimplemented branch + elif isinstance(value, torch.fx.graph_module.GraphModule): + return SourcelessGraphModuleVariable(value) + elif isinstance( + value, (torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec) + ): + return UserDefinedObjectVariable(value) + elif PlacementVariable.is_placement(value): + return PlacementVariable(value) + elif DeviceMeshVariable.is_device_mesh(value): + return DeviceMeshVariable(value) + elif value is functools.wraps: + return FunctoolsWrapsVariable(value) + elif isinstance(value, re.Pattern): + return RegexPatternVariable(value) + elif isinstance(value, torch._dynamo.variables.lazy.LazySymNodeFormatString): + return ConstantVariable.create(str(value)) + elif isinstance(value, type(torch._higher_order_ops.flex_attention_backward)): + return torch._dynamo.variables.higher_order_ops.FlexAttentionBackwardHighOrderVariable( + value + ) + elif isinstance(value, types.GenericAlias): + return TypingVariable(value) + elif is_namedtuple(value): + output = [ + SourcelessBuilder.create(tx, getattr(value, name)) + for name in namedtuple_fields(type(value)) + ] + return NamedTupleVariable(output, tuple_cls=type(value)) + elif ( + isinstance(value, torch.SymInt) + and value.node.expr in tx.output.bound_symbols + ): + proxy = tx.output.bound_symbols[value.node.expr] + return SymNodeVariable.create(tx, proxy) + unimplemented_v2( + gb_type="Unexpected type in sourceless builder", + context=f"{value_type.__module__}.{value_type.__qualname__}", + explanation=f"SourcelessBuilder.create does not know how to wrap {value_type}", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + @staticmethod + def wrap_constant_literal(value): + assert ConstantVariable.is_literal(value) + return ConstantVariable.create(value=value) + + @staticmethod + def make_type_handlers(): + create = SourcelessBuilder.create + handlers = {} + for t in common_constant_types: + handlers[t] = lambda tx, value: ConstantVariable(value) + handlers[set] = lambda tx, value: SetVariable( + [create(tx, x) for x in value], mutation_type=ValueMutationNew() + ) + handlers[dict] = lambda tx, value: ConstDictVariable( + {create(tx, k): create(tx, v) for k, v in value.items()}, + type(value), + mutation_type=ValueMutationNew(), + ) + handlers[list] = lambda tx, value: ListVariable( + [create(tx, x) for x in value], mutation_type=ValueMutationNew() + ) + handlers[tuple] = lambda tx, value: TupleVariable( + [create(tx, x) for x in value] + ) + handlers[torch.Size] = lambda tx, value: SizeVariable( + [create(tx, x) for x in value] + ) + handlers[collections.OrderedDict] = handlers[dict] + handlers[immutable_dict] = handlers[dict] + handlers[immutable_list] = handlers[list] + handlers[random.Random] = lambda tx, value: RandomClassVariable() + handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value) + + handlers[torch.DispatchKeySet] = lambda tx, value: DispatchKeySetVariable( + value, mutation_type=ValueMutationNew() + ) + handlers[torch._functorch.pyfunctorch.FuncTorchInterpreter] = ( + lambda tx, value: FuncTorchInterpreterVariable( + value, mutation_type=ValueMutationNew() + ) + ) + + handlers[torch.distributions.constraints._Real] = ( + lambda tx, value: UserDefinedObjectVariable( + value, mutation_type=ValueMutationNew() + ) + ) + handlers[torch.distributions.constraints._Interval] = ( + lambda tx, value: UserDefinedObjectVariable( + value, mutation_type=ValueMutationNew() + ) + ) + handlers[torch.distributions.constraints.Constraint] = ( + lambda tx, value: UserDefinedObjectVariable( + value, mutation_type=ValueMutationNew() + ) + ) + + def passthrough(tx: "InstructionTranslator", value): + return value + + for cls in VariableTrackerMeta.all_subclasses: + handlers[cls] = passthrough + return handlers + + +SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers() + + +class SourcelessUserDefinedObjectBuilder: + """ + SourceLessBuilder does not return a UserDefinedObjectVariable, but in some + cases it might be ok to return UserDefinedObjects. In such case, use this + builder. + """ + + def __init__(self) -> None: + raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()") + + @staticmethod + def create(tx: "InstructionTranslator", value) -> VariableTracker: + value_type = type(value) + if issubclass(value_type, MutableMapping): + return MutableMappingVariable(value, mutation_type=ValueMutationNew()) + elif isinstance(value, torch.nn.Module): + return UnspecializedNNModuleVariable( + value, mutation_type=ValueMutationNew() + ) + else: + return UserDefinedObjectVariable(value, mutation_type=ValueMutationNew()) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/builtin.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/builtin.py new file mode 100644 index 0000000000000000000000000000000000000000..e151865ad5de42d2b3fe38dab6c41499de962b86 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/builtin.py @@ -0,0 +1,2552 @@ +# mypy: allow-untyped-defs + +""" +Built-in function and type variable tracking for TorchDynamo's symbolic execution. + +This module contains variable tracker classes for Python built-in functions, types, +and operations during graph compilation. It handles symbolic execution of: + +- Built-in functions (len, getattr, isinstance, etc.) +- Type constructors (int, float, str, list, dict, etc.) +- Built-in operators and methods +- Special Python constructs (super, hasattr, etc.) + +Key classes: +- BuiltinVariable: Tracks built-in functions and handles their execution +- TypeVariable: Manages type constructor calls and type checking +- SuperVariable: Handles super() calls in class hierarchies + +These variable trackers ensure that built-in Python operations are correctly +handled during symbolic execution, either by executing them directly when safe +or by creating appropriate graph nodes when needed. +""" + +import contextlib +import functools +import inspect +import itertools +import logging +import math +import operator +import sys +import types +import typing +import unittest +from collections import defaultdict, OrderedDict +from collections.abc import KeysView, Sequence +from typing import Callable, TYPE_CHECKING, Union + +import torch +from torch import sym_float, sym_int +from torch._subclasses.meta_utils import is_sparse_any +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config, graph_break_hints, polyfills, variables +from ..exc import ( + AttributeMutationError, + ObservedAttributeError, + raise_observed_exception, + unimplemented_v2, + Unsupported, + UserError, + UserErrorType, +) +from ..guards import GuardBuilder, install_guard +from ..replay_record import DummyModule +from ..source import ( + AttrSource, + GetItemSource, + GlobalSource, + is_constant_source, + TypeSource, +) +from ..utils import ( + check_constant_args, + check_numpy_ndarray_args, + check_unspec_or_constant_args, + check_unspec_python_args, + cmp_name_to_op_mapping, + dict_methods, + extract_fake_example_value, + get_fake_value, + guard_if_dyn, + is_tensor_getset_descriptor, + is_wrapper_or_member_descriptor, + istype, + numpy_operator_wrapper, + proxy_args_kwargs, + str_methods, + tensortype_to_dtype, +) +from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .ctx_manager import EventVariable, StreamVariable +from .dicts import ( + ConstDictVariable, + DefaultDictVariable, + DictViewVariable, + FrozensetVariable, + is_hashable, + SetVariable, +) +from .lists import ( + BaseListVariable, + ListIteratorVariable, + ListVariable, + SizeVariable, + TupleIteratorVariable, + TupleVariable, +) +from .tensor import ( + FakeItemVariable, + supported_comparison_ops, + SymNodeVariable, + TensorVariable, + UnspecializedPythonVariable, +) +from .user_defined import UserDefinedObjectVariable, UserDefinedVariable + + +if TYPE_CHECKING: + # Cyclic dependency... + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + +log = logging.getLogger(__name__) + + +IN_PLACE_DESUGARING_MAP = { + operator.iadd: operator.add, + operator.isub: operator.sub, + operator.imul: operator.mul, + operator.ifloordiv: operator.floordiv, + operator.itruediv: operator.truediv, + operator.imod: operator.mod, + operator.imatmul: operator.imatmul, + operator.ilshift: operator.lshift, + operator.irshift: operator.rshift, + operator.ipow: operator.pow, + operator.iand: operator.and_, + operator.ior: operator.or_, + operator.ixor: operator.xor, +} + + +_HandlerCallback = Callable[ + ["InstructionTranslator", typing.Any, typing.Any], VariableTracker +] +_TrackersType = Union[type[VariableTracker], tuple[type[VariableTracker], ...]] +polyfill_fn_mapping = { + operator.eq: polyfills.cmp_eq, + operator.ne: polyfills.cmp_ne, + operator.lt: polyfills.cmp_lt, + operator.le: polyfills.cmp_le, + operator.gt: polyfills.cmp_gt, + operator.ge: polyfills.cmp_ge, +} + + +class BuiltinVariable(VariableTracker): + """ + A VariableTracker that represents a built-in value (functions and operators). + A lot of the code here assumes it will be a function object. + + The BuiltinVariable class wraps Python built-in functions (like len, isinstance, etc.) + and operators (like +, -, *, etc.) to enable symbolic execution during tracing. This allows + Dynamo to properly handle these operations when converting Python code to FX graphs while + maintaining correct semantics and enabling optimizations. + """ + + _SENTINEL = object() + _nonvar_fields = { + "fn", + *VariableTracker._nonvar_fields, + } + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + return cls(value, source=source) + + @staticmethod + @functools.cache + def _constant_fold_functions(): + fns = { + abs, + all, + any, + bool, + callable, + chr, + divmod, + float, + getattr, + int, + len, + max, + min, + ord, + pow, + repr, + round, + str, + str.format, + sum, + type, + operator.abs, + operator.pos, + operator.neg, + operator.not_, + operator.truth, + operator.invert, + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.sub, + operator.getitem, + operator.length_hint, + operator.lshift, + operator.rshift, + operator.and_, + operator.or_, + operator.xor, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, + operator.ilshift, + operator.irshift, + operator.iand, + operator.ixor, + operator.ior, + operator.index, + } + from .tensor import supported_comparison_ops + + fns.update(supported_comparison_ops.values()) + fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt))) + return fns + + def can_constant_fold_through(self): + return self.fn in self._constant_fold_functions() + + @staticmethod + @functools.cache + def _fx_graph_functions(): + fns = { + operator.abs, + operator.pos, + operator.neg, + operator.not_, + operator.invert, + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.length_hint, + operator.lshift, + operator.rshift, + operator.and_, + operator.or_, + operator.xor, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.getitem, + operator.imod, + operator.iadd, + operator.isub, + operator.ilshift, + operator.irshift, + operator.iand, + operator.ixor, + operator.ior, + } + return fns + + @staticmethod + @functools.cache + def _binops() -> dict[ + Callable[..., object], tuple[list[str], Callable[..., object]] + ]: + # function -> ([forward name, reverse name, in-place name], in-place op) + fns: dict[Callable[..., object], tuple[list[str], Callable[..., object]]] = { + operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd), + operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub), + operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul), + operator.truediv: ( + ["__truediv__", "__rtruediv__", "__itruediv__"], + operator.itruediv, + ), + operator.floordiv: ( + ["__floordiv__", "__rfloordiv__", "__ifloordiv__"], + operator.ifloordiv, + ), + operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod), + pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), + operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), + operator.lshift: ( + ["__lshift__", "__rlshift__", "__ilshift__"], + operator.ilshift, + ), + operator.rshift: ( + ["__rshift__", "__rrshift__", "__irshift__"], + operator.irshift, + ), + # NB: The follow binary operators are not supported for now, since the + # corresponding magic methods aren't defined on SymInt / SymFloat: + # operator.matmul + # divmod + # operator.and_ + # operator.or_ + # operator.xor + } + return fns + + @staticmethod + @functools.cache + def _binop_handlers(): + # Multiple dispatch mechanism defining custom binop behavior for certain type + # combinations. Handlers are attempted in order, and will be used if the type checks + # match. They are expected to have the signature: + # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker + from .functions import BaseUserFunctionVariable, UserFunctionVariable + from .nn_module import NNModuleVariable + from .tensor import supported_const_comparison_ops + from .torch import BaseTorchVariable + from .user_defined import ( + UserDefinedClassVariable, + UserDefinedObjectVariable, + UserDefinedVariable, + ) + + # Override table contains: op_fn -> [list of handlers] + op_handlers: dict[ + Callable[..., object], + list[ + tuple[ + tuple[ + type[VariableTracker], + _TrackersType, + ], + _HandlerCallback, + ] + ], + ] = {} + for ( + op, + (magic_method_names, in_place_op), + ) in BuiltinVariable._binops().items(): + op_handlers[op] = [] + op_handlers[in_place_op] = [] + + forward_name, reverse_name, inplace_name = magic_method_names + + # User-defined args (highest precedence) + def user_defined_handler( + tx, + a, + b, + *, + forward_name=forward_name, + reverse_name=reverse_name, + ): + # Manually handle reversing logic if needed (e.g. call __radd__) + + # TODO: If we expand this to handle tensor args, we need to manually + # handle cases like this: + # + # class A(int): + # def __radd__(self, other): + # print("woof") + # torch.randn(3) + A(3) + # + # In this example, A.__radd__() is not called -> nothing is printed, because + # Tensor.__add__ only does a subtype test against int, ignoring the subclass. + # To be fully correct, we should not call A.__radd__() here, and there may be + # other cases to reason about and add exceptions for. + if isinstance(a, UserDefinedVariable): + return a.call_method(tx, forward_name, [b], {}) + else: + return b.call_method(tx, reverse_name, [a], {}) + + op_handlers[op].append( + ((UserDefinedVariable, VariableTracker), user_defined_handler) + ) + op_handlers[op].append( + ((VariableTracker, UserDefinedVariable), user_defined_handler) + ) + + def user_defined_inplace_handler( + tx: "InstructionTranslator", a, b, *, forward_name=inplace_name + ): + return a.call_method(tx, forward_name, [b], {}) + + op_handlers[in_place_op].append( + ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler) + ) + op_handlers[in_place_op].append( + ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler) + ) + + # Dynamic shape args + def dynamic_handler(tx: "InstructionTranslator", a, b, *, fn=op): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", fn, *proxy_args_kwargs([a, b], {}) + ), + ) + + op_handlers[op].append( + ((SymNodeVariable, VariableTracker), dynamic_handler) + ) + op_handlers[op].append( + ((VariableTracker, SymNodeVariable), dynamic_handler) + ) + + # NB: Prefer out-of-place op when calling in-place op to generate valid graph + op_handlers[in_place_op].append( + ((SymNodeVariable, VariableTracker), dynamic_handler) + ) + op_handlers[in_place_op].append( + ((VariableTracker, SymNodeVariable), dynamic_handler) + ) + + # Special cases - lower precedence but still prefer these over constant folding + + # List-like addition (e.g. [1, 2] + [3, 4]) + def tuple_add_handler(tx: "InstructionTranslator", a, b): + return TupleVariable([*a.items, *b.unpack_var_sequence(tx)]) + + def size_add_handler(tx: "InstructionTranslator", a, b): + return SizeVariable([*a.items, *b.unpack_var_sequence(tx)]) + + list_like_addition_handlers: list[ + tuple[ + tuple[ + type[VariableTracker], + _TrackersType, + ], + _HandlerCallback, + ] + ] = [ + # NB: Prefer the tuple-specific logic over base logic because of + # some SizeVariable weirdness. Specifically, the tuple-specific logic + # drops the subclass type (e.g. SizeVariable) and returns TupleVariables. + ( + (SizeVariable, SizeVariable), + size_add_handler, + ), + ( + (SizeVariable, TupleVariable), + size_add_handler, + ), + ( + (TupleVariable, SizeVariable), + size_add_handler, + ), + ( + (TupleVariable, TupleVariable), + tuple_add_handler, + ), + ( + (TupleVariable, ConstantVariable), + tuple_add_handler, + ), + ( + (ConstantVariable, TupleVariable), + lambda tx, a, b: TupleVariable( + [ + *a.unpack_var_sequence(tx), + *b.items, + ], + ), + ), + ( + ( + ListVariable, + (BaseListVariable, ConstantVariable, ListIteratorVariable), + ), + lambda tx, a, b: ListVariable( + [*a.items, *b.unpack_var_sequence(tx)], + mutation_type=ValueMutationNew(), + ), + ), + ( + (BaseListVariable, BaseListVariable), + lambda tx, a, b: type(a)( + [ + *a.items, + *b.items, + ] + ), + ), + ] + op_handlers[operator.add].extend(list_like_addition_handlers) + + def list_iadd_handler(tx: "InstructionTranslator", a, b): + if a.is_immutable() or not b.has_unpack_var_sequence(tx): + # Handler doesn't apply + return None + + seq = b.unpack_var_sequence(tx) + tx.output.side_effects.mutation(a) + a.items.extend(seq) + return a + + list_like_iadd_handlers: list[ + tuple[ + tuple[type[VariableTracker], type[VariableTracker]], + _HandlerCallback, + ] + ] = [ + ( + (ListVariable, VariableTracker), + list_iadd_handler, + ), + ( + (TupleVariable, TupleVariable), + tuple_add_handler, + ), + ( + (TupleVariable, ConstantVariable), + tuple_add_handler, + ), + ] + op_handlers[operator.iadd].extend(list_like_iadd_handlers) + + # List-like expansion (e.g. [1, 2, 3] * 3) + def expand_list_like(tx: "InstructionTranslator", lst, const): + if isinstance(lst, ConstantVariable): + lst, const = const, lst + return lst.__class__( + items=lst.items * const.as_python_constant(), + mutation_type=ValueMutationNew(), + ) + + list_like_expansion_handlers: list[ + tuple[ + tuple[type[VariableTracker], type[VariableTracker]], + _HandlerCallback, + ] + ] = [ + ((ListVariable, ConstantVariable), expand_list_like), + ((TupleVariable, ConstantVariable), expand_list_like), + ((ConstantVariable, ListVariable), expand_list_like), + ((ConstantVariable, TupleVariable), expand_list_like), + ] + op_handlers[operator.mul].extend(list_like_expansion_handlers) + + def create_cmp_op_handlers(op): + def compare_by_value(tx: "InstructionTranslator", a, b): + return ConstantVariable(op(a.value, b.value)) + + result: list[ + tuple[ + tuple[ + _TrackersType, + _TrackersType, + ], + _HandlerCallback, + ] + ] = [((ConstantVariable, ConstantVariable), compare_by_value)] + + if op in polyfill_fn_mapping: + # For constants, speedup the comparison instead of using + # polyfill. Removing this line causes major regression for pr + # time benchmark - add_loop_eager. + result = [((ConstantVariable, ConstantVariable), compare_by_value)] + + op_var = BuiltinVariable(op) + # Special handling of SymNode variable + result.extend( + [ + ( + (SymNodeVariable, VariableTracker), + op_var._comparison_with_symnode, + ), + ( + (VariableTracker, SymNodeVariable), + op_var._comparison_with_symnode, + ), + ] + ) + + def handler(tx, a, b): + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfill_fn_mapping[op]), [a, b], {} + ) + + result.append(((VariableTracker, VariableTracker), handler)) + return result + + result = [((ConstantVariable, ConstantVariable), compare_by_value)] + + if op in supported_const_comparison_ops.values() and op.__name__.startswith( + "is_" + ): + # Tensor is None, List is not None, etc + none_result = op(object(), None) + + def never(tx: "InstructionTranslator", a, b): + return ConstantVariable(none_result) + + obj_op_none = never + none_op_obj = never + + types_that_are_never_none = ( + TensorVariable, + SymNodeVariable, + NNModuleVariable, + BaseListVariable, + UserDefinedVariable, + BaseUserFunctionVariable, + ConstDictVariable, + BaseTorchVariable, + ) + result.extend( + [ + ( + (types_that_are_never_none, ConstantVariable), + obj_op_none, + ), + ( + (ConstantVariable, types_that_are_never_none), + none_op_obj, + ), + ] + ) + + op_var = BuiltinVariable(op) + result.extend( + [ + ( + ( + (UserFunctionVariable, BuiltinVariable), + (UserFunctionVariable, BuiltinVariable), + ), + lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)), + ), + ( + ( + NNModuleVariable, + NNModuleVariable, + ), + lambda tx, a, b: ConstantVariable( + op( + tx.output.get_submodule(a.module_key), + tx.output.get_submodule(b.module_key), + ) + ), + ), + ( + (UserDefinedObjectVariable, UserDefinedObjectVariable), + compare_by_value, + ), + ( + (UserDefinedClassVariable, UserDefinedClassVariable), + compare_by_value, + ), + ( + ( + (StreamVariable, EventVariable, ConstantVariable), + (StreamVariable, EventVariable, ConstantVariable), + ), + compare_by_value, + ), + ( + (TensorVariable, VariableTracker), + op_var._comparison_with_tensor, + ), + ( + (VariableTracker, TensorVariable), + op_var._comparison_with_tensor, + ), + ( + (SymNodeVariable, VariableTracker), + op_var._comparison_with_symnode, + ), + ( + (VariableTracker, SymNodeVariable), + op_var._comparison_with_symnode, + ), + ] + ) + + def handle_is(tx: "InstructionTranslator", left, right): + # If the two objects are of different type, we can safely return False + # and True for `is` and `is not`, respectively + if type(left) is not type(right): + return ConstantVariable.create(op.__name__ != "is_") + if left is right: + return ConstantVariable.create(op(left, right)) + if ( + istype(left, variables.ExceptionVariable) + and istype(right, variables.ExceptionVariable) + and left.exc_type is not right.exc_type + ): + return ConstantVariable.create(op(left, right)) + + result.append(((VariableTracker, VariableTracker), handle_is)) + + return result + + for op in supported_comparison_ops.values(): + assert callable(op) + assert op not in op_handlers + op_handlers[op] = create_cmp_op_handlers(op) + + return op_handlers + + @staticmethod + def _find_binop_handler(op, a_type, b_type): + handlers = BuiltinVariable._binop_handlers().get(op) + if handlers is None: + return None + + matches = [] + for (type1, type2), handler in handlers: + if issubclass(a_type, type1) and issubclass(b_type, type2): + matches.append(handler) + return matches + + def can_insert_in_graph(self): + return self.fn in self._fx_graph_functions() + + def __init__(self, fn, **kwargs) -> None: + super().__init__(**kwargs) + self.fn = fn + + def __repr__(self) -> str: + if self.fn is None: + name = "None" + else: + name = self.fn.__name__ + + return f"{self.__class__.__name__}({name})" + + def as_python_constant(self): + return self.fn + + def as_proxy(self): + DTYPE = { + bool: torch.bool, + int: torch.int64, + float: torch.float64, + } + if self.fn in DTYPE: + return DTYPE[self.fn] + return super().as_proxy() + + def reconstruct(self, codegen: "PyCodegen"): + name = self.fn.__name__ + assert self.fn.__module__ == "builtins" + assert name not in codegen.tx.f_globals, "shadowed global" + codegen.append_output(codegen.create_load_global(name, add=True)) + + def constant_args(self, *args, **kwargs): + return check_constant_args(args, kwargs) + + def tensor_args(self, *args): + any_tensor = False + for arg in args: + if isinstance(arg, variables.GetAttrVariable): + return False + any_tensor = any_tensor or isinstance(arg, variables.TensorVariable) + return any_tensor + + def tensor_args_type(self, arg_types): + any_tensor = False + for arg_type in arg_types: + if issubclass(arg_type, variables.GetAttrVariable): + return False + any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable) + return any_tensor + + def python_and_tensor_constant_only(self, *args, **kwargs): + tensor_args = [] + non_tensor_args = [] + for i in itertools.chain(args, kwargs.values()): + if isinstance(i, variables.TensorVariable): + tensor_args.append(i) + else: + non_tensor_args.append(i) + return all( + is_constant_source(t.source) if t.source is not None else False + for t in tensor_args + ) and self.constant_args(*non_tensor_args) + + @staticmethod + def unwrap_unspec_args_kwargs(args, kwargs): + return [x.as_python_constant() for x in args], { + k: v.as_python_constant() for k, v in kwargs.items() + } + + def has_constant_handler(self, args, kwargs): + return self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ) + + @staticmethod + def _make_handler(fn, arg_types: list[type], has_kwargs: bool): + from .lazy import LazyVariableTracker + + obj = BuiltinVariable(fn) + handlers: list[_HandlerCallback] = [] + + if any(issubclass(t, LazyVariableTracker) for t in arg_types): + return lambda tx, args, kwargs: obj.call_function( + tx, [v.realize() for v in args], kwargs + ) + + if inspect.isclass(fn) and ( + issubclass(fn, Exception) + # GeneratorExit doesn't inherit from Exception + # >>> issubclass(GeneratorExit, Exception) + # False + or fn is GeneratorExit + ): + + def create_exception_class_object( + tx: "InstructionTranslator", args, kwargs + ): + if fn is AssertionError and not all( + isinstance(x, variables.ConstantVariable) + and isinstance(x.value, str) + for x in args + ): + unimplemented_v2( + gb_type="assert with non-string message", + context=str(args), + explanation="Dynamo only supports asserts with string messages", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + return variables.ExceptionVariable(fn, args, **kwargs) + + return create_exception_class_object + + if obj.can_insert_in_graph() and not ( + fn is operator.getitem + and not issubclass(arg_types[0], variables.TensorVariable) + ): + if obj.tensor_args_type(arg_types): + return obj._handle_insert_op_in_graph + elif has_kwargs: + # need runtime check for kwargs + handlers.append(obj._handle_insert_op_in_graph) + + # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.) + # NB: Tensor args are handled above and not here + if len(arg_types) == 2 and not has_kwargs: + # Try to find a handler for the arg types; otherwise, fall through to constant handler + binop_handlers = BuiltinVariable._find_binop_handler(fn, *arg_types) + if not binop_handlers: + pass + elif len(binop_handlers) == 1: + (binop_handler,) = binop_handlers + handlers.append(lambda tx, args, _: binop_handler(tx, *args)) + else: + + def call_binop_handlers(tx: "InstructionTranslator", args, _): + for fn in binop_handlers: + rv = fn(tx, *args) + if rv: + return rv + + handlers.append(call_binop_handlers) + + self_handler = getattr(obj, f"call_{fn.__name__}", None) + if self_handler: + + def call_self_handler(tx: "InstructionTranslator", args, kwargs): + try: + result = self_handler(tx, *args, **kwargs) + if result is not None: + return result + except TypeError: + # Check if binding is bad. inspect signature bind is expensive. + # So check only when handler call fails. + try: + inspect.signature(self_handler).bind(tx, *args, **kwargs) + except TypeError as e: + has_constant_handler = obj.has_constant_handler(args, kwargs) + if not has_constant_handler: + log.warning( + "incorrect arg count %s %s and no constant handler", + self_handler, + e, + ) + unimplemented_v2( + gb_type="invalid call to builtin op handler", + context=f"invalid args to {self_handler}: {args} {kwargs}", + explanation=f"Encountered TypeError when trying to handle op {fn.__name__}", + hints=[*graph_break_hints.DIFFICULT], + ) + else: + raise + except Unsupported as exc: + has_constant_handler = obj.has_constant_handler(args, kwargs) + if not has_constant_handler: + raise + # Actually, we will handle this just fine + exc.remove_from_stats() + + handlers.append(call_self_handler) + + if obj.can_constant_fold_through(): + if ( + all(issubclass(x, ConstantVariable) for x in arg_types) + and not has_kwargs + ): + + def constant_fold_handler(tx: "InstructionTranslator", args, kwargs): + # fast path + try: + res = fn( + *[x.as_python_constant() for x in args], + ) + except Exception as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + except AsPythonConstantNotImplementedError as exc: + unimplemented_v2( + gb_type="constant fold exception", + context=f"attempted to run function {fn} with arguments {args}", + explanation="Encountered exception when attempting to constant fold.", + hints=[*graph_break_hints.DYNAMO_BUG], + from_exc=exc, + ) + return VariableTracker.build(tx, res) + + else: + + def constant_fold_handler(tx: "InstructionTranslator", args, kwargs): + # path with a runtime check + if check_unspec_or_constant_args(args, kwargs): + try: + res = fn( + *[x.as_python_constant() for x in args], + **{ + k: v.as_python_constant() for k, v in kwargs.items() + }, + ) + except AsPythonConstantNotImplementedError as exc: + unimplemented_v2( + gb_type="constant fold exception", + context=f"attempted to run function {fn} with arguments {args}", + explanation="Encountered exception when attempting to constant fold.", + hints=[*graph_break_hints.DYNAMO_BUG], + from_exc=exc, + ) + except Exception as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) + return VariableTracker.build(tx, res) + + handlers.append(constant_fold_handler) + + def call_unimplemented_v2(args): + real_arg_types = [arg.python_type_name() for arg in args] + unimplemented_v2( + gb_type="Failed to trace builtin operator", + context=f"builtin {fn.__name__} {arg_types} {has_kwargs}", + explanation=f"Dynamo does not know how to trace builtin operator `{fn.__name__}` " + f"with argument types {real_arg_types} (has_kwargs {has_kwargs})", + hints=[ + f"Avoid calling builtin `{fn.__name__}` with argument types {real_arg_types}. " + f"Consider using an equivalent alternative function/method to `{fn.__name__}`.", + "If you are attempting to call a logging function (e.g. `print`), " + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please report an issue to PyTorch.", + ], + ) + + if len(handlers) == 0: + return lambda tx, args, kwargs: call_unimplemented_v2(args) + elif len(handlers) == 1: + (handler,) = handlers + + def builtin_dispatch(tx: "InstructionTranslator", args, kwargs): + rv = handler(tx, args, kwargs) + if rv: + return rv + call_unimplemented_v2(args) + + else: + + def builtin_dispatch(tx: "InstructionTranslator", args, kwargs): + for fn in handlers: + rv = fn(tx, args, kwargs) + if rv: + return rv + call_unimplemented_v2(args) + + return builtin_dispatch + + def _handle_insert_op_in_graph(self, tx: "InstructionTranslator", args, kwargs): + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls + + if kwargs and not self.tensor_args(*args, *kwargs.values()): + return + + # insert handling for torch function here + from .builder import SourcelessBuilder + from .torch_function import ( + BUILTIN_TO_TENSOR_FN_MAP, + BUILTIN_TO_TENSOR_RFN_MAP, + can_dispatch_torch_function, + dispatch_torch_function, + ) + + if can_dispatch_torch_function(tx, args, kwargs): + # Only remap the fn to tensor methods if we aren't exporting + # export serde does not handle method descriptors today + if not tx.export: + # Use sourceless builder, we built the map ourselves + if not isinstance(args[0], TensorVariable): + if self.fn in BUILTIN_TO_TENSOR_RFN_MAP: + func = BUILTIN_TO_TENSOR_RFN_MAP[self.fn] + else: + func = BUILTIN_TO_TENSOR_FN_MAP[self.fn] + + tmp = args[0] + # swap args and call reverse version of func + args[0] = args[1] + args[1] = tmp + else: + func = BUILTIN_TO_TENSOR_FN_MAP[self.fn] + else: + func = self.fn + + fn_var = SourcelessBuilder.create(tx, func) + + return dispatch_torch_function(tx, fn_var, args, kwargs) + + fn = self.fn + try: + # Constant fold for constant tensor and python constants + if self.python_and_tensor_constant_only(*args, **kwargs): + from ..bytecode_transformation import unique_id + from .functions import invoke_and_store_as_constant + + return invoke_and_store_as_constant( + tx, fn, unique_id(fn.__name__), args, kwargs + ) + + if fn in IN_PLACE_DESUGARING_MAP and isinstance( + args[0], variables.ConstantVariable + ): + # In-place operators like += usually mustate tensor + # values, but in the edge case of immutable values they + # re-bind the variable. + # + # The easiest way to keep the graph consistent in this + # scenario is to de-sugar eagerly. + fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]] + + if fn is operator.getitem and isinstance(args[1], SymNodeVariable): + # Standard indexing will force specialization due to + # __index__. Rewrite as a regular torch op which will + # trace fine + fn, args = ( + torch.select, + [ + args[0], + variables.ConstantVariable.create(0), + args[1], + ], + ) + + # Interaction between ndarray and tensors: + # We prefer the tensor op whenever there are tensors involved + if check_numpy_ndarray_args(args, kwargs) and not any( + type(arg) == variables.TensorVariable for arg in args + ): + proxy = tx.output.create_proxy( + "call_function", + numpy_operator_wrapper(fn), + *proxy_args_kwargs(args, kwargs), + ) + + return wrap_fx_proxy_cls(variables.NumpyNdarrayVariable, tx, proxy) + + if ( + fn is operator.eq + and len(args) == 2 + and isinstance(args[0], variables.TensorVariable) + ): + # Dynamo expects `__eq__` str while operator.eq gives just `eq` + # TODO - supporting all comparison operators could also work but + # it fails lots of tests because graph str changes. + return args[0].call_method(tx, "__eq__", args[1:], kwargs) + proxy = tx.output.create_proxy( + "call_function", + fn, + *proxy_args_kwargs(args, kwargs), + ) + if any(isinstance(arg, FakeItemVariable) for arg in args): + return wrap_fx_proxy_cls( + FakeItemVariable, + tx, + proxy, + ) + elif check_unspec_python_args(args, kwargs): + _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) + raw_value = fn(*_args, **_kwargs) + + need_unwrap = any( + x.need_unwrap + for x in itertools.chain(args, kwargs.values()) + if isinstance(x, variables.UnspecializedPythonVariable) + ) + + return wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx, + proxy, + raw_value=raw_value, + need_unwrap=need_unwrap, + ) + elif all(isinstance(x, SymNodeVariable) for x in args): + return SymNodeVariable.create(tx, proxy, None) + else: + # Work around for vision_maskrcnn due to precision difference + # specialize the dividend when float divide by tensor + if fn is operator.truediv and isinstance( + args[0], variables.UnspecializedPythonVariable + ): + args[0] = args[0].as_python_constant() + return wrap_fx_proxy(tx, proxy) + + except NotImplementedError: + unimplemented_v2( + gb_type="unimplemented builtin op on tensor arguments", + context=f"partial tensor op: {self} {args} {kwargs}", + explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with tensor arguments", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + call_function_handler_cache: dict[ + tuple[object, ...], + Callable[ + [ + "InstructionTranslator", + Sequence[VariableTracker], + dict[str, VariableTracker], + ], + VariableTracker, + ], + ] = {} + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence["VariableTracker"], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + key: tuple[object, ...] + if kwargs: + kwargs = {k: v.realize() for k, v in kwargs.items()} + key = (self.fn, *(type(x) for x in args), True) + else: + key = (self.fn, *(type(x) for x in args)) + + handler = self.call_function_handler_cache.get(key) + if not handler: + self.call_function_handler_cache[key] = handler = self._make_handler( + self.fn, [type(x) for x in args], bool(kwargs) + ) + return handler(tx, args, kwargs) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if self.fn is object and name == "__setattr__": + assert len(args) == 3 + assert len(kwargs) == 0 + obj, name_var, val = args + obj = obj.realize() + if ( + isinstance(obj, UserDefinedObjectVariable) + and tx.output.side_effects.is_attribute_mutation(obj) + and name_var.is_python_constant() + ): + return obj.method_setattr_standard(tx, name_var, val) + + if name == "__new__": + # Supported __new__ methods + if self.fn is object and len(args) == 1: + assert len(kwargs) == 0 + return tx.output.side_effects.track_new_user_defined_object( + self, args[0], args[1:] + ) + + if self.fn is dict and len(args) == 1 and not kwargs: + dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew()) + if isinstance(args[0], BuiltinVariable) and args[0].fn is dict: + return dict_vt + # We don't have to set the underlying dict_vt in + # UserDefinedDictVariable because it will be set to empty + # ConstDictVariableTracker in the constructor. + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + + if ( + self.fn is tuple + and len(args) == 2 + and args[1].has_unpack_var_sequence(tx) + and not kwargs + ): + if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple: + init_args = args[1].unpack_var_sequence(tx) + return variables.TupleVariable( + init_args, mutation_type=ValueMutationNew() + ) + + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + + if self.fn is list: + list_vt = ListVariable([], mutation_type=ValueMutationNew()) + if isinstance(args[0], BuiltinVariable) and args[0].fn is list: + return list_vt + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + + if self.fn is object and name == "__init__": + # object.__init__ is a no-op + return variables.ConstantVariable(None) + + if self.fn is dict and name == "fromkeys": + return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) + + if self.fn is dict: + resolved_fn = getattr(self.fn, name) + if resolved_fn in dict_methods: + if isinstance(args[0], variables.UserDefinedDictVariable): + return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs) + elif isinstance(args[0], variables.ConstDictVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + if self.fn is str and len(args) >= 1: + resolved_fn = getattr(self.fn, name) + if resolved_fn in str_methods: + if isinstance(args[0], ConstantVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + return super().call_method(tx, name, args, kwargs) + + def _call_int_float(self, tx: "InstructionTranslator", arg): + # Handle cases like int(torch.seed()) + # Also handle sym_float to sym_int cases + if isinstance(arg, (SymNodeVariable, variables.TensorVariable)): + if isinstance(arg, variables.TensorVariable): + item = arg.call_method(tx, "item", [], {}) + else: + item = arg + fn_ = sym_int if self.fn is int else sym_float + from torch._dynamo.variables.builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + (item.as_proxy(),), + {}, + ), + ) + + call_int = _call_int_float + call_float = _call_int_float + + def call_bool(self, tx: "InstructionTranslator", arg): + # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`. + # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697 + if isinstance(arg, SymNodeVariable): + # Note that we delay specializing on symbolic values to avoid + # unnecessary guards. Specialization will happen later if, e.g., the + # resulting boolean is used for branching. + if isinstance(arg.sym_num, torch.SymBool): + return arg + + # Emulate `nb_bool` of int/float objects + # - https://github.com/python/cpython/blob/3.12/Objects/longobject.c#L4940-L4944 + # - https://github.com/python/cpython/blob/3.12/Objects/floatobject.c#L878-L882 + assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat)) + return SymNodeVariable.create(tx, arg.as_proxy() != 0) + + # TODO handle more cases and merge this with this with `generic_jump`. + + def call_str(self, tx: "InstructionTranslator", arg): + # Handle `str` on a user defined function or object + if isinstance(arg, (variables.UserFunctionVariable)): + return variables.ConstantVariable.create(value=str(arg.fn)) + elif isinstance(arg, (variables.UserDefinedObjectVariable)): + # Check if object has __str__ method + if hasattr(arg.value, "__str__"): + str_method = arg.value.__str__ + elif hasattr(arg.value, "__repr__"): + # account for __repr__ functions when __str__ is absent + str_method = arg.value.__repr__ + else: + unimplemented_v2( + gb_type="failed to call str() on user defined object", + context=str(arg), + explanation="User defined object has no __str__ or __repr__ method", + hints=[*graph_break_hints.USER_ERROR], + ) + + if type(arg.value).__str__ is object.__str__: + # Rely on the object str method + try: + return variables.ConstantVariable.create(value=str_method()) + except AttributeError: + # Graph break + return + elif is_wrapper_or_member_descriptor(str_method): + unimplemented_v2( + gb_type="Attempted to a str() method implemented in C/C++", + context="", + explanation=f"{type(arg.value)} has a C/C++ based str method. This is not supported.", + hints=["Write the str method in Python"], + ) + else: + # Overrides for custom str method + # Pass method as function to call tx.inline_user_function_return + bound_method = str_method.__func__ # type: ignore[attr-defined] + + try: + # Only supports certain function types + user_func_variable = variables.UserFunctionVariable(bound_method) + except AssertionError as e: + # Won't be able to do inline the str method, return to avoid graph break + log.warning("Failed to create UserFunctionVariable: %s", e) + return + + # Inline the user function + return tx.inline_user_function_return(user_func_variable, [arg], {}) + elif isinstance(arg, (variables.ExceptionVariable,)): + if len(arg.args) == 0: + value = f"{arg.exc_type}" + else: + value = ", ".join(a.as_python_constant() for a in arg.args) + return variables.ConstantVariable.create(value=value) + + def _call_min_max(self, tx: "InstructionTranslator", *args): + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + return self._call_min_max_seq(tx, items) + elif len(args) == 2: + return self._call_min_max_binary(tx, args[0], args[1]) + elif len(args) > 2: + return self._call_min_max_seq(tx, args) + + def _call_min_max_seq(self, tx: "InstructionTranslator", items): + assert len(items) > 0 + if len(items) == 1: + return items[0] + + return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) + + def _call_min_max_binary(self, tx: "InstructionTranslator", a, b): + if a is None or b is None: + # a or b could be none if we reduce and _call_min_max_binary failed + # to return something + return + if self.tensor_args(a, b): + if not isinstance(a, variables.TensorVariable): + a, b = b, a + assert isinstance(a, variables.TensorVariable) + + # result of an item call is a scalar convert to a tensor + if isinstance(a, FakeItemVariable): + a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function( + tx, [a], {} + ) + + # Dynamic input does not get resolved, rather, gets stored as call_function + if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): + from .builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + type(a), + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.fn, + *proxy_args_kwargs([a, b], {}), + ), + ) + + # convert min/max to torch ops + if b.is_python_constant(): + fn: VariableTracker + if isinstance(a, variables.NumpyNdarrayVariable): + import numpy as np + + fn = variables.NumpyVariable(np.clip) + else: + fn = variables.TorchInGraphFunctionVariable(torch.clamp) + kwargs = {"min": b} if (self.fn is max) else {"max": b} + result = fn.call_function(tx, [a], kwargs) + else: + if isinstance(a, variables.NumpyNdarrayVariable): + import numpy as np + + np_fn = {max: np.maximum, min: np.minimum}[self.fn] + fn = variables.NumpyVariable(np_fn) + else: + torch_fn = {max: torch.maximum, min: torch.minimum}[self.fn] + fn = variables.TorchInGraphFunctionVariable(torch_fn) + result = fn.call_function(tx, [a, b], {}) + + # return unspec if both a, b are unspec or const + if all( + isinstance( + i, + ( + variables.UnspecializedPythonVariable, + variables.ConstantVariable, + ), + ) + for i in [a, b] + ): + if any(isinstance(val, FakeItemVariable) for val in [a, b]): + return variables.FakeItemVariable.from_tensor_variable(result) + + if b.is_python_constant(): + raw_b = b.as_python_constant() + else: + raw_b = b.raw_value + if self.fn is max: + raw_res = max(a.raw_value, raw_b) + else: + raw_res = min(a.raw_value, raw_b) + + need_unwrap = any( + x.need_unwrap + for x in [a, b] + if isinstance(x, variables.UnspecializedPythonVariable) + ) + return variables.UnspecializedPythonVariable.from_tensor_variable( + result, raw_res, need_unwrap + ) + # otherwise return tensor + else: + return result + elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): + py_fn = torch.sym_max if self.fn is max else torch.sym_min + proxy = tx.output.create_proxy( + "call_function", py_fn, *proxy_args_kwargs([a, b], {}) + ) + return SymNodeVariable.create(tx, proxy, None) + elif isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + value = self.fn( + a.as_python_constant(), + b.as_python_constant(), + ) + return ConstantVariable(value) + + call_min = _call_min_max + call_max = _call_min_max + + def call_abs(self, tx: "InstructionTranslator", arg: "VariableTracker"): + # Call arg.__abs__() + abs_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__abs__")], {} + ) + return abs_method.call_function(tx, [], {}) + + def call_pos(self, tx: "InstructionTranslator", arg: "VariableTracker"): + # Call arg.__pos__() + pos_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__pos__")], {} + ) + return pos_method.call_function(tx, [], {}) + + def call_index(self, tx: "InstructionTranslator", arg: "VariableTracker"): + if isinstance(arg, variables.TensorVariable): + unimplemented_v2( + gb_type="unsupported index(Tensor)", + context="", + explanation="Dynamo does not support tracing builtin index() on a Tensor", + hints=[], + ) + + arg = guard_if_dyn(arg) + constant_value = operator.index(arg) + return variables.ConstantVariable.create(constant_value) + + def call_round(self, tx: "InstructionTranslator", arg, *args, **kwargs): + # Call arg.__round__() + round_method = BuiltinVariable(getattr).call_function( + tx, [arg, ConstantVariable.create("__round__")], {} + ) + return round_method.call_function(tx, args, kwargs) + + def call_range(self, tx: "InstructionTranslator", *args): + if check_unspec_or_constant_args(args, {}): + return variables.RangeVariable(args) + elif self._dynamic_args(*args): + args = tuple( + variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args + ) + return variables.RangeVariable(args) + # None no-ops this handler and lets the driving function proceed + return None + + def _dynamic_args(self, *args, **kwargs): + return any(isinstance(x, SymNodeVariable) for x in args) or any( + isinstance(x, SymNodeVariable) for x in kwargs.values() + ) + + def call_slice(self, tx: "InstructionTranslator", *args): + return variables.SliceVariable(args) + + def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs(args, kwargs) + ), + ) + + # NOTE must handle IteratorVariable separately! + def _call_iter_tuple_list( + self, tx: "InstructionTranslator", obj=None, *args, **kwargs + ): + assert not isinstance(obj, variables.IteratorVariable) + + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) + + cls = variables.BaseListVariable.cls_for(self.fn) + if obj is None: + return cls( + [], + mutation_type=ValueMutationNew(), + ) + elif obj.has_unpack_var_sequence(tx): + if obj.source and not is_constant_source(obj.source): + if isinstance(obj, TupleIteratorVariable): + install_guard( + obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN) + ) + else: + if ( + getattr(obj, "source", False) + and isinstance(obj, ConstDictVariable) + and not istype(obj, SetVariable) + ): + tx.output.guard_on_key_order.add(obj.source) + + if isinstance(obj, variables.MappingProxyVariable): + # This could be an overguarding, but its rare to iterate + # through a mapping proxy and not use the keys. + install_guard( + obj.source.make_guard(GuardBuilder.MAPPING_KEYS_CHECK) + ) + elif not isinstance(obj, variables.UnspecializedNNModuleVariable): + # Prevent calling __len__ method for guards, the tracing + # of __iter__ will insert the right guards later. + install_guard( + obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH) + ) + + return cls( + list(obj.unpack_var_sequence(tx)), + mutation_type=ValueMutationNew(), + ) + + def _call_iter_tuple_generator(self, tx, obj, *args, **kwargs): + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), # exhaust generator + mutation_type=ValueMutationNew(), + ) + + def _call_tuple_list(self, tx, obj=None, *args, **kwargs): + if isinstance(obj, variables.IteratorVariable): + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), + mutation_type=ValueMutationNew(), + ) + elif isinstance(obj, variables.LocalGeneratorObjectVariable): + return self._call_iter_tuple_generator(tx, obj, *args, **kwargs) + else: + return self._call_iter_tuple_list(tx, obj, *args, **kwargs) + + def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): + if isinstance(obj, variables.IteratorVariable): + ret = obj + else: + # Handle the case where we are iterating over a tuple, list or iterator + ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) + + if ret is None: + # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. + # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call + # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator. + return obj.call_method(tx, "__iter__", args, kwargs) + return ret + + call_tuple = _call_tuple_list + call_list = _call_tuple_list + + def call_callable(self, tx: "InstructionTranslator", arg): + from .functions import BaseUserFunctionVariable, FunctoolsPartialVariable + from .nn_module import NNModuleVariable + + if isinstance( + arg, + ( + variables.UserDefinedClassVariable, + BaseUserFunctionVariable, + FunctoolsPartialVariable, + NNModuleVariable, + ), + ): + return variables.ConstantVariable.create(True) + elif isinstance(arg, UserDefinedVariable): + return variables.ConstantVariable.create(callable(arg.value)) + elif isinstance( + arg, + ( + ConstantVariable, + SymNodeVariable, + TensorVariable, + ListVariable, + TupleVariable, + ListIteratorVariable, + ), + ): + return variables.ConstantVariable.create(False) + + def call_cast(self, _, *args, **kwargs): + if len(args) == 2: + return args[1] + + unimplemented_v2( + gb_type="bad args to builtin cast()", + context=f"got args {args} {kwargs}", + explanation="Dynamo expects exactly 2 args to builtin cast().", + hints=["Ensure your call to cast() has exactly 2 arguments."], + ) + + def call_dict(self, tx: "InstructionTranslator", *args, **kwargs): + return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) + + @staticmethod + def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.construct_dict), + [VariableTracker.build(tx, user_cls), *args], + kwargs, + ) + + @staticmethod + def call_custom_dict_fromkeys( + tx: "InstructionTranslator", user_cls, *args, **kwargs + ): + assert user_cls in {dict, OrderedDict, defaultdict} + if kwargs: + # Only `OrderedDict.fromkeys` accepts `value` passed by keyword + assert user_cls is OrderedDict + assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs + args = (*args, kwargs.pop("value")) + if len(args) == 0: + raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0") # type: ignore[arg-type] + if len(args) == 1: + args = (*args, ConstantVariable.create(None)) + assert len(args) == 2 + arg, value = args + DictVariableType = ( + ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable + ) + + if isinstance(arg, dict): + arg = [ConstantVariable.create(k) for k in arg.keys()] + return DictVariableType( + dict.fromkeys(arg, value), user_cls, mutation_type=ValueMutationNew() + ) + elif arg.has_force_unpack_var_sequence(tx): + keys = arg.force_unpack_var_sequence(tx) + if all(is_hashable(v) for v in keys): + return DictVariableType( + dict.fromkeys(keys, value), + user_cls, + mutation_type=ValueMutationNew(), + ) + + unimplemented_v2( + gb_type="failed to call dict.fromkeys()", + context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", + explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " + "arguments could not be automatically converted to a list, " + "or some dict key is not hashable.", + hints=[ + "Manually convert the argument to a list.", + "Ensure all keys are hashable.", + ], + ) + + def call_set(self, tx: "InstructionTranslator", *args, **kwargs): + # Can we merge this implementation and call_dict's one? + assert not kwargs + if not args: + return SetVariable([], mutation_type=ValueMutationNew()) + if len(args) != 1: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"set() takes 1 positional argument but {len(args)} were given" + ) + ], + ) + arg = args[0] + if isinstance(arg, variables.SetVariable): + return arg.clone(mutation_type=ValueMutationNew()) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) + return SetVariable(items, mutation_type=ValueMutationNew()) + elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( + arg.value, KeysView + ): + iter_fn = arg.var_getattr(tx, "__iter__") + if isinstance(iter_fn, variables.UserMethodVariable): + out = tx.inline_user_function_return(iter_fn, args, kwargs) + if isinstance(out, SetVariable): + return out + return BuiltinVariable(set).call_set(tx, out) + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable.create("failed to construct builtin set()")], + ) + + def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs): + assert not kwargs + if not args: + return FrozensetVariable([]) + if len(args) != 1: + raise_observed_exception( + TypeError, + tx, + args=[ + ConstantVariable.create( + f"frozenset() takes 1 positional argument but {len(args)} were given" + ) + ], + ) + arg = args[0] + if isinstance(arg, variables.FrozensetVariable): + return FrozensetVariable([x.vt for x in arg.set_items]) + elif arg.has_unpack_var_sequence(tx): + items = arg.unpack_var_sequence(tx) + return FrozensetVariable(items) + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable.create("failed to construct builtin frozenset()")], + ) + + def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): + if kwargs: + assert len(kwargs) == 1 and "strict" in kwargs + strict = kwargs.pop("strict", False) + args = [ + arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg + for arg in args + ] + return variables.ZipVariable( + args, strict=strict, mutation_type=ValueMutationNew() + ) + + def call_len(self, tx: "InstructionTranslator", *args, **kwargs): + try: + return args[0].call_method(tx, "__len__", args[1:], kwargs) + except AttributeError as e: + raise_observed_exception(type(e), tx, args=list(e.args)) + + def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs): + return args[0].call_method(tx, "__getitem__", args[1:], kwargs) + + def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type): + try: + arg_type = arg.python_type() + except NotImplementedError: + unimplemented_v2( + gb_type="builtin isinstance() cannot determine type of argument", + context=f"isinstance({arg}, {isinstance_type})", + explanation=f"Dynamo doesn't have a rule to determine the type of argument {arg}", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + isinstance_type = isinstance_type.as_python_constant() + + if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: + + def _tensor_isinstance(tensor_var, tensor_type): + def check_type(ty): + if ty not in tensortype_to_dtype: + example_val = arg.as_proxy().node.meta["example_value"] + if ( + is_traceable_wrapper_subclass(example_val) + and ty is torch.nn.parameter.Parameter + ): + # N.B: we are calling isinstance directly on the example value. + # torch.nn.Parameter has a meta-class that overrides __isinstance__, + # the isinstance check here allows us to invoke that logic. + return isinstance(example_val, ty) + else: + return issubclass(arg.python_type(), ty) + + dtypes = tensortype_to_dtype[ty] + return arg.dtype in dtypes + + if type(tensor_type) is tuple: + return any(check_type(ty) for ty in tensor_type) + else: + return check_type(tensor_type) + + return variables.ConstantVariable.create( + _tensor_isinstance(arg, isinstance_type) + ) + # UserDefinedObject with C extensions can have torch.Tensor attributes, + # so break graph. + if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( + arg.value, types.MemberDescriptorType + ): + unimplemented_v2( + gb_type="isinstance() called on user defined object with C extensions", + context=f"isinstance({arg}, {isinstance_type})", + explanation="User-defined object with C extensions can have torch.Tensor " + "attributes; intentionally graph breaking.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + # handle __instancecheck__ defined in user class + if ( + isinstance(arg, variables.UserDefinedObjectVariable) + and "__instancecheck__" in isinstance_type.__class__.__dict__ + ): + return variables.ConstantVariable.create( + isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value) + ) + + if isinstance(arg, variables.UserDefinedExceptionClassVariable): + return ConstantVariable.create(isinstance(arg_type, isinstance_type)) + + isinstance_type_tuple: tuple[type, ...] + if isinstance(isinstance_type, type) or callable( + # E.g. isinstance(obj, typing.Sequence) + getattr(isinstance_type, "__instancecheck__", None) + ): + isinstance_type_tuple = (isinstance_type,) + elif sys.version_info >= (3, 10) and isinstance( + isinstance_type, types.UnionType + ): + isinstance_type_tuple = isinstance_type.__args__ + elif isinstance(isinstance_type, tuple) and all( + isinstance(tp, type) or callable(getattr(tp, "__instancecheck__", None)) + for tp in isinstance_type + ): + isinstance_type_tuple = isinstance_type + else: + raise_observed_exception( + TypeError, + tx, + args=[ + "isinstance() arg 2 must be a type, a tuple of types, or a union" + ], + ) + + try: + # NB: `isinstance()` does not call `__subclasscheck__` but use `__instancecheck__`. + # But usually `isinstance(obj, type_info)` and `issubclass(type(obj), type_info)` gives + # the same result. + # WARNING: This might run arbitrary user code `__subclasscheck__` and we did not trace + # through it. This is a limitation of the current implementation. + # Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it + # might not be a big issue and we trade off it for performance. + val = issubclass(arg_type, isinstance_type_tuple) + except TypeError: + val = arg_type in isinstance_type_tuple + return variables.ConstantVariable.create(val) + + def call_issubclass(self, tx: "InstructionTranslator", left_ty, right_ty): + """Checks if first arg is subclass of right arg""" + try: + left_ty_py = left_ty.as_python_constant() + right_ty_py = right_ty.as_python_constant() + except NotImplementedError: + unimplemented_v2( + gb_type="issubclass() with non-constant arguments", + context=f"issubclass({left_ty}, {right_ty})", + explanation="issubclass() with non-constant arguments not supported.", + hints=[ + "Make sure your arguments are types.", + *graph_break_hints.USER_ERROR, + ], + ) + + # WARNING: This might run arbitrary user code `__subclasscheck__`. + # See the comment in call_isinstance above. + return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py)) + + def call_super(self, tx: "InstructionTranslator", a, b): + return variables.SuperVariable(a, b) + + def call_next(self, tx: "InstructionTranslator", arg: VariableTracker): + try: + return arg.next_variable(tx) + except Unsupported as ex: + if isinstance(arg, variables.BaseListVariable): + ex.remove_from_stats() + return arg.items[0] + raise + + def call_hasattr(self, tx: "InstructionTranslator", obj, attr): + if attr.is_python_constant(): + name = attr.as_python_constant() + if isinstance(obj, variables.BuiltinVariable): + return variables.ConstantVariable(hasattr(obj.fn, name)) + return obj.call_obj_hasattr(tx, name) + + def call_map(self, tx: "InstructionTranslator", fn, *seqs): + seqs = [ + seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq + for seq in seqs + ] + return variables.MapVariable(fn, seqs, mutation_type=ValueMutationNew()) + + def call_filter(self, tx: "InstructionTranslator", fn, seq): + seq = seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq + return variables.FilterVariable(fn, seq, mutation_type=ValueMutationNew()) + + def call_getattr( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + name_var: VariableTracker, + default=None, + ): + if not name_var.is_python_constant(): + unimplemented_v2( + gb_type="getattr() with non-constant name argument", + context=f"getattr({obj}, {name_var}, {default})", + explanation="getattr() with non-constant name argument is not supported", + hints=["Ensure the name argument of getattr() is a string"], + ) + + name = name_var.as_python_constant() + + # See NOTE [Tensor "grad" and "_grad" attr] + if isinstance(obj, TensorVariable) and name == "_grad": + name = "grad" + + if tx.output.side_effects.is_attribute_mutation(obj): + if isinstance(obj, variables.UnspecializedNNModuleVariable): + if ( + name + in ( + "named_parameters", + "parameters", + "named_buffers", + "buffers", + "named_modules", + "modules", + ) + and obj.is_state_mutated + and tx.output.side_effects.has_pending_mutation(obj) + ): + unimplemented_v2( + gb_type="getattr() on nn.Module with pending mutation", + context=f"getattr({obj}, {name}, {default})", + explanation="Intentionally graph breaking on getattr() on a nn.Module " + "with a pending mutation", + hints=[], + ) + + if tx.output.side_effects.has_pending_mutation_of_attr(obj, name): + return tx.output.side_effects.load_attr(obj, name) + + if default is not None: + hasattr_var = self.call_hasattr(tx, obj, name_var) + assert hasattr_var.as_python_constant() in (True, False) + if not hasattr_var.as_python_constant(): + return default + + source = obj.source and AttrSource(obj.source, name) + if name in {"__bases__", "__base__", "__flags__"}: + try: + value = obj.as_python_constant() + if isinstance(value, type): + if name == "__bases__": + tuple_args = [ + VariableTracker.build( + tx, b, source and GetItemSource(source, i) + ) + for i, b in enumerate(value.__bases__) + ] + return variables.TupleVariable(tuple_args, source=source) + if name == "__base__": + return VariableTracker.build(tx, value.__base__, source) + if name == "__flags__": + return ConstantVariable.create(value.__flags__) + except NotImplementedError: + pass + + if isinstance(obj, variables.NNModuleVariable): + return obj.var_getattr(tx, name) + elif isinstance( + obj, + ( + variables.TensorVariable, + variables.NamedTupleVariable, + variables.ConstantVariable, + variables.DistributedVariable, + variables.UserDefinedClassVariable, + variables.UserDefinedObjectVariable, + ), + ): + if ( + isinstance(obj, variables.UserDefinedObjectVariable) + and issubclass(obj.value.__class__, unittest.TestCase) + and config.enable_trace_unittest + and name + in ( + "assertRaisesRegex", + "assertNotWarns", + "assertWarnsRegex", + "assertDictEqual", + "assertSequenceEqual", + "assertWarns", + ) + ): + unimplemented_v2( + gb_type="Failed to trace unittest method", + context=f"function: unittest.TestCase.{name}", + explanation=f"Dynamo does not know how to trace unittest method `{name}` ", + hints=[ + f"Avoid calling `TestCase.{name}`. " + "Please report an issue to PyTorch.", + ], + ) + if isinstance(obj, TensorVariable): + fake_val = obj.proxy.node.meta["example_value"] + if ( + isinstance(fake_val, torch.Tensor) + and is_sparse_any(fake_val) + and (not tx.export or not config.capture_sparse_compute) + ): + unimplemented_v2( + gb_type="Attempted to wrap sparse Tensor", + context="", + explanation="torch.compile does not support sparse Tensors", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + try: + return obj.var_getattr(tx, name) + except NotImplementedError: + return variables.GetAttrVariable(obj, name, source=source) + elif isinstance(obj, variables.TorchInGraphFunctionVariable): + # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. + member = getattr(obj.value, name) + if isinstance( + member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) and torch._dynamo.trace_rules.is_aten_op_or_tensor_method(member): + return variables.TorchInGraphFunctionVariable(member, source=source) + elif name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(obj, name, source=source) + elif isinstance(obj, DummyModule): + # TODO(mlazos) - Do we need this? + if obj.is_torch or name not in obj.value.__dict__: + member = getattr(obj.value, name) + else: + member = obj.value.__dict__[name] + + if config.replay_record_enabled: + tx.exec_recorder.record_module_access(obj.value, name, member) # type: ignore[arg-type, union-attr] + return VariableTracker.build(tx, member, source) + + elif istype(obj, variables.UserFunctionVariable) and name in ( + "__name__", + "__module__", + ): + return ConstantVariable.create(getattr(obj.fn, name)) + else: + try: + return obj.var_getattr(tx, name) + except NotImplementedError: + return variables.GetAttrVariable(obj, name, source=source) + + def call_setattr( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + name_var: VariableTracker, + val: VariableTracker, + ): + if isinstance( + obj, + ( + variables.PlacementVariable, + variables.NamedTupleVariable, + variables.UserDefinedObjectVariable, + variables.NestedUserFunctionVariable, + variables.ExceptionVariable, + ), + ): + return obj.call_method(tx, "__setattr__", [name_var, val], {}) + elif ( + tx.output.side_effects.is_attribute_mutation(obj) + and name_var.is_python_constant() + ): + name = name_var.as_python_constant() + if isinstance(obj, variables.TensorVariable): + from .builder import wrap_fx_proxy + + # Some special handling for tensor attributes. + if name == "requires_grad": + # TODO(voz): Make it work properly + unimplemented_v2( + gb_type="setattr() on Tensor.requires_grad", + context=f"setattr({obj}, {name}, {val})", + explanation="setattr() on Tensor.requires_grad not supported. " + "Mutating requires_grad can introduce a new leaf from non-leaf or vice versa in " + "the middle of the graph, which AOTAutograd does not currently know how to handle.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + elif name == "data": + # See comments on `test_set_data_on_scoped_tensor` for plans + # to support this. + if obj.source is None: + unimplemented_v2( + gb_type="Failed to mutate tensor data attribute", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo only supports mutating `.data`" + " of tensor created outside `torch.compile` region", + hints=[ + "Don't mutate `.data` on this tensor, or move " + "the mutation out of `torch.compile` region", + ], + ) + + # Remove the old reference in tracked fakes - if we don't do this + # new .data value size and shape differences will cause + # tracked fakes to produce incorrect guards. This is sound because the TensorVariable + # coming out of set_() below will be a new one, and get + # installed in tracked fakes. + to_remove = [ + tf for tf in tx.output.tracked_fakes if tf.source == obj.source + ] + for tf in to_remove: + tx.output.tracked_fakes.remove(tf) + + # Step 1 - disable grads + with dynamo_disable_grad(tx), torch.no_grad(): + # Step 2 - call `set_` + out = wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + torch.Tensor.set_, + *proxy_args_kwargs([obj, val], {}), + ), + ) + + # Step 3 - drop the version counter - this is a step required to get + # .data setting to play correctly with the autograd engine. + # Essentially, dynamo is trying to faithfully preserve the (absurd) + # behavior of .data= from eager mode + def _lower_version_count_by_1(x): + version = x._version + if version > 0: + version = version - 1 + torch._C._autograd._unsafe_set_version_counter((x,), (version,)) + return x + + tx.output.create_proxy( + "call_function", + _lower_version_count_by_1, + (out.as_proxy(),), + {}, + ) + _lower_version_count_by_1(obj.as_proxy().node.meta["example_value"]) + # This handles options prop, guards and ends with a clone + # Step 4 - replace all reference to the current object with the new one + return out + elif name in ("_grad", "grad"): + # NOTE: [Tensor "grad" and "_grad" attr] + # _grad and grad share the same setter/getter, see + # THPVariable_properties, and here we make sure setting one + # enables reading `val` from the other, by routing all + # read/write to `grad`. + name = "grad" + elif is_tensor_getset_descriptor(name): + # Attribute like `torch.Tensor.real` has special setters we + # don't yet support; it's not as simple adding an entry to + # the side effect mapping. + unimplemented_v2( + gb_type="Failed to set tensor attribute", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo doesn't support setting these tensor attributes", + hints=[ + f"Don't mutate attribute '{name}' on tensors, or " + "move the mutation out of `torch.compile` region", + ], + ) + + tx.output.side_effects.store_attr(obj, name, val) + return val + elif isinstance(obj, variables.NNModuleVariable): + if not tx.output.is_root_tracer(): + raise AttributeMutationError( + "Can't inplace modify module params/buffers inside HigherOrderOp" + ) + if name_var.is_python_constant() and isinstance( + val, variables.TensorVariable + ): + assigning_fake_val = get_fake_value(val.as_proxy().node, tx) + + try: + getattr_var = obj.var_getattr(tx, name_var.as_python_constant()) + except (AttributeError, ObservedAttributeError): + getattr_var = None + + if isinstance(getattr_var, variables.TensorVariable): + # get_fake_val will get the same fake tensor + existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx) + + # same tensor identity, setattr is a no-op + mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__") + if ( + existing_fake_attr is assigning_fake_val + and mod_setattr is torch.nn.Module.__setattr__ + ): + return getattr_var + + obj.convert_to_unspecialized(tx) + + def call_delattr( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + name_var: VariableTracker, + ): + return obj.call_method(tx, "__delattr__", [name_var], {}) + + def call_type(self, tx: "InstructionTranslator", obj: VariableTracker): + try: + py_type = obj.python_type() + except NotImplementedError as error: + raise UserError( + UserErrorType.INVALID_INPUT, + str(error), + case_name="unknown_python_type", + ) from None + + source = obj.source and TypeSource(obj.source) + if ( + source is None + and isinstance(obj, variables.UserDefinedObjectVariable) + and obj.cls_source + ): + source = obj.cls_source + if py_type is torch.Tensor: + # In some cases torch isn't available in globals + name = tx.output.install_global_by_id("", torch) + source = AttrSource(GlobalSource(name), "Tensor") + + return VariableTracker.build(tx, py_type, source) + + def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker): + if obj.has_unpack_var_sequence(tx): + items = list(reversed(obj.unpack_var_sequence(tx))) + return variables.TupleVariable(items) + + def call_sorted( + self, + tx: "InstructionTranslator", + obj: VariableTracker, + **kwargs: VariableTracker, + ): + if obj.has_force_unpack_var_sequence(tx) and not isinstance( + obj, variables.TensorVariable + ): + list_var = variables.ListVariable( + obj.force_unpack_var_sequence(tx), + mutation_type=ValueMutationNew(), + ) + list_var.call_method(tx, "sort", [], kwargs) + return list_var + + # neg is a constant fold function, so we only get here if constant fold is not valid + def call_neg(self, tx: "InstructionTranslator", a): + if isinstance(a, SymNodeVariable): + return SymNodeVariable.create( + tx, + (operator.neg)(a.as_proxy()), + sym_num=None, + ) + # None no-ops this handler and lets the driving function proceed + return None + + def call_format(self, tx: "InstructionTranslator", _format_string, *args, **kwargs): + format_string = _format_string.as_python_constant() + format_string = str(format_string) + return variables.StringFormatVariable.create(format_string, args, kwargs) + + def call_id(self, tx: "InstructionTranslator", *args): + if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable): + nn_mod_variable = args[0] + mod = tx.output.get_submodule(nn_mod_variable.module_key) + return variables.ConstantVariable.create(id(mod)) + elif len(args) == 1 and isinstance( + args[0], + (variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable), + ): + if args[0].source: + install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH)) + constant_result = id(args[0].value) + return variables.ConstantVariable.create(constant_result) + elif len(args) == 1 and isinstance(args[0], TensorVariable): + tensor_variable = args[0] + return tensor_variable.call_id(tx) + elif istype(args[0], variables.UserFunctionVariable): + return variables.ConstantVariable.create(id(args[0].fn)) + elif istype(args[0], variables.SkipFunctionVariable): + return variables.ConstantVariable.create(id(args[0].value)) + elif istype(args[0], variables.FunctoolsPartialVariable): + return variables.ConstantVariable.create(id(args[0].fake_value)) + else: + unimplemented_v2( + gb_type="id() with unsupported args", + context=str(args), + explanation=f"Dynamo doesn't know how to trace id() call with args {args}", + hints=[ + "Supported args are Tensors, and functions/nn.Modules/user-defined objects " + "from outside the compiled region.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def call_deepcopy(self, tx: "InstructionTranslator", x): + unimplemented_v2( + gb_type="copy.deepcopy()", + context=f"copy.deepcopy({x})", + explanation="Dynamo does not support copy.deepcopy()", + hints=[ + "Avoid calling copy.deepcopy()", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def _comparison_with_tensor(self, tx: "InstructionTranslator", left, right): + from .builder import wrap_fx_proxy_cls + from .tensor import supported_tensor_comparison_op_values + + op = self.fn + + if op in [operator.is_, operator.is_not]: + is_result = ( + isinstance(left, TensorVariable) + and isinstance(right, TensorVariable) + and id(extract_fake_example_value(left.as_proxy().node)) + == id(extract_fake_example_value(right.as_proxy().node)) + ) + if op is operator.is_: + return ConstantVariable.create(is_result) + else: + return ConstantVariable.create(not is_result) + + if op not in supported_tensor_comparison_op_values: + unimplemented_v2( + gb_type="unsupported Tensor comparison op", + context=f"{op.__name__}({left}, {right})", + explanation=f"Dynamo does not support the comparison op {op.__name__} " + f"with Tensor arguments {left}, {right}", + hints=[*graph_break_hints.SUPPORTABLE], + ) + if ( + isinstance(left, TensorVariable) + and isinstance(right, TensorVariable) + and (left.size and right.size) is not None + and left.size != right.size + ): + try: + torch.broadcast_shapes(left.size, right.size) + except RuntimeError: + # not broadcastable, can't be compared + unimplemented_v2( + gb_type="failed to broadcast when attempting Tensor comparison op", + context=f"{op.__name__}({left}, {right})", + explanation=f"Dynamo was unable to broad cast the arguments {left}, {right} " + f"when attempting to trace the comparison op {op.__name__}.", + hints=[*graph_break_hints.USER_ERROR], + ) + tensor_cls = left if isinstance(left, TensorVariable) else right + proxy = tx.output.create_proxy( + "call_function", op, (left.as_proxy(), right.as_proxy()), {} + ) + return wrap_fx_proxy_cls( + type(tensor_cls), # handle Ndarrays and Tensors + tx, + proxy, + ) + + def _comparison_with_symnode(self, tx: "InstructionTranslator", left, right): + from .tensor import supported_tensor_comparison_op_values + + op = self.fn + + if op not in supported_tensor_comparison_op_values: + unimplemented_v2( + gb_type="unsupported SymNode comparison op", + context=f"{op.__name__}({left}, {right})", + explanation=f"Dynamo does not support the comparison op {op.__name__} " + f"with SymNode arguments {left}, {right}", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + # This is seen in inspect signature where we check if the value is a default value + if isinstance(right, variables.UserDefinedClassVariable): + return variables.ConstantVariable(op(object(), None)) + + proxy = tx.output.create_proxy( + "call_function", op, (left.as_proxy(), right.as_proxy()), {} + ) + return SymNodeVariable.create( + tx, + proxy, + sym_num=None, + ) + + def call_and_(self, tx: "InstructionTranslator", a, b): + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.and_, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + if hasattr(a, "set_items") and hasattr(b, "set_items"): + return SetVariable(list(a.set_items & b.set_items)) + # None no-ops this handler and lets the driving function proceed + + call_iand = call_and_ + + def call_or_(self, tx: "InstructionTranslator", a, b): + # Rely on constant_handler + if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): + return None + if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( + b, (SymNodeVariable, ConstantVariable) + ): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.or_, *proxy_args_kwargs([a, b], {}) + ), + sym_num=None, + ) + if hasattr(a, "set_items") and hasattr(b, "set_items"): + return SetVariable(list(a.set_items | b.set_items)) + # This call looks like `{"one": torch.ones(1)} | {"two": torch.ones(2)}`. + if isinstance(a, ConstDictVariable): + return a.call_method(tx, "__or__", args=[b], kwargs={}) + # None no-ops this handler and lets the driving function proceed + return None + + call_ior = call_or_ + + def call_not_(self, tx: "InstructionTranslator", a): + if isinstance(a, SymNodeVariable): + return SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", operator.not_, *proxy_args_kwargs([a], {}) + ), + sym_num=None, + ) + + # Unwrap the underlying ConstDictVariable + if isinstance(a, DictViewVariable): + a = a.dv_dict + if isinstance(a, (ListVariable, ConstDictVariable)): + return ConstantVariable.create(len(a.items) == 0) + + return None + + def call_contains( + self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker + ): + return a.call_method(tx, "__contains__", [b], {}) + + +@contextlib.contextmanager +def dynamo_disable_grad(tx): + from . import GradModeVariable + + gmv = GradModeVariable.create(tx, False) + try: + gmv.enter(tx) + yield + finally: + gmv.exit(tx) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/constant.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..c10b9aa29435996252a1e73f766dd9df7fe54fd1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/constant.py @@ -0,0 +1,267 @@ +# mypy: ignore-errors + +""" +Constant and enum variable tracking in Dynamo. + +This module is fundamental to Dynamo's ability to track and propagate constant +values during compilation, ensuring proper handling of Python literals and +maintaining type safety through the compilation process. +""" + +import operator +from typing import TYPE_CHECKING + +import torch +from torch._dynamo.source import AttrSource, GetItemSource + +from .. import graph_break_hints, variables +from ..exc import raise_observed_exception, unimplemented_v2 +from ..utils import cmp_name_to_op_mapping, common_constant_types, istype, np +from .base import VariableTracker + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class ConstantVariable(VariableTracker): + """ + Variable tracker for Python literals and basic immutable types, with automatic + routing support for collection types (lists, tuples, sets, etc.). + + The create() method intelligently constructs appropriate variable types for + nested collections. + """ + + @staticmethod + def create(value, **kwargs) -> VariableTracker: + """ + Create a `ConstantVariable` based on the given value, and supports + automatic routing for collection types like `tuple` (in which case we'd + create `ConstantVariable` for the leaf items). + + NOTE: the caller must install the proper guards if needed; most often + the guard will be `CONSTANT_MATCH`. + """ + source = kwargs.get("source", None) + + # Routing for supported collection literals. + if isinstance(value, set): + items = [ConstantVariable.create(x) for x in value] + return variables.SetVariable(items, **kwargs) + elif isinstance(value, frozenset): + items = [ConstantVariable.create(x) for x in value] + return variables.FrozensetVariable(items, **kwargs) + elif isinstance(value, (list, tuple)): + items = [] + for i, x in enumerate(value): + item_source = GetItemSource(source, i) if source else None + items.append( + ConstantVariable.create( + x, + source=item_source, + ) + ) + return variables.BaseListVariable.cls_for(type(value))(items, **kwargs) + + return ConstantVariable(value, **kwargs) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + assert ConstantVariable.is_base_literal(value), f""" +Cannot construct `ConstantVariable` for value of type {type(value)}. + +This failure likely due to PyTorch-internal use of `ConstantVariable` on +non-literal python values, please try using `VariableTracker.build` instead. If +you believe it's a necessary and legitimate use case (the value is immutable and +can't easily be represented with another `VariableTracker` class), please add +its type to `common_constant_types`. +""" + if np is not None and isinstance(value, np.number): + self.value = value.item() + else: + self.value = value + + def as_proxy(self): + return self.value + + def __repr__(self) -> str: + return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" + + def as_python_constant(self): + return self.value + + def is_python_constant(self): + return True + + @property + def items(self): + """ + Need this when adding a BaseListVariable and a ConstantVariable together. + Happens in detectron2. + """ + return self.unpack_var_sequence(tx=None) + + def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + return ConstantVariable.create( + self.value[arg.as_python_constant()], + ) + + @staticmethod + def is_base_literal(obj): + return type(obj) in common_constant_types + + @staticmethod + def is_literal(obj): + if type(obj) in (list, tuple, set, frozenset, torch.Size): + return all(ConstantVariable.is_literal(x) for x in obj) + return ConstantVariable.is_base_literal(obj) + + def unpack_var_sequence(self, tx): + try: + return [ConstantVariable.create(x) for x in self.as_python_constant()] + except TypeError as e: + raise NotImplementedError from e + + def const_getattr(self, tx: "InstructionTranslator", name): + if not hasattr(self.value, name): + raise NotImplementedError + member = getattr(self.value, name) + if callable(member): + raise NotImplementedError + return member + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .tensor import SymNodeVariable + + if name == "format" and istype(self.value, str): + return variables.BuiltinVariable(str.format).call_function( + tx, [self, *args], kwargs + ) + elif name == "join" and istype(self.value, str): + assert len(args) == 1 and len(kwargs) == 0 + arg_unpacked = args[0].force_unpack_var_sequence(tx) + try: + arg_const = [x.as_python_constant() for x in arg_unpacked] + return ConstantVariable.create(self.value.join(arg_const)) + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) + + if any(isinstance(x, SymNodeVariable) for x in args): + # Promote to SymNodeVariable for operations involving dynamic shapes. + return variables.SymNodeVariable(self.as_proxy(), self.value).call_method( + tx, name, args, kwargs + ) + + try: + const_args = [a.as_python_constant() for a in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) + + if isinstance(self.value, str) and name in str.__dict__.keys(): + method = getattr(self.value, name) + try: + return ConstantVariable.create(method(*const_args, **const_kwargs)) + except Exception as e: + raise_observed_exception(type(e), tx) + elif isinstance(self.value, (float, int)): + if not (args or kwargs): + return ConstantVariable.create(getattr(self.value, name)()) + if ( + hasattr(operator, name) + and len(args) == 1 + and args[0].is_python_constant() + ): + add_target = const_args[0] + op = getattr(operator, name) + if isinstance( + add_target, (torch.SymBool, torch.SymFloat, torch.SymInt) + ): + # Addition between a non sym and sym makes a sym + proxy = tx.output.create_proxy( + "call_function", op, (self.value, add_target), {} + ) + return SymNodeVariable.create(tx, proxy, add_target) + else: + try: + return ConstantVariable.create(op(self.value, add_target)) + except Exception as e: + raise_observed_exception( + type(e), tx, args=list(map(ConstantVariable.create, e.args)) + ) + elif isinstance(self.value, bytes) and name == "decode": + method = getattr(self.value, name) + return ConstantVariable.create(method(*const_args, **const_kwargs)) + + if name == "__len__" and not (args or kwargs): + return ConstantVariable.create(len(self.value)) + elif name == "__round__" and len(args) == 1 and args[0].is_python_constant(): + return ConstantVariable.create( + round(self.value, args[0].as_python_constant()) + ) + elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): + assert not kwargs + search = args[0].as_python_constant() + result = search in self.value + return ConstantVariable.create(result) + return super().call_method(tx, name, args, kwargs) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + +class EnumVariable(VariableTracker): + """VariableTracker for enum.Enum and enum.IntEnum instances + + Provides specialized handling for Python enum types, supporting + both standard Enum and IntEnum with proper value tracking and comparison. + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + @classmethod + def create(cls, cls_type, value_vt, options): + if isinstance(value_vt, variables.ConstantVariable): + for member in list(cls_type): + if member.value == value_vt.as_python_constant(): + return cls(member, **options) + unimplemented_v2( + gb_type="Failed to construct Enum variable", + context=f"value: {value_vt}, allowed enum values: {list(cls_type)}", + explanation="Attempted to construct an Enum value that is non-constant (e.g. int, string) " + "or is not an acceptable value for the Enum. " + f"Acceptable values for Enum `{cls_type}`: {list(cls_type)}.", + hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE], + ) + + def as_proxy(self): + if isinstance(self.value, int): + return int(self.value) # convert IntEnum to a normal int + return self.value + + def __repr__(self) -> str: + return f"EnumVariable({type(self.value)})" + + def as_python_constant(self): + return self.value + + def var_getattr(self, tx: "InstructionTranslator", name): + if not hasattr(self.value, name): + raise NotImplementedError + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + member = getattr(self.value, name) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, member, source=source) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/ctx_manager.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/ctx_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9d19325d08076691eb737d6854bb2211ecb92ec8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/ctx_manager.py @@ -0,0 +1,1456 @@ +# mypy: ignore-errors + +""" +This file contains a collection of context manager classes used by Dynamo for tracking +and managing various PyTorch runtime states during graph compilation. These context +managers handle different aspects of PyTorch's execution environment, including: + +- Autograd states (grad mode, inference mode) +- CUDA streams and events +- Profiling contexts +- Deterministic algorithms +- Forward/backward AD modes +- SDPA (Scaled Dot Product Attention) kernels +- FSDP (Fully Sharded Data Parallel) states +- AMP (Automatic Mixed Precision) autocast states + +The context managers ensure proper state transitions during graph compilation by +tracking enter/exit points and managing cleanup operations. They help maintain +consistency between eager execution and compiled graph behavior by capturing and +restoring state changes. +""" + +import inspect +import sys +import warnings +from typing import TYPE_CHECKING, Union + +import torch._C +from torch._guards import Guard + +from .. import graph_break_hints, variables +from ..bytecode_transformation import ( + create_call_function, + create_instruction, + create_setup_with, +) +from ..device_interface import get_interface_for_device +from ..exc import unimplemented_v2 +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, GlobalStateSource +from .base import VariableTracker +from .functions import ( + NestedUserFunctionVariable, + SkipFunctionVariable, + UserFunctionVariable, + UserMethodVariable, + WrappedNestedUserFunctionVariable, + WrappedSkipFunctionVariable, + WrappedUserFunctionVariable, + WrappedUserMethodVariable, +) +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class ContextWrappingVariable(VariableTracker): + _nonvar_fields = { + "cm_obj", + "target_values", + "initial_values", + "state", + *VariableTracker._nonvar_fields, + } + + def __init__(self, target_values, initial_values=None, **kwargs) -> None: + super().__init__(**kwargs) + self.target_values = target_values + self.initial_values = initial_values + + def enter(self, tx): + self._call_func(tx, self.target_values) + self.set_cleanup_hook(tx) + return variables.ConstantVariable.create(None) + + def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): + if fn is None: + + def fn(): + self._call_func(tx, self.initial_values) + + self.cleanup_fn = fn + tx.output.add_cleanup_hook(self.cleanup) + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup_assert() + return variables.ConstantVariable.create(None) + + def reconstruct_type(self, codegen: "PyCodegen"): + codegen( + AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) + ) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: self.reconstruct_type(codegen)) + target_values = self.target_values + if not target_values: + target_values = () + codegen.extend_output([codegen.create_load_const(val) for val in target_values]) + codegen.extend_output(create_call_function(len(target_values), False)) + + def module_name(self): + raise NotImplementedError("module_name called on base") + + def fn_name(self): + raise NotImplementedError("fn_name called on base") + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert len(args) == 1 + assert isinstance( + args[0], + ( + NestedUserFunctionVariable, + SkipFunctionVariable, + UserMethodVariable, + UserFunctionVariable, + ), + ) + + if isinstance(args[0], NestedUserFunctionVariable): + return WrappedNestedUserFunctionVariable(args[0], self) + + if isinstance(args[0], SkipFunctionVariable): + return WrappedSkipFunctionVariable(args[0], self) + + if isinstance(args[0], UserMethodVariable): + return WrappedUserMethodVariable(args[0], self) + + if isinstance(args[0], UserFunctionVariable): + return WrappedUserFunctionVariable(args[0], self) + + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return True + + def cleanup(self): + if self.cleanup_fn is not None: + self.cleanup_fn() + self.cleanup_fn = None + + def cleanup_assert(self): + assert self.cleanup_fn, "multiple exits?" + self.cleanup() + + +class GenericContextWrappingVariable(UserDefinedObjectVariable): + # Some methods in ContextWrappingVariable assumes the arguments are + # python constants. Which might not always be the case here. + def __init__(self, cm_obj, **kwargs) -> None: + assert cm_obj is not None + super().__init__( + value=cm_obj, + value_type=cm_obj.__class__, + **kwargs, + ) + self.cm_obj = cm_obj + + def module_name(self): + return self.cm_obj.__module__ + + def fn_name(self): + return type(self.cm_obj).__name__ + + def enter(self, tx): + source = None if self.source is None else AttrSource(self.source, "__enter__") + return variables.UserMethodVariable( + self.cm_obj.__enter__.__func__, + self, + source=source, + ).call_function(tx, [], {}) + + def exit(self, tx: "InstructionTranslator", *args): + source = None if self.source is None else AttrSource(self.source, "__exit__") + x = variables.UserMethodVariable( + self.cm_obj.__exit__.__func__, + self, + source=source, + ).call_function(tx, args, {}) + tx.active_generic_context_managers.pop() + return x + + def supports_graph_breaks(self): + return False + + def exit_on_graph_break(self): + return True + + +class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): + """represents torch grad requires grad""" + + @staticmethod + def create(tx: "InstructionTranslator", target_values, **kwargs): + return GradInplaceRequiresGradCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx): + [enabled] = self.target_values + self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() + torch._C._functorch.set_inplace_requires_grad_allowed(enabled) + self.set_cleanup_hook( + tx, + lambda: torch._C._functorch.set_inplace_requires_grad_allowed( + self.prev_state + ), + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (enabled,), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (self.prev_state,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable): + """represents torch._functorch.pyfunction.temporarily_pop_interpreter_stack()""" + + @staticmethod + def create(tx: "InstructionTranslator", target_values, **kwargs): + return TemporarilyPopInterpreterStackCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx): + self.saved = torch._C._functorch.pop_dynamic_layer_stack() + self.set_cleanup_hook( + tx, + lambda: torch._C._functorch.push_dynamic_layer_stack(self.saved), + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch.pop_dynamic_layer_stack, + (), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._functorch.push_dynamic_layer_stack, + (self.proxy,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch.func.jvp increment/decrement nesting""" + + # A guard is needed as the grad level is baked into the torch FX graph + # This is fine if jvp is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a jvp + # call from eager that calls the compiled function, as the jvp levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + + @staticmethod + def create(tx: "InstructionTranslator", **kwargs): + var = JvpIncrementNestingCtxManagerVariable( + target_values=None, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx): + install_guard(self._guards_singleton) + jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting() + self.set_cleanup_hook( + tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting() + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch._jvp_increment_nesting, + (), + {}, + ) + return variables.ConstantVariable.create(jvp_level) + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup() + tx.output.create_node( + "call_function", torch._C._functorch._jvp_decrement_nesting, (), {} + ) + return variables.ConstantVariable.create(None) + + +class SetFwdGradEnabledContextManager(ContextWrappingVariable): + """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad""" + + @staticmethod + def create(tx: "InstructionTranslator", target_values, **kwargs): + return SetFwdGradEnabledContextManager( + target_values=target_values, + initial_values=None, + **kwargs, + ) + + def enter(self, tx): + [mode] = self.target_values + self.prev_state = torch._C._is_fwd_grad_enabled() + torch._C._set_fwd_grad_enabled(mode) + self.set_cleanup_hook( + tx, + lambda: torch._C._set_fwd_grad_enabled(self.prev_state), + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._set_fwd_grad_enabled, + (mode,), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._set_fwd_grad_enabled, + (self.prev_state,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class DualLevelContextManager(ContextWrappingVariable): + """Represents torch.autograd.forward_ad.dual_level ctx manager""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) + + @staticmethod + def create(tx: "InstructionTranslator", **kwargs): + return DualLevelContextManager( + target_values=None, + initial_values=None, + **kwargs, + ) + + def enter(self, tx): + install_guard(self._guards_singleton) + self.new_level = torch.autograd.forward_ad.enter_dual_level() + self.set_cleanup_hook( + tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level) + ) + self.proxy = tx.output.create_node( + "call_function", + torch._C._enter_dual_level, + (), + {}, + ) + return variables.ConstantVariable.create(self.new_level) + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup() + tx.output.create_node( + "call_function", + torch._C._exit_dual_level, + (self.new_level,), + {}, + ) + return variables.ConstantVariable.create(None) + + +class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch.func.grad increment/decrement nesting""" + + # A guard is needed as the grad level is baked into the torch FX graph + # This is fine if grad is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a grad + # call from eager that calls the compiled function, as the grad levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + + @staticmethod + def create(tx: "InstructionTranslator", **kwargs): + var = GradIncrementNestingCtxManagerVariable( + target_values=None, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx): + install_guard(self._guards_singleton) + grad_level = torch._C._functorch._grad_increment_nesting() + self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting()) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch._grad_increment_nesting, + (), + {}, + ) + return variables.ConstantVariable.create(grad_level) + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup() + tx.output.create_node( + "call_function", torch._C._functorch._grad_decrement_nesting, (), {} + ) + return variables.ConstantVariable.create(None) + + +class CatchWarningsCtxManagerVariable(ContextWrappingVariable): + """Delay a call to warnings.catch_warnings""" + + @staticmethod + def create(tx: "InstructionTranslator", catch_warnings_args): + return CatchWarningsCtxManagerVariable( + catch_warnings_args=catch_warnings_args, + target_values=None, + initial_values=None, + ) + + def __init__(self, catch_warnings_args, **kwargs) -> None: + assert isinstance(catch_warnings_args, dict), catch_warnings_args + super().__init__(**kwargs) + self.catch_warnings_args = catch_warnings_args + + def enter(self, tx): + kwargs = { + k: v.as_python_constant() for k, v in self.catch_warnings_args.items() + } + ctx_val = warnings.catch_warnings(**kwargs) + self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None)) + return variables.ConstantVariable.create(ctx_val.__enter__()) + + def reconstruct(self, cg): + cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings")) + cg.foreach(self.catch_warnings_args.values()) + keys = tuple(self.catch_warnings_args.keys()) + cg.extend_output(cg.create_call_function_kw(len(keys), keys, False)) + + +class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): + """represents torch VMap increment/decrement nesting""" + + # A guard is needed as the vmap level is baked into the torch FX graph + # generated. This is fine if vmap is only called from within the function + # being compiled. But the FX graph may be invalid in the case of a vmap + # call from eager that calls the compiled function, as the vmap levels + # may be different. + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + + @staticmethod + def create(tx: "InstructionTranslator", target_values, **kwargs): + var = VmapIncrementNestingCtxManagerVariable( + target_values=target_values, + initial_values=None, + **kwargs, + ) + return var + + def enter(self, tx): + install_guard(self._guards_singleton) + batch_size, randomness = self.target_values + if isinstance(batch_size, variables.SymNodeVariable): + batch_size_value = batch_size.sym_num + batch_size_node = batch_size.as_proxy().node + else: + batch_size_value = batch_size.as_python_constant() + batch_size_node = batch_size.as_python_constant() + randomness = randomness.as_python_constant() + vmap_level = torch._C._functorch._vmap_increment_nesting( + batch_size_value, randomness + ) + self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting()) + self.proxy = tx.output.create_node( + "call_function", + torch._C._functorch._vmap_increment_nesting, + (batch_size_node, randomness), + {}, + ) + return variables.ConstantVariable.create(vmap_level) + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup() + tx.output.create_node( + "call_function", torch._C._functorch._vmap_decrement_nesting, (), {} + ) + return variables.ConstantVariable.create(None) + + +class GradModeVariable(ContextWrappingVariable): + """represents torch.{no_grad,enable_grad,set_grad_mode}()""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) + + @staticmethod + def create(tx: "InstructionTranslator", target_value, initialized=False, **kwargs): + var = GradModeVariable( + target_values=[target_value], + initial_values=[torch.is_grad_enabled()], + **kwargs, + ) + if initialized: + var._call_func(tx, var.target_values) + return var + + def __init__( + self, target_values, initial_values=None, initialized=True, **kwargs + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def enter(self, tx): + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self._call_func(tx, self.initial_values) + return variables.ConstantVariable.create(None) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ): + self._call_func(tx, self.initial_values) # undo eager initialization + return super().call_function(tx, args, kwargs) + + def _call_func(self, tx: "InstructionTranslator", values): + assert len(values) == 1 + value = values[0] + # Coalesce grad mode mutations + if torch.is_grad_enabled() != value: + tx.output.create_node( + "call_function", torch._C._set_grad_enabled, (value,), {} + ) + torch._C._set_grad_enabled(value) + + def module_name(self): + return "torch" + + def fn_name(self): + return "set_grad_enabled" + + +class InferenceModeVariable(ContextWrappingVariable): + @staticmethod + def create(tx: "InstructionTranslator", target_value, **kwargs): + var = InferenceModeVariable( + [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs + ) + return var + + def __init__( + self, + target_values, + initial_values=None, + **kwargs, + ) -> None: + if initial_values is None: + # This must be called here since function defaults are evaluated at import time + initial_values = torch.is_inference_mode_enabled() + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.target_values = target_values + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup_assert() + tx.output.create_node( + "call_function", + torch.autograd.grad_mode._exit_inference_mode, + (self.proxy,), + {}, + ) + + def enter(self, tx): + disabled_inference_mode_forcibly = False + if ( + torch._dynamo.config.fake_tensor_disable_inference_mode + and self.target_values[0] + ): + # Do not set the inference mode because we keep it off during + # compilation. Set the grad_enabled to False to reflect the relevant + # part of inference_mode to torch.compile. + disabled_inference_mode_forcibly = True + prior = torch.is_grad_enabled() + torch._C._set_grad_enabled(False) + else: + ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values) + + def cleanup_hook(): + if disabled_inference_mode_forcibly: + torch._C._set_grad_enabled(prior) + else: + torch.autograd.grad_mode._exit_inference_mode(ctx) + + self.set_cleanup_hook(tx, cleanup_hook) + self.proxy = tx.output.create_node( + "call_function", + torch.autograd.grad_mode._enter_inference_mode, + (*self.target_values,), + {}, + ) + + def module_name(self): + return "torch" + + def fn_name(self): + return "inference_mode" + + +class CUDADeviceVariable(ContextWrappingVariable): + """represents torch.cuda.device""" + + @staticmethod + def create(tx: "InstructionTranslator", device, **kwargs): + var = CUDADeviceVariable( + target_values=[torch.cuda._get_device_index(device, optional=True)], + initial_values=None, + **kwargs, + ) + return var + + def __init__( + self, + target_values, + initial_values=None, + **kwargs, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.target_values = target_values + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup_assert() + tx.output.create_node( + "call_function", + torch.cuda._maybe_exchange_device, + (self.proxy,), + {}, + ) + return variables.ConstantVariable.create(False) + + def enter(self, tx): + prev_idx = torch.cuda._exchange_device(*self.target_values) + self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx)) + self.proxy = tx.output.create_node( + "call_function", + torch.cuda._exchange_device, + (*self.target_values,), + {}, + ) + + def module_name(self): + return "torch.cuda" + + def fn_name(self): + return "device" + + +class TorchFunctionDisableVariable(ContextWrappingVariable): + """represents whether torch function overrides are enabled or not""" + + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) + + @staticmethod + def create(tx: "InstructionTranslator", **kwargs): + var = TorchFunctionDisableVariable( + target_values=[], + initial_values=[], + **kwargs, + ) + return var + + def __init__( + self, target_values, initial_values=None, only_subclass=True, **kwargs + ) -> None: + assert len(target_values) == 0 + assert len(initial_values) == 0 + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + self.only_subclass = only_subclass + self.initial_torch_function_subclass_enabled = ( + tx.symbolic_torch_function_state.torch_function_subclass_enabled + ) + self.initial_torch_function_mode_enabled = ( + tx.symbolic_torch_function_state.torch_function_mode_enabled + ) + + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): + if fn is None: + + def fn(): + tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( + self.initial_torch_function_subclass_enabled + ) + if not self.only_subclass: + tx.symbolic_torch_function_state.torch_function_mode_enabled = ( + self.initial_torch_function_subclass_enabled + ) + + self.cleanup_fn = fn + tx.output.add_cleanup_hook(self.cleanup) + + def _call_func(self, tx: "InstructionTranslator", values): + assert len(values) == 0 + tx.symbolic_torch_function_state.torch_function_subclass_enabled = False + if not self.only_subclass: + tx.symbolic_torch_function_state.torch_function_mode_enabled = False + + def module_name(self): + return "torch._C" + + def fn_name(self): + if self.only_subclass: + return "DisableTorchFunctionSubclass" + return "DisableTorchFunction" + + +class DeterministicAlgorithmsVariable(ContextWrappingVariable): + """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" + + _guards_singleton = Guard( + GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS + ) + + @staticmethod + def create(tx: "InstructionTranslator", target_value, **kwargs): + var = DeterministicAlgorithmsVariable( + target_values=[target_value], + initial_values=[torch.are_deterministic_algorithms_enabled()], + **kwargs, + ) + var._call_func(tx, [target_value]) + var.set_cleanup_hook(tx) + return var + + def __init__(self, target_values, initial_values=None, **kwargs) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + install_guard(self._guards_singleton) + + def enter(self, tx): + return variables.ConstantVariable.create(None) + + def _call_func(self, tx: "InstructionTranslator", values): + assert len(values) == 1 + value = values[0] + ( + tx.output.create_node( + "call_function", torch._C._set_deterministic_algorithms, (value,), {} + ), + ) + torch._C._set_deterministic_algorithms(value) + + def module_name(self): + return "torch" + + def fn_name(self): + return "use_deterministic_algorithms" + + +class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): + """represents torch.autograd.graph.disable_saved_tensors_hook.""" + + @staticmethod + def create(tx: "InstructionTranslator", target_value, **kwargs): + var = DisabledSavedTensorsHooksVariable( + target_values=[target_value], + initial_values=[ + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ], + **kwargs, + ) + var._call_func(tx, [target_value]) + var.set_cleanup_hook(tx) + return var + + def __init__(self, target_values, initial_values=None, **kwargs) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def enter(self, tx): + return variables.ConstantVariable.create(None) + + def _call_func(self, tx: "InstructionTranslator", values): + assert len(values) == 1 + value = values[0] + if value is not None: + # Disable `saved_tensors_hooks` with message (`value`) + # OR + # we are exiting this context and restoring the previous message. + tx.output.create_node( + "call_function", + torch._C._autograd._saved_tensors_hooks_disable, + (value,), + {}, + ) + torch._C._autograd._saved_tensors_hooks_disable(value) + else: + # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`. + tx.output.create_node( + "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {} + ) + torch._C._autograd._saved_tensors_hooks_enable() + + def module_name(self): + return "torch.autograd.graph" + + def fn_name(self): + return "disable_saved_tensors_hooks" + + +class AutocastModeVariable(ContextWrappingVariable): + @staticmethod + def create(func, args, kwargs): + assert func in [ + torch.amp.autocast_mode.autocast, + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ] + # device_type : str, + # dtype : Optional[_dtype] = None, + # enabled : bool = True, + # cache_enabled : Optional[bool] = None):cache_enabled + bound_args = inspect.signature(func).bind(*args, **kwargs) + bound_args.apply_defaults() + target_values = [] + kwargs.clear() + + for key in ["device_type", "dtype", "enabled", "cache_enabled"]: + if key == "device_type" and func in [ + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ]: + arg = "cuda" if func is torch.cuda.amp.autocast else "cpu" + else: + arg = bound_args.arguments[key] + if isinstance(arg, VariableTracker): + target_values.append(arg.as_python_constant()) + else: + target_values.append(arg) + + var = AutocastModeVariable(target_values, initial_values=None, **kwargs) + return var + + def __init__(self, target_values, initial_values=None, **kwargs) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.target_values = target_values + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup_assert() + tx.output.create_node( + "call_function", torch.amp._exit_autocast, (self.proxy,), {} + ) + return variables.ConstantVariable.create(None) + + def enter(self, tx): + ctx = torch.amp._enter_autocast(*self.target_values) + self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx)) + self.proxy = tx.output.create_node( + "call_function", torch.amp._enter_autocast, (*self.target_values,), {} + ) + + def module_name(self): + return "torch.amp.autocast_mode" + + def fn_name(self): + return "autocast" + + +class NullContextVariable(ContextWrappingVariable): + """ + This class represents Python contextlib.nullcontext. + """ + + def __init__(self, target_values=None, **kwargs) -> None: + super().__init__(target_values=target_values, **kwargs) + + def enter(self, tx): + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + return variables.ConstantVariable.create(None) + + def module_name(self): + return "contextlib" + + def fn_name(self): + return "nullcontext" + + +class ProfilerContextVariable(ContextWrappingVariable): + """ + This class represents a set of torch profiler context objects, where Dynamo + ignores all the side-effects in the __init__, __enter__ and __exit__ methods + by treating the object mostly as a `contextlib.nullcontext`, except for edge + cases like the `__enter__` method which returns the object itself rather + than `None`, per implementation of the torch objects. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(target_values=None, **kwargs) + + def enter(self, tx): + return self + + def exit(self, tx: "InstructionTranslator", *args): + return variables.ConstantVariable.create(None) + + def module_name(self): + return "contextlib" + + def fn_name(self): + return "nullcontext" + + def reconstruct(self, cg): + unimplemented_v2( + gb_type="torch.profiler object escaped from compiled region", + context=str(self), + explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class StreamContextVariable(ContextWrappingVariable): + @staticmethod + def create(tx: "InstructionTranslator", target_value, **kwargs): + from .builder import wrap_fx_proxy_cls + + current_stream_method = get_interface_for_device( + target_value.device + ).current_stream + current_stream = wrap_fx_proxy_cls( + StreamVariable, + tx, + tx.output.create_proxy( + "call_function", + current_stream_method, + (None,), + {}, + ), + ) + return StreamContextVariable( + target_values=[target_value], + initial_values=[current_stream], + device=target_value.device, + **kwargs, + ) + + def __init__(self, target_values, device, initial_values=None, **kwargs) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.device = device + self.set_stream = get_interface_for_device(self.device).set_stream + self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id + + def enter(self, tx): + # stream generated inside the traced function + if self.target_values[0].as_proxy() is not None: + tx.output.create_proxy( + "call_function", + self.set_stream, + (self.target_values[0].as_proxy(),), + {}, + ) + # stream passed from outside the traced function + else: + stream = self.target_values[0].value + tx.output.create_proxy( + "call_function", + self.set_stream_id, + (stream.stream_id, stream.device_index, stream.device_type), + {}, + ) + self.set_stream(self.target_values[0].value) + self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value)) + + def exit(self, tx: "InstructionTranslator", *args): + tx.output.create_proxy( + "call_function", + self.set_stream, + (self.initial_values[0].as_proxy(),), + {}, + ) + self.cleanup_assert() + + +class PreserveVersionContextVariable(ContextWrappingVariable): + """ + Wraps torch.autograd._unsafe_preserve_version_counter + """ + + @staticmethod + def _create_lambda_from_tensors(tx, tensors): + if isinstance(tensors, variables.TensorVariable): + versions = variables.TupleVariable( + [x.var_getattr(tx, "_version") for x in [tensors]] + ) + tensors = variables.TupleVariable([tensors]) + else: + versions = variables.TupleVariable( + [x.var_getattr(tx, "_version") for x in tensors.items] + ) + return PreserveVersionContextVariable(tensors, versions) + + @staticmethod + def constructor(tx): + return variables.LambdaVariable( + lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors( + tx, tensors + ) + ) + + def __init__(self, tensors, prev_versions, **kwargs) -> None: + kwargs.setdefault("target_values", None) + super().__init__(**kwargs) + self.tensors = tensors + self.prev_versions = prev_versions + # The context manager accepts Union[Tensor, Tuple[Tensor]] + if isinstance(self.tensors, variables.TensorVariable): + self.tensors = variables.TupleVariable([self.tensors]) + if isinstance( + self.prev_versions, (variables.ConstantVariable, variables.SymNodeVariable) + ): + self.prev_versions = variables.TupleVariable([self.prev_versions]) + + def enter(self, tx): + pass + + def exit(self, tx: "InstructionTranslator", *args): + from ..tensor_version_op import _unsafe_set_version_counter + + return variables.TorchInGraphFunctionVariable( + _unsafe_set_version_counter + ).call_function(tx, [self.tensors, self.prev_versions], {}) + + def reconstruct(self, codegen: "PyCodegen"): + unimplemented_v2( + gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", + context=str(self), + explanation=( + "Dynamo doesn't support compiling a region that returns " + "a torch.autograd._unsafe_preserve_version_counter context manager." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable): + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) + + @staticmethod + def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs): + var = FSDPParamGroupUseTrainingStateVariable( + param_group_var=param_group_var, + target_values=[target_value], + initial_values=[param_group_var.value._training_state], + **kwargs, + ) + return var + + def __init__( + self, param_group_var, target_values, initial_values=None, **kwargs + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.param_group_var = param_group_var + install_guard(self._guards_singleton) + + def enter(self, tx): + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self._call_func(tx, self.initial_values) + return variables.ConstantVariable.create(None) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ): + self._call_func(tx, self.initial_values) # undo eager initialization + return super().call_function(tx, args, kwargs) + + def _call_func(self, tx: "InstructionTranslator", values): + assert len(values) == 1 + value = values[0] + if self.param_group_var.value._training_state != value: + self.param_group_var.call_method( + tx, + "__setattr__", + ( + variables.ConstantVariable.create("_training_state"), + variables.EnumVariable(value), + ), + {}, + ) + self.param_group_var.value._training_state = value + + def module_name(self): + return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup" + + def fn_name(self): + return "use_training_state" + + +class SDPAKernelVariable(ContextWrappingVariable): + """represents torch.nn.attention.sdpa_kernel""" + + @staticmethod + def create(tx: "InstructionTranslator", backends, set_priority=False, **kwargs): + if isinstance(backends, torch.nn.attention.SDPBackend): + backends = [backends] + var = SDPAKernelVariable( + target_values=backends, + initial_values=None, + set_priority=set_priority, + **kwargs, + ) + return var + + def __init__( + self, + target_values: list[torch.nn.attention.SDPBackend], + initial_values=None, + set_priority: bool = False, + **kwargs, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.set_priority = set_priority + + @staticmethod + def _backends_to_nodes(tx, backends): + # convert to/from string in order to bake the backend into FX graph + nodes = [ + tx.output.create_node( + "call_function", + torch.nn.attention._backend_from_string, + (backend.name,), + {}, + ) + for backend in backends + ] + return nodes + + def enter(self, tx): + self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends( + with_priority=self.set_priority + ) + self.set_cleanup_hook( + tx, + lambda: torch.nn.attention._sdpa_kernel( + self.prev_backends, set_priority=self.set_priority + ), + ) + torch.nn.attention._sdpa_kernel( + self.target_values, set_priority=self.set_priority + ) + arg = self._backends_to_nodes(tx, self.target_values) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg, bool(self.set_priority)), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self.cleanup_assert() + arg = self._backends_to_nodes(tx, self.prev_backends) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg, bool(self.set_priority)), + {}, + ) + return variables.ConstantVariable.create(None) + + def module_name(self): + return "torch.nn.attention" + + # use a private version of sdpa_kernel that accepts variadic arguments + # since dynamo reconstructs the contents of target_values one-by-one + def fn_name(self): + return "_sdpa_kernel_variadic" + + +class StreamVariable(VariableTracker): + def __init__(self, proxy, value, device, **kwargs) -> None: + if proxy is not None and "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == value + assert value.device.type == device.type, ( + "stream value is not equal to the passed device" + ) + super().__init__(**kwargs) + self.proxy = proxy + self.value = value + self.device = device + + def python_type(self): + return torch.Stream + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert hasattr(self.value, name), f"no stream method found named {name}" + + from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs + from .builder import wrap_fx_proxy_cls + + if name in ("wait_stream", "synchronize", "wait_event"): + tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ) + return variables.ConstantVariable(None) + elif name == "query": + return wrap_fx_proxy_cls( + target_cls=variables.ConstantVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + elif name == "record_event": + return wrap_fx_proxy_cls( + target_cls=EventVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: + # NB : Checking for mutation is necessary because we compare + # constant values + other = args[0] + if not isinstance(other, StreamVariable): + return variables.ConstantVariable.create(NotImplemented) + return variables.ConstantVariable.create( + cmp_name_to_op_mapping[name](self.value, other.value) + ) + + return super().call_method(tx, name, args, kwargs) + + def as_proxy(self): + return self.proxy + + def reconstruct(self, codegen: "PyCodegen"): + # If we got here, this stream is fully subsumed by the graph - this means it is + # not an input or global + assert not self.source + # Since we just proved that - for other such structures, like lists and dicts, reconstruction + # is fine and sound according to dynamo principles of treating collectives. However, + # streams are special in that we want to preserve the identity of the stream as the same as in the graph + # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not + # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending + # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there. + prefix = f"_stream_{self.device}" + name = codegen.tx.output.install_global_by_id(prefix, self.value) + codegen.append_output(codegen.create_load_global(name, add=True)) + + +class EventVariable(VariableTracker): + def __init__(self, proxy, value, **kwargs) -> None: + if proxy is not None and "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == value + super().__init__(**kwargs) + self.proxy = proxy + self.value = value + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..utils import proxy_args_kwargs + from .builder import wrap_fx_proxy_cls + + if name in ("wait", "record", "synchronize"): + tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ) + return variables.ConstantVariable(None) + elif name == "query": + return wrap_fx_proxy_cls( + target_cls=variables.ConstantVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + ), + ) + else: + method_name = ( + f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}" + ) + unimplemented_v2( + gb_type="Unsupported event method", + context=str(name), + explanation=f"Dynamo doesn't support tracing the {method_name} method. " + f"We currently support wait, record, synchronize, and query.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + def as_proxy(self): + return self.proxy + + def reconstruct(self, codegen: "PyCodegen"): + # If we got here, this event is fully subsumed by the graph - this means it is + # not an input or global + assert not self.source + # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there. + prefix = "_event" + name = codegen.tx.output.install_global_by_id(prefix, self.value) + codegen.append_output(codegen.create_load_global(name, add=True)) + + +class DynamoConfigPatchVariable(ContextWrappingVariable): + """represents torch._dynamo.patch_dynamo_config""" + + # NOTE: no need to guard on dynamo config because dynamo config should not affect soundness + # (though it may affect tracing behavior) + def __init__(self, target_values, **kwargs) -> None: + target_values = tuple(target_values.items()) + super().__init__(target_values=(target_values,), initial_values=None, **kwargs) + self.initial_values = {} + for key, _ in target_values: + self.initial_values[key] = torch._dynamo.config.__getattr__(key) + self.initial_values = (tuple(self.initial_values.items()),) + + def enter(self, tx): + # resets all config patches at the end of tracing + self.set_cleanup_hook(tx) + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self._call_func(tx, self.initial_values) + return variables.ConstantVariable.create(None) + + def _call_func(self, tx: "InstructionTranslator", values): + assert len(values) == 1 + value = values[0] + # manually patch dynamo config + for key, val in value: + torch._dynamo.config.__setattr__(key, val) + # No need to keep track of global side effects because + # dynamo will properly restore this context manager for + # unsupported instructions and continuation functions. + # Dynamo config also should not affect the semantics of the compiled graph. + + def module_name(self): + return "torch._dynamo" + + def fn_name(self): + return "patch_dynamo_config" + + +class WithExitFunctionVariable(VariableTracker): + _nonvar_fields = { + "target", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, + ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], + target, + **kwargs, + ) -> None: + super().__init__(**kwargs) + assert isinstance( + ctx, (ContextWrappingVariable, GenericContextWrappingVariable) + ) + self.ctx = ctx + self.target = target + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert not kwargs + return self.ctx.exit(tx, *args) + + def reconstruct(self, codegen: "PyCodegen"): + # Note here we reconstruct the context manager rather than the + # exit function. The handler generated by BlockStackEntry + # will re-enter the context in the resume function. + self.ctx.reconstruct_type(codegen) + if codegen.tx.output.partial_convert: + if sys.version_info >= (3, 11): + codegen.append_output(create_instruction("PUSH_NULL")) + if sys.version_info < (3, 13): + codegen.append_output(create_instruction("SWAP", arg=2)) + codegen.extend_output( + [codegen.create_load_const(val) for val in self.ctx.target_values] + ) + codegen.extend_output( + create_call_function(len(self.ctx.target_values), False) + ) + codegen.append_output(create_setup_with(self.target)) + codegen.append_output(create_instruction("POP_TOP")) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/dicts.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/dicts.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd9ef173a98d563a8c59624a21ed8faf8167a53 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/dicts.py @@ -0,0 +1,1117 @@ +# mypy: ignore-errors + +""" +Dictionary-related variable tracking classes for PyTorch Dynamo. + +This module implements variable tracking for different types of dictionary-like objects: +- Regular Python dictionaries (dict) +- Ordered dictionaries (collections.OrderedDict) +- Default dictionaries (collections.defaultdict) +- Dictionary views (keys and values) +- Sets and frozensets (implemented internally using dictionaries) + +These classes are responsible for tracking dictionary operations during graph compilation, +maintaining proper guards for dictionary mutations and key existence checks. They handle +dictionary creation, modification, key/value access, and view operations while ensuring +correct behavior in the compiled code through appropriate guard installation. + +The implementation uses a special _HashableTracker wrapper to handle dictionary keys +while preserving proper aliasing semantics. Sets are implemented as dictionaries with +None values for efficiency and code reuse. +""" + +import collections +import functools +import inspect +import operator +import types +from collections.abc import Hashable as py_Hashable +from typing import Optional, TYPE_CHECKING + +from torch._subclasses.fake_tensor import is_fake + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..exc import raise_observed_exception, unimplemented_v2 +from ..guards import GuardBuilder, install_guard +from ..source import is_from_local_source +from ..utils import ( + cmp_name_to_op_mapping, + dict_items, + dict_keys, + dict_values, + istype, + specialize_symnode, +) +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +# [Adding a new supported class within the keys of ConstDictVarialble] +# - Add its tracker type to is_hashable +# - (perhaps) Define how it is compared in _HashableTracker._eq_impl + + +def raise_args_mismatch(tx, name): + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable(f"wrong number of arguments for {name}() call")], + ) + + +def was_instancecheck_override(obj): + return type(obj).__dict__.get("__instancecheck__", False) + + +def raise_unhashable(arg, tx=None): + if tx is None: + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + raise_observed_exception( + TypeError, tx, args=[ConstantVariable(f"unhashable type: {type(arg)}")] + ) + + +def is_hashable(x): + # NB - performing isinstance check on a LazVT realizes the VT, accidentally + # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at + # the underlying value without realizing the VT. Consider updating the + # lazyVT `is_hashable` method if you see unnecessary guarding for a key VT. + if ( + isinstance(x, variables.LazyVariableTracker) + and not x.is_realized() + and x.is_hashable() + ): + return True + + if isinstance(x, variables.TensorVariable): + # Tensors are hashable if they have an example_value (a fake tensor) + # Most VT's should have one. + # It'd be nice if at some point we could assert that they all have one + return x.as_proxy().node.meta.get("example_value") is not None + elif isinstance(x, variables.TupleVariable): + return all(is_hashable(e) for e in x.items) + elif ( + isinstance(x, variables.UserDefinedObjectVariable) + and not was_instancecheck_override(x.value) + and inspect.getattr_static(x.value, "__hash__") is int.__hash__ + and isinstance(x.value, int) + ): + return isinstance(x.value, py_Hashable) + else: + return isinstance( + x, + ( + variables.BuiltinVariable, + variables.SymNodeVariable, + variables.ConstantVariable, + variables.EnumVariable, + variables.UserDefinedClassVariable, + variables.UserFunctionVariable, + variables.SkipFunctionVariable, + variables.misc.NumpyVariable, + variables.NNModuleVariable, + variables.UnspecializedNNModuleVariable, + variables.MethodWrapperVariable, + variables.TorchInGraphFunctionVariable, + variables.TypingVariable, + variables.FunctoolsPartialVariable, + variables.WeakRefVariable, + ), + ) + + +class ConstDictVariable(VariableTracker): + _nonvar_fields = { + "user_cls", + *VariableTracker._nonvar_fields, + } + + class _HashableTracker: + """ + Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable + This should not be seen or touched by anything outside of ConstDictVariable and its children + Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing + """ + + def __init__(self, vt) -> None: + # We specialize SymNodes + vt = specialize_symnode(vt) + # TODO Temporarily remove to figure out what keys are we breaking on + # and add proper support for them + if not is_hashable(vt): + raise_unhashable(vt) + self.vt = vt + + @property + def underlying_value(self): + if ( + isinstance(self.vt, variables.LazyVariableTracker) + and not self.vt.is_realized() + and self.vt.is_hashable() + ): + return self.vt.original_value() + if isinstance(self.vt, variables.TensorVariable): + x = self.vt.as_proxy().node.meta["example_value"] + elif isinstance(self.vt, variables.TupleVariable): + Hashable = ConstDictVariable._HashableTracker + x = tuple(Hashable(e).underlying_value for e in self.vt.items) + elif isinstance(self.vt, variables.NNModuleVariable): + return self.vt.value + elif isinstance(self.vt, variables.UnspecializedNNModuleVariable): + return self.vt.value + elif isinstance(self.vt, variables.UserFunctionVariable): + return self.vt.get_function() + elif isinstance(self.vt, variables.WeakRefVariable): + # Access the underlying value inside the referent_vt for the key representation + Hashable = ConstDictVariable._HashableTracker + return Hashable(self.vt.referent_vt).underlying_value + elif isinstance(self.vt, variables.UserDefinedObjectVariable): + # The re module in Python 3.13+ has a dictionary (_cache2) with + # an object as key (`class _ZeroSentinel(int): ...`): + # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual + return self.vt.value + else: + x = self.vt.as_python_constant() + return x + + def __hash__(self): + return hash(self.underlying_value) + + @staticmethod + def _eq_impl(a, b): + # TODO: Put this in utils and share it between variables/builtin.py and here + if type(a) != type(b): + return False + elif isinstance(a, tuple): + Hashable = ConstDictVariable._HashableTracker + return len(a) == len(b) and all( + Hashable._eq_impl(u, v) for u, v in zip(a, b) + ) + elif is_fake(a): + return a is b + else: + return a == b + + def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: + Hashable = ConstDictVariable._HashableTracker + assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), ( + type(other) + ) + if isinstance(other, Hashable): + return Hashable._eq_impl(self.underlying_value, other.underlying_value) + + # constant + return Hashable._eq_impl(self.underlying_value, other) + + def __init__( + self, + items: dict[VariableTracker, VariableTracker], + user_cls=dict, + **kwargs, + ) -> None: + # .clone() pass these arguments in kwargs but they're recreated a few + # lines below + if "original_items" in kwargs: + kwargs.pop("original_items") + if "should_reconstruct_all" in kwargs: + kwargs.pop("should_reconstruct_all") + + super().__init__(**kwargs) + + Hashable = ConstDictVariable._HashableTracker + + # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers + assert all( + isinstance(x, (VariableTracker, Hashable)) + and isinstance(v, VariableTracker) + for x, v in items.items() + ) + + def make_hashable(key): + return key if isinstance(key, Hashable) else Hashable(key) + + self.items = {make_hashable(x): v for x, v in items.items()} + # need to reconstruct everything if the dictionary is an intermediate value + # or if a pop/delitem was executed + self.should_reconstruct_all = not is_from_local_source(self.source) + self.original_items = items.copy() + self.user_cls = user_cls + + def as_proxy(self): + return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} + + def debug_repr(self): + return ( + "{" + + ", ".join( + f"{k.vt.debug_repr()}: {v.debug_repr()}" for k, v in self.items.items() + ) + + "}" + ) + + def as_python_constant(self): + return { + k.vt.as_python_constant(): v.as_python_constant() + for k, v in self.items.items() + } + + def keys_as_python_constant(self): + self.install_dict_keys_match_guard() + return {k.vt.as_python_constant(): v for k, v in self.items.items()} + + def python_type(self): + return self.user_cls + + def __contains__(self, vt) -> bool: + assert isinstance(vt, VariableTracker) + Hashable = ConstDictVariable._HashableTracker + return ( + is_hashable(vt) + and Hashable(vt) in self.items + and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) + ) + + def len(self): + return len( + [ + x + for x in self.items.values() + if not isinstance(x, variables.DeletedVariable) + ] + ) + + def has_new_items(self): + if self.should_reconstruct_all: + return True + return any( + self.is_new_item(self.original_items.get(key.vt), value) + for key, value in self.items.items() + ) + + def is_new_item(self, value, other): + # compare the id of the realized values if both values are not lazy VTs + if value and value.is_realized() and other.is_realized(): + return id(value.realize()) != id(other.realize()) + return id(value) != id(other) + + def reconstruct_kvs_into_new_dict(self, codegen): + # Build a dictionary that contains the keys and values. + num_args = 0 + for key, value in self.items.items(): + # We can safely call realize() here as it won't introduce any new guards + item = self.original_items.get(key.vt) + if self.is_new_item(item, value) or self.should_reconstruct_all: + codegen(key.vt) + codegen(value) + num_args += 1 + codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) + + def reconstruct(self, codegen: "PyCodegen"): + if self.user_cls is collections.OrderedDict: + # emit `OrderedDict(constructed_dict)` + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(collections), + codegen.create_load_attr("OrderedDict"), + ] + ) + ) + self.reconstruct_kvs_into_new_dict(codegen) + codegen.extend_output(create_call_function(1, False)) + else: + self.reconstruct_kvs_into_new_dict(codegen) + + def getitem_const_raise_exception_if_absent( + self, tx: "InstructionTranslator", arg: VariableTracker + ): + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + raise_observed_exception(KeyError, tx) + return self.items[key] + + def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + msg = f"Dictionary key {arg.value} not found during tracing" + unimplemented_v2( + gb_type="key not found in dict", + context=f"Key {arg.value}", + explanation=msg, + hints=[ + "Check if the key exists in the dictionary before accessing it.", + *graph_break_hints.USER_ERROR, + ], + ) + return self.items[key] + + def maybe_getitem_const(self, arg: VariableTracker): + key = ConstDictVariable._HashableTracker(arg) + if key not in self.items: + return None + return self.items[key] + + def realize_key_vt(self, arg: VariableTracker): + # Realize the LazyVT on a particular index + assert arg in self + key = ConstDictVariable._HashableTracker(arg) + index = tuple(self.items.keys()).index(key) + original_key_vt = tuple(self.original_items.keys())[index] + if isinstance(original_key_vt, variables.LazyVariableTracker): + original_key_vt.realize() + + def install_dict_keys_match_guard(self): + if self.source: + install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH)) + + def install_dict_contains_guard(self, tx, args): + # Key guarding - These are the cases to consider + # 1) The dict has been mutated. In this case, we would have already + # inserted a DICT_KEYS_MATCH guard, so we can skip. + # + # 2) args[0].source is None. This happens for const keys. Here, we + # have to insert the DICT_CONTAINS guard. + # + # 3) args[0].source is not None. This can happen for non-const VTs. + # 3a) contains=True. In this case, we can access the lazyVT from + # original_items and selectively realize it. + # 3b) contains=False. There is no easy way to selectively apply this + # DICT_NOT_CONTAINS guard because our guard are represented via trees. + # Be conservative and add DICT_KEYS_MATCH guard. + from . import ConstantVariable + + if not self.source: + return + + if tx.output.side_effects.is_modified(self): + return + + contains = args[0] in self + if args[0].source is None and isinstance(args[0], ConstantVariable): + install_guard( + self.make_guard( + functools.partial( + GuardBuilder.DICT_CONTAINS, + key=args[0].value, + invert=not contains, + ) + ) + ) + elif args[0].source: + if contains: + self.realize_key_vt(args[0]) + else: + self.install_dict_keys_match_guard() + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # NB - Both key and value are LazyVariableTrackers in the beginning. So, + # we have to insert guards when a dict method is accessed. For this to + # be simple, we are conservative and overguard. We skip guard only for + # get/__getitem__ because the key guard will be inserted by the + # corresponding value VT. For __contains__, we add a DICT_CONTAINS + # guard. But for all the other methods, we insert the DICT_KEYS_MATCH + # guard to be conservative. + from . import BuiltinVariable, ConstantVariable + + Hashable = ConstDictVariable._HashableTracker + + arg_hashable = args and is_hashable(args[0]) + + if name == "__init__": + temp_dict_vt = variables.BuiltinVariable(dict).call_dict( + tx, *args, **kwargs + ) + tx.output.side_effects.mutation(self) + self.items.update(temp_dict_vt.items) + return ConstantVariable.create(None) + elif name == "__getitem__": + # Key guarding - Nothing to do. LazyVT for value will take care. + assert len(args) == 1 + return self.getitem_const_raise_exception_if_absent(tx, args[0]) + elif name == "items": + assert not (args or kwargs) + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + return DictItemsVariable(self) + elif name == "keys": + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + assert not (args or kwargs) + return DictKeysVariable(self) + elif name == "values": + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + assert not (args or kwargs) + return DictValuesVariable(self) + elif name == "copy": + self.install_dict_keys_match_guard() + assert not (args or kwargs) + return self.clone( + items=self.items.copy(), mutation_type=ValueMutationNew(), source=None + ) + elif name == "__len__": + assert not (args or kwargs) + self.install_dict_keys_match_guard() + return ConstantVariable.create(len(self.items)) + elif name == "__setitem__" and self.is_mutable(): + if not arg_hashable: + raise_unhashable(args[0]) + + self.install_dict_keys_match_guard() + assert not kwargs and len(args) == 2 + tx.output.side_effects.mutation(self) + self.items[Hashable(args[0])] = args[1] + return ConstantVariable.create(None) + elif name == "__delitem__" and arg_hashable and self.is_mutable(): + self.install_dict_keys_match_guard() + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.__delitem__(Hashable(args[0])) + return ConstantVariable.create(None) + elif name in ("pop", "get") and len(args) in (1, 2) and args[0] not in self: + # missing item, return the default value. Install no DICT_CONTAINS guard. + self.install_dict_contains_guard(tx, args) + if len(args) == 1: + if name == "pop": + raise_observed_exception(KeyError, tx) + return ConstantVariable(None) + else: + return args[1] + elif name == "pop" and arg_hashable and self.is_mutable(): + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + return self.items.pop(Hashable(args[0])) + elif name == "clear": + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.clear() + return ConstantVariable.create(None) + elif name == "update" and self.is_mutable(): + # In general, this call looks like `a.update(b, x=1, y=2, ...)`. + # Either `b` or the kwargs is omittable, but not both. + self.install_dict_keys_match_guard() + has_arg = len(args) == 1 + has_kwargs = len(kwargs) > 0 + if has_arg or has_kwargs: + tx.output.side_effects.mutation(self) + if has_arg: + if isinstance(args[0], ConstDictVariable): + # NB - Guard on all the keys of the other dict to ensure + # correctness. + args[0].install_dict_keys_match_guard() + dict_vt = args[0] + else: + dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) + self.items.update(dict_vt.items) + if has_kwargs: + # Handle kwargs + kwargs = { + Hashable(ConstantVariable.create(k)): v + for k, v in kwargs.items() + } + self.items.update(kwargs) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + elif name in ("get", "__getattr__") and args[0] in self: + # Key guarding - Nothing to do. + return self.getitem_const(tx, args[0]) + elif name == "__contains__" and len(args) == 1: + if not arg_hashable: + raise_unhashable(args[0]) + + self.install_dict_contains_guard(tx, args) + contains = args[0] in self + return ConstantVariable.create(contains) + elif name == "setdefault" and arg_hashable and self.is_mutable(): + self.install_dict_keys_match_guard() + assert not kwargs + assert len(args) <= 2 + value = self.maybe_getitem_const(args[0]) + if value is not None: + return value + else: + if len(args) == 1: + x = ConstantVariable.create(None) + else: + x = args[1] + tx.output.side_effects.mutation(self) + self.items[Hashable(args[0])] = x + return x + elif name == "move_to_end": + self.install_dict_keys_match_guard() + assert not kwargs and len(args) == 1 + tx.output.side_effects.mutation(self) + key = Hashable(args[0]) + val = self.items[key] + self.items.pop(key) + self.items[key] = val + return ConstantVariable.create(None) + elif name == "__or__": + assert len(args) == 1 + if not isinstance(args[0], ConstDictVariable): + raise TypeError( + f"unsupported operand type(s) for |: 'dict' and '{args[0].python_type().__name__}'" + ) + + self.install_dict_keys_match_guard() + new_dict_vt = self.clone( + items=self.items.copy(), mutation_type=ValueMutationNew(), source=None + ) + + # NB - Guard on all the keys of the other dict to ensure + # correctness. + args[0].install_dict_keys_match_guard() + new_dict_vt.items.update(args[0].items) + return new_dict_vt + else: + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + self.install_dict_keys_match_guard() + return [x.vt for x in self.items.keys()] + + def call_obj_hasattr(self, tx, name): + # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict. + # OrderedDict though requires side effects tracking because it supports arbitrary setattr. + if self.user_cls is dict: + if name in self.user_cls.__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + + msg = f"hasattr on {self.user_cls} is not supported" + unimplemented_v2( + gb_type="unsupported hasattr operation", + context=f"Class {self.user_cls}", + explanation=msg, + hints=[ + "Consider using a regular dictionary instead", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def clone(self, **kwargs): + self.install_dict_keys_match_guard() + return super().clone(**kwargs) + + +class MappingProxyVariable(VariableTracker): + # proxies to the original dict_vt + def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(dv_dict, ConstDictVariable) + self.dv_dict = dv_dict + + def python_type(self): + return types.MappingProxyType + + def unpack_var_sequence(self, tx): + return self.dv_dict.unpack_var_sequence(tx) + + def reconstruct(self, codegen: "PyCodegen"): + # load types.MappingProxyType + if self.source: + msg = ( + f"Preexisting MappingProxyVariable (source: {self.source}) cannot be reconstructed " + "because the connection to the original dict will be lost." + ) + unimplemented_v2( + gb_type="mapping proxy cannot be reconstructed", + context=f"Source: {self.source}", + explanation=msg, + hints=[ + "Use a mapping proxy constructed in the same `torch.compile` region.", + *graph_break_hints.SUPPORTABLE, + ], + ) + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(types), + codegen.create_load_attr("MappingProxyType"), + ] + ) + ) + codegen(self.dv_dict) + codegen.extend_output(create_call_function(1, False)) + + def call_method( + self, + tx, + name, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + if self.source and tx.output.side_effects.has_existing_dict_mutation(): + msg = ( + "A dict has been modified while we have an existing mappingproxy object. " + "A mapping proxy object, as the name suggest, proxies a mapping " + "object (usually a dict). If the original dict object mutates, it " + "is reflected in the proxy object as well. For an existing proxy " + "object, we do not know the original dict it points to. Therefore, " + "for correctness we graph break when there is dict mutation and we " + "are trying to access a proxy object." + ) + + unimplemented_v2( + gb_type="mapping proxy affected by dictionary mutation", + context=f"Source: {self.source}, Dict mutation detected", + explanation=msg, + hints=[ + "Avoid modifying dictionaries that might be referenced by mapping proxy objects", + "Or avoid using the mapping proxy objects after modifying its underlying dictionary", + ], + ) + return self.dv_dict.call_method(tx, name, args, kwargs) + + +class NNModuleHooksDictVariable(ConstDictVariable): + # Special class to avoid adding any guards on the nn module hook ids. + def install_dict_keys_match_guard(self): + pass + + def install_dict_contains_guard(self, tx, args): + pass + + +class DefaultDictVariable(ConstDictVariable): + def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: + super().__init__(items, user_cls, **kwargs) + assert user_cls is collections.defaultdict + self.default_factory = default_factory + + def is_python_constant(self): + # Return false for unsupported defaults. This ensures that a bad handler + # path is not taken in BuiltinVariable for getitem. + if self.default_factory not in [list, tuple, dict] and not self.items: + return False + return super().is_python_constant() + + def debug_repr(self): + return ( + f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" + ) + + @staticmethod + def is_supported_arg(arg): + if isinstance(arg, variables.BuiltinVariable): + return arg.fn in (list, tuple, dict, set) + else: + return isinstance(arg, variables.functions.BaseUserFunctionVariable) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__getitem__": + assert len(args) == 1 + + if args[0] in self: + return self.getitem_const(tx, args[0]) + else: + if self.default_factory is None: + raise KeyError(f"{args[0]}") + else: + default_var = self.default_factory.call_function(tx, [], {}) + super().call_method( + tx, "__setitem__", (args[0], default_var), kwargs + ) + return default_var + else: + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen): + # emit `defaultdict(default_factory, new_dict)` + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(collections), + codegen.create_load_attr("defaultdict"), + ] + ) + ) + codegen(self.default_factory) + self.reconstruct_kvs_into_new_dict(codegen) + codegen.extend_output(create_call_function(2, False)) + + +# TODO: Implementing this via inheritance rather than composition is a +# footgun, because self method calls in dict will route back to the set +# implementation, which is almost assuredly wrong +class SetVariable(ConstDictVariable): + """We model a sets as dictionary with None values""" + + def __init__( + self, + items: list[VariableTracker], + **kwargs, + ) -> None: + items = dict.fromkeys(items, SetVariable._default_value()) + super().__init__(items, **kwargs) + + def debug_repr(self): + if not self.items: + return "set()" + else: + return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" + + @property + def set_items(self): + return set(self.items.keys()) + + @staticmethod + def _default_value(): + # Variable to fill in he keys of the dictionary + return ConstantVariable.create(None) + + def as_proxy(self): + return {k.vt.as_proxy() for k in self.set_items} + + def python_type(self): + return set + + def as_python_constant(self): + return {k.vt.as_python_constant() for k in self.set_items} + + def reconstruct(self, codegen: "PyCodegen"): + codegen.foreach([x.vt for x in self.set_items]) + codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + # We forward the calls to the dictionary model + if name == "__init__": + temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs) + tx.output.side_effects.mutation(self) + self.items.clear() + self.items.update(temp_set_vt.items) + return ConstantVariable.create(None) + elif name == "add": + assert not kwargs + if len(args) != 1: + raise_args_mismatch(tx, name) + name = "__setitem__" + args = (args[0], SetVariable._default_value()) + elif name == "pop": + assert not kwargs + assert not args + # Choose an item at random and pop it via the Dict.pop method + try: + result = self.set_items.pop().vt + except KeyError as e: + raise_observed_exception( + KeyError, tx, args=list(map(ConstantVariable.create, e.args)) + ) + super().call_method(tx, name, (result,), kwargs) + return result + elif name == "isdisjoint": + assert not kwargs + assert len(args) == 1 + return variables.UserFunctionVariable( + polyfills.set_isdisjoint + ).call_function(tx, [self, args[0]], {}) + elif name == "intersection": + assert not kwargs + return variables.UserFunctionVariable( + polyfills.set_intersection + ).call_function(tx, [self, *args], {}) + elif name == "intersection_update": + assert not kwargs + return variables.UserFunctionVariable( + polyfills.set_intersection_update + ).call_function(tx, [self, *args], {}) + elif name == "union": + assert not kwargs + return variables.UserFunctionVariable(polyfills.set_union).call_function( + tx, [self, *args], {} + ) + elif name == "difference": + assert not kwargs + return variables.UserFunctionVariable( + polyfills.set_difference + ).call_function(tx, [self, *args], {}) + elif name == "difference_update": + assert not kwargs + return variables.UserFunctionVariable( + polyfills.set_difference_update + ).call_function(tx, [self, *args], {}) + elif name == "symmetric_difference": + if len(args) != 1: + raise_args_mismatch(tx, name) + assert not kwargs + return variables.UserFunctionVariable( + polyfills.set_symmetric_difference + ).call_function(tx, [self, *args], {}) + elif name == "symmetric_difference_update": + if len(args) != 1: + raise_args_mismatch(tx, name) + assert not kwargs + return variables.UserFunctionVariable( + polyfills.set_symmetric_difference_update + ).call_function(tx, [self, *args], {}) + elif name == "update" and self.is_mutable(): + assert not kwargs + return variables.UserFunctionVariable(polyfills.set_update).call_function( + tx, [self, *args], {} + ) + elif name == "remove": + assert not kwargs + assert len(args) == 1 + if args[0] not in self: + raise_observed_exception(KeyError, tx, args=args) + return super().call_method(tx, "pop", args, kwargs) + elif name == "discard": + assert not kwargs + assert len(args) == 1 + if args[0] in self: + return super().call_method(tx, "pop", args, kwargs) + else: + return ConstantVariable.create(value=None) + elif name in ("issubset", "issuperset"): + op = { + "issubset": operator.le, + "issuperset": operator.ge, + } + other = args[0].realize() + if not istype(other, SetVariable): + other = variables.BuiltinVariable(set).call_function(tx, [other], {}) + return variables.BuiltinVariable(op.get(name)).call_function( + tx, [self, other], {} + ) + return super().call_method(tx, name, args, kwargs) + + def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + raise RuntimeError("Illegal to getitem on a set") + + def install_dict_keys_match_guard(self): + # Already EQUALS_MATCH guarded + pass + + def install_dict_contains_guard(self, tx, args): + # Already EQUALS_MATCH guarded + pass + + +class FrozensetVariable(SetVariable): + def __init__( + self, + items: list[VariableTracker], + **kwargs, + ) -> None: + super().__init__(items, **kwargs) + + def debug_repr(self): + if not self.items: + return "frozenset()" + else: + return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" + + @property + def set_items(self): + return self.items.keys() + + def python_type(self): + return frozenset + + def as_python_constant(self): + return {k.vt.as_python_constant() for k in self.set_items} + + def reconstruct(self, codegen: "PyCodegen"): + codegen.foreach([x.vt for x in self.set_items]) + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_global("frozenset"), + ] + ) + ) + codegen.extend_output(create_call_function(0, False)) + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a frozenset") + elif name == "__init__": + # frozenset is immutable. Calling __init__ again shouldn't have any effect + # In[1]: s = frozenset([1, 2]) + # + # In[2]: s.__init__([3, 4]) + # + # In[3]: s + # frozenset({1, 2}) + return ConstantVariable.create(None) + return super().call_method(tx, name, args, kwargs) + + +class DictKeySetVariable(SetVariable): + def __init__( + self, + items: list[VariableTracker], + **kwargs, + ) -> None: + super().__init__(items, **kwargs) + + def debug_repr(self): + if not self.items: + return "dict_keys([])" + else: + return ( + "dict_keys([" + + ",".join(k.vt.debug_repr() for k in self.items.keys()) + + "])" + ) + + @property + def set_items(self): + return self.items + + def python_type(self): + return dict_keys + + def as_python_constant(self): + return dict.fromkeys( + {k.vt.as_python_constant() for k in self.set_items}, None + ).keys() + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a dict_keys") + return super().call_method(tx, name, args, kwargs) + + +class DictViewVariable(VariableTracker): + """ + Models _PyDictViewObject + + This is an "abstract" class. Subclasses will override kv and the items method + """ + + kv: Optional[str] = None + + def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: + super().__init__(**kwargs) + assert self.kv in ("keys", "values", "items") + assert isinstance(dv_dict, ConstDictVariable) + self.dv_dict = dv_dict + + @property + def view_items(self): + return getattr(self.dv_dict.items, self.kv)() + + @property + def view_items_vt(self): + # Returns an iterable of the unpacked items + # Implement in the subclasses + raise NotImplementedError + + def unpack_var_sequence(self, tx): + return self.view_items_vt + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.dv_dict) + codegen.load_method(self.kv) + codegen.call_method(0) + + def call_method( + self, + tx, + name, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__len__": + return self.dv_dict.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + +class DictKeysVariable(DictViewVariable): + kv = "keys" + + @property + def set_items(self): + return set(self.view_items) + + @property + def view_items_vt(self): + # Returns an iterable of the unpacked items + return [x.vt for x in self.view_items] + + def python_type(self): + return dict_keys + + def call_method( + self, + tx, + name, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__contains__": + return self.dv_dict.call_method(tx, name, args, kwargs) + if name in cmp_name_to_op_mapping: + if not isinstance(args[0], (SetVariable, DictKeysVariable)): + return ConstantVariable.create(NotImplemented) + return ConstantVariable.create( + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) + ) + return super().call_method(tx, name, args, kwargs) + + +class DictValuesVariable(DictViewVariable): + # DictValuesVariable is an iterable but cannot be compared. + kv = "values" + + @property + def view_items_vt(self): + return list(self.view_items) + + def python_type(self): + return dict_values + + +class DictItemsVariable(DictViewVariable): + kv = "items" + + @property + def view_items_vt(self): + # Returns an iterable of the unpacked items + return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items] + + def python_type(self): + return dict_items diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/distributed.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..b299b814ad3ed9a963c0e8f3f4e7904949ee9e99 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/distributed.py @@ -0,0 +1,442 @@ +# mypy: ignore-errors + +""" +Distributed computing variable tracking classes for PyTorch Dynamo. + +This module implements variable tracking for distributed computing components: +- Process Groups (for collective communication) +- Device Meshes (for distributed tensor sharding) +- Placement Types (for specifying distribution strategies) +- Distributed Tensors and their operations +- Backward hooks for distributed module operations + +These classes are responsible for tracking distributed operations during graph +compilation while maintaining proper guards and handling distributed-specific +behaviors. They ensure correct handling of distributed components like process +groups, device meshes, and placement strategies while preserving proper semantics +for distributed tensor operations in the compiled code. + +The implementation provides special handling for distributed package availability +checks and proper tracking of distributed state and operations across processes. +""" + +import functools +import inspect +from typing import TYPE_CHECKING + +import torch +from torch.fx.experimental._backward_state import BackwardState + +from .. import compiled_autograd, variables +from .._trace_wrapped_higher_order_op import trace_wrapped +from ..exc import unimplemented_v2 +from ..external_utils import call_module_hooks_from_backward_state +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource +from ..utils import istype +from .base import VariableTracker +from .constant import ConstantVariable, EnumVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class DistributedVariable(VariableTracker): + """ + The base distributed variable that encapsulates common methods + for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.). + Concrete distributed objects could inherit this class and add object + specific logic. + + i.e. It provides the check on the distributed package existence + and hold the tracking value for the corresponding distributed object. + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + if not DistributedVariable.is_available(): + unimplemented_v2( + gb_type="torch.distributed package is not available!", + context="", + explanation="The PyTorch package doesn't include torch.distributed when building from source.", + hints=[ + "Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source." + ], + ) + self.value = value + + def python_type(self): + return type(self.value) + + @staticmethod + def is_available(): + # check if the distributed package is available or not + return torch.distributed.is_available() + + +def is_from_local(value): + if not DistributedVariable.is_available(): + return False + from torch.distributed.tensor import DTensor + + return inspect.isfunction(value) and value is DTensor.from_local + + +def is_constant_pg_functions(value): + if not DistributedVariable.is_available(): + return False + + from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + _resolve_group_name_by_ranks_and_tag, + get_process_group_ranks, + ) + + constant_processgroup_functions = [ + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + get_process_group_ranks, + _resolve_group_name_by_ranks_and_tag, + ] + + return inspect.isfunction(value) and value in constant_processgroup_functions + + +class WorldMetaClassVariable(DistributedVariable): + """ + Tracks torch.distributed.GroupMember and torch.distributed.group, which are + instances of the metaclass _WorldMeta. + """ + + @classmethod + def is_group_member_type(cls, value): + if not cls.is_available(): + return False + + from torch.distributed.distributed_c10d import _WorldMeta + + return type(value) is _WorldMeta + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "WORLD": + source = AttrSource(base=self.source, member="WORLD") + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return ProcessGroupVariable(self.value.WORLD) + elif name == "NON_GROUP_MEMBER": + source = AttrSource(base=self.source, member="NON_GROUP_MEMBER") + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return EnumVariable(self.value.NON_GROUP_MEMBER) + return super().var_getattr(tx, name) + + +class PlacementClassVariable(DistributedVariable): + @staticmethod + def is_placement_type(value): + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed.tensor.placement_types import Placement + + return type(value) is type and issubclass(value, Placement) + + def as_python_constant(self): + return self.value + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if ( + inspect.getattr_static(self.value, "__new__", None) in (object.__new__,) + and self.source + ): + # NOTE: we don't need to track mutations to the placement class as they + # suppose to be immutable. + new_obj = object.__new__(self.value) + var = PlacementVariable(new_obj) + if inspect.getattr_static(self.value, "__init__", None): + var.call_method(tx, "__init__", args, kwargs) + return var + + return super().call_function(tx, args, kwargs) + + +class PlacementVariable(DistributedVariable): + @staticmethod + def is_placement(value): + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed.tensor.placement_types import Placement + + return isinstance(value, Placement) + + def as_python_constant(self): + return self.value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "dim": + return ConstantVariable.create(self.value.dim) + return super().var_getattr(tx, name) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ConstantVariable + + # Placement types dynamo tracking only allows following methods + # and __setattr__ is for case like `Shard(dim)` and methods. + # Methods in the list must satisfy: + # 1. Input arguments are constants and do not need to be guarded on; + # 2. Output is constant with respect to their inputs + constant_fold_functions = [ + "__init__", + "__setattr__", + "is_shard", + "is_partial", + "is_replicate", + ] + + if name in constant_fold_functions: + try: + value_type = type(self.value) + assert ( + inspect.getattr_static(value_type, "__getattr__", None) is None + ), "no custom getattr allowed!" + method = inspect.getattr_static(value_type, name) + except AttributeError: + method = None + if method is object.__init__: + return ConstantVariable.create(None) + + args = [x.as_python_constant() for x in args] + kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + if name == "__setattr__": + method(self.value, *args, **kwargs) + return self + constant_val = method(self.value, *args, **kwargs) + return ConstantVariable.create(constant_val) + + return super().call_method(tx, name, args, kwargs) + + +class DeviceMeshVariable(DistributedVariable): + @staticmethod + def is_device_mesh(value): + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + + from torch.distributed.device_mesh import DeviceMesh + + return istype(value, DeviceMesh) + + def as_python_constant(self): + return self.value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "ndim": + return ConstantVariable.create(self.value.ndim) + if name == "device_type": + return ConstantVariable.create(self.value.device_type) + return super().var_getattr(tx, name) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "size": + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return ConstantVariable.create(self.value.size(*const_args, **const_kwargs)) + if name == "get_coordinate": + return ConstantVariable.create(self.value.get_coordinate()) + if name == "get_group": + const_args = [x.as_python_constant() for x in args] + const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + return ProcessGroupVariable( + self.value.get_group(*const_args, **const_kwargs) + ) + if name == "_get_or_create_default_group": + return ProcessGroupVariable(self.value._get_or_create_default_group()) + return super().call_method(tx, name, args, kwargs) + + +class ProcessGroupVariable(DistributedVariable): + """ + We don't want a ProcessGroup object to end up in our output graph. + + But it's common for dynamo to intercept a PG that is then used to get info like + rank() or world_size(), as well as passed to utility functions in distributed_c10d + which desugar it into plain types like a ranklist and tag. + + For convenience and proper guarding, we construct a variable type. + + TODO: make it possible to use ProcessGroupVariable as input to simple functions + like _expand_group without dynamo complaining about making a proxy for it. + It is not a tensor-like type, and we don't want a proxy- but dynamo assumes + torch library functions are dealing with tensor-like types and would have proxies + for their args. + TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors + or just graph-break whenever one of our special cases is not hit? + """ + + def as_python_constant(self): + return self.value + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "rank": + return variables.ConstantVariable.create(self.value.rank()) + if name == "size": + return variables.ConstantVariable.create(self.value.size()) + if name == "_get_backend_name": + return variables.ConstantVariable.create(self.value._get_backend_name()) + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name): + if name == "group_name": + return variables.ConstantVariable.create(self.value.group_name) + if name in ["rank", "size"]: + return variables.LambdaVariable( + lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) + ) + # TODO should this just raise unimplemented? + return super().var_getattr(tx, name) + + @staticmethod + def is_process_group(value): + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + from torch._C._distributed_c10d import ProcessGroup + from torch.testing._internal.distributed.fake_pg import FakeProcessGroup + + return istype(value, (ProcessGroup, FakeProcessGroup)) + + +class BackwardHookVariable(VariableTracker): + """ + Handles torch.utils.hooks.BackwardHook for module-level backward + hooks. + """ + + @staticmethod + def create( + tx, + module: VariableTracker, + user_hooks: VariableTracker, + user_pre_hooks: VariableTracker, + ): + if not compiled_autograd.compiled_autograd_enabled: + unimplemented_v2( + gb_type="Module-level backwards hooks require compiled autograd.", + context="", + explanation="", + hints=[ + "Enable compiled autograd by setting torch._dynamo.config.compiled_autograd = True." + ], + ) + + def _in_graph_bw_hooks(bw_state: BackwardState): + """ + Rather than installing the user hooks in the graph (which + don't survive AotAutograd), we install hooks that will call + trace_wrapped in the backward pass that CompiledAutograd + can turn into actual hook calls. + """ + return torch.utils.hooks.BackwardHook( + None, + ( + functools.partial( + trace_wrapped, + fn=call_module_hooks_from_backward_state, + bw_state=bw_state, + hooks_name=user_hooks_name, + module_name=module_name, + ), + ), + ( + functools.partial( + trace_wrapped, + fn=call_module_hooks_from_backward_state, + bw_state=bw_state, + hooks_name=user_pre_hooks_name, + module_name=module_name, + ), + ), + ) + + module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod") + user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks) + user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks) + proxy = tx.output.create_proxy( + "call_function", + _in_graph_bw_hooks, + (bw_state_proxy,), + {}, + ) + proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ()) + return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks) + + def __init__( + self, + proxy: torch.fx.Proxy, + module: VariableTracker, + user_hooks: VariableTracker, + user_pre_hooks: VariableTracker, + **options, + ) -> None: + super().__init__(**options) + self.proxy = proxy + self.module = module + self.user_hooks = user_hooks + self.user_pre_hooks = user_pre_hooks + + def as_proxy(self): + return self.proxy + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ("setup_input_hook", "setup_output_hook"): + return self._setup_hook(tx, name, *args, **kwargs) + return super().call_method(tx, name, args, kwargs) + + def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + hook_method_name, + (self.as_proxy(), args.as_proxy()), + {}, + ), + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/functions.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4305521b42a6a86f3a554b4052c56730c8605d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/functions.py @@ -0,0 +1,2369 @@ +# mypy: ignore-errors + +""" +Function-related variable tracking classes for Dynamo's symbolic execution. + +This module contains classes that track different types of functions during graph +compilation, including: +- User-defined functions and methods +- Built-in functions and methods +- Wrapped functions (e.g. from decorators) +- Special function types (e.g. functools.partial) +- Triton kernels and related function types + +These classes are responsible for: +- Tracking function calls and their arguments +- Managing function closures and cell variables +- Handling function attributes and special methods +- Maintaining guards for function identity and closure contents +- Supporting function inlining and specialization +- Enabling proper symbolic execution of different function types + +The variable trackers here work together with the rest of Dynamo to enable +accurate graph capture while handling Python's various function-related behaviors. +""" + +import builtins +import functools +import inspect +import itertools +import logging +import sys +import traceback +import types +from collections.abc import Sequence +from types import FunctionType +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar +from typing_extensions import Never +from unittest.mock import patch +from weakref import WeakKeyDictionary + +import torch +from torch._dynamo.exc import get_stack_above_dynamo + +from .. import config, graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function, create_rot_n, is_generator +from ..exc import ( + get_dynamo_observed_exception, + handle_observed_exception, + InfiniteGeneratorError, + ObservedException, + ObservedGeneratorExit, + ObservedUserStopIteration, + raise_observed_exception, + SkipFrame, + unimplemented_v2, + Unsupported, +) +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource +from ..utils import ( + check_constant_args, + check_unspec_or_constant_args, + cmp_name_to_op_mapping, + counters, + identity, + is_function, + is_wrapper_or_member_descriptor, + istype, + make_cell, +) +from .base import ( + AsPythonConstantNotImplementedError, + AttributeMutationNew, + ValueMutationNew, + VariableTracker, +) +from .constant import ConstantVariable + + +try: + from torch.distributed.fsdp._fully_shard import _fsdp_param_group +except ModuleNotFoundError: + _fsdp_param_group = None + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._higher_order_ops.triton_kernel_wrap import ( + TritonGridType, + TritonKernelType, + ) + + +_F = TypeVar("_F", bound=Callable) +CO_VARARGS = 0x04 +CO_VARKEYWORDS = 0x08 + + +# Module‐level cache keyed by the function object +_spec_cache = WeakKeyDictionary() + + +class FunctionSpec: + def __init__(self, func: FunctionType): + code = func.__code__ + vn = code.co_varnames + + self.posonly_count = code.co_posonlyargcount + self.arg_count = code.co_argcount + self.kwonly_count = code.co_kwonlyargcount + + self.posonly_names = vn[: self.posonly_count] + self.pos_or_kw_names = vn[self.posonly_count : self.arg_count] + self.all_pos_names = self.posonly_names + self.pos_or_kw_names + self.kwonly_names = vn[self.arg_count : self.arg_count + self.kwonly_count] + + off = self.arg_count + self.kwonly_count + self.varargs_name = vn[off] if code.co_flags & CO_VARARGS else None + off += 1 if self.varargs_name else 0 + self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None + + def update_defaults(self, func: FunctionType): + # Defaults can change from function call to function call. So re-update + # them on every call. + self.defaults = func.__defaults__ or () + self.kwdefaults = func.__kwdefaults__ or {} + + # Map positional‐default names → their index in self.defaults + self.pos_default_map = dict( + zip(self.all_pos_names[-len(self.defaults) :], range(len(self.defaults))) + ) + + +def _get_spec(func: FunctionType) -> FunctionSpec: + spec = _spec_cache.get(func) + if spec is None: + spec = FunctionSpec(func) + _spec_cache[func] = spec + return spec + + +def bind_args_cached(func, tx, fn_source, args, kwargs): + spec = _get_spec(func) + spec.update_defaults(func) + ba = {} + rem_kw = dict(kwargs) + + # 1) Bind all positional (pos-only + pos-or-kw) + for i, name in enumerate(spec.all_pos_names): + if i < len(args): + ba[name] = wrap_bound_arg(tx, args[i]) + elif name in rem_kw: + if name in spec.posonly_names: + raise TypeError(f"{name} is positional-only") + ba[name] = wrap_bound_arg(tx, rem_kw.pop(name)) + elif name in spec.pos_default_map: + idx = spec.pos_default_map[name] + default_source = None + if fn_source: + default_source = DefaultsSource(fn_source, idx) + ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source) + else: + raise TypeError(f"Missing required positional argument: {name}") + + # 2) *args + extra = args[len(spec.all_pos_names) :] + if spec.varargs_name: + ba[spec.varargs_name] = wrap_bound_arg(tx, tuple(extra)) + elif extra: + raise TypeError( + f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}" + ) + + # 3) Keyword-only + for name in spec.kwonly_names: + if name in rem_kw: + ba[name] = wrap_bound_arg(tx, rem_kw.pop(name)) + elif name in spec.kwdefaults: + kwdefault_source = None + if fn_source: + kwdefault_source = DefaultsSource(fn_source, name, is_kw=True) + ba[name] = wrap_bound_arg(tx, spec.kwdefaults[name], kwdefault_source) + else: + raise TypeError(f"Missing required keyword-only argument: {name}") + + # 4) **kwargs + if spec.varkw_name: + ba[spec.varkw_name] = wrap_bound_arg(tx, rem_kw) + elif rem_kw: + raise TypeError(f"Unexpected keyword arguments: {list(rem_kw)}") + + return ba + + +def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): + # Source propagation is best effort since not every object we encounter has a source to begin with. + if isinstance(val, VariableTracker): + return val + elif not source: + return VariableTracker.build(tx, val) + else: + # Create a lazy variable to avoid guarding on __defaults__ unless really + # needed. + return variables.LazyVariableTracker.create(val, source) + + +def wrap_args_kwargs(tx: "InstructionTranslator", result): + for k, v in list(result.items()): + if isinstance(v, (tuple, dict)): + # args/kwargs + result[k] = wrap_bound_arg(tx, v) + + +def init_cellvars(parent, result: dict[str, VariableTracker], code): + """ + Update `result` to add mapping from local name to new cells created + directly by `code`, or update SideEffects in `parent` if the a local cell is + already in `result` (cell argument). + """ + side_effects = parent.output.side_effects + + for name in code.co_cellvars: + new_cell = side_effects.track_cell_new() + if name in result: + # This handles when a function argument is a cell (e.g., captured by + # a nested func). See `MAKE_CELL` bytecode for more info. + side_effects.store_cell(new_cell, result.pop(name)) + result[name] = new_cell + + +def _create_nested_fn( + code, f_globals, name, defaults, closure, kwdefaults, annotations +): + from types import FunctionType + + func = FunctionType(code, f_globals, name, defaults, closure) + func.__kwdefaults__ = kwdefaults + + if isinstance(annotations, tuple): + from itertools import pairwise + + annotations = dict(pairwise(annotations)) + + # TypeError: __annotations__ must be set to a dict object + assert annotations is None or isinstance(annotations, dict) + func.__annotations__ = annotations + + return func + + +fn_known_dunder_attrs = { + "__annotations__", + "__defaults__", + "__kwdefaults__", + "__code__", + "__globals__", + "__closure__", + "__doc__", +} + + +def fn_var_getattr(tx, fn, source, name): + source = source and AttrSource(source, name) + try: + subobj = inspect.getattr_static(fn, name) + except AttributeError: + # function does not have a __getattr__ or __getattribute__ method, + # so we can safely assume that this attribute is absent + raise_observed_exception(AttributeError, tx) + + # Special handling for known dunder attributes + if name in fn_known_dunder_attrs: + subobj = getattr(fn, name) + if source: + return variables.LazyVariableTracker.create(subobj, source) + return VariableTracker.build(tx, subobj) + + +class BaseUserFunctionVariable(VariableTracker): + def get_filename(self): + return self.get_code().co_filename + + def get_name(self): + return self.get_code().co_name + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + result = False + + try: + result = hasattr(self.get_function(), name) + except NotImplementedError: + if name == "__name__" and isinstance(self, NestedUserFunctionVariable): + result = True + return variables.ConstantVariable.create(result) + + def inspect_parameter_names(self): + return list(inspect.signature(self.get_function()).parameters) + + def closure_vars(self, tx): + return {} + + +class UserFunctionVariable(BaseUserFunctionVariable): + """Some unsupported user-defined global function""" + + _nonvar_fields = { + "fn", + "is_constant", + *BaseUserFunctionVariable._nonvar_fields, + } + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) + return cls(value, source=source) + + def __init__(self, fn, is_constant=False, **kwargs) -> None: + super().__init__(**kwargs) + if getattr(fn, "_dynamo_marked_constant", False): + # This method should be treated as a constant for the purposes of compilation + self.is_constant = True + else: + self.is_constant = False + + # TODO putting this here to avoid duplication, because we could hit this + # from several paths (e.g., SuperVariable or `var_getattr`s). + if not isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)): + unimplemented_v2( + gb_type="can't handle functions not implemented in python ", + context=f"{fn}", + explanation="Dynamo can only handle functions defined in python", + hints=[ + "Move usage of this function out of `torch.compile` region", + *graph_break_hints.INFERENCE_MODE, + ], + ) + # TODO(anijain2305) - Replace directly calling UserFunctionVariable with + # VariableBuilder, which handles the wrapping of _torchdynamo_inline. + # unpack @torch._dynamo.optimize()(fn) wrapped function + fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) + self.fn: types.FunctionType = fn + + def as_python_constant(self): + if istype(self, UserFunctionVariable): + return self.fn + # subclasses (such as methods) usually aren't a constant + return super().as_python_constant() + + def self_args(self): + return [] + + def get_function(self): + return self.fn + + def get_code(self): + return self.fn.__code__ + + def python_type(self): + return types.FunctionType + + def has_self(self): + return getattr(self.fn, "__self__", None) is not None + + def get_globals(self): + return self.fn.__globals__ + + def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: + """ + Assume `args` and `kwargs` are VariableTracker arguments for a call to + this function, create new bindings for initial locals. + """ + assert not self.is_constant + + fn: types.FunctionType = self.fn + + if not isinstance(fn, FunctionType): + raise TypeError("Only supports regular Python functions.") + root_tx = parent.output.root_tx + result = bind_args_cached(fn, root_tx, self.source, args, kwargs) + + init_cellvars(parent, result, fn.__code__) + closure = self.fn.__closure__ or () + assert len(closure) == len(self.fn.__code__.co_freevars) + for idx, name, cell in zip( + itertools.count(), self.fn.__code__.co_freevars, closure + ): + # TODO refactor these 3 branches. + side_effects = parent.output.side_effects + if cell in side_effects: + cell_var = side_effects[cell] + + elif self.source: + closure_cell = GetItemSource( + AttrSource(self.source, "__closure__"), idx + ) + closure_cell_contents = AttrSource(closure_cell, "cell_contents") + try: + contents_var = VariableTracker.build( + parent, cell.cell_contents, closure_cell_contents + ) + except ValueError: + # Cell has not yet been assigned + contents_var = variables.DeletedVariable() + cell_var = side_effects.track_cell_existing( + closure_cell, cell, contents_var + ) + + else: + # TODO figure out why source isn't available here, and whether + # we can fix that and remove this branch. + try: + contents_var = VariableTracker.build(parent, cell.cell_contents) + except ValueError: + # Cell has not yet been assigned + contents_var = variables.DeletedVariable() + cell_var = side_effects.track_cell_existing(None, cell, contents_var) + + result[name] = cell_var + + return result + + def var_getattr(self, tx: "InstructionTranslator", name: str): + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + return fn_var_getattr(tx, self.fn, self.source, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + result = hasattr(self.fn, name) + return variables.ConstantVariable.create(result) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # Handle patch_dynamo_config call + + if self.fn is torch._dynamo.patch_dynamo_config: + try: + args_const = [arg.as_python_constant() for arg in args] + kwargs_const = { + key: val.as_python_constant() for key, val in kwargs.items() + } + changes = torch._dynamo.patch_dynamo_config( + *args_const, **kwargs_const + ).changes + return variables.DynamoConfigPatchVariable(changes) + except AsPythonConstantNotImplementedError as e: + raise RuntimeError( + "Cannot convert patch_dynamo_config args/kwargs to constants. " + "Please fix your call to patch_dynamo_config by using simpler inputs. " + f"args: {args}, kwargs: {kwargs}" + ) from e + # Handle a `nonstrict_trace(fn)` call + if self.fn is torch._dynamo.nonstrict_trace: + bound = inspect.signature(self.fn).bind(*args, **kwargs) + fn_var = bound.args[0] + if not isinstance(fn_var, BaseUserFunctionVariable): + typ = fn_var.python_type() + msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>" + unimplemented_v2( + gb_type="TypeError from user code", + context=f"call_function({self.value}, {args}, {kwargs})", + explanation=msg, + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + if not isinstance(fn_var, UserFunctionVariable): + fn_name = fn_var.get_name() + msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950 + unimplemented_v2( + gb_type="Limitation of `nonstrict_trace", + context=f"{self}", + explanation=msg, + hints=[ + f"make sure definition of {fn_name} is outside ", + "`torch.compile` region", + ], + ) + + fn = fn_var.fn + return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True) + + if self.is_constant: + return invoke_and_store_as_constant( + tx, self.fn, self.get_name(), args, kwargs + ) + + if ( + not tx.output.current_tracer.unsafe_allow_externally_visible_side_effects + and self.fn + is torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer + ): + with torch._dynamo.side_effects.allow_externally_visible_side_effects_in_subtracer( + tx + ): + return super().call_function(tx, args, kwargs) + + if ( + tx.output.current_tracer.under_activation_checkpoint + and not tx.output.current_tracer.allow_side_effects_under_checkpoint + ): + try: + from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState + except Exception: + FSDPState = None + if FSDPState is not None and self.fn in [ + FSDPState._pre_forward, + FSDPState._post_forward, + ]: + with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx): + return super().call_function(tx, args, kwargs) + return super().call_function(tx, args, kwargs) + + +class BuiltinMethodVariable(BaseUserFunctionVariable): + def __init__(self, fn, is_constant=False, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(fn, types.BuiltinMethodType) + self.fn = fn + + @staticmethod + def is_supported_builtin_method(obj): + method_self = obj.__self__ + method_name = obj.__name__ + + # TODO(anijain2305) - Add support for more builtin methods + # Supports tuple.__new__ and frozenset({....}).__contains__ + return (method_self is tuple and method_name == "__new__") or ( + type(method_self) is frozenset and method_name == "__contains__" + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + method_self = self.fn.__self__ + name = self.fn.__name__ + obj_source = self.source and AttrSource(self.source, "__self__") + obj_vt = VariableTracker.build(tx, method_self, obj_source) + return obj_vt.call_method(tx, name, args, kwargs) + + +class LocalGeneratorObjectVariable(VariableTracker): + def __init__( + self, + code: types.CodeType, + f_globals, + inline_tracer: Optional["InstructionTranslator"], + **kwargs, + ): + super().__init__(**kwargs) + self.code = code + self.f_globals = f_globals + self.inline_tracer = inline_tracer + + def get_code(self): + return self.code + + def get_filename(self): + return self.get_code().co_filename + + def get_name(self): + return self.get_code().co_name + + def get_function(self): + raise NotImplementedError + + def has_self(self): + return False + + def __name__(self): + return self.get_name() + + def __str__(self): + return f"{self.__class__.__name__}({self.get_name()})" + + __repr__ = __str__ + + def reconstruct(self, codegen: "PyCodegen"): + from torch._dynamo.side_effects import disallow_side_effects_in_generator + from torch._dynamo.symbolic_convert import ( + InstructionTranslator, + save_and_restart_speculation_log, + temporarely_allow_writes_to_output_graph, + ) + + tx = InstructionTranslator.current_tx() + save = save_and_restart_speculation_log(tx) + disallow = disallow_side_effects_in_generator(tx) + temp = temporarely_allow_writes_to_output_graph(tx) + + with save, disallow, temp: + tracer = self._get_inline_tracer(tx) + if not tracer.generator_exhausted: + self.remaining_items = self.force_unpack_var_sequence(tx) + variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) + + def bind_args(self, tx, args, kwargs): + return self.fn.bind_args(tx, args, kwargs) + + def get_globals(self): + return self.f_globals + + def python_type(self): + return types.GeneratorType + + def _get_inline_tracer(self, tx): + from torch._dynamo.symbolic_convert import InliningInstructionTranslator + + if self.inline_tracer is None: + self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( + tx, self, [], {} + ) + return self.inline_tracer + + def next_variable(self, tx): + tracer = self._get_inline_tracer(tx) + + if self._is_generator_exhausted(): + raise_observed_exception(StopIteration, tx) + + try: + # Hierarchically, tx can be seen as the parent of the inline tracer + # created on call_function. Any exception needs to be propagated to tx + # for Dynamo to behave correctly + with patch.dict(counters, {"unimplemented": counters["inline_call"]}): + return tracer.inline_call_() + except ObservedException as e: + tracer.generator_exhausted = True + raise e + except InfiniteGeneratorError: + # test/dynamo/test_misc.py::test_iterator_limit + raise + except Unsupported as e: + torch._dynamo.eval_frame.skip_code(self.get_code()) + raise SkipFrame from e + finally: + counters["unimplemented"] |= counters["inline_call"] + + def has_unpack_var_sequence(self, tx): + return False + + def has_force_unpack_var_sequence(self, tx) -> builtins.bool: + return True + + def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: + result = [] + self.force_apply_to_var_sequence(tx, result.append) + return result + + def force_apply_to_var_sequence(self, tx, fn) -> None: + while True: + try: + fn(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + + def _setup_exception(self, tx, exc): + tracer = self._get_inline_tracer(tx) + try: + tracer._raise_exception_variable(exc) + except ObservedException as e: + # if no handler is available (i.e. user code doesn't catch it), the + # exception is raised again. + tracer.exception_handler(e) + + def _is_generator_just_started(self): + return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 + + def _is_generator_exhausted(self): + return getattr(self.inline_tracer, "generator_exhausted", False) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__next__": + return self.next_variable(tx) + elif name == "__iter__": + # iter(gen) returns itself + return self + elif name == "send": + # Sends a value into the generator function. Returns the next value + # yielded by the generator, or raises StopIteration if the generator + # exits without yielding another value + if self._is_generator_just_started() and len(args): + # can't send non-None value to a just-started generator + # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen + if not all( + isinstance(arg, ConstantVariable) and arg.value is None + for arg in args + ): + raise_observed_exception(TypeError, tx) + tracer = self._get_inline_tracer(tx) + tracer.push_many(args) + return self.next_variable(tx) + elif name == "close": + # * Raises a GeneratorExit at the point where the generator function was paused. + # * If the generator function catches the exception and returns a + # value, this value is returned from close() - Python 3.13+ + # * If the generator function is already closed, or raises GeneratorExit + # (by not catching the exception), close() returns None. + # * If the generator yields a value, a RuntimeError is raised. + # * If the generator raises any other exception, it is propagated to the caller. + # * If the generator has already exited due to an exception or normal + # exit, close() returns None and has no other effect. + + # Return None if close is called on a just-started generator + # See test GeneratorCloseCpythonTests::test_close_not_started + + tracer = self._get_inline_tracer(tx) + if self._is_generator_just_started() or self._is_generator_exhausted(): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + # Raise GeneratorExit to see if user code catches it. Any other exception + # is propagated to the parent frame. + try: + self._setup_exception( + tx, variables.ExceptionVariable(GeneratorExit, ()) + ) + # There's an extra block on Python 3.12+ to handle StopIteration + # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397 + # + # 1 0 RETURN_GENERATOR + # 2 POP_TOP + # 4 RESUME 0 + + # 2 6 LOAD_CONST 1 (1) + # 8 YIELD_VALUE 1 + # 10 RESUME 1 + # 12 POP_TOP + # 14 RETURN_CONST 0 (None) + # >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR) + # 18 RERAISE 1 + # ExceptionTable: + # 4 to 14 -> 16 [0] lasti + if ( + sys.version_info >= (3, 12) + and tracer.next_instruction.opname == "CALL_INTRINSIC_1" + ): + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedGeneratorExit: + # If it doesn't catch, we just return None, as per the text above + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + + try: + # Raise RuntimeError if the generator yields any other value + if self.next_variable(tx): + raise_observed_exception(RuntimeError, tx) + except ObservedGeneratorExit: + tracer.generator_exhausted = True + return variables.ConstantVariable(None) + except ObservedUserStopIteration: + # In Python 3.13+, one can capture GeneratorExit and return a value + # See test_generator.py::test_close_capture_GeneratorExit_return + # https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26 + # https://github.com/python/cpython/pull/104771 + assert tracer.symbolic_result is not None + return tracer.symbolic_result + elif name == "throw": + # * Raises an exception at the point where the generator was paused, and + # returns the next value yielded by the generator. + # * If the generator exits without yielding, raise StopIteration + # * If the generator function does not catch the passed-in exception, + # or raises a different exception, then that exception propagates to the caller. + + # Setup the exception table and jump target in case of try...finally + tracer = self._get_inline_tracer(tx) + try: + # In Python 3.9, the exception is represented as a triple (typ, val, tb) + # In such cases, we re-raise the exception object given to avoid + # creating a new object, so that IS_OP works. + # See: https://github.com/pytorch/pytorch/pull/146496 + self._setup_exception(tx, args[1] if len(args) == 3 else args[0]) + except ObservedException: # noqa: TRY203 + # propagate the exception back to the parent caller + raise + + retval = self.next_variable(tx) + + # The exception raised before is still active. We need to check the exception + # table one more time to find the next target. But why? Let’s walk + # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M + # + # z = 0 + # def whoo(): + # global z + # z = 0 + # try: + # yield 1 + # except ValueError: + # yield 2 + # finally: + # z += 1 + # z += 10 + # + # gen = whoo() + # next(gen) + # gen.throw(ValueError) + # print('z', z) -> z = 1 + # + # ... + # >> 58 PUSH_EXC_INFO + # + # 8 60 LOAD_GLOBAL 2 (ValueError) + # 70 CHECK_EXC_MATCH + # 72 POP_JUMP_IF_FALSE 7 (to 88) + # 74 POP_TOP + # + # 9 76 LOAD_CONST 3 (2) + # 78 YIELD_VALUE 3 <------ ValueError is still active here + # 80 RESUME 1 + # 82 POP_TOP + # 84 POP_EXCEPT + # 86 jump_backward 34 (to 20) + # ... + # + # ExceptionTable: + # 4 to 8 -> 124 [0] lasti + # 12 to 18 -> 58 [0] + # 20 to 56 -> 124 [0] lasti + # 58 to 82 -> 90 [1] lasti <------ move to 90 + # 84 to 86 -> 96 [0] + # 88 to 88 -> 90 [1] lasti + # 90 to 94 -> 96 [0] + # 96 to 116 -> 118 [1] lasti + # 118 to 122 -> 124 [0] lasti + # + # In this scenario, a generator can yield after `throw()` is called. Even + # after the exception is raised a few lines above, it remains active + # within the `78 YIELD_VALUE` instruction. When the generator resumes + # after the second yield on instruction `80 RESUME`, we cannot simply + # return the control flow to the next instruction. Instead, one must + # check the exception table (or equivalent) to find the next target + # In this case, it says the instruction pointer must be moved to 90. + # + # Without this step, if we let the trace proceed to the next + # instruction, it would follow the control flow where the exception + # raised by `throw()` was handled and swallowed, potentially leading + # to incorrect behavior. + exc_type = type("__InternalThrowException", (Exception,), {}) + + try: + self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) + self.next_variable(tx) + except get_dynamo_observed_exception(exc_type): + # We should get back the exception raised before. + pass + else: + raise_observed_exception(RuntimeError, tracer) + return retval + + super().call_method(tx, name, args, kwargs) + + +class ContextlibContextManagerLocalGeneratorObjectVariable( + LocalGeneratorObjectVariable +): + """ + .. note:: + + This is only used when the function is annotated with @contextlib.contextmanager + + It is a special case of a generator function as we do not allow return a context manager + from a torch.compile function. + """ + + +class LocalGeneratorFunctionVariable(BaseUserFunctionVariable): + """functions that behaves like iterators + + .. note:: + + This is a wrapper around (Nested)UserFunctionVariable + """ + + def __init__( + self, + vt: VariableTracker, + *, + generator_cls=LocalGeneratorObjectVariable, + **kwargs, + ): + super().__init__(**kwargs) + self.vt = vt + self.generator_cls = generator_cls + + def __getattr__(self, name): + if name in self.__class__.__dict__.keys(): + return getattr(self, name) + return getattr(self.vt, name) + + def _build_inline_tracer(self, tx, args, kwargs): + from torch._dynamo.symbolic_convert import InliningInstructionTranslator + + return InliningInstructionTranslator.build_inline_tracer( + tx, + self, + args, + kwargs, + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert is_generator(self.vt.get_code()) + + inline_tracer = self._build_inline_tracer(tx, args, kwargs) + code = self.vt.get_code() + f_globals = self.vt.get_globals() + + # calling a generator returns a generator object + return self.generator_cls( + code, + f_globals, + inline_tracer, + source=self.source, + ) + + +class FunctionDecoratedByContextlibContextManagerVariable( + LocalGeneratorFunctionVariable +): + """ + .. note:: + + This is only used when the function is annotated with @contextlib.contextmanager + """ + + def __init__(self, vt, **kwargs): + super().__init__( + vt, + generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable, + **kwargs, + ) + + def _build_inline_tracer(self, tx, args, kwargs): + # NOTE: This only exists to not break support for context manager when + # config.enable_faithful_generator_behavior = False and + # config.enable_trace_contextlib = True. In case the former is false, + # Dynamo should still be able to trace through @contextmanager functions + tracer = super()._build_inline_tracer(tx, args, kwargs) + assert isinstance( + tracer, + torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator, + ) + tracer.is_generator_from_ctx_manager = True + return tracer + + +class UserMethodVariable(UserFunctionVariable): + """Some unsupported user-defined method""" + + def __init__(self, fn, obj, **kwargs) -> None: + super().__init__(fn=fn, **kwargs) + self.obj = obj + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.fn}, {self.obj})" + + def self_args(self): + return [self.obj] + + def python_type(self): + return types.MethodType + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # NOTE this is to handle methods annotated by `nonstrict_trace`. Usually + # a `nonstrict_trace`-ed function will be wrapped by + # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`, + # but in the case of method, we manually wrap it with `UserMethodVariable` + # inside `UserDefinedObjectVariable.var_getattr`. + # + # We might be able to simplify this away by canonicalizing the + # function/method wrapping code paths. + from ..trace_rules import is_nonstrict_trace_callable + + if is_nonstrict_trace_callable(self.fn): + call_args = [*self.self_args(), *args] + var = variables.TorchInGraphFunctionVariable( + self.fn, nonstrict_traceable=True + ) + return var.call_function(tx, call_args, kwargs) + + # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution + # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method + # since we ensure `forward` of allowed modules can be traced by AOT safely. + # Note this is not only for allowed modules, as user customized modules can extend from + # allowed modules but using parent's `forward` method, which is also covered by this branch. + + # If we are tracing the higher order op, we want Dynamo to step inside + # the module call so that Dynamo can see the underlying parameters and + # buffers and raise them as inputs to the graph. The is_root_tracer + # check bypasses the if condition for non-root tracers and directly + # calls the super().call_function at the end, which is basically + # equivalent of inlining the method. + if tx.output.is_root_tracer() and isinstance( + self.obj, variables.NNModuleVariable + ): + module_attr = getattr(self.fn, "__module__", "") + # inline torch.nn.utils.parametrize + if ( + module_attr is not None + and module_attr.startswith("torch.nn.") + and module_attr != "torch.nn.utils.parametrize" + or self.is_constant + ): + return self.obj.call_method( + tx, self.fn.__name__, args, kwargs, constant=self.is_constant + ) + elif ( + _fsdp_param_group is not None + and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state + ): + return variables.TorchCtxManagerClassVariable(self.fn).call_function( + tx, (self.obj, *args), kwargs + ) + if self.is_constant: + fn = getattr(self.obj.value, self.fn.__name__) + return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) + return super().call_function(tx, args, kwargs) + + def inspect_parameter_names(self): + return super().inspect_parameter_names()[1:] + + def var_getattr(self, tx: "InstructionTranslator", name: str): + source = self.source and AttrSource(self.source, name) + if name == "__self__": + return self.obj + if name == "__func__": + return VariableTracker.build(tx, self.fn, source) + return super().var_getattr(tx, name) + + +class WrappedUserMethodVariable(UserMethodVariable): + def __init__(self, wrapped, context, **kwargs) -> None: + kwargs.pop("fn", None) + kwargs.pop("obj", None) + super().__init__(wrapped.fn, wrapped.obj, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +class WrappedUserFunctionVariable(UserFunctionVariable): + def __init__(self, wrapped, context, **kwargs) -> None: + kwargs.pop("fn", None) + super().__init__(wrapped.fn, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs): + def convert(x): + if isinstance(x, variables.TensorVariable): + return x.get_real_value() + return x.as_python_constant() + + args = [convert(x) for x in args] + kwargs = {k: convert(v) for k, v in kwargs.items()} + res = fn(*args, **kwargs) + return tx.output.register_attr_or_module( + res, + name, + source=ConstantSource(name), + ) + + +class NestedUserFunctionVariable(BaseUserFunctionVariable): + _nonvar_fields = { + "f_globals", + *BaseUserFunctionVariable._nonvar_fields, + } + + def __init__( + self, + fn_name, + code, + f_globals, + defaults, + kwdefaults, + annotations, + closure, + # This is present when this function is created by + # `functools.wrap(wrapped_fn)(this_fn)`. + wrapped_fn=None, + **kwargs, + ) -> None: + if kwargs.get("mutation_type") is None: + kwargs.update(mutation_type=AttributeMutationNew()) + super().__init__(**kwargs) + assert isinstance(fn_name.as_python_constant(), str) + assert isinstance(code.as_python_constant(), types.CodeType) + assert isinstance(f_globals, dict) + self.fn_name = fn_name + self.code = code + self.f_globals = f_globals + self.defaults = defaults + self.kwdefaults = kwdefaults + self.annotations = annotations + self.closure = closure + self.wrapped_fn: Optional[VariableTracker] = wrapped_fn + + def self_args(self): + return [] + + def get_code(self): + return self.code.as_python_constant() + + def python_type(self): + return types.FunctionType + + def get_function(self): + if self.closure: + raise NotImplementedError + func = types.FunctionType( + self.code.as_python_constant(), + self.f_globals, + self.fn_name.as_python_constant(), + ) + if self.defaults: + func.__defaults__ = self.defaults.as_python_constant() + if self.kwdefaults: + func.__kwdefaults__ = self.kwdefaults.as_python_constant() + if self.annotations: + annotations = self.annotations.as_python_constant() + if isinstance(annotations, tuple): + from itertools import pairwise + + annotations = dict(pairwise(annotations)) + + # TypeError: __annotations__ must be set to a dict object + assert isinstance(annotations, dict) + func.__annotations__ = annotations + return func + + def call_setattr( + self, + tx: "InstructionTranslator", + name_var: VariableTracker, + val: VariableTracker, + ): + tx.output.side_effects.store_attr(self, name_var.value, val) + return ConstantVariable(None) + + def call_method(self, tx, name, args, kwargs): + if name == "__setattr__": + return self.call_setattr(tx, *args) + return super().call_method(tx, name, args, kwargs) + + def has_closure(self): + return self.closure is not None + + def const_getattr(self, tx, name): + if name == "__name__": + return self.fn_name.as_python_constant() + return super().const_getattr(tx, name) + + def has_self(self): + return False + + def get_globals(self): + return self.f_globals + + def bind_args(self, parent, args, kwargs): + code = self.get_code() + func = types.FunctionType( + code, + self.f_globals, + self.fn_name.as_python_constant(), + tuple(self.defaults.items) if self.defaults else None, + tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), + ) + if self.kwdefaults: + func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() + bound = inspect.signature(func).bind(*args, **kwargs) + bound.apply_defaults() + result = dict(bound.arguments.items()) + wrap_args_kwargs(parent.output.root_tx, result) + init_cellvars(parent, result, code) + + for idx, name in enumerate(code.co_freevars): + assert name not in result + cell = self.closure.items[idx] + result[name] = cell + + return result + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from(__name__, "_create_nested_fn") + ) + codegen(self.code) + codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)]) + codegen(ConstantVariable.create(self.code.value.co_name)) + + if self.defaults: + codegen(self.defaults) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.closure: + codegen(self.closure) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.kwdefaults: + codegen(self.kwdefaults) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + if self.annotations: + try: + annotations = self.annotations.as_python_constant() + codegen.extend_output( + [codegen.create_load_const_unchecked(annotations)] + ) + except NotImplementedError: + codegen(self.annotations) + else: + codegen.extend_output([codegen.create_load_const(None)]) + + codegen.extend_output(create_call_function(7, False)) + + if self.wrapped_fn: + codegen.add_push_null( + lambda: codegen.load_import_from("functools", "wraps") + ) + codegen(self.wrapped_fn) + codegen.extend_output(create_call_function(1, False)) + codegen.extend_output(create_rot_n(2)) + codegen.extend_output(create_call_function(1, True)) + + # codegen attributes + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if tx.output.side_effects.has_pending_mutation(self): + for name, value in tx.output.side_effects.store_attr_mutations[ + self + ].items(): + codegen.dup_top() + codegen(value) + codegen.extend_output(create_rot_n(2)) + codegen.store_attr(name) + + +class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable): + def __init__(self, wrapped, context, **kwargs) -> None: + kwargs.pop("fn_name", None) + kwargs.pop("code", None) + kwargs.pop("f_globals", None) + kwargs.pop("defaults", None) + kwargs.pop("kwdefaults", None) + kwargs.pop("annotations", None) + kwargs.pop("closure", None) + kwargs.pop("wrapped_fn", None) + super().__init__( + wrapped.fn_name, + wrapped.code, + wrapped.f_globals, + wrapped.defaults, + wrapped.kwdefaults, + wrapped.annotations, + wrapped.closure, + wrapped.wrapped_fn, + ) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +class SkipFunctionVariable(VariableTracker): + _nonvar_fields = { + "value", + "reason", + *VariableTracker._nonvar_fields, + } + + def __init__(self, value, reason=None, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + self.reason = reason + + def as_python_constant(self): + return self.value + + @classmethod + def create_with_source(cls, value, source): + if not is_wrapper_or_member_descriptor(value): + # These descriptors are not guaranteed to return the same object on + # attribute lookup. They are unlikely to be changed, so we can skip + # guarding them. + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return cls(value, source=source) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if inspect.getattr_static(self.value, "_torchdynamo_disable", False): + msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None) + unimplemented_v2( + gb_type="Skip calling `torch.compiler.disable()`d function", + context=str(self.value), + explanation=f"Skip calling function `{self.value}` since it was wrapped " + f"with `torch.compiler.disable` (reason: {msg})", + hints=[ + "Remove the `torch.compiler.disable` call", + ], + ) + elif self.value is torch._dynamo.graph_break: + graph_break_msg = kwargs.get("msg", None) + if graph_break_msg: + graph_break_msg = graph_break_msg.as_python_constant() + unimplemented_v2( + gb_type="Call to `torch._dynamo.graph_break()`", + context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`", + explanation=f"User-inserted graph break. Message: {graph_break_msg}", + hints=[ + "Remove the `torch._dynamo.graph_break()` call.", + ], + ) + elif self.value is torch._dynamo.skip_frame: + skip_frame_msg = kwargs.get("msg", None) + if skip_frame_msg: + skip_frame_msg = skip_frame_msg.as_python_constant() + raise SkipFrame( + f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}" + ) + else: + if config.dont_skip_tracing: + from .builder import SourcelessBuilder + + # re-build the function, attempting to not skip + rebuilt_fn = SourcelessBuilder.create(tx, self.value) + # if we still get SkipFunctionVariable, then we *really* should skip this function + if not isinstance(rebuilt_fn, SkipFunctionVariable): + return rebuilt_fn.call_function(tx, args, kwargs) + qualname = getattr(self.value, "__qualname__", "") + module_or = getattr(self.value, "__module__", None) + module_name = "" if module_or is None else str(module_or) + try: + path = inspect.getfile(self.value) + explanation = ( + f"Dynamo developers have intentionally marked that the function `{qualname}` " + f"in file `{path}` should not be traced." + ) + hints = [ + f"Avoid calling the function `{qualname}`.", + ] + # TODO improve trace_rules reasoning to provide better hints. + # How do we tell that a function/file should NOT be removed from skip files? + # Do a very basic check for now. + if "_dynamo" not in path: + hints += [ + f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{qualname}` " + "to force tracing into the function. " + "More graph breaks may occur as a result of attempting to trace into the function.", + "Please file an issue to PyTorch.", + ] + except TypeError: + known_python_builtin_modules = {"_abc", "_warnings"} + if module_or in known_python_builtin_modules: + explanation = ( + f"Dynamo does not know how to trace the Python builtin " + f"`{module_name}.{qualname}`." + ) + hints = [ + "If you are attempting to call a logging function (e.g. `_warnings.warn`), " + "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.", + "Please file an issue on GitHub " + "so the PyTorch team can add support for it. ", + ] + elif module_or is not None and module_or.startswith("optree"): + explanation = f"Dynamo cannot trace optree C/C++ function {module_name}.{qualname}." + hints = [ + " Consider using torch.utils._pytree - " + "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py" + ] + # also warn on it because most users won't see the graph break message + torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) + else: + explanation = ( + f"Dynamo does not know how to trace the builtin `{module_name}.{qualname}.` " + f"This function is either a Python builtin (e.g. _warnings.warn) " + f"or a third-party C/C++ Python extension (perhaps created with pybind)." + ) + hints = [ + "If it is a Python builtin, please file an issue on GitHub " + "so the PyTorch team can add support for it and see the next case for a workaround.", + "If it is a third-party C/C++ Python extension, please " + "either wrap it into a PyTorch-understood custom operator " + "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html " + "for more details) or, if it is traceable, use " + "`torch.compiler.allow_in_graph`.", + ] + # also warn on it because most users won't see the graph break message + torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) + if qualname == "allow_in_graph": + explanation = ( + "Found an allow_in_graph decorator to a function which " + "is created inside the parent function that is getting " + "compiled. This is not supported for now." + ) + hints = [] + reason = self.reason if self.reason else "" + unimplemented_v2( + gb_type="Attempted to call function marked as skipped", + context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}", + explanation=explanation, + hints=hints, + ) + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + return variables.ConstantVariable.create(hasattr(self.value, name)) + + def var_getattr(self, tx: "InstructionTranslator", name: str): + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + + return fn_var_getattr(tx, self.value, self.source, name) + + +class WrappedSkipFunctionVariable(SkipFunctionVariable): + def __init__(self, wrapped, context, **kwargs) -> None: + kwargs.pop("value", None) + kwargs.pop("reason", None) + super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + +class WrapperUserFunctionVariable(VariableTracker): + """ + Used to represent a wrapper object that contains the actual callable as an + attribute. For example, torch.jit.script/trace have the original function at + their _torchdynamo_inline attribute. Similarly, functions with + __script_if_tracing_wrapper have the original attr at "__original_fn". + """ + + def __init__(self, wrapper_obj, attr_to_trace, **kwargs) -> None: + super().__init__(**kwargs) + self.wrapper_obj = wrapper_obj + self.attr_to_trace = attr_to_trace + + def var_getattr(self, tx: "InstructionTranslator", name): + if name == self.attr_to_trace: + val = getattr(self.wrapper_obj, self.attr_to_trace) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, val, source) + + return super().var_getattr(tx, name) + + def self_args(self): + return [] + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if hasattr(self.wrapper_obj, "cache_info"): + target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None) + module_name = getattr(target_fn, "__module__", "") or "" + + if module_name.split(".", maxsplit=1)[0] != "torch": + msg = ( + "Dynamo detected a call to a `functools.lru_cache`-wrapped " + "function. Dynamo ignores the cache wrapper and directly " + "traces the wrapped function. Silent incorrectness is only " + "a *potential* risk, not something we have observed. " + 'Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.' + ) + + torch._dynamo.utils.warn_once(msg) + + dynamo_logger = torch._dynamo.utils.logging.getLogger("torch._dynamo") + if dynamo_logger.isEnabledFor(logging.DEBUG): + user_stack = torch._guards.TracingContext.extract_stack() + user_stack = get_stack_above_dynamo() + user_stack + frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + user_stack_formatted = "".join(traceback.format_list(user_stack)) + user_stack_trace = f"call to a lru_cache wrapped function at: {frame_loc[0]}:{frame_loc[1]}\n" + user_stack_trace += str(user_stack_formatted) + dynamo_logger.debug(user_stack_trace) + + all_args = self.self_args() + args + return variables.UserFunctionVariable( + polyfills.getattr_and_trace + ).call_function( + tx, + [self, variables.ConstantVariable(self.attr_to_trace), *all_args], + kwargs, + ) + + +class WrapperUserMethodVariable(WrapperUserFunctionVariable): + """ + Similar to WrapperUserFunctionVariable, but for methods. The only delta is + saving the vt for `self` object of the method which is then used by + WrapperUserFunctionVariable in `call_function` method. + """ + + def __init__(self, wrapper_obj, attr_to_trace, self_obj, **kwargs) -> None: + super().__init__(wrapper_obj, attr_to_trace, **kwargs) + self.obj = self_obj + + def self_args(self): + return [self.obj] + + +def _traceable_collective_remaps(): + # We can't rely on importing from distributed, since it's not always built + if torch.distributed.is_available(): + from torch.distributed._functional_collectives import ( + traceable_collective_remaps, + ) + + return traceable_collective_remaps + return {} + + +def _traceable_collectives_source(tx: "InstructionTranslator", fn): + assert torch.distributed.is_available(), "Illegal invocation." + assert fn in _traceable_collective_remaps().values() + + inner_name = fn.__name__ + path_source = tx.import_source("torch.distributed._functional_collectives") + return AttrSource(path_source, inner_name) + + +class CollectiveFunctionRewriteVariable(UserFunctionVariable): + """ + Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives. + + This class provides both a way to check if a function is remappable, and perform the remapping. + + In the case that a function is 'remappable' but only for some combinations of call-time arguments, + we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse + than status-quo as we currently graph-break on all distributed.* collectives. + """ + + def __init__(self, fn, *, replacement_var, **kwargs) -> None: + super().__init__(fn, **kwargs) + assert isinstance(replacement_var, UserFunctionVariable) + self.replacement_var = replacement_var + + @staticmethod + def create(tx: "InstructionTranslator", old_fn, source, **options): + new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn) + return CollectiveFunctionRewriteVariable( + old_fn, + replacement_var=UserFunctionVariable(new_fn, source=new_source, **options), + source=source, + **options, + ) + + @staticmethod + def can_rewrite(variable): + return ( + inspect.isfunction(variable) and variable in _traceable_collective_remaps() + ) + + @staticmethod + def rewrite(tx: "InstructionTranslator", fn): + new_fn = _traceable_collective_remaps()[fn] + return new_fn, _traceable_collectives_source(tx, new_fn) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # call_function must check any unsupported arguments and graph-break. + # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, + # since that's the contract for putting a mapping in `traceable_collective_remaps` + import torch.distributed as dist + from torch.distributed._functional_collectives import REDUCE_OP_TO_STR + + # Merge args into kwargs so positional and keyword args + # can be processed the same way. + signature = inspect.signature(self.fn) + kwargs = dict(signature.bind(*args, **kwargs).arguments) + args = () + + if "async_op" in kwargs and kwargs["async_op"].as_python_constant(): + unimplemented_v2( + gb_type="async_op=True for distributed collectives", + context=f"{self.fn}, {args=}, {kwargs=}", + explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + if self.fn in ( + dist.all_reduce, + dist.reduce_scatter_tensor, + dist._reduce_scatter_base, + ): + reduce_op_var = kwargs.get("op") + reduce_op = ( + reduce_op_var.value + if reduce_op_var is not None + else signature.parameters["op"].default + ) + if reduce_op not in REDUCE_OP_TO_STR: + raise ValueError(f"Unsupported all_reduce op: {reduce_op}") + kwargs["op"] = variables.ConstantVariable.create( + REDUCE_OP_TO_STR[reduce_op] + ) + return self.replacement_var.call_function(tx, args, kwargs) + + +class FunctoolsWrapsVariable(UserFunctionVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if not kwargs and len(args) == 1: + + def wraps(fn): + if isinstance(fn, variables.NestedUserFunctionVariable): + return fn.clone(wrapped_fn=args[0]) + unimplemented_v2( + gb_type="functools.wraps", + context=f"{fn}", + explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + return variables.LambdaVariable(wraps) + + return super().call_function(tx, args, kwargs) + + +class CollectionsNamedTupleFunction(UserFunctionVariable): + def as_python_constant(self): + return self.fn + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + constant_args = check_constant_args(args, kwargs) + if constant_args: + value = self.fn( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + return variables.UserDefinedClassVariable( + value, mutation_type=ValueMutationNew() + ) + unimplemented_v2( + gb_type="namedtuple construction", + context=f"{args=}, {kwargs=}", + explanation="`torch.compile` only support certain input types for namedtuple", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + +class FunctoolsPartialVariable(VariableTracker): + def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None: + super().__init__(**kwargs) + self.func = func + assert isinstance(args, list) + self.args = args + assert isinstance(keywords, dict) + self.keywords = keywords + # fake_value is used for id calculation. Creating this value and id'ng + # on it is sufficient for the tracing purposes. + self.fake_value = functools.partial(identity) + + def python_type(self): + return functools.partial + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) + codegen(self.func) + if self.args: + codegen.foreach(self.args) + if not self.keywords: + codegen.extend_output(create_call_function(len(self.args) + 1, False)) + return + + codegen.foreach(self.keywords.values()) + keys = tuple(self.keywords.keys()) + codegen.extend_output( + codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False) + ) + + def get_function(self): + return self.as_python_constant() + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + merged_args = self.args + args + merged_kwargs = {**self.keywords, **kwargs} + return self.func.call_function(tx, merged_args, merged_kwargs) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + # functools.partial uses slots, so attributes are constant + return variables.ConstantVariable.create( + hasattr(functools.partial(identity), name) + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str): + source = self.source and AttrSource(self.source, name) + # Handle __slots__ + if name == "func": + return self.func + if name == "args": + return variables.ListVariable(self.args, source=source) + if name == "keywords": + items = {ConstantVariable.create(k): v for k, v in self.keywords.items()} + return variables.ConstDictVariable(items, source=source) + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + raise_observed_exception(AttributeError, tx) + + def as_python_constant(self): + return functools.partial( + self.func.as_python_constant(), + *[arg.as_python_constant() for arg in self.args], + **{k: v.as_python_constant() for k, v in self.keywords.items()}, + ) + + def guard_as_python_constant(self): + """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" + return functools.partial( + self.func.guard_as_python_constant(), + *[v.guard_as_python_constant() for v in self.args], + **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, + ) + + +class PolyfilledFunctionVariable(VariableTracker): + _nonvar_fields = { + "fn", + "wrapped_fn", + "traceable_fn", + *VariableTracker._nonvar_fields, + } + + @classmethod + @functools.cache + def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: + return {} + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + + return cls(value, source=source) + + def __init__(self, fn: _F, **kwargs) -> None: + super().__init__(**kwargs) + self.fn: _F = fn + + handler = self._get_polyfill_handlers().get(fn, fn) + assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}" + for candidate_attr in ( + "__torch_dynamo_polyfill__", # registered polyfill + "__python_implementation__", # self handler from third-party libraries + ): + candidate = getattr(handler, candidate_attr, None) + if candidate: + assert callable(candidate) + traceable_fn = candidate + break + else: + raise RuntimeError( + f"Polyfill handler {handler} does not have a traceable function" + ) + + self.wrapped_fn: _F = handler + self.traceable_fn: _F = traceable_fn + + @property + def polyfill_fn(self) -> _F: + return self.traceable_fn + + def can_constant_fold_through(self): + return getattr( + self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False + ) + + def get_function(self): + return self.as_python_constant() + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ): + result = ( + self.fn( # use the original function which is faster than the polyfill + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + ) + return VariableTracker.build(tx, result) + + # Special case for sum on tuple/list of ints + if ( + self.fn is builtins.sum + and len(args) == 1 + and not kwargs + and isinstance(args[0], (variables.ListVariable, variables.TupleVariable)) + and all( + (isinstance(x, variables.ConstantVariable) and isinstance(x.value, int)) + or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int) + for x in args[0].items + ) + ): + return variables.SymNodeVariable.create( + tx, + tx.output.create_proxy( + "call_function", + torch.sym_sum, + (tuple(a.as_proxy() for a in args[0].items),), + {}, + ), + sym_num=torch.sym_sum( + [ + ( + x.value + if isinstance(x, variables.ConstantVariable) + else x.sym_num + ) + for x in args[0].items + ] + ), + ) + + traceable_function_variable = VariableTracker.build(tx, self.traceable_fn) + return traceable_function_variable.call_function(tx, args, kwargs) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__call__": + return self.call_function(tx, args, kwargs) + + method = getattr(self.fn, name, None) + assert method is not None, f"Member {name} not found in {self.fn}" + assert is_function(method), f"Member {name} is not callable in {self.fn}" + options = {} + if self.source: + options["source"] = AttrSource(self.source, name) + polyfilled_method_variable = PolyfilledFunctionVariable(method, **options) + return polyfilled_method_variable.call_function(tx, args, kwargs) + + def as_python_constant(self): + return self.fn + + +class TracebackVariable(VariableTracker): + # We don't track traceback. A call to any function in this module is a no-op + def call_function(self, tx, args, kwargs): ... + + +class SysFunctionVariable(VariableTracker): + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + def exc_info(self, tx): + if len(tx.exn_vt_stack): + exn = tx.exn_vt_stack[-1] + typ = exn.exc_type + tb = None + items = [ + VariableTracker.build(tx, typ), + exn, + VariableTracker.build(tx, tb), + ] + else: + items = [ + variables.ConstantVariable(None), + variables.ConstantVariable(None), + variables.ConstantVariable(None), + ] + return variables.TupleVariable(items) + + def exception(self, tx): + return self.exc_info(tx).items[1] + + def call_function(self, tx, args, kwargs): + if self.value is sys.exc_info: + return self.exc_info(tx) + assert self.value is sys.exception + return self.exception(tx) + + +from torch._higher_order_ops.triton_kernel_wrap import ( + create_tma_experimental_metadata, + create_tma_stable_metadata, + TMADescriptorMetadata, + TritonHOPifier, +) + + +class DynamoTritonHOPifier(TritonHOPifier): + def raise_unsupported(self, msg: str) -> Never: + raise Unsupported(msg) + + def is_callable(self, maybe_callable: Any) -> bool: + return isinstance( + maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable) + ) + + def get_value(self, val: Any) -> Any: + return val.value + + def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]: + from .lists import BaseListVariable + + if isinstance(grid, BaseListVariable): + return grid.as_proxy() + else: + unimplemented_v2( + gb_type="unsupported grid type for triton hop check_grid", + context=f"grid type = {type(grid)}", + explanation="`torch.compile` only supports list-like grid for check_grid", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + def call_grid(self, grid, meta, tx): + meta = {variables.ConstantVariable.create(k): v for k, v in meta.items()} + grid = grid.call_function(tx, [meta], {}) + return grid + + # We use this function to wrap call_prune_configs + def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable): + from .builder import SourcelessBuilder + + wrapped_user_function = SourcelessBuilder.create(tx, user_fn) + result = wrapped_user_function.call_function(tx, args, kwargs) + return result + + def wrap_user_defined_obj(self, user_obj, tx, variable, name): + from .builder import VariableBuilder + + wrapped_user_obj = VariableBuilder( + tx, AttrSource(variable.kernel_source, f"{name}") + )._wrap(user_obj) + return wrapped_user_obj + + def maybe_unpack_configs(self, configs, tx): + # unpack the list of configs + configs = configs.unpack_var_sequence(tx) + + # guard_as_python_constant inserts guards for Dynamo to check if the configs object changed. + configs = [config.guard_as_python_constant() for config in configs] + + return configs + + def maybe_unpack_heuristic_result(self, result: Any) -> Any: + if not result.is_python_constant(): + self.raise_unsupported( + "@triton.heuristics must return constant values because configs can only contain constant values." + ) + + return result.guard_as_python_constant() + + # We need to override call_getitem here so that we can add the source in the case + # where we call the triton kernel with a grid + def call_getitem( + self, + variable: "TritonKernelVariable", + args: Sequence[Any], + ) -> "TritonKernelVariable": + # __getitem__ should only be called if we don't already have a grid + # Only grid needs to be passed + if variable.grid is not None or len(args) != 1: + self.raise_unsupported( + "Triton kernels should be called with only a single grid" + ) + return type(variable)( + kernel=variable.kernel, + kernel_idx=variable.kernel_idx, + grid=args[0], + kernel_source=variable.source, + ) + + def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable: + from .constant import ConstantVariable + from .dicts import ConstDictVariable + + # as we can only pass tensors as non-const args in fx graph, + # here we replace TMA descriptors + # (TMADescriptorExperimentalVariable and TMADescriptorStableVariable + # instances) with the underlying tensors, while moving the + # TMA descriptor-related metadata to a separate argument, + # so that we can reconstruct the TMA descriptors downstream + tma_descriptor_metadata: TMADescriptorMetadata = {} + for k in list(combined_args_raw.keys()): + v = combined_args_raw[k] + if isinstance( + v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable) + ): + tma_descriptor_metadata[k] = v.to_metadata() + combined_args_raw[k] = v.get_tensor() + + combined_args = { + variables.ConstantVariable.create(k): v + for k, v in combined_args_raw.items() + } + + from torch._higher_order_ops.triton_kernel_wrap import ( + kernel_side_table, + triton_kernel_wrapper_mutation, + ) + + # Combine args and kwargs and pass as a dict so that if user defined triton + # kernel uses variables as 'grid' or 'kernel', it does not conflict with + # parameters of the wrapper function + constant_args = { + k: v.as_python_constant() + for k, v in combined_args_raw.items() + if isinstance(v, ConstantVariable) + } + non_constant_args = { + k: v + for k, v in combined_args.items() + if not isinstance(v, ConstantVariable) + } + + for v in non_constant_args.values(): + v = v.realize() + if not isinstance(v, (variables.TensorVariable, variables.SymNodeVariable)): + self.raise_unsupported( + f"Unexpected argument type for a Triton kernel: {repr(v)}." + ) + + constant_args_idx = kernel_side_table.add_constant_args(constant_args) + meta = ConstDictVariable(non_constant_args, dict) + tx.output.create_proxy( + "call_function", + triton_kernel_wrapper_mutation, + (), + { + "kernel_idx": variable.kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": grids, + "tma_descriptor_metadata": tma_descriptor_metadata, + "kwargs": meta.as_proxy(), + }, + ) + + return variables.ConstantVariable( + None, + ) + + +dynamo_triton_hopifier_singleton = DynamoTritonHOPifier() + + +class TritonKernelVariable(VariableTracker): + grid: "TritonGridType" + kernel: "TritonKernelType" + kernel_idx: Optional[int] + kernel_source: "AttrSource" + + def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None: + self.kernel_source = kwargs.pop("kernel_source", None) + super().__init__(**kwargs) + dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return dynamo_triton_hopifier_singleton.call_triton_kernel( + self, args, kwargs, tx + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__getitem__": + return dynamo_triton_hopifier_singleton.call_getitem(self, args) + elif name == "run": + return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) + + # Bail out to parent's implementation + return super().call_method(tx, name, args, kwargs) + + def specialize_symbolic(self, arg: Any) -> Any: + from .constant import ConstantVariable + from .tensor import SymNodeVariable + + # See [Note: Specialize tl.constexpr args in user-defined triton kernels] + if isinstance(arg, SymNodeVariable): + return ConstantVariable.create(arg.evaluate_expr()) + return arg + + +class TMADescriptorExperimentalVariable(VariableTracker): + def __init__( + self, + data_ptr: "variables.DataPtrVariable", + dims: "list[ConstantVariable]", + block_dims: "list[ConstantVariable]", + element_size: "ConstantVariable", + **kwargs, + ): + assert isinstance(data_ptr, variables.DataPtrVariable) + super().__init__(**kwargs) + self.data_ptr = data_ptr + self.dims = dims + self.block_dims = block_dims + self.element_size = element_size + + def to_metadata(self): + return create_tma_experimental_metadata( + [dim.as_proxy() for dim in self.dims], + [dim.as_proxy() for dim in self.block_dims], + self.element_size.as_proxy(), + ) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from( + "triton.tools.experimental_descriptor", + f"create_{len(self.dims)}d_tma_descriptor", + ) + ) + self.data_ptr.reconstruct(codegen) + args = [*self.dims, *self.block_dims, self.element_size] + codegen.foreach(args) + codegen.call_function(len(args) + 1, False) + + def get_tensor(self): + return self.data_ptr.from_tensor + + +class TMADescriptorStableVariable(VariableTracker): + def __init__( + self, + tensor: "variables.TensorVariable", + block_shape: "variables.ListVariable", + **kwargs, + ): + assert isinstance(tensor, variables.TensorVariable) + super().__init__(**kwargs) + self.tensor = tensor + self.block_shape = block_shape + + def to_metadata(self): + return create_tma_stable_metadata( + self.block_shape.as_proxy(), + ) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from( + "triton.tools.tensor_descriptor", + "TensorDescriptor", + ) + ) + codegen.load_method("from_tensor") + self.tensor.reconstruct(codegen) + codegen(self.block_shape) + codegen.call_method(2) + + def get_tensor(self) -> "variables.TensorVariable": + return self.tensor + + +class CreateTMADescriptorExperimentalVariable(VariableTracker): + def __init__( + self, + rank: int, + **kwargs, + ) -> None: + assert rank in (1, 2) + super().__init__(**kwargs) + self.rank = rank + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + ptr = kwargs["ptr"] if "ptr" in kwargs else args[0] + + if not isinstance(ptr, variables.DataPtrVariable): + raise Unsupported( + "Please ensure there were no graph breaks between " + f"create_{self.rank}d_tma_descriptor and the upstream " + ".data_ptr() call." + ) + + if self.rank == 1: + assert len(args) + len(kwargs) == 4 + dims = [ + kwargs["dim"] if "dim" in kwargs else args[1], + ] + block_dims = [ + kwargs["block_dim"] if "block_dim" in kwargs else args[2], + ] + else: + assert len(args) + len(kwargs) == 6 + dims = [ + kwargs["dim1"] if "dim1" in kwargs else args[1], + kwargs["dim0"] if "dim0" in kwargs else args[2], + ] + block_dims = [ + kwargs["block_dim1"] if "block_dim1" in kwargs else args[3], + kwargs["block_dim0"] if "block_dim0" in kwargs else args[4], + ] + element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1] + + return TMADescriptorExperimentalVariable( + data_ptr=ptr, + dims=dims, + block_dims=block_dims, + element_size=element_size, + ) + + +class CreateTMADescriptorStableVariable(VariableTracker): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + tensor = kwargs["tensor"] if "tensor" in kwargs else args[0] + block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1] + + return TMADescriptorStableVariable( + tensor=tensor, + block_shape=block_shape, + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/higher_order_ops.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/higher_order_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..71a1934ca4b29d1a5b9f638a9217f25390019958 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/higher_order_ops.py @@ -0,0 +1,3440 @@ +# mypy: ignore-errors + +""" +This module contains classes and utilities for handling higher-order operators in Dynamo. +It provides functionality for tracing and transforming control flow constructs like +conditions (torch.cond), loops (torch.while_loop), maps (torch.ops.higher_order.map), +and other higher-order operations. + +The module includes specialized VariableTracker classes for different types of +higher-order operations, along with utilities for: +- Speculating and capturing subgraphs +- Managing control flow +- Handling autograd function applications +- Supporting function transformations +- Processing activation checkpoints + +These classes work together to enable Dynamo to correctly trace and compile code +containing complex control flow patterns and higher-order functions while preserving +their semantic behavior. +""" + +import contextlib +import functools +import inspect +import itertools +import logging +import types +import warnings +from typing import Optional, TYPE_CHECKING + +import torch._C +import torch.fx +import torch.nn +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import get_fake_value +from torch._dynamo.variables.builtin import BuiltinVariable +from torch._dynamo.variables.constant import ConstantVariable +from torch._dynamo.variables.functions import UserFunctionVariable +from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable +from torch._dynamo.variables.tensor import SymNodeVariable +from torch._guards import Source +from torch._ops import HigherOrderOperator +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils import _pytree as pytree + +from .. import graph_break_hints, variables +from ..exc import ( + IncorrectUsage, + ObservedException, + UncapturedHigherOrderOpError, + unimplemented, + unimplemented_v2, + Unsupported, +) +from ..source import AttrSource, DictGetItemSource +from ..utils import proxy_args_kwargs, set_example_value +from .base import VariableTracker +from .dicts import ConstDictVariable +from .lazy import LazyVariableTracker +from .lists import ListVariable, TupleVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) +hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile") + + +def raise_hard_error_if_graph_break(reason): + def deco(fn): + @functools.wraps(fn) + def graph_break_as_hard_error(*args, **kwargs): + try: + return fn(*args, **kwargs) + except (Unsupported, ObservedException) as e: + msg = " Scroll up to find out what causes the graph break." + raise UncapturedHigherOrderOpError(reason + msg) from e + + return graph_break_as_hard_error + + return deco + + +# This function is a syntax sugar for creating a dummy new subtracer so that +# newly added nodes are added to a separate subgraph in this subtracer instead of affecting +# the main graph. This is useful for creating sample inputs for tracing the subgraph. +# For example, in FlexAttentionHigherOrderVariable, we want to create several scalars +# to trace the score_mod function but we don't want the operators that creates the scalar to +# show up in the graph, we could this function to discard the graph changes. +# Example usage: +# with discard_graph_changes(): +# sample_input= create_sample_inputs() +# speculate_subgraph(tx, f, sample_inputs, {}) +@contextlib.contextmanager +def discard_graph_changes(tx): + ctx = tx.output.subtracer("subgraph_wrapper", None) + try: + ctx.__enter__() + yield + finally: + ctx.__exit__(None, None, None) + + +def check_meta_consistency_vt( + vars1: list[VariableTracker], + vars2: list[VariableTracker], + lhs_name: str, + rhs_name: str, + include_contiguity: bool = True, +) -> None: + from torch._higher_order_ops.utils import check_meta_consistency + + from . import TensorVariable + + def _unwrap_var(var): + if isinstance(var, TensorVariable): + return var.proxy.node.meta["example_value"] + elif isinstance(var, SymNodeVariable): + return var.sym_num + elif isinstance(var, ConstantVariable): + return var.as_python_constant() + else: + unimplemented(f"Cannot unwrap var {var}") + + unwrapped1 = [_unwrap_var(var) for var in vars1] + unwrapped2 = [_unwrap_var(var) for var in vars2] + + return check_meta_consistency( + unwrapped1, + unwrapped2, + lhs_name, + rhs_name, + include_contiguity=include_contiguity, + ) + + +@contextlib.contextmanager +def dynamo_enable_grad(tx: "InstructionTranslator", enable=True): + from . import GradModeVariable + + org_value = torch.is_grad_enabled() + try: + GradModeVariable.create(tx, enable, initialized=True) + yield + finally: + GradModeVariable.create(tx, org_value, initialized=True) + + +@contextlib.contextmanager +def dynamo_under_activation_checkpoint(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.under_activation_checkpoint + try: + tx.output.current_tracer.under_activation_checkpoint = True + yield + finally: + tx.output.current_tracer.under_activation_checkpoint = orig_val + + +def find_mismatched_vars(var, types, allow_none=False): + """ + Recursively finds variables whose type is not an instance of the specified types. + Args: + var: The variable to check. + types: A tuple of allowed types. + allow_none (bool): Whether to allow None values. Defaults to False. + Returns: + A set of variables whose type is not an instance of the specified types. + """ + mismatched_vars = set() + if isinstance(var, (TupleVariable, ListVariable)): + for item in var.items: + mismatched_vars.update(find_mismatched_vars(item, types, allow_none)) + elif isinstance(var, ConstDictVariable): + for value in var.items.values(): + mismatched_vars.update(find_mismatched_vars(value, types, allow_none)) + else: + + def _is_none(var): + return var.is_python_constant() and var.as_python_constant() is None + + if not isinstance(var, types) and not (allow_none and _is_none(var)): + mismatched_vars.add(var) + return mismatched_vars + + +def only_consist_of(var, types, allow_none=False): + mismatch_vars = find_mismatched_vars(var, types, allow_none=allow_none) + return len(mismatch_vars) == 0 + + +# A more read-able syntax sugar for creating a UserFunctionVariable for f +# and run call_function on it. Make it return a function to preserve the calling +# convention of the original f. +def _make_inlined(tx: "InstructionTranslator", f): + assert callable(f), "Expect f to be a python callable." + + def inline_call(*args, **kwargs): + return UserFunctionVariable(f).call_function(tx, args, kwargs) + + return inline_call + + +def _call_function_and_unflatten_output( + tx, fn, args, kwargs, flat_example_value, ret_treespec +): + from .builder import wrap_fx_proxy + + # Store the invocation as a call + flat_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn, + args=args, + kwargs=kwargs, + ), + example_value=flat_example_value, + ) + + # Transform variable back into a list (previously made into a tuple by + # speculate_subgraph function) so as to respect the pytree API typing. + flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {}) + return ( + _make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, ret_treespec) + if ret_treespec + else flat_variable + ) + + +def _assert_tensors_nonaliasing(inputs, outputs): + input_tensor_ids = { + id(t) for t in pytree.tree_leaves(inputs) if isinstance(t, torch.Tensor) + } + output_tensor_ids = { + id(t) for t in pytree.tree_leaves(outputs) if isinstance(t, torch.Tensor) + } + assert input_tensor_ids.isdisjoint(output_tensor_ids), ( + "inputs to function body cannot alias outputs" + ) + + +def _check_all_tensorvariable(args): + from . import TensorVariable + + if not all(type(a.realize()) is TensorVariable for a in args): + unimplemented( + f"Expected all leaves to be of torch.Tensor type, but got {[type(a.realize()) for a in args]}." + ) + + +def _check_supported_callable_arg( + tx: "InstructionTranslator", func_var: VariableTracker, arg_name +): + is_callable = ( + BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant() + ) + if not is_callable: + unimplemented( + f"{arg_name} should be a Callable but is of type {str(func_var)}." + ) + + +def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode): + from torch._subclasses._fake_tensor_utils import _CacheKeyState + from torch._subclasses.fake_tensor import extract_tensor_metadata + + # Maps the equivalent nodes from a to b + node_map = {} + + def check_all_args(a_nodes, b_nodes): + for arg_a, arg_b in zip(a_nodes, b_nodes): + if isinstance(arg_a, torch.fx.Node): + if node_map[arg_a] != arg_b: + return False + elif isinstance(arg_a, slice): + if not isinstance(arg_b, slice): + return False + if not check_all_args( + (arg_a.start, arg_a.stop, arg_a.step), + (arg_b.start, arg_b.stop, arg_b.step), + ): + return False + elif arg_a != arg_b: + # This is a catch-all for everything else. `slice` was a + # surprise but can there be other data structures that can + # contain fx.Nodes in them? + return False + return True + + for a_node, b_node in zip(a_mod.graph.nodes, b_mod.graph.nodes): + if a_node.op != b_node.op: + return False + + if a_node.op == "placeholder": + a_value = a_node.meta["example_value"] + b_value = b_node.meta["example_value"] + + if isinstance(a_value, torch.Tensor): + if not isinstance(b_value, torch.Tensor): + return False + # Extract fake tensor metadata for a and b and then compare + a_result = [] + state = _CacheKeyState(fake_mode.shape_env) + a_metadata = extract_tensor_metadata(a_value) + a_metadata._flatten_into(a_result, fake_mode, state) + + b_result = [] + state = _CacheKeyState(fake_mode.shape_env) + b_metadata = extract_tensor_metadata(b_value) + b_metadata._flatten_into(b_result, fake_mode, state) + if a_result != b_result: + return False + elif isinstance(a_value, torch.SymInt): + if not isinstance(b_value, torch.SymInt): + return False + if a_value is not b_value: + return False + elif a_node.op == "call_function": + if a_node.target is not b_node.target: + return False + a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) + b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) + if not check_all_args(a_flat, b_flat): + hc_log.debug( + "%s: Graph comparison failed at node (call_function): %s", + fn_name, + a_node, + ) + return False + elif a_node.op == "call_method": + if a_node.target != b_node.target: + return False + a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) + b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) + if not check_all_args(a_flat, b_flat): + hc_log.debug( + "%s: Graph comparison failed at node (call_method) : %s", + fn_name, + a_node, + ) + return False + elif a_node.op == "output": + a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs)) + b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs)) + if not check_all_args(a_flat, b_flat): + hc_log.debug("%s: Graph comparison failed at the output node", fn_name) + return False + elif a_node.op == "get_attr": + a_attr = getattr(a_mod, a_node.target) + b_attr = getattr(b_mod, b_node.target) + if isinstance(a_attr, torch.fx.GraphModule): + if not isinstance(b_attr, torch.fx.GraphModule): + return False + # This is an example of a HOP inside a HOP + if not are_same_graph_modules(fn_name, a_attr, b_attr, fake_mode): + return False + else: + # TODO - write an example with tensor as a graph attribute in + # the Fx graph + raise NotImplementedError(f"get_attr with {type(a_attr)}") + else: + # TODO - call_module is not supported because Dynamo Fx graph does + # not install a call_module + raise NotImplementedError(f"Graph equivalence check saw a {a_node.op}") + + # Two nodes are equal - add them to them map + node_map[a_node] = b_node + + return True + + +def validate_args_and_maybe_create_graph_inputs( + sub_args, + tracer, + tx, + set_subgraph_inputs, + description, + sub_args_names=None, +): + from . import AutogradFunctionContextVariable + from .builder import wrap_fx_proxy_cls + + assert tracer.parent is not None + + if set_subgraph_inputs == "flatten_manual": + flat_args, tree_spec = _make_inlined(tx, pytree.tree_flatten)( + ListVariable(sub_args) + ).unpack_var_sequence(tx) + + flat_inputs = validate_args_and_maybe_create_graph_inputs( + flat_args.unpack_var_sequence(tx), + tracer, + tx, + set_subgraph_inputs="manual", + description=description, + ) + + return _make_inlined(tx, pytree.tree_unflatten)( + ListVariable(flat_inputs), tree_spec + ).unpack_var_sequence(tx) + else: + if sub_args_names is not None: + # Can be greater if user passes some args as kwargs + assert len(sub_args_names) >= len(sub_args) + args = [] + for idx, a in enumerate(sub_args): + assert isinstance(a, VariableTracker) + if set_subgraph_inputs == "automatic": + args.append(a) + continue + elif set_subgraph_inputs == "semi_automatic": + if isinstance(a, AutogradFunctionContextVariable): + example_value = a.as_proxy().node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + tracer.create_graph_input(arg_name, a.python_type(), example_value) + elif a.maybe_fx_node() is not None: + node = a.maybe_fx_node() + example_value = node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + new_proxy = tracer.create_graph_input( + arg_name, a.python_type(), example_value + ) + example_value = ( + node.meta["example_value"] + if "example_value" in node.meta + else None + ) + a = wrap_fx_proxy_cls( + target_cls=type(a), + tx=tx, + proxy=new_proxy, + example_value=example_value, + ) + args.append(a) + continue + + if a.is_python_constant(): + # This arg is not used in the body of the higher order op. + # Currently, this new input is added to make the calls + # happy, which expect a fixed number of arguments. In + # future, we can clean this up. + arg_name = ( + "const_unused" + if sub_args_names is None + else f"const_unused_{sub_args_names[idx]}" + ) + tracer.create_graph_input( + arg_name, a.python_type(), a.as_python_constant() + ) + new_arg = a + # Weird special case, we probably want to delete it or fold it + # into the next case (of `a` being placeable into a graph) + elif isinstance(a, AutogradFunctionContextVariable): + example_value = a.as_proxy().node.meta["example_value"] + arg_name = ( + a.as_proxy().node.name + if sub_args_names is None + else sub_args_names[idx] + ) + tracer.create_graph_input(arg_name, a.python_type(), example_value) + new_arg = a + # If `a` can be put into a graph + elif a.maybe_fx_node() is not None: + node = a.maybe_fx_node() + example_value = ( + node.meta["example_value"] if "example_value" in node.meta else None + ) + arg_name = node.name if sub_args_names is None else sub_args_names[idx] + new_proxy = tracer.create_graph_input( + arg_name, a.python_type(), example_value + ) + new_arg = wrap_fx_proxy_cls( + target_cls=type(a), + tx=tx, + proxy=new_proxy, + example_value=example_value, + ) + # If `a` cannot be put into a graph + else: + # HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic"). + unimplemented( + f"{description} with body that accepts non-Tensors as input. " + f"Got: {a.python_type()}" + ) + args.append(new_arg) + return args + + +# This helper function is used to make sure two graphs share the same input signature. For example, +# in torch.cond, two branches might lift different set of tensors as inputs. This function helps to +# dedup the inputs and modify the graphs to take the same set of inputs. +def _merge_graph_inputs( + l_graph, l_lifted_freevars, l_name, r_graph, r_lifted_freevars, r_name +): + def dedup_and_sort_lifted_freevars(l_lifted_freevars, r_lifted_freevars): + # The nn module attributes are guaranteed to be registered into the top-level graph module during + # higher order op speculation. Therefore, get_attr nodes in two branches with the same + # target refer to the same attribute and we can safely deduplicate them with their target. + # + # Note: ideally, dynamo should just create a single proxy for the same attribute of a nn module. But + # true_branch and false_branch belong to two separate tracing contexts, they may register the same + # attribute to top level separately. This creates two get_attr proxies for the same attribute + # that have different meta data such as stack_trace (one stack trace for the true_branch, + # and the other for false_branch). It seems better to discard the proxy explicitly in cond + # than make dynamo create a single proxy for the same get_attr target. + def shared_getattrs(l_lifted_proxies, r_lifted_proxies): + true_targets = { + proxy.node.target: proxy + for proxy in l_lifted_proxies + if proxy.node.op == "get_attr" + } + l_shared_getattrs = {} + r_shared_getattrs = {} + + for false_proxy in r_lifted_proxies: + if ( + false_proxy.node.op == "get_attr" + and false_proxy.node.target in true_targets + ): + true_proxy = true_targets[false_proxy.node.target] + l_shared_getattrs[true_proxy] = true_proxy + r_shared_getattrs[false_proxy] = true_proxy + return l_shared_getattrs, r_shared_getattrs + + l_shared_getattrs, r_shared_getattrs = shared_getattrs( + l_lifted_freevars.keys(), r_lifted_freevars.keys() + ) + + l_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( + l_shared_getattrs.keys() + ) + r_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( + r_shared_getattrs.keys() + ) + unique_l_freevars = l_lifted_freevars.keys() - l_shared_freevars + unique_r_freevars = r_lifted_freevars.keys() - r_shared_freevars + + def _sort_by_name(vars): + return sorted(vars, key=lambda var: var.node.name) + + return ( + list(_sort_by_name(list(l_shared_freevars))), + list(_sort_by_name(list(r_shared_freevars))), + list(_sort_by_name(list(unique_l_freevars))), + list(_sort_by_name(list(unique_r_freevars))), + ) + + (l_shared, r_shared, unique_l, unique_r) = dedup_and_sort_lifted_freevars( + l_lifted_freevars, r_lifted_freevars + ) + + # Let's say we capture cond(pred, true_fn, false_fn, (x,)) + # With set_graph_input set to automatic, + # true_fn has lifted variables x, a, b, c + # false_fn has lifted variables x, a, b, d + # Then fixup_branch_inps make sure both branches have the same signature, i.e.: + # - true_fn(x, a, b, c_true_branch, d_false_branch) + # - false_fn(x, a, b, c_true_branch, d_false_branch) + # + # More formally, the signature has three parts in the following order: + # 1. used in both branches: x, a, b + # 2. only used in true branches: c, suffixed with _true_branch + # 3. only used in false branches: d, suffixed with _false_branch + # Within each part, we re-order the nodes by name to have a derterministic ordering for testing. + def fixup_branch_inps(graph, lifted_freevars, shared, unique_l, unique_r): + def _insert_or_replace_phs(new_args, name_suffix): + for arg in new_args: + new_ph = graph.placeholder(arg.node.name + name_suffix) + # Override with new_ph if there exists a old placeholder. + if arg in lifted_freevars: + old_ph = lifted_freevars[arg].node + old_ph.replace_all_uses_with(new_ph) + # replace_all_uses_with doesn't clean users. Clean it manually so that we could erase it. + old_ph.users = {} + graph.erase_node(old_ph) + + first_not_ph_node = next( + node for node in graph.nodes if node.op != "placeholder" + ) + with graph.inserting_before(first_not_ph_node): + _insert_or_replace_phs(shared, "") + _insert_or_replace_phs(unique_l, "_" + l_name) + _insert_or_replace_phs(unique_r, "_" + r_name) + + fixup_branch_inps(l_graph, l_lifted_freevars, l_shared, unique_l, unique_r) + fixup_branch_inps(r_graph, r_lifted_freevars, r_shared, unique_l, unique_r) + return l_graph, r_graph, l_shared, r_shared, unique_l, unique_r + + +# See NOTE [HigherOrderOperator tracing design] for details of the design +def speculate_subgraph( + tx, + f, + sub_args, + sub_kwargs, + description, + *, + # source_target is the .value of HigherOrderOpVariable and is the + # target of the proxy that we created for the higherOrderOperator. + source_target=None, + always_restore=False, + enable_grad=None, + # NOTE [argument `set_subgraph_inputs`] + # set_subgraph_inputs controls what how to construct subgraphs' placeholders from sub_args. + # 1. if your HOP supports arbitrary inputs, use set_subgraph_inputs="automatic" (most recommended). + # 2. if your HOP supports only Tensor and symnode inputs, use set_subgraph_inputs="flatten_manual" (recommended). + # If sub_args contain Pytree structure (e.g. dict/list/tuple/set), the sub_args will be flattened first. + # Then the flattened args are manually set as subgraph's placeholders. + # 3. if your HOP must preserve inputs that are not tensor or symnode as placeholders e.g. AutogradFunctionContextVariable + # use set_subgraph_inputs="manual" (not recommended). We do not recommend it in general because it has the + # restriction that user need to manually control how to create placeholders and VariableTrackers for the args. + set_subgraph_inputs="automatic", + restore_side_effects=True, + should_flatten_outputs=False, + under_activation_checkpoint=False, + # TODO - supports input_mutation and aliasing should be False by default for strictness + supports_input_mutation=True, + supports_aliasing=True, + # Pass in an originating tracer - this is needed for preserving context + # across fwd-bwd for autograd.Function + tracer=None, +): + if sub_kwargs is None: + sub_kwargs = {} + + assert set_subgraph_inputs in { + "automatic", + "semi_automatic", + "flatten_manual", + "manual", + }, "Please use one of the supported set_subgraph_inputs options." + + # See NOTE [Temporary argument `set_subgraph_inputs`] + if sub_kwargs and set_subgraph_inputs != "automatic": + unimplemented("Use `set_subgraph_inputs=automatic` when passing `sub_kwargs`.") + + try: + # ensure guards on args get installed in parent subgraph + f, sub_args, sub_kwargs = LazyVariableTracker.realize_all( + (f, sub_args, sub_kwargs), + ) + + with tx.output.subtracer(source_target, tracer) as subtracer: + sub_args_names = maybe_positional_arg_names(f) + # User mismatch in the number of args. Will eventually lead to an error. + if sub_args_names is not None and len(sub_args_names) < len(sub_args): + sub_args_names = None + args = validate_args_and_maybe_create_graph_inputs( + sub_args, + subtracer, + tx, + set_subgraph_inputs, + description, + sub_args_names, + ) + + validate_args_and_maybe_create_graph_inputs( + sub_kwargs.values(), + subtracer, + tx, + set_subgraph_inputs="automatic", + description=description, + ) + + autograd_ctx = ( + dynamo_enable_grad(tx, enable_grad) + if enable_grad is not None + else contextlib.nullcontext() + ) + checkpoint_ctx = ( + dynamo_under_activation_checkpoint(tx) + if under_activation_checkpoint + else contextlib.nullcontext() + ) + + # For handling side effects, we can make an argument that we don't + # have to do anything here. The side effects infra does a good job + # of graph breaking if we mutate any nonlocal or global variable + # while subtracing. As a result if tracing succeeds, side effects + # data structure will only contain read-only data structures that + # are put there for tracking purposes. + # But on the other hand, there is an argument that if we ever write + # a new side effect in Dynamo which does not go through the side + # effect infra, we can end up in bad state. + # Therefore we restore the side effects after tracing. The catch is + # that we have to special handle tensor variables. If we have seen a + # nonlocal variable tensor during subtracing, we want to keep a + # track of that tensor, so that later subtracing or the root tracer + # itself does not create a new proxy for the already observed tensor + # variable. + if restore_side_effects: + prev_side_effects = tx.output.side_effects.clone() + + with autograd_ctx, checkpoint_ctx: + output = f.call_function(tx, args, sub_kwargs) + + if restore_side_effects: + new_side_effects = tx.output.side_effects.clone() + prev_side_effects.track_tensor_variables_from_runahead_side_effects( + new_side_effects + ) + tx.output.side_effects = prev_side_effects + + treespec = None + if should_flatten_outputs: + # Flatten the speculated subgraph output. + output, treespec = _make_inlined(tx, pytree.tree_flatten)( + output + ).unpack_var_sequence(tx) + # Actually, transform the list (returned by flatten) into a tuple + # for dynamo consistency. + output = BuiltinVariable(tuple).call_function(tx, [output], {}) + + # Register output to graph + # Modeled off of compile_and_call_fx_graph + # TODO: support pytree output + # We check always_restore because we dont use the output or side effects of always_restore code, + # like bwd. + if always_restore: + # Nothing left to do here + return (output, treespec), tx.output.graph, subtracer.lifted_freevars + else: + validate_subgraph_output_types(output) + + # The output proxies might not belong to this SubgraphTracer + # (if they are free variables that were never lifted) + # so lift them here. + output_proxies = output.as_proxy() + output_proxies = pytree.tree_map( + subtracer.maybe_lift_tracked_freevar_to_input, output_proxies + ) + + tx.output.create_node( + "output", + "output", + (subtracer.create_arg((output_proxies,))), + {}, + ) + graph = tx.output.graph + graph.lint() + lifted_freevars = subtracer.lifted_freevars + + # NOTE: [HigherOrderOperator subgraph input ordering] + # The input ordering of the higher order ops is determined by the order of + # the creation of the placeholder. + # Manually created inputs are created in validate_args_and_maybe_create_graph_inputs before + # speculating subgraph. + # During subgraph speculation, we may lift closured tensors and free symbols as inputs, + # their ordering is determined by the time they are lifted: earlier lifted ones precede later + # lifted ones. + # + # Suppose the placeholders are + # O1, O2, X1, O3, O4, X2, X3, O5 where Xs are lifted phs + # The following code re-order the placeholders to + # O1, O2, O3, O4, O5, X1, X2, X3 + def move_lifted_freevars_phs_to_end( + graph: torch.fx.Graph, lifted_freevars: tuple[torch.fx.Node] + ): + lifted_ph_set = { + child_p.node for child_p in lifted_freevars.values() + } + + prev_phs = [n for n in graph.nodes if n.op == "placeholder"] + + # No need to reorder when graph doesn't have args or doesn't + # have lifted freevars or all inputs are lifted freevars. + if ( + len(prev_phs) == 0 + or len(lifted_ph_set) == 0 + or len(prev_phs) == len(lifted_ph_set) + ): + return + + # Step 1: find first X1 + for x1 in prev_phs: + if x1 in lifted_ph_set: + break + + assert x1 is not None and x1.op == "placeholder" + # Step 2: starting from the X1, skip Xs and prepend Os before X1. + cand_x = x1.next + while cand_x is not None and cand_x.op == "placeholder": + if cand_x in lifted_ph_set: + cand_x = cand_x.next + else: + nxt = cand_x.next + cand_x._remove_from_list() + x1.prepend(cand_x) + cand_x = nxt + + # Step 3: assert that all placeholders are in the correct order as . + # in lifted_freevars + after_phs = [ + node for node in graph.nodes if node.op == "placeholder" + ][-len(lifted_freevars) :] + assert len(after_phs) == len(lifted_freevars) + for child_proxy, ph in zip(lifted_freevars.values(), after_phs): + assert child_proxy.node is ph, ( + "The order of placeholders is different from the order of lifted_freevars" + ) + + graph.lint() + + if len(lifted_freevars) > 0: + move_lifted_freevars_phs_to_end(graph, lifted_freevars) + + if not supports_input_mutation: + mutation_info = subtracer.has_input_mutation() + if mutation_info.has_mutation: + context = f"{mutation_info.msg} in\n {graph}" + unimplemented_v2( + gb_type="Encountered input mutation during higher order op tracing", + context=context, + explanation=f"Higher order ops do not support input mutation. Found in {source_target.name()}", + hints=[ + "Consider using the debug context to change user code to avoid mutation.", + "Please open an issue.", + ], + ) + + if not supports_aliasing: + aliasing_info = subtracer.has_aliasing() + if aliasing_info.has_aliasing: + context = f"{aliasing_info.msg} in\n {graph}" + unimplemented_v2( + gb_type="Encountered aliasing during higher order op tracing", + context=context, + explanation=f"Higher order ops do not support aliasing. Found in {source_target.name()}", + hints=[ + "Consider using the debug context to change user code to avoid aliasing.", + "Please open an issue.", + ], + ) + + return ( + (output, treespec), + graph, + lifted_freevars, + ) + + except Unsupported as ex: + f_name = f"{type(f).__name__}" + if isinstance(f, UserFunctionVariable): + f_name = f.get_name() + msg = ( + f"speculate_subgraph: while introspecting {description}, we were unable " + f"to trace function `{f_name}` into a single graph. This means " + f"that Dynamo was unable to prove safety for this API and will " + f"fall back to eager-mode PyTorch, which could lead to a slowdown." + ) + log.info(msg) + log.info(ex) + raise ex + + +def make_attr(tx: "InstructionTranslator", name): + node = tx.output.create_proxy( + "get_attr", + name, + (), + {}, + ) + return node + + +class TorchHigherOrderOperatorVariable(VariableTracker): + def __init__( + self, value: HigherOrderOperator, source: Optional[Source] = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.value = value + self.source = source + + @staticmethod + def make(value, source=None, **kwargs): + from torch._higher_order_ops import BaseHOP + + if value.__name__ == "cond": + return CondHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "while_loop": + return WhileLoopHigherOrderVariable(value, source, **kwargs) + elif value.__name__ in ("map", "map_impl"): + return MapHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "executorch_call_delegate": + return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "out_dtype": + return OutDtypeHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "wrap": + return WrapHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "hints_wrapper": + return HintsWrapperHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "flex_attention": + return FlexAttentionHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "flex_attention_backward": + return FlexAttentionBackwardHighOrderVariable(value, source, **kwargs) + elif value.__name__ in ( + "wrap_activation_checkpoint", + "tag_activation_checkpoint", + ): + return CheckpointHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "_export_tracepoint": + return ExportTracepointHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "trace_wrapped": + return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs) + elif value.__name__ == "strict_mode": + return StrictModeHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "run_with_rng_state": + return RunWithRNGStateHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "associative_scan": + return AssociativeScanHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "scan": + return ScanHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "call_torchbind": + return CallTorchbindHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "wrap_with_set_grad_enabled": + return WrapWithSetGradEnabledHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "wrap_with_autocast": + return WrapWithAutocastHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "dynamo_bypassing_wrapper": + return DynamoBypassingWrapperHigherOrderVariable(value, source, **kwargs) + elif ( + value.__name__ == "auto_functionalized" + or value.__name__ == "auto_functionalized_v2" + ): + return AutoFunctionalizeHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "invoke_subgraph": + return InvokeSubgraphHigherOrderVariable(value, source, **kwargs) + elif isinstance(value, BaseHOP): + return BaseHOPVariable(value, source, **kwargs) + elif value.__name__ == "custom_function_call": + return CustomFunctionHigherOrderOperatorVariable(value, source, **kwargs) + else: + unimplemented(f"HigherOrderOperator {value.__name__}") + + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + unimplemented(f"HigherOrderOperator {self.value.__name__}") + + def as_python_constant(self): + return self.value + + +class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): + """ + Wraps torch._functorch.autograd_function.custom_function_call + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return torch._dynamo.variables.UserMethodVariable( + self.value.__call__.__func__, + torch._dynamo.variables.UserDefinedObjectVariable( + self.value, source=self.source + ), + source=AttrSource(AttrSource(self.source, "__call__"), "__func__"), + ).call_function(tx, args, kwargs) + + +class CondHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="Cond doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ListVariable, TensorVariable + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + for i, k in enumerate(["pred", "true_fn", "false_fn", "operands"]): + if v := kwargs.pop(k, None): + assert i == len(args), ( + "did not provide the right number of non-keyword args" + ) + args.append(v) + + if kwargs: + unimplemented(f"torch.cond: Got unexpected kwargs: {list(kwargs.keys())}") + + # TODO(voz): Support fake tensor dispatch for recursive + # ops - see torch/dispatch/_dispatcher.py + if len(args) != 4: + unimplemented( + f"Expected 4 arguments but got {len(args)}.\n" + f"Usage: cond(pred, true_fn, false_fn, operands)", + ) + + # Specialize into one of the branches since pred is constant + pred, true_fn, false_fn, operands = args + if type(args[0]) is ConstantVariable: + warnings.warn( + "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." + " If you want torch.cond to preserve two branches, please make the predicate a boolean tensor or a SymBool.", + UserWarning, + ) + if pred.as_python_constant(): + return true_fn.call_function(tx, operands.unpack_var_sequence(tx), {}) + else: + return false_fn.call_function(tx, operands.unpack_var_sequence(tx), {}) + + # predicate + if type(pred) not in (ConstantVariable, TensorVariable, SymNodeVariable): + unimplemented( + f"Expected pred to be bool or a boolean tensor with single " + f"item but got {str(type(pred))} " + f"with original python type {str(pred.python_type())}.", + ) + + # operands + if not isinstance(operands, (ListVariable, TupleVariable)): + unimplemented( + f"Expected operands to be a list/tuple but got " + f"{operands.python_type()}", + ) + operands_seq = operands.unpack_var_sequence(tx) + if not only_consist_of(operands, (TensorVariable, ConstantVariable)): + unimplemented( + "Expect operands to be a tuple of pytrees that only consists of tensor leaves." + ) + + # branches + _check_supported_callable_arg(tx, true_fn, "true_fn") + _check_supported_callable_arg(tx, false_fn, "false_fn") + + # Our strategy for tracing the true/false branches of cond + # are to checkpoint our graphstate, run the true branch, + # roll it back to the checkpoint, and run the false + # branch, and then merge the graphstates. Well, perhaps + # "merge" is too strong a word: we mostly assert that + # the resulting graphstates have to be the same. + # + # We only permit guards to diverge (we union the guards from + # both branches). In particular, this means that side + # effects are NOT permitted inside true/false branches; this + # would be difficult to implement, because of the path + # explosion problem. + + def speculate_branch(branch): + # NB: 0 is predicate + ix = 1 if branch else 2 + # TODO: Support kwargs + ( + (ret_val, ret_treespec), + ret_graph, + ret_lifted_freevars, + ) = speculate_subgraph( + tx, + args[ix], + operands_seq, + {}, + "cond", + source_target=self.value, + should_flatten_outputs=True, + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)): + unimplemented( + "Expected branches to return a possibly nested pytree of tensors " + "or constant ints but it consists of others.", + ) + for ret in ret_val.unpack_var_sequence(tx): + if isinstance(ret, ConstantVariable) and ret.python_type() is not int: + unimplemented( + "Expected branches to return a possibly nested pytree of tensors " + f"or constant ints but it consists of others {ret.python_type()}.", + ) + return ret_val, ret_treespec, ret_graph, ret_lifted_freevars + + (true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch( + True + ) + true_nn_modules = dict(tx.output.nn_modules) + + ( + false_r, + false_treespec, + false_graph, + false_lifted_freevars, + ) = speculate_branch(False) + false_nn_modules = dict(tx.output.nn_modules) + + same_treespec = _make_inlined(tx, pytree.TreeSpec.__eq__)( + true_treespec, false_treespec + ) + if not same_treespec.as_python_constant(): + unimplemented("Expected branches to return the same pytree structure.") + + ( + true_graph, + false_graph, + true_shared, + _false_shared, + unique_true, + unique_false, + ) = _merge_graph_inputs( + true_graph, + true_lifted_freevars, + "true_branch", + false_graph, + false_lifted_freevars, + "false_branch", + ) + + true_name = tx.output.install_subgraph( + "cond_true", + torch.fx.GraphModule(true_nn_modules, true_graph), + ) + false_name = tx.output.install_subgraph( + "cond_false", + torch.fx.GraphModule(false_nn_modules, false_graph), + ) + + true_node = make_attr(tx, true_name) + false_node = make_attr(tx, false_name) + + p_args = ( + pred.as_proxy(), + true_node, + false_node, + # We pick true_shared but it shouldn't matter + tuple(true_shared + unique_true + unique_false), + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.cond, + p_args, + {}, + None, + true_treespec, + ) + + +class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable): + def __init__(self, hop, source, script_obj_var, method_name) -> None: + super().__init__(hop, source) + self.script_obj_var = script_obj_var + self.method_name = method_name + + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + args_proxy = [arg.as_proxy() for arg in args] + kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple( + [self.script_obj_var.as_proxy(), self.method_name] + args_proxy + ), + kwargs=kwargs_proxy, + ), + ) + + +def validate_subgraph_output_types(output: VariableTracker): + """Verify that that the output of the subgraph is a tensor, + int, bool, SymBool, or SymInt. + """ + from . import TensorVariable + + if non_tensor_output := find_mismatched_vars( + output, TensorVariable, allow_none=True + ): + for out in non_tensor_output: + if ( + isinstance(out, SymNodeVariable) and out.python_type() in (int, bool) + ) or ( + isinstance(out, ConstantVariable) and out.python_type() in (int, bool) + ): + continue + unimplemented( + f"HigherOrderOperator body's output must consist of tensors or ints only but got {out.python_type()}" + ) + + +class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="while_loop doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.while_loop import _create_unbacked_symint + + from . import TensorVariable + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + cond_fn, body_fn, operands, additional_inputs = args + + # Input checks + for i, k in enumerate(["cond_fn", "body_fn", "operands"]): + if v := kwargs.pop(k, None): + assert i == len(args), ( + "did not provide the right number of non-keyword args" + ) + args.append(v) + + if kwargs: + unimplemented( + f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}" + ) + + if len(args) != 4: + unimplemented( + f"Expected 4 arguments but got {len(args)}.\n" + f"Usage: while_loop(cond_fn, body_fn, operands)", + ) + + # cond_fn and body_fn input check + _check_supported_callable_arg(tx, cond_fn, "cond_fn") + _check_supported_callable_arg(tx, body_fn, "body_fn") + + # operands input check + operands_seq = operands.unpack_var_sequence(tx) + + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + f"Expected additional_inputs to be a list/tuple but got " + f"{additional_inputs.python_type()}. It seems to be an " + f"internal error, please report an issue to PyTorch." + ) + additional_inputs_seq = additional_inputs.unpack_var_sequence(tx) + + with discard_graph_changes(tx): + # See NOTE [unspecialize int carry with unbacked symints] + # Note: this must be run under discard graph changes. + def create_unbacked_sym_node_var(tx) -> SymNodeVariable: + example_value = _create_unbacked_symint( + tx.output.fake_mode, ignore_fresh_unbacked_symbols=True + ) + proxy = tx.output.current_tracer.create_graph_input( + "unbacked_symint", type(example_value), example_value + ) + return SymNodeVariable.create(tx, proxy, example_value) + + new_operands_seq = [ + ( + create_unbacked_sym_node_var(tx) + if ( + isinstance(carry, ConstantVariable) + and carry.python_type() is int + ) + or (isinstance(carry, SymNodeVariable)) + else carry + ) + for carry in operands_seq + ] + + # create cond subgrpahs + ( + (cond_r, _cond_treespec), + cond_graph, + cond_lifted_freevars, + ) = speculate_subgraph( + tx, + cond_fn, + new_operands_seq + additional_inputs_seq, + {}, + "while_loop", + source_target=self.value, + # NOTE [why we cannot use "automatic" for while_loop]: + # The reason is that we want to enforce + # the ordering of inputs and outputs to be consistent and the the ordering + # of cond_fn and body_fn to the consistent. + # e.g. suppose we use "automatic" and we have: + # + # def body_fn(ph1, ph2): + # new_a, new_b = ph2.cos(), ph1.sin() + # return new_a, new_b + # + # a, b = torch.randn(3), torch.randn(3) + # new_a, new_b = body_fn(a, b) + # + # Using automatic, the ordering of arguments will be the order that they're + # used. In this example, the capture graph looks like: + # + # def captured_body(ph1, ph2): + # new_a, new_b = ph1.cos(), ph2.add_(1) + # return new_a, new_b + # + # This is fine when we change the calling convention of captured_body to be + # new_a, new_b = captured_body(b, a). + # But for while_loop, the next iteration's input is previous iteration output + # we'll end up feeding captured_body(new_a, new_b) instead. + # So it's best we always enforce the ordering of carried_inputs the same as outputs + # with "flatten_manual". + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + cond_nn_modules = dict(tx.output.nn_modules) + validate_subgraph_output_types(cond_r) + if isinstance(cond_r, TensorVariable): + cond_r_meta = _extract_tensor_metadata( + cond_r.proxy.node.meta["example_value"], include_contiguity=False + ) + if ( + not cond_r_meta.dtype == torch.bool + or not cond_r_meta.shape == torch.Size([]) + ): + unimplemented( + f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}" + ) + elif isinstance(cond_r, ConstantVariable): + # short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False + pred = cond_r.as_python_constant() + if pred: + unimplemented( + f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}" + ) + else: + return operands + + # create body subgraph + ( + (body_r, body_treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + body_fn, + new_operands_seq + additional_inputs_seq, + {}, + "while_loop", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + should_flatten_outputs=True, + supports_input_mutation=False, + supports_aliasing=False, + ) + validate_subgraph_output_types(body_r) + + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + body_r.unpack_var_sequence(tx), + operands_seq, + "body_fn_output", + "carried_inputs", + include_contiguity=False, + ) + + ( + cond_graph, + body_graph, + cond_shared, + _body_shared, + cond_unique, + body_unique, + ) = _merge_graph_inputs( + cond_graph, + cond_lifted_freevars, + "cond_fn", + body_graph, + body_lifted_freevars, + "body_fn", + ) + + # Note: cond_shared and body_shared refer to the same proxy in parent graph + # so using either of them is OK. Use cond_shared as it doesn't matter. + additional_lifted_inputs = cond_shared + cond_unique + body_unique + + body_nn_modules = dict(tx.output.nn_modules) + + cond_name = tx.output.install_subgraph( + "cond_fn", + torch.fx.GraphModule(cond_nn_modules, cond_graph), + ) + body_name = tx.output.install_subgraph( + "body_fn", + torch.fx.GraphModule(body_nn_modules, body_graph), + ) + + cond_node = make_attr(tx, cond_name) + body_node = make_attr(tx, body_name) + + p_args = ( + cond_node, + body_node, + tuple([operand.as_proxy() for operand in operands_seq]), + tuple( + [inp.as_proxy() for inp in additional_inputs_seq] + + additional_lifted_inputs + ), + ) + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + unspecialized_flat_example_value = pytree.tree_map_only( + (int, torch.SymInt), + lambda _: _create_unbacked_symint( + tx.output.fake_mode, ignore_fresh_unbacked_symbols=False + ), + flat_example_value, + ) + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.while_loop, + p_args, + {}, + unspecialized_flat_example_value, + body_treespec, + ) + + +class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="associative_scan must be captured completely with torch.compile." + ) + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.utils import first_slice_copy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + def arg_extractor(combine_fn, xs, additional_inputs): + return combine_fn, xs, additional_inputs + + combine_fn, xs, additional_inputs = arg_extractor(*args, **kwargs) + + if args[0].python_type() is functools.partial: + # This is the standard case when the user calls the frontend + # and the frontend invokes dynamo + if len(args) != 2: + unimplemented( + f"Expected 2 positional arguments but got {len(args)}.\n" + f"Usage: associative_scan(combine_fn, xs)", + ) + + xs_treespec = args[0].keywords["spec"] + + # combine_fn input check + # We need to get the pure combine_fn from the functools.partial + _check_supported_callable_arg( + tx, combine_fn.keywords["combine_fn"], "combine_fn" + ) + else: + # This case is hit during re-tracing, for example in export tests + # In this case, the combine_fn is a callable and not a functools.partial + xs_treespec = _make_inlined(tx, pytree.tree_structure)(xs) + + _check_supported_callable_arg(tx, combine_fn, "combine_fn") + + # xs input check + if not isinstance(xs, (ListVariable, TupleVariable)): + unimplemented( + f"Expected xs to be a list/tuple but got " + f"{xs.python_type()}. It seems to be an " + f"internal error, please report an issue to PyTorch." + ) + xs_vars = xs.unpack_var_sequence(tx) + _check_all_tensorvariable(xs_vars) + + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + f"Expected additional_inputs to be a list/tuple but got " + f"{additional_inputs.python_type()}. It seems to be an " + f"internal error, please report an issue to PyTorch." + ) + additional_inputs_vars = additional_inputs.unpack_var_sequence(tx) + _check_all_tensorvariable(additional_inputs_vars) + + scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] + if scan_length == 0: + unimplemented( + "associative_scan() operator doesn't support zero-sized tensors during tracing." + ) + + # Trace the subgraph + # The sub_args is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0 + # the sub_args shape will be (4, ). + with discard_graph_changes(tx): + sub_args = [ + _make_inlined(tx, first_slice_copy)(leaf) + for leaf in itertools.chain(xs_vars, xs_vars) + ] + sub_args_additional_inputs = [ + t.call_method(tx, "clone", args=(), kwargs={}) + for t in additional_inputs_vars + ] + + sub_args = sub_args + sub_args_additional_inputs + ( + (combine_result, _combine_treespec), + combine_graph, + combine_lifted_freevars, + ) = speculate_subgraph( + tx, + combine_fn, + sub_args, + sub_kwargs={}, + description="associative_scan_combine_fn", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # Ensure that the output of scan is a flattened list of elements, + # because downstream operations assume that the output of HOPs + # is flattened + output_node = combine_graph.find_nodes(op="output")[0] + output_node.args = (pytree.tree_leaves(output_node.args),) + combine_graph.lint() + + # Collect the results from the combine_fn + results, _combine_treespec = _make_inlined(tx, pytree.tree_flatten)( + combine_result + ).unpack_var_sequence(tx) + + # Check whether the combine_fn returns one child tree for the output. + if _combine_treespec.as_python_constant().num_leaves < 1: + unimplemented( + f"combine_fn needs to produce one pytree for the output " + f"but combine_fn produces the pytree {_combine_treespec.as_python_constant()}." + ) + + # Check whether the outs produced by combine_fn has the same treespec as xs + # We need to have this check this way, because in case init is a TreeSpec and carry + # but carry is only a LeafSpec, these two cannot be compared correctly. + if ( + isinstance(xs_treespec.as_python_constant(), pytree.LeafSpec) + != isinstance(_combine_treespec.as_python_constant(), pytree.LeafSpec) + ) or not _make_inlined(tx, pytree.TreeSpec.__eq__)( + xs_treespec, _combine_treespec + ).as_python_constant(): + unimplemented( + f"The tree structure of the xs and the outs of the combine_fn are are expected to be identical, but got " + f"xs: {xs_treespec.as_python_constant()} vs output: {_combine_treespec.as_python_constant()}." + ) + + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + [_make_inlined(tx, first_slice_copy)(t) for t in xs_vars], + results.items, + "initial_xs", + "combine_fn_output", + include_contiguity=False, + ) + + combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) + combine_freevars_proxy = tuple(combine_lifted_freevars.keys()) + + # Compute the proxies for the input check + proxy_vars_inputcheck = ( + tuple(sarg.as_proxy() for sarg in sub_args) + combine_freevars_proxy + ) + + from torch._higher_order_ops.utils import _maybe_fake_tracing + from torch._inductor.utils import is_pointwise_use + + with tx.fake_mode: + sub_args_fake = [ + leaf.node.meta["example_value"].clone() + if hasattr(leaf.node.meta["example_value"], "clone") + else leaf.node.meta["example_value"] + for leaf in pytree.tree_leaves(proxy_vars_inputcheck) + ] + pre_dispatch = False + + fx = _maybe_fake_tracing( + combine_gm, sub_args_fake, pre_dispatch=pre_dispatch + ) + + for node in fx.graph.nodes: + # Check that the combine_fn is pointwise, if combine_mode='pointwise' + if not all( + is_pointwise_use(use) or use.op == "output" for use in node.users + ): + raise RuntimeError( + "For combine_mode='pointwise', the combine_fn needs to be pointwise" + ) + + combine_fn_name = tx.output.install_subgraph( + "associative_scan_combine_fn", combine_gm + ) + + # Compute the proxies + xs_proxy = xs.as_proxy() + combine_freevars_proxy = tuple(combine_lifted_freevars.keys()) + additional_inputs_proxy = additional_inputs.as_proxy() + combine_freevars_proxy + + p_args = ( + make_attr(tx, combine_fn_name), + xs_proxy, + additional_inputs_proxy, + ) + + with tx.fake_mode: + out_meta = tuple( + inp_proxy.node.meta["example_value"].clone() for inp_proxy in xs_proxy + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.associative_scan, + p_args, + {}, + out_meta, + xs_treespec, + ) + + +class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="scan must be captured completely with torch.compile." + ) + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.scan import _extract_carry_and_out, stack_y + from torch._higher_order_ops.utils import first_slice_copy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + # combine_fn input check + def _check_combine_fn_is_normalized(combine_fn_var): + if not isinstance( + combine_fn_var, + ( + variables.nn_module.NNModuleVariable, + variables.FunctoolsPartialVariable, + ), + ): + unimplemented( + f"Expected combine_fn to be wrapped as functools.partial in scan user-facing api " + f"or a graph module if we're re-exporting but got " + f"{combine_fn.python_type()}. Please report an issue to PyTorch if you're seeing this." + ) + return isinstance(combine_fn_var, variables.nn_module.NNModuleVariable) + + def arg_extractor(combine_fn, init, xs, additional_inputs): + return combine_fn, init, xs, additional_inputs + + combine_fn, init, xs, additional_inputs = arg_extractor(*args, **kwargs) + init_vars = init.unpack_var_sequence(tx) + xs_vars = xs.unpack_var_sequence(tx) + additional_inputs_vars = additional_inputs.unpack_var_sequence(tx) + + # combine_fn input check + combine_fn_is_normalized = _check_combine_fn_is_normalized(combine_fn) + if combine_fn_is_normalized: + combine_gm = combine_fn.value + assert isinstance(combine_gm, torch.fx.GraphModule), ( + combine_fn, + combine_gm, + ) + else: + # combine_fn input check + # We need to get the pure combine_fn from the functools.partial + _check_supported_callable_arg( + tx, combine_fn.keywords["combine_fn"], "combine_fn" + ) + # xs input check + if not isinstance(xs, (ListVariable, TupleVariable)): + unimplemented( + f"Expected xs to be a list/tuple but got " + f"{xs.python_type()}. It seems to be an " + f"internal error, please report an issue to PyTorch." + ) + # init input check + if not isinstance(init, (ListVariable, TupleVariable)): + unimplemented( + f"Expected init to be a list/tuple with at least one element but got " + f"{init.python_type()}. It seems to be an " + f"internal error, please report an issue to PyTorch." + ) + if len(init_vars) == 0: + unimplemented( + "scan() operator requires init leaves. It seems to be an " + "internal error, please report an issue to PyTorch." + ) + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + f"Expected additional_inputs to be a list/tuple but got " + f"{additional_inputs.python_type()}. It seems to be an " + f"internal error, please report an issue to PyTorch." + ) + # scan_length check + scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] + if scan_length == 0: + unimplemented("NYI: scan() operator doesn't support zero scan_length.") + _check_all_tensorvariable(init_vars) + _check_all_tensorvariable(xs_vars) + _check_all_tensorvariable(additional_inputs_vars) + + with discard_graph_changes(tx): + sub_args_init = [ + ini.call_method(tx, "clone", args=(), kwargs={}) for ini in init_vars + ] + # The sub_args_inp is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0 + # the sub_args_inp shape will be (4, ). + sub_args_inp = [_make_inlined(tx, first_slice_copy)(inp) for inp in xs_vars] + sub_args_additional_inputs = [ + t.call_method(tx, "clone", args=(), kwargs={}) + for t in additional_inputs_vars + ] + + sub_args = sub_args_init + sub_args_inp + sub_args_additional_inputs + ( + (combine_result, _combine_treespec), + combine_graph, + combine_lifted_freevars, + ) = speculate_subgraph( + tx, + combine_fn, + sub_args, + sub_kwargs={}, + description="scan_combine_fn", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # Ensure that the output of scan is a flattened list of elements, + # because downstream operations assume that the output of HOPs + # is flattened + output_node = combine_graph.find_nodes(op="output")[0] + output_node.args = (pytree.tree_leaves(output_node.args),) + combine_graph.lint() + combine_freevars_proxy = list(combine_lifted_freevars.keys()) + combine_result_vars = combine_result.unpack_var_sequence(tx) + + if combine_fn_is_normalized: + carry_vars, out_vars = _extract_carry_and_out( + combine_result_vars, len(init_vars) + ) + else: + if len(combine_result_vars) != 2: + unimplemented( + f"Expect combine_fn to return a tuple (next_carry, y) but got {combine_result_vars}" + ) + carry_tree, out_vars = combine_result_vars + carry_vars, carry_treespec = _make_inlined(tx, pytree.tree_flatten)( + carry_tree + ).unpack_var_sequence(tx) + carry_vars = carry_vars.unpack_var_sequence(tx) + out_vars = _make_inlined(tx, pytree.tree_leaves)( + out_vars + ).unpack_var_sequence(tx) + + # additional output checking + _combine_treespec = _make_inlined(tx, pytree.tree_structure)(combine_result) + + check_meta_consistency_vt( + init_vars, + carry_vars, + "init", + "carry", + ) + + # Check meta data of carries and inits. If we pass this stage, we are sure that the init and carries + # have the same tree structure. + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + init_vars, + carry_vars, + "init", + "carry", + include_contiguity=False, + ) + + xs_proxy = xs.as_proxy() + init_proxy = init.as_proxy() + additional_inputs_proxy = list(additional_inputs.as_proxy()) + list( + combine_freevars_proxy + ) + y_proxies = [out_var.as_proxy() for out_var in out_vars] + + combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) + combine_fn_name = tx.output.install_subgraph("scan_combine_fn", combine_gm) + + p_args = ( + make_attr(tx, combine_fn_name), + init_proxy, + xs_proxy, + additional_inputs_proxy, + ) + + with tx.fake_mode: + example_carry = [ + init_p.node.meta["example_value"].clone() for init_p in init_proxy + ] + # For the fake mode, we need to duplicate the init tensor along the dim + # to have the same size as the xs arguments + example_stacked_out = [ + stack_y(y.node.meta["example_value"], scan_length) for y in y_proxies + ] + out_meta = [*example_carry, *example_stacked_out] + + return _call_function_and_unflatten_output( + tx, torch.ops.higher_order.scan, p_args, {}, out_meta, _combine_treespec + ) + + +def non_single_tensor_return_unsupported(api, ret): + from . import TensorVariable + + if not isinstance(ret, TensorVariable): + raise Unsupported( + f"{api} over function that returns something other than one Tensor" + ) + + +class MapHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = False + supports_aliasing = False + + @raise_hard_error_if_graph_break( + reason="map doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if len(kwargs) > 0: + unimplemented( + "torch.ops.higher_order.map: kwargs are not supported in the map operator." + ) + + _check_supported_callable_arg(tx, args[0], "map_fn") + + # args = f, flat_xs, flat_args + assert isinstance(args[1], (ListVariable, TupleVariable)), args[1] + assert isinstance(args[2], (ListVariable, TupleVariable)), args[2] + unpacked_xs = args[1].unpack_var_sequence(tx) + unpacked_args = args[2].unpack_var_sequence(tx) + + sample_shape = get_fake_value(unpacked_xs[0].as_proxy().node, tx).size() + + if len(sample_shape) < 1 or sample_shape[0] == 0: + unimplemented( + "map() operator doesn't support scalar or zero-sized tensors during tracing." + ) + + # To get the example output from map() we will need to provide at least one sample to + # the loop body. In our case we will always use xs[0], and our map() won't support zero + # sized tensor during tracing. + with discard_graph_changes(tx): + sliced_xs = [ + xs.call_method( + tx, + "select", + args=(VariableTracker.build(tx, 0), VariableTracker.build(tx, 0)), + kwargs={}, + ) + for xs in unpacked_xs + ] + + # TODO: Support kwargs + ( + (body_r, body_spec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], + [ + *sliced_xs, + *unpacked_args, + ], + {}, + "torch.ops.higher_order.map", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + should_flatten_outputs=True, + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + # Check all outputs of map are tensors. + # For map, outputting None is OK, thus ignore None values in the check + body_r_vars = body_r.unpack_var_sequence(tx) + none_mask = [ + type(x.realize()) is ConstantVariable and x.as_python_constant() is None + for x in body_r_vars + ] + _check_all_tensorvariable( + [br for bm, br in zip(none_mask, body_r_vars) if not bm] + ) + + body_nn_modules = dict(tx.output.nn_modules) + + body_name = tx.output.install_subgraph( + "map_body", + torch.fx.GraphModule(body_nn_modules, body_graph), + ) + + body_node = make_attr(tx, body_name) + + p_args = ( + body_node, + [xs.as_proxy() for xs in unpacked_xs], + [arg.as_proxy() for arg in unpacked_args] + + list(body_lifted_freevars.keys()), + ) + + return _call_function_and_unflatten_output( + tx, torch.ops.higher_order.map_impl, p_args, {}, None, body_spec + ) + + +class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + # This is operator for delegation within Executorch which calls a + # specific function in the given lowered module with the given + # operators. The actual operator is defined in the Executorch codebase. + # This is a bad hierarchical violation since + # executorch_call_delegate sits at a higher level than dynamo, but + # there's no real solution to this issue yet. + if len(kwargs) > 0: + unimplemented( + "executorch_call_delegate: kwargs arguments were not enabled." + ) + lowered_module = tx.output.get_submodule(args[0].module_key) + + lowered_node = make_attr(tx, args[0].module_key) + + p_args = tuple(arg.as_proxy() for arg in args[1:]) + real_sub_args = pytree.tree_map_only( + torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args + ) + + with tx.fake_mode: + example_value = lowered_module.original_module.module()(*real_sub_args) + + # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: + # executorch modules promise not to alias inputs and outputs. + # Thus, output FakeTensors will correctly not alias input FakeTensors. + _assert_tensors_nonaliasing(real_sub_args, example_value) + + p_args = (lowered_node,) + p_args + + # Store the invocation as a call + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + +class FunctorchHigherOrderVariable(UserFunctionVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return super().call_function(tx, args, kwargs) + + +class FunctionalCallVariable(FunctorchHigherOrderVariable): + def call_function( + self, tx, args: list[VariableTracker], kwargs: dict[str, VariableTracker] + ) -> VariableTracker: + if not torch._dynamo.config.inline_inbuilt_nn_modules: + unimplemented( + "torch.func.functional_call capture is disabled, " + "it can be turned on by setting " + "`torch._dynamo.config.inline_inbuilt_nn_modules=True`" + ) + return super().call_function(tx, args, kwargs) + + +class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): + supports_input_mutation = True + supports_aliasing = True + + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body" + ): + return tx.output.install_subgraph( + f"{attr_name}", + body_gmod, + ) + + def create_wrapped_node( + self, + tx: "InstructionTranslator", + fn_vt, + fn_args_vt, + kwargs, + description, + under_activation_checkpoint=False, + *, + subgraph_name="wrap_body", + ): + # See NOTE [HigherOrderOperator tracing design] for more details + + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn_vt, + fn_args_vt, + kwargs, + description, + source_target=self.value, + should_flatten_outputs=True, + under_activation_checkpoint=under_activation_checkpoint, + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = self.install_subgraph_in_output_graph( + tx, + fn_vt, + fn_args_vt, + kwargs, + body_gmod, + attr_name=subgraph_name, + ) + body_node = make_attr(tx, body_name) + + # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, + # all the arguments are lifted. + lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) + + proxy_args = (body_node,) + lifted_args + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return proxy_args, {}, example_value, body_r, treespec, body_gmod, body_name + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # This flattens the kwargs into lifted args + ( + p_args, + p_kwargs, + _example_value, + body_r, + treespec, + _, + _, + ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "wrap") + + if len(p_kwargs) > 0: + unimplemented("kwargs should have been flattened into lifted args") + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, self.value, tuple(p_args), p_kwargs, flat_example_value, treespec + ) + + +class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable): + """ + This hop is not exposed to users but is inserted into the graph + after export as a post-processing step. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if kwargs: + unimplemented( + f"wrap_with_set_grad_enabled: Got unexpected kwargs: {list(kwargs.keys())}" + ) + + grad_enabled, fn_var, *rest_args = args + + if not isinstance(grad_enabled, ConstantVariable): + unimplemented("grad_enabled must be a constant") + + _check_supported_callable_arg(tx, fn_var, "enable_grad_fn") + + with torch.set_grad_enabled(grad_enabled.as_python_constant()): + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn_var, + [*rest_args], + {}, + "torch.ops.higher_order.wrap_with_set_grad_enabled", + source_target=self.value, + set_subgraph_inputs="manual", + should_flatten_outputs=True, + ) + + if len(body_lifted_freevars) > 0: + unimplemented( + f"wrap_with_set_grad_enabled: Got unexpected freevars {body_lifted_freevars}" + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = tx.output.install_subgraph( + "wrap_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + proxy_args = tuple( + [ + grad_enabled.as_python_constant(), + body_node, + ] + + [operand.as_proxy() for operand in rest_args] + ) + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + return _call_function_and_unflatten_output( + tx, self.value, proxy_args, {}, example_value, treespec + ) + + +class WrapWithAutocastHigherOrderVariable(TorchHigherOrderOperatorVariable): + """ + This hop is not exposed to users but is inserted into the graph + after export as a post-processing step. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if kwargs: + unimplemented( + f"wrap_with_autocast: Got unexpected kwargs: {list(kwargs.keys())}" + ) + + device_type, dtype, enabled, cache_enabled, fn_var, *rest_args = args + + for arg in [device_type, dtype, enabled, cache_enabled]: + if not isinstance(arg, ConstantVariable): + unimplemented( + "device_type, dtype, enabled, cache_enabled must be constants" + ) + + _check_supported_callable_arg(tx, fn_var, "autocast") + + python_constants = [ + arg.as_python_constant() + for arg in [device_type, dtype, enabled, cache_enabled] + ] + + with torch.autocast(*python_constants): + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn_var, + [*rest_args], + {}, + "torch.ops.higher_order.wrap_with_autocast", + source_target=self.value, + set_subgraph_inputs="manual", + should_flatten_outputs=True, + ) + + if len(body_lifted_freevars) > 0: + unimplemented( + f"wrap_with_autocast: Got unexpected freevars {body_lifted_freevars}" + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = tx.output.install_subgraph( + "wrap_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + proxy_args = tuple( + [ + *python_constants, + body_node, + ] + + [operand.as_proxy() for operand in rest_args] + ) + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, self.value, proxy_args, {}, example_value, treespec + ) + + +class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" + ) -> "VariableTracker": + _check_supported_callable_arg(tx, args[0], "body_fn") + + # inputs + if len(args) != 3: + unimplemented( + f"Expected 3 arguments but got {len(args)}.\n" + f"Usage: hints_wrapper(body_fn, args, kwargs, hints).\n" + f"kwargs required to be provided explicitly." + ) + + if not isinstance(args[1], (ListVariable, TupleVariable)): + unimplemented( + f"Expected a tuple but got {args[1].python_type()}", + ) + operands = args[1].unpack_var_sequence(tx) + + if not isinstance(args[2], ConstDictVariable): + unimplemented( + f"Expected a dict but got {args[2].python_type()}", + ) + + if "hints" not in kwargs: + raise IncorrectUsage("hints_wrapper - key hints not provided") + + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], # function + operands, + args[2].as_python_constant(), + "hints_wrapper", + source_target=self.value, + should_flatten_outputs=True, + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = tx.output.install_subgraph( + "hints_wrapper_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, + # all the arguments are lifted. + lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) + p_args = (body_node, lifted_args, {}) + + p_kwargs = {} + # add hints into p_kwargs + p_kwargs["hints"] = kwargs["hints"].as_python_constant() + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, self.value, p_args, p_kwargs, flat_example_value, treespec + ) + + +class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + if len(kwargs) > 0: + unimplemented("out_dtype does not handle kwargs") + + p_args = tuple(arg.as_proxy() for arg in args) + op = p_args[0] + output_dtype = p_args[1] + fake_sub_args = pytree.tree_map_only( + torch.fx.Proxy, lambda a: a.node.meta["example_value"], p_args[2:] + ) + # This is a simplified implementation of this operator just for tracing. + # Actual implementation may also first promote the arguments + example_value = op(*fake_sub_args).to(dtype=output_dtype) + + # Store the invocation as a call + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + +class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="strict_mode HOO doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unpacked_sequence = args[1].unpack_var_sequence(tx) + # TODO (tmanlaibaatar) support pytree here + for arg in unpacked_sequence: + if isinstance(arg, (ListVariable, TupleVariable, ConstDictVariable)): + unimplemented("strict_mode HOO only works for flat inputs for now") + + if kwargs: + unimplemented( + f"strict_mode HOO received unexpected kwargs: {list(kwargs.keys())}" + ) + + ( + (ret_val, ret_treespec), + ret_graph, + ret_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], + unpacked_sequence, + {}, + "strict_mode", + source_target=self.value, + should_flatten_outputs=True, + ) + + strict_mode_nn_modules = dict(tx.output.nn_modules) + + strict_mode_name = tx.output.install_subgraph( + "strict_mode_body", + torch.fx.GraphModule(strict_mode_nn_modules, ret_graph), + ) + + strict_mode_node = make_attr(tx, strict_mode_name) + p_args = ( + strict_mode_node, + tuple(arg for arg in ret_lifted_freevars.keys()), + ) + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + ret_val.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.strict_mode, + p_args, + {}, + flat_example_value, + ret_treespec, + ) + + +class CheckpointHigherOrderVariable(WrapHigherOrderVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.wrap import TagActivationCheckpoint + from torch.utils.checkpoint import noop_context_fn + + from .builder import wrap_fx_proxy + + context_fn = None + if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn: + ctx = kwargs.pop("context_fn") + if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable): + context_fn = ctx.fn + elif isinstance( + ctx, torch._dynamo.variables.functions.FunctoolsPartialVariable + ): + context_fn = ctx.as_python_constant() + else: + raise NotImplementedError( + f"checkpoint not implemented for {type(ctx)} context_fn" + ) + + checkpoint_kwargs, gmod_kwargs = TagActivationCheckpoint.divide_kwargs(kwargs) + + # Here we use checkpoint_kwargs (and not gmod kwargs). gmod_kwargs are + # already flattened above and managed inside the fx graph. + ( + p_args, + _, + example_value, + _body_r, + treespec, + checkpointed_gmod, + _, + ) = self.create_wrapped_node( + tx, + args[0], + args[1:], + gmod_kwargs, + "torch.utils.checkpoint.checkpoint", + under_activation_checkpoint=True, + ) + if context_fn is not None: + checkpointed_gmod.meta["_checkpoint_context_fn"] = context_fn + + _, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs) + + # Store the invocation as a call + variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs=checkpoint_kwargs, + ), + example_value=example_value, + ) + + if treespec is None: + return variable + + # Transform variable back into a list (previously made into a tuple by + # speculate_subgraph function) so as to respect the pytree API typing. + variable = BuiltinVariable(list).call_function(tx, [variable], {}) + + return _make_inlined(tx, pytree.tree_unflatten)(variable, treespec) + + +class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable): + def __init__(self, hop, source) -> None: + super().__init__(hop, source) + + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + func_var = args[0] + + if isinstance(func_var, torch._dynamo.variables.UserFunctionVariable): + func = func_var.fn + elif isinstance( + func_var, torch._dynamo.variables.functions.FunctoolsPartialVariable + ): + func = func_var.as_python_constant() + else: + raise RuntimeError( + f"DynamoBypassingWrapperHigherOrderVariable: Unsupported function {type(func_var)}" + ) + ( + p_args, + _, + example_value, + _body_r, + treespec, + gmod, + _, + ) = self.create_wrapped_node( + tx, + args[1], + args[2:], + kwargs, + str(func), + ) + + # Alternatively, we could've stored only the function's fqn and + # reconstructed, but that requires the function to be a global. + gmod_meta_key = "_dynamo_bypassing_wrapper_fn" + gmod.meta[gmod_meta_key] = func + + # Store the invocation as a call + variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=(gmod_meta_key,) + tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + if treespec is None: + return variable + + # Transform variable back into a list (previously made into a tuple by + # speculate_subgraph function) so as to respect the pytree API typing. + variable = BuiltinVariable(list).call_function(tx, [variable], {}) + + return _make_inlined(tx, pytree.tree_unflatten)(variable, treespec) + + +class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable): + def proxy_submod(self, tx, arg): + assert isinstance(arg.source.base, DictGetItemSource) + submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value) + p_submod = make_attr(tx, submod_name) + set_example_value(p_submod.node, arg.value) + return p_submod + + def to_proxy(self, tx, arg): + if isinstance(arg, UnspecializedNNModuleVariable): + return self.proxy_submod(tx, arg) + elif isinstance(arg, (ListVariable, TupleVariable)): + return arg.python_type()( + self.to_proxy(tx, nested_arg) for nested_arg in arg.items + ) + else: + return arg.as_proxy() + + def call_function( + self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + try: + p_args = tuple(self.to_proxy(tx, arg) for arg in args) + p_kwargs = {key: self.to_proxy(tx, arg) for key, arg in kwargs.items()} + except (NotImplementedError, Unsupported) as err: + raise Unsupported( + "Missing Dynamo support for FlexAttentionBackward HOP argument. Please file an issue." + ) from err + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + +class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): + """ + Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace + by unwrapping the higher order op and inlining through it. This op + is created by dynamo to survive through AotAutograd, then unwrapped + here in the call to dynamo from compiled autograd. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + kwargs = dict(kwargs) + fn = kwargs.pop("fn") + return fn.call_function(tx, args, kwargs) + + +class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable): + @staticmethod + def normalize_to_args(args, kwargs): + # input signature is (query, key, value, score_mod, block_mask, *other_buffers), + # block_mask is a tuple, and we don't want to flatten it. + # only flatten kwargs into lists + flat_kwargs = pytree.tree_flatten(kwargs)[0] + + # Combine the flattened lists + all_args = args + flat_kwargs + return all_args + + def create_wrapped_node( + self, + tx: "InstructionTranslator", + query: "VariableTracker", + fn: "VariableTracker", + fn_name: str, + ): + from .._trace_wrapped_higher_order_op import TransformGetItemToIndex + + tx: InstructionTranslator = tx + + def create_scalar(): + return query.call_method( + tx, + "new_empty", + (VariableTracker.build(tx, []),), + { + "dtype": VariableTracker.build(tx, torch.int32), + }, + ) + + with discard_graph_changes(tx): + bhmn = [create_scalar() for _ in range(4)] + if fn_name == "score_mod": + scores_require_grad: bool = query.requires_grad + score = query.call_method( + tx, + "new_empty", + (VariableTracker.build(tx, []),), + {"requires_grad": VariableTracker.build(tx, scores_require_grad)}, + ) + new_args = [score, *bhmn] + else: + assert fn_name == "mask_fn", "Illegal function name: " + fn_name + new_args = [*bhmn] + + with TransformGetItemToIndex(): + ( + (_body_output, _body_treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn, + new_args, + {}, # expect only args no kwargs for now + description=fn_name, + source_target=self.value, + set_subgraph_inputs="flatten_manual", + ) + + body_name = tx.output.install_subgraph( + fn_name, + torch.fx.GraphModule(tx.output.nn_modules, body_graph), + ) + + body_node = make_attr(tx, body_name) + + # It is possible that the score-mod function captures some free variables that are not + # passed in as arguments. In this case, we need to lift them, which is handled by speculate_subgraph. + # We then need to create proxies for this + the inputs. + + lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) + + proxy_args = (body_node, lifted_args) + + return proxy_args + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + ( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + ) = self.normalize_to_args(args, kwargs) + + score_mod_node, score_mod_lifted_args = self.create_wrapped_node( + tx, query, score_mod, "score_mod" + ) + mask_fn = block_mask.items[-1] + if isinstance(mask_fn, ConstantVariable): + mask_fn = UserFunctionVariable(torch.nn.attention._flex_attention._no_mask) + mask_fn_node, mask_fn_lifted_args = self.create_wrapped_node( + tx, query, mask_fn, "mask_fn" + ) + + proxied_args = [ + query, + key, + value, + TupleVariable(block_mask.items[:-1], source=block_mask.source), + scale, + kernel_options, + ] + + # Store the invocation as a call + # Norm_kwargs contains the score_function and we dont want to proxy this because + # Proxying user defined functions is not supported. + inp_args, _ = proxy_args_kwargs(proxied_args, {}) + + # Compose the ordered HOO args: + # - inp_args: [query, key, value, block_mask, scale, kernel_options] + # - subgraph node: [score_mod, mask_fn_node] + # - lifted args from tracing subgraph: [score_mod_other_buffers, mask_fn_other_buffers] + _, _, _, inp_arg_block_mask, inp_arg_scale, inp_arg_kernel_options = inp_args + block_mask = tuple(inp_arg_block_mask + (mask_fn_node,)) + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=inp_args[:3] + + ( + score_mod_node, + block_mask, + inp_arg_scale, + inp_arg_kernel_options, + score_mod_lifted_args, + mask_fn_lifted_args, + ), + kwargs={}, + ), + example_value=None, + ) + + +class AutogradFunctionApplyVariable(VariableTracker): + def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs) -> None: + super().__init__(**kwargs) + self.fwd_graph = fwd_graph + self.bwd_graph = bwd_graph + self.parent_source = parent_source + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ( + AutogradFunctionContextVariable, + UserDefinedClassVariable, + UserFunctionVariable, + UserMethodVariable, + ) + from .builder import wrap_fx_proxy + + """ + Consider the following: + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.sin() + @staticmethod + def backward(ctx, grad): + x, = ctx.saved_tensors + return grad * x.cos() + We want the resulting graphs to look like: + def fwd(ctx, x): + # (output, saved tensors / attrs) + return (x.sin(), [x]) + # bwd(ctx, grad0, grad1, ..., gradn, *saved_tensors_or_attrs) + def bwd(ctx, grad, x): + return grad * x.cos() + To accomplish this, we're going to: + 1. Construct a ctx object + 2. (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph on MySin.forward (manually_set_inputs=True) + 3. (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph on MySin.backward, while manually setting + the ctx and grad inputs. + 4. Manually rewriting the fwd graph's output to be (output, stuff_that_gets_used in bwd_graph) + Getting from 3 to 4 is pretty elegant: stuff_that_gets_used in bwd graph is + just the bwd_freevars returned from speculate_subgraph, assuming MySin.backward + doesn't capture any arguments. + All these steps work if MySin.backward doesn't capture any values. This is a + limitation in general that we should check for. + """ + + prev_side_effects = tx.output.side_effects.clone() + fwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=tx.output.current_tracer, + source_target="autograd.Function", + ) + + ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) + with discard_graph_changes(tx): + # A little hacky, but we need a dummy ctx proxy for speculate_subgraph. + # We should clean this up at some point. + proxy = tx.output.create_proxy( + "call_function", torch.autograd.function.FunctionCtx, (), {} + ) + set_example_value(proxy.node, ctx.value) + ctx.proxy = proxy + + if isinstance(self.fwd_graph, types.FunctionType): + fwd_fn = UserFunctionVariable(self.fwd_graph) + fwd_args = [ctx, *args] + elif isinstance(self.fwd_graph, types.MethodType): + fwd_fn = UserMethodVariable( + self.fwd_graph.__func__, + UserDefinedClassVariable(self.fwd_graph.__class__), + ) + fwd_args = [fwd_fn.obj, ctx, *args] + else: + unimplemented("non-function or method") + + # Speculate subgraph on the fwd + (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph( + tx, + fwd_fn, + fwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="semi_automatic", + restore_side_effects=False, + tracer=fwd_tracer, + ) + + if ctx in tx.output.side_effects.store_attr_mutations: + if ( + "_materialize_non_diff_grads" + in tx.output.side_effects.store_attr_mutations[ctx] + ): + unimplemented("NYI") + + bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=fwd_tracer, + source_target="autograd.Function", + ) + + # Speculate subgraph on the backward. We make the + # bwd tracer a child of the fwd tracer, because backward may rely on + # tensors/attrs created in the fwd tracer. + + if isinstance(fwd_out, variables.BaseListVariable): + bwd_args = [ctx, *fwd_out.items] + else: + bwd_args = [ctx, fwd_out] + + bwd_src = AttrSource(self.parent_source, member="backward") + if isinstance(self.bwd_graph, types.FunctionType): + bwd_fn = UserFunctionVariable(self.bwd_graph, source=bwd_src) + elif isinstance(self.bwd_graph, types.MethodType): + bwd_fn = UserMethodVariable( + self.bwd_graph.__func__, + UserDefinedClassVariable(self.bwd_graph.__class__), + source=bwd_src, + ) + bwd_args = [bwd_fn.obj, *bwd_args] + else: + unimplemented("non-function or method") + + def is_strict_for(v: VariableTracker): + if isinstance(v, variables.TensorVariable): + # we can be more lax for stuff from forward + return v.proxy.tracer is not fwd_tracer + return True + + with ( + tx.output.subtracer(fwd_fn, fwd_tracer), + tx.strict_translation_mode(is_strict_for), + ): + try: + (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph( + tx, + bwd_fn, + bwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="manual", + restore_side_effects=False, + tracer=bwd_tracer, + ) + except torch._dynamo.exc.Unsupported as e: + if isinstance( + e, torch._dynamo.exc.UnknownPropertiesDuringBackwardTrace + ): + from unittest import mock + + bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( + tx.output, + parent=fwd_tracer, + source_target="autograd.Function", + ) + from .._trace_wrapped_higher_order_op import ( + autograd_function_backward_rewritten, + ) + + if isinstance(self.bwd_graph, types.FunctionType): + bwd_fn = UserFunctionVariable( + autograd_function_backward_rewritten(self.bwd_graph) + ) + elif isinstance(self.bwd_graph, types.MethodType): + bwd_fn = UserMethodVariable( + autograd_function_backward_rewritten( + self.bwd_graph.__func__ + ), + UserDefinedClassVariable(self.bwd_graph.__class__), + ) + else: + unimplemented("non-function or method") + + with mock.patch( + "torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops", + [], + ): + (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph( + tx, + bwd_fn, + bwd_args, + kwargs, + "autograd.Function", + enable_grad=False, + set_subgraph_inputs="manual", + restore_side_effects=False, + tracer=bwd_tracer, + ) + else: + raise e + + # TODO: assert that bwd_graph didn't capture values that were + # not created inside fwd_graph. + + # TODO(oulgen): Ideally, we would not do a linear search for output + # node but as things currently are there could be nodes after the + # output node + # This is bug prone as if there's code after the output node, then + # graph.output will append the output at the very end + # This might be a behavior difference + + # If users call ctx.mark_non_differentiable, we should capture these output tensors who + # are marked as non-differentiable and pass them to ApplyTemplate + # at torch._functorch.autograd_function.AutogradFunctionApply for reconstruction. + non_differentiable_idx = [] + if ctx.non_differentiable is not None: + non_differentiable_set = set(ctx.non_differentiable) + assert isinstance(fwd_out, variables.BaseListVariable) + for i, x in enumerate(fwd_out.items): + if ( + isinstance(x, variables.TensorVariable) + and x.as_proxy() in non_differentiable_set + ): + non_differentiable_idx.append(i) + + # Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd) + for node in fwd_graph.find_nodes(op="output"): + fwd_graph.erase_node(node) + break + + # Because we lift the bwd_freevars as inputs of the bwd_graph, + # we have to manually add the bwd_freevars as output of fwd_graph. + # However, the bwd_freevars got from speculate_subgraph use the Proxies in the bwd_graph, + # we need to convert them to Proxies in the fwd_graph and then generate new fwd_graph output. + fwd_proxy_of_bwd_freevars = [] + for k in bwd_freevars.keys(): + if k in fwd_freevars: + fwd_proxy_of_bwd_freevars.append(fwd_freevars[k]) + else: + fwd_proxy_of_bwd_freevars.append(k) + + def unwrap_proxy(x): + if isinstance(x, torch.fx.Proxy): + return x.node + else: + assert variables.ConstantVariable.is_literal(x), ( + f"Only constant is allowed. Got {x}" + ) + return x + + new_fwd_graph_outputs = (fwd_out.as_proxy(), fwd_proxy_of_bwd_freevars) + new_fwd_graph_outputs = pytree.tree_map(unwrap_proxy, new_fwd_graph_outputs) + fwd_graph.output(new_fwd_graph_outputs) + fwd_graph.lint() + + # Store fwd_body + fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate() + fwd_name = tx.output.install_subgraph( + "fwd_body", + torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph), + ) + + fwd_node = make_attr(tx, fwd_name) + + # The type of original args can be arbitrary, but we only support basic type in FX graph. + # So the speculated subgraph input includes original tensor args and the lifted freevars. + # We need to filter out the original tensor args and concat them with the lifted freevars + # to generate the proxy args for the FX call_function node. + filtered_args = [] + # A boolean list to mark if the type of corresponding argument is tensor. + # This is used to determine if a FX node's argument should be an argument of + # ApplyTemplate.forward and if we should skip the output from ApplyTemplate.backward + # at torch._functorch.autograd_function.AutogradFunctionApply. + args_tensor_mask = [False] * len(args) + for i, arg in enumerate(args): + if isinstance(arg, (variables.TensorVariable, variables.SymNodeVariable)): + filtered_args.append(arg) + args_tensor_mask[i] = True + + # Rewrite the output of bwd_graph to remove the grad output for the non-Tensor args. + new_bwd_graph_outputs = None + for node in bwd_graph.find_nodes(op="output"): + bwd_graph.erase_node(node) + break + + # The same as the above fwd proxies, we need to use the bwd proxies in the bwd_graph + # if some of the output is from fwd_freevars. + bwd_out_proxy = bwd_out.as_proxy() + bwd_proxy_of_fwd_freevars = [] + if isinstance(bwd_out_proxy, (tuple, list)): + for k in bwd_out_proxy: + if k in bwd_freevars: + bwd_proxy_of_fwd_freevars.append(bwd_freevars[k]) + else: + bwd_proxy_of_fwd_freevars.append(k) + else: + if bwd_out_proxy in bwd_freevars: + bwd_proxy_of_fwd_freevars = bwd_freevars[bwd_out_proxy] + else: + bwd_proxy_of_fwd_freevars = bwd_out_proxy + + # Remove bwd output for non-Tensor args. + output_proxy = bwd_proxy_of_fwd_freevars + if isinstance(output_proxy, (tuple, list)): + new_bwd_graph_outputs = () + for x, mask in zip(output_proxy, args_tensor_mask): + if mask: + new_bwd_graph_outputs = new_bwd_graph_outputs + (x,) + else: + assert x is None, f"Grad of non-Tensor arg {x} is not None." + else: + new_bwd_graph_outputs = output_proxy + + # Update the bwd graph output. + new_bwd_graph_outputs = pytree.tree_map( + lambda x: None if x is None else x.node, new_bwd_graph_outputs + ) + bwd_graph.output(new_bwd_graph_outputs) + bwd_graph.lint() + + # Store bwd_body + bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate() + bwd_name = tx.output.install_subgraph( + "bwd_body", + torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph), + ) + + bwd_node = make_attr(tx, bwd_name) + + tx.output.side_effects = prev_side_effects + + p_args = ( + fwd_node, + bwd_node, + *([arg.as_proxy() for arg in filtered_args] + list(fwd_freevars.keys())), + ) + kwargs = { + "args_tensor_mask": args_tensor_mask, + "non_differentiable_idx": non_differentiable_idx, + } + + # Store the invocation as a call + from torch._functorch.autograd_function import autograd_function_apply + + # We use speculate_subgraph to get the fwd graph, but it's always under no grad mode like what eager mode does. + # The fwd outputs (tensor's example_value) need to be inferred from fake tensor prop to get the correct attributes + # (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing. + # Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it. + with enable_python_dispatcher(): + with tx.output.fake_mode: + fake_args = ( + tx.output.nn_modules[fwd_node.node.name], + tx.output.nn_modules[bwd_node.node.name], + *( + [ + _get_fake_value(arg) + for arg in filtered_args + list(fwd_freevars.keys()) + ] + ), + ) + example_value = autograd_function_apply(*fake_args, **kwargs) + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + autograd_function_apply, + args=p_args, + kwargs=kwargs, + ), + example_value=example_value, + ) + + +def _get_fake_value(x): + if isinstance(x, variables.VariableTracker): + return x.as_proxy().node.meta["example_value"] + elif isinstance(x, torch.fx.Proxy): + return x.node.meta["example_value"] + else: + return x + + +def maybe_positional_arg_names(func): + result = [] + if not hasattr(func, "get_function"): + return None + try: + fn = func.get_function() + except (Unsupported, NotImplementedError): + return None + try: + sig = inspect.signature(fn) + except ValueError: + return None + for name, param in sig.parameters.items(): + if param.kind is inspect.Parameter.VAR_POSITIONAL: + return None + if ( + param.kind is inspect.Parameter.POSITIONAL_ONLY + or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD + ): + if name == "self": + # FX graphs can't have a placeholder named self + result.append("self_") + else: + result.append(name) + return result + + +class BaseHOPVariable(WrapHigherOrderVariable): + supports_input_mutation = False + supports_aliasing = False + + def python_type(self): + return type(self.value) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + ( + p_args, + p_kwargs, + example_value, + body_r, + treespec, + body_gmod, + body_name, + ) = self.create_wrapped_node( + tx, args[0], args[1:], {}, self.value._name, subgraph_name="subgraph" + ) + assert len(p_kwargs) == 0 + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()} + return _call_function_and_unflatten_output( + tx, self.value, p_args, p_kwargs, flat_example_value, treespec + ) + + +class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): + supports_input_mutation = False + supports_aliasing = False + + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name + ): + # Check if the subgraph from speculate_subgraph (body_gmod) and the fake + # inputs have already been seen before. If yes, the subgraph is already + # installed in the output graph and we can just access the subgraph + # using the saved attr name. + + if not isinstance(fn_vt, (UnspecializedNNModuleVariable, UserFunctionVariable)): + unimplemented_v2( + gb_type="Encountered non user function variable during invoke_subgraph HOP tracing", + context=str(fn_vt), + explanation="invoke_subgraph does not support non user function variable", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + invoke_subgraph_cache = ( + tx.output.tracing_context.hop_dispatch_set_cache.get_cache( + torch._higher_order_ops.invoke_subgraph + ) + ) + + if isinstance(fn_vt, UserFunctionVariable): + fn_id = id(fn_vt.get_function()) + fn_name = fn_vt.get_function().__name__ + else: + assert isinstance(fn_vt, UnspecializedNNModuleVariable) + fn_id = id(fn_vt.value.forward.__func__) + fn_name = fn_vt.value.forward.__name__ + previously_installed_submodules = [] + if invoke_subgraph_cache: + previously_installed_submodules = ( + invoke_subgraph_cache.get_dynamo_installed_submodules(fn_id) + ) + current_mod = body_gmod + # NB - reverse is more likely to cause a hit sooner because first + # graph can have requires_grad=False for a few inputs + for submodule_name in reversed(previously_installed_submodules): + assert submodule_name in tx.output.nn_modules + previous_mod = tx.output.nn_modules[submodule_name] + if are_same_graph_modules( + fn_name, previous_mod, current_mod, tx.fake_mode + ): + return submodule_name + + body_name = super().install_subgraph_in_output_graph( + tx, fn_vt, fn_args_vt, kwargs, body_gmod, "subgraph" + ) + hc_log.debug( + "%s: Installing subgraph with identifier '%s', bringing total count for '%s' function to %s", + fn_name, + body_name, + fn_name, + len(previously_installed_submodules) + 1, + ) + if invoke_subgraph_cache: + invoke_subgraph_cache.add_dynamo_installed_submodule(fn_id, body_name) + + return body_name + + @raise_hard_error_if_graph_break( + reason="torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # This flattens the kwargs into lifted args + ( + p_args, + p_kwargs, + example_value, + body_r, + treespec, + body_gmod, + body_name, + ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "invoke_subgraph") + + if len(p_kwargs) > 0: + unimplemented("kwargs should have been flattened into lifted args") + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + p_args = ( + p_args[0], + body_name, + *p_args[1:], + ) + return _call_function_and_unflatten_output( + tx, + torch._higher_order_ops.invoke_subgraph, + tuple(p_args), + p_kwargs, + flat_example_value, + treespec, + ) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/iter.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/iter.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7e30791bd3d3e71d2ed7e30ac99c389d7ff74d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/iter.py @@ -0,0 +1,630 @@ +# mypy: ignore-errors + +""" +This module provides iterator-related variable tracking functionality for Dynamo. +It implements variable classes for handling Python iterators and itertools functions +during symbolic execution and tracing. + +The module includes: +- Base iterator variable classes for tracking iterator state +- Implementations of built-in iterators (zip, map, filter) +- Support for itertools functions (product, accumulate, combinations, etc.) +- Mutation tracking and reconstruction capabilities for iterator operations + +These classes integrate with Dynamo's variable tracking system to enable proper +handling of iterator operations during code transformation and optimization. +""" + +import itertools +import operator +import sys +from typing import Optional, TYPE_CHECKING, Union + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..exc import ( + handle_observed_exception, + ObservedUserStopIteration, + raise_observed_exception, + unimplemented_v2, + UserError, +) +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +MAX_ITERATOR_LIMIT = 100 * 1024 # 100k + + +class ItertoolsVariable(VariableTracker): + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def __repr__(self) -> str: + return f"ItertoolsVariable({self.value})" + + def as_python_constant(self): + return self.value + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # See also: module `torch._dynamo.polyfills.itertools` + + if ( + self.value is itertools.product + and not kwargs + and all(arg.has_unpack_var_sequence(tx) for arg in args) + ): + seqs = [arg.unpack_var_sequence(tx) for arg in args] + items = [ + variables.TupleVariable(list(item)) for item in itertools.product(*seqs) + ] + return variables.ListIteratorVariable( + items, mutation_type=ValueMutationNew() + ) + elif self.value is itertools.accumulate: + from .builtin import BuiltinVariable + + if any(key not in ["initial", "func"] for key in kwargs.keys()): + unimplemented_v2( + gb_type="Unsupported kwargs for itertools.accumulate", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Expected kwargs: 'initial', 'func', but got " + f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}", + hints=[*graph_break_hints.USER_ERROR], + ) + + if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx): + seq = args[0].unpack_var_sequence(tx) + + if "func" in kwargs and len(args) == 1: + func = kwargs["func"].call_function + elif len(args) == 2: + func = args[1].call_function + elif len(args) == 1: + # Default to operator.add + func = BuiltinVariable(operator.add).call_function + else: + unimplemented_v2( + gb_type="Unsupported `func` in itertools.accumulate", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to get the " + "function to use for itertools.accumulate. " + "itertools.accumulate expects the `func` as the second " + "argument or as a keyword argument.", + hints=[*graph_break_hints.USER_ERROR], + ) + else: + unimplemented_v2( + gb_type="Unsupported arguments for itertools.accumulate", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to trace " + f"itertools.accumulate with args: {args} and kwargs: {kwargs}. " + "itertools.accumulate expects an iterable, an optional " + "binary function for accumulation, and an optional initial " + "value to set the starting state.", + hints=[ + "Make sure the arguments to itertools.accumulate are correct.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + items = [] + acc = kwargs.get("initial") + if acc is not None: + items.append(acc) + for item in seq: + if acc is None: + acc = item + else: + try: + acc = func(tx, [acc, item], {}) + except Exception as e: + unimplemented_v2( + gb_type="Unexpected failure during itertools.accumulate() iteration", + context=f"call_function {self} {args} {kwargs}", + explanation="Unexpected failure in invoking function during accumulate. " + f"Failed running func {func}({item}{acc})", + hints=[*graph_break_hints.DIFFICULT], + from_exc=e, + ) + items.append(acc) + + return variables.ListIteratorVariable( + items, mutation_type=ValueMutationNew() + ) + elif ( + self.value is itertools.combinations + and not kwargs + and len(args) == 2 + and args[0].has_unpack_var_sequence(tx) + and args[1].is_python_constant() + ): + iterable = args[0].unpack_var_sequence(tx) + r = args[1].as_python_constant() + + items = [] + for item in itertools.combinations(iterable, r): + items.append(variables.TupleVariable(list(item))) + return variables.ListIteratorVariable( + items, mutation_type=ValueMutationNew() + ) + elif self.value is itertools.groupby: + if any(kw != "key" for kw in kwargs.keys()): + unimplemented_v2( + gb_type="Unsupported kwargs for itertools.groupby", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Expected kwargs: 'key', but got " + f"{','.join(set(kwargs.keys()) - {'key'})}", + hints=[*graph_break_hints.USER_ERROR], + ) + + def retrieve_const_key(key): + if isinstance(key, variables.SymNodeVariable): + return key.evaluate_expr() + elif isinstance(key, variables.ConstantVariable): + return key.as_python_constant() + else: + unimplemented_v2( + gb_type="Unsupported key type for itertools.groupby", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to trace " + f"itertools.groupby with key type: {str(type(key))}. " + "We only support grouping keys that are constants (int, float, str, etc.)", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + if len(args) == 1 and args[0].has_unpack_var_sequence(tx): + seq = args[0].unpack_var_sequence(tx) + else: + unimplemented_v2( + gb_type="Unsupported arguments for itertools.groupby", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to trace " + f"itertools.groupby with args: {args} and kwargs: {kwargs}. " + "itertools.groupby expects an iterable to group and an " + "optional key function to determine groupings.", + hints=[ + "Make sure the arguments to itertools.groupby are correct.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + if "key" in kwargs: + + def keyfunc(x): + return retrieve_const_key( + kwargs.get("key").call_function(tx, [x], {}) + ) + + else: + + def keyfunc(x): + return retrieve_const_key(x) + + result = [] + try: + for k, v in itertools.groupby(seq, key=keyfunc): + result.append( + variables.TupleVariable( + [ + variables.ConstantVariable.create(k) + if variables.ConstantVariable.is_literal(k) + else k, + variables.ListIteratorVariable( + list(v), mutation_type=ValueMutationNew() + ), + ], + mutation_type=ValueMutationNew(), + ) + ) + except Exception as e: + unimplemented_v2( + gb_type="Unexpected failure during itertools.groupby() iteration", + context=f"call_function {self} {args} {kwargs}", + explanation="Unexpected failure in invoking function during groupby", + hints=[*graph_break_hints.SUPPORTABLE], + from_exc=e, + ) + return variables.ListIteratorVariable( + result, mutation_type=ValueMutationNew() + ) + elif self.value is itertools.repeat: + if len(args) < 2: + return variables.RepeatIteratorVariable( + *args, mutation_type=ValueMutationNew() + ) + + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.repeat), args, kwargs + ) + elif self.value is itertools.count: + return variables.CountIteratorVariable( + *args, mutation_type=ValueMutationNew() + ) + elif self.value is itertools.cycle: + return variables.CycleIteratorVariable( + *args, mutation_type=ValueMutationNew() + ) + else: + return super().call_function(tx, args, kwargs) + + +class IteratorVariable(VariableTracker): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def next_variable(self, tx): + unimplemented_v2( + gb_type="Unimplemented next() call", + context=f"next({self})", + explanation="This abstract method must be implemented", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + # NOTE: only call when unpacking this iterator safely done eagerly! + # Normally, iterators are accessed lazily. + # Example of safe eager unpacking: list(map(f, seq)) + # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) + def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: + result = [] + self.force_apply_to_var_sequence(tx, result.append) + return result + + def force_apply_to_var_sequence(self, tx, fn) -> None: + while True: + try: + fn(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + + # don't call force_unpack_var_sequence since it can mutate + # IteratorVariable state! + def has_force_unpack_var_sequence(self, tx) -> bool: + return True + + +class RepeatIteratorVariable(IteratorVariable): + def __init__(self, item: VariableTracker, **kwargs) -> None: + super().__init__(**kwargs) + self.item = item + + # Repeat needs no mutation, clone self + def next_variable(self, tx): + return self.item + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("repeat"), + ] + ) + ) + codegen(self.item) + codegen.extend_output(create_call_function(1, False)) + + +class CountIteratorVariable(IteratorVariable): + def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: + super().__init__(**kwargs) + if not isinstance(item, VariableTracker): + item = ConstantVariable.create(item) + if not isinstance(step, VariableTracker): + step = ConstantVariable.create(step) + self.item = item + self.step = step + + def next_variable(self, tx): + assert self.is_mutable() + old_item = self.item + tx.output.side_effects.mutation(self) + self.item = self.item.call_method(tx, "__add__", [self.step], {}) + return old_item + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("count"), + ] + ) + ) + codegen(self.item) + codegen(self.step) + codegen.extend_output(create_call_function(2, False)) + + +class CycleIteratorVariable(IteratorVariable): + def __init__( + self, + iterator: IteratorVariable, + saved: Optional[list[VariableTracker]] = None, + saved_index: int = 0, + item: Optional[VariableTracker] = None, + **kwargs, + ) -> None: + if saved is None: + saved = [] + super().__init__(**kwargs) + self.iterator = iterator + self.saved = saved + self.saved_index = saved_index + self.item = item + + def next_variable(self, tx): + assert self.is_mutable() + + if self.iterator is not None: + try: + new_item = self.iterator.next_variable(tx) + if len(self.saved) > MAX_ITERATOR_LIMIT: + unimplemented_v2( + gb_type="input iterator to itertools.cycle has too many items", + context=f"next({self})", + explanation=f"Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}", + hints=[], + ) + tx.output.side_effects.mutation(self) + self.saved.append(new_item) + self.item = new_item + if self.item is None: + return self.next_variable(tx) + return self.item + except ObservedUserStopIteration: + handle_observed_exception(tx) + self.iterator = None + return self.next_variable(tx) + elif len(self.saved) > 0: + tx.output.side_effects.mutation(self) + self.saved_index = (self.saved_index + 1) % len(self.saved) + return self.item + else: + raise_observed_exception(StopIteration, tx) + + +class ZipVariable(IteratorVariable): + """ + Represents zip(*iterables) + """ + + _nonvar_fields = { + "index", + "strict", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, + iterables: list[Union[list[VariableTracker], VariableTracker]], + strict: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + assert isinstance(iterables, list) + # can be list[Variable] or VariableTracker (with next_variable implemented) + self.iterables = iterables + self.index = 0 + self.strict = strict + + def python_type(self): + return zip + + def has_unpack_var_sequence(self, tx) -> bool: + return all( + isinstance(it, list) or it.has_unpack_var_sequence(tx) + for it in self.iterables + ) + + def unpack_var_sequence(self, tx) -> list["VariableTracker"]: + assert self.has_unpack_var_sequence(tx) + iterables = [] + for it in self.iterables: + if isinstance(it, list): + iterables.append(it[self.index :]) + else: + iterables.append(it.unpack_var_sequence(tx)) + kwargs = {"strict": self.strict} if self.strict else {} + zipped = zip(*iterables, **kwargs) + return [variables.TupleVariable(list(var)) for var in zipped] + + def next_variable(self, tx): + assert self.is_mutable() + old_index = self.index + args = [] + + def get_item(it): + if isinstance(it, list): + if old_index >= len(it): + raise_observed_exception(StopIteration, tx) + return it[old_index] + else: + return it.next_variable(tx) + + try: + for idx, it in enumerate(self.iterables): + args.append(get_item(it)) + except ObservedUserStopIteration: + if self.strict: + if idx == 0: + # all other iterables should be exhausted + for it in self.iterables: + try: + get_item(it) + except ObservedUserStopIteration: + handle_observed_exception(tx) + continue + # no ObservedUserStopIteration - fall through to UserError + break + else: + # all iterables exhausted, raise original error + raise + handle_observed_exception(tx) + raise UserError( + ValueError, + "zip() has one argument of len differing from others", + ) from None + raise + + tx.output.side_effects.mutation(self) + self.index += 1 + return variables.TupleVariable(args) + + def reconstruct_items(self, codegen: "PyCodegen"): + for it in self.iterables: + if isinstance(it, list): + remaining_items = it[self.index :] + codegen.foreach(remaining_items) + codegen.append_output( + create_instruction("BUILD_TUPLE", arg=len(remaining_items)) + ) + else: + codegen(it) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True + ) + self.reconstruct_items(codegen) + codegen.append_output( + create_instruction("BUILD_TUPLE", arg=len(self.iterables)) + ) + if sys.version_info >= (3, 10): + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + create_instruction("CALL_FUNCTION_EX", arg=1), + ] + ) + else: + codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0)) + + +class MapVariable(ZipVariable): + """ + Represents map(fn, *iterables) + """ + + def __init__( + self, + fn: VariableTracker, + iterables: list[Union[list[VariableTracker], VariableTracker]], + **kwargs, + ) -> None: + super().__init__(iterables, **kwargs) + self.fn = fn + + def python_type(self): + return map + + def has_unpack_var_sequence(self, tx) -> bool: + return False + + def next_variable(self, tx): + args = super().next_variable(tx) + return self.fn.call_function(tx, args.items, {}) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True + ) + codegen(self.fn) + self.reconstruct_items(codegen) + codegen.extend_output( + [ + create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1), + create_instruction("CALL_FUNCTION_EX", arg=0), + ] + ) + + +class FilterVariable(IteratorVariable): + """ + Represents filter(fn, iterable) + """ + + _nonvar_fields = { + "index", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, + fn: VariableTracker, + iterable: Union[list[VariableTracker], VariableTracker], + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.fn = fn + self.iterable = iterable + self.index = 0 + + def python_type(self): + return filter + + def has_unpack_var_sequence(self, tx) -> bool: + return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence( + tx + ) + + def unpack_var_sequence(self, tx) -> list["VariableTracker"]: + assert self.has_unpack_var_sequence(tx) + it = None + if isinstance(self.iterable, list): + it = self.iterable[self.index :] + else: + it = self.iterable.unpack_var_sequence(tx) + filtered = self.fn.call_function(tx, it, {}) + return [variables.TupleVariable([filtered])] + + def next_variable(self, tx): + def _next(): + old_index = self.index + if isinstance(self.iterable, list): + if old_index >= len(self.iterable): + raise_observed_exception(StopIteration, tx) + return self.iterable[old_index] + else: + return self.iterable.next_variable(tx) + + # A do-while loop to find elements that make fn return true + while True: + item = _next() + self.index += 1 + res = self.fn.call_function(tx, [item], {}) + pred_res = variables.UserFunctionVariable( + polyfills.predicate + ).call_function(tx, [res], {}) + if pred_res.as_python_constant(): + return item + + def reconstruct_items(self, codegen: "PyCodegen"): + if isinstance(self.iterable, list): + remaining_items = self.iterable[self.index :] + codegen.foreach(remaining_items) + codegen.append_output( + create_instruction("BUILD_TUPLE", arg=len(remaining_items)) + ) + else: + codegen(self.iterable) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) + codegen(self.fn) + self.reconstruct_items(codegen) + codegen.extend_output(create_call_function(2, False)) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/lazy.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..1b803da19b739bd0d25f2b18eec9c892e86b9346 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/lazy.py @@ -0,0 +1,213 @@ +import collections +import functools +import inspect +from typing import Any, Callable, final, Optional, Union +from typing_extensions import Self + +from ..utils import is_function_or_wrapper +from .base import VariableTracker +from .tensor import SymNodeVariable + + +class LazyCache: + """Container to cache the real VariableTracker""" + + def __init__(self, value: Any, source: Any) -> None: + if not isinstance(value, LazySymNodeFormatString): + assert source + self.value = value + self.source = source + self.vt: Optional[VariableTracker] = None + + def realize(self) -> None: + assert self.vt is None + from ..symbolic_convert import InstructionTranslator + from . import builder + + tx = InstructionTranslator.current_tx() + + if isinstance(self.value, LazySymNodeFormatString): + self.vt = builder.SourcelessBuilder.create(tx, self.value) + else: + self.vt = builder.VariableBuilder(tx, self.source)(self.value) + + del self.value + del self.source + + +@final +class LazyVariableTracker(VariableTracker): + """ + A structure that defers the creation of the actual VariableTracker + for a given underlying value until it is accessed. + + The `realize` function invokes VariableTracker.build() to produce the real object. + Once a LazyVariableTracker has been realized, internal bookkeeping will + prevent double realization. + + This object should be utilized for processing containers, or objects that + reference other objects where we may not want to take on creating all the + VariableTrackers right away. + """ + + _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields} + + @staticmethod + def create(value: Any, source: Any, **options: Any) -> "LazyVariableTracker": + return LazyVariableTracker(LazyCache(value, source), source=source, **options) + + def __init__(self, _cache: LazyCache, **kwargs: Any) -> None: + assert isinstance(_cache, LazyCache) + super().__init__(**kwargs) + self._cache = _cache + + def realize(self) -> VariableTracker: + """Force construction of the real VariableTracker""" + if self._cache.vt is None: + self._cache.realize() + assert self._cache.vt is not None + return self._cache.vt + + def unwrap(self) -> Union[VariableTracker, Self]: + """Return the real VariableTracker if it already exists""" + if self.is_realized(): + assert self._cache.vt is not None + return self._cache.vt + return self + + def is_realized(self) -> bool: + return self._cache.vt is not None + + def clone(self, **kwargs: Any) -> VariableTracker: + assert kwargs.get("_cache", self._cache) is self._cache + if kwargs.get("source", self.source) is not self.source: + self.realize() + return VariableTracker.clone(self.unwrap(), **kwargs) + + def peek_type(self) -> type[Any]: + assert not self.is_realized() + return type(self._cache.value) + + def peek_value(self) -> Any: + assert not self.is_realized() + return self._cache.value + + def __str__(self) -> str: + if self.is_realized(): + return repr(self.unwrap()) + return super().__repr__() + + def __getattr__(self, item: str) -> Any: + return getattr(self.realize(), item) + + # most methods are auto-generated below, these are the ones we want to exclude + visit = VariableTracker.visit # type: ignore[assignment] + __repr__ = __str__ + + @classmethod + def realize_all( + cls, + value: Any, + cache: Optional[dict[int, tuple[Any, Any]]] = None, + ) -> Any: + """ + Walk an object and realize all LazyVariableTrackers inside it. + """ + if cache is None: + cache = {} + + idx = id(value) + if idx in cache: + return cache[idx][0] + + value_cls = type(value) + if issubclass(value_cls, LazyVariableTracker): + result = cls.realize_all(value.realize(), cache) + elif issubclass(value_cls, VariableTracker): + # update value in-place + result = value + value_dict = value.__dict__ + nonvars = value._nonvar_fields + for key in value_dict: + if key not in nonvars: + value_dict[key] = cls.realize_all(value_dict[key], cache) + elif value_cls is list: + result = [cls.realize_all(v, cache) for v in value] + elif value_cls is tuple: + result = tuple(cls.realize_all(v, cache) for v in value) + elif value_cls in (dict, collections.OrderedDict): + result = {k: cls.realize_all(v, cache) for k, v in list(value.items())} + else: + result = value + + # save `value` to keep it alive and ensure id() isn't reused + cache[idx] = (result, value) + return result + + def is_hashable(self) -> bool: + # Checks that the underlying value is hashable without realizing the VT. + # This is used by ConstDictVariable tracker to find if the key LazyVT + # can be hashed. + def _helper(value: Any) -> bool: + # TODO: Add support for more types + return ( + inspect.isbuiltin(value) + or issubclass(type(value), type) + or is_function_or_wrapper(value) + ) + + assert not self.is_realized() + value = self._cache.value + if isinstance(value, tuple): + return all(_helper(v) for v in value) + return _helper(value) + + def original_value(self) -> Any: + # Returns the value without realizing the VT. + assert not self.is_realized() + return self._cache.value + + def original_source(self) -> Any: + # Returns the source without realizing the VT. + assert not self.is_realized() + return self._cache.source + + +class LazySymNodeFormatString: + def __init__( + self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker + ) -> None: + from .constant import ConstantVariable + + self.sym_node_var = sym_node_variable + self.fmt_var = ConstantVariable.create( + "{:" + fmt_spec_var.as_python_constant() + "}" + ) + + def __repr__(self) -> str: + return str.format( + self.fmt_var.as_python_constant(), + str(self.sym_node_var.evaluate_expr()), + ) + + +def _create_realize_and_forward( + name: str, +) -> Callable[[LazyVariableTracker, Any, Any], Any]: + @functools.wraps(getattr(VariableTracker, name)) + def realize_and_forward( + self: LazyVariableTracker, *args: Any, **kwargs: Any + ) -> Any: + return getattr(self.realize(), name)(*args, **kwargs) + + return realize_and_forward + + +def _populate() -> None: + for name, value in VariableTracker.__dict__.items(): + if name not in LazyVariableTracker.__dict__: + if callable(value): + setattr(LazyVariableTracker, name, _create_realize_and_forward(name)) + + +_populate() diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/lists.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/lists.py new file mode 100644 index 0000000000000000000000000000000000000000..4d46270fd22dabe0807a986b6da3417ce28ce199 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/lists.py @@ -0,0 +1,1144 @@ +# mypy: ignore-errors + +""" +Variable tracking implementations for list-like data structures in Dynamo. + +This module provides specialized variable tracking for various collection types: +- Lists and list subclasses (including torch.nn.ModuleList, ParameterList) +- Tuples and named tuples +- Ranges and slices +- Collections.deque +- torch.Size with special proxy handling + +The implementations support both mutable and immutable collections, iteration, +and common sequence operations. Each collection type has a dedicated Variable +class that handles its unique behaviors while integrating with Dynamo's +variable tracking system. +""" + +import collections +import inspect +import operator +from typing import Optional, TYPE_CHECKING + +import torch +import torch.fx + +from .. import graph_break_hints, polyfills, variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..exc import raise_observed_exception, unimplemented_v2 +from ..source import AttrSource +from ..utils import ( + cmp_name_to_op_mapping, + cmp_name_to_op_str_mapping, + get_fake_value, + guard_if_dyn, + iter_contains, + Lit, + namedtuple_fields, + odict_values, + set_example_value, +) +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .functions import UserFunctionVariable, UserMethodVariable +from .iter import IteratorVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class BaseListVariable(VariableTracker): + @staticmethod + def cls_for_instance(obj): + return BaseListVariable.cls_for(type(obj)) + + @staticmethod + def cls_for(obj): + return { + iter: ListIteratorVariable, + list: ListVariable, + slice: SliceVariable, + torch.Size: SizeVariable, + tuple: TupleVariable, + odict_values: ListVariable, + torch.nn.ParameterList: ListVariable, + torch.nn.ModuleList: ListVariable, + collections.deque: DequeVariable, + }[obj] + + def __init__( + self, + items: list[VariableTracker], + **kwargs, + ) -> None: + super().__init__(**kwargs) + assert isinstance(items, list) + assert all(isinstance(x, VariableTracker) for x in items) + self.items: list[VariableTracker] = items + + def _as_proxy(self): + return [x.as_proxy() for x in self.items] + + def modified(self, items, **kwargs): + return type(self)(items, **kwargs) + + @property + def value(self): + return self.as_python_constant() + + def debug_repr_helper(self, prefix, suffix): + return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix + + def as_python_constant(self): + return self.python_type()([x.as_python_constant() for x in self.items]) + + def as_proxy(self): + assert self.python_type() is not SizeVariable + return self.python_type()(self._as_proxy()) + + def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + from .tensor import SymNodeVariable + + if isinstance(arg, SymNodeVariable): + index = arg.sym_num + else: + index = arg.as_python_constant() + + if isinstance(index, slice): + # Set source to None because slicing a list gives a new local + return self.clone( + items=self.items[index], + source=None, + mutation_type=ValueMutationNew() if self.mutation_type else None, + ) + else: + assert isinstance(index, (int, torch.SymInt)) + try: + return self.items[index] + except IndexError: + raise_observed_exception( + IndexError, tx, args=["list index out of range"] + ) + + def unpack_var_sequence(self, tx): + return list(self.items) + + def call_method( + self, + tx, + name, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__getitem__": + from .tensor import TensorVariable + + assert not kwargs and len(args) == 1 + if isinstance(args[0], TensorVariable): + value = get_fake_value(args[0].as_proxy().node, tx) + if value.constant is not None and value.constant.numel() == 1: + value = variables.ConstantVariable.create(value.constant.item()) + else: + unimplemented_v2( + gb_type="Indexing list with non-scalar tensor", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=( + "Attempted to index list-like object with tensor with > 1 element." + ), + hints=[*graph_break_hints.USER_ERROR], + ) + else: + value = args[0] + return self.getitem_const(tx, value) + elif name == "__contains__": + assert len(args) == 1 + assert not kwargs + return iter_contains(self.unpack_var_sequence(tx), args[0], tx) + elif name == "index": + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.index), + [self] + list(args), + kwargs, + ) + elif name in cmp_name_to_op_mapping: + left = self + right = args[0] + # TODO this type check logic mirrors the following + # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/object.c#L991-L1007 + # But we should probably move it up the stack to so that we don't + # need to duplicate it for different VTs. + if not isinstance(left, BaseListVariable) or not isinstance( + right, BaseListVariable + ): + if name == "__eq__": + return variables.BuiltinVariable(operator.is_).call_function( + tx, (left, right), {} + ) + elif name == "__ne__": + return variables.BuiltinVariable(operator.is_not).call_function( + tx, (left, right), {} + ) + else: + op_str = cmp_name_to_op_str_mapping[name] + left_ty = left.python_type_name() + right_ty = right.python_type_name() + msg = f"{op_str} not supported between instances of '{left_ty}' and '{right_ty}'" + raise_observed_exception(TypeError, tx, args=[msg]) + + return variables.UserFunctionVariable(polyfills.list_cmp).call_function( + tx, + [variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right], + {}, + ) + + return super().call_method(tx, name, args, kwargs) + + +class RangeVariable(BaseListVariable): + def __init__(self, items, **kwargs) -> None: + items_to_map = items + start = variables.ConstantVariable.create(0) + stop = None + step = variables.ConstantVariable.create(1) + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map + else: + raise AssertionError + + assert stop is not None + super().__init__([start, stop, step], **kwargs) + + def debug_repr(self): + return self.debug_repr_helper("range(", ")") + + def python_type(self): + return range + + def start(self): + return self.items[0].as_python_constant() + + def stop(self): + return self.items[1].as_python_constant() + + def step(self): + return self.items[2].as_python_constant() + + def range_length(self): + lo = self.start() + hi = self.stop() + step = self.step() + + assert step != 0 + if step > 0 and lo < hi: + return 1 + (hi - 1 - lo) // step + elif step < 0 and lo > hi: + return 1 + (lo - 1 - hi) // (0 - step) + else: + return 0 + + def _get_slice_indices(self, length, slice): + step_is_negative = 0 + + if slice.step is None: + step = 1 + step_is_negative = False + else: + step = slice.step + step_is_negative = slice.step < 0 + + # Find lower and upper bounds for start and stop. + if step_is_negative: + lower = -1 + upper = length + lower + else: + lower = 0 + upper = length + + # Compute start + if slice.start is None: + start = upper if step_is_negative else lower + else: + start = slice.start + + if start < 0: + start += length + if start < lower: + start = lower + else: + if start > upper: + start = upper + + # Compute stop. + if slice.stop is None: + stop = lower if step_is_negative else upper + + else: + stop = slice.stop + + if stop < 0: + stop += length + if stop < lower: + stop = lower + else: + if stop > upper: + stop = upper + + return [start, stop, step] + + def apply_index(self, index): + length = self.range_length() + if index < 0: + index = length + index + + if index < 0 or index >= length: + raise IndexError(f"index {index} is out of range") + + return variables.ConstantVariable.create(self.start() + (index * self.step())) + + def apply_slice(self, slice): + (slice_start, slice_stop, slice_step) = self._get_slice_indices( + self.range_length(), slice + ) + + def compute_item(index): + return self.start() + (index * self.step()) + + sub_step = self.step() * slice_step + sub_start = compute_item(slice_start) + sub_stop = compute_item(slice_stop) + + result = RangeVariable( + [ + variables.ConstantVariable.create(x) + for x in [sub_start, sub_stop, sub_step] + ], + mutation_type=ValueMutationNew() if self.mutation_type else None, + ) + return result + + def as_python_constant(self): + return range(*[x.as_python_constant() for x in self.items]) + + def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c + index = arg.as_python_constant() + + if isinstance(index, slice): + return self.apply_slice(index) + else: + return self.apply_index(index) + + def as_proxy(self): + return self.python_type()(*self._as_proxy()) + + def unpack_var_sequence(self, tx=None): + return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] + + def reconstruct(self, codegen: "PyCodegen") -> None: + assert "range" not in codegen.tx.f_globals + codegen.add_push_null( + lambda: codegen.append_output(codegen.create_load_python_module(range)) + ) + codegen.foreach(self.items) + codegen.extend_output(create_call_function(3, False)) + + def var_getattr(self, tx: "InstructionTranslator", name): + fields = ["start", "stop", "step"] + if name not in fields: + unimplemented_v2( + gb_type="Unsupported attribute for range() object", + context=f"var_getattr {self} {name}", + explanation=f"Expected attribute to be one of {','.join(fields)} " + f"but got {name}", + hints=[*graph_break_hints.USER_ERROR], + ) + return self.items[fields.index(name)] + + +class CommonListMethodsVariable(BaseListVariable): + """ + Implement methods common to List and other List-like things + """ + + def call_method( + self, + tx, + name, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + from .tensor import SymNodeVariable + + if name == "append" and self.is_mutable(): + assert not kwargs + (arg,) = args + tx.output.side_effects.mutation(self) + self.items.append(arg) + return ConstantVariable.create(None) + elif ( + name == "extend" + and self.is_mutable() + and args + and args[0].has_force_unpack_var_sequence(tx) + ): + assert not kwargs + (arg,) = args + arg.force_apply_to_var_sequence( + tx, lambda item: self.call_method(tx, "append", [item], {}) + ) + return ConstantVariable.create(None) + elif name == "insert" and self.is_mutable(): + assert not kwargs + idx, value = args + if isinstance(idx, SymNodeVariable): + const_idx = idx.evaluate_expr() + else: + const_idx = idx.as_python_constant() + tx.output.side_effects.mutation(self) + self.items.insert(const_idx, value) + return ConstantVariable.create(None) + elif name == "pop" and self.is_mutable(): + assert not kwargs + tx.output.side_effects.mutation(self) + return self.items.pop(*[a.as_python_constant() for a in args]) + elif name == "clear" and self.is_mutable(): + assert not kwargs and not args + tx.output.side_effects.mutation(self) + self.items.clear() + return ConstantVariable.create(None) + elif ( + name == "__setitem__" + and self.is_mutable() + and args + and args[0].is_python_constant() + ): + assert not kwargs + key, value = args + tx.output.side_effects.mutation(self) + if isinstance(key, SliceVariable): + self.items[key.as_python_constant()] = list(value.items) + else: + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + elif name == "copy": + # List copy() doesn't have args and kwargs + assert not kwargs + assert not args + items = list(self.items) + return self.modified(items, mutation_type=ValueMutationNew()) + elif name == "reverse" and self.is_mutable(): + assert not kwargs + assert not args + self.items.reverse() + tx.output.side_effects.mutation(self) + return ConstantVariable.create(None) + else: + return super().call_method(tx, name, args, kwargs) + + +class ListVariable(CommonListMethodsVariable): + def python_type(self): + return list + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)})" + + def debug_repr(self): + return self.debug_repr_helper("[", "]") + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items))) + + def call_method( + self, + tx, + name, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + if ( + name == "__setitem__" + and self.is_mutable() + and args + and args[0].is_python_constant() + ): + assert not kwargs + key, value = args + tx.output.side_effects.mutation(self) + if isinstance(key, SliceVariable): + if not value.has_force_unpack_var_sequence(tx): + unimplemented_v2( + gb_type="Unsupported conversion for slice assignment", + context=f"call_method {self} {name} {args}", + explanation=f"Missing dynamo support for converting {value} into a list for slice assignment.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + self.items[key.as_python_constant()] = value.force_unpack_var_sequence( + tx + ) + else: + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + + if name == "sort" and self.is_mutable(): + assert len(args) == 0 + key_fn_var = kwargs.pop("key", ConstantVariable.create(None)) + reverse = kwargs.pop( + "reverse", ConstantVariable.create(False) + ).as_python_constant() + assert len(kwargs) == 0 + + if ( + key_fn_var.is_python_constant() + and key_fn_var.as_python_constant() is None + ): + keys = self.items.copy() + else: + keys = [key_fn_var.call_function(tx, [x], {}) for x in self.items] + + if not all(k.is_python_constant() for k in keys): + first_non_constant_key = None + for k in keys: + if not k.is_python_constant(): + first_non_constant_key = k + assert first_non_constant_key is not None + + try: + python_type = first_non_constant_key.python_type() + except NotImplementedError: + python_type = "unknown" + + unimplemented_v2( + gb_type="sort with non-constant keys", + context=str(first_non_constant_key), + explanation=( + f"Cannot perform sort with non-constant key. " + f"First non-constant key type: {python_type}. " + f"Most notably, we cannot sort with Tensor or SymInt keys, but we can " + f"sort ints." + ), + hints=["Use something else as the key."], + ) + + tx.output.side_effects.mutation(self) + sorted_items_with_keys = sorted( + ( + ( + x, + k.as_python_constant(), + -i if reverse else i, # extra key to ensure stable sort + ) + for i, (k, x) in enumerate(zip(keys, self.items)) + ), + key=operator.itemgetter(1, 2), + reverse=reverse, + ) + self.items[:] = [x for x, *_ in sorted_items_with_keys] + return ConstantVariable.create(None) + + if name == "__init__" and self.is_mutable(): + assert not kwargs + if len(args) == 0: + return ConstantVariable.create(None) + elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + (arg,) = args + tx.output.side_effects.mutation(self) + self.items[:] = arg.force_unpack_var_sequence(tx) + return ConstantVariable.create(None) + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx, name): + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is list: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + if self.python_type() is not list: + return super().call_obj_hasattr(tx, name) + return variables.ConstantVariable.create(hasattr([], name)) + + +class DequeVariable(CommonListMethodsVariable): + def __init__(self, items, maxlen=None, **kwargs) -> None: + if maxlen is None: + maxlen = ConstantVariable.create(None) + assert maxlen.is_python_constant(), ( + f"maxlen must be a constant, got: {maxlen.debug_repr()}" + ) + self.maxlen = maxlen + items = list(items) + if self.maxlen.as_python_constant() is not None: + items = items[-maxlen.as_python_constant() :] + super().__init__(items, **kwargs) + + def python_type(self): + return collections.deque + + def debug_repr(self): + if self.maxlen.as_python_constant() is None: + return self.debug_repr_helper( + "deque([", "], maxlen=" + self.maxlen.debug_repr() + ")" + ) + return self.debug_repr_helper("deque([", "])") + + def as_python_constant(self): + return self.python_type()( + [x.as_python_constant() for x in self.items], + maxlen=self.maxlen.as_python_constant(), + ) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_python_module(collections.deque) + ) + ) + codegen.foreach(self.items) + codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))]) + codegen(self.maxlen) + codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False)) + + def var_getattr(self, tx: "InstructionTranslator", name): + if name == "maxlen": + return self.maxlen + return super().var_getattr(tx, name) + + def call_method( + self, + tx, + name, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + if ( + name == "__setitem__" + and self.is_mutable() + and args + and args[0].is_python_constant() + ): + assert len(args) == 2 + assert not kwargs + key, value = args + assert key.is_python_constant() + assert isinstance(key.as_python_constant(), int) + tx.output.side_effects.mutation(self) + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + + maxlen = self.maxlen.as_python_constant() + if maxlen is not None: + slice_within_maxlen = slice(-maxlen, None) + else: + slice_within_maxlen = None + + if ( + name == "extendleft" + and self.is_mutable() + and len(args) > 0 + and args[0].has_force_unpack_var_sequence(tx) + ): + assert len(args) == 1 + assert not kwargs + # NOTE this is inefficient, but the alternative is to represent self.items + # as a deque, which is a more intrusive change. + args[0].force_apply_to_var_sequence( + tx, lambda item: self.call_method(tx, "appendleft", [item], {}) + ) + slice_within_maxlen = slice(None, maxlen) + result = ConstantVariable.create(None) + elif name == "popleft" and self.is_mutable(): + assert not args + assert not kwargs + tx.output.side_effects.mutation(self) + result, *self.items[:] = self.items + elif name == "appendleft" and len(args) > 0 and self.is_mutable(): + assert len(args) == 1 + assert not kwargs + tx.output.side_effects.mutation(self) + self.items[:] = [args[0], *self.items] + slice_within_maxlen = slice(None, maxlen) + result = ConstantVariable.create(None) + elif name == "insert" and len(args) > 0 and self.is_mutable(): + assert len(args) == 2 + assert not kwargs + if maxlen is not None and len(self.items) == maxlen: + raise_observed_exception( + IndexError, tx, args=["deque already at its maximum size"] + ) + result = super().call_method(tx, name, args, kwargs) + else: + result = super().call_method(tx, name, args, kwargs) + + if ( + slice_within_maxlen is not None + and maxlen is not None + and len(self.items) > maxlen + ): + self.items[:] = self.items[slice_within_maxlen] + return result + + +class TupleVariable(BaseListVariable): + def python_type(self): + return tuple + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)})" + + def debug_repr(self): + return self.debug_repr_helper("(", ")") + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_TUPLE", arg=len(self.items))) + + def call_method( + self, + tx, + name, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx, name): + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is tuple: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + if self.python_type() is not tuple: + return super().call_obj_hasattr(tx, name) + return variables.ConstantVariable.create(hasattr((), name)) + + +class SizeVariable(TupleVariable): + """torch.Size(...)""" + + _nonvar_fields = { + "proxy", + *TupleVariable._nonvar_fields, + } + + def __init__( + self, + items: list[VariableTracker], + proxy: Optional[torch.fx.Proxy] = None, + **kwargs, + ) -> None: + self.proxy = proxy + super().__init__(items, **kwargs) + + def debug_repr(self): + return self.debug_repr_helper("torch.Size([", "])") + + def python_type(self): + return torch.Size + + def as_proxy(self): + if self.proxy is not None: + return self.proxy + + # torch.Size needs special handling. Normally, we pun a list-like + # container to directly contain Proxy/Node objects from FX, and FX + # knows to look inside containers (via map_aggregate). But torch.Size + # is weird; although it subclasses from tuple, it doesn't allow + # members which aren't int-like (rejecting Proxy and Node). This + # means we can't use the normal representation trick + # torch.Size([proxy0, proxy1]). I looked into seeing if I could + # relax torch.Size in PyTorch proper, but if torch.Size constructor + # sees a type that it doesn't recognize, it will try to call + # __index__() on it, so there is no BC way to actually change this + # behavior (though it occurs to me that I could have just added a + # YOLO no checking alternate constructor.) + # + # To work around this problem, I represent a torch.Size proxy as + # a straight up proxy, that would have been constructed by taking + # the constituent proxies as arguments. This trick can be generally + # used for any construct that we need a proxy for but we can't + # directly represent as an aggregate; I don't see very many examples + # of this in torchdynamo though! + + # Look for a proxy. If there are none, do the legacy behavior + tracer = None + proxies = self._as_proxy() + for proxy in proxies: + if isinstance(proxy, torch.fx.Proxy): + tracer = proxy.tracer + break + + if tracer is None: + return torch.Size(proxies) + + proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {}) + set_example_value( + proxy.node, + torch.Size( + [ + p.node.meta["example_value"] if not isinstance(p, int) else p + for p in proxies + ] + ), + ) + return proxy + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size")) + codegen.foreach(self.items) + build_torch_size = [ + create_instruction("BUILD_TUPLE", arg=len(self.items)), + ] + create_call_function(1, False) + codegen.extend_output(build_torch_size) + + def unpack_var_sequence(self, tx): + return list(self.items) + + def numel(self, tx): + from .builtin import BuiltinVariable + from .tensor import SymNodeVariable + + const_result = 1 + sym_sizes = [] + + for v in self.items: + if isinstance(v, ConstantVariable): + const_result *= v.value + else: + assert isinstance(v, SymNodeVariable), type(v) + # Delay proxy calls until we know it will be necessary + sym_sizes.append(v) + + result = ConstantVariable.create(const_result) + if sym_sizes and const_result == 1: + # Skip multiplying by 1 + result, *sym_sizes = sym_sizes + + if not sym_sizes or const_result == 0: + return result + + mul = BuiltinVariable(operator.mul) + for v in sym_sizes: + result = mul.call_function(tx, [result, v], {}) + return result + + def call_method( + self, + tx, + name, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + if name == "__getitem__": + assert not kwargs and len(args) == 1 + out = self.get_item_dyn(tx, args[0]) + return out + elif name == "numel": + assert not args and not kwargs + return self.numel(tx) + + return super().call_method(tx, name, args, kwargs) + + def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker): + from .tensor import SymNodeVariable + + if isinstance(arg, SymNodeVariable): + index = arg.sym_num + else: + index = arg.as_python_constant() + + if isinstance(index, slice): + return SizeVariable(self.items[index]) + else: + assert isinstance(index, (int, torch.SymInt)) + return self.items[index] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + return variables.ConstantVariable.create(hasattr(torch.Size, name)) + + +class NamedTupleVariable(TupleVariable): + _nonvar_fields = { + "tuple_cls", + "dynamic_attributes", + *TupleVariable._nonvar_fields, + } + + def __init__(self, items, tuple_cls, **kwargs) -> None: + super().__init__(items, **kwargs) + self.tuple_cls = tuple_cls + self.dynamic_attributes = {} + + def is_namedtuple(self): + return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( + getattr(self.tuple_cls, "_make", None) + ) + + def is_structseq(self): + return not self.is_namedtuple() + + def fields(self): + return namedtuple_fields(self.tuple_cls) + + def debug_repr(self): + if self.is_structseq(): + # StructSequenceType(iterable) + return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items])) + # NamedTupleType(*iterable) + return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items))) + + def python_type(self): + return self.tuple_cls + + def as_python_constant(self): + if self.is_structseq(): + # StructSequenceType(iterable) + return self.python_type()([x.as_python_constant() for x in self.items]) + # NamedTupleType(*iterable) + return self.python_type()(*[x.as_python_constant() for x in self.items]) + + def as_proxy(self): + assert self.python_type() is not SizeVariable + if self.is_structseq(): + # StructSequenceType(iterable) + return self.python_type()(self._as_proxy()) + # NamedTupleType(*iterable) + return self.python_type()(*self._as_proxy()) + + def reconstruct(self, codegen: "PyCodegen") -> None: + # Constructors: + # StructSequenceType(iterable) + # NamedTupleType(*iterable) + # NamedTupleType._make(iterable) + create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_const_unchecked(create_fn) + ) + ) + codegen.foreach(self.items) + codegen.extend_output( + [ + create_instruction("BUILD_TUPLE", arg=len(self.items)), + ] + + create_call_function(1, False) + ) + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__setattr__": + assert len(args) == 2 + assert len(kwargs) == 0 + attr, value = args + attr = attr.as_python_constant() + if ( + # structseq is immutable + self.is_structseq() + # namedtuple directly created by `collections.namedtuple` is immutable + or self.tuple_cls.__bases__ == (tuple,) + # fields are immutable + or attr in self.fields() + ): + raise_observed_exception(AttributeError, tx) + # Subclass of namedtuple type can have dynamic attributes + tx.output.side_effects.mutation(self) + self.dynamic_attributes[attr] = value + return ConstantVariable.create(None) + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name): + def check_and_create_method(): + method = inspect.getattr_static(self.tuple_cls, name, None) + if isinstance(method, classmethod): + # We need the unbounded cls method to avoid the inline __self__ + return UserMethodVariable( + method.__func__, + variables.UserDefinedClassVariable(self.tuple_cls), + ) + elif isinstance(method, staticmethod): + return UserFunctionVariable(method.__func__) + elif inspect.isfunction(method): + return UserMethodVariable(method, self) + else: + return None + + if name in self.dynamic_attributes: + return self.dynamic_attributes[name] + + fields = self.fields() + if name not in fields: + method = check_and_create_method() + if not method: + return super().var_getattr(tx, name) + return method + return self.items[fields.index(name)] + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + return variables.ConstantVariable.create( + name in self.dynamic_attributes or hasattr(self.tuple_cls, name) + ) + + +class SliceVariable(VariableTracker): + def __init__(self, items, **kwargs) -> None: + items_to_map = items + start, stop, step = [variables.ConstantVariable.create(None)] * 3 + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map + else: + raise AssertionError + + if isinstance(start, variables.TensorVariable) or isinstance( + stop, variables.TensorVariable + ): + unimplemented_v2( + gb_type="Dynamic slicing with Tensor arguments", + context=f"SliceVariable start: {start}, stop: {stop}, step: {step}", + explanation="Creating slices with Tensor arguments is not supported. " + "e.g. `l[:x]`, where `x` is a 1-element tensor.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + self.items = (start, stop, step) + + super().__init__(**kwargs) + + def debug_repr(self): + return self.debug_repr_helper("slice(", ")") + + def as_proxy(self): + return slice(*[x.as_proxy() for x in self.items]) + + def python_type(self): + return slice + + def as_python_constant(self): + return slice(*[guard_if_dyn(x) for x in self.items]) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach(self.items) + codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items))) + + def var_getattr(self, tx: "InstructionTranslator", name): + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + fields = ["start", "stop", "step"] + if name not in fields: + unimplemented_v2( + gb_type="Unsupported attribute for slice() object", + context=f"var_getattr {self} {name}", + explanation=f"Expected attribute to be one of {','.join(fields)} " + f"but got {name}", + hints=[*graph_break_hints.USER_ERROR], + ) + return self.items[fields.index(name)] + + +class ListIteratorVariable(IteratorVariable): + _nonvar_fields = { + "index", + *IteratorVariable._nonvar_fields, + } + + def __init__(self, items, index: int = 0, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(items, list) + # Removing this check as it slows things down too much + # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492 + + # assert all(isinstance(x, VariableTracker) for x in items) + self.items = items + self.index = index + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})" + + def next_variable(self, tx): + assert self.is_mutable() + old_index = self.index + if old_index >= len(self.items): + raise_observed_exception(StopIteration, tx) + + tx.output.side_effects.mutation(self) + self.index += 1 + return self.items[old_index] + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ): + if name == "__contains__": + assert len(args) == 1 + assert not kwargs + return iter_contains(self.items[self.index :], args[0], tx) + + return super().call_method(tx, name, args, kwargs) + + def python_type(self): + return type(iter([])) + + def as_python_constant(self): + if self.index > 0: + raise NotImplementedError + return iter([x.as_python_constant() for x in self.items]) + + def unpack_var_sequence(self, tx): + return list(self.items[self.index :]) + + def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: + return self.unpack_var_sequence(tx) + + def reconstruct(self, codegen: "PyCodegen") -> None: + remaining_items = self.items[self.index :] + codegen.foreach(remaining_items) + codegen.extend_output( + [ + create_instruction("BUILD_TUPLE", arg=len(remaining_items)), + create_instruction("GET_ITER"), + ] + ) + + +class TupleIteratorVariable(ListIteratorVariable): + pass diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/misc.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ab3e11e4816aa26728a5532226fcd56abe3d12 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/misc.py @@ -0,0 +1,1934 @@ +# mypy: ignore-errors + +""" +This module contains miscellaneous variable tracker implementations for various Python types +and features used in Dynamo's symbolic execution. These classes help track and propagate +information about different kinds of variables during graph capture. + +Key classes include: +- SuperVariable: Handles super() calls and method resolution +- ExceptionVariable: Tracks exception objects +- RandomVariable: Manages random number generators +- GetAttrVariable: Tracks attribute access +- MethodWrapperVariable: Handles method wrappers +- PythonModuleVariable: Tracks Python modules +- NumpyVariable: Handles numpy functions and types +- StringFormatVariable: Manages string formatting +- DebuggingVariable: Handles print and logging +""" + +import dataclasses +import functools +import inspect +import itertools +import random +import re +import sys +import types +import warnings +from typing import Optional, TYPE_CHECKING + +import torch._C +import torch._numpy as tnp +import torch.utils._pytree as pytree + +from .. import config, graph_break_hints, trace_rules, variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import raise_observed_exception, unimplemented, unimplemented_v2 +from ..guards import GuardBuilder, install_guard +from ..mutation_guard import unpatched_nn_module_init +from ..source import ( + AttrSource, + GenericAttrSource, + GetItemSource, + TypeSource, + WeakRefCallSource, +) +from ..utils import ( + check_unspec_or_constant_args, + cmp_name_to_op_mapping, + identity, + is_tensor_base_attr_getter, + istype, + list_methods, + proxy_args_kwargs, + tuple_methods, +) +from .base import VariableTracker +from .constant import ConstantVariable +from .functions import NestedUserFunctionVariable, UserFunctionVariable +from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class NO_SUCH_SUBOBJ: + pass + + +class SuperVariable(VariableTracker): + _nonvar_fields = { + *VariableTracker._nonvar_fields, + } + + def __init__(self, typevar, objvar=None, **kwargs) -> None: + super().__init__(**kwargs) + # typevar is the first argument to super(). In the case where no argument + # is provided to super(), it is the __class__ object where + # the super() function is being called + self.typevar = typevar + # objvar here must be an instance or subtype of typevar. + # In the case where super() is called without arguments, it is the first argument + # to the current function where super() is called from (self for regular method, + # cls for a classmethod) + self.objvar = objvar + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) + codegen(self.typevar) + if self.objvar is not None: + codegen(self.objvar) + codegen.extend_output(create_call_function(2, False)) + else: + codegen.extend_output(create_call_function(1, False)) + + def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): + assert self.objvar, "1-arg super not implemented" + search_type = self.typevar.as_python_constant() + + # The rest of this function does two things: + # - Walk the mro to find where the attribute comes from to be + # able to provide accurate source + # - Call the getattr to get the object + + # Find the class object, where the function lives. + # When objvar is "self", use type(self), when objvar is "cls", use it as-is + type_to_use = self.objvar.python_type() + type_to_use_source = ( + TypeSource(self.objvar.source) if self.objvar.source else None + ) + if issubclass(type_to_use, type): + type_to_use = self.objvar.value + type_to_use_source = self.objvar.source + + source = None + search_mro = type_to_use.__mro__ + + try: + start_index = search_mro.index(search_type) + 1 + except ValueError: + # Corner case where the typevar is not in the mro of the objvar + # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844 + return getattr(super(search_type, type_to_use), name), None + # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812 + # super has its getattro implementation. The key point is that instead of calling getattr, it checks the + # attribute in the class __dict__ + for index in range(start_index, len(search_mro)): + # Dont call getattr, just check the __dict__ of the class + if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ): + if resolved_getattr is not NO_SUCH_SUBOBJ: + # Equivalent of something like type(L['self']).__mro__[1].attr_name + if type_to_use_source: + source = AttrSource( + GetItemSource( + AttrSource(type_to_use_source, "__mro__"), index + ), + name, + ) + return resolved_getattr, source + + unimplemented_v2( + gb_type="Unable to resolve super getattr", + context="", + explanation=f"Dynamo failed to trace attribute `{name}` accessed " + f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) " + "because the resolved attribute type is not supported.", + hints=[ + "Ensure the attribute exists in the parent class.", + "Check the arguments passed to `super()`.", + ], + ) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + # Check if getattr is a constant. If not, delay the actual work by + # wrapping the result in GetAttrVariable. Mostly super is called with a + # method, so most of the work is delayed to call_function. + # + # We could have just implemented a const_getattr. However, super is + # special when it comes to finding sources. Compared to other VTs, super + # requires the attr name to walk the mro and find the actual source (and + # not just AttrSource). + value, source = self._resolved_getattr_and_source(self, name) + if not variables.ConstantVariable.is_literal(value): + return GetAttrVariable(self, name) + if source: + install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) + return variables.ConstantVariable.create(value, source=source) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + inner_fn, source = self._resolved_getattr_and_source(self, name) + # This essentially simulates CPython's `super_getattro`: + # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168 + # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`. + # + # However, `res`'s type needs to be checked for `tp_descr_get`, and + # applied if it has one. We currently don't have polyfills for all the + # relevant `tp_descr_get`, so we explicitly handle the cases we care + # about here (e.g., note the staticmethod, classmethod cases). + if inner_fn is object.__init__: + return LambdaVariable(identity) + elif inner_fn is torch.nn.Module.__init__: + objvar = self.objvar + from ..side_effects import AttributeMutationNew + + if ( + isinstance(objvar, variables.UserDefinedObjectVariable) + and isinstance(objvar.mutation_type, AttributeMutationNew) + and not (args or kwargs) + ): + with do_not_convert_to_tracable_parameter(): + return variables.UserFunctionVariable( + unpatched_nn_module_init, source=source + ).call_function(tx, [self.objvar] + args, kwargs) + else: + unimplemented_v2( + gb_type="Unsupported super().__init__() call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo encountered a super().__init__() call " + f"on {objvar} that resolved to a `torch.nn.Module.__init__()` " + "call that we cannot trace.", + hints=[*graph_break_hints.DIFFICULT], + ) + elif ( + self.objvar.source + and hasattr(inner_fn, "__name__") + and inner_fn.__name__ == "__new__" + and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn) + ): + user_cls = inner_fn.__self__ + if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins": + user_cls_vt = variables.BuiltinVariable(user_cls) + else: + user_cls_source = source.member + user_cls_vt = variables.UserDefinedClassVariable( + user_cls, source=user_cls_source + ) + return user_cls_vt.call_method(tx, "__new__", args, kwargs) + elif isinstance(inner_fn, staticmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): + return variables.UserFunctionVariable( + inner_fn.__func__, source=source + ).call_function(tx, args, kwargs) + elif isinstance(inner_fn, classmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): + if isinstance(self.objvar, variables.UserDefinedClassVariable): + # super().classmethod is called from a classmethod itself. So, + # super was converted to super(__class__, cls) in bytecode and + # therefore we have to propagate the cls. + cls_variable = self.objvar + else: + # current function is an instance method, therefore super was + # converted to super(__class__, self). We have to find + # type(self) to bind the cls to the parent classmethod. + # Note that it can't be the self.typevar because __class__ is + # the class where the method is defined, which could be + # different from type(self) with polymorphism. + cls_source = None + if self.objvar.source: + cls_source = AttrSource(self.objvar.source, "__class__") + cls_variable = VariableTracker.build( + tx, self.objvar.value_type, cls_source + ) + + return variables.UserMethodVariable( + inner_fn.__func__, cls_variable, source=source + ).call_function(tx, args, kwargs) + elif isinstance(inner_fn, types.FunctionType): + return variables.UserFunctionVariable( + inner_fn, source=source + ).call_function(tx, [self.objvar] + args, kwargs) + elif isinstance(inner_fn, types.MethodType): + return variables.UserMethodVariable( + inner_fn.__func__, self.objvar, source=source + ).call_function(tx, args, kwargs) + elif is_standard_setattr(inner_fn) and isinstance( + self.objvar, UserDefinedObjectVariable + ): + return self.objvar.method_setattr_standard(tx, *args, **kwargs) + elif inner_fn is object.__delattr__: + attr = args[0] + try: + attr = attr.as_python_constant() + except NotImplementedError as exc: + unimplemented_v2( + gb_type="Non-constant attribute given to `super().__delattr__()`", + context=f"call_method {self} {name}", + explanation="Dynamo requires the attribute name passed to " + "`super().__delattr__(...)` to be a constant (string).", + hints=[ + "Ensure the attribute name is a string literal or a constant variable." + ], + from_exc=exc, + ) + if not tx.output.side_effects.is_attribute_mutation(self.objvar): + unimplemented_v2( + gb_type="Attempted super().__delattr__() on an object without mutation tracking", + context=f"call_method {self} {name}", + explanation="Dynamo needs to track mutations on an object " + "before `super().__delattr__` can be used on it. But the " + f"object ({self.objvar}) doesn't have attribute mutation " + "tracking enabled.", + hints=[ + "Ensure the object is tracked by Dynamo's side effect system.", + *graph_break_hints.DYNAMO_BUG, + ], + ) + + tx.output.side_effects.store_attr( + self.objvar, attr, variables.DeletedVariable() + ) + return variables.ConstantVariable(None) + elif ( + isinstance(self.objvar, variables.UserDefinedDictVariable) + and inner_fn in self.objvar._dict_methods + ): + return self.objvar._dict_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedTupleVariable) + and inner_fn in tuple_methods + ): + return self.objvar._tuple_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedListVariable) + and inner_fn in list_methods + ): + return self.objvar._list_vt.call_method(tx, name, args, kwargs) + elif inner_fn is object.__getattribute__: + # object.__getattribute__ has no side-effects. We can directly call + # __getattribute__ to access the attribute. + attr_name = args[0].value + if tx.output.side_effects.has_pending_mutation_of_attr( + self.objvar, attr_name + ): + result = tx.output.side_effects.load_attr( + self.objvar, attr_name, deleted_ok=True + ) + if isinstance(result, variables.DeletedVariable): + raise_observed_exception(AttributeError, tx) + return result + + try: + # NB - use object.__getattribute__ to prevent running any user code + attr_value = object.__getattribute__(self.objvar.value, attr_name) + except AttributeError: + raise_observed_exception(AttributeError, tx) + + attr_source = None + if self.objvar.source is not None: + # setup a object.__getattribute__(self.objvar, name) source + attr_source = GenericAttrSource(self.objvar.source, attr_name) + return VariableTracker.build(tx, attr_value, attr_source) + elif inner_fn is torch._C._disabled_torch_function_impl: + # See `THPModule_disable_torch_function` for the C impl. + # The signature of _disabled_torch_function_impl is similar to + # `__torch_function__`, just without the first `cls` argument: + # * (func, types, args, kwargs) + func = args[0] + tf_kwargs = {} + tf_args = args[2].items + for hash_key_vt, value_vt in args[3].items.items(): + key_str = hash_key_vt.vt.as_python_constant() + tf_kwargs[key_str] = value_vt + + tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled + tx.symbolic_torch_function_state.torch_function_subclass_enabled = False + try: + return func.call_function(tx, tf_args, tf_kwargs) + finally: + tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( + tx_old + ) + elif ( + isinstance(inner_fn, types.MethodDescriptorType) + and inner_fn in trace_rules.get_tensor_method() + ): + # FunctionType but implementation is in C, we support some of these, + # e.g., tensor ops like `torch.Tensor.to`. + fn_var = VariableTracker.build(tx, inner_fn, source) + return fn_var.call_function(tx, [self.objvar] + args, kwargs) + + unimplemented_v2( + gb_type="Attempted to call a super() attribute that is " + "not a function or method", + context=f"call_method {self} {name}", + explanation="Dynamo does not know how to trace the call " + f"`super().{name}()` because `super().{name}` is not a " + "function or method attribute.", + hints=[ + "Ensure the attribute accessed via `super()` is a standard method or function.", + ], + ) + + +class ExceptionVariable(VariableTracker): + # The ExceptionVariable corresponds to the BaseException class in Python + def __init__(self, exc_type, args, **kwargs) -> None: + super().__init__(**kwargs) + self.exc_type = exc_type + self.args = args + # When raising a new exception while another exception is already being + # handled, the new exception's __context__ attribute is automatically + # set to the handled exception. + self.__context__ = ConstantVariable(None) + # Set when user raised an exception from another: + # raise ... from ... + self.__cause__ = ConstantVariable(None) + # Boolean flag that controls whether the __context__ attribute is set + self.__suppress_context__ = ConstantVariable(False) + # Contains the call stack where the exception was raised. Dynamo does + # not track traceback. So, this variable is always set to None + self.__traceback__ = ConstantVariable(None) + + def set_context(self, context: "ExceptionVariable"): + self.__context__ = context + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", self.exc_type.__name__) + ) + codegen.foreach(self.args) + codegen.call_function(len(self.args), False) + + def codegen_attr(name: str) -> None: + attr = getattr(self, name) + if istype(attr, ConstantVariable): + assert attr.value in (True, False, None), attr + else: + codegen.dup_top() + codegen(attr) + codegen.extend_output(codegen.rot_n(2)) + codegen.store_attr(name) + + codegen_attr("__context__") + codegen_attr("__cause__") + codegen_attr("__suppress_context__") + + def python_type(self): + return self.exc_type + + def call_setattr( + self, + tx: "InstructionTranslator", + name_var: VariableTracker, + val: VariableTracker, + ): + def raise_error(msg): + raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)]) + + name = name_var.as_python_constant() + if name == "__context__": + self.set_context(val) + elif name == "__cause__": + if (isinstance(val, ConstantVariable) and val.value is None) or isinstance( + val, + ( + variables.BuiltinVariable, + variables.ExceptionVariable, + variables.UserDefinedExceptionClassVariable, + variables.UserDefinedExceptionObjectVariable, + ), + ): + self.__cause__ = val + self.__suppress_context__ = variables.ConstantVariable(True) + else: + raise_error("exception cause must be None or derive from BaseException") + elif name == "__suppress_context__": + if isinstance(val, ConstantVariable) and val.value in (True, False): + self.__suppress_context__ = val + else: + raise_error("exception cause must be None or derive from BaseException") + elif name == "__traceback__": + if isinstance(val, ConstantVariable) and val.value is None: + self.__traceback__ = val + else: + unimplemented_v2( + gb_type="Set Exception object `__traceback__` attribute to not-`None`", + context=f"call_setattr {self} {name}", + explanation="Dynamo does not support setting the attribute " + "'__traceback__' on tracked exception objects to anything " + "other than None.", + hints=[ + "Avoid setting '__traceback__' on exception objects " + "within traced code, or set it to None." + ], + ) + else: + unimplemented_v2( + gb_type="Unsupported attribute assignment on Exception object", + context=f"call_setattr {self} {name}", + explanation="Dynamo does not support setting the attribute " + f"'{name}' on tracked exception objects. Only `__context__`, " + "`__cause__`, `__suppress_context__`, and `__traceback__` are supported.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + return variables.ConstantVariable(None) + + def call_method(self, tx, name, args, kwargs): + if name == "__setattr__": + return self.call_setattr(tx, *args) + elif name == "with_traceback": + [tb] = args + self.call_setattr(tx, ConstantVariable("__traceback__"), tb) + return self + else: + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx, name): + if name == "__context__": + return self.__context__ + elif name == "__cause__": + return self.__cause__ + elif name == "__suppress_context__": + return self.__suppress_context__ + elif name == "__traceback__": + return variables.ConstantVariable(None) + elif name == "args": + return variables.ListVariable(self.args, source=self.source) + return super().var_getattr(tx, name) + + def __str__(self): + return f"{self.__class__.__name__}({self.exc_type})" + + __repr__ = __str__ + + +class UnknownVariable(VariableTracker): + """ + It could be anything! + """ + + +class DelayGraphBreakVariable(UnknownVariable): + """ + Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. + """ + + def __init__(self, msg=None, **kwargs): + super().__init__(**kwargs) + self.msg = msg + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented_v2( + gb_type="Unsupported function call (delayed)", + context=f"source: {self.source}", + explanation="Dynamo determined that a graph break should occur " + f"when calling `{self.source.name()}`. Reason: {self.msg}", + hints=[], + ) + + +class ComptimeVariable(VariableTracker): + """ + This variable is special, it lets you execute arbitrary code at + Dynamo compile time + """ + + def reconstruct(self, codegen: "PyCodegen"): + raise NotImplementedError("comptime is special form") + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + from ..comptime import comptime + + # To support the comptime.print_graph convenience accessors + from .functions import UserFunctionVariable + + return UserFunctionVariable( + getattr(comptime, name), source=AttrSource(self.source, name) + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..comptime import ComptimeContext + + # TODO: support an expression form as well + + assert not kwargs + # Second argument is runtime lambda, ignored + assert len(args) <= 2 + fn = args[0] + if isinstance(fn, UserFunctionVariable): + fn.get_function()(ComptimeContext(tx)) + elif isinstance(fn, NestedUserFunctionVariable): + # We have to manually bind the freevars ourselves + code = fn.get_code() + assert not fn.closure, ( + "comptime function must not have free variables, " + f"but these variables were free: {code.co_freevars}" + ) + func = types.FunctionType( + code, + fn.f_globals, + fn.fn_name.as_python_constant(), + tuple(fn.defaults.items) if fn.defaults else None, + # We could automatically promote free variables into + # ComptimeVar but this is confusing if you access + # a free variable that we actually DO have the runtime + # value for + # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items) + (), + ) + func(ComptimeContext(tx)) + else: + raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") + + return variables.ConstantVariable.create(None) + + +class CellVariable(VariableTracker): + # If the cell existed before Dynamo tracing started, this will be the + # VariableTracker that represents the cell content. + # + # Note that all mutation to the cell (i.e., its content) will be buffered in + # SideEffects, rather than being reflected here. One can think of + # `CellVariable` as a special case for `UserDefinedObjectVariable`. + pre_existing_contents: Optional[VariableTracker] + + # This is set when this cell can be referenced via `LOAD/STORE_DEREF` in the + # root frame via this name (e.g., the name is in `co_cellvars/co_freevars`). + local_name: Optional[str] = None + + def __init__( + self, pre_existing_contents: Optional[VariableTracker] = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.pre_existing_contents = pre_existing_contents + + +class NewGlobalVariable(VariableTracker): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + +def produce_trampoline_autograd_apply(fn_cls): + def trampoline_autograd_apply(*args, **kwargs): + return fn_cls.apply(*args, **kwargs) + + trampoline_autograd_apply._origin = produce_trampoline_autograd_apply + return trampoline_autograd_apply + + +class AutogradFunctionVariable(VariableTracker): + """represents a torch.autograd.Function subclass""" + + _nonvar_fields = { + "fn_cls", + *VariableTracker._nonvar_fields, + } + + def __init__(self, fn_cls, **kwargs) -> None: + super().__init__(**kwargs) + self.fn_cls = fn_cls + + def call_apply(self, tx: "InstructionTranslator", args, kwargs): + requires_grad = False + + def visit(node): + nonlocal requires_grad + if isinstance(node, variables.TensorVariable): + if node.requires_grad is not False: + requires_grad = True + if isinstance(node, variables.NNModuleVariable): + if node.is_training(tx): + requires_grad = True + + VariableTracker.visit(visit, (args, kwargs)) + + if requires_grad and torch.is_grad_enabled(): + if config.capture_autograd_function is False: + warnings.warn( + "The config.capture_autograd_function flag is deprecated, it's now always true." + ) + + from torch._functorch.autograd_function import ( + autograd_function_forward_rewritten, + ) + from torch.autograd.function import _is_setup_context_defined + + forward_fn = self.fn_cls.forward + + is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context) + if is_setup_ctx_defined: + # If setup_context is defined, we generate a new forward function which includes + # the original forward and setup_context function, and trace the new forward function. + forward_fn = autograd_function_forward_rewritten( + self.fn_cls.forward, self.fn_cls.setup_context + ) + + vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] + if vjp_fn is not torch.autograd.Function.vjp: + unimplemented_v2( + gb_type="Unsupported custom vjp", + context=f"call_apply {self} {args} {kwargs}", + explanation="Dynamo does not support tracing " + "`torch.autograd.Function` subclasses that define " + "a custom `vjp` method.", + hints=[ + "Remove the custom `vjp` method if possible.", + "Use standard `backward` instead if applicable.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] + if jvp_fn is not torch.autograd.Function.jvp: + unimplemented_v2( + gb_type="Unsupported custom jvp", + context=f"call_apply {self} {args} {kwargs}", + explanation="Dynamo does not support tracing " + "`torch.autograd.Function` subclasses that define " + "a custom `jvp` method.", + hints=[ + "Remove the custom `jvp` method if possible.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + from .higher_order_ops import AutogradFunctionApplyVariable + + source = self.source + if source is None: + source = AttrSource( + tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__ + ) + + val = AutogradFunctionApplyVariable( + forward_fn, + self.fn_cls.backward, + source, + source=AttrSource(source, member="apply"), + ).call_function(tx, args, kwargs) + # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping + # the forward function, as we don't want to generate guards for new_forward.__closure__ + # if forward is rewritten by autograd_function_forward_rewritten. + # But we still need to generate correct guards for the original forward and setup_context + # functions, so we have to add guards manually. + if self.source: + fwd_src = AttrSource(self.source, "forward") + install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH)) + if is_setup_ctx_defined: + setup_ctx_src = AttrSource(self.source, "setup_context") + install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH)) + + return val + + if self.source: + source = AttrSource(self.source, "forward") + else: + source = None + + fn = self.fn_cls.forward + ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) + args = [ctx, *args] + if isinstance(fn, types.FunctionType): + sig = inspect.signature(fn) + if len(args) - 1 == len(sig._parameters): + args = args[1:] # Don't use context + return variables.UserFunctionVariable(fn, source=source).call_function( + tx, args, kwargs + ) + elif isinstance(fn, types.MethodType): + return variables.UserMethodVariable( + fn.__func__, + variables.UserDefinedClassVariable(self.fn_cls), + source=source, + ).call_function(tx, args, kwargs) + else: + unimplemented_v2( + gb_type="Non-function or method in subclass of torch.autograd.Function", + context=f"call_apply {self} {args} {kwargs}", + explanation="Dynamo requires the `forward` attribute of a " + "`torch.autograd.Function` subclass to be a standard Python " + f"function or method. Found type `{type(fn).__name__}` instead.", + hints=[ + "Ensure the `forward` method is defined as a regular " + "function or instance method." + ], + ) + + def call_backward(self, tx: "InstructionTranslator", args, kwargs): + fn = self.fn_cls.backward + assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction + assert isinstance(fn, types.FunctionType) + + fn_source = AttrSource(self.source, "backward") + return variables.UserFunctionVariable(fn, source=fn_source).call_function( + tx, args, kwargs + ) + + def call_function(self, tx: "InstructionTranslator", args, kwargs): + return AutogradFunctionVariable(self.fn_cls) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ): + from .builder import wrap_fx_proxy + + if name == "apply": + if trace_rules.is_callable_allowed(self.fn_cls): + trampoline_autograd_apply = produce_trampoline_autograd_apply( + self.fn_cls + ) + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + trampoline_autograd_apply, + *proxy_args_kwargs(args, kwargs), + ), + ) + else: + return self.call_apply(tx, args, kwargs) + + elif name == "backward": + return self.call_backward(tx, args, kwargs) + else: + source = AttrSource(self.source, name) if self.source is not None else None + try: + obj = inspect.getattr_static(self.fn_cls, name) + except AttributeError: + obj = None + + if isinstance(obj, staticmethod): + func = obj.__get__(self.fn_cls) + if source is not None: + return ( + trace_rules.lookup(func) + .create_with_source(func, source=source) + .call_function(tx, args, kwargs) + ) + else: + return trace_rules.lookup(func)(func).call_function( + tx, args, kwargs + ) + elif isinstance(obj, classmethod): + return variables.UserMethodVariable( + obj.__func__, self, source=source + ).call_function(tx, args, kwargs) + else: + unimplemented_v2( + gb_type="Unsupported autograd.Function method", + context=f"call_method {self} {name}", + explanation="Dynamo does not support calling the method " + f"`{name}` directly on the `torch.autograd.Function` " + "instance. Supported methods include `apply`, `backward`, " + "static methods, and class methods.", + hints=[ + "Ensure the method is decorated with `@staticmethod` " + "or `@classmethod` if it's meant to be called on the class.", + ], + ) + + +@dataclasses.dataclass +class SavedTensorBox: + tensors: list[VariableTracker] = dataclasses.field(default_factory=list) + + +class AutogradFunctionContextVariable(UserDefinedObjectVariable): + """ + Tracks an autograd.Function() context using mutation tracking in side_effects.py + """ + + _nonvar_fields = { + "proxy", + "inference", + "saved_tensors", + *UserDefinedObjectVariable._nonvar_fields, + } + + def __init__( + self, + value, + value_type=None, + inference=False, + saved_tensors=None, + needs_input_grad=None, + non_differentiable=None, + **kwargs, + ) -> None: + super().__init__(value=value, value_type=value_type, **kwargs) + self.inference = inference + self.saved_tensors = saved_tensors + self.needs_input_grad = needs_input_grad + self.non_differentiable = non_differentiable + + @staticmethod + def create(tx: "InstructionTranslator", args=None, kwargs=None): + needs_input_grad = None + if args and not kwargs: + needs_input_grad = tuple( + isinstance(x, variables.TensorVariable) and x.requires_grad + for x in args + ) + out = tx.output.side_effects.track_object_new( + None, + torch.autograd.function.FunctionCtx, + functools.partial( + AutogradFunctionContextVariable, + inference=True, + saved_tensors=SavedTensorBox(), + needs_input_grad=needs_input_grad, + ), + {}, + ) + return out + + def as_proxy(self): + if self.proxy is None: + unimplemented_v2( + gb_type="proxy not set", + context=f"as_proxy {self}", + explanation="Dynamo requires the autograd.Function context " + "to be initialized with a proxy.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + return self.proxy + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__setattr__": + return super().call_method(tx, name, args, kwargs) + elif name == "mark_non_differentiable": + assert len(kwargs) == 0 + self.non_differentiable = proxy_args_kwargs(args, {})[0] + return variables.ConstantVariable.create(None) + + if name != "save_for_backward": + unimplemented_v2( + gb_type="Unsupported autograd.Function context method", + context=f"call_method {self} {name}", + explanation="Dynamo does not support calling the method " + f"`{name}` on `autograd.Function` context objects. Supported " + "methods are `__setattr__`, `save_for_backward` and " + "`mark_non_differentiable`.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + if self.saved_tensors is None: + unimplemented_v2( + gb_type="Unsupported autograd.Function context `save_for_backward`", + context=f"call_method {self} {name}", + explanation="Dynamo requires the `saved_tensors` attribute " + "to be initialized on the `autograd.Function` context object.", + hints=[ + "Ensure that the `saved_tensors` attribute is properly " + "initialized before calling `save_for_backward`. " + "`save_for_backward` only supported on a newly constructed `torch.autograd.function.FunctionCtx`.", + ], + ) + + if not self.inference: + assert self.source and not kwargs + tx.output.side_effects.track_save_for_backward(self, args) + + # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls. + if len(self.saved_tensors.tensors) > 0: + self.saved_tensors.tensors = [] + for arg in args: + self.saved_tensors.tensors.append(arg) + return variables.ConstantVariable.create(None) + + def var_getattr(self, tx: "InstructionTranslator", name): + if name in ["save_for_backward", "mark_non_differentiable"]: + return LambdaVariable( + lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) + ) + if name == "saved_tensors" and self.saved_tensors is not None: + return variables.TupleVariable(list(self.saved_tensors.tensors)) + if name == "needs_input_grad": + if self.needs_input_grad is not None: + return variables.ConstantVariable.create(self.needs_input_grad) + if self.source: + source = AttrSource(self.source, "needs_input_grad") + return VariableTracker.build(tx, self.value.needs_input_grad, source) + + return super().var_getattr(tx, name) + + +class AutogradEngineVariable(UserDefinedObjectVariable): + """ + Represents a torch._C._ImperativeEngine instance. + """ + + def __init__( + self, + value, + value_type=None, + **kwargs, + ) -> None: + super().__init__(value=value, value_type=value_type, **kwargs) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "queue_callback": + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + assert tx.one_graph, ( + "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" + ) + return variables.UserFunctionVariable( + torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback, + source=self.source, + ).call_function( + tx, + (tx.output.side_effects.get_ca_final_callbacks_var(), *args), + kwargs, + ) + else: + unimplemented_v2( + gb_type="Unsupported torch._C._ImperativeEngine.queue_callback()", + context=f"call_method {self} {name}", + explanation="queue_callback() is only supported when " + "Compiled Autograd is enabled with fullgraph=True.", + hints=[], + ) + else: + unimplemented_v2( + gb_type="Unsupported torch._C._ImperativeEngine method", + context=f"call_method {self} {name}", + explanation="Dynamo only supports the `queue_callback` method " + f"on a torch._C._ImperativeEngine instance, but found: `{name}`.", + hints=[], + ) + + +class LambdaVariable(VariableTracker): + def __init__(self, fn, **kwargs) -> None: + super().__init__(**kwargs) + self.fn = fn + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.fn(*args, **kwargs) + + +class GetAttrVariable(VariableTracker): + _nonvar_fields = { + "name", + "py_type", + *VariableTracker._nonvar_fields, + } + + def __init__(self, obj, name, py_type=None, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(obj, VariableTracker) + assert isinstance(name, str) + self.obj = obj + self.name = name + self.py_type = py_type # In some cases we know the type (ex. tensor methods) + + def python_type(self): + if self.py_type is not None: + return self.py_type + else: + return super().python_type() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.obj}, {self.name})" + + @staticmethod + def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr): + return getattr(base_proxy, attr) + + def as_proxy(self): + return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name) + + def as_python_constant(self): + constant = self.obj.as_python_constant() + try: + return getattr(constant, self.name) + except AttributeError: + raise NotImplementedError(f"{self} is not a constant") from None + + def const_getattr(self, tx: "InstructionTranslator", name): + if not isinstance(self.obj, variables.NNModuleVariable): + raise NotImplementedError + step1 = tx.output.get_submodule(self.obj.module_key) + if self.name not in step1.__dict__: + raise NotImplementedError + step2 = inspect.getattr_static(step1, self.name) + if name not in step2.__dict__: + raise NotImplementedError + return inspect.getattr_static(step2, name) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.obj) + codegen.extend_output(codegen.create_load_attrs(self.name)) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.obj.call_method(tx, self.name, args, kwargs) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if ( + name in ("__getitem__", "get") + and self.name == "__dict__" + and not kwargs + and args[0].is_python_constant() + and isinstance( + self.obj, + ( + variables.UserDefinedObjectVariable, + variables.NNModuleVariable, + variables.UserDefinedClassVariable, + ), + ) + ): + obj = self.obj + key = args[0].as_python_constant() + if obj.has_key_in_generic_dict(tx, key): + # redirect to var_getattr on the original obj + return obj.var_getattr(tx, key) + + # Return the default value for get + if name == "get": + if len(args) == 2: + return args[1] + else: + return variables.ConstantVariable(None) + + elif ( + name == "__contains__" + and self.name == "__dict__" + and len(args) == 1 + and args[0].is_python_constant() + and not kwargs + and isinstance( + self.obj, + ( + variables.UserDefinedObjectVariable, + variables.NNModuleVariable, + variables.UserDefinedClassVariable, + ), + ) + ): + obj = self.obj + key = args[0].as_python_constant() + if obj.has_key_in_generic_dict(tx, key): + return variables.ConstantVariable(True) + else: + return variables.ConstantVariable(False) + + elif name == "__setitem__" and self.name == "__dict__" and not kwargs: + if isinstance(self.obj, variables.UserDefinedObjectVariable): + # Bypass any custom setattr as we are updating the `__dict__` itself + return self.obj.method_setattr_standard( + tx, args[0], args[1], directly_update_dict=True + ) + if isinstance(self.obj, variables.NNModuleVariable): + # This matches how `setattr` is handled for NNModuleVariable + self.obj.convert_to_unspecialized(tx) + + return super().call_method(tx, name, args, kwargs) + + +class MethodWrapperVariable(VariableTracker): + def __init__(self, method_wrapper, **kwargs) -> None: + super().__init__(**kwargs) + self.method_wrapper = method_wrapper + self._builtin_fns = {} + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if is_tensor_base_attr_getter(self.method_wrapper) and isinstance( + args[0], variables.TensorVariable + ): + assert len(args) == 1 and len(kwargs) == 0 + + return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__) + + # method-wrapper variables are common in __init__ calls. For example, + # str("foo").__init__ is a method-wrapper. These method wrappers point + # to C functions. Here we intercept if these method-wrappers are from + # builtins and then call the function counterpart directly by obtaining + # the self object. + self_obj = self.method_wrapper.__self__ + wrapper_name = self.method_wrapper.__name__ + # TODO(dynamo-team) - We can perhaps expand the scope to more names and + # more builtins. + if wrapper_name == "__init__": + fn_obj = type(self_obj).__init__ + if fn_obj is object.__init__: + return variables.BuiltinVariable(object).call_method( + tx, wrapper_name, [self_obj, *args], kwargs + ) + + return super().call_function(tx, args, kwargs) + + def is_python_constant(self): + return True + + def as_python_constant(self): + return self.method_wrapper + + +class GetSetDescriptorVariable(VariableTracker): + def __init__(self, desc, **kwargs) -> None: + super().__init__(**kwargs) + self.desc = desc + + def var_getattr(self, tx: "InstructionTranslator", name): + if name == "__get__" and self.source: + source = AttrSource(self.source, "__get__") + return VariableTracker.build(tx, self.desc.__get__, source) + else: + return super().var_getattr(tx, name) + + def is_python_constant(self): + return True + + def as_python_constant(self): + return self.desc + + +class PythonModuleVariable(VariableTracker): + _nonvar_fields = { + "value", + "is_torch", + *VariableTracker._nonvar_fields, + } + + def __init__(self, value: types.ModuleType, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + self.is_torch = self.value is torch or self.value.__name__.startswith("torch.") + + def python_type(self): + return types.ModuleType + + def as_python_constant(self): + return self.value + + def __repr__(self) -> str: + return f"PythonModuleVariable({self.value})" + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + def var_getattr(self, tx: "InstructionTranslator", name): + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + return tx.output.side_effects.load_attr(self, name) + + if self.is_torch or name not in self.value.__dict__: + try: + attr_value = getattr(self.value, name) + except AttributeError: + raise_observed_exception(AttributeError, tx) + else: + attr_value = self.value.__dict__[name] + + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, attr_value, source) + + +class TypingVariable(VariableTracker): + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # Create a new typing variable, e.g., `List[int]` + if name == "__getitem__" and len(args) == 1: + new_typing = self.value[args[0].as_python_constant()] + return TypingVariable(new_typing) + unimplemented("unsupported method call on typing variablel") + + def var_getattr(self, tx: "InstructionTranslator", name: str): + from .builder import SourcelessBuilder, VariableBuilder + + if name in cmp_name_to_op_mapping: + return variables.GetAttrVariable(self, name) + + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + return tx.side_effects.load_attr(self, name) + + value = getattr(self.value, name) + if self.source: + attr_source = AttrSource(self.source, name) + return VariableBuilder(tx, attr_source)(value) + else: + return SourcelessBuilder.create(tx, value) + + def as_python_constant(self): + return self.value + + def reconstruct(self, codegen: "PyCodegen") -> None: + # We're just trying to load the type here. Reconstructing the type from + # scratch is tricky - for a type like `typing.List[int]` we'd need to + # deconstruct the origin and args. The origin for `List[int]` is `list` + # and the args is `(int,)`. When we recombine those we get the parts + # back and need to emit code for: + # + # `typing.List[int]` + # + # But it's # worse than that - what if `typing` isn't in the globals (or + # was loaded like `import typing as _typing ; _typing.List[int]`?) so we + # really need to do something like: + # + # `sys.modules["typing"].List[int]` + # + # Argh - but what if they rewrote the global `int`? So we have to do: + # + # `sys.modules["typing"].List[sys.modules["builtins"].int]` + # + # But where do we get `sys`? What if they never imported it or have + # something ELSE called `sys`? + # + # Let's skip all that noise and just emit it as a simple const. + # + codegen.append_output(codegen.create_load_const(self.value)) + + +@functools.lru_cache(maxsize=1) +def get_np_to_tnp_map(): + """ + This generates a mapping from numpy modules to their torch._numpy + modules equivalents. + """ + from ..utils import NP_TO_TNP_MODULE + + np_fn_to_tnp_fn = {} + + for np_mod, tnp_mod in NP_TO_TNP_MODULE.items(): + for fn_name, tnp_fn in tnp_mod.__dict__.items(): + if callable(tnp_fn): + # some internal details do leak from tnp + # which are not part of numpy API. + if np_fn := getattr(np_mod, fn_name, None): + np_fn_to_tnp_fn[np_fn] = tnp_fn + + return np_fn_to_tnp_fn + + +@functools.lru_cache(maxsize=1) +def get_tnp_to_np_map(): + """ + This is just the reverse mapping of get_np_to_tnp_map() - mapping from + torch._numpy modules to numpy equivalents. + """ + m = get_np_to_tnp_map() + return {v: k for k, v in m.items()} + + +class NumpyVariable(VariableTracker): + """ + Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. + """ + + constant_fold_functions = (tnp.issubdtype,) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + @classmethod + def can_constant_fold_through(cls, fn): + mod = fn.__module__.split(".") + assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] + return fn in cls.constant_fold_functions + + @classmethod + def get_constant_collection_for_func(cls, fn): + mod = fn.__module__.split(".") + assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] + return np_constant_collections_map.get(fn, None) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if not config.trace_numpy: + unimplemented(f"numpy.{self.value}()") + + from ..utils import numpy_to_tensor_wrapper + from .tensor import NumpyNdarrayVariable + + func = get_np_to_tnp_map().get(self.value) + if func is None: + unimplemented( + f"Can't find numpy function {self.value} in torch._numpy. " + " Please file an issue to request support for this function." + ) + + # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo) + if ( + collection_variable_typ := self.get_constant_collection_for_func(func) + ) is not None: + try: + return collection_variable_typ( + self.value( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + ) + except NotImplementedError: + unimplemented( + f"{self.value.__name__} with non-const args: {args} {kwargs}" + ) + else: + if ( + func.__module__ == "torch._numpy.random" + and config.use_numpy_random_stream + ): + msg = f"delegate '{func.__qualname__}' to NumPy itself via " + msg += ( + f"config.use_numpy_random_stream={config.use_numpy_random_stream}" + ) + unimplemented(msg) + + args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) + + if self.can_constant_fold_through(func) and ( + check_unspec_or_constant_args(args, kwargs) + ): + # constant fold + return variables.ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + + # TODO Add all the functions that go from constants to constants to can_constant_fold_through + proxy = tx.output.create_proxy( + "call_function", + numpy_to_tensor_wrapper(func), + *proxy_args_kwargs(args, kwargs), + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented("numpy") + + def as_python_constant(self): + return self.value + + def as_proxy(self): + if config.trace_numpy and isinstance(self.value, type): + # This handles numpy dtype attributes such as np.float32 + # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph + # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does + return self.value.__name__ + + return super().as_proxy() + + +# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls +class NullVariable(VariableTracker): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def __repr__(self) -> str: + return "NullVariable" + + def reconstruct(self, codegen: "PyCodegen"): + if sys.version_info < (3, 11): + unimplemented("cannot reconstruct NullVariable in < Python 3.11") + codegen.append_output(create_instruction("PUSH_NULL")) + + +class DeletedVariable(VariableTracker): + """Marker used to implement delattr()""" + + +class StringFormatVariable(VariableTracker): + """ + Represents a call to str.format(), we delay calling format until after the graph. + """ + + _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} + + @classmethod + def create(cls, format_string, sym_args, sym_kwargs): + if all( + x.is_python_constant() + for x in itertools.chain(sym_args, sym_kwargs.values()) + ): + return variables.ConstantVariable.create( + format_string.format( + *[v.as_python_constant() for v in sym_args], + **{k: v.as_python_constant() for k, v in sym_kwargs.items()}, + ) + ) + return cls(format_string, list(sym_args), dict(sym_kwargs)) + + def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None: + super().__init__(**kwargs) + assert isinstance(format_string, str) + self.format_string = format_string + self.sym_args = sym_args + self.sym_kwargs = sym_kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_const(self.format_string), + codegen.create_load_attr("format"), + ] + ), + call_function_ex=True, + ) + codegen(variables.TupleVariable(self.sym_args)) + kwargs = { + variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items() + } + codegen(variables.ConstDictVariable(kwargs)) + codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1)) + + +class DebuggingVariable(VariableTracker): + """ + Represents a call to a debugging function like print(), or something + registered to config.reorderable_logging_functions. + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + @staticmethod + def is_reorderable_logging_function(obj): + return ( + callable(obj) + and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)) + and obj in torch._dynamo.config.reorderable_logging_functions + ) + + def call_function(self, tx: "InstructionTranslator", args, kwargs): + if tx.export: + # For export cases, we can just make debugging functions no-ops + return + + if not self.can_reorder_logs(self.value, args, kwargs): + unimplemented( + f"Reordering debugging function {self.value} " + f"with inputs {args} {kwargs} is not yet implemented." + ) + + tx.debug_locals.append((self, list(args))) + + def reconstruct(self, codegen: "PyCodegen"): + return self.source.reconstruct(codegen) + + @staticmethod + def can_reorder_logs(fn, args, kwargs) -> True: + """ + Run some additional checks for what sort of function calls can we + actually reorder. + """ + + allowed_input_types = ( + variables.TensorVariable, + variables.ConstantVariable, + StringFormatVariable, + ) + + flat_args = pytree.tree_leaves([args, kwargs]) + for arg in flat_args: + if not isinstance(arg, allowed_input_types): + return False + + return True + + +class LoggingLoggerVariable(VariableTracker): + """ + Represents a call to any of logging.Logger methods + """ + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if tx.export: + # For export cases, we can just make debugging functions no-ops + return + method = getattr(self.value, name, None) + function = getattr(method, "__func__", None) + if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods): + return variables.ConstantVariable.create(None) + unimplemented( + "Logger not supported for non-export cases. " + "To avoid graph breaks caused by logger in compile-mode, it is recommended to" + " disable logging by adding logging methods to config.ignore_logger_methods" + ) + + +class ConstantLikeVariable(VariableTracker): + """self.value is a compile-time constant, but not a literal""" + + _error_prefix = "ConstantLikeVariable" + try: + from numpy import ( + dtype as np_dtype, + floating as np_floating, + generic as np_generic, + ) + except ImportError: + np_floating = type("invalid_type", (), {}) + np_dtype = type("invalid_type", (), {}) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def as_python_constant(self): + return self.value + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + try: + # we only support constant propagation for methods + cargs = [x.as_python_constant() for x in args] + ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + except NotImplementedError: + unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})") + + result = getattr(self.value, name)(*cargs, **ckwargs) + + if variables.ConstantVariable.is_literal(result): + return variables.ConstantVariable.create(result) + if isinstance(result, re.Match): + return ConstantRegexMatchVariable(result) + + unimplemented(f"{self._error_prefix}.{name}() -> {result}") + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + result = getattr(self.value, name) + if isinstance(result, self.np_floating): + result = float(result) + if isinstance(result, self.np_dtype): + return NumpyDTypeVariable(result) + if isinstance(result, type) and issubclass(result, self.np_generic): + # things like x.dtype.type + return NumpyVariable(result) + if variables.ConstantVariable.is_literal(result): + return variables.ConstantVariable.create(result) + return GetAttrVariable(self, name) + + +class RegexPatternVariable(ConstantLikeVariable): + _error_prefix = "re.Pattern" + + +class ConstantRegexMatchVariable(ConstantLikeVariable): + _error_prefix = "re.Match" + + +class TorchVersionVariable(ConstantLikeVariable): + _error_prefix = "torch.__version__" + + def __init__(self, **kwargs) -> None: + kwargs.setdefault("value", torch.__version__) + assert kwargs["value"] is torch.__version__ + super().__init__(**kwargs) + + +class NumpyTypeInfoVariable(ConstantLikeVariable): + _error_prefix = "np.iinfo/np.finfo" + + +class NumpyDTypeVariable(ConstantLikeVariable): + _error_prefix = "np.dtype[...]" + + def as_proxy(self): + """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable: + + np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype. + This also handles unsupported things nicely (i.e. structured arrays and object arrays). + """ + return self.value.type.__name__ + + +np_constant_collections_map = { + tnp.finfo: NumpyTypeInfoVariable, + tnp.iinfo: NumpyTypeInfoVariable, + tnp.dtype: NumpyDTypeVariable, +} + + +class RandomClassVariable(VariableTracker): + """random.Random""" + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def call_function(self, tx: "InstructionTranslator", args, kwargs): + if len(args) > 1: + unimplemented("random.Random() with > 1 arg") + elif kwargs: + unimplemented("random.Random() with kwargs") + seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0] + return RandomVariable( + seed=seed, mutation_type=variables.base.ValueMutationNew() + ) + + +class RandomVariable(VariableTracker): + """random.Random() + + Implemented by wrapping a VariableTracker around a random.Random object. + The supported methods for the random.Random object cannot be overridden. + Assumes that random objects behave the same given a set seed or state. + """ + + _nonvar_fields = { + "random", + *VariableTracker._nonvar_fields, + } + + _supported_fn_names = { + "random", + "randint", + "randrange", + "uniform", + } + + def __init__( + self, + rand: Optional[random.Random] = None, + seed: Optional[VariableTracker] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if rand is not None: + assert self.is_supported_random_obj(rand) + self.random = random.Random() + self.random.setstate(rand.getstate()) + else: + seed = seed.as_python_constant() if seed is not None else None + self.random = random.Random(seed) + + def python_type(self): + return random.Random + + def as_python_constant(self): + return self.random + + @staticmethod + def is_supported_random_obj(val): + if type(val) is not random.Random: + return False + for name in itertools.chain( + RandomVariable._supported_fn_names, ("seed", "getstate", "setstate") + ): + if not hasattr(val, name): + return False + meth = getattr(val, name) + if inspect.isbuiltin(meth): + # e.g. random.Random.random + if meth != getattr(random.Random, name).__get__(val): + return False + else: + if getattr(meth, "__func__", None) is not getattr(random.Random, name): + return False + return True + + @staticmethod + def check_state(state): + assert type(state) is tuple + assert type(state[0]) is int + assert type(state[1]) is tuple + assert all(type(x) is int for x in state[1]) + assert state[2] is None or type(state[2]) is float + + @staticmethod + def wrap_state(state): + RandomVariable.check_state(state) + return variables.TupleVariable( + [ + variables.ConstantVariable.create(state[0]), + variables.TupleVariable( + [variables.ConstantVariable.create(x) for x in state[1]] + ), + variables.ConstantVariable.create(state[2]), + ] + ) + + @staticmethod + def unwrap_state(state): + state_obj = state.as_python_constant() + RandomVariable.check_state(state_obj) + return state_obj + + def call_method( + self, + tx: "InstructionTranslator", + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "seed": + tx.output.side_effects.mutation(self) + self.random.seed( + *[x.as_python_constant() for x in args], + **{key: val.as_python_constant() for key, val in kwargs.items()}, + ) + return variables.ConstantVariable.create(None) + elif name == "getstate": + return self.wrap_state(self.random.getstate()) + elif name == "setstate": + tx.output.side_effects.mutation(self) + self.random.setstate(self.unwrap_state(args[0])) + return variables.ConstantVariable.create(None) + elif name in self._supported_fn_names: + tx.output.side_effects.mutation(self) + state = self.random.getstate() + + def call_random_meth(*args, **kwargs): + r = random.Random() + r.setstate(state) + return getattr(r, name)(*args, **kwargs) + + # self.random state not actually updated by call_random_meth, so update here + # by calling the method + getattr(self.random, name)( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + + return call_random_fn(tx, call_random_meth, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(random), + codegen.create_load_attr("Random"), + ] + ) + ) + codegen.call_function(0, False) + # NOTE using add_push_null may result in NULL being duplicated + # so defer the push_null to call_function + codegen.dup_top() + codegen.load_attr("setstate") + codegen(self.wrap_state(self.random.getstate())) + codegen.call_function(1, True) + codegen.pop_top() + + +class WeakRefVariable(VariableTracker): + @staticmethod + def build(tx, weakref_value, **options): + source = options.get("source", None) + callback = weakref_value.__callback__ + callback_source = source and AttrSource(source, "__callback__") + callback_vt = VariableTracker.build(tx, callback, callback_source) + referent = weakref_value() + source = source and WeakRefCallSource(source) + referent_vt = VariableTracker.build(tx, referent, source) + options["source"] = source + return WeakRefVariable(referent_vt, callback_vt, **options) + + def __init__(self, referent_vt, callback_vt, **options): + super().__init__(**options) + self.referent_vt = referent_vt + self.callback_vt = callback_vt + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return self.referent_vt + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref")) + codegen(self.referent_vt) + codegen(self.callback_vt) + codegen.extend_output(create_call_function(2, False)) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/nn_module.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/nn_module.py new file mode 100644 index 0000000000000000000000000000000000000000..1f64fe836ed6af0e226be32e7fd36b69f6218403 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/nn_module.py @@ -0,0 +1,1227 @@ +# mypy: ignore-errors + +""" +This module implements variable tracking for PyTorch nn.Module instances during Dynamo tracing. + +It provides specialized handling for different types of nn.Module instances through several key classes: + +- NNModuleVariable: Handles instance-specific module tracing, specializing on module id() and placing + parameters directly on the torch.fx.GraphModule. This creates one graph per module instance. + +- UnspecializedNNModuleVariable: Provides class-level module tracing, treating nn.Modules like other + user-defined objects and passing parameters as inputs to the FX graph. This creates one graph per + module class. + +- UnspecializedBuiltinNNModuleVariable: Specifically handles built-in PyTorch modules (e.g. nn.Linear) + with appropriate optimizations. + +- FSDPManagedNNModuleVariable: Special handling for FSDP-wrapped modules with modified guarding behavior + and parameter handling. + +The module integrates with Dynamo's broader tracing functionality to handle module method calls, +parameter access, hooks, and other nn.Module behaviors while maintaining proper scoping and guarding +of module state. +""" + +import functools +import inspect +import itertools +import types +from contextlib import contextmanager, nullcontext +from typing import TYPE_CHECKING + +import torch.nn + +from .. import graph_break_hints, trace_rules, variables +from ..exc import ( + raise_observed_exception, + unimplemented_v2, + UnspecializeRestartAnalysis, + Unsupported, +) +from ..guards import GuardBuilder, install_guard +from ..mutation_guard import GenerationTracker +from ..source import ( + AttrSource, + ConstDictKeySource, + DictGetItemSource, + FSDPNNModuleSource, + GetItemSource, + NNModuleSource, + UnspecializedNNModuleSource, +) +from ..utils import ( + get_custom_getattr, + get_fake_value, + is_lazy_module, + is_namedtuple, + is_safe_constant, + istensor, + istype, + nnmodule_has_hooks, + object_has_getattribute, + proxy_args_kwargs, + set_example_value, + unpatched_nn_module_call, + unpatched_nn_module_call_impl, +) +from .base import typestr, ValueMutationNew, VariableTracker +from .functions import invoke_and_store_as_constant +from .lazy import LazyVariableTracker +from .lists import SliceVariable +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs): + """ + Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable. + + Used to cause lazy module to be initialized (and delete its init hook) before tracing. Especially + useful now that 'allowed' modules graph-break on hooks, calling this first ensures there is no hook + by the time we trace __call__ and thus no graph-break for lazy allowed modules. + """ + if hasattr(mod, "_initialize_hook"): + + def convert_to_fake(x): + if is_namedtuple(x): + return type(x)(*(convert_to_fake(elem) for elem in x)) + elif isinstance(x, dict): + return {k: convert_to_fake(v) for k, v in x.items()} + elif isinstance(x, (list, tuple, set)): + return type(x)(convert_to_fake(elem) for elem in x) + elif isinstance(x, torch.fx.Proxy): + return get_fake_value(x.node, tx) + else: + return x + + proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) + fake_args = [convert_to_fake(arg) for arg in proxy_args] + fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()} + mod._infer_parameters(mod, fake_args, fake_kwargs) + + +@contextmanager +def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): + fully_qualified_name = source.name() + num_calls = tx.num_calls.get(fully_qualified_name, 0) + module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key + try: + tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__) + tx.num_calls[fully_qualified_name] = num_calls + 1 + yield + finally: + del tx.nn_module_stack[module_key] + + +def guard_to_detect_forward_monkeypatching(source, mod): + # Users sometimes patch the forward method of a nn module instance to + # perform optimizations like quantization. Though this is not a good + # software practice, but python allows this and Dynamo needs to detect + # this patching. + # + # One way to do this is to add an ID_MATCH guard on every function + # getting inlined (https://github.com/pytorch/pytorch/pull/124975). But + # this increased guard overhead by around 20%. + # + # To keep the guard overhead down, we just guard on the `forward` being + # not present in the mod __dict__. The common case of patching forward + # method adds `forward` in the instance __dict__, whereas the unpatched + # `forward` sits in the type(mod).__dict__ + if source: + if "forward" in mod.__dict__ and callable(mod.__dict__["forward"]): + # Monkeypatched forward method, add an ID_MATCH guard on forward function + fwd = mod.__dict__["forward"] + forward_source = AttrSource(source, "forward") + if type(fwd) is types.MethodType: + forward_source = AttrSource(forward_source, "__func__") + install_guard(forward_source.make_guard(GuardBuilder.CLOSURE_MATCH)) + else: + # Common case - check that the forward key is absent in mod __dict__ + install_guard( + source.make_guard( + functools.partial( + GuardBuilder.NOT_PRESENT_IN_GENERIC_DICT, attr="forward" + ) + ) + ) + + +class NNModuleVariable(VariableTracker): + _nonvar_fields = { + "module_type", + "module_key", + "value", + "nn_module_stack_source", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, module_type: type, module_key: str, value: torch.nn.Module, **kwargs + ) -> None: + super().__init__(**kwargs) + self.module_type = module_type + self.module_key = module_key + self.value = value + assert self.source + self.nn_module_stack_source = self.source + + def get_nn_module_stack_source(self): + return self.nn_module_stack_source or self.source + + def set_nn_module_stack_source(self, source): + self.nn_module_stack_source = source + + def python_type(self): + return self.module_type + + def _wrap_submodule( + self, tx: "InstructionTranslator", source, submod, *key_extra, **options + ): + return + + def unpack_var_sequence(self, tx): + # implement list/iter/tuple/etc calls + base = tx.output.get_submodule(self.module_key) + if isinstance(base, torch.nn.ModuleDict): + result = [] + for name, submod in base.items(): + name_var = variables.ConstantVariable.create(name) + tx.output.register_attr_or_module( + submod, + self.module_key, + name, + source=NNModuleSource(GetItemSource(self.source, name)), + ) + result.append(name_var) + return result + + assert isinstance( + base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential) + ), typestr(base) + assert self.source + result = [] + for idx, submod in enumerate(base): + result.append( + tx.output.register_attr_or_module( + submod, + self.module_key, + idx, + source=NNModuleSource(GetItemSource(self.source, idx)), + ) + ) + return result + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + mod = tx.output.get_submodule(self.module_key) + result = hasattr(mod, name) + install_guard( + NNModuleSource(AttrSource(self.source, name)).make_guard( + GuardBuilder.HASATTR + ) + ) + return variables.ConstantVariable.create(result) + + def is_training(self, tx): + mod = tx.output.get_submodule(self.module_key) + return getattr(mod, "training", False) + + def convert_to_unspecialized(self, tx): + """Restart analysis treating this module as an UnspecializedNNModuleVariable""" + mod = tx.output.get_submodule(self.module_key) + GenerationTracker.tag(mod) + + # Mark the class dynamic unless its module initialization + if tx.f_code.co_name != "__init__": + GenerationTracker.mark_class_dynamic(type(mod)) + raise UnspecializeRestartAnalysis + + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): + base = tx.output.get_submodule(self.module_key) + + if object_has_getattribute(base): + unimplemented_v2( + gb_type="Custom __getattribute__ in nn.Module dict key check", + context=f"has_key_in_generic_dict {self} {key}", + explanation="Dynamo does not support checking key existence " + "on `nn.Module` instances that have a custom " + "`__getattribute__` method defined.", + hints=[ + "Avoid defining `__getattribute__` in your module.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + base_dict = object.__getattribute__(base, "__dict__") + return key in base_dict + + def _custom_getattr_fallback(self, base, tx, name, obj_source): + """Check for a __getattr__ and handle it specially if it is implemented""" + if object_has_getattribute(base): + unimplemented_v2( + gb_type="Custom __getattribute__ in nn.Module attribute access", + context=f"var_getattr {self} {name}", + explanation="Dynamo does not support checking key existence " + "on `nn.Module` instances that have a custom " + "`__getattribute__` method defined.", + hints=[ + "Avoid defining `__getattribute__` in your module.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + getattr_fn = get_custom_getattr(base, ignore_nn_module_getattr=True) + if getattr_fn is None: + return None + + if not isinstance(getattr_fn, types.FunctionType): + unimplemented_v2( + gb_type="torch.nn.Module with a non-function custom __getattr__", + context=f"var_getattr {self} {name}", + explanation=( + "Dynamo detected a nn.Module object with a custom " + "`__getattr__` method, but this method is not a standard " + "Python function (e.g., it might be implemented in C/C++). " + "Dynamo cannot currently trace into such non-standard " + "`__getattr__` methods." + ), + hints=[ + "Avoid using objects with non-standard __getattr__ methods " + "within the compiled region. If possible, implement " + "__getattr__ as a standard Python function.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + options = {"source": AttrSource(obj_source, "__getattr__")} + return variables.UserMethodVariable(getattr_fn, self, **options).call_function( + tx, [variables.ConstantVariable.create(name)], {} + ) + + def var_getattr(self, tx: "InstructionTranslator", name): + source = self.source and AttrSource(self.source, name) + + base = tx.output.get_submodule(self.module_key) + base_dict = object.__getattribute__(base, "__dict__") + object_member = True + all_class_attribute_names = set() + for x in inspect.getmro(base.__class__): + all_class_attribute_names.update(x.__dict__.keys()) + + if not self.source: + unimplemented_v2( + gb_type="getattr with no source", + context=f"var_getattr {self} {name}", + explanation="Dynamo does not know how to access an attribute " + "on an `nn.Module` instance that lacks a source. This is " + "usually an internal error in Dynamo.", + hints=[*graph_break_hints.DYNAMO_BUG], + ) + + if name == "__dict__": + return variables.GetAttrVariable(self, name, source=source) + + if name in base_dict: + subobj = base_dict[name] + elif ( + "_modules" in base_dict + and name in base_dict["_modules"] + and name not in all_class_attribute_names + ): + subobj = base_dict["_modules"][name] + elif "_parameters" in base_dict and name in base_dict["_parameters"]: + subobj = base_dict["_parameters"][name] + elif "_buffers" in base_dict and name in base_dict["_buffers"]: + subobj = base_dict["_buffers"][name] + else: + try: + subobj = inspect.getattr_static(base, name) + object_member = False + except AttributeError: + # see if we can fallback to __getattr__, which is not checked by getattr_static + result = self._custom_getattr_fallback( + base=base, tx=tx, name=name, obj_source=self.source + ) + if result is not None: + return result + # if we can't find a __getattr__, we can't parse this, raise attribute error + raise_observed_exception( + AttributeError, + tx, + ) + + if name == "forward": + guard_to_detect_forward_monkeypatching(self.source, base) + + if name == "__class__" and not object_member: + return variables.UserDefinedClassVariable(base.__class__, source=source) + + if object_member: + out = VariableTracker.build(tx, subobj, NNModuleSource(source)) + + if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + + else: + if istype(subobj, property): + if self.source: + # Read the class attribute to reach the property + source = AttrSource(AttrSource(self.source, "__class__"), name) + # Get the getter function + source = AttrSource(source, "fget") + return variables.UserFunctionVariable( + subobj.fget, + source=source, + ).call_function(tx, [(self)], {}) + elif istype(subobj, classmethod): + return variables.UserMethodVariable( + subobj.__func__, + variables.UserDefinedObjectVariable(type(base)), + source=source, + ) + elif istype(subobj, staticmethod): + return variables.UserFunctionVariable( + subobj.__get__(base), source=source + ) + elif istype(subobj, types.FunctionType): + return variables.UserMethodVariable(subobj, self, source=source) + elif is_safe_constant(subobj) or istensor(subobj): + # Support possibly common cases of class members + return VariableTracker.build(tx, subobj, NNModuleSource(source)) + else: + unimplemented_v2( + gb_type="Unsupported nn.Module attribute type", + context=f"nn.Module subclass: {typestr(base)}, name: {name}, attribute type: {typestr(subobj)}", + explanation=f"Dynamo does not support tracing nn.Module attributes of type `{typestr(subobj)}`", + hints=[ + f"Refactor your code so that `{name}` (type `{typestr(subobj)}`) is not an attribute of `{typestr(base)}`", + "Currently supported attribute types are methods, classmethods, staticmethods, " + "properties, constants, and tensors.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + return variables.GetAttrVariable(self, name, source=source) + + def call_function( + self, + tx, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + mod = tx.output.get_submodule(self.module_key) + + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, mod + ): + is_lazy = is_lazy_module(mod) + if ( + isinstance(mod, torch.nn.Sequential) + and mod.__class__.forward is torch.nn.Sequential.forward + ): + if nnmodule_has_hooks(mod): + # We do not want to unroll sequential if it has hooks, since evaporating it + # will cause hooks to not fire! + # This terminates and restart the tracing process + self.convert_to_unspecialized(tx) + + # Unroll sequential + assert not is_lazy, ( + "Expected lazy sequential isn't a valid combination?" + ) + assert not kwargs + (arg,) = args + # TODO: Use named_children when it supports remove_duplicate=False. + for child_name, submod in mod._modules.items(): + tx.call_function( + tx.output.register_attr_or_module( + submod, + self.module_key, + child_name, + source=NNModuleSource(AttrSource(self.source, child_name)), + ), + [arg], + {}, + ) + arg = tx.pop() + return arg + + if is_lazy: + # The module type will change after it is called + if mod.cls_to_become is not None: + self.module_type = mod.cls_to_become + + # The pre-hook runs to initialize the module shapes, then deletes itself. After this, + # the module is more or less not lazy and can be treated as a normal module regardless of + # is_allowed or other variations. + initialize_lazy_module(tx, mod, args, kwargs) + + # If we are tracing the higher order op, we want Dynamo to step + # inside the module call so that Dynamo can see the underlying + # parameters and buffers and raise them as inputs to the graph. + # + # NB: torch.nn.utils.parametrize changes the class type of a + # parametrized module such that its __module__ points to + # "torch.nn.utils.parametrize". + if ( + tx.output.is_root_tracer() + and mod.__module__.startswith(("torch.nn.", "torch.ao.")) + and mod.__module__ != "torch.nn.utils.parametrize" + ): + if nnmodule_has_hooks( + mod, check_forward_hooks=True, check_backward_hooks=True + ): + # End of fn, this bubbles up and restarts tracing. + self.convert_to_unspecialized(tx) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_module", + self.module_key, + *proxy_args_kwargs(args, kwargs), + ), + ) + else: + assert self.source, ( + "Must provide a valid source in order to inline, " + "since inlined function may have default args which must be guarded." + ) + if isinstance(mod, torch.fx.GraphModule): + # TODO: do we want to support __call__ for GM's? + # If so at least some changes are needed, we don't allow inlining + # the call_wrapped currently, and maybe other issues too + fn = mod.forward + fn_source = AttrSource(self.source, "forward") + else: + fn = mod._call_impl + fn_source = AttrSource(self.source, "_call_impl") + if istype(fn, types.MethodType): + fn = fn.__func__ + fn_source = AttrSource(fn_source, "__func__") + args = [self] + args + else: + assert istype(fn, types.FunctionType) + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=fn_source), + args, + kwargs, + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + constant=False, + ) -> "VariableTracker": + from . import ConstantVariable, ListIteratorVariable, TupleVariable + + key = self.module_key + module = tx.output.get_submodule(key) + + def generic_call_method_helper(name): + # Helper function to put a `call_method` node in FX graph, + # with nn.Module as the first arg. + mod_proxy = tx.output.create_proxy( + "get_attr", + self.module_key, + (), + {}, + ) + set_example_value(mod_proxy.node, module) + + proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_method", + name, + args=(mod_proxy, *proxy_args), + kwargs=proxy_kwargs, + ), + ) + + if name in ["_call_impl", "_wrapped_call_impl"]: + # Example: `self.layer.__call__(x)` + # This is used for explicit calling `__call__` in a forward function. + # Dynamo inlines `__call__`, includes hooks. + return self.call_function(tx, args, kwargs) + elif name == "forward": + # Example: `self.layer.forward(x)` + # This is used for explicit calling `forward` in a forward function. + # Dynamo puts `call_method` node in FX, doesn't trigger hooks. + with record_nn_module_stack( + self.module_key, self.get_nn_module_stack_source(), tx, module + ): + return generic_call_method_helper(name) + + if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed( + inspect.getfile(module.__class__._check_input_dim) + ): + return ConstantVariable.create(True) + + if name == "_get_item_by_idx": + assert args[1].is_python_constant() + assert isinstance(args[0], TupleVariable) + mod_var = args[0].items[args[1].value] + if isinstance(mod_var, UnspecializedNNModuleVariable): + return mod_var + key = mod_var.module_key + submod = tx.output.get_submodule(key) + return tx.output.register_attr_or_module( + submod, + key, + key, + source=NNModuleSource(GetItemSource(self.source, key)), + ) + + if constant: + fn = getattr(module, name) + name = f"{module.__class__.__name__}_{name}_result" + return invoke_and_store_as_constant(tx, fn, name, args, kwargs) + + def assert_all_args_kwargs_const(): + if not all( + x.is_python_constant() for x in itertools.chain(args, kwargs.values()) + ): + unimplemented_v2( + gb_type="non-const argument in nn.Module method", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation="Dynamo does not support calling " + f"method `{name}` of ``nn.Module`` {module} with non-constant arguments.", + hints=[], + ) + + def get_kwargs(*names): + assert_all_args_kwargs_const() + fn = getattr(module, name) + bound_args = inspect.signature(fn).bind( + *([x.as_python_constant() for x in args]), + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + bound_args.apply_defaults() + bound_args = bound_args.arguments + return {k: bound_args[k] for k in names} + + def wrap_values(items): + result = [] + for name, submod in items: + result.append( + tx.output.register_attr_or_module( + submod, + key, + name, + source=NNModuleSource(gen_source(self.source, name)), + ) + ) + return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + + def named_embed(name, obj): + return TupleVariable( + [ + ConstantVariable.create(name), + tx.output.register_attr_or_module( + obj, + key, + name, + source=NNModuleSource(gen_source(self.source, name)), + ), + ] + ) + + def gen_source(source, name): + name_split = name.split(".") + if name_split[0] == "": + return source + while len(name_split) > 0: + x = name_split.pop(0) + source = AttrSource(source, x) + return source + + if name == "named_children": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + assert not (args or kwargs) + result = [] + for name, submod in module.named_children(): + result.append(named_embed(name, submod)) + return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + elif name == "named_parameters": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters")) + result = [] + for name, param in module.named_parameters( + **get_kwargs("prefix", "recurse") + ): + result.append(named_embed(name, param)) + return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + elif name == "named_buffers": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers")) + result = [] + for name, buffer in module.named_buffers( + **get_kwargs("prefix", "recurse", "remove_duplicate") + ): + result.append(named_embed(name, buffer)) + return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + elif name == "named_modules": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + result = [] + for name, submod in module.named_modules( + **get_kwargs("memo", "prefix", "remove_duplicate") + ): + result.append(named_embed(name, submod)) + return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + elif name == "children": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + assert not (args or kwargs) + return wrap_values(module.named_children()) + elif name == "modules": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_modules")) + return wrap_values(module.named_modules()) + elif name == "parameters": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_parameters")) + return wrap_values(module.named_parameters(**get_kwargs("recurse"))) + elif name == "buffers": + tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers")) + return wrap_values(module.named_buffers(**get_kwargs("recurse"))) + elif name == "keys": + assert not (args or kwargs) + result = [] + for name in module.keys(): + result.append(ConstantVariable.create(name)) + return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + elif name == "values": + assert not (args or kwargs) + return wrap_values(module.items()) + elif name == "items": + assert not (args or kwargs) + result = [] + for name, submod in module.items(): + result.append(named_embed(name, submod)) + return ListIteratorVariable(result, mutation_type=ValueMutationNew()) + elif name == "__len__": + assert not (args or kwargs) + return ConstantVariable.create(len(module)) + elif ( + name == "__contains__" + and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict)) + and args + and args[0].is_python_constant() + ): + return ConstantVariable.create( + args[0].as_python_constant() in module._modules + ) + elif name == "__getitem__": + assert not kwargs and len(args) == 1 + builtin_supported = ( + torch.nn.ModuleDict.__getitem__, + torch.nn.ModuleList.__getitem__, + torch.nn.ParameterDict.__getitem__, + torch.nn.ParameterList.__getitem__, + torch.nn.Sequential.__getitem__, + ) + + if type(module).__getitem__ not in builtin_supported: + assert isinstance(args[0], variables.ConstantVariable), typestr(args[0]) + key = args[0].as_python_constant() + assert isinstance(key, (str, int)) + fn = getattr(module, name).__func__ + + assert isinstance(fn, types.FunctionType) + + src = AttrSource(AttrSource(self.source, name), "__func__") + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=src), + [self] + list(args), + kwargs, + ) + + assert self.source + + if isinstance(args[0], SliceVariable): + # TODO(anijain2305,export-team) - Remove this if condition when inlining of inbuilt nn modules is + # enabled for export. + if tx.output.export: + # Build a TupleVariable of NNModules + result = [] + + # Turn the slice into the list of integers + keys = list(range(len(module)))[args[0].as_python_constant()] + for idx, submod in enumerate(module[args[0].as_python_constant()]): + key = keys[idx] + src = NNModuleSource(GetItemSource(self.source, key)) + result.append( + tx.output.register_attr_or_module( + submod, + key, + source=src, + ) + ) + + new_module = module[args[0].as_python_constant()] + new_module_variable = tx.output.register_attr_or_module( + new_module, + f"{self}.__getitem__(slice)", + source=NNModuleSource( + GetItemSource(self.source, args[0].as_python_constant()) + ), + ) + return new_module_variable + else: + # slice on nn module results in a creation of new module instance, so we need to make it sourceless. + # Convert to unspecialized so that UnspecializedNNModule variable can take care of it. + self.convert_to_unspecialized(tx) + + from .tensor import SymNodeVariable + + if isinstance(args[0], SymNodeVariable): + key = args[0].evaluate_expr(tx.output) + elif args[0].is_python_constant(): + key = args[0].as_python_constant() + else: + unimplemented_v2( + gb_type="Unsupported key type for nn.Module.__getitem__", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation="Dynamo does not support getitem on " + "`nn.Module` with non-constant key.", + hints=[], + ) + + submod = module[key] + return tx.output.register_attr_or_module( + submod, + self.module_key, + key, + source=NNModuleSource(GetItemSource(self.source, key)), + ) + elif ( + name == "_get_abs_string_index" + or ( + isinstance(module, torch.nn.modules.conv._ConvNd) + and name == "_conv_forward" + ) + or ( + isinstance(module, torch.nn.modules.conv._ConvTransposeNd) + and name == "_output_padding" + ) + ): + # Inline the function + fn = getattr(module, name).__func__ + fn_source = AttrSource(AttrSource(self.source, name), "__func__") + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=fn_source), + [self] + args, + kwargs, + ) + # A loose heuristic, but seems to be generally good before we drop into the + # manual handling of inputs + elif ( + name in module.__class__.__dict__ + and callable(module.__class__.__dict__[name]) + and all( + isinstance(x, variables.TensorVariable) + for x in itertools.chain(args, kwargs.values()) + ) + ): + return generic_call_method_helper(name) + else: + return super().call_method(tx, name, args, kwargs) + + +class UnspecializedNNModuleVariable(UserDefinedObjectVariable): + _nonvar_fields = { + "value_type", + "is_state_mutated", + "nn_module_stack_source", + *UserDefinedObjectVariable._nonvar_fields, + } + + """ + The above class will specialize on the id() of a module and place + parameters on the torch.fx.GraphModule. Giving one graph per + module instance. This version treats nn.Modules() like other user + defined objects and will pass parameters into the FX graph as inputs. + Giving one graph per module class. + """ + + def __init__(self, value, **kwargs) -> None: + if type(value) is torch.jit._script.RecursiveScriptModule: + raise Unsupported( + "ScriptModules aren't supported in UnspecializedNNModuleVariable" + " because their .forward function isn't a static member of their type" + ) + if "value_type" in kwargs: + lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None) + if type(value) is lazy_value_to_become: + # We may have cloned a variabletracker for a LazyModule earlier (e.g. tracking side-effects) + # and then later we called and mutated the LazyModule into a MaterializedModule. + # We do not do the mutation upon first seeing a LazyModule since we preserve eager semantics to only + # mutate upon first call, but this requires we update multiple copies of the VariableTracker post-mutation. + kwargs["value_type"] = type(value) + + super().__init__(value=value, **kwargs) + self.is_state_mutated = False + # nn_module_stack_source is used to ensure BC for nn_module_stack. + # Downstream users prefer mod.linear instead of mod._modules['linear'] + # as the module stack. When Dynamo inlines the __getattr__ method, we + # cannot use self.source for nn_module_stack because it will be similar + # to mod._modules['linear']. In these cases, we set the + # nn_module_stack_source appropriately to resemble mod.linear. + self.nn_module_stack_source = self.source + + def _wrap_source(self, attr_source): + # the vt is already wrapped with UnspecializedNNModuleSource + return attr_source + + def get_nn_module_stack_source(self): + return self.nn_module_stack_source or self.source + + def set_nn_module_stack_source(self, source): + self.nn_module_stack_source = source + + @staticmethod + @functools.cache + def _nn_module_method_ids(): + # Allow __setattr__ to fall through to base class handler + supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} + return { + id(x.__code__) + for x in torch.nn.Module.__dict__.values() + if hasattr(x, "__code__") and x not in supported + } + + def unpack_var_sequence(self, tx): + try: + fn = inspect.getattr_static(self.value_type, "__iter__") + except AttributeError as e: + raise NotImplementedError from e + + if fn in ( + torch.nn.ModuleList.__iter__, + torch.nn.ParameterList.__iter__, + torch.nn.Sequential.__iter__, + ): + # The program can mutate the nn module object but the saved `value` + # will not reflect the mutations. So, trace through the `__iter__` + # function to reflect any tracked mutations. + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn), + [ + self, + ], + {}, + ).unpack_var_sequence(tx) + + return super().unpack_var_sequence(tx) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + mod = self.value + # see comment on lazy module handling in NNModuleVariable.call_function for context + if is_lazy_module(mod): + if mod.cls_to_become is not None: + self.value_type = mod.cls_to_become + initialize_lazy_module(tx, mod, args, kwargs) + + if ( + not isinstance(mod, torch.fx.GraphModule) + and mod.__call__.__func__ is not unpatched_nn_module_call + ): + name = "__call__" + fn = getattr(self.value_type, name) + else: + name = "_call_impl" + fn = getattr(self.value_type, name) + + # Check if we can short circuit nn.Module._call_impl to the forward + # method. NB - This is done to reduce the compile time of Dynamo. + if ( + istype(mod.__call__, types.MethodType) + and istype(mod._call_impl, types.MethodType) + and mod.__call__.__func__ is unpatched_nn_module_call + and mod._call_impl.__func__ is unpatched_nn_module_call_impl + and "forward" not in mod.__dict__ + ): + forward_method = inspect.getattr_static(mod, "forward") + if isinstance(forward_method, types.FunctionType): + globals_vt = tx.nn_modules_globals_vt + if not ( + self.var_getattr(tx, "_backward_hooks").realize().len() + or self.var_getattr(tx, "_backward_pre_hooks").realize().len() + or self.var_getattr(tx, "_forward_hooks").realize().len() + or self.var_getattr(tx, "_forward_pre_hooks").realize().len() + or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() + or globals_vt.var_getattr(tx, "_global_backward_hooks").len() + or globals_vt.var_getattr(tx, "_global_forward_hooks").len() + or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() + ): + name = "forward" + fn = self.value_type.forward + + if self.source: + source = AttrSource(AttrSource(self.source, "__class__"), name) + else: + source = None + + guard_to_detect_forward_monkeypatching(self.source, mod) + + ctx = ( + record_nn_module_stack( + str(id(mod)), self.get_nn_module_stack_source(), tx, mod + ) + if self.source + else nullcontext() + ) + with ctx: + return variables.UserFunctionVariable(fn, source=source).call_function( + tx, [self] + list(args), kwargs + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if name in ["_call_impl", "_wrapped_call_impl"]: + fn = getattr(self.value_type, name) + if self.source: + source = AttrSource(AttrSource(self.source, "__class__"), name) + else: + source = None + + return variables.UserFunctionVariable(fn, source=source).call_function( + tx, [self] + list(args), kwargs + ) + + if name not in getattr(self.value, "__dict__", {}): + try: + method = inspect.getattr_static(type(self.value), name) + except AttributeError: + method = None + + if isinstance(method, staticmethod): + source = AttrSource( + AttrSource(AttrSource(self.source, "__class__"), name), "__func__" + ) + return tx.inline_user_function_return( + variables.UserFunctionVariable(method.__func__, source=source), + args, + kwargs, + ) + + if ( + hasattr(method, "__code__") + and id(method.__code__) in self._nn_module_method_ids() + ): + unimplemented_v2( + gb_type="UnspecializedNNModuleVariable missing method", + context=f"call_method: {self} {name} {args} {kwargs}", + explanation=f"Dynamo does not support tracing method {name} of nn.Module {self.value}", + hints=[ + "Dynamo does not really define unspecialized nn.Module very well.", + *graph_break_hints.DIFFICULT, + ], + ) + + # "_parameters" in self.value.__dict__ checks that module is initialized + if name == "__setattr__" and "_parameters" in self.value.__dict__: + # Record if mutations happens on parameters/buffers/modules. The + # mutations on these are not tracked by base class + # UserDefinedObject vt. This will be used later to graph break + # on seeing a parameters() and family calls. + # TODO(anijain2305) - This might not be needed if we let Dynamo + # inline both getattr and setattr. In that case, it should see + # the lowest level dicts - _parameters and family and + # automatically track mutations on those. Investigate if that + # can be done. + attr_name = args[0].as_python_constant() + value = args[1] + + # This is reverse engineered by looking at nn module __setattr__ + # logic. + if ( + isinstance(value, variables.TensorVariable) + and value.python_type() is torch.nn.Parameter + ) or attr_name in self.value.__dict__["_parameters"]: + # Handle parameters + self.is_state_mutated = True + elif attr_name in self.value.__dict__["_buffers"]: + # Handle buffers + self.is_state_mutated = True + elif ( + isinstance( + value, + ( + variables.NNModuleVariable, + variables.UnspecializedNNModuleVariable, + ), + ) + or attr_name in self.value.__dict__["_modules"] + ): + # Handle submodules + self.is_state_mutated = True + + if method is torch.nn.Module.__setattr__ and isinstance( + args[1], variables.DeletedVariable + ): + # Trace through __delattr__ to track mutations on the module + # members like `_modules``. + return tx.inline_user_function_return( + variables.UserFunctionVariable(torch.nn.Module.__delattr__), + [self, args[0]], + kwargs, + ) + + return super().call_method(tx, name, args, kwargs) + + def getattr_helper(self, tx: "InstructionTranslator", field, name_vt): + dict_vt = self.var_getattr(tx, field) + if isinstance(dict_vt, variables.ConstDictVariable): + return dict_vt.maybe_getitem_const(name_vt) + return None + + def var_getattr(self, tx: "InstructionTranslator", name): + # Allow skipping of empty hook dict guards on inbuilt nn modules + if name in ( + "_backward_hooks", + "_backward_pre_hooks", + "_forward_hooks", + "_forward_pre_hooks", + ): + # For empty hooks, make an EMPTY_NN_MODULE_HOOKS_DICT. This allows us to control the installation of empty + # hooks guard via skip_nnmodule_hook_guards + if not tx.output.side_effects.has_pending_mutation_of_attr(self, name): + hooks_dict = getattr(self.value, name) + if isinstance(hooks_dict, dict) and len(hooks_dict) == 0: + if self.source: + hooks_source = AttrSource(self.source, name) + install_guard( + hooks_source.make_guard( + GuardBuilder.EMPTY_NN_MODULE_HOOKS_DICT + ) + ) + return variables.ConstDictVariable({}) + + # For non-empty hook dicts, one way is to just fallback to VariableTracker.build() and create a ConstDictVariable. + # However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for + # different nn module instances, because the key keeps changing (look more into RemovableHandle to understand why + # key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a + # NNModuleHooksDictVariable (a subclass of ConstDictVariable) to avoid any guard on the keys. + if ( + self.source + and name + in ( + "_forward_pre_hooks", + "_forward_hooks", + ) + and not tx.output.side_effects.has_pending_mutation_of_attr(self, name) + ): + hooks_dict = getattr(self.value, name) + hooks_dict_source = AttrSource(self.source, name) + install_guard(hooks_dict_source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + tx.output.guard_on_key_order.add(hooks_dict_source) + + def build_key_value(i, k, v): + # Make key sourceless to avoid any guard on it + key = variables.ConstantVariable.create(k) + + # Instead of using dict[key] to access the value, use a dict[dict.keys()[index]] to access the + # value. This removes the reliance on the actual key value. + source_key = ConstDictKeySource(hooks_dict_source, i) + source_value = DictGetItemSource(hooks_dict_source, source_key) + value = LazyVariableTracker.create(v, source_value) + return key, value + + result = dict( + build_key_value(i, k, v) for i, (k, v) in enumerate(hooks_dict.items()) + ) + + return variables.NNModuleHooksDictVariable( + result, type(hooks_dict), source=hooks_dict_source + ) + return super().var_getattr(tx, name) + + def manually_trace_nn_module_getattr(self, tx: "InstructionTranslator", name): + """ + Dynamo tracing of nn.Module __getattr__ can be expensive if the model + has deep submodule hierarchy. Since the __getattr__ is stable, we can + directly look into the underlying datastructures. This saves a lot of + compilation time. + """ + name_vt = variables.ConstantVariable(name) + out = self.getattr_helper(tx, "_parameters", name_vt) + if out is None: + out = self.getattr_helper(tx, "_modules", name_vt) + if out is None: + out = self.getattr_helper(tx, "_buffers", name_vt) + if out is None: + raise_observed_exception(AttributeError, tx) + return out + + +class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable): + """ + Differentiates between builtin nn modules (e.g. torch.nn.Linear) and user defined nn modules. + """ + + def _wrap_source(self, attr_source): + # vt is already wrapped with the UnspecializedBuiltinNNModuleSource + return attr_source + + +class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable): + """ + Tracing behavior: trace into submodules and treat them as Unspecialized, do not + register parameters to the top-level, treat them as function inputs. + + Guards behavior: if 'skip_fsdp_guards', many guards that would be installed + by a vanilla UnspecializedNNModuleVariable are simply dropped, on the basis + that a user wrapping their model in FSDP(model) is already opting into a + requirement to not modify internal model state, which would already break FSDP without + compilation. + """ + + def __init__(self, value, **kwargs) -> None: + source = kwargs.get("source", None) + assert source is not None, ( + "FSDPManagedNNModule depends on having an accurate source to control guarding." + ) + + super().__init__(value=value, **kwargs) + self.source = source + + def _wrap_source(self, attr_source): + if not isinstance( + attr_source, (FSDPNNModuleSource, UnspecializedNNModuleSource) + ): + if torch._dynamo.config.skip_fsdp_guards: + return FSDPNNModuleSource(attr_source) + else: + return UnspecializedNNModuleSource(attr_source) + return attr_source diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/optimizer.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..21edb72d28a4647aeb386370aa41bd3772d4ff5f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/optimizer.py @@ -0,0 +1,408 @@ +# mypy: ignore-errors + +""" +This module implements variable tracking for PyTorch optimizers during Dynamo tracing. + +The OptimizerVariable class provides specialized handling for optimizer instances by: +- Optimizing the tracing of expensive optimizer initialization +- Managing optimizer state and parameter group tracking +- Handling tensor sources and guards for optimizer state tensors +- Supporting CUDA graph execution through static tensor address management +- Providing special handling for parameter gradients and optimizer state tensors + +Key features include: +- Efficient initialization tracing via _init_group optimization +- Automatic marking of optimizer state tensors as static for CUDA graphs +- Proper source tracking for parameter groups, gradients, and state tensors +- Guard installation for optimizer state structure +- Support for both CPU and GPU tensor handling +- Cleanup of static tensor references via finalizers + +The module integrates with Dynamo's broader tracing system while providing +optimizer-specific optimizations and safety guarantees. +""" + +import logging +import weakref +from typing import TYPE_CHECKING + +import torch +from torch._logging import getArtifactLogger +from torch.utils._pytree import tree_map_only + +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + ConstDictKeySource, + DictGetItemSource, + GetItemSource, + GlobalWeakRefSource, + GradSource, +) +from ..utils import GLOBAL_KEY_PREFIX +from .base import VariableTracker +from .constant import ConstantVariable +from .dicts import ConstDictVariable +from .lists import ListVariable +from .misc import GetAttrVariable +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +class ArgMappingException(Exception): + pass + + +class GuardInstallException(Exception): + pass + + +perf_hint_log = getArtifactLogger(__name__, "perf_hints") + + +def _is_static_for_cudagraphs(x): + from torch._inductor.cudagraph_trees import get_manager + + if x.is_cuda: + manager = get_manager(x.device.index, False) + is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None + if manager: + return ( + is_static_address + or manager.current_node._is_cuda_graph_recorded_tensor(x) + ) + else: + return is_static_address + else: + # Don't print a warning for non-cuda tensors + return True + + +class OptimizerVariable(UserDefinedObjectVariable): + _nonvar_fields = { + "grad_to_source", + "tensor_to_source", + "static_tensor_names", + *UserDefinedObjectVariable._nonvar_fields, + } + + def __init__( + self, + value, + grad_to_source=None, + static_tensor_names=None, + tensor_to_source=None, + **kwargs, + ) -> None: + super().__init__(value, **kwargs) + self.grad_to_source = grad_to_source or {} + self.tensor_to_source = tensor_to_source or {} + self.static_tensor_names = static_tensor_names or set() + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + """This is an optimization to avoid tracing the very slow initialization of the optimizer""" + if name == "_init_group": + try: + self.graph_break_if_pending_mutation(tx) + self.move_step_if_cpu() + py_args, py_kwargs = self.get_python_args(*args, **kwargs) + ret_val = self.value._init_group(*py_args, **py_kwargs) + self.map_sources_and_install_guards(tx) + self.update_list_args(tx, args, kwargs, py_args, py_kwargs) + # stash a weak_ptr to optimizer to invalidate code + # if the optimizer object dies + mangled_name = f"__optimizer_{id(self.value)}" + tx.store_global_weakref_by_id(mangled_name, self.value) + self.create_finalizer(tx) + + # This is currently safe only because the only actual `ret_val`s returned + # by the `_init_group` of existing optimizers are properties that are invariant + # to the input tensors (e.g. dtype, layout). Changing these would trigger a + # recompilation and hence never result in the wrong specialization of `ret_val`. + return ConstantVariable.create(ret_val) + except (ArgMappingException, GuardInstallException) as _: + # trace normally if we can't map args or install guards correctly + pass + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name): + # Note: this allows us to intercept the call in call_method + # in the typical case, we return a UserMethodVariable + # which will directly inline + if name in ("_init_group", "step"): + return GetAttrVariable(self, name, source=AttrSource(self.source, name)) + + if name == "param_groups": + from ..decorators import mark_static_address + + for group in self.value.param_groups: + for p in group["params"]: + mark_static_address(p) + + self._set_capturable(tx) + + return super().var_getattr(tx, name) + + def graph_break_if_pending_mutation(self, tx): + # If there are pending mutations on a parameter (due to using closure) + # then we need to graph break to allow the python version of the parameter + # to update, so that running _init_group will initialize the states with + # the correct values + for g in self.value.param_groups: + for p in g["params"]: + side_effects = tx.output.side_effects + variable = side_effects.id_to_variable.get(id(p), None) + if variable and side_effects.has_pending_mutation(variable): + from ..exc import Unsupported + + raise Unsupported("Pending mutation on parameter") + + def _set_capturable(self, tx): + from . import LazyVariableTracker + + # We only set capturable if params are on cuda + # and the state is not initialized + def safe_to_set_capturable(group): + all_uninitialized = True + all_gpu = True + + for p in group.get("params", []): + all_gpu &= p.is_cuda or p.is_xpu + all_uninitialized &= p not in self.value.state + + return "capturable" in group and all_uninitialized and all_gpu + + # track indices to not set so we don't need to + # in the variable tracker realize the whole state + # we handle guarding the state specially + for group in self.value.param_groups: + if safe_to_set_capturable(group): + group["capturable"] = True + + source = self.source and AttrSource(self.source, "param_groups") + param_groups_vt = LazyVariableTracker.realize_all( + VariableTracker.build(tx, self.value.param_groups, source) + ) + for param_group_vt in param_groups_vt.items: + key = ConstDictVariable._HashableTracker( + ConstantVariable.create("capturable") + ) + param_group_vt.items[key] = ConstantVariable.create(True) + + def get_python_args(self, *args, **kwargs): + """Get python values equivalent to the variable tracker args""" + + def map_arg(arg): + if isinstance(arg, ConstantVariable): + return arg.as_python_constant() + elif isinstance(arg, ListVariable) and not arg.items: + return [] + elif ( + isinstance(arg, ConstDictVariable) + and isinstance(arg.source, GetItemSource) + and isinstance(arg.source.base, AttrSource) + and arg.source.base.member == "param_groups" + ): + return self.value.param_groups[arg.source.index] + + raise ArgMappingException + + new_args = [map_arg(arg) for arg in args] + new_kwargs = {k: map_arg(v) for k, v in kwargs.items()} + + return new_args, new_kwargs + + # If users load an old state dictionary, + # it's possible that step could be on the cpu + # if this is the case, move it to the GPU + # corresponding to the parameter + # in most cases this is a no-op because the state is empty + def move_step_if_cpu(self): + for p, state in self.value.state.items(): + if "step" in state and state["step"].is_cpu: + state["step"] = state["step"].to(p.device) + + def map_sources_and_install_guards(self, tx): + from ..decorators import mark_static_address + from .lazy import LazyVariableTracker + + self.grad_to_source = {} + self.tensor_to_source = {} + + # Tracing the _init_group is expensive. But we still have to insert the + # necessary guards for _init_group. So, we manually handle insertion of + # guards. We also want to mark all the tensors inside the state dict to + # be static address. + + # Mark all the tensors in the state dict to be static address. This has + # to be done first because the variable builder relies on the static + # address annotation. + def mark_static(x): + mark_static_address(x) + + tree_map_only(torch.Tensor, mark_static, self.value.state) + + # Recursively realize the variable trackers for optim.state and + # optim.param_groups, which recursively install the necessary guards. + params_groups_source = self.source and AttrSource(self.source, "param_groups") + param_groups_vt = LazyVariableTracker.realize_all( + VariableTracker.build(tx, self.value.param_groups, params_groups_source) + ) + + state_source = self.source and AttrSource(self.source, "state") + + state_vt = VariableTracker.build(tx, self.value.state, state_source) + + # We need to realize the top level state dict to populate + # the guard locals + state_vt.realize() + tx.output.guard_on_key_order.add(state_source) + + # Populate self.grad_to_source and self.tensor_to_source so that we can + # manually update_list_args + for group, group_vt in zip(self.value.param_groups, param_groups_vt.items): + # we assume here that all params within a param group + # are initialized similarly + if len(group["params"]) > 0: + for param in group["params"]: + if param.grad is not None: + key_index = None + for i, k in enumerate(self.value.state.keys()): + if k is param: + key_index = i + break + if key_index: + LazyVariableTracker.realize_all( + VariableTracker.build( + tx, + self.value.state[param], + DictGetItemSource( + state_source, + ConstDictKeySource(state_source, key_index), + ), + ) + ) + break + + params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params")) + all_static = True + non_static_grads = [] + for p_ind, (p, p_vt) in enumerate( + zip(group["params"], params_vt.unpack_var_sequence(tx)) + ): + param_source = p_vt.source + self.tensor_to_source[p] = param_source + grad_source = GradSource( + param_source, + "grad", + ) + + if p.grad is not None: + self.grad_to_source[p.grad] = grad_source + if not _is_static_for_cudagraphs(p.grad): + all_static = False + non_static_grads.append(grad_source) + else: + install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH)) + + # Note: to avoid spam logs only warn if perf hint artifact is enabled + # (NB: artifacts are only enabled at the debug or warning level) + if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG): + non_static_grads = [src.name() for src in non_static_grads] + perf_hint_log.warning( + ( + "Grad tensors %s will be copied during cudagraphs execution." + "If using cudagraphs and the grad tensor addresses will be the same across runs," + " use torch._dynamo.decorators.mark_static_address to elide this copy.", + ), + non_static_grads, + ) + + # We have to again iterate over the state dict to collect the + # tensor_to_source dict. This is used for the finalizer. + for idx, (p, value) in enumerate(self.value.state.items()): + p_state_source = DictGetItemSource( + state_source, ConstDictKeySource(state_source, idx) + ) + tx.output.guard_on_key_order.add(p_state_source) + for inner_idx, (k, v) in enumerate(value.items()): + if ( + isinstance(v, torch.Tensor) + and v not in self.grad_to_source + and v not in self.tensor_to_source + ): + self.tensor_to_source[v] = DictGetItemSource( + p_state_source, ConstDictKeySource(p_state_source, inner_idx) + ) + + def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): + """Wrap state tensor in a TensorVariable""" + from ..decorators import mark_static_address + + # If we have a source for a tensor already use it, + # if we have not seen a tensor before, stash and use a + # global weak ref source, since it must be an optimizer tensor + # that we have missed + + if tensor_value in self.tensor_to_source: + # mark these tensors as static for cudagraphs + mark_static_address(tensor_value) + source = self.tensor_to_source[tensor_value] + self.static_tensor_names.add(tx.output.module_key_name(source.name())) + elif tensor_value in self.grad_to_source: + source = self.grad_to_source[tensor_value] + else: + # mark these tensors as static for cudagraphs + mark_static_address(tensor_value) + + global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) + source = GlobalWeakRefSource(global_name) + self.static_tensor_names.add(tx.output.module_key_name(source.name())) + + return VariableTracker.build(tx, tensor_value, source) + + def update_list_args( + self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs + ): + """Update the args and kwargs to the traced optimizer call""" + for arg, py_arg in zip(args, py_args): + if isinstance(arg, ListVariable): + assert isinstance(py_arg, list), ( + "py_arg should be a list in optimizer variable" + ) + for i, val in enumerate(py_arg): + tx.output.side_effects.mutation(arg) + if isinstance(val, torch.Tensor): + arg.items.append(self.wrap_tensor(tx, val)) + else: + source = arg.source and GetItemSource(arg.source, i) + arg.items.append(VariableTracker.build(tx, val, source)) + + def create_finalizer(self, tx): + names_to_delete = self.static_tensor_names + value = self.value + tc = tx.output.tracing_context + + def init_finalizer(gm): + def clear_static_tensor_refs(): + for name in names_to_delete: + gm._buffers.pop(name, None) + gm._parameters.pop(name, None) + if tc.params_flat: + tc.params_flat.clear() + if tc.params_flat_unwrap_subclasses: + tc.params_flat_unwrap_subclasses.clear() + + weakref.finalize(value, clear_static_tensor_refs) + + tx.output.add_graph_finalizer(init_finalizer) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/script_object.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/script_object.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6a616ae1eda912842002dcc5f57b54370d0d5d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/script_object.py @@ -0,0 +1,103 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +""" +This module implements variable tracking for TorchScript objects during Dynamo tracing. + +The TorchScriptObjectVariable class provides specialized handling for TorchScript +objects with strong safety guarantees by: +- Enforcing method-call-only access to prevent unsafe attribute manipulation +- Converting graph breaks into hard errors via _raise_hard_error_if_graph_break +- Proper proxy and source tracking for TorchScript method calls +- Integration with higher-order operators for method call handling + +Key safety features: +- Strict validation that only method calls are allowed (no direct attribute access) +- Immediate error reporting for potentially unsafe operations +- Proper source tracking for debugging and guard installation +- Safe handling of TorchScript object method calls through torchbind + +The module ensures that TorchScript objects are handled safely during tracing +by limiting operations to known-safe patterns and failing fast for unsafe usage. +""" + +import functools + +import torch + +from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported +from .base import VariableTracker +from .user_defined import UserDefinedObjectVariable + + +def _raise_hard_error_if_graph_break(reason): + def deco(fn): + @functools.wraps(fn) + def graph_break_as_hard_error(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Unsupported as e: + raise UnsafeScriptObjectError(e.msg) from e + + return graph_break_as_hard_error + + return deco + + +class TorchScriptObjectVariable(UserDefinedObjectVariable): + _fake_script_object_cache: dict[int, "TorchScriptObjectVariable"] = {} + + @classmethod + def is_matching_cls(cls, user_cls: type): + return issubclass(user_cls, torch.ScriptObject) + + @staticmethod + def create(proxy, value, **options): + return TorchScriptObjectVariable(proxy, value, **options) + + def __init__(self, proxy, value, source, **kwargs) -> None: + super().__init__(value, **kwargs) + self.proxy = proxy + self.proxy.node.meta["example_value"] = value + self.source = source + + def as_proxy(self): + return self.proxy + + @_raise_hard_error_if_graph_break( + "Dynamo cannot safely trace script object due to graph break." + ) + def var_getattr(self, tx, name: str) -> VariableTracker: + from torch._higher_order_ops.torchbind import call_torchbind + + from ..source import AttrSource + from .higher_order_ops import TorchHigherOrderOperatorVariable + + method = getattr(self.value, name, None) + if method is None: + unimplemented( + f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?" + ) + + if not callable(method): + unimplemented( + "Only method calls on TorchScript objects can be supported safely." + " Please use method calls instead of attribute access." + ) + + return TorchHigherOrderOperatorVariable.make( + call_torchbind, + source=AttrSource(self.source, name), + script_obj_var=self, + method_name=name, + ) + + # We only support method calls on script objects. Interpreting the bytecodes + # should go through var_getattr then call_function instead of call_method. + # + # However, it's possible for call_method to be used directly e.g. for __setattr__. + @_raise_hard_error_if_graph_break( + "Dynamo cannot safely trace script object due to graph break." + ) + def call_method(self, tx, name, args, kwargs): + unimplemented(f"call method {name} on script object is not safe.") diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/sdpa.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..eb9bf987634e4b2b111f5f06c0ad204952538810 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/sdpa.py @@ -0,0 +1,78 @@ +# mypy: ignore-errors + +from inspect import getattr_static +from typing import TYPE_CHECKING + +from ..bytecode_transformation import create_call_function +from ..exc import Unsupported +from ..source import AttrSource +from .base import VariableTracker + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + +PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split() + + +class SDPAParamsVariable(VariableTracker): + """Represents the c++ params struct for scaled dot product attention. + This is a read-only container.""" + + @staticmethod + def create(tx: "InstructionTranslator", value, source): + from torch.backends.cuda import SDPAParams + + from .torch import TorchInGraphFunctionVariable + + params = [ + VariableTracker.build(tx, getattr(value, p), AttrSource(source, p)) + for p in PARAM_NAMES + ] + return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {}) + + def __init__(self, proxy, param_vars, **kwargs) -> None: + self.proxy = proxy + self.param_vars = param_vars + super().__init__(**kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + assert self.source is None + assert self.param_vars is not None + codegen.add_push_null( + lambda: codegen.load_import_from("torch._C", "_SDPAParams") + ) + codegen.foreach(self.param_vars) + codegen.extend_output(create_call_function(len(self.param_vars), False)) + + def as_proxy(self): + return self.proxy + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + import torch._C + + from .builder import wrap_fx_proxy + from .misc import GetAttrVariable + + try: + getattr_static(torch._C._SDPAParams, name) + except AttributeError: + # Using raise from is too verbose here + raise Unsupported( + f"Unsupported torch._C._SDPAParams attribute {name}" + ) from None + + proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) + if self.source is not None: + return wrap_fx_proxy( + tx=tx, proxy=proxy, source=AttrSource(self.source, name) + ) + else: + return wrap_fx_proxy(tx=tx, proxy=proxy) + + @staticmethod + def is_sdpa_params(value): + from torch.backends.cuda import SDPAParams + + return value is SDPAParams diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/tensor.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b2a4ff604305b917b28f6cebb2cb4ab1b3bac5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/tensor.py @@ -0,0 +1,1745 @@ +# mypy: ignore-errors + +""" +This module contains variable tracker classes for handling tensors and tensor-related operations in Dynamo. + +The main class is TensorVariable which represents torch.Tensor inputs and intermediate values in the FX graph. +It handles tensor operations, method calls, and maintains metadata about tensor properties like dtype, device, etc. + +Other key classes include: +- SymNodeVariable: Represents symbolic scalars (int/float/bool) used for size computation and unspecialized values +- NumpyNdarrayVariable: Handles numpy array interop through torch._numpy +- UnspecializedPythonVariable: Represents unspecialized Python numeric values as 1-element tensors +- TensorSubclassVariable: Handles tensor subclasses with __torch_function__ overrides +- UntypedStorageVariable: Represents tensor storage objects +- DataPtrVariable: Handles tensor data pointer operations + +These classes work together to track tensor operations and properties during Dynamo's tracing process. +""" + +import functools +import logging +import operator +import textwrap +import traceback +import types +import unittest +from typing import TYPE_CHECKING + +import sympy + +import torch._numpy as tnp +import torch.fx +import torch.random +from torch._dynamo import compiled_autograd +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental.symbolic_shapes import ( + guard_scalar, + GuardOnDataDependentSymNode, + has_free_symbols, + is_symbolic, + SymTypes, +) +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config, graph_break_hints, variables +from .._trace_wrapped_higher_order_op import trace_wrapped +from ..exc import ( + unimplemented_v2, + UnknownPropertiesDuringBackwardTrace, + UserError, + UserErrorType, +) +from ..external_utils import call_hook_from_backward_state +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource +from ..utils import ( + fqn, + get_custom_getattr, + get_fake_value, + get_real_value, + guard_if_dyn, + object_has_getattribute, + product, + proxy_args_kwargs, + set_example_value, + tensortype_to_dtype, +) +from .base import AttributeMutationNew, VariableTracker +from .constant import ConstantVariable +from .lists import SizeVariable +from .user_defined import UserDefinedClassVariable + + +try: + import numpy as np +except ModuleNotFoundError: + np = None + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) + +# Ops that allow tensor tensor +supported_tensor_comparison_ops = { + ">": operator.gt, + "<": operator.lt, + ">=": operator.ge, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + "is": operator.is_, + "is not": operator.is_not, +} +# Ops that allow tensor None +supported_const_comparison_ops = { + "is": operator.is_, + "is not": operator.is_not, + "==": operator.eq, + "!=": operator.ne, +} +supported_comparison_ops = { + **supported_tensor_comparison_ops, + **supported_const_comparison_ops, +} +supported_tensor_comparison_op_values = dict.fromkeys( + supported_tensor_comparison_ops.values() +) +supported_const_comparison_op_values = dict.fromkeys( + supported_const_comparison_ops.values() +) + + +def is_bound_tensor_method(value): + return ( + callable(value) + and not torch._dynamo.utils.object_has_getattribute(value) + and hasattr(value, "__self__") + and isinstance(value.__self__, torch.Tensor) + and getattr(value.__self__, value.__name__, None) + ) + + +# instead of using inspect.getattr_static, we directly lookup the appropriate +# dicts. It is necessary to keep the torch._C.TensorBase first in the or +# operation, because the second arg takes priority in or operation when there +# are common keys. +all_tensor_attrs = torch._C.TensorBase.__dict__ | torch.Tensor.__dict__ + + +class TensorVariable(VariableTracker): + """A torch.Tensor input or an intermediate value in the FX graph""" + + _nonvar_fields = { + "proxy", + "dtype", + "device", + "layout", + "ndim", + "size", + "stride", + "requires_grad", + "is_quantized", + "is_contiguous", + "is_nested", + "is_sparse", + "class_type", + "specialized_value", + "_is_name_set", + *VariableTracker._nonvar_fields, + } + + def get_real_value(self): + """ + Get the actual value represented by this variable if computation is run + using the user-provided inputs. + NOTE: this runs actual tensor computation and may be + slow and memory-intensive. + """ + return get_real_value(self.proxy.node, self.proxy.tracer) + + def __init__( + self, + proxy: torch.fx.Proxy, + *, + dtype, + device, + layout, + ndim, + requires_grad, + is_nested, + is_quantized, + is_sparse, + class_type, + has_grad_fn, + _size=None, + stride=None, + is_contiguous=None, + _is_name_set=None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.proxy = proxy + self.dtype = dtype + self.device = device + self.layout = layout + self.ndim = ndim + self._size = _size # this is accessed as a property for validation + self.stride = stride + self.requires_grad = requires_grad + self.is_quantized = is_quantized + self.is_contiguous = is_contiguous + self.is_nested = is_nested + self.is_sparse = is_sparse + self.class_type = class_type + self.has_grad_fn = has_grad_fn + if _is_name_set is None: + # no need to rename inputs + _is_name_set = self.proxy.node.op == "placeholder" + self._is_name_set: bool = _is_name_set + + def debug_repr(self): + # TODO: strip off fake tensor from repr here + return repr(self.proxy.node.meta["example_value"]) + + def as_proxy(self): + return self.proxy + + def python_type(self): + return self.class_type + + @staticmethod + def specialize(value: torch.Tensor): + props = { + "dtype": value.dtype, + "device": value.device, + "layout": value.layout, + "ndim": int(value.ndim), + "requires_grad": value.requires_grad, + "is_nested": value.is_nested, + "is_quantized": value.is_quantized, + "is_sparse": value.is_sparse, + "class_type": type(value), + } + try: + props["has_grad_fn"] = value.grad_fn is not None + except Exception: + # Workaround for issues with create_parameter_op in Dynamo. Reading + # grad_fn should never cause an issue. + props["has_grad_fn"] = False + + if is_sparse_any(value) and not has_free_symbols(value): + props["_size"] = tuple( + [int(s) if is_symbolic(s) else s for s in value.size()] + ) + elif not has_free_symbols(value): + # this is a fully static shape, and the keys on props here inform specialization. + # We have to cast to int here, because these might get accessed as ConstantVariable, which has + # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant + # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for + # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and + # I'd like to keep it around for now. + props["_size"] = tuple( + # the non is_symbolic case applies to the jagged layout + # NestedTensor case as singleton ints are not symbolic + [int(s) if is_symbolic(s) else s for s in value.size()] + ) + props["stride"] = tuple(value.stride()) + if torch._C._functorch.is_batchedtensor(value): + # Batched tensors does not support contiguity patterns, so + # we refrain from computing the `is_contiguous` property + props["is_contiguous"] = None + else: + props["is_contiguous"] = tuple( + [ + x + for x in torch._prims_common._memory_formats + if value.is_contiguous(memory_format=x) + ] + ) + return props + + def dynamic_getattr(self, tx: "InstructionTranslator", name): + fake_val = self.proxy.node.meta["example_value"] + # For getattrs on tensors without sources, + # we can do better than the default (creating a GetAttrVariable) + # if: + # (1) the tensor is a traceable tensor subclass + # (2) We are getattr'ing an inner tensor from that subclass + if not self.source and is_traceable_wrapper_subclass(fake_val): + attrs, _ctx = fake_val.__tensor_flatten__() + proxy = getattr(self.as_proxy(), name) + example_value = getattr(fake_val, name) + if name in attrs: + # attrs returned from tensor_flatten are always tensors + assert isinstance(example_value, torch.Tensor) + from .builder import wrap_fx_proxy + + return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value) + # any other attributes on the subclass (that are not methods) + # are assumed to be constant metadata. + elif not callable(example_value): + return VariableTracker.build(tx, example_value) + + if not (self.source and self.source.subguards_allowed()): + raise NotImplementedError + + # For local source, we associate the real value. We use this real value + # for implementing getattr fallthrough on the variable tracker base class. + + # Note - this scope construction is mirrored in guards + # A subsequent PR will introduce a util. + scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} + try: + # We raise in case we get a typerror bug w/ SuperSource. + # SuperSource has bugs in it atm, and can produce code like + # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__, + # L['mod'].model.model.encoder.embed_positions)", scope) + # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope. + _input_associated_real_value = eval(self.source.name(), scope) + except Exception as exc: + raise NotImplementedError from exc + + if _input_associated_real_value is None: + raise NotImplementedError + + if object_has_getattribute(_input_associated_real_value): + raise NotImplementedError + + if get_custom_getattr(_input_associated_real_value): + raise NotImplementedError + + real_value = getattr(_input_associated_real_value, name) + + attr_source = AttrSource(self.source, name) + install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) + + # Typically we'd want to use variable builder here + # but unfortunately id(real_value.__self__) is not id() + if is_bound_tensor_method(real_value): + from .misc import GetAttrVariable + + return GetAttrVariable( + self, name, source=attr_source, py_type=type(real_value) + ) + + return VariableTracker.build(tx, real_value, attr_source) + + def method_attr_ndim(self, tx): + if self.ndim is not None: + return ConstantVariable.create(self.ndim) + else: + return self.call_method(tx, "dim", [], {}) + + def method_attr_dtype(self, tx): + if self.dtype is not None: + return ConstantVariable.create(self.dtype) + + def method_attr_device(self, tx): + if self.device is not None: + return ConstantVariable.create(self.device) + + def method_attr_layout(self, tx): + if self.layout is not None: + return ConstantVariable.create(self.layout) + + def method_attr_is_cuda(self, tx): + if self.device is not None: + return ConstantVariable.create(self.device.type == "cuda") + + def method_attr_shape(self, tx): + if self.valid_size(): + sizes = [variables.ConstantVariable.create(x) for x in self.size] + return SizeVariable(sizes) + else: + return self.call_method(tx, "size", [], {}) + + def method_attr_requires_grad(self, tx): + if self.requires_grad is not None: + return ConstantVariable.create(self.requires_grad) + + def method_attr_is_quantized(self, tx): + if self.is_quantized is not None: + return ConstantVariable.create(self.is_quantized) + + def method_attr_is_sparse(self, tx): + if self.is_sparse is not None: + return ConstantVariable.create(self.is_sparse) + + def method_attr_is_nested(self, tx): + if self.is_nested is not None: + return ConstantVariable.create(self.is_nested) + + def method_attr_retain_grad(self, tx): + unimplemented_v2( + gb_type="Tensor.retain_grad() with AOTDispatcher", + context=f"var_getattr {self} retain_grad", + explanation="`Tensor.retain_grad()` does not work with AOTDispatcher.", + hints=[], + ) + + def method_attr_data(self, tx): + return variables.TorchInGraphFunctionVariable( + torch._C._autograd._get_data_attr + ).call_function(tx, [self], {}) + + def method_attr_grad_fn(self, tx): + if self.has_grad_fn: + unimplemented_v2( + gb_type="Tensor with grad_fn()", + context=f"var_getattr {self} grad_fn", + explanation="Dynamo does not support tracing tensors with a grad_fn directly.", + hints=[], + ) + else: + return variables.ConstantVariable(None) + + def method_attr__version(self, tx): + from ..tensor_version_op import _tensor_version + + return variables.TorchInGraphFunctionVariable(_tensor_version).call_function( + tx, [self], {} + ) + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + from . import GetAttrVariable + from .builtin import BuiltinVariable + + # TODO - This is not a good solution but solves an accuracy issue. + # Today, var_getattr returns GetAttrVariable for both non-existent + # attributes and existing attributes. This is a bug and requires more + # deep dive. + if name in ("size", "stride"): + return ConstantVariable(True) + + try: + var = BuiltinVariable(getattr).call_function( + tx, [self, ConstantVariable(name)], {} + ) + # in the event that TensorVariable returns NotImplemented + # BuiltinVariable.call_getattr returns GetAttrVariable + ret_val = not isinstance(var, GetAttrVariable) + except AttributeError: + ret_val = False + + if self.source: + install_guard( + AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) + ) + + return ConstantVariable(ret_val) + + def var_getattr(self, tx: "InstructionTranslator", name): + if self.is_strict_mode(tx): + if name in self._strict_mode_banned_ops(): + unimplemented_v2( + gb_type="Strict mode banned op", + context=f"var_getattr {self} {name}", + explanation=f"Getattr invocation '{name}' in strict mode is not supported.", + hints=[ + f"Remove `{name}` from the list of banned ops by " + "setting `torch._dynamo.config._autograd_backward_strict_mode_banned_ops`.", + ], + ) + elif name in self._strict_mode_conditional_banned_ops(): + raise UnknownPropertiesDuringBackwardTrace( + f"Unknown property {name} during speculating backward, dynamo will insert contiguous call ahead and speculate it again" # noqa: B950 + ) + + if name == "__class__": + return UserDefinedClassVariable(self.python_type()) + + handler = getattr(self, f"method_attr_{name}", None) + result = handler(tx) if handler is not None else None + + # Add a guard for type matching, these guards are checked before tensor guards + # In some cases, a . guard can be evaluated first, and break if + # is later changed to another type + if ( + result is not None + and self.source + and self.source.subguards_allowed() + and not ( + name not in ("grad", "requires_grad") and result.is_python_constant() + ) + ): + install_guard(self.make_guard(GuardBuilder.TYPE_MATCH)) + result.source = AttrSource(self.source, name) + + # It's hard to get inplace view (metadata mutation) on graph input work properly across + # dynamo/aot/inductor, just fall back. + if self.source is not None and hasattr(torch.ops.aten, name): + fn = getattr(torch.ops.aten, name) + if ( + hasattr(fn, "overloads") + and hasattr(fn, fn.overloads()[0]) + and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags + ): + # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc. + return variables.misc.DelayGraphBreakVariable( + source=AttrSource(self.source, name), + msg="Getting an inplace view on a graph input is not supported", + ) + + # For attributes (not methods) that were not caught in the special handling above, + # (e.g. tensor.real), we handle these generically, assuming that the output type is + # a tensor. + if result is None and name != "grad": + + def try_generic_attr_handling(): + from .builder import wrap_fx_proxy + from .misc import GetAttrVariable + + static_attr = all_tensor_attrs.get(name, None) + if static_attr is None: + return None + + # Make sure this is an attribute, not a method. + # type(torch.Tensor.H) should be "getset_descriptor" + # This is a because of CPython implementation, see THPVariableType: + # these attributes are implemented under tp_getset, which appear + # as `getset_descriptor`s, (compared to, say, methods which appear + # as `method_descriptor`s) + if type(static_attr) != types.GetSetDescriptorType: + return None + + proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) + if self.source is not None: + return wrap_fx_proxy( + tx=tx, proxy=proxy, source=AttrSource(self.source, name) + ) + else: + return wrap_fx_proxy(tx=tx, proxy=proxy) + + result = try_generic_attr_handling() + + if result is None: + result = self.dynamic_getattr(tx, name) + + if result is None: + raise NotImplementedError + return result + + def call_id(self, tx): + if not self.source: + unimplemented_v2( + gb_type="Unsupported call_id() without source", + context=f"call_id {self}", + explanation="call_id() not supported for sourceless TensorVariable.", + hints=[], + ) + + # For local source, we associate the real value. We use this real value + scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} + try: + _input_associated_real_value = eval(self.source.name(), scope) + except Exception as exc: + unimplemented_v2( + gb_type="Error getting associated real value", + context=f"call_id {self}", + explanation="Dynamo encountered an error while trying to " + "get the associated real value.", + hints=[], + from_exc=exc, + ) + + if _input_associated_real_value is None: + unimplemented_v2( + gb_type="call_id() without associated real value", + context=f"call_id {self}", + explanation="Dynamo could not find an associated real value for the tensor.", + hints=[], + ) + + install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) + id_value = id(_input_associated_real_value) + return ConstantVariable.create(id_value) + + def has_unpack_var_sequence(self, tx): + return self.ndim > 0 + + def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None): + from .builder import wrap_fx_proxy_cls + + if self.valid_size(): + size_len = len(self.size) + else: + size_var = self.call_method(tx, "size", [], {}) + assert isinstance(size_var, SizeVariable) + size_len = len(size_var.items) + # Ensure we don't unpack a scalar tensor. + assert size_len != 0, "Can't unpack scalar tensors." + + if self.valid_size(): + length = self.size[0] + else: + dyn_length = self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through + # symbolic_shapes, but that end up as int/sympy.Integer + assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable)) + if isinstance(dyn_length, SymNodeVariable): + length = dyn_length.evaluate_expr(tx.output) + else: + length = dyn_length.value + + if idxes is None: + idxes = range(length) + else: + assert len(idxes) == length, ( + f"Can't unpack a tensor of {length} rows into a tuple of {len(idxes)} elements." + ) + return [ + wrap_fx_proxy_cls(target_cls=type(self), tx=tx, proxy=self.as_proxy()[i]) + for i in idxes + ] + + def valid_size(self): + return self._size is not None + + @property + def size(self): + assert self._size is not None, "accessing None size in TensorVariable" + return self._size + + def _strict_mode_banned_ops(self): + return torch._dynamo.config._autograd_backward_strict_mode_banned_ops + + def _strict_mode_conditional_banned_ops(self): + return ( + torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import SourcelessBuilder, VariableBuilder + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + + if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): + unimplemented_v2( + gb_type="Illegal method invocation in strict mode", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo currently does not support this method " + f"({name}) invocation in strict mode.", + hints=[], + ) + + # Only override builtin tensor methods + # The user can manually add override handling + # with a decorator for other methods (e.g. a dispatch subclass with other methods) + static_attr = all_tensor_attrs.get(name, None) + is_base_tensor_method = static_attr is not None + + if ( + can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs) + and is_base_tensor_method + ): + if self.source: + func_var = VariableBuilder( + tx, AttrSource(AttrSource(self.source, "__class__"), name) + )(static_attr) + else: + func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name)) + + return dispatch_torch_function( + tx, func_var, tuple([self] + list(args)), kwargs + ) + + """ + Dispatch to a method-specific handler defined below. If the + handler returns None (or doesn't exist) we put the method call + in the graph. + """ + + # This is seen in inspect signature where we check if the value is a default value + if name == "__eq__" and isinstance(args[0], UserDefinedClassVariable): + return variables.ConstantVariable(False) + + # For historical reasons, these ops decompose down to syntactically + # invalid aten ops because they contain the python keyword `from`, see + # discussions in #151432 for more details. + # We graph break for now since this use case is uncommon. + if name == "random_": + unimplemented_v2( + gb_type="Tensor.random_ op", + context=f"Tensor.{name}({args=}, {kwargs=})", + explanation="This is currently not supported.", + hints=[ + "Use the out-of-place version of this op", + *graph_break_hints.SUPPORTABLE, + ], + ) + elif name == "uniform_" and "from" in kwargs: + unimplemented_v2( + gb_type="Tensor.uniform_ op called with `from` keyword", + context=f"Tensor.{name}({args=}, {kwargs=})", + explanation="This is currently not supported.", + hints=[ + "Avoid using the `from` keyword.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + try: + handler_method = getattr(self, f"method_{name}") + except AttributeError: + pass + else: + try: + result = handler_method(*args, **kwargs) + if result: + return result + except TypeError as e: + unimplemented_v2( + gb_type="Unhandled args for method", + context=f"call_method {self} {name} {args} {kwargs}", + explanation="Dynamo encountered an error while calling " + f"the method `{name}`.", + hints=[], + from_exc=e, + ) + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self, *args], kwargs), + ), + ) + + def method_size(self, *args, **kwargs): + return self._method_size_stride("size", *args, **kwargs) + + def method_stride(self, *args, **kwargs): + return self._method_size_stride("stride", *args, **kwargs) + + def _method_size_stride(self, name, dim=None): + dim = guard_if_dyn(dim) + + def make_const_size_variable(x, **options): + return SizeVariable( + [ConstantVariable.create(y, **options) for y in x], **options + ) + + RetVariable = ( + make_const_size_variable if name == "size" else ConstantVariable.create + ) + + # Technically, this should not be necessary, but I'm including it + # for enhanced BC, in case example_value is sometimes not set + # (it really should always be set though!) + if name != "size": + r = getattr(self, name) + elif name == "size" and self.valid_size(): + r = self.size + else: + r = None + + if r is not None: + if dim is None: + return RetVariable(r) + else: + return ConstantVariable.create(r[dim]) + + # It might still be constant! Consult the fake tensor and see + if (fake := self.proxy.node.meta.get("example_value")) is not None: + if dim is None: + fake_r = getattr(fake, name)() + if not has_free_symbols(fake_r): + # int conversion for safety, in case a SymInt refined + # to constant + return RetVariable(tuple(int(r) for r in fake_r)) + else: + fake_r = getattr(fake, name)(dim) + if not has_free_symbols(fake_r): + return ConstantVariable.create(int(fake_r)) + + def method_numel(self): + if self.valid_size(): + return ConstantVariable.create(product(self.size)) + + # It might still be constant! Consult the fake tensor and see + if (fake := self.proxy.node.meta.get("example_value")) is not None: + fake_r = fake.numel() + if not has_free_symbols(fake_r): + return ConstantVariable.create(int(fake_r)) + + method_nelement = method_numel + + def method_dim(self): + if self.ndim is not None: + return ConstantVariable.create(self.ndim) + + method_ndimension = method_dim + + def method_is_floating_point(self): + if self.dtype is not None: + return ConstantVariable.create(self.dtype.is_floating_point) + + def method_is_inference(self): + if config.fake_tensor_disable_inference_mode: + unimplemented_v2( + gb_type="Encountered tensor.is_inference() during tracing", + context="", + explanation="tensor.is_inference() is not supported", + hints=[ + *graph_break_hints.FUNDAMENTAL, + *graph_break_hints.INFERENCE_MODE, + ], + ) + if (fake := self.proxy.node.meta.get("example_value")) is not None: + return ConstantVariable.create(fake.is_inference()) + + def method_is_complex(self): + if self.dtype is not None: + return ConstantVariable.create(self.dtype.is_complex) + + def method_is_contiguous(self, memory_format=None): + memory_format = ( + memory_format.as_python_constant() + if memory_format is not None + else torch.contiguous_format + ) + if self.is_contiguous is not None: + return ConstantVariable.create(memory_format in self.is_contiguous) + elif (fake := self.proxy.node.meta.get("example_value")) is not None: + return ConstantVariable.create( + fake.is_contiguous(memory_format=memory_format) + ) + + def method_type(self, dtype=None, non_blocking=False, **kwargs): + if ( + dtype is None + and self.dtype is not None + and isinstance(self.device, torch.device) + ): + tensortype = next( + k for k, v in tensortype_to_dtype.items() if self.dtype in v + ) + if self.device.type == "cpu": + return ConstantVariable.create(f"torch.{tensortype.__name__}") + else: + return ConstantVariable.create( + f"torch.{self.device.type}.{tensortype.__name__}" + ) + elif ( + dtype is not None + and fqn(type(dtype.as_python_constant())) == "torch.tensortype" + ): + # torch.FloatTensor, etc. are all of type "torch.tensortype". + # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type. + # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args) + tensor_type = dtype.as_python_constant() + tensor_type_const = ConstantVariable.create(fqn(tensor_type)) + + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + + if non_blocking: + kwargs = {"non_blocking": non_blocking, **kwargs} + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + "type", + *proxy_args_kwargs([self, tensor_type_const], kwargs), + ), + ) + + def method_as_subclass(self, cls): + if isinstance(cls, TensorSubclassVariable) and cls.source: + from ..symbolic_convert import InstructionTranslator + from .torch_function import TensorWithTFOverrideVariable + + tx = InstructionTranslator.current_tx() + py_cls = cls.as_python_constant() + var = TensorWithTFOverrideVariable.from_tensor_var( + tx, self, py_cls, cls.source + ) + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var + unimplemented_v2( + gb_type="Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", + context=f"{self}.as_subclass({cls})", + explanation="Currently not supported", + hints=[ + "Avoid this call or move it outside `torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) + + def method_get_device(self): + if isinstance(self.device, torch.device): + index = self.device.index if self.device.type != "cpu" else -1 + return ConstantVariable.create(index) + + def method_element_size(self): + return ConstantVariable.create(self.dtype.itemsize) + + def method_numpy(self, *, force=False): + if not config.trace_numpy: + unimplemented_v2( + gb_type="Tensor.numpy() with trace_numpy=False", + context=f"call_method {self} numpy", + explanation="`Tensor.numpy()` was called, but the `trace_numpy` " + "configuration was manually disabled.", + hints=[ + "Set `torch._dynamo.config.trace_numpy = True` to allow " + "Dynamo to trace through NumPy.", + ], + ) + if not np: + unimplemented_v2( + gb_type="Tensor.numpy() without NumPy installed", + context=f"call_method {self} numpy", + explanation="`Tensor.numpy()` was called, but the NumPy library " + "is not available in the current environment.", + hints=[ + "Ensure NumPy is installed in your Python environment.", + ], + ) + if self.layout != torch.strided: + raise TypeError( + f"can't convert {self.layout} layout tensor to numpy. Use Tensor.to_dense() first" + ) + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + # We don't check that the tensor is on CPU when force is False, as this + # allows us to execute NumPy code on CUDA. Same for requires_grad=True + if force and force.as_python_constant(): + # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...) + t = self.call_method(tx, "detach", [], {}) + proxy = tx.output.create_proxy("call_method", "cpu", (t.as_proxy(),), {}) + else: + # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable + proxy = tx.output.create_proxy( + "call_method", "view_as", *proxy_args_kwargs([self, self], {}) + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def method_tolist(self): + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + + def tolist(tensor, sub_proxy): + def wrap(i, sub_proxy): + # Sigh, we forgot to gate this, so this data dependent is on + # by default and is load bearing in CI + with unittest.mock.patch.object( + tx.fake_mode, "allow_scalar_outputs", True + ): + return wrap_fx_proxy( + tx, + sub_proxy.item(), + ) + + if tensor.dtype not in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ]: + unimplemented_v2( + gb_type="Tensor.tolist() with non-integer tensor", + context=f"call_method {self} to_list", + explanation="Dynamo currently does not support tracing " + "`tolist()` on non-integer tensors.", + hints=[ + "Ensure the input tensor to `tolist()` is an integer " + "type (e.g., int8, int16, int32, int64)." + ], + ) + + if tensor.dim() == 0: + return wrap(tensor, sub_proxy) + + if tensor.dim() == 1: + return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)] + + return [ + tolist(sub_tensor, sub_proxy=sub_proxy[i]) + for i, sub_tensor in enumerate(tensor) + ] + + tensor = self.as_proxy().node.meta["example_value"] + out = tolist(tensor, self.as_proxy()) + return VariableTracker.build(tx, out) + + def method_backward(self, *args, **kwargs): + unimplemented_v2( + gb_type="Unsupported Tensor.backward() call", + context=f"call_method {self} backward {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.backward()`.", + hints=[*graph_break_hints.FUNDAMENTAL], + ) + + def method_data_ptr(self, *args, **kwargs): + return DataPtrVariable(self) + + def method_item(self, *args, **kwargs): + if not config.capture_scalar_outputs: + self._warn_capture_scalar_outputs() + unimplemented_v2( + gb_type="Unsupported Tensor.item() call with capture_scalar_outputs=False", + context=f"call_method {self} item {args} {kwargs}", + explanation="Dynamo does not support tracing `Tensor.item()` " + "with config.capture_scalar_outputs=False.", + hints=[ + "Set `torch._dynamo.config.capture_scalar_outputs = True` " + "or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` " + "to include these operations in the captured graph.", + ], + ) + + def method___getitem__(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + from .builder import wrap_fx_proxy + + tx = InstructionTranslator.current_tx() + if isinstance(args[0], SymNodeVariable): + # Standard indexing will force specialization due to + # __index__. Rewrite as a regular torch op which will + # trace fine + fn, args = ( + torch.select, + [ + variables.ConstantVariable.create(0), + args[0], + ], + ) + else: + fn = operator.getitem + + proxy = tx.output.create_proxy( + "call_function", + fn, + *proxy_args_kwargs([self] + list(args), kwargs), + ) + + return wrap_fx_proxy(tx, proxy) + + @staticmethod + @functools.cache + def _warn_capture_scalar_outputs(): + user_stack = torch._guards.TracingContext.extract_stack() + user_stack_formatted = "".join(traceback.format_list(user_stack)) + log.warning( + textwrap.dedent( + """\ + Graph break from `Tensor.item()`, consider setting: + torch._dynamo.config.capture_scalar_outputs = True + or: + env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 + to include these operations in the captured graph. + + Graph break: from user code at: + %s + """ + ), + user_stack_formatted, + ) + + def method___len__(self): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + return self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + + def method_addcmul_(self, tensor1, tensor2, *, value=None): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if value is not None: + from .. import polyfills + + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.addcmul_inplace), + [self, tensor1, tensor2, value], + {}, + ) + + def method___setitem__(self, key, value): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + proxy = tx.output.create_proxy( + "call_function", + operator.setitem, + *proxy_args_kwargs([self, key, value], {}), + ) + + if config.use_graph_deduplication or config.track_nodes_for_deduplication: + tx.output.region_tracker.add_node_mutation(proxy.node, 0) + + return ConstantVariable.create(None) + + def method_resize_(self, *args, **kwargs): + unimplemented_v2( + gb_type="Unsupported Tensor.resize_() call", + context=f"call_method {self} resize_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.resize_()`.", + hints=[], + ) + + def method_resize_as_(self, *args, **kwargs): + unimplemented_v2( + gb_type="Unsupported Tensor.resize_as_() call", + context=f"call_method {self} resize_as_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.resize_as_()`.", + hints=[], + ) + + def method_sparse_resize_(self, *args, **kwargs): + unimplemented_v2( + gb_type="Unsupported Tensor.sparse_resize_() call", + context=f"call_method {self} sparse_resize_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_()`.", + hints=[], + ) + + def method_sparse_resize_and_clear_(self, *args, **kwargs): + unimplemented_v2( + gb_type="Unsupported Tensor.sparse_resize_and_clear_() call", + context=f"call_method {self} sparse_resize_and_clear_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_and_clear_()`.", + hints=[], + ) + + def method_set_(self, *args, **kwargs): + if len(args) > 1: + # torch.Tensor.set_() has several overloads. + # aten::set_.source_Tensor(Tensor) gets special handling + # in AOTAutograd and functionalization, because it is the most common + # overload and is used by FSDP. + # graph-breaking on aten::set_source_Tensor_storage_offset for now, + # unless we find that we need to make it work. + unimplemented_v2( + gb_type="Unsupported Tensor.set_() call", + context=f"call_method {self} set_ {args} {kwargs}", + explanation="Dynamo currently does not support tracing `Tensor.set_()` " + "overloads that include more than one argument.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + def method_add_(self, other, *, alpha=None): + if alpha is not None: + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [other, alpha], {} + ) + return self.call_method(tx, "add_", [result], {}) + + def method_addcdiv_(self, tensor1, tensor2, *, value=None): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if value is not None: + result = variables.TorchInGraphFunctionVariable(torch.div).call_function( + tx, [tensor1, tensor2], {} + ) + result = variables.TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [result, value], {} + ) + return self.call_method(tx, "add_", [result], {}) + + def method___contains__(self, arg): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + # Rewrite __contains__ here so that downstream passes can trace through + # without dealing with unbacked symbool. Roughly the code we translate is: + # def __contains__(self, x): + # return (x == self).any().item() + result = variables.TorchInGraphFunctionVariable(torch.eq).call_function( + tx, [self, arg], {} + ) + result = variables.TorchInGraphFunctionVariable(torch.any).call_function( + tx, [result], {} + ) + return result.call_method(tx, "item", [], {}) + + def method_redistribute(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + args_as_value = [x.as_python_constant() for x in args] + kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} + + def redistribute_fn_with_prim_types(x): + return x.redistribute(*args_as_value, **kwargs_as_value) + + # attach the same function name for better debugging + redistribute_fn_with_prim_types.__name__ = "prim_redistribute" + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + redistribute_fn_with_prim_types, + *proxy_args_kwargs([self], {}), + ), + ) + + def method_to_local(self, *args, **kwargs): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + args_as_value = [x.as_python_constant() for x in args] + kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} + + def to_local_fn_with_prim_types(x): + return x.to_local(*args_as_value, **kwargs_as_value) + + # attach the same function name for better debugging + to_local_fn_with_prim_types.__name__ = "prim_to_local" + + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + to_local_fn_with_prim_types, + *proxy_args_kwargs([self], {}), + ), + ) + + def method_register_hook(self, *args, **kwargs): + return self._method_register_hook("register_hook", *args, **kwargs) + + def method_register_post_accumulate_grad_hook(self, *args, **kwargs): + return self._method_register_hook( + "register_post_accumulate_grad_hook", *args, **kwargs + ) + + def _method_register_hook(self, name: str, hook: VariableTracker): + # Note - do not arbitrarily add hooks here - make sure they match the same contract + # see [On tensor.register_hook] + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + + if not self.source: + if not compiled_autograd.compiled_autograd_enabled: + # TODO(voz): + # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary + # python state. + # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run + # them in a compiled bwd without re-entering dynamo as compiled_autograd does. + # + # Discussion point 1 - Should we bypass this if nopython/fullgraph = True? + # No. Because this was going to be a graph break anyway - this check does not + # introduce new graph breaks where there were none. + # + # Discussion point 2 - Should we defer this check to backwards? + # No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user + # would have no recourse - their forward traces just fine, but will fail at backwards unless + # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today) + # then they have nothing they can do except disable compile. + unimplemented_v2( + gb_type="Compilation of intermediate hooks requires compiled autograd", + context=f"var_getattr {self} {name}", + explanation="Dynamo must be in compiled_autograd to register hooks.", + hints=[], + ) + + hook_name, bw_state_proxy = tx.output.add_backward_state_hook(hook) + + def _register_hook_trampoline(tensor, bw_state): + register_hook = getattr(tensor, name) + register_hook( + functools.partial( + trace_wrapped, + fn=call_hook_from_backward_state, + bw_state=bw_state, + hook_name=hook_name, + ) + ) + # TODO(jansel): returning None here is wrong, it should be + # RemovableHandle, but we need some extra work to support + # this properly. + return None + + from .builder import wrap_fx_proxy + + self_proxy = self.as_proxy() + self_proxy.node.meta["has_backward_hook"] = True + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + _register_hook_trampoline, + (self_proxy, bw_state_proxy), + {}, + ), + ) + + handle_variable = variables.RemovableHandleVariable( + mutation_type=variables.base.ValueMutationNew(), + ) + tx.output.side_effects.register_hook(self, hook, handle_variable, name) + return handle_variable + + def method_requires_grad_(self, requires_grad=True): + if requires_grad is not True: + requires_grad = requires_grad.as_python_constant() + + if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad: + unimplemented_v2( + gb_type="Unsupported Tensor.requires_grad_() call", + context=f"call_method {self} requires_grad_", + explanation="Dynamo does not support changes to a Tensor's " + "`requires_grad` through calling `requires_grad_()`.", + hints=[], + ) + else: + return self + + def method_new(self, *args, **kwargs): + # Convert x.new(torch.Size) into x.new_empty(torch.Size), + # as Tensor.new acts differently with a Size input versus a tuple input. + if (len(args) == 1 and isinstance(args[0], SizeVariable)) or ( + len(args) >= 1 + and all( + isinstance(a, ConstantVariable) and a.python_type() == int for a in args + ) + ): + from ..symbolic_convert import InstructionTranslator + + return self.call_method( + InstructionTranslator.current_tx(), "new_empty", args, kwargs + ) + + def method_untyped_storage(self): + return UntypedStorageVariable( + self, self.as_proxy().node.meta["example_value"].untyped_storage() + ) + + def set_name_hint(self, name: str): + if not self._is_name_set: + self.proxy.node._rename(name) + self._is_name_set = True + + +class SymNodeVariable(VariableTracker): + """ + Represents a symbolic scalar, either int, float or bool. This is most commonly used to + handle symbolic size computation, e.g., tensor.size(0), but it is also used to + handle logic like float_tensor.item() or unspecialized float inputs. + """ + + _nonvar_fields = { + "proxy", + "sym_num", + *VariableTracker._nonvar_fields, + } + + def debug_repr(self): + return repr(self.sym_num) + + @classmethod + def create(cls, tx, proxy, sym_num=None, **options): + if sym_num is None: + sym_num = get_fake_value(proxy.node, tx) + if "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == sym_num + set_example_value(proxy.node, sym_num) + + if isinstance(sym_num, (sympy.Integer, int, bool)): + sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num + return ConstantVariable.create(sym_num) + + return SymNodeVariable(proxy, sym_num, **options) + + def __init__(self, proxy, sym_num, **kwargs) -> None: + super().__init__(**kwargs) + self.proxy = proxy + # TODO: Should we allow non SymTypes here? Today it is allowed + self.sym_num = sym_num + self._tensor_var = None + + def python_type(self): + if isinstance(self.sym_num, SymTypes): + return self.sym_num.node.pytype + else: + return type(self.sym_num) + + def as_proxy(self): + return self.proxy + + def as_tensor(self, tx, dtype): + if self._tensor_var is None: + self._tensor_var = VariableTracker.build( + tx, torch.scalar_tensor + ).call_function(tx, [self], {"dtype": VariableTracker.build(tx, dtype)}) + return self._tensor_var + + def evaluate_expr(self, output_graph=None): + try: + return guard_scalar(self.sym_num) + except GuardOnDataDependentSymNode as e: + if torch.fx.experimental._config.no_data_dependent_graph_break: + raise + + raise UserError( # noqa: B904 + UserErrorType.ANTI_PATTERN, + f"Consider annotating your code using torch._check*(). {str(e)}", + case_name="constrain_as_size_example", + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self, *args], kwargs), + ), + ) + + +class NumpyNdarrayVariable(TensorVariable): + """ + Represents a np.ndarray, but backed by torch Tensor via torch._numpy.ndarray. + Use this for Tensor.numpy() call. + """ + + @staticmethod + def create(tx: "InstructionTranslator", proxy, **options): + from .builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + target_cls=NumpyNdarrayVariable, + tx=tx, + proxy=proxy, + **options, + ) + + def var_getattr(self, tx: "InstructionTranslator", name): + # NB: This INTENTIONALLY does not call super(), because there is + # no intrinsic reason ndarray properties are related to Tensor + # properties. The inheritance here is for implementation sharing. + + from ..utils import numpy_attr_wrapper + from .builder import wrap_fx_proxy + + result = None + + example_value = self.as_proxy().node.meta["example_value"] + example_ndarray = tnp.ndarray(example_value) + + def insert_into_graph(): + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {} + ), + ) + + if name in ["T", "real", "imag"]: + proxy = tx.output.create_proxy( + "call_function", + numpy_attr_wrapper, + (self.as_proxy(), name), + {}, + ) + result = NumpyNdarrayVariable.create(tx, proxy) + + # These are awkward to implement. The standard playbook for torch._numpy + # interop is to trace a call into the torch._numpy wrapper which works for + # Tensor operations. However, we don't want to do this for calls + # that don't return Tensors, because in those cases we may not want + # to trace the attribute access into the graph at all (it is sort + # of harmless to do so, because AOTAutograd will eliminate them, + # but it's best not to trace them in to begin with.) But in any + # case, tracing these into the graph is like trying to fit a square + # peg into a round hole; best not to do it. So instead we + # painstakingly implement these by hand + # + # NB: only ALWAYS specialized attributes can go here; notably, + # size/shape not allowed! + elif name in ("ndim", "itemsize"): + return ConstantVariable.create(getattr(example_ndarray, name)) + elif name in ("shape", "stride"): + if not has_free_symbols(r := getattr(example_ndarray, name)): + return ConstantVariable.create(tuple(int(r) for r in r)) + return insert_into_graph() + elif name == "size": + if not has_free_symbols(r := example_ndarray.size): + return ConstantVariable.create(int(r)) + return insert_into_graph() + elif name in ["base", "flags", "dtype"]: + unimplemented_v2( + gb_type="Unsupported ndarray attribute access", + context=f"var_getattr {self} {name}", + explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", + hints=[], + ) + elif name in ["__version__"]: + unimplemented_v2( + gb_type="Unsupported ndarray.__version__ access", + context=f"var_getattr {self} {name}", + explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", + hints=[], + ) + if result is None: + raise NotImplementedError + return result + + @staticmethod + def patch_args(name, args, kwargs): + if name == "clip": + kwargs_rename = {"a_min": "min", "a_max": "max"} + kwargs = {kwargs_rename.get(k, k): v for k, v in kwargs.items()} + return args, kwargs + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..utils import numpy_method_wrapper + + args, kwargs = self.patch_args(name, args, kwargs) + + if name in ["__len__", "size", "tolist"]: + # delegate back to TensorVariable + return super().call_method(tx, name, args, kwargs) + if name in ("tostring", "tobytes", "__delattr__"): + unimplemented_v2( + gb_type="Unsupported ndarray method call", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"`ndarray.{name}()` is not modelled in `torch._numpy`.", + hints=[], + ) + proxy = tx.output.create_proxy( + "call_function", + numpy_method_wrapper(name), + *proxy_args_kwargs([self] + list(args), kwargs), + ) + return NumpyNdarrayVariable.create(tx, proxy) + + def python_type(self): + return np.ndarray + + +class UnspecializedPythonVariable(TensorVariable): + """ + This is a 1-element tensor represents unspecialized python float/int. + """ + + _nonvar_fields = { + "raw_value", + "need_unwrap", + *TensorVariable._nonvar_fields, + } + + def __init__( + self, proxy: torch.fx.Proxy, *, raw_value=None, need_unwrap=True, **kwargs + ) -> None: + super().__init__(proxy, **kwargs) + self.raw_value = raw_value + self.need_unwrap = need_unwrap + + @classmethod + def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True): + # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance. + return UnspecializedPythonVariable( + **dict(tensor_variable.__dict__), + raw_value=raw_value, + need_unwrap=need_unwrap, + ) + + +class FakeItemVariable(TensorVariable): + """An unspecialized python variable which prevents access to the underlying raw value. + This is needed if item is called on a FakeTensor.""" + + _nonvar_fields = { + "need_unwrap", + *TensorVariable._nonvar_fields, + } + + def __init__(self, proxy: torch.fx.Proxy, **kwargs) -> None: + need_unwrap = kwargs.pop("need_unwrap", False) + super().__init__(proxy, **kwargs) + self.need_unwrap = need_unwrap + + @classmethod + def from_tensor_variable(cls, tensor_variable): + return FakeItemVariable(**dict(tensor_variable.__dict__)) + + +class TensorSubclassVariable(UserDefinedClassVariable): + def call_function( + self, + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # Handle `Subclass(existing_tensor, ...)` calls. + from .torch_function import TensorWithTFOverrideVariable + + new_func = self.value.__new__ + if new_func is torch.Tensor.__new__: + if ( + len(args) == 1 + and isinstance(args[0], TensorVariable) + and len(kwargs) == 0 + ): + data = args[0] + # Simulate `torch.Tensor.__new__` as shallow-copying the input + # tensor data with a new type. TODO polyfill? + var = TensorWithTFOverrideVariable.from_tensor_var( + tx, data, self.value, self.source + ) + else: + unimplemented_v2( + gb_type="Calling subclass default constructor with more than tensor argument", + context=f"{self.value}(args={args}, kwargs={kwargs})", + explanation="Currently not supported", + hints=[ + "Avoid this constructor call or move it outside " + "`torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) + else: + # Let Dynamo trace through custom `__new__` + var = VariableTracker.build(tx, new_func).call_function( + tx, [self] + args, kwargs + ) + + # Let Dynamo trace through custom `__init__` + init_func = self.value.__init__ + # TODO builder should be able to handle `torch.Tensor.__init__`, + # which is `object.__init__`, so that we can remove this check. + if init_func is not torch.Tensor.__init__: + VariableTracker.build(tx, init_func).call_function(tx, [var], kwargs) + + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var + + def as_python_constant(self): + return self.value + + +class UntypedStorageVariable(VariableTracker): + _nonvar_fields = { + "example_value", + *VariableTracker._nonvar_fields, + } + + def __init__( + self, + from_tensor: TensorVariable, + example_value: torch.UntypedStorage, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.from_tensor = from_tensor + # Example_value will always have device="meta" + self.example_value = example_value + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "size": + assert not args + assert not kwargs + result = self.example_value.size() + if not has_free_symbols(result): + # avoid creating a node in the graph + return ConstantVariable.create(int(result)) + else: + from ..external_utils import untyped_storage_size + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + untyped_storage_size, + (self.from_tensor.as_proxy(),), + {}, + ), + ) + if name == "resize_" and len(args) == 1: + assert not kwargs + tx.output.create_proxy( + "call_function", + torch.ops.inductor.resize_storage_bytes_, + (self.from_tensor.as_proxy(), args[0].as_proxy()), + {}, + ) + return self + + return super().call_method(tx, name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.from_tensor) + codegen.load_method("untyped_storage") + codegen.call_method(0) + + +class DataPtrVariable(VariableTracker): + def __init__( + self, + from_tensor: TensorVariable, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.from_tensor = from_tensor + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.from_tensor) + codegen.load_method("data_ptr") + codegen.call_method(0) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/torch.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..05a9d747c19da0fa557b577b3af711272215d1e5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/torch.py @@ -0,0 +1,1650 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +""" +This module implements variable tracking for torch functions and operations during Dynamo tracing. + +It provides classes to handle different types of torch operations: + +TorchInGraphFunctionVariable: Handles torch.* functions that should be captured in the FX graph. +Provides special handling for constant folding, tensor methods, and torch function overrides. +Manages complex cases like out= variants and parameter construction. + +TorchCtxManagerClassVariable: Handles torch context managers like torch.no_grad(), autocast, etc. +Provides implementations for entering/exiting these contexts during tracing. + +DispatchKeySetVariable: Represents torch.DispatchKeySet for managing dispatch keys and +device-specific operations during tracing. + +The module includes special handling for: +- Constant folding of pure functions +- Tensor method calls +- torch.nn.Parameter construction +- __torch_function__ overrides +- Context manager state tracking +- Device and dtype management + +This is a core part of Dynamo's tracing system, translating torch operations into +traceable graph nodes while preserving correct semantics and handling edge cases. +""" + +import functools +import inspect +import logging +import math +import re +from collections.abc import Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch._C +import torch._refs +import torch.fx +import torch.nn +from torch._guards import TracingContext +from torch._logging import warning_once +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type + +from .. import config, graph_break_hints, polyfills, variables +from ..codegen import PyCodegen +from ..create_parameter_op import ( + can_convert_to_tracable_parameter, + new_parameter_placeholder, + tracable_create_parameter, +) +from ..device_interface import get_registered_device_interfaces +from ..exc import unimplemented, unimplemented_v2 +from ..guards import GuardBuilder, install_guard +from ..source import CallFunctionNoArgsSource, SyntheticLocalSource +from ..utils import ( + check_unspec_or_constant_args, + guard_if_dyn, + has_torch_function, + hashable, + product, + proxy_args_kwargs, + unwrap_if_wrapper, +) +from .base import typestr, VariableTracker +from .ctx_manager import ( + AutocastModeVariable, + ProfilerContextVariable, + TorchFunctionDisableVariable, +) +from .dicts import ConstDictVariable +from .distributed import DistributedVariable, ProcessGroupVariable +from .lists import ListVariable, TupleVariable +from .torch_function import ( + can_dispatch_torch_function, + dispatch_torch_function, + TensorWithTFOverrideVariable, + TorchFunctionModeStackVariable, +) + + +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + +try: + from torch.distributed.fsdp._fully_shard import _fsdp_param_group +except ModuleNotFoundError: + _fsdp_param_group = None # type: ignore[assignment] + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +log = logging.getLogger(__name__) + +supported_ctx_manager_classes = dict.fromkeys( + [ + torch.profiler.profiler.profile, + torch.autograd.forward_ad._set_fwd_grad_enabled, + torch.autograd.forward_ad.dual_level, + torch.autograd.profiler.profile, + torch.autograd.profiler.record_function, + torch._C.DisableTorchFunctionSubclass, + torch._C.DisableTorchFunction, + torch._functorch.vmap.vmap_increment_nesting, + torch._functorch.eager_transforms.grad_increment_nesting, + torch._functorch.eager_transforms.jvp_increment_nesting, + torch._functorch.eager_transforms.enable_inplace_requires_grad, + torch.amp.autocast_mode.autocast, + torch.autograd.grad_mode.enable_grad, + torch.autograd.grad_mode.inference_mode, + torch.autograd.grad_mode.no_grad, + torch.autograd.grad_mode.set_grad_enabled, + torch.autograd.graph.disable_saved_tensors_hooks, + torch.cpu.amp.autocast_mode.autocast, + torch.cuda.amp.autocast_mode.autocast, + torch.nn.attention.sdpa_kernel, + torch.nn.attention._sdpa_kernel_variadic, + ] +) + + +REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys( + [ + torch._shape_as_tensor, + ] +) + +constant_fold_functions_need_guards = [ + torch.accelerator.current_device_index, + torch.cuda.current_device, + torch.cuda.is_initialized, + torch.xpu.current_device, + torch.xpu.is_initialized, +] + +constant_fold_functions = [ + torch._assert, + torch._utils._get_device_index, + torch._C._get_cublas_allow_tf32, + torch._C._is_any_autocast_enabled, + torch.accelerator.is_available, + torch.cuda.get_device_properties, + torch.cuda.is_available, + torch.distributed.is_available, + torch.get_autocast_dtype, + torch.get_autocast_gpu_dtype, + torch.get_default_dtype, + torch.is_autocast_cache_enabled, + torch.is_autocast_cpu_enabled, + torch.is_autocast_enabled, + torch.is_complex, + torch.is_floating_point, + torch.nn.functional._Reduction.get_enum, # type: ignore[attr-defined] + torch.promote_types, + torch._C._get_privateuse1_backend_name, + torch.autograd._is_checkpoint_valid, + torch.xpu.get_device_properties, + torch.xpu.is_available, +] + constant_fold_functions_need_guards +if torch.distributed.is_available(): + constant_fold_functions.extend( + [ + torch.distributed.is_initialized, + torch.distributed.get_rank, + torch.distributed.get_world_size, + ] + ) +# Convert to dict for O(1) access times +constant_fold_functions_need_guards = dict.fromkeys(constant_fold_functions_need_guards) +constant_fold_functions = dict.fromkeys(constant_fold_functions) + + +@functools.cache +def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: + # Defined as a function to avoid circular import like torch.onnx + return { + torch.jit.is_scripting: False, + torch.jit.is_tracing: False, + torch._C._get_tracing_state: None, + torch.fx._symbolic_trace.is_fx_tracing: False, + torch.onnx.is_in_onnx_export: False, + torch._dynamo.external_utils.is_compiling: True, + torch._utils.is_compiling: True, + torch.compiler.is_compiling: True, + torch.compiler.is_dynamo_compiling: True, + torch.compiler.is_exporting: True, + torch.nn.modules.activation._is_make_fx_tracing: False, + } + + +bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) + +dispatch_key_set_functions = { + torch._C._dispatch_keys, + torch._C._dispatch_tls_local_include_set, + torch._C._dispatch_tls_local_exclude_set, +} + + +@functools.cache +def get_overridable_functions(): + from itertools import chain + + from torch.overrides import get_overridable_functions as get_overridable_functions_ + + funcs = set(chain.from_iterable(get_overridable_functions_().values())) + more: set[Callable[..., Any]] = { + torch.ones, + torch.ones_like, + torch.zeros, + torch.zeros_like, + torch.empty, + torch.full, + } + funcs.update(more) + return funcs + + +class BaseTorchVariable(VariableTracker): + """common base for all torch.* functions, classes, modules and other things""" + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return cls(value, source=source) + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def reconstruct(self, codegen: "PyCodegen"): + try: + name = f"{self.value.__module__}.{self.value.__name__}" + except Exception: + name = f"torch_obj_{id(self.value)}" + unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name) + codegen.extend_output( + codegen.setup_globally_cached(unique_var_name, self.value) + ) + + def as_proxy(self): + return self.value + + def as_python_constant(self): + return self.value + + def call_obj_hasattr(self, tx: "InstructionTranslator", name): + result = hasattr(self.value, name) + return variables.ConstantVariable.create(result) + + def can_constant_fold_through(self): + if self.value in constant_fold_functions: + return True + return getattr(self.value, "__module__", None) == "math" + + +class TorchCtxManagerClassVariable(BaseTorchVariable): + """Points to a context manager class in torch.* that dynamo has implementations""" + + def __repr__(self) -> str: + return f"TorchCtxManagerClassVariable({self.value})" + + @staticmethod + def is_matching_cls(value): + # Unwrap if it's a functools.lru_cache wrapper + value = unwrap_if_wrapper(value) + # We can't do isinstance(value, type) check because some ctx managers + # are implemented as a function decorated by contextlib.contextmanager, + # E.g., torch._functorch.vmap.vmap_increment_nesting. + return ( + # Context manager type or function with @contextmanager is callable + callable(value) + and ( + hashable(value) # accesses value.__hash__() + and value in supported_ctx_manager_classes + ) + ) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ( + DisabledSavedTensorsHooksVariable, + DualLevelContextManager, + FSDPParamGroupUseTrainingStateVariable, + GradIncrementNestingCtxManagerVariable, + GradInplaceRequiresGradCtxManagerVariable, + GradModeVariable, + InferenceModeVariable, + JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, + SetFwdGradEnabledContextManager, + StreamVariable, + VmapIncrementNestingCtxManagerVariable, + ) + + if self.value is torch.no_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, False) + return ctx.call_function(tx, args, kwargs) + else: + return GradModeVariable.create(tx, False) + elif self.value is torch.enable_grad: + if len(args) == 1 and isinstance( + args[0], variables.functions.BaseUserFunctionVariable + ): + ctx = GradModeVariable.create(tx, True) + return ctx.call_function(tx, args, kwargs) + return GradModeVariable.create(tx, True) + elif self.value is torch.set_grad_enabled and len(args) == 1: + return GradModeVariable.create( + tx, args[0].as_python_constant(), initialized=True + ) + elif self.value is torch.inference_mode: + assert len(args) <= 1 and len(kwargs) == 0 + inf_mode = args[0].as_python_constant() if len(args) == 1 else True + return InferenceModeVariable.create(tx, inf_mode) + elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream): + from torch._dynamo.variables.builder import wrap_fx_proxy_cls + + return wrap_fx_proxy_cls( + StreamVariable, + tx, + tx.output.create_proxy( + "call_function", + self.value, + (), + {}, + ), + ) + elif self.value in ( + torch.amp.autocast_mode.autocast, + torch.cuda.amp.autocast, + torch.cpu.amp.autocast, + ): + return AutocastModeVariable.create(self.value, args, kwargs) + elif self.value in ( + # NOTE any class added here must align with the semantic + # requirements of `ProfilerContextVariable`. + torch.profiler.profile, + torch.profiler.record_function, + torch.autograd.profiler.profile, + torch.autograd.profiler.record_function, + ): + warning_once(log, "Profiler function %s will be ignored", self.value) + return ProfilerContextVariable() + elif ( + self.value is torch._C.DisableTorchFunctionSubclass + or self.value is torch._C.DisableTorchFunction + ): + assert not (args or kwargs) + return TorchFunctionDisableVariable.create( + tx, only_subclass=self.value is torch._C.DisableTorchFunctionSubclass + ) + elif self.value is torch._functorch.vmap.vmap_increment_nesting: + assert len(args) == 2 + return VmapIncrementNestingCtxManagerVariable.create( + tx, + args, + ) + elif self.value is torch._functorch.eager_transforms.jvp_increment_nesting: + assert len(args) == 0 + return JvpIncrementNestingCtxManagerVariable.create(tx) + elif self.value is torch.autograd.forward_ad._set_fwd_grad_enabled: + assert len(args) == 1 + return SetFwdGradEnabledContextManager.create( + tx, + [guard_if_dyn(x) for x in args], + ) + elif self.value is torch.autograd.forward_ad.dual_level: + assert len(args) == 0 + return DualLevelContextManager.create(tx) + elif self.value is torch._functorch.eager_transforms.grad_increment_nesting: + assert len(args) == 0 + return GradIncrementNestingCtxManagerVariable.create(tx) + elif ( + self.value is torch._functorch.eager_transforms.enable_inplace_requires_grad + ): + assert len(args) == 1 + return GradInplaceRequiresGradCtxManagerVariable.create( + tx, + [guard_if_dyn(x) for x in args], + ) + elif self.value is torch.autograd.graph.disable_saved_tensors_hooks: + assert len(args) == 1 + return DisabledSavedTensorsHooksVariable.create( + tx, args[0].as_python_constant() + ) + elif ( + _fsdp_param_group is not None + and self.value is _fsdp_param_group.FSDPParamGroup.use_training_state + ): + assert len(args) == 2 + return FSDPParamGroupUseTrainingStateVariable.create( + tx, args[0], args[1].as_python_constant() + ) + elif self.value is torch.nn.attention.sdpa_kernel: + assert len(args) == 1 or (len(kwargs) == 1 and "backends" in kwargs) + backends = args[0] if len(args) == 1 else kwargs["backends"] + set_priority = kwargs["set_priority"] if "set_priority" in kwargs else False + return SDPAKernelVariable.create( + tx, backends.as_python_constant(), set_priority + ) + elif self.value is torch.nn.attention._sdpa_kernel_variadic: + return SDPAKernelVariable.create( + tx, [arg.as_python_constant() for arg in args] + ) + + return super().call_function(tx, args, kwargs) + + +class TorchInGraphFunctionVariable(BaseTorchVariable): + """Points to a torch function/method that should be put in FX graph""" + + def __init__(self, value, nonstrict_traceable=None, **kwargs) -> None: + super().__init__(value, **kwargs) + from ..trace_rules import is_nonstrict_trace_callable + + if nonstrict_traceable is None: + nonstrict_traceable = is_nonstrict_trace_callable(value) + self.nonstrict_traceable = nonstrict_traceable + + def __repr__(self) -> str: + return f"TorchInGraphFunctionVariable({self.value}, nonstrict_traceable={self.nonstrict_traceable})" + + def get_function(self): + return self.value + + @staticmethod + @functools.cache + def _get_handlers(): + """Build a dict from function -> method to handle it so that we are O(1) + in terms of the number of function with special handling.""" + handlers = {} + + def register(*fns): + def _register(handler): + for fn in fns: + assert fn not in handlers, fn + handlers[fn] = handler + return handler + + assert callable(fns[0]) + return _register + + from torch.backends.cuda import SDPAParams + + from . import ( + ConstantVariable, + DeterministicAlgorithmsVariable, + GradModeVariable, + StreamContextVariable, + SymNodeVariable, + TensorVariable, + UserDefinedObjectVariable, + ) + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls + + @register(*tracing_state_functions()) + def handle_tracing_state_functions( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not args and not kwargs + # See: https://github.com/pytorch/pytorch/issues/110765 + if self.value in ( + torch._utils.is_compiling, + torch._dynamo.external_utils.is_compiling, + torch.compiler.is_compiling, + torch.compiler.is_dynamo_compiling, + torch.compiler.is_exporting, + ): + tx.mark_inconsistent_side_effects() + return ConstantVariable.create(tracing_state_functions()[self.value]) + + @register(*dispatch_key_set_functions) + def handle_dispatch_key_set_functions( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not kwargs + if self.value in (torch._C._dispatch_keys,): + assert len(args) == 1 + assert isinstance(args[0], variables.TensorVariable) + example_value = args[0].proxy.node.meta["example_value"] + dks = self.value(example_value) + # Remove Python and PythonTLSSnapshot from the dispatch key set, + # as they originate from FakeTensor propagation. + # This should only be done if the example_value is a FakeTensor. + # However, if tensor subclasses are present, + # it is reasonable for Python to remain in the dispatch key set. + if isinstance(example_value, torch._subclasses.FakeTensor): + dks = ( + dks + - torch._C.DispatchKeySet(torch._C.DispatchKey.Python) + - torch._C.DispatchKeySet( + torch._C.DispatchKey.PythonTLSSnapshot + ) + ) + return DispatchKeySetVariable.create(dks) + else: + assert not args + return DispatchKeySetVariable.create(self.value()) + + @register(torch.overrides.get_default_nowrap_functions.__wrapped__) + def handle_get_default_nowrap_functions( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # [Note: __torch_function__] we return empty here because we restrict + # the set of functions that we trace __torch_function__ on to + # functions outside of the actual set. Implementing this properly will require implementing + # some variable types to track and compare tensor getset descriptors + return VariableTracker.build( + tx, torch.overrides.get_default_nowrap_functions() + ) + + @register(torch.ops.inductor.accumulate_grad_.default) + def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs): + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs + ) + + @register(math.radians) + def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs): + if not check_unspec_or_constant_args(args, kwargs): + # Use polyfill to convert math.radians(x) into math.pi * x / 180.0 + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.radians), args, kwargs + ) + + @register(torch.is_inference_mode_enabled) + def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"): + unimplemented_v2( + gb_type="Encountered torch.is_inference_mode_enabled during tracing", + context="", + explanation="torch.is_inference_mode_enabled() is not supported", + hints=[ + *graph_break_hints.FUNDAMENTAL, + *graph_break_hints.INFERENCE_MODE, + ], + ) + + @register(torch.is_tensor, torch.overrides.is_tensor_like) + def handle_is_tensor(self, tx: "InstructionTranslator", arg): + if isinstance(arg, TensorVariable) or ( + self.value is torch.overrides.is_tensor_like + and isinstance(arg, UserDefinedObjectVariable) + and hasattr(arg.value, "__torch_function__") + ): + return ConstantVariable.create(True) + else: + return ConstantVariable.create(False) + + @register( + torch.is_floating_point, + torch.is_complex, + ) + def handle_is_floating_point(self, tx: "InstructionTranslator", input): + input_arg = input + if isinstance(input_arg, TensorVariable) and input_arg.dtype is not None: + if self.value is torch.is_floating_point: + return ConstantVariable.create(input_arg.dtype.is_floating_point) + elif self.value is torch.is_complex: + return ConstantVariable.create(input_arg.dtype.is_complex) + else: + raise AssertionError(f"calling {self.value}") + + @register(torch.numel) + def handle_numel(self, tx: "InstructionTranslator", input): + if isinstance(input, TensorVariable) and input.valid_size(): + return ConstantVariable.create(product(input.size)) + elif isinstance(input, TensorVariable): + # Workaround dynamic shapes issue + return input.call_method(tx, "numel", [], {}) + + @register(torch.compile) + def handle_torch_compile(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) == 1: + # torch.compile is a no-op in dynamo + return args[0] + + unimplemented("torch.compile is used as a decorator in the compiled frame") + + @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD) + def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input): + assert isinstance(input, TensorVariable) + return input.call_method(tx, "size", [], {}) + + @register( + torch.nn.modules.utils._single, + torch.nn.modules.utils._pair, + torch.nn.modules.utils._triple, + torch.nn.modules.utils._quadruple, + torch.nn.modules.utils._ntuple, + ) + def handle_ntuple(self, tx: "InstructionTranslator", *args, **kwargs): + return self._call_ntuple(tx, args, kwargs) + + @register(torch.is_grad_enabled) + def handle_is_grad_enabled(self, tx): + install_guard(GradModeVariable._guards_singleton) + return ConstantVariable.create(torch.is_grad_enabled()) + + @register(torch.use_deterministic_algorithms) + def handle_use_deterministic_algorithms( + self, tx: "InstructionTranslator", mode, warn_only=False + ): + if warn_only and warn_only.as_python_constant(): + unimplemented("torch.use_deterministic_algorithms(warn_only=True)") + return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant()) + + @register(torch.are_deterministic_algorithms_enabled) + def handle_are_deterministic_algorithms_enabled(self, tx): + install_guard(DeterministicAlgorithmsVariable._guards_singleton) + return ConstantVariable.create(torch.are_deterministic_algorithms_enabled()) + + @register(torch._C._is_torch_function_enabled) + def handle_is_torch_function_enabled(self, tx): + install_guard(TorchFunctionDisableVariable._guards_singleton) + # see comment on SymbolicTorchFunctionState class as to why + # this is not a bug + return ConstantVariable.create( + tx.symbolic_torch_function_state.torch_function_subclass_enabled + ) + + @register(torch._C._is_torch_function_all_disabled) + def handle_is_torch_function_all_disabled(self, tx): + install_guard(TorchFunctionDisableVariable._guards_singleton) + return ConstantVariable.create( + not tx.symbolic_torch_function_state.torch_function_mode_enabled + ) + + @register( + torch.overrides.has_torch_function, + torch.overrides.has_torch_function_variadic, + torch.overrides.has_torch_function_unary, + ) + def handle_has_torch_function(self, tx: "InstructionTranslator", *args): + elems = ( + args[0].unpack_var_sequence(tx) + if len(args) == 1 and isinstance(args[0], TupleVariable) + else args + ) + return ConstantVariable.create( + any(has_torch_function(x) for x in elems), + ) + + @register( + *dict.fromkeys( # remove duplicates + device_interface.stream + for _, device_interface in get_registered_device_interfaces() + ) + ) + def handle_device_interface_stream(self, tx: "InstructionTranslator", stream): + return StreamContextVariable.create(tx, stream) + + @register(torch.from_numpy) + def handle_from_numpy(self, tx: "InstructionTranslator", *args): + if not config.trace_numpy: + unimplemented("torch.from_numpy. config.trace_numpy is False") + if not np: + unimplemented("torch.from_numpy. NumPy is not available") + return wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.as_tensor, + *proxy_args_kwargs(args, {}), + ), + example_value=None, + ) + + @register(torch.jit.annotate) + def handle_jit_annotate(self, tx: "InstructionTranslator", the_type, the_value): + return the_value + + @register(torch.backends.cudnn.is_acceptable) + def handle_cudnn_is_acceptable( + self, tx: "InstructionTranslator", tensor, *extra + ): + # is_acceptable(tensor) returns true if + # (a) tensor dtype/device are supported by cudnn + # (b) cudnn is available + # (c) some initialization has completed + # technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version) + assert not extra, "Expect 1 input to cudnn.is_acceptable" + assert isinstance(tensor, TensorVariable), ( + "Expect input to cudnn.is_acceptable to be a tensor" + ) + tensor_inp = torch.tensor(0, dtype=tensor.dtype, device=tensor.device) + return ConstantVariable.create( + torch.backends.cudnn.is_acceptable(tensor_inp) + ) + + @register(torch.utils.hooks.BackwardHook) + def handle_backward_hook(self, tx: "InstructionTranslator", *args, **kwargs): + return variables.BackwardHookVariable.create(tx, *args, **kwargs) + + @register(torch.nn.Parameter) + def handle_parameter(self, tx: "InstructionTranslator", *args, **kwargs): + return self.call_nn_parameter(tx, *args, **kwargs) + + @register(torch.ops.aten.sym_size, torch.ops.aten.sym_size.int) + def handle_sym_size(self_, tx, self, dim=None): + # we see this when retracing already traced code + if dim is not None: + return self.call_method(tx, "size", [dim], {}) + + @register(torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int) + def handle_sym_stride(self_, tx, self, dim=None): + if dim is not None: + return self.call_method(tx, "stride", [dim], {}) + + @register(torch.addcdiv) + def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) == 3 and "value" in kwargs and len(kwargs) == 1: + # decompose addcdiv into constituent ops, prevents a graph break due to converting + # value to a scalar + result = TorchInGraphFunctionVariable(torch.div).call_function( + tx, [*args[1:]], {} + ) + result = TorchInGraphFunctionVariable(torch.mul).call_function( + tx, [result, kwargs["value"]], {} + ) + return TorchInGraphFunctionVariable(torch.add).call_function( + tx, [args[0], result], {} + ) + + @register(torch.full) + def handle_full(self, tx, size, fill_value, **kwargs): + if isinstance(fill_value, TensorVariable): + result = TorchInGraphFunctionVariable( + torch.ops.aten._local_scalar_dense + ).call_function(tx, [fill_value], {}) + return TorchInGraphFunctionVariable(torch.full).call_function( + tx, [size, result], kwargs + ) + + @register(torch._foreach_lerp_) + def handle_inplace_foreach_lerp_scalar( + _, tx: "InstructionTranslator", *args, **kwargs + ): + if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.foreach_lerp_inplace), + args, + kwargs, + ) + + @register(torch._foreach_pow) + def handle_foreach_pow_scalar(_, tx: "InstructionTranslator", *args, **kwargs): + # In eager it's more performant to call item() from within the C op implementation + # in compile, it's more performant to not graph break. + if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.foreach_pow_scalar), + args, + kwargs, + ) + + @register(torch._assert) + def handle_assert(self, tx: "InstructionTranslator", condition, message): + if (condition.is_python_constant() and condition.as_python_constant()) or ( + isinstance(condition, variables.SymNodeVariable) + and condition.evaluate_expr() + ): + return ConstantVariable(None) + + @register(SDPAParams) + def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs): + return wrap_fx_proxy( + tx, + proxy=tx.output.create_proxy( + "call_function", + torch._C._SDPAParams, + *proxy_args_kwargs(args, kwargs), + ), + param_vars=args, + ) + + if DistributedVariable.is_available(): + from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + _resolve_group_name_by_ranks_and_tag, + get_process_group_ranks, + ) + from torch.distributed.tensor import DTensor + + @register( + _get_group_size_by_name, + _get_group_tag, + _rank_not_in_group, + get_process_group_ranks, + _resolve_group_name_by_ranks_and_tag, + ) + def handle_constant_processgroup_functions( + self, tx: "InstructionTranslator", *args + ): + # because the input is a "ProcessGroupVariable", we'll be guarding on its + # ID_MATCH based on how it was constructed. + + # We desugar it at trace-time into ranks by directly calling util + # bake the result into the trace + if len(args) == 1: + # group or group name + assert isinstance(args[0], (ProcessGroupVariable, ConstantVariable)) + elif len(args) == 2: + # ranks + tag + assert isinstance(args[0], ListVariable) and isinstance( + args[1], ConstantVariable + ) + else: + raise AssertionError( + f"Invalid group value ({args}) for constant pg " + f"function {self.value}" + ) + args_as_value = [arg.as_python_constant() for arg in args] + invocation_result = self.value(*args_as_value) + + # Note - while we *could* cook up sources around invocations, like a FunctionSource + # the space of invoking functions in the middle of the guard chain is very iffy. As such, + # guard propagation via options is the best we can do. + return VariableTracker.build(tx, invocation_result) + + @register(DTensor.from_local) + def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs): + # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function + # and rewrite args to have only proxyable args, then insert call_function + args_as_value = [x.as_python_constant() for x in args[1:]] + kwargs_as_value = { + k: v.as_python_constant() + for k, v in kwargs.items() + if k not in ["shape", "stride"] + } + kwargs_to_be_proxied = { + k: kwargs[k] for k in ["shape", "stride"] if k in kwargs + } + + def fn_with_prim_types(x, shape=None, stride=None): + return self.value( + x, *args_as_value, **kwargs_as_value, shape=shape, stride=stride + ) + + # attach the same function name for better debugging + fn_with_prim_types.__name__ = "prim " + self.value.__name__ + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_with_prim_types, + *proxy_args_kwargs( + [args[0]], + kwargs_to_be_proxied, + ), + ), + ) + + @register(torch.nested.nested_tensor) + def handle_nested_tensor( + self, + tx: "InstructionTranslator", + tensor_list=None, + *args, + layout=None, + **kwargs, + ): + from .lists import BaseListVariable + + if layout and layout.as_python_constant() == torch.strided: + unimplemented("torch.compile does not support strided NestedTensor") + if not isinstance(tensor_list, BaseListVariable): + unimplemented("nested_tensor with non-list input") + + @register(torch.nn.functional.one_hot) + def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) + len(kwargs) == 1 or ( + len(args) == 2 + and args[1].is_python_constant() + and args[1].as_python_constant() == -1 + ): + unimplemented( + "torch.nn.functional.one_hot with data-dependent output shape" + ) + + @register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious) + def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_size_oblivious( + expr.sym_num + ) + ) + elif isinstance(expr, ConstantVariable): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_or_true) + def handle_guard_or_true(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_true(expr.sym_num) + ) + elif isinstance(expr, ConstantVariable): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_or_false) + def handle_guard_or_false(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_false(expr.sym_num) + ) + elif isinstance(expr, ConstantVariable): + return expr + + @register(torch.fx.experimental.symbolic_shapes.statically_known_false) + def handle_statically_known_false(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.statically_known_false( + expr.sym_num + ) + ) + elif isinstance(expr, ConstantVariable): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_scalar) + def guard_scalar(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + val = expr.sym_num + elif isinstance(expr, ConstantVariable): + val = expr.value + else: + raise torch._dynamo.exc.Unsupported("branch not supported") + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_scalar(val) + ) + + @register(torch.fx.experimental.symbolic_shapes.statically_known_true) + def handle_statically_known_true(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.statically_known_true( + expr.sym_num + ) + ) + elif isinstance(expr, ConstantVariable): + return expr + + @register(torch.fx.experimental.symbolic_shapes.sym_and) + def handle_sym_and(self, tx: "InstructionTranslator", *terms): + if all(isinstance(x, SymNodeVariable) for x in terms): + return SymNodeVariable.create( + tx, + torch.fx.experimental.symbolic_shapes.sym_and( + *(x.as_proxy() for x in terms) + ), + sym_num=None, + ) + + @register(torch.fx.experimental.symbolic_shapes.sym_or) + def handle_sym_or(self, tx: "InstructionTranslator", *terms): + if all(isinstance(x, SymNodeVariable) for x in terms): + return SymNodeVariable.create( + tx, + torch.fx.experimental.symbolic_shapes.sym_or( + *(x.as_proxy() for x in terms) + ), + sym_num=None, + ) + + @register(torch.fx.experimental.symbolic_shapes.has_static_value) + def handle_has_static_value(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + val = expr.sym_num + elif isinstance(expr, ConstantVariable): + val = expr.value + else: + return + + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.has_static_value(val) + ) + + @register(torch._C._autograd._unsafe_set_version_counter) + def handle_unsafe_set_version_counter( + self, tx: "InstructionTranslator", *args, **kwargs + ): + from ..tensor_version_op import _unsafe_set_version_counter + + return TorchInGraphFunctionVariable( + _unsafe_set_version_counter + ).call_function(tx, [*args], kwargs) + + @register(torch._C._functorch.peek_interpreter_stack) + def handle_functorch_peek_interpreter_stack( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # Wrap C++ interpreter (torch._C._functorch.CInterpreter) as UserDefinedObjectVariable, + # but Python interpreter (torch._functorch.pyfunctorch.FuncTorchInterpreter) as FuncTorchInterpreterVariable. + return UserDefinedObjectVariable( + torch._C._functorch.peek_interpreter_stack() + ) + + @register(torch._functorch.pyfunctorch.coerce_cinterpreter) + def handle_functorch_pyfunctorch_coerce_cinterpreter( + self, tx: "InstructionTranslator", *args, **kwargs + ): + cinterpreter = args[0].value + return FuncTorchInterpreterVariable( + torch._functorch.pyfunctorch.coerce_cinterpreter(cinterpreter) + ) + + @register(torch.tensor) + def handle_torch_tensor(self, tx: "InstructionTranslator", *args, **kwargs): + def check_any_unspec(x): + # NB: This includes UnspecializedPythonVariable + if isinstance(x, (TensorVariable, SymNodeVariable)): + return True + elif isinstance(x, (ListVariable, TupleVariable)): + return any(check_any_unspec(y) for y in x.items) + # TODO: there maybe other recursive structures you need to + # check + else: + return False + + data_arg = None + if args: + data_arg = args[0] + elif "data" in kwargs: + data_arg = kwargs["data"] + + # NB: OK to pass torch.tensor(tensor), this will trace fine + if not isinstance(data_arg, TensorVariable) and check_any_unspec(data_arg): + # This is slower and less canonical, so only use it if we + # have to + return TorchInGraphFunctionVariable(torch._refs.tensor).call_function( + tx, [*args], kwargs + ) + + @register(torch._C._pop_torch_function_stack) + def handle_pop_torch_function( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not args and not kwargs + if not tx.symbolic_torch_function_state.mode_stack: + raise unimplemented("Popping from an empty torch function mode stack") + TorchFunctionModeStackVariable.register_mutation(tx) + return tx.symbolic_torch_function_state.pop_torch_function_mode() + + @register(torch._C._push_on_torch_function_stack) + def handle_push_torch_function( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert len(args) == 1 and not kwargs + TorchFunctionModeStackVariable.register_mutation(tx) + tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) + return ConstantVariable.create(None) + + @register(torch._C._len_torch_function_stack) + def handle_len_torch_function( + self, tx: "InstructionTranslator", *args, **kwargs + ): + assert not args and not kwargs + return ConstantVariable.create( + len(tx.symbolic_torch_function_state.mode_stack) + ) + + @register(torch._C._get_function_stack_at) + def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): + assert len(args) == 1 and not kwargs + ind = args[0].as_python_constant() + assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) + return tx.symbolic_torch_function_state.mode_stack[ind] + + @register(torch.set_default_device) + def handle_set_default_device( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # Today this is inserted in the graph, once TF mode + # handling is complete, we can trace the device context + # like any other TF mode and remove this special handling + # Insert the TF mode representing the device context at + # the bottom of the stack to match the eager semantics + # Running the graph will ensure that the DeviceContext mode is + # at the correct position in the stack + TorchFunctionModeStackVariable.register_mutation(tx) + if args[0].is_python_constant() and args[0].as_python_constant() is None: + TorchFunctionModeStackVariable.clear_default_device(tx) + else: + TorchFunctionModeStackVariable.register_device_context_insertion(tx) + + return ConstantVariable.create(None) + + return handlers + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ConstantVariable, SymNodeVariable, TensorVariable + from .builder import wrap_fx_proxy + + if self.nonstrict_traceable: + import torch._higher_order_ops.flat_apply as flat_apply + from torch._higher_order_ops.flat_apply import ( + func_to_graphable, + is_graphable_type, + ) + from torch._subclasses.fake_tensor import fake_tensor_tls + from torch.utils._pytree import tree_flatten + + from .base import AsPythonConstantNotImplementedError + + # 1. Convert `args, kwargs` into pytree-flattened proxy forms. + # + # Rather than reconstructing `args, kwargs` into python objects and + # then tree_flatten them, we just let Dynamo symbolically interpret + # `tree_flatten((args, kwargs))`. This saves us from having to + # worry about the reconstruction logic, side effects, and guards. + packed_input_vt = TupleVariable.build( + tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs)) + ) + out_vt = variables.UserFunctionVariable(tree_flatten).call_function( + tx, [packed_input_vt], {} + ) + assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2 + flat_args_vts, input_spec_vt = out_vt.items + assert isinstance(flat_args_vts, ListVariable) + + # Handle the case when the input contains a non-graphable type. + for flat_arg_vt in flat_args_vts.items: + arg_type = flat_arg_vt.python_type() + if not is_graphable_type(arg_type): + type_name = flat_arg_vt.python_type().__qualname__ + unimplemented( + f""" +For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <{type_name}>, please use one of the following to register the type with pytree: + * `torch.utils._pytree.register_constant` + * `torch.utils._pytree.register_dataclass` + * `torch.utils._pytree.register_pytree_node` +""" # NOQA: B950 + ) + + # Since we checked with `is_graphable` above, `as_proxy` on the + # flat_arg VT should always work. + proxified_flat_args = [ + flat_arg_vt.as_proxy() for flat_arg_vt in flat_args_vts.items + ] + + # The downstream `flat_apply` call requires the input spec; however, + # the spec not a graphable type, so we still have to reconstruct it + # into a python object, and store it as a constant attribute on the + # fx graph. + try: + input_spec = input_spec_vt.as_python_constant() + except AsPythonConstantNotImplementedError as e: + typ = e.vt.python_type() + type_name = typ.__qualname__ + import torch.utils._pytree as pytree + + if pytree.is_constant_class(typ): + unimplemented( + f""" +You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. + +Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub. + """ # NOQA: B950 + ) + else: + unimplemented( + f""" +You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <{type_name}> into the context. + +Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <{type_name}> + * `torch.utils._pytree.register_constant` + * `torch.utils._pytree.register_dataclass` + * `torch.utils._pytree.register_pytree_node` + +If the above doesn't work, please subtmit an issue to GitHub. +""" # NOQA: B950 + ) + + fn = self.value + + def patched_fn(*args, **kwargs): + # This enables reads to global/captured tensors, and we'll just + # treat them as constants in the graph. Note that after + # AOTDispatcher, this logic would disappear. + old_val = fake_tensor_tls.allow_non_fake_inputs_override + fake_tensor_tls.allow_non_fake_inputs_override = True + try: + res = fn(*args, **kwargs) + finally: # reset even when `fn` raises + fake_tensor_tls.allow_non_fake_inputs_override = old_val + return res + + # `flat_apply` wants a TreeSpec for the function input. + _, f_spec = func_to_graphable(patched_fn) + + # TreeSpec isn't graphable, so we register the function and input + # specs as attributes on the graph module. + f_spec_proxy = tx.output.register_static_attr_and_return_proxy( + f"{fn.__name__}_spec", f_spec + ) + input_spec_proxy = tx.output.register_static_attr_and_return_proxy( + fn.__name__ + "_input_spec", input_spec + ) + f_spec_proxy.node.type = type(f_spec) + input_spec_proxy.node.type = type(input_spec) + all_args = (f_spec_proxy, input_spec_proxy, *proxified_flat_args) + + # 2. Create a proxy call to `flat_apply`, then fake-tensor propagate + # the call and wrap output into a VariableTracker. + proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {}) + out_vt = wrap_fx_proxy(tx, proxy) + # TODO support more output types + # Q: flat_apply will likely pytree_flatten the output for this, then + # how do we intercept the output before flatten, and wrap those? + # - Maybe we can have `flat_apply` return the output spec, so that + # Dynamo can unflatten and wrap the result. + + return out_vt + + if self.torch_function_override_enabled(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + + if self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ): + # constant fold functions need to be guarded. + if self.value in constant_fold_functions_need_guards: + source = CallFunctionNoArgsSource(self.source) + install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) + # constant fold + return ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + + if self.is_tensor_method(): + name = self.value.__name__ + # Guard against inplace view op on input tensor (not supported) + if args and isinstance(args[0], variables.TensorVariable): + tensor_var = args[0] + # Check if input tensor and inplace_view op specifically + if tensor_var.source is not None and hasattr(torch.ops.aten, name): + fn = getattr(torch.ops.aten, name) + if ( + hasattr(fn, "overloads") + and hasattr(fn, fn.overloads()[0]) + and torch.Tag.inplace_view + in getattr(fn, fn.overloads()[0]).tags + ): + unimplemented_v2( + gb_type="Inplace op on input tensor", + context="", + explanation=f"Attempted to trace an inplace view op on input tensor {typestr(self.value)}.", + hints=[ + *graph_break_hints.SUPPORTABLE, + "Ensure you do not modify input tensor in place.", + ], + ) + return self.call_tensor_method(tx, args, kwargs) + + special_handler = self._get_handlers().get(self.value) + if special_handler: + result = special_handler(self, tx, *args, **kwargs) + if result: + return result + + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ +Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. +To support this behavior, we need to allow const-propping tensors that store symint data. +For now, dynamo will explicitly graph break when it encounters user code with this behavior. +""" + log.warning(msg) + unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op + ): + fn_ = getattr(torch, torch_sym_op) + + # TODO for each of the following check on `out=` or `requires_grad=` + # variant torch ops, the original function could come from a user + # defined `@allow_in_graph` function as well, which doesn't have the + # same semantics as the torch ops. + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) + + # Handle e.g., `torch.ones(10, requires_grad=True)` + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. +Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" + ) + + # Handle e.g., `torch.add(a, b, out=result)` + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and torch.sigmoid + # mutate the tensors in the out field. + # + # However, it's non-trivial to update all references of the old + # `TensorVariable` to the new one returned (`result_var`), so we + # take the conservative approach to graph break on size changes, and + # assume other cases can fall through soundly. + # + # Note that although these tensor variablels would hold different + # proxies, the in-place mutation semantics is preserved in the FX + # graph, so we won't have correctness issues. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): + if ( + isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor._size + != result_tensor._size # we actually want to compare None values here + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if fake_out_shape != fake_tensor.shape: + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where some of the output tensors were non-contiguous" + ) + else: + unimplemented(f"out variant of {type(kwargs['out'])}") + + return tensor_variable + + def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): + """inline behavior of torch.nn.modules.utils._ntuple""" + if self.value is torch.nn.modules.utils._ntuple: + count = args[0].as_python_constant() + else: + count = self.value.__closure__[0].cell_contents + assert isinstance(count, int) + assert not kwargs + + def handle_ntuple(value): + if value.has_unpack_var_sequence(tx): + return variables.TupleVariable( + list(value.unpack_var_sequence(tx)), + ) + elif value.is_python_constant(): + # constant prop through it + return variables.ConstantVariable.create( + torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), + ) + else: + unimplemented(f"torch.nn.modules.utils._ntuple({value})") + + if self.value is torch.nn.modules.utils._ntuple: + return variables.LambdaVariable(handle_ntuple) + else: + return handle_ntuple(args[0]) + + @classmethod + def call_nn_parameter(cls, tx, data=None, requires_grad=True): + """A call to torch.nn.Parameter() gets lifted to before the graph""" + if tx.export: + unimplemented("nn parameter construction not supported with export") + + if isinstance(requires_grad, variables.VariableTracker): + try: + requires_grad = requires_grad.as_python_constant() + except NotImplementedError: + unimplemented("Parameter(requires_grad=...) not constant") + + if not isinstance(data, variables.TensorVariable): + unimplemented(f"Parameter(data={data}) not implemented") + + # this results in cleaner graphs, but only works for inputs + if data.source: + return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + + if isinstance( + data, TensorWithTFOverrideVariable + ) or is_traceable_wrapper_subclass_type(data.class_type): + unimplemented("Parameter constructor with tensor subclass NYI") + + if not can_convert_to_tracable_parameter(): + unimplemented("Workaround for issues with nn_parameter construction") + + try: + shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) + dtype = data.var_getattr(tx, "dtype").as_python_constant() + device = data.var_getattr(tx, "device").as_python_constant() + except NotImplementedError as e: + unimplemented(f"Parameter not python_constant: {e}") + + placeholder = tx.output.synthetic_graph_input( + new_parameter_placeholder, [shape, dtype, device, requires_grad] + ) + if data.requires_grad: + data = data.call_method(tx, "detach", [], {}) + + from .builder import wrap_fx_proxy + + result = wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", + tracable_create_parameter, + (data.as_proxy(), placeholder.as_proxy()), + {}, + ), + # In reconstruct() we should use the original parameter. The one + # returned by the graph will be an alias. + source=placeholder.source, + ) + assert isinstance(result, variables.TensorVariable) + result.class_type = torch.nn.Parameter + + # TODO(jansel/bdhirsh) - There is some issue with + # tracable_create_paramter. It does not seem to use the right + # grad_enabled. Since this is parameter, we can just override the + # has_grad_fn field to False to workaround the issue. + result.has_grad_fn = False + + # TODO(jansel): if the new param falls out of scope, currently it won't get freed until + # the end of the graph. We should fix this. + return result + + @staticmethod + def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad): + # Alternate version if we have a .source + varname = tx.output.new_var() + + # construct the nn.Parameter before the graph save it to varname + cg = PyCodegen(tx) + cg.add_push_null(lambda: cg.load_import_from("torch.nn", "Parameter")) + cg(data.source) + cg(variables.ConstantVariable(requires_grad)) + cg.call_function(2, False) + cg.store(varname) + tx.output.pregraph_bytecode.extend(cg.get_instructions()) + + data_node = data.as_proxy().node + if data_node.op not in ("placeholder", "get_attr"): + unimplemented( + "Unexpected type of data placeholder op for parameter construction" + ) + + # add the newly constructed nn.Parameter as a graph input + source = SyntheticLocalSource(varname) + example_value = torch.nn.Parameter( + tx.output.example_value_from_input_node(data.as_proxy().node) + ) + result = VariableTracker.build(tx, example_value, source) + # Realize the VT because we will delete the guards on it in the next line. + result = result.realize() + # No need to guard on this since we already guarded on `data`. + # These guards would fail since varname doesn't exist until after the function starts + TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( + source + ) + return result + + def call_tensor_method(self, tx, args, kwargs): + return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs) + + def is_tensor_method(self): + from ..trace_rules import get_tensor_method + + return ( + inspect.ismethoddescriptor(self.get_function()) + and hasattr(self.get_function(), "__objclass__") + and self.get_function().__objclass__ == torch._C.TensorBase + ) or self.get_function() in get_tensor_method() + + def torch_function_override_enabled(self, tx, args, kwargs): + return ( + self.get_function() in get_overridable_functions() + or isinstance( + self.get_function(), + (torch._ops.OpOverload, torch._ops.OpOverloadPacket), + ) + ) and can_dispatch_torch_function(tx, args, kwargs) + + +class DispatchKeySetVariable(BaseTorchVariable): + """represents torch.DispatchKeySet""" + + @staticmethod + def create(value, **kwargs): + return DispatchKeySetVariable(value, **kwargs) + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.DISPATCH_KEY_SET_MATCH)) + return cls(value, source=source) + + def is_constant_fold_method(self, name): + return name in ["has"] + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + if self.is_constant_fold_method(name) and check_unspec_or_constant_args( + args, kwargs + ): + method = getattr(self.value, name) + return variables.ConstantVariable.create( + method( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + elif name == "highestPriorityTypeId": + return variables.EnumVariable(self.value.highestPriorityTypeId()) + return super().call_method(tx, name, args, kwargs) + + +class FuncTorchInterpreterVariable(BaseTorchVariable): + """represents torch._functorch.pyfunctorch.FuncTorchInterpreter""" + + @classmethod + def create_with_source(cls, value, source): + install_guard(source.make_guard(GuardBuilder.ID_MATCH)) + return cls(value, source=source) + + def call_method( + self, + tx, + name, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "VariableTracker": + if name == "key": + return variables.EnumVariable(self.value.key()) + elif name == "process": + return tx.inline_user_function_return( + variables.UserFunctionVariable(self.value.process.__func__), + [self] + args, + kwargs, + ) + elif name in ["level", "batch_size", "randomness"]: + return variables.ConstantVariable.create(getattr(self.value, name)()) + elif name == "lower": + assert not args and not kwargs + return variables.TemporarilyPopInterpreterStackCtxManagerVariable.create( + tx, None + ) + return super().call_method(tx, name, args, kwargs) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/torch_function.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/torch_function.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffca573b9cfc9fbe328ed2b59258b6d7756ec32 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/torch_function.py @@ -0,0 +1,763 @@ +# mypy: ignore-errors + +"""TorchDynamo support for __torch_function__ tensor subclasses. + +This module implements support for tensor subclasses with __torch_function__ overrides. +A tensor subclass instance is represented as a TensorWithTFOverrideVariable, which handles +dispatching __torch_function__ on attribute accesses, method calls, and torch API calls. + +Unsupported features: +- Triggering __torch_function__ on tensor subclass non-tensor custom attributes +- Graph breaking on mutating guardable tensor properties within a __torch_function__ context + (can cause excessive recompiles in certain cases) +- Matching exact eager behavior of ignoring __torch_function__ objects in non-tensor + argument positions of Torch API calls + +Supported features: +- Static method implementations of __torch_function__ on custom objects (triggers on torch + API calls with the object as any argument) +- Triggering __torch_function__ on torch API calls with tensor subclass arguments +- __torch_function__ calls on base tensor attribute access and method calls for tensor + subclass instances +- Matches dispatch ordering behavior of eager __torch_function__ with subclass/object + arguments in any position + +See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w +for more information on the design. +""" + +import collections +import contextlib +import functools +import inspect +import operator +from typing import TYPE_CHECKING + +import torch._C +import torch.utils._pytree as pytree +from torch._guards import Source +from torch.overrides import ( + _get_overloaded_args, + BaseTorchFunctionMode, + get_default_nowrap_functions, + TorchFunctionMode, +) +from torch.utils._device import DeviceContext + +from .. import graph_break_hints +from ..exc import unimplemented_v2 +from ..guards import GuardBuilder, install_guard +from ..polyfills import NoEnterTorchFunctionMode +from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource +from ..utils import ( + class_has_getattribute, + clear_torch_function_mode_stack, + get_safe_global_name, + has_torch_function, + is_tensor_base_attr_getter, + set_torch_function_mode_stack, +) +from .base import VariableTracker +from .constant import ConstantVariable +from .ctx_manager import GenericContextWrappingVariable +from .functions import UserMethodVariable +from .lazy import LazyVariableTracker +from .lists import TupleVariable +from .tensor import TensorSubclassVariable, TensorVariable +from .user_defined import UserDefinedObjectVariable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +bin_ops = [ + operator.pow, + operator.mul, + operator.matmul, + operator.floordiv, + operator.truediv, + operator.mod, + operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + operator.ne, + operator.eq, + operator.sub, + operator.ipow, + operator.imul, + operator.imatmul, + operator.ifloordiv, + operator.itruediv, + operator.imod, + operator.iadd, + operator.isub, +] + +bin_int_ops = [ + operator.and_, + operator.or_, + operator.xor, + operator.iand, + operator.ixor, + operator.ior, +] + +un_int_ops = [operator.invert] + +tensor_and_int_ops = [ + operator.lshift, + operator.rshift, + operator.ilshift, + operator.irshift, + operator.getitem, +] + +un_ops = [ + operator.abs, + operator.pos, + operator.neg, + operator.not_, # Note: this has a local scalar dense call + operator.length_hint, +] + +BUILTIN_TO_TENSOR_FN_MAP = {} + +# These functions represent the r* versions of the above ops +# Basically, if __add__(1, Tensor) is called, it is translated +# to __radd__(Tensor, 1). +# In the builtin var, we check if there is a tensor in the first args position, +# if not, we swap the args and use the r* version of the op. +BUILTIN_TO_TENSOR_RFN_MAP = {} + + +def populate_builtin_to_tensor_fn_map(): + global BUILTIN_TO_TENSOR_FN_MAP + + most_recent_func = None + + class GetMethodMode(BaseTorchFunctionMode): + """ + Mode to extract the correct methods from torch function invocations + (Used to get the correct torch.Tensor methods from builtins) + """ + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + nonlocal most_recent_func + most_recent_func = func + return func(*args, **kwargs) + + inp0 = torch.ones(1) + inp1 = torch.ones(1) + inp0_int = torch.ones(1, dtype=torch.int32) + inp1_int = torch.ones(1, dtype=torch.int32) + with GetMethodMode(): + setups_and_oplists = [ + (lambda o: o(inp0), un_ops), + (lambda o: o(inp0_int), un_int_ops), + (lambda o: o(inp0, inp1), bin_ops), + (lambda o: o(inp0_int, inp1_int), bin_int_ops), + (lambda o: o(inp0_int, 0), tensor_and_int_ops), + ] + for setup_fn, op_list in setups_and_oplists: + for op in op_list: + setup_fn(op) + assert most_recent_func is not None + BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func + + # gather the reverse functions + rsetups_and_oplists = [ + ( + lambda o: o(1, inp1), + bin_ops, + ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int)) + (lambda o: o(1, inp1_int), bin_int_ops), + (lambda o: o(0, inp0_int), tensor_and_int_ops), + ] + + rskips = {operator.matmul, operator.imatmul, operator.getitem} + for setup_fn, op_list in rsetups_and_oplists: + for op in op_list: + if op in rskips: + continue + setup_fn(op) + assert most_recent_func is not None + if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]: + BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func + + +populate_builtin_to_tensor_fn_map() + +banned_attrs = [ + fn.__self__.__name__ + for fn in get_default_nowrap_functions() + if is_tensor_base_attr_getter(fn) +] + + +@functools.cache +def get_prev_stack_var_name(): + from ..bytecode_transformation import unique_id + + return unique_id("___prev_torch_function_mode_stack") + + +# Used to clear/restore the python torch function mode stack and temporarily restore it as needed +class TorchFunctionModeStackStateManager: + def __init__(self): + self.stack = [] + + def __enter__(self): + self.stack = torch.overrides._get_current_function_mode_stack() + clear_torch_function_mode_stack() + + def __exit__(self, exc_type, exc_value, traceback): + set_torch_function_mode_stack(self.stack) + self.stack = [] + + @contextlib.contextmanager + def temp_restore_stack(self): + prev = torch.overrides._get_current_function_mode_stack() + set_torch_function_mode_stack(self.stack) + try: + yield + finally: + set_torch_function_mode_stack(prev) + + +torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() + + +class SymbolicTorchFunctionState: + def __init__(self, py_stack): + # This is annoyingly complicated because of how the torch function subclass + mode C API was designed + # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass + # These are their definitions: + # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered + # (if either are entered, this will be False) + # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR + # torch._C.DisableTorchFunction has been entered + # To disambiguate these and keep myself sane I added a C API to check whether all torch function + # concepts (modes and subclasses) are enabled. + # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate + # the stack length from the enablement state of torch function modes. + # This is important because now if a mode is pushed while dynamo is tracing, we know whether + # or not torch function modes are enabled and whether we should trace it. + self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() + + # This differs from the C API of the same name + # this will only be false iff we have entered torch._C.DisableTorchFunction + # and does not take into account the mode stack length, while the C API bundles these + # two concepts + self.torch_function_mode_enabled = ( + not torch._C._is_torch_function_all_disabled() + ) + + self.cur_mode = None + + TorchFunctionModeStackVariable.reset() + + self.mode_stack: collections.deque[TorchFunctionModeVariable] = ( + collections.deque() + ) + + for i, val in enumerate(py_stack): + self.mode_stack.append( + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) + ) + + def in_torch_function_mode(self): + return len(self.mode_stack) > 0 + + def pop_torch_function_mode(self): + return self.mode_stack.pop() + + def push_torch_function_mode(self, mode_var): + self.mode_stack.append(mode_var) + + def call_torch_function_mode(self, tx, fn, types, args, kwargs): + with self._pop_mode_for_inlining() as cur_mode: + return cur_mode.call_torch_function(tx, fn, types, args, kwargs) + + @contextlib.contextmanager + def _pop_mode_for_inlining(self): + old_mode = self.cur_mode + self.cur_mode = self.pop_torch_function_mode() + try: + yield self.cur_mode + finally: + mode = self.cur_mode + self.cur_mode = old_mode + self.push_torch_function_mode(mode) + + +class TorchFunctionModeStackVariable(VariableTracker): + """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation""" + + # singleton value representing the global torch function mode stack + # singleton (it exists in C++) + stack_value_singleton = object() + + # offset is used to track if we have inserted/removed a + # device context which is always placed at the bottom of the stack + # if a device context is inserted, the graph will run this mutation + # so when we want to reconstruct any other modes on the stack + # their indices should be shifted right by 1 (+1) + # Conversely, if there was a device context on the stack, and the graph + # mutates the stack to remove that context (set default device to None) + # each of the indices of other modes should be shifted left by 1 (-1) + offset = 0 + + def __init__(self, source, symbolic_stack): + self.source = source + self.symbolic_stack = symbolic_stack + + @classmethod + def reset(cls): + cls.offset = 0 + + @classmethod + def register_mutation(cls, tx: "InstructionTranslator"): + if cls.stack_value_singleton not in tx.output.side_effects: + var = cls( + source=Source(), + symbolic_stack=tx.symbolic_torch_function_state.mode_stack, + ) + tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) + tx.output.side_effects.mutation(var) + + @classmethod + def register_device_context_insertion(cls, tx: "InstructionTranslator"): + stack = tx.symbolic_torch_function_state.mode_stack + if stack and cls.is_device_context(stack[0]): + return + else: + cls.offset += 1 + stack.insert( + 0, + TorchFunctionModeVariable( + None, source=TorchFunctionModeStackSource(-cls.offset) + ), + ) + + @classmethod + def clear_default_device(cls, tx: "InstructionTranslator"): + stack = tx.symbolic_torch_function_state.mode_stack + if stack and cls.is_device_context(stack[0]): + stack.popleft() + cls.offset -= 1 + + @staticmethod + def is_device_context(var): + return isinstance(var.value, DeviceContext) or var.value is None + + @classmethod + def get_mode_index(cls, ind): + return ind + cls.offset + + +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty): + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the function across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ + ) + + def __init__(self, value, source=None, **kwargs): + if value is not None: + super().__init__(value, **kwargs) + self.value = value + self.cm_obj = value # needed for BC with calling enter from CM code + self.source = source + + def reconstruct(self, codegen: "PyCodegen"): + # This shouldn't be called unless we have a source + assert self.source + self.source.reconstruct(codegen) + + def module_name(self): + return self.value.__module__ + + def fn_name(self): + return type(self.value).__name__ + + def python_type(self): + return type(self.value) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + return call_torch_function( + tx, + get_torch_function_fn(tx, self), + fn, + types, + args, + kwargs, + ) + + def enter(self, tx): + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen: "PyCodegen"): + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return False + + +def _get_all_args(args, kwargs): + return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) + + +def _flatten_vts(vts): + from collections import deque + + from .dicts import ConstDictVariable + from .lists import ListVariable + + vts = deque(vts) + output = [] + + while vts: + vt = vts.pop() + + if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): + vt.realize() + + if vt.is_realized(): + if isinstance(vt, ListVariable): + vts.extend(vt.items) + elif isinstance(vt, ConstDictVariable): + vts.extend(vt.items.values()) + + output.append(vt) + + return output + + +def _get_subclass_type(var): + assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) + return var.python_type() + + +def _get_subclass_type_var(tx: "InstructionTranslator", var): + assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) + if isinstance(var, TensorWithTFOverrideVariable): + return var.class_type_var(tx) + elif isinstance(var, UserDefinedObjectVariable): + source = var.source and TypeSource(var.source) + return VariableTracker.build(tx, var.python_type(), source) + + +def _is_attr_overridden(tx: "InstructionTranslator", var, name): + import torch + + overridden = False + try: + attr_val = inspect.getattr_static(var.python_type(), name) + overridden |= attr_val != getattr(torch.Tensor, name) + except AttributeError: + pass + + return overridden + + +def call_torch_function(tx, torch_function_var, fn, types, args, kwargs): + # This emulates calling __torch_function__, which has a signature + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # + # Also notice the `cls` is not explicitly passed in the reference + # implementations: + # 1. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/python_arg_parser.cpp#L368-L374 # noqa: B950 + # 2. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/overrides.py#L1741-L1743 + tf_args = [ + fn, + types, + VariableTracker.build(tx, tuple(args)), + VariableTracker.build(tx, kwargs), + ] + return torch_function_var.call_function(tx, tf_args, {}) + + +def get_torch_function_fn(tx: "InstructionTranslator", vt): + # The underlying function could be a classmethod, staticmethod, regular + # function or a function with C-implementation. It doesn't matter as long as + # they satisfy the calling convention in `call_torch_function`. + from .builtin import BuiltinVariable + + args = [vt, ConstantVariable("__torch_function__")] + func_vt = BuiltinVariable(getattr).call_function(tx, args, {}) + return func_vt + + +def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): + has_overridden_args = any( + has_torch_function(arg) for arg in _get_all_args(args, kwargs) + ) + tf_state = tx.symbolic_torch_function_state + return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( + tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() + ) + + +def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): + """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" + + all_args = _get_all_args(args, kwargs) + overloaded_args = _get_overloaded_args( + [arg for arg in all_args if has_torch_function(arg)], + _get_subclass_type, + ) + + types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) + + if tx.symbolic_torch_function_state.in_torch_function_mode(): + res = tx.symbolic_torch_function_state.call_torch_function_mode( + tx, fn, types, args, kwargs + ) + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + return res + + for arg in overloaded_args: + res = arg.call_torch_function( + tx, + fn, + types, + args, + kwargs, + ) + + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + return res + + unimplemented_v2( + gb_type="All __torch_function__ overrides returned NotImplemented due to TypeError from user code", + context=f"{fn=}, {args=}, {kwargs=}", + explanation=f"All __torch_function__ overrides for for function {fn} returned NotImplemented", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + + +class TensorWithTFOverrideVariable(TensorVariable): + """ + Represents a tensor subclass instance with a __torch_function__ override. + """ + + @classmethod + def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): + # [Note: __torch_function__] coerce `tensor_var` into a + # TensorWithTFOverrideVariable. In eager, this is just a type change. + import torch + + # This simulates shallow-copying the tensor object. + kwargs = dict(tensor_var.__dict__) + input_tensor_type = kwargs.pop("class_type") + assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), ( + f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var" + ) + var = cls(class_type=class_type, **kwargs) + var.install_global(tx) + return var + + def install_global(self, tx): + # stash the subclass type to rewrap an output tensor if needed + # this is needed because the actual type needs to be available + # each time the compiled artifact is run and outputs a wrapped tensor. + if self.global_mangled_class_name(tx) not in tx.output.global_scope: + # Safe because global_mangled_class_name figures it out + tx.output.install_global_unsafe( + self.global_mangled_class_name(tx), self.class_type + ) + + def python_type(self): + return self.class_type + + def class_type_var(self, tx): + return TensorSubclassVariable( + self.class_type, source=GlobalSource(self.global_mangled_class_name(tx)) + ) + + def global_mangled_class_name(self, tx): + return get_safe_global_name( + tx, f"__subclass_{self.class_type.__name__}", self.class_type + ) + + def var_getattr(self, tx: "InstructionTranslator", name): + # [Note: __torch_function__] We currently only support attributes that are defined on + # base tensors, custom attribute accesses will graph break. + import torch + + # I think only `_base` is breaking because we aren't modelling view + # relationship perfectly in some scenarios. + if name in banned_attrs: + unimplemented_v2( + gb_type="Unsupported tensor subclass attribute access", + context=f"{name}", + explanation="`torch.compile` currently can't trace this", + hints=[ + f"Avoid accessing {name} of tensor subclass in torch.compile region", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # Handle non-overridden attributes inherited from `torch.Tensor`. + attr_is_overridden = _is_attr_overridden(tx, self, name) + if ( + hasattr(torch.Tensor, name) + and not attr_is_overridden + and not inspect.ismethoddescriptor(getattr(torch.Tensor, name)) + ): + args, kwargs = [self], {} + if can_dispatch_torch_function(tx, args, kwargs): + if self.source: + install_guard( + AttrSource( + AttrSource(self.source, "__class__"), name + ).make_guard(GuardBuilder.FUNCTION_MATCH) + ) + get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) + + return self.call_torch_function( + tx, + get_fn, + TupleVariable([self.class_type_var(tx)]), + args, + kwargs, + ) + else: + # `TensorVariable.var_getattr` doesn't handle user-defined + # function/attribute well, so we explicitly handle them here. + # + # TODO move this logic into `TensorVariable`, or try to merge it + # with similar logic in `UserDefinedObjectVariable`. + try: + attr = inspect.getattr_static(self.class_type, name) + except AttributeError: + pass + else: + import types + + cls_source = GlobalSource(self.global_mangled_class_name(tx)) + attr_source = AttrSource(cls_source, name) + if isinstance(attr, types.FunctionType): + install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return UserMethodVariable(attr, self) + + elif isinstance(attr, property): + getter_source = AttrSource(attr_source, "fget") + getter = attr.fget + getter_var = UserMethodVariable(getter, self, source=getter_source) + return getter_var.call_function(tx, [], {}) + + elif isinstance(attr, classmethod): + return UserMethodVariable( + attr.__func__, self.class_type_var(tx), source=attr_source + ) + + elif attr_is_overridden: + unimplemented_v2( + gb_type="Unsupported tensor subclass overridden attribute access", + context=f"{name}", + explanation="`torch.compile` only support tracing certain types of overridden tensor subclass attributes", + hints=[ + f"Avoid accessing {name} of tensor subclass in torch.compile region", + f"Renaming attribute `{name}` of type {self.class_type}", + *graph_break_hints.SUPPORTABLE, + ], + ) + + return super().var_getattr(tx, name) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + # NOTE this assumes `__torch_function__` isn't modified during tracing. + if not hasattr(self, "torch_function_fn"): + self.torch_function_fn = get_torch_function_fn(tx, self) + + return call_torch_function( + tx, + self.torch_function_fn, + fn, + types, + args, + kwargs, + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # This code block implements inlining the __torch_function__ override + # of `call_method`. + tf_args = [self] + args + if can_dispatch_torch_function(tx, tf_args, kwargs): + import torch + + if _is_attr_overridden(tx, self, name): + unimplemented_v2( + gb_type="Tensor subclass overridden method call", + context=f"{name}", + explanation="`torch.compile` currently can't trace this", + hints=[ + f"Avoid calling {name} of tensor subclass in torch.compile region", + f"Renaming method `{name}` of type {self.class_type}", + *graph_break_hints.SUPPORTABLE, + ], + ) + + # [Note: __torch_function__] Currently we only support methods that are defined on tensor + # we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality + # We've established with the above check that the method is not overridden, so we guard that the method is the same + # as the impl defined on tensor and retrieve it + if self.source: + source = AttrSource(AttrSource(self.source, "__class__"), name) + value = inspect.getattr_static(self.python_type(), name) + else: + source = None + value = getattr(torch.Tensor, name) + func_var = VariableTracker.build(tx, value, source) + return dispatch_torch_function(tx, func_var, tf_args, kwargs) + else: + return super().call_method(tx, name, args, kwargs) diff --git a/phivenv/Lib/site-packages/torch/_dynamo/variables/user_defined.py b/phivenv/Lib/site-packages/torch/_dynamo/variables/user_defined.py new file mode 100644 index 0000000000000000000000000000000000000000..8df5177fe120b8f1bce02eb3e87bf2989cf0e950 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_dynamo/variables/user_defined.py @@ -0,0 +1,1823 @@ +# mypy: ignore-errors + +""" +This module contains variable classes for handling user-defined objects in Dynamo's tracing system. + +The key classes are: +- UserDefinedVariable: Base class for representing custom Python objects +- UserDefinedClassVariable: Handles Python class objects/types +- UserDefinedObjectVariable: Fallback class for instance objects, with support for method calls, + attribute access, and other Python object behaviors. +- Specialized subclasses for common patterns: + - UserDefinedDictVariable: For dict subclasses + - UserDefinedTupleVariable: For tuple subclasses + - FrozenDataClassVariable: Special handling of frozen dataclasses + - MutableMappingVariable: For collections.abc.MutableMapping subclasses + +Dynamo specializes to VariableTracker subclasses like FrozenDataClassVariable if available; if no +subclass qualifies, it falls back to UserDefinedObjectVariable. + +These classes help Dynamo track and handle arbitrary Python objects during tracing, +maintaining proper semantics while enabling optimizations where possible. +""" + +import _collections +import builtins +import collections +import contextlib +import dataclasses +import enum +import functools +import inspect +import itertools +import random +import sys +import threading +import types +import warnings +import weakref +from typing import TYPE_CHECKING +from typing_extensions import is_typeddict + +import torch._dynamo.config +import torch.nn +from torch._guards import TracingContext +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type + +from .. import polyfills, variables +from ..bytecode_transformation import create_call_function +from ..create_parameter_op import do_not_convert_to_tracable_parameter +from ..exc import ( + handle_observed_exception, + ObservedAttributeError, + raise_observed_exception, + unimplemented, +) +from ..guards import GuardBuilder, install_guard +from ..source import ( + AttrSource, + CallFunctionNoArgsSource, + DataclassFieldsSource, + GetItemSource, + RandomValueSource, + TypeSource, + UnspecializedParamBufferSource, +) +from ..utils import ( + build_checkpoint_variable, + check_constant_args, + cmp_name_to_op_mapping, + dict_methods, + get_custom_getattr, + has_torch_function, + is_frozen_dataclass, + is_lru_cache_wrapped_function, + is_namedtuple_cls, + is_utils_checkpoint, + is_wrapper_or_member_descriptor, + istype, + list_methods, + namedtuple_fields, + object_has_getattribute, + proxy_args_kwargs, + tensortype_to_dtype, + tuple_methods, + unpatched_nn_module_getattr, +) +from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker +from .dicts import DefaultDictVariable +from .lists import SizeVariable + + +try: + import numpy as np +except ModuleNotFoundError: + np = None + +try: + from torch.utils._cxx_pytree import PyTreeSpec +except ImportError: + PyTreeSpec = type(None) + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + + +def is_standard_setattr(val): + return val in (object.__setattr__, BaseException.__setattr__) + + +def is_standard_delattr(val): + return val in (object.__delattr__, BaseException.__delattr__) + + +def is_forbidden_context_manager(ctx): + f_ctxs = [] + + try: + from _pytest.python_api import RaisesContext + from _pytest.recwarn import WarningsChecker + + f_ctxs.append(RaisesContext) + f_ctxs.append(WarningsChecker) + except ImportError: + pass + + if m := sys.modules.get("torch.testing._internal.jit_utils"): + f_ctxs.append(m._AssertRaisesRegexWithHighlightContext) + + return ctx in f_ctxs + + +class UserDefinedVariable(VariableTracker): + value: object + + +class UserDefinedClassVariable(UserDefinedVariable): + value: type[object] + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def as_python_constant(self): + return self.value + + def as_proxy(self): + return self.value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value})" + + @staticmethod + @functools.cache + def _constant_fold_classes(): + return { + torch.device, + torch.finfo, + torch.iinfo, + torch.Size, + } + + @staticmethod + @functools.cache + def _in_graph_classes(): + _in_graph_class_list = { + torch.Tensor, + torch.cuda.FloatTensor, + torch.cuda.DoubleTensor, + torch.cuda.HalfTensor, + torch.cuda.BFloat16Tensor, + torch.cuda.ByteTensor, + torch.cuda.CharTensor, + torch.cuda.IntTensor, + torch.cuda.ShortTensor, + torch.cuda.LongTensor, + torch.Stream, + torch.Event, + torch.cuda.Stream, + torch.cuda.Event, + torch.xpu.Stream, + torch.xpu.Event, + } + if hasattr(torch, "hpu"): + _in_graph_class_list.update( + { + torch.hpu.Stream, + torch.hpu.Event, + } + ) + + return set(tensortype_to_dtype.keys()) | _in_graph_class_list + + @staticmethod + @functools.cache + def supported_c_new_functions(): + exceptions = [ + getattr(builtins, name).__new__ + for name in dir(builtins) + if isinstance(getattr(builtins, name), type) + and issubclass(getattr(builtins, name), BaseException) + ] + return { + object.__new__, + dict.__new__, + tuple.__new__, + list.__new__, + }.union(exceptions) + + @staticmethod + def is_supported_new_method(value): + # TODO(anijain2305) - Extend this to support objects with default tp_new + # functions. + return value in UserDefinedClassVariable.supported_c_new_functions() + + def can_constant_fold_through(self): + return self.value in self._constant_fold_classes() + + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + return key in self.value.__dict__ + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + from . import ConstantVariable, EnumVariable + + source = AttrSource(self.source, name) if self.source is not None else None + + if name == "__name__": + return ConstantVariable.create(self.value.__name__) + elif name == "__qualname__": + return ConstantVariable.create(self.value.__qualname__) + elif name == "__dict__": + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + + # Special handling of collections.OrderedDict.fromkeys() + # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with + # collections.defaultdict, and both will be handled at UserDefinedClassVariable.call_method(). + # Otherwise, it would be wrapped as UserDefinedObjectVariable(collections.OrderedDict.fromkeys), + # and we need duplicate code to handle both cases. + if ( + self.value in {collections.OrderedDict, collections.defaultdict} + and name == "fromkeys" + ): + return super().var_getattr(tx, name) + + try: + obj = inspect.getattr_static(self.value, name) + except AttributeError: + if type(self.value) is type: + raise_observed_exception(AttributeError, tx) + else: + # Cannot reason about classes with a custom metaclass + # See: test_functions::test_getattr_metaclass + obj = None + + if name == "__new__" and UserDefinedClassVariable.is_supported_new_method(obj): + return super().var_getattr(tx, name) + + if name in cmp_name_to_op_mapping and not isinstance(obj, types.FunctionType): + return variables.GetAttrVariable(self, name, source=source) + + if isinstance(obj, staticmethod): + return VariableTracker.build(tx, obj.__get__(self.value), source) + elif isinstance(obj, classmethod): + if isinstance(obj.__func__, property): + return variables.UserFunctionVariable(obj.__func__.fget).call_function( + tx, [self], {} + ) + return variables.UserMethodVariable(obj.__func__, self, source=source) + elif isinstance(obj, types.ClassMethodDescriptorType): + # e.g.: inspect.getattr_static(dict, "fromkeys") + # inspect.getattr_static(itertools.chain, "from_iterable") + func = obj.__get__(None, self.value) + return VariableTracker.build(tx, func, source) + elif source: + # __mro__ is a member in < 3.12, an attribute in >= 3.12 + if inspect.ismemberdescriptor(obj) or ( + sys.version_info >= (3, 12) and name == "__mro__" + ): + return VariableTracker.build(tx, obj.__get__(self.value), source) + + if ConstantVariable.is_literal(obj): + return ConstantVariable.create(obj) + elif isinstance(obj, enum.Enum): + return EnumVariable(obj) + elif name in getattr(self.value, "__dict__", {}) or ( + self.value.__module__.startswith("torch.") + or self.value.__module__ == "torch" + ): + if source: + return VariableTracker.build(tx, obj, source) + + if ( + source + and not inspect.ismethoddescriptor(obj) + and not is_wrapper_or_member_descriptor(obj) + ): + return VariableTracker.build(tx, obj, source) + + return super().var_getattr(tx, name) + + def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs): + """ + functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', + label_smoothing=0.0 + + non functional loss call: input, target, optional_output + """ + from . import ConstantVariable + + def normalize_args( + weight=ConstantVariable.create(None), + size_average=ConstantVariable.create(None), + ignore_index=ConstantVariable.create(-100), + reduce=ConstantVariable.create(None), + reduction=ConstantVariable.create("mean"), + label_smoothing=ConstantVariable.create(0.0), + ): + return ( + weight, + size_average, + ignore_index, + reduce, + reduction, + label_smoothing, + ) + + ( + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ) = normalize_args(*args, **kwargs) + + def fake_cross_entropy_loss(input, target): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.nn.functional.cross_entropy, + *proxy_args_kwargs( + [ + input, + target, + weight, + size_average, + ignore_index, + reduce_arg, + reduction, + label_smoothing, + ], + {}, + ), + ), + ) + + return variables.LambdaVariable(fake_cross_entropy_loss) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if ( + name == "__subclasses__" + and len(args) == 0 + and not kwargs + and "__subclasses__" not in self.value.__dict__ + ): + source = self.source + if self.source: + source = AttrSource(self.source, "__subclasses__") + source = CallFunctionNoArgsSource(source) + return VariableTracker.build(tx, self.value.__subclasses__(), source) + elif ( + self.value in {collections.OrderedDict, collections.defaultdict} + and name == "fromkeys" + ): + from .builtin import BuiltinVariable + + return BuiltinVariable.call_custom_dict_fromkeys( + tx, self.value, *args, **kwargs + ) + elif name == "__eq__" and len(args) == 1 and hasattr(args[0], "value"): + return variables.ConstantVariable(self.value == args[0].value) + elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"): + return variables.ConstantVariable(self.value != args[0].value) + elif ( + name == "__new__" + and self.value is collections.OrderedDict + and isinstance(args[0], UserDefinedClassVariable) + and args[0].value is collections.OrderedDict + ): + assert len(args) == 1 + assert len(kwargs) == 0 + return variables.ConstDictVariable( + {}, collections.OrderedDict, mutation_type=ValueMutationNew() + ) + elif name == "__new__" and UserDefinedClassVariable.is_supported_new_method( + self.value.__new__ + ): + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + return super().call_method(tx, name, args, kwargs) + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from ..side_effects import SideEffects + from .builder import wrap_fx_proxy + + constant_args = check_constant_args(args, kwargs) + + if self.can_constant_fold_through() and constant_args: + # constant fold + return variables.ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + elif self.value is torch.nn.CrossEntropyLoss: + return self._call_cross_entropy_loss(tx, args, kwargs) + elif self.value is contextlib.nullcontext: + # import here to avoid circular dependency + from .ctx_manager import NullContextVariable + + return NullContextVariable() + elif self.value is collections.OrderedDict: + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.construct_dict), + [self, *args], + kwargs, + ) + elif ( + self.value is collections.defaultdict + and len(args) <= 1 + and DefaultDictVariable.is_supported_arg(args[0]) + ): + return DefaultDictVariable( + {}, + collections.defaultdict, + args[0], + mutation_type=ValueMutationNew(), + ) + elif is_typeddict(self.value): + if self.value.__optional_keys__: + unimplemented("TypedDict with optional keys not supported") + return variables.BuiltinVariable(dict).call_dict(tx, *args, **kwargs) + elif self.value is collections.deque: + maxlen = variables.ConstantVariable.create(None) + if not kwargs: + if len(args) == 0: + items = [] + elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + elif len(args) == 2 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + maxlen = args[1] + else: + unimplemented("deque() with more than 2 arg not supported") + elif tuple(kwargs) == ("maxlen",): + maxlen = kwargs["maxlen"] + if len(args) == 0: + items = [] + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + else: + unimplemented("deque() with more than 1 arg not supported") + else: + unimplemented("deque() with invalid kwargs not supported") + return variables.lists.DequeVariable( + items, maxlen=maxlen, mutation_type=ValueMutationNew() + ) + elif self.value is weakref.ref: + if len(args) > 1: + callback = args[1] + else: + callback = variables.ConstantVariable.create(None) + return variables.WeakRefVariable(args[0], callback) + elif self.value is functools.partial: + if not args: + unimplemented("functools.partial malformed") + # The first arg, a callable (the ctor below will assert on types) + fn = args[0] + rest_args = args[1:] + # guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the + # args and keywords + return variables.functions.FunctoolsPartialVariable( + fn, args=rest_args, keywords=kwargs + ) + elif self.value is warnings.catch_warnings and not args: + return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs) + elif self.value is torch.cuda.device and not kwargs and len(args) == 1: + assert args[0].is_python_constant() + return variables.CUDADeviceVariable.create(tx, args[0].as_python_constant()) + elif ( + issubclass(type(self.value), type) + and hasattr( + self.value, "__enter__" + ) # TODO(voz): These can invoke user code! + and hasattr( + self.value, "__exit__" + ) # TODO(voz): These can invoke user code! + and self.is_standard_new() + and SideEffects.cls_supports_mutation_side_effects(self.value) + and self.source + and not is_forbidden_context_manager(self.value) + ): + from .functions import ( + BaseUserFunctionVariable, + FunctionDecoratedByContextlibContextManagerVariable, + ) + + # graph break on any contextlib.* that it is not contextlib.contextmanager + # Some of the APIs below are not supported because they rely on features + # that Dynamo doesn't play well today (i.e. contextlib.suppress) + if self.value in ( + contextlib._AsyncGeneratorContextManager, + contextlib.closing, + contextlib.redirect_stdout, + contextlib.redirect_stderr, + contextlib.suppress, + contextlib.ExitStack, + contextlib.AsyncExitStack, + ): + # We are not changing the behavior of Dynamo as these function were + # already ignored on trace_rules.py before #136033 landed + unimplemented( + f"{self.value} not supported. This may be due to its use of " + "context-specific operations that are not supported in " + "Dynamo yet (i.e. Exception handling)" + ) + + if self.value is contextlib._GeneratorContextManager and isinstance( + args[0], BaseUserFunctionVariable + ): + if not torch._dynamo.config.enable_trace_contextlib: + unimplemented("contextlib.contextmanager") + # Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable + # if the function is annotated with @contextlib.contextmanager + # This shouldn't be necessary once generator functions are fully + # supported in dynamo + args = [ + FunctionDecoratedByContextlibContextManagerVariable( + args[0], source=args[0].source + ) + ] + args[1:] + + cm_obj = tx.output.side_effects.track_new_user_defined_object( + variables.BuiltinVariable(object), + self, + args, + ) + cm_obj.call_method(tx, "__init__", args, kwargs) + return cm_obj + elif is_namedtuple_cls(self.value): + fields = namedtuple_fields(self.value) + # check if this a quasi-namedtuple or a real one + if self.value.__module__ == "torch.return_types": + assert len(args) == 1 + assert not kwargs + items = args[0].force_unpack_var_sequence(tx) + else: + field_defaults = self.value._field_defaults + + items = list(args) + items.extend([None] * (len(fields) - len(items))) + + var_tracker_kwargs = {} + for field_name, var_tracker in zip(fields, items): + if var_tracker is None: + if field_name in kwargs: + field_var = kwargs[field_name] + else: + assert field_name in field_defaults + field_var = VariableTracker.build( + tx, field_defaults[field_name] + ) + var_tracker_kwargs[field_name] = field_var + + for name, value in var_tracker_kwargs.items(): + assert name in fields + items[fields.index(name)] = value + + assert all(x is not None for x in items) + + return variables.NamedTupleVariable(items, self.value) + elif self.value is torch.Size: + # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. + tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs) + return SizeVariable(tup.items) + elif is_frozen_dataclass(self.value) and self.is_standard_new(): + fields = dataclasses.fields(self.value) + fields_source = DataclassFieldsSource(self.source) + items = list(args) + items.extend([None] * (len(fields) - len(items))) + + default_kwargs = {} + for ind, field, var_tracker in zip(itertools.count(), fields, items): + if var_tracker is None: + if field.name in kwargs: + var_tracker = kwargs[field.name] + else: + if not field.init: + continue + + if field.default is not dataclasses.MISSING: + var_tracker = VariableTracker.build( + tx, + field.default, + source=AttrSource( + GetItemSource(fields_source, ind), "default" + ), + ) + elif field.default_factory is not dataclasses.MISSING: + factory_fn = VariableTracker.build( + tx, field.default_factory + ) + var_tracker = factory_fn.call_function(tx, [], {}) + else: + # if we are subclass, the constructor could possibly + # be missing args + continue + + default_kwargs[field.name] = var_tracker + kwargs.update(default_kwargs) + + var = tx.output.side_effects.track_new_user_defined_object( + variables.BuiltinVariable(object), self, args + ) + var.call_method(tx, "__init__", args, kwargs) + return var + elif ( + self.value in self._in_graph_classes() + or is_traceable_wrapper_subclass_type(self.value) + ): + # torch.LongTensor cannot accept a list of FakeTensors. + # So we stack the list of FakeTensors instead. + if ( + np + and self.value in tensortype_to_dtype + and len(args) == 1 + and isinstance(args[0], variables.ListVariable) + and len(args[0].items) > 1 + and all(isinstance(x, variables.TensorVariable) for x in args[0].items) + ): + # Stack FakeTensor + stacked = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + torch.stack, + *proxy_args_kwargs(args, kwargs), + ), + ) + args = [stacked] + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + *proxy_args_kwargs(args, kwargs), + ), + ) + + return tensor_variable + elif self.value is random.Random: + if len(args) == 1 and isinstance(args[0], variables.ConstantVariable): + seed = args[0].value + else: + seed = None + random_object = random.Random(seed) + return RandomVariable(random_object) + elif ( + self.value is types.MappingProxyType + and len(args) == 1 + and isinstance(args[0], variables.ConstDictVariable) + ): + # types.MappingProxyType is a read-only proxy of the dict. If the + # original dict changes, the changes are reflected in proxy as well. + return variables.MappingProxyVariable(args[0]) + elif SideEffects.cls_supports_mutation_side_effects(self.value) and self.source: + with do_not_convert_to_tracable_parameter(): + return tx.inline_user_function_return( + VariableTracker.build( + tx, polyfills.instantiate_user_defined_class_object + ), + [self, *args], + kwargs, + ) + return super().call_function(tx, args, kwargs) + + def is_standard_new(self): + """Check for __new__ being overridden""" + new_fn = inspect.getattr_static(self.value, "__new__", None) + if isinstance(new_fn, staticmethod): + new_fn = new_fn.__func__ + return new_fn is object.__new__ + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + if self.source: + source = AttrSource(self.source, name) + install_guard(source.make_guard(GuardBuilder.HASATTR)) + return variables.ConstantVariable(hasattr(self.value, name)) + return super().call_obj_hasattr(tx, name) + + def const_getattr(self, tx: "InstructionTranslator", name): + if name == "__name__": + return self.value.__name__ + return super().const_getattr(tx, name) + + +class UserDefinedExceptionClassVariable(UserDefinedClassVariable): + @property + def fn(self): + return self.value + + def python_type(self): + return self.value + + +class NO_SUCH_SUBOBJ: + pass + + +def call_random_fn(tx, fn, args, kwargs): + from .builder import VariableBuilder + + args = [x.as_python_constant() for x in args] + kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + random_call_index = len(tx.output.random_calls) + example_value = fn(*args, **kwargs) + source = RandomValueSource(random_call_index) + tx.output.random_calls.append((fn, args, kwargs)) + # TODO: arguably, this should route to wrap_symint/wrap_symfloat + # (currently hypothetical), but I'm not going to poke my hand in + # this nest for now + return VariableBuilder(tx, source).wrap_unspecialized_primitive(example_value) + + +class UserDefinedObjectVariable(UserDefinedVariable): + """ + Mostly objects of defined type. Catch-all for something where we only know the type. + """ + + _nonvar_fields = { + "value", + "value_type", + "attrs_directly_modifed_on_dict", + *UserDefinedVariable._nonvar_fields, + } + + def __init__( + self, + value, + *, + value_type=None, + cls_source=None, + base_cls_vt=None, + init_args=None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.value = value + self.value_type = value_type or type(value) + assert type(value) is self.value_type + # This is used with __new__, when the new object is sourceless but the user class can be sourceful. + self.cls_source = cls_source + if cls_source is None and self.source is not None: + self.cls_source = TypeSource(self.source) + + # These attributes are used to reconstruct the user defined object. The + # pseudo code looks like this. Builtin C __new__ do not support kwargs, + # so init_args is sufficient. + # obj = base_cls.__new__(user_cls, *args) + self.base_cls_vt = base_cls_vt + self.init_args = init_args + + # This records names of the attributes that were modified via instance + # `__dict__` directly, rather than the normal setattr path. + # + # TODO consider emulating `obj.__dict__` as a `ConstDictVariable` to get + # rid of these workarounds here and in `GetAttrVariable`. + self.attrs_directly_modifed_on_dict = set() + + def __str__(self) -> str: + inner = self.value_type.__name__ + if inner in [ + "builtin_function_or_method", + "getset_descriptor", + "method_descriptor", + "method", + ]: + inner = str(getattr(self.value, "__name__", None)) + return f"{self.__class__.__name__}({inner})" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value_type.__name__})" + + def is_underlying_vt_modified(self, side_effects): + return False + + def python_type(self): + return self.value_type + + def as_python_constant(self): + import torch.utils._pytree as pytree + + if pytree.is_constant_class(self.value_type): + if self.source is not None: + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + return self.value + # TODO else try reconstructing the object by, e.g., leveraging side + # effects and `as_python_constant`. + return super().as_python_constant() + + def guard_as_python_constant(self): + if self.source: + install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) + return self.value + return super().guard_as_python_constant() + + def torch_function_check(self): + assert has_torch_function(self), ( + f"calling torch function on object without __torch_function__ {self}" + ) + + def get_torch_fn(self, tx): + self.torch_function_check() + from .torch_function import get_torch_function_fn + + return get_torch_function_fn(tx, self) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + self.torch_function_check() + + from .torch_function import call_torch_function + + return call_torch_function( + tx, + self.get_torch_fn(tx), + fn, + types, + args, + kwargs, + ) + + @staticmethod + @functools.cache + def _supported_random_functions(): + fns = { + random.random, + random.randint, + random.randrange, + random.uniform, + } + return fns + + def _maybe_get_baseclass_method(self, name): + if name not in getattr(self.value, "__dict__", {}): + try: + return inspect.getattr_static(type(self.value), name) + except AttributeError: + pass + return None + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + from . import ConstantVariable, UserMethodVariable + + method = self._maybe_get_baseclass_method(name) + if method is not None: + if method is object.__init__: + return ConstantVariable.create(None) + + if is_standard_setattr(method) or isinstance(self.value, threading.local): + return self.method_setattr_standard(tx, *args, **kwargs) + + if is_standard_delattr(method): + return self.method_setattr_standard( + tx, args[0], variables.DeletedVariable() + ) + + if method is object.__eq__ and len(args) == 1 and not kwargs: + other = args[0] + if not isinstance(other, UserDefinedObjectVariable): + return variables.ConstantVariable.create(NotImplemented) + + # TODO(anijain2305) - Identity checking should already be a part + # of the cmp_eq polyfill function. + return ConstantVariable.create(self.value is other.value) + + if torch._dynamo.config.enable_faithful_generator_behavior and isinstance( + self.value, types.GeneratorType + ): + unimplemented("Generator as graph argument is not supported") + + # check for methods implemented in C++ + if isinstance(method, types.FunctionType): + source = ( + None + if self.source is None + else AttrSource(AttrSource(self.source, "__class__"), name) + ) + # TODO(jansel): add a guard to check for monkey patching? + from ..mutation_guard import unpatched_nn_module_init + + if method is torch.nn.Module.__init__: + method = unpatched_nn_module_init + return UserMethodVariable(method, self, source=source).call_function( + tx, args, kwargs + ) + + if method is list.__len__ and self.source and not (args or kwargs): + install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + return ConstantVariable(len(self.value)) + + return super().call_method(tx, name, args, kwargs) + + def method_setattr_standard( + self, tx: "InstructionTranslator", name, value, directly_update_dict=False + ): + try: + name = name.as_python_constant() + except NotImplementedError: + unimplemented(f"non-const setattr name: {name}") + if not tx.output.side_effects.is_attribute_mutation(self): + unimplemented(f"setattr({self}, {name}, ...)") + + if directly_update_dict: + self.attrs_directly_modifed_on_dict.add(name) + else: + tmp = self.try_get_descritor_and_setter_py_func(name) + if tmp: + descriptor, setter = tmp + # Emulate + # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1371-L1452 + desc_source = None + func_source = None + if self.cls_source: + desc_source = self.get_source_by_walking_mro(name) + # use `type(...)` to ignore instance attrs. + func_source = AttrSource(TypeSource(desc_source), "__set__") + desc_var = VariableTracker.build(tx, descriptor, desc_source) + func_var = VariableTracker.build(tx, setter, func_source) + args = [desc_var, self, value] + return func_var.call_function(tx, args, {}) + # NOTE: else we assume the descriptor (if any) has a + # side-effect-free `__set__` as far as Dynamo tracing is concerned. + + # Emulate the standard setattr on instance dict. + tx.output.side_effects.store_attr(self, name, value) + return variables.ConstantVariable(None) + + def needs_slow_setattr(self): + return not is_standard_setattr( + inspect.getattr_static(self.value, "__setattr__", None) + ) and not isinstance(self.value, threading.local) + + def unpack_var_sequence(self, tx): + if ( + self.source + and self._maybe_get_baseclass_method("__iter__") is list.__iter__ + and self._maybe_get_baseclass_method("__len__") is list.__len__ + and self._maybe_get_baseclass_method("__getitem__") is list.__getitem__ + ): + install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) + return [ + variables.LazyVariableTracker.create( + self.value[k], + source=GetItemSource(self.source, k), + ) + for k in range(len(self.value)) + ] + return super().unpack_var_sequence(tx) + + def next_variable(self, tx): + return self.call_method(tx, "__next__", [], {}) + + def is_supported_random(self): + try: + return self.value in self._supported_random_functions() + except TypeError: + # TypeError: unhashable type + return False + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if ( + self.is_supported_random() + and all(k.is_python_constant() for k in args) + and all(v.is_python_constant() for v in kwargs.values()) + ): + return call_random_fn(tx, self.value, args, kwargs) + elif istype(self.value, types.MethodType): + func = self.value.__func__ + obj = self.value.__self__ + if ( + func is torch.utils._contextlib._DecoratorContextManager.clone + and variables.TorchCtxManagerClassVariable.is_matching_cls( + obj.__class__ + ) + and not (args or kwargs) + ): + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, args, kwargs) + + if ( + func is torch.autograd.grad_mode.inference_mode.clone + and obj.__class__ is torch.autograd.grad_mode.inference_mode + ): + # simulate the inference_mode.clone implementation + var = variables.ConstantVariable(obj.mode) + return variables.TorchCtxManagerClassVariable( + obj.__class__ + ).call_function(tx, [var], kwargs) + + if self.source is None: + unimplemented( + "Sourceless UserDefinedObjectVariable method not supported" + ) + func_src = AttrSource(self.source, "__func__") + func_var = VariableTracker.build(tx, func, func_src) + obj_src = AttrSource(self.source, "__self__") + obj_var = VariableTracker.build(tx, obj, obj_src) + return func_var.call_function(tx, [obj_var] + args, kwargs) + elif callable(self.value): + if self.source: + source = AttrSource(self.cls_source, "__call__") + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return self.call_method(tx, "__call__", args, kwargs) + + return super().call_function(tx, args, kwargs) + + def _check_for_getattr(self): + return get_custom_getattr(self.value) + + def _is_c_defined_property(self, subobj): + if not isinstance(subobj, property): + return False + + # pybind def_readwrite is implemented via PyCFunction. At the python level, it is visible as a property whose + # fget is an instancemethod wrapper - https://docs.python.org/3/c-api/method.html#c.PyInstanceMethod_Check + + # If we have a PyCFunction, we make an assumption that there is no side effect. + return isinstance( + subobj.fget, types.BuiltinFunctionType + ) or torch._C._dynamo.utils.is_instancemethod(subobj.fget) + + def _getattr_static(self, name): + subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ) + + # In some cases, we have to do dynamic lookup because getattr_static is not enough. For example, threading.local + # has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup. + # NOTE we assume the following descriptors are side-effect-free as far + # as Dynamo tracing is concerned. + if not object_has_getattribute(self.value) and ( + subobj is NO_SUCH_SUBOBJ # e.g., threading.local + or inspect.ismemberdescriptor(subobj) # e.g., __slots__ + or inspect.isgetsetdescriptor(subobj) # e.g., __dict__ + or self._is_c_defined_property(subobj) + ): + # Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't + # want to call getattr because it can be user-overridden. + subobj = type(self.value).__getattribute__(self.value, name) + elif object_has_getattribute(self.value) and subobj is NO_SUCH_SUBOBJ: + # If the object has an overridden getattribute method, Dynamo has + # already tried tracing it, and encountered an AttributeError. We + # call getattr_static only when the __getattribute__ tracing fails + # (check var_getattr impl). So, it is safe here to raise the + # AttributeError. + raise AttributeError + + return subobj + + def should_skip_descriptor_setter(self, attr_name): + # Check if `attr_name` corresponds to a descriptor. + descriptor = inspect.getattr_static(type(self.value), attr_name, None) + setter = inspect.getattr_static(type(descriptor), "__set__", None) + if setter: + # Skip if `__set__` was traceable (no need to redo the side effect). + if inspect.isfunction(setter): + return True + # For untraceable `__set__` we should still skip if the attribute + # was mutated via instance `__dict__`. + elif attr_name in self.attrs_directly_modifed_on_dict: + return True + return False + + def try_get_descritor_and_setter_py_func(self, attr_name): + descriptor = inspect.getattr_static(type(self.value), attr_name, None) + setter = inspect.getattr_static(type(descriptor), "__set__", None) + if inspect.isfunction(setter): + return (descriptor, setter) + return None + + def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): + if tx.output.side_effects.has_pending_mutation_of_attr(self, key): + mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) + return not isinstance(mutated_attr, variables.DeletedVariable) + + return key in self.value.__dict__ + + def get_source_by_walking_mro(self, name): + assert self.cls_source is not None + + for idx, klass in enumerate(type(self.value).__mro__): + if name in klass.__dict__: + mro_source = AttrSource(self.cls_source, "__mro__") + klass_source = GetItemSource(mro_source, idx) + dict_source = AttrSource(klass_source, "__dict__") + # TODO(anijain2305) - This is a mapping proxy object. Ideally we + # should use DictGetItemSource here. + return GetItemSource(dict_source, name) + + unimplemented(f"Could not find {name} in {type(self.value).__mro__}") + + def var_getattr(self, tx: "InstructionTranslator", name): + from .. import trace_rules + from . import ConstantVariable + + source = AttrSource(self.source, name) if self.source else None + + if object_has_getattribute(self.value): + getattribute_fn = inspect.getattr_static( + type(self.value), "__getattribute__" + ) + if self.source: + new_source = AttrSource(self.source, "__getattribute__") + try: + return variables.UserMethodVariable( + getattribute_fn, self, source=new_source + ).call_function(tx, [ConstantVariable.create(name)], {}) + except ObservedAttributeError: + # Pass through to __getattr__ if __getattribute__ fails + handle_observed_exception(tx) + + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) + if isinstance(result, variables.DeletedVariable): + raise_observed_exception(AttributeError, tx) + return result + + if name == "__dict__": + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + + # TODO(anijain2305) - Investigate if we need specialization for more + # dunder attrs. inspect.getattr_static does not return correct value for + # them. + if name == "__class__": + cls_source = source + if cls_source is None: + cls_source = self.cls_source + options = {"source": cls_source} + return UserDefinedClassVariable(type(self.value), **options) + + try: + subobj = self._getattr_static(name) + except AttributeError: + subobj = NO_SUCH_SUBOBJ + getattr_fn = self._check_for_getattr() + if isinstance(getattr_fn, types.FunctionType): + # Dynamo is going to trace the __getattr__ function with + # args=name. Set the source accordingly. + if ( + getattr_fn is unpatched_nn_module_getattr + and isinstance(self, variables.UnspecializedNNModuleVariable) + # prevent against overwriting of params/buffers/submodules + and istype(self.value._parameters, dict) + and istype(self.value._buffers, dict) + and istype(self.value._modules, dict) + ): + # Manually trace out the nn module __getattr__ to avoid large compilation latency. + out = self.manually_trace_nn_module_getattr(tx, name) + else: + new_source = None + if self.source: + new_source = AttrSource(self.source, "__getattr__") + out = variables.UserMethodVariable( + getattr_fn, self, source=new_source + ).call_function(tx, [ConstantVariable.create(name)], {}) + + if self.source and getattr_fn is torch.nn.Module.__getattr__: + if isinstance( + out, + ( + variables.UnspecializedNNModuleVariable, + variables.NNModuleVariable, + ), + ): + # nn_module_stack source is BC surface area. Ensure that + # mod._modules["linear"] is reflected as mod.linear for + # nn_module_stack. + out.set_nn_module_stack_source( + AttrSource(self.get_nn_module_stack_source(), name) + ) + return out + + elif getattr_fn is not None: + unimplemented("UserDefined with non-function __getattr__") + + from ..mutation_guard import unpatched_nn_module_init + + if subobj is torch.nn.Module.__init__: + subobj = unpatched_nn_module_init + + if isinstance(subobj, property): + if self.source: + # Read the class attribute to reach the property + source = AttrSource(AttrSource(self.source, "__class__"), name) + # Get the getter function + source = AttrSource(source, "fget") + return variables.UserMethodVariable( + subobj.fget, self, source=source + ).call_function(tx, [], {}) + elif isinstance(subobj, _collections._tuplegetter): + # namedtuple fields are represented by _tuplegetter, and here we + # emulate its `__get__`, which is implemented in C. + _, (idx, _) = subobj.__reduce__() + # Don't go through the `__getitem__` method anymore, see + # https://github.com/python/cpython/blob/470941782f74288823b445120f6383914b659f23/Modules/_collectionsmodule.c#L2690 + assert isinstance(self, UserDefinedTupleVariable) + return self._tuple_vt.items[idx] + elif isinstance(subobj, staticmethod): + # Safe because `staticmethod.__get__` basically won't trigger user + # code and just returns the underlying `__func__`: + # https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100 + func = subobj.__get__(self.value) + return VariableTracker.build(tx, func, source) + elif isinstance(subobj, classmethod): + return variables.UserMethodVariable( + subobj.__func__, self.var_getattr(tx, "__class__"), source=source + ) + elif isinstance(subobj, types.ClassMethodDescriptorType): + # e.g.: inspect.getattr_static({}, "fromkeys") + func = subobj.__get__(self.value, None) + return VariableTracker.build(tx, func, source) + elif is_lru_cache_wrapped_function(subobj): + # getattr_static returns the lru_wrapped function, and we cannot + # extract the underlying method from the wrapped function. To handle + # it, manually create a wrapped user method vt. + return variables.WrapperUserMethodVariable( + subobj, "__wrapped__", self, source=source + ) + elif inspect.getattr_static( + type(subobj), "__get__", NO_SUCH_SUBOBJ + ) is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor( + type(subobj).__get__ + ): + # Emulate https://github.com/python/cpython/blob/3.11/Objects/object.c#L1271-L1285 + # + # Attribute has a __get__ method. Create a user defined object vt + # for the subobj, and then trace the __get__ method. + descriptor_source = None + descriptor_get_source = None + if self.cls_source: + # To access the method descriptor from the udf object w/o using + # inspect.getattr_static, we can look into the class mro + descriptor_source = self.get_source_by_walking_mro(name) + descriptor_get_source = AttrSource( + TypeSource(descriptor_source), "__get__" + ) + descriptor_var = VariableTracker.build(tx, subobj, descriptor_source) + else: + # Sourceless Builder does not support user defined objects + descriptor_var = UserDefinedObjectVariable(subobj) + + # The arguments of the __get__ function are (self, instance, owner) + # self - descriptor_var + # instance - instance of the class, represented by self here + # owner - class object + owner_var = UserDefinedClassVariable(type(self.value)) + return variables.UserMethodVariable( + subobj.__get__.__func__, descriptor_var, source=descriptor_get_source + ).call_function(tx, [self, owner_var], {}) + elif isinstance(subobj, types.FunctionType) or ( + isinstance(subobj, types.MethodType) + and isinstance(self.value, torch.nn.Module) + ): + # Since we get subobj via self._getattr_static, which may not trigger dynamic lookup. + # Static lookup can't tell us it's a method or function correctly, + # so we trigger dynamic lookup here to get the correct type. + dynamic_subobj = getattr(self.value, name) + + while dynamic_subobj is subobj and hasattr(subobj, "_torchdynamo_inline"): + subobj = subobj._torchdynamo_inline + dynamic_subobj = subobj + source = AttrSource(source, "_torchdynamo_inline") if source else None + + if isinstance(subobj, types.MethodType): + if dynamic_subobj.__self__ is not self.value: + if not isinstance(dynamic_subobj.__func__, types.FunctionType): + unimplemented( + f"Found a method whose __func__ is not of FunctionType - {dynamic_subobj}" + ) + + # Use the __self__ attribute of the method to find the + # source of the new self object. + self_source = None + if source is not None: + self_source = AttrSource(source, "__self__") + object_vt = VariableTracker.build( + tx, dynamic_subobj.__self__, self_source + ) + + return variables.UserMethodVariable( + dynamic_subobj.__func__, object_vt + ) + func = subobj.__func__ + else: + assert isinstance(subobj, types.FunctionType) + func = subobj + + if inspect.ismethod(dynamic_subobj): + return variables.UserMethodVariable(func, self, source=source) + elif inspect.isfunction(dynamic_subobj): + if is_utils_checkpoint(func): + return build_checkpoint_variable(source=source) + elif source is not None: + return trace_rules.lookup(func).create_with_source( + func, source=source + ) + else: + return trace_rules.lookup(func)(func) + + if ( + # wrap the source only if inline_inbuilt_nn_modules is set or fsdp modules. This is a temporary solution to + # keep Dynamo behavior compatible with no inlining, as there will be some delay to turn on the flag in + # fbcode. + ( + torch._dynamo.config.inline_inbuilt_nn_modules + or isinstance(self, variables.FSDPManagedNNModuleVariable) + ) + and source + and isinstance(self, variables.UnspecializedNNModuleVariable) + # export has some awkwardness around specialized and unspecialized modules. Skip wrapping source for export + # usecase for now. + and (not tx.output.export or torch._dynamo.config.install_free_tensors) + ): + # Recalculate source for params/buffers + if name in ("_buffers", "_parameters"): + source = UnspecializedParamBufferSource(self.source, name) + source = self._wrap_source(source) + + if subobj is not NO_SUCH_SUBOBJ: + if is_wrapper_or_member_descriptor(subobj): + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + if source: + return variables.LazyVariableTracker.create(subobj, source) + else: + # Check if the subobj is accessible from the class itself. If the class source is known, we can create a + # sourceful variable tracker. + if self.cls_source is not None: + subobj_from_class = inspect.getattr_static( + self.value.__class__, name, NO_SUCH_SUBOBJ + ) + if subobj_from_class is subobj: + src_from_class = AttrSource(self.cls_source, name) + return variables.LazyVariableTracker.create( + subobj_from_class, src_from_class + ) + + return VariableTracker.build(tx, subobj) + + # Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError. + raise_observed_exception(AttributeError, tx) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + if self.source: + install_guard( + AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) + ) + + try: + var_vt = self.var_getattr(tx, name) + return variables.ConstantVariable.create( + not isinstance(var_vt, variables.DeletedVariable) + ) + except ObservedAttributeError: + handle_observed_exception(tx) + return variables.ConstantVariable.create(False) + + +class FrozenDataClassVariable(UserDefinedObjectVariable): + @staticmethod + def create(tx, value, source): + from dataclasses import fields + + assert is_frozen_dataclass(value) + + field_map = {} + for field in fields(value): + if hasattr(value, field.name): + field_map[field.name] = VariableTracker.build( + tx, + getattr(value, field.name), + source and AttrSource(source, field.name), + ) + + return FrozenDataClassVariable(value, fields=field_map, source=source) + + def __init__(self, value, fields=None, **kwargs) -> None: + super().__init__(value, **kwargs) + if fields is None: + fields = {} + self.fields = fields + + def as_python_constant(self): + # NOTE: this is an intentionally limited version of + # `as_python_constant` for `nonstrict_trace` implementation. + from dataclasses import fields + + import torch.utils._pytree as pytree + + if not istype( + self.value, (pytree.TreeSpec, pytree.LeafSpec, pytree.ConstantNode) + ): + # TODO loosen this restriction and fix `as_proxy`. + raise NotImplementedError( + "currently can't reconstruct arbitrary frozen dataclass instances" + ) + + args = [] + kwargs = {} + for field in fields(self.value): + if field.init: + data = self.fields[field.name].as_python_constant() + if getattr(field, "kw_only", False): + kwargs[field.name] = data + else: + args.append(data) + + # This is safe because we know the TreeSpec classes constructors don't + # have external side effects. + ctor = self.python_type() + return ctor(*args, **kwargs) + + def as_proxy(self): + from dataclasses import fields + + args = [] + kwargs = {} + for field in fields(self.value): + proxy = self.fields[field.name].as_proxy() + if hasattr(field, "kw_only") and field.kw_only: + kwargs[field.name] = proxy + else: + args.append(proxy) + + # TODO this isn't really safe, because + # 1. it could invoke a user defined `__post_init__`. + # 2. it could invoke a user defined `__init__` if the class _subclasses_ + # a frozen dataclass. + # Either of the above could end up mutating external state. + ctor = self.python_type() + return ctor(*args, **kwargs) + + # NB: This is called during __init__ for a frozen dataclass + # use this to accumulate the most up-to-date field values + def method_setattr_standard(self, tx: "InstructionTranslator", name, value): + self.fields[name.as_python_constant()] = value + return super().method_setattr_standard(tx, name, value) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value_type.__name__})" + + +class SourcelessGraphModuleVariable(UserDefinedObjectVariable): + def __init__( + self, + value, + **kwargs, + ) -> None: + super().__init__(value, **kwargs) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + fn_variable = variables.UserFunctionVariable(self.value.forward.__func__) + args = [self] + args + return tx.inline_user_function_return( + fn_variable, + args, + kwargs, + ) + + +class UserDefinedExceptionObjectVariable(UserDefinedObjectVariable): + def __init__(self, value, **kwargs): + super().__init__(value, **kwargs) + self.exc_vt = variables.ExceptionVariable(self.value_type, ()) + + @property + def fn(self): + return self.value_type + + def call_method(self, tx, name, args, kwargs): + if ( + name == "__init__" + and (method := self._maybe_get_baseclass_method(name)) + and inspect.ismethoddescriptor(method) + and len(kwargs) == 0 + ): + self.exc_vt.args = args + self.value.args = args + return variables.ConstantVariable(None) + if ( + name == "__setattr__" + and len(args) == 2 + and isinstance(args[0], variables.ConstantVariable) + and args[0].value + in ("__cause__", "__context__", "__suppress_context__", "__traceback__") + ): + self.exc_vt.call_setattr(tx, args[0], args[1]) + return super().call_method(tx, name, args, kwargs) + + @property + def __context__(self): + return self.exc_vt.__context__ + + def set_context(self, context: "variables.ExceptionVariable"): + return self.exc_vt.set_context(context) + + @property + def exc_type(self): + return self.exc_vt.exc_type + + +class KeyedJaggedTensorVariable(UserDefinedObjectVariable): + @staticmethod + def is_matching_object(obj): + mod = sys.modules.get("torchrec.sparse.jagged_tensor") + return mod is not None and type(obj) is mod.KeyedJaggedTensor + + def __init__(self, value, **kwargs) -> None: + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + assert type(value) is KeyedJaggedTensor + super().__init__(value, **kwargs) + + def var_getattr(self, tx: "InstructionTranslator", name): + if ( + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt + and self.source is not None + and name in ("_length_per_key", "_offset_per_key") + ): + with TracingContext.patch(force_unspec_int_unbacked_size_like=True): + return super().var_getattr(tx, name) + return super().var_getattr(tx, name) + + +class IntWrapperVariable(UserDefinedObjectVariable): + # Dummy class to check if the object is an IntWrapper, and turn it into a + # symint + @staticmethod + def is_matching_object(obj): + mod = sys.modules.get("torch.export.dynamic_shapes") + return mod is not None and type(obj) is mod._IntWrapper + + +class RemovableHandleClass: + # Dummy class to pass to python_type of RemovableHandleVariable + # Useful for isinstance check on hooks + pass + + +class RemovableHandleVariable(VariableTracker): + REMOVED = -1 + + def __init__( + self, + mutation_type=None, + # index of the registration in the side_effects owned register_hook/handle list, used during removal. + idx=None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.mutation_type = mutation_type + self.idx = idx + + def call_method(self, tx: "InstructionTranslator", method_name, args, kwargs): + if method_name == "remove": + if self.idx != self.REMOVED: + tx.output.side_effects.remove_hook(self.idx) + self.idx = self.REMOVED + return variables.ConstantVariable.create(None) + super().call_method(tx, method_name, args, kwargs) + + def reconstruct(self, codegen: "PyCodegen"): + if self.idx == self.REMOVED: + # Hook has already been removed, return a dummy handle + codegen.add_push_null( + lambda: codegen.load_import_from( + "torch._dynamo.utils", "invalid_removeable_handle" + ) + ) + codegen.extend_output(create_call_function(0, False)) + return + # unreachable due to codegen.add_cache() when the hook is installed + super().reconstruct(codegen) + + def python_type(self): + return RemovableHandleClass + + +class UserDefinedDictVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of dict/OrderedDict. + + Internally, it uses a ConstDictVariable to represent the dict part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + _nonvar_fields = UserDefinedObjectVariable._nonvar_fields + + def __init__(self, value, dict_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._dict_vt = dict_vt + if self._dict_vt is None: + assert self.source is None, ( + "dict_vt must be constructed by builder.py when source is present" + ) + self._dict_vt = variables.ConstDictVariable( + {}, mutation_type=ValueMutationNew() + ) + self._dict_methods = dict_methods + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + method = self._maybe_get_baseclass_method(name) + if method in self._dict_methods: + return self._dict_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + if type(self.value).__iter__ in ( + dict.__iter__, + collections.OrderedDict.__iter__, + ): + return self._dict_vt.unpack_var_sequence(tx) + raise NotImplementedError + + def is_underlying_vt_modified(self, side_effects): + return side_effects.is_modified(self._dict_vt) + + +class UserDefinedListVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of lists. + + Internally, it uses a ListVariable to represent the list part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + _nonvar_fields = UserDefinedObjectVariable._nonvar_fields + + def __init__(self, value, list_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._list_vt = list_vt + if self._list_vt is None: + assert self.source is None, ( + "list_vt must be constructed by builder.py when source is present" + ) + self._list_vt = variables.ListVariable([], mutation_type=ValueMutationNew()) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert self._list_vt is not None + method = self._maybe_get_baseclass_method(name) + if method in list_methods: + return self._list_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + assert self._list_vt is not None + if type(self.value).__iter__ is list.__iter__: + return self._list_vt.unpack_var_sequence(tx) + raise NotImplementedError + + def is_underlying_vt_modified(self, side_effects): + return side_effects.is_modified(self._list_vt) + + +class UserDefinedTupleVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of tuple. + + Internally, it uses a TupleVariable to represent the tuple part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + _nonvar_fields = UserDefinedObjectVariable._nonvar_fields + + def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): + super().__init__(value, init_args=init_args, **kwargs) + self._tuple_vt = tuple_vt + if self._tuple_vt is None: + assert self.source is None, ( + "tuple_vt must be constructed by builder.py when source is present" + ) + # Emulate `tuple.__new__` + # https://github.com/python/cpython/blob/3.11/Objects/tupleobject.c#L697-L710 + # + # TODO this duplicates the logic in `BuiltinVariable(tuple)` + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + elems = init_args[0].unpack_var_sequence(tx) + self._tuple_vt = variables.TupleVariable( + elems, mutation_type=ValueMutationNew() + ) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert self._tuple_vt is not None + method = self._maybe_get_baseclass_method(name) + if method in tuple_methods: + return self._tuple_vt.call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + assert self._tuple_vt is not None + if type(self.value).__iter__ is tuple.__iter__: + return self._tuple_vt.unpack_var_sequence(tx) + raise NotImplementedError + + +class MutableMappingVariable(UserDefinedObjectVariable): + _nonvar_fields = UserDefinedObjectVariable._nonvar_fields + + def __init__(self, value, **kwargs): + super().__init__(value, **kwargs) + self.generic_dict_vt = variables.ConstDictVariable({}) + self.mutation_type = AttributeMutationExisting() + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + # A common pattern in the init code of MutableMapping objects is to + # update the __dict__ attribute. To prevent graph break, we directly + # return a ConstDictVariable for the __dict__attr. + # + # However, users can try to add a new attribute to the class using the + # __dict__ attribute. To catch this, we save the ConstDictVariable for + # the __dict__ and then lookup into this vt for each attr lookup. + if name == "get" and type(self.value).get in ( + collections.abc.Mapping.get, + dict.get, + ): + return variables.UserMethodVariable(polyfills.mapping_get, self) + elif name == "__dict__" and self.source: + self.generic_dict_vt = variables.LazyVariableTracker.create( + self.value.__dict__, AttrSource(self.source, "__dict__") + ) + return self.generic_dict_vt + elif out := self.generic_dict_vt.maybe_getitem_const( + variables.ConstantVariable(name) + ): + return out + else: + return super().var_getattr(tx, name) + + +class RandomVariable(UserDefinedObjectVariable): + pass diff --git a/phivenv/Lib/site-packages/torch/_export/__init__.py b/phivenv/Lib/site-packages/torch/_export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d6bf088f4e3480a6f1b1f3fa57ece4db2ef610f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/__init__.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import functools +import io +import json +import logging +import os +import re +import sys +import types +import warnings +import weakref +import zipfile +from collections import OrderedDict +from contextlib import contextmanager +from functools import lru_cache + +from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from unittest.mock import patch + +import torch +import torch.fx +import torch.utils._pytree as pytree + +from torch._dispatch.python import enable_python_dispatcher +from torch._guards import compile_context +from torch._utils_internal import log_export_usage +from torch.export._tree_utils import reorder_kwargs +from torch.export.graph_signature import ( + ArgumentSpec, + ConstantArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymIntArgument, + SymBoolArgument, + SymFloatArgument, + TensorArgument, +) +from torch.fx import traceback as fx_traceback +from torch.fx._compatibility import compatibility +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo + +from .wrappers import _wrap_submodules +from .utils import _materialize_cpp_cia_ops + +if TYPE_CHECKING: + from torch._C._aoti import AOTIModelContainerRunner + +log = logging.getLogger(__name__) + +@dataclasses.dataclass +class ExportDynamoConfig: + """ + Manage Export-specific configurations of Dynamo. + """ + allow_rnn: bool = True + + +# We only want to print this once to avoid flooding logs in workflows where aot_compile_warning +# is called multiple times. +@lru_cache +def aot_compile_warning(): + from torch._inductor import config + + log.warning("+============================+") + log.warning("| !!! WARNING !!! |") + log.warning("+============================+") + log.warning( + "torch._export.aot_compile()/torch._export.aot_load() is being deprecated, please switch to " + "directly calling torch._inductor.aoti_compile_and_package(torch.export.export())/" + "torch._inductor.aoti_load_package() instead.") + + +def aot_compile( + f: Callable, + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + *, + dynamic_shapes: Optional[dict[str, Any]] = None, + options: Optional[dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, + same_signature: bool = True, +) -> Union[list[Any], str]: + """ + Note: this function is not stable yet + + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside, generates executable cpp code from the program, and returns + the path to the generated shared library + + Args: + f: the `nn.Module` or callable to trace. + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: Should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + options: A dictionary of options to control inductor + + disable_constraint_solver: Whether the dim constraint solver must be disabled. + + Returns: + Path to the generated shared library + """ + from torch.export._trace import _export_to_torch_ir + from torch._inductor.decomposition import select_decomp_table + from torch._inductor import config + + aot_compile_warning() + + if config.is_predispatch: + gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module() + else: + # We want to export to Torch IR here to utilize the pre_grad passes in + # inductor, which run on Torch IR. + gm = _export_to_torch_ir( + f, + args, + kwargs, + dynamic_shapes, + disable_constraint_solver=disable_constraint_solver, + same_signature=same_signature, + # Disabling this flag, because instead we can rely on the mapping + # dynamo_flat_name_to_original_fqn which is coming from Dynamo. + restore_fqn=False, + ) + + with torch.no_grad(): + so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type] + + return so_path + +def aot_load(so_path: str, device: str) -> Callable: + """ + Loads a shared library generated by aot_compile and returns a callable + + Args: + so_path: Path to the shared library + + Returns: + A callable + """ + aot_compile_warning() + + if device == "cpu": + runner: AOTIModelContainerRunner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) + elif device == "cuda" or device.startswith("cuda:"): + runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) + elif device == "xpu" or device.startswith("xpu:"): + runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) + elif device == "mps" or device.startswith("mps:"): + runner = torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1) + else: + raise RuntimeError("Unsupported device " + device) + + def optimized(*args, **kwargs): + call_spec = runner.get_call_spec() + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = runner.run(flat_inputs) + return pytree.tree_unflatten(flat_outputs, out_spec) + + return optimized diff --git a/phivenv/Lib/site-packages/torch/_export/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdee11d183da29d75ee8824cc37f88871d150f9e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/__pycache__/converter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/__pycache__/converter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60d95c221fca9b44c2c115f09080ffa0ba000f13 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/__pycache__/converter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/__pycache__/error.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/__pycache__/error.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11706fa5b23d570cf9761b52659714d0db614ad3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/__pycache__/error.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ce802478bd8673bd9aff5dd7f295772b64a5479 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/__pycache__/pass_base.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/__pycache__/pass_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e54045b2a85932dbfc9404eea8b6d5b1148bccbb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/__pycache__/pass_base.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/__pycache__/tools.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/__pycache__/tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2f32bf2d0b61042f00b6df989589f127b207b03 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/__pycache__/tools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c91ee71992e8a5c4f6dd902c782363031023eb5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/__pycache__/verifier.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/__pycache__/verifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28ad3616bc63694ded0b16bf40fdcb0324d605e3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/__pycache__/verifier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/__pycache__/wrappers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/__pycache__/wrappers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..974ee849ab7daeca8241a45836e02e874c854dff Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/__pycache__/wrappers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/converter.py b/phivenv/Lib/site-packages/torch/_export/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..7d5de1d6d5954f6196476ed28e31f4d27725bbe3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/converter.py @@ -0,0 +1,1614 @@ +# mypy: allow-untyped-defs +import builtins +import logging +import operator +import typing +import warnings +from collections.abc import Sequence +from contextlib import contextmanager +from typing import Any, Callable, Optional, Union + +import torch +import torch.export._trace +from torch import _C +from torch._export.passes.replace_quantized_ops_with_standard_ops_pass import ( + replace_quantized_ops_with_standard_ops, +) +from torch.export.dynamic_shapes import _tree_map_with_path, Dim +from torch.export.exported_program import ExportedProgram +from torch.export.graph_signature import ( + ConstantArgument, + CustomObjArgument, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TensorArgument, +) +from torch.fx import subgraph_rewriter + + +log = logging.getLogger(__name__) + + +def _get_param_count_list(method_graph, args_params): + param_count_list = [] + for input_, arg_params_ in zip(method_graph.inputs(), args_params): + if "PackedParams" in str(input_.type()): + in_vars, _ = torch.jit._flatten(arg_params_) + param_count_list.append(len(in_vars)) + else: + param_count_list.append(arg_params_ is not None) + + return param_count_list + + +def _trace_and_get_graph_from_model(model, args): + # A basic sanity check: make sure the state_dict keys are the same + # before and after running the model. Fail fast! + orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() + + # Disable Autocast cache because it replaces kernel's weight and bias + # by (undesired) constants. + # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 + prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + trace_graph, torch_out, _inputs_states = torch.jit._get_trace_graph( + model, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) + + if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): + raise RuntimeError( + "state_dict changed after running the tracer; " + "something weird is happening in your model!" + ) + + return trace_graph, torch_out + + +def _create_jit_graph( + model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any] +) -> tuple[torch.Graph, list["_C.IValue"], Any, Optional[torch.ScriptModule]]: + if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): + flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) + torch_out = None + + if isinstance(model, torch.jit.ScriptModule): + try: + graph = model.forward.graph # type: ignore[attr-defined] + except AttributeError as e: + raise RuntimeError("'forward' method must be a script method") from e + _C._jit_pass_onnx_function_substitution(graph) + freezed_module = _C._freeze_module( + typing.cast(_C.ScriptModule, model._c), preserveParameters=True + ) + module, params = _C._jit_onnx_list_model_parameters(freezed_module) + method_graph = module._get_method("forward").graph + args_params = tuple(args) + tuple(params) + param_count_list = _get_param_count_list(method_graph, args_params) + in_vars, _ = torch.jit._flatten(args_params) + graph = _C._propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), param_count_list, False, False + ) + return graph, params, torch_out, module + + # torch.jit.ScriptFunction + params = [] + graph = model.graph + _C._jit_pass_onnx_function_substitution(graph) + param_count_list = _get_param_count_list(graph, args) + graph = _C._propagate_and_assign_input_shapes( + graph, flattened_args, param_count_list, False, False + ) + return graph, params, torch_out, None + + graph, torch_out = _trace_and_get_graph_from_model(model, args) + _C._jit_pass_onnx_lint(graph) + state_dict = torch.jit._unique_state_dict(model) + params = list(state_dict.values()) + graph_inputs = list(graph.inputs()) + user_input_num = len(graph_inputs) - len(state_dict) + param_names = list(state_dict.keys()) + for i, inp in enumerate(graph_inputs): + if i >= user_input_num: + inp.setDebugName(param_names[i - user_input_num]) + _C._jit_pass_onnx_function_substitution(graph) + return graph, params, torch_out, None + + +def list_add(a, b): + return a + b + + +def list_append(container, element): + return container + [element] + + +def execute_subgraph_from_prim_loop( + subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs +): + """ + subgraph: GraphModule from sub-block. + iter_idx: The index of interation. + len_loop_local_arguments: The number of loop local arguments in args. + """ + + # Loop local variables. TS graph create those as inputs because their values + # are updated inside the loop. + loop_local_args = args[:len_loop_local_arguments] + # Global variables that are not passed in as inputs to the loop sub-blocks + # but are directly used. Most of time, their values are not updated, but + # the only exception is when there are some operations that perform inplace + # updates. + global_args = args[len_loop_local_arguments:] + return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs) + + +def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule): + def pattern(im, dim, scale): + sym_size_int = torch.ops.aten.sym_size.int(im, dim) + scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int) + div_scalar_mode = torch.ops.aten.div.Scalar_mode( + scalar_tensor, scale, rounding_mode="trunc" + ) + int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode) + return int_tensor + + def replacement(im, dim, scale): + sym_size_int = torch.ops.aten.sym_size.int(im, dim) + return sym_size_int // scale + + subgraph_rewriter.replace_pattern(gm, pattern, replacement) + + +def is_valid_for_codegen(name): + if len(name) == 0: + raise RuntimeError("Empty argument name for codegen") + if name[0].isdigit(): + return False + return True + + +def normalize_name(name: str, prefix: str = "rename") -> str: + name = name.replace(".", "_") + if is_valid_for_codegen(name): + return name + return f"{prefix}_{name}" + + +def ir_name_to_func_name(name: str) -> str: + """prim::If -> convert_prim_If""" + name_list = name.split("::") + return "convert_" + "_".join(name_list) + + +def get_node_as_placeholder_or_get_attr(fx_graph, name, is_top_level_graph): + if is_top_level_graph: + return fx_graph.get_attr(name) + return fx_graph.placeholder(name) + + +_TORCH_DTYPE_TO_ENUM = { + torch.uint8: 0, + torch.int8: 1, + torch.int16: 2, + torch.int32: 3, + torch.int64: 4, + torch.float16: 5, + torch.float32: 6, + torch.float64: 7, + torch.complex32: 8, + torch.complex64: 9, + torch.complex128: 10, + torch.bool: 11, + torch.qint8: 12, + torch.quint8: 13, + torch.bfloat16: 15, +} + +_TORCH_ENUM_TO_DTYPE = {value: key for key, value in _TORCH_DTYPE_TO_ENUM.items()} + + +def get_dtype_as_int(tensor): + """ + prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of + the tensor and returns the integer corresponding to this dtype based on the + enum in ScalarType.h + """ + dtype = tensor.dtype + if dtype not in _TORCH_DTYPE_TO_ENUM: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _TORCH_DTYPE_TO_ENUM[dtype] + + +# Those operators will be automatically populated to a instance method +# of TS2FXGraphConverter with name convert__(). +# Please check __init__ for method population implementations. +kind_to_standard_operators: dict[str, Callable[..., Any]] = { + "prim::max": builtins.max, + "prim::min": builtins.min, + "prim::TupleIndex": operator.getitem, + "aten::__is__": operator.is_, + "aten::__isnot__": operator.is_not, + "aten::__not__": operator.not_, + "aten::__contains__": operator.contains, + "prim::dtype": get_dtype_as_int, + "aten::len": len, + # Mapping from specialized op to its symbolic counterpart. + # They currently do not have any other overrides. + "aten::numel": torch.ops.aten.sym_numel, + "aten::size": torch.ops.aten.sym_size, + "aten::storage_offset": torch.ops.aten.sym_storage_offset, + "aten::stride": torch.ops.aten.sym_stride, +} + + +def get_ir_value_parent_name_and_attr_name(node): + irv_parent_name, irv_name = node.input().debugName(), node.output().debugName() + attr_name = node.s("name") + return irv_name, irv_parent_name, attr_name + + +def construct_fqn(ir, ref_map, name_map): + name_list = [] + while ir in ref_map: + name_list.append(name_map[ir]) + ir = ref_map[ir] + return ".".join(reversed(name_list)) + + +def get_block_to_lifted_attrs( + graph: torch._C.Graph, +) -> tuple[dict[torch._C.Block, set[str]], dict[str, str]]: + """ + Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes. + When a graph has control flow, the graph will be divided into multiple blocks. We want to convert + each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model + parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model, + we will run this pass which will: + 1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls. + 2. Process the graph bottom up to find the lifted attributes of each block by taking the union + of the attributes used in the current block, and the lifted attributes of all its child blocks. + + Returns: + A mapping of blocks to a set of FQNs of its lifted attributes, and a + mapping of node names to the FQNs of its lifted attributes. + """ + + # A map from a block to its expected to be lifted arguments. + blocks_to_lifted_attrs: dict[torch._C.Block, set[str]] = {} + + # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a + # GetAttr node. By traversing this reference map, we can figure out the + # full IR aliasing pass and figure out the FQN of an attribute. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1" + node_to_parent_map: dict[str, str] = {} + + # Used for reconstructing the FQN of an attribute based on the reference map. + # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR + # This name map stores which attribute name is called for a src IR --> dest IR action. + # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear" + node_to_attr_name: dict[str, str] = {} + + def _dfs_get_attr_dependency(entry): + """ + First DFS path to construct reference map and name map. + """ + for node in entry.nodes(): + if node.kind() == "prim::GetAttr": + ( + irv_name, + irv_parent_name, + attr_name, + ) = get_ir_value_parent_name_and_attr_name(node) + node_to_parent_map[irv_name] = irv_parent_name + node_to_attr_name[irv_name] = attr_name + for block in node.blocks(): + _dfs_get_attr_dependency(block) + + def _map_blocks_to_lifted_attrs(entry): + """ + Walk the graph in a bottom-up fashion to build the expected to be + lifted arguments for each block. + """ + arguments: set[str] = set() + for node in entry.nodes(): + for block in node.blocks(): + # Recursively build. + arguments = arguments.union(_map_blocks_to_lifted_attrs(block)) + if node.kind() == "prim::GetAttr": + irv_name = node.output().debugName() + # Skip for intermediate GetAttr, which will anyway not result a FQN. + # E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"} + # node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"} + # There is only one FQN %3-->%2-->%1: self.linear.weight + # %2-->%1 is not a FQN: self.linear + if irv_name not in set(node_to_parent_map.values()): + arguments.add( + construct_fqn(irv_name, node_to_parent_map, node_to_attr_name) + ) + if not isinstance(entry, torch._C.Graph): # Skip the top level. + blocks_to_lifted_attrs[entry] = arguments + return arguments + + _dfs_get_attr_dependency(graph) + _map_blocks_to_lifted_attrs(graph) + + return blocks_to_lifted_attrs, node_to_attr_name + + +def get_attribute_fqn_from_ts_node( + name_to_attribute_fqn: dict[str, str], node: torch._C.Node +) -> str: + def get_attr(name: str): + if name in name_to_attribute_fqn: + return name_to_attribute_fqn[name] + else: + raise ValueError(f"Attribute {name} not found") + + if node.kind() == "prim::SetAttr": + input_name = next(node.inputs()).debugName() + elif node.kind() == "prim::GetAttr": + input_name = node.input().debugName() + else: + raise RuntimeError( + f"Unexpected node kind when getting attribute fqn. node: {node} " + ) + + attr_name = node.s("name") + root_attr_name = get_attr(input_name) + attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name + + return attr_fqn + + +def get_op_overload(node: torch._C.Node): + schema_str = node.schema() + assert schema_str != "(no schema)", f"got empty schema for {node}" + schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str) + ns, op_name = str(schema.name).split("::") + override = schema.overload_name + + try: + op_overload_mod = getattr(torch.ops, ns) + op_overload_packet = getattr(op_overload_mod, op_name) + if override: + op_overload = getattr(op_overload_packet, override) + else: + op_overload = op_overload_packet.default + except Exception as e: + raise RuntimeError( + f"Unable to find operator {node.kind()} with schema {node.schema()}" + ) from e + + return op_overload + + +class TS2FXGraphConverter: + def __init__( + self, + ts_graph: Union[torch._C.Graph, torch._C.Block], + name_to_param: dict[str, torch.Tensor], + name_to_buffer: dict[str, torch.Tensor], + blocks_to_lifted_attrs: dict[torch._C.Block, set[str]], + name_to_non_tensor_attribute: dict[str, Any], + name_to_constant: dict[str, Any], + name_to_attribute_fqn: dict[str, str], + ): + self.ts_graph = ts_graph + # Mapping of parameter FQN to actual parameter value + self.name_to_param = name_to_param + # Mapping of buffer FQN to actual buffer value + self.name_to_buffer = name_to_buffer + + self.fx_graph: torch.fx.Graph = torch.fx.Graph() + self.input_specs: list[InputSpec] = [] + self.output_specs: list[OutputSpec] = [] + + # Mapping of TS node name to converted FX node + self.name_to_node: dict[ + str, Union[torch.fx.Node, list[torch.fx.Node], dict[Any, torch.fx.Node]] + ] = {} + # Mapping of TS node name to constant value (int, str, TorchBind obj, + # tensor constants ...) + self.name_to_constant: dict[str, Any] = name_to_constant + + # Mapping from torchscript node output name to attribute fully qualified name + self.name_to_attribute_fqn: dict[str, str] = name_to_attribute_fqn + + # Mapping from fully qualified name to real values or a fx graph node + # During convert, this represents the current value of a non-tensor attribute + # One use case is: + # def forward(self, x): + # c1 = self.count + # self.count += 1 + # c2 = self.count + # return x + c1 + c2 + self.name_to_non_tensor_attribute_node: dict[str, Any] = {} + + # Mapping from fully qualified name to initial real values inputs + # We separate it from self.name_to_non_tensor_attribute_node since + # we need initial real value input when we construct fx.GraphModule + self.name_to_non_tensor_attribute: dict[str, Any] = name_to_non_tensor_attribute + + self.subgraphs: dict[str, torch.fx.GraphModule] = {} + + # Mapping of block to list of attributes that need to be lifted for each + # block + self.blocks_to_lifted_attrs = blocks_to_lifted_attrs + + # Populate methods for the standard operators. + for k in kind_to_standard_operators.keys(): + handler_func_name = ir_name_to_func_name(k) + # Create an indirect function call: + # convert__ --> lambda node: _convert_standard_operator(node) + setattr( + self, + handler_func_name, + lambda node: self._convert_standard_operators(node), + ) + + # This stores a list of return results that do not appear in the original TS + # graph's outputs. The reason we maintain this is because some operations in the sub-block + # might have inplace updates to the variable defined in the parent fx graph. After + # the execution of that sub-block, the variable defined in the parent fx graph also + # needs to be updated. + self.name_update_from_subblock_to_parent: set[str] = set() + + def _is_get_attr_node(self, fqn): + return ( + fqn in self.name_to_buffer + or fqn in self.name_to_param + or ( + fqn in self.name_to_constant + and isinstance(self.name_to_constant[fqn], torch.ScriptObject) + ) + ) + + def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: list[str]): + subgraph_nodes, subgraph_converters = [], [] + for block in node.blocks(): + subgraph_converter = TS2FXGraphConverter( + block, + self.name_to_param, + self.name_to_buffer, + self.blocks_to_lifted_attrs, + {}, + self.name_to_constant, + self.name_to_attribute_fqn, + ) + + for block_arg in arguments: + normalized_block_arg_name = normalize_name(block_arg) + placeholder_node = subgraph_converter.fx_graph.placeholder( + normalized_block_arg_name + ) + subgraph_converter.name_to_node[block_arg] = placeholder_node + + subgraph = subgraph_converter.convert() + subgraph_name = self.add_subgraph(subgraph) + subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name)) + subgraph_converters.append(subgraph_converter) + return subgraph_nodes, subgraph_converters + + def _identify_inputs_as_arguments(self, entry): + """ + Identify inputs from the innermost sub-block. This is needed + for nested sub-blocks when the input is hidden in the nested sub-block. + E.g., example IR of input is hidden in the nested sub-block. + Graph[x.1] + %1 = ... + Block[] + Block[x.1] + %2 = x.1 ... + """ + arguments: set[str] = set() + for block in entry.blocks(): + for block_node in block.nodes(): + for block_node_in in block_node.inputs(): + if ( + block_node_in.debugName() in self.name_to_node + and block_node_in.debugName() not in self.name_to_attribute_fqn + ): + arguments.add(block_node_in.debugName()) + arguments = arguments.union( + self._identify_inputs_as_arguments(block_node) + ) + return arguments + + def is_top_level_graph(self): + return isinstance(self.ts_graph, torch._C.Graph) + + def add_subgraph(self, subgraph) -> str: + name = f"subgraph_{len(self.subgraphs)}" + self.subgraphs[name] = subgraph + return name + + def get_args_kwargs(self, node: torch._C.Node, schema): + args = [] + kwargs = {} + for input, schema_arg in zip(node.inputs(), schema.arguments): + if schema_arg.kwarg_only: + kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input) + else: + args.append(self.get_fx_value_by_ir_value(input)) + + return tuple(args), kwargs + + def get_fx_value_by_ir_value(self, value: torch._C.Value): + value_name = value.debugName() + + if value_name in self.name_to_node: + input_node = self.name_to_node[value_name] + return input_node + elif value_name in self.name_to_constant: + if isinstance(self.name_to_constant[value_name], torch.ScriptObject): + return self.fx_graph.get_attr(value_name) + return self.name_to_constant[value_name] + elif value_name in self.name_to_attribute_fqn: + return self.get_fx_value_by_fqn(self.name_to_attribute_fqn[value_name]) + else: + raise ValueError(f"Input {value_name} not found") + + def get_fx_value_by_fqn(self, name): + if name in self.name_to_node: + fx_node = self.name_to_node[name] + elif name in self.name_to_constant: + fx_node = self.name_to_constant[name] + elif name in self.name_to_non_tensor_attribute_node: + fx_node = self.name_to_non_tensor_attribute_node[name] + elif name in self.name_to_non_tensor_attribute: + fx_node = self.name_to_non_tensor_attribute[name] + else: + raise ValueError(f"Attribute {name} not found") + return fx_node + + def convert(self) -> torch.fx.GraphModule: + self.convert_graph_inputs() + + for node in self.ts_graph.nodes(): + self.convert_node(node) + + self.convert_graph_outputs() + + # Pass parameter and buffer to the root for lookup. + gm = torch.fx.GraphModule( + { + **self.subgraphs, + **self.name_to_param, + **self.name_to_buffer, + **self.name_to_non_tensor_attribute, + **self.name_to_constant, + }, + self.fx_graph, + ) + + inplace_optimize_sym_size_div(gm) + + gm.graph.lint() + + return gm + + def convert_graph_inputs(self): + for graph_input in self.ts_graph.inputs(): + name = graph_input.debugName() + + if name in self.name_to_param: + normalized_name = normalize_name(name) + self.input_specs.append( + InputSpec( + InputKind.PARAMETER, + arg=TensorArgument(name=normalized_name), + target=name, + ) + ) + fx_node = get_node_as_placeholder_or_get_attr( + self.fx_graph, name, self.is_top_level_graph() + ) + elif name in self.name_to_buffer: + normalized_name = normalize_name(name) + self.input_specs.append( + InputSpec( + InputKind.BUFFER, + arg=TensorArgument(name=normalized_name), + target=name, + persistent=True, + ) + ) + fx_node = get_node_as_placeholder_or_get_attr( + self.fx_graph, name, self.is_top_level_graph() + ) + elif name in self.name_to_constant: + assert isinstance( + self.name_to_constant[name], torch.ScriptObject + ), "Input conversion only handles ScriptObject" + normalized_name = normalize_name(name) + self.input_specs.append( + InputSpec( + InputKind.CUSTOM_OBJ, + arg=CustomObjArgument( + name=normalized_name, class_fqn=normalized_name + ), + target=name, + persistent=False, + ) + ) + fx_node = get_node_as_placeholder_or_get_attr( + self.fx_graph, name, self.is_top_level_graph() + ) + elif isinstance(graph_input.type(), torch.ClassType): + # Directly skip inputs that are ScriptObject but not used in the graph. + continue + else: + normalized_name = normalize_name(name, prefix="input") + self.input_specs.append( + InputSpec( + InputKind.USER_INPUT, + arg=TensorArgument(name=normalized_name), + target=name, + ) + ) + fx_node = self.fx_graph.placeholder(normalized_name) + + self.name_to_node[name] = fx_node + + def convert_aten_Float(self, node: torch._C.Node): + def to_float_tensor(t): + return t.to(dtype=torch.float).item() + + inp_list = [ + self.get_fx_value_by_ir_value(inp) for inp in node.inputs() + ] # noqa: C416 + fx_node = self.fx_graph.call_function( + to_float_tensor, + tuple(inp_list), + ) + self.name_to_node[node.output().debugName()] = fx_node + + def convert_aten_tensor(self, node: torch._C.Node): + """aten::tensor creates a constant tensor ad-hoc --> GetAttr""" + args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema) + + for k in kwargs: + if k == "requires_grad": + kwargs[k] = bool(kwargs[k]) # 0 -> False, 1 -> True + + to_tensor = ( + torch.tensor + if all(isinstance(a, int) for a in args) + else torch._refs.tensor + ) + + def target(*args, **kwargs): + if "dtype" in kwargs and kwargs["dtype"] is not None: + kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]] + return to_tensor(*args, **kwargs) + + # def to_dynamic_tensor(*args, **kwargs): + # if "dtype" in kwargs and kwargs["dtype"] is not None: + # kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]] + # return torch._refs.tensor(*args, **kwargs) + + output_name = node.output().debugName() + fx_node = self.fx_graph.call_function(target, args, kwargs) + self.name_to_node[output_name] = fx_node + + def convert_aten_append(self, node: torch._C.Node): + # special handle python list append: "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)" + + # inplace append to the list!! This is kinda crazy, as we are inplace mutating the list + # This makes the converter "non-functional", and the result depends on the order of the nodes being converter + # In a sense, the converter now becomes an stateful interpreter + warnings.warn( + "Converting aten::append.t, which is a inplace mutation of the list. " + "This makes the converter non-functional: the result depends on the order of the append nodes being converter!" + ) + + args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs()) + fx_node = self.fx_graph.call_function(list_append, args) + self.name_to_node[node.output().debugName()] = fx_node + + # inplace mutate arg[0], which is the python list + self.name_to_node[node.inputsAt(0).debugName()] = fx_node + + # Variables that need to be updated to parent module. + if not self.is_top_level_graph() and args[0].op == "placeholder": + self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName()) + + def convert_prim_Constant(self, node: torch._C.Node): + name = node.output().debugName() + + value: Any = None + if node.hasAttribute("value"): + constant_kind = node.kindOf("value") + if constant_kind == "i": + value = node.i("value") + elif constant_kind == "f": + value = node.f("value") + elif constant_kind == "s": + value = node.s("value") + elif constant_kind == "t": + alias_name = ( + f"lifted_tensor_{name}" # Follow naming convention from EP tracing. + ) + fx_node = self.fx_graph.get_attr(alias_name) + self.name_to_node[name] = fx_node + name, value = alias_name, node.t("value") + elif constant_kind == "ival": + value = node.ival("value") + else: + raise ValueError(f"Unsupported constant type: {node.kindOf('value')}") + else: + value = None + + self.name_to_constant[name] = value + + def convert_prim_CallMethod(self, node: torch._C.Node): + inp_list = [ + self.get_fx_value_by_ir_value(inp) for inp in node.inputs() + ] # noqa: C416 + fx_node = self.fx_graph.call_method( + node.s("name"), + tuple(inp_list), + ) + self.name_to_node[node.output().debugName()] = fx_node + + def convert_prim_device(self, node: torch._C.Node): + input_type = node.input().type() + if input_type.isSubtypeOf(torch._C.TensorType.get()): + device = input_type.device() # type: ignore[attr-defined] + output_name = node.output().debugName() + self.name_to_constant[output_name] = device + else: + raise ValueError(f"Unsupported JitType ({input_type}) when get device") + + def convert_prim_GetAttr(self, node: torch._C.Node): + # Build fully qulified name + attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) + output_name = node.output().debugName() + self.name_to_attribute_fqn[output_name] = attr_fqn + + if self.is_top_level_graph(): + if self._is_get_attr_node(attr_fqn): + # We insert a get_attr node due to two reasons. + # First, ts graph does not lift tensor constants as input nodes. So + # tensor constants may be ignored by in convert_graph_inputs(). + # Second, attr_fqn may have been written to via SetAttr. Two + # GetAttr may give different values. + self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn) + else: + if attr_fqn not in self.name_to_non_tensor_attribute_node: + self.name_to_non_tensor_attribute_node[ + attr_fqn + ] = self.name_to_non_tensor_attribute[attr_fqn] + self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[ + attr_fqn + ] + else: + # Special support for if blocks which do not allow SetAttr TorchScript + # node and get_attr FX Graph Node. + if self._is_get_attr_node(attr_fqn): + self.name_to_node[output_name] = self.name_to_node[attr_fqn] + + def convert_prim_SetAttr(self, node: torch._C.Node): + attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) + attr_value = tuple(node.inputs())[1] + ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value) + if self._is_get_attr_node(attr_fqn): + fx_attr_node = self.fx_graph.get_attr(attr_fqn) + self.fx_graph.call_function( + torch.Tensor.copy_, (fx_attr_node, ts_graph_tensor_input) + ) + else: + self.name_to_non_tensor_attribute_node[attr_fqn] = ts_graph_tensor_input + + def convert_call_function_op(self, node: torch._C.Node): + target = get_op_overload(node) + + args, kwargs = self.get_args_kwargs(node, target._schema) + + fx_node = self.fx_graph.call_function(target, args, kwargs) + + # TODO: covnert sourceRange() into stack_trace + # fx_node.meta["stack_trace"] = node.sourceRange() + + if node.outputsSize() == 1: + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + else: + for i, outp in enumerate(node.outputs()): + output_name = outp.debugName() + next_fx_node = self.fx_graph.call_function( + operator.getitem, (fx_node, i) + ) + self.name_to_node[output_name] = next_fx_node + + def convert_prim_TupleConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + + def convert_prim_ListConstruct(self, node: torch._C.Node): + self._convert_prim_iterator(node) + + def _convert_prim_iterator(self, node: torch._C.Node): + output_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] + + output_name = node.output().debugName() + self.name_to_node[output_name] = output_list + + def convert_prim_DictConstruct(self, node: torch._C.Node): + output_dict = {} + k, v = None, None + for i, inp in enumerate(node.inputs()): + # We assume key value are stored in pair in the DictConstruct. + # The first element is the key and the following is the value. + if i % 2 == 0: + k = self.get_fx_value_by_ir_value(inp) + else: + v = self.get_fx_value_by_ir_value(inp) + assert ( + k is not None and v is not None + ), "DictConstruct has an empty key value pair." + output_dict[k] = v + k, v = None, None + + assert ( + k is None and v is None + ), "DictConstruct has an odd number of elements (violating our assumption)." + + output_name = node.output().debugName() + self.name_to_node[output_name] = output_dict + + def convert_prim_ListUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def convert_prim_TupleUnpack(self, node: torch._C.Node): + self._convert_prim_unpack_iterator(node) + + def _convert_prim_unpack_iterator(self, node: torch._C.Node): + # Single input and multiple outputs for unpacking. + for i, outp in enumerate(node.outputs()): + outp_name = outp.debugName() + inp = self.get_fx_value_by_ir_value(node.input()) + fx_node = self.fx_graph.call_function(operator.getitem, (inp, i)) + self.name_to_node[outp_name] = fx_node + + def convert_aten_Int(self, node: torch._C.Node): + # converts aten::Int as aten._to_copy + aten::_local_scalar_dense + target = torch.ops.aten._to_copy.default + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32}) + + fx_node = self.fx_graph.call_function( + torch.ops.aten._local_scalar_dense.default, (to_copy_node,) + ) + + # TODO: covnert sourceRange() into stack_trace + # fx_node.meta["stack_trace"] = node.sourceRange() + + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_NumToTensor(self, node: torch._C.Node): + # Converts prim::NumToTensor as aten.scalar_tensor. + # prim::NumToTensor IRs are currently triggered by: + # .size() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L950 + # .numel() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L971 + # For both of those APIs, torch.jit.trace implicitly sets the output tensor type + # to be LongTensor. + target = torch.ops.aten.scalar_tensor + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + + fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long}) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_CreateObject(self, node: torch._C.Node): + output_name = node.output().debugName() + self.name_to_attribute_fqn[output_name] = "" + + def convert_aten__convolution(self, node: torch._C.Node): + # converts aten::_convolution as aten.convolution, since aten::_convolution + # doesn't have a meta function + target = torch.ops.aten.convolution.default + args, kwargs = self.get_args_kwargs(node, target._schema) + + fx_node = self.fx_graph.call_function(target, args, kwargs) + + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_aten_div(self, node: torch._C.Node): + target = get_op_overload(node) + schema = target._schema + + args, kwargs = self.get_args_kwargs(node, schema) + + # converts aten::div.Tensor_mode(x, tensor_constant) + # as aten.div.Scalar_mode(x, tensor_constant.item()) + if schema.overload_name == "Tensor_mode": + arg1_name = args[1].name + if arg1_name in self.name_to_constant and isinstance( + self.name_to_constant[arg1_name], torch.Tensor + ): + tensor_constant = self.name_to_constant[arg1_name] + if tensor_constant.numel() == 1: + updated_args = list(args) + updated_args[1] = self.name_to_constant[arg1_name].item() + + fx_node = self.fx_graph.call_function( + torch.ops.aten.div.Scalar_mode, + tuple(updated_args), + kwargs, + ) + + # TODO: covnert sourceRange() into stack_trace + # fx_node.meta["stack_trace"] = node.sourceRange() + + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + return + + self.convert_call_function_op(node) + + def convert_aten___getitem__(self, node: torch._C.Node): + input_container, index = tuple( + self.get_fx_value_by_ir_value(input) for input in node.inputs() + ) + fx_node = self.fx_graph.call_function( + operator.getitem, (input_container, index) + ) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_aten_to(self, node: torch._C.Node): + target = get_op_overload(node) + args, _kwargs = self.get_args_kwargs(node, target._schema) + + # special handle aten.to.dtype and aten.to.prim_dtype followed by inplace_mutation_op + # coz aten.to + inplace_mutation_op pattern would trigger + # "cannot mutate tensors with frozen storage" functionalization error. + # To work around the issue, we override the copy to be True, so that the output + # is for sure not an alias of input + if target == torch.ops.aten.to.dtype or target == torch.ops.aten.to.prim_dtype: + user_nodes = [use.user for use in node.output().uses()] + user_targets = [ + get_op_overload(user_node) + for user_node in user_nodes + if user_node.schema() != "(no schema)" + ] + has_mutable_target = any( + target._schema.is_mutable for target in user_targets + ) + + if has_mutable_target: + assert len(args) >= 4 + new_args = list(args) + new_args[3] = True # copy, override to True + fx_node = self.fx_graph.call_function( + torch.ops.aten.to.dtype, tuple(new_args) + ) + # temp hack to work around the issue https://github.com/pytorch/pytorch/issues/131679 + # When this issue is fixed, the clone node would be no longer needed + clone_node = self.fx_graph.call_function( + torch.ops.aten.clone.default, (fx_node,) + ) + output_name = node.output().debugName() + self.name_to_node[output_name] = clone_node + return + + self.convert_call_function_op(node) + + def convert_aten_add(self, node: torch._C.Node): + if node.schema() == "(no schema)": + if isinstance(node.inputsAt(0).type(), torch.ListType) and isinstance( + node.inputsAt(1).type(), torch.ListType + ): + target = torch.ops.aten.add.t + else: + raise RuntimeError(f"unable to determind the target for {node}") + else: + target = get_op_overload(node) + + if target == torch.ops.aten.add.t: + # special handle python list/tuple add: "aten::add.t(t[] a, t[] b) -> t[]" for + # RuntimeError: aten::add() Expected a value of type 'List[t]' for argument 'a' but instead found type 'immutable_list'. + args, _kwargs = self.get_args_kwargs(node, target._schema) + output_name = node.output().debugName() + self.name_to_node[output_name] = self.fx_graph.call_function(list_add, args) + else: + self.convert_call_function_op(node) + + def _check_prim_loop_support(self, node): + inputs = list(node.inputs()) + + # TODO: (1/N) stage. + if inputs[0].debugName() not in self.name_to_constant: + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of number of iterations." + ) + + # Make sure the condition is not updated in the subblock. + subblock = next(node.blocks()) + condition_output_name = next(subblock.outputs()).debugName() + for node in subblock.nodes(): + if ( + node.outputsSize() == 1 + and node.output().debugName() == condition_output_name + ): + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of condition." + ) + if node.outputsSize() >= 2: + for outp in node.outputs(): + if outp.debugName() == condition_output_name: + raise RuntimeError( + "prim::Loop currently cannot run with dynamic value of condition." + ) + + def convert_prim_Loop(self, node: torch._C.Node): + inputs = list(node.inputs()) + self._check_prim_loop_support(node) + + num_iterations = self.get_fx_value_by_ir_value(inputs[0]) + + # Find inputs. + loop_local_arguments = [inp.debugName() for inp in inputs[2:]] + + global_arguments = self._identify_inputs_as_arguments(node) + + # Lift parameters as inputs. + for block in node.blocks(): + global_arguments = global_arguments.union( + self.blocks_to_lifted_attrs[block] + ) + + global_arguments = list(global_arguments) + + subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph( + node, global_arguments + ) + + assert len(subgraph_nodes) == 1 + subgraph_converter = subgraph_converters[0] + if not self.is_top_level_graph(): + self.name_update_from_subblock_to_parent = ( + self.name_update_from_subblock_to_parent.union( + subgraph_converter.name_update_from_subblock_to_parent + ) + ) + + fx_block_args = [ + self.get_fx_value_by_fqn(name) + for name in loop_local_arguments + global_arguments + ] + for iter_idx in range(num_iterations): + loop_node = self.fx_graph.call_function( + execute_subgraph_from_prim_loop, + # Check execute_node function for the expected arguments order. + ( + subgraph_nodes[0], + iter_idx, + len(loop_local_arguments), + *fx_block_args, + ), + {}, + ) + + # Update the value of loop local variables. + if node.outputsSize() >= 1: + for i, outp in enumerate(node.outputs()): + output_name = outp.debugName() + self.name_to_node[output_name] = self.fx_graph.call_function( + operator.getitem, + ( + loop_node, + i + 1, + ), # + 1 because the 0th element is the condition. + ) + fx_block_args[i] = self.name_to_node[output_name] + + # Update the value of global variables, whose values are modified inplace. + for i, name in enumerate( + subgraph_converter.name_update_from_subblock_to_parent + ): + self.name_to_node[name] = self.fx_graph.call_function( + operator.getitem, + ( + loop_node, + i + node.outputsSize() + 1, + ), # + 1 because the 0th element is the condition. + ) + global_argument_index = global_arguments.index(name) + fx_block_args[ + i + node.outputsSize() + global_argument_index + ] = self.name_to_node[name] + + def _check_set_attr_in_if_block(self, if_node: torch._C.Node): + for block in if_node.blocks(): + for node in block.nodes(): + if node.kind() == "prim::SetAttr": + raise RuntimeError( + "During converting prim::If to torch.cond, found prim::SetAttr op" + " which is not supported yet. Please file an issue if you come " + "across this error." + ) + + def convert_prim_If(self, node: torch._C.Node): + self._check_set_attr_in_if_block(node) + + inputs = list(node.inputs()) + assert len(inputs) == 1 + predicate = self.get_fx_value_by_ir_value(inputs[0]) + + # Find inputs. + arguments = self._identify_inputs_as_arguments(node) + + # Lift parameters as inputs. + for block in node.blocks(): + arguments = arguments.union(self.blocks_to_lifted_attrs[block]) + + arguments = list(arguments) + subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments) + + assert len(subgraph_nodes) == 2 + + fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments] + + args = ( + predicate, + subgraph_nodes[0], + subgraph_nodes[1], + tuple(fx_block_args), + ) + + cond_node = self.fx_graph.call_function(torch.cond, args, {}) + + # prim::If can also have zero output. + if node.outputsSize() == 1: + output_name = node.output().debugName() + self.name_to_node[output_name] = cond_node + elif node.outputsSize() > 1: + for i, output in enumerate(node.outputs()): + output_name = output.debugName() + getitem = self.fx_graph.call_function(operator.getitem, (cond_node, i)) + self.name_to_node[output_name] = getitem + + def convert_aten_Bool(self, node: torch._C.Node): + self._convert_as_noop(node) + + def convert_prim_Enter(self, node: torch._C.Node): + # export generally treats prim::Enter as noop + # The only context manager export supports is aten::enable_grad. + # Unfortunately, TorchScript does not support aten::enable_grad yet. + # TODO: support aten::enable_grad in both TorchScript and Converter. + return + + def convert_prim_Exit(self, node: torch._C.Node): + # export treats prim::Exit as noop + return + + def _convert_as_noop(self, node: torch._C.Node): + # Converts the node as a no-op by mapping its output node as arg[0] + + target = get_op_overload(node) + schema = target._schema + + args, _kwargs = self.get_args_kwargs(node, schema) + + output_name = node.output().debugName() + self.name_to_node[output_name] = args[0] + + def convert_profiler__record_function_exit(self, node: torch._C.Node): + # _record_function_exit has side effect so we keep it in fx.graph + # currently, _record_function_enter_new and _record_function_exit are + # discarded during `retrace_as_exported_program`. + target = torch.ops.profiler._record_function_exit + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + self.fx_graph.call_function(target, args) + + def convert_prim_tolist(self, node: torch._C.Node): + # prim::tolist cannot be supported by `_convert_standard_operators` + # since it requires call_method instead of call_function. + target = "tolist" + args = (self.get_fx_value_by_ir_value(next(node.inputs())),) + fx_node = self.fx_graph.call_method(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_prim_Uninitialized(self, node: torch._C.Node): + # `prim::Uninitialized` is inserted by the compiler when it can prove + # the value will never be used. It can be introduced by exceptions, + # breaks, continues, and returns. + # So we add a dummy constant to the graph. + output_name = node.output().debugName() + self.name_to_constant[output_name] = torch.Tensor() + + def _convert_standard_operators(self, node: torch._C.Node): + target = kind_to_standard_operators[node.kind()] + args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) + fx_node = self.fx_graph.call_function(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + + def convert_node(self, node: torch._C.Node): + node_kind = node.kind() + + # Get handler based on namespace and operator name. + # Provide a default node handler as well in case we don't find + # matching converter for that. + handler_func_name = ir_name_to_func_name(node_kind) + handler_func = getattr(self, handler_func_name, self.convert_call_function_op) + + # str calls print function implemented in CPP. To avoid repeating + # the entire logic here, we simply keep first line from node string (getting rid + # of sub-blocks IR prints). + node_str = "".join(str(node).split("\n")[:1]) + log.debug("[%s] converts [%s]", handler_func.__name__, node_str) + try: + handler_func(node) + except Exception as e: + raise RuntimeError(f"TS2EPConverter failed for node {node_kind}") from e + + def convert_graph_outputs(self): + args = [] + outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list( + self.name_update_from_subblock_to_parent + ) + for output_name in outp_name_list: + if output_name in self.name_to_node: + fx_node = self.name_to_node[output_name] + # TODO: Revisit this later after HigherOrderOp design changes. + # Currently, we cannot directly return input as output. + if ( + not self.is_top_level_graph() + and isinstance(fx_node, torch.fx.Node) + and fx_node.op == "placeholder" + ): + fx_node = self.fx_graph.call_function(torch.clone, (fx_node,)) + args.append(fx_node) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=TensorArgument(name=output_name), + target=output_name, + ) + ) + elif output_name in self.name_to_constant: + args.append(self.name_to_constant[output_name]) + self.output_specs.append( + OutputSpec( + OutputKind.USER_OUTPUT, + arg=ConstantArgument( + name=output_name, value=self.name_to_constant[output_name] + ), + target=output_name, + ) + ) + else: + raise ValueError(f"Output {output_name} not found") + + if len(args) == 0: + # Sub-block of prim::If can have zero output. + self.fx_graph.output([]) + elif len(args) == 1: + self.fx_graph.output( + args[0] + ) # Get rid of an extra list wrapped around final output. + elif len(args) > 1: + self.fx_graph.output( + args + ) # For prim::Loop and prim::If with multiple outputs. + else: + # Sub-block of prim::Loop can have multiple outputs. + self.fx_graph.output(args) + + +class ExplainTS2FXGraphConverter(TS2FXGraphConverter): + """ + Run TS2FXGraphConverter in an explain mode. It collects all failed operators conversions + and provide that information to users. In order to collect all failed conversions, it + also mocks some internal attributes (e.g., name_to_node). + """ + + class _DictMock(dict): + def __init__(self, dict_data, mock_value): + super().__init__(dict_data) + self.mock_value = mock_value + + def __getitem__(self, key): + # If the original dictionary has the key, return its value. + # Otherwise, return the mock value. + if not super().__contains__(key): + return self.mock_value + return super().__getitem__(key) + + def __contains__(self, key): + return True + + def __init__( + self, + ts_graph: Union[torch._C.Graph, torch._C.Block], + name_to_param: dict[str, torch.Tensor], + name_to_buffer: dict[str, torch.Tensor], + blocks_to_lifted_attrs: dict[torch._C.Block, set[str]], + name_to_non_tensor_attribute: dict[str, Any], + name_to_constant: dict[str, Any], + name_to_attribute_fqn: dict[str, str], + ): + super().__init__( + ts_graph, + name_to_param, + name_to_buffer, + blocks_to_lifted_attrs, + name_to_non_tensor_attribute, + name_to_constant, + name_to_attribute_fqn, + ) + + # Data to keep track of unsupported nodes. + self.unsupported_node_list: list[torch._C.Node] = [] + + # Add mock to needed attributes. + self.name_to_node = ExplainTS2FXGraphConverter._DictMock( + self.name_to_node, + # Dummy node. + torch.fx.Node( + None, # type: ignore[arg-type] + "mock", + "call_function", + lambda: None, + (), + {}, + ), + ) + + def explain(self): + self.convert_graph_inputs() + for node in self.ts_graph.nodes(): + self.convert_node(node) + self.convert_graph_outputs() + + def convert_node(self, node): + try: + super().convert_node(node) + except Exception: + self.unsupported_node_list.append(node) + + +@contextmanager +def disable_logging(log): + disabled = log.disabled + log.disabled = True + try: + yield + finally: + log.disabled = disabled + + +class TS2EPConverter: + # TorchScript model to ExportedProgram converter + def __init__( + self, + ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction], + sample_args: tuple[Any, ...], + sample_kwargs: Optional[dict[str, Any]] = None, + ): + self.ts_model = ts_model + self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) + + self.sample_args = sample_args + self.sample_kwargs = sample_kwargs + + self.name_to_param: dict[str, torch.Tensor] = {} + self.name_to_buffer: dict[str, torch.Tensor] = {} + param_list = ( + list(self.ts_model.parameters()) + if not isinstance(self.ts_model, torch._C.ScriptFunction) + else [] + ) + if not isinstance(self.ts_model, torch._C.ScriptFunction): + for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr] + # Check if tensor belongs to any parameter. + if any( + (tensor == param).all() + for param in param_list + if tensor.shape == param.shape + ): + self.name_to_param[k] = tensor + else: + self.name_to_buffer[k] = tensor + + self.name_to_non_tensor_attributes: dict[str, Any] = {} + self.name_to_constant: dict[str, Any] = {} + + self.lift_get_attr() + + def convert(self) -> ExportedProgram: + log.info( + """ +TS2EPConverter logging starts from here. + +INFO: (TORCH_LOGS="export" ) + * Log TorchScript IR. + +DEBUG: (TORCH_LOGS="+export" ), additionally + * Log conversion IR by IR in a format of [] converts []. + """ + ) + log.info("TorchScript graph\n\n%s\n", self.ts_graph) + + blocks_to_lifted_attrs, name_to_attribute_fqn = get_block_to_lifted_attrs( + self.ts_graph + ) + + graph_converter = TS2FXGraphConverter( + self.ts_graph, + self.name_to_param, + self.name_to_buffer, + blocks_to_lifted_attrs, + self.name_to_non_tensor_attributes, + self.name_to_constant, + name_to_attribute_fqn, + ) + gm = graph_converter.convert() + + # Post-proccessing step to deal with quantized operators. + replace_quantized_ops_with_standard_ops(gm) + log.info("GraphModule: %s", gm.print_readable(print_output=False)) + + ep = self.retrace_as_exported_program( + gm, + graph_converter.name_to_constant, + ) + log.info("%s", ep) + + # Post-processing step to ensure ExportedProgram has the same state_dict as + # the original TorchScript model. Throw warnings for additionally populated + # state_dict entries. + if not isinstance(self.ts_model, torch._C.ScriptFunction): + for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr] + if k not in ep.state_dict: + warnings.warn( + f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram." + ) + ep.state_dict[k] = tensor + + return ep + + @disable_logging(log) + def explain(self, print_output=True): + blocks_to_lifted_attrs, name_to_attribute_fqn = get_block_to_lifted_attrs( + self.ts_graph + ) + + graph_converter = ExplainTS2FXGraphConverter( + self.ts_graph, + self.name_to_param, + self.name_to_buffer, + blocks_to_lifted_attrs, + self.name_to_non_tensor_attributes, + self.name_to_constant, + name_to_attribute_fqn, + ) + graph_converter.explain() + if len(graph_converter.unsupported_node_list) > 0: + explain_str = "Unsupported nodes are found in the following list:" + for i, n in enumerate(graph_converter.unsupported_node_list): + node_str = "".join(str(n).split("\n")[:1]) + explain_str += f"\n\n {i}. {n.kind()} [{node_str}]" + else: + explain_str = "Success!" + if print_output: + print(explain_str) + return explain_str + + def retrace_as_exported_program( + self, + gm: torch.fx.GraphModule, + name_to_constant: dict[str, Any], + ): + dynamic_shapes = _tree_map_with_path( + lambda path, x: ( + [Dim.AUTO] * x.dim() if isinstance(x, torch.Tensor) else None + ), + self.sample_args, + ) + + # TODO: adjust input orders to match GraphSignature convention + ep = torch.export._trace._export( + gm, + self.sample_args, + dynamic_shapes=dynamic_shapes, + strict=False, + pre_dispatch=True, + ) + + # Post-processing to make sure the ExportedProgram states are correct. + # Because during conversion, we set tensor constants as GetAttr, + # retracing cannot recognize them as tensor constants but instead + # treat them as buffers. We need to set them again here. + ep._constants.update( + { + k: v + for k, v in name_to_constant.items() + if isinstance(v, (torch.Tensor, torch.ScriptObject)) + } + ) + for k in name_to_constant: + ep.state_dict.pop(k, None) + + for spec in ep.graph_signature.input_specs: + # Mark as constant tensors for erroneously traced buffers. + if spec.kind == InputKind.BUFFER and spec.target in name_to_constant: + assert isinstance( + name_to_constant[spec.target], torch.Tensor + ), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer" + spec.kind = InputKind.CONSTANT_TENSOR + spec.persistent = None + ep.verifier().check(ep) + + return ep + + def lift_get_attr(self): + # This function lifts multiple data types. + + # 1. Tensor constants attributes (e.g., self.data = torch.tensor([2,3])) + # to buffers. Currently, when there are tensor constants, export + # would error and ask users to register tensor constants as buffers. + # Since it is hard to manually do so for TorchScript models + # (e.g., source code is missing), this function automatically + # lifts tensor constants to be buffers. + + # 2. ScriptObbject to constant. It will then be converted to getattr in + # in the fx graph. + # + # This function should happen in TS2EPConverter instead of + # TS2FXGraphConverter since it gets attributes from self.ts_model + # which is not accessable in TS2FXGraphConverter. It is similar to where + # we collect self.name_to_param and self.name_to_buffer. + name_to_attribute_fqn: dict[str, str] = {} + + def get_attr(fqn: str): + name = fqn.split(".") + v = self.ts_model + for n in name: + v = getattr(v, n) + return v + + def get_fqn(node: torch._C.Node): + attr_name = node.s("name") + input_name = node.input().debugName() + root_attr_name = name_to_attribute_fqn[input_name] + attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name + return attr_fqn + + def _dfs_get_attr(block): + for node in block.nodes(): + if node.kind() == "prim::CreateObject": + output_name = node.output().debugName() + name_to_attribute_fqn[output_name] = "" + + if node.kind() == "prim::GetAttr": + attr_fqn = get_fqn(node) + value = get_attr(attr_fqn) + output_name = node.output().debugName() + name_to_attribute_fqn[output_name] = attr_fqn + if isinstance(value, torch.Tensor): + if attr_fqn not in self.name_to_buffer: + # Lift tensor constants to be a buffer + self.name_to_buffer[attr_fqn] = value + elif isinstance(value, torch.ScriptObject): + if attr_fqn not in self.name_to_constant: + self.name_to_constant[attr_fqn] = value + else: + self.name_to_non_tensor_attributes[attr_fqn] = value + + for subblock in node.blocks(): + _dfs_get_attr(subblock) + + _dfs_get_attr(self.ts_graph) diff --git a/phivenv/Lib/site-packages/torch/_export/db/__init__.py b/phivenv/Lib/site-packages/torch/_export/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6272b424658450437a313fc71bedbce73da3205 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/phivenv/Lib/site-packages/torch/_export/db/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d37deafd57a9b90ca2c5ffd372c52cc0d57b9222 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/__pycache__/case.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/__pycache__/case.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dc009fada786392570fc9265ddc83803fd428e1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/__pycache__/case.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/__pycache__/gen_example.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/__pycache__/gen_example.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0feae73121b183910fb35995e32d5c9eef2ce734 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/__pycache__/gen_example.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/__pycache__/logging.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/__pycache__/logging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a71726745aa8430fa6bdf0935bae8fcf0c0debd7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/__pycache__/logging.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/case.py b/phivenv/Lib/site-packages/torch/_export/db/case.py new file mode 100644 index 0000000000000000000000000000000000000000..2899959427027f67bdbfd0fd4490c0f83b1eb401 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/case.py @@ -0,0 +1,174 @@ +# mypy: allow-untyped-defs +import inspect +import re +import string +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional +from types import ModuleType + +import torch + +_TAGS: dict[str, dict[str, Any]] = { + "torch": { + "cond": {}, + "dynamic-shape": {}, + "escape-hatch": {}, + "map": {}, + "dynamic-value": {}, + "operator": {}, + "mutation": {}, + }, + "python": { + "assert": {}, + "builtin": {}, + "closure": {}, + "context-manager": {}, + "control-flow": {}, + "data-structure": {}, + "standard-library": {}, + "object-model": {}, + }, +} + + +class SupportLevel(Enum): + """ + Indicates at what stage the feature + used in the example is handled in export. + """ + + SUPPORTED = 1 + NOT_SUPPORTED_YET = 0 + + +ArgsType = tuple[Any, ...] + + +def check_inputs_type(args, kwargs): + if not isinstance(args, tuple): + raise ValueError( + f"Expecting args type to be a tuple, got: {type(args)}" + ) + if not isinstance(kwargs, dict): + raise ValueError( + f"Expecting kwargs type to be a dict, got: {type(kwargs)}" + ) + for key in kwargs: + if not isinstance(key, str): + raise ValueError( + f"Expecting kwargs keys to be a string, got: {type(key)}" + ) + +def _validate_tag(tag: str): + parts = tag.split(".") + t = _TAGS + for part in parts: + assert set(part) <= set( + string.ascii_lowercase + "-" + ), f"Tag contains invalid characters: {part}" + if part in t: + t = t[part] + else: + raise ValueError(f"Tag {tag} is not found in registered tags.") + + +@dataclass(frozen=True) +class ExportCase: + example_args: ArgsType + description: str # A description of the use case. + model: torch.nn.Module + name: str + example_kwargs: dict[str, Any] = field(default_factory=dict) + extra_args: Optional[ArgsType] = None # For testing graph generalization. + # Tags associated with the use case. (e.g dynamic-shape, escape-hatch) + tags: set[str] = field(default_factory=set) + support_level: SupportLevel = SupportLevel.SUPPORTED + dynamic_shapes: Optional[dict[str, Any]] = None + + def __post_init__(self): + check_inputs_type(self.example_args, self.example_kwargs) + if self.extra_args is not None: + check_inputs_type(self.extra_args, {}) + + for tag in self.tags: + _validate_tag(tag) + + if not isinstance(self.description, str) or len(self.description) == 0: + raise ValueError(f'Invalid description: "{self.description}"') + + +_EXAMPLE_CASES: dict[str, ExportCase] = {} +_MODULES: set[ModuleType] = set() +_EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {} +_EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {} + + +def register_db_case(case: ExportCase) -> None: + """ + Registers a user provided ExportCase into example bank. + """ + if case.name in _EXAMPLE_CASES: + if case.name not in _EXAMPLE_CONFLICT_CASES: + _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]] + _EXAMPLE_CONFLICT_CASES[case.name].append(case) + return + + _EXAMPLE_CASES[case.name] = case + + +def to_snake_case(name): + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + +def _make_export_case(m, name, configs): + if not isinstance(m, torch.nn.Module): + raise TypeError("Export case class should be a torch.nn.Module.") + + if "description" not in configs: + # Fallback to docstring if description is missing. + assert ( + m.__doc__ is not None + ), f"Could not find description or docstring for export case: {m}" + configs = {**configs, "description": m.__doc__} + return ExportCase(**{**configs, "model": m, "name": name}) + + +def export_case(**kwargs): + """ + Decorator for registering a user provided case into example bank. + """ + + def wrapper(m): + configs = kwargs + module = inspect.getmodule(m) + if module in _MODULES: + raise RuntimeError("export_case should only be used once per example file.") + + assert module is not None + _MODULES.add(module) + module_name = module.__name__.split(".")[-1] + case = _make_export_case(m, module_name, configs) + register_db_case(case) + return case + + return wrapper + + +def export_rewrite_case(**kwargs): + def wrapper(m): + configs = kwargs + + parent = configs.pop("parent") + assert isinstance(parent, ExportCase) + key = parent.name + if key not in _EXAMPLE_REWRITE_CASES: + _EXAMPLE_REWRITE_CASES[key] = [] + + configs["example_args"] = parent.example_args + case = _make_export_case(m, to_snake_case(m.__name__), configs) + _EXAMPLE_REWRITE_CASES[key].append(case) + return case + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__init__.py b/phivenv/Lib/site-packages/torch/_export/db/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04c14e2e0e1f7a3711b8e7283a8ed970aee01b4a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/__init__.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import dataclasses +import glob +import inspect +from os.path import basename, dirname, isfile, join + +import torch +from torch._export.db.case import ( + _EXAMPLE_CASES, + _EXAMPLE_CONFLICT_CASES, + _EXAMPLE_REWRITE_CASES, + SupportLevel, + export_case, + ExportCase, +) + + +def _collect_examples(): + case_names = glob.glob(join(dirname(__file__), "*.py")) + case_names = [ + basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py") + ] + + case_fields = {f.name for f in dataclasses.fields(ExportCase)} + for case_name in case_names: + case = __import__(case_name, globals(), locals(), [], 1) + variables = [name for name in dir(case) if name in case_fields] + export_case(**{v: getattr(case, v) for v in variables})(case.model) + +_collect_examples() + +def all_examples(): + return _EXAMPLE_CASES + + +if len(_EXAMPLE_CONFLICT_CASES) > 0: + + def get_name(case): + model = case.model + if isinstance(model, torch.nn.Module): + model = type(model) + return model.__name__ + + msg = "Error on conflict export case name.\n" + for case_name, cases in _EXAMPLE_CONFLICT_CASES.items(): + msg += f"Case name {case_name} is associated with multiple cases:\n " + msg += f"[{','.join(map(get_name, cases))}]\n" + + raise RuntimeError(msg) + + +def filter_examples_by_support_level(support_level: SupportLevel): + return { + key: val + for key, val in all_examples().items() + if val.support_level == support_level + } + + +def get_rewrite_cases(case): + return _EXAMPLE_REWRITE_CASES.get(case.name, []) diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b9b5d2baa044db10c7df79373bc8a59c542b40b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3f5634ac7eb5a754127d3a27c660a575c2e778d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/assume_constant_result.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc37a7dad9bc327b5c84e1e4453535633a6eddbb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/autograd_function.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c27c18468b5dcbf9ccab6aa03b33784c894b1c5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/class_method.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..624640e4dbd8126f81b31b77be7279ca48dc9e71 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..073ea6f94a5b221a151b1e41be7e44a60d712f34 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb5a9d134bc2e9bc15ee766db4bb6f446effe448 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1b2aa38ed7fd9ec085d655860be0800042ca49b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3328bd9ee91e43560139f775b7bb728afc54312 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_operands.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20a1b3ea201bf00fac4d4d933ab788a7ac07f431 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/cond_predicate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e9d45caab2f426ba6fbabbff42bc537fbfed898 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1f8a07e420a28d4cab2861b4bbc8380b0f8e31c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31641d306e73902435856f83053828c973951495 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/decorator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..423fecd500456c3c2e07c3d7ae5b0ba3f8ce266b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dictionary.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7521c81c9c845048f01f631033e6089d3a076bd3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eaab98a94070bc9b19d249af9d30064522ec8c6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ebf98583a53e501fb2be5ccad0472974a706b47 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c59f3f83fab905afb00fc2f625e4c8dc4a98bc3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c742a313f5ed4ab3c00cce6dbb4863d8d92ff4b7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21e938682f2775f08118ec3cea7835bf7e85ebb2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..545724d9353f6d5006c0d1a3d201b43134b2c229 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..091b284f7dd4c1f2ab338091a72692401adb586e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..620ea8a1d5fc7e5c841106317d1aa44549ecfdde Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/list_contains.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..128c8bb5bbfb93ea126c939cc3ab08986960651a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/list_unpack.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe995c0ee06a4e898a5e6589d387aac5a9f6b1fc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82b8c3c95bbe9fe4c08dd2152fe590b74f9eaad6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/nested_function.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63e86de1253870c0f433f73281e0c0a3c651a75f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/null_context_manager.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60498bb48a21b669323114848b3c11f48c63c5a0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc855d9ca0e0cbb88c8893205e517c03eb2ff45a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/pytree_flatten.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae8edaf08a7f535f6844251ad77583a845488f88 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7e8006ebeb085bb522d99244bebb7bb17b1a65e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/specialized_attribute.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95260a83e27e89944c88f4e3b70cfb77cf8e7bb9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d79479f67e5184e864adf25bd613be0c49935562 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82257ff127a793aa87db2c8c7e26987cf86a4329 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/tensor_setattr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ab0447b56761604c351d8d4b6494093b062a853 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c9845b1aa6f8f609141af50fe654e60fcc7a67f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/unsupported_operator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e8e278ef87c6f21a9bd5c363ea163c8d149d70b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/db/examples/__pycache__/user_input_mutation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/assume_constant_result.py b/phivenv/Lib/site-packages/torch/_export/db/examples/assume_constant_result.py new file mode 100644 index 0000000000000000000000000000000000000000..76d28d00442ceea77f52a5c3608f8ba19128b2ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/assume_constant_result.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs +import torch +import torch._dynamo as torchdynamo + + +class AssumeConstantResult(torch.nn.Module): + """ + Applying `assume_constant_result` decorator to burn make non-tracable code as constant. + """ + + @torchdynamo.assume_constant_result + def get_item(self, y): + return y.int().item() + + def forward(self, x, y): + return x[: self.get_item(y)] + +example_args = (torch.randn(3, 2), torch.tensor(4)) +tags = {"torch.escape-hatch"} +model = AssumeConstantResult() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/autograd_function.py b/phivenv/Lib/site-packages/torch/_export/db/examples/autograd_function.py new file mode 100644 index 0000000000000000000000000000000000000000..c54bb1a2f227a8469251966b84d9e7247b916206 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/autograd_function.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + +class MyAutogradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.clone() + + @staticmethod + def backward(ctx, grad_output): + return grad_output + 1 + +class AutogradFunction(torch.nn.Module): + """ + TorchDynamo does not keep track of backward() on autograd functions. We recommend to + use `allow_in_graph` to mitigate this problem. + """ + + def forward(self, x): + return MyAutogradFunction.apply(x) + +example_args = (torch.randn(3, 2),) +model = AutogradFunction() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/class_method.py b/phivenv/Lib/site-packages/torch/_export/db/examples/class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..dafaee246c6b5a235190f171a7762d3b57de182a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/class_method.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +class ClassMethod(torch.nn.Module): + """ + Class methods are inlined during tracing. + """ + + @classmethod + def method(cls, x): + return x + 1 + + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 2) + + def forward(self, x): + x = self.linear(x) + return self.method(x) * self.__class__.method(x) * type(self).method(x) + +example_args = (torch.randn(3, 4),) +model = ClassMethod() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/cond_branch_class_method.py b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_branch_class_method.py new file mode 100644 index 0000000000000000000000000000000000000000..2aeabfce38ca40653dfadf7389bfafa4a088ea92 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_branch_class_method.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class MySubModule(torch.nn.Module): + def foo(self, x): + return x.cos() + + def forward(self, x): + return self.foo(x) + +class CondBranchClassMethod(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + + This example demonstrates using class method in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def __init__(self) -> None: + super().__init__() + self.subm = MySubModule() + + def bar(self, x): + return x.sin() + + def forward(self, x): + return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) + +example_args = (torch.randn(3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchClassMethod() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/cond_branch_nested_function.py b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_branch_nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..000077db07fe27731db90775cf3ceeca2baae32f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_branch_nested_function.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondBranchNestedFunction(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates using nested function in cond(). + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + def true_fn(x): + def inner_true_fn(y): + return x + y + + return inner_true_fn(x) + + def false_fn(x): + def inner_false_fn(y): + return x - y + + return inner_false_fn(x) + + return cond(x.shape[0] < 10, true_fn, false_fn, [x]) + +example_args = (torch.randn(3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchNestedFunction() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..77834eb616cc4397217da20591dec5bf679f1c64 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_branch_nonlocal_variables.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondBranchNonlocalVariables(torch.nn.Module): + """ + The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: + - both branches must take the same args, which must also match the branch args passed to cond. + - both branches must return a single tensor + - returned tensor must have the same tensor metadata, e.g. shape and dtype + - branch function can be free function, nested function, lambda, class methods + - branch function can not have closure variables + - no inplace mutations on inputs or global variables + + This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. + + The code below will not work because capturing closure variables is not supported. + ``` + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y + my_tensor_var + my_primitive_var + + def false_fn(y): + nonlocal my_tensor_var, my_primitive_var + return y - my_tensor_var - my_primitive_var + + return cond(x.shape[0] > 5, true_fn, false_fn, [x]) + ``` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + my_tensor_var = x + 100 + my_primitive_var = 3.14 + + def true_fn(x, y, z): + return x + y + z + + def false_fn(x, y, z): + return x - y - z + + return cond( + x.shape[0] > 5, + true_fn, + false_fn, + [x, my_tensor_var, torch.tensor(my_primitive_var)], + ) + +example_args = (torch.randn(6),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondBranchNonlocalVariables() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/cond_closed_over_variable.py b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_closed_over_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..86aa04a1d4eeb83fe26615c90a1ac7499507cd32 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_closed_over_variable.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondClosedOverVariable(torch.nn.Module): + """ + torch.cond() supports branches closed over arbitrary variables. + """ + + def forward(self, pred, x): + def true_fn(val): + return x * 2 + + def false_fn(val): + return x - 2 + + return cond(pred, true_fn, false_fn, [x + 1]) + +example_args = (torch.tensor(True), torch.randn(3, 2)) +tags = {"torch.cond", "python.closure"} +model = CondClosedOverVariable() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/cond_operands.py b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_operands.py new file mode 100644 index 0000000000000000000000000000000000000000..bd975ea92deb8d9f35e95eb854cf508b5f360e64 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_operands.py @@ -0,0 +1,35 @@ +# mypy: allow-untyped-defs +import torch + +from torch.export import Dim + +x = torch.randn(3, 2) +y = torch.randn(2) +dim0_x = Dim("dim0_x") + +class CondOperands(torch.nn.Module): + """ + The operands passed to cond() must be: + - a list of tensors + - match arguments of `true_fn` and `false_fn` + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x, y): + def true_fn(x, y): + return x + y + + def false_fn(x, y): + return x - y + + return torch.cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) + +example_args = (x, y) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +extra_inputs = (torch.randn(2, 2), torch.randn(2)) +dynamic_shapes = {"x": {0: dim0_x}, "y": None} +model = CondOperands() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/cond_predicate.py b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_predicate.py new file mode 100644 index 0000000000000000000000000000000000000000..1e34f083dc47f6a877ac59e802cb55571fe94451 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/cond_predicate.py @@ -0,0 +1,25 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import cond + +class CondPredicate(torch.nn.Module): + """ + The conditional statement (aka predicate) passed to cond() must be one of the following: + - torch.Tensor with a single element + - boolean expression + + NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. + """ + + def forward(self, x): + pred = x.dim() > 2 and x.shape[2] > 10 + + return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) + +example_args = (torch.randn(6, 4, 3),) +tags = { + "torch.cond", + "torch.dynamic-shape", +} +model = CondPredicate() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/constrain_as_size_example.py b/phivenv/Lib/site-packages/torch/_export/db/examples/constrain_as_size_example.py new file mode 100644 index 0000000000000000000000000000000000000000..eda0a4e3e019d5267327684d8ced81c54ff723b9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/constrain_as_size_example.py @@ -0,0 +1,25 @@ +# mypy: allow-untyped-defs +import torch + + +class ConstrainAsSizeExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at torch._check and torch._check_is_size APIs. + torch._check_is_size is used for values that NEED to be used for constructing + tensor. + """ + + def forward(self, x): + a = x.item() + torch._check_is_size(a) + torch._check(a <= 5) + return torch.zeros((a, 5)) + + +example_args = (torch.tensor(4),) +tags = { + "torch.dynamic-value", + "torch.escape-hatch", +} +model = ConstrainAsSizeExample() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/constrain_as_value_example.py b/phivenv/Lib/site-packages/torch/_export/db/examples/constrain_as_value_example.py new file mode 100644 index 0000000000000000000000000000000000000000..893cd2b183c0f79ca1781ffc90328e336b7ff7f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/constrain_as_value_example.py @@ -0,0 +1,28 @@ +# mypy: allow-untyped-defs +import torch + + +class ConstrainAsValueExample(torch.nn.Module): + """ + If the value is not known at tracing time, you can provide hint so that we + can trace further. Please look at torch._check and torch._check_is_size APIs. + torch._check is used for values that don't need to be used for constructing + tensor. + """ + + def forward(self, x, y): + a = x.item() + torch._check(a >= 0) + torch._check(a <= 5) + + if a < 6: + return y.sin() + return y.cos() + + +example_args = (torch.tensor(4), torch.randn(5, 5)) +tags = { + "torch.dynamic-value", + "torch.escape-hatch", +} +model = ConstrainAsValueExample() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/decorator.py b/phivenv/Lib/site-packages/torch/_export/db/examples/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..0f893d0c1d58d7a95987cc965a99c61fbab31160 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/decorator.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import functools + +import torch + +def test_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + 1 + + return wrapper + +class Decorator(torch.nn.Module): + """ + Decorators calls are inlined into the exported function during tracing. + """ + + @test_decorator + def forward(self, x, y): + return x + y + +example_args = (torch.randn(3, 2), torch.randn(3, 2)) +model = Decorator() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/dictionary.py b/phivenv/Lib/site-packages/torch/_export/db/examples/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..3d537afb4103056d13e8f3ff8527d8f072879a37 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/dictionary.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class Dictionary(torch.nn.Module): + """ + Dictionary structures are inlined and flattened along tracing. + """ + + def forward(self, x, y): + elements = {} + elements["x2"] = x * x + y = y * elements["x2"] + return {"y": y} + +example_args = (torch.randn(3, 2), torch.tensor(4)) +tags = {"python.data-structure"} +model = Dictionary() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_assert.py b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..a9637b3e19279d7f50032faa6532b9812361b665 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_assert.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeAssert(torch.nn.Module): + """ + A basic usage of python assertion. + """ + + def forward(self, x): + # assertion with error message + assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2" + # assertion without error message + assert x.shape[0] > 1 + return x + +example_args = (torch.randn(3, 2),) +tags = {"python.assert"} +model = DynamicShapeAssert() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..6b9a5b7507c862b8dfed21a075327abe16bab3c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_constructor.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeConstructor(torch.nn.Module): + """ + Tensor constructors should be captured with dynamic shape inputs rather + than being baked in with static shape. + """ + + def forward(self, x): + return torch.zeros(x.shape[0] * 2) + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeConstructor() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py new file mode 100644 index 0000000000000000000000000000000000000000..8afa65a5adf30efaeb8a069fd01781031c425b7d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_if_guard.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeIfGuard(torch.nn.Module): + """ + `if` statement with backed dynamic shape predicate will be specialized into + one particular branch and generate a guard. However, export will fail if the + the dimension is marked as dynamic shape from higher level API. + """ + + def forward(self, x): + if x.shape[0] == 3: + return x.cos() + + return x.sin() + +example_args = (torch.randn(3, 2, 2),) +tags = {"torch.dynamic-shape", "python.control-flow"} +model = DynamicShapeIfGuard() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_map.py b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_map.py new file mode 100644 index 0000000000000000000000000000000000000000..ac5b4e2130e06c228ee72d6c0cb6ec713f8f6677 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_map.py @@ -0,0 +1,19 @@ +# mypy: allow-untyped-defs +import torch + +from functorch.experimental.control_flow import map + +class DynamicShapeMap(torch.nn.Module): + """ + functorch map() maps a function over the first tensor dimension. + """ + + def forward(self, xs, y): + def body(x, y): + return x + y + + return map(body, xs, y) + +example_args = (torch.randn(3, 2), torch.randn(2)) +tags = {"torch.dynamic-shape", "torch.map"} +model = DynamicShapeMap() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_round.py b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_round.py new file mode 100644 index 0000000000000000000000000000000000000000..3654bb9ea21c49c9aea1debecb1c591091f5aa81 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_round.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +import torch + +from torch._export.db.case import SupportLevel +from torch.export import Dim + +class DynamicShapeRound(torch.nn.Module): + """ + Calling round on dynamic shapes is not supported. + """ + + def forward(self, x): + return x[: round(x.shape[0] / 2)] + +x = torch.randn(3, 2) +dim0_x = Dim("dim0_x") +example_args = (x,) +tags = {"torch.dynamic-shape", "python.builtin"} +support_level = SupportLevel.NOT_SUPPORTED_YET +dynamic_shapes = {"x": {0: dim0_x}} +model = DynamicShapeRound() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py new file mode 100644 index 0000000000000000000000000000000000000000..d30efba0a04606635760719a02c83aa5b851e689 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_slicing.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeSlicing(torch.nn.Module): + """ + Slices with dynamic shape arguments should be captured into the graph + rather than being baked in. + """ + + def forward(self, x): + return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeSlicing() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_view.py b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_view.py new file mode 100644 index 0000000000000000000000000000000000000000..7df20d4c0a439e7c061cf85caac3635b1eb86fd3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/dynamic_shape_view.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class DynamicShapeView(torch.nn.Module): + """ + Dynamic shapes should be propagated to view arguments instead of being + baked into the exported graph. + """ + + def forward(self, x): + new_x_shape = x.size()[:-1] + (2, 5) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1) + +example_args = (torch.randn(10, 10),) +tags = {"torch.dynamic-shape"} +model = DynamicShapeView() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/fn_with_kwargs.py b/phivenv/Lib/site-packages/torch/_export/db/examples/fn_with_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2fee9aa1d165bc55b8c112e8aeac4e61dbb0be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/fn_with_kwargs.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import torch + +class FnWithKwargs(torch.nn.Module): + """ + Keyword arguments are not supported at the moment. + """ + + def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs): + out = pos0 + for arg in tuple0: + out = out * arg + for arg in myargs: + out = out * arg + out = out * mykw0 + out = out * mykwargs["input0"] * mykwargs["input1"] + return out + +example_args = ( + torch.randn(4), + (torch.randn(4), torch.randn(4)), + *[torch.randn(4), torch.randn(4)] +) +example_kwargs = { + "mykw0": torch.randn(4), + "input0": torch.randn(4), + "input1": torch.randn(4), +} +tags = {"python.data-structure"} +model = FnWithKwargs() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/list_contains.py b/phivenv/Lib/site-packages/torch/_export/db/examples/list_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..5371c12362e08793cb27a84ea1951d45e1f09a94 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/list_contains.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + +class ListContains(torch.nn.Module): + """ + List containment relation can be checked on a dynamic shape or constants. + """ + + def forward(self, x): + assert x.size(-1) in [6, 2] + assert x.size(0) not in [4, 5, 6] + assert "monkey" not in ["cow", "pig"] + return x + x + +example_args = (torch.randn(3, 2),) +tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"} +model = ListContains() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/list_unpack.py b/phivenv/Lib/site-packages/torch/_export/db/examples/list_unpack.py new file mode 100644 index 0000000000000000000000000000000000000000..57e84300a6a5dd70e943a6b58f84f5e2e2e96d77 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/list_unpack.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs + +import torch + +class ListUnpack(torch.nn.Module): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + + def forward(self, args: list[torch.Tensor]): + """ + Lists are treated as static construct, therefore unpacking should be + erased after tracing. + """ + x, *y = args + return x + y[0] + +example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],) +tags = {"python.control-flow", "python.data-structure"} +model = ListUnpack() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/model_attr_mutation.py b/phivenv/Lib/site-packages/torch/_export/db/examples/model_attr_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..f1573a175f8cf776883c5a7e8caa497df07b4436 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/model_attr_mutation.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class ModelAttrMutation(torch.nn.Module): + """ + Attribute mutation is not supported. + """ + + def __init__(self) -> None: + super().__init__() + self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)] + + def recreate_list(self): + return [torch.zeros(3, 2), torch.zeros(3, 2)] + + def forward(self, x): + self.attr_list = self.recreate_list() + return x.sum() + self.attr_list[0].sum() + + +example_args = (torch.randn(3, 2),) +tags = {"python.object-model"} +support_level = SupportLevel.NOT_SUPPORTED_YET +model = ModelAttrMutation() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/nested_function.py b/phivenv/Lib/site-packages/torch/_export/db/examples/nested_function.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c31100f23a0a20a7e30acaad9a1e22aa64e9e8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/nested_function.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + +class NestedFunction(torch.nn.Module): + """ + Nested functions are traced through. Side effects on global captures + are not supported though. + """ + + def forward(self, a, b): + x = a + b + z = a - b + + def closure(y): + nonlocal x + x += 1 + return x * y + z + + return closure(x) + +example_args = (torch.randn(3, 2), torch.randn(2)) +tags = {"python.closure"} +model = NestedFunction() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/null_context_manager.py b/phivenv/Lib/site-packages/torch/_export/db/examples/null_context_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..03ab3402aecf4761f7b71dc2de48b26a23438ba1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/null_context_manager.py @@ -0,0 +1,21 @@ +# mypy: allow-untyped-defs +import contextlib + +import torch + +class NullContextManager(torch.nn.Module): + """ + Null context manager in Python will be traced out. + """ + + def forward(self, x): + """ + Null context manager in Python will be traced out. + """ + ctx = contextlib.nullcontext() + with ctx: + return x.sin() + x.cos() + +example_args = (torch.randn(3, 2),) +tags = {"python.context-manager"} +model = NullContextManager() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/optional_input.py b/phivenv/Lib/site-packages/torch/_export/db/examples/optional_input.py new file mode 100644 index 0000000000000000000000000000000000000000..38e41a09179dc4d000834122c5330582a840d491 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/optional_input.py @@ -0,0 +1,20 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class OptionalInput(torch.nn.Module): + """ + Tracing through optional input is not supported yet + """ + + def forward(self, x, y=torch.randn(2, 3)): + if y is not None: + return x + y + return x + + +example_args = (torch.randn(2, 3),) +tags = {"python.object-model"} +support_level = SupportLevel.NOT_SUPPORTED_YET +model = OptionalInput() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/pytree_flatten.py b/phivenv/Lib/site-packages/torch/_export/db/examples/pytree_flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..6a963cad3f1a0508c0c8f40278c23e6449b3ac3b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/pytree_flatten.py @@ -0,0 +1,16 @@ +# mypy: allow-untyped-defs +import torch + +from torch.utils import _pytree as pytree + +class PytreeFlatten(torch.nn.Module): + """ + Pytree from PyTorch can be captured by TorchDynamo. + """ + + def forward(self, x): + y, _spec = pytree.tree_flatten(x) + return y[0] + 1 + +example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},), +model = PytreeFlatten() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/scalar_output.py b/phivenv/Lib/site-packages/torch/_export/db/examples/scalar_output.py new file mode 100644 index 0000000000000000000000000000000000000000..b233b450dc295b83280e81cf2d392bcc69514a2b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/scalar_output.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +import torch + +from torch.export import Dim + +x = torch.randn(3, 2) +dim1_x = Dim("dim1_x") + +class ScalarOutput(torch.nn.Module): + """ + Returning scalar values from the graph is supported, in addition to Tensor + outputs. Symbolic shapes are captured and rank is specialized. + """ + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x.shape[1] + 1 + +example_args = (x,) +tags = {"torch.dynamic-shape"} +dynamic_shapes = {"x": {1: dim1_x}} +model = ScalarOutput() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/specialized_attribute.py b/phivenv/Lib/site-packages/torch/_export/db/examples/specialized_attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..ea36d41376bf1a43369e411aee2eb77254104326 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/specialized_attribute.py @@ -0,0 +1,26 @@ +# mypy: allow-untyped-defs +from enum import Enum + +import torch + +class Animal(Enum): + COW = "moo" + +class SpecializedAttribute(torch.nn.Module): + """ + Model attributes are specialized. + """ + + def __init__(self) -> None: + super().__init__() + self.a = "moo" + self.b = 4 + + def forward(self, x): + if self.a == Animal.COW.value: + return x * x + self.b + else: + raise ValueError("bad") + +example_args = (torch.randn(3, 2),) +model = SpecializedAttribute() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/static_for_loop.py b/phivenv/Lib/site-packages/torch/_export/db/examples/static_for_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..b81bc560d72b2be37ef5728fa8a00e404afa88b2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/static_for_loop.py @@ -0,0 +1,16 @@ +# mypy: allow-untyped-defs +import torch + +class StaticForLoop(torch.nn.Module): + """ + A for loop with constant number of iterations should be unrolled in the exported graph. + """ + + def forward(self, x): + # constant + ret = [i + x for i in range(10)] + return ret + +example_args = (torch.randn(3, 2),) +tags = {"python.control-flow"} +model = StaticForLoop() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/static_if.py b/phivenv/Lib/site-packages/torch/_export/db/examples/static_if.py new file mode 100644 index 0000000000000000000000000000000000000000..cd02275c2bb6ca83d018865618a7d96bb862cf2f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/static_if.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch + +class StaticIf(torch.nn.Module): + """ + `if` statement with static predicate value should be traced through with the + taken branch. + """ + + def forward(self, x): + if len(x.shape) == 3: + return x + torch.ones(1, 1, 1) + + return x + +example_args = (torch.randn(3, 2, 2),) +tags = {"python.control-flow"} +model = StaticIf() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/tensor_setattr.py b/phivenv/Lib/site-packages/torch/_export/db/examples/tensor_setattr.py new file mode 100644 index 0000000000000000000000000000000000000000..54b10e043398ef676355d3f510249f6385f068b3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/tensor_setattr.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +import torch + + +class TensorSetattr(torch.nn.Module): + """ + setattr() call onto tensors is not supported. + """ + def forward(self, x, attr): + setattr(x, attr, torch.randn(3, 2)) + return x + 4 + +example_args = (torch.randn(3, 2), "attr") +tags = {"python.builtin"} +model = TensorSetattr() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/type_reflection_method.py b/phivenv/Lib/site-packages/torch/_export/db/examples/type_reflection_method.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5ffac0e914334d3e94daea52c1966b9af173e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/type_reflection_method.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch + +class A: + @classmethod + def func(cls, x): + return 1 + x + +class TypeReflectionMethod(torch.nn.Module): + """ + type() calls on custom objects followed by attribute accesses are not allowed + due to its overly dynamic nature. + """ + + def forward(self, x): + a = A() + return type(a).func(x) + + +example_args = (torch.randn(3, 4),) +tags = {"python.builtin"} +model = TypeReflectionMethod() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/unsupported_operator.py b/phivenv/Lib/site-packages/torch/_export/db/examples/unsupported_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..a60e257422367df2dbd80e74cbee07897c104d75 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/unsupported_operator.py @@ -0,0 +1,18 @@ +# mypy: allow-untyped-defs +import torch +from torch._export.db.case import SupportLevel + + +class TorchSymMin(torch.nn.Module): + """ + torch.sym_min operator is not supported in export. + """ + + def forward(self, x): + return x.sum() + torch.sym_min(x.size(0), 100) + + +example_args = (torch.randn(3, 2),) +tags = {"torch.operator"} +support_level = SupportLevel.NOT_SUPPORTED_YET +model = TorchSymMin() diff --git a/phivenv/Lib/site-packages/torch/_export/db/examples/user_input_mutation.py b/phivenv/Lib/site-packages/torch/_export/db/examples/user_input_mutation.py new file mode 100644 index 0000000000000000000000000000000000000000..32221c81688070b64d13c01845f444f6353e1783 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/examples/user_input_mutation.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +import torch + + +class UserInputMutation(torch.nn.Module): + """ + Directly mutate user input in forward + """ + + def forward(self, x): + x.mul_(2) + return x.cos() + + +example_args = (torch.randn(3, 2),) +tags = {"torch.mutation"} +model = UserInputMutation() diff --git a/phivenv/Lib/site-packages/torch/_export/db/gen_example.py b/phivenv/Lib/site-packages/torch/_export/db/gen_example.py new file mode 100644 index 0000000000000000000000000000000000000000..0522ab99c72bbe2018a67bd75bcc4e87163a6f24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/gen_example.py @@ -0,0 +1,21 @@ +import os +import sys + +import torch._export.db.examples as examples + +TEMPLATE = '''import torch + +def {case_name}(x): + """ + """ + + return +''' + +if __name__ == "__main__": + assert len(sys.argv) == 2 + root_dir = examples.__name__.replace(".", "/") + assert os.path.exists(root_dir) + with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f: + print("Writing to", f.name, "...") + f.write(TEMPLATE.format(case_name=sys.argv[1])) diff --git a/phivenv/Lib/site-packages/torch/_export/db/logging.py b/phivenv/Lib/site-packages/torch/_export/db/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..016cc580e2e0c8bfb9cf7271837bda190f8edc33 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/db/logging.py @@ -0,0 +1,46 @@ +from typing import Optional + +def exportdb_error_message(case_name: str) -> str: + from .examples import all_examples + from torch._utils_internal import log_export_usage + + ALL_EXAMPLES = all_examples() + # Detect whether case_name is really registered in exportdb. + if case_name in ALL_EXAMPLES: + url_case_name = case_name.replace("_", "-") + return f"See {case_name} in exportdb for unsupported case. \ + https://pytorch.org/docs/main/generated/exportdb/index.html#{url_case_name}" + else: + log_export_usage( + event="export.error.casenotregistered", + message=case_name, + ) + return f"{case_name} is unsupported." + + +def get_class_if_classified_error(e: Exception) -> Optional[str]: + """ + Returns a string case name if the export error e is classified. + Returns None otherwise. + """ + + from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError + + ALWAYS_CLASSIFIED = "always_classified" + DEFAULT_CLASS_SIGIL = "case_name" + + # add error types that should be classified, along with any attribute name + # whose presence acts like a sigil to further distinguish which errors of + # that type should be classified. If the attribute name is None, then the + # error type is always classified. + _ALLOW_LIST = { + Unsupported: DEFAULT_CLASS_SIGIL, + UserError: DEFAULT_CLASS_SIGIL, + TorchRuntimeError: None, + } + if type(e) in _ALLOW_LIST: + attr_name = _ALLOW_LIST[type(e)] + if attr_name is None: + return ALWAYS_CLASSIFIED + return getattr(e, attr_name, None) + return None diff --git a/phivenv/Lib/site-packages/torch/_export/error.py b/phivenv/Lib/site-packages/torch/_export/error.py new file mode 100644 index 0000000000000000000000000000000000000000..fa4c9eb33ca2e35758d57e3f402b2e75ccca5cd0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/error.py @@ -0,0 +1,56 @@ +from enum import Enum + + +class ExportErrorType(Enum): + # User providing invalid inputs to either tracer, or other public facing APIs + INVALID_INPUT_TYPE = 1 + + # User returning values from their models that we don't support. + INVALID_OUTPUT_TYPE = 2 + + # Generated IR does not conform to Export IR Specification. + VIOLATION_OF_SPEC = 3 + + # User's code contains types and functionalities we don't support. + NOT_SUPPORTED = 4 + + # User's code didn't provide necessary details for us to successfully trace and export. + # For example, we use a lot of decorators and ask users to annotate their model. + MISSING_PROPERTY = 5 + + # User is using an API without proper initialization step. + UNINITIALIZED = 6 + + +def internal_assert(pred: bool, assert_msg: str) -> None: + """ + This is exir's custom assert method. It internally just throws InternalError. + Note that the sole purpose is to throw our own error while maintaining similar syntax + as python assert. + """ + + if not pred: + raise InternalError(assert_msg) + + +class InternalError(Exception): + """ + Raised when an internal invariance is violated in EXIR stack. + Should hint users to report a bug to dev and expose the original + error message. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class ExportError(Exception): + """ + This type of exception is raised for errors that are directly caused by the user + code. In general, user errors happen during model authoring, tracing, using our public + facing APIs, and writing graph passes. + """ + + def __init__(self, error_code: ExportErrorType, message: str) -> None: + prefix = f"[{error_code}]: " + super().__init__(prefix + message) diff --git a/phivenv/Lib/site-packages/torch/_export/non_strict_utils.py b/phivenv/Lib/site-packages/torch/_export/non_strict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b21d8e41021d50710c6076dc15ea60bc28535841 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/non_strict_utils.py @@ -0,0 +1,1054 @@ +# mypy: allow-untyped-defs +import builtins +import contextlib +import functools +import inspect +import logging +import math +from collections import defaultdict +from collections.abc import Sequence +from contextlib import contextmanager +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.source import ( + AttrSource, + GetItemSource, + LocalSource, + TensorProperty, + TensorPropertySource, +) +from torch._dynamo.variables.builder import TrackedFake +from torch._export.passes.lift_constants_pass import ConstantAttrMap +from torch._export.utils import _fakify_params_buffers +from torch._guards import Source +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.export import Constraint +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _combine_args, + _DimHint, + _DimHintType, + _IntWrapper, + _process_dynamic_shapes, + _RelaxedConstraint, + _tree_map_with_path, +) +from torch.export.graph_signature import CustomObjArgument +from torch.fx.experimental import _config as config +from torch.fx.experimental.symbolic_shapes import ( + _find_user_code_frame, + _suggest_fixes_for_data_dependent_error_non_strict, + ConstraintViolationError, + DimDynamic, + EqualityConstraint, + GuardOnDataDependentSymNode, + RelaxedUnspecConstraint, + ShapeEnv, + StatelessSymbolicContext, + SymIntSymbolicContext, + ValueRanges, +) +from torch.utils._pytree import ( + GetAttrKey, + KeyPath, + MappingKey, + SequenceKey, + tree_map_with_path, +) +from torch.utils._sympy.numbers import int_oo + + +if TYPE_CHECKING: + from sympy import Symbol + + +log = logging.getLogger(__name__) + + +class _KeyPath: + """ + Wraps `KeyPath` to aid `isinstance` checks. + """ + + def __init__(self, kp: KeyPath): + self.kp = kp + + +class _KeyPathTrie: + """ + Builds a trie of `KeyPath` prefixes mapping to `Source` leaves. + """ + + def __init__(self): + self.root = {} + + def add(self, kp: KeyPath, src: Source): + assert len(kp) > 0 + *path, leaf = kp + node = self.root + for k in path: + if k not in node: + node[k] = {} + node = node[k] + node[leaf] = src + + def get(self, kp: KeyPath) -> tuple[Source, KeyPath]: + node = self.root + while not isinstance(node, Source): + assert len(kp) > 0 + k, *kp = kp # type: ignore[assignment] + node = node[k] + return node, kp + + +def make_sourced_prefixes(nn_module, args, kwargs) -> _KeyPathTrie: + kp_args, kp_kwargs = tree_map_with_path( + lambda kp, _: _KeyPath(kp), + (tuple(None for _ in args), {k: None for k in kwargs}), # noqa: C420 + ) + kp_combined_args = _combine_args(nn_module, kp_args, kp_kwargs) + + sourced_prefixes = _KeyPathTrie() + for name, struct in kp_combined_args.items(): + src = LocalSource(name) + + if isinstance(struct, _KeyPath): + sourced_prefixes.add(struct.kp, src) + elif isinstance(struct, tuple): + for i, prefix in enumerate(struct): + assert isinstance(prefix, _KeyPath) + sourced_prefixes.add(prefix.kp, GetItemSource(src, i)) + elif isinstance(struct, dict): + for k, prefix in struct.items(): + assert isinstance(prefix, _KeyPath) + sourced_prefixes.add(prefix.kp, GetItemSource(src, k)) + + return sourced_prefixes + + +def key_path_to_source( + kp: KeyPath, sourced_prefixes: Optional[_KeyPathTrie] = None +) -> Source: + """ + Given a key path, return the source for the key path. + """ + if sourced_prefixes is None: + source: Source = LocalSource("args") + else: + source, kp = sourced_prefixes.get(kp) + for k in kp: + if isinstance(k, SequenceKey): + source = GetItemSource(source, k.idx) + elif isinstance(k, MappingKey): + source = GetItemSource(source, k.key) + elif isinstance(k, GetAttrKey): + source = AttrSource(source, k.name) + else: + raise ValueError(f"Unknown KeyEntry {k}") + + return source + + +def _is_constant_argument(t): + return t is None or isinstance(t, (float, bool, str)) + + +def fakify( + mode: FakeTensorMode, + kp: KeyPath, + t: Any, + t_constraints: dict[int, dict[int, Constraint]], + sources: dict[tuple[int, int], list[Source]], + sourced_prefixes: Optional[_KeyPathTrie] = None, +): + source = key_path_to_source(kp, sourced_prefixes=sourced_prefixes) + if _is_constant_argument(t) or isinstance(t, (torch.ScriptObject, torch.nn.Module)): + return t + + if isinstance(t, _IntWrapper): + if t.dynamism is not None and t.dynamism.type in (_DimHintType.DYNAMIC, _DimHintType.AUTO): # type: ignore[union-attr] + symint = mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr] + t.val, source, DimDynamic.DYNAMIC + ) + context = ( + SymIntSymbolicContext( + constraint=RelaxedUnspecConstraint(warn_only=False) + ) + if t.dynamism.type == _DimHintType.DYNAMIC # type: ignore[union-attr] + else None + ) + mode.shape_env.tracked_fakes.append( # type: ignore[union-attr] + TrackedFake(symint, source, context) + ) + return symint + else: + return t.val + + if not isinstance(t, torch.Tensor): + raise ValueError( + f"Unsupported input type {type(t)}. " + "Export only supports pytree containers of basic types (Tensor, int, float, ...) as input. " + "To register a custom dataclass, use torch.export.register_dataclass. " + "To register a custom container type, use torch.utils._pytree.register_pytree_node. " + "To register a constant input, use torch.utils._pytree.register_constant" + ) + + n_dims = len(t.shape) + dynamic_sizes = [] + constraint_sizes = [None] * n_dims + for i in range(n_dims): + if i in getattr(t, "_dynamo_weak_dynamic_indices", {}): + dynamic_sizes.append(DimDynamic.DYNAMIC) + elif i in getattr(t, "_dynamo_dynamic_indices", {}): + # bit annoying, but we need to replicate process in _dynamo/variables/builder.py + # where a RelaxedUnspecConstraint is created for Dim.DYNAMIC, so constraint violations + # are raised when specializing. + dynamic_sizes.append(DimDynamic.DYNAMIC) + constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload] + else: + dynamic_sizes.append(DimDynamic.STATIC) + symbolic_context: StatelessSymbolicContext = ( # make mypy happy + StatelessSymbolicContext( + dynamic_sizes=dynamic_sizes, + constraint_sizes=constraint_sizes, # type: ignore[arg-type] + ) + ) + t_id = id(t) + assert mode.shape_env is not None + if t_id in t_constraints: + for i, constraint in t_constraints[t_id].items(): + src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i) + sources[(t_id, i)].append(src) + if isinstance(constraint, _RelaxedConstraint): + continue + symbolic_context.constraint_sizes[i] = constraint.constraint_range + mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment] + fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context) + mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr] + return fake + + +def _is_unbacked_symint(symbol): + if not isinstance(symbol, torch.SymInt): + return False + + return symbol.node.shape_env.is_unbacked_symint(symbol.node.expr) + + +def _tensor_min_max(*args, real_callable, tensor_callable, **kwargs): + """ + This logic is replicated from dynamo/variables/builtin.py + """ + if len(args) == 2 and not kwargs: + arg1, arg2 = args + + # Case 1: Both are tensors + if isinstance(arg1, torch.Tensor) and isinstance(arg2, torch.Tensor): + return tensor_callable(arg1, arg2) + + # Case 2: One tensor, one scalar + elif isinstance(arg1, torch.Tensor) or isinstance(arg2, torch.Tensor): + if not isinstance(arg1, torch.Tensor): + arg1, arg2 = arg2, arg1 + + if isinstance(arg2, (int, float)): + kwarg = {"min" if tensor_callable is torch.maximum else "max": arg2} + return torch.clamp(arg1, **kwarg) # type: ignore[call-overload] + else: + return real_callable(arg1, arg2) + + # Case 3: SymInts + elif isinstance(arg1, torch.SymInt) or isinstance(arg2, torch.SymInt): + return ( + torch.sym_max(arg1, arg2) + if tensor_callable is torch.maximum + else torch.sym_min(arg1, arg2) + ) + + # Fallback + else: + return real_callable(arg1, arg2) + + # Single iterable argument handling + if len(args) == 1 and not kwargs: + iterable = args[0] + + if isinstance(iterable, torch.Tensor): + return tensor_callable(iterable) + try: + iterator = iter(iterable) + except TypeError: + pass + else: + items = list(iterator) + if not items: + raise ValueError(f"{real_callable.__name__}() arg is an empty sequence") + + return functools.reduce( + lambda a, b: _tensor_min_max( + a, b, real_callable=real_callable, tensor_callable=tensor_callable + ), + items, + ) + + # Fallback to original callable + return real_callable(*args, **kwargs) + + +@contextmanager +def _override_builtin_ops(): + original_max = builtins.max + original_min = builtins.min + original_pow = math.pow + + builtins.max = functools.partial( + _tensor_min_max, real_callable=original_max, tensor_callable=torch.maximum + ) + + builtins.min = functools.partial( + _tensor_min_max, real_callable=original_min, tensor_callable=torch.minimum + ) + + math.pow = lambda x, y: x**y # type: ignore[operator] + + try: + yield + finally: + builtins.max = original_max + builtins.min = original_min + math.pow = original_pow + + +def make_fake_inputs( + nn_module, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=False, + allow_complex_guards_as_runtime_asserts=False, +): + """ + Given an nn module, example inputs, and constraints, return a new fake mode, + fake inputs created in that mode whose dynamic shape dimensions are constrained + by the given ranges, and sources for pairs of dynamic shape dimensions that are + constrained to be equal. + """ + # TODO(avik): refactor Dynamo to avoid duplication of the following code + # between non-strict and strict. + # Specifically, here (non-strict) we do the following pre-tracing steps: + # - Fakify inputs. + # - Process input shape equalities. + # In strict, these steps are spread across multiple files: + # - output_graph.py fakifies inputs. + # - [post-tracing] guards.py processes input shape equalities. + import torch._functorch.config as _config + + # Map ints to a wrapper structure to help us mark it as dynamic, if it is + # dynamic. We will unwrap ints in fakify later. + args, kwargs = pytree.tree_map_only(int, lambda a: _IntWrapper(a), (args, kwargs)) + + combined_args = _combine_args(nn_module, args, kwargs) + _check_dynamic_shapes(combined_args, dynamic_shapes) + constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) + t_constraints: dict[int, dict[int, Constraint]] = defaultdict(dict) + for constraint in constraints: + t_constraints[constraint.t_id][constraint.dim] = constraint + + context = torch._guards.TracingContext.try_get() + if context is not None: + # This occurs when we are exporting within dynamo. There already exists + # a toplevel TracingContext with a fake mode, so we do not want to + # create another fake mode. + fake_mode = context.fake_mode + elif not _is_torch_jit_trace: + if isinstance(nn_module.forward, functools.partial): + # functools handles nesting by itself, no need to recurse + code = nn_module.forward.func.__code__ + else: + code = nn_module.forward.__code__ + co_fields = { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + } + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_mode = FakeTensorMode( + shape_env=ShapeEnv( + tracked_fakes=[], + co_fields=co_fields, + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + trace_asserts=True, + ), + allow_non_fake_inputs=True, + export=True, + ) + else: + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_mode = FakeTensorMode( + shape_env=ShapeEnv( + tracked_fakes=[], + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + trace_asserts=True, + ), + allow_non_fake_inputs=True, + ) + if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: + raise ValueError( + "Detected fake_mode does not have a shape_env with tracked fakes. " + "If you constructed the module under a FakeTensorMode, " + "please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))" + ) + + with fake_mode: + # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock + if not _is_torch_jit_trace: + original_signature = inspect.signature(nn_module.forward) + else: + original_signature = None + sources: dict[tuple[int, int], list[Source]] = defaultdict(list) + sourced_prefixes = make_sourced_prefixes(nn_module, args, kwargs) + fake_args, fake_kwargs = tree_map_with_path( + lambda kp, val: fakify( + fake_mode, + kp, + val, + t_constraints, + sources, + sourced_prefixes=sourced_prefixes, + ), + (args, kwargs), + ) + + names: dict[str, tuple[int, int]] = {} + source_pairs: list[tuple[Source, Source]] = [] + derived_equalities: list[tuple[Source, Union[Source, Symbol], Callable]] = [] + phantom_symbols: dict[str, Symbol] = {} + relaxed_sources: set[Source] = set() + for constraint in constraints: + torch.export.dynamic_shapes._process_equalities( + constraint, + lambda t_id, dim: sources[(t_id, dim)], + fake_mode.shape_env, + names, + source_pairs, + derived_equalities, + phantom_symbols, + relaxed_sources, + ) + + equalities_inputs = EqualityConstraint( + source_pairs=source_pairs, + derived_equalities=derived_equalities, + phantom_symbols=list(phantom_symbols.values()), + relaxed_sources=relaxed_sources, + warn_only=False, + ) + return ( + fake_mode, + fake_args, + fake_kwargs, + equalities_inputs, + original_signature, + dynamic_shapes, + ) + + +def _flatten_dynamic_shapes( + combined_args: dict[str, Any], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]], +) -> list[Any]: + flat_shapes = [] + + def _tree_map_helper(path, t, shape): + nonlocal flat_shapes + flat_shapes.append(shape) + + _tree_map_with_path(_tree_map_helper, combined_args, dynamic_shapes) + return flat_shapes + + +def _clean_dynamic_markers(tensor: torch.Tensor) -> None: + for attr in [ + "_dynamo_weak_dynamic_indices", + "_dynamo_dynamic_indices", + "_dynamo_dynamic_range", + "_dynamo_static_indices", + "_dynamo_unbacked_indices", + ]: + if hasattr(tensor, attr): + delattr(tensor, attr) + + +def produce_guards_and_solve_constraints( + fake_mode: FakeTensorMode, + gm: torch.fx.GraphModule, + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], + equalities_inputs: EqualityConstraint, + original_signature: inspect.Signature, + _is_torch_jit_trace=False, +): + """ + Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, + and a graph module, produce guards on the fake mode's shape env (raising constraint + violations if any), solve (to suggest simplifications or fixes). + Dynamo already performs this, so this is for non-strict mode. + + Additional inputs: + equalities_inputs: the equality constraints to use for guards + original_signature: the signature of the forward method + """ + shape_env = fake_mode.shape_env + assert shape_env is not None + assert shape_env.tracked_fakes is not None + + placeholders = [tf.fake for tf in shape_env.tracked_fakes] + sources = [tf.source for tf in shape_env.tracked_fakes] + input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes] + constraint_violation_error = None + try: + shape_env.produce_guards( + placeholders, + sources, + input_contexts=input_contexts, + equalities_inputs=equalities_inputs, + ignore_static=False, + ) + except ConstraintViolationError as e: + constraint_violation_error = e + + shape_env.frozen = True + dim_constraints = shape_env.dim_constraints + if dim_constraints is None: + # Expected when shape_env.produce_guards throws an early constraint violation error. + # There is nothing to solve for in this case. + # TODO(avik): Maybe record the constraint violation error instead and replay later? + assert constraint_violation_error + raise constraint_violation_error + dim_constraints.solve() + forced_specializations = dim_constraints.forced_specializations() + if not _is_torch_jit_trace: + msg = dim_constraints.prettify_results( + original_signature, + dynamic_shapes, # type: ignore[arg-type] + constraint_violation_error, + forced_specializations, # type: ignore[arg-type] + ) + else: + # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod + msg = "dummy constraint violation message" + if constraint_violation_error: + constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) + elif forced_specializations: + constraint_violation_error = ConstraintViolationError(msg) + if constraint_violation_error: + raise constraint_violation_error + + +def is_int(x: object) -> bool: + return isinstance(x, int) or (isinstance(x, torch.SymInt) and x.node.expr.is_number) + + +def _constrain_user_specified_dimhint_range( + symint: torch.SymInt, + hint: int, + dim: _DimHint, + range_constraints, + shape_env, + keypath: KeyPath, + i: Optional[int] = None, +) -> Optional[str]: + trace_vr = ( + range_constraints[symint.node.expr] + if not is_int(symint) + else ValueRanges(int(symint), int(symint)) + ) + + # warn on 0/1 specialization for Dim.AUTO; not an actual error + if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1): + pathstr = f"inputs{pytree.keystr(keypath)}" + if i is not None: + pathstr += f".shape[{i}]" + msg = ( + f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along " + + f"with a sample input with hint = {hint}." + ) + log.warning(msg) + + try: + user_vr = ValueRanges( + lower=0 if dim.min is None else dim.min, + upper=int_oo if dim.max is None else dim.max, + ) + if is_int(symint): + out_vr = trace_vr & user_vr + else: + range_constraints[symint.node.expr] &= user_vr + shape_env.var_to_range[symint.node._expr] &= user_vr + out_vr = range_constraints[symint.node.expr] + + # check for Dim.DYNAMIC specializations; special case error message on 0/1 + if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton(): + path = f"inputs{pytree.keystr(keypath)}" + if i is not None: + path += f".shape[{i}]" + if ( + trace_vr.is_singleton() + and hint in (0, 1) + and not torch.fx.experimental._config.backed_size_oblivious + ): + msg = ( + f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), " + f"but export 0/1 specialized due to hint of {hint} for dimension {path}." + ) + else: + msg = ( + f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), " + f"but tracing inferred a static shape of {out_vr.lower} for dimension {path}." + ) + return msg + + except torch.utils._sympy.value_ranges.ValueRangeError: + path = f"inputs{pytree.keystr(keypath)}" + if i is not None: + path += f".shape[{i}]" + msg = ( + f"- Received user-specified min/max range of [{dim.min}, {dim.max}], " + f"conflicting with the inferred min/max range of [{trace_vr.lower}, {trace_vr.upper}], " + f"for {path}." + ) + return msg + + return None + + +def make_constraints( + fake_mode: FakeTensorMode, + gm: torch.fx.GraphModule, + combined_args: dict[str, Any], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], + num_lifted_inputs: int, +): + """ + Given a fake mode's shape env and user-specified dynamic shapes, + return the resulting range constraints and equality constraints. + + Additional args: + num_lifted_inputs: the number of non-user-input placeholder nodes in the graph + (used only to enumerate the user-input nodes) + """ + + shape_env = fake_mode.shape_env + assert shape_env is not None + inline_constraints = gm.meta.get("inline_constraints", []) + range_constraints = defaultdict(lambda: ValueRanges(0, int_oo)) | inline_constraints + if not dynamic_shapes: + return dict(range_constraints) + + # clean up dynamic markers from tensors + flat_paths, flat_args = zip(*pytree.tree_flatten_with_path(combined_args)[0]) + for arg in flat_args: + if isinstance(arg, torch.Tensor): + _clean_dynamic_markers(arg) + + # get individual dynamic shapes spec for each input + if not isinstance(dynamic_shapes, dict): + assert isinstance(dynamic_shapes, (tuple, list)) + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) + + # check number of shapes vs. number of inputs + num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True) + assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs + + free_symbols = set() + range_violations = [] + for input_index, node in enumerate(gm.graph.nodes): + meta_val = node.meta.get("val") + + if ( + input_index < num_lifted_inputs + or node.op != "placeholder" + or meta_val is None + ): + continue + + elif _is_constant_argument(meta_val) or isinstance(meta_val, CustomObjArgument): + continue + + shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs] + keypath = flat_paths[input_index - num_lifted_inputs] + flat_arg = flat_args[input_index - num_lifted_inputs] + + if isinstance(meta_val, int) or ( + isinstance(meta_val, torch.SymInt) and meta_val.node.expr.is_number + ): + pass + + elif isinstance(meta_val, torch.SymInt): + if shape_spec is not None and isinstance(shape_spec, _DimHint): + hint = flat_arg + range_constraints[meta_val.node.expr] &= shape_env.bound_sympy( + meta_val.node._expr + ) + violation = _constrain_user_specified_dimhint_range( + meta_val, + hint, + shape_spec, + range_constraints, + shape_env, + keypath, + None, + ) + if violation: + range_violations.append(violation) + else: + raise RuntimeError("nyi") + free_symbols.update(meta_val.node.expr.free_symbols) + + elif isinstance(meta_val, torch.Tensor): + for i, d in enumerate(node.meta["val"].shape): + dim = None + if isinstance(shape_spec, (list, tuple)): + dim = shape_spec[i] + elif isinstance(shape_spec, dict): + dim = shape_spec.get(i) + if not is_int(d): + # Compute the range constraint for the symbolic expression corresponding + # to this shape dimension and store it. + if dim is None or isinstance(dim, _DimHint): + range_constraints[d.node.expr] &= shape_env.bound_sympy( + d.node.expr + ) + else: + range_constraints[d.node.expr] &= ValueRanges( + lower=dim.min, upper=dim.max + ) + + free_symbols.update(d.node.expr.free_symbols) + + # check user-specified min/max range for DimHints; + # we might want to do this even if model tracing inferred a static dimension. + if isinstance(dim, _DimHint): + hint = flat_arg.shape[i] + violation = _constrain_user_specified_dimhint_range( + d, hint, dim, range_constraints, shape_env, keypath, i + ) + if violation: + range_violations.append(violation) + else: + raise RuntimeError(f"Unfamiliar meta val: {meta_val}") + + if range_violations: + prefix = "Found the following conflicts between user-specified ranges and inferred ranges from model tracing:\n" + raise ValueError(prefix + "\n".join(range_violations)) + + for symbol in free_symbols: + if symbol not in range_constraints: + # Placeholders can have symbolic shapes that are derived expressions. + # The above code will record direct range constraints for them + # so that we can do runtime assertions. In addition, for serde checks + # we want to record range constraints for their root symbols. + range_constraints[symbol] = shape_env.var_to_range[symbol] + + return dict(range_constraints) + + +def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: + """Search the module hierarchy, gathering up all tensor and ScriptObject constants. + + Returns a dictionary mapping hash(value) to the name of the constant. We + have to abuse `hash` here unfortunately, see: [ScriptObject hash]. + """ + constants = ConstantAttrMap() + buffers_parameters = set(m.buffers()) + buffers_parameters.update(m.parameters()) + + def inner(m: torch.nn.Module, prefix_atoms: list[str], constants): + for k, v in m.__dict__.items(): + if isinstance( + v, + ( + torch.Tensor, + torch.ScriptObject, + FakeScriptObject, + ), + ): + if v in buffers_parameters: + # filter out buffers and parameters, leaving only constants + continue + + fqn = ".".join(prefix_atoms + [k]) + constants.add(v, fqn) + for k, v in m.named_children(): + inner(v, prefix_atoms + [k], constants) + + inner(m, [], constants) + return constants + + +def _get_graph_inputs_of_type_nn_module( + args: Optional[tuple[tuple[Any], dict[Any, Any]]], +) -> set[type[torch.nn.Module]]: + if args is None: + return set() + module_types = set() + for arg in pytree.tree_leaves(args): + if isinstance(arg, torch.nn.Module): + module_types.add(type(arg)) + return module_types + + +def _enter_enable_graph_inputs_of_type_nn_module( + module_types: set[type[torch.nn.Module]], +) -> None: + for t in module_types: + torch._export.utils.register_module_as_pytree_input_node(t) + + +def _exit_enable_graph_inputs_of_type_nn_module( + module_types: set[type[torch.nn.Module]], +) -> None: + for t in module_types: + torch._export.utils.deregister_module_as_pytree_input_node(t) + + +@contextlib.contextmanager +def _enable_graph_inputs_of_type_nn_module( + args: Optional[tuple[tuple[Any], dict[Any, Any]]], +): + if args is None: + yield + return + + module_types = _get_graph_inputs_of_type_nn_module(args) + _enter_enable_graph_inputs_of_type_nn_module(module_types) + try: + yield + finally: + _exit_enable_graph_inputs_of_type_nn_module(module_types) + + +@contextlib.contextmanager +def _fakify_module_inputs( + args: tuple[Any], + kwargs: dict[Any, Any], + fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, +): + # This context manager is used to fakify module inputs. + # Inputs: + # args, kwargs: the args and kwargs containing module inputs that haven't been fakified. + # fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors. + + ctxs = [_enable_graph_inputs_of_type_nn_module((args, kwargs))] + for arg in pytree.tree_leaves((args, kwargs)): + if isinstance(arg, torch.nn.Module): + fake_params_buffers = _fakify_params_buffers(fake_mode, arg) + ctxs.append( + torch.nn.utils.stateless._reparametrize_module( + arg, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ) + ) + with contextlib.ExitStack() as stack: + for ctx in ctxs: + stack.enter_context(ctx) + yield + + +@contextlib.contextmanager +def _fakify_script_objects( + mod: torch.nn.Module, + args: Sequence[Any], + kwargs: dict[Any, Any], + fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, +): + # This context manager is used to fakify script objects into FakeScriptObject. + # Inputs: + # mod: the module to be exported, it (and its recursive submodules)'s script object attrs haven't been fakified. + # args, kwargs: the args and kwargs inputs for mod, script object inputs haven't been fakified. + # fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors. + # + # Returns: + # mod: the patched module, its (and its recursive submodules) script object attrs have been fakified. + # fake_args, fake_kwargs: new fakified args and kwargs. + # Script object inputs have been fakified. Don't touch the tensors. + # fake_constant_attrs: a new map from FakeScriptObject to the fqn of the original script object. + # fake_to_real: a mapping between FakeScriptObject and the original script object in order to un-do the patching. + + constant_attrs: ConstantAttrMap = _gather_constant_attrs(mod) + assert not any( + isinstance(obj, FakeScriptObject) for obj in constant_attrs.values() + ), "Mod shouldn't contain any FakeScriptObject." + assert not pytree.tree_any( + lambda obj: isinstance(obj, FakeScriptObject), (args, kwargs) + ), "args and kwargs shouldn't contain any FakeScriptObject." + + patched_attr = {} + fake_constant_attrs = ConstantAttrMap() + fake_to_real = {} + + def _maybe_fakify_obj(obj): + fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj) + fake_to_real[fake_obj] = obj + return fake_obj + + def _leaf_mod_and_attr( + mod: torch.nn.Module, attr_fqn: str + ) -> tuple[torch.nn.Module, str]: + *prefix_attr, last_attr = attr_fqn.split(".") + cur_mod = mod + for attr in prefix_attr: + cur_mod = getattr(cur_mod, attr) + return cur_mod, last_attr + + try: + for obj, fqns in constant_attrs.items(): + if torch._library.fake_class_registry._is_script_object(obj): + fake_script_obj = _maybe_fakify_obj(obj) + for fqn in fqns: + cur_mod, attr = _leaf_mod_and_attr(mod, fqn) + assert obj is getattr(cur_mod, attr) + setattr(cur_mod, attr, fake_script_obj) + fake_constant_attrs.add(fake_script_obj, fqn) + patched_attr[fqn] = obj + else: + for fqn in fqns: + fake_constant_attrs.add(obj, fqn) + + fake_args, fake_kwargs = pytree.tree_map_only( + torch.ScriptObject, _maybe_fakify_obj, (args, kwargs) + ) + yield (mod, fake_args, fake_kwargs, fake_constant_attrs, fake_to_real) + finally: + for fqn, orig_obj in patched_attr.items(): + cur_mod, attr = _leaf_mod_and_attr(mod, fqn) + setattr(cur_mod, attr, orig_obj) + + +class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): + """ + 1. Handles data-dependent errors raised by torch function calls in non-strict. + + Any data-dependent error is due to some condition on unbacked symints + that cannot be resolved. A mechanical way of fixing the error is to use + a torch._check() call to assert either that condition or its negation. + The handler suggests these options as code and points to the location + of the torch function call that raised the error as part of the error + message shown to the user, who can then simply select and copy-paste + a suggested fix at that location. + + NOTE: Not all data-dependent errors are raised by torch function calls. + In particular, conditions on unbacked symints can appear outside such + calls, and as such are not handled here. + + 2. Overrides torch functions that are known to cause problems in non-strict. + + Certain Python features, such as indexing/slicing, cannot be intercepted + in non-strict. Likewise, certain legacy ops, such as distributed collectives, + may need to be mapped to other ops. When there is special handling in Dynamo + for such things, tracing can fail in non-strict (while succeeding in strict). + Fortunately, redirecting to other torch functions can often fix such issues. + + 3. Handles line-of-code logging for each torch function call in non-strict. + + Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ... + """ + + def _override(self, func, args, kwargs): + if torch.distributed.is_available(): + from torch.distributed._functional_collectives import ( + REDUCE_OP_TO_STR, + traceable_collective_remaps, + ) + + if func in traceable_collective_remaps: + # Redirect to a corresponding functional collective, following Dynamo. + # See torch/distributed/_functional_collectives.py for details. + # The following is an adaptation of CollectiveFunctionRewriteVariable. + mapped_func = traceable_collective_remaps[func] + signature = inspect.signature(func) + kwargs = dict(signature.bind(*args, **kwargs).arguments) + args = () + if func in ( + torch.distributed.all_reduce, + torch.distributed.reduce_scatter_tensor, + torch.distributed._reduce_scatter_base, + ): + if "op" in kwargs: + kwargs["op"] = REDUCE_OP_TO_STR[kwargs["op"]] + return mapped_func, args, kwargs + if func is torch.tensor: + # Redirect to Python implementation of torch.tensor for data with symints. + # NOTE(avik): We don't unconditionally redirect to this implementation + # because it has some known incompletenesses, e.g., it doesn't support + # empty data. See https://github.com/pytorch/pytorch/issues/143216 + if any( + isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)) + for a in pytree.tree_flatten(args[0])[0] + ): + return torch._refs.tensor, args, kwargs + if func.__name__ == "__getitem__" and isinstance(args[0], torch.Tensor): + + def rewrite(dim, item): + # Redirect to torch.select for indexing. + if isinstance(item, (int, torch.SymInt)): + return dim, (torch.select, [dim, item]) + # Redirect to torch.ops.aten.slice for slicing. + if isinstance(item, slice): + return dim + 1, ( + torch.ops.aten.slice, + [dim, item.start, item.stop, item.step or 1], + ) + # Otherwise do nothing. + + items = args[1] if isinstance(args[1], tuple) else (args[1],) + dim = 0 + # Sequence rewrites. + sequence = [] + for item in items: + if (r := rewrite(dim, item)) is None: + return func, args, kwargs + dim, call_spec = r + sequence.append(call_spec) + + def run(): + # Run sequence. + t = args[0] + for _method, _args in sequence: + t = _method(t, *_args) + return t + + return run, [], {} + + return func, args, kwargs + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if torch.compiler.is_dynamo_compiling(): + return func(*args, **kwargs) + + if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: + frame = _find_user_code_frame() + if frame is not None: + log.debug( + "%s called at %s:%s in %s", + func.__qualname__, + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + + func, args, kwargs = self._override(func, args, kwargs) + try: + return func(*args, **kwargs) + except GuardOnDataDependentSymNode as e: + _suggest_fixes_for_data_dependent_error_non_strict(e) + raise diff --git a/phivenv/Lib/site-packages/torch/_export/pass_base.py b/phivenv/Lib/site-packages/torch/_export/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3dfd36341cd1e7fc611fb15533aedd398810aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/pass_base.py @@ -0,0 +1,478 @@ +# mypy: allow-untyped-defs +import operator +import traceback +import typing +from contextlib import nullcontext +from typing import Any, Callable, Optional, Union + +import torch +from torch import fx +from torch._dispatch.python import enable_python_dispatcher +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._higher_order_ops.map import _unstack_pytree +from torch._subclasses import FakeTensor, UnsupportedFakeTensorException +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import traceback as fx_traceback +from torch.fx.experimental.proxy_tensor import PythonKeyTracer +from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + PropagateUnbackedSymInts, +) +from torch.fx.graph import CodeGen +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.utils import _pytree as pytree + + +__all__ = ["_ExportPassBaseDeprecatedDoNotUse"] + + +Argument = Any +Value = Any +Fn = Callable[..., Any] +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +_TORCH_SYM_OPS: set[Callable] = { + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, +} + + +class ExportPassBaseError(RuntimeError): + pass + + +class _ExportPassBaseDeprecatedDoNotUse(PassBase): + """ + Interpreter-based pass class to help users maintain the IR spec while writing + transformations. + """ + + @staticmethod + def _create_dummy_node_metadata(): + return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) + + class ExportTracer(PythonKeyTracer): + def __init__( + self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen + ) -> None: + super().__init__() + self.callback = callback + self.root = torch.nn.Module() + self.graph = torch.fx.Graph() + self.graph.set_codegen(codegen) + self.tensor_attrs: dict[str, torch.Tensor] = {} # type: ignore[assignment] + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self.submodules: dict[torch.nn.Module, str] = {} + + def trace(self) -> None: # type: ignore[override] + raise ExportPassBaseError("ExportTracer doesn't support trace().") + + def create_arg(self, a: Argument) -> torch.fx.Node: + if isinstance(a, torch.nn.Module): + if a not in self.submodules: + name_submodule = f"submodule_{len(self.submodules)}" + self.root.add_module(name_submodule, a) + self.submodules[a] = name_submodule + elif isinstance(a, FakeTensor): + if not hasattr(a, "constant") or a.constant is None: + raise ExportPassBaseError(f"Cannot add {a} to graph.") + a = a.constant + node = super().create_arg(a) + if ( + isinstance(a, torch.Tensor) + and isinstance(node, torch.fx.Node) + and node.op == "get_attr" + ): + self.set_metadata(node, a) + self.callback.on_attr(ProxyValue(a, node)) + return node + + def set_metadata( + self, + node: torch.fx.Node, + value: Argument, + ) -> None: + # propagate the fake tensor or sym nodes + def make_val( + x: Argument, + ) -> Union[ + FakeTensor, + torch.SymInt, + torch.SymFloat, + torch.SymBool, + int, + float, + bool, + str, + None, + ]: + if isinstance(x, FakeTensor): + return x + elif isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + # TODO we should allocate static shapes + # for param/buffer values + if isinstance(x, torch.nn.Parameter): + fake_tensor = self.fake_tensor_mode.from_tensor( + x, static_shapes=True + ) + else: + fake_tensor = self.fake_tensor_mode.from_tensor(x) + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + print( + "Fakeifying a Tensor subclass is not supported \ + right now. Instead a TensorMetadata is used." + ) + fake_tensor = None + return fake_tensor + elif isinstance( + x, + ( + torch.SymInt, + torch.SymFloat, + torch.SymBool, + int, + float, + bool, + str, + ), + ): + return x + else: + return None + + node.meta["val"] = pytree.tree_map(make_val, value) + + # Set the tensor_metadata for values that do not have a corresponding FakeTensor + def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: + if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor): + if x.is_quantized: + # TODO (tmanlaibaatar) properly support Quantized FakeTensor + x = torch.dequantize(x) + + try: + assert self.fake_tensor_mode is not None + _ = self.fake_tensor_mode.from_tensor(x) + tensor_meta = None + except UnsupportedFakeTensorException: + # TODO: This is just a workaround to get over the + # x.as_subclass error + tensor_meta = _extract_tensor_metadata(x) + return tensor_meta + else: + return None + + node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) + + class ExportInterpreter(fx.Interpreter): + def __init__( + self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule + ) -> None: + super().__init__(gm) + self.callback = callback + self.node: torch.fx.Node = next(iter(gm.graph.nodes)) + + def placeholder( + self, + target: str, # type: ignore[override] + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + ) -> ProxyValue: + arg = super().placeholder(target, args, kwargs) + return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) + + def output( + self, + target: torch.fx.node.Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + ) -> ProxyValue: + return self.callback.output(args[0], NodeMetadata(self.node.meta)).data # type: ignore[return-value] + + def call_function( + self, + target: torch.fx.node.Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + ) -> ProxyValue: + meta = NodeMetadata(self.node.meta) + + if target == operator.getitem: + value, key = args + return self.callback.call_getitem(value, key, meta) + elif getattr(target, "__module__", None) in { + "_operator", + "builtins", + "math", + }: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif target in _TORCH_SYM_OPS: + assert callable(target) + return self.callback.call_sym(target, args, meta) + elif isinstance( + target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket) + ): + return self.callback.call_operator( + target, + args, + kwargs, + meta, + ) + elif target == torch.ops.higher_order.cond: + pred, true_fn, false_fn, inputs = args + return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) + elif target == torch.ops.higher_order.map_impl: + f, mapped_args, operands = args # type: ignore[assignment] + return self.callback.call_map(f, mapped_args, operands, meta) + # For other unregistered HigherOrderOps, just interpret them blindly + elif isinstance(target, torch._ops.HigherOrderOperator): + return self.callback._fx( + "call_function", + target, + args, + kwargs, + meta, + ) + else: + raise ExportPassBaseError(f"Unsupported target type: {target}") + + def get_attr( + self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override] + ) -> Argument: + return super().get_attr(target, args, kwargs) + + def call_module( + self, + target: torch.fx.node.Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + ) -> None: + raise ExportPassBaseError("call_module is not supported.") + + def call_method( + self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override] + ) -> None: + raise ExportPassBaseError("call_method is not supported.") + + def run_node(self, n: torch.fx.Node) -> Argument: + self.node = n + self.callback.node_debug_str = n.format_node() + return super().run_node(n) + + def __init__(self) -> None: + self.interpreter = PropagateUnbackedSymInts( + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + self.tracer = self.ExportTracer(self, CodeGen()) + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self._initialized = True + self.node_debug_str: typing.Optional[str] = None + + def _fx( + self, + kind: str, + target: torch.fx.node.Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + args_data, kwargs_data = pytree.tree_map_only( + ProxyValue, lambda x: x.data, (args, kwargs) + ) + res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data) + args_proxy, kwargs_proxy = pytree.tree_map_only( + ProxyValue, lambda x: x.proxy, (args, kwargs) + ) + + name = None + if isinstance(target, torch._ops.OpOverload): + name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) + + res_proxy = self.tracer.create_proxy( + kind, target, args_proxy, kwargs_proxy, name=name + ) + res_proxy.node.meta.update(meta.data) + if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env): + if symbol_to_path := compute_unbacked_bindings(shape_env, res_data): + res_proxy.node.meta["unbacked_bindings"] = symbol_to_path + self.tracer.set_metadata(res_proxy.node, res_data) + return ProxyValue(res_data, res_proxy) + + def inputs(self, graph_module: torch.fx.GraphModule) -> list[Argument]: + # TODO(angelayi): Update this with what we decide to do for metadata in + # the exported graph module + if (args := graph_module.meta.get("args", None)) is not None: + return list(args) + + def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: + if "val" in node.meta: + fake = node.meta["val"] + if hasattr(fake, "constant") and fake.constant is not None: + return fake.constant + return fake + elif tensor_meta := node.meta.get("tensor_meta"): + assert self.fake_tensor_mode is not None + return FakeTensor( + self.fake_tensor_mode, + torch.empty( + tensor_meta.shape, + dtype=tensor_meta.dtype, + device="meta", + requires_grad=tensor_meta.requires_grad, + memory_format=tensor_meta.memory_format, + ), + torch.device("cpu"), + ) + elif len(node.users) == 0: + return None + raise ExportPassBaseError( + f"Cannot construct an input for graph module: {graph_module}.", + ) + + return [ + extract_input(node) + for node in graph_module.graph.nodes + if node.op == "placeholder" + ] + + def on_attr(self, attr: ProxyValue) -> None: + pass + + def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue: + arg_proxy = self.tracer.create_proxy("placeholder", name, (), {}) + arg_proxy.node.meta = meta.data + self.tracer.set_metadata(arg_proxy.node, arg) + return ProxyValue(arg, arg_proxy) + + def call_operator( + self, + op, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", op, args, kwargs, meta) + + def call_sym( + self, + target: Fn, + args: tuple[Argument, ...], + meta: NodeMetadata, + ) -> ProxyValue: + return self._fx("call_function", target, args, {}, meta) + + def call_cond( + self, + pred: ProxyValue, + true_fn: torch.fx.GraphModule, + false_fn: torch.fx.GraphModule, + inputs: list[Argument], + meta: NodeMetadata, + ) -> ProxyValue: + true_branch = self.call_submodule(true_fn, tuple(inputs)) + false_branch = self.call_submodule(false_fn, tuple(inputs)) + assert true_branch is not None + assert false_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.cond, + (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)), + {}, + meta, + ) + + def call_map( + self, + f: torch.fx.GraphModule, + mapped_args: list[ProxyValue], + operands: list[ProxyValue], + meta: NodeMetadata, + ) -> ProxyValue: + xs = _unstack_pytree([arg.data for arg in mapped_args])[0] + f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands])) + assert f_branch is not None + return self._fx( + "call_function", + torch.ops.higher_order.map_impl, + (f_branch.graph_module, mapped_args, operands), + {}, + meta, + ) + + def call_getitem( + self, value: ProxyValue, key: int, meta: NodeMetadata + ) -> ProxyValue: + return self._fx("call_function", operator.getitem, (value, key), {}, meta) + + def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue: + return self._fx("output", "output", (results,), {}, meta) + + def call_submodule( + self, graph_module: fx.GraphModule, inputs: tuple[Argument, ...] + ) -> PassResult: + prev_tracer, self.tracer = self.tracer, self.ExportTracer( + self, graph_module.graph._codegen + ) + self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode + interpreter = self.ExportInterpreter(self, graph_module) + prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment] + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + ) + inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) + with fx_traceback.preserve_node_meta(): + interpreter.run(*inputs_data) + + new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph) + + self.tracer = prev_tracer + self.interpreter = prev_interpreter + return PassResult( + new_graph_module, + True, + ) + + def call(self, graph_module: fx.GraphModule) -> PassResult: + if not getattr(self, "_initialized", False): + raise ExportPassBaseError( + "ExportPass is not initialized with __init__().", + ) + + inputs = self.inputs(graph_module) + + fake_tensor_mode = None + for i in inputs: + if isinstance(i, FakeTensor): + assert ( + fake_tensor_mode is None or fake_tensor_mode is i.fake_mode + ), "Multiple fake tensor mode detected." + fake_tensor_mode = i.fake_mode + if fake_tensor_mode is None: + self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) + fake_tensor_mode = nullcontext() # type: ignore[assignment] + dispatcher_mode = nullcontext() # type: ignore[assignment] + else: + fake_tensor_mode.allow_non_fake_inputs = True + self.tracer.fake_tensor_mode = fake_tensor_mode + dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment] + self.fake_tensor_mode = self.tracer.fake_tensor_mode + + with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] + result = self.call_submodule(graph_module, tuple(inputs)) + + return result diff --git a/phivenv/Lib/site-packages/torch/_export/pass_infra/__init__.py b/phivenv/Lib/site-packages/torch/_export/pass_infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..774fa07b2e5de2d9385957b79599a633ea7dc5fc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/pass_infra/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae15f1cb81e5301a7d29c918ae90036fad540912 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/pass_infra/__pycache__/node_metadata.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b74c1625f5ae45099874d8e43a5c4a6ee40669b9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/pass_infra/__pycache__/proxy_value.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/pass_infra/node_metadata.py b/phivenv/Lib/site-packages/torch/_export/pass_infra/node_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..f743434cfc8ab49bf0b0bb7205161c800aa988bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/pass_infra/node_metadata.py @@ -0,0 +1,32 @@ +from typing import Any + + +NodeMetadataValue = Any + + +PROTECTED_KEYS: set[str] = { + "val", + "stack_trace", + "nn_module_stack", + "debug_handle", + "tensor_meta", +} + + +class NodeMetadata: + def __init__(self, data: dict[str, Any]) -> None: + self.data: dict[str, Any] = data.copy() + + def __getitem__(self, key: str) -> NodeMetadataValue: + return self.data[key] + + def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue: + if key in PROTECTED_KEYS: + raise RuntimeError(f"Could not override node key: {key}") + self.data[key] = value + + def __contains__(self, key: str) -> bool: + return key in self.data + + def copy(self) -> "NodeMetadata": + return NodeMetadata(self.data.copy()) diff --git a/phivenv/Lib/site-packages/torch/_export/pass_infra/proxy_value.py b/phivenv/Lib/site-packages/torch/_export/pass_infra/proxy_value.py new file mode 100644 index 0000000000000000000000000000000000000000..55f0e038f4394d5c65929fc0a76f547274386060 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/pass_infra/proxy_value.py @@ -0,0 +1,45 @@ +# pyre-strict +from collections.abc import Iterable, Iterator +from typing import Generic, TypeVar, Union + +import torch + + +_T = TypeVar("_T") + + +class ProxyValue(Generic[_T]): + # pyre-ignore + def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]): + # pyre-ignore + self.data = data + self.proxy_or_node = proxy + + @property + def node(self) -> torch.fx.Node: + if isinstance(self.proxy_or_node, torch.fx.Node): + return self.proxy_or_node + assert isinstance(self.proxy_or_node, torch.fx.Proxy) + return self.proxy_or_node.node + + @property + def proxy(self) -> torch.fx.Proxy: + if not isinstance(self.proxy_or_node, torch.fx.Proxy): + raise RuntimeError( + f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}" + ) + return self.proxy_or_node + + def to_tensor(self) -> torch.Tensor: + assert isinstance(self.data, torch.Tensor) + return self.data + + def is_tensor(self) -> bool: + return isinstance(self.data, torch.Tensor) + + # pyre-ignore + def __iter__(self) -> Iterator[_T]: + yield from self.data + + def __bool__(self) -> bool: + return bool(self.data) diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__init__.py b/phivenv/Lib/site-packages/torch/_export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ad040cae5672be1b58bfe523d4fb57e41d2344 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/__init__.py @@ -0,0 +1 @@ +from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae5e2f19ac3c32a6f2bc24835db8376a2375ac9a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e73cc6b06398c51671e57c929b610a5c4d0f4ba6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ab24c7893bc79fd79e29d0106ac44fa5212403c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ae565c33a81ace3ca1ab7853641afd53c9c887c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6c4154d4865489b9939cebebb1075c400205207 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85b2bb4fc6e9677903671d792ebd34a60b1147d3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53f2c973f821d13042636ece6490c64bb55e83fd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bff6f0e0c5d23bafb60d55e5da9cc1bb8c20f72 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1de0d6ed5b0c572a5c4d0d2306b81eef0a325dfd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ccb30f1f26e3bbdf7fed39e647ef42799833538 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e350f5e77d39633b353625a9bbf757a0ccb93e8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e94c02d0cd437a2f2390784610b5d2d256b70c72 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fa34174548ae4a2cb07525d346093b898ccf065 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dc39cced7affad36a7022b9af409310e9cc3ff2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/passes/_node_metadata_hook.py b/phivenv/Lib/site-packages/torch/_export/passes/_node_metadata_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..3808abc4d7b7d39c90cb780f74c2d2a6fe106c8f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/_node_metadata_hook.py @@ -0,0 +1,81 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Optional + +import torch +from torch.fx.graph_module import GraphModule + + +_EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook" + + +def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None) -> None: + """ + Hook for adding the appropriate metadata to nodes that are created during a + pass using graph.create_node. An example of how to use it: + + ``` + with _set_node_metadata_hook(gm, + functools.partial(_node_metadata_hook, stack_trace="file") + ): + pass(gm) + ``` + + This hook should not work for all generic cases -- specifically it assumes + that nodes being added are only call_function nodes, and copies over the + first argument node's nn_module_stack. + """ + assert node.op == "call_function" and callable(node.target) + + arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)] + assert len(arg_meta) >= 1 + arg_meta = arg_meta[0] + + if ( + isinstance(node.target, torch._ops.OpOverload) + and len(node.target._schema.returns) == 0 + ): + node.meta["val"] = None + else: + fake_args = [ + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ] + fake_res = node.target(*fake_args) + node.meta["val"] = fake_res + + node.meta["stack_trace"] = stack_trace + node.meta["nn_module_stack"] = arg_meta.get( + "nn_module_stack", + { + _EMPTY_NN_MODULE_STACK_KEY: ( + _EMPTY_NN_MODULE_STACK_KEY, + _EMPTY_NN_MODULE_STACK_KEY, + ) + }, + ) + node.meta["torch_fn"] = ( + f"{node.target.__name__}_0", + f"{node.target.__class__.__name__}.{node.target.__name__}", + ) + + +@contextlib.contextmanager +def _set_node_metadata_hook(gm: torch.fx.GraphModule, f): + """ + Takes a callable which will be called after we create a new node. The + callable takes the newly created node as input and returns None. + """ + assert callable(f), "node_metadata_hook must be a callable." + + # Add the hook to all submodules + for m in gm.modules(): + if isinstance(m, GraphModule): + m._register_create_node_hook(f) + try: + yield + finally: + # Restore hook for all submodules + for m in gm.modules(): + if isinstance(m, GraphModule): + m._unregister_create_node_hook(f) diff --git a/phivenv/Lib/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/phivenv/Lib/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e2ec6432eb54075c79c42f59d5c2f77774cc0f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +import math +import operator +import traceback +from functools import partial +from typing import Callable, NamedTuple + +import sympy + +import torch +import torch.fx +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + + +__all__ = ["InputDim"] + + +class InputDim(NamedTuple): + input_name: str + dim: int + + +def _convert_to_int(val): + # Convert simple sympy Integers into concrete int + if val in (sympy.oo, int_oo): + return math.inf + if val in (-sympy.oo, -int_oo): + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + raise RuntimeError("Export constraints cannot be non-integer expressions") + + +def _convert_range_to_int(range: ValueRanges): + assert isinstance(range, ValueRanges) + min_val = _convert_to_int(range.lower) + max_val = _convert_to_int(range.upper) + return min_val, max_val + + +class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase): + def __init__( + self, + range_constraints: dict[sympy.Symbol, ValueRanges], + ): + super().__init__() + self.range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints + self._asserts_generated_unbacked_symbols: set[sympy.Symbol] = set() + self.counter = 0 + + def _assert_range_constraint(self, node, lower, upper, assert_msg): + last_node = node + if lower > -math.inf: + last_node = self._insert_assert_async( + last_node, operator.ge, node, lower, assert_msg + ) + + if upper < math.inf: + last_node = self._insert_assert_async( + last_node, operator.le, node, upper, assert_msg + ) + + def _insert_assert_async(self, last_node, op, lower, upper, assert_msg): + """ + Inserts assert_async call_function nodes in the graph. This function is + called **during** the interpreter-based pass. + """ + self.counter += 1 + graph = last_node.graph + with graph.inserting_after(last_node): + cmp = graph.call_function(op, (lower, upper), {}) + with graph.inserting_after(cmp): + cmp_tensor = graph.call_function( + torch.ops.aten.scalar_tensor.default, (cmp,), {} + ) + with graph.inserting_after(cmp_tensor): + assert_async = graph.call_function( + torch.ops.aten._assert_async.msg, + (cmp_tensor, assert_msg), + {}, + ) + return assert_async + + def call(self, graph_module) -> PassResult: + self.existing_inline_assertions = _get_existing_inline_assertions( + graph_module, self.range_constraints + ) + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if "val" not in node.meta: + continue + + val = node.meta["val"] + # In general, we may have to deal the case such as: ret[1].shape[0]. + # We need first find out what symbols require assertion, then we need to follow the path + # from ret to the symbol, construct the proxies along the way and construct the messages + # piece-wise at the same time. + # + # We use post-order traversal to collect all the proxies callbacks needed, construct + # the error message callbacks, and at the top-level traversal tree we execute all the callbacks. + # We need the callbacks because, in order to call the function to create a proxy for shape[0], we + # need the proxy for shape, which further requires the proxy for ret[1], etc. + + def add_assertions(val): + call_backs: list[Callable] = [] + messages: list[str] = [] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + symbol = val.node.expr + if symbol in self.existing_inline_assertions: + return call_backs, messages + if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols( + symbol + ): + if symbol in self._asserts_generated_unbacked_symbols: + return call_backs, messages + # We only care about unbacked symints for these inline + # constraints, which are prefixed with 'u' + constraint = self.range_constraints[symbol] + min_val, max_val = _convert_range_to_int(constraint) + assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." + call_backs.append( + partial( + self._assert_range_constraint, + lower=min_val, + upper=max_val, + ) + ) + messages.append(assert_msg) + self._asserts_generated_unbacked_symbols.add(symbol) + + elif isinstance(val, torch.Tensor): + for i, sym in enumerate(val.shape): + cbs, msgs = add_assertions(sym) + for cb, msg in zip(cbs, msgs): + + def sym_size_cb(node, assert_msg, dim): + with node.graph.inserting_after(node): + dim_node = module.graph.call_function( + torch.ops.aten.sym_size.int, + (node, dim), + {}, + ) + cb(node=dim_node, assert_msg=assert_msg) + + call_backs.append(partial(sym_size_cb, dim=i)) + messages.append(f".shape[{i}]" + msg) + return call_backs, messages + + callbacks, messages = add_assertions(val) + for cb, msg in zip(callbacks, messages): + cb(node=node, assert_msg=f"{node}" + msg) + + module.recompile() + + # Sometimes this pass would return a wrong graph where we have mismatched + # node names in signature. Before we fix it, let's just skip it. + if ( + self.counter == 0 + and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass + ): + return PassResult(graph_module, False) + + # Populate the stack trace with dummy vals to respect IR + for node in graph_module.graph.nodes: + if not node.meta.get("stack_trace", None) and node.op not in [ + "placeholder", + "output", + ]: + node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1)) + return PassResult(graph_module, True) + + +def _get_existing_inline_assertions( + graph_module: torch.fx.GraphModule, + range_constraints: dict[sympy.Symbol, ValueRanges], +) -> dict[sympy.Symbol, ValueRanges]: + existing_inline_assertions: dict[sympy.Symbol, ValueRanges] = {} + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + # Find all the existing inline assertions. They will look something like: + # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {}) + # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {}) + # %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {}) + for node in module.graph.nodes: + if node.target != torch.ops.aten._assert_scalar.default: + continue + + compare_arg = node.args[0] + if not ( + isinstance(compare_arg, torch.fx.Node) + and compare_arg.op == "call_function" + and compare_arg.target in (operator.le, operator.ge) + and len(compare_arg.args) == 2 + ): + continue + + compare_op = compare_arg.target + lhs, rhs = compare_arg.args + + def maybe_get_symint(x): + if ( + isinstance(x, torch.fx.Node) + and "val" in x.meta + and isinstance(x.meta["val"], torch.SymInt) + ): + return x.meta["val"].node.expr + return x + + lhs = maybe_get_symint(lhs) + rhs = maybe_get_symint(rhs) + + if compare_op == operator.ge: + lhs, rhs = rhs, lhs + + if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int): + symint = lhs + scalar = rhs + elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int): + symint = rhs + scalar = lhs + else: + continue + + if symint not in range_constraints: + raise RuntimeError( + f"Unable to find symint {symint} in {range_constraints}" + ) + + previous_range = existing_inline_assertions.get( + symint, ValueRanges(-math.inf, math.inf) + ) + + if symint is lhs: + bounds = ValueRanges(-math.inf, scalar) + else: + bounds = ValueRanges(scalar, math.inf) + existing_inline_assertions[symint] = previous_range & bounds + + return existing_inline_assertions diff --git a/phivenv/Lib/site-packages/torch/_export/passes/collect_tracepoints_pass.py b/phivenv/Lib/site-packages/torch/_export/passes/collect_tracepoints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..565dce96bbb0ff594b285584857d4316d4a4bed2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/collect_tracepoints_pass.py @@ -0,0 +1,146 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import operator +from typing import Optional, TYPE_CHECKING, Union + +import torch +from torch.export.exported_program import ConstantArgument, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +if TYPE_CHECKING: + from torch.export.exported_program import ModuleCallSignature + from torch.export.graph_signature import ExportGraphSignature + + +__all__ = ["CollectTracepointsPass"] + + +class CollectTracepointsPass(PassBase): + """ + Performs constant folding and constant propagation. + """ + + def __init__( + self, specs: dict[str, ModuleCallSignature], sig: ExportGraphSignature + ) -> None: + super().__init__() + self.specs = specs + self.sig = sig + + def call(self, gm: torch.fx.GraphModule) -> Optional[PassResult]: + def get_arg_spec(arg) -> Union[TensorArgument, ConstantArgument]: + if isinstance(arg, torch.fx.Node): + if isinstance(arg.meta.get("val"), torch.Tensor): + return TensorArgument(name=arg.name) + else: + raise AssertionError( + "Symint input is not implemented yet for submodule call signature." + ) + else: + return ConstantArgument(name="", value=arg) + + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + nn_module_stack = None + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.higher_order._export_tracepoint: + kind = node.kwargs["kind"] + if kind == "module_call_outputs": + nn_module_stack = node.meta["nn_module_stack"] + elif kind == "module_call_inputs": + nn_module_stack = None + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + elif node.meta["nn_module_stack"] == nn_module_stack: + node.meta["nn_module_stack"].popitem() + else: + nn_module_stack = None + nn_module_stack = None + for node in reversed(module.graph.nodes): + if node.op != "call_function": + continue + if node.target == torch.ops.higher_order._export_tracepoint: + kind = node.kwargs["kind"] + if kind == "module_call_inputs": + nn_module_stack = node.meta["nn_module_stack"] + elif kind == "module_call_outputs": + nn_module_stack = None + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + elif node.meta["nn_module_stack"] == nn_module_stack: + node.meta["nn_module_stack"].popitem() + else: + nn_module_stack = None + + def copy_sig(sig) -> ModuleCallSignature: + from torch.export.exported_program import ModuleCallSignature + + return ModuleCallSignature( + inputs=[], + outputs=[], + in_spec=sig.in_spec, + out_spec=sig.out_spec, + forward_arg_names=None, + ) + + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.higher_order._export_tracepoint: + # There's some subtlety worth noting. Here fqn corresponds to + # the call name, whereas path corresponds to the module name. + # They are not necessarily the same! When a submodule is shared + # through different aliases, there are as many _export_tracepoint + # markers as there are aliases, since the shared submodule is + # wrapped once for each alias. + path = node.kwargs["path"] + fqn, _ = next(reversed(node.meta["nn_module_stack"].values())) + + module_key = next(reversed(node.meta["nn_module_stack"])) + if "@" in module_key: + suffix = module_key.split("@")[-1] + path = f"{path}@{suffix}" + + call_fqn = f"{fqn}@{suffix}" + if call_fqn not in self.specs: + self.specs[call_fqn] = copy_sig(self.specs[fqn]) + fqn = call_fqn + + kind = node.kwargs["kind"] + for i, arg in enumerate(node.args): + # We only update the signature of the alias used to call + # the submodule. Otherwise the signatures of all aliases + # would get conflated; the inputs/outputs of every call + # would be recorded in every other call as well. + if fqn == path: + if kind == "module_call_inputs": + self.specs[path].inputs.append(get_arg_spec(arg)) + elif kind == "module_call_outputs": + self.specs[path].outputs.append(get_arg_spec(arg)) + else: + raise AssertionError(f"Unknown tracepoint kind: {kind}") + if isinstance(arg, torch.fx.Node): + for user in node.users: + assert user.op == "call_function" + assert user.target == operator.getitem + assert isinstance(user.args[1], int) + if user.args[1] == i: + user.replace_all_uses_with(arg) + self.sig.replace_all_uses(user.name, arg.name) + break + users = list(node.users) + for user in users: + assert len(user.users) == 0 + gm.graph.erase_node(user) + gm.graph.erase_node(node) + return PassResult(gm, True) + + return None diff --git a/phivenv/Lib/site-packages/torch/_export/passes/constant_folding.py b/phivenv/Lib/site-packages/torch/_export/passes/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..66debb1ce81f19ba9033dff423d050daf861dcab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/constant_folding.py @@ -0,0 +1,303 @@ +# mypy: allow-untyped-defs +import collections +from collections import defaultdict +from typing import Any, Callable, Optional + +import torch +import torch.utils._pytree as pytree + + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + + +def replace_node_with_constant(gm, node, constant, name=None): + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm: torch.fx.GraphModule, + skip_constructors: bool = False, + ): + super().__init__(gm) + self.node_replacements: dict[torch.fx.Node, Any] = {} + self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + + def is_impure(self, node: torch.fx.Node) -> bool: + if ( + node.target == torch.ops.prims.convert_element_type.default + and node.args[0].op == "get_attr" # type: ignore[union-attr] + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ): + # For int8_weight -> dq -> bf16_weight + return True + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.pt2e_quant.dequantize_affine, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + return False + + def node_to_last_non_output_use(self): + last_non_output_use = collections.defaultdict(list) + seen_uses = set() + output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr] + + for node in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] + if node.target == "output": + continue + + def add_use(inp): + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node): + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg): + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) == type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target == aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and node.op != "get_attr" + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + out = super().run_node(node) + + if node.op != "get_attr" and isinstance(out, torch.Tensor): + if out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self): # type: ignore[override] + env = {} + for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] + env[n] = self.unknown_value + return super().run(initial_env=env) + + +def constant_fold( + gm: torch.fx.GraphModule, + constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +): + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + # Get all attr users by looking up the graph instead from node.users, because in this case + # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor. + + # opcode name target args kwargs + # ------------- ------------------- ---------------- --------------------------- -------- + # placeholder arg0_1 arg0 () {} + # get_attr _tensor_constant0 state () {} + # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {} + # get_attr _tensor_constant0_1 state () {} + # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {} + # output output output ([add],) {} + + get_attr_node_users = defaultdict(list) + for node in gm.graph.nodes: + if node.op == "get_attr": + get_attr_node_users[node.target].extend(node.users.keys()) + for node in gm.graph.find_nodes(op="get_attr"): + if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def constant_graph_tag(gm: torch.fx.GraphModule) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node in gm.graph.nodes: + if ( + node.op == "get_attr" + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag(gm) + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.find_nodes(op="get_attr"): + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + + new_graph = torch.fx.Graph() + + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/phivenv/Lib/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py b/phivenv/Lib/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9d8cc4286a1ad0b460fb2a978a46468f640fcb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py @@ -0,0 +1,99 @@ +import copy +from typing import Optional + +import torch +from torch._export.pass_base import ( + _ExportPassBaseDeprecatedDoNotUse, + Argument, + PassResult, +) +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._ops import OpOverload + + +aten = torch.ops.aten + +_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = { + aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default, + aten._assert_async.msg: aten._functional_assert_async.msg, +} + + +class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Functionalize ops with side effect in graph module by replacing the op with + functional version of it. A new dependency token (`dep_token`) will be + created and propagated through functional ops to output. + For example: + ``` + def f(x): + sym_constrain_range(x.shape[0], min=1, max=3) + return x.add(3) + ``` + Will be transformed to: + ``` + def f(x): + dep_token0 = _make_dep_token() + dep_token1 = _functional_sym_constrain_range( + x.shape[0], min=1, max=3, dep_token=dep_token0 + ) + + return x.add(3), dep_token1 + ``` + """ + + def __init__(self) -> None: + super().__init__() + self._dep_token: Optional[ProxyValue] = None + self._next_dep_token_index: Optional[int] = None + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Early return if no non-functional assertions. + if not any( + n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS + for n in graph_module.graph.nodes + ): + return PassResult(graph_module=graph_module, modified=False) + + gm = copy.deepcopy(graph_module) + self._dep_token = None + self._next_dep_token_index = None + return super().call(gm) + + def call_operator( + self, + op: OpOverload, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: + return super().call_operator(op, args, kwargs, meta) + + if self._dep_token is None: + self._dep_token = super().call_operator( + aten._make_dep_token, + args=(), + kwargs={}, + meta=self._create_dummy_node_metadata(), + ) + self._dep_token.node.name = "dep_token0" + self._next_dep_token_index = 1 + + self._dep_token = super().call_operator( + _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op], + args=args, + kwargs={**kwargs, "dep_token": self._dep_token}, + meta=meta, + ) + assert self._next_dep_token_index is not None + self._dep_token.node.name = f"dep_token{self._next_dep_token_index}" + self._next_dep_token_index += 1 + + return self._dep_token + + def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue: + assert self._dep_token is not None + + return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type] diff --git a/phivenv/Lib/site-packages/torch/_export/passes/insert_custom_op_guards.py b/phivenv/Lib/site-packages/torch/_export/passes/insert_custom_op_guards.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6a99a8faa2cd9332d6390c53cf24534775c115 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/insert_custom_op_guards.py @@ -0,0 +1,78 @@ +import functools +from collections import defaultdict + +import torch +from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, +) +from torch._library.fake_profile import OpProfile, TensorMetadata + + +def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) -> None: + """ + This is used by draft_export to insert guards in front of calls to custom + operators which have a generated fake kernel. + """ + for node in gm.graph.nodes: + if node.op == "call_function" and str(node.target) in ops_to_guard: + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, stack_trace=node.meta.get("stack_trace") + ), + ), gm.graph.inserting_before(node): + for arg in (*node.args, *node.kwargs.values()): + if isinstance(arg, torch.fx.Node) and isinstance( + arg.meta.get("val"), torch.Tensor + ): + val = arg.meta["val"] + gm.graph.call_function( + torch.ops.aten._assert_tensor_metadata.default, + args=(arg,), + kwargs={ + "dtype": val.dtype, + "device": val.device, + "layout": val.layout, + }, + ) + + gm.recompile() + + +def get_op_profiles( + gm: torch.fx.GraphModule, ops_to_guard: set[str] +) -> dict[str, set[OpProfile]]: + """ + This is used by draft_export to get a list of custom operator profiles so + that we can generate fake kernels. + """ + + def _get_op_profile(node: torch.fx.Node) -> OpProfile: + args_profile = tuple( + [ + TensorMetadata.maybe_from_tensor(arg.meta.get("val")) + if isinstance(arg, torch.fx.Node) + else None + for arg in (*node.args, *node.kwargs.values()) + ] + ) + + out_profile = None + meta = node.meta.get("val") + assert meta is not None + if isinstance(meta, torch.Tensor): + out_profile = TensorMetadata.maybe_from_tensor(meta) + elif isinstance(meta, (list, tuple)): + out_profile = tuple([TensorMetadata.maybe_from_tensor(m) for m in meta]) # type: ignore[assignment] + assert out_profile is not None + + return OpProfile(args_profile, out_profile) # type: ignore[arg-type] + + op_profiles: dict[str, set[OpProfile]] = defaultdict(set) + + for node in gm.graph.nodes: + if node.op == "call_function" and str(node.target) in ops_to_guard: + op_profiles[str(node.target)].add(_get_op_profile(node)) + + return op_profiles diff --git a/phivenv/Lib/site-packages/torch/_export/passes/lift_constants_pass.py b/phivenv/Lib/site-packages/torch/_export/passes/lift_constants_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..f26ea98bac71fd20cb583ef80b9ebb73df5011fd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/lift_constants_pass.py @@ -0,0 +1,414 @@ +# mypy: allow-untyped-defs +import collections +import logging +from typing import Any, Optional, Union + +import torch +from torch._export.verifier import SpecViolationError +from torch._guards import detect_fake_mode +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.export.exported_program import ( + ArgumentSpec, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + TensorArgument, +) +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx.graph_module import _get_attr + + +log = logging.getLogger(__name__) + + +class ConstantAttrMap(collections.abc.MutableMapping): + """A mapping class that understands how to use module constants (tensors, + ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally, + but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to + the same underlying value (but we guarantee that they will `hash()` to the same value + if that's the case). + """ + + def __init__(self) -> None: + # Underlying dict that we use to implement this mapping. + self._constant_attrs: dict[ + Union[int, torch.Tensor, FakeScriptObject, torch.utils._pytree.TreeSpec], + list[Any], + ] = {} + # Map from the hash(ScriptObject) to the ScriptObject itself. Used for + # APIs like `__iter__` that should look like they're returning the + # original ScriptObjects. + self._script_object_map: dict[int, torch.ScriptObject] = {} + + def __getitem__(self, key: _ConstantAttributeType) -> Any: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject)) + return self._constant_attrs[real_key] + + def __setitem__(self, key: _ConstantAttributeType, value): + # we shouldn't actually call this, should go to add() instead to handle aliasing + raise NotImplementedError( + """Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead. +The same key can be mapped to multiple values, for handling constant aliasing.""" + ) + + def add(self, key: _ConstantAttributeType, value: Any) -> None: + if isinstance(key, torch.ScriptObject): + if hash(key) not in self._constant_attrs: + self._constant_attrs[hash(key)] = [] + self._constant_attrs[hash(key)].append(value) + self._script_object_map[hash(key)] = key + elif isinstance(key, (torch.Tensor, FakeScriptObject)): + if key not in self._constant_attrs: + self._constant_attrs[key] = [] + self._constant_attrs[key].append(value) + else: + raise TypeError( + f"Expected key to be a tensor or ScriptObject, got {type(key)}" + ) + + def __delitem__(self, key: _ConstantAttributeType): + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + + del self._constant_attrs[real_key] + + def __iter__(self): + for key in self._constant_attrs: + if isinstance(key, int): + yield self._script_object_map[key] + else: + yield key + + def __len__(self): + return len(self._constant_attrs) + + def __contains__(self, key: object) -> bool: + real_key = hash(key) if isinstance(key, torch.ScriptObject) else key + return real_key in self._constant_attrs + + +def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str: + # The FQN of the constant tensor in the state dict should + # correspond to the module where the constant tensor was + # originally used. + if len(node.meta["nn_module_stack"]) == 0: + return constant_name + parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0] + if len(parent_fqn) > 0: + return f"{parent_fqn}.{constant_name}" + else: + return constant_name + + +def _get_first_fqn( + const_attrs: ConstantAttrMap, + key: _ConstantAttributeType, +) -> Any: + fqns = const_attrs.get(key) + return fqns[0] if fqns else None + + +def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]: + """ + If there is a tensor constant created while tracing, here is how the graph + looks like: + + %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0] + %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,)) + %detach_ : [num_users=?] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,)) + + To check to see if the tensor constant is being used, we want to traverse to + the detach node to see if it's actually being used. + + This function returns None if this constant is being used, otherwise it returns the + lift_fresh and detach node to be removed later. + """ # noqa: B950 + if len(node.users) > 1: + return None + + lift_fresh_node = next(iter(node.users.keys())) + if not ( + lift_fresh_node.op == "call_function" + and lift_fresh_node.target + in ( + torch.ops.aten.lift_fresh.default, + torch.ops.aten.lift_fresh_copy.default, + ) + ): + return None + + if len(lift_fresh_node.users) > 1: + return None + + detach_node = next(iter(lift_fresh_node.users.keys())) + if not ( + detach_node.op == "call_function" + and detach_node.target + in ( + torch.ops.aten.detach_.default, + torch.ops.aten.detach.default, + ) + ): + return None + + if len(detach_node.users) > 0: + return None + else: + return [detach_node, lift_fresh_node, node] + + +def lift_constants_pass( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + constant_attrs: ConstantAttrMap, +) -> dict[str, _ConstantAttributeType]: + """ + Takes a graph module, graph signature, and modifies them implace to lift any + constants (tensors or custom classes) as inputs to the graph. Returns a + dictionary of names to constants. + + Arguments: + gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift. + graph_signature (ExportGraphSignature): This graph signature will be + mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs. + constant_attrs (ConstantAttr): A mapping from a constant value to its + fully-qualified path in `gm`. This is used to maintain consistent + location of constants between the original module and the exported + version. + + Returns: + A dictionary of fqn => constant value. + """ + all_constants: dict[str, _ConstantAttributeType] = {} + + inputs = graph_signature.input_specs + num_custom_obj = sum( + input_specs.kind == InputKind.CUSTOM_OBJ for input_specs in inputs + ) + num_tensor_constants = sum( + input_specs.kind == InputKind.CONSTANT_TENSOR for input_specs in inputs + ) + + fake_mode = detect_fake_mode( + tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder") + ) + + first_user_input_loc, first_user_input = 0, next(iter(gm.graph.nodes)) + used_target_names = set() + for node in gm.graph.nodes: + if node.op == "placeholder": + if node.name in graph_signature.user_inputs: + first_user_input = node + break + used_target_names.add(inputs[first_user_input_loc].target) + first_user_input_loc += 1 + # If we ever hit here, it means that + # there was no user input so the constants + # should be inserted right before the first + # non-placeholder node. + if node.op != "placeholder": + first_user_input = node + break + + lifted_objs = ConstantAttrMap() + renamed_targets = {} + for node in list(gm.graph.nodes): + if node.op == "get_attr": + if nodes_to_remove := _unused_constant(node): + # Remove the node if it's not being used + for node_rm in nodes_to_remove: + gm.graph.erase_node(node_rm) + continue + + constant_val = _get_attr(gm, node.target) + # These are not hashable and not gonna be lifted + # so we can skip them earlier + if isinstance(constant_val, torch.fx.GraphModule): + continue + if "LoweredBackendModule" in type(constant_val).__name__: + continue + if "AOTInductorRunnerWrapper" in type(constant_val).__name__: + continue + if isinstance(constant_val, torch.utils._pytree.TreeSpec): + continue + + if constant_val in lifted_objs: + # We already lifted this constant elsewhere. Just rewrite uses + # of this get_attr to point to the already-existing placeholder + # node. + const_placeholder_node = _get_first_fqn(lifted_objs, constant_val) + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + renamed_targets[node.name] = const_placeholder_node.name + continue + + # For ScriptObject, Tensor and FakeScriptObject constants: + # First check if the constant was an attribute on some module by + # consulting `constant_attrs` map. If it is, use the fqn that keeps + # its location consistent with the eager module. + # + # If it's not in the `constant_attrs` map, that means it's an inline + # constant (e.g. x + torch.tensor(0)), and thus did not have a + # specific location in the eager module. In that case, just generate + # some name and attach it to the module in which it was used. + if isinstance(constant_val, (torch.ScriptObject, FakeScriptObject)): + constant_kind = InputKind.CUSTOM_OBJ + constant_fqn = _get_first_fqn(constant_attrs, constant_val) + if constant_fqn is not None: + constant_name = constant_fqn.replace(".", "_") + else: + constant_name = f"lifted_custom_{num_custom_obj}" + constant_fqn = get_constant_fqn(node, constant_name) + while constant_fqn in used_target_names: + num_custom_obj += 1 + constant_name = f"lifted_custom_{num_custom_obj}" + constant_fqn = get_constant_fqn(node, constant_name) + num_custom_obj += 1 + elif isinstance(constant_val, torch.Tensor): + # Remove the parameterness of constant_val + if isinstance(constant_val, torch.nn.Parameter): + log.debug( + "%s created when tracing %s is a parameter. But " + "it's not registered with register_parameter(). export will treat it as a constant tensor", + str(node.target), + str(node.meta.get("stack_trace", "")), + ) + # We get the real data out of the parameter by disabling the surrounding fake mode. + with unset_fake_temporarily(): + constant_val = constant_val.data + constant_kind = InputKind.CONSTANT_TENSOR + constant_fqn = _get_first_fqn(constant_attrs, constant_val) + if constant_fqn is not None: + constant_name = constant_fqn.replace(".", "_") + else: + constant_name = f"lifted_tensor_{num_tensor_constants}" + constant_fqn = get_constant_fqn(node, constant_name) + while constant_fqn in used_target_names: + num_tensor_constants += 1 + constant_name = f"lifted_tensor_{num_tensor_constants}" + constant_fqn = get_constant_fqn(node, constant_name) + num_tensor_constants += 1 + else: + raise SpecViolationError( + f"getattr node {node} referencing unsupported type {type(constant_val)}" + ) + + with gm.graph.inserting_before(first_user_input): + # Insert the constant node before the first user input + const_placeholder_node = gm.graph.placeholder(constant_name) + # match target name with its node name in case there is name collision + # and suffix is added to node name in fx + const_placeholder_node.target = const_placeholder_node.name + + for k, v in node.meta.items(): + const_placeholder_node.meta[k] = v + + # Once the FQN has been used, remove nn_module_stack, stack_trace + const_placeholder_node.meta.pop("nn_module_stack") + const_placeholder_node.meta.pop("stack_trace", None) + + input_spec_arg: ArgumentSpec + if isinstance(constant_val, torch.Tensor): + if fake_mode is not None: + const_placeholder_node.meta["val"] = fake_mode.from_tensor( + constant_val, static_shapes=True + ) + const_placeholder_node.meta["val"].constant = constant_val + else: + const_placeholder_node.meta["val"] = constant_val + input_spec_arg = TensorArgument(name=const_placeholder_node.name) + elif isinstance(constant_val, torch._C.ScriptObject): + class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined] + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, class_fqn=class_fqn + ) + elif isinstance(constant_val, FakeScriptObject): + class_fqn = constant_val.script_class_name + const_placeholder_node.meta["val"] = CustomObjArgument( + constant_fqn, class_fqn, constant_val + ) + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, + class_fqn=class_fqn, + fake_val=constant_val, + ) + else: + raise SpecViolationError( + f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}" + ) + + lifted_objs.add(constant_val, const_placeholder_node) + node.replace_all_uses_with(const_placeholder_node) + gm.graph.erase_node(node) + + renamed_targets[node.name] = const_placeholder_node.name + + # Add the constant as a buffer to the graph signature + graph_signature.input_specs.insert( + first_user_input_loc, + InputSpec( + kind=constant_kind, + arg=input_spec_arg, + target=constant_fqn, + ), + ) + if constant_val in constant_attrs: + for fqn in constant_attrs[constant_val]: + all_constants[fqn] = constant_val + else: + all_constants[constant_fqn] = constant_val + first_user_input_loc += 1 + + for spec in graph_signature.output_specs: + if spec.arg.name in renamed_targets: + spec.arg.name = renamed_targets[spec.arg.name] + + return all_constants + + +def rewrite_script_object_meta( + gm: torch.fx.GraphModule, +) -> dict[str, _ConstantAttributeType,]: + """When tracing, we produce a graph with FakeScriptObject in the + meta["val"]. + + For now, we rewrie meta["val"] to be a placeholder CustomObjArgument + """ + constants: dict[ + str, + _ConstantAttributeType, + ] = {} + for node in gm.graph.nodes: + if "val" not in node.meta: + continue + + old_meta = node.meta["val"] + + if isinstance(old_meta, torch.ScriptObject): + class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + elif isinstance(old_meta, FakeScriptObject): + class_fqn = old_meta.script_class_name # type: ignore[attr-defined] + new_meta = CustomObjArgument(node.name, class_fqn, old_meta) + constants[node.name] = old_meta + node.meta["val"] = new_meta + + return constants + + +def _materialize_and_lift_constants( + gm: torch.fx.GraphModule, + export_graph_signature: ExportGraphSignature, + constant_attrs: ConstantAttrMap, +) -> dict[str, _ConstantAttributeType]: + constants = rewrite_script_object_meta(gm) + constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) + return constants diff --git a/phivenv/Lib/site-packages/torch/_export/passes/remove_runtime_assertions.py b/phivenv/Lib/site-packages/torch/_export/passes/remove_runtime_assertions.py new file mode 100644 index 0000000000000000000000000000000000000000..021a08d63180f31bf0eb0ae1f05aca22ee7512d4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/remove_runtime_assertions.py @@ -0,0 +1,36 @@ +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class _RemoveRuntimeAssertionsPass(PassBase): + """ + Remove runtime assertions inserted by the + _AddRuntimeAssertionsForInlineConstraintsPass. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.target in [ + torch.ops.aten._assert_async.msg, + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten._assert_tensor_metadata.default, + ]: + assert_async_node = node + if len(assert_async_node.users) > 0: + continue + module.graph.erase_node(assert_async_node) + # the upstream scalar_tensor <- {le, ge} <- sym_size + # linear chain of nodes of nodes is removed by the + # downstream dead code elimination + modified = True + + # We don't necessarily want to run DCE here because it could affect + # nodes that are in the module_call_graph attribute of the exported + # program. We will leave it to the pass caller to call DCE. + return PassResult(graph_module, modified) diff --git a/phivenv/Lib/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py b/phivenv/Lib/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..a3fa6da610b3e169a0b252a013441a2162e31f81 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py @@ -0,0 +1,189 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING, Union + +import torch +from torch._higher_order_ops.wrap import wrap_with_autocast + +from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split +from .replace_with_hop_pass_util import ( + _replace_with_hop_helper, + _replace_with_hop_pass_helper, + _sequential_split_and_maybe_inline_subgraphs_helper, +) + + +if TYPE_CHECKING: + from torch.export.graph_signature import ExportGraphSignature + + +def _is_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: + return ( + node + and node.op == "call_function" + and node.target + in [ + torch.amp.autocast_mode._enter_autocast, + torch.amp.autocast_mode._exit_autocast, + ] + ) + + +def _is_enter_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: + return ( + node + and node.op == "call_function" + and node.target == torch.amp.autocast_mode._enter_autocast + ) + + +def _is_exit_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: + return ( + node + and node.op == "call_function" + and node.target == torch.amp.autocast_mode._exit_autocast + ) + + +def _is_autocast_sub_mod(node: torch.fx.Node) -> bool: + """ + Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`. + """ + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target == torch.amp.autocast_mode._enter_autocast + ): + # TODO: check if current auto-cast type is the same as the args of + # _enter_autocast. If so, return False, i.e. do not create a submodule. + return True + return False + + +def _check_valid_autocast_block( + enter_autocast_node: torch.fx.Node, exit_autocast_node: torch.fx.Node +) -> None: + assert _is_enter_autocast_node(enter_autocast_node) + assert _is_exit_autocast_node(exit_autocast_node) + assert exit_autocast_node.args[0] == enter_autocast_node + + +def _replace_with_hop(node: torch.fx.Node) -> None: + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node) + if len(autocast_nodes) > 0: + assert len(autocast_nodes) > 1 # need at least an enter node and an exist node + enter_autocast_node = autocast_nodes[0] + exit_autocast_node = autocast_nodes[-1] + _check_valid_autocast_block(enter_autocast_node, exit_autocast_node) + + _replace_with_hop_helper(node, enter_autocast_node, wrap_with_autocast) + sub_graph.erase_node(exit_autocast_node) + sub_graph.erase_node(enter_autocast_node) + + +def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + split_autocast creates a new graph module that splits the input graph module into multiple submodules + based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module. + + Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are splitted + into a submodule. Nested autocast regions are not splitted. + `_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well. + + Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph + module. Nodes marked with the same number are grouped into the same submodule. + A # 0 + enter_autocast # 1 + B # 1 + exit_autocast # 1 + C # 2 + enter_autocast # 3 + D # 3 + exit_autocast # 3 + E # 4 + """ + enter_autocast_node_stack: list[torch.fx.Node] = [] + first_node_after_outer_most_exit: bool = False + + def node_call_back(node: torch.fx.Node) -> bool: + nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit + increment_id = False + if first_node_after_outer_most_exit or ( + len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node) + ): + assert len(enter_autocast_node_stack) == 0 + first_node_after_outer_most_exit = False + increment_id = True + if _is_enter_autocast_node(node): + enter_autocast_node_stack.append(node) + elif _is_exit_autocast_node(node): + assert len(enter_autocast_node_stack) > 0 + last_enter_autocast_node = enter_autocast_node_stack.pop() + assert node.args[0] == last_enter_autocast_node + if len(enter_autocast_node_stack) == 0: + # next node should be in the next submodule since + # autocast block ends + first_node_after_outer_most_exit = True + return increment_id + + return sequential_split(gm, node_call_back) + + +def _sequential_split_and_maybe_inline_subgraphs( + gm: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] +) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + """ + Helper function for replace_autocast_with_hop_pass(). + Split the graph module into multiple subgraphs based on the autocast nodes. + For each subgraph, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module. + Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered + as a subgraph. + """ + need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes) + if not need_replacing: + return gm, graph_signature + + # split_autocast returns a new graph module that could have different output + # args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`. + new_gm = _split_autocast(gm) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node) -> None: + if _is_autocast_sub_mod(node): + _replace_with_hop(node) + else: + assert node.op == "call_module" + assert isinstance(node.target, str) + node_inline_(node) + + return _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm, graph_signature, _maybe_inline_or_replace_with_hop + ) + + +def replace_autocast_with_hop_pass( + gm: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] +) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + return _replace_with_hop_pass_helper( + gm, + graph_signature, + _sequential_split_and_maybe_inline_subgraphs, + ) diff --git a/phivenv/Lib/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py b/phivenv/Lib/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..130360c7d195f0bb8de7c6ece90573537573e609 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py @@ -0,0 +1,673 @@ +# mypy: allow-untyped-defs +import logging +import operator +from typing import Optional, Union + +import torch +import torch.export._trace +from torch._ops import OpOverload +from torch.ao.quantization.fx._decomposed import ( + dequantize_per_channel, + dequantize_per_tensor, + quantize_per_tensor, +) +from torch.ao.quantization.utils import calculate_qmin_qmax +from torch.fx.graph_module import _assign_attr + + +log = logging.getLogger(__name__) + +# Those values will need to be carried over multiple operators. +_INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None +_SCALE: Optional[Union[float, torch.fx.Node]] = None +_ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None + + +def int_to_valid_dtype(val: int) -> torch.dtype: + from torch._export.converter import _TORCH_ENUM_TO_DTYPE # No circular import. + + if isinstance(val, torch.dtype): + return val + dtype = _TORCH_ENUM_TO_DTYPE[val] + if dtype == torch.quint8: + return torch.uint8 + elif dtype == torch.qint8: + return torch.int8 + return dtype + + +def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node: + return gm.graph.call_function(int_to_valid_dtype, (val,)) + + +def insert_quantized_node( + gm: torch.fx.GraphModule, + val_node: torch.fx.Node, + scale_node: Union[float, torch.fx.Node], + zero_point_node: Union[float, torch.fx.Node], + qmin_node: Union[float, int, torch.fx.Node], + qmax_node: Union[float, int, torch.fx.Node], + dtype_node: Union[torch.dtype, torch.fx.Node], + qscheme: Optional[torch.qscheme], +) -> torch.fx.Node: + return gm.graph.call_function( + quantize_per_tensor, + ( + val_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + + +def get_dequantized( + val: torch.Tensor, + scale: Union[float, torch.Tensor], + zero_point: Union[float, torch.Tensor], + qmin: Union[float, int], + qmax: Union[float, int], + dtype: torch.dtype, + axis: Optional[int], + qscheme: Optional[torch.qscheme], +) -> torch.Tensor: + if qscheme is torch.per_tensor_affine: + return dequantize_per_tensor( + val, + scale, # type: ignore[arg-type] + zero_point, # type: ignore[arg-type] + qmin, # type: ignore[arg-type] + qmax, # type: ignore[arg-type] + dtype, + ) + elif qscheme is torch.per_channel_affine: + return dequantize_per_channel( + val, + scale, # type: ignore[arg-type] + zero_point, # type: ignore[arg-type] + axis, # type: ignore[arg-type] + qmin, # type: ignore[arg-type] + qmax, # type: ignore[arg-type] + dtype, + ) + else: + raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") + + +def insert_dequantized_node( + gm: torch.fx.GraphModule, + val_node: torch.fx.Node, + scale_node: Union[float, torch.fx.Node], + zero_point_node: Union[float, torch.fx.Node], + qmin_node: Union[float, int, torch.fx.Node], + qmax_node: Union[float, int, torch.fx.Node], + dtype_node: Union[torch.dtype, torch.fx.Node], + axis_node: Optional[Union[int, torch.fx.Node]], + qscheme: Optional[torch.qscheme], +) -> torch.fx.Node: + if qscheme is torch.per_tensor_affine: + return gm.graph.call_function( + dequantize_per_tensor, + ( + val_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + elif qscheme is torch.per_channel_affine: + return gm.graph.call_function( + dequantize_per_channel, + ( + val_node, + scale_node, + zero_point_node, + axis_node, + qmin_node, + qmax_node, + dtype_node, + ), + ) + else: + raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") + + +def get_qmin_qmax(dtype: torch.dtype) -> tuple[Union[int, float], Union[int, float]]: + return calculate_qmin_qmax(None, None, False, dtype, False) # type: ignore[arg-type] + + +def insert_qmin_qmax_node( + gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node] +) -> tuple[torch.fx.Node, torch.fx.Node]: + q_min_max_node = gm.graph.call_function( + calculate_qmin_qmax, (None, None, False, dtype_node, False) + ) + qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0)) + qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1)) + return qmin_node, qmax_node + + +def get_script_object( + gm: torch.nn.Module, node: torch.fx.Node +) -> torch._C.ScriptObject: + assert isinstance(node, torch.fx.Node) + assert node.op == "get_attr" + attr_name = node.target + assert isinstance(attr_name, str) + + mod = gm + for attr in attr_name.split("."): + mod = getattr(mod, attr) + assert isinstance(mod, torch._C.ScriptObject) + return mod + + +def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm: torch.fx.GraphModule, + param_node: torch.fx.Node, +) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]: + """Directly inline tensor from a get_attr fx node.""" + mod = get_script_object(gm, param_node) + w_qtensor, b_qtensor = mod.unpack() # type: ignore[attr-defined] + w_attr_name, b_attr_name = ( + f"dequantized_{param_node.target}_w", + f"dequantized_{param_node.target}_b", + ) + return insert_weight_and_bias_get_attr_node( + gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name + ) + + +def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm: torch.fx.GraphModule, + get_attr_to_weight_node: torch.fx.Node, + get_attr_to_bias_node: Optional[torch.fx.Node], +) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]: + assert isinstance(get_attr_to_weight_node.target, str) + w_qtensor = getattr(gm, get_attr_to_weight_node.target) + w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w" + + if get_attr_to_bias_node is not None: + assert isinstance(get_attr_to_bias_node.target, str) + b_qtensor = getattr(gm, get_attr_to_bias_node.target) + b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b" + else: + b_qtensor, b_attr_name = None, "" + + return insert_weight_and_bias_get_attr_node( + gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name + ) + + +def insert_weight_and_bias_get_attr_node( + gm: torch.fx.GraphModule, + w_qtensor: torch.Tensor, + b_qtensor: Optional[torch.Tensor], + w_attr_name: str, + b_attr_name: str, +) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]: + w_tensor = get_tensor_from_qtensor(w_qtensor) + _assign_attr(w_tensor, gm, w_attr_name) + w_tensor_attr = gm.graph.get_attr(w_attr_name) + + if b_qtensor is not None: + b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False) + _assign_attr(b_tensor, gm, b_attr_name) + b_tensor_attr = gm.graph.get_attr(b_attr_name) + else: + b_tensor_attr = None + + return w_tensor_attr, b_tensor_attr + + +def get_tensor_from_qtensor( + qtensor: torch.Tensor, dequant: bool = True +) -> torch.Tensor: + # Manual conversion because qint8 is not used anymore. + if qtensor.dtype in [torch.qint8, torch.quint8]: + tensor = qtensor.int_repr() + else: + tensor = qtensor + + # Weights need dequantization with scaling and zero_point adjustment, but + # bias does not need that. + if dequant: + qscheme = qtensor.qscheme() + if qscheme == torch.per_channel_affine: + scale, zero_point, axis = ( + qtensor.q_per_channel_scales(), + qtensor.q_per_channel_zero_points(), + qtensor.q_per_channel_axis(), + ) + else: + scale, zero_point, axis = ( + qtensor.q_scale(), # type: ignore[assignment] + qtensor.q_zero_point(), # type: ignore[assignment] + None, + ) + dtype = tensor.dtype + qmin, qmax = get_qmin_qmax(dtype) + return get_dequantized( + tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme + ) + return tensor + + +def insert_fused_activation_node( + gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node +) -> torch.fx.Node: + if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]: + fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,)) + return fx_node + + +def _conv1d_op_with_squeeze( + inp: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, +) -> torch.Tensor: + # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze + # operations before and after the conv2d operation to match the dimension of weights. + # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950 + s_inp = torch.ops.aten.unsqueeze(inp, 2) + conv1d_res = torch.ops.aten.conv2d( + s_inp, + weight, + bias, + stride, + padding, + dilation, + groups, + ) + uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2) + return uns_conv1d_res + + +def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Conv specfic transformation function.""" + assert isinstance(node.target, torch._ops.OpOverload) + opname = node.target._opname + scale_node, zero_point_node = node.args[2], node.args[3] + + op_f = ( + torch.ops.aten.conv2d + if opname in ["conv2d", "conv2d_relu"] + else _conv1d_op_with_squeeze + ) + + inp_node, param_node = node.args[0], node.args[1] + assert isinstance(inp_node, torch.fx.Node) + assert isinstance(param_node, torch.fx.Node) + + if param_node.op == "call_function": + # Using Conv2dPrepackParam from conv_prepack. + # We directly skip the packing call and inline weights and bias. + w_node, b_node = param_node.args[0], param_node.args[1] + assert isinstance(w_node, torch.fx.Node) + assert b_node is None or isinstance(b_node, torch.fx.Node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm, w_node, b_node + ) + op_res_node = gm.graph.call_function( + op_f, (inp_node, param_0, param_1, *param_node.args[2:]) + ) + else: + # Using ConvPrepackedParam. + param = get_script_object(gm, param_node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm, param_node + ) # type: ignore[assignment] + op_res_node = gm.graph.call_function( + op_f, + ( + inp_node, + param_0, + param_1, + param.stride(), # type: ignore[attr-defined] + param.padding(), # type: ignore[attr-defined] + param.dilation(), # type: ignore[attr-defined] + param.groups(), # type: ignore[attr-defined] + ), + ) + return op_res_node, scale_node, zero_point_node + + +def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Linear specfic transformation function.""" + scale_node, zero_point_node = node.args[2], node.args[3] + + inp_node, param_node = node.args[0], node.args[1] + assert isinstance(inp_node, torch.fx.Node) + assert isinstance(param_node, torch.fx.Node) + + if param_node.op == "call_function": + # Using LinearPrepackParam from linear_prepack. + # We directly skip the packing call and inline weights and bias. + w_node, b_node = param_node.args[0], param_node.args[1] + assert isinstance(w_node, torch.fx.Node) + assert b_node is None or isinstance(b_node, torch.fx.Node) + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( + gm, w_node, b_node + ) + op_res_node = gm.graph.call_function( + torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:]) + ) + else: + # Using LinearPackedParams. + ( + param_0, + param_1, + ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( + gm, param_node + ) # type: ignore[assignment] + op_res_node = gm.graph.call_function( + torch.ops.aten.linear, (inp_node, param_0, param_1) + ) + return op_res_node, scale_node, zero_point_node + + +def _transform_op_where_last_two_arguments_are_scale_and_zero_point( + gm: torch.fx.GraphModule, node: torch.fx.Node +): + """ + This transformation function can be used for function where the last two + parameters are scale and zero point. Additionally, the function's parameters + do not need any unpacking. + """ + to_standard_op = { + "mul": torch.ops.aten.mul, + "mul_relu": torch.ops.aten.mul, + "add": torch.ops.aten.add, + "add_relu": torch.ops.aten.add, + "softmax": torch.ops.aten.softmax, + "cat": torch.ops.aten.cat, + "hardswish": torch.ops.aten.hardswish, + } + + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + scale_node, zero_point_node = args[-2], args[-1] + op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2])) + return op_res_node, scale_node, zero_point_node + + +def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node): + """Transform scalar overload for basic arithmetic.""" + to_standard_op = { + "mul": torch.ops.aten.mul.Scalar, + "add": torch.ops.aten.add.Scalar, + } + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + op_res_node = gm.graph.call_function(to_standard_op[opname], args) + return op_res_node, _SCALE, _ZERO_POINT + + +def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node): + """ + Transformation for functions under prepacked namespace, where they share + the same handling logic that [...]OpContext contains all parameters. + """ + assert isinstance(node.target, torch._ops.OpOverload) + opname, args = node.target._opname, node.args + op_f = None + if opname == "conv2d_clamp_run": + op_f = torch.ops.aten.conv2d + elif opname == "linear_clamp_run": + op_f = torch.ops.aten.linear + else: + raise RuntimeError(f"Invalid operator {opname}") + + assert isinstance(args[1], torch.fx.Node) + so = get_script_object(gm, args[1]) + + func_args = [] + func_args += [args[0]] + func_args += so.unpack()[:2] # type: ignore[attr-defined] + if opname == "conv2d_clamp_run": + func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:] + + op_res_node = gm.graph.call_function(op_f, tuple(func_args)) + return op_res_node + + +def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node): + args = node.args + scale_node, zero_point_node = args[-2], args[-1] + op_res_node = gm.graph.call_function( + torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3]) + ) + op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0)) + return op_res_node, scale_node, zero_point_node + + +def fx_transform_quantized_op_to_standard_op( + gm: torch.fx.GraphModule, node: torch.fx.Node +) -> torch.fx.Node: + global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE + + assert isinstance(node.target, torch._ops.OpOverload) + opname, overload = node.target._opname, node.target._overloadname + + key = f"{opname}.{overload}" + opname_to_transform_f = { + "conv1d.new": _transform_conv_with_packedparam, + "conv1d_relu.new": _transform_conv_with_packedparam, + "conv1d.default": _transform_conv_with_packedparam, + "conv1d_relu.default": _transform_conv_with_packedparam, + "conv2d.new": _transform_conv_with_packedparam, + "conv2d_relu.new": _transform_conv_with_packedparam, + "conv2d.default": _transform_conv_with_packedparam, + "conv2d_relu.default": _transform_conv_with_packedparam, + "linear.default": _transform_linear_with_packedparam, + "linear_relu.default": _transform_linear_with_packedparam, + "add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, + "batch_norm2d.default": _transform_batch_norm, + "mul.Scalar": _transform_scalar_arithmetic, + "add.Scalar": _transform_scalar_arithmetic, + } + + if f"{key}" not in opname_to_transform_f: + raise RuntimeError(f"Unsupported quantized op during transformation: {key}") + + op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node) + + # Add fused activation layer. + op_res_node = insert_fused_activation_node(gm, opname, op_res_node) + _SCALE, _ZERO_POINT = scale_node, zero_point_node + + assert _INPUT_Q_DTYPE is not None + qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE) + q_fx_node = insert_quantized_node( + gm, + op_res_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + _INPUT_Q_DTYPE, + torch.per_tensor_affine, + ) + dq_fx_node = insert_dequantized_node( + gm, + q_fx_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + _INPUT_Q_DTYPE, + None, + torch.per_tensor_affine, + ) + return dq_fx_node + + +def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule): + """ + Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with + PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv). + + Before: x || -> aten.q || -> quantized.conv2d || -> quantized.linear || -> aten.dq || -> y + + After: x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y + + (qd == quantized_decomposed library, q = quantize, dq = dequantize) + ^ + | + getattr(w), getattr(b) from Conv2dParamPrepack + + During each iteration, the transformation spits out the transformed operator, its quantized output, + and its dequantized value together. We did this because dequantization need to use the + scale and zero point parameters from the quantization to recover the approximate original value. After each + iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear). + + For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject. + During the transformation, we unpack those objects, get their dequantized tensor, populate those + as attributes to the module, and use getattr to access them. + + One exception in the transformation is conv_prepack and linear_prepack. Those calls pack + weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls. + During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the + quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters + to the operator by converting them to a getattr fx.node. + + For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear + without the need of doing de/quantization. + + Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization + data type, which is the same across the entire program, but it only shows up in the very first quantization + call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar. + """ + + global _INPUT_Q_DTYPE + + quantized = False + + last_quantized_node = None + for node in gm.graph.nodes: + if isinstance(node.target, OpOverload): + with gm.graph.inserting_before(node): + namespace, opname = node.target.namespace, node.target._opname + if namespace == "quantized" and opname not in [ + "conv_prepack", + "linear_prepack", + ]: + quantized = True + fx_node = fx_transform_quantized_op_to_standard_op(gm, node) + node.replace_all_uses_with(fx_node) + last_quantized_node = fx_node + elif namespace == "prepacked": + quantized = True + fx_node = _transform_prepacked_op(gm, node) + node.replace_all_uses_with(fx_node) + last_quantized_node = fx_node + elif namespace == "aten" and opname == "quantize_per_tensor": + inp_node, scale_node, zero_point_node, dtype_node = node.args + dtype_node = fx_enum_to_dtype(gm, dtype_node) + _INPUT_Q_DTYPE = dtype_node + qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node) + q_fx_node = insert_quantized_node( + gm, + inp_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + torch.per_tensor_affine, + ) + dq_fx_node = insert_dequantized_node( + gm, + q_fx_node, + scale_node, + zero_point_node, + qmin_node, + qmax_node, + dtype_node, + None, + torch.per_tensor_affine, + ) + node.replace_all_uses_with(dq_fx_node) + last_quantized_node = dq_fx_node + elif namespace == "aten" and opname == "dequantize": + assert last_quantized_node is not None + node.replace_all_uses_with(last_quantized_node) + else: + last_quantized_node = node + + # Post-processing again to remove legacy ScriptObjects and quantizated tensors + # stored as attributes or in the buffer. This is used to clean up the GraphModule + # to not trigger tracing errors like missing __obj_flatten__ functions. + def _clean_attr(mod: torch.nn.Module): + for submod in mod.modules(): + attr_names_to_clean = set() + for k, v in submod.__dict__.items(): + if isinstance(v, torch.ScriptObject): + attr_names_to_clean.add(k) + if k == "_buffers": + buffer_name_to_clean = set() + for b_name, b_value in v.items(): + if isinstance(b_value, torch.Tensor) and b_value.dtype in [ + torch.qint8, + torch.quint8, + ]: + buffer_name_to_clean.add(b_name) + for b_name in buffer_name_to_clean: + v.pop(b_name, None) + for attr_name in attr_names_to_clean: + delattr(submod, attr_name) + + if quantized: + """ + TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily + bypass test cases. + + The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing + will throw errors. However, the current way of SetAttr does inplace update to attributes, so + this pass regard them as dead code and remove them. Below is an example of GraphModule before + and after the dead code elimination pass. + + class GraphModule(torch.nn.Module): + def forward(self, x_1): + # No stacktrace found for following nodes + data = self.data; data = None + data_1 = self.data + add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1); data_1 = None + data_2 = self.data + copy_ = torch_Tensor_copy_(data_2, add_tensor); data_2 = add_tensor = copy_ = None + data_3 = self.data + add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None + return add_tensor_1 + + class GraphModule(torch.nn.Module): + def forward(self, x_1): + # No stacktrace found for following nodes + data_3 = self.data + add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None + return add_tensor_1 + """ + gm.graph.eliminate_dead_code() + _clean_attr(gm) diff --git a/phivenv/Lib/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py b/phivenv/Lib/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..45177ad51a86bdcb3e7de87463f20a0e1d3e9660 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING, Union + +import torch +from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled + +from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split +from .replace_with_hop_pass_util import ( + _replace_with_hop_helper, + _replace_with_hop_pass_helper, + _sequential_split_and_maybe_inline_subgraphs_helper, +) + + +if TYPE_CHECKING: + from torch.export.graph_signature import ExportGraphSignature + + +def _is_set_grad_enabled_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: + return ( + node + and node.op == "call_function" + and node.target == torch._C._set_grad_enabled + ) + + +def _is_set_grad_enabled_sub_mod( + node: torch.fx.Node, omit_if_same_with_ambient: bool = False +) -> Union[bool, torch.Tensor]: + if node.op == "call_module": + assert isinstance(node.target, str) + subgm = getattr(node.graph.owning_module, node.target) + first_non_ph = nodes_first( + subgm.graph.nodes, lambda node: node.op != "placeholder" + ) + if ( + first_non_ph + and first_non_ph.op == "call_function" + and first_non_ph.target == torch._C._set_grad_enabled + ): + return ( + first_non_ph.args[0] != torch.is_grad_enabled() + if omit_if_same_with_ambient + else True + ) + return False + + +def _replace_with_hop(node: torch.fx.Node) -> None: + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node) + if len(set_grad_nodes) > 0: + assert len(set_grad_nodes) == 1 + set_grad_node = set_grad_nodes[0] + _replace_with_hop_helper(node, set_grad_node, wrap_with_set_grad_enabled) + sub_graph.erase_node(set_grad_node) + + +def _remove_set_grad_and_inline(node: torch.fx.Node) -> None: + assert node.op == "call_module" + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + sub_graph = sub_gm.graph + nodes_map( + sub_graph.nodes, + lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n, + ) + node_inline_(node) + + +def _sequential_split_and_maybe_inline_subgraphs( + gm: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] +) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + """ + Helper function for replace_set_grad_with_hop_pass(). + Split the graph module into multiple subgraphs based on the set_grad_enabled nodes. + For each subgraph, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module. + """ + need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes) + if not need_replacing: + return gm, graph_signature + + # sequential_split returns a new graph module that could have different output + # args names. We need to fix the graph signature. + new_gm = sequential_split(gm, _is_set_grad_enabled_node) + + def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): + if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True): + _replace_with_hop(node) + else: + _remove_set_grad_and_inline(node) + + return _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm, graph_signature, _maybe_inline_or_replace_with_hop + ) + + +def replace_set_grad_with_hop_pass( + gm: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] +) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + return _replace_with_hop_pass_helper( + gm, + graph_signature, + _sequential_split_and_maybe_inline_subgraphs, + ) diff --git a/phivenv/Lib/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py b/phivenv/Lib/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..faa045b256cdd017d9cac57394f1c2ac7c76df10 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch._export.error import InternalError +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse +from torch._ops import HigherOrderOperator, OpOverload + + +__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"] + + +_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: dict[OpOverload, OpOverload] = { + torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, +} + + +def is_view_op(schema: torch._C.FunctionSchema) -> bool: + if len(schema.arguments) == 0: + return False + alias_info = schema.arguments[0].alias_info + return (alias_info is not None) and (not alias_info.is_write) + + +def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]: + if is_view_op(schema) and schema.name.startswith("aten::"): + view_op_name = schema.name.split("::")[1] + view_op_overload = ( + schema.overload_name if schema.overload_name != "" else "default" + ) + view_copy_op_name = view_op_name + "_copy" + if not hasattr(torch.ops.aten, view_copy_op_name): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name) + + if not hasattr(view_copy_op_overload_packet, view_op_overload): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + return getattr(view_copy_op_overload_packet, view_op_overload) + + return None + + +class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Our backend expects pure functional operators. For efficiency + purposes, we keep view ops around while functionalizing the exported + program. This pass replaces view ops with view copy ops for backends that + need AOT memory planning. + """ + + def call_operator(self, op, args, kwargs, meta): + if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: + return super().call_operator( + (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta + ) + + if isinstance(op, HigherOrderOperator): + return super().call_operator(op, args, kwargs, meta) + + if view_copy_op := get_view_copy_of_view_op(op._schema): + return super().call_operator(view_copy_op, args, kwargs, meta) + + return super().call_operator(op, args, kwargs, meta) diff --git a/phivenv/Lib/site-packages/torch/_export/passes/replace_with_hop_pass_util.py b/phivenv/Lib/site-packages/torch/_export/passes/replace_with_hop_pass_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c60119919aa4f54f267d97cda12775f0e556f3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/passes/replace_with_hop_pass_util.py @@ -0,0 +1,187 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import copy +import operator +from typing import Callable, Optional, TYPE_CHECKING + +import torch + +from ..utils import node_replace_, nodes_map + + +if TYPE_CHECKING: + from torch._ops import HigherOrderOperator + from torch.export.graph_signature import ExportGraphSignature + + +def _replace_with_hop_helper( + node: torch.fx.Node, + enter_block_node: torch.fx.Node, + wrap_hoo: HigherOrderOperator, +) -> None: + graph: torch.fx.Graph = node.graph + assert graph.owning_module is not None + gm: torch.fx.GraphModule = graph.owning_module + assert isinstance(node.target, str) + sub_gm = getattr(gm, node.target) + + def set_hoo_node_meta(call_func_node): + call_func_node.meta["nn_module_stack"] = copy.copy( + enter_block_node.meta.get("nn_module_stack", {}) + ) + call_func_node.meta["torch_fn"] = ( + f"{wrap_hoo.__name__}", + f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}", + ) + if isinstance(output_args, (tuple, list)): + call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args) + elif isinstance(output_args, torch.fx.Node): + call_func_node.meta["val"] = (output_args.meta["val"],) + + with graph.inserting_before(node): + get_attr_node = graph.get_attr(node.target) + get_attr_node.meta["nn_module_stack"] = copy.copy( + enter_block_node.meta.get("nn_module_stack", {}) + ) + output_node = next(iter(reversed(sub_gm.graph.nodes)), None) + # Split_module pass intentially doesn't add output node + # if the graph doesn't return anything. + # TODO (tmanlaibaatar) Figure out if this is right behaviour + # for split_module + if isinstance(output_node, torch.fx.Node) and output_node.op != "output": + output_node = None + if output_node is not None: + assert len(output_node.args) == 1 + output_args = output_node.args[0] + enter_block_node_args = enter_block_node.args + if isinstance(output_args, (tuple, list)): + call_func_node = graph.call_function( + wrap_hoo, + (*enter_block_node_args, get_attr_node, *node.args), + {}, + ) + # Create the metadata + set_hoo_node_meta(call_func_node) + node_replace_(node, call_func_node) + + # Rename the name of getitem nodes to the actual name of its contents + # for passing verifier and better readability, also propagate metadata + for get_item_node in call_func_node.users.keys(): + idx: int = get_item_node.args[1] # type: ignore[assignment] + output_node = output_args[idx] + get_item_node._rename(output_node.name) + get_item_node.meta = output_node.meta + + elif isinstance(output_args, torch.fx.Node): + call_func_node = graph.create_node( + "call_function", + wrap_hoo, + (*enter_block_node_args, get_attr_node, *node.args), + {}, + output_args.name, + ) + # Modify the subgraph to output a singleton list. + output_node.args = ((output_args,),) + # Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph. + get_item_node = graph.create_node( + "call_function", + operator.getitem, + (call_func_node, 0), + {}, + ) + # Create the metadata + get_item_node.meta = output_args.meta + set_hoo_node_meta(call_func_node) + node_replace_(node, get_item_node) + else: + raise NotImplementedError( + f"repalce_with_hop_pass doesnt' support output type {type(output_args)}" + ) + else: + # TODO (shangdiy): remove this line, since the export graph can be non-functional + node.graph.erase_node(node) + + +def _sequential_split_and_maybe_inline_subgraphs_helper( + new_gm: torch.fx.GraphModule, + graph_signature: Optional[ExportGraphSignature], + maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None], +) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + """ + Helper function for replacing graph nodse with higher order nodes. + For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls + back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`. + """ + # new_gm is a new graph module that could have different output args names. + # We need to fix the graph signature. + replace_ctx = contextlib.nullcontext() + new_signature = None + if graph_signature is not None: + # Cannot deep copy a real ScriptObject, which is referenced + # in the FakeScriptObject. Copy should be good enough to guard + # against accidental mutation to original graph_signature. + new_signature = copy.copy(graph_signature) + new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output"))) + assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len( + new_signature.output_specs + ) + for arg_node, out_spec in zip( + new_gm_out_node.args[0], new_signature.output_specs + ): + if arg_node is None: + assert out_spec.arg.value is None # type: ignore[union-attr] + elif ( + isinstance(arg_node, torch.fx.Node) + and out_spec.arg.name != arg_node.name + ): + out_spec.arg.name = arg_node.name + + replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook()) # type: ignore[assignment] + + with replace_ctx: + nodes_map( + list(new_gm.graph.nodes), + lambda node: ( + maybe_inline_or_replace_with_hop(node) + if node.op == "call_module" + else node + ), + ) + new_gm.recompile() + new_gm.graph.lint() + return new_gm, new_signature + + +def _replace_with_hop_pass_helper( + gm: torch.fx.GraphModule, + graph_signature: Optional[ExportGraphSignature], + sequential_split_and_maybe_inline_subgraphs: Callable[ + [torch.fx.GraphModule, Optional[ExportGraphSignature]], + tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]], + ], +) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + """ + Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and + then recursively call itself on each of the submodules. + """ + new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs( + gm, graph_signature + ) + # recursively call + for node in new_gm.graph.nodes: + if node.op == "get_attr": + subgm = getattr(new_gm, node.target) + if not isinstance(subgm, torch.fx.GraphModule): + continue + new_subgm, _ = _replace_with_hop_pass_helper( + subgm, + None, + sequential_split_and_maybe_inline_subgraphs, + ) + setattr(new_gm, node.target, new_subgm) + + new_gm.recompile() + new_gm.graph.lint() + return new_gm, new_signature diff --git a/phivenv/Lib/site-packages/torch/_export/serde/__init__.py b/phivenv/Lib/site-packages/torch/_export/serde/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bcd8e23c3be644c8bda5503b36e87c68b86186c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c61b0b60a9ee5cedcfa52bed4211eeb50695a60e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/schema.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e8cbb7248e30dc129ab79c4c29351eba6cf8b5b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/schema.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d980f43196b9d2c35b67710138c2753ebe669f98 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/serialize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/serialize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f861ea6f9099c0b1b2b0a882bfd14817df7647d4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/serialize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/union.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/union.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65577f43adc6695ff41775917923edf264a8de83 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_export/serde/__pycache__/union.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_export/serde/dynamic_shapes.py b/phivenv/Lib/site-packages/torch/_export/serde/dynamic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c9bab0473bebb8389b32c62aed038a2e2c5dcf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/serde/dynamic_shapes.py @@ -0,0 +1,322 @@ +import dataclasses +from typing import Any, Optional, Union + +import torch +from torch._dynamo.exc import UserError, UserErrorType +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _DerivedDim, + _DimHint, + _tree_map_with_path, + Dim, +) +from torch.utils._pytree import tree_map + +from .serialize import _dataclass_to_dict + + +@dataclasses.dataclass +class RootDim: + """ + This represents a Dim object. + """ + + min: int + max: Union[int, None] + derived: list[str] + + +@dataclasses.dataclass +class DynamicShapesSpec: + """ + This stores a dynamic_shapes spec for de/serialization. + """ + + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None] + dims: dict[str, RootDim] + + +def _postprocess_serialized_shapes( + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], + dims: dict[str, dict[str, Union[int, list[str], None]]], + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, dict[str, Any]]: + """ + Sorts dims and dumps to dictionary format. + """ + from torch.utils._sympy.numbers import int_oo + + dims = { + k: RootDim( + min=v["min"], # type: ignore[arg-type] + max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type] + derived=sorted(v["derived"]), # type: ignore[arg-type] + ) + for k, v in sorted(dims.items()) + } + spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims) + if to_dict: + return _dataclass_to_dict(spec) + else: + return spec + + +def _dump_dynamic_shapes( + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, dict[str, Any]]: + """ + Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec. + Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims". + Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones). + + dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export(): + - Each tensor input is represented with a list of values, non-tensor inputs with None. + - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings. + - static dimensions are represented with ints. + + dims: A dictionary mapping each symbol name to the min/max range and derived dim names. + + For example: + ``` + dx = Dim("dx", min=4, max=16) + dy = dx + 1 + + inputs = ( + [ + torch.randn(4, 4), + torch.randn(5, 4), + ], + torch.randn(4), + torch.randn(4, 4), + "hello", + ) + dynamic_shapes = { + "a": [ + (dx, 4), + (dy, 4), + ], + "b": (Dim.STATIC,), + "c": None, + "d": None, + } + out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True) + ``` + would generate the following output: + ``` + { + 'dynamic_shapes': ( + [ + ['dx', 4], + ['dx + 1', 4], + ], + ['_DimHint.STATIC'], + ['_DimHint.STATIC', '_DimHint.STATIC'], + None, + ), + 'dims': { + 'dx': { + 'min': 4, + 'max': 16, + 'derived': ['dx + 1'], + }, + }, + } + ``` + """ + dims: dict[str, dict[str, Any]] = {} + + def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] + """ + Helps standardize the dynamic_shapes tree structure we serialize, + returning lists for each tensor shape, handling tensor-level Nones. + """ + if not isinstance(tensor, torch.Tensor): + return None + if shape is None: + return [Dim.STATIC] * len(tensor.shape) + + out = [] + if isinstance(shape, dict): + for i, s in enumerate(tensor.shape): + out.append(s if shape.get(i) is None else shape.get(i)) + else: + assert isinstance(shape, (tuple, list)) + for i, s in enumerate(tensor.shape): + out.append(s if shape[i] is None else shape[i]) + return out + + def _track_dim_from_dims( + val: Union[None, int, _DimHint, Dim] + ) -> Union[None, int, str]: + """ + Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec. + """ + if val is None or isinstance(val, int): # non-tensor input or static + return val + if isinstance(val, _DimHint): # store enum as string + return val.__class__.__name__ + "." + val.type.name + + assert isinstance(val, Dim) + + # track root dim + root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined] + if root.__name__ not in dims: + dims[root.__name__] = { + "min": root.min, # type: ignore[attr-defined,union-attr] + "max": root.max, # type: ignore[attr-defined,union-attr] + "derived": set(), + } + + # track derived dims + if isinstance(val, _DerivedDim): + dims[root.__name__]["derived"].add(val.__name__) + + return val.__name__ + + if dynamic_shapes is None: + return {"dynamic_shapes": None, "dims": {}} + + # convert to tuple of specs, for each arg/kwarg + kwargs = kwargs or {} + if isinstance(dynamic_shapes, dict): + dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment] + dynamic_shapes = tuple(dynamic_shapes) + combined_args = tuple(args) + tuple(kwargs.values()) + + # run same check when we're processing shapes for export - is this too lazy? + _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type] + + tree_shapes = _tree_map_with_path( + _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs" + ) + serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes) + return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict) + + +def _load_dynamic_shapes( + spec: Union[DynamicShapesSpec, dict[str, Any]], + from_dict: Optional[bool] = False, +) -> Union[dict[str, Any], tuple[Any], list[Any], None]: + """ + Utility function for dynamic shapes serialization. + Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export(). + """ + import sympy + + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + if from_dict: + if not isinstance(spec, dict): + raise UserError( + UserErrorType.INVALID_INPUT, + f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}", + ) + if sorted(spec.keys()) != ["dims", "dynamic_shapes"]: + raise UserError( + UserErrorType.INVALID_INPUT, + "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, " + f"instead found {spec.keys()}", + ) + dims = {} + for k, v in spec["dims"].items(): + if not isinstance(k, str): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}", + ) + if sorted(v.keys()) != ["derived", "max", "min"]: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, " + f"instead found {v.keys()}", + ) + if not isinstance(v["min"], int): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}", + ) + if not isinstance(v["max"], int) or v["max"] is None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}", + ) + if not isinstance(v["derived"], list) or any( + not isinstance(d, str) for d in v["derived"] + ): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, " + f"got {k}: {v['derived']}", + ) + dims[k] = RootDim(**v) + dynamic_shapes = spec["dynamic_shapes"] + else: + if not isinstance(spec, DynamicShapesSpec): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}", + ) + dims = spec.dims + dynamic_shapes = spec.dynamic_shapes + + if dynamic_shapes is None: + return None + + dim_cache = {} + for name, info in dims.items(): + symbol = sympy.sympify(name) + if not isinstance(symbol, sympy.Symbol): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be symbols, got {name}", + ) + dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim + for _expr in info.derived: + expr = sympy.sympify(_expr) + if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions in to have {name} as the only free symbol, got {expr}", + ) + if not _is_supported_equivalence(expr): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions to be linear expressions, got {expr}", + ) + modulus, remainder = sympy.polys.polytools.div(expr, symbol) + ddim = dim_cache[name] + if modulus != 1: + ddim = int(modulus) * ddim # type: ignore[assignment, operator] + if remainder != 0: + ddim = ddim + int(remainder) # type: ignore[assignment, operator] + dim_cache[_expr] = ddim # cache derived dims + + def deserialize_shape( + val: Union[None, int, str] + ) -> Union[None, int, Dim, _DimHint]: + if val is None or isinstance(val, int): + return val + elif val == "_DimHint.AUTO": + return _DimHint.AUTO() + elif val == "_DimHint.DYNAMIC": + return _DimHint.DYNAMIC() + elif val == "_DimHint.STATIC": + return _DimHint.STATIC() + if not isinstance(val, str): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, " + f" or derived expressions, got {val}", + ) + if val not in dim_cache: + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " + f"got {val} which is not in {dims.keys()}", + ) + return dim_cache[val] # type: ignore[return-value] + + return tree_map(deserialize_shape, dynamic_shapes) diff --git a/phivenv/Lib/site-packages/torch/_export/serde/export_schema.thrift b/phivenv/Lib/site-packages/torch/_export/serde/export_schema.thrift new file mode 100644 index 0000000000000000000000000000000000000000..59dae5e36ce4a7085f0c07cb04017b2644ae6059 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/serde/export_schema.thrift @@ -0,0 +1,362 @@ +// @generated by update_schema.py +// checksum<> + +namespace py3 torch._export +namespace cpp2 torch._export.schema + +enum ArgumentKind { + UNKNOWN = 0, + POSITIONAL = 1, + KEYWORD = 2, +} + + +enum Layout { + Unknown = 0, + SparseCoo = 1, + SparseCsr = 2, + SparseCsc = 3, + SparseBsr = 4, + SparseBsc = 5, + _mkldnn = 6, + Strided = 7, +} + + +enum MemoryFormat { + Unknown = 0, + ContiguousFormat = 1, + ChannelsLast = 2, + ChannelsLast3d = 3, + PreserveFormat = 4, +} + + +enum ScalarType { + UNKNOWN = 0, + BYTE = 1, + CHAR = 2, + SHORT = 3, + INT = 4, + LONG = 5, + HALF = 6, + FLOAT = 7, + DOUBLE = 8, + COMPLEXHALF = 9, + COMPLEXFLOAT = 10, + COMPLEXDOUBLE = 11, + BOOL = 12, + BFLOAT16 = 13, + UINT16 = 28, + FLOAT8E4M3FN = 29, + FLOAT8E5M2 = 30, +} + + +struct Device { + 10: string type; + 20: optional i64 index; +} + +union SymExprHint { + 10: i64 as_int; + 20: bool as_bool; + 30: double as_float; +} + +struct SymExpr { + 10: string expr_str; + 20: optional SymExprHint hint; +} + +union SymInt { + 10: SymExpr as_expr; + 20: i64 as_int; +} + +union SymFloat { + 10: SymExpr as_expr; + 20: double as_float; +} + +union SymBool { + 10: SymExpr as_expr; + 20: bool as_bool; +} + +struct TensorMeta { + 10: ScalarType dtype; + 20: list sizes; + 30: bool requires_grad; + 40: Device device; + 50: list strides; + 60: SymInt storage_offset; + 70: Layout layout; +} + +union SymIntArgument { + 10: string as_name; + 20: i64 as_int; +} + +union SymFloatArgument { + 10: string as_name; + 20: double as_float; +} + +union SymBoolArgument { + 10: string as_name; + 20: bool as_bool; +} + +struct TensorArgument { + 10: string name; +} + +struct TokenArgument { + 10: string name; +} + +union OptionalTensorArgument { + 20: TensorArgument as_tensor; + 10: bool as_none; +} + +struct GraphArgument { + 10: string name; + 20: Graph graph; +} + +struct CustomObjArgument { + 10: string name; + 20: string class_fqn; +} + +union Argument { + 10: bool as_none; + 20: TensorArgument as_tensor; + 30: list as_tensors; + 50: i64 as_int; + 70: list as_ints; + 80: double as_float; + 90: list as_floats; + 100: string as_string; + 101: list as_strings; + 110: SymIntArgument as_sym_int; + 120: list as_sym_ints; + 130: ScalarType as_scalar_type; + 140: MemoryFormat as_memory_format; + 150: Layout as_layout; + 160: Device as_device; + 170: bool as_bool; + 180: list as_bools; + 182: SymBoolArgument as_sym_bool; + 184: list as_sym_bools; + 200: GraphArgument as_graph; + 190: list as_optional_tensors; + 210: CustomObjArgument as_custom_obj; + 220: string as_operator; + 230: SymFloatArgument as_sym_float; + 240: list as_sym_floats; + 250: OptionalTensorArgument as_optional_tensor; +} + +struct NamedArgument { + 10: string name; + 20: Argument arg; + 30: optional ArgumentKind kind; +} + +struct Node { + 10: string target; + 20: list inputs; + 30: list outputs; + 40: map metadata; + 50: optional bool is_hop_single_tensor_return; +} + +struct Graph { + 10: list inputs; + 20: list outputs; + 30: list nodes; + 40: map tensor_values; + 50: map sym_int_values; + 60: map sym_bool_values; + 70: bool is_single_tensor_return; + 80: map custom_obj_values; + 90: map sym_float_values; +} + +struct UserInputSpec { + 10: Argument arg; +} + +union ConstantValue { + 10: bool as_none; + 20: i64 as_int; + 30: double as_float; + 40: string as_string; + 50: bool as_bool; +} + +struct InputToConstantInputSpec { + 10: string name; + 20: ConstantValue value; +} + +struct InputToParameterSpec { + 10: TensorArgument arg; + 20: string parameter_name; +} + +struct InputToBufferSpec { + 10: TensorArgument arg; + 20: string buffer_name; + 30: bool persistent; +} + +struct InputToTensorConstantSpec { + 10: TensorArgument arg; + 20: string tensor_constant_name; +} + +struct InputToCustomObjSpec { + 10: CustomObjArgument arg; + 20: string custom_obj_name; +} + +struct InputTokenSpec { + 10: TokenArgument arg; +} + +union InputSpec { + 10: UserInputSpec user_input; + 20: InputToParameterSpec parameter; + 30: InputToBufferSpec buffer; + 40: InputToTensorConstantSpec tensor_constant; + 50: InputToCustomObjSpec custom_obj; + 70: InputTokenSpec token; + 60: InputToConstantInputSpec constant_input; +} + +struct UserOutputSpec { + 10: Argument arg; +} + +struct LossOutputSpec { + 10: TensorArgument arg; +} + +struct BufferMutationSpec { + 10: TensorArgument arg; + 20: string buffer_name; +} + +struct GradientToParameterSpec { + 10: TensorArgument arg; + 20: string parameter_name; +} + +struct GradientToUserInputSpec { + 10: TensorArgument arg; + 20: string user_input_name; +} + +struct UserInputMutationSpec { + 10: TensorArgument arg; + 20: string user_input_name; +} + +struct OutputTokenSpec { + 10: TokenArgument arg; +} + +union OutputSpec { + 10: UserOutputSpec user_output; + 20: LossOutputSpec loss_output; + 30: BufferMutationSpec buffer_mutation; + 40: GradientToParameterSpec gradient_to_parameter; + 50: GradientToUserInputSpec gradient_to_user_input; + 60: UserInputMutationSpec user_input_mutation; + 70: OutputTokenSpec token; +} + +struct GraphSignature { + 10: list input_specs; + 20: list output_specs; +} + +struct RangeConstraint { + 10: optional i64 min_val; + 20: optional i64 max_val; +} + +struct ModuleCallSignature { + 10: list inputs; + 20: list outputs; + 30: string in_spec; + 40: string out_spec; + 50: optional list forward_arg_names; +} + +struct ModuleCallEntry { + 10: string fqn; + 30: optional ModuleCallSignature signature; +} + +struct NamedTupleDef { + 10: list field_names; +} + +struct GraphModule { + 10: Graph graph; + 50: GraphSignature signature; + 60: list module_call_graph; + 40: map metadata; + 70: map treespec_namedtuple_fields; +} + +struct SchemaVersion { + 10: i64 major; + 20: i64 minor; +} + +struct ExportedProgram { + 10: GraphModule graph_module; + 20: map opset_version; + 30: map range_constraints; + 60: SchemaVersion schema_version; + 70: list verifiers; + 80: string torch_version; +} + +struct Program { + 200: map methods; +} + +struct Model { + 10: string name; + 20: map tensorPaths; + 40: Program program; + 50: map delegates; + 60: map deviceAllocationMap; + 70: map constantPaths; +} + +struct AOTInductorModelPickleData { + 1: string library_basename; + 2: list input_names; + 3: list output_names; + 4: optional i64 floating_point_input_dtype; + 5: optional i64 floating_point_output_dtype; + 6: optional bool aot_inductor_model_is_cpu; +} + +struct ExternKernelNode { + 10: string name; + 20: Node node; +} + +struct ExternKernelNodes { + 10: list nodes; +} diff --git a/phivenv/Lib/site-packages/torch/_export/serde/schema.py b/phivenv/Lib/site-packages/torch/_export/serde/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..555006ecf80b7e44c2a8822b50e793c47abb684a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/serde/schema.py @@ -0,0 +1,505 @@ +# NOTE: This is a placeholder for iterating on export serialization schema design. +# Anything is subject to change and no guarantee is provided at this point. + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Annotated, Optional + +from torch._export.serde.union import _Union + + +# NOTE: Please update this value if any modifications are made to the schema +SCHEMA_VERSION = (8, 8) +TREESPEC_VERSION = 1 + + +# NOTE: If you updated the schema, please run `scripts/export/update_schema.py` +# to update the auto generated files. +class ScalarType(IntEnum): + UNKNOWN = 0 + BYTE = 1 + CHAR = 2 + SHORT = 3 + INT = 4 + LONG = 5 + HALF = 6 + FLOAT = 7 + DOUBLE = 8 + COMPLEXHALF = 9 + COMPLEXFLOAT = 10 + COMPLEXDOUBLE = 11 + BOOL = 12 + BFLOAT16 = 13 + UINT16 = 28 + FLOAT8E4M3FN = 29 + FLOAT8E5M2 = 30 + + +class Layout(IntEnum): + Unknown = 0 + SparseCoo = 1 + SparseCsr = 2 + SparseCsc = 3 + SparseBsr = 4 + SparseBsc = 5 + _mkldnn = 6 + Strided = 7 + + +class MemoryFormat(IntEnum): + Unknown = 0 + ContiguousFormat = 1 + ChannelsLast = 2 + ChannelsLast3d = 3 + PreserveFormat = 4 + + +@dataclass +class Device: + type: Annotated[str, 10] + index: Annotated[Optional[int], 20] = None + + +@dataclass(repr=False) +class SymExprHint(_Union): + as_int: Annotated[int, 10] + as_bool: Annotated[bool, 20] + as_float: Annotated[float, 30] + + +# This is for storing the symbolic expressions behind symints/symfloats/symbools +# For example, we can get something like +# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) +# if we also have the hint that s0 and s1 are both 2. +@dataclass +class SymExpr: + expr_str: Annotated[str, 10] + hint: Annotated[Optional[SymExprHint], 20] = None + + +@dataclass(repr=False) +class SymInt(_Union): + as_expr: Annotated[SymExpr, 10] + as_int: Annotated[int, 20] + + +@dataclass(repr=False) +class SymFloat(_Union): + as_expr: Annotated[SymExpr, 10] + as_float: Annotated[float, 20] + + +@dataclass(repr=False) +class SymBool(_Union): + as_expr: Annotated[SymExpr, 10] + as_bool: Annotated[bool, 20] + + +@dataclass +class TensorMeta: + dtype: Annotated[ScalarType, 10] + sizes: Annotated[list[SymInt], 20] + requires_grad: Annotated[bool, 30] + device: Annotated[Device, 40] + strides: Annotated[list[SymInt], 50] + storage_offset: Annotated[SymInt, 60] + layout: Annotated[Layout, 70] + + +# In most cases we will use the "as_name" field to store arguments which are +# SymInts. +# The "as_int" field is used in the case where we have a list containing a mix +# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to +# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints +# to the "as_int" field. +@dataclass(repr=False) +class SymIntArgument(_Union): + as_name: Annotated[str, 10] + as_int: Annotated[int, 20] + + +# In most cases we will use the "as_name" field to store arguments which are +# SymFloats. +# The "as_float" field is used in the case where we have a list containing a mix +# of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to +# be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints +# to the "as_float" field. +@dataclass(repr=False) +class SymFloatArgument(_Union): + as_name: Annotated[str, 10] + as_float: Annotated[float, 20] + + +# In most cases we will use the "as_name" field to store arguments which are +# SymBools. +# The "as_bool" field is used in the case where we have a list containing a mix +# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to +# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools +# to the "as_bool" field. +@dataclass(repr=False) +class SymBoolArgument(_Union): + as_name: Annotated[str, 10] + as_bool: Annotated[bool, 20] + + +@dataclass +class TensorArgument: + name: Annotated[str, 10] + + +@dataclass +class TokenArgument: + name: Annotated[str, 10] + + +# This is use for storing the contents of a list which contain optional tensors +# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the +# type List[OptionalTensorArgument], with tensor values seiralized to the +# "as_tensor" field, and None values serialized to the "as_none" field. +@dataclass(repr=False) +class OptionalTensorArgument(_Union): + as_tensor: Annotated[TensorArgument, 20] + as_none: Annotated[bool, 10] + + +@dataclass +class GraphArgument: + name: Annotated[str, 10] + graph: Annotated["Graph", 20] + + +@dataclass +class CustomObjArgument: + name: Annotated[str, 10] + class_fqn: Annotated[str, 20] + + +# This is actually a union type +@dataclass(repr=False) +class Argument(_Union): + as_none: Annotated[bool, 10] + as_tensor: Annotated[TensorArgument, 20] + as_tensors: Annotated[list[TensorArgument], 30] + as_int: Annotated[int, 50] + as_ints: Annotated[list[int], 70] + as_float: Annotated[float, 80] + as_floats: Annotated[list[float], 90] + as_string: Annotated[str, 100] + as_strings: Annotated[list[str], 101] + as_sym_int: Annotated[SymIntArgument, 110] + as_sym_ints: Annotated[list[SymIntArgument], 120] + as_scalar_type: Annotated[ScalarType, 130] + as_memory_format: Annotated[MemoryFormat, 140] + as_layout: Annotated[Layout, 150] + as_device: Annotated[Device, 160] + as_bool: Annotated[bool, 170] + as_bools: Annotated[list[bool], 180] + as_sym_bool: Annotated[SymBoolArgument, 182] + as_sym_bools: Annotated[list[SymBoolArgument], 184] + as_graph: Annotated[GraphArgument, 200] + as_optional_tensors: Annotated[list[OptionalTensorArgument], 190] + as_custom_obj: Annotated[CustomObjArgument, 210] + as_operator: Annotated[str, 220] + as_sym_float: Annotated[SymFloatArgument, 230] + as_sym_floats: Annotated[list[SymFloatArgument], 240] + as_optional_tensor: Annotated[OptionalTensorArgument, 250] + + +class ArgumentKind(IntEnum): + UNKNOWN = 0 + POSITIONAL = 1 + KEYWORD = 2 + + +@dataclass +class NamedArgument: + # Argument name from the operator schema + name: Annotated[str, 10] + arg: Annotated[Argument, 20] + kind: Annotated[Optional[ArgumentKind], 30] = None + + +@dataclass +class Node: + target: Annotated[str, 10] + inputs: Annotated[list[NamedArgument], 20] + outputs: Annotated[list[Argument], 30] + metadata: Annotated[dict[str, str], 40] + is_hop_single_tensor_return: Annotated[Optional[bool], 50] = None + + +@dataclass +class Graph: + inputs: Annotated[list[Argument], 10] + outputs: Annotated[list[Argument], 20] + nodes: Annotated[list[Node], 30] + tensor_values: Annotated[dict[str, TensorMeta], 40] + sym_int_values: Annotated[dict[str, SymInt], 50] + sym_bool_values: Annotated[dict[str, SymBool], 60] + # This is for deserializing the submodule graphs from higher order ops + # (ex. cond, map) where single tensor returns will just return a single + # tensor, rather than following export schema and returning a singleton + # list. + is_single_tensor_return: Annotated[bool, 70] = False + custom_obj_values: Annotated[dict[str, CustomObjArgument], 80] = field( + default_factory=dict + ) + sym_float_values: Annotated[dict[str, SymFloat], 90] = field(default_factory=dict) + + +@dataclass +class UserInputSpec: + # Actually, only tensors and SymInts are allowed here + arg: Annotated[Argument, 10] + + +@dataclass(repr=False) +class ConstantValue(_Union): + as_none: Annotated[bool, 10] + as_int: Annotated[int, 20] + as_float: Annotated[float, 30] + as_string: Annotated[str, 40] + as_bool: Annotated[bool, 50] + + +@dataclass +class InputToConstantInputSpec: + name: Annotated[str, 10] + value: Annotated[ConstantValue, 20] + + +@dataclass +class InputToParameterSpec: + arg: Annotated[TensorArgument, 10] + parameter_name: Annotated[str, 20] + + +@dataclass +class InputToBufferSpec: + arg: Annotated[TensorArgument, 10] + buffer_name: Annotated[str, 20] + persistent: Annotated[bool, 30] + + +@dataclass +class InputToTensorConstantSpec: + arg: Annotated[TensorArgument, 10] + tensor_constant_name: Annotated[str, 20] + + +@dataclass +class InputToCustomObjSpec: + arg: Annotated[CustomObjArgument, 10] + custom_obj_name: Annotated[str, 20] + + +@dataclass +class InputTokenSpec: + arg: Annotated[TokenArgument, 10] + + +@dataclass(repr=False) +class InputSpec(_Union): + user_input: Annotated[UserInputSpec, 10] + parameter: Annotated[InputToParameterSpec, 20] + buffer: Annotated[InputToBufferSpec, 30] + tensor_constant: Annotated[InputToTensorConstantSpec, 40] + custom_obj: Annotated[InputToCustomObjSpec, 50] + token: Annotated[InputTokenSpec, 70] + constant_input: Annotated[InputToConstantInputSpec, 60] + + +@dataclass +class UserOutputSpec: + arg: Annotated[Argument, 10] + + +@dataclass +class LossOutputSpec: + arg: Annotated[TensorArgument, 10] + + +@dataclass +class BufferMutationSpec: + arg: Annotated[TensorArgument, 10] + buffer_name: Annotated[str, 20] + + +@dataclass +class GradientToParameterSpec: + arg: Annotated[TensorArgument, 10] + parameter_name: Annotated[str, 20] + + +@dataclass +class GradientToUserInputSpec: + arg: Annotated[TensorArgument, 10] + user_input_name: Annotated[str, 20] + + +@dataclass +class UserInputMutationSpec: + arg: Annotated[TensorArgument, 10] + user_input_name: Annotated[str, 20] + + +@dataclass +class OutputTokenSpec: + arg: Annotated[TokenArgument, 10] + + +@dataclass(repr=False) +class OutputSpec(_Union): + user_output: Annotated[UserOutputSpec, 10] + loss_output: Annotated[LossOutputSpec, 20] + buffer_mutation: Annotated[BufferMutationSpec, 30] + gradient_to_parameter: Annotated[GradientToParameterSpec, 40] + gradient_to_user_input: Annotated[GradientToUserInputSpec, 50] + user_input_mutation: Annotated[UserInputMutationSpec, 60] + token: Annotated[OutputTokenSpec, 70] + + +@dataclass +class GraphSignature: + input_specs: Annotated[list[InputSpec], 10] + output_specs: Annotated[list[OutputSpec], 20] + + +@dataclass +class RangeConstraint: + min_val: Annotated[Optional[int], 10] + max_val: Annotated[Optional[int], 20] + + +@dataclass +class ModuleCallSignature: + inputs: Annotated[list[Argument], 10] + outputs: Annotated[list[Argument], 20] + + # These are serialized by calling pytree.treespec_loads + # And deserialized by calling pytree.treespec_dumps + in_spec: Annotated[str, 30] + out_spec: Annotated[str, 40] + + # This field is used to prettify the graph placeholders + # after we ser/der and retrace + forward_arg_names: Annotated[Optional[list[str]], 50] = None + + +@dataclass +class ModuleCallEntry: + fqn: Annotated[str, 10] + signature: Annotated[Optional[ModuleCallSignature], 30] = None + + +@dataclass +class NamedTupleDef: + field_names: Annotated[list[str], 10] + + +@dataclass +class GraphModule: + graph: Annotated[Graph, 10] + signature: Annotated[GraphSignature, 50] + # This is used for unflattening, by tracking the calling structure of all of + # the modules in order to unflatten the modules back to the eager calling + # conventions. + module_call_graph: Annotated[list[ModuleCallEntry], 60] + metadata: Annotated[dict[str, str], 40] = field(default_factory=dict) + # Mapping of namedtuple types to namedtuple field names, used for BC + treespec_namedtuple_fields: Annotated[dict[str, NamedTupleDef], 70] = field( + default_factory=dict + ) + + +# Invariant: Every time a change is made to the schema, one of the versions +# should be upadted. +@dataclass +class SchemaVersion: + major: Annotated[ + int, 10 + ] # Major version number is bumped every time a breaking change is made. + minor: Annotated[ + int, 20 + ] # Minor version number is bumped when a compatible change is made. + + +@dataclass +class ExportedProgram: + graph_module: Annotated[GraphModule, 10] + # Key is the opset namespace (ex. aten), and value is the version number + opset_version: Annotated[dict[str, int], 20] + range_constraints: Annotated[dict[str, RangeConstraint], 30] + schema_version: Annotated[SchemaVersion, 60] + verifiers: Annotated[list[str], 70] = field(default_factory=list) + torch_version: Annotated[str, 80] = "<=2.4" + + +######################################################################### +# Container types for inference tasks, not being used directly for export. +######################################################################### + + +@dataclass +class Program: + methods: Annotated[dict[str, ExportedProgram], 200] + + +# This is the top-level model definition that be will serialized into the package +@dataclass +class Model: + # unique identifier of the model in the package, e.g. local, remote, merge + name: Annotated[str, 10] + # key is the FQN of tensor in exported program + # value is the archive path of tensor payloads + # e.g. "L__self__linear.weight" : "/data/tensor/L__self__linear.weight" + tensorPaths: Annotated[dict[str, str], 20] + # program exported from torch.export() + program: Annotated[Program, 40] + # Backend-specialized Lowered GraphModule + # e.g. "aotinductor-a100" : ExportedProgram_with_AOTInductor_delegate + delegates: Annotated[dict[str, Program], 50] + deviceAllocationMap: Annotated[dict[str, str], 60] + # key is the FQN of constant in exported program (constant tensor or torchbind objs) + # value is the archive path of serialized constants + constantPaths: Annotated[dict[str, str], 70] + + +# +# The structure is used to serialize instances of AOTInductorModel to pass +# them from the publishing pipeline to the predictor. +# +# All new fields should be marked as optional. +# +@dataclass +class AOTInductorModelPickleData: + # Base name of an associated .so AOTInductor library. Typically looks like: + # "abc.so". + library_basename: Annotated[str, 1] + + # AOTInductor engine input names. + input_names: Annotated[list[str], 2] + + # AOTInductor engine output names. + output_names: Annotated[list[str], 3] + + # These fields tell whether floating point inputs/outputs should be converted to + # a certain type. If None, the dtypes that the AOTInductor engine inferred from the sample + # inputs are used. + floating_point_input_dtype: Annotated[Optional[int], 4] = None + floating_point_output_dtype: Annotated[Optional[int], 5] = None + + # Whether AOTInductor runtime is for CPU. + aot_inductor_model_is_cpu: Annotated[Optional[bool], 6] = None + + +@dataclass +class ExternKernelNode: + # name is not the unique identifier of the node + name: Annotated[str, 10] + node: Annotated[Node, 20] + + +@dataclass +class ExternKernelNodes: + nodes: Annotated[list[ExternKernelNode], 10] diff --git a/phivenv/Lib/site-packages/torch/_export/serde/schema.yaml b/phivenv/Lib/site-packages/torch/_export/serde/schema.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7880c58f2989710f372ae355cdee9dd0e5b89d6a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/serde/schema.yaml @@ -0,0 +1,536 @@ +# @generated by update_schema.py +# checksum<<110c364974d3b0f7dcbdf6862781212bdcc7178925c43c894c336fc2b6ca6628>> +AOTInductorModelPickleData: + kind: struct + fields: + library_basename: + type: str + input_names: + type: List[str] + output_names: + type: List[str] + floating_point_input_dtype: + type: Optional[int] + default: None + floating_point_output_dtype: + type: Optional[int] + default: None + aot_inductor_model_is_cpu: + type: Optional[bool] + default: None +Argument: + kind: union + fields: + as_none: + type: bool + as_tensor: + type: TensorArgument + as_tensors: + type: List[TensorArgument] + as_int: + type: int + as_ints: + type: List[int] + as_float: + type: float + as_floats: + type: List[float] + as_string: + type: str + as_strings: + type: List[str] + as_sym_int: + type: SymIntArgument + as_sym_ints: + type: List[SymIntArgument] + as_scalar_type: + type: ScalarType + as_memory_format: + type: MemoryFormat + as_layout: + type: Layout + as_device: + type: Device + as_bool: + type: bool + as_bools: + type: List[bool] + as_sym_bool: + type: SymBoolArgument + as_sym_bools: + type: List[SymBoolArgument] + as_graph: + type: GraphArgument + as_optional_tensors: + type: List[OptionalTensorArgument] + as_custom_obj: + type: CustomObjArgument + as_operator: + type: str + as_sym_float: + type: SymFloatArgument + as_sym_floats: + type: List[SymFloatArgument] + as_optional_tensor: + type: OptionalTensorArgument +ArgumentKind: + kind: enum + fields: + UNKNOWN: 0 + POSITIONAL: 1 + KEYWORD: 2 +BufferMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str +ConstantValue: + kind: union + fields: + as_none: + type: bool + as_int: + type: int + as_float: + type: float + as_string: + type: str + as_bool: + type: bool +CustomObjArgument: + kind: struct + fields: + name: + type: str + class_fqn: + type: str +Device: + kind: struct + fields: + type: + type: str + index: + type: Optional[int] + default: None +ExportedProgram: + kind: struct + fields: + graph_module: + type: GraphModule + opset_version: + type: Dict[str, int] + range_constraints: + type: Dict[str, RangeConstraint] + schema_version: + type: SchemaVersion + verifiers: + type: List[str] + default: '[]' + torch_version: + type: str + default: <=2.4 +ExternKernelNode: + kind: struct + fields: + name: + type: str + node: + type: Node +ExternKernelNodes: + kind: struct + fields: + nodes: + type: List[ExternKernelNode] +GradientToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +GradientToUserInputSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +Graph: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + nodes: + type: List[Node] + tensor_values: + type: Dict[str, TensorMeta] + sym_int_values: + type: Dict[str, SymInt] + sym_bool_values: + type: Dict[str, SymBool] + is_single_tensor_return: + type: bool + default: 'False' + custom_obj_values: + type: Dict[str, CustomObjArgument] + default: '{}' + sym_float_values: + type: Dict[str, SymFloat] + default: '{}' +GraphArgument: + kind: struct + fields: + name: + type: str + graph: + type: Graph +GraphModule: + kind: struct + fields: + graph: + type: Graph + signature: + type: GraphSignature + module_call_graph: + type: List[ModuleCallEntry] + metadata: + type: Dict[str, str] + default: '{}' + treespec_namedtuple_fields: + type: Dict[str, NamedTupleDef] + default: '{}' +GraphSignature: + kind: struct + fields: + input_specs: + type: List[InputSpec] + output_specs: + type: List[OutputSpec] +InputSpec: + kind: union + fields: + user_input: + type: UserInputSpec + parameter: + type: InputToParameterSpec + buffer: + type: InputToBufferSpec + tensor_constant: + type: InputToTensorConstantSpec + custom_obj: + type: InputToCustomObjSpec + token: + type: InputTokenSpec + constant_input: + type: InputToConstantInputSpec +InputToBufferSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str + persistent: + type: bool +InputToConstantInputSpec: + kind: struct + fields: + name: + type: str + value: + type: ConstantValue +InputToCustomObjSpec: + kind: struct + fields: + arg: + type: CustomObjArgument + custom_obj_name: + type: str +InputToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +InputToTensorConstantSpec: + kind: struct + fields: + arg: + type: TensorArgument + tensor_constant_name: + type: str +InputTokenSpec: + kind: struct + fields: + arg: + type: TokenArgument +Layout: + kind: enum + fields: + Unknown: 0 + SparseCoo: 1 + SparseCsr: 2 + SparseCsc: 3 + SparseBsr: 4 + SparseBsc: 5 + _mkldnn: 6 + Strided: 7 +LossOutputSpec: + kind: struct + fields: + arg: + type: TensorArgument +MemoryFormat: + kind: enum + fields: + Unknown: 0 + ContiguousFormat: 1 + ChannelsLast: 2 + ChannelsLast3d: 3 + PreserveFormat: 4 +Model: + kind: struct + fields: + name: + type: str + tensorPaths: + type: Dict[str, str] + program: + type: Program + delegates: + type: Dict[str, Program] + deviceAllocationMap: + type: Dict[str, str] + constantPaths: + type: Dict[str, str] +ModuleCallEntry: + kind: struct + fields: + fqn: + type: str + signature: + type: Optional[ModuleCallSignature] + default: None +ModuleCallSignature: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + in_spec: + type: str + out_spec: + type: str + forward_arg_names: + type: Optional[List[str]] + default: None +NamedArgument: + kind: struct + fields: + name: + type: str + arg: + type: Argument + kind: + type: Optional[ArgumentKind] + default: None +NamedTupleDef: + kind: struct + fields: + field_names: + type: List[str] +Node: + kind: struct + fields: + target: + type: str + inputs: + type: List[NamedArgument] + outputs: + type: List[Argument] + metadata: + type: Dict[str, str] + is_hop_single_tensor_return: + type: Optional[bool] + default: None +OptionalTensorArgument: + kind: union + fields: + as_tensor: + type: TensorArgument + as_none: + type: bool +OutputSpec: + kind: union + fields: + user_output: + type: UserOutputSpec + loss_output: + type: LossOutputSpec + buffer_mutation: + type: BufferMutationSpec + gradient_to_parameter: + type: GradientToParameterSpec + gradient_to_user_input: + type: GradientToUserInputSpec + user_input_mutation: + type: UserInputMutationSpec + token: + type: OutputTokenSpec +OutputTokenSpec: + kind: struct + fields: + arg: + type: TokenArgument +Program: + kind: struct + fields: + methods: + type: Dict[str, ExportedProgram] +RangeConstraint: + kind: struct + fields: + min_val: + type: Optional[int] + max_val: + type: Optional[int] +ScalarType: + kind: enum + fields: + UNKNOWN: 0 + BYTE: 1 + CHAR: 2 + SHORT: 3 + INT: 4 + LONG: 5 + HALF: 6 + FLOAT: 7 + DOUBLE: 8 + COMPLEXHALF: 9 + COMPLEXFLOAT: 10 + COMPLEXDOUBLE: 11 + BOOL: 12 + BFLOAT16: 13 + UINT16: 28 + FLOAT8E4M3FN: 29 + FLOAT8E5M2: 30 +SchemaVersion: + kind: struct + fields: + major: + type: int + minor: + type: int +SymBool: + kind: union + fields: + as_expr: + type: SymExpr + as_bool: + type: bool +SymBoolArgument: + kind: union + fields: + as_name: + type: str + as_bool: + type: bool +SymExpr: + kind: struct + fields: + expr_str: + type: str + hint: + type: Optional[SymExprHint] + default: None +SymExprHint: + kind: union + fields: + as_int: + type: int + as_bool: + type: bool + as_float: + type: float +SymFloat: + kind: union + fields: + as_expr: + type: SymExpr + as_float: + type: float +SymFloatArgument: + kind: union + fields: + as_name: + type: str + as_float: + type: float +SymInt: + kind: union + fields: + as_expr: + type: SymExpr + as_int: + type: int +SymIntArgument: + kind: union + fields: + as_name: + type: str + as_int: + type: int +TensorArgument: + kind: struct + fields: + name: + type: str +TensorMeta: + kind: struct + fields: + dtype: + type: ScalarType + sizes: + type: List[SymInt] + requires_grad: + type: bool + device: + type: Device + strides: + type: List[SymInt] + storage_offset: + type: SymInt + layout: + type: Layout +TokenArgument: + kind: struct + fields: + name: + type: str +UserInputMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +UserInputSpec: + kind: struct + fields: + arg: + type: Argument +UserOutputSpec: + kind: struct + fields: + arg: + type: Argument +SCHEMA_VERSION: +- 8 +- 8 +TREESPEC_VERSION: 1 diff --git a/phivenv/Lib/site-packages/torch/_export/serde/schema_check.py b/phivenv/Lib/site-packages/torch/_export/serde/schema_check.py new file mode 100644 index 0000000000000000000000000000000000000000..8575bb762c7be5830e4ef786cb2c6546fc0cc89a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/serde/schema_check.py @@ -0,0 +1,738 @@ +# mypy: allow-untyped-defs +import dataclasses +import hashlib +import inspect +import re +import typing +from enum import IntEnum +from typing import Annotated, Any, ForwardRef, Optional, Union + +from torch._export.serde import schema +from torch._export.serde.union import _Union + + +class SchemaUpdateError(Exception): + pass + + +def _check(x, msg): + if not x: + raise SchemaUpdateError(msg) + + +_CPP_TYPE_MAP = { + str: "std::string", + int: "int64_t", + float: "F64", + bool: "bool", +} + +_THRIFT_TYPE_MAP = { + str: "string", + int: "i64", + float: "double", + bool: "bool", +} + + +def _staged_schema(): + yaml_ret: dict[str, Any] = {} + defs = {} + cpp_enum_defs: dict[str, str] = {} + cpp_class_defs: dict[str, str] = {} + cpp_type_decls: list[str] = [] + cpp_json_defs: list[str] = [] + thrift_enum_defs: list[str] = [] + thrift_type_defs: dict[str, str] = {} + + def _handle_aggregate(ty) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + def dump_type(t, level: int) -> tuple[str, str, str]: + if getattr(t, "__name__", None) in cpp_enum_defs: + return t.__name__, "int64_t", t.__name__ + elif t in _CPP_TYPE_MAP: + return (t.__name__, _CPP_TYPE_MAP[t], _THRIFT_TYPE_MAP[t]) + elif isinstance(t, str): + assert t in defs + assert t not in cpp_enum_defs + assert "[" not in t + return t, f"ForwardRef<{t}>", t + elif isinstance(t, ForwardRef): + return ( + t.__forward_arg__, + f"ForwardRef<{t.__forward_arg__}>", + t.__forward_arg__, + ) + elif o := typing.get_origin(t): + # Lemme know if there's a better way to do this. + if o == list: + yaml_head, cpp_head, thrift_head, thrift_tail = ( + "List", + "std::vector", + "list<", + ">", + ) + elif o == dict: + yaml_head, cpp_head, thrift_head, thrift_tail = ( + "Dict", + "std::unordered_map", + "map<", + ">", + ) + elif o == Union: + assert level == 0, "Optional is only supported at the top level." + args = typing.get_args(t) + assert len(args) == 2 and args[1] == type(None) + yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1) + return ( + f"Optional[{yaml_type}]", + f"std::optional<{cpp_type}>", + f"optional {thrift_type}", + ) + elif o == Annotated: + return dump_type(t.__origin__, level) + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + yaml_arg_types, cpp_arg_types, thrift_arg_types = zip( + *[dump_type(x, level + 1) for x in typing.get_args(t)] + ) + return ( + (f"{yaml_head}[{', '.join(yaml_arg_types)}]"), + (f"{cpp_head}<{', '.join(cpp_arg_types)}>"), + f"{thrift_head}{', '.join(thrift_arg_types)}{thrift_tail}", + ) + elif isinstance(t, type): + return (t.__name__, t.__name__, t.__name__) + else: + raise AssertionError(f"Type {t} is not supported in export schema.") + + def dump_cpp_value(v) -> str: + if v is None: + return "std::nullopt" + elif v is True: + return "true" + elif v is False: + return "false" + elif v == {}: + return "{}" + elif v == []: + return "{}" + elif v == (): + return "{}" + elif isinstance(v, str): + return f'"{v}"' + else: + raise AssertionError( + f"Default value {v} is not supported yet in export schema." + ) + + def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: + t, cpp_type, thrift_type = dump_type(f.type, 0) + ret = {"type": t} + cpp_default: Optional[str] = None + assert ( + typing.get_origin(f.type) == Annotated + ), f"Field {f.name} must be annotated with an integer id." + thrift_id = f.type.__metadata__[0] + assert ( + type(thrift_id) is int + ), f"Field {f.name} must be annotated with an integer id." + + value = dataclasses.MISSING + if f.default is not dataclasses.MISSING: + value = f.default + elif f.default_factory is not dataclasses.MISSING: + value = f.default_factory() + + if value is not dataclasses.MISSING: + default = str(value) + ret["default"] = default + cpp_default = dump_cpp_value(value) + + if t.startswith("Optional[") and value is not None: + raise AssertionError( + f"Optional field {ty.__name__}.{f.name} must have default value to be None." + ) + + return ret, cpp_type, cpp_default, thrift_type, thrift_id + + yaml_ret = {} + cpp_ret = {} + thrift_ret = {} + thrift_ids = set() + for f in dataclasses.fields(ty): + yaml_res, cpp_type, cpp_default, thrift_type, thrift_id = dump_field(f) + yaml_ret[f.name] = yaml_res + cpp_ret[f.name] = {"cpp_type": cpp_type, "cpp_default": cpp_default} + thrift_ret[f.name] = {"thrift_type": thrift_type, "thrift_id": thrift_id} + if thrift_id in thrift_ids: + raise AssertionError( + f"Duplicate thrift id {thrift_id} for field {f.name} in {ty.__name__}." + ) + thrift_ids.add(thrift_id) + return yaml_ret, cpp_ret, thrift_ret + + def _handle_int_enum(name, ty): + yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} + cpp_enum_defs[ + name + ] = f""" +enum class {name} {{ +{chr(10).join([f" {x.name} = {x.value}," for x in ty])} +}}; + +inline std::string_view printEnum(const {name}& e) {{ + switch (e) {{ +{chr(10).join([f" case {name}::{x.name}: return {chr(34)}{x.name}{chr(34)};" for x in ty])} + default: + throw std::runtime_error("Unknown enum value"); + }} +}} + +inline void parseEnum(std::string_view s, {name}& t) {{ +{chr(10).join([f" if (s == {chr(34)}{x.name}{chr(34)}) {{ t = {name}::{x.name}; return; }}" for x in ty])} + throw std::runtime_error("Unknown enum value: " + std::string{{s}}); +}} +""" + thrift_enum_defs.append( + f""" +enum {name} {{ +{chr(10).join([f" {x.name} = {x.value}," for x in ty])} +}} +""" + ) + + def _handle_struct(name, ty): + fields, cpp_fields, thrift_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "struct", "fields": fields} + field_decls = "\n".join( + f" {f['cpp_type']} {name}{' = ' + f['cpp_default'] if f['cpp_default'] is not None else ''};" + for name, f in cpp_fields.items() + ) + + def accessor(name, ty): + type_name = fields[name]["type"] + if type_name in cpp_enum_defs: + return f""" + {type_name} get_{name}() const {{ + return static_cast<{type_name}>({name}); + }} + + void set_{name}({type_name} def) {{ + {name} = static_cast(def); + }} +""" + return f""" + const {ty}& get_{name}() const {{ + return {name}; + }} + + void set_{name}({ty} def) {{ + {name} = std::move(def); + }} +""" + + to_json_decl = f"void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t)" + to_json_def = f"""{{ +{chr(10).join([f' nlohmann_json_j["{name}"] = nlohmann_json_t.{name};' for name, f in cpp_fields.items()])} +}} +""" + from_json_decl = f"void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t)" + + from_json_def = f"""{{ + {name} nlohmann_json_default_obj; +{chr(10).join( + [f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});' + for name, f in cpp_fields.items()])} +}} +""" + cpp_class_defs[ + name + ] = f""" +class {name} {{ + private: +{field_decls} + + public: +{"".join([accessor(name, f["cpp_type"]) for name, f in cpp_fields.items()])} + friend {to_json_decl}; + friend {from_json_decl}; +}}; +""" + cpp_json_defs.append(f"inline {to_json_decl} {to_json_def}") + cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}") + cpp_type_decls.append(f"class {name};") + + thrift_type_defs[ + name + ] = f""" +struct {name} {{ +{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} +}}""" + + def _handle_union(name, ty): + fields, cpp_fields, thrift_fields = _handle_aggregate(ty) + yaml_ret[name] = {"kind": "union", "fields": fields} + + def accessor(name, ty, idx): + return f""" + const {ty}& get_{name}() const {{ + return std::get<{idx + 1}>(variant_); + }} + + void set_{name}({ty} def) {{ + variant_.emplace<{idx + 1}>(std::move(def)); + tag_ = Tag::{name.upper()}; + }} +""" + + to_json_branches = "".join( + [ + f""" + if (nlohmann_json_t.tag_ == Tag::{name.upper()}) {{ + nlohmann_json_j["{name}"] = nlohmann_json_t.get_{name}(); + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + from_json_branches = "".join( + [ + f""" + if (nlohmann_json_j.contains("{name}")) {{ + nlohmann_json_t.variant_.emplace<{idx + 1}>(nlohmann_json_j.at("{name}").template get<{f["cpp_type"]}>()); + nlohmann_json_t.tag_ = Tag::{name.upper()}; + return; + }}""" + for idx, (name, f) in enumerate(cpp_fields.items()) + ] + ) + + cpp_class_defs[ + name + ] = f""" +class {name} {{ + struct Void {{}}; + + public: + enum class Tag {{ + {", ".join([name.upper() for name in cpp_fields])} + }}; + + private: + std::variant variant_; + Tag tag_; + + public: + Tag tag() const {{ + return tag_; + }} +{"".join([accessor(name, f["cpp_type"], idx) for idx, (name, f) in enumerate(cpp_fields.items())])} + friend void to_json(nlohmann::json& nlohmann_json_j, const {name}& nlohmann_json_t) {{ +{to_json_branches} + }} + + friend void from_json(const nlohmann::json& nlohmann_json_j, {name}& nlohmann_json_t) {{ +{from_json_branches} + }} +}}; + +inline std::string_view printEnum(const {name}::Tag& e) {{ + switch (e) {{ +{chr(10).join([f" case {name}::Tag::{x.upper()}: return {chr(34)}{x.upper()}{chr(34)};" for x in cpp_fields])} + default: + throw std::runtime_error("Unknown enum value"); + }} +}} + +inline void parseEnum(std::string_view s, {name}::Tag& t) {{ +{chr(10).join([f" if (s == {chr(34)}{x.upper()}{chr(34)}) {{ t = {name}::Tag::{x.upper()}; return; }}" for x in cpp_fields])} + throw std::runtime_error("Unknown enum value: " + std::string{{s}}); +}} + +""" + cpp_type_decls.append(f"class {name};") + + thrift_type_defs[ + name + ] = f""" +union {name} {{ +{chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} +}}""" + + for name in dir(schema): + if name.startswith("_"): + continue + + value = getattr(schema, name) + + if hasattr(value, "__module__") and value.__module__ != schema.__name__: + continue + + defs[name] = value + + class_ordering = {} + for name, value in defs.items(): + if isinstance(value, type): + if issubclass(value, IntEnum): + _handle_int_enum(name, value) + elif dataclasses.is_dataclass(value): + class_ordering[name] = inspect.findsource(value)[1] + if issubclass(value, _Union): + _handle_union(name, value) + else: + _handle_struct(name, value) + else: + raise AssertionError(f"Unknown schema type {name}: {value}") + elif isinstance(value, (int, tuple)): + assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION") + else: + raise AssertionError(f"Unknown variable {name}: {value}") + + yaml_ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) + assert all(x > 0 for x in yaml_ret["SCHEMA_VERSION"]) + yaml_ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] + assert yaml_ret["TREESPEC_VERSION"] > 0 + + cpp_header = f""" +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN namespace nlohmann {{ +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END }} +#endif + +// https://github.com/nlohmann/json/pull/2117 +NLOHMANN_JSON_NAMESPACE_BEGIN +template +struct adl_serializer> {{ + static void to_json(json& j, const std::optional& opt) {{ + if (opt == std::nullopt) {{ + j = nullptr; + }} else {{ + j = *opt; // this will call adl_serializer::to_json which will + // find the free function to_json in T's namespace! + }} + }} + + static void from_json(const json& j, std::optional& opt) {{ + if (j.is_null()) {{ + opt = std::nullopt; + }} else {{ + opt = j.template get(); // same as above, but with + // adl_serializer::from_json + }} + }} +}}; +NLOHMANN_JSON_NAMESPACE_END + +namespace torch {{ +namespace _export {{ + +template +class ForwardRef {{ + static_assert(!std::is_reference_v, "ForwardRef cannot be a reference type"); + + public: + ForwardRef(): ptr_(std::make_unique()) {{}} + ForwardRef(ForwardRef&&); + ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {{}} + ForwardRef& operator=(ForwardRef&&); + ForwardRef& operator=(const ForwardRef& other) {{ + ptr_ = std::make_unique(*other.ptr_); + return *this; + }} + const T& operator*() const {{ + return *ptr_; + }} + + const T* operator->() const {{ + return ptr_.get(); + }} + + void emplace(T&& t) {{ + ptr_ = std::make_unique(std::move(t)); + }} + + private: + std::unique_ptr ptr_; +}}; + +template +void to_json(nlohmann::json& j, const ForwardRef& p) {{ + j = *p; +}} + +template +void from_json(const nlohmann::json& j, ForwardRef& p) {{ + p.emplace(j.template get()); +}} + +class F64 {{ + public: + double get() const {{ + return value_; + }} + + void set(double value) {{ + value_ = value; + }} + + private: + double value_; +}}; + +inline void to_json(nlohmann::json& j, const F64& f) {{ + if (std::isinf(f.get())) {{ + j = "Infinity"; + }} else if (std::isinf(-f.get())) {{ + j = "-Infinity"; + }} else if (std::isnan(f.get())) {{ + j = "NaN"; + }} else {{ + j = f.get(); + }} +}} + +inline void from_json(const nlohmann::json& j, F64& f) {{ + if (j == "Infinity") {{ + f.set(std::numeric_limits::infinity()); + }} else if (j == "-Infinity") {{ + f.set(-std::numeric_limits::infinity()); + }} else if (j == "NaN") {{ + f.set(std::numeric_limits::quiet_NaN()); + }} else {{ + f.set(j.get()); + }} +}} + +{chr(10).join(cpp_type_decls)} +{"".join(cpp_enum_defs.values())} +{"".join(dict(sorted(cpp_class_defs.items(), key=lambda x: class_ordering[x[0]])).values())} +{chr(10).join(cpp_json_defs)} + +template ForwardRef::ForwardRef(ForwardRef&&) = default; +template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +}} // namespace _export +}} // namespace torch +""" + thrift_schema = f""" +namespace py3 torch._export +namespace cpp2 torch._export.schema +{chr(10).join(thrift_enum_defs)} +{chr(10).join(dict(sorted(thrift_type_defs.items(), key=lambda x: class_ordering[x[0]])).values())} +""" + return yaml_ret, cpp_header, thrift_schema + + +def _diff_schema(dst, src): + additions = {key: src[key] for key in src.keys() - dst.keys()} + subtractions = {key: dst[key] for key in dst.keys() - src.keys()} + + common_keys = src.keys() & dst.keys() + + versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"} + common_keys -= versions + + for key in common_keys: + src_kind = src[key]["kind"] + src_fields = src[key]["fields"] + dst_kind = dst[key]["kind"] + dst_fields = dst[key]["fields"] + _check( + src_kind == dst_kind, + f"Type {key} changed kind from {dst_kind} to {src_kind}", + ) + assert isinstance(src_fields, dict) and isinstance(dst_fields, dict) + added_fields = { + key: src_fields[key] for key in src_fields.keys() - dst_fields.keys() + } + subtracted_fields = { + key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys() + } + common_fields = src_fields.keys() & dst_fields.keys() + + for field in common_fields: + src_field = src_fields[field] + dst_field = dst_fields[field] + if src_kind == "struct": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + if "default" in src_field and "default" not in dst_field: + added_fields[field] = {} + added_fields[field]["default"] = src_field["default"] + if "default" not in src_field and "default" in dst_field: + subtracted_fields[field] = {} + subtracted_fields[field]["default"] = dst_field["default"] + elif src_kind == "enum": + _check( + src_field == dst_field, + f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}", + ) + elif src_kind == "union": + _check( + src_field["type"] == dst_field["type"], + f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", + ) + else: + raise AssertionError(f"Unknown kind {src_kind}: {key}") + if len(added_fields) > 0: + assert key not in additions + additions[key] = {} + additions[key]["fields"] = added_fields + if len(subtracted_fields) > 0: + assert key not in subtractions + subtractions[key] = {} + subtractions[key]["fields"] = subtracted_fields + + return additions, subtractions + + +def _hash_content(s: str): + return hashlib.sha256(s.strip().encode("utf-8")).hexdigest() + + +@dataclasses.dataclass +class _Commit: + result: dict[str, Any] + checksum_next: str + yaml_path: str + additions: dict[str, Any] + subtractions: dict[str, Any] + base: dict[str, Any] + checksum_head: Optional[str] + cpp_header: str + cpp_header_path: str + thrift_checksum_head: Optional[str] + thrift_checksum_real: Optional[str] + thrift_checksum_next: str + thrift_schema: str + thrift_schema_path: str + + +def update_schema(): + import importlib.resources + + if importlib.resources.is_resource(__package__, "schema.yaml"): + content = importlib.resources.read_text(__package__, "schema.yaml") + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) + _check(match is not None, "checksum not found in schema.yaml") + assert match is not None + checksum_head = match.group(1) + + thrift_content = importlib.resources.read_text( + __package__, "export_schema.thrift" + ) + match = re.search("checksum<<([A-Fa-f0-9]{64})>>", thrift_content) + _check(match is not None, "checksum not found in export_schema.thrift") + assert match is not None + thrift_checksum_head = match.group(1) + thrift_content = thrift_content.splitlines() + assert thrift_content[0].startswith("// @" + "generated") + assert thrift_content[1].startswith("// checksum<<") + thrift_checksum_real = _hash_content("\n".join(thrift_content[2:])) + + from yaml import load, Loader + + dst = load(content, Loader=Loader) + assert isinstance(dst, dict) + else: + checksum_head = None + thrift_checksum_head = None + thrift_checksum_real = None + dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} + + src, cpp_header, thrift_schema = _staged_schema() + additions, subtractions = _diff_schema(dst, src) + yaml_path = __package__.replace(".", "/") + "/schema.yaml" + thrift_schema_path = __package__.replace(".", "/") + "/export_schema.thrift" + torch_prefix = "torch/" + assert yaml_path.startswith(torch_prefix) # sanity check + assert thrift_schema_path.startswith(torch_prefix) # sanity check + + return _Commit( + result=src, + checksum_next=_hash_content(repr(src)), + yaml_path=yaml_path, + additions=additions, + subtractions=subtractions, + base=dst, + checksum_head=checksum_head, + cpp_header=cpp_header, + cpp_header_path=torch_prefix + "csrc/utils/generated_serialization_types.h", + thrift_checksum_head=thrift_checksum_head, + thrift_checksum_real=thrift_checksum_real, + thrift_checksum_next=_hash_content(thrift_schema), + thrift_schema=thrift_schema, + thrift_schema_path=thrift_schema_path, + ) + + +def check(commit: _Commit, force_unsafe: bool = False): + next_version = None + reason = "" + # Step 1: Detect major schema updates. + if len(commit.additions) > 0: + for k, v in commit.additions.items(): + if k not in commit.base: + continue + kind = commit.result[k]["kind"] + fields = v["fields"] + for f, d in fields.items(): + if kind == "struct" and "default" not in d: + reason += ( + f"Field {k}.{f} is added to schema.py without a default value as an incomparible change " + + "which requires major version bump.\n" + ) + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + if k not in commit.result: + continue + for f in v["fields"]: + reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n" + next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] + + if force_unsafe: + reason += "--force-unsafe is used." + next_version = commit.result["SCHEMA_VERSION"] + else: + # Step 2: Detect minor schema updates. + if next_version is None and len(commit.additions) > 0: + for k, v in commit.additions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is added to schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + if next_version is None and len(commit.subtractions) > 0: + for k, v in commit.subtractions.items(): + for f in v["fields"]: + reason += ( + f"Field {k}.{f} is removed from schema.py as an compatible change " + + "which still requires minor version bump.\n" + ) + next_version = [ + commit.base["SCHEMA_VERSION"][0], + commit.base["SCHEMA_VERSION"][1] + 1, + ] + + return next_version, reason diff --git a/phivenv/Lib/site-packages/torch/_export/serde/serialize.py b/phivenv/Lib/site-packages/torch/_export/serde/serialize.py new file mode 100644 index 0000000000000000000000000000000000000000..df39df6f414549490b455766c5da3ea7f834e5b9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/serde/serialize.py @@ -0,0 +1,3556 @@ +# mypy: allow-untyped-defs +import base64 +import copy +import copyreg +import dataclasses +import heapq +import inspect +import io +import json +import keyword +import logging +import math +import operator +import traceback +import typing +from collections import namedtuple, OrderedDict +from collections.abc import Iterable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from typing import Annotated, Any, Callable, cast, final, Optional, Union + +import sympy + +import torch +import torch.export.exported_program as ep +from torch._export.non_strict_utils import _enable_graph_inputs_of_type_nn_module +from torch._export.verifier import load_verifier +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx.experimental import symbolic_shapes +from torch.utils import _pytree as pytree +from torch.utils._pytree import treespec_dumps, treespec_loads +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.symbol import prefix_str, SymT +from torch.utils._sympy.value_ranges import ValueRanges +from torch.utils._traceback import CapturedTraceback + +from ..utils import remove_proxy_from_state_dict +from .schema import ( # type: ignore[attr-defined] + Argument, + ArgumentKind, + BufferMutationSpec, + ConstantValue, + CustomObjArgument, + Device, + ExportedProgram, + GradientToParameterSpec, + GradientToUserInputSpec, + Graph, + GraphArgument, + GraphModule, + GraphSignature, + InputSpec, + InputToBufferSpec, + InputToConstantInputSpec, + InputToCustomObjSpec, + InputTokenSpec, + InputToParameterSpec, + InputToTensorConstantSpec, + Layout, + LossOutputSpec, + MemoryFormat, + ModuleCallEntry, + ModuleCallSignature, + NamedArgument, + NamedTupleDef, + Node, + OptionalTensorArgument, + OutputSpec, + OutputTokenSpec, + RangeConstraint, + ScalarType, + SCHEMA_VERSION, + SchemaVersion, + SymBool, + SymBoolArgument, + SymExpr, + SymExprHint, + SymFloat, + SymFloatArgument, + SymInt, + SymIntArgument, + TensorArgument, + TensorMeta, + TokenArgument, + TREESPEC_VERSION, + UserInputMutationSpec, + UserInputSpec, + UserOutputSpec, +) +from .union import _Union + + +__all__ = [ + "serialize", + "GraphModuleSerializer", + "ExportedProgramSerializer", + "GraphModuleDeserializer", + "ExportedProgramDeserializer", +] + +log = logging.getLogger(__name__) + + +class SerializeError(RuntimeError): + pass + + +def _reverse_map(d: dict[Any, Enum]): + return {v.value: k for k, v in d.items()} + + +MetaType = Union[ + FakeTensor, + int, + torch.SymInt, + float, + torch.SymFloat, + bool, + torch.SymBool, + ep.CustomObjArgument, +] + +DEFAULT_PICKLE_PROTOCOL = 2 + +ST_DELIMITER = ";" + +_TORCH_TO_SERIALIZE_DTYPE = { + torch.uint8: ScalarType.BYTE, + torch.int8: ScalarType.CHAR, + torch.uint16: ScalarType.UINT16, + torch.int16: ScalarType.SHORT, + torch.int32: ScalarType.INT, + torch.int64: ScalarType.LONG, + torch.float16: ScalarType.HALF, + torch.float32: ScalarType.FLOAT, + torch.float64: ScalarType.DOUBLE, + torch.complex32: ScalarType.COMPLEXHALF, + torch.complex64: ScalarType.COMPLEXFLOAT, + torch.complex128: ScalarType.COMPLEXDOUBLE, + torch.bool: ScalarType.BOOL, + torch.bfloat16: ScalarType.BFLOAT16, + torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN, + torch.float8_e5m2: ScalarType.FLOAT8E5M2, +} + + +_SERIALIZE_TO_TORCH_DTYPE = _reverse_map(_TORCH_TO_SERIALIZE_DTYPE) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_LAYOUT = { + torch.sparse_coo: Layout.SparseCoo, + torch.sparse_csr: Layout.SparseCsr, + torch.sparse_csc: Layout.SparseCsc, + torch.sparse_bsr: Layout.SparseBsr, + torch.sparse_bsc: Layout.SparseBsc, + torch._mkldnn: Layout._mkldnn, # type: ignore[attr-defined] + torch.strided: Layout.Strided, +} + + +_SERIALIZE_TO_TORCH_LAYOUT = _reverse_map(_TORCH_TO_SERIALIZE_LAYOUT) # type: ignore[arg-type] + + +_TORCH_TO_SERIALIZE_MEMORY_FORMAT = { + torch.contiguous_format: MemoryFormat.ContiguousFormat, + torch.channels_last: MemoryFormat.ChannelsLast, + torch.channels_last_3d: MemoryFormat.ChannelsLast3d, + torch.preserve_format: MemoryFormat.PreserveFormat, +} + + +_SERIALIZE_TO_TORCH_MEMORY_FORMAT = _reverse_map(_TORCH_TO_SERIALIZE_MEMORY_FORMAT) # type: ignore[arg-type] + +_SYM_OPS = { + operator.eq, + operator.ne, + operator.le, + operator.ge, + operator.lt, + operator.gt, + operator.neg, + operator.pos, + operator.and_, + operator.or_, + math.trunc, + torch.sym_not, + operator.mul, + operator.add, + operator.sub, + operator.floordiv, + operator.mod, + operator.pow, + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_sqrt, + operator.truediv, + operator.and_, +} + + +assert not any(isinstance(op, torch._ops.OpOverload) for op in _SYM_OPS) + + +@dataclass +class SerializedArtifact: + exported_program: bytes + state_dict: bytes + constants: bytes + example_inputs: bytes + + +@dataclass +class _SerializedProgram: + exported_program: ExportedProgram + state_dict: bytes + constants: bytes + example_inputs: bytes + + +def deserialize_device(d: Device) -> torch.device: + if d.index is None: + return torch.device(type=d.type) # type: ignore[call-overload] + return torch.device(type=d.type, index=d.index) + + +def _print_sympy(s: Union[torch.SymInt, torch.SymBool, torch.SymFloat, sympy.Expr]): + if isinstance(s, (torch.SymInt, torch.SymBool, torch.SymFloat)): + s = s.node.expr + return sympy.printing.repr.srepr(s) + + +def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt: + if isinstance(s, (torch.SymInt, sympy.Symbol, int)): + if symbolic_shapes.is_concrete_int(s): + return SymInt.create(as_int=int(s)) + else: + assert isinstance(s, (torch.SymInt, sympy.Symbol)) + if s.node.hint is None: + return SymInt.create(as_expr=SymExpr(_print_sympy(s))) + else: + return SymInt.create( + as_expr=SymExpr( + _print_sympy(s), + hint=SymExprHint.create(as_int=s.node.hint), + ) + ) + else: + raise SerializeError( + f"SymInt should be either symbol or int, got `{s}` of type `{type(s)}`" + ) + + +def serialize_sym_float(s: Union[float, torch.SymFloat]) -> SymFloat: + if isinstance(s, (torch.SymFloat, sympy.Symbol, float)): + if symbolic_shapes.is_concrete_float(s): + return SymFloat.create(as_float=float(s)) + else: + assert isinstance(s, (torch.SymFloat, sympy.Symbol)) + if s.node.hint is None: + return SymFloat.create(as_expr=SymExpr(_print_sympy(s))) + else: + return SymFloat.create( + as_expr=SymExpr( + _print_sympy(s), + hint=SymExprHint.create(as_float=s.node.hint), + ) + ) + else: + raise SerializeError( + f"SymFloat should be either symbol or float, got `{s}` of type `{type(s)}`" + ) + + +def serialize_sym_bool(s: Union[bool, torch.SymBool]) -> SymBool: + if isinstance(s, (torch.SymBool, bool)): + if symbolic_shapes.is_concrete_bool(s): + return SymBool.create(as_bool=bool(s)) + else: + return SymBool.create(as_expr=SymExpr(expr_str=_print_sympy(s))) + else: + raise SerializeError( + f"SymBool should be either symbol or bool, got `{s}` of type `{type(s)}`" + ) + + +def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: + """ + Extract a TensorMeta describing `t`. + """ + return TensorMeta( + dtype=_TORCH_TO_SERIALIZE_DTYPE[t.dtype], + sizes=[serialize_sym_int(s) for s in t.shape], + requires_grad=t.requires_grad, + device=Device(type=t.device.type, index=t.device.index), + strides=[serialize_sym_int(s) for s in t.stride()], + storage_offset=serialize_sym_int(0), # TODO needs to be fixed. + layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout], + ) + + +_CURRENT_DESERIALIZER: Optional["GraphModuleDeserializer"] = None + + +def _reduce_fake_tensor(fake_tensor: FakeTensor): + is_parameter = isinstance(fake_tensor, torch.nn.Parameter) + tensor_meta = serialize_tensor_meta(fake_tensor) + tensor_meta_bytes = json.dumps( + _dataclass_to_dict(tensor_meta), cls=EnumEncoder + ).encode("utf-8") + return _reconstruct_fake_tensor, (tensor_meta_bytes, is_parameter) + + +def _reconstruct_fake_tensor( + serialized_tensor_meta: bytes, is_parameter: bool +) -> FakeTensor: + # Deserialize the bytes into a TensorMeta + json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) + tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) + # Find the current fake mode + assert ( + _CURRENT_DESERIALIZER is not None + ), "Need access to current deserializer state" + fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) + if is_parameter: + fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] + return fake_tensor + + +def serialize_torch_artifact( + artifact: Optional[Any], pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL +) -> bytes: + if artifact is None: + return b"" + + assert ( + FakeTensor not in copyreg.dispatch_table + ), "Refusing to stomp on existing FakeTensor reducer" + try: + copyreg.pickle(FakeTensor, _reduce_fake_tensor) + buffer = io.BytesIO() + # This is a workaround for backend's tensor deserialization problem: + # unpickleTensor() always create a tensor on the device where it was originally saved + # This behavior is bad for multi-gpu training, as we wish to directly load the tensor + # on the designated device. + # For now, we simply move the tensor to cpu before saving. + # TODO: this should be fixed by deserialization instead. + torch.save(artifact, buffer, pickle_protocol=pickle_protocol) + return buffer.getvalue() + finally: + del copyreg.dispatch_table[FakeTensor] + + +def deserialize_torch_artifact( + serialized: Union[dict[str, Any], tuple[Any, ...], bytes] +): + if isinstance(serialized, (dict, tuple)): + return serialized + if len(serialized) == 0: + return {} + buffer = io.BytesIO(serialized) + buffer.seek(0) + # weights_only=False as we want to load custom objects here (e.g. ScriptObject) + artifact = torch.load(buffer, weights_only=False) + assert isinstance(artifact, (tuple, dict)) + return artifact + + +def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]: + # Convert simple sympy Integers into concrete int + if val in (sympy.oo, int_oo): + return None + if val in (-sympy.oo, -int_oo): + return None + if isinstance(val, sympy.Integer): + return int(val) + + # TODO: Remove this adjustment when Ed gets rid of fractional ranges + log.warning( + "Export constraints cannot be non-integer expressions. Found " + "type %s, and value %s. We will attempt to %s " + "this value.", + type(val), + val, + adjust, + ) + + if adjust == "floor": + return math.floor(val) + elif adjust == "ceil": + return math.ceil(val) + else: + raise RuntimeError(f"Got invalid adjustment {adjust}") + + +def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr: + # Convert concrete int into simple sympy Integers + if val is None: + return default + if val in [-int_oo, int_oo]: + return val + if val == math.inf: + return int_oo + if val == -math.inf: + return -int_oo + return sympy.Integer(val) + + +def _symbol_index(sym: sympy.Symbol, sym_type: SymT): + return int(str(sym)[len(prefix_str[sym_type]) :]) + + +def serialize_range_constraints( + range_constraints: dict[sympy.Symbol, ValueRanges] +) -> dict[str, RangeConstraint]: + return { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower, "ceil"), # type: ignore[arg-type] + _sympy_int_to_int(v.upper, "floor"), # type: ignore[arg-type] + ) + for k, v in range_constraints.items() + } + + +def _get_schema_from_target(target): + if isinstance(target, torch._ops.OpOverload): + return target._schema + elif type(target) in _serialization_registry: + return _serialization_registry[type(target)].op_schema(target) + raise RuntimeError(f"Cannot find schema for {type(target)}") + + +@dataclass +class GraphState: + inputs: list[Argument] = field(default_factory=list) + outputs: list[Argument] = field(default_factory=list) + nodes: list[Node] = field(default_factory=list) + tensor_values: dict[str, TensorMeta] = field(default_factory=dict) + sym_int_values: dict[str, SymInt] = field(default_factory=dict) + sym_bool_values: dict[str, SymBool] = field(default_factory=dict) + sym_float_values: dict[str, SymFloat] = field(default_factory=dict) + is_single_tensor_return: bool = False + custom_obj_values: dict[str, CustomObjArgument] = field(default_factory=dict) + + +class Final(type): + def __new__(metacls, name, bases, classdict): + for b in bases: + if isinstance(b, Final): + raise TypeError(f"type '{b.__name__}' is not an acceptable base type") + return type.__new__(metacls, name, bases, dict(classdict)) + + +@final +class GraphModuleSerializer(metaclass=Final): + def __init__( + self, + graph_signature: ep.ExportGraphSignature, + module_call_graph: list[ep.ModuleCallEntry], + ): + self.graph_state = GraphState() + self.graph_signature = graph_signature + self.module_call_graph = module_call_graph + self.custom_objs: dict[str, torch._C.ScriptObject] = {} + self.duplicate_getitem_nodes: dict[str, str] = {} + self.treespec_namedtuple_fields: dict[str, NamedTupleDef] = {} + + @contextmanager + def save_graph_state(self): + saved = self.graph_state + self.graph_state = GraphState() + try: + yield + finally: + self.graph_state = saved + + def handle_placeholder(self, node: torch.fx.Node): + assert node.op == "placeholder" + val = node.meta["val"] + log.debug("[handle_placeholder] %s: %s", node.name, val) + if isinstance(val, torch.Tensor): + graph_input = Argument.create( + as_tensor=self.serialize_tensor_output(node.name, val) + ) + elif isinstance(val, torch.SymInt): + graph_input = Argument.create( + as_sym_int=self.serialize_sym_int_output(node.name, val) + ) + elif isinstance(val, torch.SymFloat): + raise AssertionError("SymFloat graph input is not implemented yet.") + elif isinstance(val, (int, bool, str, float, type(None))): + graph_input = self.serialize_input(val) + elif isinstance(val, ep.CustomObjArgument): + class_fqn = val.class_fqn + graph_input = Argument.create( + as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn) + ) + self.graph_state.custom_obj_values[ + node.name + ] = self.serialize_script_obj_meta(val) + else: + raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") + self.graph_state.inputs.append(graph_input) + + def handle_output(self, node: torch.fx.Node): + assert node.op == "output" + assert len(node.args) == 1, "FX.Node's args should have one arg" + node_args = node.args[0] + log.debug("[handle_output] %s: %s", node.name, node_args) + if isinstance(node_args, torch.fx.Node): + # For singleton tensor returns + self.graph_state.is_single_tensor_return = True + self.graph_state.outputs = [self.serialize_input(node_args)] + else: + assert isinstance(node_args, (tuple, list)) + self.graph_state.outputs = [self.serialize_input(arg) for arg in node_args] + + def serialize_operator(self, target) -> str: + if isinstance(target, str): + return target + elif target.__module__.startswith("torch._ops"): + # TODO(zhxchen17) Maybe provide a function name helper in FX. + # From torch.fx.node._get_qualified_name + module = target.__module__.replace("torch._ops", "torch.ops") + return f"{module}.{target.__name__}" + else: # TODO(zhxchen17) Don't catch all here. + return f"{target.__module__}.{target.__name__}" + + def handle_call_function(self, node: torch.fx.Node): + assert node.op == "call_function" + meta_val = node.meta.get("val") + log.debug( + "[handle_call_function] %s: %s(%s, {%s}) -> %s", + node.name, + node.target, + node.args, + node.kwargs, + meta_val, + ) + + # getitem has been handled in the producer node, skip it here + if node.target is operator.getitem: + return + + if node.target in _SYM_OPS or ( + meta_val is not None + and isinstance(meta_val, (torch.SymInt, torch.SymBool, torch.SymFloat)) + ): + assert len(node.kwargs) == 0 + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_sym_op_inputs(node.target, node.args), + outputs=[self.serialize_output(node.name, meta_val)], + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.OpOverload): + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + # TODO: create a new tensor_values here, meta might have faketensor info + metadata=self.serialize_metadata(node), + ) + elif isinstance(node.target, torch._ops.HigherOrderOperator): + + def _is_hop_single_tensor_return(node) -> bool: + assert isinstance(node.target, torch._ops.HigherOrderOperator) + # HOP schema is not always available, so we look at node.meta["val"] + meta_val = node.meta.get("val", None) + return meta_val is not None and isinstance(meta_val, torch.Tensor) + + # Special handle serialization for aoti_call_delegate + if node.target is torch._higher_order_ops.aoti_call_delegate: + serializable_args = list(node.args) + + # AOTI lowered module is not serializable, serialize the aoti_path instead + lowered_module_name: str = node.args[0].name # type: ignore[assignment, no-untyped-def, union-attr] + assert hasattr(node.graph.owning_module, lowered_module_name) + lowered_module = getattr(node.graph.owning_module, lowered_module_name) # type: ignore[no-untyped-def] + serializable_args[0] = lowered_module.aoti_path + + # AOTI compiled graph module in node.args[0] is stateful, and will fail the verifier check + # Skip serializing original_gm as a workaround + serializable_args[1] = None + + serializable_weight_nodes = [] + if serializable_args[2] is not None and isinstance( + serializable_args[2], Iterable + ): + for weight_node in serializable_args[2]: + # skip passing custom obj into the weight arg as an hack + # The schema of weight input is a list of Tensors. + # Downstream runtime is not actively consuming the weighs arg for anything meaningful. + if isinstance(weight_node, torch.fx.Node) and isinstance( + weight_node.meta.get("val", None), ep.CustomObjArgument + ): + continue + serializable_weight_nodes.append(weight_node) + serializable_args[2] = serializable_weight_nodes + + def serialize_tensor_list_output(node): + meta_val = node.meta.get("val", None) + tensor_args = [] + for idx, meta in enumerate(meta_val): + name = self._output_node_name_at_index(node, idx) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(serializable_args, node.kwargs), + outputs=serialize_tensor_list_output(node), + metadata=self.serialize_metadata(node), + is_hop_single_tensor_return=False, + ) + else: + ex_node = Node( + target=self.serialize_operator(node.target), + inputs=self.serialize_hoo_inputs(node.args, node.kwargs), + outputs=self.serialize_hoo_outputs(node), + metadata=self.serialize_metadata(node), + is_hop_single_tensor_return=_is_hop_single_tensor_return(node), + ) + elif type(node.target) in _serialization_registry: + # Sanity check for unhandled serialization. + assert ( + type(node.target) in _serialization_registry + ), f"{type(node.target)} is not supported in export serialization." + + handler = _serialization_registry[type(node.target)] + namespace = handler.namespace() + op_name = handler.to_op_name(node.target) + assert isinstance(namespace, str) and isinstance(op_name, str) + assert ":" not in namespace and ":" not in op_name + ex_node = Node( + target=f"#{namespace}:{op_name}", + inputs=self.serialize_inputs(node.target, node.args, node.kwargs), + outputs=self.serialize_outputs(node), + metadata=self.serialize_metadata(node), + ) + else: + raise SerializeError(f"Serializing {node.target} is not supported") + + self.graph_state.nodes.append(ex_node) + + def handle_get_attr(self, node): + log.debug("[handle_get_attr] %s", node.name) + + def _output_node_at_index(self, node, index) -> Optional[torch.fx.Node]: + user_node = None + for user in node.users: + assert user.target is operator.getitem, f"{user} is not a getitem node" + if index == user.args[1]: + if user_node is None: + user_node = user + else: + # We want to deduplicate getitem nodes that are trying to + # index to the same index + self.duplicate_getitem_nodes[user.name] = user_node.name + return user_node + + def _output_node_name_at_index(self, node, index) -> str: + user_node = self._output_node_at_index(node, index) + if user_node is None: + return f"{node.name}_unused_{index}" + else: + return user_node.name + + def serialize_metadata(self, node: torch.fx.Node) -> dict[str, str]: + ret = {} + + if stack_trace := node.meta.get("stack_trace"): + ret["stack_trace"] = stack_trace + + if nn_module_stack := node.meta.get("nn_module_stack"): + + def export_nn_module_stack(val): + assert isinstance(val, tuple) and len(val) == 2 + path, ty = val + + assert isinstance(path, str) + assert isinstance(ty, str) + + return path + "," + ty + + # Serialize to "key,orig_path,type_str" + nn_module_list = [ + f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items() + ] + ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) + + if source_fn_st := node.meta.get("source_fn_stack"): + source_fn_list = [ + f"{source_fn[0]},{self.serialize_operator(source_fn[1])}" + for source_fn in source_fn_st + ] + ret["source_fn_stack"] = ST_DELIMITER.join(source_fn_list) + + if torch_fn := node.meta.get("torch_fn"): + ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn)) + + if custom := node.meta.get("custom"): + try: + ret["custom"] = json.dumps(custom) + except Exception as e: + raise SerializeError( + f"Failed to serialize custom metadata for node {node.name} with error {e}" + ) from e + + return ret + + def serialize_script_obj_meta( + self, script_obj_meta: ep.CustomObjArgument + ) -> CustomObjArgument: + log.debug("[serialize_script_obj_meta] %s", script_obj_meta) + return CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def serialize_sym_op_inputs(self, op, args) -> list[NamedArgument]: + if isinstance(op, torch._ops.OpOverload): + args_names = [arg.name for arg in op._schema.arguments] + else: + assert op in _SYM_OPS + args_names = list(inspect.signature(op).parameters.keys()) + serialized_args = [] + for args_name, arg in zip(args_names, args): + serialized_args.append( + NamedArgument( + name=args_name, + arg=self.serialize_input(arg), + kind=ArgumentKind.POSITIONAL, + ) + ) + return serialized_args + + def serialize_inputs( + self, + target: Any, # torch._ops.OpOverload and other custom operator types. + args, + kwargs=None, + ) -> list[NamedArgument]: + schema = None + serialized_args = [] + + if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind): + obj = args[0] + method = args[1] + schema = target.schema(obj, method) + else: + assert isinstance( + target, (torch._ops.OpOverload, *_registered_extension_types()) + ) + schema = _get_schema_from_target(target) + assert schema is not None + kwargs = kwargs or {} + + for i, schema_arg in enumerate(schema.arguments): + if schema_arg.name in kwargs: + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input( + kwargs[schema_arg.name], schema_arg.type + ), + kind=ArgumentKind.KEYWORD, + ) + ) + elif not schema_arg.kwarg_only and i < len(args): + serialized_args.append( + NamedArgument( + name=schema_arg.name, + arg=self.serialize_input(args[i], schema_arg.type), + kind=ArgumentKind.POSITIONAL, + ) + ) + else: + # We intentionally don't serialize the missing arguments + # with default values + pass + + return serialized_args + + def serialize_hoo_inputs(self, args, kwargs) -> list[NamedArgument]: + """ + For serializing HOO inputs since HOOs do not have a schema. + """ + inputs = [ + NamedArgument( + name="", arg=self.serialize_input(a), kind=ArgumentKind.POSITIONAL + ) + for a in args + ] + inputs.extend( + [ + NamedArgument( + name=name, + arg=self.serialize_input(a), + kind=ArgumentKind.KEYWORD, + ) + for name, a in kwargs.items() + ] + ) + return inputs + + def is_inductor_sym_int_arg(self, arg) -> bool: + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node and should be + # verified with is_sym_int_arg() + return type(arg) is int or isinstance(arg, torch.SymInt) + + def is_sym_int_arg(self, arg) -> bool: + return type(arg) is int or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_int_values + ) + + def is_sym_float_arg(self, arg) -> bool: + return isinstance(arg, float) or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_float_values + ) + + def is_sym_bool_arg(self, arg) -> bool: + return isinstance(arg, bool) or ( + isinstance(arg, torch.fx.Node) + and arg.name in self.graph_state.sym_bool_values + ) + + # should be torch._C.JitType but that annotation is busted + def serialize_input(self, arg, arg_type: Optional[Any] = None) -> Argument: + import torch._inductor.ir as inductor_ir + + inductor_tensor_buffers = ( + inductor_ir.Buffer, + inductor_ir.ReinterpretView, + ) + + if isinstance(arg, torch.fx.Node): + if arg.op == "get_attr": + assert isinstance(arg.target, str) + attr = getattr(arg.graph.owning_module, arg.target) + + if isinstance(attr, torch.Tensor): + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) + elif isinstance(attr, torch.fx.GraphModule): + with self.save_graph_state(): + graph = self.serialize_graph(attr) + return Argument.create( + as_graph=GraphArgument(name=arg.target, graph=graph) + ) + else: + raise SerializeError( + f"Unsupported getattr attribute {arg.target} with type: {type(attr)}" + ) + elif self.is_sym_int_arg(arg): + return Argument.create( + as_sym_int=SymIntArgument.create(as_name=arg.name) + ) + elif self.is_sym_float_arg(arg): + return Argument.create( + as_sym_float=SymFloatArgument.create(as_name=arg.name) + ) + elif self.is_sym_bool_arg(arg): + return Argument.create( + as_sym_bool=SymBoolArgument.create(as_name=arg.name) + ) + elif isinstance(arg.meta["val"], ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument( + name=arg.name, class_fqn=arg.meta["val"].class_fqn + ) + ) + elif arg.name in self.duplicate_getitem_nodes: + dedup_name = self.duplicate_getitem_nodes[arg.name] + return Argument.create(as_tensor=TensorArgument(name=dedup_name)) + else: + return Argument.create(as_tensor=TensorArgument(name=arg.name)) + elif isinstance(arg, inductor_tensor_buffers): + # Other branches are for arguments in fx node. + # This is a special branch for handling buffers (representing tensor arguments) + # for inductor's ExternalFallbackNode + # export_extern_kernel_node() is using this function to serialize arguments + arg_name = arg.get_name() + assert arg_name is not None, "Buffer must have valid name" + return Argument.create(as_tensor=TensorArgument(name=arg_name)) + elif isinstance(arg, inductor_ir.TorchBindObject): + # This is a special branch for handling TorchBindObject + # for inductor's ExternalFallbackNode + # export_extern_kernel_node() is using this function to serialize arguments + arg_name = arg.get_name() + assert arg_name is not None, "Buffer must have valid name" + arg_val = arg.get_real_obj() + class_fqn = arg_val._type().qualified_name() + self.custom_objs[arg_name] = arg_val + return Argument.create(as_custom_obj=CustomObjArgument(arg_name, class_fqn)) + elif isinstance(arg, torch.SymInt): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_int_arg(arg) being true + return Argument.create(as_sym_int=SymIntArgument.create(as_name=str(arg))) + elif isinstance(arg, torch.SymFloat): + # This is a special branch for handling SymFloat args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node with + # self.is_sym_float_arg(arg) being true + return Argument.create( + as_sym_float=SymFloatArgument.create(as_name=str(arg)) + ) + elif type(arg) is bool: + return Argument.create(as_bool=arg) + elif type(arg) is str: + return Argument.create(as_string=arg) + elif type(arg) is int: + return Argument.create(as_int=arg) + elif type(arg) is float: + return Argument.create(as_float=arg) + elif arg is None: + return Argument.create(as_none=True) + elif isinstance(arg, (list, tuple)): + if len(arg) == 0: + if arg_type is not None: + if isinstance(arg_type, torch.OptionalType): + arg_type = arg_type.getElementType() # type: ignore[assignment] + assert isinstance(arg_type, torch.ListType) + elem_type = arg_type.getElementType() + if isinstance(elem_type, torch.OptionalType): + elem_type = elem_type.getElementType() + + if isinstance(elem_type, torch.BoolType): + return Argument.create(as_bools=[]) + elif isinstance(elem_type, torch.IntType): + return Argument.create(as_ints=[]) + elif isinstance(elem_type, torch.FloatType): + return Argument.create(as_floats=[]) + elif isinstance(elem_type, torch.StringType): + return Argument.create(as_strings=[]) + elif isinstance(elem_type, torch.TensorType): + return Argument.create(as_tensors=[]) + else: + # I believe empty symint lists default to ints, but + # please file an issue if this is not the case + raise SerializeError(f"Empty list with type {elem_type} nyi.") + else: + # We could serialize this by default to a tensor list. This + # is needed in the HOO case + log.warning( + "Unsure how to serialize the given empty list, " + "as we don't know what is the type of this argument. " + "Serializing it as a tensor list by default." + ) + return Argument.create(as_tensors=[]) + + if all(type(a) is bool for a in arg): + return Argument.create(as_bools=list(arg)) + elif all(type(a) is int for a in arg): + return Argument.create(as_ints=list(arg)) + elif all(type(a) is float for a in arg): + return Argument.create(as_floats=list(arg)) + elif all(type(a) is str for a in arg): + return Argument.create(as_strings=list(arg)) + elif all(self.is_inductor_sym_int_arg(a) for a in arg): + # This is a special branch for handling SymInt args in inductor's + # ExternalFallbackNode. + # For regular FX graph, SymInt arg should be a fx.Node + values = [] + for a in arg: + if isinstance(a, torch.SymInt): + values.append(SymIntArgument.create(as_name=str(a))) + elif type(a) is int: + values.append(SymIntArgument.create(as_int=a)) + return Argument.create(as_sym_ints=values) + elif all(isinstance(a, torch.SymFloat) for a in arg): + return Argument.create( + as_sym_floats=[SymFloatArgument.create(as_name=str(a)) for a in arg] + ) + elif all(self.is_sym_int_arg(a) for a in arg): + # list of sym_ints + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymIntArgument.create(as_name=a.name)) + elif type(a) is int: + values.append(SymIntArgument.create(as_int=a)) + return Argument.create(as_sym_ints=values) + elif all(self.is_sym_float_arg(a) for a in arg): + # list of sym_float + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymFloatArgument.create(as_name=a.name)) + elif isinstance(a, float): + values.append(SymFloatArgument.create(as_float=a)) + return Argument.create(as_sym_floats=values) + elif all(self.is_sym_bool_arg(a) for a in arg): + # list of sym_bools + values = [] + for a in arg: + if isinstance(a, torch.fx.Node): + values.append(SymBoolArgument.create(as_name=a.name)) + elif isinstance(a, bool): + values.append(SymBoolArgument.create(as_bool=a)) + return Argument.create(as_sym_bools=values) + elif all(isinstance(a, torch.fx.Node) for a in arg): + # list of tensors + arguments = [] + for a in arg: + if a.op == "get_attr": + raise SerializeError( + "getattr nodes containing tensors should not appear in the graph" + ) + arguments.append(TensorArgument(name=a.name)) + return Argument.create(as_tensors=arguments) + elif all(isinstance(a, (torch.fx.Node, type(None))) for a in arg): + # list of optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=True) + elif isinstance(a, torch.fx.Node): + return OptionalTensorArgument.create( + as_tensor=TensorArgument(name=a.name) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + elif all(isinstance(a, inductor_tensor_buffers) for a in arg): + # list of inductor buffers + return Argument.create( + as_tensors=[TensorArgument(name=a.get_name()) for a in arg], + ) + elif all( + isinstance(a, (*inductor_tensor_buffers, type(None))) for a in arg + ): + # list of inductor buffers as optional tensors + def serialize_optional_tensor_args(a): + if a is None: + return OptionalTensorArgument.create(as_none=True) + elif isinstance(a, inductor_tensor_buffers): + return OptionalTensorArgument.create( + as_tensor=TensorArgument(name=a.get_name()) + ) + else: + raise SerializeError(f"Unsupported list/tuple argument: {a}") + + return Argument.create( + as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) + ) + else: + raise SerializeError( + f"Unsupported list/tuple argument type: {[type(a) for a in arg]}" + ) + elif isinstance(arg, torch.dtype): + return Argument.create(as_scalar_type=_TORCH_TO_SERIALIZE_DTYPE[arg]) + elif isinstance(arg, torch.device): + return Argument.create(as_device=Device(type=arg.type, index=arg.index)) + elif isinstance(arg, torch.memory_format): + return Argument.create( + as_memory_format=_TORCH_TO_SERIALIZE_MEMORY_FORMAT[arg] + ) + elif isinstance(arg, torch.layout): + return Argument.create(as_layout=_TORCH_TO_SERIALIZE_LAYOUT[arg]) + elif isinstance(arg, torch._C.ScriptObject): + if not ( + arg._has_method("__getstate__") # type: ignore[attr-defined] + and arg._has_method("__setstate__") # type: ignore[attr-defined] + ): + raise SerializeError( + f"Unable to serialize custom class {arg}. Please define " + "serialization methods via def_pickle()." + ) + # Custom objects through torchind are serializable with pickle, + # through implementing the .def_pickle function. This should result + # in the object containing a __getstate__ and __setstate__ + # serialize/deserialize function. + custom_obj_name = f"_custom_obj_{len(self.custom_objs)}" + self.custom_objs[custom_obj_name] = arg + class_fqn = arg._type().qualified_name() # type: ignore[attr-defined] + return Argument.create( + as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn) + ) + elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + return Argument.create(as_operator=self.serialize_operator(arg)) + else: + raise SerializeError( + f"Unsupported argument type: {type(arg)} with schema arg_type {arg_type}" + ) + + def serialize_tensor_output(self, name, meta_val) -> TensorArgument: + assert name not in self.graph_state.tensor_values + self.graph_state.tensor_values[name] = serialize_tensor_meta(meta_val) + return TensorArgument(name=name) + + def serialize_sym_int_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_int_values + self.graph_state.sym_int_values[name] = serialize_sym_int(meta_val) + return SymIntArgument.create(as_name=name) + + def serialize_sym_float_output(self, name, meta_val) -> SymFloatArgument: + assert name not in self.graph_state.sym_float_values + self.graph_state.sym_float_values[name] = serialize_sym_float(meta_val) + return SymFloatArgument.create(as_name=name) + + def serialize_sym_bool_output(self, name, meta_val) -> SymIntArgument: + assert name not in self.graph_state.sym_bool_values + self.graph_state.sym_bool_values[name] = serialize_sym_bool(meta_val) + return SymBoolArgument.create(as_name=name) + + def serialize_input_spec(self, spec: ep.InputSpec) -> InputSpec: + log.debug("[serialize_input_spec] %s", spec) + if spec.kind == ep.InputKind.USER_INPUT: + if isinstance(spec.arg, ep.ConstantArgument): + if type(spec.arg.value) is int: + constant_spec = ConstantValue.create(as_int=spec.arg.value) + elif type(spec.arg.value) is bool: + constant_spec = ConstantValue.create(as_bool=spec.arg.value) + elif type(spec.arg.value) is str: + constant_spec = ConstantValue.create(as_string=spec.arg.value) + elif type(spec.arg.value) is float: + constant_spec = ConstantValue.create(as_float=spec.arg.value) + elif spec.arg.value is None: + constant_spec = ConstantValue.create(as_none=True) + else: + raise SerializeError( + f"Unhandled constant input {spec.arg.value} to serialize" + ) + return InputSpec.create( + constant_input=InputToConstantInputSpec( + name=spec.arg.name, value=constant_spec + ) + ) + else: + return InputSpec.create( + user_input=UserInputSpec(arg=self.serialize_argument_spec(spec.arg)) + ) + elif spec.kind == ep.InputKind.PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + parameter=InputToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.BUFFER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + assert spec.persistent is not None + return InputSpec.create( + buffer=InputToBufferSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + persistent=spec.persistent, + ) + ) + elif spec.kind == ep.InputKind.CONSTANT_TENSOR: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return InputSpec.create( + tensor_constant=InputToTensorConstantSpec( + arg=TensorArgument(name=spec.arg.name), + tensor_constant_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.CUSTOM_OBJ: + assert spec.target is not None + assert isinstance(spec.arg, ep.CustomObjArgument) + return InputSpec.create( + custom_obj=InputToCustomObjSpec( + arg=CustomObjArgument( + name=spec.arg.name, class_fqn=spec.arg.class_fqn + ), + custom_obj_name=spec.target, + ) + ) + elif spec.kind == ep.InputKind.TOKEN: + assert isinstance(spec.arg, ep.TokenArgument) + return InputSpec.create( + token=InputTokenSpec( + arg=TokenArgument(name=spec.arg.name), + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: + log.debug("[serialize_output_spec] %s", spec) + if spec.kind == ep.OutputKind.USER_OUTPUT: + return OutputSpec.create( + user_output=UserOutputSpec(arg=self.serialize_argument_spec(spec.arg)) + ) + elif spec.kind == ep.OutputKind.LOSS_OUTPUT: + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + loss_output=LossOutputSpec(arg=TensorArgument(name=spec.arg.name)) + ) + elif spec.kind == ep.OutputKind.BUFFER_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + buffer_mutation=BufferMutationSpec( + arg=TensorArgument(name=spec.arg.name), + buffer_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_parameter=GradientToParameterSpec( + arg=TensorArgument(name=spec.arg.name), + parameter_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.GRADIENT_TO_USER_INPUT: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + gradient_to_user_input=GradientToUserInputSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.USER_INPUT_MUTATION: + assert spec.target is not None + assert isinstance(spec.arg, ep.TensorArgument) + return OutputSpec.create( + user_input_mutation=UserInputMutationSpec( + arg=TensorArgument(name=spec.arg.name), + user_input_name=spec.target, + ) + ) + elif spec.kind == ep.OutputKind.TOKEN: + assert isinstance(spec.arg, ep.TokenArgument) + return OutputSpec.create( + token=OutputTokenSpec( + arg=TokenArgument(name=spec.arg.name), + ) + ) + else: + raise AssertionError(f"Unknown argument kind: {spec}") + + def serialize_signature(self, sig: ep.ExportGraphSignature) -> GraphSignature: + log.debug("\n[serialize_signature]") + return GraphSignature( + input_specs=[self.serialize_input_spec(s) for s in sig.input_specs], + output_specs=[self.serialize_output_spec(s) for s in sig.output_specs], + ) + + def serialize_argument_spec(self, x: ep.ArgumentSpec) -> Argument: + if isinstance(x, ep.TensorArgument): + return Argument.create(as_tensor=TensorArgument(name=x.name)) + elif isinstance(x, ep.SymIntArgument): + return Argument.create(as_sym_int=SymIntArgument.create(as_name=x.name)) + elif isinstance(x, ep.SymFloatArgument): + return Argument.create(as_sym_float=SymFloatArgument.create(as_name=x.name)) + elif isinstance(x, ep.ConstantArgument): + return self.serialize_input(x.value) + elif isinstance(x, ep.CustomObjArgument): + return Argument.create( + as_custom_obj=CustomObjArgument(name=x.name, class_fqn=x.class_fqn) + ) + else: + raise AssertionError("TODO") + + def serialize_treespec(self, treespec): + # We want to additionally save all the field names of the namedtuples in + # case users want to check that the treespec types are equivalent + def store_namedtuple_fields(ts): + if ts.type is None: + return + if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): + serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ + ts.context + ].serialized_type_name + if serialized_type_name in self.treespec_namedtuple_fields: + field_names = self.treespec_namedtuple_fields[ + serialized_type_name + ].field_names + if field_names != ts.context._fields: + raise SerializeError( + f"The given TreeSpec's namedtuple type {ts.context} " + f"was found to have field names {ts.context._fields} " + f"but somehow previously was found to have field names {field_names}." + ) + else: + self.treespec_namedtuple_fields[ + serialized_type_name + ] = NamedTupleDef(field_names=ts.context._fields) + + for child in ts.children_specs: + store_namedtuple_fields(child) + + serialized_treespec = treespec_dumps(treespec, TREESPEC_VERSION) + store_namedtuple_fields(treespec) + return serialized_treespec + + def serialize_module_call_signature( + self, module_call_signature: ep.ModuleCallSignature + ) -> ModuleCallSignature: + log.debug("[serialize_module_call_signature] %s", module_call_signature) + return ModuleCallSignature( + inputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.serialize_argument_spec(x) for x in module_call_signature.outputs + ], + in_spec=self.serialize_treespec(module_call_signature.in_spec), + out_spec=self.serialize_treespec(module_call_signature.out_spec), + forward_arg_names=names + if (names := module_call_signature.forward_arg_names) + else None, + ) + + def serialize_module_call_graph( + self, module_call_graph: list[ep.ModuleCallEntry] + ) -> list[ModuleCallEntry]: + log.debug("\n[serialize_module_call_graph]") + return [ + ModuleCallEntry( + fqn=entry.fqn, + signature=( + self.serialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph + ] + + def serialize_outputs(self, node: torch.fx.Node) -> list[Argument]: + """For a given node, return the dataclass representing its output values. + + [NOTE: Multiple outputs] We handle aggregates differently than FX. For + FX, it looks like: + + x = call_function("multiple_return", ...) + element0 = call_function(getitem, x, 0) + foo = call_function("use_output", element0) + + We do not want the intermediate `getitem` call, so our serialized thing looks like: + + element0, element1, element2 = call_function("multiple_return", ...) + foo = call_function("use_output", element0) + + We want names to be consistent across these two schemes, so that we can + mostly reuse the names coming from FX. This function computes a mapping from + the FX representation to our representation, preserving the names. + """ + + def _is_single_tensor_list_return(target: Any) -> bool: + schema = _get_schema_from_target(target) + returns = schema.returns + + if len(returns) != 1: + return False + return_type = returns[0].real_type + return isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ) + + assert node.op == "call_function" and isinstance( + node.target, (torch._ops.OpOverload, *_registered_extension_types()) + ) + + schema = _get_schema_from_target(node.target) + returns = schema.returns + + if len(returns) == 0: + return [] + + meta_val = node.meta["val"] + + # Check single value return + if _is_single_tensor_list_return(node.target): + # e.g "-> Tensor[]" + tensor_args = [] + for idx, meta in enumerate(meta_val): + name = self._output_node_name_at_index(node, idx) + tensor_args.append(self.serialize_tensor_output(name, meta)) + return [Argument.create(as_tensors=tensor_args)] + elif len(returns) == 1: + return [self.serialize_output(node.name, meta_val)] + + # There are a two possibilities at this point: + # - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)" + # - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])" + # + # Either way, start by gathering a list of TensorArguments with the correct names. + # For consistent naming with FX, consult the downstream `getitem` node and + # make sure our outputs have the same name. + + output_arguments = [] + for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)): + if meta is None: + assert isinstance( + return_schema.real_type, (torch.OptionalType, torch.TensorType) + ) + # When the return type is annoated as Tensor type, the op can also return an + # undefined Tensor which will be implicitly converted to None in Python. + output_arguments.append(Argument.create(as_none=True)) + elif isinstance(meta, FakeTensor): + assert isinstance( + return_schema.real_type, (torch.OptionalType, torch.TensorType) + ) + name = self._output_node_name_at_index(node, idx) + output_arguments.append(self.serialize_output(name, meta)) + elif isinstance(meta, list): + # for List[Tensor] return type + assert isinstance( + return_schema.real_type, torch.ListType + ) and isinstance( + return_schema.real_type.getElementType(), torch.TensorType + ) + user_node = self._output_node_at_index(node, idx) + assert user_node is not None + + args = [] + for i, m in enumerate(meta): + if m is None: + continue + sub_user_node_name = self._output_node_name_at_index(user_node, i) + args.append(self.serialize_tensor_output(sub_user_node_name, m)) + output_arguments.append(Argument.create(as_tensors=args)) + elif isinstance(meta, (int, SymInt, float, SymFloat)): + user_node_name = self._output_node_name_at_index(node, idx) + output_arguments.append(self.serialize_output(user_node_name, meta)) + else: + raise ValueError( + f"Unhandled output type {type(meta)} from node {node.format_node()}" + ) + + return output_arguments + + def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]: + """ + For serializing HOO outputs since HOOs do not have a schema. + """ + meta_val = node.meta["val"] + + if isinstance(meta_val, tuple): + outputs = [] + for i, element_meta_val in enumerate(meta_val): + user_node = self._output_node_at_index(node, i) + if isinstance(element_meta_val, list): + # e.g "-> Tensor[]" + assert user_node is not None + + tensors = [] + for j, m in enumerate(element_meta_val): + if not isinstance(m, torch.Tensor): + raise SerializeError( + f"Serialize list output with type {type(m)} nyi" + ) + + name = self._output_node_name_at_index(user_node, j) + tensors.append(self.serialize_tensor_output(name, m)) + outputs.append(Argument.create(as_tensors=tensors)) + + else: + name = ( + user_node.name + if user_node is not None + else f"{node.name}_unused_{i}" + ) + + outputs.append(self.serialize_output(name, element_meta_val)) + + return outputs + else: + return [self.serialize_output(node.name, meta_val)] + + def serialize_output(self, name: str, meta_val: Any) -> Argument: + # Check single value return + if meta_val is None: + return Argument.create(as_none=True) + if isinstance(meta_val, torch.Tensor): + # e.g "-> Tensor" + return Argument.create( + as_tensor=self.serialize_tensor_output(name, meta_val) + ) + elif isinstance(meta_val, (bool, torch.SymBool)): + # e.g "-> SymBool" + return Argument.create( + as_sym_bool=self.serialize_sym_bool_output(name, meta_val) + ) + elif isinstance(meta_val, (int, torch.SymInt)): + # e.g "-> SymInt" + assert not isinstance(meta_val, bool) + return Argument.create( + as_sym_int=self.serialize_sym_int_output(name, meta_val) + ) + elif isinstance(meta_val, (float, torch.SymFloat)): + # e.g "-> SymFloat" + return Argument.create( + as_sym_float=self.serialize_sym_float_output(name, meta_val) + ) + + # list outputs should've been handled earlier + raise SerializeError(f"Unable to serialize output {meta_val}") + + def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]: + meta_val = node.meta["val"] + + idx_to_name = {} + for user in node.users: + assert ( + user.target is operator.getitem + ), f"User node {user} of {node} is incorrect" + idx_to_name[user.args[1]] = user.name + + for idx, _ in enumerate(meta_val): + # FX does not emit a getitem node for any outputs that are unused. + # However, we need a name for them so that the number of outputs will + # correctly match the schema. Just assign a dummy name. + if idx not in idx_to_name: + idx_to_name[idx] = f"{node.name}_unused_{idx}" + + arg_list = [] + for i, element_meta_val in enumerate(meta_val): + arg_list.append( + self.serialize_tensor_output(idx_to_name[i], element_meta_val) + ) + + return arg_list + + def serialize_graph(self, graph_module: torch.fx.GraphModule) -> Graph: + assert isinstance(graph_module, torch.fx.GraphModule) + log.debug( + "[serialize_graph]\n\n%s", graph_module.print_readable(print_output=False) + ) + + for node in graph_module.graph.nodes: + try: + getattr(self, f"handle_{node.op}")(node) + except Exception as e: + raise SerializeError( + f"Failed serializing node {node} in graph: {node.format_node()}\n Original exception {traceback.format_exc()}" + ) from e + + return Graph( + inputs=self.graph_state.inputs, + nodes=self.graph_state.nodes, + tensor_values=self.graph_state.tensor_values, + sym_int_values=self.graph_state.sym_int_values, + sym_float_values=self.graph_state.sym_float_values, + sym_bool_values=self.graph_state.sym_bool_values, + custom_obj_values=self.graph_state.custom_obj_values, + outputs=self.graph_state.outputs, + is_single_tensor_return=self.graph_state.is_single_tensor_return, + ) + + def serialize_graph_module_metadata(self, meta: dict[str, Any]): + ret = {} + if custom := meta.get("custom"): + log.debug("\n[serialize_graph_module_metadata] %s", custom) + try: + ret["custom"] = json.dumps(custom) + except Exception as e: + raise SerializeError( + f"Failed to serialize custom metadata for graph with error {e}" + ) from e + + return ret + + def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule: + log.debug("\n[serialize]") + graph = self.serialize_graph(graph_module) + + return GraphModule( + graph=graph, + signature=self.serialize_signature(self.graph_signature), + module_call_graph=self.serialize_module_call_graph(self.module_call_graph), + metadata=self.serialize_graph_module_metadata(graph_module.meta), + treespec_namedtuple_fields=self.treespec_namedtuple_fields, + ) + + +@final +class ExportedProgramSerializer(metaclass=Final): + def __init__( + self, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, + ): + self.opset_version: dict[str, int] = {} + if opset_version: + self.opset_version.update(opset_version) + if "aten" not in self.opset_version: + self.opset_version["aten"] = torch._C._get_max_operator_version() + + self.pickle_protocol = pickle_protocol + + def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: + """ + Args: + exported_program: Exported Program to serialize + """ + exported_program.validate() + + gm_serializer = GraphModuleSerializer( + exported_program.graph_signature, exported_program.module_call_graph + ) + serialized_graph_module = gm_serializer.serialize(exported_program.graph_module) + serialized_range_constraints = serialize_range_constraints( + exported_program.range_constraints + ) + + # TODO: Directly serialize exported_program.constants once + # CustomClassHolders get stored in the ExportedProgram rather than in + # the graph + constants: dict[str, Any] = gm_serializer.custom_objs.copy() + for n, t in exported_program.constants.items(): + assert n not in constants + constants[n] = t + + serialized_ep = ExportedProgram( + graph_module=serialized_graph_module, + opset_version=self.opset_version, + range_constraints=serialized_range_constraints, + schema_version=SchemaVersion( + major=SCHEMA_VERSION[0], + minor=SCHEMA_VERSION[1], + ), + verifiers=[v.dialect for v in exported_program.verifiers], + torch_version=torch.__version__, + ) + + # Test canonical form is well defined. + canonicalize(serialized_ep, set(constants.keys())) + + # Proxy cannot be dumped, so we remove them. + new_state_dict = remove_proxy_from_state_dict( + exported_program.state_dict, in_place=False + ) + return _SerializedProgram( + serialized_ep, + serialize_torch_artifact(new_state_dict, self.pickle_protocol), + serialize_torch_artifact(constants, self.pickle_protocol), + serialize_torch_artifact( + exported_program.example_inputs, self.pickle_protocol + ), + ) + + +@final +class GraphModuleDeserializer(metaclass=Final): + @dataclasses.dataclass + class Result: + graph_module: torch.fx.GraphModule + signature: ep.ExportGraphSignature + module_call_graph: list[ep.ModuleCallEntry] + names_to_symbols: dict[str, sympy.Symbol] + state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]] + constants: dict[str, _ConstantAttributeType] + example_inputs: Optional[tuple[tuple[torch.Tensor, ...], dict[str, Any]]] + + def __init__(self) -> None: + self.serialized_name_to_node: dict[str, torch.fx.Node] = {} + self.serialized_name_to_meta: dict[str, MetaType] = {} + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + + @contextmanager + def save_graph_module(self) -> Iterator[None]: + saved = ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + self.unbacked_symbols, + ) + self.graph = torch.fx.Graph() + self.module = torch.nn.Module() + self.serialized_name_to_node = {} + self.serialized_name_to_meta = {} + self.unbacked_symbols: set[sympy.Symbol] = set() + try: + yield + finally: + ( + self.graph, + self.module, + self.serialized_name_to_node, + self.serialized_name_to_meta, + self.unbacked_symbols, + ) = saved + + def deserialize_extension_operator(self, serialized_target: str): + namespace, op_name = serialized_target.split(":") + namespace = namespace[1:] # starting with # + handler = _deserialization_registry[namespace] + return handler.from_op_name(op_name) + + def deserialize_operator(self, serialized_target: str): + if serialized_target.startswith( + "_operator" + ): # TODO(zhxchen17) Follow up on this. + module = operator + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("torch"): + module = torch # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("math"): + module = math # type: ignore[misc] + serialized_target_names = serialized_target.split(".")[1:] + elif serialized_target.startswith("#"): + return self.deserialize_extension_operator(serialized_target) + else: # TODO(zhxchen17) Don't catch all here. + return serialized_target + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + def _parse_sym_expr( + self, expr_str: str, hint: Optional[Union[int, bool, float]] = None + ) -> sympy.Expr: + """ + Parses and does bottom-up processing of sympy.Expr nodes, + populating ShapeEnv & caching symbols as needed. + """ + + def _process_sym_expr( + sym: sympy.Expr, hint: Optional[Union[int, bool, float]] = None + ) -> sympy.Expr: + if sym.is_Integer or sym.is_Float or sym.is_Boolean: # base case + return sym + else: # recursive case + # important to use str(expr) and not _print_sympy(), + # str(expr) is key for self.symbol_name_to_range + expr_str = str(sym) + for arg in sym.args: + self._parse_sym_expr(arg) + # symbol caching + if expr_str in self.symbol_name_to_symbol: + sym = self.symbol_name_to_symbol[expr_str] + else: + self.symbol_name_to_symbol[expr_str] = sym + if isinstance(sym, sympy.Symbol) and symbolic_shapes.symbol_is_type( + sym, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT) + ): + self.unbacked_symbols.add(sym) + # hints + if hint is not None and sym not in self.shape_env.var_to_val: + self.shape_env.add_var_to_val(sym, hint) # type: ignore[arg-type] + # ValueRanges + if vr := self.symbol_name_to_range.get(expr_str): + self.shape_env.constrain_symbol_range( + sym, + compiler_min=vr.lower, # type: ignore[arg-type] + compiler_max=vr.upper, # type: ignore[arg-type] + ) + # ShapeEnv meta + if isinstance(sym, sympy.Symbol): + self.shape_env.var_to_stack[sym] = CapturedTraceback.extract(skip=1) + return sym + + expr = sympy.sympify( + expr_str, + locals={**self.sympy_functions, **self.symbol_name_to_symbol}, + ) + return _process_sym_expr(expr, hint) + + def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: + val = s.value + if s.type == "as_expr": + if val.hint is None: + hint = None + else: + assert val.hint.type == "as_int" + hint = val.hint.value + + sym = self._parse_sym_expr(val.expr_str, hint) + return self.shape_env.create_symintnode(sym, hint=hint) + elif s.type == "as_int": + assert type(val) is int + return val + else: + raise SerializeError( + f"SymInt has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_sym_float(self, s: SymFloat) -> Union[float, torch.SymFloat]: + val = s.value + if s.type == "as_expr": + hint = val.hint.as_float if val.hint else None + sym = self._parse_sym_expr(val.expr_str, hint) + return self.shape_env.create_symfloatnode(sym, hint=hint) + elif s.type == "as_float": + assert isinstance(val, float) + return val + else: + raise SerializeError( + f"SymFloat has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]: + val = s.value + if s.type == "as_expr": + expr = self._parse_sym_expr(val.expr_str) + return self.shape_env.create_symboolnode(expr) + elif s.type == "as_bool": + assert isinstance(val, bool) + return val + else: + raise SerializeError( + f"SymBool has invalid field type {s.type} with value {s.value}" + ) + + def deserialize_tensor_meta( + self, + tensor_meta: TensorMeta, + ) -> FakeTensor: + with self.fake_tensor_mode: + return cast( + FakeTensor, + torch.empty_strided( + tuple(self.deserialize_sym_int(val) for val in tensor_meta.sizes), # type: ignore[misc] + tuple(self.deserialize_sym_int(val) for val in tensor_meta.strides), # type: ignore[misc] + device=deserialize_device(tensor_meta.device), + dtype=_SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype], + requires_grad=tensor_meta.requires_grad, + ), + ) + + def deserialize_script_obj_meta( + self, script_obj_meta: CustomObjArgument + ) -> ep.CustomObjArgument: + return ep.CustomObjArgument( + name=script_obj_meta.name, + class_fqn=script_obj_meta.class_fqn, + ) + + def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]: + if output.type == "as_tensor": + return self.serialized_name_to_node[output.as_tensor.name] + elif output.type == "as_sym_int": + return self.serialized_name_to_node[output.as_sym_int.as_name] + elif output.type == "as_sym_bool": + return self.serialized_name_to_node[output.as_sym_bool.as_name] + elif output.type == "as_sym_float": + return self.serialized_name_to_node[output.as_sym_float.as_name] + elif output.type == "as_int": + return output.as_int + elif output.type == "as_float": + return output.as_float + elif output.type == "as_bool": + return output.as_bool + elif output.type == "as_none": + return None + else: + raise SerializeError(f"Unable to deserialize output node {output}") + + def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: + log.debug("\n[deserialize_graph]") + + # Handle the tensor metas. + for name, tensor_value in serialized_graph.tensor_values.items(): + log.debug("[deserialize_tensor_meta] %s (input): %s", name, tensor_value) + meta_val = self.deserialize_tensor_meta(tensor_value) + log.debug("[deserialize_tensor_meta] %s (output): %s", name, meta_val) + self.serialized_name_to_meta[name] = meta_val + + for name, sym_int_value in serialized_graph.sym_int_values.items(): + log.debug("[deserialize_sym_int] %s (input): %s", name, sym_int_value) + int_val = self.deserialize_sym_int(sym_int_value) + log.debug("[deserialize_sym_int] %s (output): %s", name, int_val) + self.serialized_name_to_meta[name] = int_val + + for name, sym_float_value in serialized_graph.sym_float_values.items(): + log.debug("[deserialize_sym_float] %s (input): %s", name, sym_float_value) + float_val = self.deserialize_sym_float(sym_float_value) + log.debug("[deserialize_sym_float] %s (output): %s", name, float_val) + self.serialized_name_to_meta[name] = float_val + + for name, sym_bool_value in serialized_graph.sym_bool_values.items(): + log.debug("[deserialize_sym_bool] %s (input): %s", name, sym_bool_value) + bool_val = self.deserialize_sym_bool(sym_bool_value) + log.debug("[deserialize_sym_bool] %s (output): %s", name, bool_val) + self.serialized_name_to_meta[name] = bool_val + + for name, script_obj_meta in serialized_graph.custom_obj_values.items(): + log.debug("[deserialize_script_obj_meta] %s", script_obj_meta) + self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta( + script_obj_meta + ) + + log.debug("\n[deserialize graph nodes]") + # Inputs: convert to placeholder nodes in FX. + for i, input_ in enumerate(serialized_graph.inputs): + log.debug("[deserialize input] %s", input_) + if input_.type in ("as_tensor", "as_custom_obj"): + node_name = input_.value.name + placeholder_node = self.graph.placeholder(node_name) + # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments) + # we will overwrite it + placeholder_node.name = node_name + self.sync_fx_node(node_name, placeholder_node) + elif input_.type == "as_sym_int": + if input_.value.type == "as_name": + node_name = input_.value.as_name + placeholder_node = self.graph.placeholder(node_name) + # FX might declare a name illegal (e.g. some nn.Modules use "input" as forward() arguments) + # we will overwrite it + placeholder_node.name = node_name + self.sync_fx_node(node_name, placeholder_node) + else: + raise SerializeError( + f"Deserializing a constant symint {input_.value} as an input" + ) + elif input_.type in ( + "as_int", + "as_float", + "as_bool", + "as_none", + "as_string", + ): + node_name = self.signature.input_specs[i].arg.name or f"arg{i}" + placeholder_node = self.graph.placeholder(node_name) + placeholder_node.meta["val"] = self.deserialize_input(input_) + else: + raise SerializeError(f"Invalid input type {input_}") + + # Nodes: convert to call_function nodes. + for serialized_node in serialized_graph.nodes: + try: + target = self.deserialize_operator(serialized_node.target) + self.deserialize_node(serialized_node, target) + + except Exception as e: + raise SerializeError( + f"Failed deserializing node {serialized_node}\n Original exception {traceback.format_exc()}" + ) from e + + # Outputs: convert to a single `output` node. + outputs = [] + for output in serialized_graph.outputs: + log.debug("[deserialize output] %s", output) + outputs.append(self.deserialize_graph_output(output)) + + if serialized_graph.is_single_tensor_return: + assert len(outputs) == 1 + outputs = outputs[0] # type: ignore[assignment] + else: + outputs = tuple(outputs) # type: ignore[assignment] + + output_node = self.graph.output(outputs) + + if serialized_graph.is_single_tensor_return: + output_node.meta["val"] = output_node.args[0].meta["val"] + else: + output_node.meta["val"] = tuple( + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ) + + # recompute unbacked bindings + for node in self.graph.nodes: + if (val := node.meta.get("val")) is not None and ( + unbacked_bindings := symbolic_shapes._free_unbacked_symbols_with_path( + val, + (), + shape_env=self.shape_env, + pending=self.unbacked_symbols, + simplify=True, + ) + ): + node.meta["unbacked_bindings"] = unbacked_bindings + + assert len(self.unbacked_symbols) == 0 + return self.graph + + def deserialize_node(self, serialized_node: Node, target: Callable) -> None: + def _is_single_tensor_return(target) -> bool: + schema = _get_schema_from_target(target) + returns = schema.returns + return len(returns) == 1 and isinstance( + returns[0].real_type, torch.TensorType + ) + + if ( + target in _SYM_OPS + or target + == torch.ops.aten.item.default # this can produce either SymInt or SymBool + ): + name = serialized_node.outputs[0].value.as_name + args = self.deserialize_sym_op_inputs(serialized_node.inputs) + + fx_node = self.graph.create_node("call_function", target, args, {}, name) + self.deserialize_sym_op_outputs(serialized_node, fx_node) + + elif isinstance(target, torch._ops.HigherOrderOperator): + args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs) + metadata = self.deserialize_metadata(serialized_node.metadata) + for x in (*args, *kwargs.values()): + if isinstance(x, torch.fx.Node) and x.op == "get_attr": + # this means that we have deserialized a graph argument, but + # unfortunately the schema for it does not include metadata; + # so we reuse the metadata of the HOP call for such arguments + x.meta.update(metadata) + # If a serialized HOP node has a length=1 outputs of type `as_tensor``. + # There could be two cases: + # (1) The HOP node returns a single tensor + # (2) The HOP node returns a tuple containing a single tensor + # We distinguish (1) and (2) by the `is_single_tensor_return` + # field in the schema of Node + # For BC, getattr() will return True if `is_single_tensor_return` doesn't + # exist. This is because prior to adding `is_single_tensor_return`, + # only (1) could happen as we handle (2) with type `as_tensors` + name = ( + serialized_node.outputs[0].as_tensor.name + if len(serialized_node.outputs) == 1 + and hasattr(serialized_node.outputs[0], "as_tensor") + and getattr(serialized_node, "is_hop_single_tensor_return", True) + else None + ) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + fx_node.meta.update(metadata) + + elif isinstance( + target, (torch._ops.OpOverload, *_registered_extension_types()) + ): + # For convenience: if this node returns a single tensor, name the + # newly-created node after it. This ensures that these tensor values + # have names that are consistent with serialized. + name = ( + serialized_node.outputs[0].as_tensor.name + if _is_single_tensor_return(target) + else None # FX will generate a name for us. + ) + args, kwargs = self.deserialize_inputs(target, serialized_node) + fx_node = self.graph.create_node( + "call_function", target, args, kwargs, name + ) + self.deserialize_outputs(serialized_node, fx_node) + else: + _additional_msg = ( + ( + f"We failed to resolve {target} to an operator. " + + "If it's a custom op/custom triton op, this is usally because the custom op is not registered" + + " when deserializing. Please import the custom op to register it before deserializing." + + " Otherwise, please file an issue on github." + ) + if isinstance(target, str) + else "" + ) + raise SerializeError( + _additional_msg + + f" Unsupported target type for node {serialized_node}: {type(target)}." + ) + + fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata)) + log.debug( + "[deserialize_node] %s: %s(%s, {%s}) -> %s", + fx_node.name, + fx_node.target, + fx_node.args, + fx_node.kwargs, + fx_node.meta.get("val"), + ) + if ( + fx_node.op not in ["placeholder", "output"] + and "nn_module_stack" not in fx_node.meta + ): + fx_node.meta[ + "nn_module_stack" + ] = {} # serialization throws away empty dicts + + def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: + log.debug("[deserialize_input_spec] %s", i) + if i.type == "user_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=self.deserialize_argument_spec(i.user_input.arg), + target=None, + ) + elif i.type == "parameter": + return ep.InputSpec( + kind=ep.InputKind.PARAMETER, + arg=ep.TensorArgument(name=i.parameter.arg.name), + target=i.parameter.parameter_name, + ) + elif i.type == "buffer": + return ep.InputSpec( + kind=ep.InputKind.BUFFER, + arg=ep.TensorArgument(name=i.buffer.arg.name), + target=i.buffer.buffer_name, + persistent=i.buffer.persistent, + ) + elif i.type == "tensor_constant": + return ep.InputSpec( + kind=ep.InputKind.CONSTANT_TENSOR, + arg=ep.TensorArgument(name=i.tensor_constant.arg.name), + target=i.tensor_constant.tensor_constant_name, + ) + elif i.type == "custom_obj": + return ep.InputSpec( + kind=ep.InputKind.CUSTOM_OBJ, + arg=ep.CustomObjArgument( + name=i.custom_obj.arg.name, class_fqn=i.custom_obj.arg.class_fqn + ), + target=i.custom_obj.custom_obj_name, + ) + elif i.type == "token": + return ep.InputSpec( + kind=ep.InputKind.TOKEN, + arg=ep.TokenArgument(name=i.token.arg.name), + target=None, + ) + elif i.type == "constant_input": + return ep.InputSpec( + kind=ep.InputKind.USER_INPUT, + arg=ep.ConstantArgument( + name=i.constant_input.name, + value=self.deserialize_constant_input(i.constant_input.value), + ), + target=None, + ) + else: + raise AssertionError(f"Unknown input spec {i}") + + def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: + log.debug("[deserialize_output_spec] %s", o) + if o.type == "user_output": + return ep.OutputSpec( + kind=ep.OutputKind.USER_OUTPUT, + arg=self.deserialize_argument_spec(o.user_output.arg), + target=None, + ) + elif o.type == "loss_output": + return ep.OutputSpec( + kind=ep.OutputKind.LOSS_OUTPUT, + arg=ep.TensorArgument(name=o.loss_output.arg.name), + target=None, + ) + elif o.type == "buffer_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.BUFFER_MUTATION, + arg=ep.TensorArgument(name=o.buffer_mutation.arg.name), + target=o.buffer_mutation.buffer_name, + ) + elif o.type == "gradient_to_parameter": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_PARAMETER, + arg=ep.TensorArgument(name=o.gradient_to_parameter.arg.name), + target=o.gradient_to_parameter.parameter_name, + ) + elif o.type == "gradient_to_user_input": + return ep.OutputSpec( + kind=ep.OutputKind.GRADIENT_TO_USER_INPUT, + arg=ep.TensorArgument(name=o.gradient_to_user_input.arg.name), + target=o.gradient_to_user_input.user_input_name, + ) + elif o.type == "user_input_mutation": + return ep.OutputSpec( + kind=ep.OutputKind.USER_INPUT_MUTATION, + arg=ep.TensorArgument(name=o.user_input_mutation.arg.name), + target=o.user_input_mutation.user_input_name, + ) + elif o.type == "token": + return ep.OutputSpec( + kind=ep.OutputKind.TOKEN, + arg=ep.TokenArgument(name=o.token.arg.name), + target=None, + ) + else: + raise AssertionError(f"Unknown output spec {o}") + + def deserialize_signature(self, sig: GraphSignature) -> ep.ExportGraphSignature: + log.debug("\n[deserialize_signature]") + return ep.ExportGraphSignature( + input_specs=[self.deserialize_input_spec(i) for i in sig.input_specs], + output_specs=[self.deserialize_output_spec(o) for o in sig.output_specs], + ) + + def deserialize( + self, + serialized_graph_module: GraphModule, + serialized_state_dict: Union[dict[str, torch.Tensor], bytes], + constants: Union[dict[str, Any], bytes], + example_inputs: Optional[ + Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes] + ] = None, + symbol_name_to_range: Optional[dict[str, symbolic_shapes.ValueRanges]] = None, + ) -> Result: + global _CURRENT_DESERIALIZER + assert _CURRENT_DESERIALIZER is None + _CURRENT_DESERIALIZER = self + try: + log.debug("\n[deserialize]") + self.shape_env = symbolic_shapes.ShapeEnv(assume_static_by_default=True) + self.fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=True, + shape_env=self.shape_env, + ) + self.sympy_functions = { + # all torch.utils._sympy.functions should go here + # TODO(avik): find a better way to keep this collection in sync; + # e.g.., `exec('from torch.utils._sympy.functions import *', ...)` + # would work as long as the public API of that module is complete + "FloorDiv": torch.utils._sympy.functions.FloorDiv, + "ModularIndexing": torch.utils._sympy.functions.ModularIndexing, + "Where": torch.utils._sympy.functions.Where, + "PythonMod": torch.utils._sympy.functions.PythonMod, + "Mod": torch.utils._sympy.functions.Mod, + "CleanDiv": torch.utils._sympy.functions.CleanDiv, + "CeilToInt": torch.utils._sympy.functions.CeilToInt, + "FloorToInt": torch.utils._sympy.functions.FloorToInt, + "CeilDiv": torch.utils._sympy.functions.CeilDiv, + "LShift": torch.utils._sympy.functions.LShift, + "RShift": torch.utils._sympy.functions.RShift, + "PowByNatural": torch.utils._sympy.functions.PowByNatural, + "FloatPow": torch.utils._sympy.functions.FloatPow, + "FloatTrueDiv": torch.utils._sympy.functions.FloatTrueDiv, + "IntTrueDiv": torch.utils._sympy.functions.IntTrueDiv, + "IsNonOverlappingAndDenseIndicator": torch.utils._sympy.functions.IsNonOverlappingAndDenseIndicator, + "TruncToFloat": torch.utils._sympy.functions.TruncToFloat, + "TruncToInt": torch.utils._sympy.functions.TruncToInt, + "RoundToInt": torch.utils._sympy.functions.RoundToInt, + "RoundDecimal": torch.utils._sympy.functions.RoundDecimal, + "ToFloat": torch.utils._sympy.functions.ToFloat, + "Identity": torch.utils._sympy.functions.Identity, + } + self.symbol_name_to_symbol: dict[str, sympy.Symbol] = {} + self.constants = deserialize_torch_artifact(constants) + self.signature = self.deserialize_signature( + serialized_graph_module.signature + ) + + # deserialization does analysis with checks on 0/1, so we create fake range constraints and + # restore the original range constraints afterwards + self.symbol_name_to_range = {} + # we also need to bump unbacked sym[float,int] counters in the + # shape env to accommodate unbacked symbols in the exported program + self.unbacked_symbols = set() + count_unbacked_symfloat, count_unbacked_symint = -1, -1 + unbacked_symfloat_prefix, unbacked_symint_prefix = ( + prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT] + ) + if symbol_name_to_range: + for k, vr in symbol_name_to_range.items(): + lower = vr.lower + if vr.upper >= 2: # max is >= 2, not sym bool range + lower = max(2, lower) + self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges( + _int_to_sympy_int(lower, -int_oo), vr.upper + ) + if k.startswith(unbacked_symfloat_prefix): + i = int(k[len(unbacked_symfloat_prefix) :]) + count_unbacked_symfloat = max(count_unbacked_symfloat, i) + elif k.startswith(unbacked_symint_prefix): + i = int(k[len(unbacked_symint_prefix) :]) + count_unbacked_symint = max(count_unbacked_symint, i) + + # TODO(pianpwk): if we can clean up unused symbols in range_constraints, + # then this logic can just be handled with self.unbacked_symbols alone + for _ in range(count_unbacked_symfloat + 1): + next(self.shape_env.unbacked_symfloat_counter) + for _ in range(count_unbacked_symint + 1): + next(self.shape_env.unbacked_symint_counter) + + if example_inputs is not None and len(example_inputs) > 0: + self.example_inputs = deserialize_torch_artifact(example_inputs) + else: + self.example_inputs = None + self.deserialize_graph(serialized_graph_module.graph) + + with _enable_graph_inputs_of_type_nn_module(self.example_inputs): + module_call_graph = self.deserialize_module_call_graph( + serialized_graph_module.module_call_graph + ) + graph_module = ep._create_graph_module_for_export(self.module, self.graph) + meta = {} + if custom := serialized_graph_module.metadata.get("custom"): + meta["custom"] = json.loads(custom) + if hasattr(serialized_graph_module, "treespec_namedtuple_fields"): + meta["treespec_namedtuple_fields"] = {} + for ( + type_, + fields, + ) in serialized_graph_module.treespec_namedtuple_fields.items(): + meta["treespec_namedtuple_fields"][type_] = fields.field_names + graph_module.meta = meta + return GraphModuleDeserializer.Result( + graph_module=graph_module, + signature=self.signature, + module_call_graph=module_call_graph, + names_to_symbols=self.symbol_name_to_symbol, + state_dict=deserialize_torch_artifact(serialized_state_dict), + constants=self.constants, + example_inputs=self.example_inputs, + ) + finally: + _CURRENT_DESERIALIZER = None + + def sync_fx_node(self, name: str, fx_node: torch.fx.Node): + if name in self.serialized_name_to_node: + raise SerializeError(f"Node {name} has already been deserialized before.") + # overwrite name + fx_node.name = name + self.serialized_name_to_node[name] = fx_node + assert "val" not in fx_node.meta + fx_node.meta["val"] = self.serialized_name_to_meta[name] + + def deserialize_sym_op_inputs(self, inputs): + return tuple(self.deserialize_input(input.arg) for input in inputs) + + def deserialize_inputs(self, target, serialized_node: Node): + schema_args = _get_schema_from_target(target).arguments + argument_kinds = {input.name: input.kind for input in serialized_node.inputs} + actual_args = { + input.name: self.deserialize_input(input.arg) + for input in serialized_node.inputs + } + args = [] + kwargs: OrderedDict[str, Any] = OrderedDict() + for schema_arg in schema_args: + if schema_arg.name in actual_args: + arg = actual_args[schema_arg.name] + kind = argument_kinds[schema_arg.name] + if kind == ArgumentKind.POSITIONAL: + args.append(arg) + continue + elif kind == ArgumentKind.KEYWORD and not keyword.iskeyword( + schema_arg.name + ): + kwargs[schema_arg.name] = arg + continue + + # If there's no ArgumentKind found, fallback to the old cases. + is_positional = ( + not schema_arg.has_default_value() and not schema_arg.kwarg_only + ) + if is_positional: + args.append(actual_args[schema_arg.name]) + elif keyword.iskeyword(schema_arg.name): + assert not schema_arg.kwarg_only + if len(kwargs) > 0: + kwargs = OrderedDict() + args.extend(list(kwargs.values())) + args.append(actual_args[schema_arg.name]) + else: + if schema_arg.name in actual_args: + kwargs[schema_arg.name] = actual_args[schema_arg.name] + return tuple(args), kwargs + + def deserialize_hoo_inputs(self, inputs: list[NamedArgument]): + """ + For deserializing HOO inputs since HOOs do not have a schema. + """ + args = [] + kwargs = {} + for input_ in inputs: + if input_.name != "": + kwargs[input_.name] = self.deserialize_input(input_.arg) + else: + args.append(self.deserialize_input(input_.arg)) + return (tuple(args), kwargs) + + def deserialize_input(self, inp: Argument) -> Any: + value = inp.value + typ_ = inp.type + if typ_ == "as_none": + # None should converted as None, but is encoded as bool in serialized + # Convert serialized object to torch equivalent + return None + elif typ_ == "as_tensor": + return self.serialized_name_to_node[inp.as_tensor.name] + elif typ_ == "as_scalar_type": + return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type] + elif typ_ == "as_memory_format": + return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format] + elif typ_ == "as_layout": + return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout] + elif typ_ == "as_graph": + assert isinstance(value, GraphArgument) + with self.save_graph_module(): + self.deserialize_graph(value.graph) + submodule = ep._create_graph_module_for_export(self.module, self.graph) + self.module.register_module(value.name, submodule) + return self.graph.create_node( + "get_attr", + value.name, + name=value.name, + ) + elif typ_ == "as_device": + return deserialize_device(inp.as_device) + elif typ_ == "as_int": + return inp.as_int + elif typ_ == "as_float": + return inp.as_float + elif typ_ == "as_bool": + return inp.as_bool + elif typ_ == "as_string": + return inp.as_string + elif typ_ == "as_sym_int": + return self.deserialize_sym_argument(inp.as_sym_int) + elif typ_ == "as_sym_float": + return self.deserialize_sym_argument(inp.as_sym_float) + elif typ_ == "as_sym_bool": + return self.deserialize_sym_argument(inp.as_sym_bool) + elif isinstance(value, list): + if len(value) == 0: + return [] + elif typ_ == "as_tensors": + result = [self.serialized_name_to_node[arg.name] for arg in value] + return result + elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): + # convert from serialized.python.types.List to python list + return list(value) + elif typ_ in ("as_sym_ints", "as_sym_bools", "as_sym_floats"): + return [self.deserialize_sym_argument(arg) for arg in value] + elif typ_ == "as_optional_tensors": + + def deserialize_optional_tensor_args(a): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return self.serialized_name_to_node[a.value.name] + else: + raise SerializeError(f"Unhandled argument {inp}") + + return list(map(deserialize_optional_tensor_args, value)) + else: + raise SerializeError(f"Unhandled argument {inp}") + elif typ_ == "as_custom_obj": + if inp.as_custom_obj.name in self.serialized_name_to_node: + # Custom object has been lifted as an input + return self.serialized_name_to_node[inp.as_custom_obj.name] + return self.constants[inp.as_custom_obj.name] + elif typ_ == "as_operator": + return self.deserialize_operator(inp.as_operator) + else: + raise SerializeError(f"Unhandled argument {inp}") + + def deserialize_constant_input(self, inp: ConstantValue) -> Any: + if inp.type == "as_int": + return int(inp.as_int) + elif inp.type == "as_float": + return float(inp.as_float) + elif inp.type == "as_string": + return str(inp.as_string) + elif inp.type == "as_bool": + return bool(inp.as_bool) + elif inp.type == "as_none": + return None + else: + raise SerializeError(f"Unhandled constant argument {inp} to deserialize") + + def deserialize_sym_argument(self, sym_arg): + if isinstance(sym_arg, SymIntArgument): + if sym_arg.type == "as_int": + return sym_arg.as_int + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + elif isinstance(sym_arg, SymFloatArgument): + if sym_arg.type == "as_float": + return sym_arg.as_float + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + elif isinstance(sym_arg, SymBoolArgument): + if sym_arg.type == "as_bool": + return sym_arg.as_bool + elif sym_arg.type == "as_name": + return self.serialized_name_to_node[sym_arg.as_name] + raise SerializeError(f"Unknown symbolic argument type: {sym_arg}") + + def deserialize_sym_op_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + + def deserialize_outputs(self, serialized_node: Node, fx_node: torch.fx.Node): + # Check single value return + if len(serialized_node.outputs) == 0: + return + + if ( + len(serialized_node.outputs) == 1 + and "torch.ops.higher_order" in serialized_node.target + and not getattr(serialized_node, "is_hop_single_tensor_return", True) + and serialized_node.outputs[0].type != "as_none" + ): + + def _deserialize_hop_with_single_return(serialized_node, fx_node): + meta_val: list[Any] = [] + arg = None + if serialized_node.outputs[0].type == "as_tensor": + arg = serialized_node.outputs[0].as_tensor + elif isinstance( + serialized_node.outputs[0].value, + (SymIntArgument, SymBoolArgument, SymFloatArgument), + ): + arg = serialized_node.outputs[0].value + deserialized_metadata = self.deserialize_metadata( + serialized_node.metadata + ) + assert arg is not None + self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata) + fx_node.meta["val"] = tuple(meta_val) + self.serialized_name_to_node[fx_node.name] = fx_node + return + + return _deserialize_hop_with_single_return(serialized_node, fx_node) + + if ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].type == "as_tensor" + ): + self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node) + return + elif len(serialized_node.outputs) == 1 and isinstance( + serialized_node.outputs[0].value, + (SymIntArgument, SymBoolArgument, SymFloatArgument), + ): + self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node) + return + elif ( + len(serialized_node.outputs) == 1 + and serialized_node.outputs[0].type == "as_none" + ): + # manually rename the node to a unused name to avoid naming conflicts + fx_node.meta["val"] = None + fx_node._rename(f"{self.graph._target_to_str(fx_node.target)}_unused") + return + + self.deserialize_multiple_outputs(serialized_node, fx_node) + + def generate_getitem( + self, + meta_val, + fx_node: torch.fx.Node, + arg: Union[TensorArgument, SymIntArgument, SymFloatArgument], + idx: int, + deserialized_metadata: dict[str, Any], + ): + if isinstance(arg, TensorArgument): + name = arg.name + elif isinstance(arg, SymIntArgument): + name = arg.as_name + elif isinstance(arg, SymFloatArgument): + name = arg.as_name + else: + raise AssertionError( + f"generate_getitem got unknown argument type {type(arg)}" + ) + individual_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + name=name, + ) + self.sync_fx_node(name, individual_output) + meta_val.append(self.serialized_name_to_meta[name]) + # The derived `getitem` nodes should have the same stacktrace as the + # original `fx_node` + individual_output.meta.update(deserialized_metadata) + + def generate_getitems( + self, + meta_val, + fx_node: torch.fx.Node, + args, + deserialized_metadata: dict[str, Any], + ): + for idx, arg in enumerate(args): + if isinstance(arg, (TensorArgument, SymIntArgument, SymFloatArgument)): + self.generate_getitem( + meta_val, fx_node, arg, idx, deserialized_metadata + ) + continue + + assert isinstance(arg, Argument) + if arg.type in ("as_tensor", "as_sym_int", "as_sym_float"): + self.generate_getitem( + meta_val, fx_node, arg.value, idx, deserialized_metadata + ) + elif arg.type in ( + "as_tensors", + "as_sym_ints", + "as_sym_floats", + "as_ints", + "as_floats", + "as_strings", + "as_bools", + "as_sym_bools", + ): + list_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + ) + meta_val.append([]) + self.generate_getitems( + meta_val[-1], list_output, arg.value, deserialized_metadata + ) + list_output.meta.update(deserialized_metadata) + list_output.meta["val"] = meta_val[-1] + elif arg.type == "as_none": + individual_output = self.graph.create_node( + "call_function", + operator.getitem, + (fx_node, idx), + name="as_none", + ) + meta_val.append(None) + individual_output.meta["val"] = None + individual_output.meta.update(deserialized_metadata) + else: + raise NotImplementedError(f"Unimplemented node output type: {arg}") + + def deserialize_multiple_outputs( + self, serialized_node: Node, fx_node: torch.fx.Node + ) -> None: + deserialized_metadata = self.deserialize_metadata(serialized_node.metadata) + + # Convert multiple return types to FX format. + # In FX, each node only returns one value. So in order to represent + # multiple return values, we have to emit a `getitem` node for each + # return value. + # This performs the inverse mapping of the `serialize_outputs` call in + # serialization, see [NOTE: Multiple outputs] + meta_val: list[Any] = [] + if len(serialized_node.outputs) == 1: + assert isinstance(serialized_node.outputs[0].value, list) + assert isinstance(serialized_node.outputs[0].value[0], TensorArgument) + self.generate_getitems( + meta_val, + fx_node, + serialized_node.outputs[0].as_tensors, + deserialized_metadata, + ) + else: + self.generate_getitems( + meta_val, fx_node, serialized_node.outputs, deserialized_metadata + ) + + # also update the metaval for `fx_node` to be a list(meta) + fx_node.meta["val"] = tuple(meta_val) + self.serialized_name_to_node[fx_node.name] = fx_node + + def deserialize_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: + ret: dict[str, Any] = {} + if stack_trace := metadata.get("stack_trace"): + ret["stack_trace"] = stack_trace + + def deserialize_meta_func(serialized_target: str): + module = None + if serialized_target.startswith("torch.nn"): + module = torch.nn + serialized_target_names = serialized_target.split(".")[2:] + elif serialized_target.startswith("torch"): + module = torch + serialized_target_names = serialized_target.split(".")[1:] + else: + return self.deserialize_operator(serialized_target) + + target = module + for name in serialized_target_names: + if not hasattr(target, name): + return serialized_target + else: + target = getattr(target, name) + return target + + if nn_module_stack_str := metadata.get("nn_module_stack"): + # Originally serialized to "key,orig_path,type_str" + def import_nn_module_stack(key, path, ty): + return key, (path, ty) + + # Helper function to split string by commas, accounting for nested parentheses/brackets + def metadata_split(metadata): + out = [] + start, n = 0, 0 + a, b = "[(", ")]" + for end, c in enumerate(metadata): + if c in a: + n += 1 + elif c in b: + n -= 1 + elif c == "," and n == 0: + out.append(metadata[start:end]) + start = end + 1 + out.append(metadata[start:]) + assert len(out) == 3 + return out + + nn_module_stack = dict( + import_nn_module_stack(*metadata_split(item)) + for item in nn_module_stack_str.split(ST_DELIMITER) + ) + ret["nn_module_stack"] = nn_module_stack + + if source_fn_st_str := metadata.get("source_fn_stack"): + # Originally serializes to "fx_node_name,op_str" + source_fn_st = [] + for source_fn_str in source_fn_st_str.split(ST_DELIMITER): + name, target_str = source_fn_str.split(",") + source_fn_st.append((name, deserialize_meta_func(target_str))) + ret["source_fn_stack"] = source_fn_st + + if torch_fn_str := metadata.get("torch_fn"): + ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER)) + + if custom_str := metadata.get("custom"): + ret["custom"] = json.loads(custom_str) + + return ret + + def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec: + log.debug("[deserialize_argument_spec] %s", x) + if x.type == "as_tensor": + return ep.TensorArgument(name=x.as_tensor.name) + elif x.type == "as_sym_int": + return ep.SymIntArgument(name=x.as_sym_int.as_name) + elif x.type == "as_sym_float": + return ep.SymFloatArgument(name=x.as_sym_float.as_name) + elif x.type == "as_custom_obj": + return ep.ConstantArgument( + name=x.as_custom_obj.name, value=self.deserialize_input(x) + ) + else: + return ep.ConstantArgument(name="", value=self.deserialize_input(x)) + + def deserialize_module_call_signature( + self, module_call_signature: ModuleCallSignature + ) -> ep.ModuleCallSignature: + return ep.ModuleCallSignature( + inputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.inputs + ], + outputs=[ + self.deserialize_argument_spec(x) for x in module_call_signature.outputs + ], + in_spec=treespec_loads(module_call_signature.in_spec), + out_spec=treespec_loads(module_call_signature.out_spec), + forward_arg_names=names + if (names := module_call_signature.forward_arg_names) + else None, + ) + + def deserialize_module_call_graph( + self, module_call_graph: list[ModuleCallEntry] + ) -> list[ep.ModuleCallEntry]: + log.debug("\n[deserialize_module_call_graph]") + return [ + ep.ModuleCallEntry( + fqn=entry.fqn, + signature=( + self.deserialize_module_call_signature(entry.signature) + if entry.signature + else None + ), + ) + for entry in module_call_graph + ] + + +@final +class ExportedProgramDeserializer(metaclass=Final): + def __init__(self, expected_opset_version: Optional[dict[str, int]] = None): + self.expected_opset_version: dict[str, int] = {} + if expected_opset_version: + self.expected_opset_version.update(expected_opset_version) + if "aten" not in self.expected_opset_version: + self.expected_opset_version["aten"] = torch._C._get_max_operator_version() + + def deserialize_range_constraints( + self, + symbol_name_to_range: dict[str, symbolic_shapes.ValueRanges], + symbol_name_to_symbol: dict[str, sympy.Symbol], + ) -> dict[sympy.Symbol, ValueRanges]: + log.debug("\n[deserialize_range_constraints]") + range_constraints = {} + for k, v in symbol_name_to_range.items(): + if symbol := symbol_name_to_symbol.get(k): + log.debug("[deserialize_range_constraints] %s -> %s", k, v) + range_constraints[symbol] = v # type: ignore[arg-type] + else: + log.warning( + "Symbol %s did not appear in the graph that was deserialized", k + ) + return range_constraints + + def deserialize( + self, + exported_program: ExportedProgram, + state_dict: Union[dict[str, torch.Tensor], bytes], + constants: Union[dict[str, torch.Tensor], bytes], + example_inputs: Optional[ + Union[tuple[tuple[torch.Tensor, ...], dict[str, Any]], bytes] + ] = None, + *, + _unsafe_skip_version_check=False, + ) -> ep.ExportedProgram: + assert isinstance(exported_program, ExportedProgram) + version = exported_program.schema_version + + # TODO(zhxchen17) blocked on thrift schema refactor + if version.major != SCHEMA_VERSION[0] and not ( + version.major == 0 and version.minor == 0 + ): + if not _unsafe_skip_version_check: + raise SerializeError( + f"Serialized schema version {exported_program.schema_version} " + f"does not match our current schema version {SCHEMA_VERSION}." + ) + + symbol_name_to_range = { + k: symbolic_shapes.ValueRanges( + _int_to_sympy_int(v.min_val, -int_oo), + _int_to_sympy_int(v.max_val, int_oo), + ) + for k, v in exported_program.range_constraints.items() + } + res = GraphModuleDeserializer().deserialize( + exported_program.graph_module, + state_dict, + constants, + example_inputs, + symbol_name_to_range, + ) + range_constraints = self.deserialize_range_constraints( + symbol_name_to_range, + res.names_to_symbols, + ) + + result = ep.ExportedProgram( + root=res.graph_module, + graph=res.graph_module.graph, + graph_signature=res.signature, + state_dict=res.state_dict, # type: ignore[arg-type] + range_constraints=range_constraints, + module_call_graph=res.module_call_graph, + example_inputs=res.example_inputs, + constants=res.constants, + verifiers=[load_verifier(v) for v in exported_program.verifiers], + ) + log.debug("\n[deserialize]: %s", result) + return result + + +class EnumEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("utf-8") + return super().default(obj) + + +def _dataclass_to_dict(obj): + if isinstance(obj, _Union): + return {obj.type: _dataclass_to_dict(obj.value)} + elif dataclasses.is_dataclass(obj): + return { + f.name: _dataclass_to_dict(getattr(obj, f.name)) + for f in dataclasses.fields(obj) + } + elif isinstance(obj, list): + return [_dataclass_to_dict(x) for x in obj] + elif isinstance(obj, tuple): + return tuple(_dataclass_to_dict(x) for x in obj) + elif isinstance(obj, dict): + return {k: _dataclass_to_dict(v) for k, v in obj.items()} + elif isinstance(obj, float): + if obj == math.inf: + return "Infinity" + elif obj == -math.inf: + return "-Infinity" + elif obj == math.nan: + return "NaN" + else: + return obj + else: + return obj + + +def _to_json_bytes(obj: Any) -> bytes: + return json.dumps(_dataclass_to_dict(obj), cls=EnumEncoder, allow_nan=False).encode( + "utf-8" + ) + + +def serialize( + exported_program: ep.ExportedProgram, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> SerializedArtifact: + with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs): + serialized_program = ExportedProgramSerializer( + opset_version, pickle_protocol + ).serialize(exported_program) + assert isinstance(serialized_program.exported_program, ExportedProgram) + + json_bytes = _to_json_bytes(serialized_program.exported_program) + artifact = SerializedArtifact( + json_bytes, + serialized_program.state_dict, + serialized_program.constants, + serialized_program.example_inputs, + ) + return artifact + + +def _dict_to_dataclass(cls, data): + assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." + if typing.get_origin(cls) == Annotated: + return _dict_to_dataclass(cls.__origin__, data) + if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls): + if data is None: + return None + ty_args = typing.get_args(cls) + assert len(ty_args) == 2 + return _dict_to_dataclass(ty_args[0], data) + elif isinstance(cls, type) and issubclass(cls, _Union): + assert isinstance(data, dict) + assert len(data) == 1 + _type = next(iter(data.keys())) + _value = next(iter(data.values())) + assert isinstance(_type, str) + field_type = cls.__annotations__[_type] + return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) + elif dataclasses.is_dataclass(cls): + obj = cls(**data) # type: ignore[assignment,operator] + type_hints = typing.get_type_hints(cls) + for f in dataclasses.fields(cls): + name = f.name + new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name)) + setattr(obj, name, new_field_obj) + return obj + elif isinstance(data, list): + if len(data) == 0: + return data + d_type = typing.get_args(cls)[0] + return [_dict_to_dataclass(d_type, d) for d in data] + elif isinstance(data, dict): + v_type = typing.get_args(cls)[1] + return {k: _dict_to_dataclass(v_type, v) for k, v in data.items()} + elif cls == float: + return float(data) + return data + + +def deserialize( + artifact: SerializedArtifact, + expected_opset_version: Optional[dict[str, int]] = None, + *, + _unsafe_skip_version_check=False, +) -> ep.ExportedProgram: + assert isinstance(artifact.exported_program, bytes) + exported_program_str = artifact.exported_program.decode("utf-8") + exported_program_dict = json.loads(exported_program_str) + serialized_exported_program = _dict_to_dataclass( + ExportedProgram, exported_program_dict + ) + return ExportedProgramDeserializer(expected_opset_version).deserialize( + serialized_exported_program, + artifact.state_dict, + artifact.constants, + artifact.example_inputs, + _unsafe_skip_version_check=_unsafe_skip_version_check, + ) + + +def _canonicalize_graph( + sorted_inputs, sorted_outputs, graph, constants +) -> tuple[Graph, dict[str, str]]: + def _get_argument(a: Argument): + if a.type == "as_none": + return None + elif a.type == "as_tensor": + return a.as_tensor + elif a.type == "as_tensors": + return a.as_tensors + elif a.type == "as_int": + return None + elif a.type == "as_ints": + return None + elif a.type == "as_float": + return None + elif a.type == "as_floats": + return None + elif a.type == "as_string": + return None + elif a.type == "as_strings": + return None + elif a.type == "as_sym_int": + return a.as_sym_int + elif a.type == "as_sym_ints": + return a.as_sym_ints + elif a.type == "as_sym_float": + return a.as_sym_float + elif a.type == "as_sym_floats": + return a.as_sym_floats + elif a.type == "as_scalar_type": + return None + elif a.type == "as_memory_format": + return None + elif a.type == "as_layout": + return None + elif a.type == "as_device": + return None + elif a.type == "as_bool": + return None + elif a.type == "as_bools": + return None + elif a.type == "as_sym_bool": + return a.as_sym_bool + elif a.type == "as_sym_bools": + return a.as_sym_bools + elif a.type == "as_graph": + return None + elif a.type == "as_optional_tensors": + return a.as_optional_tensors + elif a.type == "as_custom_obj": + return a.as_custom_obj + elif a.type == "as_operator": + return None + else: + raise AssertionError(f"Unknown input type to the ExportedProgram: {a}") + + # Stage 1: Reorder named items. + def for_args(f, a): + assert isinstance(a, Argument) + pytree.tree_map(f, _get_argument(a)) + + def sort_nodes(nodes): + @dataclass + class Edges: + outs: list[int] + ins: int + + graph_inputs: set[str] = set() + def_table: dict[str, int] = {} + edges: dict[int, Edges] = {} + candidates: list[tuple[str, list[tuple[str, list[int]]], int]] = [] + rank: dict[str, int] = {} + ret: list[Node] = [] + + def get_name(a) -> Optional[str]: + if a is None: + return None + if isinstance(a, TensorArgument): + return a.name + elif isinstance(a, (SymIntArgument, SymBoolArgument, SymFloatArgument)): + if a.type == "as_name": + return a.as_name + elif a.type in ("as_int", "as_bool", "as_float"): + return None + else: + raise AssertionError(f"Unknown argument type: {a}") + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + return a.as_tensor.name + elif a.type == "as_none": + return None + else: + raise AssertionError(f"Unknown optional tensor type: {a}") + elif isinstance(a, CustomObjArgument): + return a.name + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + + def add_input(a): + if s := get_name(a): + graph_inputs.add(s) + + for_args(add_input, i) + + for idx, node in enumerate(nodes): + + def add_def(a): + if s := get_name(a): + assert s not in def_table + def_table[s] = idx + + for o in node.outputs: + for_args(add_def, o) + + edges[idx] = Edges([], 0) + + for idx, user in enumerate(nodes): + + def add_edge(a): + if s := get_name(a): + if s in constants: + return + if s not in def_table: + assert s in graph_inputs + return + src = def_table[s] + edges[src].outs.append(idx) + edges[idx].ins += 1 + + for i in user.inputs: + for_args(add_edge, i.arg) + + def add_rank(a): + if s := get_name(a): + assert s not in rank + rank[s] = len(rank) + + def get_rank(a): + s = get_name(a) + if s and s not in constants: + return rank[s] + else: + return -1 + + for i in sorted_inputs: + for_args(add_rank, i) + + def add_candidate(idx: int): + def get_ranks(i): + ranks = [] + for_args(lambda x: ranks.append(get_rank(x)), i) + return ranks + + node = nodes[idx] + args_rank = [(a.name, get_ranks(a.arg)) for a in node.inputs] + heapq.heappush(candidates, (node.target, args_rank, idx)) + + for idx, e in edges.items(): + if e.ins == 0: + add_candidate(idx) + + while len(candidates) > 0: + _, _, idx = heapq.heappop(candidates) + node = nodes[idx] + for o in node.outputs: + for_args(add_rank, o) + ret.append(node) + assert idx in edges + for user in edges[idx].outs: + e = edges[user] + assert e.ins > 0 + e.ins -= 1 + if e.ins == 0: + add_candidate(user) + edges[idx].outs.clear() + + return ret + + sorted_nodes = sort_nodes(graph.nodes) + assert len(sorted_nodes) == len(graph.nodes) + + # Stage 2: Rename nodes. + name_table: dict[str, str] = {} + + def rename_def(a): + def _rename(arg_name, values): + new_name = f"_{len(name_table)}" + assert arg_name not in name_table + name_table[arg_name] = new_name + assert arg_name in values + values[new_name] = values.pop(arg_name) + return new_name + + if a is None: + return + if isinstance(a, TensorArgument): + a.name = _rename(a.name, graph.tensor_values) + elif isinstance(a, SymIntArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_int_values) + elif isinstance(a, SymFloatArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_float_values) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = _rename(a.as_name, graph.sym_bool_values) + elif isinstance(a, CustomObjArgument): + a.name = _rename(a.name, graph.custom_obj_values) + else: + raise AssertionError(f"Unknown argument type: {a}") + + def replace_use(a): + if a is None: + return + if isinstance(a, TensorArgument): + a.name = name_table.get(a.name, a.name) + elif isinstance(a, (SymIntArgument, SymFloatArgument)): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, SymBoolArgument): + if a.type == "as_name": + a.as_name = name_table.get(a.as_name, a.as_name) + elif isinstance(a, OptionalTensorArgument): + if a.type == "as_tensor": + a.as_tensor.name = name_table.get(a.as_tensor.name, a.as_tensor.name) + elif isinstance(a, CustomObjArgument): + a.name = name_table.get(a.name, a.name) + else: + raise AssertionError(f"Unknown argument type: {a}") + + for i in sorted_inputs: + for_args(rename_def, i) + + for n in sorted_nodes: + for o in n.outputs: + for_args(rename_def, o) + + for n in sorted_nodes: + for i in n.inputs: + for_args(replace_use, i.arg) + + for o in sorted_outputs: + for_args(replace_use, o) + + # Stage 3: Remove unstable fields. + for n in sorted_nodes: + n.metadata.clear() + + # Stage 4: Aggregate values. + sorted_tensor_values = dict( + sorted(graph.tensor_values.items(), key=operator.itemgetter(0)) + ) + sorted_sym_int_values = dict( + sorted(graph.sym_int_values.items(), key=operator.itemgetter(0)) + ) + sorted_sym_float_values = dict( + sorted(graph.sym_float_values.items(), key=operator.itemgetter(0)) + ) + sorted_sym_bool_values = dict( + sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0)) + ) + sorted_custom_obj_values = dict( + sorted(graph.custom_obj_values.items(), key=operator.itemgetter(0)) + ) + + # Stage 5: Recurse in subgraphs. + counter = 0 + for node in sorted_nodes: + for i in node.inputs: + a = i.arg + if a.type == "as_graph": + a.as_graph.graph, _ = _canonicalize_graph( + a.as_graph.graph.inputs, + a.as_graph.graph.outputs, + a.as_graph.graph, + constants, + ) + a.as_graph.name = f"_g{counter}" + counter += 1 + + graph = Graph( + inputs=sorted_inputs, + outputs=sorted_outputs, + nodes=sorted_nodes, + tensor_values=sorted_tensor_values, + sym_int_values=sorted_sym_int_values, + sym_float_values=sorted_sym_float_values, + sym_bool_values=sorted_sym_bool_values, + is_single_tensor_return=graph.is_single_tensor_return, + custom_obj_values=sorted_custom_obj_values, + ) + return graph, name_table + + +def canonicalize( + ep: ExportedProgram, constants: Optional[set[str]] = None +) -> ExportedProgram: + """ + Normalize a serialized ExportedProgram, so that different eager program which + shares the same semantics can get a single representation on disk. + + This function canonicalizes an ExportedProgram by: + + 1. Sorting nodes in topological order. + 2. Rename nodes to have unique names. + 3. Remove unstable fields. + 4. Aggregate the above program fields. + 5. Recurse in subgraphs. + + Args: + ep (ExportedProgram): The ExportedProgram to canonicalize. + constants (Optional[set[str]]): Set of constants names + + Returns: + ExportedProgram: The canonicalized exported program. + """ + ep = copy.deepcopy(ep) + constants: set[str] = constants or set() + + opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0))) + range_constraints = dict( + sorted(ep.range_constraints.items(), key=operator.itemgetter(0)) + ) + module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn) + signature = ep.graph_module.signature + graph = ep.graph_module.graph + + assert len(graph.inputs) == len(signature.input_specs) + assert len(graph.outputs) == len(signature.output_specs) + + def rank_input(inp) -> tuple[int, Optional[str], int]: + idx, (_arg, spec) = inp + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + return 5, None, idx + elif spec.type == "parameter": + return 1, spec.parameter.parameter_name, idx + elif spec.type == "buffer": + return 2, spec.buffer.buffer_name, idx + elif spec.type == "tensor_constant": + return 3, spec.tensor_constant.tensor_constant_name, idx + elif spec.type == "custom_obj": + return 4, spec.custom_obj.custom_obj_name, idx + elif spec.type == "token": + return 0, None, idx + elif spec.type == "constant_input": + return 6, spec.constant_input.name, idx + else: + raise AssertionError(f"Unknown input type: {spec}") + + def rank_output(out) -> tuple[int, Optional[str], int]: + idx, (_arg, spec) = out + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + return 3, None, idx + elif spec.type == "loss_output": + return 3, None, idx + elif spec.type == "buffer_mutation": + return 1, spec.buffer_mutation.buffer_name, idx + elif spec.type == "gradient_to_parameter": + return 4, spec.gradient_to_parameter.parameter_name, idx + elif spec.type == "gradient_to_user_input": + return 5, None, idx + elif spec.type == "user_input_mutation": + return 2, None, idx + elif spec.type == "token": + return 0, None, idx + else: + raise AssertionError(f"Unknown output type: {spec}") + + sorted_ins = sorted( + enumerate(zip(graph.inputs, signature.input_specs)), key=rank_input + ) + + if len(sorted_ins) > 0: + sorted_inputs, input_specs = zip(*(i for idx, i in sorted_ins)) # type: ignore[assignment] + else: + sorted_inputs = () + input_specs = () + + sorted_outs = sorted( + enumerate(zip(graph.outputs, signature.output_specs)), key=rank_output + ) + sorted_outputs, output_specs = zip(*(i for idx, i in sorted_outs)) # type: ignore[assignment] + + sorted_graph, replace_table = _canonicalize_graph( + sorted_inputs, sorted_outputs, graph, constants + ) + + def replace_input(spec): + assert isinstance(spec, InputSpec) + if spec.type == "user_input": + arg = spec.user_input.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type == "as_sym_float": + f = arg.as_sym_float + if f.type == "as_name": + f.as_name = replace_table[f.as_name] + elif f.type == "as_float": + pass + else: + raise AssertionError(f"Unknown sym_float type: {f}") + elif arg.type in ( + "as_none", + "as_bool", + "as_int", + "as_float", + "as_string", + "as_custom_obj", + ): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "parameter": + t = spec.parameter.arg + t.name = replace_table[t.name] + elif spec.type == "buffer": + t = spec.buffer.arg + t.name = replace_table[t.name] + elif spec.type == "tensor_constant": + t = spec.tensor_constant.arg + t.name = replace_table[t.name] + elif spec.type == "custom_obj": + t_custom_obj = spec.custom_obj.arg + t_custom_obj.name = replace_table[t_custom_obj.name] + return + elif spec.type == "token": + tok = spec.token.arg + tok.name = replace_table[tok.name] + elif spec.type == "constant_input": + return + else: + raise AssertionError(f"Unknown input type: {spec}") + + def replace_output(out): + assert isinstance(spec, OutputSpec) + if spec.type == "user_output": + arg = spec.user_output.arg + if arg.type == "as_tensor": + t = arg.as_tensor + t.name = replace_table[t.name] + elif arg.type == "as_sym_int": + s = arg.as_sym_int + if s.type == "as_name": + s.as_name = replace_table[s.as_name] + elif s.type == "as_int": + pass + else: + raise AssertionError(f"Unknown sym_int type: {s}") + elif arg.type == "as_sym_float": + f = arg.as_sym_float + if f.type == "as_name": + f.as_name = replace_table[f.as_name] + elif f.type == "as_float": + pass + else: + raise AssertionError(f"Unknown sym_float type: {f}") + elif arg.type in ("as_none", "as_bool", "as_int", "as_float", "as_string"): + return + else: + raise AssertionError(f"Unknown input type: {arg}") + elif spec.type == "loss_output": + t = spec.loss_output.arg + t.name = replace_table[t.name] + elif spec.type == "buffer_mutation": + t = spec.buffer_mutation.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_parameter": + t = spec.gradient_to_parameter.arg + t.name = replace_table[t.name] + elif spec.type == "gradient_to_user_input": + g = spec.gradient_to_user_input + g.arg.name = replace_table[g.arg.name] + g.user_input_name = replace_table[g.user_input_name] + elif spec.type == "user_input_mutation": + u = spec.user_input_mutation + u.arg.name = replace_table[u.arg.name] + u.user_input_name = replace_table[u.user_input_name] + elif spec.type == "token": + tok = spec.token.arg + tok.name = replace_table[tok.name] + else: + raise AssertionError(f"Unknown output type: {spec}") + + for spec in input_specs: + replace_input(spec) + + for spec in output_specs: + replace_output(spec) + + return ExportedProgram( + graph_module=GraphModule( + graph=sorted_graph, + signature=GraphSignature( + input_specs=list(input_specs), + output_specs=list(output_specs), + ), + module_call_graph=module_call_graph, + ), + opset_version=opset_version, + range_constraints=range_constraints, + schema_version=ep.schema_version, + verifiers=ep.verifiers, + torch_version=ep.torch_version, + ) + + +class ExtensionHandler: + """ + Base class for handling extension operators. + """ + + @classmethod + def namespace(cls) -> str: + raise NotImplementedError(f"{cls.__class__} namespace() must be implemented") + + @classmethod + def to_op_name(cls, op) -> str: + raise NotImplementedError(f"{cls.__class__} op_name() must be implemented") + + @classmethod + def from_op_name(cls, name: str): + raise NotImplementedError(f"{cls.__class__} op_name() must be implemented") + + @classmethod + def op_schema(cls, op) -> torch.FunctionSchema: + raise NotImplementedError(f"{cls.__class__} op_schema() must be implemented") + + +def register_extension( + op_type: type[Any], + extension_handler: type[ExtensionHandler], +): + """Register custom de/serialization method for a node with non-standard type.""" + assert issubclass( + extension_handler, ExtensionHandler + ), f"Expected ExtensionHandler, got {extension_handler}." + assert op_type not in _serialization_registry, f"{op_type} is already registered." + assert isinstance(op_type, type) # Maybe a good idea to enforce this first. + assert not ( + op_type.__module__.startswith("torch") + or op_type.__module__.startswith("builtins") + ) + assert extension_handler.namespace() not in _deserialization_registry + _serialization_registry[op_type] = extension_handler + _deserialization_registry[extension_handler.namespace()] = extension_handler + + +def _registered_extension_types(): + return tuple(_serialization_registry.keys()) + + +# Registry to store all custom serialization implementations. +# The registry maps a operation to its serialization function (a callable), in their own +# namespace to avoid conflicts. +# Serialization: Op type --> custom handler. +# De-serialization: Namespace --> custom handler. +_serialization_registry: dict[type[Any], type[ExtensionHandler]] = {} +_deserialization_registry: dict[str, type[ExtensionHandler]] = {} diff --git a/phivenv/Lib/site-packages/torch/_export/serde/union.py b/phivenv/Lib/site-packages/torch/_export/serde/union.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4984d2f86e7720b47e2b3c4c729676a10bdd76 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/serde/union.py @@ -0,0 +1,71 @@ +# mypy: allow-untyped-defs +import functools +from collections.abc import Hashable +from dataclasses import fields + + +class _UnionTag(str): + __slots__ = ("_cls",) + _cls: Hashable + + @staticmethod + def create(t, cls): + tag = _UnionTag(t) + assert not hasattr(tag, "_cls") + tag._cls = cls + return tag + + def __eq__(self, cmp) -> bool: + assert isinstance(cmp, str) + other = str(cmp) + assert other in _get_field_names( + self._cls + ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" + return str(self) == other + + def __hash__(self): + return hash(str(self)) + + +@functools.cache +def _get_field_names(cls) -> set[str]: + return {f.name for f in fields(cls)} + + +class _Union: + _type: _UnionTag + + @classmethod + def create(cls, **kwargs): + assert len(kwargs) == 1 + obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type] + obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls) + return obj + + def __post_init__(self): + assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc] + + @property + def type(self) -> str: + try: + return self._type + except AttributeError as e: + raise RuntimeError( + f"Please use {type(self).__name__}.create to instantiate the union type." + ) from e + + @property + def value(self): + return getattr(self, self.type) + + def __getattribute__(self, name): + attr = super().__getattribute__(name) + if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type] + raise AttributeError(f"Field {name} is not set.") + return attr + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return f"{type(self).__name__}({self.type}={getattr(self, self.type)})" diff --git a/phivenv/Lib/site-packages/torch/_export/tools.py b/phivenv/Lib/site-packages/torch/_export/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..0e20be2814d12ed3530fc593b67ddb6b9d44a0a9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/tools.py @@ -0,0 +1,147 @@ +# mypy: allow-untyped-defs +import logging +import warnings +from collections.abc import Iterable +from typing import Any, Optional + +import torch +import torch.export +import torch.export._trace +from torch._utils_internal import log_export_usage + + +log = logging.getLogger(__name__) + +__all__ = ["report_exportability"] + + +def _generate_inputs_for_submodules( + model: torch.nn.Module, + target_submodules: Iterable[str], + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, +) -> dict[str, tuple[Any, Any]]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + kwargs = kwargs or {} + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_args, module_kwargs): + results[submodule_to_names[module]] = (module_args, module_kwargs) + + try: + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append( + mod.register_forward_pre_hook(pre_forward, with_kwargs=True) + ) + model(*args, **kwargs) + except Exception as e: + warnings.warn( + f"Failed to generate submodule inputs because of the following error:\n{e}" + ) + finally: + for h in handles: + h.remove() + return results + + +def report_exportability( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + strict: bool = True, + pre_dispatch: bool = False, +) -> dict[str, Optional[Exception]]: + """ + Report exportability issues for a module in one-shot. + + Args: + mod: root module. + args: args to the root module. + kwargs: kwargs to the root module. + Returns: + A dict that maps from submodule name to the exception that was raised when trying to export it. + `None` means the module is exportable without issue. + Sample output: + { + '': UnsupportedOperatorException(func=), + 'submod_1': UnsupportedOperatorException(func=), + 'submod_2': None + } + """ + + log_export_usage(event="export.report_exportability") + + kwargs = kwargs or {} + + all_submod_names = [name for name, _ in mod.named_modules() if name != ""] + submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) + + tried_module_types = set() + report: dict[str, Optional[Exception]] = {} + + def try_export(module, module_name, args, kwargs): + nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types + + if type(module) in tried_module_types: + return + tried_module_types.add(type(module)) + + if args is not None or kwargs is not None: + try: + torch.export._trace._export( + module, + args, + kwargs, + strict=strict, + pre_dispatch=pre_dispatch, + ) + report[module_name] = None + log.info("Successfully exported `%s`", module_name) + return + except Exception as e: + short_msg = repr(e).split("\n")[0] + log.warning( + "Failed exporting `%s` with exception: %s", module_name, short_msg + ) + report[module_name] = e + + for name, submod in module.named_children(): + sub_module_name = name if module_name == "" else f"{module_name}.{name}" + + submod_args, submod_kwargs = submod_inputs.get( + sub_module_name, (None, None) + ) + + try_export(submod, sub_module_name, submod_args, submod_kwargs) + + return + + try_export(mod, "", args, kwargs) + + unique_issues = set() + for exception in report.values(): + if exception is not None: + key = repr(exception).split("\\n")[0] + unique_issues.add(key) + + log.warning("Found %d export issues:", len(unique_issues)) + for issue in unique_issues: + log.warning(issue) + + return report diff --git a/phivenv/Lib/site-packages/torch/_export/utils.py b/phivenv/Lib/site-packages/torch/_export/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc3e3ab9fc3c582be52481c75e7e6cb968051405 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/utils.py @@ -0,0 +1,1510 @@ +# mypy: allow-untyped-defs +import ast +import copy +import dataclasses +import functools +import inspect +import json +import math +import operator +import re +from collections.abc import Iterable +from contextlib import contextmanager +from inspect import ismethod, Parameter +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx._utils import first_call_function_nn_module_stack +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts + + +if TYPE_CHECKING: + from torch._export.passes.lift_constants_pass import ConstantAttrMap + from torch._ops import OperatorBase + from torch.export import ExportedProgram + from torch.export.graph_signature import ExportGraphSignature + +from torch.export.graph_signature import CustomObjArgument, InputKind, OutputKind +from torch.fx._pytree import ( + _deregister_pytree_flatten_spec, + register_pytree_flatten_spec, +) +from torch.utils._pytree import ( + _deregister_pytree_node, + _register_pytree_node, + Context, + FlattenFunc, + FromDumpableContextFn, + GetAttrKey, + KeyPath, + keystr, + MappingKey, + SequenceKey, + ToDumpableContextFn, + tree_flatten_with_path, + UnflattenFunc, +) + + +placeholder_prefixes = { + InputKind.USER_INPUT: "", + InputKind.PARAMETER: "p_", + InputKind.BUFFER: "b_", + InputKind.CONSTANT_TENSOR: "c_", + InputKind.CUSTOM_OBJ: "obj_", + InputKind.TOKEN: "token", +} + +_DISABLE_ATEN_TO_ASSERTION_PASS = False + + +def _collect_and_set_constant_attrs( + graph_signature, constants, mod +) -> "ConstantAttrMap": + # the exported module will store constants & non-persistent buffers such that + # retracing treats them as persistent buffers, so we inform the constants lifting pass + # and overwrite the new graph signature using the previous program. This is intended to only be used + # in run_decompositions where we still have access to original EP. + from torch._export.passes.lift_constants_pass import ConstantAttrMap + + constant_attrs = ConstantAttrMap() + non_persistent_buffers = { + spec.target + for spec in graph_signature.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + for name, value in constants.items(): + if name in non_persistent_buffers: + continue + # recursive getattr + _mod = mod + *atoms, attr = name.split(".") + for atom in atoms: + _mod = getattr(_mod, atom) + # remove as buffer, reassign as constant/non-persistent buffer + _mod._buffers.pop(attr, None) + setattr(_mod, attr, value) + constant_attrs.add(value, name) + return constant_attrs + + +def _register_constants_as_buffers( + mod: torch.fx.GraphModule, state_dict, non_persistent_buffers +): + # TODO some annoying circular dependency issue + from torch.export.unflatten import _assign_attr, _AttrKind + + temp_registered_constants = set() + + for node in mod.graph.nodes: + if node.op == "get_attr": + target = torch.fx.graph_module._get_attr(mod, node.target) + if isinstance(target, torch.Tensor): + # Make sure we also check if the original buffer is + # non persistent as well. + if (node.target not in state_dict) and ( + node.target not in non_persistent_buffers + ): + torch.fx.graph_module._del_attr(mod, node.target) + _assign_attr(target, mod, node.target, _AttrKind.BUFFER, False) + temp_registered_constants.add(node.target) + + mod.recompile() + + return temp_registered_constants + + +def _override_graph_signature_for_temp_registered_constants( + sig: "ExportGraphSignature", temp_registered_constants +): + for spec in sig.input_specs: + if spec.target in temp_registered_constants: + spec.kind = InputKind.CONSTANT_TENSOR + spec.persistent = None + + for spec in sig.output_specs: + if ( + spec.kind == OutputKind.BUFFER_MUTATION + and spec.target in temp_registered_constants + ): + raise RuntimeError( + f"Constant {spec.target} is mutated in the forward method. Pls register it as buffer" + ) + + return sig + + +def _overwrite_signature_for_non_persistent_buffers( + old_sig: "ExportGraphSignature", new_sig: "ExportGraphSignature" +): + # overwrite signature for non-persistent buffers + non_persistent_buffers = { + spec.target + for spec in old_sig.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + + for spec in new_sig.input_specs: + if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers: + spec.persistent = False + return new_sig + + +def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> dict[str, Any]: + """ + Param/buffer metadata needs to be saved before lowering to aten IR + because aten IR lifts them, as a result, automatic preservation doesn't work. + This is intended to be called on the strict mode tracing right before lowering to + aten IR OR run_decomposition pass. + """ + params_buffers_to_node_meta = {} + + def _getattr(model: torch.fx.GraphModule, attr_name: str): + *prefix, field = attr_name.split(".") + t = model + for item in prefix: + t = getattr(t, item, None) # type: ignore[assignment] + assert t is not None + + return getattr(t, field) + + for node in mod.graph.nodes: + target = node.target + meta = node.meta + if node.op == "call_module": + submodule = _getattr(mod, target) + if isinstance(submodule, torch.nn.Module): + for name, _ in submodule.named_parameters( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + for name, _ in submodule.named_buffers( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + if node.op == "get_attr": + submodule = _getattr(mod, target) + if not isinstance(submodule, torch.fx.GraphModule): + params_buffers_to_node_meta[target] = meta + + # If the call_function uses param as input, we also need to update params' meta + # with this call_function node's meta. + # This is basically the same flow as torch.fx.traceback.preserve_meta() + if node.op == "call_function" and not isinstance( + node.target, torch._ops.HigherOrderOperator + ): + for arg in node._input_nodes: + if arg.op == "get_attr": + for entry in torch.fx.proxy._COPY_META_FIELDS: + # the custom field should not be copied + if entry == "custom": + continue + if entry in meta: + params_buffers_to_node_meta[arg.target][entry] = meta[entry] + + return params_buffers_to_node_meta + + +def _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta: dict[str, Any], + gm: torch.fx.GraphModule, + new_sig: "ExportGraphSignature", +) -> None: + """ + Given that we collected param'buffer metadata before, we put them back in + newly traced graph module + """ + # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes + for metadata in params_buffers_to_node_meta.values(): + metadata.pop("nn_module_stack", None) + metadata.pop("stack_trace", None) + + for node in gm.graph.nodes: + if node.op == "placeholder": + if node.target in new_sig.inputs_to_parameters: + param_name = new_sig.inputs_to_parameters[node.target] + if param_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[param_name].items(): + node.meta[k] = v + if node.target in new_sig.inputs_to_buffers: + buffer_name = new_sig.inputs_to_buffers[node.target] + if buffer_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[buffer_name].items(): + node.meta[k] = v + + +def _get_shape_env_from_gm(gm: torch.fx.GraphModule): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + + fake_mode = _detect_fake_mode_from_gm(gm) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + +def _rename_without_collisions( + name_map: dict[str, str], + orig_name: str, + name: str, + is_placeholder: bool = False, +): + """ + Renames nodes to avoid name collisions, with suffixing. + name_map: map from original name to new name + orig_name: mapping key + name: candidate name (potentially suffixed, e.g. mul_2) + is_placeholder: if the node is a placeholder, avoid detecting suffix + """ + if name in name_map.values(): + # non-placeholder nodes may be suffixed with the count + # instead of adding another suffix, we will try to increment it + match = re.match(r"(.*)_(\d+)", name) + if match and not is_placeholder: + name, n = match.group(1), int(match.group(2)) + else: + n = 0 + while (dup_name := f"{name}_{n + 1}") in name_map.values(): + n += 1 + name_map[orig_name] = dup_name + else: + name_map[orig_name] = name + return name_map[orig_name] + + +def get_keystr(key_path: KeyPath) -> str: + """For a given index into the flat_args, return a human readable string + describing how to access it, e.g. "*args["foo"][0].bar" + """ + # Prefix the keypath with "*args" or "**kwargs" to make it clearer where + # the arguments come from. Ultimately we ought to serialize the + # original arg names for the best error message here. + args_kwargs_key_path = key_path[0] + assert isinstance(args_kwargs_key_path, SequenceKey) + if args_kwargs_key_path.idx == 0: + return f"*args{keystr(key_path[1:])}" + else: + kwarg_key = key_path[1] + assert isinstance(kwarg_key, MappingKey) + name = str(kwarg_key)[1:-1] # get rid of the enclosed [] + return f"{name}{keystr(key_path[2:])}" + + +def _check_symint( + symint: Union[int, torch.SymInt], + arg: int, + range_constraints, + unification_map, + keypath: KeyPath, + i: Optional[int] = None, +) -> None: + from torch.export.dynamic_shapes import _IntWrapper + + if ( + isinstance(arg, torch.SymInt) + and not arg.node.expr.is_number + or isinstance(arg, _IntWrapper) + ): + # This can happen when, say, arg is a fake tensor. + # We do not run checks on symbolic shapes of fake inputs as + # such checks can affect the shape env. + return + + import sympy + + from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + _convert_range_to_int, + ) + from torch.utils._sympy.solve import try_solve + + if isinstance(symint, torch.SymInt) and len(symint.node.expr.free_symbols) == 1: + symbol = next(iter(symint.node.expr.free_symbols)) + if symbol in unification_map: + existing_dim = symint.node.expr.subs(unification_map) + if arg != existing_dim: + path = get_keystr(keypath) + if i is not None: + path += f".shape[{i}]" + raise RuntimeError( + f"Expected input at {path} to be equal to {existing_dim}, but got {arg}", + ) + else: + if isinstance(symint.node.expr, sympy.Symbol): + # Short cut for try_solve below. Also useful in cases where + # sympy.Eq(symint.node.expr, arg) would evaluate to False + # purely because symbol is constrained to be size-like, + # e.g., when symint.node.expr = symbol and arg = 0. + unification_map[symbol] = int(arg) + else: + solution = try_solve(sympy.Eq(symint.node.expr, arg), symbol) + if solution is None: + path = get_keystr(keypath) + if i is not None: + path += f".shape[{i}]" + raise RuntimeError( # noqa: B904 + f"Expected input {path} = {arg} to be " + f"of the form {symint.node.expr}, where {symbol} is an integer" + ) + else: + unification_map[symbol] = int(solution[1]) + + if symint.node.expr in range_constraints: + min_val, max_val = _convert_range_to_int( + range_constraints[symint.node.expr] + ) + # NOTE: we allow dimensions to be 0/1 at runtime + if min_val > 2: + if arg < min_val: + path = get_keystr(keypath) + if i is not None: + path += f".shape[{i}]" + raise RuntimeError( + f"Expected input at {path} to be >= {min_val}, but got {arg}", + ) + if max_val < math.inf: + if arg > max_val: + path = get_keystr(keypath) + if i is not None: + path += f".shape[{i}]" + raise RuntimeError( + f"Expected input at {path} to be <= {max_val}, but got {arg}", + ) + elif isinstance(symint, torch.SymInt) and not symint.node.expr.is_number: + # this means we deferred a guard from export analysis to runtime, let this pass + # we'll add a runtime assert checking equality to this replacement expression + pass + elif arg != symint: + path = get_keystr(keypath) + if i is not None: + path += f".shape[{i}]" + raise RuntimeError( + f"Expected input at {path} to be equal to {symint}, but got {arg}. " + "If you meant for this dimension to be dynamic, please re-export and specify dynamic_shapes " + "(e.g. with Dim.DYNAMIC)" + ) + + +def _check_input_constraints_for_graph( + input_placeholders: list[torch.fx.Node], flat_args_with_path, range_constraints +) -> None: + import sympy # noqa: TC002 + + if len(flat_args_with_path) != len(input_placeholders): + raise RuntimeError( + "Unexpected number of inputs " + f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})" + ) + # NOTE: export already guarantees that the same symbol is used in metadata + # for all InputDims related by equality constraints, so we can just unify + # symbols with given input dimension values to check equality constraints. + unification_map: dict[sympy.Symbol, Any] = {} + for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): + node_val = node.meta.get("val") + if isinstance(node_val, FakeTensor): + if not isinstance(arg, torch.Tensor): + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}", + ) + + if len(node_val.shape) != len(arg.shape): + raise RuntimeError( + f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape " + f"(expected {node_val.shape}, got {arg.shape})" + ) + + for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)): + _check_symint( + node_dim, arg_dim, range_constraints, unification_map, key_path, j + ) + + elif isinstance(node_val, (int, float, str)): + if type(arg) != type(node_val) or arg != node_val: + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", + ) + elif isinstance(node_val, torch.SymInt): + _check_symint( + node_val, arg, range_constraints, unification_map, key_path, None + ) + + +def register_dataclass_as_pytree_node( + cls: type[Any], + flatten_fn: Optional[FlattenFunc] = None, + unflatten_fn: Optional[UnflattenFunc] = None, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + return_none_fields: bool = False, +) -> None: + assert dataclasses.is_dataclass( + cls + ), f"Only dataclasses can be registered with this function: {cls}" + + def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]: + flattened = [] + flat_names = [] + none_names = [] + for f in dataclasses.fields(obj): + name, val = f.name, getattr(obj, f.name) + if val is not None or return_none_fields: + flattened.append(val) + flat_names.append(name) + else: + none_names.append(name) + return flattened, [flat_names, none_names] + + def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, none_names = context + return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) + + def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: + flattened, (flat_names, _none_names) = flatten_fn(obj) # type: ignore[misc] + return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + + flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn + unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + _register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + flatten_with_keys_fn=default_flatten_fn_with_keys, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def is_param(program: "ExportedProgram", node: torch.fx.Node) -> bool: + """ + Checks if the given node is a parameter within the exported program + """ + + return node.name in program.graph_signature.inputs_to_parameters + + +def get_param( + program: "ExportedProgram", + node: torch.fx.Node, +) -> Optional[torch.nn.Parameter]: + """ + Returns the parameter associated with the given node in the exported program. + Returns None if the node is not a parameter within the exported program + """ + + if is_param(program, node): + parameter_name = program.graph_signature.inputs_to_parameters[node.name] + return program.state_dict[parameter_name] + + return None + + +def is_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool: + """ + Checks if the given node is a buffer within the exported program + """ + + return node.name in program.graph_signature.inputs_to_buffers + + +def get_buffer( + program: "ExportedProgram", + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the buffer associated with the given node in the exported program. + Returns None if the node is not a buffer within the exported program + """ + + if is_buffer(program, node): + buffer_name = program.graph_signature.inputs_to_buffers[node.name] + if buffer_name in program.graph_signature.non_persistent_buffers: + return program.constants[buffer_name] + else: + return program.state_dict[buffer_name] + + return None + + +def is_lifted_tensor_constant( + program: "ExportedProgram", + node: torch.fx.Node, +) -> bool: + """ + Checks if the given node is a lifted tensor constant within the exported program + """ + + return node.name in program.graph_signature.inputs_to_lifted_tensor_constants + + +def get_lifted_tensor_constant( + program: "ExportedProgram", + node: torch.fx.Node, +) -> Optional[torch.Tensor]: + """ + Returns the lifted tensor constant associated with the given node in the exported program. + Returns None if the node is not a lifted tensor constant within the exported program + """ + + if is_lifted_tensor_constant(program, node): + lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + return program.constants[lifted_tensor_name] + + return None + + +def sequential_split( + gm: torch.fx.GraphModule, + node_call_back: Callable[[torch.fx.Node], Union[torch.fx.Node, bool]], +) -> torch.fx.GraphModule: + """ + sequential_split creates a new graph module that splits the input graph module into multiple submodules + based on the node_call_back. It doesn't mutate the input graph module. The node_call_back should return + True if the node is a delimiter. Delimiter will be the first node in the next submodule. + """ + from torch.fx.passes.split_module import split_module + + split_map = {} + split_id = 0 + for node in gm.graph.nodes: + if node_call_back(node): + split_id += 1 + split_map[node] = split_id + + new_gm = split_module( + gm, + gm, + lambda node: split_map[node], + keep_original_order=True, + keep_original_node_name=True, + ) + # Keep the codegen from original graph module to preserve e.g. pytree info. + new_gm.graph._codegen = gm.graph._codegen + new_gm.recompile() + return new_gm + + +def nodes_filter(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]: + """Returns the nodes that match the node_call_back as a list.""" + return [node for node in nodes if node_call_back(node)] + + +@contextmanager +def _disable_aten_to_metadata_assertions(): + global _DISABLE_ATEN_TO_ASSERTION_PASS + orig_val = _DISABLE_ATEN_TO_ASSERTION_PASS + _DISABLE_ATEN_TO_ASSERTION_PASS = True + try: + yield + finally: + _DISABLE_ATEN_TO_ASSERTION_PASS = orig_val + + +def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None: + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + + if _DISABLE_ATEN_TO_ASSERTION_PASS: + return + + aten_to_variants = [ + torch.ops.aten.to.device, + torch.ops.aten.to.dtype, + torch.ops.aten.to.dtype_layout, + ] + for node in gm.graph.nodes: + if node.target in aten_to_variants: + if ( + node.prev.target == torch.ops.aten._assert_tensor_metadata.default + and node.args[0] == node.prev.args[0] + ): + # skip if already guarded + continue + + if (tensor_val := node.args[0].meta.get("val")) is not None: + with gm.graph.inserting_before(node), _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + ), + ): + gm.graph.call_function( + torch.ops.aten._assert_tensor_metadata.default, + args=(node.args[0],), + kwargs={ + "dtype": tensor_val.dtype, + "device": tensor_val.device, + "layout": tensor_val.layout, + }, + ) + + +def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature): + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names + + if not torch._dynamo.config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + shape_env = _get_shape_env_from_gm(gm) + if shape_env: + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + + # insert runtime assertions for aten.to nodes + _insert_aten_to_metadata_assert_pass(gm) + + # update output specs + gm.recompile() + graph_signature.user_outputs = _graph_output_names(gm) + return gm, graph_signature + + +def nodes_first( + nodes: list[torch.fx.Node], node_call_back=None +) -> Optional[torch.fx.Node]: + """ + Returns the first node that matches the node_call_back. If no node matches, returns None. + When node_call_back is None, returns the first node in the node list. + """ + ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True) + if len(ret) > 0: + return ret[0] + return None + + +def nodes_count(nodes: list[torch.fx.Node], node_call_back) -> int: + """Returns the number of nodes that match the node_call_back.""" + return len(nodes_filter(nodes, node_call_back)) + + +def nodes_map(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.Node]: + """ + Sequentially visit the nodes list and invoke node_call_back on each element. + Returns the nodes list after the node_call_back is invoked on each element. + """ + for node in nodes: + node_call_back(node) + return nodes + + +def node_replace_(old_node: torch.fx.Node, new_node: torch.fx.Node) -> None: + """ + Replace all uses of old_node with new_node. + """ + old_node.replace_all_uses_with(new_node) + old_node.users.clear() + old_node.graph.erase_node(old_node) + + +def _update_gm_meta_if_possible(gm: torch.fx.GraphModule, mod: torch.nn.Module) -> None: + if ( + isinstance(mod, torch.fx.GraphModule) + and hasattr(mod, "meta") + and "custom" in mod.meta + ): + gm.meta.update({"custom": mod.meta["custom"]}) + + +def node_inline_(call_mod_node: torch.fx.Node) -> Optional[torch.fx.GraphModule]: + """ + Inline the submodule of the given node into the parent module. + Note: we only support the case where submodule takes tensors inputs. + """ + assert call_mod_node.op == "call_module" + gm = call_mod_node.graph.owning_module + assert gm is not None + + assert isinstance(call_mod_node.target, str) + sub_gm = getattr(gm, call_mod_node.target) + + phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder") + body = ( + node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output") + ) + output = [node for node in sub_gm.graph.nodes if node.op == "output"] + + for ph, arg in zip(phs, call_mod_node.args): + assert isinstance(arg, torch.fx.Node) + node_replace_(ph, arg) + + with gm.graph.inserting_before(call_mod_node): + for node in body: + new_node = gm.graph.node_copy(node) + if node.op == "get_attr": + new_target_name = new_node.target + if hasattr(gm, new_target_name): + # Loop through and find the "submod_{i}" that have no name collision + i = 1 + new_target_name = f"submod_{i}" + while hasattr(gm, new_target_name): + i += 1 + new_target_name = f"submod_{i}" + new_node.target = new_target_name + setattr(gm, new_node.target, getattr(sub_gm, node.target)) + node_replace_(node, new_node) + + if len(output) > 0: + assert len(output) == 1 and len(output[0].args) == 1 + new_output = output[0].args[0] + + if isinstance(new_output, torch.fx.Node): + # Clear the users of the output node and set + # the users to be the users of original call_module node. + new_output.users.clear() + node_replace_(call_mod_node, new_output) + elif isinstance(new_output, (list, tuple)): + # Pop subgraph output node from users. + for node in new_output: + node.users.pop(output[0]) + + # Inline the get_item calls for the output node. + get_item_users = nodes_filter( + list(call_mod_node.users.keys()), + lambda node: node.op == "call_function" + and node.target == operator.getitem, + ) + # get_item_node.args[1] is the idx referring to new_output[idx] + nodes_map( + get_item_users, + lambda get_item_node: node_replace_( + get_item_node, + new_output[get_item_node.args[1]], + ), + ) + call_mod_node.graph.erase_node(call_mod_node) + else: + raise NotImplementedError( + f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes." + ) + else: + call_mod_node.graph.erase_node(call_mod_node) + + gm.delete_all_unused_submodules() + gm.recompile() + return gm + + +def _get_torch_jit_trace_forward_signature(mod: torch.nn.Module) -> inspect.Signature: + """ + Get source code and parse argument names using AST. The function returns + a signature of the forward() function. + + # TODO: Directly provide inspect.signature compatible TS-d module. + """ + ast_mod = ast.parse(mod.code) # type: ignore[call-overload] + ast_func_def: ast.FunctionDef = ast_mod.body[0] + + # FIXME(jiashenc): TorchScript should only allow positional or keywords arguments. + arg_type_map = {"args": Parameter.POSITIONAL_OR_KEYWORD} + + # Traverse all argument types in AST tree and create associated parameters. + param_list = [] + for arg_type, param_type in arg_type_map.items(): + arg_name_list = [a.arg for a in getattr(ast_func_def.args, arg_type)] + for arg_name in arg_name_list: + if arg_name == "self": + continue # Skip self argument. + param_list.append(inspect.Parameter(arg_name, param_type)) + + return inspect.Signature(parameters=param_list) + + +def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): + if isinstance(mod, (torch.jit.ScriptModule, torch.jit.TracedModule)): + sig = _get_torch_jit_trace_forward_signature(mod) + + # Sanity check for placeholder names coming from TorchScript. + assert len(sig.parameters) == len(fake_args) + len(fake_kwargs), ( + "Arguments other than POSITIONAL_OR_KEYWORD kinds in forward() " + "are not supported in _get_torch_jit_trace_forward_signature" + ) + else: + sig = inspect.signature(mod.forward) + + # Rather than binding both fake_args and fake_kwargs to sig names, we + # (partially) bind only fake_args, while reusing fake_kwarg names. This + # ensures that fake_kwargs do not get reordered, which is important to + # match flattened user inputs. + return {**sig.bind_partial(*fake_args).arguments, **fake_kwargs} + + +def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: + """ + Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, + and handle collisions with non-placeholders by count suffixing. + Different HOO subgraph types have different input schemas, so we first enumerate them + and gather the top-level named placeholder nodes. + """ + # gather all HOO subgraphs and their top-level named placeholder nodes + subgraph_ph_tuples: list[tuple[torch.fx.GraphModule, list[torch.fx.Node]]] = [] + for node in gm.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.HigherOrderOperator + ): + # HOO subgraphs have varying input schemas, so we enumerate them there + if node.target._name == "cond": + _, true_graph, false_graph, cond_args = node._args + subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) + subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) + elif node.target._name == "wrap_with_set_grad_enabled": + subgraph, phs = node._args[1], node._args[2:] + subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) + elif node.target._name == "map_impl": + body_graph, array, args = node._args + subgraph_ph_tuples.append( + (getattr(gm, body_graph.target), array + args) + ) + + # propagate names + for subgraph, hoo_phs in subgraph_ph_tuples: + name_map: dict[str, str] = {} + for i, node in enumerate(subgraph.graph.nodes): + if i < len(hoo_phs): # placeholder, retain name + name_map[node.name] = hoo_phs[i].name + node.name = node.target = hoo_phs[i].name + else: # non-placeholder, check for collisions + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # recurse and recompile + _name_hoo_subgraph_placeholders(subgraph) + subgraph.recompile() + + +def placeholder_naming_pass( + gm: torch.fx.GraphModule, + export_graph_signature: "ExportGraphSignature", + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constants: dict[str, Any], +) -> None: + """ + This pass is run at the end of _export_non_strict() to assign better placeholder node names: + - User inputs: + These follow the signature of mod.forward(), e.g. forward(x, y) produces nodes x, y. + For nested inputs from dictionaries, lists, tuples, or dataclasses, + the names are a concatenation of the path to the tensor. + e.g. x = { + 'a': torch.randn(), + 'b': [torch.randn(), torch.randn()] + } + produces nodes x_a, x_b_0, x_b_1. + - Parameters/buffers/constants/custom objects: + These follow the FQN of the object, prefixed by "p", "b", "c", "obj" respectively. + e.g. self.bar.l0.weight produces "p_bar_l0_weight". + - Effect tokens: + These are named token, token_1, ... + """ + + custom_meta: dict[str, Any] = {} + if isinstance(mod, torch.fx.GraphModule): + for node in mod.graph.nodes: + if "custom" in node.meta: + custom_meta[node.name] = node.meta["custom"] + + def _strip_name(x): + if x.startswith("L__self___"): + x = x[len("L__self___") :] + elif x.startswith("self_"): + x = x[len("self_") :] + x = re.sub(r"[^a-zA-Z0-9]", "_", x) + return x + + def _extract_pytree_key(x): + if isinstance(x, MappingKey): + x = re.sub(r"[^a-zA-Z0-9]", "_", str(x.key)) + return x + elif isinstance(x, SequenceKey): + return str(x.idx) + elif isinstance(x, GetAttrKey): + return x.name + else: + raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") + + name_map: dict[str, str] = {} + + # map user input names with mod.forward() signature + combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) + + flat_args_with_path, _ = tree_flatten_with_path(combined_args) + user_input_names = [ + spec.arg.name + for spec in export_graph_signature.input_specs + if spec.kind == InputKind.USER_INPUT + ] + + # use pytree path to name nested user inputs + for (arg_path, _arg), user_input_name in zip(flat_args_with_path, user_input_names): + if user_input_name: + _rename_without_collisions( + name_map, + user_input_name, + placeholder_prefixes[InputKind.USER_INPUT] + + "_".join(_extract_pytree_key(x).lower() for x in arg_path), + is_placeholder=True, + ) + + # use graph signature input specs to map param/buffer/constant names + # name effect tokens as token, token_1, ... (these aren't visible to user) + for spec in export_graph_signature.input_specs: + if spec.kind == InputKind.USER_INPUT: + continue + if spec.kind == InputKind.TOKEN: + base_name = "" + else: + base_name = _strip_name(spec.target).lower() + base_name = re.sub(r"[^a-zA-Z0-9]", "_", base_name) + + _rename_without_collisions( + name_map, + spec.arg.name, + placeholder_prefixes[spec.kind] + base_name, + is_placeholder=True, + ) + if base_name in custom_meta: + # the keys in custom_meta are node names from `mod`, + # which is the base_name here. + # we need the re-mapped name for lookup later + custom_meta[name_map[spec.arg.name]] = custom_meta[base_name] + del custom_meta[base_name] + + # handle naming collisions with call_function/get_attr inputs. + # here, we want to prioritize user input names over call_function names + # e.g. not have forward(self, mul): lead to a placeholder node called mul_13, + # so we increment the suffix of call_function nodes as needed + for node in gm.graph.nodes: + if node.op == "placeholder": + continue + _rename_without_collisions(name_map, node.name, node.name) + + # assign new node names + for node in gm.graph.nodes: + if node.op == "placeholder": + assert node.name in name_map + node.name = node.target = name_map[node.name] + if node.name in custom_meta: + if node.meta.get("custom") is None: + node.meta["custom"] = custom_meta[node.name] + else: + assert node.meta["custom"] == custom_meta[node.name] + # if the constant obj is an input, we also need to update meta["val"] + # because this is created before the placeholder naming pass + if isinstance(node.meta["val"], CustomObjArgument): + node.meta["val"].name = node.name + elif node.name in name_map: + node.name = name_map[node.name] + + # propagate names to higher order op subgraphs + _name_hoo_subgraph_placeholders(gm) + + # re-generate graph module code + gm.recompile() + + # modify graph signature (input specs, output specs, user input mutations) + for spec in export_graph_signature.input_specs: + assert spec.arg.name in name_map + spec.arg.name = name_map[spec.arg.name] + if ( # handle targets for custom objects + spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map + ): + spec.target = name_map[spec.target][4:] # strip obj_ prefix + + for spec in export_graph_signature.output_specs: + if spec.arg.name in name_map: + spec.arg.name = name_map[spec.arg.name] + if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map: + spec.target = name_map[spec.target] + + # rename keys in constants dict for custom objects + for name in list(constants.keys()): + constant = constants[name] + if name in name_map and not isinstance( + constant, torch.Tensor + ): # rename custom objects with generic names + new_name = name_map[name] + if ( + new_name != name + and re.match(r"arg(\d+)_1", name) + and new_name != placeholder_prefixes[InputKind.CUSTOM_OBJ] + name + ): + constants[new_name] = constant + del constants[name] + + +def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict: + """ + If `in_place` is false, return a new copy of `state_dict` with "proxy" removed from `v.__dict__`. + `v` is the values in the dictionary. + If `in_place` is true, modify `state_dict` in place. + """ + if in_place: + for k, v in state_dict.items(): + if hasattr(v, "proxy"): + delattr(state_dict[k], "proxy") + return state_dict + else: + new_state_dict = {} + for k, v in state_dict.items(): + if hasattr(v, "proxy"): + new_state_dict[k] = v.detach().clone() + else: + new_state_dict[k] = v + return new_state_dict + + +def _detect_fake_mode_from_gm( + gm: torch.fx.GraphModule, +) -> torch._subclasses.fake_tensor.FakeTensorMode: + """ + For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs. + Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes. + If no fake mode is found, we return None for fake_mode. + """ + + fake_inps: list[torch.Tensor] = [] + fake_vals: list[torch.Tensor] = [] + for node in gm.graph.nodes: + if node.op == "placeholder" and "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_inps.append(fake_val) + elif len(fake_inps) == 0 and ( + "example_value" in node.meta or "val" in node.meta + ): + fake_val = None + if "example_value" in node.meta: + fake_val = node.meta["example_value"] + elif "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_vals.append(fake_val) + + return detect_fake_mode(fake_inps + fake_vals) + + +@contextmanager +def _disable_load_state_dict_hooks(mod: torch.nn.Module): + state_dict_hooks: dict[int, Callable] = dict(mod._state_dict_hooks) + state_dict_pre_hooks: dict[int, Callable] = dict(mod._state_dict_pre_hooks) + mod._state_dict_hooks.clear() + mod._state_dict_pre_hooks.clear() + try: + yield + finally: + mod._state_dict_hooks = state_dict_hooks + mod._state_dict_pre_hooks = state_dict_pre_hooks + + +def _is_cia_op(op: "OperatorBase") -> bool: + return ( + torch._C._dispatch_has_kernel_for_dispatch_key( + op.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ) + or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels + ) + + +def _is_preservable_cia_op(op: "OperatorBase") -> bool: + return _check_valid_to_preserve(op) and _is_cia_op(op) + + +def _is_aten_op(op: "OperatorBase") -> bool: + return op.name().split("::")[0] == "aten" + + +def _is_custom_op(op: "OperatorBase") -> bool: + return not _is_aten_op(op) + + +# We can't cache this because custom op registry API in python can still +# add entries to the C++ dispatcher. +def _materialize_cpp_cia_ops() -> None: + """ + Utility function to query C++ dispatcher to get the all + possible CIA ops and populate them into torch.ops namespace + """ + cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key( + "CompositeImplicitAutograd" + ) + + # Materialize all CIA ops + for op in cia_ops: + namespace, op_name = tuple(op.split("::")) + split_list = op_name.split(".") + # Sometime overload could be missing + assert len(split_list) == 1 or len(split_list) == 2 + op_name = split_list[0] + op_overload_name = "default" + if len(split_list) == 2: + op_overload_name = split_list[1] + + _ = getattr(getattr(getattr(torch.ops, namespace), op_name), op_overload_name) + + +def _special_op_to_preserve_cia(*args, **kwargs): + """ + This is an special marker that tells our infra that we shouldn't decompose this op. + """ + return NotImplemented + + +# Our strategy for deciding if we can preserve a op is following: +# 1. The op should be known statically that it is functional +# 2. If it is maybe aliasing, we decompose because we must know if an op +# is mutating or aliasing. +def _check_valid_to_preserve(op_overload: "OperatorBase"): + from torch._decomp import _should_decompose_because_unsafe_op + + if _should_decompose_because_unsafe_op(op_overload): + return False + if op_overload in FunctionalTensor.metadata_fns: + return False + + if not hasattr(op_overload, "_schema"): + return False + + alias_info = len( + [i for i in op_overload._schema.arguments if i.alias_info is not None] + ) + + is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable + + if is_mutating_or_aliasing: + return False + + if not torch._C._dispatch_has_kernel(op_overload.name()): + return False + + return True + + +@functools.lru_cache(maxsize=1) +def _collect_all_valid_cia_ops_for_aten_namespace() -> set["OperatorBase"]: + return _collect_all_valid_cia_ops_for_namespace(torch.ops.aten) + + +def _collect_all_valid_cia_ops_for_namespace( + op_namespace: torch._ops._OpNamespace, +) -> set["OperatorBase"]: + # Step 1: Materialize all ops from C++ dispatcher + _materialize_cpp_cia_ops() + + # Step 2: Query all ops from python dispatcher + cia_ops = set() + for op in op_namespace: + op_packet = getattr(op_namespace, op) + for overload in op_packet.overloads(): + op_overload = getattr(op_packet, overload) + if _is_preservable_cia_op(op_overload): + cia_ops.add(op_overload) + return cia_ops + + +def _collect_all_valid_cia_ops() -> set["OperatorBase"]: + """ + This is an util function that gets the all CIA functional ops. + + The algorithm is in 2 steps: + 1. We first query C++ dispatcher to get the list of CIA ops + and then we call getattr on torch.ops.aten to lazily populate + them. + + 2. Sometimes, handful of ops have CIA registered in python dispatcher + but not on the C++ side, these can't be caught at the first step. + So we walk again to get the final list. + + Note that the output of this function should never be modified + """ + cia_ops = set() + for op_namespace_name in torch.ops._dir: + # The reason we split here is because aten ops are safe to cache. + if op_namespace_name != "aten": + assert hasattr(torch.ops, op_namespace_name) + op_namespace = getattr(torch.ops, op_namespace_name) + if isinstance(op_namespace, torch._ops._OpNamespace): + cia_ops |= _collect_all_valid_cia_ops_for_namespace(op_namespace) + else: + cia_ops |= _collect_all_valid_cia_ops_for_aten_namespace() + return cia_ops + + +def _get_decomp_for_cia(op: "OperatorBase"): + # [NOTE] Seperating out func.decompose + # Ideally we should be able to just register func.decompose but + # we can't as this decomp is gonna be registered to the py_impl. + # As a result it will infinitely recurse. So we first check if the op + # has py_impl entry for CIA and if it is we use that first. If not, + # we register C++ query to py_impl. + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey): + return op.py_kernels[dk] + + def _special_op_to_decompose_cia(*args, **kwargs): + kernel = kwargs["kernel"] + del kwargs["kernel"] + # Can't call kernel.decompose due to infinite recursion as + # we register this kernel to py_impl directly + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if torch._C._dispatch_has_kernel_for_dispatch_key( + kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + return kernel._op_dk(dk, *args, **kwargs) + else: + raise AssertionError( + f"Expected {kernel} to have CompositeImplicitAutograd kernel" + ) + + return functools.partial(_special_op_to_decompose_cia, kernel=op) + + +@contextmanager +def _compiling_state_context(): + old_compiling_flag = torch.compiler._is_compiling_flag + old_exporting_flag = torch.compiler._is_exporting_flag + try: + torch.compiler._is_compiling_flag = True + torch.compiler._is_exporting_flag = True + yield + finally: + torch.compiler._is_compiling_flag = old_compiling_flag + torch.compiler._is_exporting_flag = old_exporting_flag + + +def _fakify_params_buffers( + fake_mode: FakeTensorMode, + mod: torch.nn.Module, +) -> dict[str, Union[torch.Tensor, torch.nn.Parameter]]: + params_buffers = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + + faked_params_buffers = {} + memo: dict[int, FakeTensor] = {} + for key, value in params_buffers.items(): + if id(value) in memo: + fake_tensor = memo[id(value)] + else: + fake_tensor = fake_mode.from_tensor(value, static_shapes=True) + memo[id(value)] = fake_tensor + faked_params_buffers[key] = fake_tensor + return faked_params_buffers # type: ignore[return-value] + + +def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None: + """ + Registers a module as a valid input type for :func:`torch.export.export`. + + Args: + mod: the module instance + serialized_type_name: The serialized name for the module. This is + required if you want to serialize the pytree TreeSpec containing this + module. + + Example:: + + import torch + + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return self.linear(x) + + torch._export.utils.register_module_as_pytree_node(InputDataClass) + + class Mod(torch.nn.Module): + def forward(self, x, m): + return m(x) + x + + ep = torch.export.export(Mod(), (torch.randn(3), Module())) + print(ep) + + """ + assert issubclass(cls, torch.nn.Module) + + import weakref + + class PrototypeModule(weakref.ref): + def __init__(self, m, *args, **kwargs): + super().__init__(m, *args, **kwargs) # type: ignore[call-arg] + assert isinstance(m, torch.nn.Module) + assert not hasattr(self, "_proto_cls") + self._proto_cls = cls + + def __eq__(self, other): + return self._proto_cls == other._proto_cls + + def __deepcopy__(self, memo): + return PrototypeModule(self()) + + def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]: + named_parameters = dict(obj.named_parameters()) + named_buffers = dict(obj.named_buffers()) + params_buffers = {**named_parameters, **named_buffers} + return list(params_buffers.values()), [ + list(params_buffers.keys()), + PrototypeModule(obj), + ] + + def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, ref = context + if ref is None or ref() is None: + raise RuntimeError("Module has been garbage collected") + obj = ref() + assert flatten_fn is not None + flattened, _ = flatten_fn(obj) + + # NOTE: This helper function will replicate an nn.Module in the exactly same + # structure to be used together with _reparametrize_module. This will + # create a clone of the module with the new parameters and buffers without + # affecting the original module. + def copy_module(mod: torch.nn.Module): + ret = copy.copy(mod) + ret.__dict__ = {copy.copy(k): copy.copy(v) for k, v in mod.__dict__.items()} + for name, child in ret.named_children(): + setattr(ret, name, copy_module(child)) + return ret + + if any(v is not o for v, o in zip(values, flattened)): + with torch.nn.utils.stateless._reparametrize_module( + obj, dict(zip(flat_names, values)), tie_weights=True, strict=True + ): + ret = copy_module(obj) + else: + ret = obj + return ret + + def default_flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: + flattened, [flat_names, *args] = flatten_fn(obj) # type: ignore[misc] + return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], [ + flat_names, + *args, + ] + + flatten_fn = default_flatten_fn + unflatten_fn = default_unflatten_fn + + serialized_type_name = cls.__module__ + "." + cls.__qualname__ + + def to_dumpable_context(context): + keys, *_ = context + return json.dumps([keys, *([None] * len(_))]) + + def from_dumpable_context(dumpable): + s = json.loads(dumpable) + s[1] = PrototypeModule(torch.nn.Module()) + return s + + _register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + flatten_with_keys_fn=default_flatten_fn_with_keys, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + def default_flatten_fn_spec(obj, spec) -> list[Any]: + flats, context = flatten_fn(obj) + assert context == spec.context + return flats + + register_pytree_flatten_spec( + cls, + default_flatten_fn_spec, + ) + + +def deregister_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None: + _deregister_pytree_node(cls) + _deregister_pytree_flatten_spec(cls) + + +def _sync_state(src, dst): + assert isinstance( + src, + torch.nn.Module, + ), f"Expected {src} to be a nn.Module" + assert isinstance( + dst, + torch.nn.Module, + ), f"Expected {dst} to be a nn.Module" + # Share state (params, buffers) between modules. + # This ensures that state mutations are visible across them. + # Since tensor constants are not mutable, copying (without sharing) is OK. + # Also, primitive constants are specialized, so copying (without sharing) is OK. + dst._parameters = src._parameters + dst._buffers = src._buffers + + +def sync_state(*wrapped_method_modules): + """ + Sync state between exported modules corresponding to wrapped methods. + This might be necessary after serializing/deserializing due to copying. + """ + if wrapped_method_modules: + m, *other_ms = wrapped_method_modules + for other_m in other_ms: + _sync_state(m, other_m) + + +class _WrappedMethod(torch.nn.Module): + def __init__(self, method): + super().__init__() + # share state of method's self module + _sync_state(method.__self__, self) + # redirect forward to method + self.forward = method + + +def wrap_method(method): + """ + Wrap a method as a module so that it can be exported. + The wrapped module's forward points to the method, and + the method's original module state is shared. + """ + assert ismethod( + method, + ), f"Expected {method} to be a method" + return _WrappedMethod(method) diff --git a/phivenv/Lib/site-packages/torch/_export/verifier.py b/phivenv/Lib/site-packages/torch/_export/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..edced32698462e2badc870e19126e8813ab88b36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/verifier.py @@ -0,0 +1,502 @@ +# mypy: allow-untyped-defs +import inspect +import math +import operator +from collections.abc import Iterable +from typing import Any, final, TYPE_CHECKING + +import torch +from torch._ops import HigherOrderOperator, OpOverload +from torch._subclasses.fake_tensor import FakeTensor +from torch.export.graph_signature import ( + CustomObjArgument, + InputKind, + SymBoolArgument, + SymFloatArgument, + SymIntArgument, + TensorArgument, + TokenArgument, +) +from torch.fx import GraphModule + + +if TYPE_CHECKING: + from torch.export.exported_program import ExportedProgram + + +class SpecViolationError(Exception): + pass + + +def is_functional(op: OpOverload) -> bool: + return not op._schema.is_mutable + + +def _check_has_fake_tensor(node: torch.fx.Node) -> None: + # TODO(angelayi): remove this in favor of _check_val + return _check_val(node) + + +def _check_val(node: torch.fx.Node) -> None: + from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt + + def _check_correct_val(val): + if val is None: + return True + elif isinstance(val, (int, bool, str, float)): + return True + elif isinstance( + val, (torch.memory_format, torch.dtype, torch.device, torch.layout) + ): + return True + elif isinstance( + val, (FakeTensor, torch.Tensor) + ): # TODO(zhxchen17) Remove Tensor. + return True + elif isinstance(val, (SymInt, SymFloat, SymBool)): + return True + elif isinstance(val, CustomObjArgument): + return True + elif isinstance(val, Iterable): + return all(_check_correct_val(x) for x in val) + return False + + def _no_returns(op): + if not isinstance(op, OpOverload): + return False + return len(op._schema.returns) == 0 + + if "val" not in node.meta: + if node.op == "call_function" and _no_returns(node.target): + return + raise SpecViolationError(f"Node.meta {node.name} is missing val field.") + + val = node.meta["val"] + if not _check_correct_val(val): + raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}") + + +def _check_torch_fn(node: torch.fx.Node) -> None: + torch_fn = node.meta.get("torch_fn") + if torch_fn is None: + raise SpecViolationError( + f"Unable to find torch_fn metadata for node {node.name}" + ) + if ( + not isinstance(torch_fn, tuple) + and isinstance(torch_fn[0], str) + and isinstance(torch_fn[1], str) + ): + raise SpecViolationError( + f"Node.meta {node.name} has invalid torch_fn field {torch_fn}" + ) + + +class _VerifierMeta(type): + _registry: dict[str, type["Verifier"]] = {} + + def __new__(metacls, name, bases, attrs): + if bases: + if "check" in attrs or "_check_graph_module" in attrs: + raise SyntaxError("Overriding method check is not allowed.") + assert "dialect" in attrs and attrs["dialect"] != "ATEN" + else: + assert "check" in attrs + assert "_check_graph_module" in attrs + assert attrs["dialect"] == "ATEN" + + assert isinstance(attrs["dialect"], str) + ret = type.__new__(metacls, name, bases, attrs) + metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment] + return ret + + +def getattr_recursive(obj: Any, target: str) -> Any: + target_atoms = target.split(".") + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +class Verifier(metaclass=_VerifierMeta): + dialect = "ATEN" + + def allowed_builtin_ops(self) -> list: + return [ + operator.getitem, + operator.add, + operator.mul, + operator.sub, + operator.truediv, + operator.ge, + operator.le, + operator.gt, + operator.lt, + operator.eq, + operator.ne, + operator.floordiv, + operator.mod, + operator.and_, + operator.or_, + operator.not_, + operator.pow, + operator.neg, + operator.abs, + operator.lshift, + operator.rshift, + math.ceil, + math.floor, + math.trunc, + round, + ] + + def allowed_op_types(self) -> tuple[type[Any], ...]: + return (OpOverload, HigherOrderOperator) + + def allowed_getattr_types(self) -> tuple[type[Any], ...]: + return (torch.fx.GraphModule, torch.utils._pytree.TreeSpec) + + def allowed_getattr_types_for_subgm(self) -> tuple[type[Any], ...]: + # subgm in HOP's argument could has have getattr(weight) nodes, thus stateful + return ( + torch.fx.GraphModule, + torch.nn.parameter.Parameter, + torch.Tensor, # for buffer and constant tensor + torch.utils._pytree.TreeSpec, + ) + + def check_valid_op(self, op): + pass + + def check_additional(self, gm: GraphModule) -> None: + """ + Additional checks that are specific to some dialects. + """ + + @final + def check(self, ep: "ExportedProgram") -> None: + self._check_graph_module(ep.graph_module) + _verify_exported_program_module_call_graph(ep) + _verify_exported_program_signature(ep) + + @final + def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: + def _allowed_getattr_types(is_toplevel_gm) -> tuple[type[Any], ...]: + if is_toplevel_gm: + ret = self.allowed_getattr_types() + else: + ret = self.allowed_getattr_types_for_subgm() + assert not any(t is object for t in ret) + return ret + + def _check_valid_op(op) -> None: + def _allowed_builtin_ops() -> list: + ret = self.allowed_builtin_ops() + assert all(inspect.isbuiltin(op) for op in ret) + return ret + + def _allowed_op_types() -> tuple[type[Any], ...]: + ret = self.allowed_op_types() + assert not any(t is object for t in ret) + return ret + + # TODO Remove this allowlist. + _allowed_torch_functions = ( + torch.autograd.grad_mode.set_grad_enabled, + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, + # TODO (tmanlaibaatar) + # Predispatch export is able to contain autograd ops. + # These will be modeled as HOO later + torch._C._set_grad_enabled, + torch.amp.autocast_mode._enter_autocast, + torch.amp.autocast_mode._exit_autocast, + torch.fx.experimental.symbolic_shapes.cast_symbool_to_symint_guardless, + ) + + if not isinstance(op, _allowed_op_types()): + if ( + op not in _allowed_builtin_ops() + and op not in _allowed_torch_functions + ): + raise SpecViolationError( + f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" + f"Valid builtin ops: {_allowed_builtin_ops()}" + f"Valid torch functions: {_allowed_torch_functions}" + ) + + if isinstance(op, OpOverload): + # All ops functional + # TODO (tmanlaibaatar) more proper way is needed here + if self.dialect != "TRAINING" and not is_functional(op): + raise SpecViolationError(f"operator '{op}' is not functional") + self.check_valid_op(op) + + for mod in gm.modules(): + is_toplevel_gm = mod is gm + + if not isinstance(mod, torch.fx.GraphModule): + continue + + mod.graph.lint() + for node in mod.graph.nodes: + # TODO(T140410192): should have fake tensor for all dialects + if node.op in {"call_module", "call_method"}: + raise SpecViolationError( + f"call_module is not valid: got a class '{node.target}' ", + ) + + elif node.op == "call_function": + _check_val(node) + + _check_valid_op(node.target) + + elif node.op == "get_attr": + if not isinstance(node.target, str): + raise SpecViolationError( + f"Expected get_attr target to be string, but got {type(node.target)}" + ) + + attr = getattr_recursive(mod, node.target) + if isinstance(attr, torch.nn.Module): + + def _is_type(name, ty): + return isinstance(getattr(attr, name, None), ty) + + if type(attr).__name__ == "LoweredBackendModule": + if ( + _is_type("backend_id", str) + and _is_type("processed_bytes", bytes) + and _is_type("compile_specs", list) + and hasattr(attr, "original_module") + ): + continue + else: + backend_id = getattr(attr, "backend_id", None) + processed_bytes = getattr(attr, "processed_bytes", None) + compile_specs = getattr(attr, "compile_specs", None) + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"LoweredBackendModule fields: " + f"backend_id(str) : {type(backend_id)}, " + f"processed_bytes(bytes) : {type(processed_bytes)}, " + f"compile_specs(list) : {type(compile_specs)}" + ) + elif type(attr).__name__ == "AOTInductorEPModule": + continue + + elif type(attr).__name__ == "AOTInductorRunnerWrapper": + continue + + if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)): + raise SpecViolationError( + f"Invalid get_attr type {type(attr)} on target {node.target}. \n" + f"Valid get_attr types: {_allowed_getattr_types(is_toplevel_gm)}" + ) + + elif node.op == "placeholder": + _check_val(node) + # TODO(zhxchen17) + # elif node.op == "output": + # _check_flattened_outputs() + + self.check_additional(gm) + + +class TrainingIRVerifier(Verifier): + dialect = "TRAINING" + + +def _verify_exported_program_module_call_graph(exported_program) -> None: + module_call_graph = exported_program.module_call_graph + nodes = {node.name for node in exported_program.graph.nodes} + for entry in module_call_graph: + if entry.signature is not None: + for arg in entry.signature.inputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Input {arg.name} does not exist in the graph." + ) + for arg in entry.signature.outputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Output {arg.name} does not exist in the graph." + ) + + +def _verify_exported_program_signature(exported_program) -> None: + # Check ExportedProgram signature matches + gs = exported_program.graph_signature + + # Check every node in the signature exists in the graph + input_node_names = [ + node.name for node in exported_program.graph.nodes if node.op == "placeholder" + ] + + if len(input_node_names) != len(gs.input_specs): + raise SpecViolationError( + f"Number of graph inputs ({len(input_node_names)}) " + f"does not match number of inputs in the graph signature ({len(gs.input_specs)})" + ) + + for input_spec, node in zip(gs.input_specs, input_node_names): + if isinstance( + input_spec.arg, + (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument), + ): + if input_spec.arg.name != node: + raise SpecViolationError( + f"Input spec name {input_spec.arg.name} does not match node name {node}" + ) + + if input_spec.kind == InputKind.USER_INPUT: + continue + + elif input_spec.kind == InputKind.PARAMETER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + param = input_spec.target + if param not in exported_program.state_dict: + raise SpecViolationError(f"Parameter {param} is not in the state dict.") + + if not isinstance(exported_program.state_dict[param], torch.nn.Parameter): + raise SpecViolationError( + f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter." + ) + + elif input_spec.kind == InputKind.BUFFER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + buffer = input_spec.target + if input_spec.persistent is None: + raise SpecViolationError( + f"Buffer {buffer} is missing a persistence flag" + ) + + if ( + input_spec.persistent is True + and buffer not in exported_program.state_dict + ): + raise SpecViolationError(f"Buffer {buffer} is not in the state dict.") + + if input_spec.persistent is False and buffer in exported_program.state_dict: + raise SpecViolationError( + f"Non-persistent buffer {buffer} is in the state dict, it should not be." + ) + elif input_spec.kind == InputKind.CONSTANT_TENSOR: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + tensor_const = input_spec.target + if tensor_const not in exported_program.constants: + raise SpecViolationError( + f"Constant tensor {tensor_const} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.CUSTOM_OBJ: + if not isinstance(input_spec.arg, CustomObjArgument): + raise SpecViolationError( + f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + custom_obj = input_spec.target + if custom_obj not in exported_program.constants: + raise SpecViolationError( + f"Custom object {custom_obj} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.TOKEN: + if not isinstance(input_spec.arg, TokenArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + else: + raise SpecViolationError(f"Unknown InputKind {input_spec.kind}.") + + # Check outputs + output_node = list(exported_program.graph.nodes)[-1] + assert output_node.op == "output" + output_nodes = [ + arg.name if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ] + + if len(output_nodes) != len(gs.output_specs): + raise SpecViolationError( + f"Number of output nodes {len(output_nodes)} is different " + "Than the number of outputs specified by the graph signature: \n" + f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n" + f"Number of user outputs: {len(gs.user_outputs)}. \n" + ) + + num_tokens = len(gs.output_tokens) + end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens + mutate_nodes: list[str] = output_nodes[num_tokens:end] + user_output_nodes = output_nodes[end : end + len(gs.user_outputs)] + + for mutation_node in mutate_nodes: + if mutation_node in gs.buffers_to_mutate: + if gs.buffers_to_mutate[mutation_node] not in gs.buffers: + raise SpecViolationError( + f"Buffer output {mutation_node} does not point to a buffer that exists. \n" + f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" + f"Buffer nodes available: {gs.buffers} \n" + ) + elif mutation_node in gs.user_inputs_to_mutate: + if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: + raise SpecViolationError( + f"User input output {mutation_node} does not point to a user input that exists. \n" + f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" + f"User input nodes available: {gs.user_inputs} \n" + ) + else: + raise SpecViolationError( + f"Mutation node {mutation_node} is neither a buffer nor a user input. " + f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" + ) + + for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): + if user_output_node != user_output_name: + raise SpecViolationError( + f"User output {user_output_node} is not in the correct " + "order or is not found in the " + f"exported program's user_output list: {gs.user_outputs}. " + ) + + +def load_verifier(dialect: str) -> type[Verifier]: + if dialect == "ATEN" or dialect == "": + return _VerifierMeta._registry.get(dialect, Verifier) + return _VerifierMeta._registry[dialect] diff --git a/phivenv/Lib/site-packages/torch/_export/wrappers.py b/phivenv/Lib/site-packages/torch/_export/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..4df8d98e8a35a30c74935942f313b0109b6549c7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_export/wrappers.py @@ -0,0 +1,251 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +import torch +import torch._custom_ops +from torch._C import DispatchKey +from torch._higher_order_ops.flat_apply import ( + _ConstantFunction, + flat_apply, + to_graphable, +) +from torch._higher_order_ops.strict_mode import strict_mode +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + get_proxy_slot, + PreDispatchTorchFunctionMode, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils import _pytree as pytree +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type + + +class ExportTracepoint(HigherOrderOperator): + def __init__(self): + super().__init__("_export_tracepoint") + + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +_export_tracepoint = ExportTracepoint() + + +@_export_tracepoint.py_impl(ProxyTorchDispatchMode) +def export_tracepoint_dispatch_mode(mode, *args, **kwargs): + p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs)) + proxy = mode.tracer.create_proxy( + "call_function", _export_tracepoint, p_args, p_kwargs + ) + return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer) + + +@_export_tracepoint.py_impl(FakeTensorMode) +def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs): + with mode: + return args + + +@_export_tracepoint.py_functionalize_impl +def export_tracepoint_functional(ctx, *args, **kwargs): + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + _export_tracepoint(*unwrapped_args, **unwrapped_kwargs) + return args + + +_export_tracepoint.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(_export_tracepoint, deferred_error=True) +) + + +@_export_tracepoint.py_impl(DispatchKey.CPU) +def export_tracepoint_cpu(*args, **kwargs): + return args + + +def _wrap_submodule(mod, path, module_call_specs): + assert isinstance(mod, torch.nn.Module) + assert path != "" + submodule = torch.fx.graph_module._get_attr(mod, path) + + def update_module_call_signatures(path, in_spec, out_spec): + if path in module_call_specs: + assert module_call_specs[path]["in_spec"] == in_spec + assert module_call_specs[path]["out_spec"] == out_spec + module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec} + + def check_flattened(flat_args): + for a in flat_args: + if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None): + raise AssertionError( + f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}" + ) + + def pre_hook(module, args, kwargs): + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + check_flattened(flat_args) + flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path) + args, kwargs = pytree.tree_unflatten(flat_args, in_spec) + return args, kwargs + + def post_hook(module, args, kwargs, res): + _, in_spec = pytree.tree_flatten((args, kwargs)) + flat_res, out_spec = pytree.tree_flatten(res) + check_flattened(flat_res) + flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path) + update_module_call_signatures(path, in_spec, out_spec) + return pytree.tree_unflatten(flat_res, out_spec) + + pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True) + post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True) + return pre_handle, post_handle + + +@contextmanager +def _wrap_submodules(f, preserve_signature, module_call_signatures): + handles = [] + + try: + for path in preserve_signature: + handles.extend(_wrap_submodule(f, path, module_call_signatures)) + yield + finally: + for handle in handles: + handle.remove() + + +def _mark_strict_experimental(cls): + def call(self, *args): + return strict_mode(self, args) + + cls.__call__ = call + return cls + + +def _register_subclass_spec_proxy_in_tracer(tracer, name, spec): + """ + This is a wrapper utility method on top of tracer to cache the + already registered subclass spec attribute. This is useful because + Subclass.__init__ will be same for each subclass. By default, fx will + create multiple attributes/proxies for given attribute. + """ + fx_name = name + "0" + if hasattr(tracer.root, fx_name): + assert getattr(tracer.root, fx_name) == spec + return tracer.create_proxy("get_attr", fx_name, (), {}) + + qualname = tracer.get_fresh_qualname(name) + setattr(tracer.root, qualname, spec) + return tracer.create_proxy("get_attr", qualname, (), {}) + + +def mark_subclass_constructor_exportable_experimental(constructor_subclass): + """ + Experimental decorator that makes subclass to be traceable in export + with pre-dispatch IR. To make your subclass traceble in export, you need to: + 1. Implement __init__ method for your subclass (Look at DTensor implementation) + 2. Decorate your __init__ method with _mark_constructor_exportable_experimental + 3. Put torch._dynamo_disable decorator to prevent dynamo from peeking into its' impl + + Example: + + class FooTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, *, requires_grad=False): + # ... + return torch.Tensor._make_subclass(cls, elem, requires_grad=requires_grad) + + @torch._dynamo_disable + @mark_subclass_constructor_exportable_experimental + def __init__(self, elem, ...): + # ... + """ + + def _is_init(fn): + return callable(fn) and fn.__name__ == "__init__" + + if not _is_init(constructor_subclass): + raise RuntimeError( + f"torch._export.wrappers.mark_constructor_exportable_experimental can only be applied on subclass tensor.__init__" + f"But, you are adding it on {constructor_subclass.__name__} which is not supported. " + f"If __init__ doesn't exist on your subclass, please add it. Look at DTensor.__init__ implementation for example" + ) + + def wrapper(*args, **kwargs): + if not is_traceable_wrapper_subclass_type(type(args[0])): + assert constructor_subclass.__qualname__.endswith("__init__") + obj_name = constructor_subclass.__qualname__[: -len("__init__")] + raise RuntimeError( + f"Applying mark_constructor_exportable_experimental on {obj_name} is not valid as it is not a traceable " + f"tensor subclass. Please look at DTensor.__init__ implementation as an example of proper usage of this API." + ) + constructor_subclass(*args, **kwargs) + if not torch._C._is_torch_function_mode_enabled(): + return + torch_function_mode_stack = torch.overrides._get_current_function_mode_stack() + + pre_dispatch_tf_modes = [ + mode + for mode in torch_function_mode_stack + if isinstance(mode, PreDispatchTorchFunctionMode) + ] + assert ( + len(pre_dispatch_tf_modes) <= 1 + ), f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}" + + if len(pre_dispatch_tf_modes) == 0: + return + + mode = pre_dispatch_tf_modes[0] + + tracer = mode.tracer + subclass = args[0] + + flat_args, in_spec = to_graphable((tuple(args[1:]), kwargs)) + + constructor_spec_name = "_".join( + constructor_subclass.__qualname__.lower().split(".") + ) + qualname = tracer.get_fresh_qualname(constructor_spec_name) # type: ignore[union-attr] + setattr(tracer.root, qualname, in_spec) # type: ignore[union-attr] + spec_proxy = tracer.create_proxy("get_attr", qualname, (), {}) + flat_proxy_args = pytree.tree_map_only( + torch.Tensor, lambda x: get_proxy_slot(x, tracer).proxy, flat_args + ) + + _, func_spec = torch.utils._pytree.tree_flatten( + _ConstantFunction(type(subclass)) + ) + + # We actually don't want to create a new spec for each instance + # In fx graph, it will look like dtensor_const_func_spec + # We can't directly shove DTensor.__init__ into fx as it is not + # allowed type. + fxable_constructor_call_spec_name = ( + type(subclass).__name__.lower() + "_const_func_spec" + ) + + # We should try to reuse the constructor call spec as it is guaranteed to be same + # for each subclass type. This is different from proxy-ing the init arguments which + # can't be reused because for example, DTensor can receive different DeviceMesh etc + # as it's arguments + func_spec_proxy = _register_subclass_spec_proxy_in_tracer( + tracer, fxable_constructor_call_spec_name, func_spec + ) + + inner_proxy = tracer.create_proxy( + "call_function", + flat_apply, + (func_spec_proxy, spec_proxy, *flat_proxy_args), + {}, + ) + track_tensor_tree(subclass, inner_proxy, constant=None, tracer=tracer) + return + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/_functorch/__init__.py b/phivenv/Lib/site-packages/torch/_functorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6272b424658450437a313fc71bedbce73da3205 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbd3f21c7127768ccf913ba97b6e740dc0059be2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2658a3a5fbb3773fb003e656a75449ba1ae109c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/aot_autograd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/apis.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/apis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d12385cd1cc4bbb1f3bf32da543ba867fd9d5ab4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/apis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0b18462e04ca88be1e95baaef60e46b0f471630 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/autograd_function.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bb6d4e6d756b7716ad386a3f2f5ded6e1f82b35 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/batch_norm_replacement.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7337c60c6a505604ce7766d2e2444f81f6a3ff6e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/benchmark_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07ded977baafd91046e0cb77561723f0a1a76cc6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/compile_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/compilers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/compilers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18ad0cb7d76093067502c820681913a2c722e690 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/compilers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9730badae7332ced2ff16aebf3632b2fc688edbd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/deprecated.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/deprecated.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d747a63bd710e121443c744bb3d3ed532b9bd467 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/deprecated.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..543bdec5b71cab4a38c6e52f7b559d4f94e8b1ad Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/eager_transforms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/functional_call.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/functional_call.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7739a27d072428ff2b0df38ad0e3490ef6592a10 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/functional_call.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a4903ff118a94304a96e24a43095fd07b15cdcd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/fx_minifier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/make_functional.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/make_functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b4d49c88769c31ec9e56dad9c5fb03e3fdc1c8f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/make_functional.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/partitioners.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/partitioners.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da5cdbe3a41f970ee3aeb56fc5c24eac0b2a41e2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/partitioners.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3acc604ebd7ed432036604f70deeeb6076e9e81 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/pyfunctorch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/python_key.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/python_key.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78621fc3bb45ecafd235851157f72e9845d0fdfa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/python_key.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c575417febbd229727041ba0db8cd83562f0bcf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/pytree_hacks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b2025941332e9b3ffc97a5407b04a6b44d17789 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/top_operators_github_usage.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6934dbad78dfd9add35b78ba9dffd431fa89294e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/__pycache__/vmap.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/vmap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..599473adb4fdc85ac332dcd105a940d0fbbc808e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/__pycache__/vmap.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__init__.py b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6272b424658450437a313fc71bedbce73da3205 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58dab11d49acbbc71a0e8b011e4da26221749902 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7103ab494cf3ffb2e8aa13e872cd4ed407a8d0d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b10da376a9cbdef252561f21eb6ddd0f997307cc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f117eca6ec94342cce39460a14bead6ae7300ab Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..892be522eed9b1d2b17dd7586a3bef0d47758ce0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/ac_logging_utils.py b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/ac_logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b15033fe64aacef3c0ff6ff75147d21569e5a0bf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/ac_logging_utils.py @@ -0,0 +1,145 @@ +import json +import logging +from typing import Any + +from torch._logging import trace_structured +from torch.fx import Graph, Node + + +log: logging.Logger = logging.getLogger(__name__) + + +def create_joint_graph_node_information( + joint_graph: Graph, + recomputable_node_info: dict[str, int], +) -> dict[str, Any]: + joint_graph_node_information: dict[str, Any] = {} + + for i, joint_graph_node in enumerate(joint_graph.nodes): + is_recomputable_candidate: bool = ( + joint_graph_node.name in recomputable_node_info + ) + tensor_meta = joint_graph_node.meta.get("tensor_meta") + shape = getattr(tensor_meta, "shape", []) if tensor_meta else [] + + node_info: dict[str, Any] = { + "index": i, + "name": joint_graph_node.name, + "is_recomputable_candidate": is_recomputable_candidate, + "target": str(joint_graph_node.target), + "shape": str(shape), + "input_arguments": [inp.name for inp in joint_graph_node.all_input_nodes], + "stack_trace": joint_graph_node.meta.get("stack_trace", ""), + } + + if is_recomputable_candidate: + idx: int = recomputable_node_info[joint_graph_node.name] + node_info["recomputable_candidate_info"] = { + "recomputable_node_idx": idx, + } + + joint_graph_node_information[joint_graph_node.name] = node_info + + return joint_graph_node_information + + +def create_joint_graph_edges(joint_graph: Graph) -> list[tuple[str, str]]: + joint_graph_edges: list[tuple[str, str]] = [ + (inp.name, node.name) + for node in joint_graph.nodes + for inp in node.all_input_nodes + ] + return joint_graph_edges + + +def create_activation_checkpointing_logging_structure_payload( + joint_graph: Graph, + joint_graph_node_information: dict[str, Any], + joint_graph_edges: list[tuple[str, str]], + all_recomputable_banned_nodes: list[Node], + expected_runtime: float, + saved_node_idxs: list[int], + recomputable_node_idxs: list[int], + memories_banned_nodes: list[float], + runtimes_banned_nodes: list[float], + min_cut_saved_values: list[Node], +) -> dict[str, Any]: + activation_checkpointing_logging_structure_payload: dict[str, Any] = { + "Joint Graph Size": len(joint_graph.nodes), + "Joint Graph Edges": { + "Total": len(joint_graph_edges), + "Edges": joint_graph_edges, + }, + "Joint Graph Node Information": joint_graph_node_information, + "Recomputable Banned Nodes Order": [ + node.name for node in all_recomputable_banned_nodes + ], + "Expected Runtime": expected_runtime, + "Knapsack Saved Nodes": saved_node_idxs, + "Knapsack Recomputed Nodes": recomputable_node_idxs, + "Knapsack Input Memories": memories_banned_nodes, + "Knapsack Input Runtimes": runtimes_banned_nodes, + "Min Cut Solution Saved Values": [node.name for node in min_cut_saved_values], + } + return activation_checkpointing_logging_structure_payload + + +def create_structured_trace_for_min_cut_info( + joint_graph: Graph, + all_recomputable_banned_nodes: list[Node], + saved_node_idxs: list[int], + recomputable_node_idxs: list[int], + expected_runtime: float, + memories_banned_nodes: list[float], + runtimes_banned_nodes: list[float], + min_cut_saved_values: list[Node], +) -> None: + recomputable_node_info: dict[str, int] = { + node.name: idx for idx, node in enumerate(all_recomputable_banned_nodes) + } + joint_graph_node_information = create_joint_graph_node_information( + joint_graph, recomputable_node_info + ) + + for node_name, node_info in joint_graph_node_information.items(): + if node_info["is_recomputable_candidate"]: + idx = recomputable_node_info[node_name] + node_info["recomputable_candidate_info"]["memory"] = memories_banned_nodes[ + idx + ] + node_info["recomputable_candidate_info"]["runtime"] = runtimes_banned_nodes[ + idx + ] + node_info["recomputable_candidate_info"]["is_saved"] = ( + idx in saved_node_idxs + ) + node_info["recomputable_candidate_info"]["is_recomputed"] = ( + idx in recomputable_node_idxs + ) + + joint_graph_edges = create_joint_graph_edges(joint_graph) + activation_checkpointing_logging_structure_payload = ( + create_activation_checkpointing_logging_structure_payload( + joint_graph, + joint_graph_node_information, + joint_graph_edges, + all_recomputable_banned_nodes, + expected_runtime, + saved_node_idxs, + recomputable_node_idxs, + memories_banned_nodes, + runtimes_banned_nodes, + min_cut_saved_values, + ) + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "min_cut_information", + "encoding": "json", + }, + payload_fn=lambda: json.dumps( + activation_checkpointing_logging_structure_payload + ), + ) diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/graph_info_provider.py b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/graph_info_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..d5f036cab49fc324e62f0a324d798692edb7a62e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/graph_info_provider.py @@ -0,0 +1,321 @@ +from typing import Any, Optional + +import networkx as nx + +from torch.fx import Graph, Node + + +class GraphInfoProvider: + """ + This class provides information about the graph, such as the nodes, edges, and their runtime and memory requirements. + It also provides methods to create graphs from the information provided. + """ + + __RECOMPUTABLE_NODE_ONLY_GRAPH = "recomputable_node_only_graph" + __RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT = ( + "recomputable_node_only_graph_with_larger_graph_context" + ) + __FULL_NX_JOINT_GRAPH = "full_nx_joint_graph" + __SIMPLIFIED_FX_JOINT_GRAPH = "fx_joint_graph" + + def __init__( + self, + graph_nodes_in_order: list[str], + graph_edges: list[tuple[str, str]], + all_recomputable_banned_nodes: list[str], + all_node_runtimes: Optional[dict[str, float]] = None, + all_node_memories: Optional[dict[str, float]] = None, + recorded_knapsack_input_memories: Optional[list[float]] = None, + recorded_knapsack_input_runtimes: Optional[list[float]] = None, + joint_graph: Optional[Graph] = None, + ): + self.graph_nodes_in_order = graph_nodes_in_order + self.graph_edges = graph_edges + self.all_node_runtimes: dict[str, float] = dict() + if all_node_runtimes is None: + if recorded_knapsack_input_runtimes is None: + raise ValueError( + "Either all_node_runtimes or recorded_knapsack_input_runtimes must be provided." + ) + self.all_node_runtimes = { + node: recorded_knapsack_input_runtimes[i] + for i, node in enumerate(all_recomputable_banned_nodes) + } + else: + self.all_node_runtimes.update(all_node_runtimes) + self.all_node_memories: dict[str, float] = dict() + if all_node_memories is None: + if recorded_knapsack_input_memories is None: + raise ValueError( + "Either all_node_memories or recorded_knapsack_input_memories must be provided." + ) + self.all_node_memories = { + node: recorded_knapsack_input_memories[i] + for i, node in enumerate(all_recomputable_banned_nodes) + } + else: + self.all_node_memories.update(all_node_memories) + self.all_recomputable_banned_nodes = all_recomputable_banned_nodes + self.all_recomputable_banned_nodes_set = set(all_recomputable_banned_nodes) + self.recorded_knapsack_input_memories = recorded_knapsack_input_memories + self.recorded_knapsack_input_runtimes = recorded_knapsack_input_runtimes + self._lazily_initialized_graphs: dict[str, Any] = { + self.__RECOMPUTABLE_NODE_ONLY_GRAPH: None, + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT: None, + self.__FULL_NX_JOINT_GRAPH: None, + self.__SIMPLIFIED_FX_JOINT_GRAPH: None, + } + + @classmethod + def inialize_from_graph( + cls, + joint_graph: Graph, + all_recomputable_banned_nodes: list[Node], + recorded_knapsack_input_memories: list[float], + recorded_knapsack_input_runtimes: list[float], + ) -> "GraphInfoProvider": + """ + Enables initialization from a joint graph. + """ + graph_nodes_in_order = [node.name for node in joint_graph.nodes] + graph_edges = [ + (node.name, user.name) for node in joint_graph.nodes for user in node.users + ] + all_recomputable_banned_node_names = [ + node.name for node in all_recomputable_banned_nodes + ] + return cls( + graph_nodes_in_order=graph_nodes_in_order, + graph_edges=graph_edges, + all_recomputable_banned_nodes=all_recomputable_banned_node_names, + recorded_knapsack_input_memories=recorded_knapsack_input_memories, + recorded_knapsack_input_runtimes=recorded_knapsack_input_runtimes, + joint_graph=joint_graph, + ) + + @property + def recomputable_node_only_graph(self) -> nx.DiGraph: + if self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] is None: + self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH + ] = self._create_recomputable_node_only_graph() + return self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] + + @property + def recomputable_node_only_graph_with_larger_graph_context(self) -> nx.DiGraph: + if ( + self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT + ] + is None + ): + self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT + ] = self._create_recomputable_node_only_graph_with_larger_graph_context() + return self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT + ] + + @property + def full_joint_nx_graph(self) -> nx.DiGraph: + if self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] is None: + self._lazily_initialized_graphs[ + self.__FULL_NX_JOINT_GRAPH + ] = self._create_full_joint_graph() + return self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] + + @property + def simplified_fx_joint_graph(self) -> Graph: + if self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] is None: + self._lazily_initialized_graphs[ + self.__SIMPLIFIED_FX_JOINT_GRAPH + ] = self._recreate_psuedo_joint_graph() + return self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] + + def get_non_ac_peak_memory(self) -> float: + return sum( + self.all_node_memories[node_name] + for node_name in self.all_recomputable_banned_nodes_set + ) + + def get_theoretical_max_runtime(self) -> float: + return sum( + self.all_node_runtimes[node_name] + for node_name in self.all_recomputable_banned_nodes_set + ) + + def get_knapsack_memory_input(self) -> list[float]: + return ( + self.recorded_knapsack_input_memories + if self.recorded_knapsack_input_memories + else [ + self.all_node_memories[node_name] + for node_name in self.all_recomputable_banned_nodes + ] + ) + + def get_knapsack_runtime_input(self) -> list[float]: + return ( + self.recorded_knapsack_input_runtimes + if self.recorded_knapsack_input_runtimes + else [ + self.all_node_runtimes[node_name] + for node_name in self.all_recomputable_banned_nodes + ] + ) + + def _create_recomputable_node_only_graph(self) -> nx.DiGraph: + graph = nx.DiGraph() + for recomputable_node in self.all_recomputable_banned_nodes: + graph.add_node(recomputable_node) + + for a, b in self.graph_edges: + if ( + a in self.all_recomputable_banned_nodes_set + and b in self.all_recomputable_banned_nodes_set + ): + graph.add_edge(a, b) + return graph + + def _create_recomputable_node_only_graph_with_larger_graph_context( + self, + ) -> nx.DiGraph: + # Create a dictionary to store the reachable nodes for each node + all_recomputable_banned_nodes_set = set(self.all_recomputable_banned_nodes) + + reachable_nodes = {} + for node in all_recomputable_banned_nodes_set: + # Use BFS to find all reachable nodes + predecessors = dict(nx.bfs_predecessors(self.full_joint_nx_graph, node)) + reachable_recomputable_nodes = set(predecessors.keys()).intersection( + all_recomputable_banned_nodes_set + ) + reachable_nodes[node] = reachable_recomputable_nodes + # Create the candidate graph + candidate_graph = nx.DiGraph() + candidate_graph.add_nodes_from(all_recomputable_banned_nodes_set) + for node1 in all_recomputable_banned_nodes_set: + for node2 in reachable_nodes[node1]: + # Check if there is an overlapping path + overlapping_path = False + for intermediate_node in reachable_nodes[node1]: + if ( + intermediate_node != node2 + and node2 in reachable_nodes[intermediate_node] + ): + overlapping_path = True + break + if not overlapping_path: + candidate_graph.add_edge(node1, node2) + return candidate_graph + + def _create_full_joint_graph(self) -> nx.DiGraph: + graph = nx.DiGraph() + for node in self.graph_nodes_in_order: + if node == "output": + continue + graph.add_node(node) + + for a, b in self.graph_edges: + if a == "output" or b == "output": + continue + graph.add_edge(a, b) + return graph + + def _recreate_psuedo_joint_graph(self) -> Graph: + # Create a dictionary to store the dependencies of each node + node_dependencies: dict[str, list[str]] = { + node: [] for node in self.graph_nodes_in_order + } + for a, b in self.graph_edges: + if a not in node_dependencies or b not in node_dependencies: + raise ValueError(f"Edge ({a}, {b}) references a non-existent node.") + node_dependencies[b].append(a) + + joint_graph = Graph() + # Create nodes in the graph + nodes: dict[str, Node] = {} + for node_name in self.graph_nodes_in_order: + input_nodes = [nodes[dep] for dep in node_dependencies[node_name]] + if input_nodes: + node = joint_graph.call_function(lambda *x: x, tuple(input_nodes)) + node.name = node_name + else: + node = joint_graph.placeholder(node_name) + nodes[node_name] = node + return joint_graph + + def _visualize_recomputable_candidate_graph_with_larger_context( + self, + layout_k: float = 0.5, + layout_iterations: int = 30, + ) -> None: + """ + Visualize the recomputable candidate graph with larger context. + """ + from matplotlib import cm, colors as mcolors, pyplot as plt + + pos = nx.spring_layout( + self.recomputable_node_only_graph_with_larger_graph_context, + k=layout_k, + iterations=layout_iterations, + ) + # pos = nx.spectral_layout(graph_with_indirect_edges) + plt.figure(figsize=(20, 15)) + + # Create a dictionary for node labels using the index + labels = { + node: self.recomputable_node_only_graph_with_larger_graph_context.nodes[ + node + ].get("index", node) + for node in self.recomputable_node_only_graph_with_larger_graph_context.nodes + } + + # Extract memory values and normalize them + norm = mcolors.Normalize( + vmin=min(self.get_knapsack_memory_input()), + vmax=max(self.get_knapsack_memory_input()), + ) + cmap = cm.viridis # type: ignore[attr-defined] + + # Assign colors based on memory + node_colors = [ + cmap( + norm( + float( + self.recomputable_node_only_graph_with_larger_graph_context.nodes[ + node + ][ + "memory" + ] + ) + ) + ) + for node in self.recomputable_node_only_graph_with_larger_graph_context.nodes + ] + + # Draw the graph with parsed nodes only + nx.draw_networkx_nodes( + self.recomputable_node_only_graph_with_larger_graph_context, + pos, + node_color=node_colors, + node_size=300, + label="Parsed Nodes", + ) + nx.draw_networkx_edges( + self.recomputable_node_only_graph_with_larger_graph_context, + pos, + arrows=True, + arrowsize=10, + ) + nx.draw_networkx_labels( + self.recomputable_node_only_graph_with_larger_graph_context, + pos, + labels=labels, + font_size=8, + font_weight="bold", + ) + + plt.title("Memory Colour Coded Dependency Graph for Recomputable Nodes") + plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), label="Memory") + plt.show() diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/knapsack.py b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/knapsack.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ebaf15692a2b5757d44a2795aa169dbdfdb8a4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/knapsack.py @@ -0,0 +1,121 @@ +import torch + + +def greedy_knapsack( + memory: list[float], runtimes: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + n = len(runtimes) + items = list(range(n)) + + # Sort items based on the ratio of runtime to memory in descending order + items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) + + total_memory = 0.0 + total_runtime = 0.0 + items_to_save = [] + items_to_allow_recomputing = [] + + for i in items: + if total_memory + memory[i] <= max_memory: + total_memory += memory[i] + total_runtime += runtimes[i] + items_to_save.append(i) + else: + items_to_allow_recomputing.append(i) + return total_runtime, items_to_save, items_to_allow_recomputing + + +def ilp_knapsack( + memory: list[float], runtimes: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + import numpy as np + + try: + from scipy.optimize import Bounds, LinearConstraint, milp + except ImportError: + raise RuntimeError( + "To use the ILP for memory budget checkpointing you need to install scipy" + ) from None + + np_memory = np.array(memory) + np_runtimes = np.array(runtimes) + c = -np_runtimes # type: ignore[operator] + + memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) + constraints = [memory_constraint] + + integrality = np.ones_like(c) + res = milp( + c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) + ) + if not res.success: + raise RuntimeError("Somehow scipy solving failed") + + items_to_save = [] + items_to_allow_recomputing = [] + for idx, i in enumerate(res.x): + if i == 1: + items_to_save.append(idx) + else: + items_to_allow_recomputing.append(idx) + return -res.fun, items_to_save, items_to_allow_recomputing + + +def dp_knapsack( + memory: list[float], runtime: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # Quantize the memory weights + quantized_memory = torch.tensor( + [int(round(m * S)) for m in memory], dtype=torch.long, device="cpu" + ) + runtimes = torch.tensor(runtime, dtype=torch.float32, device="cpu") + + # Quantized pseudopolynomial DP for 0-1 Knapsack + quantized_max_memory = int(round(max_memory * S)) + + n = len(memory) + + # Initialize the DP table + # TODO(chilli): I think if needed, this memory can be optimized with sliding + # window trick + Hirschberg trick: + # https://codeforces.com/blog/entry/47247?#comment-316200 + dp = torch.zeros( + (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" + ) + + for i in range(1, n + 1): + current_memory = quantized_memory[i - 1] + current_runtime = runtimes[i - 1] + + # Copy the previous row + dp[i, :] = dp[i - 1, :] + + # Update dp[i, j] for all j >= current_memory + if current_memory == 0: + dp[i, :] = dp[i - 1, :] + current_runtime + else: + dp[i, current_memory:] = torch.maximum( + dp[i - 1, current_memory:], + dp[i - 1, :-current_memory] + current_runtime, + ) + + # Backtrack to find the items included in the knapsack + saved_items = [] + recomputable_items = [] + j: int = quantized_max_memory + for i in range(n, 0, -1): + if dp[i][j] != dp[i - 1][j]: + saved_items.append(i - 1) # Include this item (indexing from 0) + j -= int(quantized_memory[i - 1].item()) + else: + recomputable_items.append(i - 1) + + saved_items.reverse() # To get items in the order they were added + + # The maximum runtime that can be achieved within the max_memory constraint + max_runtime = dp[n][quantized_max_memory].item() + + return max_runtime, saved_items, recomputable_items diff --git a/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..94da354c7090ab1619b67250679a4ef321033487 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_activation_checkpointing/knapsack_evaluator.py @@ -0,0 +1,273 @@ +import operator +from collections import deque +from typing import Callable + +import networkx as nx + +from torch._functorch._activation_checkpointing.graph_info_provider import ( + GraphInfoProvider, +) + + +class KnapsackEvaluator: + """ + This class evaluates the theoretical runtime and peak memory usage of a given checkpointing strategy. + It takes in a graph and a list of nodes that are saved and recomputed, and then simulates the + backward pass to calculate the peak memory usage. + """ + + def __init__( + self, + graph_info_provider: GraphInfoProvider, + ) -> None: + self._graph_info_provider = graph_info_provider + + def _get_backward_memory_from_topologically_sorted_graph( + self, + node_graph: nx.DiGraph, + node_memories: dict[str, float], + saved_nodes_set: set[str], + peak_memory_after_forward_pass: float, + ) -> list[tuple[float, str]]: + """ + Simulates the backward pass and keeps track of the peak memory usage. + + High Level Steps: + 1. Set Initial Peak/Current Memory + Allows you to set the peak memory after the forward pass, but typically this is + the sum of the estimated memory of the saved nodes. + 2. Perform a reverse topological sort of the node_graph. + If full graph is defined then will sort the full graph and only process the subset + of nodes in the node_graph. + 3. Iterate through the sorted graph nodes. + If the node is saved then just drop it's memory from current memory. + If the node is not saved then add it's memory to current memory and then traverse it's + predecessors to simulate recomuptation chain. Will check if new peak memory after all + predecessors are processed. + + Args: + node_graph (nx.DiGraph): A directed graph representing the recomputable forward nodes. + saved_nodes_set (Set[str]): A set of node names that are saved. + peak_memory_after_forward_pass (float): The peak memory usage after the forward pass. + """ + current_memory = [ + (peak_memory_after_forward_pass, "Initial Peak/Current Memory") + ] + already_computed = set() + sorted_nodes = list(reversed(list(nx.topological_sort(node_graph)))) + dependencies_computed = set() + + for node in sorted_nodes: + if node in saved_nodes_set or node in already_computed: + current_memory.append( + ( + current_memory[-1][0] - node_memories[node], + f"Dropping Node(already saved): {node}", + ) + ) + continue + + already_computed.add(node) + current_memory.append( + ( + current_memory[-1][0] + node_memories[node], + f"Recomputing Node: {node}", + ) + ) + # Create a queue of dependencies required for recomputation + predecessor_queue = deque( + [ + dependency + for dependency, v in node_graph.in_edges(node) + if dependency not in already_computed + ] + ) + while predecessor_queue: + dep = predecessor_queue.popleft() + already_computed.add(dep) + dependencies_computed.add(dep) + current_memory.append( + ( + current_memory[-1][0] + node_memories[dep], + f"Recomputing Predecessor of {node}: {dep}", + ) + ) + # Add predecessors of the predecessor to the queue if they haven't been recomputed yet + for dependency_of_dependency, _ in node_graph.in_edges(dep): + if ( + dependency_of_dependency in already_computed + or dependency_of_dependency in saved_nodes_set + or dependency_of_dependency in predecessor_queue + ): + continue + predecessor_queue.append(dependency_of_dependency) + dependencies_computed.clear() + current_memory.append( + (current_memory[-1][0] - node_memories[node], f"Dropping Node: {node}") + ) + return current_memory + + def _validate_all_indexes_accounted_for_in_provided_output( + self, saved_nodes_idxs: list[int], recomputable_node_idxs: list[int] + ) -> None: + """ + Validate that all indexes are accounted for in the provided output. + This function checks that the union of saved nodes and recomputable nodes + covers all candidate nodes without any overlaps. + """ + recomputable_node_idxs_set = set(recomputable_node_idxs) + saved_nodes_idxs_set = set(saved_nodes_idxs) + all_candidate_nodes_idxs = set( + range(len(self._graph_info_provider.all_recomputable_banned_nodes)) + ) + # Check that there are no overlaps between saved nodes and recomputable nodes + assert ( + len(recomputable_node_idxs_set.intersection(saved_nodes_idxs_set)) == 0 + ), "Saved nodes and recomputable nodes cannot have any overlaps" + # Check that all candidate nodes are accounted for + assert ( + recomputable_node_idxs_set.union(saved_nodes_idxs_set) + == all_candidate_nodes_idxs + ), "All candidate nodes must be accounted for in the provided output" + + def evaluate_knapsack_output( + self, + saved_nodes_idxs: list[int], + recomputable_node_idxs: list[int], + account_for_backward_pass: bool = False, + ) -> dict[str, float]: + """ + Evaluate the theoretical runtime and peak memory usage of a given checkpointing strategy. + Args: + - saved_nodes_idxs (List[int]): The indices of nodes that are saved. + - recomputable_node_idxs (List[int]): The indices of nodes that need to be recomputed. + """ + self._validate_all_indexes_accounted_for_in_provided_output( + saved_nodes_idxs, recomputable_node_idxs + ) + recomputation_runtime = sum( + self._graph_info_provider.all_node_runtimes[ + self._graph_info_provider.all_recomputable_banned_nodes[node] + ] + for node in recomputable_node_idxs + ) + if account_for_backward_pass: + memory_list = self._get_backward_memory_from_topologically_sorted_graph( + node_graph=self._graph_info_provider.recomputable_node_only_graph_with_larger_graph_context, + saved_nodes_set={ + self._graph_info_provider.all_recomputable_banned_nodes[i] + for i in saved_nodes_idxs + }, + node_memories=self._graph_info_provider.all_node_memories, + peak_memory_after_forward_pass=sum( + self._graph_info_provider.all_node_memories[ + self._graph_info_provider.all_recomputable_banned_nodes[i] + ] + for i in saved_nodes_idxs + ), + ) + peak_memory = max(memory_list, key=operator.itemgetter(0))[0] + else: + peak_memory = sum( + self._graph_info_provider.all_node_memories[ + self._graph_info_provider.all_recomputable_banned_nodes[node] + ] + for node in saved_nodes_idxs + ) + return { + "peak_memory": peak_memory, + "recomputation_runtime": recomputation_runtime, + "non_ac_peak_memory": self._graph_info_provider.get_non_ac_peak_memory(), + "theoretical_max_runtime": self._graph_info_provider.get_theoretical_max_runtime(), + "percentage_of_theoretical_peak_memory": peak_memory + / self._graph_info_provider.get_non_ac_peak_memory(), + "percentage_of_theoretical_peak_runtime": recomputation_runtime + / self._graph_info_provider.get_theoretical_max_runtime(), + } + + def evaluate_distribution_of_results_for_knapsack_algo( + self, + knapsack_algo: Callable[ + [list[float], list[float], float], tuple[float, list[int], list[int]] + ], + memory_budget_values: list[float], + ) -> list[dict[str, float]]: + """ + Evaluates the distribution of results for a given knapsack algorithm. + Args: + knapsack_algo (Callable): The knapsack algorithm to use for evaluation. + memory_budget_values (List[float]): A list of memory budgets to evaluate. + """ + results = list() + for memory_budget in memory_budget_values: + _, saved_nodes, recomputed_nodes = knapsack_algo( + self._graph_info_provider.get_knapsack_memory_input(), + self._graph_info_provider.get_knapsack_runtime_input(), + memory_budget, + ) + result = self.evaluate_knapsack_output( + saved_nodes_idxs=saved_nodes, + recomputable_node_idxs=recomputed_nodes, + ) + result["memory_budget"] = memory_budget + results.append(result) + return results + + def get_knee_point_memory_budget( + self, + knapsack_algo: Callable[ + [list[float], list[float], float], tuple[float, list[int], list[int]] + ], + max_mem_budget: float = 0.1, + min_mem_budget: float = 0.001, + iterations: int = 100, + ) -> float: + """ + Finds the memory budget at the knee point in the Pareto frontier. + + The knee point is defined as the point where the trade-off between + runtime and memory usage is optimal. + + Args: + knapsack_algo (callable): Knapsack algorithm to use for evaluation. + max_mem_budget (float, optional): Maximum memory budget. Defaults to 0.1. + min_mem_budget (float, optional): Minimum memory budget. Defaults to 0.001. + iterations (int, optional): Number of memory budgets to evaluate. Defaults to 100. + + Returns: + float: Memory budget at the knee point. + """ + results = self.evaluate_distribution_of_results_for_knapsack_algo( + knapsack_algo=knapsack_algo, + memory_budget_values=[ + min_mem_budget + + i * (max_mem_budget - min_mem_budget) / (iterations - 1) + for i in range(iterations) + ], + ) + runtime_values = [ + result["percentage_of_theoretical_peak_runtime"] for result in results + ] + memory_values = [ + result["percentage_of_theoretical_peak_memory"] for result in results + ] + runtime_range = max(runtime_values) - min(runtime_values) + memory_range = max(memory_values) - min(memory_values) + if runtime_range == 0 or memory_range == 0: + return max_mem_budget + + # Normalize values + runtime_min = min(runtime_values) + memory_min = min(memory_values) + runtime_norm = [ + (value - runtime_min) / runtime_range for value in runtime_values + ] + memory_norm = [(value - memory_min) / memory_range for value in memory_values] + # Calculate Euclidean distance + distances = [ + (runtime_norm[i] ** 2 + memory_norm[i] ** 2) ** 0.5 + for i in range(len(runtime_norm)) + ] + # Find the knee point(shortest distance from the origin) + knee_index = distances.index(min(distances)) + return results[knee_index]["memory_budget"] diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__init__.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6272b424658450437a313fc71bedbce73da3205 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec9cda1c4bbfa1cb04a8613c3d2159b9b44da849 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57d7f8104f47bde12fee55ed74f145e02ac3d3cd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92c1fad5d0ec10d2178858e1a5d628fe1c8ade34 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/dispatch_and_compile_graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/dispatch_and_compile_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb98a2a7624d5bd995d8e5cb994fb75fad325994 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/dispatch_and_compile_graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..395479a12e51634f789c9bc635dcece27865b376 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1332ba96b115b5044c93c684306bbdb4c47db0ed Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/jit_compile_runtime_wrappers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/jit_compile_runtime_wrappers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca60e71304c0c5ed314b94814932a3a7c2c55258 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/jit_compile_runtime_wrappers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30c01cb93741cef2970f371526562411d1ea2888 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd3431438245d272134beb43a213fc7cfa2e60d8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..148446d131499425ecfa51db7ae67d922e17954d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96b62c8d2f627218a18b2912076a712733873d9d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d94eca22c83289c946bf20e4cfd0fbab5254cea Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/traced_function_transforms.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/traced_function_transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bf2b87dd3394a62c4e9c1b4d0c2f5eaa4cca0f0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/traced_function_transforms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2435713201c3c333e3b72ec9d3c6ede4ceefecc0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..adc81906026962217b347f8d54d13d3acc2edfcb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py @@ -0,0 +1,1442 @@ +# mypy: allow-untyped-defs +""" +Utils for caching the outputs of AOTAutograd +""" +from __future__ import annotations + +import base64 +import contextlib +import functools +import json +import logging +import os +import pickle +import shutil +import time +import traceback +from abc import ABC, abstractmethod +from copy import copy +from dataclasses import dataclass +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import override + +import torch +from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext +from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions +from torch._dynamo.utils import ( + chromium_event_log_active, + CompileEventLogger, + counters, + dynamo_timed, +) +from torch._functorch import config +from torch._inductor.codecache import ( + _ident, + add_ephemeral_timeout_increase_for_distributed, + BypassFxGraphCache, + create_cache, + extract_tensor_metadata_for_cache_key, + FxGraphCache, + FxGraphCachePickler, + FxGraphHashDetails, + GuardedCache, + sha256_hash, + write_atomic, +) +from torch._inductor.output_code import ( + CompiledFxGraph, + CompiledFxGraphConstants, + OutputCode, +) +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import should_use_remote_fx_graph_cache +from torch._logging import LazyString +from torch._utils_internal import log_cache_bypass +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) +from torch.fx.experimental.symbolic_shapes import hint_int +from torch.utils._triton import has_triton_package +from torchgen.utils import dataclass_repr + +from .runtime_wrappers import ( + AOTDispatchAutograd, + AOTDispatchSubclassWrapper, + CachedAutogradLazyBackwardCompileInfo, + CompilerWrapper, + FunctionalizedRngRuntimeWrapper, + post_compile, + RuntimeWrapper, + SubclassMeta, +) +from .schemas import AOTAutogradCacheInfo, AOTConfig, ViewAndMutationMeta # noqa: F401 + + +if TYPE_CHECKING: + from torch._inductor.compile_fx import _CompileFxKwargs + from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.remote_cache import JsonDataTy, RemoteCache + from torch._inductor.utils import BoxedBool + from torch.fx.node import Node + +log = logging.getLogger(__name__) + + +class BypassAOTAutogradCache(Exception): + pass + + +# Used to signify when FXGraphCache missed when AOTAutogradCache uses it +class FXGraphCacheMiss(BypassAOTAutogradCache): + pass + + +def should_use_remote_autograd_cache(): + if torch._inductor.config.force_disable_caches: + return False + if config.enable_remote_autograd_cache is not None: + return config.enable_remote_autograd_cache + if not config.is_fbcode(): + return False + + if torch._utils_internal.is_fb_unit_test(): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + jk_name = "pytorch/remote_cache:aot_autograd_cache_version" + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(jk_name) + + +def should_use_local_autograd_cache(): + if torch._inductor.config.force_disable_caches: + return False + return config.enable_autograd_cache + + +def check_node_safe(node: Node): + """ + Checks that the node only uses supported operators. We are starting with very + conservative cacheability constraints, and incrementally adding more support as we expand. + + [Note: AOTAutograd Cacheability checks] + - Our cache key is computed from the FX graph produced by Dynamo and the input example values + - A node is "safe" if the same cache key results in a compiled artifact that has the same behavior + (i.e, the set of inputs that go into our cache key is sufficient to distinguish its behavior) + + To accomplish this safety check, we consider the following functions to be safe: + - Public functions under modules torch, torch.functional, and torch.nn.functional: these are + allowed in the graph by dynamo, so we can assume they are safe to cache. + - method calls on base tensor types + - Any call_module that dynamo deemed safe to allow AOTAutograd to trace + - Non callable nodes, such as placeholder, output, get_attr + + The test suite test_aot_autograd_cache.py::AOTAutogradCachePicklerTests tries its best to fully cover/specify this behavior. + """ + SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional") + SAFE_TORCH_FUNCTIONS = ( + "torch.Size", + "torch.Tensor", + "torch.sym_int", + "torch._sym_sqrt", + "torch.sym_float", + "torch.sym_sum", + ) + SAFE_NON_TORCH_FUNCTIONS = ( + "einops.einops.rearrange", + "einops.einops.repeat", + ) + + def is_public_torch_api(target): + # Don't blindly allow private functions in the torch namespace + is_private = target.__name__.startswith("_") + + return ( + getattr(target, "__module__", None) in SAFE_TORCH_MODULES and not is_private + ) + + def is_safe_torch_function(target): + """Allowlisted torch functions""" + function_name = f"{target.__module__}.{target.__name__}" + # Allow torch.autograd.function.FunctionCtx if custom autograd functions are allowed + if function_name == "torch.autograd.function.FunctionCtx": + return ( + torch._functorch.config.autograd_cache_allow_custom_autograd_functions + ) + + # Functions in torch_non_c_binding_in_graph_functions + # are guaranteed to be cache safe. + # See NOTE: [Cacheability of in-graph torch functions] + return ( + function_name in torch_non_c_binding_in_graph_functions + or function_name in SAFE_TORCH_FUNCTIONS + or function_name in torch._inductor.config.unsafe_marked_cacheable_functions + ) + + def is_cacheable_function(target): + if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): + return True + if is_public_torch_api(target): + return True + # Technically, FXGraphCache._check_for_hop already checks this, + # but better to error earlier anyway + if isinstance(target, torch._ops.HigherOrderOperator): + return target.cacheable() + is_builtin_fun_or_type = type(target).__name__ == "builtin_function_or_method" + if is_builtin_fun_or_type: + return True + if is_safe_torch_function(target): + return True + function_name = f"{target.__module__}.{target.__name__}" + if function_name in SAFE_NON_TORCH_FUNCTIONS: + return True + return False + + def is_tensor(target): + # Tensors always have example values in meta field + return "example_value" in target.meta + + # I'd love to use a match statement here, but it wasn't introduced until py3.10 + if node.op == "call_function": + if node.meta and node.meta.get("is_wrapped", False): + # This is fx.wrap function + # By default we BypassAOTAutogradCache for unknown functions, + # But if user explicitly specified cache hash - allow to cache it. + if node.meta.get("user_cache_hash", None): + return + + if not is_cacheable_function(node.target): + module = getattr(node.target, "__module__", None) + name = getattr(node.target, "__name__", None) + raise BypassAOTAutogradCache( + f"Unsupported call_function target {node.target}. \n Function module: {module}, \nFunction name: {name}" + ) + elif node.op == "call_method": + method_name = node.target + method_target = node.args[0] + # Only support method calls on base tensors + if not is_tensor(method_target): + module = getattr(method_target, "__module__", None) + name = getattr(method_target, "__name__", None) + raise BypassAOTAutogradCache( + f"Unsupported call_method target {method_target}. \nMethod module: {module}, \nMethod name: {name}" + ) + if ( + type(method_name) != str + and type(method_name).__name__ != "method_descriptor" + ): + raise BypassAOTAutogradCache( + f"Unsupported call_method method {node.target}: {method_name}" + ) + # Cache safe + elif node.op in ("placeholder", "get_attr", "call_module", "output"): + # Assumption today for call_module being a safe op: + # (1) today the only call_module ops that can show up in a graph come from "built-in-nn-modules" + # that dynamo assumes are safe to trace. If dynamo assumes they are safely to blindly trace, then + # they should be safe to cache as well. + # (2) in the steady-state (some time in H2?) we shouldn't see these anymore, once inline builtin nn modules by default + # (3) We do not allow user made nn modules in the graph today, only function calls. + pass + else: + raise BypassAOTAutogradCache(f"Unsupported node op {node.op}") + + +def check_cacheable(gm: torch.fx.GraphModule): + """ + Checks that the graph module only uses supported operators + """ + nodes = gm.graph.nodes + if torch._inductor.config.freezing: + raise BypassAOTAutogradCache("Cannot cache a graph with freezing enabled") + + if not ( + torch._inductor.config.fx_graph_cache or should_use_remote_fx_graph_cache() + ): + raise BypassAOTAutogradCache("FX graph cache is not enabled") + + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context and tracing_context.fakify_first_call: + raise BypassAOTAutogradCache( + "Won't cache a graph with fakify_first_call enabled" + ) + for node in nodes: + check_node_safe(node) + + # Saved tensors hooks are globally set subgraphs, + # that are not used explicitly in the main graph. + # They are inlined in aot_autograd graphs. + # Subgraphs are only used for caching logic. + if hasattr(gm, "saved_tensors_hooks_pack_0"): + check_cacheable(gm.saved_tensors_hooks_pack_0) # type: ignore[arg-type] + # We have guarantee of unpack sugraph existance if pack subgraph exists + check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type] + + +def check_metadata_cacheable(metadata: ViewAndMutationMeta): + """ + When view replay is turned on, we bypass autograd cache if + the output is aliased. + """ + if config.view_replay_for_aliased_outputs: + for info in metadata.output_info: + if info.functional_tensor is not None: + raise BypassAOTAutogradCache( + "Cannot cache a graph with functional tensor" + ) + + +class AOTAutogradCacheDetails(FxGraphHashDetails): + """ + Object to capture all the details for a dynamo graph module relevant to computing + a safe and stable cache key for AOTAutograd. + """ + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs, + aot_config: AOTConfig, + fx_config: _CompileFxKwargs, + ): + # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info + self.aot_config = aot_config + self.grad_enabled = torch.is_grad_enabled() + self.disable_amp = torch._C._is_any_autocast_enabled() + self.deterministic_algorithms = torch.are_deterministic_algorithms_enabled() + self.autograd_config = config.save_config() + self.saved_tensors_hooks_fx_wrap_cache_hashes: tuple[list[str], list[str]] = ( + [], + [], + ) + + if hasattr(gm, "saved_tensors_hooks_pack_0"): + + def _add_wrapped_user_cache_hashes(_gm, _l): + for node in _gm.graph.nodes: + if node.meta and node.meta.get("is_wrapped", False): + _l.append(node.meta["user_cache_hash"]) + + _add_wrapped_user_cache_hashes( + gm.saved_tensors_hooks_pack_0, + self.saved_tensors_hooks_fx_wrap_cache_hashes[0], + ) + _add_wrapped_user_cache_hashes( + gm.saved_tensors_hooks_unpack_0, + self.saved_tensors_hooks_fx_wrap_cache_hashes[1], + ) + + try: + # FXGraphCache has constraints on what can be pickled in its inductor + # config. Check that the gm is cacheable by inductor first, + # and if it raises an exception, also bypass on our end. + FxGraphCache._check_can_cache(gm) + super().__init__(gm, example_inputs, fx_config, []) + except BypassFxGraphCache as e: + # Sometimes inductor configs are unpickleable and can fail + raise BypassAOTAutogradCache(str(e)) from e + + +class AOTAutogradCachePickler(FxGraphCachePickler): + def __init__(self, gm: torch.fx.GraphModule): + super().__init__(gm) + self.dispatch_table: dict + self.dispatch_table.update( + { + AOTConfig: functools.partial(self._reduce_aot_config), + torch.Tensor: functools.partial(self._reduce_tensor), + } + ) + + def _reduce_aot_config(self, aot_config: AOTConfig): + """ + Reduce the config to a stable key for caching. + """ + return ( + _ident, + ( + aot_config.num_params_buffers, + aot_config.keep_inference_input_mutations, + aot_config.is_export, + aot_config.no_tangents, + aot_config.dynamic_shapes, + aot_config.aot_autograd_arg_pos_to_source, + aot_config.enable_log, + aot_config.pre_dispatch, + ), + ) + + def _reduce_tensor(self, tensor): + """ + Reduce the tensor to a stable key for caching. + """ + metadata = extract_tensor_metadata_for_cache_key(tensor) + return (_ident, (metadata,)) + + +def autograd_cache_key( + gm: torch.fx.GraphModule, + example_inputs, + config: AOTConfig, + fx_config: _CompileFxKwargs, + # TODO: add args and parameters +) -> tuple[str, list[str]]: + """ + Generate a unique hash of the FX graph for caching. + """ + check_cacheable(gm) + if has_triton_package(): + # Due to https://github.com/triton-lang/triton/issues/3729, + # if triton is < 3.2.0, AOTAutogradCache may cause us to + # attempt to load a cache entry without initializing + # the CUDA context on the autograd thread. + + # Without caching, we naturally do this initialization when + # tracing through the graph with the autograd engine. + import triton + + if triton.__version__ < "3.2.0": + raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") + + details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) + pickler = AOTAutogradCachePickler(gm) + # The prefix distinguishes among the other kinds of objects we cache + key = "a" + pickler.get_hash(details) + debug_lines = pickler.debug_lines(details) + log.debug( + "Autograd graph cache hash details for key %s:\n%s", + key, + LazyString(lambda: "\n".join(debug_lines)), + ) + return key, debug_lines + + +TOut = TypeVar("TOut", bound=OutputCode) + + +class InductorOutput(Generic[TOut], ABC): + """ + Class representing a single inductor output + """ + + @abstractmethod + def pre_save(self) -> None: + ... + + @abstractmethod + def load(self, example_inputs) -> TOut: + ... + + @abstractmethod + def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: + ... + + +@dataclass +class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]): + """ + A full compiled fx graph that doesn't need to lookup the FxGraphCache + to run + """ + + result: CompiledFxGraph + + def pre_save(self) -> None: + disk_compiled_graph = copy(self.result) + disk_compiled_graph.prepare_for_serialization() + self.result = disk_compiled_graph + return + + def load(self, example_inputs) -> CompiledFxGraph: + self.example_inputs = example_inputs + + return self.result + + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + constants = CompiledFxGraphConstants() + # Cache hit specific post compile + graph, cache_info = FxGraphCache.cache_hit_post_compile(result, {}, constants) + if graph is None: + raise BypassAOTAutogradCache("Failed to reload cache entry from disk") + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_bundled_cache_hit", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + counters["inductor"]["fxgraph_cache_hit"] += 1 + # Run normal post compile + graph.post_compile(self.example_inputs, constants, fx_config) + return graph + + +@dataclass +class FxGraphCacheLoadable(InductorOutput[CompiledFxGraph]): + fx_graph_cache_info: tuple[str, list[str]] + fx_graph_guard_expr: Optional[str] + + def pre_save(self): + return + + def _is_backward(self) -> bool: + return False + + def load(self, example_inputs) -> CompiledFxGraph: + # [Note: AOTAutogradCache and FXGraphCache Guard interactions] + # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. + # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. + # The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly + # the same as the ones it passes to inductor, for both the forward and backward passes. + # (This does not mean that the tensor values passed in are the same: only that their symints are). + # That is, AOTAutograd and Inductor never create new guards based on symints with different sources + # than those passed to it by inductor. + + # We pass the post compile function, which sets various fx_config boxed values, + # so we can call it only after we're sure both forward and backward have + + # Clear CompiledTritonKernels before loading from FXGraphCache + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + remote_cache = None + constants = CompiledFxGraphConstants() + if should_use_remote_fx_graph_cache(): + remote_cache = FxGraphCache.get_remote_cache() + (cache_key, debug_lines) = self.fx_graph_cache_info + + def check_exact_guard_match(guard_expr, _hints): + """ + AOTAutogradCache tracks its own guards, so we just need to treat these guard expressions as a second + cache key of sorts: we just check for equality, i.e. the FXGraphCache entry with + the exact same guards as we originally saved into the cache. + """ + return guard_expr == self.fx_graph_guard_expr + + result, cache_info = FxGraphCache.load_with_key( + cache_key, + debug_lines, + example_inputs, + local=True, + remote_cache=remote_cache, + is_backward=self._is_backward(), + constants=constants, + evaluate_guards=check_exact_guard_match, + ) + if result is None: + log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_info) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_miss", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + + raise FXGraphCacheMiss + + # No need to log chromium event because AOTAutograd will log that immediately for us + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_hit", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + self.example_inputs = example_inputs + self.constants = constants + return result + + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + """ + Called after FXGraphCacheLoadable.load, mutates fx_config + """ + result.post_compile(self.example_inputs, self.constants, fx_config) + return result + + +@dataclass +class CompiledForward(FxGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + def _is_backward(self) -> bool: + return False + + +@dataclass +class GenericCompiledBackward(InductorOutput[TOut]): + # Used by AOTDispatchAutograd.post_compile + backward_state_indices: list[int] + num_symints_saved_for_bw_: int + + +@dataclass +class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + def _is_backward(self) -> bool: + return True + + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value] + + +# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence +class BundledCompiledForward(CompiledFxGraphLoadable): + pass + + +@dataclass +class BundledCompiledBackward( + GenericCompiledBackward[CompiledFxGraph], CompiledFxGraphLoadable +): + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value] + + +@dataclass +class SerializedGraphModule: + fn: Callable[[dict[Any, Any], str], torch.nn.Module] + args: tuple[Any, ...] + + def __init__(self, gm: torch.fx.GraphModule): + self.fn, self.args = gm.__reduce__() + + def deserialize(self) -> torch.fx.GraphModule: + gm = self.fn(*self.args) + assert isinstance(gm, torch.fx.GraphModule) + return gm + + +def serialize_graph_module(gm: torch.fx.GraphModule) -> SerializedGraphModule: + # NOTE: mutates the graph module + gm.meta = {} + for node in gm.graph.nodes: + node.meta = {} + return SerializedGraphModule(gm) + + +TForward = TypeVar("TForward", bound=InductorOutput) +TBackward = TypeVar("TBackward", bound=GenericCompiledBackward) + + +@dataclass +class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]): + """A single entry into the cache, genericized by Forward and Backward types. + + A TForward is always an InductorOutput of some sort, which represents the + forward graph of the compile. + A TBackward is an InductorOutput + metadata about the backward, useful for specific + backward-only wrappers. This type is encapsulated by GenericCompiledBackward. + + Each AOTAutogradCacheEntry is essentially parameterized by 1. the method of loading + from the cache (either Bundled or UnBundled), and 2. The type of the output. For now, + the only type of output we support is Python Wrapper output, i.e. OutputCode.CompiledFxGraph, + but the same technique works for C++ wrapper code; we'd just add an extra InductorOutput type. + """ + + # Forward and Backward info + compiled_fw: TForward + compiled_bw: Optional[TBackward] + + # Code of the joint graph using print_readable() + # Used for logging purposes + aot_joint_graph_str: Optional[str] + aot_forward_graph_str: Optional[str] + aot_backward_graph_str: Optional[str] + + # Runtime_metadata saved right before compilation + runtime_metadata: ViewAndMutationMeta + + # Wrappers that run after each aot_dispatch_* function + dispatch_wrappers: list[CompilerWrapper] + + # Used by AOTSubclassWrapper + maybe_subclass_meta: Optional[SubclassMeta] + num_fw_outs_saved_for_bw: Optional[int] + + # Used by RuntimeWrapepr + indices_of_inps_to_detach: list[int] + + # Time taken to trace/compile the forward + # forward_time_taken includes AOTAutograd tracing time + inductor compilation time + # backward_time_taken is essentially just the time inductor took to compile + forward_time_taken_ns: int + backward_time_taken_ns: int + + # Used by standalone_compile + sanitized_aot_config: AOTConfig + + guards_expr: Optional[str] + + # Used by Compiled Autograd + serialized_bw_module: Optional[SerializedGraphModule] + + def pre_save(self): + """ + Perform any preparations to make the cache entry ready for serialization. + """ + check_metadata_cacheable(self.runtime_metadata) + self.compiled_fw.pre_save() + if self.compiled_bw is not None: + self.compiled_bw.pre_save() + + # Turn cache entry into the original callable + def wrap_post_compile( + self, + args: list[torch.Tensor], + aot_config: AOTConfig, + fx_config: _CompileFxKwargs, + ) -> Callable: + """ + This function takes a cache entry and carefully reconstructs the original callable + that AOTAutograd returned the first time it was run. It does this by running the various + post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. + + In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. + In the autograd path, this consists of AOTAutogradDispatch.post_compile. + + The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. + + Notably absent from the cached path are: + - DebugAssertWrapper + - FakifiedOutWrapper + + Which we'll handle separately later on, if necessary. + """ + # Log the output of AOTAutogradCache + if aot_config.enable_log: + # TODO: maybe also log to aot_graphs_log + # Unfortunately aot_graphs_log uses + # slightly different formatting though + if self.aot_joint_graph_str is not None: + torch._logging.trace_structured( + "aot_joint_graph", payload_fn=lambda: self.aot_joint_graph_str + ) + + if self.aot_forward_graph_str is not None: + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(self.runtime_metadata), + ) + if self.maybe_subclass_meta is not None: + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(self.maybe_subclass_meta), + ) + + # It's called an inference graph if not running with autograd + name = ( + "aot_forward_graph" + if self.aot_backward_graph_str is not None + else "aot_inference_graph" + ) + torch._logging.trace_structured( + name, payload_fn=lambda: self.aot_forward_graph_str + ) + + if self.aot_backward_graph_str is not None: + torch._logging.trace_structured( + "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str + ) + with dynamo_timed("AOTAutogradCache.inductor_load"): + compiled_fw_func = self.compiled_fw.load(args) + compiled_bw_func = None + if self.compiled_bw is not None: + compiled_bw_func = self.compiled_bw.load(args) + needs_autograd = True + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + # Now that we've loaded forward and backward, call post compile on both + # This avoids setting things like BoxedBools in fx_config until + # after both forward and backward cache hit + fw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + bw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": True, + } + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, fw_fx_config + ) + compiled_bw_func = self.compiled_bw.post_compile( + compiled_bw_func, bw_fx_config + ) + else: + inference_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + + needs_autograd = False + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, inference_fx_config + ) + + # Wrap the forward function in post compile wrappers + compiled_fw_func = AOTDispatchSubclassWrapper( + trace_joint=needs_autograd, + fw_only=None, + maybe_subclass_meta=self.maybe_subclass_meta, + num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + req_subclass_dispatch = self.maybe_subclass_meta is not None + CompileEventLogger.try_add_pt2_compile( + "backend_compile", requires_subclass_dispatch=req_subclass_dispatch + ) + + # In autograd case, functionalizedRngWrapper should not modify outs + return_new_outs = not needs_autograd + compiled_fw_func = FunctionalizedRngRuntimeWrapper( + return_new_outs=return_new_outs + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + disable_amp = torch._C._is_any_autocast_enabled() + + if needs_autograd: + assert self.compiled_bw is not None + + cached_lazy_backward = None + if self.serialized_bw_module is not None: + cached_lazy_backward = CachedAutogradLazyBackwardCompileInfo( + self.serialized_bw_module.deserialize + ) + # This function is run on both cache miss and cache hit, either here + # or in aot_dispatch_autograd. On a cache hit, + # 1. the bw is already compiled + # 2. we don't need to save to the cache again + # so those corresponding arguments are set to None. + compiled_function = AOTDispatchAutograd.post_compile( + compiled_fw_func, + compiled_bw_func, + self.maybe_subclass_meta, + self.compiled_bw.num_symints_saved_for_bw_, + self.compiled_bw.backward_state_indices, + disable_amp, + self.indices_of_inps_to_detach, + cached_lazy_backward, + aot_config, + fw_metadata=self.runtime_metadata, + try_save_cache_entry=None, + ) + else: + compiled_function = RuntimeWrapper( + indices_of_inps_to_detach=self.indices_of_inps_to_detach, + trace_joint=False, + disable_amp=disable_amp, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + compiled_function, _ = post_compile( + self.dispatch_wrappers, + compiled_function, + aot_config, + runtime_metadata=self.runtime_metadata, + ) + + # Now that we're pretty sure it's a successful load, add guards + # to the existing shape environment from the cache + if self.guards_expr: + symints = AOTAutogradCache._filter_backed_symints(args) + check = bool(AOTAutogradCache.evaluate_guards(self.guards_expr, symints)) + assert check is True + + return compiled_function + + +class AOTAutogradCacheEntry( + GenericAOTAutogradCacheEntry[CompiledForward, CompiledBackward] +): + """ + Regular AOTAutogradCacheEntry: saves the forward/backward FxGraphCache keys + and looks them up in FxGraphCache on load + """ + + +class BundledAOTAutogradCacheEntry( + GenericAOTAutogradCacheEntry[BundledCompiledForward, BundledCompiledBackward] +): + """ + AOTAutogradCacheEntry where we save the entire CompiledFxGraph instead + of relying on cache keys from FxGraphCache + """ + + +@contextlib.contextmanager +def sanitize_gm_for_cache(gm: torch.fx.GraphModule): + """ + Clears a few fields in a dynamo supplied Graph Module that are not stable between graph inputs, but don't + affect inductor or aotdispatch correctness. + + These fields **can** be used by code calling into aotdispatch (namely, dynamo), so we can't null them out completely. + + To ensure that these fields are not accessed by inductor or aotdispatch, we clear them during AOTAutogradCache.load, + and then put them back before returning. This way, we generate a cache key based off of a canonical graph + without these fields, and also guarantee they aren't used to affect the cache's output. + """ + IGNORED_FIELDS = ( + "meta", # metadata used by export + "compile_subgraph_reason", # Used by dynamo only for logging, no change in inductor/autograd behavior + "_param_name_to_source", # Encapsulated by aot_config.aot_autograd_arg_pos_to_source + "_backend_id", + ) + saved_fields = {} + for field in IGNORED_FIELDS: + saved_fields[field] = getattr(gm, field, None) + # Clear the field + setattr(gm, field, None) + try: + yield + finally: + # Put the fields back after dispatch_and_compile is complete + for field, value in saved_fields.items(): + setattr(gm, field, value) + + +@CacheArtifactFactory.register +class AOTAutogradCacheArtifact(CacheArtifact): + @override + def populate_cache(self): + AOTAutogradCache._write_to_local_cache(self.key, self.content) + + @override + @staticmethod + def type(): + return "aot_autograd" + + +@CacheArtifactFactory.register +class BundledAOTAutogradCacheArtifact(PrecompileCacheArtifact[Callable]): + @override + @staticmethod + def type(): + return "precompile_aot_autograd" + + @override + def after_deserialization(self) -> Callable: + entry = pickle.loads(self.content) + # In the precompile use case, guards are already serialized + # by dynamo, so we don't need to add them to the environment + entry.guards_expr = None + # TODO: this isn't exactly right, because cudagraphs needs to be a shared config + # which is set by compile_fx. But in precompile, we never actually call compile_fx + # so we don't have a place to track cudagraphs here. + cudagraphs = torch._inductor.config.triton.cudagraphs + compiled_fn = entry.wrap_post_compile( + [], entry.sanitized_aot_config, {"cudagraphs": cudagraphs} + ) + + # TODO: this ignores flat_params, which can exist + # if inline_builtin_nn_modules=False + def forward(*runtime_args: tuple[Any]): + return compiled_fn(list(runtime_args)) + + return forward + + +class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]): + """ + Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas + AOTAutogradCacheEntry handles the wrapping/unwrapping logic. + + Cache Inputs (AOTAutogradCacheDetails) + - AOTAutogradCache takes in the following inputs, which are analogous to inputs given + to AOTAutograd by dynamo: + - A fx graph module generated by dynamo + - A list of args, which consists of: + - Symint inputs to the graph, generated by dynamo + - The **real tensor** inputs, which inductor uses for cudagraphs + - Notably, the real tensor inputs don't have symints in their metadata. + AOTAutograd then retraces those real tensor arguments into FakeTensors later during execution. + - A set of global configurations that affect AOTAutograd or Inductor behavior. + + It then generates a cache key given these values. Notably, this means AOTAutogradCache currently + specializes on the sizes and strides of the real tensor inputs when dynamic shapes are turned on. + In a later PR, we'll likely generate the cache key based on the FakeTensors AOTAutograd generates + based on the real tensor inputs, which can contain symints. + + # Cache Outputs (AOTAutogradCacheEntry) + - AOTAutogradCache caches the following values: + - The compiled forward and backward functions from inductor, via keys to the FXGraphCache + - Metadata to reconstruct the AOTModule from the compiled inductor artifacts + - See AOTAutogradCacheEntry for more info + + [Note: Caching guards generated by AOTAutograd and Inductor] + AOTAutograd and inductor both can introduce new guards to the shape environment. FXGraphCache saves guards with each + compiled graph inductor generates. On a cache hit, AOTAutograd reloads the compiled forward and backward functions + from FXGraphCache, giving it new symint arguments from the input args. + FXGraphCache uses those symints and its saved guards to repopulate the ShapeEnv with guards. + **No new guards are generated into the shape env after inductor finishes compiling**, so the guards + saved by inductor are sufficient for correctness for both AOTAutograd and Inductor's caches. + """ + + @staticmethod + def clear(): + """Clear the cache""" + try: + shutil.rmtree(AOTAutogradCache._get_tmp_dir()) + except FileNotFoundError: + pass + + @staticmethod + def load( + dispatch_and_compile: Callable, + mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper], + args, + aot_config: AOTConfig, + cudagraphs: BoxedBool, + boxed_forward_device_index: Optional[BoxedDeviceIndex], + local: bool, + remote: bool, + ) -> Callable: + """ + Load a result from the cache, and reconstruct a runtime wrapper around the object + """ + gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod + with sanitize_gm_for_cache(gm): + compiled_fn = None + cache_info: dict[str, Any] = {} + cache_key = None + debug_lines: list[str] = [] + cache_event_time = time.time_ns() + cache_state = None + fx_config: _CompileFxKwargs = { + "cudagraphs": cudagraphs, + "boxed_forward_device_index": boxed_forward_device_index, + } + try: + cache_key, debug_lines = autograd_cache_key( + gm, args, aot_config, fx_config + ) + entry: Optional[ + GenericAOTAutogradCacheEntry + ] = AOTAutogradCache._lookup( + cache_key, local, remote, args, cache_info, aot_config + ) + if entry is not None: + compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config) + log.info("AOTAutograd cache hit for key %s", cache_key) + + counters["aot_autograd"]["autograd_cache_hit"] += 1 + cache_state = "hit" + cache_event_time = time.time_ns() + forward_time_saved = entry.forward_time_taken_ns // 1e6 + backward_time_saved = entry.backward_time_taken_ns // 1e6 + cache_info.update( + { + "forward_time_saved_ms": forward_time_saved, + "backward_time_saved_ms": backward_time_saved, + "time_saved_ms": forward_time_saved + backward_time_saved, + } + ) + time_saved_ns = ( + entry.forward_time_taken_ns + entry.backward_time_taken_ns + ) + # TODO: should we use the same field for remote cache time saved for both + # FXGraphCache and AOTAutogradCache? + # get_metrics_context().increment(...) + if ( + ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( + time_saved_ns + ) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase + + if compiled_fn is None: + log.info("AOTAutograd cache miss for key %s", cache_key) + counters["aot_autograd"]["autograd_cache_miss"] += 1 + cache_state = "miss" + cache_event_time = time.time_ns() + # Count missing the FXGraphCache as a miss not a bypass + except FXGraphCacheMiss as e: + counters["aot_autograd"]["autograd_cache_miss"] += 1 + cache_state = "miss" + if config.strict_autograd_cache: + raise e + # Most often this is BypassAOTAutogradCache, but + # if there's ever different reason we can't cache, + # we still never want to hard throw an exception, since + # we can always fallback to a cache bypass. + # As an example, if the user calls autograd via + # standalone inductor, we will sometimes get a GraphModule + # that doesn't actually have a `.graph` on it. Instead + # of checking every single case, we safely catch the exception + # in those cases. + except Exception as e: + cache_key = None + counters["aot_autograd"]["autograd_cache_bypass"] += 1 + log.info("Bypassing autograd cache due to: %s", e) + cache_state = "bypass" + cache_event_time = time.time_ns() + cache_info["cache_bypass_reason"] = str(e) + cache_info["cache_bypass_exception_type"] = type(e).__name__ + cache_info["cache_bypass_traceback"] = traceback.format_exc().split( + "\n" + ) + # TODO: this gets logged implicitly by cache_bypass_reason, + # and here we explicitly log it into tlparse. + # We may want to log this as an extra column in Scuba, though. + cache_info["cache_bypass_hard_exception"] = not isinstance( + e, BypassAOTAutogradCache + ) + if remote: + log_cache_bypass("bypass_aot_autograd", str(e)) + if config.strict_autograd_cache: + raise e + if compiled_fn is None: + # Set the cache key so we can save a cache result later + symints = AOTAutogradCache._filter_backed_symints(args) + if cache_key is not None: + aot_config.cache_info = AOTAutogradCacheInfo( + cache_key, + time.time_ns(), + forward_symints=symints, + ) + compiled_fn = dispatch_and_compile() + + cache_info.update( + { + "key": cache_key, + "cache_state": cache_state, + "components": debug_lines, + } + ) + if chromium_event_log_active(): + CompileEventLogger.instant( + f"autograd_cache_{cache_state}", + metadata=cache_info, + time_ns=cache_event_time, + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", + cache_state=cache_state, + cache_event_time=cache_event_time, + key=cache_info.get("key"), + components=cache_info.get("components"), + cache_bypass_reason=cache_info.get("cache_bypass_reason"), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"aotautograd_cache_{cache_state}", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + return compiled_fn + + @classmethod + def generate_guards_expression( + cls: type[AOTAutogradCache], cache_info: AOTAutogradCacheInfo + ) -> Optional[str]: + shape_env = cls._get_shape_env() + assert shape_env is not None + symints = cache_info.forward_symints + guards = shape_env.get_pruned_guards(symints) + return shape_env.produce_guards_expression(placeholders=symints, guards=guards) + + @classmethod + def _get_tmp_dir(cls: type[AOTAutogradCache]) -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cache_dir(), "aotautograd") + + @classmethod + def _get_tmp_dir_for_key(cls: type[AOTAutogradCache], key) -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cls._get_tmp_dir(), key) + + @staticmethod + def evaluate_guards(guard_expr: str, hints: Union[list[int], list[torch.SymInt]]): + if torch._inductor.config.unsafe_skip_cache_dynamic_shape_guards: + return True + shape_env = AOTAutogradCache._get_shape_env() + assert shape_env is not None + result = shape_env.evaluate_guards_expression(guard_expr, hints) + return result + + @staticmethod + def _lookup( + key: str, + local: bool, + remote: bool, + args: list[Any], + cache_info: dict[str, Any], + aot_config: Optional[AOTConfig], + ) -> Optional[GenericAOTAutogradCacheEntry]: + """Given a key generated by AOTAutogradCachePickler, look up its location in the cache.""" + remote_cache: Optional[RemoteCache[JsonDataTy]] = None + if remote: + remote_cache = AOTAutogradCache.get_remote_cache() + + symints = AOTAutogradCache._filter_backed_symints(args) + hints = [hint_int(s) for s in symints] + entry = None + try: + ( + entry, + pickled_content, + guard_info, + ) = AOTAutogradCache.find_guarded_entry( + key, local, remote_cache, AOTAutogradCache.evaluate_guards, hints + ) + + if entry is None and guard_info["cache_status_detailed"] == "guard_miss": + counters["aot_autograd"]["autograd_cache_guard_miss"] += 1 + cache_info.update(guard_info) + if pickled_content is not None: + CacheArtifactManager.record_artifact( + AOTAutogradCacheArtifact.type(), key, pickled_content + ) + if ( + config.bundled_autograd_cache + and aot_config is not None + and aot_config.precompile_backend_id is not None + ): + # NB: We don't want to use the cached aot_config.precompile_backend_id + # 1. because we set it to None on save 2. even if we didn't, this new run + # that cache hit has a *new* backend id associated with it. + PrecompileContext.record_artifact( + BundledAOTAutogradCacheArtifact.type(), + aot_config.precompile_backend_id, + pickled_content, + ) + except Exception as e: + log.info("AOTAutograd cache unable to load compiled graph: %s", e) + if config.strict_autograd_cache: + raise e + return entry + + @staticmethod + def _write_to_local_cache(key: str, content: bytes): + """Write an entry to the local cache.""" + subdir = AOTAutogradCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized entry to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + log.info("Writing AOTAutograd cache entry to %s", path) + write_atomic(path, content) + + @staticmethod + def save(key: str, entry: GenericAOTAutogradCacheEntry, remote: bool): + """Save a single entry into the cache.""" + try: + entry.pre_save() + content = pickle.dumps(entry) + CacheArtifactManager.record_artifact( + AOTAutogradCacheArtifact.type(), key, content + ) + if ( + config.bundled_autograd_cache + and entry.sanitized_aot_config.precompile_backend_id is not None + ): + precompile_key = entry.sanitized_aot_config.precompile_backend_id + # Now that we're saving it, the precompile_backend_id field is no longer + # useful, remove it from the entry. + entry.sanitized_aot_config.precompile_backend_id = None + PrecompileContext.record_artifact( + BundledAOTAutogradCacheArtifact.type(), precompile_key, content + ) + AOTAutogradCache._write_to_local_cache(key, content) + counters["aot_autograd"]["autograd_cache_saved"] += 1 + except BypassAOTAutogradCache as e: + counters["aot_autograd"]["autograd_cache_bypass"] += 1 + log.info("Bypassing autograd cache due to: %s", e) + if remote: + log_cache_bypass("bypass_aot_autograd", str(e)) + return None + except Exception as e: + log.info("AOTAutograd cache unable to serialize compiled graph: %s", e) + if remote: + log_cache_bypass( + "bypass_aot_autograd", "Unable to serialize: " + str(e) + ) + if config.strict_autograd_cache: + raise e + return None + + if remote: + remote_cache: Optional[ + RemoteCache[JsonDataTy] + ] = AOTAutogradCache.get_remote_cache() + if remote_cache is not None: + time_taken_ms = int( + (entry.forward_time_taken_ns + entry.backward_time_taken_ns) // 1e6 + ) + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + "time_taken_ms": time_taken_ms, + } + remote_cache.put(key, cache_data) + + @staticmethod + @functools.cache + def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: + """ + Attempts to load the remote cache, returns None on error. + """ + cache_id = "autograd-experimental" + return create_cache( + cache_id, + config.is_fbcode(), + "FbRemoteAOTAutogradCache", + "RemoteAOTAutogradCache", + ) + + @staticmethod + def make_entry( + compiled_fw_func: CompiledFxGraph, + compiled_bw_func: Optional[CompiledFxGraph], + aot_joint_graph_str: Optional[str], + aot_forward_graph_str: Optional[str], + aot_backward_graph_str: Optional[str], + runtime_metadata: ViewAndMutationMeta, + dispatch_wrappers: list[CompilerWrapper], + maybe_subclass_meta: Optional[SubclassMeta], + num_fw_outs_saved_for_bw: Optional[int], + indices_of_inps_to_detach: list[int], + forward_time_taken_ns: int, + backward_time_taken_ns: int, + sanitized_aot_config: AOTConfig, + guards_expr: Optional[str], + backward_state_indices: Optional[list[int]], + num_symints_saved_for_bw: Optional[int], + serialized_bw_module: Optional[SerializedGraphModule], + ) -> GenericAOTAutogradCacheEntry: + if config.bundled_autograd_cache: + # Helper function to unwrap all the wrappers we added during aotdispatch + # They get reapplied on cache load + def unwrap_compiled_fx_graph(obj): + while hasattr(obj, "__wrapped__"): + obj = obj.__wrapped__ + assert isinstance(obj, CompiledFxGraph) + return obj + + compiled_fw_graph = unwrap_compiled_fx_graph(compiled_fw_func) + bundled_compiled_forward = BundledCompiledForward(compiled_fw_graph) + bundled_compiled_backward = None + if compiled_bw_func is not None: + assert backward_state_indices is not None + assert num_symints_saved_for_bw is not None + compiled_bw_graph = unwrap_compiled_fx_graph(compiled_bw_func) + bundled_compiled_backward = BundledCompiledBackward( + compiled_bw_graph, backward_state_indices, num_symints_saved_for_bw + ) + + return BundledAOTAutogradCacheEntry( + compiled_fw=bundled_compiled_forward, + compiled_bw=bundled_compiled_backward, + aot_joint_graph_str=aot_joint_graph_str, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=aot_backward_graph_str, + runtime_metadata=runtime_metadata, + dispatch_wrappers=dispatch_wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + indices_of_inps_to_detach=indices_of_inps_to_detach, + forward_time_taken_ns=forward_time_taken_ns, + backward_time_taken_ns=backward_time_taken_ns, + sanitized_aot_config=sanitized_aot_config, + guards_expr=guards_expr, + serialized_bw_module=serialized_bw_module, + ) + + else: + fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None) + fw_debug_lines = getattr( + compiled_fw_func, "_fx_graph_cache_debug_lines", [] + ) + + assert fw_key is not None + compiled_forward = CompiledForward( + fx_graph_cache_info=(fw_key, fw_debug_lines), + fx_graph_guard_expr=getattr(compiled_fw_func, "guards_expr", None), + ) + compiled_backward = None + if compiled_bw_func is not None: + bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None) + bw_debug_lines = getattr( + compiled_bw_func, "_fx_graph_cache_debug_lines", [] + ) + assert bw_key is not None + assert backward_state_indices is not None + assert num_symints_saved_for_bw is not None + compiled_backward = CompiledBackward( + fx_graph_cache_info=(bw_key, bw_debug_lines), + fx_graph_guard_expr=getattr(compiled_bw_func, "guards_expr", None), + backward_state_indices=backward_state_indices, + num_symints_saved_for_bw_=num_symints_saved_for_bw, + ) + + return AOTAutogradCacheEntry( + compiled_fw=compiled_forward, + compiled_bw=compiled_backward, + aot_joint_graph_str=aot_joint_graph_str, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=aot_backward_graph_str, + runtime_metadata=runtime_metadata, + dispatch_wrappers=dispatch_wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + indices_of_inps_to_detach=indices_of_inps_to_detach, + forward_time_taken_ns=forward_time_taken_ns, + backward_time_taken_ns=backward_time_taken_ns, + sanitized_aot_config=sanitized_aot_config, + guards_expr=guards_expr, + serialized_bw_module=serialized_bw_module, + ) diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb622c83db84e17300da95b45a02da23b75e151 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -0,0 +1,813 @@ +# mypy: allow-untyped-defs +""" +This module is one of the analysis modules - it takes as input a function or graph +and some preexisting properties, and returns some data that is useful for deciding +how to further proceed with compilation or construct runtime wrappers. + +In particular, the analysis here constructs view and mutation metadata from running +a functionalized version of the graph under compilation. +""" + +import collections +import contextlib +import logging +from functools import wraps +from typing import Callable, Optional + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._guards import detect_fake_mode +from torch._logging import getArtifactLogger +from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode +from torch._subclasses.meta_utils import safe_is_leaf +from torch.fx.experimental.symbolic_shapes import is_concrete_int +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + transform_subclass, +) + +from .functional_utils import ( + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + from_fun, + has_data_mutation, + has_metadata_mutation, + MetadataKey, + to_fun, + was_inductor_storage_resized, +) +from .schemas import ( + FunctionalTensorMetadataEq, + InputAliasInfo, + MemoryFormatMeta, + MutationType, + OutputAliasInfo, + OutputType, + ViewAndMutationMeta, +) +from .subclass_utils import create_subclass_meta +from .utils import _get_autocast_states, KNOWN_TYPES, strict_zip + + +zip = strict_zip + +log = logging.getLogger(__name__) +static_input_logger = getArtifactLogger("torch._dynamo", "cudagraph_static_inputs") + + +# Note [Tangents memory format] +# We assume tangents memory format to be similar to corresponding output's memory_format. +# The idea is that we are technically making a guess about the strides of our tangents, +# while we trace out the joint. +# If runtime specfied tangents will not have the same memory format as predicted traced tangents, +# we coerce them at runtime to traced tangents memory format. + + +# Coercing and collecting traced tangents memory format in one recursive traversal +# mypy: ignore-errors +def coerce_tangent_and_suggest_memory_format(x: Tensor): + updated = False + if not isinstance(x, Tensor): + return x, None, updated + + out = x.detach() + + is_subclass = is_traceable_wrapper_subclass(out) + + memory_format = MemoryFormatMeta.from_tensor(out) + + if memory_format.memory_format is not None: + was = out + out = out.contiguous(memory_format=memory_format.memory_format) + updated = was is not out + + # For subclass we keep memory format of outer strides at the beggining of the list + out_memory_format = [memory_format] if is_subclass else memory_format + + # Note [Tangents memory format, Part 2] + # In the same way that "what strides do we assigns to our tangents" is a question + # that we can not answer (and therefore have to guess) as we trace the backward ahead-of-time, + # The same applies to any tensor subclass metadata, when we have tangents that are subclasses. + # To handle this situation, we have two new methods that a tensor subclass can implement: + # (1) __coerce_tangent_metadata__(self) + # Given a subclass with "non-standard" metadata, turn it into a new subclass with "normal" metadata. + # The main example here is a DTensor with the "_Partial" placement. + # If we have a forward output with a _Partial placement, and corresponding tangent + # with a Replicate/Shard placement, we have no way to convert the tangent "back" to a _Partial placement. + # This method lets us avoid the problem entirely by allowing subclasses to ensure that we can never + # have a tangent with "problematic" metadata, that we cannot convert to. + # (1) __coerce_same_metadata_as_tangent__(self, metadata) + # Given a subclass, and a target differing metadata, + # convert self to have the same metadata as the target. + # With DTensor being the main example, we can use this to convert a DTensor with a Replicate() + # placement into one with a Shard() placement, in the case that we "guessed wrong", + # and traced tangents with a Shard() placement at compile time. + # + if is_subclass and hasattr(out, "__coerce_tangent_metadata__"): + out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined] + + if is_subclass: + attrs = out.__tensor_flatten__()[0] + + for attr in attrs: + elem = getattr(out, attr) + ( + new_elem, + new_elem_memory_format, + elem_updated, + ) = coerce_tangent_and_suggest_memory_format(elem) + out_memory_format.append(new_elem_memory_format) + if elem_updated: + setattr(out, attr, new_elem) + + return out, out_memory_format, updated + + +# This is a version of functionalization that is specifically designed +# for the AOTAutograd use case. +# +# Unlike functorch's variant, this doesn't use the functorch level system, +# instead it directly uses PyTorch's conventional dispatcher to hit the +# functionalization key. In particular, this means that FunctionalTensorWrapper +# can have autograd data stored directly on it. +# +# In typical AOTAutograd usage, the dispatch key order will look like: +# +# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor +# outer tensor inner tensor +# +# Returns: +# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and +# The list of outputs from the forward, but **only** the outputs that we need +# to pass in as tangents into the backward. +# Specifically, aliased outputs from the forward get regenerated, and don't participate +# in the compiled backward function. +def run_functionalized_fw_and_collect_metadata( + f, + *, + keep_input_mutations: bool, + # TODO: refactor to kill this flag + is_train: bool = False, + # Note: this is guaranteed to be set when running under dynamo + static_input_indices: Optional[list[int]] = None, + pre_dispatch: bool = False, + # is_export is technically only needed to avoid using functionalization V2 + # during analysis + is_export: bool = False, +) -> Callable[..., ViewAndMutationMeta]: + memo: dict[Tensor, Tensor] = {} + + def _to_fun(t): + if isinstance(t, Tensor): + if t in memo: + return memo[t] + r = to_fun(t) + memo[t] = r + return r + else: + return t + + @wraps(f) + def inner(*flat_args): + # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args. + assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args) + + input_info: list[InputAliasInfo] = [] + output_info: list[OutputAliasInfo] = [] + + prior_grad_enabled = torch.is_grad_enabled() + prior_autocast_states = _get_autocast_states() + + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + # It doesn't matter if we run this under predispatch or not because it is + # only for figuring out metadata + mode = FunctionalTensorMode(_allow_token_discovery=True, export=is_export) + suppress_pending = contextlib.nullcontext() + fake_mode = detect_fake_mode() + if fake_mode and (shape_env := fake_mode.shape_env): + suppress_pending = shape_env.ignore_fresh_unbacked_symbols() + with disable_above, mode, suppress_pending: + # precondition: The passed in function already handles unflattening inputs + flattening outputs + flat_f_args = pytree.tree_map(_to_fun, flat_args) + flat_f_outs = f(*flat_f_args) + # We didn't do any tracing, so we don't need to process the + # unbacked symbols, they will just disappear into the ether. + # Also, prevent memoization from applying. + if fake_mode: + fake_mode.epoch += 1 + fake_mode.reset_nt_tensor_id_counter() + + if prior_autocast_states != _get_autocast_states(): + raise RuntimeError( + "AOTAutograd does not support tracing graphs that mutate the autocast state. " + "Dynamo will only insert autocast context managers (e.g. with torch.autocast(..)) into the graph, " + "which will unwind all of their mutations to autocast state before the graph exits. " + "If you encounter this error while using torch.compile, please file a bug." + ) + + # Inspect the state of the input tensor functional wrapper to detect input mutation info + # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version + for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)): + # NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in + # strides between the functionalized arg inner tensors and non-functionalized arg inner + # tensors. This is a problem as the inner tensor stride change may not be reflected + # correctly in the outer tensor, so disallow this for now. + mutates_data = has_data_mutation(f_arg) + mutates_metadata = has_metadata_mutation( + f_arg, arg, check_only_storage_mutation=False + ) + if mutates_metadata and is_traceable_wrapper_subclass(arg): + raise RuntimeError( + "Metadata mutations are currently not allowed on tensor subclasses" + ) + mutates_storage_metadata = has_metadata_mutation( + f_arg, arg, check_only_storage_mutation=True + ) + mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd( + f_arg + ) + mutations_under_no_grad_or_inference_mode = ( + mutates_data + and are_all_mutations_under_no_grad_or_inference_mode(f_arg) + ) + mutation_inductor_storage_resize = was_inductor_storage_resized(f_arg) + + if mutates_storage_metadata: + mutates_data = False + + requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad + + input_info.append( + InputAliasInfo( + is_leaf=isinstance(arg, Tensor) and safe_is_leaf(arg), + mutates_data=mutates_data, + mutates_metadata=mutates_metadata, + mutations_hidden_from_autograd=mutations_hidden_from_autograd, + mutates_storage_metadata=mutates_storage_metadata, + mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, + mutation_inductor_storage_resize=mutation_inductor_storage_resize, + requires_grad=requires_grad, + keep_input_mutations=keep_input_mutations, + ) + ) + + # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate, + # We need to make sure our graph returns the _base as a graph output, and we manually recreate the view + # to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad + # on the base tensor, but we are obligated to properly set requires-gradness on the real output. + + inp_storage_refs = { + StorageWeakRef(inpt.untyped_storage()): idx + for idx, inpt in enumerate(flat_f_args) + if isinstance(inpt, Tensor) + } + + # We need inp tensor id's to be able to tell if an outputs **are** inputs. + inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, Tensor)} + # We need output tensor id's to tell if any output._base` attributes **are** other outputs. + # (This is also a dict because we need to know that output's index, so we can regenerate + # the alias from it). + out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)} + + # Keep track of which outputs alias other outputs + out_tensor_alias_counts: collections.defaultdict = collections.defaultdict(int) + # This tells us, for a given group of outputs that alias each other, + # whether they e.g. all came from an unbind call + num_aliased_tensors_that_are_multi_output_views: collections.defaultdict = ( + collections.defaultdict(int) + ) + + out_storage_to_metadata_key_to_tensors: collections.defaultdict[ + Optional[StorageWeakRef], + collections.defaultdict[MetadataKey, set[torch.Tensor]], + ] = collections.defaultdict(lambda: collections.defaultdict(set)) + + curr_storage = None + for o in flat_f_outs: + if isinstance(o, torch.Tensor): + curr_storage = StorageWeakRef(o.untyped_storage()) + out_tensor_alias_counts[curr_storage] += 1 + # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + # This is an optimization on top of the "alias of intermediates" logic, + # which you can read more about under Note [AOT Autograd: outputs aliasing inputs or intermediates!] + # + # Before describing the optimization: this is important for AOTAutograd to have good + # perf around, multi-output views. HOWEVER: + # - There is a more generic change to AOTAutograd that we'd like to make, that subsumes this case, + # around using pre-dispatch tracing to partition out a graph so we can faithfully replay all + # views without having to regenerate them at runtime. + # - It's loosely described in this doc (more details will be added soon): + # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit + # - Once that change lands, we should just rip out this "optimization", since: + # (1) It will be fully unnecessary + # (2) Although it is only a few lines of code, it is a bit difficult to reason about + # its correctness with the autograd engine in all cases. + # + # + # What is this optimization? Consider the below case: + # def f(x): + # intermediate = x.mul(2) + # # x and intermediate here require grad + # o1, o2, ... o10 = intermediate.unbind(-1) + # return intermediate, o1, o2, ... o10 + # Now, the "intermediate base" handling in AOTAutograd implies that we must do the following: + # (1) return "intermediate as an extra output of the compiled graph + # (2) regenerate each aliased output off of "intermediate", **outside** of the autograd.Function. + # The reason AOTAutograd ordinarily does this is for safety: the autograd engine needs to know + # that o1 through o10 are all aliased, and if we blindly return o1 through o10 from the autograd.Function, + # this information will be hidden. + # In particular, mutating one alias might require autograd to update autograd metadata on the other aliases + # (like their grad_fn, for example, when the autograd engine needs to do view-replay). + # + # However, intermediate_base logic can be bad for backward performance (we sometimes generate + # as_strided calls during the intermediate base logic, which can have a slow backward formula). + # Is it possible to find a set of conditions where it is **safe** to hide the output aliasing from autograd? + # + # For a set of outputs of the graph that alias each other, o_1...o_k, consider: + # (1) They came from the same multi-output view op, e.g. o_1, ..., o_k = intermediate.unbind(0) + # (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate), + # **at most** 1 can escape from the graph (e.g. there is not some other graph input/output + # o_other, that aliases these outputs) + # (3) o_1...o_k all require_grad, they all share the same ._base, and their ._base requires grad. + # This condition is important because it's what causes slowness in the intermediate_base + # codepath of aot_autograd. Ordinarily, o_1...o_k would all get a grad_fn, and + # aot_autograd's view-replay might give each output an AsStridedBackward as its grad_fn. + # "K" AsStridedBackward calls will be *much* slower than a single UnbindBackward. + # In this setup, is it possible to mutate one of the outputs o_i in a way that would affect the autograd meta + # of the other aliases? + # + # Claim: No! Consider a few example (which I'm pretty sure cover all cases of mutation w.r.t. autograd): + # (a) What happens if we mutate any of o_1 through o_k directly? + # Autograd raises an error: + # "RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is + # the output of a function that returns multiple views. Such functions do not allow the output + # views to be modified inplace. You should replace the inplace operation by an out-of-place one." + # (b) What if we take a view of o_k and mutate it, o_k.view(o_k.shape).mul_(2)? + # Autograd raises the same error- the "multi-output-view"ness of an alias propagates to future views. + # (c) What if we mutate o_k under no_grad? + # Autograd raises the same error + # (d) What if we detach and mutate, e.g. o_k.detach().mul_(2)? + # Autograd allows this, *but* autograd updates all alias's grad_fn's to be error functions when accessed. + # Autograd raises the same error + # (e) What if we try to mutate another alias of o_1...o_k, that was **not** created from a multi-output view? + # We promised that there is at most **one** such alias, e.g. intermediate in the example above. + # You can mutate intermediate, but in eager mode this will change the grad_fn of o_1...o_k + # to be error fn's. + # Since intermediate was the *only* non-multi-output-alias, there are no other aliases + # of `intermediate` around that were produced by the compiled fn and have a valid grad_fn. + # + # Coming back to this optimization: + # Given that it is not possible for mutating one of these aliases to affect the autograd metadata of another alias + # without causing an error in eager mode, we will simple hide the aliasing from autograd during torch.compile + # if all of the above conditions are met. + # This has the slight downside that it's possible to write some "bad" code that autograd will raise an error on + # in eager but fail to during torch.compile, but it has the benefit that this code has much better performance. + # NOTE: if and when we eventually update AOTAutograd to do the "view graph slicing" defined here: + # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit, + # then this optimization will probably matter less and might be ok to remove. + is_cur_tensor_multi_out_view = isinstance( + o, FunctionalTensor + ) and torch._functionalize_is_multi_output_view( # type: ignore[attr-defined] + o.elem + ) + if is_cur_tensor_multi_out_view: + num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1 + if o.requires_grad: + out_storage_to_metadata_key_to_tensors[curr_storage][ + MetadataKey.make(o) + ].add(o) + + # maps the id of an intermediate base to its index in the output of the compiled forward + intermediate_base_tensor_id_to_output_idx: dict[int, int] = {} + intermediate_bases: list[torch.Tensor] = [] + # Why Do We Care If Storage Changed? + # It's important to understand the implications of storage changes in complex scenarios. Take this example: + # + # def f(x): + # x_storage = x.untyped_storage() + # non_leaf_tensor = torch.ones(4, requires_grad=True).clone() + # + # # Using no_grad() and _unsafe_preserve_version_counter to simulate the .data = operation + # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): + # x.set_(non_leaf_tensor.untyped_storage()) + # + # out = x.view(-1) + # + # # Restoring x to its original storage, again simulating .data = operation + # with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x): + # x.set_(x_storage) + # + # return out + # + # In this scenario, 'x' and 'out' have different shapes and are stored at different memory addresses, aka no aliasing. + # However, due to how set_() and more specificlaly, set is functionalized, is defined to preserve eager semantics, + # the autograd engine mistakenly assumes that 'x' and 'out' are aliased, treating 'x' as 'out._base'. + # This misinterpretation leads to an 'alias_of_input' flag, causing an unnecessary as_strided() call to be generated, + # which could lead to issues later in the code. + for o in flat_f_outs: + functional_tensor_storage_changed = isinstance( + o, FunctionalTensor + ) and torch._functionalize_was_storage_changed( # type: ignore[attr-defined] + o.elem + ) + curr_storage = ( + None + if not isinstance(o, torch.Tensor) + else StorageWeakRef(o.untyped_storage()) + ) + outs_with_identical_metadata_that_require_grad = ( + [] + if not isinstance(o, Tensor) + else [ + curr + for curr in out_storage_to_metadata_key_to_tensors[curr_storage][ + MetadataKey.make(o) + ] + if o is not curr + ] + ) + + # See Note [Accessing .grad_fn on FunctionalTensor] + # In-place operations on views will trigger a lazy rebase of the autograd graph; + # this runs during access to the .grad_fn. The rebase logic will invoke view ops + # on FunctionalTensors, so we must enable a FunctionalTensorMode here to ensure + # these op calls succeed. + grad_fn = None + if isinstance(o, Tensor): + with FunctionalTensorMode(): + grad_fn = o.grad_fn + + is_result_of_custom_autograd_fn = False + # Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) + # autograd fns + if type(grad_fn).__name__ == "CppFunction": + is_result_of_custom_autograd_fn = True + if isinstance(grad_fn, torch.autograd.function.BackwardCFunction): + is_result_of_custom_autograd_fn = True + + if not isinstance(o, Tensor): + output_type = OutputType.non_alias + base_idx = None + elif ( + curr_storage in inp_storage_refs + and grad_fn is not None + and is_result_of_custom_autograd_fn + ): + output_type = OutputType.custom_function_view + base_idx = None + elif ( + curr_storage in inp_storage_refs + and not functional_tensor_storage_changed + ): + base_idx = inp_storage_refs[curr_storage] + is_input_tensor = id(o) in inp_tensor_ids + num_aliased_outs = out_tensor_alias_counts[curr_storage] + num_multi_output_view_outs = ( + num_aliased_tensors_that_are_multi_output_views[curr_storage] + ) + num_aliased_outs_that_are_not_multi_output_views = ( + num_aliased_outs - num_multi_output_view_outs + ) + if ( + grad_fn is not None + and num_aliased_outs_that_are_not_multi_output_views == 0 + ): + # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + # In particular, given: + # def f(x): + # return list(x.unbind(0)) + # The main reason we ordinarily try to regenerate these output aliases outside of the + # compiled autograd.Function is because if any of the outputs are later mutated, + # autograd needs to perform view-replay to regenerate them. + # However, autograd does not allow users to mutate multi-output views + # in any way that can change the autograd metadata of other aliases. + # So we hide this aliasing from autograd here. + log.debug( + "Encountered AOTAutograd case: differentiable outputs that \ +alias each other from a multi-output view call" + ) + output_type = OutputType.non_alias + elif is_input_tensor: + output_type = OutputType.is_input + else: + output_type = OutputType.alias_of_input + elif functional_tensor_storage_changed and id(o) in inp_tensor_ids: + # When there is a set_() on an input, we cannot rely on checking storages + # to detect if we are returning an input (since the inputs storage is different) + assert curr_storage is not None + base_idx = inp_storage_refs[curr_storage] + output_type = OutputType.is_input + + # We only need to handle the intermediate base case when both + # the intermediate base and the output require gradients. + # See Note [AOT Autograd: outputs aliasing inputs or intermediates!] + elif o._base is not None and o.requires_grad and o._base.requires_grad: + num_aliased_outs = out_tensor_alias_counts[curr_storage] + num_multi_output_view_outs = ( + num_aliased_tensors_that_are_multi_output_views[curr_storage] + ) + num_aliased_outs_that_are_not_multi_output_views = ( + num_aliased_outs - num_multi_output_view_outs + ) + # Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] + if ( + out_tensor_alias_counts[curr_storage] == 1 + or num_aliased_outs_that_are_not_multi_output_views <= 1 + ): + # Note [Intermediate Bases Optimization] + # Normally if we have an output that aliases an intermediate, + # we need to add the extra "intermediate base" logic further down + # to prevent autograd from yelling at us if the user later tries to + # mutate that output. + # However, the common case here is if we have an output that aliases an intermediate, + # but doesn't alias any other outputs. + # In that case, autograd shouldn't have to worry about the aliasing at all + # (if that output is mutated, there are no other live aliases for autograd to worry about). + # The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs. + # So as an optimization, we won't do intermediate base handling in this case. + # Instead, we'll hide the aliasing from autograd using aten._unsafe_view(). + if ( + out_tensor_alias_counts[curr_storage] != 1 + and num_aliased_outs_that_are_not_multi_output_views <= 1 + ): + log.debug( + "Encountered AOTAutograd case: differentiable outputs that alias each other \ +from a multi-output view call" + ) + output_type = OutputType.unsafe_view_alias + base_idx = None + else: + # First, check if o's ._base is an existing output + maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None) + if maybe_existing_out_idx is not None: + # Special case where the output is an alias of a graph intermediate, but that intermediate + # is itself also a user output. + output_type = ( + OutputType.alias_of_intermediate_base_is_user_output + ) + base_idx = maybe_existing_out_idx + else: + # Next, check if o's ._base is an intermediate base that we already returned + maybe_existing_base_output_idx = ( + intermediate_base_tensor_id_to_output_idx.get( + id(o._base), None + ) + ) + if maybe_existing_base_output_idx is not None: + output_type = OutputType.alias_of_intermediate + base_idx = maybe_existing_base_output_idx + else: + # Otherwise, take o._base and explicitly return it as an output in the compiled graph + new_out_idx = len(intermediate_bases) + base_idx = new_out_idx + # Indicate to the logic later on (when we trace the joint) + # that this particular output should get it's ._base appended to the forward graph outputs + output_type = ( + OutputType.alias_of_intermediate_save_as_output + ) + intermediate_base_tensor_id_to_output_idx[ + id(o._base) + ] = new_out_idx + intermediate_bases.append(o._base) + elif ( + # See https://github.com/pytorch/pytorch/issues/100348 for this case. + # This protects against the specific case where a user fn returns (output, output.detach()) + out_tensor_alias_counts[curr_storage] > 1 + and len(outs_with_identical_metadata_that_require_grad) > 0 + and not o.requires_grad + ): + # In theory we could use any of these tensors to regenerate the aliased outputs from, + # since they all alias each other and have identical metatadata + out_alias = outs_with_identical_metadata_that_require_grad[0] + existing_out_idx = out_tensor_ids[id(out_alias)] + output_type = OutputType.alias_of_intermediate_base_is_user_output + base_idx = existing_out_idx + else: + output_type = OutputType.non_alias + base_idx = None + + if isinstance(o, torch.Tensor): + dynamic_dims = { + i for i, s in enumerate(o.shape) if not is_concrete_int(s) + } + else: + dynamic_dims = None + + # Save the current FunctionalTensor output. + # + # This will be used at runtime for reconstructing output views from + # their respective base tensors. + # + # The FunctionalTensor will be saved if one of the 2 conditions below + # is true: + functional_tensor = None + if ( + # 1. If the output_type is either of: + # (i) alias_of_intermediate; + # (ii) alias_of_intermediate_save_as_output; or + # (iii) alias_of_intermediate_base_is_user_output. + # + # No need to worry about in-place view operations here, since + # this functionalization step elimitates mutations. + # + # i.e. we have access to the actual base tensor, before the + # in-place operation was applied. + output_type + in ( + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + OutputType.alias_of_intermediate_base_is_user_output, + ) + ) or ( + # 2. If the output_type is alias_of_input, and no in-place view + # operationthe was run on the input (base tensor). + # + # In this case, we need to check for metadata mutation because + # the runtime explicitly reconstructs the inputs, before actually + # reconstructing the outputs. Due to in-place view operations, the + # fully reconstructed input may not be this output base tensor + # anymore. + output_type == OutputType.alias_of_input + and base_idx is not None + and not input_info[base_idx].mutates_metadata + ): + if isinstance(o, FunctionalTensor): + functional_tensor = FunctionalTensorMetadataEq(o.elem) + + out_info = OutputAliasInfo( + output_type=output_type, + raw_type=type(o), + base_idx=base_idx, + dynamic_dims=dynamic_dims, + requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, + functional_tensor=functional_tensor, + ) + output_info.append(out_info) + + # See Note [AOT Autograd: Views to avoid tangents aliasing inputs] + def view_avoid_dupes_with_primals(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + return transform_subclass( + t, lambda _, inner_t: view_avoid_dupes_with_primals(inner_t) + ) + if isinstance(t, Tensor): + return t.view(t.shape) + return t + + # This analysis function returns *only* the outputs that are meant to be tangents to the backwards. + # Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates) + # are *regenerated* later, and not used directly in the autograd graph + def _plain_fake_tensor_like_subclass(x): + with detect_fake_mode(): + return torch.empty( + x.shape, dtype=x.dtype, device=x.device, layout=x.layout + ) + + def _is_subclass_mutated_input_tangent_always_subclass(inp): + return ( + isinstance(inp, torch.nested._internal.nested_tensor.NestedTensor) + or torch._functorch.config.disable_guess_zero_tangent_for_mutated_input_subclass + ) + + f_input_tangents = [ + # Note: [AOTAutograd Tangent Subclassness for mutated inputs] + # Generally when creating tangents to trace with, we assume that tangents will have + # the same subclass-ness as their forward outs + # however: for tangents that correspond to input mutations, in practice it is more likely + # that these tangents will be plain tensors of zeros at runtime, so we tweak our guess + # to assume that these tangents should always be plaint tensors. + # Example: + # def f(x): + # x.mul_(2) + # return x + 1 + # out = f(x) + # out.sum().backward() + # In the above code, we will have a tangent "x_updated_tangent", + # which will be a plain tensor of zeros, *unless* x is used in some compute after executing f + # + # However, there are exceptions to this logic. If a view is created from mutated input and is used in backward, + # The tangent for this subclass input will be a subclass tensor. + # Example: + # def f(a, b): + # a.mul_(2) + # b.mul_(3) + # return b.view(b.shape), a + b + # a_out, b_out = f(..., Subclass) + # (a * b).sum().backward() + # + # We can not deduce it easily now, so introducing a debug config to be able to turn off this for specific cases. + # NJT gurantees to have its tangent as NJT, because it has dedicated integration in Autograd + # See torch/csrc/autograd/python_function.cpp, use_zeros_like. + ( + _plain_fake_tensor_like_subclass(inp) + if is_traceable_wrapper_subclass(inp) + and not _is_subclass_mutated_input_tangent_always_subclass(inp) + else inp + ) + for inp, info in zip(flat_f_args, input_info) + if info.mutation_type == MutationType.MUTATED_OUT_GRAPH + and info.mutates_data + and info.requires_grad + ] + f_output_tangents = [ + o + for o, info in zip(flat_f_outs, output_info) + if info.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad + ] + # intermediate bases are also included in the backward graph + f_tangents = f_input_tangents + f_output_tangents + intermediate_bases + traced_tangents = pytree.tree_map(from_fun, f_tangents) + traced_tangents = pytree.tree_map( + view_avoid_dupes_with_primals, traced_tangents + ) + + traced_tangents = [ + coerce_tangent_and_suggest_memory_format(tt)[0] + for i, tt in enumerate(traced_tangents) + ] + nonlocal static_input_indices + static_input_indices = static_input_indices or [] + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + passed_indices = set(static_input_indices) + static_input_indices = [ + i + for i, arg in enumerate(flat_args) + if (isinstance(arg, torch.nn.Parameter) or i in passed_indices) + ] + + static_input_logger.debug( + "static input indices metadata analysis: %s", static_input_indices + ) + + f_mutated_inputs = [ + inp + for inp, info in zip(flat_f_args, input_info) + if info.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + f_metadata_mutated_inputs = [ + inp for inp, info in zip(flat_f_args, input_info) if info.mutates_metadata + ] + # This logic (annoyingly) re-figures out exactly what the outputs to the compiled fw graph will be. + # When handling subclasses, we need info about **all** outputs of compiled forward graph, + # so we know precisely which graph outputs to wrap back into tensor subclasses + # Ideally we would refactor this so not have an is_train flag, and have the separate + # inference and training paths decide which inputs/output to ask for subclass info on. + # However, we currently stash indexing information on each SubclassMeta about its order + # in the graph outputs list. + f_fw_graph_outs = list(flat_f_outs) + if is_train or not keep_input_mutations: + f_fw_graph_outs = f_mutated_inputs + f_fw_graph_outs + else: + # even when "keep_input_mutations" is True, + # we never keep metadata-only mutations in the fw graph + f_fw_graph_outs = f_metadata_mutated_inputs + f_fw_graph_outs + if is_train: + f_fw_graph_outs = f_fw_graph_outs + intermediate_bases + fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs) + + grad_enabled_mutation = None + if torch.is_grad_enabled() != prior_grad_enabled: + grad_enabled_mutation = torch.is_grad_enabled() + torch.set_grad_enabled( + prior_grad_enabled + ) # Restore the prior state after tracing it + log.debug( + ( + "grad_mode mutation encountered in graph. " + "Will emit mutation epilogue, to set grad_mode=%s" + ), + grad_enabled_mutation, + ) + + metadata = ViewAndMutationMeta( + input_info=input_info, + output_info=output_info, + num_intermediate_bases=len(intermediate_bases), + keep_input_mutations=keep_input_mutations, + traced_tangents=traced_tangents, + subclass_inp_meta=create_subclass_meta(flat_args), + subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs), + subclass_tangent_meta=create_subclass_meta( + traced_tangents, count_symints=False, with_memory_format=True + ), + is_train=is_train, + grad_enabled_mutation=grad_enabled_mutation, + static_input_indices=static_input_indices, + tokens=mode._tokens, + ) + return metadata + + return inner diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..a3435dfd8c9f5745a4e750b3cefefba63f45de66 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -0,0 +1,338 @@ +# mypy: allow-untyped-defs +""" +This module dispatches the graphs to either the forward-only or joint compilation +pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata. +""" + +import dataclasses +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code +from torch._logging import getArtifactLogger, trace_structured +from torch._subclasses.functional_tensor import FunctionalTensorMode +from torch.fx.experimental.proxy_tensor import make_fx +from torchgen.utils import dataclass_repr + +from .. import config +from .functional_utils import ( + assert_functional_graph, + propagate_input_mutation_stacktraces, +) +from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta +from .traced_function_transforms import ( + aot_dispatch_subclass, + create_functionalized_fn, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, + handle_effect_tokens_fn, +) +from .utils import ( + copy_fwd_metadata_to_bw_nodes, + register_buffer_assignment_hook, + root_module_when_exporting_non_strict, + unlift_tokens, +) + + +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + + +def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule: + # FunctionalTensorMode must be enabled here. + # See Note [Accessing .grad_fn on FunctionalTensor] + with enable_python_dispatcher(), FunctionalTensorMode( + pre_dispatch=aot_config.pre_dispatch, + export=aot_config.is_export, + # Allow token discovery for joint fn tracing as tokens can be used in backward. + _allow_token_discovery=True, + ): + fx_g = make_fx( + f, + decomposition_table=aot_config.decompositions, + record_module_stack=True, + pre_dispatch=aot_config.pre_dispatch, + )(*args) + + return fx_g + + +# TODO: Refactor the following code so detach() persists item_memo +def _detach_and_copy_item_memo(t): + detached_t = t.detach() + if hasattr(t, "item_memo"): + detached_t.item_memo = t.item_memo + return detached_t + + +def aot_dispatch_base_graph( + flat_fn, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[torch.fx.GraphModule, list[Any], Optional[SubclassMeta]]: + # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case. + # The cases that aot_dispatch_base doesn't need to handle include: + # - outputs that are aliases of graph intermediates + # - outputs that are aliases of graph inputs + # While cases that it does need to handle include: + # - input mutations (including when inputs are aliases of each other) + # - input metadata mutations + fn_to_trace = fn_input_mutations_to_outputs( + flat_fn, + fw_metadata, + keep_data_input_mutations=aot_config.keep_inference_input_mutations, + ) + + fn_to_trace, updated_flat_args = create_functionalized_fn( + fn_to_trace, + flat_args, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=False, + ) + + # TODO: replace with AOTDispatchSubclassWrapper once we refactor + # fn_input_mutations_to_outputs and create_functionalized_fn + # into CompilerWrappers. + ( + fn_to_trace, + updated_flat_args_subclasses_desugared, + maybe_subclass_meta, + ) = aot_dispatch_subclass( + fn_to_trace, + updated_flat_args, + is_joint_structure=False, + meta=fw_metadata, + fw_only=flat_fn, + ) + + (fn_to_trace, updated_flat_args_subclasses_desugared) = handle_effect_tokens_fn( + fn_to_trace, + updated_flat_args_subclasses_desugared, + meta=fw_metadata, + trace_joint=False, + ) + + aot_graphs_log.debug( + "aot_config id: %s, fw_metadata=%s,subclass_metadata=%s", + str(aot_config.aot_id), + str(fw_metadata), + str(maybe_subclass_meta), + ) + + # We track buffer assignments when exporting in non-strict mode. + # (In contrast, strict mode errors on any attribute assignment.) + mod_when_exporting_non_strict = root_module_when_exporting_non_strict(flat_fn) + if aot_config.is_export and mod_when_exporting_non_strict is not None: + # For any buffer that is assigned, we want to associate it to the final proxy node + # that it is assigned to. This node can then be added as a buffer mutation output. + assigned_buffers: dict[str, str] = {} + hook = register_buffer_assignment_hook( + mod_when_exporting_non_strict, assigned_buffers + ) + + fake_mode = detect_fake_mode() + if fake_mode: + saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( + torch.Tensor, + _detach_and_copy_item_memo, + updated_flat_args_subclasses_desugared, + ) + else: + saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( + torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared + ) + + fw_module = _create_graph( + fn_to_trace, + updated_flat_args_subclasses_desugared, + aot_config=aot_config, + ) + + if aot_config.is_export and mod_when_exporting_non_strict is not None: + # We update metadata to consider any assigned buffers as buffer mutations. + i = len(dict(mod_when_exporting_non_strict.named_parameters())) + for name, _ in mod_when_exporting_non_strict.named_buffers(): + if name in assigned_buffers and not fw_metadata.input_info[i].mutates_data: # type: ignore[possibly-undefined] + fw_metadata.input_info[i] = dataclasses.replace( + fw_metadata.input_info[i], mutates_data=True + ) + fw_metadata.num_mutated_inp_runtime_indices += 1 + i += 1 + + # We add nodes corresponding to buffer assignments as output nodes in the graph. + add_nodes = [] + output_node = list(fw_module.graph.nodes)[-1] + for name in assigned_buffers.values(): # type: ignore[possibly-undefined] + for node in fw_module.graph.nodes: + if node.name == name: + add_nodes.append(node) + node.users[output_node] = None + output_node.args = ((*add_nodes, *output_node.args[0]),) + + hook.remove() # type: ignore[possibly-undefined] + + # As long as we opted to remove input mutations, then + # there should be *NO* mutating ops in the graph at this point. + copy_count = assert_functional_graph(fw_module.graph) + fw_module.graph.eliminate_dead_code() + fw_module.recompile() + + copy_count2 = assert_functional_graph(fw_module.graph) + propagate_input_mutation_stacktraces(fw_module.graph) + + # See Note [Side-Effectful Tokens in AOTAutograd] + num_tokens = len(fw_metadata.tokens) + if num_tokens != 0 and config.unlift_effect_tokens: + unlift_tokens(fw_module, fw_metadata, aot_config) + saved_updated_flat_args_subclasses_desugared = ( + saved_updated_flat_args_subclasses_desugared[num_tokens:] + ) + + assert copy_count == copy_count2 + + if aot_config.enable_log: + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + + trace_structured( + "aot_inference_graph", + payload_fn=lambda: fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + # TODO: should factor this into a separate function for export that always only returns just the graph. + if aot_config.is_export: + assert ( + maybe_subclass_meta is None + ), "aot_export_module does not support tensor subclass inputs for now." + return fw_module, saved_updated_flat_args_subclasses_desugared, maybe_subclass_meta + + +# Has the precondition that there +# are no duplicate arguments in flat_args (e.g., the same Tensor +# object never shows up twice. However, two tensor inputs MAY alias +# the same storage, so long as they have separate TensorImpls.) +def aot_dispatch_autograd_graph( + flat_fn, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[torch.fx.GraphModule, tuple[list[Any], list[Any]], Optional[SubclassMeta]]: + # traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward. + # It includes outputs of the original forward, *and* any updated inputs due to input mutations. + # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations. + joint_inputs = (flat_args, fw_metadata.traced_tangents) + + fn_prepared_for_autograd = fn_prepped_for_autograd( + flat_fn, + fw_metadata, + ) + joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config) + + joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn( + joint_fn_to_trace, + joint_inputs, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=True, + ) + + # TODO: replace with AOTDispatchSubclassWrapper once we refactor + # fn_input_mutations_to_outputs and create_functionalized_fn + # into CompilerWrappers. + subclass_tracing_info = aot_dispatch_subclass( + joint_fn_to_trace, + updated_joint_inputs, + is_joint_structure=True, + meta=fw_metadata, + fw_only=flat_fn, + ) + + joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn + updated_joint_inputs = subclass_tracing_info.plain_tensor_args + + (joint_fn_to_trace, updated_joint_inputs) = handle_effect_tokens_fn( + joint_fn_to_trace, + updated_joint_inputs, + meta=fw_metadata, + trace_joint=True, + ) + + # When we call _create_graph, this may mutate the metadata of joint + # inputs. But callers are expecting to get the original joint inputs. So + # we make aliases of all the inputs to make sure we have a copy that + # doesn't get modified. + # + # This destroys requires_grad/grad_fn information. However, backends + # beneath AOTAutograd are indifferent to this information, so it doesn't + # matter. + + fake_mode = detect_fake_mode() + if fake_mode: + saved_updated_joint_inputs = pytree.tree_map_only( + torch.Tensor, _detach_and_copy_item_memo, updated_joint_inputs + ) + else: + saved_updated_joint_inputs = pytree.tree_map_only( + torch.Tensor, lambda t: t.detach(), updated_joint_inputs + ) + maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta + + fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config) + + # There should be *NO* mutating ops in the graph at this point. + assert_functional_graph(fx_g.graph) + + # Redundant with the check above, but worth having in case tracing introduced + # a fake tensor. Unlikely. + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g) + fx_g.graph.eliminate_dead_code() + copy_fwd_metadata_to_bw_nodes(fx_g) + fx_g.recompile() + + # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect + # when we need to manually detach() some inputs in the forward. + # Higher order ops might eventually need to do the same. + if aot_config.is_export: + assert ( + maybe_subclass_meta is None + ), "aot_export_module does not support tensor subclass inputs for now." + return fx_g, saved_updated_joint_inputs, maybe_subclass_meta diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/functional_utils.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/functional_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d217ff25cafa243389018b1d2e9c2e5ad30ac2d4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/functional_utils.py @@ -0,0 +1,523 @@ +# mypy: allow-untyped-defs +""" +This file contains utilities related to functionalization in AOTAutograd: +1. converting to/from functional tensors +2. detecting Tensor mutations - both metadata and Tensor value +3. regenerating/replaying views from their base +4. checking if a graph is functional i.e. whether it contains any mutation ops +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor +from torch._logging import getArtifactLogger +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq, SymIntEqByExpr +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + transform_subclass, +) + + +aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") + + +def to_fun(t): + if isinstance(t, Tensor): + if is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t)) + torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] + return out + else: + return FunctionalTensor.to_functional(t) + else: + return t + + +def sync_functional_tensor(t): + if is_traceable_wrapper_subclass(t): + attrs, _ctx = t.__tensor_flatten__() # type: ignore[attr-defined] + for attr in attrs: + sync_functional_tensor(getattr(t, attr)) + else: + torch._sync(t) + + +# When subclasses are involved, t here will usually look something like: +# SubclassA(SubclassB(FunctionalTensor(_to_fun_tensor(FakeTensor)))) +def from_fun(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t)) + torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined] + return out + + if not isinstance(t, FunctionalTensor): + # quick sanity assert + if isinstance(t, torch.Tensor): + assert not torch._is_functional_tensor(t) # type: ignore[attr-defined] + return t + sync_functional_tensor(t) + return torch._from_functional_tensor(t.elem) + + +def is_fun(t): + if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t): + # See Note [Functionalization always runs last] + # This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper + # goes at the bottom. + # recurse here, so we can support nested wrapper subclasses + t_attrs, _ = t.__tensor_flatten__() # type: ignore[attr-defined] + t_inners = [getattr(t, attr) for attr in t_attrs] + any_fun = any(is_fun(x) for x in t_inners) + all_fun = all(is_fun(x) for x in t_inners) + assert any_fun == all_fun + return any_fun + + return isinstance(t, FunctionalTensor) + + +# t here is either +# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) +# (2) A traceable tensor subclass that holds a FunctionalTensor +# (3) Not a tensor +def has_data_mutation(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + # A tensor subclass was updated if any of its inner elements were updated + return any(has_data_mutation(getattr(t, attr)) for attr in attrs) + else: + if isinstance(t, torch.Tensor): + assert isinstance(t, FunctionalTensor) + return torch._functionalize_has_data_mutation(t.elem) # type: ignore[attr-defined] + return False + + +def are_all_mutations_hidden_from_autograd(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + # If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd. + return all( + are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs + ) + elif isinstance(t, torch.Tensor): + assert isinstance(t, FunctionalTensor) + return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem) + else: + return False + + +def are_all_mutations_under_no_grad_or_inference_mode(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + return all( + are_all_mutations_under_no_grad_or_inference_mode(getattr(t, attr)) + for attr in attrs + ) + else: + assert isinstance(t, FunctionalTensor) + return torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode( + t.elem + ) + + +def was_inductor_storage_resized(t): + if is_traceable_wrapper_subclass(t): + attrs, _ = t.__tensor_flatten__() + if any(was_inductor_storage_resized(getattr(t, attr)) for attr in attrs): + raise RuntimeError( + f"storage resizing is not supported on tensor subclass: {type(t)}" + ) + elif not isinstance(t, torch.Tensor): + return False + else: + assert isinstance(t, FunctionalTensor) + return torch._functionalize_was_inductor_storage_resized(t.elem) + + +# f_arg here is either +# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor)) +# (2) A traceable tensor subclass that holds a FunctionalTensor +# (3) Not a tensor +# Assumption: arg promises to be the "original" tensor wrapped by f_arg +# Note: "storage mutations" coming from set_() are a type of metadata mutation. So: +# - check_only_storage_mutation=True: only return true if there was a storage mutation +# - check_only_storage_mutation=Flse: return true if there was any metadata mutation (including a storage mutation) +def has_metadata_mutation(f_arg, arg, *, check_only_storage_mutation: bool): + if is_traceable_wrapper_subclass(f_arg): + attrs, _ = f_arg.__tensor_flatten__() + # A tensor subclass was updated if any of its inner elements were updated + f_inner_ts = [getattr(f_arg, attr) for attr in attrs] + inner_ts = [getattr(arg, attr) for attr in attrs] + return any( + has_metadata_mutation( + f_inner_t, + inner_t, + check_only_storage_mutation=check_only_storage_mutation, + ) + for f_inner_t, inner_t in zip(f_inner_ts, inner_ts) + ) + else: + if not isinstance(f_arg, torch.Tensor): + assert not isinstance(arg, torch.Tensor) + return False + assert isinstance(f_arg, FunctionalTensor) + assert isinstance(arg, FakeTensor) + + arg_after = torch._from_functional_tensor(f_arg.elem) + # This is true if the current tensor experienced at least one set_() call + maybe_storage_changed = torch._functionalize_was_storage_changed(f_arg.elem) # type: ignore[attr-defined] + # However, multiple set_() calls can cancel out. So we also check whether the + # storage of the tensor has changed. + # Note: if an input experienced two set_() calls that cancel out, **and** + # it experiences an data mutation, we pessimistically think that the set_() + # call is necessary here. We could in theory fix this, but this will + # hopefully never happen in user code, and is not needed for fsdp. + if is_sparse_any(arg): + # TODO:add sparse tensors support to functionalization + same_storages = False + else: + same_storages = StorageWeakRef(arg.untyped_storage()) == StorageWeakRef( + arg_after.untyped_storage() + ) + has_storage_metadata_mutation = maybe_storage_changed and not same_storages + if check_only_storage_mutation: + return has_storage_metadata_mutation + + # storage metadata mutation is a type of metadata mutation, so return true if we saw one + if has_storage_metadata_mutation: + return True + + maybe_metadata_mutated = torch._functionalize_has_metadata_mutation(f_arg.elem) # type: ignore[attr-defined] + # This is true if the current tensor experienced at least one metadata mutation. + # So if false, we know there was no metadata mutation + if not maybe_metadata_mutated: + return False + + # However, multi metadata mutations can cancel out. + # So we also check if the concrete sizes/strides on the tensor have changed. + same_sizes = arg.shape == arg_after.shape + same_strides = arg.stride() == arg_after.stride() + same_offsets = arg.storage_offset() == arg_after.storage_offset() + has_metadata_mutation_ = maybe_metadata_mutated and not ( + same_sizes and same_strides and same_offsets + ) + # We consider a tensor to have been metadata mutated if its storage was mutated through a set_() call. + return has_metadata_mutation_ + + +def gen_alias_from_base( + aliased_base_tensor, + target_meta_tensor, + target_requires_grad, + target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, + *, + replay_views, +): + # Patch the correct requires_grad field of the output tensor, depending on whether: + # (i) the reconstructed output (out) was came from a tensor that requires grad or not; + # and (ii) the concrete returned output does require grad or not. + def patch_requires_grad(out): + if aliased_base_tensor.requires_grad and not target_requires_grad: + out = out.detach() + elif not aliased_base_tensor.requires_grad and target_requires_grad: + out.requires_grad_(True) + return out + + # If provided, use the target functional tensor for replaying the views. + # + # In summary, we use the fact that FunctionalTensorWrapper saves the view + # functions applied to itself (collected during functionalization) so as + # to replay them (view functions) on the aliased_base_tensor. + if ( + replay_views + and target_functional_tensor is not None + and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) + ): + functional_tensor = target_functional_tensor.tensor + + out = torch._functionalize_apply_view_metas( + functional_tensor, aliased_base_tensor + ) + # If re-applying the ViewMeta sequence succeeded, there should be no more + # problems going forward. We just check we got to the target shape and + # patch requires_grad flag. + assert out.shape == target_meta_tensor.shape, ( + "incorrect out shape after application of ViewMeta sequence: " + f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)" + ) + return patch_requires_grad(out) + + # Try to do view-replay if possible. + # fall back to .as_strided() if we can't. + if target_meta_tensor._base is not None: + # The base that we want to replay our view off of might have a different shape than the view's original base. + b = target_meta_tensor._base + abt = aliased_base_tensor + # Don't unnecessarily call as_strided if nothing changed; as_strided's + # backward is poorly implemented and slow + if abt is not b and ( + abt.size() != b.size() + or abt.stride() != b.stride() + or abt.storage_offset() != b.storage_offset() + ): + reshaped_base_tensor = aliased_base_tensor.as_strided( + b.size(), b.stride(), b.storage_offset() + ) + else: + reshaped_base_tensor = aliased_base_tensor + out = target_meta_tensor._view_func(reshaped_base_tensor) + # This shape mismatch can happen due to a bug in inplace/view handling in autograd. + # Try putting a breakpoint here and running + # `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types` + # Also, https://github.com/pytorch/pytorch/issues/49825 + # + # As a stopgap, we'll fall back to as_strided. + if out is not None and out.shape == target_meta_tensor.shape: + return patch_requires_grad(out) + + size = target_meta_tensor.size() + stride = target_meta_tensor.stride() + storage_offset = target_meta_tensor.storage_offset() + if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex(): + aliased_out = torch.view_as_real(aliased_base_tensor).as_strided( + size, stride, storage_offset + ) + elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex(): + aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided( + size, stride, storage_offset + ) + else: + aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset) + # For outputs aliasing inputs, we need to check if the requires-gradness has changed. + aliased_out = patch_requires_grad(aliased_out) + # For outputs aliasing inputs, we need to check if the dtype has changed. + # as_strided() is the "most generic" view, but it does not cover cross-dtype views + if aliased_out.dtype != target_meta_tensor.dtype: + aliased_out = aliased_out.view(target_meta_tensor.dtype) + return aliased_out + + +def has_same_metadata(t1, t2): + return ( + guard_or_false(sym_eq(t1.size(), t2.size())) + and guard_or_false(t1.layout == t2.layout) + and ( + is_sparse_any(t1) + or ( + guard_or_false(sym_eq(t1.stride(), t2.stride())) + and guard_or_false(t1.storage_offset() == t2.storage_offset()) + ) + ) + and t1.is_conj() == t2.is_conj() + and t1.is_neg() == t2.is_neg() + ) + + +@dataclass(frozen=True) +class MetadataKey: + """ + This should be equal whenever has_same_metadata would return True + """ + + size: tuple[SymIntEqByExpr, ...] + layout: torch.layout + is_sparse: bool + # these are empty when is_sparse + stride: Optional[tuple[SymIntEqByExpr, ...]] + storage_offset: Optional[SymIntEqByExpr] + is_conj: bool + is_neg: bool + + @staticmethod + def make(t): + is_sparse = is_sparse_any(t) + return MetadataKey( + size=tuple(SymIntEqByExpr(s) for s in t.size()), + layout=t.layout, + is_sparse=is_sparse, + stride=None if is_sparse else tuple(SymIntEqByExpr(s) for s in t.stride()), + storage_offset=None if is_sparse else SymIntEqByExpr(t.storage_offset()), + is_conj=t.is_conj(), + is_neg=t.is_neg(), + ) + + +# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata +# after applying all the ViewMeta operations. +class FunctionalTensorMetadataEq: + def __init__(self, tensor: torch.Tensor) -> None: + assert torch._is_functional_tensor(tensor) + self.tensor = tensor + + def __eq__(self, other: object) -> bool: + # If other is None, then it probably means that we weren't able to recreate + # the FunctionalTensorMetadataEq. One of this cases is when we update the + # view metadata by calling: create_synthetic_base_metadata. + if other is None: + return True + + # Comparison agains any other type is not implemented. + if not isinstance(other, FunctionalTensorMetadataEq): + return NotImplemented + + return has_same_metadata(self.tensor, other.tensor) + + +# new_arg and arg here are either: +# (1) both a FakeTensor +# (2) both a traceable tensor subclass that holds a FakeTensor +# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. +# When we run functionalization and wrap our inputs into FunctionalTensors, +# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed +# +# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization +# to confirm that inputs were not mutated when running the user's model with functionalization on. +# But when we have subclass inputs, we can't rely on that: +# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs +# a brand new subclass instance: we are calling __tensor_unflatten__, and going +# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor)) +def was_tensor_updated(arg, new_arg): + if is_traceable_wrapper_subclass(arg): + assert is_traceable_wrapper_subclass(new_arg) + attrs, _ = arg.__tensor_flatten__() + new_attrs, _ = new_arg.__tensor_flatten__() + assert attrs == new_attrs + # A tensor subclass was updated if any of its inner elements were updated + return any( + was_tensor_updated(getattr(arg, attr), getattr(new_arg, attr)) + for attr in attrs + ) + else: + return arg is not new_arg + + +# new_arg and arg here are either: +# (1) both a FakeTensor +# (2) both a traceable tensor subclass that holds a FakeTensor +# Pre-condition: the two args are the "old" and "new" inputs from running functionalization. +# When we run functionalization and wrap our inputs into FunctionalTensors, +# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed, +# but shares storage with the old input +def was_tensor_metadata_updated(arg, new_arg): + if is_traceable_wrapper_subclass(arg): + assert is_traceable_wrapper_subclass(new_arg) + attrs, _ = arg.__tensor_flatten__() + new_attrs, _ = new_arg.__tensor_flatten__() + assert attrs == new_attrs + # A tensor subclass was updated if any of its inner elements were updated + return any( + was_tensor_metadata_updated(getattr(arg, attr), getattr(new_arg, attr)) + for attr in attrs + ) + else: + return arg is not new_arg and StorageWeakRef( + arg.untyped_storage() + ) == StorageWeakRef(new_arg.untyped_storage()) + + +# Returns the number of detected copy_ +def assert_functional_graph(fx_g: torch.fx.Graph) -> int: + allowed_mutation_ops = [ + torch.ops.aten.copy_.default, + torch.ops.aten.set_.source_Tensor, + ] + if hasattr(torch.ops.fsdp, "copy_"): + allowed_mutation_ops.append(torch.ops.fsdp.copy_.default) + + placeholders = set() + mutation_count = 0 + # NB: It would also be nice to verify that the mutations all happen at the + # end, but we also do some administrative views after mutations so this + # isn't actually true. (TODO: Could this cause problems for Inductor?) + for n in fx_g.nodes: + if n.op == "placeholder": + placeholders.add(n) + if isinstance(n.target, torch._ops.OpOverload): + if n.target in allowed_mutation_ops: + # Can only copy_/set_ into an input + # this is mostly a hack to avoid failing XLA tests. + # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 + if "set_buffer_donor_" not in str(n.args[0]): + assert ( + n.args[0] in placeholders + ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + mutation_count += 1 + else: + assert ( + not n.target._schema.is_mutable + ), f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + return mutation_count + + +def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None: + placeholders = set() + for n in fx_g.nodes: + if n.op == "placeholder": + placeholders.add(n) + if isinstance(n.target, torch._ops.OpOverload): + if n.target is torch.ops.aten.copy_.default: + # Can only copy_ into an input, and can only do so once + if "set_buffer_donor_" not in str(n.args[0]): + assert ( + n.args[0] in placeholders + ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + placeholders.remove(n.args[0]) + copy_from_node = n.args[1] + # Pre-condition: every node has a "stack_trace" field in its meta, + # but copy_() nodes do not (since we manually added them during functionalization). + # Instead, we manually propagate here. + if "stack_trace" in copy_from_node.meta: + n.meta["stack_trace"] = copy_from_node.meta["stack_trace"] + + +def _check_if_mutation_can_be_in_graph( + keep_input_mutations: bool, + mutates_data, + mutates_metadata, + mutations_hidden_from_autograd, + mutations_under_no_grad_or_inference_mode, + mutates_storage_metadata, + mutation_inductor_storage_resize, + requires_grad, +): + if keep_input_mutations: + in_graph = ( + mutates_data or mutates_storage_metadata or mutation_inductor_storage_resize + ) and ( + (not mutates_metadata and not requires_grad) + or mutations_hidden_from_autograd + or mutations_under_no_grad_or_inference_mode + ) + else: + in_graph = False + # See Note [set_() Input Mutations in AOTAutograd] + # If there was a `set_()`, we require that all mutations were under no_grad, + # so we can (safely) emit the set_() in the graph at runtime + # resize_() gets the same treatment + if mutation_inductor_storage_resize or mutates_storage_metadata: + op_name = "resize_" if mutation_inductor_storage_resize else "set_" + assert in_graph, f"""\ +Encountered a {op_name} on a graph input, but the input has other mutations that we cannot +keep in the graph. This is not supported today. Current state: + keep_input_mutations={keep_input_mutations} + mutates_data={mutates_data} + mutates_metadata={mutates_metadata} + mutations_hidden_from_autograd={mutations_hidden_from_autograd} + mutations_under_no_grad_or_inference_mode={mutations_under_no_grad_or_inference_mode} + mutation_inductor_storage_resize={mutation_inductor_storage_resize} + requires_grad={requires_grad}""" + return in_graph diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..aae55f4ed81b4458905039130aa28e4d2993f82a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -0,0 +1,438 @@ +# mypy: allow-untyped-defs +""" +This module is one of the analysis modules - it takes as input a function or graph +and some preexisting properties, and returns some data that is useful for deciding +how to further proceed with compilation or construct runtime wrappers. + +In particular, the following analyses are provided: +1. Refine the view and mutation metadata collected previously - removing duplicate + inputs or mapping views to their bases. +2. We also analyze the function signature for export graphs. +""" + +import contextlib +import itertools +from typing import Any, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C._dynamo.guards import compute_overlapping_tensors +from torch._functorch._aot_autograd.schemas import PlainTensorMeta +from torch._guards import StorageOverlap +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental.symbolic_shapes import is_concrete_int + +from .collect_metadata_analysis import coerce_tangent_and_suggest_memory_format +from .schemas import ( + BackwardSignature, + GraphSignature, + InputAliasInfo, + MemoryFormatMeta, + OutputAliasInfo, + OutputType, + ViewAndMutationMeta, +) +from .utils import strict_zip + + +zip = strict_zip + + +def remove_dupe_metadata( + m: ViewAndMutationMeta, + keep_arg_mask: list[bool], + add_dupe_map: list[int], +) -> ViewAndMutationMeta: + assert len(m.input_info) == len(keep_arg_mask) + # Easy invariant: the first argument should never be a dupe (it will be kept) + assert len(keep_arg_mask) > 0 and keep_arg_mask[0] + + # Filter dupe'd mutated inputs out of traced_tangents + num_data_mutations = len([x for x in m.input_info if x.mutates_data]) + other_traced_tangents = m.traced_tangents[num_data_mutations:] + inp_traced_tangents = m.traced_tangents[:num_data_mutations] + filtered_inp_traced_tangents = [ + # See Note [Tangents memory format] + x + for i, x in enumerate(inp_traced_tangents) + if keep_arg_mask[m.mutated_inp_runtime_indices[i]] + ] + traced_tangents = filtered_inp_traced_tangents + other_traced_tangents + + assert m.subclass_tangent_meta is not None + subclass_tangent_meta = [ + PlainTensorMeta( + 0, memory_format=MemoryFormatMeta(memory_format=torch.contiguous_format) + ) + ] * len(filtered_inp_traced_tangents) + m.subclass_tangent_meta[num_data_mutations:] + + return ViewAndMutationMeta( + input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]], + # For outputs that are views of inputs, we store the index of the input that the output + # was generated from. Need to update that index to account for removed dupes. + output_info=[ + OutputAliasInfo( + output_type=o.output_type, + raw_type=o.raw_type, + dynamic_dims=o.dynamic_dims, + base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], + requires_grad=o.requires_grad, + functional_tensor=o.functional_tensor, + ) + for o in m.output_info + ], + num_intermediate_bases=m.num_intermediate_bases, + keep_input_mutations=m.keep_input_mutations, + traced_tangents=traced_tangents, + # We are guaranteed not to get here, since dupes are not supported today with subclass inputs. + subclass_inp_meta=[], + subclass_fw_graph_out_meta=[], + subclass_tangent_meta=subclass_tangent_meta, + is_train=m.is_train, + ) + + +# Given our ViewAndMutation metadata, this fn constructs a new set of metadata, +# after adding synthetic base arguments to the function. +# Most of the work in this fn is slogging through all of the metadata corresponding to inputs, +# and updating it with our synthetic base calling convention. +# +# When config.debug_assert is set, we automatically regenerate the metadata +# and compare it to this output for sanity. +# +# In addition to the updated metadata, also return the list of input indices +# that will need to be updated in the synthetic base epilogue +def create_synthetic_base_metadata( + m: ViewAndMutationMeta, + # Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a + # synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata) + synthetic_base_info: list[Union[int, tuple[int, torch.Tensor]]], + outer_args: list[Any], + inner_args: list[Any], +) -> tuple[ViewAndMutationMeta, list[int]]: + # maps inner arg indices to outer arg indices + synthetic_base_to_indices: dict[int, list[int]] = {} + for inner_idx in range(len(inner_args)): + outer_aliased_indices_of_current_base_arg = [ + outer_idx + for outer_idx, inner_idx_or_tuple in enumerate(synthetic_base_info) + if (isinstance(inner_idx_or_tuple, int) and inner_idx_or_tuple == inner_idx) + or ( + isinstance(inner_idx_or_tuple, tuple) + and inner_idx_or_tuple[0] == inner_idx + ) + ] + synthetic_base_to_indices[inner_idx] = outer_aliased_indices_of_current_base_arg + + # given the requires_grad info on mutated inputs, + # generate the requires_grad info on those same mutated inputs, but after constructing synthetic bases. + input_infos = [] + for outer_indices in synthetic_base_to_indices.values(): + # leaf-ness should be all-or-nothing for aliased tensor. + # (aka if "a" and "b" are views, then a.is_leaf == b.is_leaf) + any_leaf = any(m.input_info[x].is_leaf for x in outer_indices) + all_leaf = all(m.input_info[x].is_leaf for x in outer_indices) + assert any_leaf == all_leaf + + mutates_data = ( + True + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_data + ) + mutates_metadata = ( + False + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_metadata + ) + requires_grad = any(m.input_info[x].requires_grad for x in outer_indices) + mutations_under_no_grad_or_inference_mode = all( + m.input_info[x].mutations_under_no_grad_or_inference_mode + for x in outer_indices + ) + + mutation_inductor_storage_resize = all( + m.input_info[x].mutation_inductor_storage_resize for x in outer_indices + ) + + inpt_info = InputAliasInfo( + # If len(outer_indices) > 1, then this input is a synthetic base. + # The invariant is that to the rest of aot autograd, synthetic bases only show up if + # one of their aliases gets a data mutation. And if any of their aliases get metadata + # mutations, they will be hidden from the rest of aot autograd. + mutates_data=mutates_data, + mutates_metadata=mutates_metadata, + mutations_hidden_from_autograd=all( + m.input_info[x].mutations_hidden_from_autograd for x in outer_indices + ), + mutates_storage_metadata=( + False + if len(outer_indices) > 1 + else m.input_info[outer_indices[0]].mutates_storage_metadata + ), + mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode, + mutation_inductor_storage_resize=mutation_inductor_storage_resize, + is_leaf=any_leaf, + requires_grad=requires_grad, + keep_input_mutations=m.keep_input_mutations, + ) + input_infos.append(inpt_info) + + # Find any inputs that fulfill the following criteria: + # (1) They are part of a synthetic base (because they alias another input, + # and at least one input experiences a data mutation) + # (2) They experience a metadata mutation + outer_aliased_arg_idx_with_metadata_mutations = [ + outer_idx + for outer_idx, inpt_info in enumerate(m.input_info) + if inpt_info.mutates_metadata + and not isinstance(synthetic_base_info[outer_idx], int) + ] + + # grab the original requires grad info on the outputs, except the ones from the mutated inputs + input_metadata_output_info = [ + OutputAliasInfo( + output_type=OutputType.alias_of_input, + raw_type=FunctionalTensor, + dynamic_dims={ + i + for i, s in enumerate(outer_args[outer_idx].shape) + if not is_concrete_int(s) + }, + base_idx=synthetic_base_info[outer_idx][0], # type: ignore[index] + requires_grad=outer_args[outer_idx].requires_grad, + ) + for outer_idx in outer_aliased_arg_idx_with_metadata_mutations + ] + existing_output_infos = [] + for o in m.output_info: + new_base_idx = ( + None + if o.base_idx is None + else ( + synthetic_base_info[o.base_idx] + if isinstance(synthetic_base_info[o.base_idx], int) + else synthetic_base_info[o.base_idx][0] # type: ignore[index] + ) + ) + # If base_idx is changed for OutputType.is_input, we need to update the output type to reflect the change + new_output_type = ( + OutputType.alias_of_input + if o.output_type == OutputType.is_input and o.base_idx != new_base_idx + else o.output_type + ) + existing_output_infos.append( + OutputAliasInfo( + output_type=new_output_type, + raw_type=o.raw_type, + dynamic_dims=o.dynamic_dims, + # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases + base_idx=new_base_idx, # type: ignore[arg-type] + requires_grad=o.requires_grad, + functional_tensor=o.functional_tensor, + ) + ) + + inner_mutated_tangents_and_memory_formats = [ + # See Note [Tangents memory format] + coerce_tangent_and_suggest_memory_format(x) + for inner_idx, x in enumerate(inner_args) + if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad + ] + inner_mutated_tangents = [x[0] for x in inner_mutated_tangents_and_memory_formats] + inner_mutated_tangents_memory_formats = [ + x[1] for x in inner_mutated_tangents_and_memory_formats + ] + + output_info = existing_output_infos + input_metadata_output_info + # Regenerate traced tangents to include mutated inputs including synthetic bases + traced_tangents = ( + inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :] + ) + assert m.subclass_tangent_meta is not None + subclass_tangent_meta = [ + PlainTensorMeta(0, memory_format=x) + for x in inner_mutated_tangents_memory_formats + ] + m.subclass_tangent_meta[len(inner_mutated_tangents) :] + + return ( + ViewAndMutationMeta( + input_info=input_infos, + output_info=output_info, + num_intermediate_bases=m.num_intermediate_bases, + keep_input_mutations=m.keep_input_mutations, + traced_tangents=traced_tangents, + # We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs. + subclass_inp_meta=[], + subclass_fw_graph_out_meta=[], + subclass_tangent_meta=subclass_tangent_meta, + is_train=m.is_train, + ), + outer_aliased_arg_idx_with_metadata_mutations, + ) + + +def compute_overlapping_inputs(aot_config, fwd_inputs, aliased_input_indices): + num_aliases = len(aliased_input_indices) + + shape_env = None + maybe_suppress_guards = contextlib.nullcontext + tracing_context = torch._guards.TracingContext.try_get() + + if tracing_context is not None: + shape_env = tracing_context.fake_mode.shape_env + + # Check whether we can actually get the dynamo sources from within AOTAutograd. + if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None: + maybe_suppress_guards = shape_env.suppress_guards + + # Check whether there are any symbolic values being used. + # We do this for 2 reasons: + # 1. StorageOverlap guard is only issued whenever dynamic shapes is turned on + # 2. Triggers the fast-path for computing storage overlapping + symbolic = any( + isinstance(x, torch.SymInt) + for i in aliased_input_indices + for x in [ + *fwd_inputs[i].shape, + *fwd_inputs[i].stride(), + fwd_inputs[i].storage_offset(), + ] + ) + + if torch._inductor.config.is_fbcode(): + if symbolic and num_aliases > 400: + from torch._subclasses.fake_tensor import ( + UnsupportedMutationAliasingException, + ) + from torch._utils_internal import justknobs_check + + msg = f"Encountered {num_aliases} dynamic, aliased/mutated inputs, consider setting dynamic=False" + + if justknobs_check( + "pytorch/compiler:aliased_inputs_with_mutation_and_dyn_shapes_killswitch", + False, + ): + raise UnsupportedMutationAliasingException(msg) + + with maybe_suppress_guards(): + aliased_fwd_inputs = [fwd_inputs[i] for i in aliased_input_indices] + actual_aliased_indices = { + aliased_input_indices[i] + for i in compute_overlapping_tensors(aliased_fwd_inputs, symbolic=symbolic) + } + + # Add the StorageOverlap AOTAutograd guard only if we are actually keeping track of + # dynamo sources inside AOTAutograd. + if ( + tracing_context is not None + # Make sure dynamic shapes is currently being used. + and symbolic + # We check that we have more than 1 aliased tensor, which should be true at + # this point, anyway. + and num_aliases > 1 + and aot_config.aot_autograd_arg_pos_to_source + ): + no_overlap_indices = list(set(aliased_input_indices) - actual_aliased_indices) + + overlapping_sources = [ + aot_config.aot_autograd_arg_pos_to_source[i] for i in actual_aliased_indices + ] + non_overlapping_sources = [ + aot_config.aot_autograd_arg_pos_to_source[i] for i in no_overlap_indices + ] + + tracing_context.guards_context.aotautograd_guards.append( + StorageOverlap(overlapping_sources, non_overlapping_sources) + ) + + return actual_aliased_indices + + +def _graph_input_names(gm): + return [node.name for node in gm.graph.find_nodes(op="placeholder")] + + +def _graph_output_names(gm): + output_node = next(iter(reversed(gm.graph.nodes))) + assert output_node.op == "output" and len(output_node.args) == 1 + return_args = output_node.args[0] + return [getattr(return_arg, "name", None) for return_arg in return_args] + + +def create_graph_signature( + fx_g: torch.fx.GraphModule, + fw_metadata: ViewAndMutationMeta, + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec, + *, + user_args_flat: list[Tensor], + params_and_buffers_flat: list[Tensor], + param_names: list[str], + buffer_names: list[str], + trace_joint: bool, + num_user_fw_outs: Optional[int], + loss_index: Optional[int], +) -> GraphSignature: + # Retrieve graph input names + graph_input_names = _graph_input_names(fx_g) + # Retrieve graph output names + graph_output_names = _graph_output_names(fx_g) + + num_params_buffers = len(param_names) + len(buffer_names) + num_tokens = len(fw_metadata.tokens) + # We have enough restrictions on the graph (no de-duping, synthetic bases, etc), + # Such that # graph inps = # user inps + # params + # buffers + num_user_args = len(graph_input_names) - num_params_buffers - num_tokens + + if trace_joint: + assert num_user_fw_outs is not None + num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inp_runtime_indices + backward_output_names = graph_output_names[num_fw_outs:] + + grad_index = itertools.count(0) + gradients_to_parameters = { + backward_output_names[next(grad_index)]: param_names[i] + for i, param in enumerate(params_and_buffers_flat) + if param.requires_grad + } + + gradients_to_user_inputs = { + backward_output_names[next(grad_index)]: graph_input_names[ + i + len(params_and_buffers_flat) + ] + for i, user_input in enumerate(user_args_flat) + if user_input.requires_grad + } + + assert len(gradients_to_parameters) + len(gradients_to_user_inputs) == len( + backward_output_names + ) + + # Check that we have fully accounted for all graph outputs + backward_signature = BackwardSignature( + gradients_to_parameters, + gradients_to_user_inputs, + graph_output_names[loss_index], + ) + else: + backward_signature = None + num_user_fw_outs = ( + len(graph_output_names) + - fw_metadata.num_mutated_inp_runtime_indices + - num_tokens + ) + + return GraphSignature.from_tracing_metadata( + in_spec=in_spec, + out_spec=out_spec, + graph_input_names=graph_input_names, + graph_output_names=graph_output_names, + view_mutation_metadata=fw_metadata, + named_parameters=param_names, + named_buffers=buffer_names, + num_user_inputs=num_user_args, + num_user_outputs=num_user_fw_outs, + loss_index=loss_index, + backward_signature=backward_signature, + ) diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..761936710da7d194f13301c3f2d47328af026252 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -0,0 +1,1845 @@ +# mypy: allow-untyped-defs +""" +Functions in this module do most of the "work" of AOTAutograd. +An aot_dispatch_* function: +- Takes in the input flat_fn, flat_args, and some metadata +- Runs a set of pre compile wrappers (e.g. argument deduping) +- Runs the actual compiler +- Wraps the returned callable in a set of post compile wrappers +- Returns the wrapped callable and metadata. +""" + +import copy +import dataclasses +import itertools +import logging +import operator +import time +import traceback +from collections import defaultdict +from contextlib import nullcontext +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._dynamo.utils import detect_fake_mode, dynamo_timed, lazy_format_graph_code +from torch._guards import CompileContext, TracingContext +from torch._logging import getArtifactLogger, trace_structured +from torch._subclasses import FakeTensor +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import is_sym_node +from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals +from torch.fx.graph_module import GraphModule +from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars +from torch.multiprocessing.reductions import StorageWeakRef +from torch.types import py_sym_types +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torchgen.utils import dataclass_repr + +from .. import config +from .autograd_cache import ( + AOTAutogradCache, + serialize_graph_module, + should_use_remote_autograd_cache, +) +from .dispatch_and_compile_graph import ( + aot_dispatch_autograd_graph, + aot_dispatch_base_graph, +) +from .logging_utils import track_graph_compiling +from .runtime_wrappers import ( + AOTDedupeWrapper, + AOTDispatchAutograd, + AOTDispatchSubclassWrapper, + AOTSyntheticBaseWrapper, + AutogradLazyBackwardCompileInfo, + CompilerWrapper, + DebugAssertWrapper, + EffectTokensWrapper, + FakifiedOutWrapper, + FunctionalizedRngRuntimeWrapper, + make_runtime_safe, + post_compile, + pre_compile, + RuntimeWrapper, +) +from .schemas import AOTConfig, MutationType, ViewAndMutationMeta +from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta +from .utils import ( + _get_symint_hints, + contain_metadata_mutation_ops, + get_cuda_generator_meta_val, + make_boxed_func, + strict_zip, + unlift_tokens, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + +zip = strict_zip + +log = logging.getLogger(__name__) +aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + +aten = torch.ops.aten + +# Returns a Callable and a ViewAndMutationMeta. +# Currently, only export needs the ViewAndMutationMeta after this function. +DispatchReturn = tuple[Callable, ViewAndMutationMeta] + + +def _create_wrappers_for_dispatch(needs_autograd: bool) -> list[CompilerWrapper]: + """ + Wrappers that run on every dispatch function + """ + return [AOTDedupeWrapper(), AOTSyntheticBaseWrapper(trace_joint=needs_autograd)] + + +# Export's dispatching logic is unique in a few ways: it only needs the "graph" +# bits of aot_autograd, and doesn't need to do any specific wrapping. +def aot_dispatch_export( + flat_fn: Callable, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + needs_autograd: bool, +) -> DispatchReturn: + wrappers = _create_wrappers_for_dispatch(needs_autograd) + flat_fn, flat_args, fw_metadata = pre_compile( + wrappers, + flat_fn, + flat_args, + aot_config, + fw_metadata=fw_metadata, + ) + if needs_autograd and not aot_config.pre_dispatch: + graph, _, _ = aot_dispatch_autograd_graph( + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + else: + graph, _, _ = aot_dispatch_base_graph( + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + + # NB: the wrappers that run in pre_compile for export are + # either a no-op, because they're not needed, or will raise a runtime error, + # since they don't support export. + # We still run these wrappers to make sure that they're not needed pre compile, + # but we technically don't need to run them post compile at all here. + compiled_fn, fw_metadata = post_compile( + wrappers, graph, aot_config, runtime_metadata=fw_metadata + ) + + # Therefore, since no wrapperes run, we don't get back a callable - we get back the raw fx graph + # (either a joint or an inference-only graph) + assert isinstance(compiled_fn, torch.fx.GraphModule) + return compiled_fn, fw_metadata + + +def sanitize_aot_config(input: AOTConfig) -> AOTConfig: + return AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + inference_compiler=None, + num_params_buffers=input.num_params_buffers, + aot_id=input.aot_id, + keep_inference_input_mutations=input.keep_inference_input_mutations, + is_export=input.is_export, + no_tangents=input.no_tangents, + aot_autograd_arg_pos_to_source=input.aot_autograd_arg_pos_to_source, + dynamic_shapes=input.dynamic_shapes, + enable_log=input.enable_log, + static_input_indices=input.static_input_indices, + pre_dispatch=input.pre_dispatch, + cache_info=None, + precompile_backend_id=input.precompile_backend_id, + ) + + +def aot_dispatch_base( + flat_fn, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> DispatchReturn: + """ + Handles functions that don't need autograd. Runs wrappers and compiles with fw_compiler. + """ + wrappers = _create_wrappers_for_dispatch(needs_autograd=False) + flat_fn, flat_args, fw_metadata = pre_compile( + wrappers, flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc] + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + # Save the forward_graph_str right after aot_dispatch_base_graph, + # to save in the cache + aot_forward_graph_str = None + if aot_config.cache_info is not None: + aot_forward_graph_str = fw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + fast_sympy_print=True, + ) + + fakified_out_wrapper = FakifiedOutWrapper() + ( + fw_module, + updated_flat_args, + fw_metadata, + ) = fakified_out_wrapper.pre_compile( + fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata + ) + functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper() + ( + fw_module, + updated_flat_args, + fw_metadata, + ) = functionalized_rng_wrapper.pre_compile( + fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata + ) + assert isinstance(fw_module, GraphModule) + + if aot_config.enable_log: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "torch._functorch.config", + "encoding": "string", + }, + payload_fn=lambda: torch._functorch.config.get_config_copy(), + ) + + disable_amp = torch._C._is_any_autocast_enabled() + context = torch._C._DisableAutocast if disable_amp else nullcontext + + with context(), track_graph_compiling(aot_config, "inference"): + compiler = ( + aot_config.inference_compiler + if aot_config.inference_compiler is not None + else aot_config.fw_compiler + ) + + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.fw_metadata = ( + fw_metadata + if maybe_subclass_meta is None + else maybe_subclass_meta.fw_metadata + ) + + with TracingContext.report_output_strides() as fwd_output_strides: + fake_mode = detect_fake_mode() + if fake_mode is not None and fake_mode.shape_env is not None: + tensorify_python_scalars(fw_module, fake_mode.shape_env, fake_mode) + compiled_fw = compiler(fw_module, updated_flat_args) + + if fakified_out_wrapper.needs_post_compile: + fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) + + make_runtime_safe(fw_metadata, maybe_subclass_meta) + + # However, RuntimeWrapper does not expect the rng offsets in the + # output. So, we have to create another wrapper and take out the offset. As + # a result, we have to account for not boxed_call compilers as well. + if not getattr(compiled_fw, "_boxed_call", False): + compiled_fw = make_boxed_func(compiled_fw) + + # Create a wrapper to set up the rng functionalize and fakified out bits + compiled_fw = functionalized_rng_wrapper.post_compile( + compiled_fw, aot_config, runtime_metadata=fw_metadata + ) + cache_info = aot_config.cache_info + if cache_info is not None: + if hasattr(compiled_fw, "_fx_graph_cache_key"): + time_taken_ns = time.time_ns() - cache_info.start_time_ns + guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) + entry = AOTAutogradCache.make_entry( + compiled_fw_func=compiled_fw, # type: ignore[arg-type] + compiled_bw_func=None, + aot_joint_graph_str=None, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=None, + runtime_metadata=fw_metadata, + dispatch_wrappers=wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=None, + indices_of_inps_to_detach=[], + forward_time_taken_ns=time_taken_ns, + backward_time_taken_ns=0, + sanitized_aot_config=sanitize_aot_config(aot_config), + guards_expr=guards_expr, + backward_state_indices=None, + num_symints_saved_for_bw=None, + serialized_bw_module=None, + ) + AOTAutogradCache.save( + cache_info.cache_key, entry, remote=should_use_remote_autograd_cache() + ) + + compiled_fw = fakified_out_wrapper.post_compile( + compiled_fw, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw = EffectTokensWrapper().post_compile( + compiled_fw, + aot_config, + runtime_metadata=fw_metadata, + ) + + # Why do we need to pass in num_fw_outs_saved_for_bw? + # See Note: [Partitioner handling for Subclasses, Part 2] + compiled_fw = AOTDispatchSubclassWrapper( + trace_joint=False, + # TODO: once we use pre_compile this will be flat_fn at the top of this function + fw_only=None, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=None, + ).post_compile( + compiled_fw, + aot_config, # not used + runtime_metadata=fw_metadata, + ) + + if not getattr(compiled_fw, "_boxed_call", False): + compiled_fw = make_boxed_func(compiled_fw) + + compiled_fn = RuntimeWrapper( + indices_of_inps_to_detach=[], + trace_joint=False, + disable_amp=disable_amp, + ).post_compile( + compiled_fw, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fn = post_compile( + wrappers, compiled_fn, aot_config, runtime_metadata=fw_metadata + ) + return compiled_fn + + +def collect_fw_donated_buffer_idxs( + fw_ins: list[Optional[FakeTensor]], + user_fw_outs: list[Optional[FakeTensor]], + bw_outs: list[Optional[FakeTensor]], + saved_tensors: list[FakeTensor], +) -> list[int]: + """ + Checks if the saved tensors are donated buffers, which means a saved tensor is not + an alias of any tensors in fw_ins, user_fw_outs, and bw_outs. + """ + + storage_refs = set() + for t in itertools.chain(fw_ins, user_fw_outs, bw_outs): + # Only access storage if a tensor has storage (not sparse) + if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t): + storage_refs.add(StorageWeakRef(t.untyped_storage())) + + num_saved_tensor = len(saved_tensors) + donated_buffer_idxs = [] + for i in range(num_saved_tensor): + t = saved_tensors[i] + if ( + t is not None + and not is_sparse_any(t) + and StorageWeakRef(t.untyped_storage()) not in storage_refs + ): + donated_buffer_idxs.append(i) + + return donated_buffer_idxs + + +def collect_bw_donated_buffer_idxs( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + fw_metadata: ViewAndMutationMeta, +) -> list[int]: + """ + Collects backward donated buffer indexes from fw_module and bw_module. + """ + + # [Note: Metadata mutation in proxy tracing] + # node.meta["val"] is a snapshot of the tensor value when tracing a graph, + # instead of the final state after the graph has run. node.meta["val"] is + # not updated even if later there is a metadata mutation op. + # See: https://github.com/pytorch/pytorch/pull/141308#issuecomment-2495798947 + # + # Currently, metadata mutation op happens only for sacrificial parameter + # specifically the `set_` op. This motivates banning metadata mutation from + # proxy tracing. + # + # Since node.meta["val"] is used to detect donated buffer, we return an empty + # list if there exists metadata mutation op. + if contain_metadata_mutation_ops(fw_module) or contain_metadata_mutation_ops( + bw_module + ): + return [] + + fw_ins = fw_module.graph.find_nodes(op="placeholder") + bw_outs = next(reversed(bw_module.graph.find_nodes(op="output"))).args[0] + fw_outs = next(reversed(fw_module.graph.find_nodes(op="output"))).args[0] + + fw_ins = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in fw_ins + ] + fw_outs = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in fw_outs + ] + bw_outs = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in bw_outs + ] + + user_fw_outs = fw_outs[: fw_metadata.num_forward] + saved_tensors = fw_outs[fw_metadata.tensors_saved_for_backwards_slice] + + fw_donated_buffer = collect_fw_donated_buffer_idxs( + fw_ins, + user_fw_outs, + bw_outs, + saved_tensors, + ) + + assert fw_metadata.num_symints_saved_for_bw is not None + return [fw_metadata.num_symints_saved_for_bw + i for i in fw_donated_buffer] + + +@dataclasses.dataclass +class InvokeSubgraphHopGraphs: + """ + A data structure to hold all the information needed to partition the + `joint_hop_gm` and joint graph and the restitch the `new_fw_hop_gm` and + `new_bw_hop_gm` into the bigger `joint_gm`. + """ + + # To avoid re-partitioning subgraphs + partitioning_done: bool = False + old_num_fw_outputs: Optional[int] = None + old_num_fw_inputs: Optional[int] = None + + new_fw_hop_gm: Optional[torch.fx.GraphModule] = None + new_bw_hop_gm: Optional[torch.fx.GraphModule] = None + new_num_sym_nodes: Optional[int] = None + new_num_saved_nodes: Optional[int] = None + + +def run_joint_graph_passes_on_hops( + joint_gm: torch.fx.GraphModule, + joint_inputs: Any, + aot_config: AOTConfig, +) -> torch.fx.GraphModule: + """ + This pass runs the joint graph passes on the HOP graph. In torch.compile, we + typically have many passes which work on the joint graph and then end with a + partitioner. + + + The partitioner part is quite mechanical to handle. HOP have their own + forward and backward graph. The process can be broken into following steps + + 1) Get a `joint_hop_gm` from the `fw_hop_gm` and `bw_hop_gm` + 2) Run joint graph passes on the `joint_hop_gm` to get `new_fw_hop_gm` and `new_bw_hop_gm` + 3) Stitch the `new_fw_hop_gm` and `new_bw_hop_gm` back into the `joint_gm`. + + The terminology used in the code is + `joint_graph/joint_gm` : Refers to the main graph. This may contain many HOPs which have their own `hop_graph` + `fw_hop_graph/fw_hop_gm` : Refers to the forward graph associated with a HOP. + `bw_hop_graph/bw_hop_gm` : Refers to the backward graph associated with a HOP. + `joint_hop_graph/joint_hop_gm` : Refers to the subgraph associated with the HOP like invoke_subgraph. + `new_fw_hop_graph/new_fw_hop_gm` : Refers to the forward graph after partitioning is applied to `joint_hop_gm`. + `new_bw_hop_graph/new_bw_hop_gm` : Refers to the backward graph after partitioning is applied to `joint_hop_gm`. + + NB: This pass works for invoke_subgraph today because we took extra care in + the Autograd.Dispatch key of invoke_subgraph to vastly simplify Step 1. + """ + from torch._higher_order_ops import invoke_subgraph + + def num_outputs(mod): + return len(mod.graph.find_nodes(op="output")[0].args[0]) + + def num_inputs(mod): + return len(mod.graph.find_nodes(op="placeholder")) + + def prepare_for_partitioner(mod, num_primals, num_fw_outputs): + # min-cut partitioner requires the placeholders to have primals and + # tangents string in the node.name. The signature of the joint graph is + # (*primals, *tangents) + + # We also have to update the output signature which is right now + # (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the + # partitioner to work. + new_graph = torch.fx.Graph() + env = {} + + primals_counter = itertools.count(0) + tangents_counter = itertools.count(0) + + for idx, node in enumerate(mod.graph.nodes): + if node.op == "placeholder": + if idx < num_primals: + env[node] = new_graph.placeholder( + f"primals_{next(primals_counter)}" + ) + else: + env[node] = new_graph.placeholder( + f"tangents_{next(tangents_counter)}" + ) + env[node].meta = copy.copy(node.meta) + elif node.op == "output": + # Reverse the (*grads, *fw_outs) to (*fw_outs, *grads) + # The reason for having the reversed signature in the first + # place is to simplify step 3. + old_outputs = node.args[0] + new_outputs = ( + *old_outputs[-num_fw_outputs:], + *old_outputs[:-num_fw_outputs], + ) + new_outputs = [env[n] if n else None for n in new_outputs] + new_graph.output(tuple(new_outputs)) + else: + env[node] = new_graph.node_copy(node, lambda n: env[n]) + env[node].meta = copy.copy(node.meta) + + new_graph.lint() + + out = torch.fx.GraphModule(mod, new_graph) + return out + + new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict( + lambda: InvokeSubgraphHopGraphs() + ) + + # Step 1 - Get a `joint_hop_gm` from the `fw_hop_gm` and `bw_hop_gm` This is + # easy to do for `invoke_subgraph` HOP. During the Autograd dispatch key + # tracing, we have put the joint_hop_graph in the backward hop graph itself. + # So to recover the joint_hop_gm, we just have to look at the backward + # HOP graphs. + # So we will merge step 1 and step 2 in this next section + + # Save the fw and bwd hop nodes. We will later in-place modify the graph + # using these nodes. + fw_hop_nodes = [] + bw_hop_nodes = [] + for node in joint_gm.graph.nodes: + if ( + node.op == "call_function" + and node.target is invoke_subgraph + and isinstance(node.args[1], str) + ): + if node.args[1].startswith("fw"): + fw_hop_nodes.append(node) + elif node.args[1].startswith("bw"): + bw_hop_nodes.append(node) + + if not bw_hop_nodes: + return joint_gm + + assert len(fw_hop_nodes) == len(bw_hop_nodes) + + # Create a bw to hop node mapping. This helps us in identifying the bw and + # fw subgraph pairs without relying on the identifier. This is important + # because we can have different subgraphs for bwd for same subgraph in the + # fwd because of differing strides in the backward. + bw_to_fw_hop_node = dict(zip(list(reversed(bw_hop_nodes)), fw_hop_nodes)) + + for node in bw_hop_nodes: + identifier = node.args[1].removeprefix("bw") + + # If partitioning already done for this identifier, skip. This saves + # redundant joint graph passes for same subgraphs. + if new_hop_graphs[identifier].partitioning_done: + continue + + # Collect some information from the forward hop graph + fw_hop_node = bw_to_fw_hop_node[node] + fw_hop_gm = getattr(joint_gm, fw_hop_node.args[0].target) + assert isinstance(fw_hop_gm, torch.fx.GraphModule) + num_fw_inputs = num_inputs(fw_hop_gm) + num_fw_outputs = num_outputs(fw_hop_gm) + new_hop_graphs[identifier].old_num_fw_inputs = num_fw_inputs + new_hop_graphs[identifier].old_num_fw_outputs = num_fw_outputs + + # Step 1) - Get the `joint_hop_gm`. As mentioned earlier, the + # backward graph is the joint graph. + joint_hop_gm = getattr(joint_gm, node.args[0].target) + assert isinstance(joint_hop_gm, torch.fx.GraphModule) + + # Prepare the graph for the partitioner + joint_hop_gm = prepare_for_partitioner( + joint_hop_gm, num_fw_inputs, num_fw_outputs + ) + + # TODO: invoke_subgraph should track which of its inputs static indices + # so it can propagate them to the partitioner (and use in cudagraphs) + static_lifetime_input_indices: list[int] = [] + # Step 2) and 3) - Run joint graph passes and partitioner + new_fw_hop_gm, new_bw_hop_gm = aot_config.partition_fn( + joint_hop_gm, + [], + num_fwd_outputs=num_fw_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + + # Save the new forward and backward graph modules + new_hop_graphs[identifier].new_fw_hop_gm = new_fw_hop_gm + new_hop_graphs[identifier].new_bw_hop_gm = new_bw_hop_gm + + # Save the number of symints and saved tensors + new_fw_out_nodes = new_fw_hop_gm.graph.find_nodes(op="output")[0].args[0] + extra_outputs = new_fw_out_nodes[num_fw_outputs:] + symint_outputs = [n for n in extra_outputs if is_sym_node(n)] + + new_hop_graphs[identifier].new_num_sym_nodes = len(symint_outputs) + new_hop_graphs[identifier].new_num_saved_nodes = len(extra_outputs) - len( + symint_outputs + ) + + new_hop_graphs[identifier].partitioning_done = True + + # Step 3) Restitch the new fw and bw graphs back into the main graph. + # + # This is a very mechanical process. There are a quite a few pieces that we + # need to connect together to make it work. Lets try to understand the + # problem statement first. + # + # For the forward graph, the signature of the old_fw_hop_gm is + # inputs - (*primals) + # outputs - (*fw_outs) + # Now the signature of the new_fw_hop_gm is + # inputs - (*primals) -- This is same + # outputs - (*fw_outs, *saved_tensors) - This is different + # At a high level, this is an easy transformation, in the new graph we just + # have to replace the old_fw_hop_gm with the new_fw_hop_gm. Everything else + # falls into place, because the input signature (i.e. args) is same. And + # even though output signature is different, fw_outs are still at the same + # indexes as before. So the forward of the `joint_gm` works nicely. + # + # Now, lets look at the backward hop graph. Old signature + # inputs - (*primals, *tangents) + # outputs - (*grad_outs, *fw_outs) + # New signature + # inputs - (*saved_tensors, *tangents) -- Different + # outputs - (*grad_outs) -- Different + # Here both input and output signature change. The output signature handling + # is quite easy because the grads_out are sitting at the right place, so we + # dont have to do anything. + # + # For the input signature, we have to collect the saved tensors from the + # corresponding forward graph output. We collect all saved_tensors when we + # see the forward graph, and save it into a map and then later use it during + # the backward. + + # The stack of fw_nodes for invoke_subgraph HOP. There is an implicit + # assumption about the graph structure, i.e., if we have hop1, hop2, hop3, + # ... in the forward part of the joint graph, we will have .., hop3, hop2, + # hop1 order for the backward. This structure allows us to just use a stack + # to collect all the information that we need to pass from the forward hop + # node to the corresponding backward node. + + already_added_new_hop_mods = set() + + def add_new_hop_gm(new_subgraph_mod, name): + new_subgraph_attr_name = f"partitioned_{name}" + if new_subgraph_attr_name in already_added_new_hop_mods: + return new_subgraph_attr_name + + joint_gm.register_module(new_subgraph_attr_name, new_subgraph_mod) + already_added_new_hop_mods.add(new_subgraph_attr_name) + return new_subgraph_attr_name + + def propagate_meta_info(new_hop_gm, new_call_function_node, old_call_function_node): + # Copy all the fields from the old call_function node. And then override + # the `val` meta field with the outputs of new_hop_gm. + new_call_function_node.meta = copy.copy(old_call_function_node.meta) + + output = new_hop_gm.graph.find_nodes(op="output")[0] + out_example_vals = [n.meta["val"] if n else None for n in output.args[0]] + new_call_function_node.meta["val"] = tuple(out_example_vals) + + for bw_node in reversed(bw_hop_nodes): + identifier = bw_node.args[1].removeprefix("bw") + + # Make changes to the corresponding fw and bw node pair simultaneously. + # The removes the need of any bookkeeping. + + # Fw node changes + # Insert the new_fw_hop_gm. This is straightforward. Get the + # new_fw_hop_gm, insert the hop_gm as a get_attr fw_node, and then + # add a call_function fw_node. Additionally, also use getitem + # call_functions to collect the saved_tensor nodes + + fw_node = bw_to_fw_hop_node[bw_node] + new_fw_hop_gm = new_hop_graphs[identifier].new_fw_hop_gm + assert new_fw_hop_gm is not None + + old_num_fw_outputs = new_hop_graphs[identifier].old_num_fw_outputs + new_num_sym_nodes = new_hop_graphs[identifier].new_num_sym_nodes + new_num_saved_nodes = new_hop_graphs[identifier].new_num_saved_nodes + assert old_num_fw_outputs is not None + assert new_num_sym_nodes is not None + assert new_num_saved_nodes is not None + total_outputs = old_num_fw_outputs + new_num_saved_nodes + new_num_sym_nodes + + extra_fw_outputs = [] + + # Insert the new_fw_hop_gm into the joint_gm + with joint_gm.graph.inserting_after(fw_node): + new_fw_mod_attr_name = add_new_hop_gm(new_fw_hop_gm, f"fw{identifier}") + new_fw_mod_attr = joint_gm.graph.get_attr(new_fw_mod_attr_name) + + # new_hop_fw_gm output signature is (*fw_outs, *saved_tensors) + with joint_gm.graph.inserting_after(new_fw_mod_attr): + new_fw_node = joint_gm.graph.call_function( + the_function=invoke_subgraph, + args=( + new_fw_mod_attr, + new_fw_mod_attr_name, + *fw_node.args[2:], + ), + ) + propagate_meta_info(new_fw_hop_gm, new_fw_node, fw_node) + + # old_num_fw_outputs = (*fw_outs) + # new_num_fw_outputs = (*fw_outs, *saved_tensors, *sym_nodes) + with joint_gm.graph.inserting_after(new_fw_node): + for fw_out_idx in range(old_num_fw_outputs, total_outputs): + saved_tensor_node = joint_gm.graph.call_function( + the_function=operator.getitem, args=(new_fw_node, fw_out_idx) + ) + saved_tensor_node.meta = copy.copy(new_fw_node.meta) + saved_tensor_node.meta["val"] = new_fw_node.meta["val"][fw_out_idx] + extra_fw_outputs.append(saved_tensor_node) + + fw_node.replace_all_uses_with(new_fw_node) + joint_gm.graph.erase_node(fw_node) + + # Bw node changes + # Prepare the operands for the bwd graph + # Old bw graph signature : (*primals, *tangents) + # New signature will be : (*sym_nodes, *saved_tensors, *tangents) + # We have already collected the saved_tensors in the forward hop processing. + + # extra_fw_outputs are in the order (*saved_nodes, *sym_nodes). + # Partitioner has this quirk where the backward wants sym_nodes + # first. So extract the sym and saved nodes. + + new_bw_hop_gm = new_hop_graphs[identifier].new_bw_hop_gm + assert new_bw_hop_gm is not None + + saved_tensor_nodes = extra_fw_outputs[:new_num_saved_nodes] + sym_nodes = extra_fw_outputs[new_num_saved_nodes:] + + num_primals = new_hop_graphs[identifier].old_num_fw_inputs + assert num_primals is not None + tangents = list(bw_node.args[2 + num_primals :]) + operands = sym_nodes + saved_tensor_nodes + tangents + + # Insert the new_bw_hop_gm into the joint_gm + with joint_gm.graph.inserting_after(bw_node): + new_bw_mod_attr_name = add_new_hop_gm(new_bw_hop_gm, bw_node.args[1]) + new_bw_mod_attr = joint_gm.graph.get_attr(new_bw_mod_attr_name) + + with joint_gm.graph.inserting_after(new_bw_mod_attr): + new_bw_node = joint_gm.graph.call_function( + the_function=invoke_subgraph, + args=( + new_bw_mod_attr, + new_bw_mod_attr_name, + *operands, + ), + ) + propagate_meta_info(new_bw_hop_gm, new_bw_node, bw_node) + # Since the partitioner is run after the graph passes, we have lost + # the eager information and cannot faithfully extract the eager + # inputs for the new partitioned backward graph. For the forward + # graph, it was fine because the input signature remains same. + new_bw_node.meta.pop("eager_input_vals", None) + + bw_node.replace_all_uses_with(new_bw_node) + joint_gm.graph.erase_node(bw_node) + + joint_gm.graph.eliminate_dead_code() + joint_gm.graph.lint() + joint_gm.recompile() + return joint_gm + + +def maybe_log_graph( + gm, + graph_name, + aot_config, + structured_log_prefix_fn, + out_structured_logs: Optional[list[str]] = None, +): + if not aot_config.enable_log: + return + aot_graphs_log.debug( + "%s", + lazy_format_graph_code( + f"{graph_name}", + gm, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + def gm_str_fn() -> str: + return gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + if out_structured_logs is not None: + out_structured_logs.append(f"{structured_log_prefix_fn()}:{gm_str_fn()}") + else: + trace_structured( + f"{structured_log_prefix_fn()}", + payload_fn=lambda: gm_str_fn(), + ) + + +def create_wrap_fn(fn, args): + from functools import wraps + + from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify + + from .functional_utils import from_fun, has_data_mutation, to_fun + + def assert_no_mutation(t): + assert not has_data_mutation( + t + ), "Saved tensors hooks with inputs mutations are not allowed" + + @wraps(fn) + def _wrapper(*args): + with maybe_enable_thunkify(): + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + f_args = pytree.tree_map(to_fun, args) + f_outs = fn(*f_args) + pytree.tree_map(assert_no_mutation, f_args) + return pytree.tree_map(from_fun, f_outs) + + return _wrapper, args + + +def prepare_hook_gm(aot_config, fn, args): + from torch._functorch._aot_autograd.dispatch_and_compile_graph import _create_graph + + fn, args = create_wrap_fn(fn, args) + gm = _create_graph(fn, args, aot_config=aot_config) + return gm + + +# Inline Autograd saved_tensors_hooks into epilogue of forward graph +# and prologue of backward graph. +# This changes forward graph outputs and inputs. +# Pack hook can return tensors, sym scalars, constants. +# All tensors to save for backward will be grouped together at front. +# Sym scalars grouped on another end. Constants are inlined in the graph. +def maybe_inline_graph_saved_tensors_hooks( + fw_module, + bw_module, + num_inner_fwd_outputs, + inner_meta, + aot_config, + static_input_indices, +): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + return + + get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks + are_inline_hooks = ( + torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable + ) + + hooks = get_hooks() + if not are_inline_hooks(hooks): + return + + pack_hook_gm, unpack_hook_gm = hooks + + structured_logs: list[str] = [] + maybe_log_graph( + fw_module, + "Forward graph pre saved_tensors_hooks inlining", + aot_config, + lambda: "aot_forward_graph_pre_saved_tensors_hooks", + structured_logs, + ) + maybe_log_graph( + bw_module, + "Backward graph pre saved_tensors_hooks inlining", + aot_config, + lambda: "aot_backward_graph_pre_saved_tensors_hooks", + structured_logs, + ) + fw_g = fw_module.graph + bw_g = bw_module.graph + + fw_g_names = {node.name for node in fw_g.nodes} + bw_g_names = {node.name for node in bw_g.nodes} + + def _gen_unused_name(candidate: str): + c = candidate + i = 0 + while c in fw_g_names or c in bw_g_names: + c = f"{candidate}_{i}" + i = i + 1 + return c + + bw_g_inputs = bw_g.find_nodes(op="placeholder") + + fw_out_n = fw_g.output_node() + fw_outs = fw_out_n.args[0] # type: ignore[var-annotated] + fw_outs_inner_set = set(fw_outs[:num_inner_fwd_outputs]) + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + fw_outs_packed_tensors = [] # type: ignore[var-annotated] + fw_outs_packed_syms = [] # type: ignore[var-annotated] + + # The main use case for saved_tensors_hooks is activation quantization, + # for memory usage optimization. + # Desired behavior is to quantize saved activations to free the original saved tensor. + # Saved nodes may include forward inputs, outputs, parameters. + # They may be held by something else and will not be deallocated after quantization. + # Donated buffers are intermediates in the graph invisible for the user, + # this guarantees that they can be deallocated. + # Using this as a default behavior to select saved nodes to apply hooks. + # There is also a config to apply hooks for all saved nodes without any filtering. + # The plan is to propagate meta about the source of the saved node to the user hook function. + mode = torch._functorch.config.saved_tensors_hooks_filtering_mode + allow_set = None + exclude_set = None + + if mode == "donated": + # collect_bw_donated_buffer_idxs requires inner_meta to have num_symints_saved_for_bw + inner_meta.num_symints_saved_for_bw = len( + [n for n in fw_outs_saved_for_bw if is_sym_node(n)] + ) + bw_donated_idxs = collect_bw_donated_buffer_idxs( + fw_module, + bw_module, + inner_meta, + ) + fw_donated_idxs = [ + i - inner_meta.num_symints_saved_for_bw for i in bw_donated_idxs + ] + allow_set = {fw_outs_saved_for_bw[i].name for i in fw_donated_idxs} + elif mode == "no_static": + fw_g_inputs = fw_g.find_nodes(op="placeholder") + exclude_set = {fw_g_inputs[i].name for i in static_input_indices} + + if (allow_set is not None) and (not allow_set): + # This means we have empty whitelist, + # No donated (intermediate) saved. + # Do not do anything in this case + return + + if aot_config.enable_log: + structured_logs.append(f"fw_outs_saved_for_bw:{fw_outs_saved_for_bw}") + structured_logs.append(f"mode:{mode}") + structured_logs.append(f"allow_set:{allow_set}") + structured_logs.append(f"exclude_set:{exclude_set}") + + for saved in fw_outs_saved_for_bw: + if ((allow_set is not None) and (saved.name not in allow_set)) or ( + (exclude_set is not None) and (saved.name in exclude_set) + ): + if isinstance(saved.meta["val"], torch.Tensor): + fw_outs_packed_tensors.append(saved) + continue + + val = saved.meta["val"] + if not isinstance(val, torch.Tensor): + continue + + pack_out_val = pack_hook_gm(val) + + requires_sc_handling = any( + is_traceable_wrapper_subclass(x) for x in pytree.tree_leaves(pack_out_val) + ) + if requires_sc_handling: + raise NotImplementedError( + "Tensor subclasses in GraphModule saved tensors hooks are not supported" + "You can workaround it by manually returning subclass's inner tensors" + " in the pack hook, and reconstructing the subclass in the unpack hook" + ) + + pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) + pack_g = pack_gm.graph + maybe_log_graph( + pack_gm, + f"saved_tensors_pack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_pack {saved.name}", + structured_logs, + ) + pack_out_val = pack_gm(val) + + # Install pack hook graph as eiplogue of fw_module. + # Saved tensor output becomes input of pack hook graph. + # Replace saved tensor output with pack hook graph output. + # Outputs symbolic scalars, tensors are accumulated separately. + # Then in forward outputs and backward inputs installed in order + # sym_scalars, packed_saved_tensors. + # Keeping all tensors together allows to preserve + # the same identification at runtime, + # updating only number of saved sym_scalars and tensors. + pack_g_inputs = pack_g.find_nodes(op="placeholder") + assert len(pack_g_inputs) == 1 + env = {pack_g_inputs[0]: saved} + fw_pack_out_args = None + with fw_g.inserting_before(fw_out_n): + for node in pack_g.nodes: + if node.op == "placeholder": + continue + new_n = fw_g.node_copy(node, lambda n: env[n]) + fw_g_names.add(new_n.name) + env[node] = new_n + # Output node is temporarily copied to have remapped arguments. + # Removed in the end. + if node.op == "output": + fw_pack_out_args = new_n.args[0] + fw_g.erase_node(new_n) + + env.clear() + assert fw_pack_out_args + fw_outs_bw_ins_node_names = [] + for out_idx, _n in enumerate(pytree.tree_leaves(fw_pack_out_args)): + if not isinstance(_n, torch.fx.Node): + fw_outs_bw_ins_node_names.append("") + continue + + # This happens when hook is noop and it is either user input or user output. + # Do not do anything with this node. + if _n.op == "placeholder" or _n in fw_outs_inner_set: + # This means the hook returned input primals unchanged + # Do not rename in this case. + n = _n + new_node_name = _n.name + fw_outs_bw_ins_node_names.append(new_node_name) + else: + # We can not specify desired name in node_copy. + # Copying node manually to set specifc name, + # to have matching fw_outs, bw_inputs names. + new_node_name = _gen_unused_name(f"{saved.name}_hook_{out_idx}") + with fw_g.inserting_before(_n): + n = fw_g.create_node( + _n.op, + _n.target, + _n.args, + _n.kwargs, + name=new_node_name, + ) + assert n.name == new_node_name + fw_outs_bw_ins_node_names.append(new_node_name) + n.meta = copy.copy(_n.meta) + _n.replace_all_uses_with(n) + fw_g.erase_node(_n) + if isinstance(n.meta["val"], torch.Tensor): + fw_outs_packed_tensors.append(n) + elif is_sym_node(n): + fw_outs_packed_syms.append(n) + + # Install unpack hook graph as a prologue of backward graph + # Saved tensors inputs are replaced with packed tensors and packed sym scalars. + # The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs. + unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) + unpack_g = unpack_gm.graph + maybe_log_graph( + unpack_gm, + f"saved_tensors_unpack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", + structured_logs, + ) + + def find_saved_in_bw_inputs(bw_inputs): + for n in bw_inputs: + if n.name == saved.name: + return n + + bw_g_input = find_saved_in_bw_inputs(bw_g_inputs) + assert bw_g_input + original_bw_g_input_users = list(bw_g_input.users.keys()) + bw_g_input_used_directly = False + + # Replace backward graph saved tensor input with copy of pack graph outputs + # All non-Tensor, non-symscalars outputs are constanted. + + unpack_g_inputs = unpack_g.find_nodes(op="placeholder") + env = {} + for out_idx, (unp_in_n, out_n, val) in enumerate( + zip( + unpack_g_inputs, + pytree.tree_leaves(fw_pack_out_args), + pytree.tree_leaves(pack_out_val), + ) + ): + is_sym = isinstance(val, py_sym_types) + if isinstance(val, torch.Tensor) or is_sym: + # We want forward_outputs names to match backward_inputs, + # Potentially backward may already have "{saved.name}_hook_{idx}", + # In this case fx.Graph will add suffix. + new_node_name = fw_outs_bw_ins_node_names[out_idx] + if bw_g_input.name == new_node_name: + env[unp_in_n] = bw_g_input + bw_g_input_used_directly = True + else: + # Backward calling convention: ctx_symints,ctx_saved_tensors + # Inserting packed sym scalars before first saved tensor input. + # Inserting packed tensors before last saved tensor input. + # Saved tensor inputs between them will be removed. + with bw_g.inserting_before( + bw_g_inputs[0] + ) if is_sym else bw_g.inserting_before(bw_g_input): + new_n = bw_g.placeholder(new_node_name) + assert new_n.name == new_node_name + new_n.meta = copy.copy(out_n.meta) + env[unp_in_n] = new_n + else: + # Inline values of non-Tensor, non-SymScalars + env[unp_in_n] = val + + # Inserting unpack hook after placeholders. + bw_unpack_out_n = None + with bw_g.inserting_before(bw_g_inputs[-1].next): + for node in unpack_g.nodes: + if node.op == "placeholder": + continue + new_n = bw_g.node_copy(node, lambda n: env[n]) + bw_g_names.add(new_n.name) + env[node] = new_n + # Temporary insert output, to have remapped by node_copy args. + # Removed in the end. + if node.op == "output": + bw_unpack_out_n = new_n + + assert bw_unpack_out_n + _leaves = pytree.tree_leaves(bw_unpack_out_n.args) + assert len(_leaves) == 1 + unpack_saved_tensor_n = _leaves[0] + + if not bw_g_input_used_directly: + bw_g_input.replace_all_uses_with(unpack_saved_tensor_n) + bw_g.erase_node(bw_g_input) + else: + # Keep usages of bw_g_input in inserted unpacked hook graph. + # Replace other usages of bw_g_input with unpack_saved_tensor_n. + from torch._C import _fx_map_arg + + def maybe_replace_node(n): + return unpack_saved_tensor_n if n == bw_g_input else n + + for use_node in original_bw_g_input_users: + new_args = _fx_map_arg(use_node.args, maybe_replace_node) + new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + use_node._update_args_kwargs(new_args, new_kwargs) + bw_g.erase_node(bw_unpack_out_n) + + # Changing forward graph outputs, + # Inserting packed_tensors and packed_syms on the place of saved tensors. + # Packed sym_scalars are together with saved symints + symint_outs_saved_for_bw = [n for n in fw_outs_saved_for_bw if is_sym_node(n)] + fw_new_outs = pytree.tree_leaves( + ( + fw_outs[:num_inner_fwd_outputs], + fw_outs_packed_tensors, + fw_outs_packed_syms, + symint_outs_saved_for_bw, + ) + ) + fw_out_n.args = (tuple(fw_new_outs),) + + # Assert that saved tensors and symints in forward outputs are aligned with backward inputs + _fw_n = num_inner_fwd_outputs + _fw_num_t = len(fw_outs_packed_tensors) + _fw_num_s = len(fw_outs_packed_syms) + len(symint_outs_saved_for_bw) + fw_outs_saved_tensors = fw_new_outs[_fw_n : _fw_n + _fw_num_t] + fw_outs_saved_syms = fw_new_outs[_fw_n + _fw_num_t :] + bw_new_ins = list(bw_g.find_nodes(op="placeholder")) + bw_ins_saved_syms = bw_new_ins[:_fw_num_s] + bw_ins_saved_tensors = bw_new_ins[_fw_num_s : _fw_num_s + _fw_num_t] + + fw_t_names = [n.name for n in fw_outs_saved_tensors] + bw_t_names = [n.name for n in bw_ins_saved_tensors] + fw_s_names = [n.name for n in fw_outs_saved_syms] + bw_s_names = [n.name for n in bw_ins_saved_syms] + + def _log_structured_logs(): + if not aot_config.enable_log: + return + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_saved_tensors_hooks_graphs", + "encoding": "string", + }, + payload_fn=lambda: "\n".join(structured_logs), + ) + + if aot_config.enable_log: + structured_logs.append( + f"fw_outs[:num_inner_fwd_outputs]:{fw_outs[:num_inner_fwd_outputs]}" + ) + structured_logs.append(f"fw_outs_packed_tensors:{fw_outs_packed_tensors}") + structured_logs.append(f"fw_t_names:{fw_t_names}") + structured_logs.append(f"bw_t_names:{bw_t_names}") + structured_logs.append(f"fw_s_names:{fw_s_names}") + structured_logs.append(f"bw_s_names:{bw_s_names}") + structured_logs.append(f"\nfw_g_pre_assert:{fw_g}") + structured_logs.append(f"\nbw_g_pre_assert:{bw_g}") + maybe_log_graph( + fw_module, + "Forward graph after transform pre-assert", + aot_config, + lambda: "aot_forward_graph_pre_assert_saved_tensors_hooks", + structured_logs, + ) + maybe_log_graph( + bw_module, + "Backward graph after transform pre-assert", + aot_config, + lambda: "aot_backward_graph_pre_assert_saved_tensors_hooks", + structured_logs, + ) + _log_structured_logs() + + assert fw_t_names == bw_t_names + assert fw_s_names == bw_s_names + + fw_g.lint() + bw_g.lint() + fw_module.recompile() + bw_module.recompile() + + +def aot_dispatch_autograd( + flat_fn, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> DispatchReturn: + """ + Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers, + and returns a wrapped torch.autograd.Function with a forward and backward. + """ + wrappers = _create_wrappers_for_dispatch(needs_autograd=True) + flat_fn, flat_args, fw_metadata = pre_compile( + wrappers, + flat_fn, + flat_args, + aot_config, + fw_metadata=fw_metadata, + ) + + fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() + with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True): + fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph( + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + + # Copied from aot_dispatch_autograd_graph. + disable_amp = torch._C._is_any_autocast_enabled() + joint_graph_str = None + if aot_config.enable_log: + aot_joint_log.info( + "%s", + lazy_format_graph_code( + "Joint graph", + fx_g, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + joint_graph_str = fx_g.print_readable( + print_output=False, include_stride=True, include_device=True + ) + trace_structured( + "aot_joint_graph", + payload_fn=lambda: joint_graph_str, + ) + + with torch.no_grad(): + inner_meta = ( + fw_metadata + if maybe_subclass_meta is None + else maybe_subclass_meta.fw_metadata + ) + with track_graph_compiling(aot_config, "joint"): + # See Note: [Partitioner handling for Subclasses, Part 1] + # See Note: [Recomputing subclass mutation handling] + mutated_inp_runtime_indices = ( + compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata, inner_meta + ) + ) + num_tokens = len(fw_metadata.tokens) + num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices) + num_inner_fwd_outputs = ( + num_mutated_inp_runtime_indices + + inner_meta.num_outputs + + inner_meta.num_intermediate_bases + + inner_meta.num_outputs_rng_offset + + num_tokens # See Note [Side-Effectful Tokens in AOTAutograd] + ) + fake_mode = detect_fake_mode() + fx_g = run_joint_graph_passes_on_hops(fx_g, joint_inputs, aot_config) + + # TODO(anijain2305) - Add tensorify_python_scalars to the HOP graph passes. + if fake_mode is not None and fake_mode.shape_env is not None: + tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode) + + static_lifetime_input_indices = fw_metadata.static_input_indices + fw_module, bw_module = aot_config.partition_fn( + fx_g, + joint_inputs, + num_fwd_outputs=num_inner_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + rng_states = [ + n + for n in fw_module.graph.find_nodes(op="placeholder") + if "fwd_rng_state" in n.name + ] + fw_metadata.num_graphsafe_rng_states = len(rng_states) + if rng_states: + fw_metadata.graphsafe_rng_state_index = ( + rng_states[0].meta["val"].device.index + ) + + # See Note [Side-Effectful Tokens in AOTAutograd] + if config.unlift_effect_tokens and ( + num_tokens > 0 or fw_metadata.num_backward_tokens > 0 + ): + unlift_tokens(fw_module, fw_metadata, aot_config, bw_module) + + num_inner_fwd_outputs -= num_tokens + joint_inputs = ( + joint_inputs[0][num_tokens:], + joint_inputs[1], + ) + + maybe_inline_graph_saved_tensors_hooks( + fw_module, + bw_module, + num_inner_fwd_outputs, + inner_meta, + aot_config, + fw_metadata.static_input_indices, + ) + static_lifetime_input_indices = fw_metadata.static_input_indices + + fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0] + # we only need to bookkeep the symints that are saved for bw, not any symints + # the user forward might have returned in its own output + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) + symint_outs_saved_for_bw = [] + for idx, node in enumerate(fw_outs_saved_for_bw): + if is_sym_node(node): + symint_outs_saved_for_bw.append(node) + elif ( + isinstance(node, torch.fx.Node) + and "val" in getattr(node, "meta", {}) + and isinstance(node.meta["val"], FakeTensor) + ): + # record dynamic tensor activations + dynamic_dims: set[int] = { + dim + for dim, size in enumerate(node.meta["val"].shape) + if not isinstance(size, int) + } + if dynamic_dims: + fw_metadata.dynamic_saved_tensors_idxs[idx] = dynamic_dims + + fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + if torch._functorch.config.donated_buffer: + fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs( + fw_module, + bw_module, + inner_meta, + ) + inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs + + if aot_config.enable_log: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "torch._functorch.config", + "encoding": "string", + }, + payload_fn=lambda: torch._functorch.config.get_config_copy(), + ) + aot_graphs_log.info( + "aot_config id: %s, fw_metadata=%s, inner_meta=%s", + str(aot_config.aot_id), + str(fw_metadata), + str(inner_meta), + ) + + # Note [Detaching inputs that never need gradients] + # See https://github.com/pytorch/pytorch/issues/97745 + # Suppose we have a function like this that we want to compile: + # + # def f(x, y): + # return torch.mul(x, y.detach()) + # + # What gradients should we compute for x and y? + # By default, AOTAutograd will compute a gradient for **every** input that requires gradients, + # and so we'll compute: + # x_grad_input = y + # y_grad_input = None + # Does this preserve the semantics of eager mode? + # Unfortunately, no. + # Doing the above will cause autograd to **continue** to backprop the autograd tape + # that was generated from constructing y. + # + # This is **different** from what would have happened in eager mode. + # In eager mode, if we backprop through the output of this function, autograd will only traverse + # the bit of the autograd tape corresponding to "x". + # In particular, if a user had previously backpropped through y's autograd tape, + # And then they try to backprop through the output of the above function, + # then we'll hit the dreaded "Trying to backward through the graph a second time" error. + # + # You might think: If autograd sees that a gradient is None, shouldn't it stop early, + # instead of continuing the backprop through the ancestors of that node in the graph? + # + # Autograd has two passes: + # (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed + # (2) a second pass that actually goes ahead and executes each node when it becomes ready, + # propagating gradients + # By the time we're executing a node and we see that it produces a None, the set of nodes to execute + # is already locked-in. + # + # The fix: instead, we can recognize statically that the graph we're compiling will never contribute + # gradients to y, and prevent autograd from trying to traverse y's autograd tape at all. + # We can do this by manually detach'ing y before sending it through the `CompiledFunction`. + # + # Note that this solution is not bulletproof. + # It's possible to construct a case where eager may or may not have have tried to autograd through y, + # depending on the actual grad_outputs that were passed in during the backward. + # There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`, + # allowing autograd to re-use the graph. + # + # An example of this case is: + # def f(x): + # return x.detach() * 2, x * 3 + # If we were to only backprop through outs[0], in eager, we would stop + # If we backward only on the first output, we shouldn't send a grad through x. + # But the custom autograd function doesn't know that: it will materialize zero grads for x * 3 + # and we will end up with a zero grad at x. + # If we later backprop through the second output, this will also require backprop'ing through x. + # Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time. + _indices_of_inps_to_detach: list[int] = [] + + # reversed() since we expect output at end of graph + bw_output = next(reversed(bw_module.graph.find_nodes(op="output"))) + bw_outs: Sequence[torch.fx.Node] = bw_output.args[0] # type: ignore[assignment] + + # TODO: we should apply the below "detach inputs if their gradients are statically known to be None" + # optimization even if we have subclass inputs/outputs (we do not handle this today). + # Computing which our our inputs get None gradients is a bit more complicated, + # if any of our inputs are subclasses. Why? + # (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses. + # (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors, + # so we need to figure out which subclass fw inputs they map to. + if maybe_subclass_meta is None: + num_backward_tokens: int = inner_meta.num_backward_tokens + assert ( + len(bw_outs) + == len(fw_metadata.input_info) + + inner_meta.num_outputs_rng_offset + + num_backward_tokens + ) + bw_outs_no_rng_no_tokens = bw_outs + if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0: + bw_outs_no_rng_no_tokens = bw_outs[ + : -(inner_meta.num_outputs_rng_offset + num_backward_tokens) + ] + assert len(bw_outs_no_rng_no_tokens) == len(fw_metadata.input_info) + + for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens): + # If our input experiences a metadata mutation inside the graph (e.g. set_()), + # we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation + metadata_mutation_in_graph = ( + fw_metadata.input_info[i].mutation_type + == MutationType.MUTATED_IN_GRAPH + and fw_metadata.input_info[i].mutates_storage_metadata + ) + is_non_leaf = ( + fw_metadata.input_info[i].requires_grad + and not fw_metadata.input_info[i].is_leaf + ) + if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: + _indices_of_inps_to_detach.append(i) + + fw_module_str = None + bw_module_str = None + if aot_config.enable_log: + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Backward graph", + bw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + fw_module_str = fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ) + bw_module_str = bw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + + trace_structured( + "aot_forward_graph", + payload_fn=lambda: fw_module_str, + ) + trace_structured( + "aot_backward_graph", + payload_fn=lambda: bw_module_str, + ) + + # AMP is already traced out in joint graph. we do not wish to reapply it accidentally + # in the compiler. + with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast(): + # flat_args at this point might still be subclasses- + # make sure to pass the unwrapped fake tensors into the compiler! + adjusted_flat_args = joint_inputs[0] + + fakified_out_wrapper = FakifiedOutWrapper() + ( + fw_module, + adjusted_flat_args, + fw_metadata, + ) = fakified_out_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + + functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper( + return_new_outs=False + ) + + if rng_states: + index = fw_metadata.graphsafe_rng_state_index + assert index is not None + rng_states = [ + get_cuda_generator_meta_val(index) + for _ in range(fw_metadata.num_graphsafe_rng_states) + ] + adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] + + ( + fw_module, + adjusted_flat_args, + fw_metadata, + ) = functionalized_rng_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.fw_metadata = inner_meta + + with TracingContext.report_output_strides() as fwd_output_strides: + compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args) + + if not getattr(compiled_fw_func, "_boxed_call", False): + compiled_fw_func = make_boxed_func(compiled_fw_func) + + if fakified_out_wrapper.needs_post_compile: + fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) + + compiled_fw_func = EffectTokensWrapper().post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = AOTDispatchSubclassWrapper( + fw_only=None, + trace_joint=False, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, + aot_config, # not used + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = functionalized_rng_wrapper.post_compile( + compiled_fw_func, aot_config, runtime_metadata=fw_metadata + ) + compiled_fw_func = fakified_out_wrapper.post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + # NB: It's important to compile backwards ahead of time, as this may + # add extra guards which we need to apply to the Dynamo cache at + # forwards + with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast(): + placeholder_list = fx_placeholder_vals(bw_module) + + forward_saved_for_backwards_strides = None + if fwd_output_strides is not None: + forward_saved_for_backwards_strides = fwd_output_strides[ + inner_meta.tensors_saved_for_backwards_slice + ] + + # saved activations can have different stride to eager if + # the compiler does layout optimization. We should restride the + # tensor passed in for compiling the backward graph using the + # saved tensor's stride. + for i in range(len(placeholder_list)): + ph_arg = placeholder_list[i] + if not isinstance(ph_arg, torch.Tensor): + continue + + if forward_saved_for_backwards_strides is None: + continue + + real_stride = None + # Per all_args calling convention + j = i - num_symints_saved_for_bw + if 0 <= j < len(forward_saved_for_backwards_strides): + real_stride = forward_saved_for_backwards_strides[j] + if real_stride is None: + continue + + # Comparing ph_arg.stride() with real_stride directly may + # cause dynamic dimensions in ph_arg being specialized to static + # value. Using the hints to avoid that. + if _get_symint_hints(ph_arg.stride()) != real_stride: + # Note that here we use the stride of the real tensor to + # restride a FakeTensor. This does not cause trouble + # for dynamic shape since this code path only get + # executed if layout optimization is enabled. And we + # disable layout optimization for dynamic shape right + # now. + # + # A solution that decide stride order based on real + # tensor's stride and then apply that stride order to + # the FakeTensor does not work smoothly since some + # tensor's layout is not 'dense'. E.g. mixnet_l has a + # tensor with size [8, 64, 112, 112] and strides + # (2408448, 1, 21504, 192). The solution mentioned will + # decide a stride of (802816, 1, 7168, 64) for this + # tensor which is wrong. + placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride) + + compiled_bw_func = None + if num_symints_saved_for_bw > 0: + try: + # See Note: [Backward graph lazy lowering] + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + # If bw_module contains lifted constants, they will be real tensors stored as + # GraphModule. Deepcopying tensors under fake mode is not supported and will + # raise when attempting to set storage. + bw_module_copy = copy.deepcopy(bw_module) + compiled_bw_func = aot_config.bw_compiler( + bw_module_copy, placeholder_list + ) + del bw_module_copy + except Exception as e: + exc = e + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "eager_compile_backwards_failure", + "encoding": "string", + }, + payload_fn=lambda: "\n".join( + traceback.format_exception( + type(exc), exc, exc.__traceback__ + ) + ), + ) + log.warning( + "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", + exc_info=True, + ) + # Compiled autograd will run the bw_module in the backward pass, + # so recompilation need happen anyway if the backward pass is ever + # called. + # + # The reason we do the GraphModule recompilation here is because + # the lazy recompilation will cause issue in the backward pass + # with compiled autograd. + # + # Do the _LazyGraphModule.force_recompile here rather than when + # bw_module is first generated by the partitioner because the bw_module.recompile + # may be called in some code path later and cause the _LazyGraphModule.forward + # becomes the lazy version again. One example is when dynamic shape is enabled + # upfront, the bw_compiler will be called above which can cause extra + # graph module recompilation on bw_module. + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(bw_module) + + saved_context = TracingContext.try_get() + saved_compile_context = CompileContext.try_get() + + backward_state_indices = [ + idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState) + ] + assert len(backward_state_indices) <= 1 + + lazy_backward_info = AutogradLazyBackwardCompileInfo( + bw_module, + placeholder_list, + saved_context, + saved_compile_context, + ) + + make_runtime_safe(fw_metadata, maybe_subclass_meta) + + try_save_cache_entry: Optional[Callable] = None + + if aot_config.cache_info is not None: + forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns + + # NB: aot_config here is technically not needed as an argument: we could just + # close over aot_config.cache_info, since aot_config never changes. + # But closing over random variables is confusing IMO, so I'm leaving it. + def try_save_cache_entry( # noqa: F811 + compiled_bw_func: Callable, + bw_module: torch.fx.GraphModule, + _fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, + ): + fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None) + bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None) + cache_info = aot_config.cache_info + if cache_info is not None and fw_key and bw_key: + assert forward_time_taken_ns is not None + # TODO: technically, AOTAutograd does a *little* bit of post processing work + # in the backward that isn't measured here. But it's small enough that it's not worth + # the complexity of threading a bunch of times through the code, so we + # use the compiled_bw_func's inductor compile time instead. + # It's possible this changes in the future, in which case we should + # update backward_time_taken_ns to be more inclusive + backward_time_taken_ns = getattr(compiled_bw_func, "_time_taken_ns", 0) + + aot_forward_graph_str: Optional[str] = fw_module_str + aot_backward_graph_str: Optional[str] = bw_module_str + aot_joint_graph_str: Optional[str] = joint_graph_str + guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) + + entry = AOTAutogradCache.make_entry( + compiled_fw_func, # type: ignore[arg-type] + compiled_bw_func, # type: ignore[arg-type] + aot_joint_graph_str, + aot_forward_graph_str, + aot_backward_graph_str, + _fw_metadata, + wrappers, + maybe_subclass_meta, + num_fw_outs_saved_for_bw, + _indices_of_inps_to_detach, + forward_time_taken_ns, + backward_time_taken_ns, + sanitized_aot_config=sanitize_aot_config(aot_config), + guards_expr=guards_expr, + backward_state_indices=backward_state_indices, + num_symints_saved_for_bw=num_symints_saved_for_bw, + serialized_bw_module=serialize_graph_module(bw_module), + ) + remote = should_use_remote_autograd_cache() + AOTAutogradCache.save(cache_info.cache_key, entry, remote) + + if compiled_bw_func is not None: + # If we already compiled the backward, we save its cache entry now + try_save_cache_entry(compiled_bw_func, bw_module, fw_metadata, aot_config) + try_save_cache_entry = None + + compiled_fn = AOTDispatchAutograd.post_compile( + compiled_fw_func, + compiled_bw_func, + maybe_subclass_meta, + num_symints_saved_for_bw, + backward_state_indices, + disable_amp, + _indices_of_inps_to_detach, + lazy_backward_info, + aot_config, + fw_metadata=fw_metadata, + try_save_cache_entry=try_save_cache_entry, + ) + + if config.debug_assert: + flat_requires_grad: list[Optional[bool]] = [ + a.requires_grad if isinstance(a, Tensor) else None for a in flat_args + ] + compiled_fn = DebugAssertWrapper( + flat_requires_grad=flat_requires_grad + ).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata) + + compiled_fn = post_compile( + wrappers, + compiled_fn, + aot_config, + runtime_metadata=fw_metadata, + ) + return compiled_fn diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/logging_utils.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a6f4ae6436c1846095e9df6c410f08e250a760 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/logging_utils.py @@ -0,0 +1,146 @@ +# mypy: allow-untyped-defs +""" +Contains utils for logging in AOTAutograd, including managing the names of the graphs under +compilation, capturing user-friendly tracebacks, and debug messages. +""" + +import collections +from contextlib import contextmanager + +import torch +import torch.fx.traceback as fx_traceback + + +# This is a list since looking forward, we can have this arbitrarily nested. +graph_being_compiled: list[str] = [] +# TODO: It would be nice to reset the numbering every time aot_id goes +# up, but this is annoying to do right now (because we don't know if +# an aot_id will come back from the dead), so right now this also happens +# to be a globally unique number too (at the cost of wobbling if you change +# how the graphs compile) +nth_graph: int = 0 +model_name: str = "model" + + +def set_model_name(name): + global model_name + model_name = name + + +def get_aot_compilation_context() -> tuple[list[str], str, int]: + return list(graph_being_compiled), model_name, nth_graph + + +def get_aot_graph_name() -> str: + """ + Returns the name of the graph being compiled. + """ + global model_name, graph_being_compiled, nth_graph + return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}" + + +get_graph_being_compiled = get_aot_graph_name + + +@contextmanager +def track_graph_compiling(aot_config, graph_name): + global graph_being_compiled + # TODO: Don't shove the aot_id in here; set it in the context + graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"] + old_name = None + if tracing_context := torch._guards.TracingContext.try_get(): + old_name = tracing_context.aot_graph_name + tracing_context.aot_graph_name = graph_being_compiled + has_tracing_context = True + else: + has_tracing_context = False + try: + yield + finally: + global nth_graph + nth_graph += 1 + graph_being_compiled = [] + if has_tracing_context: + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.aot_graph_name = old_name + + +# Set up hooks so that during backward the fx's stack_trace is properly set +callback_set = False + + +def setup_stacktrace_preservation_hooks(roots: list): + def iter_graph(roots): + if not roots: + return + seen = set() + q = collections.deque() # type: ignore[var-annotated] + for node in roots: + if node is not None and node not in seen: + seen.add(node) + q.append(node) + + while q: + node = q.popleft() + for fn, _idx in node.next_functions: + if fn in seen or fn is None: + continue + seen.add(fn) + q.append(fn) + + yield node + + def get_callback(saved_stack_): + def callback(): + global callback_set + fx_traceback.set_stack_trace(saved_stack_) + callback_set = False + + return callback + + def get_prehook(stack_, seq_nr): + def prehook(grad_output): + global callback_set + + if not callback_set: + torch.autograd.variable.Variable._execution_engine.queue_callback( # type: ignore[attr-defined] + get_callback(fx_traceback.format_stack()) + ) + callback_set = True + + fx_traceback.set_stack_trace(stack_) + fx_traceback.set_grad_fn_seq_nr(seq_nr) + + return prehook + + def get_posthook(special_stack_, seq_nr): + def posthook(grad_input, grad_output): + fx_traceback.set_stack_trace(special_stack_) + fx_traceback.reset_grad_fn_seq_nr() + + return posthook + + for node in iter_graph(roots): + forward_node_stack = node.metadata.get("traceback_", []) + node.register_prehook(get_prehook(forward_node_stack, node._sequence_nr())) + + special_stack = forward_node_stack.copy() + special_stack.append( + "Gradient addition node due to multiple use of tensor around:" + ) + node.register_hook(get_posthook(special_stack, node._sequence_nr())) + + +def describe_input(i, aot_config): + if i < aot_config.num_params_buffers: + return f"parameter/buffer {i}" + else: + return f"input {i - aot_config.num_params_buffers}" + + +def format_guard_bug_msg(aot_config, expected): + return ( + f"At compilation time, graph {aot_config.aot_id} was compiled under the " + f"assumption that {expected}, but at runtime this was not the case. " + "This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch." + ) diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..8e22e20e22af534fe7a87284a6ee2fd2e6b20b5e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -0,0 +1,2485 @@ +# mypy: allow-untyped-defs +""" +This module defines runtime wrappers, which, based on previous analysis attempts to: +1. process the inputs and outputs +2. apply mutations +3. handle functionalized randomness +4. deduplicate inputs and consolidate views into their bases (see input_output_analysis) +""" +import builtins +import collections +import contextlib +import copy +import itertools +import pprint +from contextlib import AbstractContextManager, nullcontext +from dataclasses import dataclass, field +from functools import wraps +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +import torch.utils.dlpack +from torch import Tensor +from torch._dynamo import config as dynamo_config +from torch._dynamo.callback import callback_handler, CallbackTrigger +from torch._dynamo.utils import CompileEventLogger, dynamo_timed, get_metrics_context +from torch._guards import ( + compile_context, + CompileContext, + detect_fake_mode, + DuplicateInputs, + tracing, + TracingContext, +) +from torch._prims_common import CUDARngStateHelper +from torch._subclasses import FakeTensor +from torch.fx.experimental._backward_state import BackwardState +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata +from .functional_utils import gen_alias_from_base +from .input_output_analysis import ( + compute_overlapping_inputs, + create_synthetic_base_metadata, + remove_dupe_metadata, +) +from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling +from .schemas import ( + AOTConfig, + InputAliasInfo, + MemoryFormatMeta, + MutationType, + OutputType, + PlainTensorMeta, + SubclassCreationMeta, + SubclassMeta, + TensorAlias, + ViewAndMutationMeta, +) +from .subclass_utils import ( + requires_subclass_dispatch, + runtime_unwrap_tensor_subclasses, + wrap_tensor_subclasses, +) +from .traced_function_transforms import aot_dispatch_subclass +from .utils import ( + call_func_at_runtime_with_args, + make_boxed_func, + partial_flatten_asdict, + strict_zip, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +zip = strict_zip + + +class CompilerWrapper: + """ + A wrapper around the inputs and outputs to the compiler_fn. We separate these into two parts: + + 1. The prologue, which edits the input to the compiler_fn(flat_fn, flat_args, etc) + 2. The epilogue, which edits the outputs of the compiler_fn (compiled_fn, real arguments) + + Each wrapper below should be implemented as a CompilerWrapper, so that we can facilitate + caching on the compiled output, and re-wrapping the output via epilogues. + Extra metadata that is needed to compute pre or post compile can be passed in via attributes. + """ + + def pre_compile( + self, + flat_fn, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return flat_fn, flat_args, fw_metadata + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + +# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic +# that needs to run after the compiled function. +# +# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime +# epilogue for a forward-only inference graph, or for an autograd.Function.apply function. +# This is because there are some minor differences in how we treat these cases at runtime: +# - resize_() is currently handled in the inference case, but not fully handled in the autograd case. +# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs +@dataclass +class RuntimeWrapper(CompilerWrapper): + indices_of_inps_to_detach: list[int] + trace_joint: bool + disable_amp: bool + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + return _create_runtime_wrapper( + compiled_fn, + runtime_metadata=runtime_metadata, + indices_of_inps_to_detach=self.indices_of_inps_to_detach, + trace_joint=self.trace_joint, + keep_input_mutations=aot_config.keep_inference_input_mutations, + disable_amp=self.disable_amp, + ) + + +class NoopAliasHandler: + def __init__(self, info, runtime_metadata, trace_joint): + pass + + def __call__(self, orig_inputs, fw_outs, out): + return out + + +def _unwrap_tensoralias(x): + assert isinstance(x, TensorAlias) + return x.alias + + +def _identity(x): + return x + + +class AliasOfInputHandler: + def __init__(self, info, runtime_metadata, trace_joint): + self.base_idx = info.base_idx + self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity + self.requires_grad = info.requires_grad + self.functional_tensor = info.functional_tensor + self.replay_views = config.view_replay_for_aliased_outputs + + def __call__(self, orig_inputs, fw_outs, out): + aliased_base_tensor = orig_inputs[self.base_idx] + return gen_alias_from_base( + aliased_base_tensor, + self.unwrap_out(out), + self.requires_grad, + self.functional_tensor, + replay_views=self.replay_views, + ) + + +class IsInputHandler: + def __init__(self, info, runtime_metadata, trace_joint): + self.base_idx = info.base_idx + self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity + + def __call__(self, orig_inputs, fw_outs, out): + aliased_base_tensor = orig_inputs[self.base_idx] + return aliased_base_tensor + + +class AliasOfIntermediateHandler: + def __init__(self, info, runtime_metadata, trace_joint): + self._unwrap_aliased_base_tensor = _identity + if info.output_type in ( + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + ): + num_user_outputs = len(runtime_metadata.output_info) + self.base_idx = info.base_idx + num_user_outputs + else: + self.base_idx = info.base_idx + if self.base_idx in runtime_metadata.aliased_out_indices: + self._unwrap_aliased_base_tensor = _unwrap_tensoralias + + self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity + self.requires_grad = info.requires_grad + self.functional_tensor = info.functional_tensor + self.replay_views = config.view_replay_for_aliased_outputs + + def __call__(self, orig_inputs, fw_outs, out): + aliased_base_tensor = fw_outs[self.base_idx] + return gen_alias_from_base( + self._unwrap_aliased_base_tensor(aliased_base_tensor), + self.unwrap_out(out), + self.requires_grad, + self.functional_tensor, + replay_views=self.replay_views, + ) + + +_HANDLER_MAP = { + OutputType.non_alias: NoopAliasHandler, + OutputType.unsafe_view_alias: NoopAliasHandler, + OutputType.custom_function_view: NoopAliasHandler, + OutputType.alias_of_input: AliasOfInputHandler, + OutputType.is_input: IsInputHandler, + OutputType.alias_of_intermediate: AliasOfIntermediateHandler, + OutputType.alias_of_intermediate_save_as_output: AliasOfIntermediateHandler, + OutputType.alias_of_intermediate_base_is_user_output: AliasOfIntermediateHandler, +} + + +def make_output_handler(info, runtime_metadata, trace_joint): + handler_type = _HANDLER_MAP[info.output_type] + return handler_type(info, runtime_metadata, trace_joint) + + +# not sure why AOTDispatcher needs to manually set this +def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]): + if hasattr(t, "_dynamo_weak_dynamic_indices"): + t._dynamo_weak_dynamic_indices |= dims + else: + t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined] + + +def _should_disable_saved_tensors_hooks(): + # Compiled autograd is not supported yet, to be added in future. + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + return False + + get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks + are_inline_hooks = ( + torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable + ) + + hooks = get_hooks() + if are_inline_hooks(hooks): + return True + + return False + + +def _create_runtime_wrapper( + compiled_fn, + *, + runtime_metadata: ViewAndMutationMeta, + indices_of_inps_to_detach: list[int], + trace_joint: bool, + keep_input_mutations: bool, + disable_amp: bool, +): + if not getattr(compiled_fn, "_boxed_call", False): + compiled_fn = make_boxed_func(compiled_fn) + + # Note [Inputs needed in runtime epilogue after list clearing] + # In Python functions, you can't free the input arguments of a function within the scope of that function. A workaround is to + # wrap the input arguments in a list, and clear the list from within the function. + # Here, this is implemented as `call_func_at_runtime_with_args(..., steal_args=True)`. + # + # This is needed for Compiled Autograd since some of the inputs (activations) should be freed early. + # However, we cannot blindly clear the entire list, because AOTAutograd may need access to some of the graph inputs + # **after** the compiled function has finished running. There are two main cases: + # (1) Input mutations: If there are an input mutations that we must run outside of the graph, we need access to the input. + # (2) Output aliasing: Outputs that aliases graph inputs generally must be regenerated outside of the `autograd.Function`, + # and doing so requires us accessing the corresponding input after the compiled artifact has run. + epilogue_args_idx = [] + epilogue_args_idx.extend(runtime_metadata.mutated_inp_runtime_indices) + for info in runtime_metadata.output_info: + if ( + info.output_type == OutputType.alias_of_input + or info.output_type == OutputType.is_input + ): + assert isinstance(info.base_idx, int) + epilogue_args_idx.append(info.base_idx) + + if config.unlift_effect_tokens: + assert len(runtime_metadata.tokens) == 0 + + if runtime_metadata.num_outputs_aliased > 0: + output_handlers = tuple( + make_output_handler(info, runtime_metadata, trace_joint) + for info in runtime_metadata.output_info + ) + + def record_runtime_wrapper_prologue_enter() -> ( + Optional[AbstractContextManager[None]] + ): + if ( + torch.autograd.profiler._is_profiler_enabled + and dynamo_config.record_runtime_overhead + ): + cm = torch._C._profiler._RecordFunctionFast( + "AOTDispatcher Runtime Wrapper Prologue" + ) + cm.__enter__() + return cm + return None + + def record_runtime_wrapper_prologue_exit( + cm: Optional[AbstractContextManager[None]], + ) -> None: + if cm is not None: + cm.__exit__(None, None, None) + + def runtime_wrapper(args: list[Any]): + # Create context manager for profiler + cm = record_runtime_wrapper_prologue_enter() + + # stash a ref to each input tensor we plan to use after the compiled function + orig_inputs = {i: args[i] for i in epilogue_args_idx} + + if keep_input_mutations: + mutated_args = ( + args[i] + for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd + ) + torch.autograd.graph.increment_version(mutated_args) + + if trace_joint: + args_ = list(args) + # See Note [Detaching inputs that never need gradients] + for idx in indices_of_inps_to_detach: + if isinstance(args_[idx], torch.Tensor): + args_[idx] = args_[idx].detach() + + # It's possible to have trace_joint inside user specified with no_grad() region, + # if there is a nested with enable_grad(), that forces some outputs to require gradients. + # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. + with ( + torch.autograd._force_original_view_tracking(True), + torch.enable_grad(), + ): + record_runtime_wrapper_prologue_exit(cm) + all_outs = call_func_at_runtime_with_args( + compiled_fn, args_, disable_amp=disable_amp, steal_args=True + ) + else: + # When we have an inference graph, we run with grad disabled. + # It's possible to get an inference graph with inputs that require grad, + # in which case we want to make sure autograd is disabled + # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) + # NOTE: We use _set_grad_enabled directly to reduce runtime overhead + grad_enabled = torch.is_grad_enabled() + try: + if grad_enabled: + torch._C._set_grad_enabled(False) + record_runtime_wrapper_prologue_exit(cm) + all_outs = call_func_at_runtime_with_args( + compiled_fn, args, disable_amp=disable_amp, steal_args=True + ) + finally: + if grad_enabled: + torch._C._set_grad_enabled(True) + del args + + num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices + num_intermediate_bases = runtime_metadata.num_intermediate_bases + + assert ( + len(all_outs) + == num_mutated_runtime_inps + + runtime_metadata.num_outputs + + num_intermediate_bases + ) + + # Step 3: After running the compiled fw, apply updates to mutated inputs + num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices + if num_mutations_to_apply > 0: + updated_inputs = all_outs[:num_mutations_to_apply] + fw_outs = all_outs[num_mutations_to_apply:] + + for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices): + meta = runtime_metadata.input_info[inpt_idx] + if not meta.mutates_data and not meta.mutates_metadata: + continue + original_inpt = orig_inputs[inpt_idx] + updated_inpt = updated_inputs[i] + if meta.mutates_storage_metadata: + # See Note [set_() Input Mutations in AOTAutograd] + # mutates_storage_metadata means our input saw a x.set_(y) call. + # What if x **also** saw a data and/or a metadata mutation? + # (1) If the [meta]data mutation occurred after the set_(), + # then there is no need to copy_() the data. + # When we perform x.set_(x_updated), we are guaranteed that + # x_updated already has the final version of the data/metadata + # (2) If a data mutation occurred before the set_(). + # This case seems very difficult to support. + # TODO: discuss on the PR and decide if we want to tr to + # either support it, or detect and ban it. + if trace_joint: + assert isinstance(updated_inpt, TensorAlias) + updated_inpt = updated_inpt.alias + with torch.no_grad(): + original_inpt.set_(updated_inpt) + continue + if meta.mutates_metadata and not meta.mutates_data: + if trace_joint: + assert isinstance(updated_inpt, TensorAlias) + updated_inpt = updated_inpt.alias + # We need to grab the size/stride/storage_offset from the compiled forward, + # and use that to mutate the metadata of the input + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + if meta.mutates_data and meta.mutates_metadata: + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + assert meta.mutates_data + if meta.is_leaf and original_inpt.requires_grad: + # We can hit this situation in this case: + # def f(x): + # x.detach().mul_(2) + # return x + 1 + # AOTAutograd will see a mutation in the above case, and try to + # apply a copy_() here, in the epilogue. + # But if x required gradients, and is a leaf, then autograd + # will yell at us for trying to mutate it. + # However, it's only possible to end up in this scenario (like the above) + # if all of the mutations to the leaf input were non-autograd-tracking mutations + # (aka mutations under no_grad(), or on detached views). + # In that case, we fully want to hide the mutation from autograd, so detaching is ok. + original_inpt.detach().copy_(updated_inpt) + else: + original_inpt.copy_(updated_inpt) + else: + fw_outs = all_outs + + # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of + # compiling them. + if runtime_metadata.num_outputs_aliased > 0: + # The compiled forward also returned intermediate bases. We don't want to return them to the user. + expect_num_outputs = ( + len(output_handlers) + runtime_metadata.num_intermediate_bases + ) + assert len(fw_outs) == expect_num_outputs + ret_outs = [ + handler(orig_inputs, fw_outs, out) + for out, handler in builtins.zip(fw_outs, output_handlers) + ] + else: + ret_outs = fw_outs + + if runtime_metadata.dynamic_outputs: + for t, o in zip(ret_outs, runtime_metadata.output_info): + if o.dynamic_dims is None: + continue + maybe_mark_dynamic_helper(t, o.dynamic_dims) + if runtime_metadata.grad_enabled_mutation is not None: + torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation) + return ret_outs + + if not (trace_joint and _should_disable_saved_tensors_hooks()): + return runtime_wrapper + + # Disabling saved tensors hooks + def _runtime_wrapper(*args, **kwargs): + with _disable_saved_tensors_hooks(): + return runtime_wrapper(*args, **kwargs) + + return _runtime_wrapper + + +@dataclass +class FunctionalizedRngRuntimeWrapper(CompilerWrapper): + # TODO: I would love to get rid of this argument, but it's + # Wrapped pretty tightly around our aot_dispatch_autograd logic. + # Specifically, tensors_saved_for_backwards_slice's value is both used for calculating indices + # for setting placeholder strides(which is done before runtime, before this wrapper runs) + # and for saving tensors for backward (which is done during runtime, after this wrapper runs) + # So in aot_dispatch_autograd, this wrapper can't edit the set of outs without making one + # of those two indices incorrect. + return_new_outs: bool = True + + def pre_compile( + self, + flat_fn, + flat_args, + aot_config, + *, + fw_metadata, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + if config.functionalize_rng_ops: + # Update example inputs for the fw_compiler + fake_mode = detect_fake_mode() + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode) + flat_args.extend([seed, offset]) + # We are not clearing flat_args here because + # 1) There is a check in the debug compiler at the end + # 2) It does not matter as these are fake tensors + return flat_fn, flat_args, fw_metadata + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + @wraps(compiled_fn) + def wrapper(runtime_args: list[Any]): + if runtime_metadata.is_rng_op_functionalized: + # Add the seed and offset to args + seed, offset = CUDARngStateHelper.get_torch_state_as_tuple() + runtime_args.extend([seed, offset]) + out = compiled_fn(runtime_args) + out = self._functionalized_rng_runtime_epilogue( + runtime_metadata, + out, + # TODO: this won't be right for the backward when we convert the call_compiled_backward to use the wrapper + runtime_metadata.num_forward_returns, + ) + return out + return compiled_fn(runtime_args) + + return wrapper + + # Calling convention: If we are running functionalized RNG, then outs consists + # of (user_outs, rng_offset) + def _functionalized_rng_runtime_epilogue( + self, + metadata: ViewAndMutationMeta, + outs, + offset_index, + ): + if metadata.is_rng_op_functionalized: + assert metadata.num_outputs_rng_offset == 1 + new_rng_offset = outs[offset_index] + CUDARngStateHelper.set_new_offset(new_rng_offset) + if self.return_new_outs: + user_outs = outs[:offset_index] + outs[offset_index + 1 :] + return user_outs + else: + return outs + + return outs + + +@dataclass +class FakifiedOutWrapper(CompilerWrapper): + out_metas: list[torch.Tensor] = field(default_factory=list) + # TracingContext.fwd_output_strides + # Generated from actually doing compile + # NB: an entry is None if it's not a Tensor + fwd_output_strides: Optional[list[Optional[list[int]]]] = None + needs_post_compile: bool = True + + def pre_compile( + self, + fw_module, # Must be fw_module from aot_dispatch_*_graph + flat_args, + aot_config, + *, + fw_metadata, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context and tracing_context.fakify_first_call: + self.out_metas = [ + n.meta["val"] for n in (list(fw_module.graph.nodes)[-1].args[0]) + ] + else: + self.needs_post_compile = False + return fw_module, flat_args, fw_metadata + + def _compute_output_meta_with_inductor_strides(self): + out = self.out_metas + fwd_output_strides = self.fwd_output_strides + if not fwd_output_strides: + return out + + from torch.fx.experimental.symbolic_shapes import statically_known_true + + for i in range(len(out)): + if not isinstance(out[i], Tensor): + continue + strides = fwd_output_strides[i] + # fwd_output_strides is best effort by Inductor. When an output + # Tensor has unbacked SymInts, Inductor may sometimes be unable + # to compute what the output stride would be. If Inductor doesn't + # have any clear direction on the layout, we don't have to run + # as_strided. To repro without this, run: + # + # python test/distributed/test_dynamo_distributed.py + # TestFakeDistributedSingleProc.test_unbacked_symbol_splitting_no_binding + if strides is None: + continue + if all( + statically_known_true(s1 == s2) + for s1, s2 in zip(out[i].stride(), strides) + ): + continue + out[i] = out[i].as_strided(out[i].shape, strides) + return out + + # To be called post compile + def set_fwd_output_strides(self, fwd_output_strides): + self.fwd_output_strides = fwd_output_strides + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if self.needs_post_compile: + assert self.fwd_output_strides is not None + fakified_out = self._compute_output_meta_with_inductor_strides() + + @wraps(compiled_fn) + def wrapper(runtime_args): + nonlocal fakified_out + if fakified_out is not None: + out = fakified_out + fakified_out = None + return out + return compiled_fn(runtime_args) + + return wrapper + # If we don't need to fakify, we can just return the original compiled function + return compiled_fn + + +# This wrapper handles the AOTDispatch runtime logic for tensor subclasses. +# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor, +# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs). +# This function handles the wrapping and unwrapping of tensor subclasses at runtime. +@dataclass +class AOTDispatchSubclassWrapper(CompilerWrapper): + trace_joint: bool + fw_only: Optional[Callable] # Not cached, only used in pre_compile + maybe_subclass_meta: Optional[SubclassMeta] + num_fw_outs_saved_for_bw: Optional[int] + + def pre_compile( + self, + flat_fn, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ): + (new_flat_fn, new_flat_args, subclass_meta) = aot_dispatch_subclass( + flat_fn, + flat_args, + is_joint_structure=self.trace_joint, + meta=fw_metadata, + fw_only=self.fw_only, # type: ignore[arg-type] + ) + self.maybe_subclass_meta = subclass_meta + return new_flat_fn, new_flat_args, fw_metadata + + def post_compile( + self, + compiled_fn, + _aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if self.maybe_subclass_meta is None: + return compiled_fn + + subclass_metas = runtime_metadata.subclass_fw_graph_out_meta + + @wraps(compiled_fn) + def inner_fn(args: list[Any]): + unwrapped_args = runtime_unwrap_tensor_subclasses( + args, + subclass_metas=runtime_metadata.subclass_inp_meta, + append_symints=True, + ) + args.clear() + # expectation: runtime_fn is a boxed fn + unwrapped_outs = compiled_fn(unwrapped_args) + wrapped_outs = wrap_tensor_subclasses( + unwrapped_outs, + subclass_metas=subclass_metas, + num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, + is_runtime=True, + included_subclass_symints=True, + ) + return wrapped_outs + + # box it + inner_fn._boxed_call = True # type: ignore[attr-defined] + return inner_fn + + +@dataclass +class EffectTokensWrapper(CompilerWrapper): + def post_compile( + self, + compiled_fn, + _aot_config, + *, + runtime_metadata: ViewAndMutationMeta, + ): + num_tokens = len(runtime_metadata.tokens) + + @wraps(compiled_fn) + def inner_fn(args: list[Any]): + if num_tokens > 0: + # Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) + old_args = args + args = [*([None] * num_tokens), *args] + old_args.clear() + + outs = compiled_fn(args) + + # Inductor cache DummyModule can return None + if outs is None: + return None + # Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd]) + return outs[num_tokens:] if num_tokens != 0 else outs + + # box it + inner_fn._boxed_call = True # type: ignore[attr-defined] + return inner_fn + + +# MOTIVATION: +# +# When tracing functions for future execution, one must be careful not to pass +# in the same input tensor multiple times (e.g., f(x, x), as this can result +# in graphs that are ONLY valid if you later pass a new tensor in exactly the +# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct +# tensors that alias each other is a different situation that is covered by +# aot_dispatch_deduplicated_autograd). Here are two examples: +# +# (1) Suppose you have a function: +# +# def f(x, y): +# return x + y +# +# If you make_fx(f)(x, x), you will trace out: +# +# def f(x, y): +# return y + y +# +# Oops! +# +# (2) For most tensors x and y, you can compute f's gradient with respect to +# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However, +# if x is y, you will trace out a program that gets incorrect gradients: +# +# >>> x = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + x, (x, x)) +# (tensor([2.]), tensor([2.])) +# +# In other words, the gradient is double-counted. Deduplicating the arguments +# gives you an appropriate gradient: +# +# >>> y = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + y, (x, y)) +# (tensor([1.]), tensor([1.])) +# +# HOW TO DEDUPLICATE: +# +# There are a few strategies, in order of preference: +# +# 1. For every duplicate argument to the function, detach it into +# a separate leaf tensor, so that it is no longer duplicated. +# +# PRO: The resulting compiled graph works for any configuration +# of duplicated arguments. +# +# CON: It does not (naively) work if you mutate the metadata of inputs: +# +# def f(x, y): +# x.transpose_(0, 1) +# y.transpose_(0, 2) +# +# x = torch.randn(2, 3, 4) +# f(x, x) +# +# The ordering of the transposes inside f dictates whether or not +# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute +# what metadata mutations should get applied to each input; you need to +# assume they aren't duplicates (what we do today) or preserve +# the original metadata mutations exactly in order, so that they work +# for any duplicate configuration. +# +# CON: It does not (naively) work if you mutate the data of inputs. +# In particular, leaf tensors that require grad cannot be mutated, +# this makes it impossible to differentiate with respect to the original +# base. +# +# 2. For every duplicate argument to the function, remove it, so it is +# no longer part of the "true" signature: +# +# PRO: Implemented naively, it still works for metadata/data mutation. +# +# CON: The resulting compiled graph is duplicate-specialized: it only +# works if future calls duplicate arguments in exactly the same way. +# Horribly, Dynamo doesn't guard on this at the moment. But even if +# it did, you could still end up recompiling a bunch of each duplicate. +# +# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if +# Dynamo's guards are not enough. In practice, this seems to cover +# everything. +# +@dataclass +class AOTDedupeWrapper(CompilerWrapper): + keep_arg_mask: list[bool] = field(default_factory=list) + add_dupe_map: list[int] = field(default_factory=list) + old_input_metadata: list[InputAliasInfo] = field(default_factory=list) + needs_post_compile: bool = True + + # NB: Hot path, avoid set lookups here + # TODO: Can avoid the zip here too, probably + def remove_dupe_args(self, args): + return [t for t, keep in zip(args, self.keep_arg_mask) if keep] + + def add_dupe_args(self, args): + return [args[i] for i in self.add_dupe_map] + + def pre_compile( + self, + flat_fn, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + # Use information about whether or not flat_fn mutates its arguments + # or not to handle dupe args + + # Strategy 1: For any input that is not mutated, we can leafify it if we + # need to remove a duplicate. + leaf_flat_args = [] + args_set = set() + ok = True + + for i, a in enumerate(flat_args): + if not isinstance(a, torch.Tensor): + leaf_flat_args.append(a) + elif a not in args_set: + args_set.add(a) + leaf_flat_args.append(a) + elif ( + not fw_metadata.input_info[i].mutates_data + and not fw_metadata.input_info[i].mutates_metadata + ): + leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad)) + else: + ok = False + break + + if ok: + self.needs_post_compile = False + return flat_fn, leaf_flat_args, fw_metadata + + if requires_subclass_dispatch(leaf_flat_args, fw_metadata): + raise RuntimeError( + """\ + Encountered duplicate inputs that are mutated in the graph, but at least one input/output + to the graph is a tensor subclass. This is not supported today. You can try to + remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" + ) + + # export path: ban duplicate inputs for now, add later if requested. + if aot_config.is_export: + raise RuntimeError( + f"""\ + Encountered duplicated inputs that are mutated in the graph you are trying to export. + This functionality is currently not supported. If needed, please file a github issue. + + fw_metadata={str(fw_metadata)} + """ + ) + + # Strategy 2: Duplicate specialization + # + # When we have duplicate arguments in a function call, we need to handle them specially. + # For example, if we have a function call f(a, b, a, c), we need to: + # + # 1. Remove duplicates to get a deduplicated list [a, b, c] + # 2. Compile our function to work with this deduplicated list + # 3. At runtime, convert incoming arguments with duplicates to the deduplicated form + # 4. Pass the deduplicated arguments to our compiled function + # + # To do this, we need two helper functions: + # + # - remove_dupe_args: Converts [a, b, a, c] -> [a, b, c] + # - add_dupe_args: Converts [a, b, c] -> [a, b, a, c] + # + # For our example [a, b, a, c], we track: + # + # - seen_args = {a: 0, b: 1, c: 2} (maps each unique arg to its first position) + # - add_dupe_map = [0, 1, 0, 2] (tells us how to reconstruct the original list) + # - keep_arg_mask = [True, True, False, True] (tells us which args to keep when deduplicating) + + seen_args: dict[Tensor, int] = {} + # Implicitly map duped arg position (list index) to de-duped arg position + keep_arg_mask: list[bool] = [] + add_dupe_map: list[int] = [] + duped_arg_len = len(flat_args) + + j = 0 # index into deduped_flat_args + for t in flat_args: + if isinstance(t, torch.Tensor): + if t in seen_args: + keep_arg_mask.append(False) + add_dupe_map.append(seen_args[t]) + continue + seen_args[t] = j + + keep_arg_mask.append(True) + add_dupe_map.append(j) + j += 1 + assert ( + len(add_dupe_map) == duped_arg_len + ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}" + + self.keep_arg_mask = keep_arg_mask + self.add_dupe_map = add_dupe_map + + deduped_flat_args = self.remove_dupe_args(flat_args) + + # Update our input metadata to remove duped input metadata. + updated_fw_metadata = remove_dupe_metadata( + fw_metadata, keep_arg_mask, add_dupe_map + ) + + if ( + tracing_context := TracingContext.try_get() + and aot_config.aot_autograd_arg_pos_to_source + ): + # TODO(voz): This structure is 1:1, we could consider an alternate structure like + # kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there, + # which feels like needless complexity for a tiny bit of efficiency at this point. + for dupe_arg_pos, (kept_pos, keep_arg) in enumerate( + zip(add_dupe_map, keep_arg_mask) + ): + if not keep_arg: + dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[ + dupe_arg_pos + ] + kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[ + kept_pos + ] + tracing_context.guards_context.aotautograd_guards.append( # type: ignore[attr-defined] + DuplicateInputs(kept_arg_source, dupe_arg_source) + ) + + @wraps(flat_fn) + def wrapped_flat_fn(*args): + return flat_fn(*self.add_dupe_args(args)) + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + wrapped_flat_fn, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=fw_metadata.keep_input_mutations, + is_train=fw_metadata.is_train, + )(*deduped_flat_args) + assert ( + ref_fw_metadata == updated_fw_metadata + ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}" + + return wrapped_flat_fn, deduped_flat_args, updated_fw_metadata + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if not self.needs_post_compile: + return compiled_fn + + @wraps(compiled_fn) + def wrapped_compiled_fn(args: list[Any]): + deduped_args = self.remove_dupe_args(args) + args.clear() + return compiled_fn(deduped_args) + + wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined] + + # This can be uncommented when we properly guard for duplicates, + # but right now we must not do it. + # if not config.debug_assert: + # return wrapped_compiled_fn + + @wraps(wrapped_compiled_fn) + def debugged_compiled_fn(args): + # Test that the computed remove/add arg functions are an inverse + new_args = self.add_dupe_args(self.remove_dupe_args(args)) + seen: dict[Any, None] = {} + for i, (x, y) in enumerate(zip(new_args, args)): + seen[y] = None + assert x is y, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would be a duplicate of " + f"{describe_input(self.add_dupe_map[i], aot_config)}", + ) + # This is only an error if there is metadata mutation on both of + # the duped arguments; in this case, we need to know what order + # the metadata mutation applies in. You'll get the correct result + # otherwise, because a graph that assumes distinct inputs works if + # you dupe the inputs (the gradient contributions from each input + # will get summed up appropriately.) + # + # TODO: work out how to setup this assert correctly + """ + assert len(seen) == unique_args, format_guard_bug_msg(aot_config, + f"there would be {unique_args} distinct arguments" + ) + """ + return wrapped_compiled_fn(args) + + debugged_compiled_fn._boxed_call = True # type: ignore[attr-defined] + + return debugged_compiled_fn + + +# This layer handles the situation where you have two inputs that alias each other, +# and one of the inputs is mutated. +# We need to take special care to ensure that the mutation is applied to the other aliases in the graph. +# +# pre-condition: AOTDedupWrapper has already run. +# (This function will in theory work if there are duplicate args. +# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs +# would cause us to hit that path more frequently). +@dataclass +class AOTSyntheticBaseWrapper(CompilerWrapper): + # Currently, the only reason we need to plumb this bool is because + # the synthetic base code prohibits more cases in the autograd case than the inference case. + trace_joint: bool # TODO: refactor trace_joint + needs_post_compile: bool = True + aliased_arg_idx_with_metadata_mutations: list[int] = field(default_factory=list) + + def pre_compile( + self, + flat_fn, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + is_inference = not self.trace_joint + flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs( + aot_config, + flat_args, + fw_metadata.input_info, + is_inference=is_inference, + ) + + # Happy path: we don't need synthetic bases + if synthetic_base_info is None: + self.needs_post_compile = False + return flat_fn, flat_args, fw_metadata + + # export path: ban synthetic bases for now, add later if requested. + if requires_subclass_dispatch(flat_args, fw_metadata): + raise RuntimeError( + """\ + Encountered aliased inputs that are mutated in the graph, but at least one input/output + to the graph is a tensor subclass. This is not supported today. You can try to + remove the aliasing yourself as a workaround, or otherwise file an issue on github.""" + ) + + if aot_config.is_export: + raise RuntimeError( + f"""\ + Encountered aliased inputs that are mutated in the graph you are trying to export. + This functionality is currently not supported. If needed, please file a github issue. + + synthetic_base_info={str(synthetic_base_info)} + + fw_metadata={str(fw_metadata)} + """ + ) + + assert len(fw_metadata.input_info) == len(synthetic_base_info) + + # Update our forward metadata to take synthetic bases into account + ( + fw_metadata_updated, + aliased_arg_idx_with_metadata_mutations, + ) = create_synthetic_base_metadata( + fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases + ) + # Save old input args for post-compile + self.old_input_info = fw_metadata.input_info + + self.aliased_arg_idx_with_metadata_mutations = ( + aliased_arg_idx_with_metadata_mutations + ) + replay_views = config.view_replay_for_aliased_outputs + + def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]: + f_args_inner = [] + for inner_idx_or_tuple in synthetic_base_info: + if isinstance(inner_idx_or_tuple, int): + f_args_inner.append(primals[inner_idx_or_tuple]) + else: + inner_base_idx, view_tensor = inner_idx_or_tuple + base = primals[inner_base_idx] + view_arg = gen_alias_from_base( + base, + view_tensor, + view_tensor.requires_grad, + replay_views=replay_views, + ) + f_args_inner.append(view_arg) + return f_args_inner + + @wraps(flat_fn) + def wrapped_flat_fn(*args): + unpacked_args = _unpack_synthetic_bases(args) + # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases) + # is to relieve the downstream logic from having to reason about mutations on inputs that alias + # each other, by replacing aliased inputs with a synthetic base. + # One area where this breaks down a bit however is if one of those aliased inputs + # experienced a metadata mutation. + # We are now obligated to reapply the metadata mutation directly to the user's input; + # it isn't enough to apply mutations back to the synthetic base in the downstream logic. + # + # The way we handle this is by pretending that those aliased inputs that experience metadata mutations + # are additional outputs in the user's forward function. + # The downstream logic will just treat these as "user outputs that alias inputs". + # However, we will manually grab them at runtime here, use them to reapply the metadata mutation + # to the user inputs, and not return them to the user. + aliased_args_with_metadata_mutations = [ + x + for i, x in enumerate(unpacked_args) + if i in self.aliased_arg_idx_with_metadata_mutations + ] + if len(aliased_args_with_metadata_mutations) > 0: + return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations + else: + return flat_fn(*unpacked_args) + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + wrapped_flat_fn, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=fw_metadata.keep_input_mutations, + is_train=fw_metadata.is_train, + )(*flat_args_with_synthetic_bases) + assert ref_fw_metadata == fw_metadata_updated, ( + f"ref_metadata={pprint.pformat(partial_flatten_asdict(ref_fw_metadata))}, " + f"\nactual_metadata={pprint.pformat(partial_flatten_asdict(fw_metadata_updated))}" + ) + return ( + wrapped_flat_fn, + flat_args_with_synthetic_bases, + fw_metadata_updated, + ) + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + if not self.needs_post_compile: + return compiled_fn + + is_inference = not self.trace_joint + + @wraps(compiled_fn) + def wrapped_compiled_fn(args): + args_with_synthetic_bases, synthetic_base_info = merge_view_inputs( + aot_config, args, self.old_input_info, is_inference=is_inference + ) + assert synthetic_base_info is not None + aliased_args_w_metadata_mutations = [ + args[i] for i in self.aliased_arg_idx_with_metadata_mutations + ] + num_aliased_args_with_metadata_mutations = len( + aliased_args_w_metadata_mutations + ) + args.clear() + outs = compiled_fn(args_with_synthetic_bases) + if num_aliased_args_with_metadata_mutations > 0: + # This code does not handle **all** input metadata mutations. + # Instead, it only handles metadata mutations on inputs that were converted into synthetic bases + # (which only happens if at least one aliased input experienced a data mutation). + # e.g: + # def f(a, b): + # a.mul_(2) + # b.t_(1, 0) + # f(x.view(2, 2), x.view(2, 2)) + mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:] + user_outs = outs[:-num_aliased_args_with_metadata_mutations] + for inp, mutated_inp in zip( + aliased_args_w_metadata_mutations, mutated_metadata_inps + ): + inp.as_strided_( + mutated_inp.size(), + mutated_inp.stride(), + mutated_inp.storage_offset(), + ) + return user_outs + return outs + + return wrapped_compiled_fn + + +# Note [Handling mutations on an input that aliases other inputs] +# The easiest example to show-case this edge case is here: +# +# def f(a, b): +# a.mul_(2) +# out = a + b +# return out +# b = torch.ones(...) +# a = b.view(-1) +# f(a, b) +# +# In this situation, if a and b happened to be aliased, we need to trace something different! +# Suppose we had b = a.view(-1) +# (In this case, that means that `a._base is b`) +# +# We need to ensure that the aliasing relationship between a and b is preserved. +# We do that detecting the specific situation above (mutate an input that aliases another input), +# and when we do that, we create a synthetic base argument. Then inside of the traced forward, +# we regenerate a and b off of that base. +# The complete example of the transformed function looks like this: +# +# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views +# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph +# def traced_forward(base): +# a = base.as_strided(...) +# b = base.as_strided(...) +# a_updated = a.mul(2) +# base_updated = torch.as_strided_scatter(base, a_updated, ...) +# b_updated = base_updated.as_strided(...) +# out = a_updated + b_updated +# return a_updated, out +# +# def compiled_fn(a, b): +# // we detect that a is the "differentiable base" here +# base = a +# // In other situations, we might do either: +# // (1) a and b are both views off of some larger differentiable base +# // assert a._base is b._base and a._base is not None +# // base = a._base +# // (2) a and b both don't require gradients. Create a base from the storage +# // assert a._base is None and b._base is None +# // base = torch.Tensor(a.storage()) +# a_updated, out = traced_forward(base) +# a.copy_(a_updated) +# return out +# +# This function: +# (1) Merges input views into a synthetic base argument, when any of those input views are mutated +# (2) Returns metadata telling the autograd.Function how to modify their arguments properly, +# to respect the new calling convention. +# +# The calling convention is as follows. +# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base. +# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN], +# Where the ordering of the bases is determined from the ordering of the original view args. +# baseA will come before baseB if the earliest original argument coming from baseA +# showed up earlier in the argument list than the earliest original argument coming from baseB. +# +# Example, given some tensors a, b, c, d +# call site: +# f(a, c.view(-1), b.view(-1), b, c, d) +# Modified argument list: +# c_base comes first because the first c view came earlier in arg list than the first b view +# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases +# b_base = torch.Tensor(b.storage()) +# c_base = torch.Tensor(c.storage()) +# f(c_base, b_base, a, d) +def merge_view_inputs( + aot_config: AOTConfig, + fwd_inputs: list[Any], + mutated_input_info: list[InputAliasInfo], + *, + # The autograd case currently has more restrictions than the inference case. + is_inference: bool, +) -> tuple[list[Any], Optional[list[Union[int, tuple[int, torch.Tensor]]]]]: + def _are_differentiable_views(view1, view2): + if view1 is view2: + return True + if view1._base is None and view2._base is None: + return False + if view1._base is view2._base or view1._base is view2 or view1 is view2._base: + return True + return False + + def _same_dtype_views(view1, view2): + if view1.dtype != view2.dtype: + return False + if view1._base is not None and view1.dtype != view1._base.dtype: + return False + if view2._base is not None and view2.dtype != view2._base.dtype: + return False + return True + + assert len(fwd_inputs) == len(mutated_input_info) + if not [info for info in mutated_input_info if info.mutates_data]: + # Return early when there are no mutations. + return fwd_inputs, None + + storage_ref_to_idx: dict[StorageWeakRef, list[int]] = collections.defaultdict(list) + base_args = [] + other_args = [] + for i, inpt in enumerate(fwd_inputs): + if isinstance(inpt, Tensor): + storage_ref = StorageWeakRef(inpt.untyped_storage()) + storage_ref_to_idx[storage_ref].append(i) + else: + other_args.append(inpt) + # Note [Synthetic Base Info Metadata] + # This list contains metadata that tells you what the i'th argument in the inner calling convention should be. + # It's either: + # - another int (corresponding to the index in the argument list of the element from the outer calling convention) + # - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx]) + # idx corresponds to which synthetic base from the outer calling context to view + inner_calling_convention_meta: dict[int, Union[int, tuple[int, torch.Tensor]]] = {} + for aliased_input_indices in storage_ref_to_idx.values(): + if len(aliased_input_indices) <= 1 or not any( + # We only care about mutations that affect all aliases, + # so metadata mutations on an input doesn't require us to do synthetic base handling. + mutated_input_info[inpt_idx].mutates_data + for inpt_idx in aliased_input_indices + ): + other_args.extend( + fwd_inputs[curr_idx] for curr_idx in aliased_input_indices + ) + continue + + # Here, we attempt to do a more complicated check to detect false aliasing + # (e.g. if all the tensors have the same storage, but don't actually overlap) + # In theory, we could have a large group of tensors that all share storages, where only *some* of them + # have overlapping memory. + # I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair + # of tensors in the current group that shares a storage is non-overlapping. + aliased_input_indices_no_false_sharing = compute_overlapping_inputs( + aot_config, fwd_inputs, aliased_input_indices + ) + if len(aliased_input_indices_no_false_sharing) <= 1: + other_args.extend( + fwd_inputs[curr_idx] for curr_idx in aliased_input_indices + ) + continue + + # We detected an input that was mutated, AND aliases with another input. + # we need to replace this set of aliased inputs with a single synthetic base. + # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases + # and error out. We can fix them later. + # These checks are transitive, so we don't need to check every pair. + for idx1, idx2 in zip( + aliased_input_indices, aliased_input_indices[1:], strict=False + ): + view1 = fwd_inputs[idx1] + view2 = fwd_inputs[idx2] + # The "inputs that are aliased but have different differentiable bases" case + # is more complicated and hopefully pretty rare. Not currently handled. + if not is_inference: + assert _are_differentiable_views( + view1, view2 + ), "aot_autograd() does not yet handle non-differentiable view input mutations." + # Regenerating views when reinterpreting complex / real tensors seems non-trivial, + # not handling for now + assert _same_dtype_views( + view1, view2 + ), "aot_autograd() does not yet handle input mutations on views with different dtypes." + non_none_bases = [ + fwd_inputs[i]._base + for i in aliased_input_indices + if fwd_inputs[i]._base is not None + ] + aliases_with_none_bases = [ + fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None + ] + if len(non_none_bases) == 0: + # Case where none of the aliases have a ._base + # we generate a synthetic base without gradients, and generate views off of it + # We hit this case when we have input tensors to the graph that share a storage, + # but do not have a ._base field. + # Wondering when we hit this case? + # The _base field simply says that autograd knows about the aliasing relationship, + # but sometimes we create tensors which are aliased out of the same storage but guaranteed + # to be disjoint. In these cases, we will skip setting up the _base relationship + # for performance reasons (because the fact that the tensors share the same storage + # is unobservable unless you (1) do naughty things with resize_/as_strided + # or (2) look at the storage--as we are doing here.) + # One particular example of this is optimizer steps on the LSTM module: + # LSTM parameters are packed into a contiguous storage for efficiency reasons when + # calling cuDNN kernels, so when these parameters get passed to the optimizer we will + # find they share the same storage, but do not have _base set since they are all disjoint. + # + # NOTE: There is one case where this is unsafe: + # torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily + # the same shape as the "actual" base that the tensor came from. + # For the most part this is fine, because we always use as_strided() + # to generate the original aliased inputs again. + # If we were to use view-replay though, this could cause the aliased views + # to have incorrect sizes. + example_idx = aliased_input_indices[0] + example_alias = fwd_inputs[example_idx] + # Note that this function is re-used at both trace time and runtime. + # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor. + synthetic_base = torch.empty( + (0,), dtype=example_alias.dtype, device=example_alias.device + ) + # We don't actually have a convenient way of going from storage -> tensor, + # So using set_() here (we suffer some minor overhead, but this case is rare). + synthetic_base.set_(example_alias.untyped_storage()) + else: + # Case where all of the aliases require gradients, and have the same _base. + synthetic_base = non_none_bases[0] + for other_base in non_none_bases[1:]: + assert ( + other_base is synthetic_base + ), "aot_autograd() does not yet handle non-differentiable view input mutations." + for alias in aliases_with_none_bases: + assert ( + alias is synthetic_base + ), "aot_autograd() does not yet handle non-differentiable view input mutations." + base_args.append(synthetic_base) + for curr_view_idx in aliased_input_indices: + curr_view = fwd_inputs[curr_view_idx] + base_idx = len(base_args) - 1 + # We store just enough info here so that we can regenerate the view later. + # Regeneration: curr_view._view_func(args[base_idx]) + inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view) + if len(base_args) == 0: + assert len(other_args) == len(fwd_inputs) + # If no synthetic bases are necessary, just return the original inputs. + return fwd_inputs, None + else: + from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr + + def make_hashable(arg): + if isinstance(arg, torch.SymInt): + # Since only nested SymInt objects can be hashed, we wrap them with + # SymIntEqByExpr, which is a hashable wrapper of SymInts. + return SymIntEqByExpr(arg) + return arg + + # Otherwise, return: + # (1) The new args according to the updated calling convention: (synthetic_bases, other_args) + # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention. + # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention. + args_to_functionalization = base_args + other_args + + # Map each argument into its old index. + # There may be some repeated arguments, so we collect their indices in a list. + arg_to_old_idx_map = collections.defaultdict(list) + for i, arg in enumerate(fwd_inputs): + arg_to_old_idx_map[make_hashable(arg)].append(i) + # Reverse the list of each argument, so that we can easily pop them one-after-the-other in order. + for hashable_arg in arg_to_old_idx_map: + arg_to_old_idx_map[hashable_arg] = list( + reversed(arg_to_old_idx_map[hashable_arg]) + ) + + for i, other_arg in enumerate(other_args): + new_idx = len(base_args) + i + old_idx = arg_to_old_idx_map[make_hashable(other_arg)].pop() + inner_calling_convention_meta[old_idx] = new_idx + + # post process into a list + post_processed_calling_convention_meta: list[ + Union[int, tuple[int, torch.Tensor]] + ] = [-1 for _ in range(len(inner_calling_convention_meta))] + for k, v in inner_calling_convention_meta.items(): + post_processed_calling_convention_meta[k] = v + # Quick assert: every argument in the inner calling convention should be accounted for. + for x in post_processed_calling_convention_meta: + assert x != -1 + return args_to_functionalization, post_processed_calling_convention_meta + + +# Note: [Backward graph lazy lowering] +# After AOTDispatch traces the backward for graphs requiring autograd, we will lower the graph lazily, +# unless we suspect that inductor might specialize and insert additional guards. When we do lazy +# lowering, we stash the AOT backward graph (bw_module) in this class. +# +# Lowering passes are performed on a deepcopy of this bw_module due to compatbility +# with compiled autograd. See: https://github.com/pytorch/pytorch/pull/149229#discussion_r2002122645. +@dataclass +class AutogradLazyBackwardCompileInfo: + bw_module: Callable + placeholder_list: list[Any] + saved_context: Optional[TracingContext] + saved_compile_context: Optional[CompileContext] + + +# On an AOT Autograd cache hit, we already have a lowered backward, so there is usually +# no need to keep information around for a new lazy compilation. Except for compiled autograd, +# which wants to retrace this backward into a larger graph, and it needs the graph module to do so. +@dataclass +class CachedAutogradLazyBackwardCompileInfo: + bw_module_fn: Callable + + +def _raise_if_functorch_active(): + # not ideal but prevent the user from seeing a nasty traceback - See #138422 + stack = torch._C._functorch.peek_interpreter_stack() + torch._check( + stack is None, + lambda: ( + "It looks like you're trying to call a compiled backward function within vmap/grad/vjp, " + "which isn't supported. Try wrapping vmap inside torch.compile, or skip compiling the " + "backward function." + ), + ) + + +# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. +def _backward_prologue_functional( + ctx_saved_tensors, ctx_symints, metadata, maybe_subclass_metadata, *flat_args +): + # Calling convention: we expect a grad_out passed to the backward: + # - for every output of the fw that does *not* alias an input or graph intermediate + # - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations) + # - for every graph intermediate that we need to use to generate an output later. + # The other outputs in the autograd.Function.forward that do *not* show up in the backward include: + # - outputs that alias inputs or graph intermediates + # - updated inputs due to metadata-only mutations. + # We need to return them in the forward, but ensure that they all do not get gradients in the backward, + # and we filter them out here before passing the remaining grad_outputs into the compiled backward. + _raise_if_functorch_active() + + num_intermediate_bases = metadata.num_intermediate_bases + num_mutated_runtime_inps = metadata.num_mutated_inp_runtime_indices + expected_grad_outs = ( + metadata.num_outputs + num_mutated_runtime_inps + num_intermediate_bases + ) + deterministic = metadata.deterministic + global_deterministic = torch.are_deterministic_algorithms_enabled() + if deterministic is not None: + torch._check( + not (not deterministic and global_deterministic), + lambda: ( + "This compiled backward function is being run with " + "torch.use_deterministic_algorithms(True), " + "but it was previously generated during the forward function while " + "torch.use_deterministic_algorithms(False) was set." + ), + ) + + assert len(flat_args) == expected_grad_outs + out_info = metadata.output_info + + inp_tangents, out_tangents, intermediate_base_tangents = ( + flat_args[:num_mutated_runtime_inps], + flat_args[ + num_mutated_runtime_inps : num_mutated_runtime_inps + metadata.num_outputs + ], + flat_args[num_mutated_runtime_inps + metadata.num_outputs :], + ) + # input_info contains info on *every* input, + # But in the backward(), we are only given grad outputs for every mutated input + # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad + input_info = metadata.input_info + inp_tangents_filtered = [ + x + for x, info_idx in zip( + inp_tangents, + metadata.mutated_inp_runtime_indices, + ) + if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad + ] + # We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates + out_tangents_filtered = [ + x + for x, info in zip(out_tangents, out_info) + if info.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + and issubclass(info.raw_type, torch.Tensor) + and info.requires_grad + ] + # intermediate bases always require gradients, and always participate in the backward graph. + flat_bw_args_with_grads = [ + *inp_tangents_filtered, + *out_tangents_filtered, + *intermediate_base_tangents, + ] + num_flat_bw_args_with_grads = len(flat_bw_args_with_grads) + + # sanity asserts + # metadata_only_inps = [ + # x for x, info_idx in zip(inp_tangents, mutated_inp_indices) + # if not input_info[info_idx].mutates_data + # ] + # aliased_outputs = [ + # x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias] + # assert all(x is None for x in metadata_only_inps) + # assert all(x is None for x in aliased_outputs) + # TODO: replace this with FunctionalizedRngRuntimeWrapper + rng_args = [] + if metadata.is_rng_op_functionalized: + # Add the seed and offset to args + rng_args = CUDARngStateHelper.get_torch_state_as_tuple() + + bw_tokens = [None] * metadata.num_backward_tokens + + # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first + # in the bw output order. + + # Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls + # There are tests that count these calls, saving to var. + num_ctx_saved_tensors = len(ctx_saved_tensors) + all_args = [ + *ctx_symints, + *ctx_saved_tensors, + *flat_bw_args_with_grads, + *bw_tokens, + *rng_args, + ] + del ctx_saved_tensors + + # Note: [AOTAutograd Backward Guards] + # During AOTDispatch, we eagerly create and trace out a joint fw-bw graph. + # Doing so requires us to "guess" about some of the metadata of our grad_outputs. + # + # In particular: if an output to the forward is a plain tensor or a subclass, + # its corresponding grad_output in the backward **may or may not** be + # a plain tensor or a subclass. The main cases are: + # (1) If an output is a plain tensor, its grad_out will also be a plain tensor, + # *unless* the output is used in some subclass compute later in the forward graph, + # which will cause its grad_output to become a subclass + # (2) If an output is a subclass, its grad_out will also be a subclass, + # *unless* the output of the forward did not actually participate in the gradient computation, + # in which case autograd will insert a plain tensor of zeros for the grad_output. + # We could avoid this case with `torch.autograd.Function.set_materialize_grads`, + # although this is not turned on today in AOTAutgrad and would require more work. + # + # Today, we make a guess on subclass-ness based on the above examples, + # and hard-error in the backward if we guessed wrong. + # + # In the future, we should add backward guards that would allow us to + # properly handle this case instead of erroring: we would need to retrace the backward graph, + # since we might produce an entirely different trace if our grad_outputs are subclass or not. + del flat_bw_args_with_grads + + tangents_start_idx = ( + len(all_args) - num_flat_bw_args_with_grads - len(rng_args) - len(bw_tokens) + ) + assert tangents_start_idx == len(ctx_symints) + num_ctx_saved_tensors + tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens) + + # TODO: figure out how to refactor the backward properly + # so I can use aot_dispatch_subclass_wrapper() here. + if maybe_subclass_metadata is not None: + tangents = all_args[tangents_start_idx:tangents_end_idx] + + if len(tangents) != len(metadata.subclass_tangent_meta): + raise RuntimeError( + "The grad inputs should be same number as forward output tangents" + ) + + flat_processed_tangents = list( + itertools.chain.from_iterable( + ( + AOTDispatchAutograd.process_runtime_tangent( + t, + m, + )[1] + ) + for t, m in zip( + tangents, + metadata.subclass_tangent_meta, + ) + ) + ) + + all_args = ( + runtime_unwrap_tensor_subclasses( + all_args[:tangents_start_idx], + # SymInts that are inputs to the backward graph are + # already included in the "all_args" list. + # Any symints coming from tensor subclasses should always + # come from primals, and so they will show up as extra + # arguments to the forward graph, and they will be saved + # as activation in the backward graph. + append_symints=False, + ) + + flat_processed_tangents + + runtime_unwrap_tensor_subclasses( + all_args[tangents_end_idx:], + append_symints=False, + ) + ) + else: + all_args = [ + ( + AOTDispatchAutograd.process_runtime_tangent( + t, + metadata.subclass_tangent_meta[i - tangents_start_idx], + )[0] + if (tangents_start_idx <= i < tangents_end_idx) + else t + ) + for i, t in enumerate(all_args) + ] + + # Backward with forward inputs mutations is not supported in double backward. + if ( + torch.is_grad_enabled() + and metadata.indices_of_inputs_that_requires_grad_with_mutations_in_bw + ): + raise RuntimeError( + "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True" + ) + + return all_args + + +def initialize_rng_states( + num_rng: int, + graphsafe_idx: int, + fwd_rng_states: list[torch.Generator], + bwd_rng_states: list[torch.Generator], +): + """ + Initialize the cudagraph safe rng states. + + Initialization of rng states should have a few properties: + - the initialization for each rng state should be independent + - the initialization should be deterministic + - the initialization should be based off current rng state, so that independent graphs do not + have equal rng behavior + + We defer initialization of rng states until runtime because compilation is wrapped + with preserve_rng_states. Seed initialization should advance the rng states so consecutive compilations + do not give equal randomness. + """ + with torch.utils._python_dispatch._disable_current_modes(): + seeds = torch.randint(0, torch.iinfo(torch.int64).max, (num_rng,), device="cpu") + fwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + bwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + + +# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. +def _backward_epilogue_functional( + metadata, maybe_subclass_metadata, out, *, make_subclass_override=None +): + # Toss out the backward output tokens + num_bw_tokens = metadata.num_backward_tokens + if num_bw_tokens > 0: + out = out[:-num_bw_tokens] + + # TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile + out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue( + metadata, out, offset_index=len(out) - 1 + ) + out = tuple(out) + + # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here. + if maybe_subclass_metadata is not None: + assert maybe_subclass_metadata.grad_input_metas is not None + outs_wrapped = wrap_tensor_subclasses( + out, + subclass_metas=maybe_subclass_metadata.grad_input_metas, + included_subclass_symints=True, + is_runtime=True, + make_subclass_override=make_subclass_override, + ) + return outs_wrapped + return out + + +def coerce_to_expected_memory_format(x: torch.Tensor, memory_format: MemoryFormatMeta): + if memory_format.memory_format is not None: + # Coerce to torch.memory_format + if not x.is_contiguous(memory_format=memory_format.memory_format): + x = x.contiguous(memory_format=memory_format.memory_format) + return x + + expected_size = memory_format.size + assert expected_size is not None + expected_stride = memory_format.stride + assert expected_stride is not None + # Expected size and stride are static ints + # ok to use == to compare runtime tensor strides and shapes + + if x.shape == expected_size and x.stride() == expected_stride: + # Runtime tangent size and stride are the same as expected, no need to coerce + return x + + # Empty_strided creates a raw Tensor. + # We are guranteed that only raw Tensors has expected size and stride. + # Subclasses have only expected memory_format. + restrided = torch.empty_strided( + size=expected_size, + stride=expected_stride, + dtype=x.dtype, + device=x.device, + layout=x.layout, + requires_grad=x.requires_grad, + ) + restrided.copy_(x) + return restrided + + +@contextlib.contextmanager +def _disable_saved_tensors_hooks(): + error_message = ( + "Saved tensors hooks were specialized as GraphModules." + "In this case aot_autograd inlines them in forward and backward graph " + "and disables them during runtime of aot_autograd compiled region." + "If you see this error, that means that there is some unexpected push or pop manipulation " + "during aot_autograd compiled region runtime." + "Compilation with different hooks must result in recompilation." + ) + fail_if_non_empty = False + maybe_prev_message = None + try: + maybe_prev_message = ( + torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() + ) + torch._C._autograd._saved_tensors_hooks_disable( + error_message, fail_if_non_empty + ) + yield + finally: + if maybe_prev_message is None: + torch._C._autograd._saved_tensors_hooks_enable() + else: + torch._C._autograd._saved_tensors_hooks_disable( + maybe_prev_message, fail_if_non_empty + ) + + +# This is wrapped in a class just for namespacing purposes +# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly +class AOTDispatchAutograd: + @staticmethod + def process_runtime_tangent(x, meta: Union[PlainTensorMeta, SubclassCreationMeta]): + if not isinstance(x, torch.Tensor): + return x, [x] + + if isinstance(x, FakeTensor): + assert meta.memory_format + x = coerce_to_expected_memory_format(x, meta.memory_format) + return x, [x] + + expected_type: Optional[type] = torch.Tensor + expected_meta = None + if isinstance(meta, SubclassCreationMeta): + expected_type = meta.original_subclass_type + expected_meta = meta.meta + + runtime_type = type(x) + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + # When we're inside compiled autograd's AOTDispatcher step, + # regular Tensors look like FunctionalTensors. + # Tensor subclasses still look like Tensor subclasses though. + if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor): + runtime_type = torch.Tensor + + runtime_meta = None + runtime_subclass_keys: Sequence[str] = [] + + if is_traceable_wrapper_subclass(x): + runtime_subclass_keys, runtime_meta = x.__tensor_flatten__() + + def maybe_coerce(x): + same_type: bool = expected_type == runtime_type + same_meta: bool = expected_meta == runtime_meta + + if same_type and same_meta: + return x + + if not hasattr(x, "__coerce_same_metadata_as_tangent__"): + return None + + if same_type: + # Backward Compatibility, as some Subclass impls can have original 1-arg function. + return x.__coerce_same_metadata_as_tangent__(expected_meta) + + return x.__coerce_same_metadata_as_tangent__(expected_meta, expected_type) + + # Coerce to expected type and metadata + orig_x = x + x = maybe_coerce(x) + if x is None: + raise RuntimeError( + f""" +During the backward, we encountered a tensor subclass where we guessed its +metadata incorrectly. + +Expected metadata: {str(expected_meta)}, expected type: {str(expected_type)} + +Runtime metadata: {str(runtime_meta)}, runtime type: {str(runtime_type)} + +shape: {str(orig_x.shape)} +To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__. +""" + ) + + # Coerce to expected memory format + assert meta.memory_format + x = coerce_to_expected_memory_format(x, meta.memory_format) + + if not is_traceable_wrapper_subclass(x): + return x, [x] + + assert isinstance(meta, SubclassCreationMeta) + if orig_x is not x: + runtime_subclass_keys = x.__tensor_flatten__()[0] + + assert len(meta.attrs) == len(runtime_subclass_keys) + leaves = [] + for i, (attr, attr_meta) in enumerate(meta.attrs.items()): + elem = getattr(x, attr) + new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent( + elem, attr_meta + ) + if new_elem is not elem: + setattr(x, attr, new_elem) + leaves.extend(elem_leaves) + + return x, leaves + + @staticmethod + def post_compile( + compiled_fw_func, # fw_module after compilation + wrappers + compiled_bw_func, # bw_module after compilation + wrappers + maybe_subclass_meta: Optional[SubclassMeta], + num_symints_saved_for_bw_: int, + backward_state_indices: list[int], + disable_amp: bool, + indices_of_inps_to_detach: list[int], + lazy_backward_info: Optional[ + Union[ + AutogradLazyBackwardCompileInfo, + CachedAutogradLazyBackwardCompileInfo, + ] + ], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, # runtime metadata + try_save_cache_entry: Optional[Callable], # Save cache entry after compilation + ): + # For additional context see Note [CUDA Graph Safe RNG Functionalization] + # Each pair forward, backward rng states must be equal prior to its invocation on any + # iteration of forward, backward. Because they are initialized equal, and are computing the same rng op, + # running forward then backward advances them the same amount and keeps them equal. + # However, a user may invoke multiple forwards, then backwards, such that they are not in sync. + # Initially we have: + # fwd_state0 == bwd_state0. + # Lets say we run: + # fwd0: fwd_state0 -> fwd_state1 + # fwd1: fwd_state1 -> fwd_state2 + # fwd2: fwd_state2 -> fwd_state3 + # If we now invoke bwd2, + # we need to update bwd_state equal to the rng that was observed in fwd2. + # we save the rng_state fwd_state2 in forward because we detect that it is not the + # current backward state and therefore would not be accessible if we do not save it. + # Similarly, if we are going to update the backward state to a new value, and there is a pending + # forwards which needs its current state, we will save it. + # Within the autograd context, we keep track of the curr iteration so that on backward + # we know what the generator state must be before the backward is run. + num_rng = fw_metadata.num_graphsafe_rng_states + graphsafe_idx = fw_metadata.graphsafe_rng_state_index + fwd_rng_states: list[torch.Generator] = [] + bwd_rng_states: list[torch.Generator] = [] + curr_fwd_iter = itertools.count(0) + backward_state_position = 0 + pending_forwards: set[int] = set() + saved_backward_tensor_states: dict[int, list[torch.Tensor]] = {} + + class CompiledFunction(torch.autograd.Function): + compiled_fw = compiled_fw_func + compiled_bw = compiled_bw_func + metadata: ViewAndMutationMeta = fw_metadata # type: ignore[assignment] + maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta + num_symints_saved_for_bw = num_symints_saved_for_bw_ + _aot_id = aot_config.aot_id + _lazy_backward_info = lazy_backward_info + + @staticmethod + def _compiled_autograd_key(ctx): + return (ctx._autograd_function_id, *ctx.symints) + + @staticmethod + def forward(ctx, *deduped_flat_tensor_args): + args = deduped_flat_tensor_args + if backward_state_indices: + bw_state = args[backward_state_indices[0]] + assert isinstance(bw_state, BackwardState) + ctx._compiled_autograd_backward_state = bw_state + + if num_rng: + if len(fwd_rng_states) == 0: + assert graphsafe_idx is not None + initialize_rng_states( + num_rng, graphsafe_idx, fwd_rng_states, bwd_rng_states + ) + + _curr_iter = next(curr_fwd_iter) + ctx._curr_iter = _curr_iter + + # if this state is not contained in the backward, + # we need to save it for when its backward pass happens + if _curr_iter != backward_state_position: + saved_backward_tensor_states[_curr_iter] = [ + rng_state.get_state() for rng_state in fwd_rng_states + ] + + pending_forwards.add(_curr_iter) + args = (*args, *fwd_rng_states) + + # There is a pretty complicated calling convention around what the compiled fw returns. + # The full list of outputs and their relative order is: + # (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints) + # - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version + # of the original view, and not the synthetic base + # - Note that donated buffer logic requires (*saved_tensors, *saved_symints) showing up last + # in the fw output order. + fw_outs = call_func_at_runtime_with_args( + CompiledFunction.compiled_fw, + args, + disable_amp=disable_amp, + ) + + num_outputs = CompiledFunction.metadata.num_outputs + num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased + num_mutated_runtime_inps = ( + CompiledFunction.metadata.num_mutated_inp_runtime_indices + ) + num_forward_returns = CompiledFunction.metadata.num_forward_returns + + # Partitioners must put symint arguments at the end separate from tensor arguments + tensors_saved_for_backwards = fw_outs[ + CompiledFunction.metadata.tensors_saved_for_backwards_slice + ] + assert all( + isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards + ) + + def mark_dynamic_activations(activations: list[torch.Tensor]): + for ( + idx, + dims, + ) in CompiledFunction.metadata.dynamic_saved_tensors_idxs.items(): + maybe_mark_dynamic_helper(activations[idx], dims) + return activations + + # See Note [Detaching saved tensors in AOTAutograd] + ctx.save_for_backward( + *mark_dynamic_activations( + [ + x.detach() if x._is_view() else x + for x in tensors_saved_for_backwards + ] + ) + ) + symint_outs = fw_outs[ + CompiledFunction.metadata.symints_saved_for_backwards_slice + ] + assert all( + isinstance(x, (int, float, torch.SymInt, torch.SymFloat)) + for x in symint_outs + ), str([type(x) for x in symint_outs]) + ctx.symints = symint_outs + + raw_returns = fw_outs[0:num_forward_returns] + + # Wrap all autograd.Function.forward() outputs that are aliases + # so that autograd.Function doesn't treat them as tensors + if num_mutated_runtime_inps > 0: + for i, idx in enumerate( + CompiledFunction.metadata.mutated_inp_runtime_indices + ): + # We could make this faster by only looping over inputs with metadata-only mutations + # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many. + info = CompiledFunction.metadata.input_info[idx] + if info.mutates_metadata and not info.mutates_data: + raw_return_idx = i + raw_returns[raw_return_idx] = TensorAlias( + raw_returns[raw_return_idx] + ) + + if config.debug_assert: + user_mutated_inputs_raw = raw_returns[ + 0:num_mutated_runtime_inps + ] + mut_inp_infos = [ + x + for x in CompiledFunction.metadata.input_info + if x.mutates_data or x.mutates_metadata + ] + assert len(user_mutated_inputs_raw) == len(mut_inp_infos) + + if CompiledFunction.metadata.num_unsafe_view_outputs > 0: + for idx in CompiledFunction.metadata.unsafe_view_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + o = raw_returns[raw_return_idx] + raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view( + o, o.shape + ) + + if num_outputs_aliased > 0: + for idx in CompiledFunction.metadata.aliased_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + raw_returns[raw_return_idx] = TensorAlias( + raw_returns[raw_return_idx] + ) + + if config.debug_assert: + intermediates_raw = raw_returns[ + num_mutated_runtime_inps + num_outputs : + ] + assert not any( + isinstance(x, TensorAlias) for x in intermediates_raw + ) + + # invariant: intermediate bases always require gradients, so we don't have to + # consider marking them as non-differentiable. + raw_returns_not_including_intermediate_bases = raw_returns[ + : num_mutated_runtime_inps + num_outputs + ] + raw_returns_meta = [ + x + for x in CompiledFunction.metadata.input_info + if x.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + CompiledFunction.metadata.output_info + + fw_outs_not_requiring_grad = [ + x + for (i, x) in enumerate( + raw_returns_not_including_intermediate_bases + ) + if isinstance(x, torch.Tensor) + and not raw_returns_meta[i].requires_grad + ] + ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) + ctx._materialize_non_diff_grads = False + return tuple(raw_returns) + + @staticmethod + def backward(ctx, *flat_args): + all_args = _backward_prologue_functional( + ctx.saved_tensors, + ctx.symints, + CompiledFunction.metadata, + CompiledFunction.maybe_subclass_metadata, + *flat_args, + ) + + if num_rng: + nonlocal backward_state_position, bwd_rng_states + curr_backward_iter = ctx._curr_iter + retain_graph = ( + torch._C._autograd._get_current_graph_task_keep_graph() + ) + + # Save current state if we have a pending forward that needs this state + # or this state may be needed again because of retain graph + if ( + backward_state_position in pending_forwards + and backward_state_position not in saved_backward_tensor_states + and ( + backward_state_position != curr_backward_iter + or retain_graph + ) + ): + saved_backward_tensor_states[backward_state_position] = [ + rng_state.get_state() for rng_state in bwd_rng_states + ] + + # Restore saved states if needed + if curr_backward_iter in saved_backward_tensor_states: + if backward_state_position != curr_backward_iter: + for bwd_state, saved_state in zip( + bwd_rng_states, + saved_backward_tensor_states[curr_backward_iter], + ): + bwd_state.set_state(saved_state) + if not retain_graph: + del saved_backward_tensor_states[curr_backward_iter] + else: + assert backward_state_position == curr_backward_iter + + backward_state_position = curr_backward_iter + 1 + if not retain_graph: + pending_forwards.remove(curr_backward_iter) + all_args.extend(bwd_rng_states) + + def impl_fn(double_ctx=None): + out = CompiledFunction._backward_impl(ctx, all_args) + return _backward_epilogue_functional( + CompiledFunction.metadata, + CompiledFunction.maybe_subclass_metadata, + out, + ) + + needs_grad = torch.is_grad_enabled() and any( + t.requires_grad for t in all_args if isinstance(t, torch.Tensor) + ) + if needs_grad: + # double backward + return CompiledFunction._double_backward(ctx, impl_fn, all_args) + else: + return impl_fn() + + @staticmethod + def _double_backward(ctx, impl_fn, all_args): + # Ensure that the graph is connected, and error if double backward is performed. + # See comment for why once_differentiable is not sufficient: + # https://github.com/pytorch/pytorch/pull/92348/files#r1072962107 + class CompiledFunctionBackward(torch.autograd.Function): + # CompiledFunctionBackward is not yet supported in dynamo skipfiles + _aot_id = aot_config.aot_id + + @staticmethod + def forward(double_ctx, *unused_args): + return impl_fn(double_ctx) + + @staticmethod + def backward(double_ctx, *args): + raise RuntimeError( + "torch.compile with aot_autograd does not currently support double backward" + ) + + CompiledFunctionBackward._compiled_autograd_key = ( # type: ignore[method-assign] + CompiledFunction._compiled_autograd_key + ) + + return CompiledFunctionBackward.apply(*all_args) + + @staticmethod + def _backward_impl(ctx, all_args): + # compiled autograd reimplements this function at proxy_call_aot_backward + assert ( + not backward_state_indices + ), "BackwardState requires CompiledAutograd" + ctx.maybe_clear_saved_tensors() + + saved_tensors_use_once = ( + not torch._C._autograd._get_current_graph_task_keep_graph() + ) + + if CompiledFunction.compiled_bw is None: + assert lazy_backward_info is not None + assert isinstance( + lazy_backward_info, AutogradLazyBackwardCompileInfo + ) + + if not saved_tensors_use_once: + fw_metadata.bw_donated_idxs = [] + # Update bw_donated_idxs if using lazy_backward_info from `aot_dispatch_autograd` + if ( + hasattr(lazy_backward_info, "saved_context") + and hasattr(lazy_backward_info.saved_context, "fw_metadata") + and hasattr( + lazy_backward_info.saved_context.fw_metadata, # type: ignore[union-attr] + "bw_donated_idxs", + ) + ): + lazy_backward_info.saved_context.fw_metadata.bw_donated_idxs = ( # type: ignore[union-attr] + [] + ) + + bw_module = lazy_backward_info.bw_module + placeholder_list = lazy_backward_info.placeholder_list + saved_context = lazy_backward_info.saved_context + saved_compile_context = lazy_backward_info.saved_compile_context + + context = torch._C._DisableAutocast if disable_amp else nullcontext + metrics_context = get_metrics_context() + with ( + tracing(saved_context), + compile_context(saved_compile_context), + context(), + track_graph_compiling(aot_config, "backward"), + metrics_context, + dynamo_timed( + "backward._backward_impl", + phase_name="entire_backward_compile", + log_pt2_compile_event=True, + dynamo_compile_column_us="backward_cumulative_compile_time_us", + log_waitcounter=True, + waitcounter_name_override="entire_backward_compile", + ), + callback_handler.install_callbacks( + CallbackTrigger.LAZY_BACKWARD, + str(CompileContext.current_compile_id()), + ), + ): + CompileEventLogger.compilation_metric(is_forward=False) + # See Note: [Backward graph lazy lowering] + CompiledFunction.compiled_bw = aot_config.bw_compiler( + copy.deepcopy(bw_module), placeholder_list + ) + # Maybe save cache entry + if try_save_cache_entry is not None: + try_save_cache_entry( + CompiledFunction.compiled_bw, + bw_module, + fw_metadata, + aot_config, + ) + + if ( + torch._functorch.config.donated_buffer + and not saved_tensors_use_once + and fw_metadata.bw_donated_idxs != [] + ): + torch._check( + False, + lambda: ( + "This backward function was compiled with non-empty donated " + "buffers which requires create_graph=False and retain_graph=False. " + "Please keep backward(create_graph=False, retain_graph=False) " + "across all backward() function calls, or set " + "torch._functorch.config.donated_buffer=False to disable " + "donated buffer." + ), + ) + + out = call_func_at_runtime_with_args( + CompiledFunction.compiled_bw, + all_args, + steal_args=True, + disable_amp=disable_amp, + ) + return out + + compiled_function = RuntimeWrapper( + indices_of_inps_to_detach=indices_of_inps_to_detach, + trace_joint=True, + disable_amp=disable_amp, + ).post_compile( + CompiledFunction.apply, + aot_config, + runtime_metadata=fw_metadata, + ) + + return compiled_function + + +@dataclass +class DebugAssertWrapper(CompilerWrapper): + flat_requires_grad: list[Optional[bool]] = field(default_factory=list) + + def post_compile( + self, + compiled_fn, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, + ): + @wraps(compiled_fn) + def debug_compiled_function(args: list[Any]): + # TODO: Check aliasing relationships + # TODO: Check strides for metadata mutation + # (NB: ideally, this logic is factored out of this function and + # you move these debug checks there) + + # Check requires grad. Bad case is when we compiled with + # requires_grad = False, but input requires_grad = True + # (vice versa is OK; we compute a gradient and then throw + # it away when it hits the input.) + for i, a in enumerate(args): + can_require_grad = self.flat_requires_grad[i] + if can_require_grad is None: + assert not isinstance(a, Tensor) + elif not can_require_grad: + assert not a.requires_grad, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would not require grad", + ) + + return compiled_fn(args) + + return debug_compiled_function + + +def pre_compile( + wrappers: list[CompilerWrapper], + flat_fn: Callable, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + """ + Runs a sequence of wrappers on the given function and arguments. + Mutates wrappers in place. + """ + for wrapper in wrappers: + flat_fn, flat_args, fw_metadata = wrapper.pre_compile( + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + return flat_fn, flat_args, fw_metadata + + +def post_compile( + wrappers: list[CompilerWrapper], + compiled_fn: Callable, + aot_config: AOTConfig, + *, + runtime_metadata: ViewAndMutationMeta, +) -> tuple[Callable, ViewAndMutationMeta]: + """ + Runs a sequence of wrappers on the given function. Should be called after pre_compile() + """ + for wrapper in reversed(wrappers): + compiled_fn = wrapper.post_compile( + compiled_fn, aot_config, runtime_metadata=runtime_metadata + ) + return compiled_fn, runtime_metadata + + +def make_runtime_safe( + fw_metadata: ViewAndMutationMeta, + maybe_subclass_meta: Optional[SubclassMeta], +): + """ + Calls make_runtime_safe on all ViewAndMutationMetas. + Modifies both arguments. Allows ViewAndMutationMetas to + be safely cached in AOTAutogradCache. + """ + fw_metadata.make_runtime_safe() + if maybe_subclass_meta is not None: + maybe_subclass_meta.fw_metadata.make_runtime_safe() + if maybe_subclass_meta.grad_input_metas: + for meta in maybe_subclass_meta.grad_input_metas: + if isinstance(meta, SubclassCreationMeta): + meta.make_runtime_safe() diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/schemas.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..f80b070dc91565b988965f014d5a37532a312d0b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/schemas.py @@ -0,0 +1,968 @@ +# mypy: allow-untyped-defs +""" +The various dataclasses, Enums, namedtuples etc used in AOTAutograd. This includes +input/output types, metadata, config, function signatures etc. +""" + +import collections +import dataclasses +import functools +import itertools +from collections.abc import Iterable, Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, NewType, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._guards import Source +from torch._ops import OpOverload +from torch._subclasses import FakeTensor +from torch._subclasses.fake_tensor import is_fake +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .functional_utils import ( + _check_if_mutation_can_be_in_graph, + FunctionalTensorMetadataEq, +) +from .utils import strict_zip + + +zip = strict_zip + + +OutputType = Enum( + "OutputType", + ( + # output is not an alias + "non_alias", + # output aliases an input + "alias_of_input", + # output **is** an input tensor + "is_input", + # output has a ._base tensor, which is a graph intermediate. + # We need to return its ._base as a graph output, + # so its requires_grad info is populated correctly. + # Instructs the runtime code to regenerate the current output + # from a base tensor, graph_intermediates[base_idx] + "alias_of_intermediate_save_as_output", + # Same as above; but we don't need to explicitly add its ._base + # as a graph output, because it already **is** a graph output. + "alias_of_intermediate", + # Same as above; but the output's ._base is **already** a user output. + # Instructs the runtime code to regenerate the current output from + # a base tensor, user_outputs[base_idx] + "alias_of_intermediate_base_is_user_output", + # See Note [Intermediate Bases Optimization] + "unsafe_view_alias", + # output is an alias, but has a custom autograd.Function backward. + # In this case, we don't want to do view-replay, since we won't be able to replay the custom function. + # Instead, we'll treat this output "normally", and trace its backward into the graph. + "custom_function_view", + ), +) + + +# This class stores info about every user output. +@dataclass(frozen=True) +class OutputAliasInfo: + # Tells us if this output is: + # (1) a regular (non-aliased) output + # (2) an alias of a forward input + # (3) **is** a forward input (special case of "alias_of_input") + # (4) an alias of an intermediate (aka an alias of an output of the inner traced forward) + # (5) an alias of an intermediate, that explicitly requires returning the intermediate + # as a graph output + # (6) an alias of an intermediate, where that intermediate is also a user output + output_type: OutputType + # The raw type of the output (torch.Tensor, SymInt, etc) + raw_type: type + # If (1) above, then + # - base_idx is None + # If (2) or (3) above, then + # - Tells us that the base of this alias is user_fwd_input[base_idx] + # (This is an index into the inputs *before* we make synthetic bases) + # If (4) or (5) above, then + # - Tells us that the base of this alias is output_graph_intermediates[base_idx] + # here, this refers to the index of the *direct* traced + # If (6) above, then: + # - Tells us that the base of this alias is output_user_fwds[base_idx] + # here, this refers to the index of the *direct* traced + base_idx: Optional[int] + # If it is a Tensor, what the dynamic dims are (otherwise is None) + dynamic_dims: Optional[set[int]] + # requires_grad + requires_grad: bool + # FunctionalTensorWrapper that represents this output. + # + # Provides us the means to replay views from it. + # + # We need to wrap the actual FunctionalTensorWrapper with this class so that + # we only compare the tensor's metadata. That's because with the transformations + # of the model throughout AOTAutograd, the sequence of ViewMeta and the base + # tensor might change. + functional_tensor: Optional[FunctionalTensorMetadataEq] = None + + +class MutationType(Enum): + NOT_MUTATED = 1 + MUTATED_IN_GRAPH = 2 + MUTATED_OUT_GRAPH = 3 + + +# This class tells us info about user inputs. +@dataclass(frozen=True) +class InputAliasInfo: + is_leaf: bool + mutates_data: bool + mutates_metadata: bool + mutations_hidden_from_autograd: bool + mutations_under_no_grad_or_inference_mode: bool + mutation_inductor_storage_resize: bool + mutates_storage_metadata: bool + requires_grad: bool + keep_input_mutations: bool + + def __post_init__(self): + if self.mutates_storage_metadata: + # For convenience, we guarantee that this is always true. + # In practice, If we call .set_(), then at runtime there is no need + # to additionally fix up the tensor metadata, since our runtime + # call to inp.set_(updated_inp) will already have the right metadata + assert self.mutates_metadata + + @functools.cached_property + def mutation_type(self) -> MutationType: + if ( + (not self.mutates_data) + and (not self.mutates_metadata) + and not (self.mutation_inductor_storage_resize) + ): + return MutationType.NOT_MUTATED + + if _check_if_mutation_can_be_in_graph( + self.keep_input_mutations, + self.mutates_data, + self.mutates_metadata, + self.mutations_hidden_from_autograd, + self.mutations_under_no_grad_or_inference_mode, + self.mutates_storage_metadata, + self.mutation_inductor_storage_resize, + self.requires_grad, + ): + return MutationType.MUTATED_IN_GRAPH + + return MutationType.MUTATED_OUT_GRAPH + + +@dataclass +class MemoryFormatMeta: + # For static shapes we assume tangents have the same strideness as outputs + size: Optional[Sequence[int]] = None + stride: Optional[Sequence[int]] = None + + # For dynamic shapes we assume the same memory format: contiguous, channels_last etc. + memory_format: Optional[torch.memory_format] = None + + @staticmethod + def from_tensor(t: torch.Tensor) -> Optional["MemoryFormatMeta"]: + # We only memorize expected memory format for + # 1. Traceable wrapper subclasses + # We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors. + # 2. Dynamic shape tensors + # Support for symbolic shapes is not implemented yet. + use_memory_format: bool = ( + not torch._functorch.config.guess_tangent_strides_as_outputs + or is_traceable_wrapper_subclass(t) + ) + if not use_memory_format: + is_static_shape = True + for s in itertools.chain(t.shape, t.stride()): + if not isinstance(s, int): + is_static_shape = False + break + + use_memory_format = not is_static_shape + + if use_memory_format: + return MemoryFormatMeta( + memory_format=torch._prims_common.suggest_memory_format(t), + ) + + return MemoryFormatMeta( + size=t.size(), + stride=t.stride(), + ) + + +@dataclass +class PlainTensorMeta: + unwrapped_idx: int + memory_format: Optional[MemoryFormatMeta] = None + + +@dataclass +class SubclassCreationMeta: + """ + Used for AOTDispatch. + This dataclass gives us the information we need to reconstruct a tensor subclass + from our flat inputs. + Why is this important? The graph that we'd like to trace out contains flat tensor inputs, + But the user's original model may have subclass inputs and outputs. + So we need to wrap/unwrap subclasses as necessary to translate between the user's + view (subclass inps/outs), and the backend compiler's view (graph with no subclass args). + + Complications arise mostly from the fact that a subclass can hold more than one inner tensor; + So for a given subclass input/output, we need to carefully track which indices map + to the subclass tensor in the corresponding "dense-tensor-only" graph. + """ + + # In the inner graph that only takes in dense tensor inputs, + # this maps to the first index of "tensors that should go in this subclass wrapper" + flat_tensor_start_idx: int + # arg_count is inclusive of the arg_counts of any + # inner tensor subclasses: If I have a TwoTensor and + # both of its inner elements are TwoTensors, then the + # arg_count of the outer-most sublass will be 4 + arg_count: int + # Mark where or not symints were included. This flag is only used in one assertion + # in "wrap_tensor_subclasses" + included_subclass_symints: bool + # meta and attrs are produced by the subclass's __tensor_flatten__. + # We need to keep them around along with outer_size / outer_stride to plumb them + # into __tensor_unflatten__ + attrs: dict[str, Union["SubclassCreationMeta", PlainTensorMeta]] + outer_size: Iterable[Union[None, int, torch.SymInt]] + outer_stride: Iterable[Union[None, int, torch.SymInt]] + meta: Any + # Stores the original subclass itself. + # This is needed because we need the autograd metadata on the original subclass + # (this is guaranteed to be a wrapper subclass that holds a fake tensor, + # so holding onto this at runtime shouldn't leak memory) + # This field is nulled out after calling make_runtime_safe() + original_subclass: Optional[torch.Tensor] + + # Used at runtime to determine the subclass type, so we don't need to save the original subclass + original_subclass_type: Optional[type] = None + memory_format: Optional[MemoryFormatMeta] = None + + def compute_outer_size_and_stride( + self, + all_args, + *, + curr_start_idx: int, + ): + from .subclass_utils import compute_symint_placeholders + + def compute(outer, start_idx): + placeholders = compute_symint_placeholders(outer) + has_symbolic = any(placeholders) + + if has_symbolic: + start = curr_start_idx + end = start_idx + sum(placeholders) + it_args = iter(all_args[start:end]) + it_placeholders = iter(placeholders) + return pytree.tree_map_only( + lambda _: next(it_placeholders), lambda _: next(it_args), outer + ), start + len(placeholders) + else: + return outer, start_idx + + outer_size, next_idx = compute(self.outer_size, curr_start_idx) + outer_stride, _ = compute(self.outer_stride, next_idx) + return outer_size, outer_stride + + def creation_fn( + self, + all_args, + *, + is_runtime: bool, + ): + inner_tensors = {} + + curr_start_idx = self.flat_tensor_start_idx + for attr, creation_meta in self.attrs.items(): + if isinstance(creation_meta, PlainTensorMeta): + subclass = all_args[curr_start_idx] + curr_start_idx += 1 + else: + subclass = creation_meta.creation_fn( + all_args, + is_runtime=is_runtime, + ) + curr_start_idx += creation_meta.arg_count + inner_tensors[attr] = subclass + + if is_runtime: + assert self.original_subclass_type is not None + original_subclass_type = self.original_subclass_type + else: + original_subclass_type = type(self.original_subclass) + + if is_runtime: + outer_size, outer_stride = self.compute_outer_size_and_stride( + all_args, + curr_start_idx=curr_start_idx, + ) + else: + outer_size, outer_stride = self.outer_size, self.outer_stride + + rebuilt = original_subclass_type.__tensor_unflatten__( # type: ignore[attr-defined] + inner_tensors, self.meta, outer_size, outer_stride + ) + + if not is_runtime: + # After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper + # has correct autograd metadata, since we'll be tracing through the autograd engine with the subclass. + # We don't trace through the autograd engine at runtime though, so no need + # to compute this extra metadata then! + torch._mirror_autograd_meta_to(self.original_subclass, rebuilt) # type: ignore[attr-defined] + + return rebuilt + + def make_runtime_safe(self): + def _make_size_runtime_safe(x: Union[None, int, torch.SymInt]) -> Optional[int]: + dummy = -1 + if isinstance(x, torch.SymInt): + # Replace nested ints by a dummy value (-1) as NJT ignores + # the outer_size/outer_stride at runtime. + return dummy if x.node.is_nested_int() else None + return x + + assert self.original_subclass is not None + self.original_subclass_type = type(self.original_subclass) + self.original_subclass = None + + # Note: NJT outer_size in AOTDispatcher + # `_make_size_runtime_safe` replaces any nested int with a dummy value (-1) + # to prevent serializing a SymInt at runtime. Internally, nested tensor __tensor_unflatten__ + # is designed to safely ignore this dummy value. + # For more details, see: https://github.com/pytorch/pytorch/blob/5141ade8e30c64e873e14dcc8de233da45d15025/torch/nested/_internal/nested_tensor.py#L266-L299 # noqa: B950 + self.outer_size = tuple(map(_make_size_runtime_safe, self.outer_size)) + self.outer_stride = tuple(map(_make_size_runtime_safe, self.outer_stride)) + + # Recurse on nested subclass info + for creation_meta in self.attrs.values(): + if isinstance(creation_meta, SubclassCreationMeta): + creation_meta.make_runtime_safe() + + def __post_init__(self): + # sanity assert to make sure we don't leak memory + assert is_fake(self.original_subclass) + + +# This class encapsulates all aliasing + mutation info we need about the forward graph +# See a more detailed overview of the edge case handling at +# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit +# NOTE: This class is saved in AOTAutogradCache, If you are adding elements, make sure +# they are covered by warm cache tests. +@dataclass(eq=False) +class ViewAndMutationMeta: + # length = # user inputs + # This gives us info about every input, and what sort of mutation happened to it (if any) + input_info: list[InputAliasInfo] + + # length = # user outputs + # This gives us info about every output (mostly around whether it aliases other tensors) + output_info: list[OutputAliasInfo] + + # length = the number of intermediate bases appended as outputs to the end of the forward graph. + # Note: this is not necessarily the same thing as: + # len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate]) + # Because outputs might share a ._base, or an output's ._base might itself be + # another user output (in both cases, we won't redundantly append bases to the end of the graph) + num_intermediate_bases: int + + # For inference only: instructs us to keep data-only input mutations directly in the graph + keep_input_mutations: bool + + # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) + # + (# intermediate bases) + # These are the FakeTensor (or potential SymInt) outputs that we traced from our + # metadata pass of the user's forward function. + # Their only use today is to pass them as a best-guess for tangents when tracing the joint. + # Stashing them as part of our "metadata" makes it simpler if we want to run our analysis + # pass once, and re-use the output throughout AOTAutograd + traced_tangents: list[Any] + + # Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs + # They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors, + # Given a (potentially larger) list of plain torch tensors. + + # Taking subclass_inp_meta as an example: + # subclass_inp_meta[i] = j (an int) tells us: + # "The i'th user input is not a subclass, and corresponds to inputs[j] of the plain-tensor graph." + # subclass_inp_meta[i] = SubclassCreationMeta(flat_tensor_start_idx=3, arg_count=2) + # "The i'th user input is subclass holding two inner tensors, which are + # inputs[3] and inputs[4] of the plain-tensor graph". + + # length = # user inputs + subclass_inp_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]] + # So, the full set of outputs to the forward graph looks something like: + # (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors) + # where the first 3 of those 4 can be subclasses + # (but not saved_for_bw tensors, since these are internal to the compiler + # and not user visible, so there's no point in wrapping/unwrapping them at runtime). + # This list contains subclass information on all of the fw graph outputs + # except for saved_for_bw_tensors. + subclass_fw_graph_out_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]] + # length = # backward graph inputs + subclass_tangent_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]] + # TODO: we should kill this + # (need to default it to not break internal) + is_train: bool = False + + # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) + # + (# intermediate bases) + # At runtime, we don't keep the traced_tangents around since they're not serializable. + # Instead, we keep any necessary subclass metadata necessary about each traced_tangent. + # This list is generated after calling make_runtime_safe(). + traced_tangent_metas: Optional[list[Any]] = None + + num_symints_saved_for_bw: Optional[int] = None + + # The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue + # NOTE: AOTAutograd will assume that the ambient `is_grad_enabled` is the grad mode + # that is intended to be in effect prior to running the graph, in keeping with + # equivalence to eager mode. It is the responsibility of upstream graph acquisition + # to reset the grad mode to its pre-graph value prior to calling aot_autograd. + grad_enabled_mutation: Optional[bool] = None + + # Keeps track of whether `torch.use_deterministic_algorithms` was turned on + # when the forward was run. If deterministic mode was turned off during the + # forward, but is turned on during the backward call, then an error is + # raised + deterministic: Optional[bool] = None + + # Keeps track of which input indices store parameters (which we will treat as static) + static_input_indices: list[int] = field(default_factory=list) + + # Map of effect type (ex. _EffectType.ORDERED) to token. If there are + # side-effectful operators, FunctionalTensorMode will populate this + # dictionary telling us how many tokens we will need during tracing. + tokens: dict[Any, torch.Tensor] = field(default_factory=dict) + + # Only filled in if/when we trace the joint function + # If an input requires grad and is mutated in the backward, it is only safe to keep the mutation + # in the graph if gradients are disabled while the backward runs + # (grad mode is disabled by default when users run the backward, but can be turned on with create_graph=True) + # At runtime during the backward, we use this list of indices to error properly if we find out + # that it was not safe to include a backward mutation in the graph. + indices_of_inputs_that_requires_grad_with_mutations_in_bw: list[int] = field( + default_factory=list + ) + + # Indexes of saved tensors which are donated buffer. + # Donated buffer means the tensor is not alias of any forward user input, forward user output, + # and backward output. + bw_donated_idxs: Optional[list[int]] = None + + # Number of tokens used in backward, appended at the end of backward outputs. + # Filled after tracing joint function. + num_backward_tokens: int = 0 + + # Number of rng states that will get thread into the forward and backward for + # cudagraph compatible run_and_save_rng + num_graphsafe_rng_states: int = 0 + + graphsafe_rng_state_index: Optional[int] = None + + def __post_init__(self): + # pre-compute the indices of the inputs that are mutated. + # When keep_input_mutations is set, we don't need to worry about our epilogue + # handling data-only mutations, because we keep them directly in the graph. + mutated_inp_runtime_indices = [ + i + for i, m in enumerate(self.input_info) + if (m.mutation_type == MutationType.MUTATED_OUT_GRAPH) + ] + + mutated_graph_handled_indices = [ + i + for i, m in enumerate(self.input_info) + if m.mutation_type == MutationType.MUTATED_IN_GRAPH + ] + self.mutated_graph_handled_indices = mutated_graph_handled_indices + self.num_mutated_graph_handled_indices = len(self.mutated_graph_handled_indices) + + mutated_graph_handled_indices_seen_by_autograd = [ + i + for i in mutated_graph_handled_indices + if not self.input_info[i].mutations_hidden_from_autograd + ] + + self.mutated_graph_handled_indices_seen_by_autograd = ( + mutated_graph_handled_indices_seen_by_autograd + ) + self.num_mutated_graph_handled_indices_seen_by_autograd = len( + self.mutated_graph_handled_indices_seen_by_autograd + ) + + aliased_out_indices = [ + i + for i, m in enumerate(self.output_info) + if m.output_type + not in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + ] + unsafe_view_out_indices = [ + i + for i, m in enumerate(self.output_info) + if m.output_type is OutputType.unsafe_view_alias + ] + + # This is pre-computed in post_init for perf. + # It contains the index of every element + # of input_info that corresponds to a mutation (data or metadata or both) + self.mutated_inp_runtime_indices = mutated_inp_runtime_indices + self.num_mutated_inp_runtime_indices = len(self.mutated_inp_runtime_indices) + + # This is pre-computed for perf. + # It contains the index of every element + # of output_info that corresponds to an alias (either of an input or intermediate) + self.aliased_out_indices = aliased_out_indices + self.unsafe_view_out_indices = unsafe_view_out_indices + self.num_outputs = len(self.output_info) + self.num_outputs_non_aliased = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + ] + ) + self.num_outputs_aliased_to_inputs = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.alias_of_input, + OutputType.is_input, + ] + ] + ) + self.num_unsafe_view_outputs = len(self.unsafe_view_out_indices) + self.num_outputs_aliased_to_intermediates = len( + [ + x + for x in self.output_info + if x.output_type + in [ + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + OutputType.alias_of_intermediate_base_is_user_output, + ] + ] + ) + self.num_outputs_aliased = ( + self.num_outputs_aliased_to_inputs + + self.num_outputs_aliased_to_intermediates + ) + + # Record dynamic outputs of the Dynamo traced forward graph + # Mark them as dynamic at the end of the runtime wrapper + self.dynamic_outputs = any(o.dynamic_dims for o in self.output_info) + + # Record the indices of dynamic outputs in the partitioned forward graph + # Mark them as dynamic in the runtime wrapper + # activation index -> dynamic dims indices + self.dynamic_saved_tensors_idxs: dict[int, set[int]] = {} + + # See Note: [AOTAutograd Backward Guards] + # This is pre-computed for fast asserts on the types of our grad_outputs in the backward. + # Eventually, we should kill this and replace with real backward guards. + # (we want to precompute the "runtime" types, so replace FakeTensor with torch.Tensor) + self.output_types = [ + torch.Tensor if isinstance(x, FakeTensor) else type(x) + for x in self.traced_tangents + ] + + self.is_rng_op_functionalized = config.functionalize_rng_ops + # All of the above metadata is collected by tracing the fw function. + # However, extra outputs for rng offsets behave differently. Both fwd + # and bwd graphs have their own outputs for the total consumed offsets. + # Unlike mutated inputs, we don't have to worry about sending the right + # set of tensors between fwd and bwd. Fwd and bwd offsets are + # independent and simpler to handle. Therefore, we track them + # separately. + self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0 + + # Our forward() returns both (tokens, mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints) + # Tokens will be split out before mutations/view handling and we do not count them here. + self.num_forward_returns = ( + self.num_mutated_inp_runtime_indices + + self.num_outputs + + self.num_intermediate_bases + ) + # In case of functionalization of rng ops, the fw_module returns one + # additional output for rng offset. This rng offset is used right + # away to advance the rng state, and is not passed on to the raw + # outputs. However, we need to know the exact boundary to identify + # which tensors to be saved for the bwd graph. num_forward captures + # this information. + self.num_forward = self.num_forward_returns + self.num_outputs_rng_offset + + def make_runtime_safe(self): + """ + There are various fields in ViewAndMutationMeta that aren't serializable. This function is called after all tracing + is completed to simplify certain fields in the metadata so that they can be safely cached. + + Doing so may lose information (in the case of traced_tangents), but none of the information is needed at runtime. + """ + # TODO: This function is only a best effort: there are other fields that may not be cache safe + # (i.e., there's no guarantee that tensor_flatten() returns a serializable result), or that + # SubclassCreationMeta is cache safe. + assert self.traced_tangent_metas is None + + def extract_metadata(t): + if isinstance(t, torch.Tensor) and is_traceable_wrapper_subclass(t): + (inner_tensors, flatten_spec) = t.__tensor_flatten__() # type: ignore[attr-defined] + # Technically, we only need the flatten_spec, not the inner tensors. + # However, some Tensor subclasses (like TwoTensor) may have flatten_spec = None. + # And we want to be able to assert that this metadata is non-None, + # to distinguish between "this was a tensor subclass with no metadata" vs. + # "this wasn't a tensor subclass at all". + return (inner_tensors, flatten_spec) + else: + return None + + self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] + # Clear traced tangents at runtime + self.traced_tangents = [] + new_output_info = [] + for out in self.output_info: + if config.view_replay_for_aliased_outputs: + new_out = out + else: + # If we're not using view_replay, remove the functional tensor. + # Functional tensors are unfortunately not serializable, + # so doing this is required for AOTAutograd caching. + new_out = dataclasses.replace(out, functional_tensor=None) + new_output_info.append(new_out) + self.output_info = new_output_info + for inp_meta in self.subclass_inp_meta: + if isinstance(inp_meta, SubclassCreationMeta): + inp_meta.make_runtime_safe() + for inp_meta in self.subclass_fw_graph_out_meta: + if isinstance(inp_meta, SubclassCreationMeta): + inp_meta.make_runtime_safe() + for inp_meta in self.subclass_tangent_meta: + if isinstance(inp_meta, SubclassCreationMeta): + inp_meta.make_runtime_safe() + + @property + def tensors_saved_for_backwards_slice(self): + assert self.num_symints_saved_for_bw is not None + if self.num_symints_saved_for_bw > 0: + return slice(self.num_forward, -self.num_symints_saved_for_bw) + else: + return slice(self.num_forward, None) + + @property + def symints_saved_for_backwards_slice(self): + assert self.num_symints_saved_for_bw is not None + if self.num_symints_saved_for_bw > 0: + return slice(-self.num_symints_saved_for_bw, None) + else: + return slice(0, 0) # empty slice + + def __eq__(self, other): + if not isinstance(other, ViewAndMutationMeta): + return NotImplemented + return ( + self.input_info == other.input_info + and self.output_info == other.output_info + and self.num_intermediate_bases == other.num_intermediate_bases + and self.keep_input_mutations == other.keep_input_mutations + and self.is_rng_op_functionalized == other.is_rng_op_functionalized + and self.num_outputs_rng_offset == other.num_outputs_rng_offset + and len(self.traced_tangents) == len(other.traced_tangents) + and all( + x.shape == y.shape and x.dtype == y.dtype + for x, y, in zip(self.traced_tangents, other.traced_tangents) + ) + and self.num_backward_tokens == other.num_backward_tokens + ) + + +@dataclass(eq=False) +class SubclassMeta: + # A copy of all forward metadata, but computed on the *dense* tensor forward (after desugaring subclasses) + # So for example, if the user had a model containing two `TwoTensor` inputs, + # Then `SubclassMeta.fw_metadata.input_infos` would have length 4 here. + fw_metadata: ViewAndMutationMeta + + # Note: [Computing Subclass Metadata about grad_inputs] + # Given a list of flattened, plain tensor grad_inputs, this tells us how to reconstruct the grad_input subclasses + # + # You might think: why not just assume that all grad_inputs will have the same subclass-ness as the original inputs? + # (AOTAutograd generally assumes other properties, e.g. that grad_outputs are contiguous) + # + # This doesn't really work though. take this example: + # + # def f(DoubleTensor, DenseTensor): + # return DoubleTensor * DenseTensor + # + # In the above example, the .grad field of *both* DoubleTensor and DenseTensor will be a DoubleTensor. + # When we trace out a joint fw-bw graph, we'll end up returning two subclasses for the two grad_inputs. + # This means that our backward graph will return 4 outputs (two dense tensors for each DoubleTensor grad_input) + # and we need to properly store the metadata that tells us how to turn these 4 outputs back into DoubleTensors. + # + # Note that this info **cannot** easily be figured out from ViewAndMutationMeta. + # We can only compute this info by tracing the entire joint and examining the grad_inputs that we computed. + # + # See Note: [AOTAutograd Backward Guards] + # This will also eventually require us to install backward guards, + # in case we made incorrect assumptions about the subclass-ness of our grad_outputs + # + # Optional field because we don't compute for inference graphs + grad_input_metas: Optional[ + list[Union[PlainTensorMeta, SubclassCreationMeta]] + ] = None + + def __init__(self) -> None: + # The fields in this class get set after its construction. + pass + + +# This class exists because: +# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs +# - we only care about the metadata on those aliases, so we can regenerate them. +# We do not want them to participate in the autograd.Function. +# We do that by wrapping them in an opaque class, so the autograd.Function +# does not know to treat them as tensors. +@dataclass(frozen=True) +class TensorAlias: + alias: torch.Tensor + + +@dataclass +class BackwardSignature: + """ + Provides information about the backward section of an exported + joint forward-backward graph. + For a particular fx GraphModule, this class contains information on: + (1) A mapping from each gradient (backwards output) to the parameter + it corresponds to (forward input) + (2) A mapping from each gradient (backwards output) to the user input + it corresponds to (forward input) + (3) Which of the forward outputs corresponds to the loss, that we backprop on. + + Each string name is the `node.name` of the corresponding node in the fx graph. + """ + + gradients_to_parameters: dict[str, str] + gradients_to_user_inputs: dict[str, str] + loss_output: str + + +GraphOutputName = NewType("GraphOutputName", str) +GraphInputName = NewType("GraphInputName", str) +FQN = NewType("FQN", str) + + +@dataclass +class GraphSignature: + """ + Provides information about an exported module. + For a particular fx GraphModule, this class contains information on: + (1) Which graph inputs are parameters, buffers, or user inputs + (2) (for params/buffers) a mapping from the name of each graph argument + to its parameter/buffer FQN in the original nn.Module. + (3) If there are input mutations, these are represented as extra outputs + in the fx GraphModule. We provide a mapping from these + extra output names to the names of the actual inputs. + (4) The pytree metadata on how to flatten/unflatten inputs and outputs. + The corresponding FX GraphModule only accepts and returns + pytree-flattened inputs/outputs. + (5) (Optionally) if the FX is a joint forward-backward graph, we provide + a signature on the backward section of the joint graph. + """ + + parameters: list[FQN] + buffers: list[FQN] + + user_inputs: list[GraphInputName] + user_outputs: list[GraphOutputName] + inputs_to_parameters: dict[GraphInputName, FQN] + inputs_to_buffers: dict[GraphInputName, FQN] + + # If the user's module mutates a buffer, + # it's represented in the graph as an extra graph output. + # This dict is a mapping from + # "graph outputs that correspond to updated buffers" + # to the FQN names of those mutated buffers. + buffers_to_mutate: dict[GraphOutputName, FQN] + user_inputs_to_mutate: dict[GraphOutputName, GraphInputName] + + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + + backward_signature: Optional[BackwardSignature] + + input_tokens: list[GraphInputName] + output_tokens: list[GraphOutputName] + + @classmethod + def from_tracing_metadata( + cls, + *, + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec, + graph_input_names: list[str], + graph_output_names: list[str], + view_mutation_metadata: ViewAndMutationMeta, + named_parameters: list[str], + named_buffers: list[str], + num_user_inputs: int, + num_user_outputs: int, + loss_index: Optional[int], + backward_signature: Optional[BackwardSignature], + ) -> "GraphSignature": + graph_inputs = graph_input_names + graph_outputs = graph_output_names + parameters = list(named_parameters) + buffers = list(named_buffers) + num_tokens = len(view_mutation_metadata.tokens) + + # Calling convention assumptions: + # (1) graph inputs = (input_tokens, params, buffers, user_inputs) + # (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients) + # (If we are capturing an inference graph, this convention is identical + # except that param_gradients is empty) + # See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens + + # Address input calling conventions: + start, stop = 0, num_tokens + input_tokens = graph_inputs[start:stop] + + start, stop = stop, stop + len(parameters) + inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters)) + + start, stop = stop, stop + len(buffers) + inputs_to_buffers = dict( + zip( + graph_inputs[start:stop], + buffers, + ) + ) + + start, stop = stop, stop + num_user_inputs + user_inputs = graph_inputs[start:stop] + + # We should've gone through all the inputs now + assert len(graph_inputs) - stop == 0 + + # Address output calling conventions: + start, stop = 0, num_tokens + output_tokens = graph_outputs[start:stop] + + names = [*input_tokens, *parameters, *buffers, *user_inputs] + mutations = [] + for idx, input_info in enumerate(view_mutation_metadata.input_info): + if input_info.mutates_data: + # Only buffers can be mutated, not parameters + assert idx >= len(parameters) + mutations.append(names[idx + num_tokens]) + + assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices + + start, stop = ( + stop, + stop + view_mutation_metadata.num_mutated_inp_runtime_indices, + ) + outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations)) + + user_inputs_to_mutate = {} + buffers_to_mutate = {} + for output_name, mutation_name in outputs_to_mutations.items(): + if mutation_name in user_inputs: + user_inputs_to_mutate[output_name] = mutation_name + else: + assert mutation_name in buffers + buffers_to_mutate[output_name] = mutation_name + + start, stop = stop, stop + num_user_outputs + user_outputs = graph_outputs[start:stop] + + unused_outputs = len(graph_outputs) - stop + if backward_signature is not None: + unused_outputs -= len(backward_signature.gradients_to_parameters) + len( + backward_signature.gradients_to_user_inputs + ) + assert unused_outputs == 0 + + return GraphSignature( + parameters=parameters, # type: ignore[arg-type] + buffers=buffers, # type: ignore[arg-type] + user_inputs=user_inputs, # type: ignore[arg-type] + user_outputs=user_outputs, # type: ignore[arg-type] + inputs_to_buffers=inputs_to_buffers, # type: ignore[arg-type] + inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type] + user_inputs_to_mutate=user_inputs_to_mutate, + buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type] + in_spec=in_spec, + out_spec=out_spec, + backward_signature=backward_signature, + input_tokens=input_tokens, # type: ignore[arg-type] + output_tokens=output_tokens, # type: ignore[arg-type] + ) + + +@dataclass +class AOTAutogradCacheInfo: + cache_key: str + start_time_ns: int + forward_symints: list[torch.SymInt] + + +@dataclass +class AOTConfig: + """ + Configuration for AOTDispatcher + """ + + fw_compiler: Callable + bw_compiler: Callable + partition_fn: Callable + decompositions: dict[OpOverload, Callable] + num_params_buffers: int + aot_id: int + keep_inference_input_mutations: bool + is_export: bool = False + no_tangents: bool = False + dynamic_shapes: bool = False + aot_autograd_arg_pos_to_source: Optional[list[Source]] = None + static_input_indices: Optional[list[int]] = None + inference_compiler: Optional[Callable] = None + enable_log: bool = True + # this is always false outside of export. + pre_dispatch: bool = False + # Key to use for AOTAutogradCache + cache_info: Optional[AOTAutogradCacheInfo] = None + # If we should ignore the shape_env in the ambient tracing_context. + # The net effect is that if dynamic shapes are on, we end up + # specializing on example_inputs. + # Used only by standalone_compile. + ignore_shape_env: bool = False + precompile_backend_id: Optional[str] = None + + def __post_init__(self): + if self.pre_dispatch: + assert self.is_export, "Can only have pre_dispatch IR for export." + + +SubclassTracingInfo = collections.namedtuple( + "SubclassTracingInfo", + ["plain_tensor_trace_fn", "plain_tensor_args", "maybe_subclass_meta"], +) diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/subclass_parametrization.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/subclass_parametrization.py new file mode 100644 index 0000000000000000000000000000000000000000..954b31a5049c9de24dfb9f2daad77e95355768ac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/subclass_parametrization.py @@ -0,0 +1,103 @@ +import dataclasses +import itertools +from collections.abc import Iterable +from typing import Any, Union + +import torch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +# This is technically very similar to SubclassCreatingMeta +# in aot_autograd, but we don't need all the stuff in there +# so just recreated a new dataclass. +@dataclasses.dataclass +class SubclassCreationMeta: + start_idx: int + num_tensors: int + class_type: Any + attrs: dict[str, "SubclassCreationMeta"] + metadata: Any + outer_size: Iterable[Union[None, int, torch.SymInt]] + outer_stride: Iterable[Union[None, int, torch.SymInt]] + + +class UnwrapTensorSubclass(torch.nn.Module): + def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def] + todo: list[torch.Tensor] = list(tensors) + + def _unwrap_tensor_subclasses(subclass_meta, tensors, offset): # type: ignore[no-untyped-def] + if subclass_meta is None: + return tensors[offset], offset + 1 + inner_tensors = {} + for attr, meta in subclass_meta.attrs.items(): + built_tensor, offset = _unwrap_tensor_subclasses(meta, tensors, offset) + inner_tensors[attr] = built_tensor + rebuilt = subclass_meta.class_type.__tensor_unflatten__( + inner_tensors, + subclass_meta.metadata, + subclass_meta.outer_size, + subclass_meta.outer_stride, + ) + return rebuilt, offset + + return _unwrap_tensor_subclasses(self.subclass_meta, todo, 0)[0] + + def right_inverse(self, tensor: torch.Tensor) -> list[torch.Tensor]: + assert type(tensor) is not torch.Tensor + plain_tensors: list[torch.Tensor] = [] + + def _create_subclass_meta(tensor, idx, plain_tensor_container): # type: ignore[no-untyped-def] + if type(tensor) is torch.Tensor: + plain_tensor_container.append(tensor) + return None, idx + 1 + inner_tensors_attrnames, metadata = tensor.__tensor_flatten__() # type: ignore[attr-defined] + new_idx = idx + attr_to_meta = {} + for attr in inner_tensors_attrnames: + val = getattr(tensor, attr) + subclass_meta, new_idx = _create_subclass_meta( + val, new_idx, plain_tensor_container + ) + attr_to_meta[attr] = subclass_meta + return ( + SubclassCreationMeta( + start_idx=idx, + num_tensors=new_idx - idx, + class_type=type(tensor), + attrs=attr_to_meta, + metadata=metadata, + outer_size=tensor.size(), + outer_stride=tensor.stride(), + ), + new_idx, + ) + + self.subclass_meta = _create_subclass_meta(tensor, 0, plain_tensors)[0] + return plain_tensors + + +def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Module: + """ + Model transformation that replaces all the parameters that are subclasses to plain tensors. + This reduces runtime overhead of flattening/unflattening the parameters. + + This transformation adds parametrization with `torch.nn.utils.parametrize`. + The FQNs of the subclass parameters will be changed and state_dict will become incompatible with the original model. + E.g. + Original model state_dict: {"p1": torch.testing._internal.TwoTensor} + becomes: {"parametrizations.p2.original0": torch.Tensor, "parametrizations.p2.original1": torch.Tensor} + + """ + for name, tensor in itertools.chain( + list(module.named_parameters(recurse=False)), + list(module.named_buffers(recurse=False)), + ): + if is_traceable_wrapper_subclass(tensor): + torch.nn.utils.parametrize.register_parametrization( + module, name, UnwrapTensorSubclass() + ) + + for name, child in module.named_children(): + unwrap_tensor_subclass_parameters(child) + + return module diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..06349b18a30883a9c54e1c78c66532e15c0c287c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/subclass_utils.py @@ -0,0 +1,476 @@ +# mypy: allow-untyped-defs +""" +This file contains utilities for tracing through __torch_dispatch__ based tensor subclasses and modes. +AOTAutograd's responsibility is to trace through all pytorch capabilities that live in the pytorch dispatcher, +and this includes tensor subclasses that implement __torch_dispatch__. +""" + +import collections +import typing +from collections.abc import Iterable +from typing import Any, Callable, Optional, TypeVar, Union + +import torch +import torch.utils._pytree as pytree +from torch import SymInt, Tensor +from torch._subclasses.fake_tensor import get_plain_tensors +from torch.types import IntLikeType +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .schemas import ( + MutationType, + PlainTensorMeta, + SubclassCreationMeta, + ViewAndMutationMeta, +) +from .utils import strict_zip + + +zip = strict_zip + +T = TypeVar("T", bound=torch.Tensor) + + +def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: + args_flattened = pytree.arg_tree_leaves(*args) + any_subclass_args = any( + is_traceable_wrapper_subclass(x) + for x in args_flattened + if isinstance(x, Tensor) + ) + from torch._functorch._aot_autograd.schemas import SubclassCreationMeta + + any_subclass_outputs = any( + type(x) is SubclassCreationMeta for x in fw_metadata.subclass_fw_graph_out_meta + ) + # This tells us whether or not we need to perform any unwrapping/wrapping of tensor subclasses at runtime. + return any_subclass_args or any_subclass_outputs + + +from .schemas import MemoryFormatMeta + + +def maybe_suggest_memory_format( + t, with_memory_format: bool +) -> Optional[MemoryFormatMeta]: + if not with_memory_format: + return None + + return MemoryFormatMeta.from_tensor(t) + + +def get_subclass_typing_container( + tensor_subclass: torch.Tensor, +) -> dict[type[torch.Tensor], list[type[torch.Tensor]]]: + """ + Given a subclass, returns a recursive dictionary mapping each + inner tensors to its' subclass types. + """ + + def _get_types_for_subclass(tensor_subclass: torch.Tensor) -> None: + if not is_traceable_wrapper_subclass(tensor_subclass): + return + tracker[type(tensor_subclass)].append(tensor_subclass) + inner_keys, _ = tensor_subclass.__tensor_flatten__() + for key in inner_keys: + inner_tensor = getattr(tensor_subclass, key) + _get_types_for_subclass(inner_tensor) + + tracker: dict[Any, list[Any]] = collections.defaultdict(list) + _get_types_for_subclass(tensor_subclass) + return tracker + + +def create_subclass_metadata( + a: Any, start_idx: int, count_symints: bool, with_memory_format: bool = False +): + if not is_traceable_wrapper_subclass(a): + idx = start_idx + 1 + return ( + PlainTensorMeta( + idx, + memory_format=maybe_suggest_memory_format(a, with_memory_format), + ), + idx, + ) + + inner_keys, metadata = a.__tensor_flatten__() + new_start_idx = start_idx + attrs = {} + + for key in inner_keys: + new_subclass_meta, new_start_idx = create_subclass_metadata( + getattr(a, key), + new_start_idx, + count_symints=count_symints, + with_memory_format=with_memory_format, + ) + attrs[key] = new_subclass_meta + + # It *must* be because is_traceable_wrapper_subclass() - but mypy is not smart. + assert isinstance(a, Tensor) + + new_start_idx = ( + new_start_idx + + count_symints * len(filter_symints(a.size())) + + count_symints * len(filter_symints(a.stride())) + ) + + return ( + SubclassCreationMeta( + flat_tensor_start_idx=start_idx, + arg_count=new_start_idx - start_idx, + included_subclass_symints=count_symints, + attrs=attrs, + meta=metadata, + outer_size=a.size(), # type: ignore[attr-defined, arg-type] + outer_stride=a.stride(), # type: ignore[arg-type] + original_subclass=a, + memory_format=maybe_suggest_memory_format(a, with_memory_format), + ), + new_start_idx, + ) + + +# Given a flat list of arguments, some of which may be tensor subclasses, +# computes metadata about "how to reconstruct the current list of subclasses, +# if we were given their flattened dense tensors instead" +def create_subclass_meta( + curr_args: Union[list[Any], tuple[Any, ...]], + *, + count_symints: bool = True, + with_memory_format: bool = False, +) -> list[Union[PlainTensorMeta, SubclassCreationMeta]]: + idx = 0 + infos: list[Union[PlainTensorMeta, SubclassCreationMeta]] = [] + for a in curr_args: + if is_traceable_wrapper_subclass(a): + assert isinstance(a, Tensor) + start_idx = idx + subclass_meta, _ = create_subclass_metadata( + a, + start_idx, + count_symints=count_symints, + with_memory_format=with_memory_format, + ) + infos.append(subclass_meta) + cnt = subclass_meta.arg_count + else: + infos.append( + PlainTensorMeta( + idx, + memory_format=maybe_suggest_memory_format(a, with_memory_format), + ) + ) + cnt = 1 + idx += cnt + return infos + + +def filter_symints(lst: Iterable[IntLikeType]): + # Capture all SymInts from the iterable. + def symint_check(s: IntLikeType) -> bool: + return isinstance(s, SymInt) and not s.node.is_nested_int() + + return [s for s in lst if symint_check(s)] + + +def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> list[bool]: + # Non-nested symints are replaced with None in `make_runtime_safe()` + return [s is None for s in lst] + + +# This function takes in a pytree of arguments and unwraps any tensor +# subclasses. +# +# NOTE: The reason for "append_symints": +# +# * At compile time: we append extra symint args when unwrapping primals +# (but not tangents, because they should always share symints with primals). +# We also append extra symints when unwrapping the subclass outputs of the +# traced function, so we can return them as extra outputs +# +# * At runtime: we similarly append subclass sizes when we unwrap subclass +# primals (but not tangents) on entry to the forward. See the runtime version of +# this function below. +def unwrap_tensor_subclasses( + wrapped_args: list[Union[Tensor, int]], + *, + append_symints: bool, +): + def flatten_subclass(t: Union[Tensor, int], *, out=None): + # unwrap a subclass into plain tensors and their size/stride if "append_symint" + # is True + if not is_traceable_wrapper_subclass(t): + out.append(t) + return + + attrs, _ = t.__tensor_flatten__() + + for attr in attrs: + inner_tensor = getattr(t, attr) + flatten_subclass(inner_tensor, out=out) + + if append_symints: + out.extend(filter_symints(t.size())) + out.extend(filter_symints(t.stride())) + + xs_inner: list[Union[int, Tensor, SymInt]] = [] + + for x in wrapped_args: + flatten_subclass(typing.cast(Tensor, x), out=xs_inner) + + return xs_inner + + +# subclass_metas is needed at runtime to compute which indices are symints in +# the outer_size/outer_stride +def runtime_unwrap_tensor_subclasses( + wrapped_args: list[Union[Tensor, int]], + *, + append_symints: bool, + subclass_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = None, +): + def flatten_subclass(x: Tensor, meta: Optional[SubclassCreationMeta], *, out): + if not is_traceable_wrapper_subclass(x): + out.append(x) + return out + + assert isinstance(x, Tensor) + + attrs, _ = x.__tensor_flatten__() + + for attr in attrs: + inner_tensor = getattr(x, attr) + inner_meta = meta.attrs.get(attr) + flatten_subclass(inner_tensor, inner_meta, out=out) + + if append_symints: + assert isinstance(meta, SubclassCreationMeta) + # outer_size + size = x.size() + symint_placeholders = compute_symint_placeholders(meta.outer_size) + assert len(size) == len(symint_placeholders) + out.extend( + [r for (r, is_symint) in zip(size, symint_placeholders) if is_symint] + ) + + # outer_stride + stride = x.stride() + symint_placeholders = compute_symint_placeholders(meta.outer_stride) + assert len(stride) == len(symint_placeholders) + out.extend( + [r for (r, is_symint) in zip(stride, symint_placeholders) if is_symint] + ) + return out + + xs_inner: list[Union[int, Tensor, SymInt]] = [] + + if append_symints: + assert subclass_metas is not None + + for idx, x in enumerate(wrapped_args): + if not is_traceable_wrapper_subclass(x): + xs_inner.append(x) + continue + + if subclass_metas is None: + get_plain_tensors(typing.cast(Tensor, x), out=xs_inner) + else: + meta = subclass_metas[idx] + assert isinstance(meta, SubclassCreationMeta) + flatten_subclass(typing.cast(Tensor, x), meta, out=xs_inner) + + return xs_inner + + +def unwrap_tensor_subclasses_with_indices_to_original(wrapped_args): + ret_unwrapped = [] + ret_indices_to_original = [] + for i, a in enumerate(wrapped_args): + a_unwrapped = unwrap_tensor_subclasses([a], append_symints=False) + ret_unwrapped.extend(a_unwrapped) + n = len(a_unwrapped) + ret_indices_to_original.extend([i] * n) + + return ret_unwrapped, ret_indices_to_original + + +def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices): + static_input_indices = set(static_input_indices) + new_ind = 0 + remapped_static_indices = [] + for i, arg in enumerate(wrapped_args): + num_indices = 1 + if is_traceable_wrapper_subclass(arg): + num_indices = ( + len(get_plain_tensors(typing.cast(Tensor, arg), out=[])) + + len(filter_symints(arg.size())) + + len(filter_symints(arg.stride())) + ) + + for _ in range(num_indices): + if i in static_input_indices: + remapped_static_indices.append(new_ind) + + new_ind += 1 + + return remapped_static_indices + + +# Turns a flattened list of tensor arguments into (maybe) subclass tensors. +# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in. +def wrap_tensor_subclasses( + unwrapped_args: Union[tuple[Any, ...], list[Any]], + *, + subclass_metas: list[Union[PlainTensorMeta, SubclassCreationMeta]], + num_fw_outs_saved_for_bw: Optional[int] = None, + included_subclass_symints: bool = False, + is_runtime: bool = False, + make_subclass_override: Optional[Callable] = None, +) -> tuple[Any, ...]: + wrapped_args = [] + num_args_tallied = 0 + for subclass_meta in subclass_metas: + if isinstance(subclass_meta, PlainTensorMeta): + wrapped_args.append(unwrapped_args[subclass_meta.unwrapped_idx]) + num_args_tallied += 1 + else: + assert isinstance(subclass_meta, SubclassCreationMeta) + assert subclass_meta.included_subclass_symints == included_subclass_symints + + if make_subclass_override: + wrapped_args.append( + make_subclass_override(subclass_meta, is_runtime, unwrapped_args) + ) + else: + wrapped_args.append( + subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) + ) + num_args_tallied += subclass_meta.arg_count + + # Note: [Partitioner handling for Subclasses, Part 2] + # At the beginning of AOTAutograd, we collect metadata on the inputs and outputs of the user fw, + # to figure out which inputs/outputs are subclasses, and how to reconstruct the subclasses after flattening them. + # + # When this function is called at runtime in the forward, + # we have been passed a list of (flattened) dense-tensor fw-outs, and need to reconstruct any subclass fw outs. + # + # One reasonable question that you should ask: when should the dense_tensor -> subclass_tensor wrapping happen? + # Answer: we do it **inside of our compiled autograd.Function**. + # This seems like morally the right place: autograd happens above subclass desugaring, + # so autograd should see actual tensor subclasses at runtime, and not flattened dense tensors. + # + # This causes a tricky interaction though: when we run the min-cut partitioner to divvy up the joint graph + # into a forward and backward graph, we end up with some activations that show up as extra outputs + # in the compiled forward graph, that are **not** user outputs. + # These activations are not visible to the user, and so there's no need for us to wrap them back into subclasses. + # + # On top of that, when we first computed subclass metadata (in `run_functionalized_fw_and_collect_metadata`), + # we computed subclass metadata on every forward output, but this did **not** include activations + # created by the partitioner. + # as a result, `unwrapped_args` here will correspond to (*unwrapped_user_fw_outs, *activations), + # but `subclass_metas` will only correspond to subclass metatadata on `user_fw_outs`. + # We then need to make sure that we return (*wrapped_user_fw_outs, *activations). + if num_fw_outs_saved_for_bw is not None: + assert len(unwrapped_args) == num_args_tallied + num_fw_outs_saved_for_bw, ( + f"Expected the number actual unwrapped-subclass outputs {len(unwrapped_args)} to equal " + f"the number of args calculated from subclasses ({num_args_tallied}) plus the number of " + f"additional activations saved for the backward pass ({num_fw_outs_saved_for_bw})" + ) + activations = unwrapped_args[num_args_tallied:] + if isinstance(wrapped_args, tuple) and isinstance(activations, tuple): + return wrapped_args + activations + return tuple(list(wrapped_args) + list(activations)) + else: + assert ( + len(unwrapped_args) == num_args_tallied + ), f"Expected {len(unwrapped_args)} == {num_args_tallied}" + return tuple(wrapped_args) + + +# Given a bunch of "dense" tensor arguments, this function (potentially) wraps them into tensor subclasses. +# This function carefully handles the inference vs. joint cases: +# - when is_joint_structure is True, args is (primals, tangents) +# - when is_joint_structure is False, args is [*primals] +def wrap_tensor_subclasses_maybe_joint( + unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta +) -> Union[tuple[Any, ...], list[Any]]: + # Since this function is re-used for both inference and joint graphs, + if is_joint_structure: + assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2 + assert isinstance(unwrapped_args[0], (tuple, list)) and isinstance( + unwrapped_args[1], (tuple, list) + ) + primals, tangents = unwrapped_args[0], unwrapped_args[1] + wrapped_primals = wrap_tensor_subclasses( + primals, + subclass_metas=meta.subclass_inp_meta, + included_subclass_symints=True, + ) + wrapped_tangents = wrap_tensor_subclasses( + tangents, + subclass_metas=meta.subclass_tangent_meta, + included_subclass_symints=False, + ) + return (wrapped_primals, wrapped_tangents) + else: + wrapped_args = wrap_tensor_subclasses( + unwrapped_args, + subclass_metas=meta.subclass_inp_meta, + included_subclass_symints=True, + ) + return wrapped_args + + +def compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata: ViewAndMutationMeta, + inner_metadata: ViewAndMutationMeta, +) -> list[int]: + # Note: [Recomputing subclass mutation handling] + # + # Generally, if a subclass requires grad, its components will not require grad. + # But for the purposes of tracking returned tensors, we should treat those component + # tensors as if they require grad. + # + # For example, if the subclass tensor requires grad and will be mutated in a way that + # requires us to handle the mutation outside of the graph, we need to return it + # from the forward graph. The inner_meta data won't consider the component tensors + # as if they need to be returned, because they don't require grad; but really, we + # should handle those tensors the same way we handle the subclass tensor itself; i.e. + # if we'd include the subclass tensor as part of the outputs, then we should also + # include the component tensors. + # + # To do this, we patch num_mutated_inp_runtime_indices below by expanding the inputs + # from the outer subclass tensors and propagating + + updated_input_info = [] + inner_idx = 0 + if not fw_metadata.subclass_inp_meta: + # Sometimes we don't have subclass info, e.g. synthetic_base codepaths + return inner_metadata.mutated_inp_runtime_indices + assert len(fw_metadata.subclass_inp_meta) == len(fw_metadata.input_info) + for outer_idx, inp_meta in enumerate(fw_metadata.subclass_inp_meta): + if isinstance(inp_meta, PlainTensorMeta): + assert outer_idx < len(fw_metadata.input_info) + if inner_metadata is not None: + assert inner_idx < len(inner_metadata.input_info) + assert ( + inner_metadata.input_info[inner_idx] + == fw_metadata.input_info[outer_idx] + ) + updated_input_info.append(fw_metadata.input_info[outer_idx]) + inner_idx += 1 + else: + assert inp_meta.original_subclass is not None + for _ in range(inp_meta.arg_count): + updated_input_info.append(fw_metadata.input_info[outer_idx]) + inner_idx += 1 + if inner_metadata is not None: + assert len(inner_metadata.input_info) == len(updated_input_info) + + return [ + i + for i, inp in enumerate(updated_input_info) + if inp.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..4c595a86171b28d29431c579bec02efb1f69ead8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -0,0 +1,924 @@ +# mypy: allow-untyped-defs +""" +This module is responsible for transforming functions to be traced into a form +that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis) +to handle. + +It does so by: +1. functionalization (including RNG functionalzation) +2. creating a joint graph when required +3. transforming mutations into extra outputs +4. dispatching subclasses +""" + +import warnings +from contextlib import contextmanager, nullcontext +from functools import wraps +from typing import Any, Callable, Union +from unittest.mock import patch + +import torch +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch import Tensor +from torch._decomp.decompositions_for_rng import PhiloxStateTracker +from torch._guards import detect_fake_mode +from torch._prims_common import CUDARngStateHelper +from torch.fx.experimental.proxy_tensor import ( + maybe_disable_thunkify, + maybe_enable_thunkify, +) +from torch.fx.experimental.symbolic_shapes import ( + guard_or_true, + PropagateUnbackedSymInts, + sym_eq, +) +from torch.nn.utils import stateless + +from .. import config +from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata +from .functional_utils import ( + _check_if_mutation_can_be_in_graph, + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + from_fun, + has_data_mutation, + has_metadata_mutation, + is_fun, + sync_functional_tensor, + to_fun, + was_inductor_storage_resized, +) +from .logging_utils import setup_stacktrace_preservation_hooks +from .schemas import ( + AOTConfig, + MutationType, + OutputType, + SubclassMeta, + SubclassTracingInfo, + ViewAndMutationMeta, +) +from .subclass_utils import ( + create_subclass_meta, + remap_unwrapped_subclass_arg_indices, + requires_subclass_dispatch, + unwrap_tensor_subclasses, + wrap_tensor_subclasses_maybe_joint, +) +from .utils import maybe_to_fresh_input + + +# This function returns a new function that returns mutated inputs as outputs. +# if keep_data_input_mutations is set, then we assume that data-only mutations +# will be left in the graph, and we only return metadata-mutated inputs as outputs. +def fn_input_mutations_to_outputs( + fn: Callable, + meta: ViewAndMutationMeta, + keep_data_input_mutations: bool, +) -> Any: + @wraps(fn) + def inner_fn(*args): + outs = fn(*args) + assert len(meta.output_info) == len(outs) + # The compiled fw will return mutated input tensors, *including* metadata-only mutation. + # However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs. + # (because data-only input mutations are handled directly in the compiled graph) + mutated_inputs_to_return = [ + x for (i, x) in enumerate(args) if i in meta.mutated_inp_runtime_indices + ] + return *mutated_inputs_to_return, *outs + + return inner_fn + + +# This function takes in a fn with external aliasing and mutation, +# and returns a new fn with no external aliasing and mutation, +# as needed for autograd. +# The main transformations are: +# - Return mutated inputs as extra outputs +# - Clone mutated inputs that require gradients, +# because autograd will require us to pass the pre-mutated inputs into autograd.grad +# - Return intermediate bases of outputs as additional outputs, +# needed to appease autograd.Function +# The new function returns: +# (1) The updated outputs +# (2) A boolean mask of len(new_fn_outputs), +# that can be used to tell autograd.grad which outputs should get tangents +# if we trace the backward. +def fn_prepped_for_autograd( + fn: Callable, + meta: ViewAndMutationMeta, +) -> Any: + @wraps(fn) + def inner_fn(*args): + args_maybe_cloned = [ + maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args) + ] + + outs = fn(*args_maybe_cloned) + assert isinstance(outs, (tuple, list)) + outs = list(outs) + assert len(meta.output_info) == len(outs) + + mutated_inputs_to_return = [ + x + for (i, x) in enumerate(args_maybe_cloned) + if i in meta.mutated_inp_runtime_indices + ] + + intermediate_bases = [] + for i, (o, info) in enumerate(zip(outs, meta.output_info)): + if info.output_type == OutputType.alias_of_intermediate_save_as_output: + intermediate_bases.append(o._base) + + assert meta.num_intermediate_bases == len(intermediate_bases) + + # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases) + fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases + + # Also return a boolean mask specifying which outputs to this function will be used as tangents + mutated_inputs_grad_mask = [ + meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data + and meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad + for (i, x) in enumerate(mutated_inputs_to_return) + ] + + # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw + # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, + # which we *should* send to grad() + output_grad_mask = [ + meta.output_info[i].output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + # Also, only tensor outputs should participate in the backward + # (in particular, Symint outputs in the forward graph shouldn't get tangents) + and issubclass(meta.output_info[i].raw_type, Tensor) + and meta.output_info[i].requires_grad + for (i, x) in enumerate(outs) + ] + + intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))] + + out_grad_mask = ( + mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask + ) + assert len(out_grad_mask) == len(fw_outs_to_return) + + # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!) + # and not primals (the preserved inputs, pre-mutation, that we pass to grad()) + # This is annoying: our joint function needs to be aware of functionalization + # (syncing mutated inputs before calling autograd.grad()) + # In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner. + for arg in args_maybe_cloned: + if not isinstance(arg, Tensor): + continue + sync_functional_tensor(arg) + + return fw_outs_to_return, out_grad_mask + + return inner_fn + + +# Given a fn, computes the joint. +# NOTE: fn is expects the following behavior: +# (1) fn() needs to return a tuple of (outs, mask), +# where `mask` tells us which outputs are meant to have tangents. +# we don't know this info automatically, because we don't actually want to blindly +# compute tangents for every output that requires grad. +# Specifically, outputs that alias inputs won't participate in the backward and get tangents. +# (2) fn() cannot mutate any inputs that require gradient. +# otherwise, when we compute autograd.grad(), we will not take those input mutations into account +# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first) +def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any: + def inner_fn(primals: list[Any], tangents: list[Any]): + outs, tangent_mask = fn(*primals) + + assert len(tangent_mask) == len(outs) + outs_to_grad = [ + o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent + ] + assert len(outs_to_grad) == len(tangents) + + # Get the inputs that need gradients + grad_primals = [] + inputs_needs_grads = [] + # Note that we're not using primals here, + # being carefully not to pass any mutated inputs into autograd.grad() + for p in primals: + is_grad_tensor = isinstance(p, Tensor) and p.requires_grad + inputs_needs_grads.append(is_grad_tensor) + if is_grad_tensor: + grad_primals.append(p) + + # Get the outputs that need gradients + needed_outs = [] + needed_tangents = [] + for out, tangent in zip(outs_to_grad, tangents): + if isinstance(out, Tensor) and out.requires_grad: + # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 + # The issue is that we are sensitive to decomps that don't accurately maintain + # their output's _base.shape compared to eager mode, and this helps mitigate a bit. + # The guard_or_true also sketchy; if unbacked + # symints are involved, we're just going to assume that the + # decomps setup the base shape correctly + + # Return out if the result of out.shape==tangent.shape is unknown or known to be true. + # otherwise if its a known false return out.view(tangent.shape). + needed_outs.append( + out + if guard_or_true(sym_eq(out.shape, tangent.shape)) + else out.view(tangent.shape) + ) + needed_tangents.append(tangent) + + setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) + + if config.functionalize_rng_ops: + PhiloxStateTracker.mark_beginning_of_backward() + backward_out: tuple[Tensor, ...] = () + # Call the backwards pass + if grad_primals: + functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + if functional_tensor_mode is not None: + # Side-Effect Tokens: + # We want to have independent chains of tokens for forward and backward. + # functional_tensor_mode._tokens is used by both. + # We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output, + # to return them as joint graph outputs. + # We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward. + # Joint graph tracing allows tokens discovery, + # So all the tokens in backward will be created and added as a graph inputs during tracing. + functional_tensor_mode._tokens_forward_output = ( + functional_tensor_mode._tokens + ) + functional_tensor_mode._tokens = {} + + with set_partitioner_tag_is_backward(), fx_traceback.preserve_node_meta(): + # for full graph export, we always export a joint graph where we assume no tangents are needed. + if aot_config.no_tangents: + assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1 + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + allow_unused=True, + ) + else: + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) + backward_out_iter = iter(backward_out) + return outs, [ + next(backward_out_iter) if i else None for i in inputs_needs_grads + ] + + def inner_fn_with_anomaly(*args): + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Anomaly Detection has been enabled.") + with torch.autograd.detect_anomaly(check_nan=False): + return inner_fn(*args) + + return inner_fn_with_anomaly + + +def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True) -> Any: + # Functionalization of rng ops changes the calling convention of the joint graph. + # It goes from (primals, tangents) to (seed, offset, primals, tangents) + # At runtime, we pass on the current seed and offset. This is hidden from + # the user. + fake_mode = detect_fake_mode() + if fake_mode is None: + fake_mode = nullcontext() + + def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"): + out = PhiloxStateTracker.get_state_as_tensor() + return out + + def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"): + PhiloxStateTracker.set_state_from_tensor(x) + + def append_rng_offsets(args): + if trace_joint: + # args signature before: Tuple(fwd_outputs), Tuple(bwd_outputs) + # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset) + return ( + (*args[0], PhiloxStateTracker.get_updated_fwd_offset()), + (*args[1], PhiloxStateTracker.get_updated_bwd_offset()), + ) + else: + # args signature before: Tuple(fwd_outputs) + # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset) + return (*args, PhiloxStateTracker.get_updated_fwd_offset()) + + def traced_joint( + primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset + ): + with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( + "torch.cuda.set_rng_state", override_set_rng_state + ): + return append_rng_offsets(func(primals, tangents)) + + def traced_forward(*primals_fwd_seed_fwd_base_offset): + # The signature is (*primals, seed, offset) + with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( + "torch.cuda.set_rng_state", override_set_rng_state + ): + return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2])) + + if trace_joint: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward") + return traced_joint, ( + *args, + fwd_seed, + fwd_base_offset, + bwd_seed, + bwd_base_offset, + ) + else: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + return traced_forward, (*args, fwd_seed, fwd_base_offset) + + +@contextmanager +def set_partitioner_tag(tag: str): + meta_key = "partitioner_tag" + assert fx_traceback.has_preserved_node_meta() + + original_val = fx_traceback.current_meta.get(meta_key, None) + fx_traceback.current_meta[meta_key] = tag + try: + yield + finally: + fx_traceback.current_meta[meta_key] = original_val + + +def set_partitioner_tag_is_backward(): + return set_partitioner_tag("is_backward") + + +def set_partitioner_tag_must_be_in_backward(): + return set_partitioner_tag("must_be_in_backward") + + +# This creates the final function that we want to trace using make_fx(), +# in both aot_dispatch_autograd and aot_dispatch_base. +# Preconditions: +# - fn corresponds to the user's fw function +# - fn arguments have been flattened, duplicate arguments have been handled +# - In the returned function, the "primals" arguments *includes* synthetic bases. +# This function does the work of functionalizing the input function, +# and performing copy_() calls at the end of the function if `keep_input_mutations` is set. +# The function returned has signature that is either: +# (1) "traced_fn(primals: List[Any])" if trace_joint is False +# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True +# Returns a new (functionalized) function, and updated arguments to call it with. +def create_functionalized_fn( + fn, + args, + *, + meta: ViewAndMutationMeta, + aot_config: AOTConfig, + trace_joint: bool, +) -> Any: + @wraps(fn) + def _functionalized_f_helper(*args): + with maybe_enable_thunkify(): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + # The functionalization code here can potentially trigger traces + # into the graph, but we'd prefer to NOT do this, because if we + # trace them now, we will end up with FX nodes that don't have + # module stack annotations, which makes unflattener unhappy. + # Wrap inputs into functional wrappers + f_args = pytree.tree_map(to_fun, args) + + # Run the joint + f_outs = fn(*f_args) + + if trace_joint: + # We support a limited amount of mutation of graph inputs during the backward pass. + # (This is used e.g. by Float8, which needs to update buffers during the backward pass) + # Here, we perform extra checks for primals that were mutated in the **backward** + # We're doing the checks here instead of doing them with the rest of the input mutation handling because: + # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened + # during the forward, because the handling is different: some input mutations from the the forward + # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same + # types of mutations in the backward we would need a bw-only runtime epilogue. + # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in + # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would + # require an extra round of tracing though, so it's more efficient to do in-line here. + assert ( + isinstance(args, tuple) + and len(args) == 2 + and isinstance(args[0], (list, tuple)) + ) + # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw) + primals_before = args[0] + primals_after = pytree.tree_map(from_fun, f_args[0]) + for idx, (f_inpt, before, after, inpt_info) in enumerate( + zip(f_args[0], primals_before, primals_after, meta.input_info) + ): + # Store information about mutations in joint(for backward analysis) + joint_mutates_data = has_data_mutation(f_inpt) + + joint_mutates_metadata = has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ) + + # Ban metadata mutations on fw inputs during the bw + if not inpt_info.mutates_metadata: + assert ( + not joint_mutates_metadata + ), "Found a graph input that had its metadata mutated in the backward. This is not supported" + + # Ban storage resizing on fw inputs during the bw + if not inpt_info.mutation_inductor_storage_resize: + assert not was_inductor_storage_resized( + f_inpt + ), "Found a graph input that had storage resizing in the backward. This is not supported" + + # Allow data mutations on fw inputs during the bw, but only if they do not require grad + # So we can guarantee that we can keep the mutations in the graph + if ( + joint_mutates_data + and not inpt_info.mutates_data + and not inpt_info.mutates_storage_metadata + ): + # Not banning here mutations on inpt_info.requires_grad - + # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph) + # Add node meta for copy_ for partitioner that this node should be in backward graph. + with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward(): + before.copy_(after) + meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append( + idx + ) + # Now that we covered mutations to *forward* inputs during the backward, + # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out). + # Today, we will just error in all cases of this happening unless someone needs us to support it. + tangents_before = args[1] + tangents_after = pytree.tree_map(from_fun, f_args[1]) + for f_inpt, before, after in zip( + f_args[1], tangents_before, tangents_after + ): + assert not has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ), "Found an input to the backward that had metadata mutated during the backward pass. This is not supported" + if has_data_mutation(f_inpt): + can_be_in_graph = _check_if_mutation_can_be_in_graph( + keep_input_mutations=True, + mutates_data=True, + mutates_metadata=False, + mutations_hidden_from_autograd=are_all_mutations_hidden_from_autograd( + f_inpt + ), + mutations_under_no_grad_or_inference_mode=are_all_mutations_under_no_grad_or_inference_mode( + f_inpt + ), + mutates_storage_metadata=False, + mutation_inductor_storage_resize=was_inductor_storage_resized( + f_inpt + ), + requires_grad=f_inpt.requires_grad, + ) + assert ( + can_be_in_graph + ), "a backward input that had data mutated in an autograd-aware way. This is not supported" + # Perform the input mutation + with torch.fx.traceback.preserve_node_meta(): + before.copy_(after) + + if aot_config.keep_inference_input_mutations: + # Note: This is a bit annoying. There's a layering issue here, where: + # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. + # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. + # However, we **only** want to support this for inputs that have data-only (and no metadata) mutations, + # because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()). + # This makes it pretty difficult for this logic to operate on synthetic bases. + # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual + # (unpacked) input aliases, instead of the synthetic base. + # Example case where (3) could be important: + # + # def f(x, y): + # x.mul_(2) + # y.mul_(3) + # return x, y + # a = torch.ones(1'000'000) + # x, y = out(a[0:9], a[1:10]) + # + # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing + # a giant "updated synthetic base" and copying into a's entire storage. + # + # For now, we are pessimistically not performing the optimization from (3); + # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. + # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry + # about synthetic bases. + for i, (inpt_old, inpt_f) in enumerate( + zip(args, f_args) if not trace_joint else zip(args[0], f_args[0]) + ): + if not isinstance(inpt_f, torch.Tensor): + continue + assert is_fun(inpt_f) + inpt_new = from_fun(inpt_f) + if ( + meta.input_info[i].mutation_type + == MutationType.MUTATED_IN_GRAPH + ): + # See Note [set_() Input Mutations in AOTAutograd] + # all mutations on the input must be under no_grad, so it is safe to put in the graph + # Here, we're saying that if an input experienced a set call, inp.set_(other), + # then we can effectively not have to worry about whether its data was mutated. + # There are 3 cases: + # (1) We mutate inp *after* the set_() call. other is a graph intermediate. + # In this case, we're not really mutating the input storage of "inp"; + # we're mutating the storage of an intermdiate value (other), + # and slamming that storage into the input tensor. So no data mutation is necessary. + # (2) We mutate inp *after* the set_() call. other is a graph *input*. + # In this case, the data mutation will be properly handled in the runtime + # epilogue during the processing of "other" + # (3) We mutate inp *before* the set_() call. + # This case is *not* currently handled. + if meta.input_info[i].mutates_storage_metadata: + with torch.no_grad(): + inpt_old.set_(inpt_new) + + # Note [Ordering of resize_() and set_()] + # Importantly: the common usage in FSDP is that we have a dummy parameter + # that sees a set_() and **Then** a resize_(). + # We must put those mutations into the graph in the same order, + # Since running them in the opposite order will have different behavior. + # We fully ban resize_() followed by set_() for now, although in principal + # we could support this + if meta.input_info[i].mutation_inductor_storage_resize: + # resizing is not supported on subclasses (we error earlier if this happens) + from torch._subclasses.functional_tensor import ( + FunctionalTensor, + ) + + assert isinstance(inpt_f, FunctionalTensor) + old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + inpt_f.elem, before=True + ) + new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + inpt_f.elem, before=False + ) + if old_storage_size != new_storage_size: + assert ( + old_storage_size == 0 or new_storage_size == 0 + ), f"""\ + Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size} + We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0 + (the case for FSDP)""" + torch.ops.inductor.resize_storage_bytes_( + inpt_old, new_storage_size + ) + if new_storage_size == 0: + # Even if we marked the input as having a data mutation (thus needing a copy_()), + # We should **ignore** it if our input has no storage + # (this can happen if, e.g. we temporarily resize our input, copy data into it, + # and resize it back down to zero) + continue + # Optimization: if the copy_() is a no-op then don't include it in the graph. + # In theory inductor could optimize this away, however in fsdp, we end up with + # param.copy_(param), where param is a zero-storage-size tensor, + # and running this op in eager mode (using the aot_eager backend) will result in a segfault. + # So we may as well optimize it away here. + if inpt_old is inpt_new: + # (This check needs to be done after putting resize_() in the graph, + # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor) + continue + # We found an input that had a (data-only) mutation. + # Since keep_input_mutations is set, we need to faithfully apply a copy_() + # so the compiler will see the input mutation in the graph. + if ( + meta.input_info[i].mutates_data + and meta.input_info[i].mutations_hidden_from_autograd + ): + # Hidden from autograd = run under no_grad, **and** don't bump VC + # (although if the tensor was created in inference mode, it has no VC) + if inpt_old.is_inference(): + maybe_preserve_vc = nullcontext() + else: + maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( + inpt_old # type: ignore[assignment] + ) + with torch.no_grad(), maybe_preserve_vc: + inpt_old.copy_(inpt_new) + elif ( + meta.input_info[i].mutates_data + and meta.input_info[ + i + ].mutations_under_no_grad_or_inference_mode + ): + # Under no_grad = run under no_grad (we still bump the VC though) + # (inference_mode will also bump the VC, as long as the tensor in question + # was created outside of inference_mode) + with torch.no_grad(): + inpt_old.copy_(inpt_new) + elif meta.input_info[i].mutates_data: + inpt_old.copy_(inpt_new) + + # When an output tensor is a functionalized mutated input, and we + # were able to move the mutation in to the graph then we can return + # the mutated input directly. This prevents duplicating the + # tensors contents. + flat_outs, outs_spec = pytree.tree_flatten(f_outs) + flat_outs = [from_fun(o) for o in flat_outs] + num_outs = len(meta.output_info) + + for i in range(num_outs): + info = meta.output_info[i] + if info.output_type != OutputType.is_input: + continue + + assert info.base_idx is not None + if ( + meta.input_info[info.base_idx].mutation_type + == MutationType.MUTATED_IN_GRAPH + ): + fw_args = args[0] if trace_joint else args + flat_outs[i] = fw_args[info.base_idx] + return pytree.tree_unflatten(flat_outs, outs_spec) + + return pytree.tree_map(from_fun, f_outs) + + # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals" + # and "tangents" as its input names (which are special-cased by the partitioner) + # TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export + def joint_helper(primals, tangents): + return _functionalized_f_helper(primals, tangents) + + helper = joint_helper if trace_joint else _functionalized_f_helper + if config.functionalize_rng_ops: + # Setup the wrapper for functionalization of rng ops + helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint) + + return helper, args + + +def handle_effect_tokens_fn( + fn, + args, + *, + meta: ViewAndMutationMeta, + trace_joint: bool, +) -> Any: + num_tokens = len(meta.tokens) + + @wraps(fn) + def inner_fn(*args): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + # See Note [Side-Effectful Tokens in AOTAutograd] + if trace_joint: + assert isinstance(args, tuple) and isinstance(args[0], (list, tuple)) + tokens = args[0][:num_tokens] + assert all(token.numel() == 0 for token in tokens) + args = (args[0][num_tokens:], *args[1:]) + else: + tokens = args[:num_tokens] + assert all(token.numel() == 0 for token in tokens) + args = args[num_tokens:] + + # Populate the current FunctionalTensorMode with the tokens per + # operator. See Note [FunctionalTensorMode is Stateful] + functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + assert functional_tensor_mode is not None + f_tokens = pytree.tree_map(to_fun, tokens) + for i, k in enumerate(meta.tokens.keys()): + functional_tensor_mode._tokens[k] = f_tokens[i] + + # Run the joint + outs = fn(*args) + + # Return both the tokens and the outputs + # See Note [Side-Effectful Tokens in AOTAutograd] + if trace_joint: + assert len(outs) == 2 + assert len(functional_tensor_mode._tokens_forward_output) == num_tokens + fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values() + + bwd_out_tokens = functional_tensor_mode._tokens.values() + + f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens] + f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens] + + meta.num_backward_tokens = len(bwd_out_tokens) + return ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens)) + + out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()] + return (*out_tokens, *outs) + + # Additionally pass in tokens as inputs + # See Note [Side-Effectful Tokens in AOTAutograd] + additional_fwd_token_inputs = [torch.tensor([])] * num_tokens + + if trace_joint: + args = ([*additional_fwd_token_inputs, *args[0]], *args[1:]) + else: + args = [*additional_fwd_token_inputs, *args] + return inner_fn, args + + +# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor +# Also returns: +# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated) +# - the updated ViewAndMutationMeta for this dense -> dense function. +# The other important arguments are: +# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function. +# when is_joint_structure=False, this is just the forward function. +# - fw_only: this is *always* the forward-only function. +# Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions. +# In particular, we need this to tell the partitioner how many dense forward outputs there are. +def aot_dispatch_subclass( + flat_fn_maybe_joint, + args: list[Any], + *, + is_joint_structure: bool, + meta: ViewAndMutationMeta, + fw_only: Callable, +) -> SubclassTracingInfo: + # Skip logic if we don't need to trace through any subclasses + req_subclass_dispatch = requires_subclass_dispatch(args, meta) + if not req_subclass_dispatch: + return SubclassTracingInfo( + plain_tensor_trace_fn=flat_fn_maybe_joint, + plain_tensor_args=args, + maybe_subclass_meta=None, + ) + + # TODO: add subclass guards (later PR). + + # What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs). + # Annoying: we don't know the grad input metas until we're in the middle of tracing the joint, + # so we set it later, while we're tracing the joint (see inner_fn() below). + # Another option would be to run our run_functionalized_fw_and_collect_metadata() function + # directly on the joint, but this would hurt compile time (adding yet another pass through the joint). + subclass_meta = SubclassMeta() + + def inner_fn(fn, args, *, use_trace_joint: bool): + # Step 1: wrap tensor inputs into subclasses if necessary + all_args = wrap_tensor_subclasses_maybe_joint( + args, is_joint_structure=use_trace_joint, meta=meta + ) + + # Step 2: call the inner function, with our (maybe subclass) inputs + wrapped_outs = fn(*all_args) + + if use_trace_joint: + # See Note: [Computing Subclass Metadata about grad_inputs] + # We also stash subclass info on our grad_inputs, if we're tracing the joint. + nonlocal subclass_meta + assert isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2 + # Don't need fw outs since we already have subclass metadata on them + grad_inputs = wrapped_outs[1] + subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs) + + # Add extra symints as outputs to the forward/backward graphs + # ignore nested ints here + forward_outs = unwrap_tensor_subclasses( + wrapped_outs[0], append_symints=True + ) + # ignore nested ints here + backward_outs = unwrap_tensor_subclasses( + wrapped_outs[1], append_symints=True + ) + return (forward_outs, backward_outs) + + # Step 3: Unwrap any subclass outputs back into dense tensors + unwrapped_outs = unwrap_tensor_subclasses(wrapped_outs, append_symints=True) + return unwrapped_outs + + def joint_fn(primals, tangents): + with maybe_enable_thunkify(): + return inner_fn( + flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True + ) + + def fw_fn(*primals): + with maybe_enable_thunkify(): + return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False) + + def metadata_fn(*primals): + return inner_fn(fw_only, primals, use_trace_joint=False) + + if is_joint_structure: + args_unwrapped = ( + # Add extra symints (size/strides) as input to the forward graph + unwrap_tensor_subclasses(args[0], append_symints=True), + # We pass append_symints=False here because the partitioner will + # capture and add any extra argument + unwrap_tensor_subclasses(args[1], append_symints=False), + ) + else: + args_unwrapped = unwrap_tensor_subclasses(args, append_symints=True) + remapped_static_indices = remap_unwrapped_subclass_arg_indices( + args, meta.static_input_indices + ) + + if is_joint_structure: + primals_unwrapped = args_unwrapped[0] + fn_to_trace = joint_fn + else: + primals_unwrapped = args_unwrapped + fn_to_trace = fw_fn + + # Note: [Partitioner handling for Subclasses, Part 1] + # The way the partitioner works is that: + # (1) we pass is a single graph containing the joint fw/bw, + # where the # of graph outputs corresponds to # fw_outputs + # grad_inputs + # (2) The partitioner accepts an arguments, num_fwd_outputs, + # and assumes that the first "num_fwd_outputs" graph outputs correspond + # to outputs of the forward graph. + # How do tensor subclasses enter the picture? + # the num_fwd_outputs in the final graph is actually non-trivial to compute, + # because it can be influenced by input mutations and intermediate bases. + # So we compute it by inspecting the current ViewAndMutationMeta object. + # However, the original ViewAndMutationMeta that we computed was created + # on the subclass -> subclass graph, + # which can have a different number of outputs than the dense -> dense graph. + # That's why we created a fresh metadata object on the dense -> dense function here, + # and plumb it back up to the partitioner. + # See Note: [Partitioner handling for Subclasses, Part 2] for more info. + meta_updated = run_functionalized_fw_and_collect_metadata( + metadata_fn, + static_input_indices=remapped_static_indices, + keep_input_mutations=meta.keep_input_mutations, + is_train=meta.is_train, + )(*primals_unwrapped) + + subclass_meta.fw_metadata = meta_updated + + return SubclassTracingInfo( + plain_tensor_trace_fn=fn_to_trace, + plain_tensor_args=args_unwrapped, + maybe_subclass_meta=subclass_meta, + ) + + +def create_functional_call(mod, params_spec, params_len, store_orig_mod=False): + # Redundant with dynamo, but worth having in case this gets invoked elsewhere. + # https://github.com/pytorch/pytorch/issues/103569 + + def functional_call(*args, **kwargs): + with stateless._reparametrize_module( + mod, pytree.tree_unflatten(args[:params_len], params_spec) + ), maybe_disable_thunkify(): + if isinstance(mod, torch.fx.GraphModule): + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Anomaly Detection has been enabled." + ) + with torch.autograd.detect_anomaly(check_nan=False): + detect_fake_mode().epoch += 1 + out = PropagateUnbackedSymInts(mod).run( + *args[params_len:], **kwargs + ) + else: + out = mod(*args[params_len:], **kwargs) + + if not isinstance(out, (tuple, list)): + raise RuntimeError( + "Graph output must be a (). This is so that we can avoid " + "pytree processing of the outputs. Please change the module to " + "have tuple outputs or use aot_module instead." + ) + return out + + # Note [Preserving the nn module stack metadata during export non-strict mode] + # This path is currently only used by the non-strict export flow, + # where we cannot rely on dynamo to preserve nn stack metadata in our captured graph. + # Instead, we stash the original user nn module here, and rely on `make_fx` to grab + # this stashed module and use it to track nn module stack metadata + if store_orig_mod and not hasattr(functional_call, "_orig_mod"): + functional_call._orig_mod = mod # type: ignore[attr-defined] + + return functional_call diff --git a/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/utils.py b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..95792fa24b70e60074ac22186496e2d743e7e0cf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/_aot_autograd/utils.py @@ -0,0 +1,515 @@ +# mypy: allow-untyped-defs +""" +Contains various utils for AOTAutograd, including those for handling collections. +""" + +import dataclasses +import operator +import warnings +from contextlib import nullcontext +from functools import wraps +from typing import Any, Callable, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._library.fake_class_registry import FakeScriptObject +from torch._logging import getArtifactLogger +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import py_sym_types + + +KNOWN_TYPES = [ + torch.Tensor, + BackwardState, + int, + str, + float, + bool, + type(None), + *py_sym_types, + FakeScriptObject, + torch.ScriptObject, +] + +original_zip = zip + +aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects") + + +def strict_zip(*iterables, strict=True, **kwargs): + if not strict: + return original_zip(*iterables, **kwargs) + + length = len(iterables[0]) + for iterable in iterables[1:]: + if len(iterable) != length: + raise ValueError( + "The iterables have different lengths and strict mode is enabled." + ) + + return original_zip(*iterables, **kwargs) + + +def _get_symint_hints(exprs): + """ + Get the hints of a list/tuple of int/SymInt. + """ + if isinstance(exprs, (list, tuple)): + return type(exprs)(_get_symint_hints(e) for e in exprs) + elif isinstance(exprs, torch.SymInt): + return exprs.node.shape_env.size_hint(exprs.node.expr) + else: + return exprs + + +def partial_flatten_asdict(obj: Any) -> Any: + if dataclasses.is_dataclass(obj): + return { + field.name: getattr(obj, field.name) for field in dataclasses.fields(obj) + } + elif isinstance(obj, (list, tuple)): + return obj.__class__([partial_flatten_asdict(item) for item in obj]) + elif isinstance(obj, dict): + return {k: partial_flatten_asdict(v) for k, v in obj.items()} + else: + return obj + + +def normalize_as_list(x): + if isinstance(x, tuple): + return list(x) + elif isinstance(x, list): + return x + return [x] + + +def _get_autocast_states(): + return [ + torch.is_autocast_enabled("cuda"), + torch.is_autocast_enabled("cpu"), + torch.get_autocast_dtype("cuda"), + torch.get_autocast_dtype("cpu"), + torch.is_autocast_cache_enabled(), + ] + + +def make_boxed_func(f): + def g(args): + return f(*args) + + g._boxed_call = True # type: ignore[attr-defined] + return g + + +def make_boxed_compiler(compiler): + @wraps(compiler) + def f(fx_g, inps): + out_f = compiler(fx_g, inps) + fx_g = make_boxed_func(out_f) + return fx_g + + return f + + +def call_func_at_runtime_with_args( + f, args: Union[tuple[Any], list[Any]], steal_args=False, disable_amp=False +): + if not steal_args: + args = list(args) + assert isinstance(args, list) + + context = torch._C._DisableAutocast if disable_amp else nullcontext + with context(): + if getattr(f, "_boxed_call", False): + out = normalize_as_list(f(args)) + else: + # TODO: Please remove soon + # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 + warnings.warn( + "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. " + "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " + "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale." + ) + out = normalize_as_list(f(*args)) + return out + + +# Inspired by autodidax (thanks!) +class PytreeThunk: + spec: Optional[pytree.TreeSpec] = None + # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. + is_simple: Optional[ + bool + ] = None # if the output spec is a tuple/list, we won't bother unflattening it. + is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec + + def set(self, spec: pytree.TreeSpec) -> None: + assert self.spec is None or self.spec == spec + assert spec is not None + self.spec: pytree.TreeSpec = spec + if self.spec.type in {tuple, list} and all( + child.is_leaf() for child in spec.children_specs + ): + self.is_simple = True + if self.spec.is_leaf(): + self.is_really_simple = True + + def unflatten(self, x: list[Any]) -> Any: + if self.is_really_simple: + return x[0] + if self.is_simple: + return x + assert self.spec is not None + return pytree.tree_unflatten(x, self.spec) + + +# Creates a function that returns flattened inputs and outputs +# Also returns the output tree spec, which is needed to recover the "unflattened" +# output tree structure later. +def create_tree_flattened_fn(fn, args, kwargs=None) -> tuple[Callable, PytreeThunk]: + if kwargs is None: + kwargs = {} + # Save the args_spec for flat_tensor_args to unflatten while tracing + _, tensor_args_spec = pytree.tree_flatten((args, kwargs)) + out_spec = PytreeThunk() + + def flat_fn(*flat_args): + # The input are flattened tensor args. Prepare the args in the + # order that original function expects. Add static args as well. + # They will appear as tensor constants in the traced graph. + nonlocal out_spec + args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec) + tree_out = fn(*args, **kwargs) + flat_out, spec = pytree.tree_flatten(tree_out) + for i in flat_out: + is_known_type = False + for j in KNOWN_TYPES: + if isinstance(i, j): + is_known_type = True + break + if not is_known_type: + raise RuntimeError( + f"Found {type(i)} in output, which is not a known type. " + "If this type holds tensors, you need to register a pytree for it. " + "See https://github.com/pytorch/functorch/issues/475 for a brief " + "explanation why. If you don't need to register a pytree, please " + "leave a comment explaining your use case and we'll make this more " + "ergonomic to deal with" + ) + out_spec.set(spec) + return flat_out + + # Can't use functools.wraps here because the wrapper has different + # calling convention + if hasattr(fn, "_orig_mod"): + flat_fn._orig_mod = fn._orig_mod # type: ignore[attr-defined] + + return flat_fn, out_spec + + +# This function takes in a tensor t, and returns one of t, t.view(), or t.clone(). +# When tracing the joint forward + backward, for any inputs in the graph that are mutated, +# we need to clone them first (and similarly for metadata-only mutations, we need to view them first). +# The idea is that when we trace the backward, we need to pass in the *original* primals +# to autograd.grad(), before they were mutated. +# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them. +# This means that "idx" here represents the index of the (potentially) synthetic base. +# What we need to do is: +# (1) map the current (post-synthetic-base calling convention) input argument index +# to int index pre-synthetic-base-calling-convention. +# (2) There could be multiple, if this index corresponds to a synthetic base +# that has multiple input aliases. +# (3) If any of those corresponding inputs get metadata mutations, then we clone the base. +def maybe_to_fresh_input(idx, t, meta): + if not isinstance(t, torch.Tensor): + return t + if idx in meta.mutated_inp_runtime_indices: + # We only need to bother cloning mutated inputs that participate in autograd. + if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data: + # Make sure the primal we pass to autograd.grad() + # sees the tensor before the mutation + return t.clone() + if meta.input_info[idx] and meta.input_info[idx].mutates_metadata: + # Make sure the primal we pass to autograd.grad() + # sees the tensor before the metadata mutation + return t.view(t.shape) + return t + + +def is_with_effects(node): + return ( + node.op == "call_function" + and node.target == torch.ops.higher_order.with_effects + ) + + +def is_with_effects_op(node, op): + return is_with_effects(node) and node.args[1] == op + + +def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None): + # Remove the tokens from the inputs/outputs of the graph since inductor does + # not want these extra inputs/outputs, and replace them with + # _make_token() to create a token, and _sink_tokens() to collect the + # tokens. See Note [Side-Effectful Tokens in AOTAutograd] + # Logic: + # 1. Inputs identified as input tokens: + # - If used as a first argument in with_effects + # + # 2. Outputs identified as output tokens: + # - If Produced by getitem(with_effects, 0) + # + # 3. Checks invariants of number input output tokens: + # forward: + # expected_num_erased_inputs == len(fw_metadata.tokens) + # expected_num_erased_outputs == len(fw_metadata.tokens) + # backward: + # expected_num_erased_inputs == fw_metadata.num_backward_tokens + # expected_num_erased_outputs == fw_metadata.num_backward_tokens + num_forward_tokens = len(fw_metadata.tokens) + num_backward_tokens = fw_metadata.num_backward_tokens + + def rewrite_with_effects_input_token(module, node): + with module.graph.inserting_before(node): + new_token_node = module.graph.call_function( + torch.ops.prims._make_token.default, () + ) + new_token_node.meta["val"] = torch.tensor([]) + new_token_node.meta["tensor_meta"] = torch.tensor([]) + + args = list(node.args) + args[0] = new_token_node + node.args = tuple(args) + + def rewrite_output(module, node, output_token_nodes, other_output_args): + for output_token_node in output_token_nodes: + assert ( + output_token_node.op == "call_function" + and output_token_node.target == operator.getitem + and output_token_node.args[1] == 0 + ) + with module.graph.inserting_before(node): + module.graph.call_function( + torch.ops.prims._sink_tokens.default, + (output_token_nodes,), + ) + node.args = (other_output_args,) + + def do(module, subgraph, expected_num_erased): + num_erased_inputs = 0 + num_erased_outs = 0 + input_nodes = [] + input_token_nodes = set() + with_effect_nodes = [] + output_token_nodes = [] + other_output_nodes = [] + for node in module.graph.nodes: + if node.op == "placeholder": + input_nodes.append(node) + elif is_with_effects(node): + with_effect_nodes.append(node) + if node.args[0] in input_nodes: + input_token_nodes.add(node.args[0]) + rewrite_with_effects_input_token(module, node) + elif node.op == "output": + outs = node.args[0] + for out in outs: + if ( + isinstance(out, torch.fx.node.Node) + and out.op == "call_function" + and out.target == operator.getitem + and out.args[1] == 0 + and out.args[0] in with_effect_nodes + ): + output_token_nodes.append(out) + else: + other_output_nodes.append(out) + + rewrite_output(module, node, output_token_nodes, other_output_nodes) + num_erased_outs = len(output_token_nodes) + + for input_token_node in input_token_nodes: + module.graph.erase_node(input_token_node) + + num_erased_inputs = len(input_token_nodes) + + assert ( + num_erased_inputs == expected_num_erased + ), f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}" + assert ( + num_erased_outs == expected_num_erased + ), f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}" + + module.recompile() + + if num_forward_tokens > 0: + if aot_config.enable_log: + from torch._dynamo.utils import lazy_format_graph_code + + aot_graphs_effects_log.debug( + "%s", + lazy_format_graph_code( + "Forward graph before unlifting tokens", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + do( + fw_module, + "forward", + num_forward_tokens, + ) + + if bw_module is not None and num_backward_tokens > 0: + if aot_config.enable_log: + from torch._dynamo.utils import lazy_format_graph_code + + aot_graphs_effects_log.debug( + "%s", + lazy_format_graph_code( + "Backward graph before unlifting tokens", + bw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + do(bw_module, "backward", num_backward_tokens) + + # This is sad, but we need to update the metadata to get rid of + # the tokens. + fw_metadata.tokens = {} + fw_metadata.num_backward_tokens = 0 + + +def root_module_when_exporting_non_strict(flat_fn): + # When exporting in non-strict mode, we wrap the root module in a specific pattern. + # See `_aot_export_non_strict` in torch.export._trace.py. + # We look for that wrapping pattern here. + if hasattr(flat_fn, "_orig_mod") and hasattr(flat_fn._orig_mod, "_export_root"): + return flat_fn._orig_mod._export_root + else: + return None + + +def copy_fwd_metadata_to_bw_nodes(fx_g): + """ + Input: `fx_g` which contains the joint fwd+bwd FX graph created by + aot_autograd. + + This function walks the graph and copies over metadata from forward nodes + to backward nodes, using the `seq_nr` field as a one-to-many mapping + from forward node to backward node. This metadata is useful for performance + profiling and debugging. + """ + + def _is_forward_node_with_seq_nr(node): + # For now, assume that if nn_module_stack_metadata is populated, this + # node is from the forward. Ignore nodes without `seq_nr`. + # TODO(future): there is likely a less brittle way to do this by walking + # the descendants of graph inputs corresponding to fwd inputs, didn't + # seem obvious at first glance on how to partition graph inputs into + # fwd vs bwd without relying on string names. + return "nn_module_stack" in node.meta and "seq_nr" in node.meta + + def _is_backward_node_with_seq_nr(node): + # For now, assume that if nn_module_stack_metadata is not populated, + # this node is from the backward. Ignore nodes without `seq_nr`. + # TODO(future): there is likely a less brittle way to do this, same + # as with the forward. + return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta + + fwd_seq_nr_to_node = {} + for node in fx_g.graph.nodes: + if not _is_forward_node_with_seq_nr(node): + continue + seq_nr = node.meta["seq_nr"] + if seq_nr in fwd_seq_nr_to_node: + # If we already saw an op with the current `seq_nr`, that means + # that the current op did not create an autograd node, and there + # is no corresponding backward node, so we skip. + continue + fwd_seq_nr_to_node[node.meta["seq_nr"]] = node + + for node in fx_g.graph.nodes: + if not _is_backward_node_with_seq_nr(node): + continue + # fwd_node should always exist, but handle non-existence just in case + fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"]) + if fwd_node is not None: + node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"] + node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack") + + +def register_buffer_assignment_hook(mod, assigned_buffers): + """ + Register a hook that intercepts buffer assignments. + This is used to detect when a buffer is assigned to, and then we can + map that buffer to the corresponding proxy node in the graph. + """ + + def _map_assigned_buffer_to_proxy(_mod, name, buffer): + # We intercept buffer assignments on the root module through this hook. + if _mod._buffers is mod._buffers: + # either buffer is a functional tensor, which wraps a fake tensor + if isinstance(buffer, FunctionalTensor): + buffer = buffer.from_functional() + # or buffer is a fake tensor + assert isinstance(buffer, FakeTensor) + # The fake tensor in turn is associated with a proxy node. + proxy_mode = torch.fx.experimental.proxy_tensor.get_proxy_mode() + assert proxy_mode is not None + proxy = torch.fx.experimental.proxy_tensor.get_proxy_slot( + buffer, proxy_mode.tracer + ).proxy.node + # We map the assigned buffer to this proxy node. + assigned_buffers[name] = proxy.name + return buffer + + return torch.nn.modules.module.register_module_buffer_registration_hook( + _map_assigned_buffer_to_proxy + ) + + +def contain_metadata_mutation_ops(module: torch.fx.GraphModule) -> bool: + """ + Checks if the module contains any metadata mutation ops. + """ + for node in module.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "tags") + and torch.Tag.inplace_view in node.target.tags + ): + return True + return False + + +def get_cuda_generator_meta_val(device_idx: int): + """ + Get a generator value to use as a meta val + + newly cloned generator will not contain tensors. it is only Generators that are + registered to a CUDAGraph that contain tensors. since this does not contain Tensor + it is fine to use in the meta. + """ + return torch.cuda.default_generators[device_idx].clone_state() + + +def top_saved_tensors_hooks(): + return torch._C._autograd._top_saved_tensors_default_hooks(True) + + +def saved_tensors_hooks_are_inlineable(hooks) -> bool: + if not hooks: + return False + pack, unpack = hooks + return isinstance(pack, torch.fx.GraphModule) and isinstance( + unpack, torch.fx.GraphModule + ) diff --git a/phivenv/Lib/site-packages/torch/_functorch/aot_autograd.py b/phivenv/Lib/site-packages/torch/_functorch/aot_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..a042101e55277b4a805065444ca22156a0ddaa0d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/aot_autograd.py @@ -0,0 +1,1737 @@ +# mypy: ignore-errors + +import itertools +from collections.abc import KeysView, Sequence +from contextlib import contextmanager, nullcontext +from functools import partial, wraps +from typing import Any, Callable, NewType, Optional, Protocol, TypeVar +from unittest.mock import patch + +import torch +import torch._dynamo.logging +import torch.nn as nn +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo import compiled_autograd +from torch._dynamo.utils import ( + CompileEventLogger, + dynamo_timed, + preserve_rng_state, + set_feature_use, +) +from torch._guards import detect_fake_mode +from torch._inductor.cudagraph_utils import BoxedDeviceIndex +from torch._inductor.output_code import OutputCode +from torch._inductor.utils import BoxedBool, InputType +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + _pytree_subclasses_that_lose_info, + make_fx, +) +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) +from . import config +from ._aot_autograd.autograd_cache import ( # noqa: F401 + AOTAutogradCache, + autograd_cache_key, + should_use_local_autograd_cache, + should_use_remote_autograd_cache, +) +from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 + run_functionalized_fw_and_collect_metadata, +) +from ._aot_autograd.functional_utils import ( # noqa: F401 + _check_if_mutation_can_be_in_graph, + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + assert_functional_graph, + from_fun, + gen_alias_from_base, + has_data_mutation, + has_metadata_mutation, + is_fun, + sync_functional_tensor, + to_fun, +) +from ._aot_autograd.input_output_analysis import ( # noqa: F401 + compute_overlapping_inputs, + create_graph_signature, + create_synthetic_base_metadata, + remove_dupe_metadata, +) +from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401 + aot_dispatch_autograd, + aot_dispatch_base, + aot_dispatch_export, +) +from ._aot_autograd.logging_utils import ( # noqa: F401 + callback_set, + describe_input, + format_guard_bug_msg, + get_aot_compilation_context, + get_aot_graph_name, + get_graph_being_compiled, + graph_being_compiled, + model_name, + nth_graph, + set_model_name, + setup_stacktrace_preservation_hooks, + track_graph_compiling, +) +from ._aot_autograd.runtime_wrappers import ( # noqa: F401 + AOTDedupeWrapper, + AOTSyntheticBaseWrapper, +) +from ._aot_autograd.schemas import ( # noqa: F401 + AOTConfig, + BackwardSignature, + FQN, + GraphInputName, + GraphOutputName, + GraphSignature, + InputAliasInfo, + MutationType, + OutputAliasInfo, + OutputType, + SubclassCreationMeta, + SubclassMeta, + TensorAlias, + ViewAndMutationMeta, +) +from ._aot_autograd.subclass_utils import ( # noqa: F401 + requires_subclass_dispatch, + unwrap_tensor_subclasses, + unwrap_tensor_subclasses_with_indices_to_original, + wrap_tensor_subclasses, + wrap_tensor_subclasses_maybe_joint, +) +from ._aot_autograd.traced_function_transforms import ( # noqa: F401 + aot_dispatch_subclass, + create_functional_call, + create_functionalized_fn, + create_functionalized_rng_ops_wrapper, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, +) +from ._aot_autograd.utils import ( # noqa: F401 + _get_autocast_states, + _get_symint_hints, + call_func_at_runtime_with_args, + create_tree_flattened_fn, + KNOWN_TYPES, + make_boxed_compiler, + make_boxed_func, + maybe_to_fresh_input, + normalize_as_list, + partial_flatten_asdict, + root_module_when_exporting_non_strict, + strict_zip, +) +from .partitioners import default_partition + + +zip = strict_zip + +# This global counter increments every time we compile a graph with +# AOTAutograd. You can use this to correlate runtime error messages +# with compile time (e.g., if you get an error at runtime saying +# compiled graph 3 failed, you can set a breakpoint at compile time +# for this graph number to investigate further at compile time.) +# +# NB: this is different from get_aot_compilation_context, which tracks +# each underlying graph that is compiled. In contrast, AOT_COUNTER +# corresponds to top-level invocations of aot_module/aot_function; +# one counter is allocated per entire compiled block (but this block +# may involve compiling multiple subgraphs; e.g., for forwards/backwards) +AOT_COUNTER = itertools.count() + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation +# that are external to the graph (they show up as side effects in some way when you run the graph). +# +# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions +# and what they're compiled graphs looks like. +# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them. +# +# Note [AOT Autograd: input data mutations] +# +# If we compile a function that mutates inputs, then those input mutations are real side effects +# that a user expects to see after running the compiled graph. +# However, the graph that we want to send to a backend needs to be *entirely* functional. +# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile +# but we update the graph to return (updated_inputs, user_outputs). +# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals. +# +# Example: original user code: +# def f(x): +# x.mul_(2) +# out = x.mul(3) +# return out +# +# After AOT Autograd compiles, we end up with a: +# (a) compiled graph +# (b) autograd.Function.forward() method, that executes the compiled graph +# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue +# +# The output of (a, b, c) are all written below. +# +# def compiled_forward_graph(x): +# x_updated = x.mul(2) +# out = x_updated.mul(3) +# return x_updated, out +# +# # x_updated gets a gradient in the compiled backward +# def compiled_backward_graph(grad_x_updated, grad_out): +# grad_x = ... +# return grad_x +# +# def autograd.Function.forward(x): +# x_updated, out = compiled_forward_graph(x) +# return x_updated, out +# +# def compiled_wrapper(x): +# x_updated, out = autograd.Function.apply(x) +# x.copy_(x_updated) +# return out +# +# Another important thing to note is that updated inputs (due to data mutations) *do* participate +# in the compiled backward graph! Since the compiled forward graph gets N extra outputs +# (due to updated inputs showing up as graph outputs), +# The compiled backward gets an additional N inputs. +# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input +# back to the original input. + + +# Note [AOT Autograd: input metadata mutations] +# +# For the same reason as input mutations, we also don't put input metadata mutations in the graph. +# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph +# +# Example: original user code: +# def f(x): +# x.t_() +# out = x.mul(3) +# return out +# +# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): +# def compiled_forward_graph(x): +# x_updated = x.t() +# out = x_updated.mul(3) +# return x_updated, out +# +# # x_updated does *not* get a gradient in the compiled backward +# def compiled_backward_graph(grad_out): +# grad_x = ... +# return grad_x +# +# def autograd.Function.forward(x): +# x_updated, out = compiled_forward_graph(x) +# return x_updated, out +# +# def compiled_wrapper(x): +# x_updated, out = autograd.Function.apply(x) +# x.as_strided_(x_updated) +# return out + + +# Note [AOT Autograd: outputs aliasing inputs or intermediates!] +# +# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates! +# Why? +# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated. +# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph, +# in an epilogue. +# For outputs that alias inputs, we do the following: +# (a) *still* return the aliased output as a graph output +# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output. +# +# For outputs that alias *intermediates*, we do the following: +# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward +# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output). +# You might wonder why we return the aliased output directly in the graph (and making the graph compute it), +# only to not return it and instead generate a fresh alias off of the intermediate, +# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons: +# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call +# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance. +# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides, +# when it has a different set of strides. +# By including the view op directly in the graph, inductor takes that into account when deciding what memory format +# the graph intermediate should be. +# +# Another important thing to note is how our traced backward() graph handles aliases. +# (this applies to outputs aliasing inputs, outputs aliasing intermediates, +# *and* updated inputs returned in the compiled forward due to metadata-only mutations). +# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph +# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly +# at the end of the forward. +# +# Example: original user code: +# def f(x): +# out1 = x.t() +# intermediate = x.mul(2) +# out2 = intermediate.view(-1) +# return out1, out2 +# +# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): +# def compiled_forward_graph(x): +# out1 = x.t() +# intermediate = x.mul(2) +# out2 = intermediate.view(-1) +# # the compiled graph also returns the intermediate +# return out1, out2, intermediate +# +# # intermediate gets a gradient in the compiled backward. +# # both output aliases (out1 and out2) do not. +# def compiled_backward_graph(grad_intermediate): +# grad_x = ... +# return grad_x +# +# def autograd.Function.forward(x): +# out1, out2, intermediate = compiled_forward_graph(x) +# return out1, out2, intermediate +# +# def compiled_wrapper(x): +# out1, out2, intermediate = autograd.Function.apply(x) +# # regenerate out1 from the input +# out1_regenerated = out1._view_func(x) +# # regenerate out1 from the intermediate +# out2_regenerated = out2._view_func(intermediate) +# return out1_regenerated, out2_regenerated + + +# Note [AOT Autograd: mutations to inputs that alias other inputs] +# +# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input. +# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other. +# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias +# given the mutation that occurred. +# +# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input +# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base +# inside of the compiled function. +# +# This logic is fully encapsulated in aot_wrapper_synthetic_base() +# +# Example: original user code: +# def f(x, x_view): +# x.mul_(2) +# out = x * x_view +# return out +# f(x, x.view(-1)) +# +# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function): +# def compiled_forward_graph(base) +# x = generate_x(base) +# x_view = generate_x_view(base) +# x_updated = x.mul(2) +# x_view_updated = x_updated.view(-1) +# out = x_updated * x_view_updated +# return x_updated, out +# +# # The calling convention change from (aliases) -> (base) happens +# # *outside* of the autograd.Function.forward(). +# # That means the forward() only has 1 input (base), +# # and the backward() only has 1 output (grad_base) +# def compiled_backward_graph(grad_out): +# grad_base = ... +# return grad_base +# +# def autograd.Function.forward(base): +# x_updated, out = compiled_forward_graph(base) +# return x_updated, out +# +# # The compiled wrapper is where we create synthetic bases. +# # The info on which inputs are mutated is also tracked *before* synthetic base creation. +# def compiled_wrapper(x, x_view): +# base = merge_view_inputs(x, x_view) +# x_updated, out = autograd.Function.apply(base) +# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view. +# x.copy_(x_updated) +# return out + + +# Note [AOT Autograd: Views to avoid tangents aliasing inputs] +# +# We view every forward output when creating out tangent tensors to handle the problematic +# case in which a subclass does extra aliasing between graph outputs/inputs in a way that +# is not visible above the sublass. +# +# Ordinarily, when constructing the joint function that we want to trace in AOTAutograd, +# we're guaranteed that the tangent tensors that we pass +# into the joint are distinct tensors from the primals. This is because when +# decide which forward outputs to create tangents for, we only create tangents +# for forward outputs that are not aliases of inputs (See Note +# [AOT Autograd: outputs aliasing inputs or intermediates!]). +# +# However, when wrapper tensor subclasses enter the picture, it is possible +# to have an output of the forward that is a subclass that is not an +# input / alias of an input, but one of its inner tensors is an alias! +# NestedTensor is an example: Performing an out-of-place pointwise op on a +# NestedTensor constructs a fresh NestedTensor that holds onto the input's +# offsets tensor directly. +# +# Having tangent tensors that are the same as the (primal) forward inputs, +# can cause problems during tracing as make_fx() will specialize on our +# duplicate inputs: If we passed in the same tensor for primals_1 and +# tangents_1 during tracing, make_fx() will happily sub out all usages of +# tangents_1 with primals_1 in the graph, which is not what we want. +# +# To work around this, we view every forward output when creating out tangent +# tensors so that tangents can never be the same as forward inputs even if +# forward inputs alias forward outputs. + +# Note [Side-Effectful Tokens in AOTAutograd] +# +# We allow some some side-effectful operators in +# the post-AOTAutograd (functional) graph, such as prints and torchbind operations. +# To ensure that these side-effects are compatible to future graph passes that +# assume that the graph is functional, we will thread "effect tokens" to show +# data dependence between these side-effectful operators. Practically speaking, +# effect tokens are just dummy values (torch.tensor([])). The graph would look +# like the following: +# +# def gm(self, token0, reader): +# token1, frame = with_token(ordered_effect_op, (reader,), token0) +# frame = frame * 2 +# token2, frame2 = with_token(ordered_effect_op, (reader,), token1) +# frame2 = frame2 * 2 +# return token2, frame, frame2 +# +# We will pass the token as an input to the graph, thread it through +# side-effectful operators using the `with_effects` high order operator, and then +# return the updated token as an output. +# So the signature of the graph input would look something like +# (*tokens, *params_buffers, *user_inputs), and the signature of the graph +# output would look something like (*tokens, *outputs). +# +# However, Inductor does not want the concept of tokens in the final generated +# code's input and output. Since changing the graph signature inside of inductor +# is difficult, after generating the forward graph, we will run a pass to +# remove the tokens from the inputgenerate the following graph for Inductor, where +# the tokens are created and sunk within the graph, rather than as inputs and +# outputs: +# +# def gm(self, reader): +# token0 = torch.ops.prims._make_token() +# token1, frame = with_token(ordered_effect_op, (reader,), token0) +# frame = frame * 2 +# token2, frame2 = with_token(ordered_effect_op, (reader,), token1) +# frame2 = frame2 * 2 +# sink_token = torch.ops.prims._sink_tokens([token2]) +# return frame, frame2 + +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +aot_autograd_decompositions = {} + +FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any]) + + +TOutputCode = TypeVar("TOutputCode", bound=OutputCode) + + +class AOTDispatchCompiler(Protocol): + """ + Represents a fw or bw_compiler passed to AOTAutograd. + """ + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> Any: + ... + + +# TODO: bikeshed on this name +class SerializableAOTDispatchCompiler(AOTDispatchCompiler): + """ + Represents an AOTDispatchCompiler that returns an OutputCode, and is + therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode. + A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of + the kwargs in _CompileFxKwargs. + """ + + def __init__( + self, + output_code_ty: type[TOutputCode], + compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode], + ): + self.output_code_ty = output_code_ty + self.compiler_fn = compiler_fn + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> OutputCode: + return self.compiler_fn(gm, example_inputs) + + +def process_inputs( + flat_args: list[Any], + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], + ignore_shape_env: bool = False, +) -> FakifiedFlatArgs: + with fake_mode: + + def convert(idx, x): + if shape_env is not None and not ignore_shape_env: + from torch._dynamo.source import ConstantSource + + if isinstance(x, int): + # We always specialize on scalar values in export. + if aot_config.is_export: + return x + source = ConstantSource(f"sym_{idx}") + return shape_env.create_symintnode( + shape_env.create_symbol(x, source), hint=x, source=source + ) + if isinstance(x, torch.ScriptObject): + return torch._library.fake_class_registry.maybe_to_fake_obj( + fake_mode, x + ) + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + if is_traceable_wrapper_subclass(x): + attrs, _ = x.__tensor_flatten__() + if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): + assert all( + getattr(x, attr).fake_mode is fake_mode for attr in attrs + ) + return x + + # see note [Tensor Fakification and Symbol Caching] + symbolic_context = None + source = None + trace = True + if tracing_context := torch._guards.TracingContext.try_get(): + if x in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[x] + source = symbolic_context.tensor_source + # We already fakeified this tensor in Dynamo, don't + # dump the trace for it again + trace = False + if ( + idx < aot_config.num_params_buffers + and config.static_weight_shapes + and not symbolic_context + ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo + return fake_mode.from_tensor(x, static_shapes=True) + + result = fake_mode.from_tensor( + x, + static_shapes=ignore_shape_env, + symbolic_context=symbolic_context, + source=source, + trace=trace, + ) + return result + + return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) + + +def construct_fake_mode( + flat_args: list[Any], aot_config: AOTConfig +) -> tuple[FakeTensorMode, Optional[ShapeEnv]]: + fake_mode = detect_fake_mode(flat_args) + if fake_mode is None: + shape_env = ShapeEnv() if aot_config.dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) + else: + shape_env = fake_mode.shape_env + return (fake_mode, shape_env) + + +def create_aot_dispatcher_function( + flat_fn, + fake_flat_args: FakifiedFlatArgs, + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], +) -> tuple[Callable, ViewAndMutationMeta]: + with dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True): + return _create_aot_dispatcher_function( + flat_fn, fake_flat_args, aot_config, fake_mode, shape_env + ) + + +def _create_aot_dispatcher_function( + flat_fn, + fake_flat_args: FakifiedFlatArgs, + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], +) -> tuple[Callable, ViewAndMutationMeta]: + """ + Traces the forward and backward graphs of the attr:`flat_fn` to generate a + joint graph. The joint graph is an Fx graph with Aten ops. Please refer to + the tracing mechanism to understand the graph capturing details. + + The joint graph is then passed through attr:`partition_fn` to isolate the + forward and backward portions, which are then respectively compiled via the + provided attr:`fw_compiler` and attr:`bw_compiler`. + + The resulting compiled forward and backward graphs are then wrapped up in a + ``torch.autograd.Function`` object. + + The calling convention here is that the first aot_config.num_params_buffers + inputs in flat_args are parameters and buffers, and the rest are inputs. + + We use this to assume that parameters/buffer's shapes don't change. + + Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export) + When aot_config.is_export is True, we return an FX graph + metadata + When aot_config.is_export is False, we return an ordinary runtime function + """ + + # This is the main entry point. + # TODO: Chillee argues that dynamo itself should pass in fake tensors to + # the list of arguments when compiling; at the moment we do not do this + + if aot_config.decompositions is None: + aot_config.decompositions = {} + + aot_config.decompositions = { + **aot_autograd_decompositions, + **aot_config.decompositions, + } + + if config.functionalize_rng_ops: + # Update the decompositions with functionalized random decompositions + aot_config.decompositions = { + **rng_decompositions, + **aot_config.decompositions, + } + + # Check flat_args to see if they're already fake. If so, use that fake + # mode instead. + + python_dispatcher_mode = ( + enable_python_dispatcher() if shape_env is not None else nullcontext() + ) + + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + # If any saved tensor hooks are active, we **don't** want to trace them. + # Instead, we'll let them run at runtime, around the custom autograd.Function + # that we generate in torch.compile. + with torch.autograd.set_multithreading_enabled( + False + ), preserve_rng_state(), ( + fake_mode + ), ( + python_dispatcher_mode + ), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + from torch._library.fake_class_registry import ( + FakeScriptObject, + maybe_to_fake_obj, + ) + + # Tracing may mutate the states the fake script object, + # so we need to duplicate the fake script objects so that subsequent tracing + # won't be affected. + def _dup_fake_script_obj(fake_flat_args): + return [ + maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj) + if isinstance(arg, FakeScriptObject) + else arg + for arg in fake_flat_args + ] + + needs_autograd = any( + x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) + ) + + with enable_python_dispatcher(): + # Patch set_rng_state as set_rng_state with fake tensors is + # nonsensical. This does not affect the collection of metadata. + with patch("torch.cuda.set_rng_state", lambda *args: None): + mod = root_module_when_exporting_non_strict(flat_fn) + if mod is not None: + ctx = _detect_attribute_assignment(mod) + else: + ctx = nullcontext() + + if torch._functorch.config.fake_tensor_propagate_real_tensors: + # Running dynamo_timed causes fake tensor issues when + # propagate real tensor is switched on. + dynamo_timed_ctx = nullcontext() + else: + dynamo_timed_ctx = dynamo_timed( + "aot_collect_metadata", log_pt2_compile_event=True + ) + + with dynamo_timed_ctx, ctx: + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=aot_config.keep_inference_input_mutations, + is_train=needs_autograd, + pre_dispatch=aot_config.pre_dispatch, + is_export=aot_config.is_export, + )(*_dup_fake_script_obj(fake_flat_args)) + + req_subclass_dispatch = requires_subclass_dispatch( + fake_flat_args, fw_metadata + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", requires_subclass_dispatch=req_subclass_dispatch + ) + + output_and_mutation_safe = not any( + x.requires_grad + # view-type operations preserve requires_grad even in no_grad. + # Do not count aliases of inputs with requires_grad as reason to make a training graph, + # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, + # setting their grad_fn properly. + and not ( + x.output_type + in (OutputType.alias_of_input, OutputType.is_input) + and fw_metadata.input_info[x.base_idx].requires_grad + ) + for x in fw_metadata.output_info + ) and not any( + x.requires_grad + and x.mutates_data + and not x.mutations_under_no_grad_or_inference_mode + and not x.mutations_hidden_from_autograd + for x in fw_metadata.input_info + ) + + if needs_autograd and output_and_mutation_safe: + # We realized that none of the outputs require grad, + # and none of the inputs that require grad are mutated. + # so we actually have an inference graph. + needs_autograd = False + # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta + # changes depending on whether we pass in is_train / keep_input_mutations, + # so we're forced to recompute the metadata. + # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata + # so that this is unnecessary. + if req_subclass_dispatch: + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + keep_input_mutations=aot_config.keep_inference_input_mutations, + is_train=False, + pre_dispatch=aot_config.pre_dispatch, + static_input_indices=aot_config.static_input_indices, + )(*fake_flat_args) + else: + fw_metadata = ViewAndMutationMeta( + input_info=fw_metadata.input_info, + output_info=fw_metadata.output_info, + num_intermediate_bases=fw_metadata.num_intermediate_bases, + keep_input_mutations=aot_config.keep_inference_input_mutations, + traced_tangents=fw_metadata.traced_tangents, + subclass_inp_meta=fw_metadata.subclass_inp_meta, + subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, + subclass_tangent_meta=fw_metadata.subclass_tangent_meta, + is_train=False, + tokens=fw_metadata.tokens, + static_input_indices=fw_metadata.static_input_indices, + ) + + if fw_metadata.num_intermediate_bases > 0: + assert not req_subclass_dispatch, f"""\ +torch.compile is currently being used with tensor subclass inputs: +{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs +that alias one another, which is currently unsupported in the subclass use case. If you run into this, +please file a github issue""" + + if aot_config.is_export: + # aot_export: ban input metadata mutations for now to keep shared code paths simpler. + # Keeping .resize_() in the graph will require some work + # Allowing it but keeping the graph functional will require some calling convention changes. + if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: + raise RuntimeError( + f"""\ +Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`. +This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. + +fw_metadata={str(fw_metadata)}""" + ) + # In export, banning data mutations on inputs that require grad for now. + # This should be rare, and is tricky to get right. When we trace the backward, + # we currently trace with autograd.grad instead of .backward(), which makes it difficult + # to ensure that we run autograd all the way through the input **before** it saw the mutation. + if ( + len( + [ + x + for x in fw_metadata.input_info + if x.requires_grad and x.mutates_data + ] + ) + != 0 + ): + raise RuntimeError( + f"""\ +Found a graph input that requires gradients, and received a mutation. +This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. + +fw_metadata={str(fw_metadata)}""" + ) + if req_subclass_dispatch: + raise RuntimeError( + """\ +aot_export is not currently supported with traceable tensor subclass. +If you need this feature, please comment on """ + ) + + # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, + # and turning it on will require a non-trivial calling convention change for any export runtime. + if config.functionalize_rng_ops: + raise RuntimeError( + """\ +Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue, +or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" + ) + + def choose_dispatcher(needs_autograd, aot_config): + """ + Pick a dispatcher based on the config rules. + """ + if aot_config.is_export: + # export uses just the "graph bits", whereas the other + # two dispatchers include some extra work around handling a runtime epilogue + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="export" + ) + return partial(aot_dispatch_export, needs_autograd=needs_autograd) + elif needs_autograd and not aot_config.pre_dispatch: + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + return aot_dispatch_autograd + else: + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + return aot_dispatch_base + + compiler_fn = choose_dispatcher(needs_autograd, aot_config) + + compiled_fn, fw_metadata = compiler_fn( + flat_fn, + _dup_fake_script_obj(fake_flat_args), + aot_config, + fw_metadata=fw_metadata, + ) + return compiled_fn, fw_metadata + + +def aot_function( + fn: Callable, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[dict] = None, + num_params_buffers: int = 0, + keep_inference_input_mutations: bool = False, + inference_compiler: Optional[Callable] = None, + *, + # Whether or not to trace with dynamic shapes + dynamic=False, + enable_log=True, +) -> Callable: + """ + Traces the forward and backward graph of :attr:`fn` using torch dispatch + mechanism, and then compiles the generated forward and backward graphs + through :attr:`fw_compiler` and :attr:`bw_compiler`. + + :func:`aot_function` traces the forward and backward graph ahead of time, + and generates a joint forward and backward graph. :attr:`partition_fn` is + then used to separate out forward and backward graphs. The partitioner + function can be used to perform optimizations such as recomputation. One can + set `decompositions` dictionary to decompose the operators into a sequence + of core or simpler operators supported by the backend compilers. + + .. warning:: + This API is experimental and likely to change. + + Args: + fn (Callable): A Python function that takes one ore more arguments. Must + return one or more Tensors. + fw_compiler (Callable): A Python function that accepts an Fx graph with + Aten ops and input args, and returns a Callable that semantically is + equivalent to the input Fx graph. + bw_compiler (Optional[Callable]): A Python function that accepts an + Fx graph with Aten ops and input args, and returns a Callable that + semantically is equivalent to the input Fx graph. Default: None + (when None, it defaults to the :attr:`fw_compiler`) + partition_fn (Callable): A Python function that takes a joint forward + and backward graph, and partitions it into separate forward and + backward graphs. + decompositions (Dict): A dictionary to define the decomposition of + larger Aten ops into simpler or core Aten ops. + inference_compiler (Optional[Callable]): A Python function that accepts an + Fx graph with Aten ops and input args, and returns a Callable that + semantically is equivalent to the input Fx graph. inference_compiler is invoked + if no autograd is needed. Default: None + (when None, it defaults to the :attr:`fw_compiler`) + Returns: + Returns a ``Callable`` that retains the eager behavior of the original + :attr:`fn`, but with forward and backward graph compiled via + :attr:`fw_compile` and :attr:`bw_compile`. + + A simple example usage of :func:`aot_function` is as follows. This example + will print the forward and backward graphs of the function ``fn`` + + >>> fn = lambda x : x.sin().cos() + >>> def print_compile_fn(fx_module, args): + >>> print(fx_module) + >>> return fx_module + >>> aot_fn = aot_function(fn, print_compile_fn) + >>> x = torch.randn(4, 5, requires_grad=True) + >>> aot_fn(x) + """ + + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + num_params_buffers=num_params_buffers, + aot_id=next(AOT_COUNTER), + keep_inference_input_mutations=keep_inference_input_mutations, + dynamic_shapes=dynamic, + aot_autograd_arg_pos_to_source=None, + is_export=False, + no_tangents=False, + enable_log=enable_log, + ) + cached_res = None + + @wraps(fn) + def returned_function(*args, **kwargs): + nonlocal cached_res + # Now flatten the tensor args + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + + # Compile the function and save it in the cache + if cached_res is None: + flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs) + (fake_mode, shape_env) = construct_fake_mode(flat_args, aot_config) + fake_flat_args: FakifiedFlatArgs = process_inputs( + flat_args, aot_config, fake_mode, shape_env + ) + compiled_fn, _ = create_aot_dispatcher_function( + flat_fn, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + cached_res = (compiled_fn, out_spec) + + cached_fn, out_spec = cached_res + out = cached_fn(flat_args) + return out_spec.unflatten(out) + + return returned_function + + +def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: + """ + Traces the forward and backward graph of :attr:`mod` using torch dispatch + tracing mechanism. It is wrapper function, that underneath uses + :func:`aot_function` to perform tracing and compilation. + + :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs + to a new callable which is then compiled through :func:`aot_function`. + + .. warning:: + This API is experimental and likely to change. + + Args: + mod (Callable): A ``nn.Module`` module. + args : args to be passed to :func:`aot_function` + kwargs : kwargs to be passed to :func:`aot_function` + + Returns: + Returns a ``nn.Module`` that retains the eager behavior of the original + :attr:`mod`, but with forward and backward graph compiled. + + """ + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(mod) + + def functional_call(named_params, named_buffers, *args, **kwargs): + params_and_buffers = {**named_params, **named_buffers} + return torch.func.functional_call(mod, params_and_buffers, args, kwargs) + + named_params = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + num_params_buffers = len(named_params) + len(named_buffers) + compiled_f = aot_function( + functional_call, *args, num_params_buffers=num_params_buffers, **kwargs + ) + + class AOTModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.orig_module = mod + + def forward(self, *args, **kwargs): + return compiled_f( + named_params, + named_buffers, + *args, + **kwargs, + ) + + return AOTModule() + + +def _try_get_metadata_from_dynamo( + mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int +) -> tuple[Optional[list[torch._guards.Source]], list[int]]: + """ + Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule. + We first verify that `mod` does come from Dynamo, then we handle cases where + metadata might be missing. + + Returns: + aot_autograd_arg_pos_to_source: used to dedup params and their guards + static_input_indices: used to identify static inputs for cudagraphs + """ + # Note [Assumption on Dynamo Metadata] + # This function assumes a graph module from dynamo provides `dynamo_compiled_id`, + # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes. + # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to + # be propagated in order to be recognized as a dynamo graph + + if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta): + # graph was not captured by dynamo + return None, [] + + if not hasattr(mod, "_param_name_to_source"): + # is from export + return None, [] + + # We now know this came from dynamo, and (1) we care about guards, + # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards + # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. + # Additionally, we mark static indices for cudagraphs. + param_name_to_source = mod._param_name_to_source + seen_sources = set() + + aot_autograd_arg_pos_to_source = [] + static_input_indices = [] + # Collect the new inputs lifted by aotdispatch + for i, name in enumerate(param_keys): + assert name in param_name_to_source, f"{name} not found." + source = param_name_to_source[name] + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + + static_input_indices.append(i) + + # Collect the dynamo graph inputs + # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID + # matched tensors back into the Fx graph, this might not be necessary. + for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): + assert hasattr(node, "_dynamo_source") + source = node._dynamo_source + # `source`` specifies the source from user code. ddp optimizer may have + # intermediate values becoming submodule placeholders which does not + # have a source + assert source is None or source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + source_name = source.name() if source else str(source) + + # input[i] in dynamo is now: + # input[i + len(extra_params)] in AOT, + # where extra_params are the params/buffers that dynamo baked into the + # OutputGraph + actual_pos = pos + len(param_keys) + + if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( + "_dynamo_static_input_type", None + ): + static_inputs_log.debug( + "Adding static input pos %s for source %s", actual_pos, source_name + ) + static_input_indices.append(actual_pos) + else: + static_inputs_log.debug( + "Non-static input pos %s for source %s", actual_pos, source_name + ) + + assert full_args_num == len(aot_autograd_arg_pos_to_source) + return aot_autograd_arg_pos_to_source, static_input_indices + + +def aot_module_simplified( + mod: nn.Module, + args, + fw_compiler: AOTDispatchCompiler, + bw_compiler: Optional[AOTDispatchCompiler] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[dict] = None, + keep_inference_input_mutations=False, + inference_compiler: Optional[AOTDispatchCompiler] = None, + cudagraphs: Optional[BoxedBool] = None, + boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, + ignore_shape_env: bool = False, +) -> nn.Module: + """ + This is the simplified or low overhead version of aot_module. For frontends + like TorchDynamo, the input functions/modules to AOT are static and have + unpacked inputs/outputs. This gives us an opportunity to remove the + (1) pytree overhead to parse inputs/outputs, + (2) AOT Autograd cache, + (3) Reading of params/buffers in every forward call + + :func:`aot_module_simplified` removes these overheads. + """ + params = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + params_flat, params_spec = pytree.tree_flatten(params) + params_flat = list(params_flat) + params_len = len(params_flat) + + if cudagraphs is None: + cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler + + full_args = [] + # First, the params + full_args.extend(params_flat) + + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.params_flat = params_flat + ( + tracing_context.params_flat_unwrap_subclasses, + tracing_context.params_unwrapped_to_flat_index, + ) = unwrap_tensor_subclasses_with_indices_to_original(params_flat) + + # Next, the input args + full_args.extend(args) + + ( + aot_autograd_arg_pos_to_source, + static_input_indices, + ) = _try_get_metadata_from_dynamo(mod, params.keys(), len(full_args)) + + dynamic_shapes = False + for x in full_args: + if isinstance(x, FakeTensor): + dynamic_shapes = x.fake_mode.shape_env is not None + break + + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + num_params_buffers=params_len, + aot_id=next(AOT_COUNTER), + keep_inference_input_mutations=keep_inference_input_mutations, + dynamic_shapes=dynamic_shapes, + aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source, + static_input_indices=static_input_indices, + is_export=False, + no_tangents=False, + cache_info=None, + ignore_shape_env=ignore_shape_env, + precompile_backend_id=getattr(mod, "_backend_id", None), + ) + fake_mode, shape_env = construct_fake_mode(full_args, aot_config) + fake_flat_args = process_inputs( + full_args, aot_config, fake_mode, shape_env, ignore_shape_env + ) + + def dispatch_and_compile(): + functional_call = create_functional_call(mod, params_spec, params_len) + with compiled_autograd._disable(): + compiled_fn, _ = create_aot_dispatcher_function( + functional_call, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + return compiled_fn + + # We only care if the forward will return an OutputCode. + if isinstance(fw_compiler, SerializableAOTDispatchCompiler): + local = should_use_local_autograd_cache() + remote = should_use_remote_autograd_cache() + if local or remote: + set_feature_use("aot_autograd_remote_cache", remote) + compiled_fn = AOTAutogradCache.load( + dispatch_and_compile, + mod, + fake_flat_args, + aot_config, + cudagraphs, + boxed_forward_device_index, + local, + remote, + ) + else: + compiled_fn = dispatch_and_compile() + else: + compiled_fn = dispatch_and_compile() + + if isinstance(mod, torch._dynamo.utils.GmWrapper): + # This function is called by the flatten_graph_inputs wrapper, which boxes + # the inputs so that they can be freed before the end of this scope. + # For overhead reasons, this is not the default wrapper, see comment: + # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481 + def boxed_forward(runtime_args: list[Any]): + flat_args = [] + flat_args.extend(params_flat) + flat_args.extend(runtime_args) + runtime_args.clear() + return compiled_fn(flat_args) + + # Just for convenience + boxed_forward.zero_grad = mod.zero_grad + boxed_forward.named_parameters = mod.named_parameters + boxed_forward.named_buffers = mod.named_buffers + return boxed_forward + + # TODO: There is something deeply wrong here; compiled_fn running with + # the boxed calling convention, but aot_module_simplified somehow + # historically returned a function that was not the boxed calling + # convention. This should get fixed... + # NB: GraphModule/nn.Module rely on the non-boxed calling convention here + def forward(*runtime_args: tuple[Any]): + full_args = [] + full_args.extend(params_flat) + full_args.extend(runtime_args) + return compiled_fn(full_args) + + # Just for convenience + forward.zero_grad = mod.zero_grad + forward.named_parameters = mod.named_parameters + forward.named_buffers = mod.named_buffers + + return forward + + +def aot_export_module( + mod: nn.Module, + args, + *, + decompositions: Optional[dict] = None, + # If true, we'll return a joint forward-backward graph, + # As well as metadata on the loss + gradients in the backward. + trace_joint: bool, + # If trace_joint is True, we expect your module to return a scalar loss. + # Your module can return multiple outputs, so you must specify which output the loss is. + output_loss_index: Optional[int] = None, + pre_dispatch: bool = False, + # If None, will be infered from inputs and mod.graph.nodes if mod is a graph module, but the inferred result might be wrong. + dynamic_shapes: Optional[bool] = None, + kwargs=None, +) -> tuple[torch.fx.GraphModule, GraphSignature]: + """ + This function takes in a module, and returns: + (1) an FX graph that can be exported + (2) some metadata about the graph + + If `trace_joint=True` we will return a joint graph of the forward + backward. + + The traced FX graph will have the following properties compared to the original module: + (1) Inputs and outputs to the module will be pytree-flattened + (2) Parameters and buffers on the module will be lifted into graph inputs, + graph_inputs = (*parameters, *buffers, *user_inputs) + (3) The graph will be fully functionalized + (4) Any input mutations will be converted into additional outputs in the graph, + meaning whoever calls this graph is responsible for applying the mutations + back to the original inputs. + (5) If is_joint is provided the graph will return parameter gradients in addition to user outputs. + The graph output will look like: + graph_outputs = (*updated_inputs, *user_outputs, *param_gradients) + + There are also several restrictions on what modules can use this API. In particular: + (1) If trace_joint is specified, we expect the loss function to be **fused** + into the module forward. One of the outputs to the forward must be a scalar loss, + which is specified with `output_loss_index`. + All other outputs to the forward are presumed to not require gradients. + (2) This API cannot capture optimizers (although in theory we could build an API for this). + (3) Metadata mutations on params/buffers/inputs are banned. + (4) Data mutations on anything that requires gradients are banned (parameters) + (5) If an input is mutated, it is not allowed to alias any other inputs. + (6) Parameters must not be duplicated. + """ + if pre_dispatch and trace_joint: + raise RuntimeError("pre_dispatch is not supported when trace_joint is True.") + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + + params_and_buffers = { + **dict(named_parameters), + **dict(named_buffers), + } + params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) + params_and_buffers_flat = tuple(params_and_buffers_flat) + params_len = len(params_and_buffers_flat) + + kwargs = kwargs or {} + + functional_call = create_functional_call( + mod, params_spec, params_len, store_orig_mod=True + ) + + num_fw_outs = None + + if trace_joint: + # This helper effectively just adds some extra asserts about what the backward will look like: + # Outputs must include a scalar loss, that we compute gradients w.r.t. + # We don't compute gradients w.r.t. anything else: so just in case we detach() + # and other output tensors. + def fn_to_trace(*args): + nonlocal num_fw_outs + out = functional_call(*args) + if output_loss_index is None: + raise RuntimeError( + """\ +If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss. +You must specify the which (index) output is the loss with output_loss_index.""" + ) + if isinstance(out, (torch.Tensor)): + out = (out,) + if not isinstance(out, (tuple, list)): + raise RuntimeError( + f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}" + ) + + for i, o in enumerate(out): + # We only want to create a backward graph w.r.t. the loss that the user passed in. + # This implies that every other output should not require gradients. + # Instead of making this an error (and forcing the user to detach all other outputs + # of their forward), + # we'll automatically detach them here. + if o.requires_grad and i != output_loss_index: + raise RuntimeError( + f"""\ +Found an output of the forward that requires gradients, that was not the scalar loss. +We require all outputs to the forward that are not the scalar loss to not require gradient, +because we will only compute a backward graph against the scalar loss. +You can fix this by calling .detach() on each of your forward outputs that is not the loss. +You specified that output index {output_loss_index} is the loss, but we found that +the output at index {i} requires gradients.""" + ) + out_loss = out[output_loss_index] + num_fw_outs = len(out) + if not out_loss.requires_grad: + raise RuntimeError( + f"""\ +The output at index {output_loss_index} was marked as the loss, but it does not require gradients""" + ) + if out_loss.numel() != 1: + raise RuntimeError( + f"""\ +We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}""" + ) + return out + + ctx = nullcontext + else: + # Run under no_grad, so our tracing machinery only traces an inference graph. + # However if pre_dispatch=True, we want to correctly trace set_grad_enabled calls for training. + ctx = nullcontext if pre_dispatch else torch.no_grad + fn_to_trace = functional_call + + full_args = [] + # First, the params + # NB: It is REQUIRED that parameters come first, Inductor infers "fixed" + # parameters by looking at the difference in parameter count outside + # and inside AOTAutograd, and assumes the prefix of arguments are fixed + # arguments + full_args.extend(params_and_buffers_flat) + # Next, the input args + full_args.extend(args) + + with ctx(): + fx_g, metadata, in_spec, out_spec = _aot_export_function( + fn_to_trace, + full_args, + decompositions=decompositions, + num_params_buffers=params_len, + no_tangents=True, + pre_dispatch=pre_dispatch, + dynamic_shapes=dynamic_shapes, + kwargs=kwargs, + ) + if trace_joint: + + @wraps(functional_call) + def flattened_joint(*args): + # The idea here is that the joint graph that AOTAutograd creates has some strict properties: + # (1) It accepts two arguments (primals, tangents), and pytree_flattens them + # (2) It returns a tuple of (fw_outs, gradients) + # This is a very useful convention for anyone who wants to partition the joint graph + # into a separate forward and backward graph. + # However, + # (1) for people exporting a single joint graph, it would be preferable not to have + # any pytrees in the graph. + # (2) We are guaranteed in the aot_export_module case that the forward outputs a loss, + # and there are therefore no tangents that are needed to run the joint graph. + # (3) AOTAutograd creates a grad_input for every input in the forward, + # including None's for inputs that are not grad-requiring tensors. + # we don't want these in our export graph. + # and there are therefore no tangents that are needed to run the joint graph. + # This function "fixes" both of the above by removing any tangent inputs, + # and removing pytrees from the original FX graph. + fake_tangents = [ + None + for _ in range( + metadata.num_outputs + metadata.num_mutated_inp_runtime_indices + ) + ] + fw_outs, gradients = fx_g(args, fake_tangents) + assert len(gradients) == len(args) + output_gradients = [] + for a, grad in zip(args, gradients): + if isinstance(a, torch.Tensor) and a.requires_grad: + assert ( + grad is not None + ), """\ +Found a parameter that did not receive a gradient. +"This is most likely a bug, but if this needs to be supported please comment on this Github issue: +https://github.com/pytorch/pytorch/issues/101192 +""" + output_gradients.append(grad) + else: + assert grad is None + return *fw_outs, *output_gradients + + fx_g = make_fx(flattened_joint, record_module_stack=True)(*full_args) + + user_args_flat = pytree.arg_tree_leaves(*args, **kwargs) + return fx_g, create_graph_signature( + fx_g, + metadata, + in_spec, + out_spec, + user_args_flat=user_args_flat, + params_and_buffers_flat=params_and_buffers_flat, + param_names=list(named_parameters.keys()), + buffer_names=list(named_buffers.keys()), + trace_joint=trace_joint, + num_user_fw_outs=num_fw_outs, + loss_index=output_loss_index, + ) + + +def aot_export_joint_simple( + func: Callable, + args, + *, + trace_joint: bool, + # It looks like the main consequence of this API is that for dynamic shapes, + # it will assume that parms/buffers are static. + # With the new inferred dynamic shapes API, maybe this doesn't matter? + num_params_buffers: int = 0, + decompositions: Optional[dict] = None, +) -> torch.fx.GraphModule: + """ + A simplified version of export. Used by higher order operators. + + This function makes a high-level "no calling convention changes" guarantee: + - If no inputs require grad (so we export an inference graph), + there are *no* calling convention change between the exported graph, and "func". + - If at least one input requires grad (so we trace out and export a joint fw-bw graph), + Then if you were partition the graph into a separate forward and backward graph, + The forward graph will have no calling convention changes compared to "func". + + The above also relies on some strong restrictions around which functions this API accepts: + (1) `args` cannot contain any pytrees (they must have been pytree_flattened already) + (2) `func` cannot mutate any inputs + (3) The outputs of `func` cannot alias any inputs. + + Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops. + """ + if trace_joint: + ctx = nullcontext + else: + # Run under no_grad, so our tracing machinery only traces an inference graph. + ctx = torch.no_grad + + with ctx(): + fx_g, metadata, in_spec, out_spec = _aot_export_function( + func, + args, + decompositions=decompositions, + ) + in_spec, _kw_in_spec = in_spec.children_specs + # At this point, we can just directly return the (joint or inference graph) that we traced. + # First though: a bunch of assertions to make sure that our graph doesn't require + # any calling convention changes compared to the original function. + # These restrictions are *in addition to* the general restrictions on export. + + # No input mutations + if ( + len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) + != 0 + ): + raise RuntimeError( + f"aot_export_joint_simple does not support input mutations. {str(metadata)}" + ) + # No output aliasing + if ( + len([x for x in metadata.output_info if x.output_type != OutputType.non_alias]) + != 0 + ): + raise RuntimeError( + f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}" + ) + # No pytrees + if in_spec.is_leaf(): + raise RuntimeError( + f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}" + ) + if not all(child.is_leaf() for child in in_spec.children_specs): + raise RuntimeError( + f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}" + ) + if out_spec.is_leaf(): + raise RuntimeError( + f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}" + ) + if not all(child.is_leaf() for child in out_spec.children_specs): + raise RuntimeError( + f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}" + ) + # TODO: we might have to temporarily patch config.functionalize_rng + # so that it doesn't run when we're exporting a higher order op. + + if config.debug_assert: + # Smoke test that after partitioning, we can run the forward without any calling convention changes. + fw_module, _bw_module = aot_config.default_partition( # noqa: F821 + fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821 + ) + # Attempt to run the fw_module with the original user inputs + fake_mode = detect_fake_mode(args) + if fake_mode is None: + fake_mode = FakeTensorMode() + with fake_mode: + fw_module(*args) + return fx_g + + +# Private for now because we aren't providing a contract on what to return +# for joint graphs (we could when there's a clearer use case) +# In the future, we may need to add more export API's that provide their own strong guarantees. +# This is meant as a general helper function for handling various export-y use cases. +def _aot_export_function( + func: Callable, + args, + *, + num_params_buffers: int = 0, + decompositions: Optional[dict] = None, + # If we're exporting a joint graph and we don't want any tangent inputs in the graph + # (because we are backpropping through a scalar 1 loss), + # we need to explicitly specify not to include tangents in the graph. + # It's not enough just to check that our tangent is a scalar, since we also + # need to know if it is a 1 (no need to make it a graph input), or something else + # (requiring it to be a graph input). + # We don't know this info at trace time though, so we need to make it an explicit config. + no_tangents: bool = False, + pre_dispatch: bool = False, + # If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong. + dynamic_shapes: Optional[bool] = None, + kwargs=None, +) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]: + kwargs = kwargs or {} + + flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs) + flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + + fake_mode = None + if dynamic_shapes is None: + # Try to infer `dynamic_shapes from inputs and graph nodes + fake_mode = detect_fake_mode(flat_args) + if ( + fake_mode is None + and hasattr(func, "_orig_mod") + and isinstance(func._orig_mod, torch.fx.GraphModule) + ): + vals = [ + node.meta["val"] + for node in func._orig_mod.graph.nodes + if "val" in node.meta + ] + fake_mode = detect_fake_mode(vals) + dynamic_shapes = fake_mode is not None and fake_mode.shape_env is not None + + # The export use case doesn't care about several bits of AOTConfig + # (1) compilers (we just export the graph) + # (2) partitioners (export is only full graph, user can partition themselves) + aot_config = AOTConfig( + fw_compiler=None, + bw_compiler=None, + inference_compiler=None, + partition_fn=None, + decompositions=decompositions, + num_params_buffers=num_params_buffers, + aot_id=next(AOT_COUNTER), + # For now there's no use case involving keeping input mutations in the graph + # (which we can only do in the inference case anyway). + # We can add this later if we need to. + keep_inference_input_mutations=False, + dynamic_shapes=dynamic_shapes, + aot_autograd_arg_pos_to_source=None, + is_export=True, + no_tangents=no_tangents, + pre_dispatch=pre_dispatch, + ) + if fake_mode is None: + fake_mode, shape_env = construct_fake_mode(flat_args, aot_config) + else: + shape_env = fake_mode.shape_env + fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env) + + fx_g, meta = create_aot_dispatcher_function( + flat_fn, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + return fx_g, meta, in_spec, out_spec.spec + + +@contextmanager +def _detect_attribute_assignment(mod: torch.nn.Module): + # Do not allow assignment of tensor attributes during export unless + # the attribute is registered as a buffer. + + NN_MODULE_STD_ATTRS = [ + "_backward_hooks", + "_backward_pre_hooks", + "_buffers", + "_forward_hooks", + "_forward_hooks_always_called", + "_forward_hooks_with_kwargs", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_is_full_backward_hook", + "_load_state_dict_post_hooks", + "_load_state_dict_pre_hooks", + "_modules", + "_non_persistent_buffers_set", + "_parameters", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "training", + ] + NN_MODULE_LAZY_STD_ATTRS = [ + "_initialize_hook", + "_load_hook", + ] + STD_ATTRS = { + *NN_MODULE_STD_ATTRS, + *NN_MODULE_LAZY_STD_ATTRS, + } + + def _get_attributes(mod): + # return any attributes of a module that are not standard attributes + return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} + + # save state of attributes before enter + snapshot = pytree.tree_map( + lambda x: x, + _get_attributes(mod), + is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info, + ) + try: + yield + finally: + # after exit, compare state of attributes with snapshot + # to detect which tensor attributes were assigned + assigned_tensor_attributes = [] + + def _collect_assigned_tensor_attributes(kp, v, _v): + if _v is not v: + attr, *rest = kp + if isinstance(v, torch.Tensor): + assigned_tensor_attributes.append( + f"self.{attr.key}{pytree.keystr(rest)}" + ) + # TODO(avik): Assigning all other types are allowed right now. + # Maybe in the future we want to limit this to primitive types? + return v + + new_attrs = _get_attributes(mod) + if len(new_attrs) != len(snapshot): + added_attrs = new_attrs.keys() - snapshot.keys() + deleted_attrs = snapshot.keys() - new_attrs.keys() + + if len(added_attrs) > 0: + raise ValueError( + f"During torch.export, following attrs were created in the model.forward: {added_attrs} " + f"Such attributes must be registered as buffers using the `register_buffer` " + f"API and must be initialized at model.__init__ " + f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + if len(deleted_attrs) > 0: + raise ValueError( + f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} " + f"Such attributes must be registered as buffers using the `register_buffer` " + f"API and must be initialized at model.__init__ " + f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + pytree.tree_map_with_path( + _collect_assigned_tensor_attributes, snapshot, new_attrs + ) + # restore state of all attributes (including, e.g., of primitive types) + mod.__dict__.update(snapshot) + + if assigned_tensor_attributes: + if len(assigned_tensor_attributes) > 1: + noun, verb = "attributes", "were" + else: + noun, verb = "attribute", "was" + raise ValueError( + f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " + "Such attributes must be registered as buffers using the `register_buffer` API " + "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + +compiled_function = aot_function +compiled_module = aot_module diff --git a/phivenv/Lib/site-packages/torch/_functorch/apis.py b/phivenv/Lib/site-packages/torch/_functorch/apis.py new file mode 100644 index 0000000000000000000000000000000000000000..ccede8682ab2b282c1e744d21b63113e8d52d874 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/apis.py @@ -0,0 +1,448 @@ +# mypy: allow-untyped-defs +# NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can +# trace through functorch transforms. +# Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing +# and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file +# to Dynamo. +import functools + +from torch._functorch.utils import argnums_t, exposed_in +from torch._functorch.vmap import ( + _check_out_dims_is_int_or_int_pytree, + _check_randomness_arg, + _chunked_vmap, + _process_batched_inputs, + Callable, + in_dims_t, + out_dims_t, + vmap_impl, +) + + +# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, +# sends those into func, and then unwraps the output BatchedTensors. Operations +# on BatchedTensors perform the batched operations that the user is asking for. +# +# vmap's randomness behavior differs from JAX's, which would require a PRNG key +# to be passed everywhere. + + +@exposed_in("torch.func") +def vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = "error", + *, + chunk_size=None, +) -> Callable: + """ + vmap is the vectorizing map; ``vmap(func)`` returns a new function that + maps ``func`` over some dimension of the inputs. Semantically, vmap + pushes the map into PyTorch operations called by ``func``, effectively + vectorizing those operations. + + vmap is useful for handling batch dimensions: one can write a function + ``func`` that runs on examples and then lift it to a function that can + take batches of examples with ``vmap(func)``. vmap can also be used to + compute batched gradients when composed with autograd. + + .. note:: + :func:`torch.vmap` is aliased to :func:`torch.func.vmap` for + convenience. Use whichever one you'd like. + + Args: + func (function): A Python function that takes one or more arguments. + Must return one or more Tensors. + in_dims (int or nested structure): Specifies which dimension of the + inputs should be mapped over. ``in_dims`` should have a + structure like the inputs. If the ``in_dim`` for a particular + input is None, then that indicates there is no map dimension. + Default: 0. + out_dims (int or Tuple[int]): Specifies where the mapped dimension + should appear in the outputs. If ``out_dims`` is a Tuple, then + it should have one element per output. Default: 0. + randomness (str): Specifies whether the randomness in this + vmap should be the same or different across batches. If 'different', + the randomness for each batch will be different. If 'same', the + randomness will be the same across batches. If 'error', any calls to + random functions will error. Default: 'error'. WARNING: this flag + only applies to random PyTorch operations and does not apply to + Python's random module or numpy randomness. + chunk_size (None or int): If None (default), apply a single vmap over inputs. + If not None, then compute the vmap :attr:`chunk_size` samples at a time. + Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop. + If you run into memory issues computing the vmap, please try a non-None chunk_size. + + Returns: + Returns a new "batched" function. It takes the same inputs as + ``func``, except each input has an extra dimension at the index + specified by ``in_dims``. It takes returns the same outputs as + ``func``, except each output has an extra dimension at the index + specified by ``out_dims``. + + .. warning: + :func:`vmap` works best with functional-style code. Please do not + perform any side-effects in ``func``, with the exception of + in-place PyTorch operations. Examples of side-effects include mutating + Python data structures and assigning values to variables not captured + in ``func``. + + One example of using :func:`vmap` is to compute batched dot products. PyTorch + doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully + rummaging through docs, use :func:`vmap` to construct a new function. + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] + >>> x, y = torch.randn(2, 5), torch.randn(2, 5) + >>> batched_dot(x, y) + + :func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler + model authoring experience. + + >>> batch_size, feature_size = 3, 5 + >>> weights = torch.randn(feature_size, requires_grad=True) + >>> + >>> def model(feature_vec): + >>> # Very simple linear model with activation + >>> return feature_vec.dot(weights).relu() + >>> + >>> examples = torch.randn(batch_size, feature_size) + >>> result = torch.vmap(model)(examples) + + :func:`vmap` can also help vectorize computations that were previously difficult + or impossible to batch. One example is higher-order gradient computation. + The PyTorch autograd engine computes vjps (vector-Jacobian products). + Computing a full Jacobian matrix for some function f: R^N -> R^N usually + requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`, + we can vectorize the whole computation, computing the Jacobian in a single + call to ``autograd.grad``. + + >>> # Setup + >>> N = 5 + >>> f = lambda x: x ** 2 + >>> x = torch.randn(N, requires_grad=True) + >>> y = f(x) + >>> I_N = torch.eye(N) + >>> + >>> # Sequential approach + >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] + >>> for v in I_N.unbind()] + >>> jacobian = torch.stack(jacobian_rows) + >>> + >>> # vectorized gradient computation + >>> def get_vjp(v): + >>> return torch.autograd.grad(y, x, v) + >>> jacobian = torch.vmap(get_vjp)(I_N) + + :func:`vmap` can also be nested, producing an output with multiple batched dimensions + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] + >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) + >>> batched_dot(x, y) # tensor of size [2, 3] + + If the inputs are not batched along the first dimension, ``in_dims`` specifies + the dimension that each inputs are batched along as + + >>> torch.dot # [N], [N] -> [] + >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] + >>> x, y = torch.randn(2, 5), torch.randn(2, 5) + >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension + + If there are multiple inputs each of which is batched along different dimensions, + ``in_dims`` must be a tuple with the batch dimension for each input as + + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None + + If the input is a Python struct, ``in_dims`` must be a tuple containing a struct + matching the shape of the input: + + >>> f = lambda dict: torch.dot(dict['x'], dict['y']) + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> input = {'x': x, 'y': y} + >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) + >>> batched_dot(input) + + By default, the output is batched along the first dimension. However, it can be batched + along any dimension by using ``out_dims`` + + >>> f = lambda x: x ** 2 + >>> x = torch.randn(2, 5) + >>> batched_pow = torch.vmap(f, out_dims=1) + >>> batched_pow(x) # [5, 2] + + For any function that uses kwargs, the returned function will not batch the kwargs but will + accept kwargs + + >>> x = torch.randn([2, 5]) + >>> def fn(x, scale=4.): + >>> return x * scale + >>> + >>> batched_pow = torch.vmap(fn) + >>> assert torch.allclose(batched_pow(x), x * 4) + >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] + + .. note:: + vmap does not provide general autobatching or handle variable-length + sequences out of the box. + """ + from torch.compiler import is_compiling + + _check_randomness_arg(randomness) + if not (chunk_size is None or chunk_size > 0): + raise ValueError( + f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})" + ) + + def wrapped(*args, **kwargs): + return vmap_impl( + func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs + ) + + if not is_compiling(): + wrapped = functools.wraps(func)(wrapped) + + return wrapped + + +def chunk_vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = "error", + chunks=2, +) -> Callable: + """ + chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes + everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of + chunks at a time. For more details about vectorizing map, see :func:`vmap`. + + .. note:: + Please use :func:`vmap` with ``chunk_size`` argument instead of this API. + + Args: + func (function): A Python function that takes one or more arguments. + Must return one or more Tensors. + in_dims (int or nested structure): Specifies which dimension of the + inputs should be mapped over. ``in_dims`` should have a + structure like the inputs. If the ``in_dim`` for a particular + input is None, then that indicates there is no map dimension. + Default: 0. + out_dims (int or Tuple[int]): Specifies where the mapped dimension + should appear in the outputs. If ``out_dims`` is a Tuple, then + it should have one element per output. Default: 0. + randomness (str): Specifies whether the randomness in this + vmap should be the same or different across batches. If 'different', + the randomness for each batch will be different. If 'same', the + randomness will be the same across batches. If 'error', any calls to + random functions will error. Default: 'error'. WARNING: this flag + only applies to random PyTorch operations and does not apply to + Python's random module or numpy randomness. + chunks (int): Number of chunks to use to split the input data. Default is 2. + If equals to 1 then :func:`vmap` is called. + + Returns: + Returns a new "batched" function. It takes the same inputs as + ``func``, except each input has an extra dimension at the index + specified by ``in_dims``. It takes returns the same outputs as + ``func``, except each output has an extra dimension at the index + specified by ``out_dims``. + """ + _check_randomness_arg(randomness) + + if chunks == 1: + return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness) + + def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_): + flat_args_chunks = tuple( + t.chunk(chunks_, dim=in_dim) + if in_dim is not None + else [ + t, + ] + * chunks_ + for t, in_dim in zip(flat_args_, flat_in_dims_) + ) + # transpose chunk dim and flatten structure + # chunks_flat_args is a list of flatten args + chunks_flat_args = zip(*flat_args_chunks) + return chunks_flat_args + + @functools.wraps(func) + def wrapped_with_chunks(*args, **kwargs): + _check_out_dims_is_int_or_int_pytree(out_dims, func) + _, flat_in_dims, flat_args, args_spec = _process_batched_inputs( + in_dims, args, func + ) + # Chunk flat arguments + chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks) + + # Apply vmap on chunks + return _chunked_vmap( + func, + flat_in_dims, + chunks_flat_args, + args_spec, + out_dims, + randomness, + **kwargs, + ) + + return wrapped_with_chunks + + +@exposed_in("torch.func") +def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + """``grad`` operator helps computing gradients of ``func`` with respect to the + input(s) specified by ``argnums``. This operator can be nested to + compute higher-order gradients. + + Args: + func (Callable): A Python function that takes one or more arguments. + Must return a single-element Tensor. If specified ``has_aux`` equals ``True``, + function can return a tuple of single-element Tensor and other auxiliary objects: + ``(output, aux)``. + argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to. + ``argnums`` can be single integer or tuple of integers. Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a tensor and other + auxiliary objects: ``(output, aux)``. Default: False. + + Returns: + Function to compute gradients with respect to its inputs. By default, the output of + the function is the gradient tensor(s) with respect to the first argument. + If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects + is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with + respect to each ``argnums`` value is returned. + + Example of using ``grad``: + + >>> # xdoctest: +SKIP + >>> from torch.func import grad + >>> x = torch.randn([]) + >>> cos_x = grad(lambda x: torch.sin(x))(x) + >>> assert torch.allclose(cos_x, x.cos()) + >>> + >>> # Second-order gradients + >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) + >>> assert torch.allclose(neg_sin_x, -x.sin()) + + When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients: + + >>> # xdoctest: +SKIP + >>> from torch.func import grad, vmap + >>> batch_size, feature_size = 3, 5 + >>> + >>> def model(weights, feature_vec): + >>> # Very simple linear model with activation + >>> assert feature_vec.dim() == 1 + >>> return feature_vec.dot(weights).relu() + >>> + >>> def compute_loss(weights, example, target): + >>> y = model(weights, example) + >>> return ((y - target) ** 2).mean() # MSELoss + >>> + >>> weights = torch.randn(feature_size, requires_grad=True) + >>> examples = torch.randn(batch_size, feature_size) + >>> targets = torch.randn(batch_size) + >>> inputs = (weights, examples, targets) + >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) + + Example of using ``grad`` with ``has_aux`` and ``argnums``: + + >>> # xdoctest: +SKIP + >>> from torch.func import grad + >>> def my_loss_func(y, y_pred): + >>> loss_per_sample = (0.5 * y_pred - y) ** 2 + >>> loss = loss_per_sample.mean() + >>> return loss, (y_pred, loss_per_sample) + >>> + >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) + >>> y_true = torch.rand(4) + >>> y_preds = torch.rand(4, requires_grad=True) + >>> out = fn(y_true, y_preds) + >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample)) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``grad``. + + Case 1: Using ``torch.no_grad`` inside a function: + + >>> # xdoctest: +SKIP + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``grad`` inside ``torch.no_grad`` context manager: + + >>> # xdoctest: +SKIP + >>> with torch.no_grad(): + >>> grad(f)(x) + + In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``grad`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + + """ + # To avoid cyclical dependency. + import torch._functorch.eager_transforms as eager_transforms + from torch.compiler import is_compiling + + def wrapper(*args, **kwargs): + return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) + + if not is_compiling(): + wrapper = functools.wraps(func)(wrapper) + + return wrapper + + +@exposed_in("torch.func") +def grad_and_value( + func: Callable, argnums: argnums_t = 0, has_aux: bool = False +) -> Callable: + """ + Returns a function to compute a tuple of the gradient and primal, or + forward, computation. + + Args: + func (Callable): A Python function that takes one or more arguments. + Must return a single-element Tensor. If specified ``has_aux`` + equals ``True``, function can return a tuple of single-element + Tensor and other auxiliary objects: ``(output, aux)``. + argnums (int or Tuple[int]): Specifies arguments to compute gradients + with respect to. ``argnums`` can be single integer or tuple of + integers. Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a tensor and + other auxiliary objects: ``(output, aux)``. Default: False. + + Returns: + Function to compute a tuple of gradients with respect to its inputs + and the forward computation. By default, the output of the function is + a tuple of the gradient tensor(s) with respect to the first argument + and the primal computation. If specified ``has_aux`` equals + ``True``, tuple of gradients and tuple of the forward computation with + output auxiliary objects is returned. If ``argnums`` is a tuple of + integers, a tuple of a tuple of the output gradients with respect to + each ``argnums`` value and the forward computation is returned. + + See :func:`grad` for examples + """ + from torch._functorch import eager_transforms + from torch.compiler import is_compiling + + def wrapper(*args, **kwargs): + return eager_transforms.grad_and_value_impl( + func, argnums, has_aux, args, kwargs + ) + + if not is_compiling(): + wrapper = functools.wraps(func)(wrapper) + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/_functorch/autograd_function.py b/phivenv/Lib/site-packages/torch/_functorch/autograd_function.py new file mode 100644 index 0000000000000000000000000000000000000000..92a1b356cd463554591f99dc4a80b7a56abae3e6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/autograd_function.py @@ -0,0 +1,777 @@ +# mypy: allow-untyped-defs +from typing import NamedTuple + +import torch +import torch.utils._pytree as pytree +from torch._C._functorch import ( + _unwrap_for_grad, + _wrap_for_grad, + current_level, + TransformType, +) +from torch._functorch.apis import vmap +from torch._functorch.utils import enable_single_level_autograd_function +from torch._functorch.vmap import ( + _add_batch_dim, + _broadcast_to_and_flatten, + restore_vmap, + unwrap_batched, + wrap_batched, +) +from torch._ops import HigherOrderOperator +from torch.autograd.forward_ad import _set_fwd_grad_enabled + + +# autograd.Function technically runs before the regular PyTorch dispatcher. +# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot) +# work with it. One day we might decide to change this, but until then, +# we need to give the illusion that autograd.Function runs before those things. +# +# We do this by using creating a custom HigherOrderOperator that only functorch +# dispatches specially. +class CustomFunctionHigherOrderOperator(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("custom_function_call") + + def __call__(self, autograd_function, *args, **kwargs): + # When custom_function_call is done dispatching through functorch, + # it should just invoke the autograd.Function. This is consistent + # with the autograd.Function behavior of being invoked before the + # PyTorch dispatcher. + # + # This will lead us into trouble later down the line, but this is + # pre-existing. There is an invariant that a function traced by + # make_fx should have the same behavior when provided the same + # Tensor. However, make_fx sees autograd.Function as a composite + # (because autograd.Function happens before the Python dispatch key) + # and only traces the forward pass. + if torch._C._are_functorch_transforms_active(): + return super().__call__(autograd_function, *args, **kwargs) + return autograd_function.apply(*args, **kwargs) + + +# "custom_function_call" +# This is the mechanism for an autograd.Function that works with functorch transforms. +# It wraps an autograd.Function; interactions with functorch transforms are defined +# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch +# dispatcher. +custom_function_call = CustomFunctionHigherOrderOperator() + + +# The grad rule for custom_function_call is to construct a new _SingleLevelFunction +# (autograd.Function that only works with a single layer (level) of functorch) that: +# - unwraps the inputs +# - redispatches to custom_function_call +# - wraps the outputs +# and whose backward pass calls the original autograd.Function's backward. +# +# Why do we need to redispatch to custom_function_call? +# ----------------------------------------------------- +# This is consistent with how ATen operators work with functorch's grad transform: +# they always redispatch to the original operator. +# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x) +# +# grad1 will: +# - set up the autograd graph +# - unwrap the inputs +# - redispatch to at::sin (*) +# - rewrap the outputs on the return +# +# On the redispatch in (*), grad0 will: +# - set up the autograd graph +# - unwrap the inputs +# - redispatch to at::sin +# - rewrap the outputs on the return +# +# To "set up the autograd graph", we generate a _SingleLevelFunction +# and apply it. +@custom_function_call.py_impl(TransformType.Grad) +@custom_function_call.py_impl(TransformType.Jvp) +def custom_function_call_grad(interpreter, autograd_function, *operands): + Generated = generate_single_level_function(interpreter, autograd_function) + with enable_single_level_autograd_function(): + flat_out = Generated.apply(*operands) + return flat_out + + +def generate_single_level_function(interpreter, autograd_function): + level = interpreter.level() + + def forward(*operands): + unwrapped_operands = pytree.tree_map_only( + torch.Tensor, lambda x: _unwrap_for_grad(x, level), operands + ) + # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter + # the transform. _SingleLevelFunction will turn off both fwd and bwd + # gradient computation and we need to turn it back on here. + with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower(): + unwrapped_output = custom_function_call( + autograd_function, *unwrapped_operands + ) + + # See NOTE [mark_dirty object identity check] + def wrap_fn(output): + return _wrap_for_grad(output, level) + + return wrap_outputs_maintaining_identity( + unwrapped_output, unwrapped_operands, operands, wrap_fn + ) + + def setup_context(ctx, inputs, output): + return autograd_function.setup_context(ctx, inputs, output) + + # backward is only used if the transform is TransformType.Grad + def backward(ctx, *grads): + result = autograd_function.backward(ctx, *grads) + return result + + # jvp is only used if the transform is TransformType.Jvp + def jvp(ctx, *tangents): + result = autograd_function.jvp(ctx, *tangents) + return result + + # This is the sequence of magic words to dynamically generate a Subclass with + # a given name. A Tensor's .grad_fn field has a class name that is the original + # autograd.Function's name + Backward, so we do this to generate some + # meaningful name. + name = f"{autograd_function.__name__}Generated" + Generated = type( + name, + (torch.autograd.function._SingleLevelFunction,), + { + "forward": staticmethod(forward), + "backward": staticmethod(backward), + "jvp": staticmethod(jvp), + "setup_context": staticmethod(setup_context), + }, + ) + return Generated + + +# wrap_outputs_maintaining_identity handles outputs from the vmap, +# backward (vjp), and jvp staticmethod. The way it distinguishes +# between the vmap case and the {backward, jvp} case is if the out_dims +# are specified or not. +# +# NB: we cannot use out_dims=None as the deciding factor. This because +# out_dims=None can still happen in the vmap staticmethod! What the +# user is saying in that case is that their output does not have a +# dimension that is being vmapped over, which is valid. +NO_OUT_DIMS = "not specified" + + +# NOTE [mark_dirty object identity check] +# autograd.Function's ctx.mark_dirty expect a returned input +# to have the same object identity as the input. +# Mode-only functorch will greatly simplify this logic. +def wrap_outputs_maintaining_identity( + outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS +): + flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs) + flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs) + + unwrapped_input_to_orig_input = { + id(unwrapped): orig + for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs) + } + + flat_outputs, spec = pytree.tree_flatten(outputs) + result = [] + + out_dims_specified = out_dims != NO_OUT_DIMS + + if out_dims_specified: + flat_out_dims = _broadcast_to_and_flatten(out_dims, spec) + # _broadcast_to_and_flatten returns None if it is unable to broadcast. + # TODO: update following link from master to stable once that's out + if flat_out_dims is None: + raise RuntimeError( + f"The autograd.Function's vmap staticmethod returned an " + f"incompatible (output, out_dims) tuple. " + f"Expected out_dims={out_dims} " + f"to be compatible with the structure of `output`. " + f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} " + f"but output has structure {spec}. " + f"For more details, please see " + f"https://pytorch.org/docs/main/notes/extending.func.html" + ) + + for i, output in enumerate(flat_outputs): + if not isinstance(output, torch.Tensor): + result.append(output) + continue + if id(output) in unwrapped_input_to_orig_input: + result.append(unwrapped_input_to_orig_input[id(output)]) + continue + if out_dims_specified: + result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index] + else: + result.append(wrap_fn(output)) + + return pytree.tree_unflatten(result, spec) + + +# NOTE: [functorch vjp and autograd interaction] +# There's an edge case with the functorch vjp and autograd interaction +# that will eventually be fixed by mode-only functorch. +# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper, +# so we (the framework) need to do it manually. Regular PyTorch operators +# automatically do so this is consistent. +# +# class MyExp(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return x.exp() +# +# @staticmethod +# def setup_context(ctx, inputs, output): +# y = output +# ctx.save_for_backward(y) +# +# @staticmethod +# def backward(gy): +# y, = ctx.saved_tensors() +# return MyMul.apply(gy, y) +# +# x = torch.randn([], requires_grad=True) +# gy = torch.randn([], requires_grad=True) +# _, vjp_fn = vjp(MySin.apply, x) +# result = vjp_fn(gy) +# +# MyMul is an autograd.Function that is not shown here. +# It saves a `y` for backward (since gy requires grad). +# +# in vjp_fn(gy), we get: +# > MyMul.apply(gy, GradTensorWrapper(y, level=dead)) +# Because the y that is saved for backward by MyExp is a GradTensorWrapper +# but is now dead since we are outside the vjp context. +# +# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper, +# will automatically unwrap the GradTensorWrapper when applied. +# But since autograd.Function technically sits above the regular PyTorch +# dispatcher, it doesn't get this treatment. So we manually do +# the unwrapping to be consistent with regular PyTorch dispatcher operations. + + +class VmapInfo(NamedTuple): + batch_size: int + randomness: str + + +def has_overridden_vmap_rule(autograd_function): + return autograd_function.vmap is not torch.autograd.Function.vmap + + +def validate_vmap_returns_tuple_of_two_elements(result): + base_error_msg = ( + "Expected the vmap staticmethod to have two returns, an output " + "and out_dims with pytree structure compatible with the output. " + ) + if not isinstance(result, tuple): + raise RuntimeError(base_error_msg + f"Got a {type(result)} instead") + if not len(result) == 2: + raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead") + + +@custom_function_call.py_impl(TransformType.Vmap) +def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs): + if any( + isinstance(val, torch.Tensor) + for val in torch.utils._pytree.tree_flatten(kwargs)[0] + ): + raise NotImplementedError( + f"Run vmap on autograd.Function with kwarg-only Tensor args. " + f"Please do not pass kwarg-only Tensors to autograd.Function. " + f"Got: {kwargs}" + ) + + if autograd_function.generate_vmap_rule: + if has_overridden_vmap_rule(autograd_function): + # TODO: Update link to stable once that's out + # https://github.com/pytorch/pytorch/issues/92029 + raise RuntimeError( + f"You tried to vmap over {autograd_function.__name__}, but " + f"it has both generate_vmap_rule=True and an overridden vmap " + f"staticmethod. Please set generate_vmap_rule=False or delete " + f"the overridden vmap staticmethod to avoid ambiguity. " + f"For more details, please see " + f"https://pytorch.org/docs/main/notes/extending.func.html" + ) + return custom_function_call_vmap_generate_rule( + interpreter, autograd_function, *operands + ) + + if not has_overridden_vmap_rule(autograd_function): + # TODO: Update link to stable once that's out + # https://github.com/pytorch/pytorch/issues/92029 + raise RuntimeError( + f"You tried to vmap over {autograd_function.__name__}, but " + f"it does not have vmap support. Please override and implement the " + f"vmap staticmethod or set generate_vmap_rule=True. " + f"For more details, please see " + f"https://pytorch.org/docs/main/notes/extending.func.html" + ) + + return custom_function_call_vmap_helper( + interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs + ) + + +def custom_function_call_vmap_helper( + interpreter, vmap_function, op, *operands, **kwargs +): + current_level = interpreter.level() + info = VmapInfo( + batch_size=interpreter.batch_size(), + randomness=interpreter.randomness(), + ) + # We're either in the autograd.Function case (vmap staticmethod) + # or the torch.library.register_vmap case. + autograd_function_case = isinstance(op, torch.autograd.function.FunctionMeta) + + def lower_to_next(): + if autograd_function_case: + return interpreter.lower() + else: + return torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.FuncTorchBatched) + ) + + unwrapped_operands, in_dims = unwrap_batched(operands, current_level) + # If none of the tensors are batched at the current level, then we skip the + # current level. This saves the user from needing to handle this case in + # their vmap staticmethod (and is consistent with our C++ batching rule API) + if pytree.tree_all(lambda dim: dim is None, in_dims): + with lower_to_next(): + if autograd_function_case: + return custom_function_call(op, *operands) + else: + return op(*operands, **kwargs) + + with lower_to_next(): + result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs) + validate_vmap_returns_tuple_of_two_elements(result) + unwrapped_output, out_dims = result + + # See NOTE [mark_dirty object identity check] + def wrap_fn(output, out_dim): + return ( + output + if out_dim is None + else _add_batch_dim(output, out_dim, current_level) + ) + + return wrap_outputs_maintaining_identity( + unwrapped_output, unwrapped_operands, operands, wrap_fn, out_dims=out_dims + ) + + +def unpack_outputs(outputs): + out_dims = outputs[-1] + if isinstance(out_dims, tuple): + outputs = outputs[:-1] + else: + outputs = outputs[0] + return outputs, out_dims + + +def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands): + unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level()) + vmapped_function = vmapify_autograd_function( + autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness() + ) + with interpreter.lower(): + outputs = custom_function_call(vmapped_function, *unwrapped_operands) + + assert isinstance(outputs, tuple) + outputs, out_dims = unpack_outputs(outputs) + return wrap_batched(outputs, out_dims, interpreter.level()) + + +@custom_function_call.py_impl(TransformType.Functionalize) +def custom_function_call_functionalize( + interpreter, autograd_function, generate_vmap_rule, *operands +): + raise RuntimeError("NYI: Functionalize rule for custom_function_call") + + +def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness): + def forward(*operands): + outputs, out_dims = restore_vmap( + autograd_function.forward, in_dims, batch_size, randomness + )(*operands) + if isinstance(outputs, torch.Tensor): + return outputs, out_dims + else: + return *outputs, out_dims + + def setup_context(ctx, inputs, outputs): + outputs, out_dims = unpack_outputs(outputs) + key = id(Generated) + + def inner(inputs, outputs): + # wrapped_ctx.save_for_backward will: + # - unwrap batchedtensors into (tensor, bdim) + # - save_for_backward(*unwrapped_tensors) + # - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims + wrapped_ctx = CtxCustomSave(ctx, current_level()) + autograd_function.setup_context(wrapped_ctx, inputs, outputs) + + # input_shapes are used for reductify later to reduce expanded gradients + # to the correct shape. + # See NOTE: [Why can't we rely on autograd to reduce expanded gradients?] + # for more details + input_shapes = tuple( + inp.shape if isinstance(inp, torch.Tensor) else None for inp in inputs + ) + if not hasattr(ctx, "_pt_input_shapes"): + ctx._pt_input_shapes = {} + ctx._pt_input_shapes.update({key: input_shapes}) + + if not hasattr(ctx, "_pt_saved_tensors_bdims_stack"): + ctx._pt_saved_tensors_bdims_stack = {} + ctx._pt_saved_tensors_bdims_stack.update( + {key: (wrapped_ctx._pt_saved_tensors_bdims)} + ) + + # See NOTE: [Why do we need to run setup_context under a vmap?] + restore_vmap( + inner, + (in_dims, out_dims), + batch_size, + randomness, + )(inputs, outputs) + + if not hasattr(ctx, "_pt_out_dims"): + ctx._pt_out_dims = {} + ctx._pt_out_dims.update({key: out_dims}) + + def jvp(ctx, *tangents): + key = id(Generated) + + def jvp_no_context(saved_tensors, tangents): + wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) + return autograd_function.jvp(wrapped_ctx, *tangents) + + tangent_in_dims = get_tangents_in_dims(in_dims, tangents) + out_tangents, out_tangents_dims = restore_vmap( + jvp_no_context, + (ctx._pt_saved_tensors_bdims_stack[key], tangent_in_dims), + batch_size, + randomness, + )(ctx.saved_tensors, tangents) + + result = reductify( + out_tangents, out_tangents_dims, ctx._pt_out_dims[key], batch_size + ) + if isinstance(result, torch.Tensor): + return result, None + else: + return *result, None + + def backward(ctx, *grad_outputs): + key = id(Generated) + grad_outputs_ = grad_outputs[:-1] + grad_outputs_in_dims = ctx._pt_out_dims[key] + + if not isinstance(grad_outputs_in_dims, tuple): + grad_outputs_in_dims = (grad_outputs_in_dims,) + + grad_outputs_in_dims = tuple( + in_dim if grad_output is not None else None + for grad_output, in_dim in zip(grad_outputs_, grad_outputs_in_dims) + ) + + def backward_no_context(inputs): + saved_tensors, grad_outputs = inputs + wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) + return autograd_function.backward(wrapped_ctx, *grad_outputs) + + grad_ins, grad_ins_dims = restore_vmap( + backward_no_context, + ((ctx._pt_saved_tensors_bdims_stack[key], grad_outputs_in_dims),), + batch_size, + randomness, + )((ctx.saved_tensors, grad_outputs_)) + result = reductify( + grad_ins, grad_ins_dims, in_dims, batch_size, ctx._pt_input_shapes[key] + ) + return result + + name = f"Vmapped{autograd_function.__name__}" + Generated = type( + name, + (torch.autograd.Function,), + { + "forward": staticmethod(forward), + "backward": staticmethod(backward), + "jvp": staticmethod(jvp), + "setup_context": staticmethod(setup_context), + "generate_vmap_rule": True, + }, + ) + + return Generated + + +# tangents might be None, so we need to replace +# the corresponding in_dims with None. +def get_tangents_in_dims(input_dims, tangents): + flat_in_dims, spec = pytree.tree_flatten(input_dims) + flat_tangents = pytree.arg_tree_leaves(*tangents) + result = [ + None if tangent is None else in_dim + for in_dim, tangent in zip(flat_in_dims, flat_tangents) + ] + return pytree.tree_unflatten(result, spec) + + +# NOTE: [Why do we need to run setup_context under a vmap?] +# Consider the following autograd.Function +# +# class Sum(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return x.sum() +# @staticmethod +# def setup_context(ctx, inputs, outputs): +# ctx.x_shape = inputs[0] +# @staticmethod +# def backward(ctx, gy): +# return gy.expand(ctx.x_shape) +# +# x = torch.randn(B, 4) +# in_dims = 0 +# vmap(Sum.apply, in_dims)(x) +# +# Let's assume for a moment that we didn't vmap setup_context in VmappedSum: +# +# class VmappedSum(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return vmap(Sum.forward, in_dims)(x) +# +# @staticmethod +# def setup_context(ctx, inputs, outputs): +# Sum.setup_context(ctx, inputs, outputs) +# +# @staticmethod +# def backward(ctx, gy): +# def backward_no_context(gy): +# return gy.expand(ctx.x_shape) +# +# dims = (0,) +# gx = vmap(backward_no_context, dims)(gy) +# return gx +# +# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B], +# and we're doing: +# +# def backward_no_context(gy): +# return gy.expand([B, 4]) +# +# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]") +# +# This gives us the wrong result (gx has shape [B, B, 4], but it should +# have shape [4]). Performing vmap over setup_context means the shape +# saved has shape [4] and leads to a correct result shape for gx. + + +# Wraps a ctx object. Forwards all attr accesses to the underlying object +# except for the attrs in _pt_attrs +class WrappedCtx: + _pt_reserved_attrs: tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx") + + def __init__(self, ctx): + if not isinstance(ctx, WrappedCtx): + reserved_attrs = type(self)._pt_reserved_attrs + for name in reserved_attrs: + if not hasattr(ctx, name): + continue + raise RuntimeError( + f"PyTorch reserves the {reserved_attrs} field on ctx. " + "Please name your fields on ctx something else to avoid name " + "collision." + ) + self._pt_inner_ctx = ctx + + def __getattr__(self, name): + return getattr(self._pt_inner_ctx, name) + + def __setattr__(self, name, value): + if name in type(self)._pt_reserved_attrs: + self.__dict__[name] = value + return + return setattr(self._pt_inner_ctx, name, value) + + +# Wraps ctx to create a new ctx object that overrides saved_tensors. +class CtxWithSavedTensors(WrappedCtx): + _pt_reserved_attrs = ("_pt_new_saved_tensors", *WrappedCtx._pt_reserved_attrs) + + def __init__(self, ctx, new_saved_tensors): + super().__init__(ctx) + self._pt_new_saved_tensors = new_saved_tensors + + @property + def saved_tensors(self): + return self._pt_new_saved_tensors + + +class CtxCustomSave(WrappedCtx): + _pt_reserved_attrs = ( + "_pt_saved_tensors_bdims", + "_pt_current_level", + *WrappedCtx._pt_reserved_attrs, + ) + + def __init__(self, ctx, current_level): + super().__init__(ctx) + self._pt_saved_tensors_bdims = () + self._pt_current_level = current_level + + def save_for_backward(self, *tensors): + unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) + self._pt_inner_ctx.save_for_backward(*unwrapped_tensors) + self._pt_saved_tensors_bdims = bdims + + def save_for_forward(self, *tensors): + unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) + self._pt_inner_ctx.save_for_forward(*unwrapped_tensors) + self._pt_saved_tensors_bdims = bdims + + +def reductify( + grad_input, + grad_input_bdim, + input_bdim, + batch_size, + target_shape_without_bdim_to_reduce_to=None, +): + if not isinstance(grad_input, tuple): + grad_input = (grad_input,) + if not isinstance(grad_input_bdim, tuple): + grad_input_bdim = (grad_input_bdim,) + if not isinstance(input_bdim, tuple): + input_bdim = (input_bdim,) + + if target_shape_without_bdim_to_reduce_to is None: + target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,) + result = tuple( + reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape) + for gi, gi_bdim, i_bdim, maybe_ishape in zip( + grad_input, + grad_input_bdim, + input_bdim, + target_shape_without_bdim_to_reduce_to, + ) + ) + return result + + +def reductify_leaf( + grad_input, + grad_input_bdim, + input_bdim, + batch_size, + target_shape_without_bdim_to_reduce_to=None, +): + if grad_input is None: + return None + + if grad_input_bdim is None and input_bdim is None: + return grad_input + + if grad_input_bdim is not None and input_bdim is None: + return grad_input.sum(grad_input_bdim) + + # NOTE: [Why can't we rely on autograd to reduce expanded gradients?] + # For reverse-mode AD, + # given a grad_input and input, it is valid for the user to return a + # grad_input that has a broadcasted shape when compared to the input. + # In this situation, autograd automatically reduces the grad_input to + # the shape of the input. + # + # However, when input_bdim is not None, we have problems. + # + # [example 1] + # grad_input: Tensor[3, 4], input: Tensor[B, 4] + # We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable + # from [B, 4]. + # + # [example 2] + # grad_input: Tensor[3, B, 4], input: Tensor[B, 4] + # We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable + # from [B, 4]. + # + # This means that we need to also reduce the grad_input to the shape of the + # input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag; + # if not-None then we do the reducing manually, otherwise, we do not do a reduction. + assert input_bdim is not None + + if grad_input_bdim is None: + grad_input = grad_input.unsqueeze(input_bdim) + new_shape = list(grad_input.shape) + new_shape[input_bdim] = batch_size + grad_input = grad_input.expand(new_shape) + grad_input_bdim = input_bdim + + if target_shape_without_bdim_to_reduce_to is not None: + return vmap( + torch.Tensor.sum_to_size, + in_dims=(grad_input_bdim, None), + out_dims=input_bdim, + )(grad_input, target_shape_without_bdim_to_reduce_to) + + if input_bdim != grad_input_bdim: + grad_input = grad_input.movedim(grad_input_bdim, input_bdim) + return grad_input + + +def autograd_function_forward_rewritten(original_forward, original_setup_context): + def new_forward(ctx, *args, **kwargs): + output = original_forward(*args, **kwargs) + original_setup_context(ctx, args, output) + return output + + return new_forward + + +class AutogradFunctionApply(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("autograd_function_apply") + + def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs): + saved_values = None + args_tensor_mask = fwd_kwargs["args_tensor_mask"] + non_differentiable_idx = fwd_kwargs["non_differentiable_idx"] + length_of_tensor_args = sum(args_tensor_mask) + # Filter out the original tensor args from fwd_args, + # lifted freevars should not be args of ApplyTemplate.apply + # since we don't need to calculate the gradients of them. + new_fwd_args = fwd_args[:length_of_tensor_args] + + class ApplyTemplate(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + nonlocal saved_values + output, saved_values = fwd(None, *fwd_args) + + # If users call ctx.mark_non_differentiable() in the original fwd function. + if len(non_differentiable_idx) > 0: + non_differentiable_output = [] + for i, x in enumerate(output): + if i in non_differentiable_idx: + non_differentiable_output.append(x) + ctx.mark_non_differentiable(*non_differentiable_output) + + return output + + @staticmethod + def backward(ctx, *grad): + return bwd(None, *grad, *saved_values) + + return ApplyTemplate.apply(*new_fwd_args) + + +autograd_function_apply = AutogradFunctionApply() diff --git a/phivenv/Lib/site-packages/torch/_functorch/batch_norm_replacement.py b/phivenv/Lib/site-packages/torch/_functorch/batch_norm_replacement.py new file mode 100644 index 0000000000000000000000000000000000000000..5a9a1b22920d3d19c00e7f7bad45cebb9207a642 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/batch_norm_replacement.py @@ -0,0 +1,27 @@ +import torch.nn as nn +from torch._functorch.utils import exposed_in + + +def batch_norm_without_running_stats(module: nn.Module) -> None: + if ( + isinstance(module, nn.modules.batchnorm._BatchNorm) + and module.track_running_stats + ): + module.running_mean = None + module.running_var = None + module.num_batches_tracked = None + module.track_running_stats = False + + +@exposed_in("torch.func") +def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module: + """ + In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and + setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root` + """ + # base case + batch_norm_without_running_stats(root) + + for obj in root.modules(): + batch_norm_without_running_stats(obj) + return root diff --git a/phivenv/Lib/site-packages/torch/_functorch/benchmark_utils.py b/phivenv/Lib/site-packages/torch/_functorch/benchmark_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1de9397ba5969cab6c945569fc28a03e540fe1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/benchmark_utils.py @@ -0,0 +1,231 @@ +# mypy: ignore-errors + +import contextlib +import json +import operator +import os +import time + +import torch +from torch.profiler import profile, ProfilerActivity + + +def synchronize(): + pass + + +def dump_chrome_trace( + f, + input, + trace_filename, + optimize_ctx, + activities, + num_runs=1, + devices=None, + kwargs_for_f=None, + kwargs_for_profiler=None, +): + """ + Output the chrome trace of running f(input, **kwargs_for_f) with [optimize_ctx] + [num_runs] times to [trace_filename]. + + [activities] are the activities that the profiler will record, e.g. ProfilerActivity.CUDA. + Return total runtime without the profiler + + Outputs to trace_filename + """ + + if devices is None: + devices = ["cuda"] + + global synchronize + if devices != ["cpu"] and torch.cuda.is_available(): + synchronize = torch.cuda.synchronize + + if kwargs_for_f is None: + kwargs_for_f = {} + if kwargs_for_profiler is None: + kwargs_for_profiler = {} + + with optimize_ctx: + torch.manual_seed(1337) + for _ in range(5): # warmup runs + f(input, **kwargs_for_f) + synchronize() + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(num_runs): + f(input, **kwargs_for_f) + synchronize() + t1 = time.perf_counter() + timing = t1 - t0 + + with profile(activities=activities, **kwargs_for_profiler) as prof: + with optimize_ctx: + synchronize() + torch.manual_seed(1337) + for _ in range(num_runs): + f(input, **kwargs_for_f) + synchronize() + prof.export_chrome_trace(trace_filename) + + return timing + + +def get_chrome_trace_events(filename): + f = open(filename) + data = json.load(f) + events = data["traceEvents"] + return events + + +def is_gpu_compute_event(event): + global gpu_pids + return ( + "pid" in event + and event["pid"] in gpu_pids + and "ph" in event + and event["ph"] == "X" + ) + + +def get_sorted_gpu_events(events): + sorted_gpu_events = [] + for event in events: + if not is_gpu_compute_event(event): + continue + sorted_gpu_events.append(event) + return sorted(sorted_gpu_events, key=operator.itemgetter("ts")) + + +def get_duration(sorted_gpu_events): + if len(sorted_gpu_events) == 0: + return 0 + event = sorted_gpu_events[0] + current_end_time = event["ts"] + event["dur"] + total_duration = event["dur"] + for event in sorted_gpu_events[1:]: + start_time = max(event["ts"], current_end_time) + end_time = event["ts"] + event["dur"] + total_duration = total_duration + max(end_time - start_time, 0) + current_end_time = max(current_end_time, end_time) + return total_duration + + +def get_sorted_gpu_mm_conv_events(events): + def is_mm_conv_event(event): + return "name" in event and ( + "gemm" in event["name"] + or "conv" in event["name"] + or "cutlass" in event["name"] + or "wgrad" in event["name"] + ) + + gpu_events = get_sorted_gpu_events(events) + sorted_events = [] + for event in gpu_events: + if not is_mm_conv_event(event): + continue + sorted_events.append(event) + return sorted_events + + +gpu_pids = [] + + +def compute_utilization(filename: str, total_length: float): + """ + Process the chrome traces outputs by the pytorch profiler to compute GPU Utilization + and percent of times spent on matmul and convolution + + Args: + filename(str): Name of chrome traces file produced by pytorch profiler + + total_length(float): total length of the process without profiler in second + + Return: + tuple: (GPU Utilization, percent of time spent on matmul and convolution) + """ + events = get_chrome_trace_events(filename) + + # get pids of GPU events + global gpu_pids + gpu_pids = [] + for event in events: + if "name" not in event: + continue + if event["name"] == "process_labels" and "GPU" in event["args"]["labels"]: + gpu_pids.append(event["pid"]) + + total_length = total_length * 1e6 + sorted_gpu_events = get_sorted_gpu_events(events) + utilization = get_duration(sorted_gpu_events) / total_length + + sorted_gpu_mm_conv_events = get_sorted_gpu_mm_conv_events(events) + mm_conv_utilization = get_duration(sorted_gpu_mm_conv_events) / total_length + + return utilization, mm_conv_utilization + + +def benchmark_utilization( + f, + input, + trace_folder, + optimize_ctx=None, + trace_file_name="tmp_chrome_trace", + num_runs=1, +): + """ + Benchmark the GPU Utilization and percent of time spent on matmul and convolution operations of + running f(input, **kwargs_for_f) with [optimize_ctx] [num_runs] times. + It will produce a chrome trace file in trace_folder/trace_file_name.json + + Example: + + ``` + def f(a): + return a.sum() + a = torch.rand(2**20, device="cuda") + utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace") + ``` + + Args: + f: function to benchmark + + input: input to :attr:`f` + + trace_folder: name of the folder to store the chrome trace + + optimize_ctx: the context in which f will run + + trace_file_name: name of the dumped chrome trace file, default to "tmp_chrome_trace" + + num_runs: number of times to run f, excluding the warm-up runs, default to 1. + + Return: + tuple: (GPU Utilization, percent of time spent on matmul and convolution) + + """ + isExist = os.path.exists(trace_folder) + if not isExist: + os.makedirs(trace_folder) + print("create folder " + trace_folder) + + if optimize_ctx is None: + optimize_ctx = contextlib.nullcontext() + + chrome_trace_file_name = os.path.join(trace_folder, trace_file_name + ".json") + total_length = dump_chrome_trace( + f, + input, + chrome_trace_file_name, + optimize_ctx, + [ProfilerActivity.CUDA], + num_runs=num_runs, + devices=["cuda"], + ) + utilization, mm_conv_utilization = compute_utilization( + chrome_trace_file_name, total_length + ) + + return utilization, mm_conv_utilization diff --git a/phivenv/Lib/site-packages/torch/_functorch/compile_utils.py b/phivenv/Lib/site-packages/torch/_functorch/compile_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..517bc69323b724cd6f35614750dd730269ad5b9b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/compile_utils.py @@ -0,0 +1,212 @@ +# mypy: ignore-errors + + +import operator +from typing import Callable + +import sympy + +import torch +import torch.fx as fx +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_flatten + + +aten = torch.ops.aten + + +def get_aten_target(node: fx.Node) -> Callable: + if hasattr(node.target, "overloadpacket"): + return node.target.overloadpacket + return node.target + + +rand_ops = [ + aten.dropout, + aten._fused_dropout, + aten._standard_gamma, + aten.bernoulli, + aten.multinomial, + aten.native_dropout, + aten.normal, + aten.poisson, + aten.binomial, + aten.rrelu, + aten.rand_like, + aten.rand, + aten.randint, + aten.randn, + aten.randperm, +] + + +# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph +def fx_graph_cse(fx_g: torch.fx.graph.Graph): + new_graph = fx.Graph() + env = {} # map from node in the old graph to node in the new graph + hash_env = {} # map from hash to a node in the new graph + token_map = {} # map from hash to token + + from torch._inductor.pattern_matcher import ( + compute_mutation_region_ids, + same_mutation_regions, + ) + + compute_mutation_region_ids(fx_g) # type: ignore[arg-type] + + # Make a set of separate storages returned from the output, which will be preserved + # when pruning. This prevents us from deduplicating returned tensors which have + # experienced identical operations, but are separate data structures in eager mode. + output_node: fx.Node = list(fx_g.nodes)[-1] + assert output_node.op == "output" + + def checkable_node(node: fx.Node) -> bool: + """We can evaluate only nodes that represent tensors with defined storage.""" + if "val" not in node.meta or not isinstance(node.meta["val"], torch.Tensor): + return False + + try: + node.meta["val"].untyped_storage() + except NotImplementedError: + return False + + return True + + output_storages = { + StorageWeakRef(n.meta["val"].untyped_storage()) + for n in output_node.all_input_nodes + if checkable_node(n) + } + nodes_that_alias_outputs = { + n + for n in fx_g.nodes + if checkable_node(n) + and StorageWeakRef(n.meta["val"].untyped_storage()) in output_storages + } + + for n in fx_g.nodes: + # The placeholder, output, and get_attr nodes are copied to the new graph without change + # do not CSE away random operations + if ( + n.op == "placeholder" + or n.op == "output" + or n.op == "get_attr" + or get_aten_target(n) in rand_ops + # aten.empty is non-deterministic, so don't CSE it. + # Also, aten.empty is almost always fusible into its consumer, + # so it's not worth CSEing. + or get_aten_target(n) is aten.empty + or n in nodes_that_alias_outputs + # This CSE pass currently doesn't handle re-propogation of unbacked + # meta where it'll sometimes eliminate a _local_scalar_dense but not + # replace the meta of downstream users. eg. one bug we've seen is: + # + # _local_scalar_dense_11: "Sym(u14)" = torch.ops.aten._local_scalar_dense.default(select_10); + # sym_sum_2: "Sym(u19 + u20 + u21)" = torch.sym_sum((_local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13)) # noqa: B950 + # + # Notice how _local_scalar_dense_11 is u14 but sym_sum_2's meta is incorrectly the old + # pre-cse value of u19. + or ( + "val" in n.meta + and isinstance(n.meta["val"], sympy.Symbol) + and free_unbacked_symbols(n.meta["val"]) + ) + ): + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs members to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, torch.fx.node.Node) and v in env: + arg_list[i] = env[v] + if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)): + arg_list[i] = v.node + return tuple(arg_list), spec + + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = { + "target": n.target, + "args": args, + "args_spec": args_spec, + "kwargs": kwargs, + "kwargs_spec": kwargs_spec, + } + + # hash substituted args to a number, do not hash specs because specs are not hashable + # We need to add type into hash to avoid situations like: + # hash((primals_2, 1.0)) == hash((primals_2, 1)) + hash_arg = hash( + (tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs)) + ) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + overwrite_due_to_mutation = False + if hash_val_in_hash_env and token_map[hash_val] == token: + duplicate_n_prev = hash_env[hash_val] + if same_mutation_regions(n, duplicate_n_prev): + env[n] = duplicate_n_prev + continue + else: + # any futures duplicates should replace with n, not duplicate_n_prev + overwrite_due_to_mutation = True + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if overwrite_due_to_mutation or not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + return new_graph + + +def raise_getitems(gm: fx.GraphModule) -> fx.GraphModule: + # Pre-create a list of nodes to iterate over, as modifying the node order + # during the loop can lead to infinite loops if not handled properly. + getitem_nodes = list( + gm.graph.find_nodes(op="call_function", target=operator.getitem) + ) + + # loop through getitem nodes in the graph and raise them to the parent node + # in reverse order to perserve their original relative order + for node in reversed(getitem_nodes): + assert len(node.all_input_nodes) == 1 + parent = node.all_input_nodes[0] + parent.append(node) + + gm.recompile() + return gm + + +def strip_overloads(gm): + """ + Modifies the target of graph nodes in :attr:`gm` to strip overloads. + + Args: + gm(fx.GraphModule): The input Fx graph module to be modified + """ + for node in gm.graph.nodes: + if isinstance(node.target, torch._ops.OpOverload): + node.target = node.target.overloadpacket + gm.recompile() + + +def get_placeholders(graph): + return graph.find_nodes(op="placeholder") + + +def get_outputs(graph): + for node in graph.find_nodes(op="output"): + return pytree.tree_leaves(node.args[0]) + raise AssertionError("No output node found") diff --git a/phivenv/Lib/site-packages/torch/_functorch/compilers.py b/phivenv/Lib/site-packages/torch/_functorch/compilers.py new file mode 100644 index 0000000000000000000000000000000000000000..b7aee97bac643dda993796e3f06127098d98fca2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/compilers.py @@ -0,0 +1,445 @@ +# mypy: ignore-errors + +import copy +import logging +import os +import pickle +import random +from contextlib import contextmanager +from functools import partial +from typing import Callable, Union + +import sympy + +import torch +import torch.fx as fx +import torch.nn as nn +import torch.utils._pytree as pytree +from torch import SymInt +from torch._decomp import get_decompositions +from torch.fx.experimental.symbolic_shapes import bind_symbols + +from .aot_autograd import aot_function, aot_module, make_boxed_compiler +from .compile_utils import strip_overloads +from .partitioners import ( + default_partition, + draw_graph, + min_cut_rematerialization_partition, +) + + +log = logging.getLogger(__name__) + + +# These canonicalizations are needed here (and not decompositions), as the ops +# we're trying to canonicalize to CompositeImplicitAutograd. +def _canonicalize(fx_g): + for node in fx_g.graph.find_nodes( + op="call_function", target=torch.ops.aten._to_copy + ): + node.target = torch.ops.aten.to + fx_g.recompile() + return fx_g + + +@contextmanager +def _disable_jit_autocast(): + old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) + try: + yield + finally: + torch._C._jit_set_autocast_mode(old_jit_autocast_flag) + + +@make_boxed_compiler +def ts_compile(fx_g: fx.GraphModule, inps) -> Callable: + """ + Compiles the :attr:`fx_g` with Torchscript compiler. + + .. warning:: + This API is experimental and likely to change. + + Args: + fx_g(fx.GraphModule): The input Fx graph module to be compiled. + + Returns: + Torch scripted model. + """ + + with _disable_jit_autocast(): + strip_overloads(fx_g) + + for node in fx_g.graph.find_nodes( + op="call_function", target=torch.ops.aten._to_copy + ): + if len(node.args) == 1 and len(node.kwargs) == 1 and "dtype" in node.kwargs: + node.target = torch.ops.aten.to + + for node in fx_g.graph.nodes: + new_kwargs = {} + for k, v in node.kwargs.items(): + if isinstance(v, torch.device): + v = v.type + new_kwargs[k] = v + node.kwargs = new_kwargs + + fx_g.graph.lint() + + fx_g.recompile() + + f = torch.jit.script(fx_g) + + torch._C._jit_pass_remove_mutation(f.graph) + + f = torch.jit.freeze(f.eval()) + f = torch.jit.optimize_for_inference(f) + if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps): + f(*inps) + return f + + +def _draw_graph_compile(fx_g, _, name, clear_meta=True): + print(fx_g.code) + draw_graph(fx_g, name, clear_meta=clear_meta) + return fx_g + + +def draw_graph_compile(name): + return make_boxed_compiler(partial(_draw_graph_compile, name=name)) + + +@make_boxed_compiler +def nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler + and can be used to check accuracy. + + .. warning:: + This API is experimental and likely to change. + + """ + return fx_g + + +class DebugInterpreter(fx.Interpreter): + def run(self, *args): + self.symbol_mapping = bind_symbols(self.module, *args) + super().run(*args) + + def run_node(self, n): + def subst_symint(ni): + if not isinstance(ni, SymInt): + return ni + r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping)) + assert r.is_number, r + return int(r) + + def subst_symint_tuple(nis): + return tuple(subst_symint(ni) for ni in nis) + + def check_significant_strides(a, b): + if subst_symint(a.numel()) > 0: + for idx in range(a.ndim): + if ( + subst_symint(a.stride(idx)) != b.stride(idx) + and subst_symint(a.size(idx)) > 1 + ): + return False + return True + + def check(nv, rv, desc): + assert callable(desc) + assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}" + assert ( + subst_symint_tuple(nv.size()) == rv.size() + ), f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" + same_strides = check_significant_strides(nv, rv) + assert ( + same_strides + ), f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" + + r = super().run_node(n) + if "val" in n.meta: + n_vals, _n_spec = pytree.tree_flatten(n.meta["val"]) + r_vals, _r_spec = pytree.tree_flatten(r) + # TODO: There is some sort of problem where we record that an + # operator returned a tuple/list, and then later it turns out the + # real version of the operator returned a list/tuple. Need to + # figure out what's actually going on here, the error itself is + # harmless enough as we only getitem out the outputs. + # assert n_spec == r_spec, f"{n_spec} != {r_spec}" + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}") + return r + + +@make_boxed_compiler +def debug_nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns a (slow) interpreter over the FX graph module that also checks + various debugging properties (e.g., that tracing strides matched real + strides.) + """ + return DebugInterpreter(fx_g).run + + +@make_boxed_compiler +def simple_ts_compile(fx_g, _): + strip_overloads(fx_g) + f = torch.jit.script(fx_g) + f = torch.jit.freeze(f.eval()) + return f + + +def nnc_jit(f): + return aot_function(f, simple_ts_compile) + + +aten = torch.ops.aten +default_decompositions = { + aten.detach, + aten.gelu_backward, + aten.leaky_relu_backward, + aten.sigmoid_backward, + aten.threshold_backward, + aten.hardtanh_backward, + aten.hardsigmoid_backward, + aten.hardswish_backward, + aten.tanh_backward, + aten.silu_backward, + aten.elu_backward, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.masked_fill.Scalar, + aten.masked_fill.Tensor, + aten.elu, + aten.leaky_relu, + aten.hardtanh, + aten.hardswish, + aten.hardsigmoid, + aten.conj_physical, + aten.is_same_size, +} + +default_decompositions = get_decompositions(default_decompositions) + + +@make_boxed_compiler +def print_compile(fx_g, _): + print(fx_g.code) + return fx_g + + +def memory_efficient_fusion( + fn: Union[Callable, nn.Module], + **kwargs, +): + """ + Wrapper function over :func:`aot_function` and :func:`aot_module` to perform + memory efficient fusion. It uses the + :func:`min_cut_rematerialization_partition` partitioner to perform efficient + recomputation. It uses NVFuser to compile the generated forward and backward + graphs. + + .. warning:: + This API is experimental and likely to change. + + Args: + fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` + that takes one ore more arguments. Must return one or more Tensors. + **kwargs: Any other overrides you want to make to the settings + + Returns: + Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior + of the original :attr:`fn`, but whose forward and backward graphs have + gone through recomputation optimizations, and the graphs have been + compiled with nvfuser. + + """ + config = { + "fw_compiler": ts_compile, + "bw_compiler": ts_compile, + "partition_fn": min_cut_rematerialization_partition, + "decompositions": default_decompositions, + } + config.update(kwargs) + if isinstance(fn, torch.nn.Module): + return aot_module(fn, **config) + else: + return aot_function(fn, **config) + + +def debug_compile(fx_g, inps): + fx_g.to_folder("foo") + print( + f""" +############################################################## +# To minimize FX graph, copy and paste the below and run it # +############################################################## + +import torch +import torch.fx as fx +from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess + +inps = {[(i.shape, i.dtype) for i in inps]} +inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] +from foo import FxModule +mod = FxModule().cuda() + +with torch.jit.fuser("fuser2"): + # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess + minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess) +""" + ) + from foo import FxModule + + FxModule().cuda()(*inps) + + return ts_compile(fx_g, inps) + + +graph_index = 0 + + +def get_inputs(input_data_path): + """ + Return a random input for the given inputs meta generated from _save_fx_default. + """ + inputs = [] + with open(input_data_path, "rb") as f: + inputs_meta = pickle.load(f) + inputs = [] + for meta in inputs_meta: + if len(meta) == 1: + type = meta + input = type(random.rand()) + else: + type, shape, _stride, dtype, device = meta + if dtype in { + torch.int, + torch.int32, + torch.int64, + torch.bool, + torch.int, + torch.uint8, + int, + float, + }: + input = torch.randint(0, 1, shape, dtype=dtype, device=device) + else: + input = torch.rand(shape, dtype=dtype, device=device) + inputs.append(input) + return inputs + + +def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs): + """ + The forward, backward, and joint computation graph will be stored in + {folder_name}/{current_name}/{current_name}_forward_{graph_index}, + {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and + {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively. + The input shape of the graphs will be stored in the .input files. + These files can be loaded with pickle, + and is a list of format (type, shape, stride, dtype, device). + In the case of type = int or float, it is just (type,). + For joint graph input, it is a nested list [[],[]] + where the two inner lists have the same format. + If dump_example_input is True, example_inputs will be stored in .pt file. + Since each function might produce multiple graphs, + the graph_index is used to distinguish difference graphs + """ + from functorch.compile import aot_module_simplified + + def get_input_meta(args): + input_meta = [] + if len(args) > 0 and isinstance(args[0], tuple): # joint input + input_meta += get_input_meta(args[0]) + input_meta += get_input_meta(args[1]) + return input_meta + for arg in args: + if type(arg) == int or type(arg) == float: + input_meta.append((type(arg),)) + else: + input_meta.append( + (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device) + ) + return input_meta + + def graph_saver_helper(gm_to_save, args, type_name): + global graph_index + if len(gm_to_save.graph.nodes) == 0: + log.log( + logging.WARNING, + "No nodes in graph {%s}_{%s}_{%s}.", + current_name, + type_name, + graph_index, + ) + return + + gm = copy.deepcopy(gm_to_save) + gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen + gm.recompile() + + input_meta = get_input_meta(args) + + os.makedirs(f"{folder_name}/{current_name}", exist_ok=True) + gm.to_folder( + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}" + ) + pickle.dump( + input_meta, + open( + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950 + "wb", + ), + ) # noqa: E501 + if dump_example_input: + torch.save( + args, + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950 + ) # noqa: E501 + + def graph_saver_forward(gm, fw_args): + graph_saver_helper(gm, fw_args, "forward") + return gm + + def graph_saver_backward(gm, bw_args): + graph_saver_helper(gm, bw_args, "backward") + global graph_index + graph_index += 1 + return gm + + def graph_saver_joint(gm, joint_args): + graph_saver_helper(gm, joint_args, "joint") + return default_partition(gm, joint_args) + + return aot_module_simplified( + gm, + example_inputs, + fw_compiler=graph_saver_forward, + bw_compiler=graph_saver_backward, + partition_fn=graph_saver_joint, + decompositions=default_decompositions, + ) + + +# WARNING: This isn't tested anywhere!! +def graph_dumper_aot(current_name, folder_name, dump_example_input=False): + """ + Dump the forward, backward, and joint computation graph. + Example Usage: + save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False) + optimize_ctx = torchdynamo.optimize( + save_fx_func + ) + with torch.enable_grad(): + with optimize_ctx: + result = forward_and_backward_pass(model, example_inputs) + """ + global graph_index + graph_index = 0 + return partial(_save_fx_default, current_name, folder_name, dump_example_input) diff --git a/phivenv/Lib/site-packages/torch/_functorch/config.py b/phivenv/Lib/site-packages/torch/_functorch/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2eefed94b50092e338565f952d021982888959b5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/config.py @@ -0,0 +1,311 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Global flags for aot autograd +""" +import os +import sys +from typing import Literal, Optional, TYPE_CHECKING + +from torch.utils._config_module import Config, install_config_module + + +# Converts torch rng ops to their functional philox rng equivalents. Note that +# we functionalize only CUDA rng ops today. +functionalize_rng_ops = False + +# can be useful for debugging if we are incorrectly creating meta fake tensors +fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0" + +# Enables optional asserts in hotpath code to check for errors. If +# you are seeing weird accuracy problems, try turning this on. +# This is currently off by default as it will harm tracing time, +# but it is on by default for aot_eager. +debug_assert = False + +debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0" + +# See # NOTE [Export custom triton op] +decompose_custom_triton_ops = True + +static_weight_shapes = True + +# See https://github.com/pytorch/pytorch/issues/141881 +# Tells partitioner that parameters are free to save for backward. +treat_parameters_as_free_to_save = True + +# Applies CSE to the graph before partitioning +cse = True + +from torch._environment import is_fbcode + + +enable_autograd_cache: bool = Config( + justknob="pytorch/remote_cache:enable_local_autograd_cache", + env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE", + default=True, +) + +autograd_cache_allow_custom_autograd_functions: bool = Config( + env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE_ALLOW_CUSTOM_AUTOGRAD", default=False +) + +# For now, this is just for enabling unit testing in test_aot_autograd_cache.py +# We will either make this the default with AOTAutogradCache, or +# we'll just use it in the precompile flow. So there's no +# need to add env vars or make it configurable +bundled_autograd_cache: bool = False + + +def remote_autograd_cache_default() -> Optional[bool]: + if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "1": + return True + if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "0": + return False + return None + + +enable_remote_autograd_cache = remote_autograd_cache_default() + + +# When AOTAutograd regenerates aliased graph outputs, +# attempt to use functionalization's view-replay logic +# before falling back to the autograd engine's view replay or as_strided. +# This can have some perf implications +# (although for many models this will not matter). +# (1) If you have many view ops chained together, replaying all of them +# at runtime can have more overhead compared to a single as_strided call +# (2) If you are doing training, AsStridedBackward is quite slow, +# and the individual view op backward formulas will likely be faster. +# (3) Some backends like XLA do not support as_strided + +# Temporary hack: disable this flag for internal +# (needed to fix an internal issue while avoiding bumping XLA pin) +# eventually: either default this config to false completely +# once XLA pin update works, +# or default config to true and fix relevant bugs + + +# View replay is currently not compatible with AOTAutogradCache, since +# FunctionalTensors are not serializable. We'll need to make them +# serializable before enabling warm cache with this config turned on. +view_replay_for_aliased_outputs = not is_fbcode() + +# Restricts the amount of computation AOTAutograd can do. +# NB: We have essentially disabled this heuristic now. However, this is kept +# here for now in case it's useful. Setting it low can artificially reduce the +# amount of recomputation AOTAutograd performs, although not in any kind of +# principled way. +max_dist_from_bw = 1000 + + +# Bans recomputation of nodes that are reading from nodes that is far before +# the current node +ban_recompute_used_far_apart = True +# Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily +# long chain of recomputation in the backwards pass. +ban_recompute_long_fusible_chains = True +# Bans recomputation of nodes that must be materialized in the backwards pass +# (used by a non-fusible node) +ban_recompute_materialized_backward = True +# Chooses to ban recomputation of nodes based off an allowlist. Setting it to +# False changes it to use a denylist. Main change is on operators like +# sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't +# that expensive +ban_recompute_not_in_allowlist = True +# Chooses to ban recomputation of reductions. This is generally a good idea, as +# the result of reductions is generally very small but recomputing reductions in +# a fusion can be expensive. +ban_recompute_reductions = True +# Prevents the partitioner from ever saving views (i.e. always recompute them). +# Generally a good idea since views are free to recompute. +recompute_views = False + +# By default, the partitioner is purely trying to optimize for runtime (although +# it should always use less memory than eager) +# This knob controls the partitioner to make that tradeoff for you, choosing the +# fastest option that saves less activations than the memory budget. +# Specifically, 0.0 corresponds to the activation memory from applying +# activation checkpointing to the full compiled region, and 1.0 corresponds to +# the activation memory from the default runtime-optimized strategy. So, 0.4 +# would result in a strategy that saves 40% of the activations compared to the +# default strategy. +# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below +# the activation memory budget. +# NOTE: This *cannot* be treated as +activation_memory_budget = 1.0 + +# This controls how we estimate the runtime when deciding what the cheapest +# operators to recompute are. The 3 options are +# "flops": Bases it off of the flop count provided by torch.utils.flop_counter +# "profile": Benchmarks each operator to come up with a runtime +# "testing": Returns 1 for everything +activation_memory_budget_runtime_estimator = "flops" + +# This controls the solver used for the 0-1 knapsack. By default we use a +# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" +# (which has a scipy dependency). +activation_memory_budget_solver = "dp" + +# This dumps out a SVG visualization of the expected runtime vs. activation +# memory tradeoffs for all memory budget values from 0 to 1 in increments of +# 0.5. See an example here: +# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015 +visualize_memory_budget_pareto = ( + os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1" +) + +# This controls the directory in which to dump the SVG plot with the pareto +# frontier of the activation checkpointing memory-vs-runtime tradeoffs. +memory_budget_pareto_dir = os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO_DIR") + +# Sets all of the ban_recompute heuristics to False except ban_recompute_reductions +# Generally, this will probably result in some memory improvement, but at the +# cost of some performance +aggressive_recomputation = False + +# If FakeTensor.data_ptr() should error. +# This option is independent of AOTAutograd and torch.compile, but our policy +# is to turn it off during torch.compile. +fake_tensor_allow_unsafe_data_ptr_access = True + +# Unlifts effect tokens from the inputs/outputs in the traced graph and instead +# inserts make_token/sink_token calls in the graph to create tokens and then +# sink them at the end. Note that this means the graph is no longer functional +# which may lead to silent errors unless the backend knows how to handle the +# tokens. +unlift_effect_tokens = False + +# NOTE: [The default layout constraint for custom operators.] +# This must be the name of one of the layout constraint tags +# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}), +# If the custom op does not have a layout constraint tag already +# then we assume the following applies. +# +# This config is respected by Inductor and we recommend other backends also +# respect it. +# This config is in torch._functorch and not torch._inductor because it affects +# ProxyTensor tracing. +custom_op_default_layout_constraint: Literal[ + "needs_exact_strides", "needs_fixed_stride_order", "flexible_layout" +] = "needs_exact_strides" + + +# Run aot eager decomp partition with CrossRefFakeMode +# options = False, "all", "custom_ops" +fake_tensor_crossref = False + +# This mode specifies that we should also keep track of the real +# tensor along with the fake tensor, and do real compute. While +# seemingly this eliminates the whole point of fake tensors, there are +# two obvious use cases for it: +# +# 1. When users call item()/other data dependent operations, +# if we propagate_real_tensors we are able to determine what +# the true value is and keep going. +# +# 2. It can be useful for testing, when you want to see if the fake +# and real tensors agree with each other. (Note that there are +# currently known inaccuracies in how we clone real tensors, that +# would have to be tightened up for this to be useful in this +# case.) +# +# Note that fake tensors are typically understood to be cheap to store +# indefinitely, so we tend to hold on to them longer than we would +# hold onto the real tensors. So we also support you explicitly +# deallocating the real tensor associated with a fake tensor, at which +# point we will stop propagating real tensors. +# +# One more thing: when you provide a real tensor to fakeify, we will +# clone it, so that we can safely perform mutations on it if necessary. +# This will increase live memory usage. This could potentially be +# optimized by using COW. We also currently do not faithfully +# maintain autograd metadata on the real tensor; this is fine because +# AOTAutograd will only use the fake tensor to determine leafness/etc +# of tensors in question. +fake_tensor_propagate_real_tensors = False + +# This controls whether we collect donated buffer. This flag must be set +# False if a user wants to retain_graph=True for backward. +donated_buffer = False if is_fbcode() else True + +# Controls the default graph output format used by draw_graph +# Supported formats are defined here https://graphviz.org/docs/outputs/ +torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg") + +# Valid only if fake_tensor_propagate_real_tensors = True; if a fake-real +# kernel mismatch is detected, bypasses by making a fake kernel from the +# real tensor outputs. +generate_fake_kernels_from_real_mismatches = False + +# CUDAGraph save run_with_rng functionalization. +# TODO: turn on by default +graphsafe_rng_functionalization = True + + +# Error on BypassAOTAutogradCache instead of just a warning +# Used for tests +strict_autograd_cache = False + +# Note [Recomputing collectives in the partitioner] +# The purpose of this config is as follows: +# - We have many passes in the compiler (min-cut partitioning, DCE, etc) +# which can reorder or ,delete duplicate nodes in the graph +# - If any of these passes reorder/delete/duplicate a collective +# in a setting where the compiler is being run independently on multiple +# ranks, we run the risk that the compiler will make a different decison on +# different ranks, resulting in a NCCL hang when using torch.compile +# To handle this, we will (by default) ensure that collectives are not modified +# by the compiler. +# +# A few examples: +# - don't dead-code-eliminate collectives +# (in case they are dead on rank i but not rank j) +# - don't recompute collectives in partitioning +# (in case we recompute on rank i but not rank j) +# +# Today this flag **must** be set to false, but eventually +# we want the option to set it to true. +# In order to potentially optimize collectives, we'll need the compiler +# to broadcast information across ranks at compile time to ensure +# that any decisions on collectives are made consistently. +unsafe_allow_optimization_of_collectives = False + +# See Note [AOTAutograd Tangent Subclassness for mutated inputs] +# TODO(ivankobzarev): Remove this config, being able to deduce it compile time. +disable_guess_zero_tangent_for_mutated_input_subclass = False + +# See Note [Tangents memory format] +# By default tangents strideness is guessed to be contiguous, +# At runtime non contiguous tangents will be coerced to be contiguous. +# This config changes this guess for tangents strides to be the same as outputs. +# TODO(ivankobzarev): Remove this config once extra memory usage is investigated. +guess_tangent_strides_as_outputs = False + +# This is a temporary config to ensure all ranks take the same decision in the partitioner +# it will untimately be removed once we share size_hints across ranks through compiler collectives +_broadcast_rank0_decision = False + +# By default apply inlined saved_tensors_hooks only for "donated" buffers. +# "donated" buffers are invisible to the user, they are intermediates of the forward graph. +# Applying saved tensors hooks for memory optimizations only for intermediates +# guarantees that original saved tensors could be deallocated. +# This config enables saved_tensors_hooks are applied for **all** saved tensors, +# that could include inputs, parameters, outputs. +# "donated" - applied only to saved intermediates of the graph +# "no_static" - applied to all saved but not "static" +# (this includes parameters and user marked as static) +# "all" - no filtering, everything saved for backward. +saved_tensors_hooks_filtering_mode = "donated" + + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + + +# adds patch, save_config, invalid config checks, etc +install_config_module(sys.modules[__name__]) diff --git a/phivenv/Lib/site-packages/torch/_functorch/deprecated.py b/phivenv/Lib/site-packages/torch/_functorch/deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..fea6a97d781d67b3fe058d43897e4136117421c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/deprecated.py @@ -0,0 +1,172 @@ +# mypy: allow-untyped-defs +""" +The APIs in this file are exposed as `functorch.*`. They are thin wrappers +around the torch.func.* APIs that have deprecation warnings -- we're trying +to move people to the torch.func.* equivalents. + +NB: We don't use *args, **kwargs in the signatures because that changes the +documentation. +""" + +import textwrap +import warnings +from typing import Any, Callable, Optional, Union + +import torch._functorch.apis as apis +import torch._functorch.eager_transforms as _impl +import torch._functorch.make_functional as _nn_impl +import torch.nn as nn +from torch._functorch.eager_transforms import argnums_t +from torch._functorch.vmap import in_dims_t, out_dims_t + + +def get_warning(api, new_api=None, replace_newlines=False): + if new_api is None: + new_api = f"torch.func.{api}" + warning = ( + f"We've integrated functorch into PyTorch. As the final step of the \n" + f"integration, `functorch.{api}` is deprecated as of PyTorch \n" + f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n" + f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n" + f"and/or the `torch.func` migration guide for more details \n" + f"https://pytorch.org/docs/main/func.migrating.html" + ) + if replace_newlines: + warning = warning.replace("\n", "") + return warning + + +def warn_deprecated(api, new_api=None): + warning = get_warning(api, new_api, replace_newlines=True) + warnings.warn(warning, FutureWarning, stacklevel=3) + + +def setup_docs(functorch_api, torch_func_api=None, new_api_name=None): + api_name = functorch_api.__name__ + if torch_func_api is None: + torch_func_api = getattr(_impl, api_name) + # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO + if torch_func_api.__doc__ is None: + return + + warning = get_warning(api_name, new_api_name) + warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ") + warning_note = textwrap.indent(warning_note, " ") + functorch_api.__doc__ = torch_func_api.__doc__ + warning_note + + +def vmap( + func: Callable, + in_dims: in_dims_t = 0, + out_dims: out_dims_t = 0, + randomness: str = "error", + *, + chunk_size=None, +) -> Callable: + warn_deprecated("vmap", "torch.vmap") + return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size) + + +def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: + warn_deprecated("grad") + return apis.grad(func, argnums, has_aux) + + +def grad_and_value( + func: Callable, argnums: argnums_t = 0, has_aux: bool = False +) -> Callable: + warn_deprecated("grad_and_value") + return apis.grad_and_value(func, argnums, has_aux) + + +def vjp(func: Callable, *primals, has_aux: bool = False): + warn_deprecated("vjp") + return _impl.vjp(func, *primals, has_aux=has_aux) + + +def jvp( + func: Callable, + primals: Any, + tangents: Any, + *, + strict: bool = False, + has_aux: bool = False, +): + warn_deprecated("jvp") + return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux) + + +def jacrev( + func: Callable, + argnums: Union[int, tuple[int]] = 0, + *, + has_aux=False, + chunk_size: Optional[int] = None, + _preallocate_and_copy=False, +): + warn_deprecated("jacrev") + return _impl.jacrev( + func, + argnums, + has_aux=has_aux, + chunk_size=chunk_size, + _preallocate_and_copy=_preallocate_and_copy, + ) + + +def jacfwd( + func: Callable, + argnums: argnums_t = 0, + has_aux: bool = False, + *, + randomness: str = "error", +): + warn_deprecated("jacfwd") + return _impl.jacfwd(func, argnums, has_aux, randomness=randomness) + + +def hessian(func, argnums=0): + warn_deprecated("hessian") + return _impl.hessian(func, argnums=argnums) + + +def functionalize(func: Callable, *, remove: str = "mutations") -> Callable: + warn_deprecated("functionalize") + return _impl.functionalize(func, remove=remove) + + +def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): + warn_deprecated("make_functional", "torch.func.functional_call") + return _nn_impl.make_functional(model, disable_autograd_tracking) + + +def make_functional_with_buffers( + model: nn.Module, disable_autograd_tracking: bool = False +): + warn_deprecated("make_functional_with_buffers", "torch.func.functional_call") + return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking) + + +def combine_state_for_ensemble(models): + warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state") + return _nn_impl.combine_state_for_ensemble(models) + + +setup_docs(vmap, apis.vmap, "torch.vmap") +setup_docs(grad, apis.grad) +setup_docs(grad_and_value, apis.grad_and_value) +setup_docs(vjp) +setup_docs(jvp) +setup_docs(jacrev) +setup_docs(jacfwd) +setup_docs(hessian) +setup_docs(functionalize) +setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call") +setup_docs( + make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call" +) +setup_docs( + combine_state_for_ensemble, + _nn_impl.combine_state_for_ensemble, + "torch.func.stack_module_state", +) diff --git a/phivenv/Lib/site-packages/torch/_functorch/eager_transforms.py b/phivenv/Lib/site-packages/torch/_functorch/eager_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..de2b42b6a6db414c602547e924b2152749b72f51 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/eager_transforms.py @@ -0,0 +1,1817 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from functools import partial, wraps +from typing import Any, Callable, Optional, Union + +import torch +import torch.autograd.forward_ad as fwAD +from torch._C._functorch import ( + _assert_wrapped_functional, + _func_decrement_nesting, + _func_increment_nesting, + _grad_decrement_nesting, + _grad_increment_nesting, + _jvp_decrement_nesting, + _jvp_increment_nesting, + _propagate_functional_input_mutation, + _unwrap_for_grad, + _unwrap_functional_tensor, + _wrap_for_grad, + _wrap_functional_tensor, + get_inplace_requires_grad_allowed, + get_unwrapped, + is_functorch_wrapped_tensor, + set_inplace_requires_grad_allowed, +) +from torch._functorch.utils import argnums_t, exposed_in +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental import const_fold +from torch.fx.experimental.proxy_tensor import make_fx +from torch.utils import _pytree as pytree +from torch.utils._pytree import ( + tree_flatten, + tree_map, + tree_map_, + tree_map_only, + tree_unflatten, + treespec_pprint, +) + +from .apis import vmap +from .vmap import doesnt_support_saved_tensors_hooks, get_chunk_sizes + + +def lazy_dynamo_disallow(func): + import torch._dynamo + + return torch._dynamo.disallow_in_graph(func) + + +@contextlib.contextmanager +def enable_inplace_requires_grad(enabled): + prev_state = get_inplace_requires_grad_allowed() + set_inplace_requires_grad_allowed(enabled) + try: + yield + finally: + set_inplace_requires_grad_allowed(prev_state) + + +def _set_tensor_requires_grad(x): + # avoid graph-break on x.requires_grad_() + # https://github.com/pytorch/pytorch/pull/110053 + return x.requires_grad_() + + +def _create_differentiable(inps, level=None): + def create_differentiable(x): + if isinstance(x, torch.Tensor): + with enable_inplace_requires_grad(True): + return _set_tensor_requires_grad(x) + raise ValueError(f"Thing passed to transform API must be Tensor, got {type(x)}") + + return tree_map(create_differentiable, inps) + + +def _undo_create_differentiable(inps, level=None): + def unwrap_tensors(x): + if isinstance(x, torch.Tensor): + return _unwrap_for_grad(x, level) + # TODO: Remove the following hack for namedtuples + if isinstance(x, tuple): + return tree_map(unwrap_tensors, tuple(x)) + + raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}") + + return tree_map(unwrap_tensors, inps) + + +def _is_differentiable(maybe_tensor): + if not isinstance(maybe_tensor, torch.Tensor): + return False + return maybe_tensor.requires_grad + + +def _any_differentiable(tensor_or_tuple_of_tensors): + flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors) + return any(tuple(map(_is_differentiable, flat_args))) + + +def _wrap_tensor_for_grad(maybe_tensor, level): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + return _wrap_for_grad(maybe_tensor, level) + + +def _wrap_all_tensors(tensor_pytree, level): + return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree) + + +def _as_tuple(val): + if isinstance(val, tuple): + return val + return (val,) + + +# Version of autograd.grad that handles outputs that don't depend on inputs + + +def _autograd_grad( + outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True +): + if grad_outputs is None: + diff_outputs = tuple(out for out in outputs if out.requires_grad) + else: + result = tuple( + (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad + ) + if len(result) == 0: + diff_outputs, grad_outputs = (), () + else: + diff_outputs, grad_outputs = zip(*result) + if len(diff_outputs) == 0: + return tuple(torch.zeros_like(inp) for inp in inputs) + with torch._dynamo.compiled_autograd._disable(): + grad_inputs = torch.autograd.grad( + diff_outputs, + inputs, + grad_outputs, + retain_graph=retain_graph, + create_graph=create_graph, + allow_unused=True, + ) + grad_inputs = tuple( + torch.zeros_like(inp) if gi is None else gi + for gi, inp in zip(grad_inputs, inputs) + ) + return grad_inputs + + +# NOTE [grad and vjp interaction with no_grad] +# +# def f(x): +# with torch.no_grad(): +# c = x ** 2 +# return x - c +# +# The thing to consider is if enable_grad is on/off before grad gets called. +# +# Case 1: enable_grad is on. +# grad(f)(x) +# In this case, `grad` should respect the inner torch.no_grad. +# +# Case 2: enable_grad is off +# with torch.no_grad(): +# grad(f)(x) +# In this case, `grad` should respect the inner torch.no_grad, but not the +# outer one. This is because `grad` is a "function transform": its result +# should not depend on the result of a context manager outside of `f`. +# +# This gives us the following desired behavior: +# - (nested) grad transforms must obey torch.no_grad inside them +# - (nested) grad transforms should not obey torch.no_grad outside them +# +# To achieve this behavior, upon entering grad/vjp: +# - we save the current ("previous") is_grad_enabled (*) +# - we unconditionally enable grad. +# +# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer +# off the stack: +# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad +# active, all subsequent grad transforms must obey it). +# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False, +# then we temporarily restore the previous `is_grad_enabled`. This is +# because we're crossing the boundary from a `grad` outside the +# no_grad to a `grad` inside the no_grad. +# +# NB: vjp has some interesting behavior because the vjp's callable can be called +# under a different grad_mode than the forward computation... +# +# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but +# it respects c10::AutoFwGradMode. We've implemented the same logic for +# our jvp transform (it will have special handling if FwGradMode is disabled). + + +# How do we increment and decrement the nesting? I don't think we can. +@exposed_in("torch.func") +def vjp(func: Callable, *primals, has_aux: bool = False): + """ + Standing for the vector-Jacobian product, returns a tuple containing the + results of ``func`` applied to ``primals`` and a function that, when + given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with + respect to ``primals`` times ``cotangents``. + + Args: + func (Callable): A Python function that takes one or more arguments. Must + return one or more Tensors. + primals (Tensors): Positional arguments to ``func`` that must all be + Tensors. The returned function will also be computing the + derivative with respect to these arguments + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + other auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a ``(output, vjp_fn)`` tuple containing the output of ``func`` + applied to ``primals`` and a function that computes the vjp of + ``func`` with respect to all ``primals`` using the cotangents passed + to the returned function. If ``has_aux is True``, then instead returns a + ``(output, vjp_fn, aux)`` tuple. + The returned ``vjp_fn`` function will return a tuple of each VJP. + + When used in simple cases, :func:`vjp` behaves the same as :func:`grad` + + >>> x = torch.randn([5]) + >>> f = lambda x: x.sin().sum() + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> grad = vjpfunc(torch.tensor(1.))[0] + >>> assert torch.allclose(grad, torch.func.grad(f)(x)) + + However, :func:`vjp` can support functions with multiple outputs by + passing in the cotangents for each of the outputs + + >>> x = torch.randn([5]) + >>> f = lambda x: (x.sin(), x.cos()) + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) + >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) + + :func:`vjp` can even support outputs being Python structs + + >>> x = torch.randn([5]) + >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} + >>> vjps = vjpfunc(cotangents) + >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) + + The function returned by :func:`vjp` will compute the partials with + respect to each of the ``primals`` + + >>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) + >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) + >>> cotangents = torch.randn([5, 5]) + >>> vjps = vjpfunc(cotangents) + >>> assert len(vjps) == 2 + >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) + >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents)) + + ``primals`` are the positional arguments for ``f``. All kwargs use their + default value + + >>> x = torch.randn([5]) + >>> def f(x, scale=4.): + >>> return x * scale + >>> + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> vjps = vjpfunc(torch.ones_like(x)) + >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.)) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``vjp``. + Case 1: Using ``torch.no_grad`` inside a function: + + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager: + + >>> # xdoctest: +SKIP(failing) + >>> with torch.no_grad(): + >>> vjp(f)(x) + + In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``vjp`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + """ + return _vjp_with_argnums(func, *primals, has_aux=has_aux) + + +@contextlib.contextmanager +def grad_increment_nesting(): + try: + grad_level = _grad_increment_nesting() + yield grad_level + finally: + _grad_decrement_nesting() + + +def enter_jvp_nesting(): + global JVP_NESTING + jvp_level = _jvp_increment_nesting() + JVP_NESTING += 1 + return jvp_level + + +def exit_jvp_nesting(): + global JVP_NESTING + _jvp_decrement_nesting() + JVP_NESTING -= 1 + + +@contextlib.contextmanager +def jvp_increment_nesting(): + try: + yield enter_jvp_nesting() + finally: + exit_jvp_nesting() + + +@doesnt_support_saved_tensors_hooks +def _vjp_with_argnums( + func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False +): + # This is the same function as vjp but also accepts an argnums argument + # All args are the same as vjp except for the added argument + # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to. + # If None, computes the gradients with respect to all inputs (used for vjp). Default: None + # + # WARN: Users should NOT call this function directly and should just be calling vjp. + # It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers. + # + # NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev + # + # Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs + # for only the primal elements given by argnums. + with grad_increment_nesting() as level: + # See NOTE [grad and vjp interaction with no_grad] + with torch.enable_grad(): + primals = _wrap_all_tensors(primals, level) + if argnums is None: + diff_primals = _create_differentiable(primals, level) + else: + diff_primals = _slice_argnums(primals, argnums, as_tuple=False) + tree_map_(partial(_create_differentiable, level=level), diff_primals) + primals_out = func(*primals) + + if has_aux: + if not (isinstance(primals_out, tuple) and len(primals_out) == 2): + raise RuntimeError( + "vjp(f, *primals): output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + primals_out, aux = primals_out + aux = _undo_create_differentiable(aux, level) + + flat_primals_out, primals_out_spec = tree_flatten(primals_out) + assert_non_empty_tensor_output(flat_primals_out, "vjp(f, *primals)") + flat_diff_primals, primals_spec = tree_flatten(diff_primals) + results = _undo_create_differentiable(primals_out, level) + + for primal_out in flat_primals_out: + assert isinstance(primal_out, torch.Tensor) + if primal_out.is_floating_point() or primal_out.is_complex(): + continue + raise RuntimeError( + "vjp(f, ...): All outputs of f must be " + "floating-point or complex Tensors, got Tensor " + f"with dtype {primal_out.dtype}" + ) + + def wrapper(cotangents, retain_graph=True, create_graph=None): + if create_graph is None: + create_graph = torch.is_grad_enabled() + flat_cotangents, cotangents_spec = tree_flatten(cotangents) + if primals_out_spec != cotangents_spec: + raise RuntimeError( + f"Expected pytree structure of cotangents to be the same " + f"as pytree structure of outputs to the function. " + f"cotangents: {treespec_pprint(cotangents_spec)}, " + f"primal output: {treespec_pprint(primals_out_spec)}" + ) + result = _autograd_grad( + flat_primals_out, + flat_diff_primals, + flat_cotangents, + retain_graph=retain_graph, + create_graph=create_graph, + ) + return tree_unflatten(result, primals_spec) + + if has_aux: + return results, wrapper, aux + else: + return results, wrapper + + +def _safe_zero_index(x): + assert len(x) == 1 + return x[0] + + +# jacrev and jacfwd don't support complex functions +# Helper function to throw appropriate error. +def error_if_complex(func_name, args, is_input): + flat_args = pytree.tree_leaves(args) + for idx, arg in enumerate(flat_args): + if isinstance(arg, torch.Tensor) and arg.dtype.is_complex: + input_or_output = "inputs" if is_input else "outputs" + err_msg = ( + f"{func_name}: Expected all {input_or_output} " + f"to be real but received complex tensor at flattened input idx: {idx}" + ) + raise RuntimeError(err_msg) + + +@exposed_in("torch.func") +def jacrev( + func: Callable, + argnums: Union[int, tuple[int]] = 0, + *, + has_aux=False, + chunk_size: Optional[int] = None, + _preallocate_and_copy=False, +): + """ + Computes the Jacobian of ``func`` with respect to the arg(s) at index + ``argnum`` using reverse mode autodiff + + .. note:: + Using :attr:`chunk_size=1` is equivalent to computing the jacobian + row-by-row with a for-loop i.e. the constraints of :func:`vmap` are + not applicable. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Jacobian with respect to. + Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + auxiliary objects that will not be differentiated. + Default: False. + chunk_size (None or int): If None (default), use the maximum chunk size + (equivalent to doing a single vmap over vjp to compute the jacobian). + If 1, then compute the jacobian row-by-row with a for-loop. + If not None, then compute the jacobian :attr:`chunk_size` rows at a time + (equivalent to doing multiple vmap over vjp). If you run into memory issues computing + the jacobian, please try to specify a non-None chunk_size. + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Jacobian of ``func`` with respect to the arg(s) at + ``argnums``. If ``has_aux is True``, then the returned function + instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` + is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. + + A basic usage with a pointwise, unary operation will give a diagonal array + as the Jacobian + + >>> from torch.func import jacrev + >>> x = torch.randn(5) + >>> jacobian = jacrev(torch.sin)(x) + >>> expected = torch.diag(torch.cos(x)) + >>> assert torch.allclose(jacobian, expected) + + If you would like to compute the output of the function as well as the + jacobian of the function, use the ``has_aux`` flag to return the output + as an auxiliary object: + + >>> from torch.func import jacrev + >>> x = torch.randn(5) + >>> + >>> def f(x): + >>> return x.sin() + >>> + >>> def g(x): + >>> result = f(x) + >>> return result, result + >>> + >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x) + >>> assert torch.allclose(f_x, f(x)) + + :func:`jacrev` can be composed with vmap to produce batched + Jacobians: + + >>> from torch.func import jacrev, vmap + >>> x = torch.randn(64, 5) + >>> jacobian = vmap(jacrev(torch.sin))(x) + >>> assert jacobian.shape == (64, 5, 5) + + Additionally, :func:`jacrev` can be composed with itself to produce + Hessians + + >>> from torch.func import jacrev + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hessian = jacrev(jacrev(f))(x) + >>> assert torch.allclose(hessian, torch.diag(-x.sin())) + + By default, :func:`jacrev` computes the Jacobian with respect to the first + input. However, it can compute the Jacboian with respect to a different + argument by using ``argnums``: + + >>> from torch.func import jacrev + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacrev(f, argnums=1)(x, y) + >>> expected = torch.diag(2 * y) + >>> assert torch.allclose(jacobian, expected) + + Additionally, passing a tuple to ``argnums`` will compute the Jacobian + with respect to multiple arguments + + >>> from torch.func import jacrev + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacrev(f, argnums=(0, 1))(x, y) + >>> expectedX = torch.diag(torch.ones_like(x)) + >>> expectedY = torch.diag(2 * y) + >>> assert torch.allclose(jacobian[0], expectedX) + >>> assert torch.allclose(jacobian[1], expectedY) + + .. note:: + Using PyTorch ``torch.no_grad`` together with ``jacrev``. + Case 1: Using ``torch.no_grad`` inside a function: + + >>> def f(x): + >>> with torch.no_grad(): + >>> c = x ** 2 + >>> return x - c + + In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``. + + Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager: + + >>> with torch.no_grad(): + >>> jacrev(f)(x) + + In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the + outer one. This is because ``jacrev`` is a "function transform": its result + should not depend on the result of a context manager outside of ``f``. + """ + if not (chunk_size is None or chunk_size > 0): + raise ValueError("jacrev: `chunk_size` should be greater than 0.") + + @wraps(func) + def wrapper_fn(*args): + error_if_complex("jacrev", args, is_input=True) + vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux) + if has_aux: + output, vjp_fn, aux = vjp_out + else: + output, vjp_fn = vjp_out + + # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs] + flat_output, output_spec = tree_flatten(output) + + error_if_complex("jacrev", flat_output, is_input=False) + + # NB: vjp already checks that all outputs are tensors + # Step 1: Construct grad_outputs by splitting the standard basis + flat_output_numels = tuple(out.numel() for out in flat_output) + + primals = _slice_argnums(args, argnums) + flat_primals, primals_spec = tree_flatten(primals) + + def compute_jacobian_stacked(): + # Helper function to compute chunked Jacobian + # The intermediate chunked calculation are only + # scoped at this function level. + chunked_results = [] + for flat_basis_chunk in _chunked_standard_basis_for_( + flat_output, flat_output_numels, chunk_size=chunk_size + ): + if chunk_size == 1: + # sanity check. + for t in flat_basis_chunk: + assert t.size(0) == 1 + + flat_basis_chunk = tree_map( + lambda t: torch.squeeze(t, 0), flat_basis_chunk + ) + + basis = tree_unflatten(flat_basis_chunk, output_spec) + + if chunk_size == 1: + # Behaviour with `chunk_size=1` is same as `for-loop` + # i.e. user shouldn't deal with the limitations of vmap. + chunked_result = vjp_fn(basis) + else: # chunk_size is None or chunk_size != 1 + chunked_result = vmap(vjp_fn)(basis) + + flat_results = pytree.tree_leaves(chunked_result) + + if chunk_size == 1: + flat_results = tree_map( + lambda t: torch.unsqueeze(t, 0), flat_results + ) + + chunked_results.append(flat_results) + + if len(chunked_results) == 1: + # Short-circuit if we used a single chunk + return chunked_results[0] + + # Concatenate chunks. + flat_results = [] + # Iterate and concat the jacobians of different + # inputs. + for idx in range(len(flat_primals)): + r = tuple(r_[idx] for r_ in chunked_results) + flat_results.append(torch.cat(r, 0)) + + return flat_results + + def compute_jacobian_preallocate_and_copy(): + # Helper function to compute chunked Jacobian + # The intermediate chunked calculation are only + # scoped at this function level. + out_vec_size = sum(flat_output_numels) + + # Don't pre-allocate if we have a single chunk. + if not (chunk_size is None or chunk_size >= out_vec_size): + stacked_results = [ + primal.new_zeros(out_vec_size, *primal.shape) + for primal in flat_primals + ] + + for idx, flat_basis_chunk in enumerate( + _chunked_standard_basis_for_( + flat_output, flat_output_numels, chunk_size=chunk_size + ) + ): + if chunk_size == 1: + # sanity check. + for t in flat_basis_chunk: + assert t.size(0) == 1 + + flat_basis_chunk = [torch.squeeze(t, 0) for t in flat_basis_chunk] + + basis = tree_unflatten(flat_basis_chunk, output_spec) + + if chunk_size == 1: + # Behaviour with `chunk_size=1` is same as `for-loop` + # i.e. user shouldn't deal with the limitations of vmap. + chunked_result = vjp_fn(basis) + else: # chunk_size is None or chunk_size != 1 + chunked_result = vmap(vjp_fn)(basis) + + flat_results = pytree.tree_leaves(chunked_result) + + # Short-circuit if we have a single chunk. + if chunk_size is None or chunk_size >= out_vec_size: + if chunk_size == 1: # and out_vec_size == 1 + # Since we squeezed the output dim + flat_results = tree_map( + lambda t: torch.unsqueeze(t, 0), flat_results + ) + return flat_results + + for r, sr in zip(flat_results, stacked_results): + sr[idx * chunk_size : (idx + 1) * chunk_size].copy_(r) + + return stacked_results + + if _preallocate_and_copy: + flat_jacobians_per_input = compute_jacobian_preallocate_and_copy() + else: + flat_jacobians_per_input = compute_jacobian_stacked() + + # Step 2: The returned jacobian is one big tensor per input. In this step, + # we split each Tensor by output. + flat_jacobians_per_input = [ + result.split(flat_output_numels, dim=0) + for result in flat_jacobians_per_input + ] + flat_input_flat_output = [ + tuple( + split.view(out.shape + primal.shape) + for split, out in zip(splits, flat_output) + ) + for splits, primal in zip(flat_jacobians_per_input, flat_primals) + ] + + # Step 3: Right now, `jacobian` is a List[List[Tensor]]. + # The outer List corresponds to the number of primals, + # the inner List corresponds to the number of outputs. + # We need to: + # a. Exchange the order of the outer List and inner List + # b. tree_unflatten the inner Lists (which correspond to the primals) + # c. handle the argnums=int case + # d. tree_unflatten the outer List (which corresponds to the outputs) + flat_output_flat_input = tuple(zip(*flat_input_flat_output)) + + flat_output_input = tuple( + tree_unflatten(flat_input, primals_spec) + for flat_input in flat_output_flat_input + ) + + if isinstance(argnums, int): + flat_output_input = tuple( + _safe_zero_index(flat_input) for flat_input in flat_output_input + ) + output_input = tree_unflatten(flat_output_input, output_spec) + if has_aux: + return output_input, aux + return output_input + + return wrapper_fn + + +# NOTE: [Computing jacobian with vmap and vjp for multiple outputs] +# +# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). +# It turns out we can compute the jacobian of this function with a single +# call to autograd.grad by using vmap over the correct grad_outputs. +# +# Firstly, one way to compute the jacobian is to stack x**2 and x.sum() +# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()]) +# +# To get the first row of the jacobian, we call +# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0])) +# To get the 2nd row of the jacobian, we call +# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0])) +# and so on. +# +# Using vmap, we can vectorize all 4 of these computations into one by +# passing the standard basis for R^4 as the grad_output. +# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)). +# +# Now, how do we compute the jacobian *without stacking the output*? +# We can just split the standard basis across the outputs. So to +# compute the jacobian of f(x), we'd use +# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) +# The grad_outputs looks like the following: +# ( torch.tensor([[1, 0, 0], +# [0, 1, 0], +# [0, 0, 1], +# [0, 0, 0]]), +# torch.tensor([[0], +# [0], +# [0], +# [1]]) ) +# +# But we're not done yet! +# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) +# returns a Tensor of shape [4, 3]. We have to remember to split the +# jacobian of shape [4, 3] into two: +# - one of shape [3, 3] for the first output +# - one of shape [ 3] for the second output + + +def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None): + # This function: + # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. + # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. + # - Each chunk corresponds to one tensor. The chunk has the same dtype and + # device as the tensor + # + # For example, with tensor_numels = [1, 2, 1], this function returns: + # ( tensor([[1], tensor([[0, 0], tensor([[0], + # [0], [1, 0], [0], + # [0], [0, 1], [0], + # [0]]) , [0, 0]]) , [1]]) ) + # + # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) + # Precondition: tensors always has at least one element. + # + # See NOTE: [Computing jacobian with vmap and grad for multiple tensors] + # for context behind this function. + # NOTE: Argument `chunk_size` is used to generate chunked basis instead of + # one huge basis matrix. `chunk_size` dictates the maximum size of the + # basis matrix along dim=0. + assert len(tensors) == len(tensor_numels) + assert len(tensors) > 0 + assert chunk_size is None or chunk_size > 0 + total_numel = sum(tensor_numels) + if chunk_size and chunk_size < total_numel: + chunk_numels = get_chunk_sizes(total_numel, chunk_size) + else: # chunk_size is None or chunk_size >= total_numel + chunk_size = total_numel + chunk_numels = [total_numel] + + diag_start_indices = ( + 0, + *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind(), + ) + + for chunk_idx, total_numel in enumerate(chunk_numels): + chunks = tuple( + tensor.new_zeros(total_numel, tensor_numel) + for tensor, tensor_numel in zip(tensors, tensor_numels) + ) + + for chunk, diag_start_idx in zip(chunks, diag_start_indices): + chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1) + chunks = tuple( + chunk.view(total_numel, *tensor.shape) + for chunk, tensor in zip(chunks, tensors) + ) + yield chunks + + +def _construct_standard_basis_for(tensors, tensor_numels): + for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None): + return basis + + +def _validate_and_wrap_argnum(argnum, num_args): + if not isinstance(argnum, int): + raise RuntimeError(f"argnum must be int, got: {type(argnum)}") + if argnum >= 0 and argnum < num_args: + return argnum + if argnum < 0 and argnum >= -num_args: + return argnum + num_args + raise RuntimeError(f"Got argnum={argnum}, but only {num_args} positional inputs") + + +def _check_unique_non_empty(argnums): + if isinstance(argnums, tuple): + if len(argnums) == 0: + raise RuntimeError("argnums must be non-empty") + if len(set(argnums)) != len(argnums): + raise RuntimeError(f"argnums elements must be unique, got {argnums}") + + +def _replace_args(old_args, new_args, argnums): + if isinstance(argnums, int): + if len(new_args) != 1: + raise RuntimeError( + f"new_args should be of size 1, was of size {len(new_args)}" + ) + return tuple( + new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)) + ) + if isinstance(argnums, tuple): + if len(new_args) != len(argnums): + raise RuntimeError( + "new_args should have the same size as argnums. " + f"Argnums size {len(argnums)}, new_args size {len(new_args)}" + ) + + def get_right_elem(i): + return new_args[argnums.index(i)] if i in argnums else old_args[i] + + return tuple(get_right_elem(i) for i in range(len(old_args))) + raise RuntimeError(f"argnums must be int or Tuple[int, ...], got: {type(argnums)}") + + +def _validate_and_wrap_argnums(argnums, num_args): + if isinstance(argnums, int): + return _validate_and_wrap_argnum(argnums, num_args) + if isinstance(argnums, tuple): + return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums) + raise AssertionError("Should never get here") + + +def _slice_argnums(args, argnums, as_tuple=True): + if not isinstance(argnums, int) and not isinstance(argnums, tuple): + raise RuntimeError( + f"argnums must be int or Tuple[int, ...], got: {type(argnums)}" + ) + argnums = _validate_and_wrap_argnums(argnums, len(args)) + _check_unique_non_empty(argnums) + if isinstance(argnums, int): + if as_tuple: + return (args[argnums],) + else: + return args[argnums] + return tuple(args[i] for i in argnums) + + +JVP_NESTING = 0 + + +def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None: + if not isinstance(elts, tuple): + raise RuntimeError( + f"{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}" + ) + for elt in elts: + if isinstance(elt, torch.Tensor): + continue + raise RuntimeError( + f"{api}: Expected {argname} to be a tuple of Tensors, got " + f"a tuple with an element of type {type(elt)}" + ) + if len(elts) == 0: + raise RuntimeError( + f"{api}: Expected {argname} to be a non-empty tuple of Tensors." + ) + + +def assert_non_empty_tensor_output(output: list[Any], api: str) -> None: + if (len(output) == 1 and output[0] is None) or len(output) < 1: + raise RuntimeError( + f"{api}: Expected f to be a function that has non-empty output (got output = {output})" + ) + for o in output: + if not isinstance(o, torch.Tensor): + raise RuntimeError( + f"{api}: expected f(*primals) to return only tensors" + f", got unsupported type {type(o)}" + ) + + +def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None: + if isinstance(output, torch.Tensor): + return + if not isinstance(output, tuple): + raise RuntimeError( + f"{api}: Expected output of f to be a Tensor or Tensors, got " + f"{type(output)}" + ) + if len(output) == 0: + raise RuntimeError( + f"{api}: Expected output of f to be a non-empty tuple of Tensors." + ) + for out in output: + if isinstance(out, torch.Tensor): + continue + raise RuntimeError( + f"{api}: Expected output of f to be a Tensor or Tensors, got " + f"{type(out)} as an output" + ) + + +def assert_non_empty_list_of_tensors( + output: list[torch.Tensor], api: str, argname: str +) -> None: + if len(output) == 0: + raise RuntimeError(f"{api}: Expected {argname} to contain at least one Tensor.") + for out in output: + if isinstance(out, torch.Tensor): + continue + raise RuntimeError( + f"{api}: Expected {argname} to only contain Tensors, got {type(out)}" + ) + + +jvp_str = "jvp(f, primals, tangents)" + + +def safe_unpack_dual(dual, strict): + if not isinstance(dual, torch.Tensor): + raise RuntimeError( + f"{jvp_str}: expected f(*args) to return only tensors" + f", got unsupported type {type(dual)}" + ) + + primal, tangent = fwAD.unpack_dual(dual) + if tangent is None: + if strict: + raise RuntimeError( + "jvp(f, primals, tangents, strict=True): " + "The output of f is independent of " + "the inputs. This is not allowed with strict=True." + ) + tangent = torch.zeros_like(primal) + return primal, tangent + + +@exposed_in("torch.func") +def jvp( + func: Callable, + primals: Any, + tangents: Any, + *, + strict: bool = False, + has_aux: bool = False, +): + """ + Standing for the Jacobian-vector product, returns a tuple containing + the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at + ``primals``" times ``tangents``. This is also known as forward-mode autodiff. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + primals (Tensors): Positional arguments to ``func`` that must all be + Tensors. The returned function will also be computing the + derivative with respect to these arguments + tangents (Tensors): The "vector" for which Jacobian-vector-product is + computed. Must be the same structure and sizes as the inputs to + ``func``. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + other auxiliary objects that will not be differentiated. + Default: False. + + Returns: + Returns a ``(output, jvp_out)`` tuple containing the output of ``func`` + evaluated at ``primals`` and the Jacobian-vector product. + If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple. + + .. note:: + You may see this API error out with "forward-mode AD not implemented + for operator X". If so, please file a bug report and we will prioritize it. + + jvp is useful when you wish to compute gradients of a function R^1 -> R^N + + >>> from torch.func import jvp + >>> x = torch.randn([]) + >>> f = lambda x: x * torch.tensor([1., 2., 3]) + >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) + >>> assert torch.allclose(value, f(x)) + >>> assert torch.allclose(grad, torch.tensor([1., 2, 3])) + + :func:`jvp` can support functions with multiple inputs by passing in the + tangents for each of the inputs + + >>> from torch.func import jvp + >>> x = torch.randn(5) + >>> y = torch.randn(5) + >>> f = lambda x, y: (x * y) + >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) + >>> assert torch.allclose(output, x + y) + + """ + + return _jvp_with_argnums( + func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux + ) + + +def _jvp_with_argnums( + func: Callable, + primals: Any, + tangents: Any, + argnums: Optional[argnums_t], + *, + strict: bool = False, + has_aux: bool, +): + # This is the same function as jvp but also accepts an argnums argument + # Most args are the same as jvp except for the added argument + # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to. + # If None, computes the gradients with respect to all inputs (used for jvp). Default: None + # Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is + # given by argnums + # + # WARN: Users should NOT call this function directly and should just be calling jvp. + # It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers. + # + # NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd + # + # Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to + # the primals given by argnums + if not isinstance(primals, tuple): + raise RuntimeError( + f"{jvp_str}: Expected primals to be a tuple. " + f"E.g. it should be valid to call f(*primals)." + ) + diff_args = primals if argnums is None else _slice_argnums(primals, argnums) + flat_primals, primals_spec = tree_flatten(diff_args) + flat_tangents, tangents_spec = tree_flatten(tangents) + if primals_spec != tangents_spec: + raise RuntimeError( + f"{jvp_str}: Expected primals and tangents to have the same python " + f"structure. For example, if primals is a tuple of 3 tensors, " + f"tangents also must be. Got primals with structure {primals_spec} " + f"and tangents with structure {tangents_spec}" + ) + assert_non_empty_list_of_tensors(flat_primals, jvp_str, "primals") + assert_non_empty_list_of_tensors(flat_tangents, jvp_str, "tangents") + + global JVP_NESTING + + with jvp_increment_nesting() as level: + with fwAD._set_fwd_grad_enabled(True): + ctx = fwAD.dual_level if JVP_NESTING == 1 else contextlib.nullcontext + with ctx(): + flat_duals = tuple( + fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents) + ) + duals = tree_unflatten(flat_duals, primals_spec) + if argnums is not None: + primals = _wrap_all_tensors(primals, level) + duals = _replace_args(primals, duals, argnums) + result_duals = func(*duals) + if has_aux: + if not (isinstance(result_duals, tuple) and len(result_duals) == 2): + raise RuntimeError( + f"{jvp_str}: output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + result_duals, aux = result_duals + aux = _undo_create_differentiable(aux, level) + + result_duals, spec = tree_flatten(result_duals) + assert_non_empty_tensor_output(result_duals, jvp_str) + + primals_out, tangents_out = zip( + *[safe_unpack_dual(dual, strict) for dual in result_duals] + ) + primals_out = tree_map( + partial(_undo_create_differentiable, level=level), primals_out + ) + tangents_out = tree_map( + partial(_undo_create_differentiable, level=level), tangents_out + ) + + primals_out_unflatten = tree_unflatten(primals_out, spec) + tangents_out_unflatten = tree_unflatten(tangents_out, spec) + if has_aux: + return primals_out_unflatten, tangents_out_unflatten, aux + + return primals_out_unflatten, tangents_out_unflatten + + +def safe_unflatten(tensor, dim, shape): + if len(shape) == 0: + assert tensor.shape[dim] == 1 + return tensor.squeeze(dim) + return tensor.unflatten(dim, shape) + + +@exposed_in("torch.func") +def jacfwd( + func: Callable, + argnums: argnums_t = 0, + has_aux: bool = False, + *, + randomness: str = "error", +): + """ + Computes the Jacobian of ``func`` with respect to the arg(s) at index + ``argnum`` using forward-mode autodiff + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Jacobian with respect to. + Default: 0. + has_aux (bool): Flag indicating that ``func`` returns a + ``(output, aux)`` tuple where the first element is the output of + the function to be differentiated and the second element is + auxiliary objects that will not be differentiated. + Default: False. + randomness(str): Flag indicating what type of randomness to use. + See :func:`vmap` for more detail. Allowed: "different", "same", "error". + Default: "error" + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Jacobian of ``func`` with respect to the arg(s) at + ``argnums``. If ``has_aux is True``, then the returned function + instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` + is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. + + .. note:: + You may see this API error out with "forward-mode AD not implemented + for operator X". If so, please file a bug report and we will prioritize it. + An alternative is to use :func:`jacrev`, which has better operator coverage. + + A basic usage with a pointwise, unary operation will give a diagonal array + as the Jacobian + + >>> from torch.func import jacfwd + >>> x = torch.randn(5) + >>> jacobian = jacfwd(torch.sin)(x) + >>> expected = torch.diag(torch.cos(x)) + >>> assert torch.allclose(jacobian, expected) + + :func:`jacfwd` can be composed with vmap to produce batched + Jacobians: + + >>> from torch.func import jacfwd, vmap + >>> x = torch.randn(64, 5) + >>> jacobian = vmap(jacfwd(torch.sin))(x) + >>> assert jacobian.shape == (64, 5, 5) + + If you would like to compute the output of the function as well as the + jacobian of the function, use the ``has_aux`` flag to return the output + as an auxiliary object: + + >>> from torch.func import jacfwd + >>> x = torch.randn(5) + >>> + >>> def f(x): + >>> return x.sin() + >>> + >>> def g(x): + >>> result = f(x) + >>> return result, result + >>> + >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x) + >>> assert torch.allclose(f_x, f(x)) + + Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev` + to produce Hessians + + >>> from torch.func import jacfwd, jacrev + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hessian = jacfwd(jacrev(f))(x) + >>> assert torch.allclose(hessian, torch.diag(-x.sin())) + + By default, :func:`jacfwd` computes the Jacobian with respect to the first + input. However, it can compute the Jacboian with respect to a different + argument by using ``argnums``: + + >>> from torch.func import jacfwd + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacfwd(f, argnums=1)(x, y) + >>> expected = torch.diag(2 * y) + >>> assert torch.allclose(jacobian, expected) + + Additionally, passing a tuple to ``argnums`` will compute the Jacobian + with respect to multiple arguments + + >>> from torch.func import jacfwd + >>> def f(x, y): + >>> return x + y ** 2 + >>> + >>> x, y = torch.randn(5), torch.randn(5) + >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y) + >>> expectedX = torch.diag(torch.ones_like(x)) + >>> expectedY = torch.diag(2 * y) + >>> assert torch.allclose(jacobian[0], expectedX) + >>> assert torch.allclose(jacobian[1], expectedY) + + """ + + @wraps(func) + def wrapper_fn(*args): + error_if_complex("jacfwd", args, is_input=True) + primals = args if argnums is None else _slice_argnums(args, argnums) + flat_primals, primals_spec = tree_flatten(primals) + flat_primals_numels = tuple(p.numel() for p in flat_primals) + flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels) + basis = tree_unflatten(flat_basis, primals_spec) + + def push_jvp(basis): + output = _jvp_with_argnums( + func, args, basis, argnums=argnums, has_aux=has_aux + ) + # output[0] is the output of `func(*args)` + error_if_complex("jacfwd", output[0], is_input=False) + if has_aux: + _, jvp_out, aux = output + return jvp_out, aux + _, jvp_out = output + return jvp_out + + results = vmap(push_jvp, randomness=randomness)(basis) + if has_aux: + results, aux = results + # aux is in the standard basis format, e.g. NxN matrix + # We need to fetch the first element as original `func` output + flat_aux, aux_spec = tree_flatten(aux) + flat_aux = [value[0] for value in flat_aux] + aux = tree_unflatten(flat_aux, aux_spec) + + jac_outs, spec = tree_flatten(results) + # Most probably below output check can never raise an error + # as jvp should test the output before + # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)') + + jac_outs_ins = tuple( + tuple( + safe_unflatten(jac_out_in, -1, primal.shape) + for primal, jac_out_in in zip( + flat_primals, + jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1), + ) + ) + for jac_out in jac_outs + ) + jac_outs_ins = tuple( + tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins + ) + + if isinstance(argnums, int): + jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins) + if has_aux: + return tree_unflatten(jac_outs_ins, spec), aux + return tree_unflatten(jac_outs_ins, spec) + + return wrapper_fn + + +@exposed_in("torch.func") +def hessian(func, argnums=0): + """ + Computes the Hessian of ``func`` with respect to the arg(s) at index + ``argnum`` via a forward-over-reverse strategy. + + The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is + a good default for good performance. It is possible to compute Hessians + through other compositions of :func:`jacfwd` and :func:`jacrev` like + ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``. + + Args: + func (function): A Python function that takes one or more arguments, + one of which must be a Tensor, and returns one or more Tensors + argnums (int or Tuple[int]): Optional, integer or tuple of integers, + saying which arguments to get the Hessian with respect to. + Default: 0. + + Returns: + Returns a function that takes in the same inputs as ``func`` and + returns the Hessian of ``func`` with respect to the arg(s) at + ``argnums``. + + .. note:: + You may see this API error out with "forward-mode AD not implemented + for operator X". If so, please file a bug report and we will prioritize it. + An alternative is to use ``jacrev(jacrev(func))``, which has better + operator coverage. + + A basic usage with a R^N -> R^1 function gives a N x N Hessian: + + >>> from torch.func import hessian + >>> def f(x): + >>> return x.sin().sum() + >>> + >>> x = torch.randn(5) + >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) + >>> assert torch.allclose(hess, torch.diag(-x.sin())) + + """ + return jacfwd(jacrev(func, argnums), argnums) + + +@doesnt_support_saved_tensors_hooks +def grad_and_value_impl(func, argnums, has_aux, args, kwargs) -> Callable: + with grad_increment_nesting() as level: + output, aux, grad_input = None, None, None + # See NOTE [grad and vjp interaction with no_grad] + with torch.enable_grad(): + args = _wrap_all_tensors(args, level) + kwargs = _wrap_all_tensors(kwargs, level) + diff_args = _slice_argnums(args, argnums, as_tuple=False) + tree_map_(partial(_create_differentiable, level=level), diff_args) + + output = func(*args, **kwargs) + if has_aux: + if not (isinstance(output, tuple) and len(output) == 2): + raise RuntimeError( + "grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) " + "if has_aux is True" + ) + output, aux = output + + if not isinstance(output, torch.Tensor): + raise RuntimeError( + "grad_and_value(f)(*args): Expected f(*args) " + f"to return a Tensor, got {type(output)}" + ) + if output.dim() != 0: + raise RuntimeError( + "grad_and_value(f)(*args): Expected f(*args) " + "to return a scalar Tensor, got tensor with " + f"{output.dim()} dims. Maybe you wanted to " + "use the vjp or jacrev APIs instead?" + ) + + flat_diff_args, spec = tree_flatten(diff_args) + + # NB: need create_graph so that backward pass isn't run in no_grad mode + flat_outputs = _as_tuple(output) + flat_grad_input = _autograd_grad( + flat_outputs, flat_diff_args, create_graph=True + ) + grad_input = tree_unflatten(flat_grad_input, spec) + + grad_input = _undo_create_differentiable(grad_input, level) + output = _undo_create_differentiable(output, level) + if has_aux: + aux = _undo_create_differentiable(aux, level) + + if has_aux: + return grad_input, (output, aux) + return grad_input, output + + +def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs): + results = grad_and_value_impl(func, argnums, has_aux, args, kwargs) + if has_aux: + grad, (_, aux) = results + return grad, aux + grad, _ = results + return grad + + +def _maybe_wrap_functional_tensor( + maybe_tensor, level, *, _python_functionalize: bool = False +): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + wrapped = _wrap_functional_tensor(maybe_tensor, level) + _assert_wrapped_functional(maybe_tensor, wrapped) + if _python_functionalize: + out = FunctionalTensor(wrapped) + torch._mirror_autograd_meta_to(maybe_tensor, out) + return out + return wrapped + + +def _wrap_all_tensors_to_functional( + tensor_pytree, level, *, _python_functionalize: bool = False +): + return tree_map( + partial( + lambda x: _maybe_wrap_functional_tensor( + x, level, _python_functionalize=_python_functionalize + ) + ), + tensor_pytree, + ) + + +def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool): + if not isinstance(maybe_tensor, torch.Tensor): + return maybe_tensor + if isinstance(maybe_tensor, FunctionalTensor): + maybe_tensor = maybe_tensor.elem + + if not torch._is_functional_tensor(maybe_tensor): + # If it's not a functional tensor, just return it. + # This can happen if we functionalize a fn that returns a global, + # which was never wrapped properly. + return maybe_tensor + # Sync any pending updates on the output tensor + torch._sync(maybe_tensor) + return _unwrap_functional_tensor(maybe_tensor, reapply_views) + + +def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool): + return tree_map( + lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), + tensor_pytree, + ) + + +@exposed_in("torch.func") +def functionalize(func: Callable, *, remove: str = "mutations") -> Callable: + """ + functionalize is a transform that can be used to remove (intermediate) + mutations and aliasing from a function, while preserving the function's + semantics. + + ``functionalize(func)`` returns a new function with the same semantics + as ``func``, but with all intermediate mutations removed. + Every inplace operation performed on an intermediate tensor: + ``intermediate.foo_()`` + gets replaced by its out-of-place equivalent: + ``intermediate_updated = intermediate.foo()``. + + functionalize is useful for shipping a pytorch program off to + backends or compilers that aren't able to easily represent + mutations or aliasing operators. + + Args: + func (Callable): A Python function that takes one or more arguments. + remove (str): An optional string argument, that takes on either + the value 'mutations' or 'mutations_and_views'. + If 'mutations' is passed in then all mutating operators + will be replaced with their non-mutating equivalents. + If 'mutations_and_views' is passed in, then additionally, all aliasing + operators will be replaced with their non-aliasing equivalents. + Default: 'mutations'. + + Returns: + Returns a new "functionalized" function. It takes the same inputs as + ``func``, and has the same behavior, but any mutations + (and optionally aliasing) performed on intermediate tensors + in the function will be removed. + + functionalize will also remove mutations (and views) that were performed on function inputs. + However to preserve semantics, functionalize will "fix up" the mutations after + the transform has finished running, by detecting if any tensor inputs "should have" + been mutated, and copying the new data back to the inputs if necessary. + + + Example:: + + >>> # xdoctest: +SKIP + >>> import torch + >>> from torch.fx.experimental.proxy_tensor import make_fx + >>> from torch.func import functionalize + >>> + >>> # A function that uses mutations and views, but only on intermediate tensors. + >>> def f(a): + ... b = a + 1 + ... c = b.view(-1) + ... c.add_(1) + ... return b + ... + >>> inpt = torch.randn(2) + >>> + >>> out1 = f(inpt) + >>> out2 = functionalize(f)(inpt) + >>> + >>> # semantics are the same (outputs are equivalent) + >>> print(torch.allclose(out1, out2)) + True + >>> + >>> f_traced = make_fx(f)(inpt) + >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> + >>> print(f_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]) + add_ = torch.ops.aten.add_(view, 1); view = None + return add + + >>> print(f_no_mutations_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view = torch.ops.aten.view(add, [-1]); add = None + add_1 = torch.ops.aten.add(view, 1); view = None + view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None + return view_1 + + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + add = torch.ops.aten.add(a_1, 1); a_1 = None + view_copy = torch.ops.aten.view_copy(add, [-1]); add = None + add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None + return view_copy_1 + + + >>> # A function that mutates its input tensor + >>> def f(a): + ... b = a.view(-1) + ... b.add_(1) + ... return a + ... + >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) + >>> # + >>> # All mutations and views have been removed, + >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input + >>> # after the function has completed. + >>> print(f_no_mutations_and_views_traced.code) + + + + def forward(self, a_1): + view_copy = torch.ops.aten.view_copy(a_1, [-1]) + add = torch.ops.aten.add(view_copy, 1); view_copy = None + view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None + copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None + return view_copy_1 + + + There are a few "failure modes" for functionalize that are worth calling out: + (1) Like other torch.func transforms, `functionalize()` doesn't work with functions + that directly use `.backward()`. The same is true for torch.autograd.grad. + If you want to use autograd, you can compute gradients directly + with `functionalize(grad(f))`. + (2) Like other torch.func transforms, `functionalize()` doesn't work with global state. + If you call `functionalize(f)` on a function that takes views / mutations of + non-local state, functionalization will simply no-op and pass the view/mutation + calls directly to the backend. + One way to work around this is is to ensure that any non-local state creation + is wrapped into a larger function, which you then call functionalize on. + (3) `resize_()` has some limitations: functionalize will only work on programs + that use resize_()` as long as the tensor being resized is not a view. + (4) `as_strided()` has some limitations: functionalize will not work on + `as_strided()` calls that result in tensors with overlapping memory. + + + Finally, a helpful mental model for understanding functionalization is that + most user pytorch programs are writing with the public torch API. + When executed, torch operators are generally decomposed into + our internal C++ "ATen" API. + The logic for functionalization happens entirely at the level of ATen. + Functionalization knows how to take every aliasing operator in ATen, + and map it to its non-aliasing equivalent + (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``), + and how to take every mutating operator in ATen, + and map it to its non-mutating equivalent + (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``), + while tracking aliases and mutations out-of-line to know when to fix things up. + Information about which ATen operators are aliasing or mutating all comes from + https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml. + """ + if remove == "mutations": + reapply_views = True + elif remove == "mutations_and_views": + reapply_views = False + else: + raise RuntimeError( + f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}." + " Valid options are:\n" + " remove='mutations': all inplace and out= operators will be removed from the program, and replaced" + " with their out-of-place equivalents.\n" + " remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be" + " replaced with their non-aliasing counterparts, {view}_copy.\n" + ) + + @wraps(func) + def wrapped(*args, **kwargs): + try: + func_level = _func_increment_nesting(reapply_views) + func_args = _wrap_all_tensors_to_functional(args, func_level) + func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level) + + flattened_unwrapped_args = pytree.arg_tree_leaves(*args) + flattened_wrapped_args = pytree.arg_tree_leaves(*func_args) + flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs) + flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs) + + func_outputs = func(*func_args, **func_kwargs) + outputs = _unwrap_all_tensors_from_functional( + func_outputs, reapply_views=reapply_views + ) + + for a in flattened_wrapped_args + flattened_wrapped_kwargs: + if isinstance(a, torch.Tensor): + # Call sync_() on the inputs, to ensure that any pending mutations have been applied. + torch._sync(a) + + # And if any mutations were applied to the inputs, we need to propagate them back to the user. + for unwrapped, wrapped in zip( + flattened_unwrapped_args, flattened_wrapped_args + ): + if isinstance(unwrapped, torch.Tensor) and isinstance( + wrapped, torch.Tensor + ): + _propagate_functional_input_mutation(unwrapped, wrapped) + for unwrapped, wrapped in zip( + flattened_unwrapped_kwargs, flattened_wrapped_kwargs + ): + if isinstance(unwrapped, torch.Tensor) and isinstance( + wrapped, torch.Tensor + ): + _propagate_functional_input_mutation(unwrapped, wrapped) + + return outputs + finally: + _func_decrement_nesting() + + return wrapped + + +@exposed_in("torch.func") +def linearize(func: Callable, *primals) -> tuple[Any, Callable]: + """ + Returns the value of ``func`` at ``primals`` and linear approximation + at ``primals``. + + Args: + func (Callable): A Python function that takes one or more arguments. + primals (Tensors): Positional arguments to ``func`` that must all be + Tensors. These are the values at which the function is linearly approximated. + + Returns: + Returns a ``(output, jvp_fn)`` tuple containing the output of ``func`` + applied to ``primals`` and a function that computes the jvp of + ``func`` evaluated at ``primals``. + + linearize is useful if jvp is to be computed multiple times at ``primals``. However, + to achieve this, linearize saves intermediate computation and has higher memory requirements + than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient + to compute vmap(jvp) instead of using linearize. + + .. note:: + linearize evaluates ``func`` twice. Please file an issue for an implementation + with a single evaluation. + + Example:: + + >>> import torch + >>> from torch.func import linearize + >>> def fn(x): + ... return x.sin() + ... + >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) + >>> jvp_fn(torch.ones(3, 3)) + tensor([[1., 1., 1.], + [1., 1., 1.], + [1., 1., 1.]]) + >>> + + """ + # Note: We evaluate `fn` twice. + # Once for returning the output and other while + # tracing the graph. + # If this becomes a bottle-neck, we should update + # make_fx such that it also returns the output. + + output = func(*primals) + _, output_spec = tree_flatten(output) + + flat_primals, primals_argspec = tree_flatten(primals) + + # tangents for tracing + flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals) + + # function to trace + def trace_fn(flat_tangents): + with fwAD.dual_level(): + flat_duals = tuple( + fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents) + ) + duals = tree_unflatten(flat_duals, primals_argspec) + output = func(*duals) + tangents = tree_map_only( + torch.Tensor, lambda dual: safe_unpack_dual(dual, False)[1], output + ) + + return tangents + + jvp_graph = lazy_dynamo_disallow(make_fx)(trace_fn)(flat_tangents) + const_folded_jvp_graph = lazy_dynamo_disallow(const_fold.split_const_subgraphs)( + jvp_graph + ) + + # Hold only the meta-data regarding the primals. + flat_primals_shape = tuple(p.shape for p in flat_primals) + flat_primals_device = tuple(p.device for p in flat_primals) + flat_primals_dtype = tuple(p.dtype for p in flat_primals) + + def forward_ad_checks(flat_tangents): + for idx, t in enumerate(flat_tangents): + if t.shape != flat_primals_shape[idx]: + msg = ( + f"tangent:{idx} with shape {t.shape} in flattened " + f"pytree doesn't match the shape {flat_primals_shape[idx]} " + "of the corresponding primal." + ) + raise RuntimeError(msg) + + if t.device != flat_primals_device[idx]: + msg = ( + f"tangent:{idx} with device {t.device} in flattened " + f"pytree doesn't match the device {flat_primals_device[idx]} " + "of the corresponding primal." + ) + raise RuntimeError(msg) + + if t.dtype != flat_primals_dtype[idx]: + msg = ( + f"tangent:{idx} with dtype {t.dtype} in flattened " + f"pytree doesn't match the dtype {flat_primals_dtype[idx]} " + "of the corresponding primal." + ) + raise RuntimeError(msg) + + # jvp_fn : callable to return + # It takes care of checking the argspec of tangents, + # calling the folded fx graph and unflattening fx graph output + def jvp_fn(*tangents): + flat_tangents, tangent_argspec = tree_flatten(tangents) + if tangent_argspec != primals_argspec: + raise RuntimeError( + f"Expected the tangents {tangent_argspec} to have " + f"the same argspec as the primals {primals_argspec}" + ) + + forward_ad_checks(flat_tangents) + + flat_output = const_folded_jvp_graph(*flat_tangents) + # const folded graph can return flat output, + # so transform output. + return tree_unflatten(flat_output, output_spec) + + return output, jvp_fn + + +@exposed_in("torch.func") +def debug_unwrap(tensor: torch.Tensor, *, recurse=True) -> torch.Tensor: + """Unwraps a functorch tensor (e.g. BatchedTensor, GradTrackingTensor) to its underlying tensor. + + This function should only be used in a debug setting (e.g. trying to print the + value of a Tensor in a debugger). Otherwise, using the result of function + inside of a function being transformed will lead to undefined behavior. + """ + if not is_functorch_wrapped_tensor(tensor): + return tensor + result = get_unwrapped(tensor) + if recurse: + return debug_unwrap(result) + return result diff --git a/phivenv/Lib/site-packages/torch/_functorch/functional_call.py b/phivenv/Lib/site-packages/torch/_functorch/functional_call.py new file mode 100644 index 0000000000000000000000000000000000000000..82bdca6a260f1f25d2d05479c695dfaf8055a20f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/functional_call.py @@ -0,0 +1,253 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch._functorch.utils import exposed_in + + +@exposed_in("torch.func") +def functional_call( + module: "torch.nn.Module", + parameter_and_buffer_dicts: Union[dict[str, Tensor], Sequence[dict[str, Tensor]]], + args: Optional[Union[Any, tuple]] = None, + kwargs: Optional[dict[str, Any]] = None, + *, + tie_weights: bool = True, + strict: bool = False, +): + r"""Performs a functional call on the module by replacing the module parameters + and buffers with the provided ones. + + .. note:: If the module has active parametrizations, passing a value in the + :attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter + name will completely disable the parametrization. + If you want to apply the parametrization function to the value passed + please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. + + .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected + in the ``parameter_and_buffer_dicts`` input. + + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # does self.foo = self.foo + 1 + >>> print(mod.foo) # tensor(0.) + >>> functional_call(mod, a, torch.ones(())) + >>> print(mod.foo) # tensor(0.) + >>> print(a['foo']) # tensor(1.) + + .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the + tie_weights flag. + + Example:: + + >>> a = {'foo': torch.zeros(())} + >>> # xdoctest: +SKIP + >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied + >>> print(mod.foo) # tensor(1.) + >>> mod(torch.zeros(())) # tensor(2.) + >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too + >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated + >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} + >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) + + An example of passing multiple dictionaries + + .. code-block:: python + + a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries + mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer + print(mod.weight) # tensor(...) + print(mod.buffer) # tensor(...) + x = torch.randn((1, 1)) + print(x) + functional_call(mod, a, x) # same as x + print(mod.weight) # same as before functional_call + + + And here is an example of applying the grad transform over the parameters + of a model. + + .. code-block:: python + + import torch + import torch.nn as nn + from torch.func import functional_call, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + + def compute_loss(params, x, t): + y = functional_call(model, params, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t) + + .. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the + parameters for better performance and memory usage + + Example:: + + >>> detached_params = {k: v.detach() for k, v in model.named_parameters()} + >>> grad_weights = grad(compute_loss)(detached_params, x, t) + >>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad + + This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking + outside of the transforms, this will result in less memory usage and faster speeds. + + Args: + module (torch.nn.Module): the module to call + parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in + the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can + be used together + args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. + kwargs (dict): keyword arguments to be passed to the module call + tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as + tied in the reparameterized version. Therefore, if True and different values are passed for the tied + parameters and buffers, it will error. If False, it will not respect the originally tied parameters and + buffers unless the values passed for both weights are the same. Default: True. + strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and + buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will + error. Default: False. + + Returns: + Any: the result of calling ``module``. + """ + if isinstance(parameter_and_buffer_dicts, dict): + parameters_and_buffers = parameter_and_buffer_dicts + elif isinstance(parameter_and_buffer_dicts, Sequence): + if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts): + raise ValueError( + "Expected all elements of parameter_and_buffer_dicts to be dictionaries" + ) + all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()] + all_keys_counter: dict[str, int] = {} + for k in all_keys: + v = all_keys_counter.get(k, 0) + all_keys_counter[k] = v + 1 + repeated_keys = [key for key, n in all_keys_counter.items() if n > 1] + if len(repeated_keys) > 0: + raise ValueError( + f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous" + ) + parameters_and_buffers = { + k: v for d in parameter_and_buffer_dicts for k, v in d.items() + } + else: + raise ValueError( + f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, " + f"but got {type(parameter_and_buffer_dicts)}" + ) + + return nn.utils.stateless._functional_call( + module, + parameters_and_buffers, + args, + kwargs, + tie_weights=tie_weights, + strict=strict, + ) + + +@exposed_in("torch.func") +def stack_module_state( + models: Union[Sequence[nn.Module], nn.ModuleList], +) -> tuple[dict[str, Any], dict[str, Any]]: + """stack_module_state(models) -> params, buffers + + Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. + + Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries + that stack all of their parameters and buffers together, indexed by name. + The stacked parameters are optimizable (i.e. they are new leaf nodes in the + autograd history that are unrelated to the original parameters and can be + passed directly to an optimizer). + + Here's an example of how to ensemble over a very simple model: + + .. code-block:: python + + num_models = 5 + batch_size = 64 + in_features, out_features = 3, 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + data = torch.randn(batch_size, 3) + + def wrapper(params, buffers, data): + return torch.func.functional_call(models[0], (params, buffers), data) + + params, buffers = stack_module_state(models) + output = vmap(wrapper, (0, 0, None))(params, buffers, data) + + assert output.shape == (num_models, batch_size, out_features) + + When there's submodules, this follows state dict naming conventions + + .. code-block:: python + + import torch.nn as nn + class Foo(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + hidden = 4 + self.l1 = nn.Linear(in_features, hidden) + self.l2 = nn.Linear(hidden, out_features) + + def forward(self, x): + return self.l2(self.l1(x)) + + num_models = 5 + in_features, out_features = 3, 3 + models = [Foo(in_features, out_features) for i in range(num_models)] + params, buffers = stack_module_state(models) + print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias" + + .. warning:: + All of the modules being stacked together must be the same (except for + the values of their parameters/buffers). For example, they should be in the + same mode (training vs eval). + """ + if len(models) == 0: + raise RuntimeError("stack_module_state: Expected at least one model, got 0.") + if not (all(m.training for m in models) or all(not m.training for m in models)): + raise RuntimeError( + "stack_module_state: Expected all models to have the same training/eval mode." + ) + model0_typ = type(models[0]) + if not all(type(m) == model0_typ for m in models): + raise RuntimeError( + "stack_module_state: Expected all models to be of the same class." + ) + all_params = [dict(model.named_parameters()) for model in models] + params = { + k: construct_stacked_leaf(tuple(params[k] for params in all_params), k) + for k in all_params[0] + } + all_buffers = [dict(model.named_buffers()) for model in models] + buffers = { + k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k) + for k in all_buffers[0] + } + + return params, buffers + + +def construct_stacked_leaf( + tensors: Union[tuple[Tensor, ...], list[Tensor]], name: str +) -> Tensor: + all_requires_grad = all(t.requires_grad for t in tensors) + none_requires_grad = all(not t.requires_grad for t in tensors) + if not all_requires_grad and not none_requires_grad: + raise RuntimeError( + f"Expected {name} from each model to have the same .requires_grad" + ) + result = torch.stack(tensors) + if all_requires_grad: + result = result.detach().requires_grad_() + return result diff --git a/phivenv/Lib/site-packages/torch/_functorch/fx_minifier.py b/phivenv/Lib/site-packages/torch/_functorch/fx_minifier.py new file mode 100644 index 0000000000000000000000000000000000000000..a55c5a1cee08a66f7011990589965ffb8638bbdc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/fx_minifier.py @@ -0,0 +1,501 @@ +# mypy: ignore-errors + +import copy +import math +import os +import sys +from dataclasses import dataclass +from functools import partial, wraps +from typing import Callable + +import torch +import torch.fx as fx +from torch.hub import tqdm +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._content_store import ContentStoreWriter + +from .compile_utils import get_outputs, get_placeholders + + +is_tuple = object() + + +@dataclass +class LoadTensorMeta: + size: list[int] + stride: list[int] + dtype: torch.dtype + device: torch.device + + +class ConcreteProp(torch.fx.Interpreter): + def __init__(self, mod, *, writer=None, skip_offload=False): + super().__init__(mod) + self.writer = writer + self.skip_offload = skip_offload + self.seen_storages = set() + + def run_node(self, n): + self.pbar.update(1) + r = super().run_node(n) + name = n.name + + if isinstance(r, torch.Tensor): + if self.writer is None: + n.meta["concrete_value"] = r + else: + if StorageWeakRef(r.untyped_storage()) in self.seen_storages: + # Refuse to offload tensors which alias other live + # tensors, because this will violate operator contracts + n.meta["concrete_value"] = None + else: + if not self.skip_offload: + self.writer.write_tensor(os.path.join("eager", name), r) + n.meta["concrete_value"] = LoadTensorMeta( + r.size(), r.stride(), r.dtype, r.device + ) + self.seen_storages.add(StorageWeakRef(r.untyped_storage())) + else: + n.meta["concrete_value"] = is_tuple + + return r + + def propagate(self, *args): + with tqdm( + desc="Saving intermediates for delta debugging", + total=len(self.module.graph.nodes), + disable=self.writer is None, + ) as pbar: + self.pbar = pbar + r = super().run(*args) + if not self.skip_offload: + pbar.set_description( + "Saved! To skip next time, run with --skip-saving-eager-intermediates" + ) + return r + + +def is_load_tensor_node(node): + return ( + node.op == "call_function" + and node.target is torch.ops.debugprims.load_tensor.default + ) + + +# inplace modifies node/inps +def _convert_node_to_placeholder(graph, node, inps): + if node.op == "output" or node.op == "placeholder": + return False + + if is_load_tensor_node(node): + return False + + concrete_val = node.meta.get("concrete_value", None) + + if isinstance(concrete_val, torch.Tensor): + node.op = "placeholder" + node.target = node.name + node.args = () + node.kwargs = {} + + inps.append(concrete_val) + return True + + elif concrete_val is None: + return False + + elif concrete_val is is_tuple: + r = False + for tuple_user in list(node.users): + r = _convert_node_to_placeholder(graph, tuple_user, inps) or r + # NB: We must not erase the node at this point, because + # we are iterating over the nodes and this would change + # the iteration order + # graph.erase_node(node) + return r + + elif isinstance(concrete_val, LoadTensorMeta): + node.op = "call_function" + node.target = torch.ops.debugprims.load_tensor.default + node.args = ( + os.path.join("eager", node.name), + concrete_val.size, + concrete_val.stride, + ) + node.kwargs = { + "device": concrete_val.device, + "dtype": concrete_val.dtype, + } + return True + + return False + + +def create_minified_hlo_graph(minified_fx_graph, inputs): + """ + Takes minified FX graph as primary input, and ports it to HLO via StableHLO + Provides minified HLO graph as output, and archive them to local directory + """ + hlo_dir = f"{os.getcwd()}/hlo_files" + os.makedirs(hlo_dir, exists_ok=True) + + from torch_xla.stablehlo import save_torch_model_as_stablehlo + + save_torch_model_as_stablehlo(minified_fx_graph, inputs, hlo_dir) + + +def dump_state(fx_g, inps): + print( + f""" +# Working Repro with {len(fx_g.graph.nodes)} nodes +inps = {[(i.shape, i.dtype, i.device.type) for i in inps]} +inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps] +{fx_g.code} +""" + ) + + +def is_power_of_two(n): + if n == 0: + return False + return (n & (n - 1)) == 0 + + +@dataclass +class ReproState: + graph: fx.Graph + inps: list[torch.Tensor] + + def __post_init__(self): + ph_nodes = get_placeholders(self.graph) + assert len(ph_nodes) == len(self.inps) + + +def minifier( + fail_f: fx.GraphModule, + inps, + module_fails, + dump_state: Callable = dump_state, + *, + save_dir=None, + offload_to_disk=False, + skip_offload=False, + skip_sanity=False, + max_granularity=None, +): + """ + Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails. + + Does 2 main strategies: + 1. Truncates suffix: Removes some suffix from the graph and sets a new output. + 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, + tries replacing quarter of the graph, etc. + + >>> # xdoctest: +SKIP(failing) + >>> failing_function = fx.symbolic_trace(f) + >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps)) + + note: module_fails returns True if it fails. + """ + assert isinstance(inps, (tuple, list)) + + failing_graph = fail_f.graph + cur_size = len(failing_graph.nodes) + + if max_granularity is not None and not is_power_of_two(max_granularity): + raise RuntimeError(f"max_granularity {max_granularity} not power of two") + + num_queries = 0 + + def deepcopy_fx_graph(fx_graph): + return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph + + def graph_fails(graph, inps): + nonlocal num_queries + graph = copy.deepcopy(graph) + num_queries += 1 + mod = fx.GraphModule(fail_f, graph) + mod.graph.lint() + return module_fails(mod, inps) + + writer = None + if offload_to_disk: + writer = ContentStoreWriter(save_dir) + + ConcreteProp(fail_f, writer=writer, skip_offload=skip_offload).propagate(*inps) + if not skip_sanity and not graph_fails(failing_graph, inps): + raise RuntimeError("Input graph did not fail the tester") + print(f"Started off with {cur_size} nodes", file=sys.stderr) + + def _register_strategy(strategy: Callable, name: str): + @wraps(strategy) + def new_func(old_state: ReproState, granularity=1): + print(file=sys.stderr) + print( + f"Strategy: {name} (G: {granularity}) " + f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)", + file=sys.stderr, + ) + new_state = strategy( + deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity + ) + if new_state is not None: + new_nodes = len(new_state.graph.nodes) + old_nodes = len(old_state.graph.nodes) + new_inps = len(new_state.inps) + old_inps = len(old_state.inps) + new_outs = len(get_outputs(new_state.graph)) + old_outs = len(get_outputs(old_state.graph)) + progress_made = False + if new_nodes < old_nodes: + progress_made = True + print( + f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes", + file=sys.stderr, + ) + if new_inps > old_inps: + progress_made = True + print( + f"SUCCESS: Went from {old_inps} to {new_inps} inputs", + file=sys.stderr, + ) + if new_outs < old_outs: + progress_made = True + print( + f"SUCCESS: Went from {old_outs} to {new_outs} outputs", + file=sys.stderr, + ) + + if not progress_made: + raise RuntimeError("Success raised but no progress made?") + + if not graph_fails(new_state.graph, new_state.inps): + print( + "WARNING: Something went wrong, not applying this minification", + file=sys.stderr, + ) + return None + return new_state + else: + print(f"FAIL: {name}", file=sys.stderr) + return None + + return new_func + + def register_strategy(name: str): + return partial(_register_strategy, name=name) + + @register_strategy("Truncate suffix") + def remove_suffix(cur_graph, cur_inps, granularity): + tested = set() + new_graph = fx.Graph() + env = {} + for idx, node in enumerate(cur_graph.nodes): + new_node = new_graph.node_copy(node, lambda x: env[x]) + if node.op not in ["placeholder", "output"]: + # If idx is divisible by (granularity * 2), it would have been checked already. + if ( + idx % granularity == 0 + and (idx % (granularity * 2) != 0) + and idx not in tested + ): + output_node = new_graph.output((new_node,)) + if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails( + new_graph, cur_inps + ): + return ReproState(new_graph, cur_inps) + else: + tested.add(idx) + new_graph.erase_node(output_node) + env[node] = new_node + return None + + @register_strategy("Remove outputs") + def remove_outputs(cur_graph, cur_inps, granularity): + granularity = max(1, granularity // 2) + for idx, node in enumerate(cur_graph.nodes): + node.idx = idx + if node.op == "output": + output = node + break + + if isinstance(output.args[0], fx.Node): + return None + + output_args = sorted( + output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9) + ) + if len(output_args) == 1: + return None + + for idx in range(0, len(output_args), granularity): + output.args = (output_args[:idx] + output_args[idx + granularity :],) + if graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + def remove_unused_inputs_unchecked(cur_state: ReproState): + cur_graph = cur_state.graph + cur_inps = cur_state.inps + ph_nodes = get_placeholders(cur_graph) + assert len(ph_nodes) == len(cur_inps) + + new_inps = [] + for idx in range(len(ph_nodes)): + if len(ph_nodes[idx].users) == 0: + cur_graph.erase_node(ph_nodes[idx]) + else: + new_inps.append(cur_inps[idx]) + if len(new_inps) < len(cur_inps): + return ReproState(cur_graph, new_inps) + return None + + def remove_unused_inputs_checked(cur_state: ReproState): + new_state = remove_unused_inputs_unchecked(cur_state) + if new_state is not None and graph_fails(new_state.graph, new_state.inps): + return new_state + return None + + def _remove_unused_wrapper(cur_graph, cur_inps, granularity): + return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps)) + + remove_unused_inputs = register_strategy("Remove unused inputs")( + _remove_unused_wrapper + ) + + @register_strategy("Eliminate dead code") + def eliminate_dead_code(cur_graph, cur_inps, granularity): + if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + def _consolidate_placeholders(cur_graph, inps): + new_graph = fx.Graph() + env = {} + seen_non_placeholder = False + + # Move all placeholders to the front; also, if any load_tensor + # is at the front, convert it into an input (because it can be live + # all the time) + for node in cur_graph.nodes: + if node.op == "placeholder": + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + elif not seen_non_placeholder and is_load_tensor_node(node): + new_node = new_graph.placeholder(node.name) + env[node] = new_node + inps.append( + torch.ops.debugprims.load_tensor.default(*node.args, **node.kwargs) + ) + else: + seen_non_placeholder = True + + # Move everyone else + for node in cur_graph.nodes: + if node not in env: + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + return new_graph + + @register_strategy("Delta Debugging") + def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity): + num_nodes = len(cur_graph.nodes) + for start_range in range(0, num_nodes, granularity): + is_removing = False + new_graph = deepcopy_fx_graph(cur_graph) + new_inps = cur_inps[:] + end_range = min(num_nodes, start_range + granularity) + for idx in range(start_range, end_range): + new_node = list(new_graph.nodes)[idx] + if _convert_node_to_placeholder(new_graph, new_node, new_inps): + is_removing = True + if not is_removing: + continue + new_graph.eliminate_dead_code() + new_graph = _consolidate_placeholders(new_graph, new_inps) + new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps)) + if new_state is None: + new_state = ReproState(new_graph, new_inps) + if graph_fails(new_state.graph, new_state.inps): + return ReproState(new_state.graph, new_state.inps) + + return None + + @register_strategy("Consolidate Inputs") + def consolidate_inputs(cur_graph, cur_inps, granularity): + old_len = len(cur_inps) + cur_graph = _consolidate_placeholders(cur_graph, cur_inps) + if len(cur_inps) > old_len and graph_fails(cur_graph, cur_inps): + return ReproState(cur_graph, cur_inps) + return None + + failing_state = ReproState(failing_graph, inps) + + def try_granularity(failing_state, granularity, use_non_granular): + print(f"Trying granularity {granularity}", file=sys.stderr) + + strategies = [] + num_nodes = len(failing_state.graph.nodes) + num_outputs = len(get_outputs(failing_state.graph)) + if num_outputs > num_nodes // 2: + strategies += [remove_outputs] + + if use_non_granular: + strategies += [ + eliminate_dead_code, + remove_unused_inputs, + consolidate_inputs, + ] + + strategies += [remove_suffix, delta_debugging] + + for strategy in strategies: + new_state = strategy(failing_state, granularity) + if new_state is not None: + return new_state + return None + + while True: + dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps) + granularity = int(2 ** (math.floor(math.log2(len(failing_state.graph.nodes))))) + if max_granularity is not None: + granularity = min(max_granularity, granularity) + new_state = try_granularity(failing_state, granularity, use_non_granular=True) + if new_state is not None: + failing_state = new_state + continue + + granularity //= 2 + has_progress = False + while granularity >= 1: + new_state = try_granularity( + failing_state, granularity, use_non_granular=False + ) + if new_state is not None: + failing_state = new_state + has_progress = True + break + granularity //= 2 + if has_progress: + continue + + new_state = remove_outputs(failing_state, 1) + if new_state is not None: + failing_state = new_state + continue + + break + + if not graph_fails(failing_state.graph, failing_state.inps): + raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing") + + print(f"Made {num_queries} queries", file=sys.stderr) + failing_fx = fx.GraphModule(fail_f, failing_state.graph) + + # If XLA debugging environment is enabled, create minified HLO graph as well + if "XLA_HLO_DEBUG" in os.environ: + create_minified_hlo_graph(failing_fx, failing_state.inps) + + dump_state(failing_fx, failing_state.inps) + print("Wrote minimal repro out to repro.py", file=sys.stderr) + return failing_fx, failing_state.inps diff --git a/phivenv/Lib/site-packages/torch/_functorch/make_functional.py b/phivenv/Lib/site-packages/torch/_functorch/make_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..60650f16ae3d2749aec05a3df3281527e606b930 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/make_functional.py @@ -0,0 +1,607 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from collections.abc import Iterable, Sequence +from typing import Any, Callable, NoReturn, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.utils._named_member_accessor import NamedMemberAccessor + + +# Utilities to make nn.Module "functional" +# In particular the goal is to be able to provide a function that takes as input +# the parameters and evaluate the nn.Module using fixed inputs. + + +def raise_parameter_tying_error() -> NoReturn: + raise RuntimeError( + "make_functional(module): we don't yet support models that " + "do parameter tying (also sometimes known as weight sharing). " + "Please try to rewrite your model by replacing all instances of the " + "tied parameter with another and/or comment your support in " + "https://github.com/pytorch/functorch/issues/446" + ) + + +def create_names_map( + named_params: Union[dict[str, Tensor], Iterable[tuple[str, Tensor]]], + tied_named_params: Union[dict[str, Tensor], Iterable[tuple[str, Tensor]]], +) -> dict[str, list[str]]: + """ + named_params is a dictionary of tensors: {'A': A, 'B': B} + tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B} + with potentially tied (or 'duplicated') tensors + + This function creates a mapping from the names in named_params to the + names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}. + """ + named_params = dict(named_params) + tied_named_params = dict(tied_named_params) + + tensors_dict_keys = set(named_params.keys()) + tied_tensors_dict_keys = set(tied_named_params.keys()) + assert tensors_dict_keys.issubset(tied_tensors_dict_keys) + + tensor_to_mapping: dict[Tensor, tuple[str, list[str]]] = {} + for key, tensor in named_params.items(): + tensor_to_mapping[tensor] = (key, []) + for key, tensor in tied_named_params.items(): + assert tensor in tensor_to_mapping + tensor_to_mapping[tensor][1].append(key) + return dict(tensor_to_mapping.values()) + + +def _extract_members( + mod: nn.Module, + named_members: Callable[..., Iterable[tuple[str, Tensor]]], + subclass: Callable[[Tensor], Tensor], +) -> tuple[tuple[Tensor, ...], tuple[str, ...], dict[str, list[str]]]: + all_named_members = tuple(named_members(remove_duplicate=False)) + unique_named_members = tuple(named_members(remove_duplicate=True)) + names_map = create_names_map(unique_named_members, all_named_members) + + # Remove all the members in the model + memo = {} + accessor = NamedMemberAccessor(mod) + for name, p in all_named_members: + if p not in memo: + memo[p] = subclass(torch.empty_like(p, device="meta")) + replacement = memo[p] + accessor.set_tensor(name, replacement) + + if len(unique_named_members) == 0: + names, params = (), () + else: + names, params = zip(*unique_named_members) # type: ignore[assignment] + return params, names, names_map + + +def extract_weights( + mod: nn.Module, +) -> tuple[tuple[Tensor, ...], tuple[str, ...], dict[str, list[str]]]: + """ + This function removes all the Parameters from the model and + return them as a tuple as well as their original attribute names. + The weights must be re-loaded with `load_weights` before the model + can be used again. + Note that this function modifies the model in place and after this + call, mod.parameters() will be empty. + """ + return _extract_members(mod, mod.named_parameters, nn.Parameter) + + +def extract_buffers( + mod: nn.Module, +) -> tuple[tuple[Tensor, ...], tuple[str, ...], dict[str, list[str]]]: + return _extract_members(mod, mod.named_buffers, lambda x: x) + + +def load_weights( + mod: nn.Module, + names: Sequence[str], + params: Sequence[Tensor], + as_params: bool = False, +) -> None: + """ + Reload a set of weights so that `mod` can be used again to perform a forward pass. + Note that the `params` are regular Tensors (that can have history) and so are left + as Tensors. This means that mod.parameters() will still be empty after this call. + """ + accessor = NamedMemberAccessor(mod) + if as_params: + params = [nn.Parameter(p) for p in params] + accessor.set_tensors(names, params) + + +def _swap_state( + mod: nn.Module, names_map: dict[str, list[str]], elems: Iterable[Tensor] +) -> list[Tensor]: + result: list[Tensor] = [] + accessor = NamedMemberAccessor(mod) + for (_, attr_names), elem in zip(names_map.items(), elems): + for i, attr_name in enumerate(attr_names): + if i == 0: + result.append(accessor.swap_tensor(attr_name, elem)) + else: + accessor.set_tensor(attr_name, elem) + return result + + +def load_buffers( + mod: nn.Module, + names: Sequence[str], + buffers: Sequence[Tensor], + as_params: bool = False, +) -> None: + accessor = NamedMemberAccessor(mod) + accessor.set_tensors(names, buffers) + + +def load_state( + model: nn.Module, + weights: Sequence[Tensor], + weight_names: Sequence[str], + buffers: Sequence[Tensor] = (), + buffer_names: Sequence[str] = (), +) -> nn.Module: + """load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model + + load_state takes `weights` and `buffers` and assigns them to the model. + This is the inverse operation of `make_functional_deprecated_v1`. + """ + assert len(weight_names) == len(weights) + load_weights(model, weight_names, weights) + if len(buffers) > 0: + assert len(buffer_names) == len(buffers) + load_buffers(model, buffer_names, buffers) + return model + + +def make_functional_deprecated_v1(model: nn.Module): + """make_functional_deprecated_v1(model) -> weights, func, weight_names + + Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights) + and returns a functional version of the model, `func`. This makes + it so that it is possible use transforms over the parameters of + `model`. + + `func` can be invoked as follows: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, func, _ = make_functional_deprecated_v1(model) + func(weights, (x,)) + ``` + + And here is an example of applying the grad transform: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, _, func = make_functional_deprecated_v1(model) + grad_weights = grad(func)(weights, (x,)) + ``` + + To put the state back into a model, use `load_state`. + """ + buffers = list(model.buffers()) + if len(buffers) > 0: + raise RuntimeError( + "make_functional_deprecated_v1(model): `model` has buffers. Please use " + "make_functional_with_buffers_deprecated_v1(model) instead." + ) + weights, descriptors, _ = extract_weights(model) + + def fun(weights, data): + mutable_model = copy.deepcopy(model) + load_weights(mutable_model, descriptors, weights) + return mutable_model(*data) + + return weights, fun, descriptors + + +def make_functional_with_buffers_deprecated_v1(model: nn.Module): + """make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names + + Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers) + and returns a functional version of the model, `func`. + + `func` can be invoked as follows: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) + func(weights, buffers, (x,)) + ``` + + And here is an example of applying the grad transform: + ``` + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) + func(weights, buffers, (x,)) + grad_weights = grad(func)(weights, buffers, (x,)) + ``` + + To put the state back into a model, use `load_state`. + """ + weights, weight_descriptors, _ = extract_weights(model) + buffers, buf_descriptors, _ = extract_buffers(model) + + def fun(weights, buffers, data): + mutable_model = copy.deepcopy(model) + load_weights(mutable_model, weight_descriptors, weights) + load_buffers(mutable_model, buf_descriptors, buffers) + return mutable_model(*data) + + return weights, buffers, fun, weight_descriptors, buf_descriptors + + +class FunctionalModuleWithBuffers(nn.Module): + """ + This is the callable object returned by :func:`make_functional_with_buffers`. + """ + + def __init__( + self, + stateless_model: nn.Module, + param_names: tuple[str, ...], + buffer_names: tuple[str, ...], + param_names_map: dict[str, list[str]], + buffer_names_map: dict[str, list[str]], + ) -> None: + super().__init__() + self.stateless_model = stateless_model + self.param_names = param_names + self.buffer_names = buffer_names + + self.all_names_map = dict(param_names_map) + self.all_names_map.update(buffer_names_map) + + @staticmethod + def _create_from( + model: nn.Module, disable_autograd_tracking: bool = False + ) -> tuple["FunctionalModuleWithBuffers", tuple[Tensor, ...], tuple[Tensor, ...]]: + # TODO: We don't need to copy the model to create a stateless copy + model_copy = copy.deepcopy(model) + params, param_names, param_names_map = extract_weights(model_copy) + buffers, buffer_names, buffer_names_map = extract_buffers(model_copy) + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + return ( + FunctionalModuleWithBuffers( + model_copy, param_names, buffer_names, param_names_map, buffer_names_map + ), + params, + buffers, + ) + + def forward( + self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs + ) -> Any: + # Temporarily load the state back onto self.stateless_model + old_state = _swap_state( + self.stateless_model, + self.all_names_map, + tuple(params) + tuple(buffers), + ) + try: + return self.stateless_model(*args, **kwargs) + finally: + # Remove the loaded state on self.stateless_model + _swap_state(self.stateless_model, self.all_names_map, old_state) + + +class FunctionalModule(nn.Module): + """ + This is the callable object returned by :func:`make_functional`. + """ + + def __init__( + self, + stateless_model: nn.Module, + param_names: tuple[str, ...], + names_map: dict[str, list[str]], + ) -> None: + super().__init__() + self.stateless_model = stateless_model + self.param_names = param_names + self.names_map = names_map + + @staticmethod + def _create_from( + model: nn.Module, disable_autograd_tracking: bool = False + ) -> tuple["FunctionalModule", tuple[Tensor, ...]]: + # TODO: We don't need to copy the model to create a stateless copy + model_copy = copy.deepcopy(model) + params, param_names, names_map = extract_weights(model_copy) + if disable_autograd_tracking: + for param in params: + param.requires_grad_(False) + return FunctionalModule(model_copy, param_names, names_map), params + + def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any: + # Temporarily load the state back onto self.stateless_model + old_state = _swap_state(self.stateless_model, self.names_map, params) + try: + return self.stateless_model(*args, **kwargs) + finally: + # Remove the loaded state on self.stateless_model + _swap_state(self.stateless_model, self.names_map, old_state) + + +def make_functional( + model: nn.Module, disable_autograd_tracking: bool = False +) -> tuple[FunctionalModule, tuple[Tensor, ...]]: + """make_functional(model, disable_autograd_tracking=False) -> func, params + + Given a ``torch.nn.Module``, :func:`make_functional` extracts the state + (params) and returns a functional version of the model, ``func``. This + makes it so that it is possible use transforms over the parameters of + ``model``. + + ``func`` can be invoked as follows: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional + + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params = make_functional(model) + func(params, x) + + And here is an example of applying the grad transform over the parameters + of a model. + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params = make_functional(model) + + def compute_loss(params, x, t): + y = func(params, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(params, x, t) + + If the model has any buffers, please use :func:`make_functional_with_buffers` instead. + + Args: + model (torch.nn.Module): Input model. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. + + """ + buffers = list(model.buffers()) + if len(buffers) > 0: + raise RuntimeError( + "make_functional(model): `model` has buffers. Please use " + "make_functional_with_buffers(model) instead." + ) + return FunctionalModule._create_from( + model, disable_autograd_tracking=disable_autograd_tracking + ) + + +def make_functional_with_buffers( + model: nn.Module, disable_autograd_tracking: bool = False +) -> tuple[FunctionalModuleWithBuffers, tuple[Tensor, ...], tuple[Tensor, ...]]: + """make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers + + Given a ``torch.nn.Module``, make_functional_with_buffers extracts the + state (params and buffers) and returns a functional version of the model + ``func`` that can be invoked like a function. + + ``func`` can be invoked as follows: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional_with_buffers + + x = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params, buffers = make_functional_with_buffers(model) + func(params, buffers, x) + + And here is an example of applying the grad transform over the parameters + of a model: + + .. code-block:: python + + import torch + import torch.nn as nn + from functorch import make_functional_with_buffers, grad + + x = torch.randn(4, 3) + t = torch.randn(4, 3) + model = nn.Linear(3, 3) + func, params, buffers = make_functional_with_buffers(model) + + def compute_loss(params, buffers, x, t): + y = func(params, buffers, x) + return nn.functional.mse_loss(y, t) + + grad_weights = grad(compute_loss)(params, buffers, x, t) + + Args: + model (torch.nn.Module): Input model. + disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. + The returned params are unrelated to the set of params from the original model. If False (default), + the params will have ``requires_grad=True`` on them (aka they will be trackable with regular + PyTorch autograd), matching the requires_grad-ness of the params from the original model. + Otherwise, the returned params will have ``requires_grad=False``. Default, False. + If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or + ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. + Otherwise, if you're only planning on using functorch's gradient transforms, + then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking + history with PyTorch autograd. + + """ + return FunctionalModuleWithBuffers._create_from( + model, disable_autograd_tracking=disable_autograd_tracking + ) + + +def transpose_stack( + tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...] +) -> tuple[Tensor, ...]: + tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors)) + results = tuple( + torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors + ) + return results + + +def combine_state_for_ensemble( + models: Sequence[nn.Module], +) -> tuple[FunctionalModuleWithBuffers, tuple[Tensor, ...], tuple[Tensor, ...]]: + """combine_state_for_ensemble(models) -> func, params, buffers + + Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. + + Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their + parameters and buffers together to make ``params`` and ``buffers``. + Each parameter and buffer in the result will have an additional dimension + of size ``M``. + + :func:`combine_state_for_ensemble` also returns ``func``, a functional + version of one of the models in :attr:`models`. One cannot directly run + ``func(params, buffers, *args, **kwargs)`` directly, you probably want to + use ``vmap(func, ...)(params, buffers, *args, **kwargs)`` + + Here's an example of how to ensemble over a very simple model: + + .. code-block:: python + + num_models = 5 + batch_size = 64 + in_features, out_features = 3, 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + data = torch.randn(batch_size, 3) + + fmodel, params, buffers = combine_state_for_ensemble(models) + output = vmap(fmodel, (0, 0, None))(params, buffers, data) + + assert output.shape == (num_models, batch_size, out_features) + + .. warning:: + All of the modules being stacked together must be the same (except for + the values of their parameters/buffers). For example, they should be in the + same mode (training vs eval). + + This API is subject to change -- we're investigating better ways to + create ensembles and would love your feedback how to improve this. + """ + if len(models) == 0: + raise RuntimeError( + "combine_state_for_ensemble: Expected at least one model, got 0." + ) + if not (all(m.training for m in models) or all(not m.training for m in models)): + raise RuntimeError( + "combine_state_for_ensemble: Expected all models to " + "have the same training/eval mode." + ) + model0_typ = type(models[0]) + if not all(type(m) == model0_typ for m in models): + raise RuntimeError( + "combine_state_for_ensemble: Expected all models to be of the same class." + ) + funcs, params, buffers = zip( + *[make_functional_with_buffers(model) for model in models] + ) + params = transpose_stack(params) + buffers = transpose_stack(buffers) + return funcs[0], params, buffers + + +def functional_init( + model_class: type[nn.Module], + ensemble_shape: Union[tuple[()], tuple[int]] = (), + device: torch.types.Device = "cpu", +): + def wrapped(*args, **kwargs): + if len(ensemble_shape) >= 2: + raise ValueError("NYI: ensemble_shape with more than 1 element") + if len(ensemble_shape) == 0: + model = model_class(*args, **kwargs).to(device) + return make_functional_deprecated_v1(model) + num_models = ensemble_shape[0] # type: ignore[misc] + if num_models <= 0: + raise ValueError(f"num_models {num_models} should be > 0") + # NB: Not very efficient, more of a POC + models = tuple( + model_class(*args, **kwargs).to(device) for _ in range(num_models) + ) + _, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs)) + weights = tuple(make_functional_deprecated_v1(model)[0] for model in models) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + return weights, fn, names + + return wrapped + + +def functional_init_with_buffers( + model_class: type[nn.Module], + ensemble_shape: Union[tuple[()], tuple[int]] = (), + device: torch.types.Device = "cpu", +): + def wrapped(*args, **kwargs): + if len(ensemble_shape) >= 2: + raise ValueError("NYI: ensemble_shape with more than 1 element") + if len(ensemble_shape) == 0: + model = model_class(*args, **kwargs).to(device) + return make_functional_deprecated_v1(model) + num_models = ensemble_shape[0] # type: ignore[misc] + if num_models <= 0: + raise ValueError(f"num_models {num_models} should be > 0") + # NB: Not very efficient, more of a POC + models = tuple( + model_class(*args, **kwargs).to(device) for _ in range(num_models) + ) + ( + _, + _, + fn, + weight_names, + buffer_names, + ) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs)) + weights, buffers = zip( + *tuple( + make_functional_with_buffers_deprecated_v1(model)[:2] + for model in models + ) + ) + weights = tuple(zip(*weights)) + weights = tuple(torch.stack(shards).detach() for shards in weights) + buffers = tuple(zip(*buffers)) + buffers = tuple(torch.stack(shards).detach() for shards in buffers) + return weights, buffers, fn, weight_names, buffer_names + + return wrapped diff --git a/phivenv/Lib/site-packages/torch/_functorch/partitioners.py b/phivenv/Lib/site-packages/torch/_functorch/partitioners.py new file mode 100644 index 0000000000000000000000000000000000000000..7783346d0acd9152160ca0ffa35250c36b838f69 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/partitioners.py @@ -0,0 +1,2724 @@ +# mypy: allow-untyped-defs +import copy +import functools +import hashlib +import heapq +import itertools +import logging +import math +import operator +import os +import os.path +from collections import defaultdict +from dataclasses import dataclass, replace +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +import torch._inductor.inductor_prims +import torch.distributed +import torch.fx as fx +import torch.utils._pytree as pytree +from torch._dynamo.utils import counters, is_node_meta_valid +from torch._functorch._activation_checkpointing.ac_logging_utils import ( + create_structured_trace_for_min_cut_info, +) +from torch._inductor import config as inductor_config +from torch._logging import trace_structured +from torch._subclasses.fake_tensor import extract_tensor_metadata +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import ( + find_symbol_binding_fx_nodes, + free_symbols, + hint_int, + is_symbol_binding_fx_node, + statically_known_false, + statically_known_true, +) +from torch.fx.passes import graph_drawer +from torch.utils._ordered_set import OrderedSet +from torch.utils.checkpoint import CheckpointPolicy + +from . import config +from ._activation_checkpointing.graph_info_provider import GraphInfoProvider +from ._activation_checkpointing.knapsack import ( + dp_knapsack, + greedy_knapsack, + ilp_knapsack, +) +from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator +from ._aot_autograd.logging_utils import get_aot_graph_name +from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects +from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems + + +if TYPE_CHECKING: + import sympy + + +AOT_PARTITIONER_DEBUG: bool = config.debug_partitioner +log: logging.Logger = logging.getLogger(__name__) + +aten = torch.ops.aten +prims = torch.ops.prims + + +@dataclass +class OpTypes: + """Class for keeping track of different operator categories""" + + fusible_ops: OrderedSet[Callable] + compute_intensive_ops: OrderedSet[Callable] + random_ops: OrderedSet[Callable] + view_ops: OrderedSet[Callable] + recomputable_ops: OrderedSet[Callable] + + def is_fusible(self, node: fx.Node): + return get_aten_target(node) in self.fusible_ops + + def is_compute_intensive(self, node: fx.Node): + return get_aten_target(node) in self.compute_intensive_ops + + def is_random(self, node: fx.Node): + return get_aten_target(node) in self.random_ops + + def is_view(self, node: fx.Node): + return get_aten_target(node) in self.view_ops + + def is_recomputable(self, node: fx.Node): + return get_aten_target(node) in self.recomputable_ops + + +@dataclass +class NodeInfo: + # Be careful about iterating over these explicitly, as their order may not + # be deterministic + inputs: list[fx.Node] + _required_fw_nodes: OrderedSet[fx.Node] + required_bw_nodes: OrderedSet[fx.Node] + unclaimed_nodes: OrderedSet[fx.Node] + fw_order: dict[fx.Node, int] + # Effectively maps to which of our primals are parameters + static_lifetime_input_nodes: OrderedSet[fx.Node] + + @functools.cached_property + def required_fw_nodes(self) -> list[fx.Node]: + return sorted( + (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n] + ) + + def is_required_fw(self, n: fx.Node) -> bool: + return n in self._required_fw_nodes + + def is_required_bw(self, n: fx.Node) -> bool: + return n in self.required_bw_nodes + + def is_unclaimed(self, n: fx.Node) -> bool: + return n in self.unclaimed_nodes + + def get_fw_order(self, n: fx.Node) -> int: + assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!" + return self.fw_order[n] + + +@dataclass +class MinCutOptions: + ban_if_used_far_apart: bool + ban_if_long_fusible_chains: bool + ban_if_materialized_backward: bool + ban_if_not_in_allowlist: bool + ban_if_reduction: bool + + +def must_recompute(node: fx.Node) -> bool: + return node.meta.get("recompute", None) in [ + CheckpointPolicy.MUST_RECOMPUTE, + CheckpointPolicy.PREFER_RECOMPUTE, + ] + + +def has_recomputable_ops(fx_g: fx.GraphModule) -> bool: + for node in fx_g.graph.nodes: + if must_recompute(node): + return True + return False + + +def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool: + for node in fx_g.graph.nodes: + if ( + must_recompute(node) + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return True + return False + + +def sym_node_size(node: fx.Node) -> int: + if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): + return 1 + assert isinstance(node.meta["val"], torch.SymFloat) + return 4 + + +class InvalidNodeBase: + def __repr__(self): + return "Invalid Node" + + +InvalidNode = InvalidNodeBase() + + +def _extract_graph_with_inputs_outputs( + joint_graph: fx.Graph, + inputs: list[fx.Node], + outputs: list[fx.Node], + subgraph: Optional[str] = None, +) -> fx.Graph: + """ + Given a graph, extracts out a subgraph that takes the specified nodes as + inputs and returns the specified outputs. + + This includes specifying non-placeholder nodes as inputs. + + The general strategy is to initialize all inputs with proxies as we + encounter them, and trace through the graph, only keeping values which take + in valid proxies. Then, all dead code is eliminated. + """ + new_graph = fx.Graph() + env = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in inputs: + new_node = new_graph.placeholder(node.name) + # Can't use node_copy here as we may be turning previous call_function into placeholders + new_node.meta = node.meta + env[node] = new_node + + for node in joint_graph.nodes: + if _must_be_in_backward(node) and subgraph != "backward": + env[node] = InvalidNode # type: ignore[assignment] + continue + + if node in env: + # Node must be one of our inputs. (Any member of env which wasn't an + # input to start must have been created by this loop and won't be in + # joint_graph.nodes). + continue + elif node.op == "placeholder": + env[node] = InvalidNode # type: ignore[assignment] + elif node.op == "call_function": + all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) + all_args = [ + isinstance(env[x], InvalidNodeBase) + for x in all_args + if isinstance(x, fx.Node) + ] + if any(all_args): + env[node] = InvalidNode # type: ignore[assignment] + continue + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == "get_attr": + env[node] = new_graph.node_copy(node, lambda x: env[x]) + elif node.op == "output": + pass + output_values = [] + for x in outputs: + if isinstance(x, fx.Node): + if x not in env: + raise RuntimeError(f"Node {x} couldn't be found in env") + assert not isinstance( + env[x], InvalidNodeBase + ), f"Node {x} was invalid, but is output" + output_values.append(env[x]) + else: + output_values.append(x) + new_graph.output(tuple(output_values)) + + new_graph.eliminate_dead_code() + new_graph.lint() + return new_graph + + +def _is_primal(node: fx.Node) -> bool: + return ( + node.op == "placeholder" + and "tangents" not in str(node.target) + and not _is_bwd_seed_offset(node) + and not _is_fwd_seed_offset(node) + ) + + +def _is_tangent(node: fx.Node) -> bool: + return node.op == "placeholder" and "tangents" in str(node.target) + + +def _is_bwd_seed_offset(node: fx.Node) -> bool: + return node.op == "placeholder" and ( + "bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target) + ) + + +def _is_fwd_seed_offset(node: fx.Node) -> bool: + return node.op == "placeholder" and ( + "fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target) + ) + + +def _is_backward_state(node: fx.Node) -> bool: + return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState) + + +def _has_tag_is_backward(node: fx.Node) -> bool: + return node.meta.get("partitioner_tag", None) == "is_backward" + + +def _has_tag_must_be_in_backward(node: fx.Node) -> bool: + return node.meta.get("partitioner_tag", None) == "must_be_in_backward" + + +def _must_be_in_backward(node: fx.Node) -> bool: + return _has_tag_must_be_in_backward(node) or ( + _has_tag_is_backward(node) and is_with_effects(node) + ) + + +def _extract_fwd_bwd_outputs( + joint_module: fx.GraphModule, *, num_fwd_outputs +) -> tuple[list[fx.Node], list[fx.Node]]: + outputs = pytree.arg_tree_leaves( + *(node.args for node in joint_module.graph.find_nodes(op="output")) + ) + fwd_outputs = outputs[:num_fwd_outputs] + bwd_outputs = outputs[num_fwd_outputs:] + return fwd_outputs, bwd_outputs + + +def _remove_by_name(saved_values: list[fx.Node], name: str): + for saved_value in saved_values: + if saved_value.name == name: + saved_values.remove(saved_value) + break + + +def find_first_sym_node( + fwd_module_outputs: Union[list[fx.Node], tuple[fx.Node]], +) -> int: + idx = len(fwd_module_outputs) + for i in range(len(fwd_module_outputs) - 1, -1, -1): + if not is_sym_node(fwd_module_outputs[i]): + idx = i + 1 + break + return idx + + +def calculate_quantization_scaling( + graph: torch.fx.Graph, + node: torch.fx.Node, + max: float = 57344.0, + min: float = 1e-12, +): + with graph.inserting_after(node): + abs_node = graph.call_function( + torch.ops.aten.abs.default, + args=(node,), + ) + abs_node.meta["val"] = torch.ops.aten.abs.default(node.meta["val"]) + abs_node.meta["tensor_meta"] = extract_tensor_metadata(abs_node.meta["val"]) + with graph.inserting_after(abs_node): + amax_node = graph.call_function( + torch.ops.aten.amax.default, + args=(abs_node, [-1], True), + ) + amax_node.meta["val"] = torch.ops.aten.amax.default( + abs_node.meta["val"], [-1], True + ) + amax_node.meta["tensor_meta"] = extract_tensor_metadata(amax_node.meta["val"]) + with graph.inserting_after(amax_node): + amax_64_node = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(amax_node, torch.float64), + ) + amax_64_node.meta["val"] = torch.ops.prims.convert_element_type.default( + amax_node.meta["val"], torch.float64 + ) + amax_64_node.meta["tensor_meta"] = extract_tensor_metadata( + amax_64_node.meta["val"] + ) + with graph.inserting_after(amax_64_node): + clamp_min_node = graph.call_function( + torch.ops.aten.clamp_min.default, + args=(amax_64_node, min), + ) + clamp_min_node.meta["val"] = torch.ops.aten.clamp_min.default( + amax_64_node.meta["val"], min + ) + clamp_min_node.meta["tensor_meta"] = extract_tensor_metadata( + clamp_min_node.meta["val"] + ) + with graph.inserting_after(clamp_min_node): + reciprocal_node = graph.call_function( + torch.ops.aten.reciprocal.default, + args=(clamp_min_node,), + ) + reciprocal_node.meta["val"] = torch.ops.aten.reciprocal.default( + clamp_min_node.meta["val"] + ) + reciprocal_node.meta["tensor_meta"] = extract_tensor_metadata( + reciprocal_node.meta["val"] + ) + with graph.inserting_after(reciprocal_node): + mul_node = graph.call_function( + torch.ops.aten.mul.Tensor, + args=(reciprocal_node, max), + ) + mul_node.meta["val"] = torch.ops.aten.mul.Tensor( + reciprocal_node.meta["val"], max + ) + mul_node.meta["tensor_meta"] = extract_tensor_metadata(mul_node.meta["val"]) + with graph.inserting_after(mul_node): + scale_node = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(mul_node, torch.float32), + name="fp8_scale_" + str(node.name), + ) + scale_node.meta["val"] = torch.ops.prims.convert_element_type.default( + mul_node.meta["val"], torch.float32 + ) + scale_node.meta["tensor_meta"] = extract_tensor_metadata(scale_node.meta["val"]) + return scale_node + + +def perform_quantization( + graph: torch.fx.Graph, + node: torch.fx.Node, + scale_node: torch.fx.Node, + quant_type: torch.dtype, + clamp_min: float, + clamp_max: float, +) -> torch.fx.Node: + with graph.inserting_after(scale_node): + target_node_32 = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(node, torch.float32), + ) + target_node_32.meta["val"] = torch.ops.prims.convert_element_type.default( + node.meta["val"], torch.float32 + ) + target_node_32.meta["tensor_meta"] = extract_tensor_metadata( + target_node_32.meta["val"] + ) + with graph.inserting_after(target_node_32): + scaled_target_node = graph.call_function( + torch.ops.aten.mul.Tensor, + args=(target_node_32, scale_node), + ) + scaled_target_node.meta["val"] = torch.ops.aten.mul.Tensor( + target_node_32.meta["val"], scale_node.meta["val"] + ) + scaled_target_node.meta["tensor_meta"] = extract_tensor_metadata( + scaled_target_node.meta["val"] + ) + with graph.inserting_after(scaled_target_node): + clamp_min_scaled_node = graph.call_function( + torch.ops.aten.clamp_min.default, + args=(scaled_target_node, clamp_min), + ) + clamp_min_scaled_node.meta["val"] = torch.ops.aten.clamp_min.default( + scaled_target_node.meta["val"], clamp_min + ) + clamp_min_scaled_node.meta["tensor_meta"] = extract_tensor_metadata( + clamp_min_scaled_node.meta["val"] + ) + with graph.inserting_after(clamp_min_scaled_node): + clamp_max_scaled_node = graph.call_function( + torch.ops.aten.clamp_max.default, + args=(clamp_min_scaled_node, clamp_max), + ) + clamp_max_scaled_node.meta["val"] = torch.ops.aten.clamp_max.default( + clamp_min_scaled_node.meta["val"], clamp_max + ) + clamp_max_scaled_node.meta["tensor_meta"] = extract_tensor_metadata( + clamp_max_scaled_node.meta["val"] + ) + with graph.inserting_after(clamp_max_scaled_node): + quant_activation_node = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(clamp_max_scaled_node, quant_type), + name="fp8_quant_" + str(node.name), + ) + quant_activation_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + clamp_max_scaled_node.meta["val"], quant_type + ) + quant_activation_node.meta["tensor_meta"] = extract_tensor_metadata( + quant_activation_node.meta["val"] + ) + return quant_activation_node + + +def calculate_tensor_size(tensor: torch.Tensor) -> float: + """ + Calculate the size of a PyTorch tensor in megabytes (MB). + + Args: + tensor (torch.Tensor): Input tensor + + Returns: + float: Memory size in MB + """ + # Get number of elements and size per element + num_elements = tensor.numel() + element_size = tensor.element_size() + + return (num_elements * element_size) / (1024 * 1024) + + +def get_allowed_dtypes() -> list[torch.dtype]: + allowed_dtypes = torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("allowed_dtypes", "torch.bfloat16") + allowed_dtypes = [ + getattr(torch, dtype.split(".")[-1]) for dtype in allowed_dtypes.split(";") + ] + return allowed_dtypes + + +def should_quantize(node: torch.fx.Node) -> bool: + allowed_dtypes = get_allowed_dtypes() + if not is_node_meta_valid(node) or node.meta["val"].dtype not in allowed_dtypes: + return False + size_threshold = torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("size_in_mb", 100) + # calculate the size of the node + size_in_mb = calculate_tensor_size(node.meta["val"]) + if not torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("skip_dynamo_guards", False): + return size_in_mb >= size_threshold + else: + # case 1: we alway quantize tensors with dynamic shapes + if torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("quantize_dynamic_shape", False): + return statically_known_true( + size_in_mb >= size_threshold + ) or not statically_known_false(size_in_mb >= size_threshold) + else: + # case 2: we alway not quantize tensors with dynamic shapes + return statically_known_true(size_in_mb >= size_threshold) + + +def get_quant_type() -> torch.dtype: + quant_type = torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("quant_type", "torch.float8_e5m2") + + return getattr(torch, quant_type.split(".")[-1]) + + +def calculate_range(dtype: torch.dtype) -> tuple: + """ + Calculate the range of values for a given torch.dtype. + Args: + dtype (torch.dtype): The input dtype. + Returns: + tuple: A tuple containing the minimum and maximum values. + """ + info = torch.finfo(dtype) + return info.min, info.max + + +def quantize_activation_fw(graph: torch.fx.Graph) -> None: + output = graph.find_nodes(op="output")[0] + fwd_outputs = output.args[0] + quant_type = get_quant_type() + clamp_min, clamp_max = calculate_range(quant_type) + node_to_quant = dict() + tensor_scale_nodes, sym_scale_nodes = [], [] + for node in fwd_outputs: + # check if the activation node is the node saved for quantization + if node.meta.get("saved_for_quantization", False): + # case: use scaling + if torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("use_scaling", True): + # calculating the scale + scale_node = calculate_quantization_scaling( + graph, node, clamp_max, 1e-12 + ) + # converting to fp8 + quant_node = perform_quantization( + graph, node, scale_node, quant_type, clamp_min, clamp_max + ) + if not is_sym_node(scale_node): + tensor_scale_nodes.append(scale_node) + else: + sym_scale_nodes.append(scale_node) + else: + # case: do not use scaling + with graph.inserting_after(node): + quant_node = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(node, quant_type), + name="fp8_quant_" + str(node.name), + ) + quant_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + node.meta["val"], quant_type + ) + quant_node.meta["tensor_meta"] = extract_tensor_metadata( + quant_node.meta["val"] + ) + node_to_quant[node] = quant_node + # only update the return node args, and remain all other users unchanged + output_updated_args = [ + node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs # type: ignore[union-attr] + ] + # add the scale nodes to the ouput find the first sym_node in the output + idx = find_first_sym_node(output_updated_args) + scale_nodes = tensor_scale_nodes + sym_scale_nodes + if scale_nodes: + output_updated_args = ( + output_updated_args[:idx] + scale_nodes + output_updated_args[idx:] + ) + + output.update_arg(0, tuple(output_updated_args)) + counters["inductor"]["activation_quantization_fwd_aten_pass"] += 1 + + +def quantize_activation_bw(graph: torch.fx.Graph) -> None: + bw_inputs = [node for node in graph.nodes if node.op == "placeholder"] + activation_node = None + for node in bw_inputs: + if node.meta.get("saved_for_quantization", False): + node.meta.pop("saved_for_quantization") + dequant_type = node.meta.pop("dequant_type") + # dequantize the node + if torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("use_scaling", False): + # case: use scaling + with graph.inserting_after(node): + # find corresponding scale node + scale_name = "fp8_scale_" + node.name.replace("fp8_quant_", "") + scale_node = next( + bwd_input + for bwd_input in bw_inputs + if bwd_input.name == scale_name + ) + with graph.inserting_after(scale_node): + activation_node = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(node, dequant_type), + ) + activation_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + node.meta["val"], dequant_type + ) + activation_node.meta["tensor_meta"] = extract_tensor_metadata( + activation_node.meta["val"] + ) + with graph.inserting_after(activation_node): + divided_target_node_32 = graph.call_function( + torch.ops.aten.div.Tensor, + args=(activation_node, scale_node), + ) + divided_target_node_32.meta["val"] = torch.ops.aten.div.Tensor( + activation_node.meta["val"], scale_node.meta["val"] + ) + divided_target_node_32.meta[ + "tensor_meta" + ] = extract_tensor_metadata(divided_target_node_32.meta["val"]) + with graph.inserting_after(divided_target_node_32): + dequant_node = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(divided_target_node_32, dequant_type), + ) + dequant_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + divided_target_node_32.meta["val"], dequant_type + ) + dequant_node.meta["tensor_meta"] = extract_tensor_metadata( + dequant_node.meta["val"] + ) + else: + with graph.inserting_after(node): + dequant_node = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(node, dequant_type), + name="dequant_" + str(node.name), + ) + dequant_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + node.meta["val"], dequant_type + ) + dequant_node.meta["tensor_meta"] = extract_tensor_metadata( + dequant_node.meta["val"] + ) + # find the users of the node and replace them with the new node except the dequant_node + for user in list(node.users.keys()): + if user != dequant_node and user != activation_node: + user.replace_input_with(node, dequant_node) + + counters["inductor"]["activation_quantization_bwd_aten_pass"] += 1 + + +def enable_activation_quantization( + saved_values: list[fx.Node], + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, + static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None, +) -> None: + if ( + inductor_config.post_grad_fusion_options.get( + "activation_quantization_aten_pass", None + ) + is None + ): + return + + static_input_names = ( + [node.name for node in static_lifetime_input_nodes] + if static_lifetime_input_nodes + else [] + ) + saved_values_names = {node.name: node for node in saved_values} + if torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("exclude_primals", False): + saved_values_names = { + node.name: node for node in saved_values if "primals" not in node.name + } + fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0] + bwd_module_inputs = { + node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") + } + for node in fwd_module_outputs: + if node.name in saved_values_names and should_quantize(node): + if node.name in static_input_names: + log.debug("Skipping quantization of static input %s: ", node.name) + continue + node.meta["saved_for_quantization"] = True + node.meta["dequant_type"] = node.meta["val"].dtype + # some of the fwd outputs and bwd inputs are not share the same object + bwd_module_inputs[node.name].meta["saved_for_quantization"] = True + bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_activation_quantization_fwd_aten_pass", + "encoding": "string", + }, + payload_fn=lambda: fwd_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + quantize_activation_fw(fwd_module.graph) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "after_activation_quantization_fwd_aten_pass", + "encoding": "string", + }, + payload_fn=lambda: fwd_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_activation_quantization_bwd_aten_pass", + "encoding": "string", + }, + payload_fn=lambda: bwd_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + quant_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0] + # update the corresponding bwd_inputs due to the fwd_outputs quantization + for fwd_node in quant_fwd_module_outputs: + if "fp8_quant_" in fwd_node.name: + bwd_input = bwd_module_inputs[fwd_node.name.replace("fp8_quant_", "")] + with bwd_module.graph.inserting_after(bwd_input): + quant_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name) + dequant_type = bwd_input.meta["dequant_type"] + quant_bwd_input.meta.update(fwd_node.meta) + quant_bwd_input.meta["saved_for_quantization"] = True + quant_bwd_input.meta["dequant_type"] = dequant_type + bwd_input.replace_all_uses_with(quant_bwd_input) + bwd_module.graph.erase_node(bwd_input) + # update the bwd_inputs if quantization with scaling is used + if torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("use_scaling", True): + quant_bwd_module_inputs = list(bwd_module.graph.find_nodes(op="placeholder")) + # update the corresponding bwd input nodes find the last non-tangent node + bwd_input_loc = quant_bwd_module_inputs[-1] + for bw_input in reversed(quant_bwd_module_inputs): + if not _is_tangent(bw_input): + bwd_input_loc = bw_input + break + + scaled_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0] + for fwd_node in scaled_fwd_module_outputs: + if "fp8_scale_" in fwd_node.name: + # fwd node is a scale node + with bwd_module.graph.inserting_after(bwd_input_loc): + scale_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name) + scale_bwd_input.meta.update(fwd_node.meta) + bwd_input_loc = scale_bwd_input + + quantize_activation_bw(bwd_module.graph) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "after_activation_quantization_bwd_aten_pass", + "encoding": "string", + }, + payload_fn=lambda: bwd_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + +def _extract_fwd_bwd_modules( + joint_module: fx.GraphModule, + saved_values: list[fx.Node], + saved_sym_nodes: list[fx.Node], + *, + num_fwd_outputs: int, + static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None, +) -> tuple[fx.GraphModule, fx.GraphModule]: + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + placeholders = joint_module.graph.find_nodes(op="placeholder") + primal_inputs = [*filter(_is_primal, placeholders)] + tangent_inputs = [*filter(_is_tangent, placeholders)] + fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] + bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] + backward_state_inputs = [*filter(_is_backward_state, placeholders)] + + bwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, + bwd_outputs, + "backward", + ) + + distributed_enabled = torch.distributed.is_available() + + for node in bwd_graph.find_nodes(op="placeholder"): + # This is to filter out saved values that don't actually end up being used by the backwards pass + if not node.users: + _remove_by_name(saved_values, node.name) + _remove_by_name(saved_sym_nodes, node.name) + # wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw, + # but this dead activation is actually a collective, + # then the collective will generally by followed by a wait_tensor() call. + # we need to peak one node further to see if this wait_tensor is dead as well. + elif distributed_enabled and all( + n.target is torch.ops._c10d_functional.wait_tensor.default + and len(n.users) == 0 + for n in node.users + ): + _remove_by_name(saved_values, node.name) + _remove_by_name(saved_sym_nodes, node.name) + elif _is_backward_state(node): + # BackwardState is saved directly + _remove_by_name(saved_values, node.name) + assert backward_state_inputs + + # Now that we have the finalized list of saved values, we need to ensure + # we propagate all symbols which are referenced by backwards inputs. + # These are not directly used in the graph but are required for downstream + # sizevar assignment + saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet() + saved_sym_nodes_binding = [] + saved_sym_nodes_derived = [] + + # Some symbols may already be bound in the directly saved_sym_nodes, + # keep track of them so we don't re-bind them + for node in saved_sym_nodes: + symbol = is_symbol_binding_fx_node(node) + if symbol: + saved_symbols.add(symbol) + saved_sym_nodes_binding.append(node) + else: + saved_sym_nodes_derived.append(node) + + # Now go through all of the prospective backward inputs and track any + # other symbols we need to bind + symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) + for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs): + if "val" not in node.meta: + continue + new_symbols = free_symbols(node.meta["val"]) - saved_symbols + # NB: Deterministic order please! + for s in sorted(new_symbols, key=lambda s: s.name): + # NB: For well formed graphs, the symbol should always be present, + # but we also have ways to produce ill-formed graphs, e.g., direct + # make_fx usages, so don't choke in this case + if s not in symbol_bindings: + continue + saved_sym_nodes_binding.append(symbol_bindings[s]) + saved_symbols |= new_symbols + + # Update saved_sym_nodes that are now reordered to have all bindings at + # front. This can also be used later on to figure out the position of saved + # sym nodes in the output of fwd graph. + saved_sym_nodes.clear() + saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) + + # Now, we re-generate the fwd/bwd graphs. + # NB: This might increase compilation time, but I doubt it matters + fwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + primal_inputs + fwd_seed_offset_inputs, + fwd_outputs + saved_values + saved_sym_nodes, + "forward", + ) + bwd_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, + saved_sym_nodes + + saved_values + + tangent_inputs + + bwd_seed_offset_inputs + + backward_state_inputs, + bwd_outputs, + "backward", + ) + + fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) + bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) + enable_activation_quantization( + saved_values, fwd_module, bwd_module, static_lifetime_input_nodes + ) + return fwd_module, bwd_module + + +def default_partition( + joint_module: fx.GraphModule, + _joint_inputs, + *, + num_fwd_outputs, + static_lifetime_input_indices: Optional[list[int]] = None, + static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None, +) -> tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the :attr:`joint_module` in a manner that closely resembles the + behavior observed in the original ``.forward()`` and ``.backward()`` of the + callable, i.e., the resulting forward graph contains those operators that + are executed in the original ``.forward()`` callable passed to + :func:`aot_function`. + + The default partitioner collects the operators that are between the forward + inputs and the forward outputs. This helps in finding the tensors which have + to be stashed for the backward pass. These stashed tensors become the output + of the generated forward graph. The remaining operators are then placed in + the backward graph. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + if has_recomputable_ops(joint_module): + return min_cut_rematerialization_partition( + joint_module, + _joint_inputs, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, "forward" + ) + forward_node_names = OrderedSet( + node.name for node in forward_only_graph.nodes if node.op != "output" + ) + saved_values = [] + saved_sym_nodes = [] + + for node in joint_module.graph.nodes: + if node.name not in forward_node_names: + continue + if is_sym_node(node): + # Symints must be kept separate from tensors so that PythonFunction only calls + # save_for_backward on tensors and stashes symints in autograd .ctx + saved_sym_nodes.append(node) + elif "tensor_meta" not in node.meta and node.op == "call_function": + # Since we can't save tuple of tensor values, we need to flatten out what we're saving + users = node.users + assert all(user.target == operator.getitem for user in users) + saved_values.extend(users) + else: + backward_usages = [ + n for n in node.users if n.name not in forward_node_names + ] + if "tensor_meta" in node.meta and all( + is_sym_node(n) for n in backward_usages + ): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + saved_sym_nodes.extend(backward_usages) + else: + saved_values.append(node) + saved_values = list(dict.fromkeys(saved_values).keys()) + saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) + + return _extract_fwd_bwd_modules( + joint_module, + saved_values, + saved_sym_nodes=saved_sym_nodes, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_nodes=static_lifetime_input_nodes, + ) + + +INT_INF = int(1e6) + + +def _tensor_nbytes(numel: int, dtype) -> int: + return numel * dtype.itemsize + + +def _size_of(node: fx.Node) -> int: + def object_nbytes(x) -> int: + if not isinstance(x, torch.Tensor): + return 0 + return _tensor_nbytes(hint_int(x.numel(), fallback=4096), x.dtype) + + if "val" in node.meta: + val = node.meta["val"] + if isinstance(val, py_sym_types): + return 1 + # NB: The fallback values here are meaningless, maybe we should respect + # torch._inductor.config.unbacked_symint_fallback (but this is a + # layering violation) + elif isinstance(val, (list, tuple)): + return sum(object_nbytes(n) for n in val) + elif isinstance(val, dict): + return sum(object_nbytes(n) for _, n in val.items()) + elif isinstance(val, torch.Tensor): + return object_nbytes(val) + + raise RuntimeError(f"Unknown metadata type {type(val)} on node {node}") + if node.op == "get_attr" or node.target is torch.ops.aten._assert_scalar.default: + return 0 + raise RuntimeError( + f"Node {node} didn't have `val` metadata; we should always have `val` metadata on the nodes." + ) + + +# Used for some investigative purposes +def _count_ops(graph: fx.Graph): + from collections import defaultdict + + cnt: dict[str, int] = defaultdict(int) + for node in graph.nodes: + if node.op == "call_function": + cnt[node.target.__name__] += 1 + log.info("%s", sorted(cnt.items(), key=operator.itemgetter(1), reverse=True)) + + +@functools.cache +def pointwise_ops(): + ops = [] + for attr_name in dir(torch.ops.aten): + opoverloadpacket = getattr(torch.ops.aten, attr_name) + if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket): + continue + + for overload in opoverloadpacket.overloads(): + op_overload = getattr(opoverloadpacket, overload) + if torch.Tag.pointwise in op_overload.tags: + # currently aot autograd uses packet not overload + ops.append(opoverloadpacket) + break + + return ops + + +def sort_depths(args, depth_map: dict[fx.Node, int]) -> list[tuple[fx.Node, int]]: + arg_depths = { + arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) + } + return sorted(arg_depths.items(), key=operator.itemgetter(1), reverse=True) + + +def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: + """ + This pass finds the first bwd node in the graph (by looking at users of + tangents) and then reorders the graph by walking from this node to all the + way to the end of the graph. At each op in this traveral, we insert this op + in a new graph and try to bring only the relevant subgraph from the other + non-bwd edges relevant for this op. This closely mimics the behavior of + autograd engine. + + Why is this pass required in the first place? + + This is an artifact of how partitioners work today. The starting point of + partitioner is a joint graph, which is fwd and then bwd graph. In the case + of checkpointing, we keep portions of fwd graph in their original place in + the joint graph, while obtaining a bwd graph. As a result, the resulting bwd + graph has copies of recomputed fwd subgraphs followed by the original bwd + graph. If we run this naively, this leads to bad memory footprint, because + the fwd subgraphs are live for way longer duration than necessary. This pass + reorders the operations such that we prioritize the ops for the original bwd + graph while only realizing those ops from the fwd graph that are necessary + at any given point in the graph. + """ + + new_graph = fx.Graph() + env: dict[fx.Node, fx.Node] = {} + + # Add new placeholder nodes in the order specified by the inputs + for node in gm.graph.find_nodes(op="placeholder"): + env[node] = new_graph.node_copy(node, lambda x: env[x]) + + order = {node: idx for idx, node in enumerate(gm.graph.nodes)} + + def insert_node_in_graph(node): + cur_nodes = [node] + insertable_nodes: OrderedSet[fx.Node] = OrderedSet() + while len(cur_nodes) > 0: + node = cur_nodes.pop() + if node in insertable_nodes or node in env: + continue + insertable_nodes.add(node) + + # Bias traversal towards the nodes that have higher depth - prioritizes + # critical path first. + cur_nodes += node.all_input_nodes + + insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n]) + for node in insertable_nodes: + env[node] = new_graph.node_copy(node, lambda x: env[x]) + + # Find first bwd node in the graph + tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) + first_node_in_bwd = None + minimum_order = math.inf + for tangent in tangent_inputs: + for user in tangent.users: + if order[user] < minimum_order: + minimum_order = order[user] + first_node_in_bwd = user + + # If gradInp does not depend upon gradOut, we may not find any nodes in the "backwards pass" + if first_node_in_bwd is None: + return gm + + # Build the graph op-by-op by starting from the node all the way to the end + # copy_ can be not using tangents at all, we must copy it. + for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]: + if node.op == "call_function" and node.target == torch.ops.aten.copy_.default: + insert_node_in_graph(node) + + for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]: + insert_node_in_graph(node) + + # The output node is already built by the traversal. + new_gm = torch.fx.GraphModule(gm, new_graph) + return new_gm + + +def apply_graphsafe_rng_functionalization( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + fw_node: torch.fx.Node, + bw_node: torch.fx.Node, + device: torch.device, + rng_count: int, + last_fwd_input: torch.fx.Node, + last_bwd_input: torch.fx.Node, +): + """ + Note [CUDA Graph Safe RNG Functionalization] + + CUDA Graph capture doesn't work with get_rng_state and set_rng_state because these functions operate on CPU values, + while CUDA Graph RNG capture uses on-device CUDA tensors. To solve this, we use graphsafe_set_state with a + CUDA Generator registered to the CUDA Graph before capture begins. graphsafe_set_state updates the generator's pointer + to reference a different GeneratorImpl, ensuring subsequent calls are correctly forwarded to the desired generator + (and its cuda-tensor RNG state during graph capture). + + For each RNG operation's forward/backward pair: + + - We create two generators initialized with identical values + - Each forward and backward call advances its respective generator equally + - This keeps generators synchronized so forward and backward operations use matching RNG values + + When forward is called multiple times before backward (causing desynchronization): + + - We save the forward RNG state + - We update the backward Generator's state before executing backward + + Before each CUDA Graph replay, replay_prologue updates captured RNG pointers with current states, ensuring backward Generator + changes are reflected during replay. + + This function modifies both forward and backward computation graphs by: + + Creating RNG state placeholders for both passes + Updating the forward node to use graph-safe RNG state + Updating the backward node to use graph-safe RNG state + + For more details: https://github.com/pytorch/pytorch/issues/113541 + """ + device_idx = device.index + assert device_idx is not None + fw_graph = fw_module.graph + bw_graph = bw_module.graph + graphsafe_run_with_rng_state = torch._prims.rng_prims.graphsafe_run_with_rng_state + + # Handle forward pass + + # Note: [Generator arguments in AOTDispatcher] + # Generator arguments in AOTDispatcher are added to support graphsafe rng + # functionalization. See note above [CUDA Graph Safe RNG Functionalization] + with fw_module.graph.inserting_after(last_fwd_input): + fwd_rng_state = fw_module.graph.placeholder(f"fwd_rng_state_{rng_count}") + fwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx) + last_fwd_input = fwd_rng_state + + # Handle backward pass + with bw_module.graph.inserting_after(last_bwd_input): + bwd_rng_state = bw_module.graph.placeholder(f"bwd_rng_state_{rng_count}") + # as above, clone so that meta val generator will not contain tensors + bwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx) + last_bwd_input = bwd_rng_state + + # Update forward node + fw_kwargs = dict(fw_node.kwargs) + fw_kwargs["rng_state"] = fwd_rng_state + with fw_module.graph.inserting_after(fw_node): + functional_fw_node = fw_graph.create_node( + "call_function", + graphsafe_run_with_rng_state, + args=(fw_node.target, *fw_node.args), # type: ignore[arg-type] + kwargs=fw_kwargs, + ) + fw_node.replace_all_uses_with(functional_fw_node) + fw_graph.erase_node(fw_node) + + # Update backward node + bwd_kwargs = dict(bw_node.kwargs) + bwd_kwargs["rng_state"] = bwd_rng_state + with bw_graph.inserting_before(bw_node): + rng_output = bw_graph.create_node( + "call_function", + graphsafe_run_with_rng_state, + args=(bw_node.target, *bw_node.args), # type: ignore[arg-type] + kwargs=bwd_kwargs, + ) + bw_node.replace_all_uses_with(rng_output) + bw_graph.erase_node(bw_node) + + return last_fwd_input, last_bwd_input + + +def functionalize_rng_ops( + joint_module: fx.GraphModule, + fw_module: fx.GraphModule, + bw_module: fx.GraphModule, + num_sym_nodes: int, +) -> tuple[fx.GraphModule, fx.GraphModule]: + # During user-driven activation checkpointing, we have to ensure that a rng + # op in fwd yields the same output as the recomputed rng op in the bwd. To + # do this, we use functionalize wrappers to wrap the random ops and share + # rng state between the fwd and bwd graphs. + + # There are 3 main steps to do this + # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. + # Step 2 - Modify the fwd pass such that + # 1) Replace rand with run_and_save_rng_state wrapper + # 2) Replace the users of the original op with the output[1] of this op. + # 3) Collect all the rng_state - output[0] of each op, and make them + # output nodes. Special care needs to be taken here because fwd outputs + # has symints at the very end. + # Step 3 - Modify the bwd pass such that + # 1) Add the input nodes just before the tangents for the stashed rng states + # 2) Replace rand with run_with_save_rng_state wrappers + # 3) Use the stashed states as inputs to these ops + + # Unique id to generate name + uid = itertools.count() + + def get_rng_ops(gmod): + random_nodes = {} + for node in gmod.graph.nodes: + if ( + node.op == "call_function" + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + random_nodes[node.name] = node + return random_nodes + + def get_device(node) -> Optional[torch.device]: + """ + Check the example value of the node outputs to find the device type. + """ + if "val" not in node.meta: + return None + + candidates = node.meta["val"] + if not isinstance(candidates, tuple): + candidates = (candidates,) + + for candidate in candidates: + if isinstance(candidate, torch.Tensor): + if candidate.device.type == "cuda": + return candidate.device + + return torch.device("cpu") + + def get_sample_rng_state(device: Optional[torch.device]): + if device is not None and device.type == "cuda": + return torch.cuda.get_rng_state() + return torch.get_rng_state() + + # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. + joint_graph_rng_ops = get_rng_ops(joint_module) + fw_graph_rng_ops = get_rng_ops(fw_module) + bw_graph_rng_ops = get_rng_ops(bw_module) + recomputable_rng_ops_map = {} + for node in joint_module.graph.nodes: + if ( + must_recompute(node) + and hasattr(node.target, "tags") + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + base_node = joint_graph_rng_ops[node.name] + fw_node = fw_graph_rng_ops[node.name] + bw_node = bw_graph_rng_ops[node.name] + recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node} + + run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state + run_with_rng_state = torch._prims.rng_prims.run_with_rng_state + + bw_tangent_start_node = None + for node in bw_module.graph.find_nodes(op="placeholder"): + if "tangent" in node.name: + bw_tangent_start_node = node + break + if bw_tangent_start_node is None: + raise RuntimeError( + "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this" + ) + + fw_rng_state_outputs = [] + + last_fwd_input = next(reversed(fw_module.graph.find_nodes(op="placeholder"))) + last_bwd_input = next(reversed(bw_module.graph.find_nodes(op="placeholder"))) + + devices = OrderedSet( + get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values() + ) + devices.discard(torch.device("cpu")) + # multiple cuda devices wont work with cudagraphs anyway, + # fallback to non graphsafe rng checkpointing + multi_cuda_devices = len(devices) > 1 + + # this changes numerics, so if fallback_random is set we will not use it + ind_config = torch._inductor.config + use_rng_graphsafe_rng_functionalization = ( + config.graphsafe_rng_functionalization + and not multi_cuda_devices + and ( + not ind_config.fallback_random + or ind_config.test_configs.graphsafe_rng_func_ignores_fallback_random + ) + ) + + for rng_count, (base_node, node_pair) in enumerate( + recomputable_rng_ops_map.items() + ): + # Step 2 - Modify the fwd pass such that + fw_node = node_pair["fwd"] + bw_node = node_pair["bwd"] + device = get_device(fw_node) + + fw_graph = fw_module.graph + bw_graph = bw_module.graph + + if ( + use_rng_graphsafe_rng_functionalization + and device is not None + and device.type == "cuda" + ): + last_fwd_input, last_bwd_input = apply_graphsafe_rng_functionalization( + fw_module, + bw_module, + fw_node, + bw_node, + device, + rng_count, + last_fwd_input, + last_bwd_input, + ) + else: + with fw_graph.inserting_before(fw_node): + functional_fw_node = fw_graph.create_node( + "call_function", + run_and_save_rng, + args=(fw_node.target, *fw_node.args), + kwargs=fw_node.kwargs, + ) + state = fw_graph.create_node( + "call_function", + operator.getitem, + args=(functional_fw_node, 0), + kwargs={}, + ) + rng_output = fw_graph.create_node( + "call_function", + operator.getitem, + args=( + functional_fw_node, + 1, + ), + kwargs={}, + ) + fw_node.replace_all_uses_with(rng_output) + fw_graph.erase_node(fw_node) + fw_rng_state_outputs.append(state) + + # Step 3 - Modify the bwd pass such that + with bw_graph.inserting_before(bw_tangent_start_node): + state_name = f"rng_state_output_{next(uid)}" + bw_rng_state_node = bw_graph.placeholder(state_name) + bw_rng_state_node.meta["val"] = get_sample_rng_state(device) + + with bw_graph.inserting_before(bw_node): + rng_output = bw_graph.create_node( + "call_function", + run_with_rng_state, + args=(bw_rng_state_node, bw_node.target, *bw_node.args), + kwargs=bw_node.kwargs, + ) + + bw_node.replace_all_uses_with(rng_output) + bw_graph.erase_node(bw_node) + + # Add the rng states in the output of the fwd graph. AOT Autograd assumes + # that symints are at the end of forward graph outputs. So, insert the new + # rng states accordingly. + if fw_rng_state_outputs: + fw_output_node = next(iter(fw_module.graph.find_nodes(op="output"))) + fw_outputs = fw_output_node.args[0] + sym_node_start_idx = len(fw_outputs) - num_sym_nodes + outputs = ( + fw_outputs[:sym_node_start_idx] + + tuple(fw_rng_state_outputs) + + fw_outputs[sym_node_start_idx:] + ) + fw_module.graph.output(outputs) + fw_module.graph.erase_node(fw_output_node) + fw_module.recompile() + bw_module.recompile() + return fw_module, bw_module + + +def force_save_collectives(joint_module: fx.GraphModule) -> None: + """ + By default, the partitioner is not allowed to recompute collectives + unless they come from a user-annotated AC region. + See Note [Recomputing collectives in the partitioner] + """ + for node in joint_module.graph.nodes: + if ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace == "_c10d_functional" + and not must_recompute(node) + ): + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + + +def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: + """ + If there are two consecutive checkpointed blocks with no operator in + between, we would still want to stash the tensor at the boundary of + checkpointed blocks. The following pass makes the last output node + non-recomputable to allow for that. + """ + for node in joint_module.graph.nodes: + if must_recompute(node): + for user in node.users: + if ( + must_recompute(user) + and user.meta["ac_graph_id"] > node.meta["ac_graph_id"] + ): + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + if node.meta.get("has_backward_hook", False) and not any( + must_recompute(user) for user in node.users + ): + # If node is AC region output and has a backward hook on it, we intentionally choose to save it. + # This is to work around circular dependencies in Traceable FSDP2+AC. + # Example: + # ``` + # out = fully_shard(utils.checkpoint(module))(x) + # norm_out = layer_norm(out) + # ``` + # Here there is a circular dependency: + # 1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`. + # 2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) + # in order to be recomputed. + # 3. `out`'s backward hook, as is the case for all eager backward hooks, depends on `out_grad` + # -> circular dependency with (1)! + # + # Solution: check whether `out` has a backward hook, and if so, intentionally save `out` + # in forward graph outputs. With this, we can break the above circular dependency. + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + return joint_module + + +def solve_min_cut( + joint_graph: fx.Graph, + node_info: NodeInfo, + min_cut_options: MinCutOptions, + dont_ban: Optional[OrderedSet[fx.Node]] = None, +): + if dont_ban is None: + dont_ban = OrderedSet() + op_types = get_default_op_list() + + if AOT_PARTITIONER_DEBUG: + joint_module_ops = OrderedSet( + str(node.target._overloadpacket) + for node in joint_graph.nodes + if node.op == "call_function" and hasattr(node.target, "_overloadpacket") + ) + ops_ignored = joint_module_ops - OrderedSet( + str(i) for i in op_types.recomputable_ops + ) + log.info("Ops banned from re-materialization: %s", ops_ignored) + + def can_fuse_into_auto_functionalized(a, b): + if b.target != torch.ops.higher_order.auto_functionalized: + return False + mutable_op = b.args[0] + ( + mutable_arg_names, + _, + ) = torch._higher_order_ops.auto_functionalize.get_mutable_args(mutable_op) + for name in mutable_arg_names: + arg = b.kwargs[name] + if a is arg: + return True + if isinstance(arg, list): + if a in arg: + return True + return False + + def can_fuse_into_triton_kernel_wrapper_functional(a, b): + if b.target != torch.ops.higher_order.triton_kernel_wrapper_functional: + return False + mutable_arg_names = b.kwargs["tensors_to_clone"] + for name in mutable_arg_names: + arg = b.kwargs["kwargs"][name] + if a is arg: + return True + return False + + def is_fusible(a, b): + # We can perform "memory fusion" into a cat, but cat cannot be a + # producer to a fusion + if get_aten_target(b) == aten.cat: + return True + if can_fuse_into_auto_functionalized(a, b): + return True + if can_fuse_into_triton_kernel_wrapper_functional(a, b): + return True + if ( + a.target is operator.getitem + and a.args[0].target + is torch.ops.higher_order.triton_kernel_wrapper_functional + ): + # if a is the output of a user triton kernel, + # then (by default) we will not be able to fuse b into it + return False + return op_types.is_fusible(a) and op_types.is_fusible(b) + + try: + import networkx as nx + except ImportError as e: + raise RuntimeError( + "Need networkx installed to perform smart recomputation heuristics" + ) from e + + def is_materialized_backwards(node): + if op_types.is_view(node): + return False + cur_nodes = OrderedSet([node]) + while len(cur_nodes) > 0: + cur = cur_nodes.pop() + for user in cur.users: + if not node_info.is_required_fw(user) and not is_fusible(cur, user): + return True + if op_types.is_view(user): + cur_nodes.add(user) + + return False + + def should_ban_recomputation(node): + if node.op != "call_function": + return False + if node.target == operator.getitem: + return False + if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE: + return True + if config.recompute_views and op_types.is_view(node): + return False + if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: + return False + + if min_cut_options.ban_if_not_in_allowlist: + if not op_types.is_recomputable(node): + return True + else: + if op_types.is_random(node) or op_types.is_compute_intensive(node): + return True + + # If a node *must* be materialized in the backwards pass, then we + # should never recompute it. This is a pretty subtle point. In + # general, the assumption we make is that recomputing a node in the + # backwards pass is "free". However, if a node must be materialized + # in the backwards pass, then recomputing it is never free. + if min_cut_options.ban_if_materialized_backward and is_materialized_backwards( + node + ): + log.debug("materialized backwards: %s %s", node, tuple(node.users)) + return True + + # Arbitrary hack that sometimes seems to help things. The above + # modification appears to have made this heuristic a lot less critical + # for performance. + # NB: As of PR #121692, this hack no longer seems necessary. + if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: + return True + + # If the output of an op is 4x smaller (arbitrary choice), + # then we don't allow recomputation. The idea here is that for + # things like reductions, saving the output of the reduction is very + # cheap/small, and it makes sure we don't do things like recompute + # normalizations in the backwards. + if min_cut_options.ban_if_reduction: + input_tensors_size = sum( + _size_of(i) for i in node.args if isinstance(i, fx.Node) + ) + output_size = _size_of(node) + return output_size * 4 < input_tensors_size + return False + + def is_materialized(node): + if node.op == "placeholder": + return True + + return not all(is_fusible(node, user) for user in node.users) + + def get_node_weight(node, static_lifetime_input_nodes) -> float: + if ( + config.treat_parameters_as_free_to_save + and node in static_lifetime_input_nodes + ): + return 0 + mem_sz = _size_of(node) + if config.recompute_views and op_types.is_view(node): + # If `config.recompute_views=True`, we don't save views. This is generally + # a good idea since views are free to recompute, and it makes it a bit simpler + # to analyze. + # NB: If they're not free to recompute (e.g. nested tensors)... I + # think we should modify checks for view_ops to `is_view` and check + # that. Basically, with nested tensors, `aten.view` is not a "view + # op". + return math.inf + + if isinstance(node.meta["val"], py_sym_types): + # We never want to save symfloats + if not isinstance(node.meta["val"], torch.SymInt): + return INT_INF + + # Heuristic to bias towards nodes closer to the backwards pass + # Complete guess about current value + mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) + if is_materialized(node): + return mem_sz + else: + return mem_sz * 2 + + nx_graph = nx.DiGraph() + banned_nodes: OrderedSet[fx.Node] = OrderedSet() + + def ban_recomputation_if_allowed(node): + if op_types.is_view(node): + return False + if node in dont_ban: + # collectives are *always* banned from recompute, overriding `dont_ban` + # (in particular, the activation memory budget logic is not allowed to recompute collectives) + is_collective = ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace == "_c10d_functional" + ) + if config.unsafe_allow_optimization_of_collectives or not is_collective: + return False + # This bans recomputation of the node unless we've been forced not to by + # user annotation + if must_recompute(node): + return False + + if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat): + return False + banned_nodes.add(node) + # A node will only ever be recomputed if there is a path from an + # ancestor of this node to the backwards path through this node that + # doesn't go through any saved value. If this node is saved, then that + # condition is not possible. + nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) + return True + + for node in joint_graph.nodes: + if node.op == "output": + continue + + if node in node_info.required_bw_nodes: + if node not in node_info.inputs: + nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) + continue + # If someone saves a input for backward as-is and backward + # returns that tensor as-is as a grad input, then the node x would + # be both a required_bw_node and an input. In this case we + # (1) connect x_in to to the source, (2) x_out to the sink, and + # (3) assign the proper weight to the x_in-x_out edge, so that + # x would be part of cut nodes. A case where this happens is if + # NestedTensor saves a offset tensor as part of the singleton int + # in sizes. + nx_graph.add_edge(node.name + "_out", "sink", capacity=math.inf) + + if must_recompute(node): + # If user explicitly says they want to recompute a node, we honor it + # by adding an inf-capacity edge from X_in to the sink. + # This way, X_in node is guaranteed to be part of the subgraph that contains "sink" + # after the cut, thus guaranteeing that X op will be recomputed. + nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) + continue + + if _is_primal(node) or _is_fwd_seed_offset(node): + ban_recomputation_if_allowed(node) + + # If a node can't be recomputed (too expensive or involves randomness), + # we prevent it from being recomputed by adding an inf edge to the source + # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. + if node_info.is_required_fw(node) and should_ban_recomputation(node): + ban_recomputation_if_allowed(node) + + # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors. + is_non_tensor_node = ( + "val" not in node.meta and "tensor_meta" not in node.meta + ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor)) + + if is_sym_node(node): + weight = float(sym_node_size(node)) + elif is_non_tensor_node: + weight = ( + 0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf + ) + else: + weight = get_node_weight(node, node_info.static_lifetime_input_nodes) + # Creates the weights on the "node" edge + nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) + for user in node.users: + nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) + + # todo(chilli): This is the most questionable of the 3 heuristics for banning recompute. + # Some example models to look at where this helps perf: poolformer_m36, + # mixer_b16_224, cait_m36_384 + + # The "rough" idea here is that if you have some node that is used by both a + # node nearby downstream as well as a node far downstream, if we recompute + # both of the downstream nodes, we're unlikely to be able to fuse both + # downstream nodes together. + + # Thus, we shouldn't aim to recompute far downstream nodes that depend on + # this node. That intuition of "far downstream" is captured by whether + # there's an unfusible op along the chain somewhere + + # It could probably be improved by properly analyzing what's going on in the + # backwards pass instead of only relying on whether it's unfusible in the + # forwards. + + def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int: + """ + Finds the first unfusible node in the chain of nodes starting from + `start_nodes` and returns its position. + """ + sorted_nodes: list[tuple[int, fx.Node, bool]] = [] + for n in start_nodes: + heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True)) + + while len(sorted_nodes) > 0: + _, node, node_is_fusible = heapq.heappop(sorted_nodes) + if not node_is_fusible: + return node_info.get_fw_order(node) + for user in node.users: + if node_info.is_required_fw(user): + if node_info.get_fw_order(user) > max_range: + continue + val: tuple[int, fx.Node, bool] = ( + node_info.get_fw_order(user), + user, + is_fusible(node, user), + ) + if val not in sorted_nodes: + heapq.heappush(sorted_nodes, val) + return max_range + + if min_cut_options.ban_if_used_far_apart: + for used_node in node_info.required_fw_nodes: + orders = [ + node_info.get_fw_order(user) + for user in used_node.users + if node_info.is_required_fw(user) + ] + fw_users = [ + user for user in used_node.users if node_info.is_required_fw(user) + ] + if len(orders) > 0: + first_unfusible_use = find_first_unfusible(fw_users, max(orders)) + for user in tuple(used_node.users): + if ( + node_info.is_required_fw(user) + and node_info.get_fw_order(user) > first_unfusible_use + and is_fusible(used_node, user) + ): + if user in banned_nodes: + continue + log.info( + "used above/below fusible %s:(%s) -> %s -> %s:(%s)", + used_node, + node_info.get_fw_order(used_node), + first_unfusible_use, + user, + node_info.get_fw_order(user), + ) + ban_recomputation_if_allowed(user) + + # This heuristic is fairly straightforward. The idea is that although it is + # cheap to recompute bandwidth-bound ops, we don't want to end up in a situation + # where we have a long chain of pointwise ops from the beginning to the end + # of the model (like say, residual connections) + + # todo: I'm not totally sure why this heuristic matters. It's possible that this is + # working around Inductor fusion decisions, or that it's a patch over + # suboptimal partitioning decisions + + # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36 + + if min_cut_options.ban_if_long_fusible_chains: + visited: OrderedSet[fx.Node] = OrderedSet() + for start_node in joint_graph.nodes: + if not node_info.is_required_fw(start_node): + continue + fusible: list[tuple[int, fx.Node]] = [ + (node_info.get_fw_order(start_node), start_node) + ] + start_order = node_info.get_fw_order(start_node) + while len(fusible) > 0: + _, cur = heapq.heappop(fusible) + if cur in visited: + continue + visited.add(cur) + # 100 is arbitrary choice to try and prevent degenerate cases + if ( + node_info.get_fw_order(cur) > start_order + 100 + and len(fusible) == 0 + ): + log.info( + "too long %s %s %s %s", + cur, + start_node, + node_info.get_fw_order(cur), + node_info.get_fw_order(start_node), + ) + ban_recomputation_if_allowed(cur) + break + + for user in cur.users: + if ( + node_info.is_required_fw(user) + and is_fusible(cur, user) + and user not in banned_nodes + ): + heapq.heappush(fusible, (node_info.get_fw_order(user), user)) + + try: + cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") + except Exception: + log.info("Failed to compute min-cut on following graph:") + log.info("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) + visualize_min_cut_graph(nx_graph) + raise + + reachable, non_reachable = partition + cutset: OrderedSet[tuple[str, str]] = OrderedSet() + for u, nbrs in ((n, nx_graph[n]) for n in reachable): + cutset.update((u, v) for v in nbrs if v in non_reachable) + + cut_nodes: OrderedSet[str] = OrderedSet() + for node_in, node_out in cutset: + assert node_in[:-3] == node_out[:-4] + node_name = node_in[:-3] + cut_nodes.add(node_name) + + name_to_node = get_name_to_node(joint_graph) + # To make this stuff deterministic + node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)} + saved_values = sorted( + (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x] + ) + return saved_values, banned_nodes + + +def visualize_min_cut_graph(nx_graph): + import networkx as nx + import pydot + + dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string() + dot_graph = pydot.graph_from_dot_data(dot_format)[0] # type: ignore[index] + for edge in dot_graph.get_edges(): + weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"] + # Set edge label to weight + edge.set_label(str(weight)) # type: ignore[union-attr] + # Color edges with weight 'inf' as red + if weight == float("inf"): + edge.set_color("red") # type: ignore[union-attr] + log.info("Visualizing the failed graph to min_cut_failed.svg") + dot_graph.write_svg("min_cut_failed.svg") # type: ignore[union-attr] + + +def get_default_op_list() -> OpTypes: + default_recomputable_ops: list[Callable] = [ + aten.add, + aten.sub, + aten.div, + aten.atan2, + aten.mul, + aten.max, + aten.min, + aten.pow, + aten.remainder, + aten.fmod, + aten.__and__, + aten.__or__, + aten.__xor__, + aten.__lshift__, + aten.__rshift__, + aten.eq, + aten.ne, + aten.ge, + aten.gt, + aten.le, + aten.lt, + aten.abs, + aten.bitwise_not, + aten.ceil, + aten.floor, + aten.frac, + aten.neg, + aten.relu, + aten.round, + aten.silu, + aten.trunc, + aten.log, + aten.log10, + aten.log1p, + aten.log2, + aten.lgamma, + aten.exp, + aten.expm1, + aten.erf, + aten.erfc, + aten.cos, + aten.acos, + aten.cosh, + aten.sin, + aten.asin, + aten.sinh, + aten.tan, + aten.atan, + aten.tanh, + aten.atanh, + aten.sqrt, + aten.rsqrt, + aten.reciprocal, + aten.sigmoid, + aten.softplus, + aten.threshold, + aten.threshold_backward, + aten.clamp, + aten.where, + aten.lerp, + aten.addcmul, + aten.gelu, + aten.gelu_backward, + aten.sum, + aten.mean, + aten._grad_sum_to_size, + aten.sum_to_size, + aten.amax, + aten.to, + aten.type_as, + operator.getitem, + aten.squeeze, + aten.unsqueeze, + aten.rsub, + aten._to_copy, + ] # noqa: E501,B950 + recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] + recomputable_view_ops += [ + aten.view, + aten.slice, + aten.t, + prims.broadcast_in_dim, + aten.expand, + aten.as_strided, + aten.permute, + aten.select, + aten.split, + ] + view_ops = recomputable_view_ops + default_recomputable_ops += [ + prims.div, + prims.convert_element_type, + aten.clone, + aten._to_copy, + aten.full_like, + prims.var, + prims.sum, + aten.var, + aten.std, + prims.broadcast_in_dim, + aten.select, + aten._unsafe_view, + aten.view, + aten.expand, + aten.slice, + aten.reshape, + aten.broadcast_tensors, + aten.scalar_tensor, + aten.ones, + aten.new_zeros, + aten.lift_fresh_copy, + aten.arange, + aten.triu, + aten.var_mean, + aten.isinf, + aten.any, + aten.full, + aten.as_strided, + aten.zeros, + aten.empty, + aten.empty_like, + aten.argmax, + aten.maximum, + prims.iota, + prims._low_memory_max_pool_offsets_to_indices, + ] # noqa: E501,B950 + # Natalia said that we should allow recomputing indexing :) + default_recomputable_ops += [aten.index, aten.gather] + default_recomputable_ops += view_ops + + default_recomputable_ops += pointwise_ops() + + default_recomputable_ops += [ + aten.zeros_like, + ] + + default_recomputable_ops += [method_to_operator(m) for m in magic_methods] + recomputable_ops = OrderedSet(default_recomputable_ops) + + random_ops = OrderedSet[Callable[..., Any]]( + [aten.native_dropout, aten.rand_like, aten.randn_like] + ) + compute_intensive_ops = [ + aten.mm, + aten.convolution, + aten.convolution_backward, + aten.bmm, + aten.addmm, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_efficient_attention, + aten._flash_attention_forward, + aten._efficient_attention_forward, + aten.upsample_bilinear2d, + aten._scaled_mm, + ] # noqa: E501,B950 + + fusible_ops = recomputable_ops | random_ops + return OpTypes( + fusible_ops, + OrderedSet(compute_intensive_ops), + random_ops, + OrderedSet(view_ops), + recomputable_ops, + ) + + +def get_name_to_node(graph: fx.Graph): + name_to_node = {} + for node in graph.nodes: + name_to_node[node.name] = node + return name_to_node + + +def _optimize_runtime_with_given_memory( + joint_graph: fx.Graph, + memory: list[float], + runtimes: list[float], + max_memory: float, + node_info: NodeInfo, + all_recomputable_banned_nodes: list[fx.Node], +) -> tuple[float, list[int], list[int]]: + SOLVER = config.activation_memory_budget_solver + if SOLVER == "greedy": + return greedy_knapsack(memory, runtimes, max_memory) + elif SOLVER == "ilp": + return ilp_knapsack(memory, runtimes, max_memory) + elif SOLVER == "dp": + return dp_knapsack(memory, runtimes, max_memory) + elif SOLVER == "dynamic_memory_budget_dp": + log.warning( + "dynamic_memory_budget_dp is an experimental solver. " + "It does not guarantee performance improvements. " + "Additionally, it is not guaranteed to be stable." + ) + graph_info_provider = GraphInfoProvider.inialize_from_graph( + joint_graph=joint_graph, + all_recomputable_banned_nodes=all_recomputable_banned_nodes, + recorded_knapsack_input_memories=memory, + recorded_knapsack_input_runtimes=runtimes, + ) + return dp_knapsack( + memory, + runtimes, + KnapsackEvaluator( + graph_info_provider=graph_info_provider, + ).get_knee_point_memory_budget( + knapsack_algo=dp_knapsack, + max_mem_budget=max_memory, + ), + ) + elif callable(SOLVER): + saved_node_idx, recomp_node_idx = SOLVER( + memory, joint_graph, max_memory, node_info, all_recomputable_banned_nodes + ) + return (0.0, saved_node_idx, recomp_node_idx) + else: + raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") + + +from torch.utils._mode_utils import no_dispatch + + +# replace symbols in size and strides with their hints without guarding. +def _remove_symbols_without_guarding(x: torch.Tensor, fallback: int) -> torch.Tensor: + shape = list(x.shape) + + def realize_symbol(d): + return hint_int(d, fallback=fallback) + + shape = [realize_symbol(s) for s in shape] + stride = [realize_symbol(s) for s in x.stride()] + return x.new_empty_strided(shape, stride=stride) + + +def estimate_runtime(node): + RUNTIME_MODE = config.activation_memory_budget_runtime_estimator + + def materialize_arg(x): + if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): + return _remove_symbols_without_guarding(x.meta["val"], fallback=4096) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): + return hint_int(x.meta["val"], fallback=4096) + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): + return 1.0 + elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): + return True + else: + return x + + if RUNTIME_MODE == "testing": + return 1 + + elif RUNTIME_MODE == "profile": + with no_dispatch(): + from torch._inductor.runtime.benchmarking import benchmarker + + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + ms = benchmarker.benchmark_gpu(lambda: node.target(*args, **kwargs)) + return ms + + elif RUNTIME_MODE == "flops": + # todo(chilli): Normalize this to also return ms + from torch.utils.flop_counter import FlopCounterMode + + args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) + with FlopCounterMode(display=False) as mode: + node.target(*args, **kwargs) + counted_flops = mode.get_total_flops() + return max(counted_flops, 1) + else: + raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}") + + +def choose_saved_values_set( + joint_graph: fx.Graph, + node_info: NodeInfo, + memory_budget=1, +) -> list[fx.Node]: + if memory_budget > 1 or memory_budget < 0: + raise RuntimeError( + f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}" + ) + min_cut_options = MinCutOptions( + ban_if_used_far_apart=config.ban_recompute_used_far_apart, + ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, + ban_if_materialized_backward=config.ban_recompute_materialized_backward, + ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist, + ban_if_reduction=config.ban_recompute_reductions, + ) + + if config.aggressive_recomputation: + min_cut_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ban_if_not_in_allowlist=False, + ) + if memory_budget == 0: + return node_info.inputs + + runtime_optimized_saved_values, _ = solve_min_cut( + joint_graph, + node_info, + min_cut_options, + ) + # return runtime_optimized_saved_values + if memory_budget == 1: + return runtime_optimized_saved_values + + def estimate_activations_size(saved_values: list[fx.Node]) -> float: + return sum(map(_size_of, saved_values)) / 1e9 + + min_act_size = estimate_activations_size(node_info.inputs) + max_act_size = estimate_activations_size(runtime_optimized_saved_values) + # The optimized choice is smaller than the inputs anyways + if max_act_size <= min_act_size: + return runtime_optimized_saved_values + + def get_normalized_size(sz): + return (sz / 1e9) / (max_act_size - min_act_size) + + def get_mem_ratio(activations: list[fx.Node]): + return (estimate_activations_size(activations) - min_act_size) / ( + max_act_size - min_act_size + ) + + more_aggressive_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ) + more_aggressive_saved_values, _ = solve_min_cut( + joint_graph, node_info, more_aggressive_options + ) + if get_mem_ratio(more_aggressive_saved_values) < memory_budget: + return more_aggressive_saved_values + + aggressive_options = replace( + more_aggressive_options, + ban_if_not_in_allowlist=False, + ) + aggressive_recomputation_saved_values, banned_nodes = solve_min_cut( + joint_graph, node_info, aggressive_options + ) + + if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget: + return aggressive_recomputation_saved_values + + from torch._inductor.fx_utils import get_node_storage + + input_storages = OrderedSet(get_node_storage(node) for node in node_info.inputs) + + def get_recomputable_banned_nodes( + banned_nodes: OrderedSet[fx.Node], + ) -> list[fx.Node]: + return [ + i + for i in banned_nodes + if ( + # Only allow recomputing nodes that are actually required for BW + i.dist_from_bw < int(1e9) # type: ignore[attr-defined] + and get_node_storage(i) not in input_storages + ) + ] + + recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes) + must_save_nodes = [ + i + for i in recomputable_banned_nodes + if i.meta.get("recompute", False) == CheckpointPolicy.MUST_SAVE + ] + recomputable_banned_nodes = [ + i for i in recomputable_banned_nodes if i not in must_save_nodes + ] + + # default: runtime_optimized_saved_values + # more aggressive: more_aggressive_saved_values + # full aggressive: aggressive_recomputation_saved_values + + all_recomputable_banned_nodes = sorted( + recomputable_banned_nodes, key=_size_of, reverse=True + ) + if len(all_recomputable_banned_nodes) == 0: + return node_info.inputs + must_save_nodes + memories_banned_nodes = [ + get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes + ] + runtimes_banned_nodes = [ + estimate_runtime(node) for node in all_recomputable_banned_nodes + ] + from torch.utils._mode_utils import no_dispatch + + def get_saved_values_knapsack(memory_budget, node_info, joint_graph): + with no_dispatch(): + ( + expected_runtime, + saved_node_idxs, + recomputable_node_idxs, + ) = _optimize_runtime_with_given_memory( + joint_graph, + memories_banned_nodes, + runtimes_banned_nodes, + max(memory_budget, 0), + node_info, + all_recomputable_banned_nodes, + ) + dont_ban: OrderedSet[fx.Node] = OrderedSet() + for idx in recomputable_node_idxs: + # if idx in all_recomputable_banned_nodes: + try: + dont_ban.add(all_recomputable_banned_nodes[idx]) + except BaseException: + pass + + assert dont_ban.issubset(all_recomputable_banned_nodes) + + saved_values, _ = solve_min_cut( + joint_graph, + node_info, + aggressive_options, + dont_ban, + ) + if AOT_PARTITIONER_DEBUG: + create_structured_trace_for_min_cut_info( + joint_graph=joint_graph, + all_recomputable_banned_nodes=all_recomputable_banned_nodes, + saved_node_idxs=saved_node_idxs, + recomputable_node_idxs=recomputable_node_idxs, + expected_runtime=expected_runtime, + memories_banned_nodes=memories_banned_nodes, + runtimes_banned_nodes=runtimes_banned_nodes, + min_cut_saved_values=saved_values, + ) + return saved_values, expected_runtime + + if config.visualize_memory_budget_pareto: + + def estimate_for_budget(b): + saved_values, expected_runtime = get_saved_values_knapsack( + b, node_info=node_info, joint_graph=joint_graph + ) + return ( + b, + sum(runtimes_banned_nodes) - expected_runtime, + get_mem_ratio(saved_values), + ) + + options = [estimate_for_budget(0.0), estimate_for_budget(1.0)] + + if options[0][1:] != options[1][1:]: + bisects = [(options[0], options[1])] + while bisects: + lhs, rhs = bisects.pop() + if rhs[0] - lhs[0] < 1e-3: + options.append(lhs) + options.append(rhs) + continue + mid = estimate_for_budget((lhs[0] + rhs[0]) / 2) + if mid[1:] != lhs[1:]: + bisects.append((lhs, mid)) + if mid[1:] != rhs[1:]: + bisects.append((mid, rhs)) + options.sort() + + import matplotlib.pyplot as plt + + x_values = [item[2] for item in options] + y_values = [item[1] for item in options] + + # Plotting the values with updated axis labels and chart title + plt.figure(figsize=(10, 6)) + plt.plot(x_values, y_values, marker="o") + + # Adding labels for each point + for i, txt in enumerate(x_values): + plt.annotate( + f"{txt:.4f}", + (txt, y_values[i]), + textcoords="offset points", + xytext=(0, 10), + ha="center", + ) + + plt.xlabel("Memory Budget") + plt.ylabel("Runtime of Recomputed Components") + plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime") + plt.grid(True) + fig = plt.gcf() + plt.show() + fig_dir = os.getcwd() + if config.memory_budget_pareto_dir is not None: + fig_dir = config.memory_budget_pareto_dir + os.makedirs(fig_dir, exist_ok=True) + rank_suffix = "" + if torch.distributed.is_available() and torch.distributed.is_initialized(): + rank_suffix = f"_rank_{torch.distributed.get_rank()}" + fig_name = os.path.join( + fig_dir, f"memory_budget_pareto{rank_suffix}_{get_aot_graph_name()}.svg" + ) + fig.savefig(fig_name) + log.warning("Generated Pareto frontier curve at %s", fig_name) + + # todo(chilli): Estimated doesn't align exactly with actual - actual is + # usually less memory than estimated. i'm guessing (actually quite + # unsure about this) that's because estimated is just only including + # tensors we actually banned from recompute, but there may be other + # tensors that we choose to save. + + return get_saved_values_knapsack( + memory_budget=memory_budget, node_info=node_info, joint_graph=joint_graph + )[0] + + +def _broadcast_rank0_decision( + joint_graph: torch.fx.Graph, saved_values: list[torch.fx.Node] +): + # use the same policy across different GPUs + from torch._subclasses.fake_tensor import unset_fake_temporarily + + def has_collectives(joint_graph): + for node in joint_graph.nodes: + if isinstance( + node.target, torch._ops.OpOverload + ) and node.target.namespace in {"_c10d_functional", "c10d_functional"}: + return True + return False + + def has_same_nodes(joint_graph): + # proxy to check if the graph is the same across different GPUs. + # We only consider the name and order of nodes. A more robust way + # would be to check the hash of the whole graph (disregarding input shapes), + # this is is a reasonable first-order approximation. + node_str = "/".join(x.name for x in joint_graph.nodes) + inputs = hashlib.sha256(node_str.encode("utf-8")).hexdigest() + all_inputs = [None for _ in range(torch.distributed.get_world_size())] + with no_dispatch(), unset_fake_temporarily(): + # TODO: maybe use a different process group? + torch.distributed.all_gather_object(all_inputs, inputs) + return all(all_inputs[0] == x for x in all_inputs) + + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and torch.distributed.get_world_size() > 1 + and has_collectives(joint_graph) + and has_same_nodes(joint_graph) + ): + with no_dispatch(), unset_fake_temporarily(): + objects = [[x.name for x in saved_values]] + # TODO: maybe use a different process group for this + torch.distributed.broadcast_object_list(objects, src=0) + saved_values_names = objects[0] + name_to_node = get_name_to_node(joint_graph) + saved_values = [name_to_node[n] for n in saved_values_names] + return saved_values + + +def min_cut_rematerialization_partition( + joint_module: fx.GraphModule, + _joint_inputs, + compiler="inductor", + *, + num_fwd_outputs, + static_lifetime_input_indices: Optional[list[int]] = None, +) -> tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the joint graph such that the backward recomputes the forward. + Recomputing helps in trading off memory bandwidth with computation. + + To create the fwd and bwd graph, we copy the joint graph, manually set the + outputs to just original forward or backward outputs. And then we run the + resulting graphs through dead code elimination. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + _joint_inputs: The inputs to the joint graph. This is unused. + compiler: This option determines the default set of recomputable ops. + Currently, there are two options: ``nvfuser`` and ``inductor``. + recomputable_ops: This is an optional set of recomputable ops. If this + is not None, then this set of ops will be used instead of the + default set of ops. + num_fwd_outputs: The number of outputs from the forward graph. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + + joint_module.graph.eliminate_dead_code() + joint_module.recompile() + + fx_g = joint_module.graph + + # add the CSE pass + if config.cse: + cse_graph = fx_graph_cse(fx_g) + joint_module.graph = cse_graph + joint_graph = joint_module.graph + + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + joint_module = cleanup_recompute_tags(joint_module) + if not config.unsafe_allow_optimization_of_collectives: + force_save_collectives(joint_module) + + def classify_nodes(joint_module, static_lifetime_input_indices): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + elif _must_be_in_backward(node): + required_bw_nodes.add(node) + + if node in required_bw_nodes: + required_bw_nodes.update(node.users) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list( + filter(_is_fwd_seed_offset, joint_module.graph.nodes) + ) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, "forward" + ) + required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + ) + unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + ) + static_lifetime_input_nodes = OrderedSet( + p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices + ) + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, + required_fw_nodes, + required_bw_nodes, + unclaimed_nodes, + fw_order, + static_lifetime_input_nodes, + ) + + if static_lifetime_input_indices is None: + static_lifetime_input_indices = [] + node_info = classify_nodes(joint_module, static_lifetime_input_indices) + + # networkx blows up on graphs with no required backward nodes + # Since there's nothing to partition anyway, and the default partitioner can "handle" + # this case, send our graph over to the default partitioner. + if len(node_info.required_bw_nodes) == 0: + return default_partition( + joint_module, + _joint_inputs, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + static_lifetime_input_nodes=node_info.static_lifetime_input_nodes, + ) + + for node in reversed(joint_module.graph.nodes): + if node.op == "output": + node.dist_from_bw = int(1e9) + elif not node_info.is_required_fw(node): + node.dist_from_bw = 0 + else: + node.dist_from_bw = int(1e9) + for user in node.users: + node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) + + memory_budget = config.activation_memory_budget + for node in joint_graph.nodes: + if isinstance(node.meta.get("memory_budget", None), float): + memory_budget = node.meta["memory_budget"] + break + saved_values = choose_saved_values_set( + joint_graph, + node_info, + memory_budget=memory_budget, + ) + if config._broadcast_rank0_decision: + saved_values = _broadcast_rank0_decision(joint_graph, saved_values) + # save_for_backward on tensors and stashes symints in autograd .ctx + saved_sym_nodes = list(filter(is_sym_node, saved_values)) + saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) + + # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols + fw_module, bw_module = _extract_fwd_bwd_modules( + joint_module, + saved_values, + saved_sym_nodes=saved_sym_nodes, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_nodes=node_info.static_lifetime_input_nodes, + ) + if graph_has_recomputable_ops: + if graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + joint_module, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + # raise all getitem ops to as early as possible + # this is helpful for memory, especially in the case of aot_eager backend + fw_module = raise_getitems(fw_module) + bw_module = raise_getitems(bw_module) + + if AOT_PARTITIONER_DEBUG: + # Calculate sorted sizes of saved values + sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values]) + + # Log total theoretical activations stored + total_activations_size_gb = sum(_size_of(i) for i in saved_values) / 1e9 + log.info("Theoretical Activations Stored: %.2f GB", total_activations_size_gb) + + # Log theoretical per activation storage sizes + log.info("Theoretical Per Activation Storage Sizes: %s", sorted_sizes) + fw_module_nodes = OrderedSet( + node.name for node in fw_module.graph.nodes if node.op == "call_function" + ) + bw_module_nodes = OrderedSet( + node.name for node in bw_module.graph.nodes if node.op == "call_function" + ) + remat_nodes = fw_module_nodes & bw_module_nodes + + counts: dict[str, int] = defaultdict(int) + for node in fw_module.graph.nodes: + if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"): + counts[str(node.target._overloadpacket)] += 1 + log.info( + "# remat/fw/bw: %d/%d/%d", + len(remat_nodes), + len(fw_module_nodes), + len(bw_module_nodes), + ) + rematerialized_ops = sorted( + counts.items(), key=operator.itemgetter(1), reverse=True + ) + log.info("Count of Ops Rematerialized: %s", rematerialized_ops) + return fw_module, bw_module + + +def draw_graph( + traced: torch.fx.GraphModule, + fname: str, + figname: str = "fx_graph", + clear_meta: bool = True, + prog: Optional[Union[str, list[str]]] = None, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, +) -> None: + if clear_meta: + new_graph = copy.deepcopy(traced.graph) + traced = fx.GraphModule(traced, new_graph) + for node in traced.graph.nodes: + node.meta = {} + base, ext = os.path.splitext(fname) + if not ext: + ext = "." + config.torch_compile_graph_format + log.info("Writing FX graph to file: %s%s", base, ext) + g = graph_drawer.FxGraphDrawer( + traced, + figname, + parse_stack_trace=parse_stack_trace, + dot_graph_shape=dot_graph_shape, + ) + x = g.get_main_dot_graph() + write_method = getattr(x, "write_" + ext.lstrip(".")) + fname = f"{base}{ext}" + if prog is None: + write_method(fname) + else: + write_method(fname, prog=prog) diff --git a/phivenv/Lib/site-packages/torch/_functorch/pyfunctorch.py b/phivenv/Lib/site-packages/torch/_functorch/pyfunctorch.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9df956ed2bf86a4afa78a34adc844c4ae7021b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/pyfunctorch.py @@ -0,0 +1,312 @@ +# mypy: allow-untyped-defs +import contextlib +from abc import ABC, abstractmethod +from functools import cached_property +from typing import Any + +import torch +import torch.utils._pytree as pytree +from torch._C._functorch import ( + CFunctionalizeInterpreterPtr, + CGradInterpreterPtr, + CInterpreter, + CJvpInterpreterPtr, + CVmapInterpreterPtr, + pop_dynamic_layer_stack, + push_dynamic_layer_stack, + RandomnessType, + TransformType, +) +from torch.autograd.forward_ad import _set_fwd_grad_enabled + + +""" +This file contains the functorch integration with PyDispatcher. + +PyDispatcher does not understand functorch's DynamicLayerStack dispatching +logic because it is entirely implemented in C++ in the fallbacks for two +dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable +to directly reuse C++ boxed fallbacks). + +Instead of trying to hammer PyDispatcher into understanding those fallbacks, +we re-implement the logic of peeking the top of the stack for an interpreter, +selecting the interpreter to dispatch on, etc, in Python. This leads to a +simpler design. + +The main difference between C++ functorch and PyDispatcher's functorch logic +is that: +- C++ functorch needs to manually tweak dispatch keys to ping-pong between + DynamicLayerFrontMode and DynamicLayerBackMode. +- PyDispatcher's functorch logic pops an Interpreter from the top of the stack + and asks it to execute the rule associated with the Interpreter. + +In C++ we do the ping-pong because e.g. vmap rules are associated with the +batched DispatchKey, but in PyDispatcher we are able to avoid this by asking +the user to register a batching rule directly to a transform that an +interpreter then invokes. +""" + + +# FuncTorchInterpreter is the Python version of Interpreter (recall that +# the DynamicLayerStack is a stack of interpreters). +# It is a wrapper around the actual C++ Interpreter object. +# +# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h +class FuncTorchInterpreter(ABC): + def __init__(self, cptr: Any): + self._cptr = cptr + + # Process an operation. eg for vmap, this is invoking a batching rule. + # Conceptually this is analogous to Interpreter::process in C++ + @abstractmethod + def process(self, op, args, kwargs): + pass + + # lower an operation from this Interpreter to the next Interpreter on the stack. + # Concretely, this involves temporarily popping the current Interpreter. + # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++ + def lower(self): + return temporarily_pop_interpreter_stack() + + def level(self): + return self._cptr.level() + + def key(self): + return self._cptr.key() + + def get_state(self): + raise NotImplementedError + + def check_state(self, state): + return state == self.get_state() + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_cptr", None) + return state + + +@contextlib.contextmanager +def temporarily_pop_interpreter_stack(): + try: + saved = pop_dynamic_layer_stack() + yield + finally: + push_dynamic_layer_stack(saved) + + +@contextlib.contextmanager +def temporarily_clear_interpreter_stack(): + stack = [] + try: + while torch._C._functorch.peek_interpreter_stack() is not None: + stack.append(pop_dynamic_layer_stack()) + yield list(stack) + finally: + while stack: + push_dynamic_layer_stack(stack.pop()) + + +@contextlib.contextmanager +def temporarily_restore_interpreter_stack(stack): + pushed = [] + try: + for s in reversed(stack): + push_dynamic_layer_stack(s) + pushed.append(s) + yield + finally: + for s in reversed(pushed): + # TODO: would be nice to assert that the layers are the same, but + # Python object identity is not preserved + pop_dynamic_layer_stack() + + +class VmapInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Vmap + # NOTE: [Interpreter cdata vs cptr] + # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr + # so that we can access methods specific to the vmap interpreter + self._cdata = cdata + + @cached_property + def _cptr(self): + return CVmapInterpreterPtr(self._cdata) + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Vmap] + return kernel(self, *args, **kwargs) + + def batch_size(self): + return self._cptr.batchSize() + + def randomness(self): + typ = self._cptr.randomness() + if typ == RandomnessType.Error: + return "error" + elif typ == RandomnessType.Same: + return "same" + elif typ == RandomnessType.Different: + return "different" + raise RuntimeError(f"Unknown RandomnessType: {typ}") + + def get_state(self): + return (self.key().name, self.level(), self.randomness()) + + +@contextlib.contextmanager +def nested(*contexts): + with contextlib.ExitStack() as stack: + for ctx in contexts: + stack.enter_context(ctx) + yield contexts + + +class GradInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Grad + # See NOTE: [Interpreter cdata vs cptr] + self._cdata = cdata + + @cached_property + def _cptr(self): + return CGradInterpreterPtr(self._cdata) + + def lift(self, args, kwargs): + args, kwargs = pytree.tree_map_only( + torch.Tensor, self._cptr.lift, [args, kwargs] + ) + return args, kwargs + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Grad] + args, kwargs = self.lift(args, kwargs) + return kernel(self, *args, **kwargs) + + # GradInterpreter has custom lower because of the no_grad interaction + # See NOTE [grad and vjp interaction with no_grad] + # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter + def lower(self): + prev_grad_mode = self.prev_grad_mode() + if not prev_grad_mode: + return nested(torch.no_grad(), super().lower()) + return super().lower() + + def prev_grad_mode(self): + return self._cptr.prevGradMode() + + def get_state(self): + return (self.key().name, self.level(), self.prev_grad_mode()) + + +class JvpInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Jvp + # See NOTE: [Interpreter cdata vs cptr] + self._cdata = cdata + + @cached_property + def _cptr(self): + return CJvpInterpreterPtr(self._cdata) + + def lift(self, args, kwargs): + args, kwargs = pytree.tree_map_only( + torch.Tensor, self._cptr.lift, [args, kwargs] + ) + return args, kwargs + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Jvp] + args, kwargs = self.lift(args, kwargs) + return kernel(self, *args, **kwargs) + + # Jvp has custom lower because of the no_fwd_grad interaction + # See NOTE [grad and vjp interaction with no_grad] for related info. + # This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter + def lower(self): + prev_fwd_grad_mode = self.prev_fwd_grad_mode() + if not prev_fwd_grad_mode: + return nested(_set_fwd_grad_enabled(False), super().lower()) + return super().lower() + + def prev_fwd_grad_mode(self): + return self._cptr.prevFwdGradMode() + + def get_state(self): + return (self.key().name, self.level(), self.prev_fwd_grad_mode()) + + +class FunctionalizeInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Functionalize + self._cdata = cdata + + @cached_property + def _cptr(self): + return CFunctionalizeInterpreterPtr(self._cdata) + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Functionalize] + return kernel(self, *args, **kwargs) + + def functionalize_add_back_views(self): + return self._cptr.functionalizeAddBackViews() + + def get_state(self): + return (self.key().name, self.level()) + + +def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter: + key = cinterpreter.key() + if key == TransformType.Grad: + return GradInterpreter(cinterpreter) + if key == TransformType.Vmap: + return VmapInterpreter(cinterpreter) + if key == TransformType.Jvp: + return JvpInterpreter(cinterpreter) + if key == TransformType.Functionalize: + return FunctionalizeInterpreter(cinterpreter) + raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}") + + +def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter: + interpreter = torch._C._functorch.peek_interpreter_stack() + assert interpreter is not None + return coerce_cinterpreter(interpreter) + + +def retrieve_all_functorch_interpreters() -> list[FuncTorchInterpreter]: + cis = torch._C._functorch.get_interpreter_stack() + if cis is None: + return [] + return [coerce_cinterpreter(ci) for ci in cis] + + +def compare_functorch_state(states: list[tuple[Any, ...]]) -> bool: + # There are four possible cases covered here: + # 1. Current stack empty AND stack when generated not empty -> Invalidate + # 2. Current stack not empty AND stack when generated empty -> Invalidate + # 3. Current stack and generated stack empty -> Valid FX graph + # 4. Current stack and generated stack not empty -> Valid if both states match + peek = torch._C._functorch.peek_interpreter_stack() + if (peek is None and len(states) != 0) or (peek is not None and len(states) == 0): + return False + + cis = retrieve_all_functorch_interpreters() + return len(cis) == len(states) and all( + ci.check_state(state) for ci, state in zip(cis, states) + ) + + +def dispatch_functorch(op, args, kwargs): + interpreter = retrieve_current_functorch_interpreter() + # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's + # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers. + # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch + # transforms, so we manually unwrap the dead tensors here. + # This logic won't need to exist when we have mode-only functorch. + args, kwargs = pytree.tree_map_only( + torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs) + ) + return interpreter.process(op, args, kwargs) diff --git a/phivenv/Lib/site-packages/torch/_functorch/python_key.py b/phivenv/Lib/site-packages/torch/_functorch/python_key.py new file mode 100644 index 0000000000000000000000000000000000000000..e020bba370ff41727285d987758213edd95bf9f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/python_key.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +__all__ = ["make_fx", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"] +from torch.fx.experimental.proxy_tensor import ( + decompose, + dispatch_trace, + make_fx, + PythonKeyTracer, +) + + +pythonkey_decompose = decompose diff --git a/phivenv/Lib/site-packages/torch/_functorch/pytree_hacks.py b/phivenv/Lib/site-packages/torch/_functorch/pytree_hacks.py new file mode 100644 index 0000000000000000000000000000000000000000..aae7a0c5fa3fd386b289b6f1cbe6cf0d93deaa89 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/pytree_hacks.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +# TODO: remove this file when the migration of the pytree utility is done +from torch.utils._pytree import tree_map_, treespec_pprint + + +__all__ = ["tree_map_", "treespec_pprint"] + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch._functorch.pytree_hacks` is deprecated and will be removed in a future release. " + "Please `use torch.utils._pytree` instead.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/phivenv/Lib/site-packages/torch/_functorch/top_operators_github_usage.py b/phivenv/Lib/site-packages/torch/_functorch/top_operators_github_usage.py new file mode 100644 index 0000000000000000000000000000000000000000..0659d53dd13f3df8019db2934afaae6f649ad786 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/top_operators_github_usage.py @@ -0,0 +1,629 @@ +# mypy: ignore-errors + +""" +From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0 +Try to keep this list in sync with that. +""" +import operator + + +top_torch = [ + ("t", 6837449), + ("tensor", 585786), + ("mode", 462182), + ("cat", 394818), + ("max", 368038), + ("zeros", 329495), + ("load", 327756), + ("no_grad", 294694), + ("save", 265130), + ("from_numpy", 243063), + ("manual_seed", 165044), + ("ones", 153696), + ("randn", 150796), + ("stack", 133358), + ("sum", 130772), + ("arange", 98087), + ("rand", 94715), + ("mean", 88546), + ("exp", 73883), + ("zeros_like", 72831), + ("min", 72248), + ("sigmoid", 66798), + ("log", 62135), + ("matmul", 47811), + ("clamp", 45304), + ("sqrt", 44911), + ("abs", 43535), + ("tanh", 42793), + ("empty", 40311), + ("argmax", 38435), + ("bmm", 33984), + ("pow", 33571), + ("norm", 31125), + ("mm", 30995), + ("is_tensor", 29546), + ("ones_like", 29512), + ("nonzero", 28681), + ("full", 28373), + ("unsqueeze", 27911), + ("where", 26585), + ("randperm", 26450), + ("eye", 24342), + ("mul", 23236), + ("topk", 22537), + ("as_tensor", 21967), + ("sort", 21412), + ("squeeze", 20863), + ("randint", 20771), + ("linspace", 20041), + ("add", 19201), + ("transpose", 18663), + ("split", 18325), + ("gather", 17904), + ("set_grad_enabled", 16013), + ("sin", 15669), + ("cos", 15562), + ("div", 15513), + ("index_select", 14866), + ("multinomial", 14331), + ("flatten", 14267), + ("isnan", 14170), + ("randn_like", 13096), + ("eq", 12680), + ("einsum", 12480), + ("round", 12367), + ("floor", 11628), + ("allclose", 11000), + ("reshape", 10605), + ("diag", 10167), + ("chunk", 9581), + ("std", 9379), + ("set_default_tensor_type", 9281), + ("triu", 8559), + ("meshgrid", 8292), + ("set_num_threads", 8126), + ("unique", 7964), + ("full_like", 7780), + ("tril", 7538), + ("dot", 7275), + ("sign", 6943), + ("equal", 6916), + ("normal", 6750), + ("cumsum", 6556), + ("dist", 6058), + ("isfinite", 6030), + ("gt", 5935), + ("set_printoptions", 5888), + ("range", 5491), + ("empty_like", 5351), + ("flip", 5342), + ("masked_select", 5341), + ("bernoulli", 5262), + ("atan", 5253), + ("var", 5247), + ("prod", 5200), + ("erf", 5088), + ("inverse", 5072), + ("addmm", 4854), + ("logsumexp", 4582), + ("fft", 4436), + ("lt", 4421), + ("log2", 4316), + ("enable_grad", 4238), + ("rand_like", 4187), + ("argsort", 3972), + ("seed", 3932), + ("mv", 3547), + ("ger", 3309), + ("ge", 3248), + ("atan2", 3210), + ("ceil", 3202), + ("ne", 3075), + ("bincount", 3063), + ("acos", 3055), + ("rsqrt", 3031), + ("svd", 3029), + ("numel", 3003), + ("log1p", 2840), + ("unbind", 2808), + ("le", 2714), + ("isinf", 2707), + ("cross", 2646), + ("set_default_dtype", 2536), + ("argmin", 2535), + ("sparse_coo_tensor", 2489), + ("log10", 2304), + ("kthvalue", 2192), + ("set_rng_state", 2158), + ("get_rng_state", 1996), + ("get_default_dtype", 1879), + ("det", 1868), + ("qr", 1864), + ("histc", 1852), + ("symeig", 1832), + ("trace", 1801), + ("median", 1795), + ("addcmul", 1751), + ("remainder", 1717), + ("baddbmm", 1693), + ("lgamma", 1665), + ("repeat_interleave", 1598), + ("fmod", 1576), + ("reciprocal", 1575), + ("tan", 1560), + ("initial_seed", 1532), + ("take", 1529), + ("stft", 1487), + ("get_num_threads", 1477), + ("real", 1459), + ("cholesky", 1406), + ("quantize_per_tensor", 1392), + ("diag_embed", 1364), + ("lerp", 1363), + ("asin", 1345), + ("eig", 1333), + ("trunc", 1290), + ("diagonal", 1287), + ("cosh", 1279), + ("rfft", 1269), + ("cumprod", 1260), + ("addr", 1211), + ("roll", 1198), + ("narrow", 1188), + ("digamma", 1172), + ("square", 1163), + ("sinh", 1131), + ("logspace", 1084), + ("broadcast_tensors", 1070), + ("irfft", 1013), + ("frac", 997), + ("hann_window", 994), + ("solve", 989), + ("logdet", 977), + ("expm1", 968), + ("cdist", 946), + ("addmv", 903), + ("randint_like", 888), + ("tensordot", 888), + ("ifft", 877), + ("true_divide", 854), + ("erfinv", 830), + ("addcdiv", 819), + ("addbmm", 813), + ("renorm", 781), + ("pinverse", 753), + ("isclose", 740), + ("erfc", 729), + ("is_storage", 725), + ("triangular_solve", 723), + ("rot90", 709), + ("logical_not", 686), + ("geqrf", 681), + ("slogdet", 677), + ("lu", 665), + ("hamming_window", 659), + ("orgqr", 651), + ("ormqr", 622), + ("is_floating_point", 602), + ("diagflat", 562), + ("cholesky_solve", 559), + ("tril_indices", 552), + ("chain_matmul", 551), + ("triu_indices", 548), + ("angle", 522), + ("poisson", 505), + ("matrix_power", 485), + ("unique_consecutive", 471), + ("quantize_per_channel", 465), + ("std_mean", 458), + ("bartlett_window", 447), + ("var_mean", 428), + ("lstsq", 421), + ("logical_and", 419), + ("mvlgamma", 411), + ("blackman_window", 400), + ("bitwise_not", 395), + ("cholesky_inverse", 388), + ("as_strided", 384), + ("floor_divide", 353), + ("cartesian_prod", 321), + ("lu_solve", 317), + ("set_flush_denormal", 310), + ("empty_strided", 283), + ("logical_xor", 282), + ("polygamma", 282), + ("logical_or", 280), + ("set_num_interop_threads", 278), + ("combinations", 274), + ("trapz", 270), + ("matrix_rank", 260), + ("lu_unpack", 255), + ("result_type", 244), + ("conj", 231), + ("cummax", 230), + ("lobpcg", 229), + ("bitwise_xor", 217), + ("promote_types", 213), + ("get_num_interop_threads", 211), + ("cummin", 205), + ("bitwise_and", 198), + ("dequantize", 192), + ("bitwise_or", 191), + ("imag", 191), + ("can_cast", 184), + ("istft", 180), + ("compiled_with_cxx11_abi", 159), + ("is_complex", 151), + ("block_diag", 136), + ("pca_lowrank", 124), + ("absolute", 122), + ("svd_lowrank", 108), + ("neg", 2), +] + +top_nn_functional = [ + ("nn.functional.softmax", 10522), + ("nn.functional.relu", 8572), + ("nn.functional.interpolate", 7277), + ("nn.functional.pad", 5207), + ("nn.functional.log_softmax", 4699), + ("nn.functional.normalize", 2338), + ("nn.functional.cross_entropy", 2083), + ("nn.functional.grid_sample", 1970), + ("nn.functional.one_hot", 1967), + ("nn.functional.mse_loss", 1920), + ("nn.functional.conv2d", 1593), + ("nn.functional.dropout", 1516), + ("nn.functional.softplus", 1385), + ("nn.functional.sigmoid", 1128), + ("nn.functional.linear", 1036), + ("nn.functional.gelu", 930), + ("nn.functional.avg_pool2d", 899), + ("nn.functional.max_pool2d", 876), + ("nn.functional.nll_loss", 863), + ("nn.functional.embedding", 737), + ("nn.functional.tanh", 664), + ("nn.functional.leaky_relu", 640), + ("nn.functional.adaptive_avg_pool2d", 633), + ("nn.functional.cosine_similarity", 627), + ("nn.functional.unfold", 609), + ("nn.functional.conv1d", 596), + ("nn.functional.binary_cross_entropy_with_logits", 591), + ("nn.functional.l1_loss", 571), + ("nn.functional.binary_cross_entropy", 492), + ("nn.functional.elu", 416), + ("nn.functional.batch_norm", 413), + ("nn.functional.upsample", 413), + ("nn.functional.fold", 305), + ("nn.functional.affine_grid", 298), + ("nn.functional.max_pool1d", 297), + ("nn.functional.torch", 294), + ("nn.functional.threshold", 263), + ("nn.functional.smooth_l1_loss", 262), + ("nn.functional.pairwise_distance", 253), + ("nn.functional.logsigmoid", 243), + ("nn.functional.adaptive_max_pool2d", 235), + ("nn.functional.relu6", 213), + ("nn.functional.pixel_shuffle", 209), + ("nn.functional.avg_pool3d", 203), + ("nn.functional.bilinear", 203), + ("nn.functional.conv_transpose2d", 201), + ("nn.functional.gumbel_softmax", 197), + ("nn.functional.max_unpool2d", 196), + ("nn.functional.kl_div", 191), + ("nn.functional.hardtanh", 189), + ("nn.functional.ctc_loss", 185), + ("nn.functional.layer_norm", 178), + ("nn.functional.conv3d", 172), + ("nn.functional.max_unpool3d", 167), + ("nn.functional.hardshrink", 165), + ("nn.functional.hardswish", 156), + ("nn.functional.selu", 156), + ("nn.functional.glu", 155), + ("nn.functional.assert_int_or_pair", 150), + ("nn.functional.hardsigmoid", 146), + ("nn.functional.upsample_bilinear", 146), + ("nn.functional.max_pool3d", 140), + ("nn.functional.adaptive_avg_pool3d", 139), + ("nn.functional.instance_norm", 124), + ("nn.functional.embedding_bag", 122), + ("nn.functional.upsample_nearest", 110), + ("nn.functional.avg_pool1d", 105), + ("nn.functional.prelu", 102), + ("nn.functional.celu", 92), + ("nn.functional.dropout2d", 86), + ("nn.functional.hinge_embedding_loss", 82), + ("nn.functional.softsign", 81), + ("nn.functional.max_unpool1d", 74), + ("nn.functional.silu", 74), + ("nn.functional.softshrink", 70), + ("nn.functional.leaky_relu_", 68), + ("nn.functional.softmin", 67), + ("nn.functional.channel_shuffle", 66), + ("nn.functional.multilabel_margin_loss", 66), + ("nn.functional.dropout3d", 65), + ("nn.functional.multi_margin_loss", 65), + ("nn.functional.lp_pool2d", 64), + ("nn.functional.conv_transpose1d", 62), + ("nn.functional.triplet_margin_loss", 62), + ("nn.functional.tanhshrink", 61), + ("nn.functional.adaptive_max_pool1d", 59), + ("nn.functional.cosine_embedding_loss", 58), + ("nn.functional.multi_head_attention_forward", 58), + ("nn.functional.max_pool1d_with_indices", 53), + ("nn.functional.poisson_nll_loss", 53), + ("nn.functional.margin_ranking_loss", 52), + ("nn.functional.soft_margin_loss", 52), + ("nn.functional.adaptive_max_pool3d", 51), + ("nn.functional.group_norm", 51), + ("nn.functional.local_response_norm", 51), + ("nn.functional.multilabel_soft_margin_loss", 51), + ("nn.functional.relu_", 50), + ("nn.functional.alpha_dropout", 49), + ("nn.functional.feature_alpha_dropout", 49), + ("nn.functional.lp_pool1d", 49), + ("nn.functional.adaptive_max_pool1d_with_indices", 48), + ("nn.functional.adaptive_max_pool2d_with_indices", 48), + ("nn.functional.adaptive_max_pool3d_with_indices", 48), + ("nn.functional.fractional_max_pool2d", 48), + ("nn.functional.fractional_max_pool2d_with_indices", 48), + ("nn.functional.fractional_max_pool3d", 48), + ("nn.functional.fractional_max_pool3d_with_indices", 48), + ("nn.functional.max_pool2d_with_indices", 48), + ("nn.functional.max_pool3d_with_indices", 48), + ("nn.functional.handle_torch_function", 47), + ("nn.functional.has_torch_function", 47), + ("nn.functional.adaptive_avg_pool1d", 43), + ("nn.functional.pdist", 43), + ("nn.functional.rrelu_", 37), + ("nn.functional.elu_", 34), + ("nn.functional.boolean_dispatch", 33), + ("nn.functional.hardtanh_", 26), + ("nn.functional.triplet_margin_with_distance_loss", 23), + ("nn.functional.selu_", 20), + ("nn.functional.pixel_unshuffle", 19), + ("nn.functional.conv_transpose3d", 18), + ("nn.functional.gaussian_nll_loss", 15), + ("nn.functional.has_torch_function_unary", 15), + ("nn.functional.has_torch_function_variadic", 15), + ("nn.functional.celu_", 13), + ("nn.functional.huber_loss", 7), + ("nn.functional.mish", 4), + ("nn.functional.threshold_", 3), + ("nn.functional.grad", 2), + ("nn.functional.conv_tbc", 1), + ("nn.functional.math", 1), +] + +top_nn_module = [ + ("nn.Module", 927129, None), + ("nn.Linear", 530688, "nn.functional.linear"), + ("nn.Sequential", 384968, None), + ("nn.Conv2d", 383320, "nn.functional.conv2d"), + ("nn.ReLU", 318877, "nn.functional.relu"), + ("nn.BatchNorm2d", 233265, "nn.functional.batch_norm"), + ("nn.Dropout", 179268, "nn.functional.dropout"), + ("nn.ModuleList", 171225, None), + ("nn.Parameter", 153291, None), + ("nn.CrossEntropyLoss", 152696, "nn.functional.cross_entropy"), + ("nn.MaxPool2d", 138619, "nn.functional.max_pool2d"), + ("nn.Embedding", 111844, "nn.functional.embedding"), + ("nn.DataParallel", 104238, None), + ("nn.MSELoss", 82954, "nn.functional.mse_loss"), + ("nn.Sigmoid", 75810, "nn.functional.sigmoid"), + ("nn.LeakyReLU", 65632, "nn.functional.leaky_relu"), + ("nn.BatchNorm1d", 65374, "nn.functional.batch_norm"), + ("nn.Softmax", 65114, "nn.functional.softmax"), + ("nn.Tanh", 59445, "nn.functional.tanh"), + ("nn.AdaptiveAvgPool2d", 59071, "nn.functional.adaptive_avg_pool2d"), + ("nn.AvgPool2d", 58377, "nn.functional.avg_pool2d"), + ("nn.ConvTranspose2d", 57524, "nn.functional.conv_transpose2d"), + ("nn.LSTM", 57411, None), + ("nn.Conv1d", 41108, "nn.functional.conv1d"), + ("nn.LayerNorm", 36089, "nn.functional.layer_norm"), + ("nn.BCELoss", 34005, "nn.functional.binary_cross_entropy"), + ("nn.Upsample", 32527, "nn.functional.interpolate"), + ("nn.BCEWithLogitsLoss", 29944, "nn.functional.binary_cross_entropy_with_logits"), + ("nn.GRU", 25421, None), + ("nn.Dropout2d", 23512, "nn.functional.dropout2d"), + ("nn.LogSoftmax", 22897, "nn.functional.log_softmax"), + ("nn.L1Loss", 22778, "nn.functional.l1_loss"), + ("nn.GroupNorm", 22183, "nn.functional.group_norm"), + ("nn.NLLLoss", 21751, "nn.functional.nll_loss"), + ("nn.Conv3d", 20874, "nn.functional.conv3d"), + ("nn.Identity", 17911, None), + ("nn.InstanceNorm2d", 16426, "nn.functional.instance_norm"), + ("nn.BatchNorm3d", 16378, "nn.functional.batch_norm"), + ("nn.PReLU", 13472, "nn.functional.prelu"), + ("nn.ReLU6", 12622, "nn.functional.relu6"), + ("nn.ELU", 12508, "nn.functional.elu"), + ("nn.LSTMCell", 10885, None), + ("nn.Flatten", 10384, "torch.flatten"), + ("nn.ModuleDict", 10255, None), + ("nn.ReflectionPad2d", 9954, "nn.functional.pad"), + ("nn.MaxPool3d", 9526, "nn.functional.max_pool3d"), + ("nn.MaxPool1d", 9154, "nn.functional.max_pool1d"), + ("nn.RNN", 9154, None), + ("nn.ZeroPad2d", 8847, "nn.functional.pad"), + ("nn.ParameterList", 7702, None), + ("nn.SyncBatchNorm", 6814, None), + ("nn.PixelShuffle", 6571, "nn.functional.pixel_shuffle"), + ("nn.SmoothL1Loss", 6517, "nn.functional.smooth_l1_loss"), + ("nn.Hardswish", 6458, "nn.functional.hardswish"), + ("nn.AdaptiveMaxPool2d", 6071, "nn.functional.adaptive_max_pool2d"), + ("nn.SELU", 6043, "nn.functional.selu"), + ("nn.ConvTranspose3d", 6039, "nn.functional.conv_transpose3d"), + ("nn.GRUCell", 5840, None), + ("nn.ReplicationPad2d", 5600, "nn.functional.pad"), + ("nn.KLDivLoss", 5541, "nn.functional.kl_div"), + ("nn.ConvTranspose1d", 5183, "nn.functional.conv_transpose1d"), + ("nn.Softplus", 5120, "nn.functional.softplus"), + ("nn.SiLU", 4895, "nn.functional.silu"), + ("nn.AvgPool3d", 4523, "nn.functional.avg_pool3d"), + ("nn.CosineSimilarity", 4058, "nn.functional.cosine_similarity"), + ("nn.GELU", 3932, "nn.functional.gelu"), + ("nn.UpsamplingBilinear2d", 3673, "nn.functional.interpolate"), + ("nn.InstanceNorm1d", 3658, "nn.functional.instance_norm"), + ("nn.Transformer", 3604, None), + ("nn.MultiheadAttention", 3435, "nn.functional.multi_head_attention_forward"), + ("nn.AvgPool1d", 3195, "nn.functional.avg_pool1d"), + ("nn.Dropout3d", 2964, "nn.functional.dropout3d"), + ("nn.AdaptiveAvgPool3d", 2915, "nn.functional.adaptive_avg_pool3d"), + ("nn.InstanceNorm3d", 2893, "nn.functional.instance_norm"), + ("nn.Hardtanh", 2613, "nn.functional.hardtanh"), + ("nn.MarginRankingLoss", 2568, "nn.functional.margin_ranking_loss"), + ("nn.GLU", 2526, "nn.functional.glu"), + ("nn.AdaptiveAvgPool1d", 2481, "nn.functional.adaptive_avg_pool1d"), + ("nn.EmbeddingBag", 2344, "nn.functional.embedding_bag"), + ("nn.TransformerEncoderLayer", 2292, None), + ("nn.TransformerEncoder", 2091, None), + ("nn.MaxUnpool2d", 2031, "nn.functional.max_unpool2d"), + ("nn.UpsamplingNearest2d", 2004, "nn.functional.interpolate"), + ("nn.ConstantPad1d", 1904, "nn.functional.pad"), + ("nn.ConstantPad2d", 1791, "nn.functional.pad"), + ("nn.CTCLoss", 1789, "nn.functional.ctc_loss"), + ("nn.AdaptiveMaxPool1d", 1713, "nn.functional.adaptive_max_pool1d"), + ("nn.AdaptiveLogSoftmaxWithLoss", 1665, None), + ("nn.Bilinear", 1664, "nn.functional.bilinear"), + ("nn.RNNCell", 1653, None), + ("nn.MultiLabelSoftMarginLoss", 1624, "nn.functional.multilabel_soft_margin_loss"), + ("nn.Unfold", 1452, "nn.functional.unfold"), + ("nn.RReLU", 1431, "nn.functional.rrelu"), + ("nn.CosineEmbeddingLoss", 1357, "nn.functional.cosine_embedding_loss"), + ("nn.LocalResponseNorm", 1331, "nn.functional.local_response_norm"), + ("nn.Softmax2d", 1300, "nn.functional.softmax"), + ("nn.PairwiseDistance", 1241, "nn.functional.pairwise_distance"), + ("nn.LogSigmoid", 1235, "nn.functional.logsigmoid"), + ("nn.TripletMarginLoss", 1230, "nn.functional.triplet_margin_loss"), + ("nn.RNNBase", 1133, None), + ("nn.Threshold", 1043, "nn.functional.threshold"), + ("nn.AdaptiveMaxPool3d", 1025, "nn.functional.adaptive_max_pool3d"), + ("nn.CELU", 1018, "nn.functional.celu"), + ("nn.NLLLoss2d", 966, "nn.functional.nll_loss"), + ("nn.Softsign", 877, "nn.functional.softsign"), + ("nn.ReplicationPad1d", 862, "nn.functional.pad"), + ("nn.SoftMarginLoss", 856, "nn.functional.soft_margin_loss"), + ("nn.ParameterDict", 742, None), + ("nn.ReflectionPad1d", 731, "nn.functional.pad"), + ("nn.Softshrink", 713, "nn.functional.softshrink"), + ("nn.AlphaDropout", 710, "nn.functional.alpha_dropout"), + ("nn.Tanhshrink", 681, "nn.functional.tanhshrink"), + ("nn.PoissonNLLLoss", 676, "nn.functional.poisson_nll_loss"), + ("nn.MaxUnpool3d", 660, "nn.functional.max_unpool3d"), + ("nn.Fold", 630, "nn.functional.fold"), + ("nn.MultiMarginLoss", 622, "nn.functional.multi_margin_loss"), + ("nn.TransformerDecoderLayer", 614, None), + ("nn.TransformerDecoder", 607, None), + ("nn.Hardshrink", 592, "nn.functional.hardshrink"), + ("nn.ConstantPad3d", 582, "nn.functional.pad"), + ("nn.MultiLabelMarginLoss", 580, "nn.functional.multilabel_margin_loss"), + ("nn.LPPool2d", 550, "nn.functional.lp_pool2d"), + ("nn.Softmin", 537, "nn.functional.softmin"), + ("nn.MaxUnpool1d", 518, "nn.functional.max_unpool1d"), + ("nn.FractionalMaxPool2d", 484, "nn.functional.fractional_max_pool2d"), + ("nn.Hardsigmoid", 477, "nn.functional.hardsigmoid"), + ("nn.ReplicationPad3d", 470, "nn.functional.pad"), + ("nn.HingeEmbeddingLoss", 442, "nn.functional.hinge_embedding_loss"), + ("nn.LPPool1d", 386, "nn.functional.lp_pool1d"), + ("nn.FractionalMaxPool3d", 252, "nn.functional.fractional_max_pool3d"), + ("nn.Container", 217, None), + ("nn.Unflatten", 206, "nn.functional.unflatten"), + ("nn.FeatureAlphaDropout", 136, "nn.functional.feature_alpha_dropout"), + ( + "nn.TripletMarginWithDistanceLoss", + 107, + "nn.functional.triplet_margin_with_distance_loss", + ), + ("nn.ChannelShuffle", 90, "nn.functional.channel_shuffle"), + ("nn.RNNCellBase", 88, None), + ("nn.LazyLinear", 81, "nn.functional.linear"), + ("nn.UninitializedParameter", 60, None), + ("nn.CrossMapLRN2d", 59, None), + ("nn.GaussianNLLLoss", 55, "nn.functional.gaussian_nll_loss"), + ("nn.PixelUnshuffle", 45, "nn.functional.pixel_unshuffle"), + ("nn.Mish", 31, "nn.functional.mish"), + ("nn.ReflectionPad3d", 22, "nn.functional.pad"), + ("nn.HuberLoss", 18, "nn.functional.huber_loss"), + ("nn.LazyConv2d", 15, None), + ("nn.LazyConv1d", 9, None), + ("nn.LazyConv3d", 8, None), + ("nn.LazyConvTranspose1d", 8, None), + ("nn.LazyConvTranspose2d", 8, None), + ("nn.LazyConvTranspose3d", 8, None), + ("nn.LazyBatchNorm1d", 3, None), + ("nn.LazyBatchNorm2d", 3, None), + ("nn.LazyBatchNorm3d", 3, None), + ("nn.UninitializedBuffer", 3, None), +] + +# No rankings because these are a little hard to get rankings for +method_only_ops = [ + "bfloat16", + "bool", + "byte", + "char", + "contiguous", + "cpu", + "cuda", + "detach", + "double", + "expand", + "expand_as", + "float", + "get_device", + "half", + "hardshrink", + "index_add", + "index_copy", + "index_fill", + "index_put", + "int", + "is_contiguous", + "is_pinned", + "is_set_to", + "is_shared", + "is_signed", + "item", + "long", + "masked_scatter", + "masked_fill", + "narrow_copy", + "numpy", + "pin_memory", + "repeat", + "reshape_as", + "select", + "short", + "storage_offset", + "sum_to_size", + "to", + "to_mkldnn", + "tolist", + "type", + "type_as", + "unfold", + "view", + "view_as", +] + + +def get_nn_functional_top_list(): + top_nn_functional_ = dict(top_nn_functional) + for _, count, functional_name in top_nn_module: + if functional_name is None: + continue + if functional_name == "torch.flatten": + continue + if functional_name not in top_nn_functional_: + top_nn_functional_[functional_name] = count + else: + top_nn_functional_[functional_name] += count + + top_nn_functional_ = list(top_nn_functional_.items()) + top_nn_functional_.sort(key=operator.itemgetter(1), reverse=True) + return top_nn_functional_ + + +usage_count = dict(get_nn_functional_top_list()) +usage_count.update(top_torch) diff --git a/phivenv/Lib/site-packages/torch/_functorch/utils.py b/phivenv/Lib/site-packages/torch/_functorch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..74214664e2d3b47b1067cd54dfeb0f8e8d23bb2b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/utils.py @@ -0,0 +1,40 @@ +import contextlib +from collections.abc import Generator +from typing import Any, Union + +import torch +from torch._C._functorch import ( + get_single_level_autograd_function_allowed, + set_single_level_autograd_function_allowed, + unwrap_if_dead, +) +from torch.utils._exposed_in import exposed_in + + +__all__ = [ + "exposed_in", + "argnums_t", + "enable_single_level_autograd_function", + "unwrap_dead_wrappers", +] + + +@contextlib.contextmanager +def enable_single_level_autograd_function() -> Generator[None, None, None]: + try: + prev_state = get_single_level_autograd_function_allowed() + set_single_level_autograd_function_allowed(True) + yield + finally: + set_single_level_autograd_function_allowed(prev_state) + + +def unwrap_dead_wrappers(args: tuple[Any, ...]) -> tuple[Any, ...]: + # NB: doesn't use tree_map_only for performance reasons + result = tuple( + unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args + ) + return result + + +argnums_t = Union[int, tuple[int, ...]] diff --git a/phivenv/Lib/site-packages/torch/_functorch/vmap.py b/phivenv/Lib/site-packages/torch/_functorch/vmap.py new file mode 100644 index 0000000000000000000000000000000000000000..5c1b10a9d1b3a6b372c0467f21278020e6df93a8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_functorch/vmap.py @@ -0,0 +1,539 @@ +# mypy: ignore-errors + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import functools +import itertools +import os +import threading +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +from torch import Tensor +from torch._C._functorch import ( + _add_batch_dim, + _remove_batch_dim, + _vmap_decrement_nesting, + _vmap_increment_nesting, + is_batchedtensor, +) +from torch.utils._pytree import ( + _broadcast_to_and_flatten, + tree_flatten, + tree_map_, + tree_unflatten, + TreeSpec, +) + + +in_dims_t = Union[int, tuple] +out_dims_t = Union[int, tuple[int, ...]] + + +def doesnt_support_saved_tensors_hooks(f): + message = ( + "torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. " + "Please open an issue with your use case." + ) + + @functools.wraps(f) + def fn(*args, **kwargs): + with torch.autograd.graph.disable_saved_tensors_hooks(message): + return f(*args, **kwargs) + + return fn + + +# Checks that all args-to-be-batched have the same batch dim size +def _validate_and_get_batch_size( + flat_in_dims: list[Optional[int]], flat_args: list +) -> int: + batch_sizes = [ + arg.size(in_dim) + for in_dim, arg in zip(flat_in_dims, flat_args) + if in_dim is not None + ] + if len(batch_sizes) == 0: + raise ValueError("vmap: Expected at least one Tensor to vmap over") + if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes): + raise ValueError( + f"vmap: Expected all tensors to have the same size in the mapped " + f"dimension, got sizes {batch_sizes} for the mapped dimension" + ) + return batch_sizes[0] + + +def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int: + if isinstance(batched_outputs, tuple): + return len(batched_outputs) + return 1 + + +# If value is a tuple, check it has length `num_elements`. +# If value is not a tuple, make a tuple with `value` repeated `num_elements` times + + +def _as_tuple( + value: Any, num_elements: int, error_message_lambda: Callable[[], str] +) -> tuple: + if not isinstance(value, tuple): + return (value,) * num_elements + if len(value) != num_elements: + raise ValueError(error_message_lambda()) + return value + + +def _process_batched_inputs( + in_dims: in_dims_t, args: tuple, func: Callable +) -> tuple[int, list[Any], list[Any], TreeSpec]: + if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"expected `in_dims` to be int or a (potentially nested) tuple " + f"matching the structure of inputs, got: {type(in_dims)}." + ) + if len(args) == 0: + raise ValueError( + f"vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add " + f"inputs, or you are trying to vmap over a function with no inputs. " + f"The latter is unsupported." + ) + + flat_args, args_spec = tree_flatten(args) + flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) + if flat_in_dims is None: + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"in_dims is not compatible with the structure of `inputs`. " + f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs " + f"has structure {args_spec}." + ) + + for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): + if not isinstance(in_dim, int) and in_dim is not None: + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for an input but in_dim must be either " + f"an integer dimension or None." + ) + if isinstance(in_dim, int) and not isinstance(arg, Tensor): + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for an input but the input is of type " + f"{type(arg)}. We cannot vmap over non-Tensor arguments, " + f"please use None as the respective in_dim" + ) + if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): + raise ValueError( + f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " + f"Got in_dim={in_dim} for some input, but that input is a Tensor " + f"of dimensionality {arg.dim()} so expected in_dim to satisfy " + f"-{arg.dim()} <= in_dim < {arg.dim()}." + ) + if in_dim is not None and in_dim < 0: + flat_in_dims[i] = in_dim % arg.dim() + + return ( + _validate_and_get_batch_size(flat_in_dims, flat_args), + flat_in_dims, + flat_args, + args_spec, + ) + + +# Creates BatchedTensors for every Tensor in arg that should be batched. +# Returns the (potentially) batched arguments and the batch_size. + + +def _create_batched_inputs( + flat_in_dims: list[Any], flat_args: list[Any], vmap_level: int, args_spec +) -> tuple: + # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] + batched_inputs = [ + arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level) + for in_dim, arg in zip(flat_in_dims, flat_args) + ] + return tree_unflatten(batched_inputs, args_spec) + + +def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim): + if out_dim is None: + if isinstance(batched_output, torch.Tensor) and is_batchedtensor( + batched_output + ): + raise ValueError( + f"vmap({name}, ...): `{name}` can not return a " + f"BatchedTensor when out_dim is None" + ) + return batched_output + + # out_dim is non None + if not isinstance(batched_output, torch.Tensor): + raise ValueError( + f"vmap({name}, ...): `{name}` must only return " + f"Tensors, got type {type(batched_output)}. " + "Did you mean to set out_dims= to None for output?" + ) + + return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) + + +# Undos the batching (and any batch dimensions) associated with the `vmap_level`. +def _unwrap_batched( + batched_outputs: Union[Tensor, tuple[Tensor, ...]], + out_dims: out_dims_t, + vmap_level: int, + batch_size: int, + func: Callable, +) -> tuple: + flat_batched_outputs, output_spec = tree_flatten(batched_outputs) + + def incompatible_error(): + raise ValueError( + f"vmap({_get_name(func)}, ..., out_dims={out_dims})(): " + f"out_dims is not compatible with the structure of `outputs`. " + f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs " + f"has structure {output_spec}." + ) + + if isinstance(batched_outputs, torch.Tensor): + # Some weird edge case requires us to spell out the following + # see test_out_dims_edge_case + if isinstance(out_dims, int): + flat_out_dims = [out_dims] + elif isinstance(out_dims, tuple) and len(out_dims) == 1: + flat_out_dims = out_dims + elif out_dims is None: + flat_out_dims = [out_dims] + else: + incompatible_error() + else: + flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec) + if flat_out_dims is None: + incompatible_error() + + flat_outputs = [ + _maybe_remove_batch_dim( + _get_name(func), batched_output, vmap_level, batch_size, out_dim + ) + for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims) + ] + return tree_unflatten(flat_outputs, output_spec) + + +def _check_int_or_none(x, func, out_dims): + if isinstance(x, int): + return + if x is None: + return + raise ValueError( + f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be " + f"an int, None or a python collection of ints representing where in the outputs the " + f"vmapped dimension should appear." + ) + + +def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None: + if isinstance(out_dims, int): + return + tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims) + + +def _get_name(func: Callable): + if hasattr(func, "__name__"): + return func.__name__ + + if isinstance(func, functools.partial): + return f"functools.partial({_get_name(func.func)}, ...)" + + # Not all callables have __name__, in fact, only static functions/methods + # do. A callable created via nn.Module, to name one example, doesn't have a + # __name__. + return repr(func) + + +DECOMPOSITIONS_LOADED = False +DECOMPOSITIONS_LOCK = threading.Lock() +VMAP_DECOMPOSITIONS_LIB = None + + +# torch.package, Python 3.11, and torch.jit-less environments are unhappy with +# decompositions. Only load them when needed if possible. +def lazy_load_decompositions(): + global DECOMPOSITIONS_LOADED + if DECOMPOSITIONS_LOADED: + return + + with DECOMPOSITIONS_LOCK: + if DECOMPOSITIONS_LOADED: + return + + if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__): + DECOMPOSITIONS_LOADED = True + return + + # use an alternate way to register an operator into the decomposition table + # _register_jit_decomposition doesn't work for some operators, e.g. addr, + # because the Tensor types generated cannot be unioned by torchscript + # decomp should be type OpOverload + global VMAP_DECOMPOSITIONS_LIB + VMAP_DECOMPOSITIONS_LIB = torch.library.Library( + "aten", "IMPL", "FuncTorchBatched" + ) + + from torch._decomp import decomposition_table + + def _register_python_decomposition_vmap(decomp): + if decomp in decomposition_table: + VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp]) + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + + _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default) + _register_python_decomposition_vmap( + torch.ops.aten.smooth_l1_loss_backward.default + ) + _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.addr.default) + + DECOMPOSITIONS_LOADED = True + + +def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs): + lazy_load_decompositions() + _check_out_dims_is_int_or_int_pytree(out_dims, func) + batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs( + in_dims, args, func + ) + + if chunk_size is not None: + chunks_flat_args = _get_chunked_inputs( + flat_args, flat_in_dims, batch_size, chunk_size + ) + return _chunked_vmap( + func, + flat_in_dims, + chunks_flat_args, + args_spec, + out_dims, + randomness, + **kwargs, + ) + + # If chunk_size is not specified. + return _flat_vmap( + func, + batch_size, + flat_in_dims, + flat_args, + args_spec, + out_dims, + randomness, + **kwargs, + ) + + +def get_chunk_sizes(total_elems, chunk_size): + n_chunks = n_chunks = total_elems // chunk_size + chunk_sizes = [chunk_size] * n_chunks + # remainder chunk + remainder = total_elems % chunk_size + if remainder != 0: + chunk_sizes.append(remainder) + return chunk_sizes + + +def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size): + split_idxs = (batch_size,) + if chunk_size is not None: + chunk_sizes = get_chunk_sizes(batch_size, chunk_size) + split_idxs = tuple(itertools.accumulate(chunk_sizes)) + + flat_args_chunks = tuple( + ( + t.tensor_split(split_idxs, dim=in_dim) + if in_dim is not None + else [ + t, + ] + * len(split_idxs) + ) + for t, in_dim in zip(flat_args, flat_in_dims) + ) + + # transpose chunk dim and flatten structure + # chunks_flat_args is a list of flatten args + chunks_flat_args = zip(*flat_args_chunks) + return chunks_flat_args + + +def _flatten_chunks_output(chunks_output_): + # chunks_output is a list of chunked outputs + # flatten chunked outputs: + flat_chunks_output = [] + arg_spec = None + for output in chunks_output_: + flat_output, arg_specs = tree_flatten(output) + flat_chunks_output.append(flat_output) + if arg_spec is None: + arg_spec = arg_specs + + # transpose chunk dim and flatten structure + # flat_output_chunks is flat list of chunks + flat_output_chunks = list(zip(*flat_chunks_output)) + return flat_output_chunks, arg_spec + + +def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks): + # concat chunks on out_dim + flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec) + assert len(flat_out_dims) == len(flat_output_chunks) + flat_output = [] + for idx, out_dim in enumerate(flat_out_dims): + flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim)) + # release tensors + flat_output_chunks[idx] = None + + return flat_output + + +# Applies vmap on chunked_input and returns concatenated output over the chunks. +def _chunked_vmap( + func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs +): + chunks_output = [] + rs = torch.get_rng_state() if randomness == "same" else None + for flat_args in chunks_flat_args: + batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) + + # The way we compute split the input in `_get_chunked_inputs`, + # we may get a tensor with `0` batch-size. We skip any computation + # in that case. + # Eg. + # >>> chunk_size = 1 + # >>> batch_size = 6 + # >>> t = torch.zeros(batch_size, 1) + # >>> t.tensor_split([1, 2, 3, 4, 5, 6]) + # (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), + # tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1))) + if batch_size == 0: + continue + + if rs is not None: + torch.set_rng_state(rs) + chunks_output.append( + _flat_vmap( + func, + batch_size, + flat_in_dims, + flat_args, + args_spec, + out_dims, + randomness, + **kwargs, + ) + ) + + flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output) + + # chunked output tensors are held by both `flat_output_chunks` and `chunks_output`. + # eagerly remove the reference from `chunks_output`. + del chunks_output + + # concat chunks on out_dim + flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks) + + # finally unflatten the output + return tree_unflatten(flat_output, arg_spec) + + +# Vmap refactored helper functions: +def _check_randomness_arg(randomness): + if randomness not in ["error", "different", "same"]: + raise RuntimeError( + f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}" + ) + + +@contextlib.contextmanager +def vmap_increment_nesting(batch_size, randomness): + try: + vmap_level = _vmap_increment_nesting(batch_size, randomness) + yield vmap_level + finally: + _vmap_decrement_nesting() + + +def _flat_vmap( + func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs +): + with vmap_increment_nesting(batch_size, randomness) as vmap_level: + batched_inputs = _create_batched_inputs( + flat_in_dims, flat_args, vmap_level, args_spec + ) + batched_outputs = func(*batched_inputs, **kwargs) + return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) + + +# `restore_vmap` is a private helper function. It is vmap but has the following +# differences: +# - instead of returning outputs, it returns an (outputs, out_dims) tuple. +# out_dims is a pytree of same shape as outputs and contains Optional[int] +# specifying where the vmapped dimension, if it exists, is in the corresponding output. +# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped). +# restore_vmap allows for no inputs to have the vmap dimension +# - does no validation on outputs (vmap expects only Tensor outputs) +# restore_vmap allows for return of arbitrary outputs (not just Tensors) +# +# The TL;DR is that restore_vmap is more general than vmap and has a slightly +# different API. The relaxations are so that we can "pause" vmap in the middle +# of its execution and then "restore" it later (this is what we do in +# the generate_vmap_rule=True implementation of autograd.Function). +# +# restore_vmap can be technically used in the implementation of vmap, but doing +# that refactor is a bit technically challenging because: +# - vmap couples the tensor-wrapping code with error checking +# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it +# in python because it overlaps with unwrap_batched +def restore_vmap(func, in_dims, batch_size, randomness): + def inner(*args, **kwargs): + with vmap_increment_nesting(batch_size, randomness) as vmap_level: + batched_inputs = wrap_batched(args, in_dims, vmap_level) + batched_outputs = func(*batched_inputs, **kwargs) + return unwrap_batched(batched_outputs, vmap_level) + + return inner + + +def wrap_batched(args, bdims, level): + flat_args, spec = tree_flatten(args) + flat_bdims = _broadcast_to_and_flatten(bdims, spec) + assert flat_bdims is not None + result = _create_batched_inputs(flat_bdims, flat_args, level, spec) + return result + + +def unwrap_batched(args, level): + flat_args, spec = tree_flatten(args) + if len(flat_args) == 0: + return args, () + result = [ + ( + torch._C._functorch._unwrap_batched(arg, level) + if isinstance(arg, torch.Tensor) + else (arg, None) + ) + for arg in flat_args + ] + output, bdims = zip(*result) + return tree_unflatten(output, spec), tree_unflatten(bdims, spec) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__init__.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b49333c737eeac91b1df96a48b3708dc25e6ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/__init__.py @@ -0,0 +1,72 @@ +from torch._higher_order_ops._invoke_quant import ( + invoke_quant, + invoke_quant_packed, + InvokeQuant, +) +from torch._higher_order_ops.aoti_call_delegate import aoti_call_delegate +from torch._higher_order_ops.associative_scan import associative_scan +from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized, + auto_functionalized_v2, +) +from torch._higher_order_ops.base_hop import BaseHOP +from torch._higher_order_ops.cond import cond +from torch._higher_order_ops.effects import with_effects +from torch._higher_order_ops.executorch_call_delegate import executorch_call_delegate +from torch._higher_order_ops.flat_apply import flat_apply +from torch._higher_order_ops.flex_attention import ( + flex_attention, + flex_attention_backward, +) +from torch._higher_order_ops.foreach_map import _foreach_map, foreach_map +from torch._higher_order_ops.hints_wrap import hints_wrapper +from torch._higher_order_ops.invoke_subgraph import invoke_subgraph +from torch._higher_order_ops.map import map +from torch._higher_order_ops.out_dtype import out_dtype +from torch._higher_order_ops.run_const_graph import run_const_graph +from torch._higher_order_ops.scan import scan +from torch._higher_order_ops.strict_mode import strict_mode +from torch._higher_order_ops.torchbind import call_torchbind +from torch._higher_order_ops.while_loop import while_loop +from torch._higher_order_ops.wrap import ( + dynamo_bypassing_wrapper, + tag_activation_checkpoint, + wrap_activation_checkpoint, + wrap_with_autocast, + wrap_with_set_grad_enabled, +) + + +__all__ = [ + "cond", + "while_loop", + "invoke_subgraph", + "scan", + "map", + "flex_attention", + "flex_attention_backward", + "hints_wrapper", + "BaseHOP", + "flat_apply", + "foreach_map", + "_foreach_map", + "with_effects", + "tag_activation_checkpoint", + "auto_functionalized", + "auto_functionalized_v2", + "associative_scan", + "out_dtype", + "executorch_call_delegate", + "call_torchbind", + "run_const_graph", + "InvokeQuant", + "invoke_quant", + "invoke_quant_packed", + "wrap_with_set_grad_enabled", + "wrap_with_autocast", + "wrap_activation_checkpoint", + "dynamo_bypassing_wrapper", + "strict_mode", + "aoti_call_delegate", + "map", +] diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d042d088f06f5e811d132b66ca49fff961e2a931 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4417ef967818c69e88d6f760a2f360455f9688d3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..275961664a687da4b075f0ad949366c93e195b76 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/associative_scan.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/associative_scan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d25f678ee8d5b96d97cbbc722f285338deff422 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/associative_scan.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78c3d99a692fed4395ea450be51f8ad3892f1cb0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/base_hop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/base_hop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f59526c8bed87d1025207c70ece9802cd162c1c3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/base_hop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12eb57916004d900e7555439802440f7c0a24031 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df43d773d60791e98c80382555c58bc3e260e8d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc3d13fe411fd3b06fbedb84db35ce6072f0959b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/flat_apply.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/flat_apply.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..392852b2c95acb4f2a7223e1e186ec491b4d4f85 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/flat_apply.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/flex_attention.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/flex_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f7da9e647bda9b09f3e75b198d7c14866f6445c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/flex_attention.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/foreach_map.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/foreach_map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7a2852a80fc94ba3406cb327b39e654310f4af5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/foreach_map.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/hints_wrap.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/hints_wrap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e67c4b361cc9707d9cbbb559a0a8b9d5409962af Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/hints_wrap.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5a6701f7e3b30fd6c279dde746994149b5c7bbf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f23d1a38b002e0f22af79694f12706a100e57aa4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eca8f5c93b5de4809451c28888f334d81938ca9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/out_dtype.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/run_const_graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/run_const_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d517d8c3b9b154dc92339e65b547695726cd612 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/run_const_graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/scan.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/scan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9a63b83536f5c136f3664a9401abb4aaad7ae43 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/scan.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/schema.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63197914f273e254a6bc8c9240bee140abc5d776 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/schema.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a8c22ea2ccee0d67cd931a36c0f3ae85ac659f9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9c9f658177ac3abf0fb25cdebede8060cb4033b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/torchbind.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0c1d9b48221d0d8f85dbdfee60b2258d6d443ce Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f533a921541c55ab8f435707ac15ed250b6d5d16 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd66e4d7801b9266a1209220a5594db66849c38a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/while_loop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..165e095f3e230ff82b4b755f771889f61e2bd0b9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/_invoke_quant.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/_invoke_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..e8df9c2c712501b028dd9b66f068451812ed4b5b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/_invoke_quant.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +# need to fix prim_hop_base type annotations first + +import dataclasses +from typing import Optional + +import torch +from torch._higher_order_ops.base_hop import BaseHOP, FunctionWithNoFreeVars + + +class InvokeQuantTracer(BaseHOP): + def __init__(self) -> None: + super().__init__("invoke_quant_packed") + + def __call__(self, subgraph, *operands, scheme=None, quant_options=None): + subgraph = FunctionWithNoFreeVars(subgraph) + return super().__call__( + subgraph, *operands, scheme=scheme, quant_options=quant_options + ) + + +invoke_quant_packed = InvokeQuantTracer() + + +class InvokeQuantUnpacked(BaseHOP): + def __init__(self) -> None: + super().__init__("invoke_quant") + + def __call__(self, subgraph, *operands, scheme=None): + return super().__call__(subgraph, *operands, scheme=scheme) + + +invoke_quant = InvokeQuantUnpacked() + + +@dataclasses.dataclass(frozen=True, repr=True) +class InvokeQuant: + """ + Invoke a quantization function that will be preserved as a single operator. Preservation + as a single operator aids in pattern matching and custom lowerings. + + The operation appears as: + torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=scheme) + + Args: + codegen_low_precision: Use observed subgraph dtypes for codegen instead of + upcasting to fp32. Can improve performance for prologue fusion but + requires careful testing of numerics. + """ + + codegen_low_precision: bool = True + + def __call__( + self, + *args, + scheme: Optional[str] = None, + **kwargs, + ): + if not torch.compiler.is_compiling(): + return args[0](*args[1:], **kwargs) + + if scheme is not None: + kwargs["scheme"] = scheme + + return invoke_quant_packed(*args, **kwargs, quant_options=self) # type: ignore[call-arg] diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/aoti_call_delegate.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/aoti_call_delegate.py new file mode 100644 index 0000000000000000000000000000000000000000..fc166a0128aeb517835875073643f4a1ef7c6c0d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/aoti_call_delegate.py @@ -0,0 +1,161 @@ +# mypy: allow-untyped-defs + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import torch +import torch.utils._pytree as pytree +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +AOTI_LOWERED_MODULE = "AOTInductorEPModule/AOTInductorRunnerWrapper" + + +class AOTICallDelegate(HigherOrderOperator): + """aoti_call_delegate is a HOP for calling AOTInductor lowered submodule in ExportedProgram. + + It has the following signature: + aoti_call_delegate( + lowered_module: Union[AOTInductorEPModule, AOTInductorRunnerWrapper] + original_gm:fx.GraphModule, + weight_args: List[Tensor], + input_args: List[Tensor], + ) -> outputs: List[Tensor] + + where, + - lowered_module is the AOTInductor lowered submodule, backed by compiled .so file, supporting real tensor inputs + - original_gm is the stateless version of the original GraphModule before lowering, allowing FakeTensor propagation + - weight_args is the list of weights in original GraphModule, including parameters and buffers + - input_args is the list of flatten inputs + """ + + def __init__(self) -> None: + super().__init__("aoti_call_delegate") + + def __call__( + self, + lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] + original_gm: torch.fx.GraphModule, + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], + ) -> list[torch.Tensor]: + return super().__call__(lowered_module, original_gm, weight_args, input_args) + + +aoti_call_delegate = AOTICallDelegate() +aoti_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) +aoti_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot) +aoti_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView) +aoti_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU) + + +@aoti_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) +def call_delegate_cpu( + lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] + original_gm: torch.fx.GraphModule, + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +) -> list[torch.Tensor]: + # FX creates this immutable_dict/list concept. Get rid of this. + map_types: dict[type, type] = { + torch.fx.immutable_collections.immutable_dict: dict, + torch.fx.immutable_collections.immutable_list: list, + } + new_args = pytree.tree_map_only( + tuple(map_types.keys()), + lambda a: map_types[type(a)](a), + weight_args + input_args, + lambda a: isinstance(a, tuple(map_types.keys())), + ) + has_fake_args = any(isinstance(arg, FakeTensor) for arg in new_args) + if has_fake_args: + # use stateless original_gm for tracing with fake tensors + fake_out = original_gm(*new_args) + return fake_out + else: + # use AOTI Runner for real tensors + new_input_args = new_args[len(weight_args) :] + if type(lowered_module).__name__ == "AOTInductorRunnerWrapper": + return lowered_module(*new_input_args) # type: ignore[misc] + elif type(lowered_module).__name__ == "AOTInductorEPModule": + return lowered_module(new_input_args) # type: ignore[misc] + else: + raise RuntimeError( + f"Unexpected lowered_module type: {type(lowered_module)}." + ) + + +def trace_aoti_call_delegate( + proxy_mode, func_overload, lowered_module, original_gm, weight_args, input_args +): + proxy_mode.tracer.root.register_module("lowered_module", lowered_module) + proxy_mode.tracer.root.register_module("original_gm", original_gm) + + node_args = (lowered_module, original_gm, weight_args, input_args) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="aoti_call_delegate" + ) + with disable_proxy_modes_tracing(): + out = call_delegate_cpu(lowered_module, original_gm, weight_args, input_args) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@aoti_call_delegate.py_impl(ProxyTorchDispatchMode) +def call_delegate_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, + lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] + original_gm: torch.fx.GraphModule, + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +): + res = trace_aoti_call_delegate( + mode, aoti_call_delegate, lowered_module, original_gm, weight_args, input_args + ) + return res + + +@aoti_call_delegate.py_impl(FakeTensorMode) +def call_delegate_fake_tensor_mode( + mode: FakeTensorMode, + lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] + original_gm: torch.fx.GraphModule, + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +) -> list[torch.Tensor]: + with mode: + return call_delegate_cpu(lowered_module, original_gm, weight_args, input_args) + + +@aoti_call_delegate.py_functionalize_impl +def call_delegate_functionalize( + ctx, + lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] + original_gm: torch.fx.GraphModule, + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +): + unwrapped_weight_args = tuple( + ctx.unwrap_tensors(weight_arg) for weight_arg in weight_args + ) + unwrapped_input_args = tuple( + ctx.unwrap_tensors(input_arg) for input_arg in input_args + ) + with ctx.redispatch_to_next(): + res = aoti_call_delegate( + lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type] + ) + return ctx.wrap_tensors(res) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/associative_scan.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/associative_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..44da4555637dc43a103fa330eeff812065201ad6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/associative_scan.py @@ -0,0 +1,491 @@ +# mypy: allow-untyped-defs +import functools +import itertools +from typing import Any, Callable + +import torch +import torch._prims_common as utils +import torch._subclasses.functional_tensor +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _maybe_compile_and_run_fn, + _maybe_run_with_interpreter, + autograd_not_implemented, + check_meta_consistency, + first_slice_copy, + reenter_make_fx, + unique_graph_id, + validate_subgraph_args_types, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +aten = torch._ops.ops.aten + + +def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves): + assert ( + len(args) == 2 * num_leaves + ), f"Combin_fn received wrong number of arguments, expected {2 * num_leaves}, but got {len(args)}" + lhs = pytree.tree_unflatten(args[:num_leaves], spec) + rhs = pytree.tree_unflatten(args[num_leaves:], spec) + return combine_fn(lhs, rhs) + + +def _interleave(a, b, dim=0): + # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors + if b_trunc := (a.shape[dim] == b.shape[dim] + 1): + pad = ( + [0] * ((b.ndim - dim - 1) * 2 + 1) + + [1] + + [0] * (b.ndim * 2 - ((b.ndim - dim - 1) * 2 + 2)) + ) + b = torch.nn.functional.pad(b, pad) + + stacked = torch.stack([a, b], dim=dim + 1) + interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1) + if b_trunc: + # TODO: find torch alternative for slice_along dim for torch.jit.script to work + interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1) + return interleaved + + +def safe_map(f, *args): + args = list(map(list, args)) + n = len(args[0]) + for arg in args[1:]: + if len(arg) != n: + raise ValueError("length mismatch: {list(map(len, args))}") + + def nf(a): + return f(*a) + + return list(map(nf, zip(*args))) + + +class AssociativeScanOp(HigherOrderOperator): + def __init__(self): + super().__init__("associative_scan") + + def __call__(self, combine_fn, xs, additional_inputs): + # There is currently an issue that the ScanOp is sometimes called with + # the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785 + # Once this issue is resolved, the assertion should only allow tuples + # and the tuple cast should be removed + assert isinstance( + additional_inputs, (tuple, list) + ), "additional_inputs must be a tuple." + additional_inputs = ( + tuple(additional_inputs) + if isinstance(additional_inputs, list) + else additional_inputs + ) + validate_subgraph_args_types(additional_inputs) + return super().__call__(combine_fn, xs, additional_inputs) + + +associative_scan_op = AssociativeScanOp() + + +def associative_scan( + combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree], + xs: pytree.PyTree, + dim: int, + reverse: bool = False, + combine_mode: str = "pointwise", +) -> torch.Tensor: + r""" + Performs an inclusive scan with an associative combine function. + + .. warning:: + `torch.associative_scan` is a prototype feature in PyTorch. It currently + does not support autograd and you may run into miscompiles. + Read more about feature classification at: + https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + This operator requires runtime code generation and so requires support for + ``torch.compile``. Further, only CUDA device codegen is supported at the moment. + + Args: + combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, + or if input is a pytree ``(pytree, pytree) -> pytree``. + This function must be pure, i.e., no lifted arguments are supported at the moment, + satisfy the associative property and have no side-effects. + xs (torch.Tensor): The input tensor, or nested pytree of tensors. + All inputs are expected to have the same shape. + dim (int): the dimension to scan over + reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``. + combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``, default ``pointwise``. + If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations + and ``xs`` must be CUDA tensors. + In all other cases ``combine_mode=generic`` should be used. + Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``. + + + Example:: + + def add(x: torch.Tensor, y: torch.Tensor): + return x + y + + cumsum = associative_scan(add, x, dim) + + """ + # The reason we flatten xs before calling into dynamo is that + # we want to create a consistent input ordering for combine_fn + # and we also want to the input ordering matches the output ordering. + leaves_xs_orig, spec_xs = pytree.tree_flatten(xs) + + def _validate_input(cfn, lxs, d, r, cm): + # Basic arguments check + if not callable(cfn): + raise ValueError("Combine_fn must be a callable, but got {cfn}") + if not isinstance(d, int): + raise ValueError("Dim must be an int, but got " + str(type(d))) + if not isinstance(r, bool): + raise RuntimeError("Reverse must be a bool, but got " + str(type(r))) + if cm not in ["pointwise", "generic"]: + raise ValueError( + "Combine_mode must either 'pointwise' or 'generic', but got {cm}" + ) + if cm == "pointwise" and not all(l.device.type == "cuda" for l in lxs): + raise ValueError( + "For combine_mode='pointwise', all input tensors need to be on CUDA" + ) + + # Checks for xs + if len(lxs) == 0: + raise ValueError("Expected at least 1 xs leaf") + if any(not isinstance(x, torch.Tensor) for x in lxs): + raise ValueError("xs leaves must be a Tensor") + if any(x.is_sparse for x in lxs): + raise ValueError( + "xs leaves must dense Tensors, consider using `to_dense()`" + ) + if any(x.ndim <= d for x in lxs): + raise ValueError( + "All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0" + ) + if any(x.shape[d] == 0 for x in lxs): + raise ValueError( + "All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0" + ) + + ndim = leaves_xs_orig[0].ndim + dim = utils.canonicalize_dim(ndim, dim) + + _validate_input(combine_fn, leaves_xs_orig, dim, reverse, combine_mode) + + # Move scan dim to 0 and always perform scan on dim 0 + leaves_xs = [torch.movedim(elem, dim, 0) for elem in leaves_xs_orig] + + if reverse: + leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs] + + # TODO: Support Autograd + # TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc. + + if combine_mode == "generic": + # The generic_associative_scan implementation calls the combine_fn with a `batch` along the scan dimension + # For example, consider: + # def add(x: torch.Tensor, y: torch.Tensor): + # return x + y + # leaves = torch.tensor([[0.0, 1.0, 2.0, 3.0] + # [0.0, 1.0, 2.0, 3.0]]) + # which has shape 2 x 4; + # dim = 1; + # In the first iteration of `_scan` the combine_fn gets invoked with + # combine_fn([torch.tensor([[0.0, 2.0], + # [0.0, 2.0]])], + # [torch.tensor([[1.0, 3.0], + # [1.0, 3.0]])]) + # The arguments are of shape 2 x 2, but can be evaluated in parallel along the scan dimension. + combine_fn = functools.partial( + wrap_combine_fn_flat, + combine_fn=torch.vmap( + combine_fn, + in_dims=( + pytree.tree_unflatten([0] * len(leaves_xs), spec_xs), + pytree.tree_unflatten([0] * len(leaves_xs), spec_xs), + ), + out_dims=0, + ), + spec=spec_xs, + num_leaves=len(leaves_xs), + ) + out = generic_associative_scan(combine_fn, leaves_xs, additional_inputs=()) + out = pytree.tree_unflatten(out, spec_xs) + else: + combine_fn = functools.partial( + wrap_combine_fn_flat, + combine_fn=combine_fn, + spec=spec_xs, + num_leaves=len(leaves_xs), + ) + + def run_flattened_associative_scan(combine_fn, leaves_xs): + return associative_scan_op(combine_fn, leaves_xs, additional_inputs=()) + + out = _maybe_compile_and_run_fn( + run_flattened_associative_scan, + combine_fn, + leaves_xs, + ) + + if reverse: + out = pytree.tree_map(lambda elem: elem.flip([0]), out) + + out = pytree.tree_map(lambda elem: torch.movedim(elem, 0, dim), out) + + return out + + +def generic_associative_scan(operator, leaves, dim=0, additional_inputs=()): + r""" + This function performs the associative_scan operation. + The algorithm works by recursively collecting neighbours of ``leaves`` and subsequently + applying the ``operator`` on all pairs in parallel along ``dim``. + The results of the recursive calls are later combined. + + Args: + operator (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, + or if input is a pytree ``(pytree, pytree) -> pytree``. + This function must be pure, pointwise, and satisfy the associative property. + leaves (torch.Tensor): A list of torch.Tensors converted from the pytree of + ``xs`` provided to ``associative_scan``. + All inputs are expected to have the same shape. + dim (int): the dimension to scan over + additional_inputs (Tuple of tensors): A tuple of lifted parameters from the global scope. + This parameter will be populated internally. + + Example:: + + def add(x: torch.Tensor, y: torch.Tensor): + return x + y + + leaves = torch.tensor([0.0, 1.0, 2.0, 3.0]) + + First iteration of _scan -> + # odd_elems -> apply operator on all neighbours + # odd_elems = operator([torch.tensor([0.0, 2.0])], + # [torch.tensor([1.0, 3.0])]) + odd_elems = torch.tensor([1.0, 5.0]) + Second iteration of _scan -> + # odd_elems = operator([torch.tensor([1.0])], + # [torch.tensor([5.0])]) + odd_elems = torch.tensor([6.0]) + # even_elems -> apply operator on all odd_elems and + # every second element of ``elems``, starting from the second element. + # even_elems is expanded with the first element of ``elems`` + even_elems = [1.0] + # Merges odd_elems and even_elems + res = torch.tensor([1.0, 6.0]) + # even_elems -> apply operator on all odd_elems and + # every second element of ``elems``, starting from the second element. + # even_elems is expanded with the first element of ``elems`` + even_elems = [0.0, 3.0] + # Merges odd_elems and even_elems + res = torch.tensor([0.0, 1.0, 3.0, 6.0]) + + """ + + def call_operator(*args): + return pytree.tree_leaves(operator(*args)) + + def _scan(elems): + """Perform the actual recursive scan on ``elems``.""" + num_elems = elems[0].shape[dim] + + if num_elems < 2: + return elems + + reduced_elems = call_operator( + *[aten.slice(elem, dim, 0, -1, 2) for elem in elems], + *[aten.slice(elem, dim, 1, None, 2) for elem in elems], + *additional_inputs, + ) + + # Recursively compute scan for partially reduced tensors. + odd_elems = _scan(reduced_elems) + + if num_elems % 2 == 0: + even_elems = call_operator( + *[aten.slice(e, dim, 0, -1) for e in odd_elems], + *[aten.slice(e, dim, 2, None, 2) for e in elems], + *additional_inputs, + ) + else: + even_elems = call_operator( + *odd_elems, + *[aten.slice(e, dim, 2, None, 2) for e in elems], + *additional_inputs, + ) + + # The first element of a scan is the same as the first element + # of the original `elems`. + even_elems = [ + torch.cat([aten.slice(elem, dim, 0, 1), result], dim=dim) + if result.shape.numel() > 0 and elem.shape[dim] > 0 + else result + if result.shape.numel() > 0 + else aten.slice( + elem, dim, 0, 1 + ) # Jax allows/ignores concat with 0-dim, Pytorch does not + for (elem, result) in zip(elems, even_elems) + ] + + return list( + safe_map(functools.partial(_interleave, dim=dim), even_elems, odd_elems) + ) + + scans = _scan(leaves) + + return scans + + +def trace_associative_scan( + proxy_mode, + func_overload, + combine_fn: Callable, + xs: list[torch.Tensor], + additional_inputs: tuple[torch.Tensor], +): + from torch._dynamo.utils import clone_input + + with disable_proxy_modes_tracing(): + sample_xs = [first_slice_copy(x) for x in itertools.chain(xs, xs)] + sample_additional_inputs = [ + clone_input(x) if isinstance(x, torch.Tensor) else x + for x in additional_inputs + ] + combine_graph = reenter_make_fx(combine_fn)( + *sample_xs, *sample_additional_inputs + ) + + outputs = None + for node in combine_graph.graph.nodes: + if node.op == "output": + assert outputs is None + assert len(node.args) == 1 + outputs = node.args[0] + + assert outputs is not None + outputs = pytree.tree_leaves(outputs) + assert len(outputs) == len( + xs + ), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}" + + xs_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ + first_slice_copy(x) for x in xs + ] + output_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ + c.meta["val"] for c in outputs + ] + check_meta_consistency( + xs_fake_tensors, output_fake_tensors, "init", "carry", include_contiguity=False + ) + + _, combine_graph_name = unique_graph_id( + proxy_mode, prefix="associative_scan_combine_graph" + ) + + proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph) + + args = (combine_graph, xs, additional_inputs) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="associative_scan" + ) + + with disable_proxy_modes_tracing(): + out = tuple(aten.clone(x) for x in xs) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def associative_scan_op_dense(combine_fn, xs, additional_inputs): + return generic_associative_scan(combine_fn, xs, additional_inputs=additional_inputs) + + +associative_scan_op.py_autograd_impl( + autograd_not_implemented(associative_scan_op, deferred_error=True) +) + + +@associative_scan_op.py_impl(ProxyTorchDispatchMode) +def associative_scan_proxy_mode(mode, combine_fn, xs, additional_inputs): + return trace_associative_scan( + mode, associative_scan_op, combine_fn, xs, additional_inputs + ) + + +@associative_scan_op.py_impl(FakeTensorMode) +def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs, additional_inputs): + with mode: + return tuple(x.clone() for x in xs) + + +@associative_scan_op.py_functionalize_impl +def associative_scan_functionalize(ctx, combine_fn, xs, additional_inputs): + from torch._higher_order_ops.utils import _check_alias_and_mutation + + unwrapped_xs = ctx.unwrap_tensors(xs) + unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) + with ctx.redispatch_to_next(): + functional_combine_fn = ctx.functionalize( + _maybe_run_with_interpreter(combine_fn) + ) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + sample_unwrapped_xs_sliced = [ + first_slice_copy(inp) for inp in itertools.chain(unwrapped_xs, unwrapped_xs) + ] + sample_inputs = list( + itertools.chain( + sample_unwrapped_xs_sliced, + unwrapped_additional_inputs, + ) + ) + _check_alias_and_mutation( + combine_fn, sample_inputs, "associative_scan", pre_dispatch + ) + ret = associative_scan_op( + functional_combine_fn, + unwrapped_xs, + unwrapped_additional_inputs, + ) + return ctx.wrap_tensors(ret) + + +def _fake_associative_scan(combine_fn, xs, dim, reverse=False): + inp_leaves, spec = pytree.tree_flatten(xs) + result_flat: list[Any] = [] + num_leaves = len(inp_leaves) + op = reversed if reverse else lambda x: x + + for ind in op(range(inp_leaves[0].size(dim))): + r = [ + inp_leaves[leave_ind][(slice(None),) * dim + (ind,)] + for leave_ind in range(num_leaves) + ] + if (ind > 0 and not reverse) or ( + ind < (inp_leaves[0].size(dim) - 1) and reverse + ): + r = combine_fn( + pytree.tree_unflatten(result_flat[-1], spec), + pytree.tree_unflatten(r, spec), + ) + r_flat, _ = pytree.tree_flatten(r) + result_flat.append(r_flat) + + results = [ + torch.stack([e[leave_ind] for e in op(result_flat)], dim) + for leave_ind in range(num_leaves) + ] + return pytree.tree_unflatten(results, spec) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/auto_functionalize.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/auto_functionalize.py new file mode 100644 index 0000000000000000000000000000000000000000..f263e39fb3a6f06bd3d7cf13446758076b0d0466 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/auto_functionalize.py @@ -0,0 +1,1006 @@ +# mypy: allow-untyped-defs +import warnings +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, get_args, Optional, Union + +import torch +import torch._library.utils as library_utils +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _has_gen_schema, + call_op, + HopInstance, + HopSchema, + materialize_callable_in_args, + unique_graph_id, +) +from torch._ops import HigherOrderOperator, OperatorBase, OpOverload +from torch._prims_common import clone_preserve_strides +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +class SchemaHolder: + def __init__(self, schema: torch.FunctionSchema): + self.schema = schema + + def __eq__(self, other): + return self.schema == other.schema + + def __hash__(self) -> int: + return hash(self.schema) + + @classmethod + def from_tree_spec(cls, tree_spec: pytree.TreeSpec): + assert tree_spec is not None + return cls(pytree.tree_unflatten([], tree_spec).schema) + + +# regsiter_constant allows us to get a tree_spec from pytree.tree_flatten(SchemaHolder(FunctionSchema)). +# The tree_spec is proxable in the graph and we can get back the schema via +# schema = pytree.tree_unflatten([], tree_spec).schema +pytree.register_constant(SchemaHolder) + + +def get_base(tensor): + if torch.is_inference_mode_enabled(): + return tensor._inference_mode_base + else: + return tensor._base + + +class ViewInfo(ABC): + base_index: int + + def __init__(self, base_index): + self.base_index = base_index + + @abstractmethod + def regenerate_view(self, bases_list: list[Tensor]): + pass + + +@dataclass +class AsStridedViewInfo(ViewInfo): + size: Sequence[Union[int, torch.SymInt]] + stride: Sequence[Union[int, torch.SymInt]] + storage_offset: int + + def __init__(self, base_index, size, stride, storage_offset): + super().__init__(base_index) + self.size = size + self.stride = stride + self.storage_offset = storage_offset + + def regenerate_view(self, bases_list: list[Tensor]): + return torch.as_strided( + bases_list[self.base_index], + self.size, + self.stride, + self.storage_offset, + ) + + +@dataclass +class SliceViewInfo(ViewInfo): + dim: Union[int, torch.SymInt] + start: Union[int, torch.SymInt] + end: Union[int, torch.SymInt] + + def __init__(self, base_index, dim, start, end): + super().__init__(base_index) + self.dim = dim + self.start = start + self.end = end + + def regenerate_view(self, bases_list: list[Tensor]): + return torch.ops.aten.slice.Tensor( + bases_list[self.base_index], self.dim, self.start, self.end + ) + + +@dataclass +class AliasViewInfo(ViewInfo): + def __init__(self, base_index): + super().__init__(base_index) + + def regenerate_view(self, bases_list: list[Tensor]): + return torch.ops.aten.alias.default(bases_list[self.base_index]) + + +@dataclass +class NotView(ViewInfo): + def __init__(self, base_index): + super().__init__(base_index) + + def regenerate_view(self, bases_list: list[Tensor]): + return bases_list[self.base_index] + + +def is_alias(base, tensor): + from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq + + return all( + statically_known_true(a) + for a in [ + sym_eq(base.storage_offset(), tensor.storage_offset()), + sym_eq(base.stride(), tensor.stride()), + sym_eq(base.size(), tensor.size()), + ] + ) + + +# return None or (dim, start, end) +def try_use_slice(base, tensor): + from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq + + # This condition should never be triggered. + if is_alias(base, tensor): + return (0, 0, base.size()[0]) + + # TODO is there cases can we use slice even if stride or len(sizes) are not equal? + if not statically_known_true(sym_eq(tensor.stride(), base.stride())): + return None + if not statically_known_true(sym_eq(len(tensor.size()), len(base.size()))): + return None + + dim = None + count = 0 + for i in range(len(tensor.size())): + if base.size()[i] != tensor.size()[i]: + dim = i + count = count + 1 + if count != 1: + return None + + if tensor.storage_offset() % tensor.stride()[dim] != 0: + return None + start = tensor.storage_offset() // tensor.stride()[dim] + end = start + tensor.size()[dim] + return (dim, start, end) + + +def write_view_information_to_args( + mutable_arg_names: list[str], + mutable_arg_types: list[torch.Type], + kwargs: dict[str, Any], + arg_to_base_index: dict[str, Any], +): + """ + This function writes the view information into kwargs. It reads mutable_args from kwargs. + and uses arg_to_base_index and tensor information to write ViewInfo into kwargs. + mutable_arg_names: mutable custom operator arg names. + mutable_arg_types: mutable custom operator arg types. + kwargs: the original custom operator args. + arg_to_base_index: maps mutable_arg_name to int | [int] that refers to the base tensor that + corresponds to the input tensor + """ + + def write_single_view(prefix: str, tensor: Tensor, base_index: int): + assert f"{prefix}_base_index" not in kwargs + assert f"{prefix}_size" not in kwargs + assert f"{prefix}_stride" not in kwargs + assert f"{prefix}_storage_offset" not in kwargs + + assert f"{prefix}_slice_dim" not in kwargs + assert f"{prefix}_slice_start" not in kwargs + assert f"{prefix}_slice_end" not in kwargs + + def use_as_strided(tensor): + kwargs[f"{prefix}_size"] = tensor.size() + kwargs[f"{prefix}_stride"] = tensor.stride() + kwargs[f"{prefix}_storage_offset"] = tensor.storage_offset() + + def use_slice(dim, start, end): + kwargs[f"{prefix}_slice_dim"] = dim + kwargs[f"{prefix}_slice_start"] = start + kwargs[f"{prefix}_slice_end"] = end + + def use_alias(): + kwargs[f"{prefix}_alias"] = True + + # The start if the function + if tensor is None: + kwargs[f"{prefix}_base_index"] = None + else: + base = get_base(tensor) + kwargs[f"{prefix}_base_index"] = base_index + if base is None: + # no need to add anything else other than _base_index + return + elif is_alias(base, tensor): + use_alias() + elif (slice_info := try_use_slice(base, tensor)) is not None: + use_slice(*slice_info) + else: + use_as_strided(tensor) + + for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): + arg = kwargs[arg_name] + if library_utils.is_tensorlist_like_type(arg_type): + if arg is None: + kwargs[f"_{arg_name}_length"] = None + else: + kwargs[f"_{arg_name}_length"] = len(arg) + for i, elem in enumerate(arg): + write_single_view( + f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i] + ) + + elif library_utils.is_tensor_like_type(arg_type): + write_single_view( + f"_{arg_name}", + kwargs[arg_name], + arg_to_base_index.get(arg_name, None), # type: ignore[arg-type] + ) + else: + raise RuntimeError(f"Unsupported type {arg_type}") + + +# Returns a dict of arg_name -> ViewInfo | [ViewInfo] +def read_view_information_from_args( + mutable_arg_names: list[str], + mutable_arg_types: list[torch.Type], + kwargs: dict[str, Any], + all_bases: list[Tensor], +): + """ + This reads the view information added by `write_view_information_to_args` from kwargs, pop them, + and returns a dict arg_name -> ViewInfo | [ViewInfo](if the input is list). that maps each mutable arg + to its view information. + mutable_arg_names: mutable custom operator arg names. + mutable_arg_types: mutable custom operator arg types. + kwargs : args of auto_functionalize(custom_op, kwargs) + """ + + def get_arg(name): + return kwargs.pop(name) + + def read_single_view(prefix): + base_index = get_arg(f"{prefix}_base_index") + if base_index is None: + return None + elif f"{prefix}_alias" in kwargs: + get_arg(f"{prefix}_alias") + return AliasViewInfo(base_index) + elif f"{prefix}_storage_offset" in kwargs: + # The view is regenerated using as_strided. + size = get_arg(f"{prefix}_size") + stride = get_arg(f"{prefix}_stride") + storage_offset = get_arg(f"{prefix}_storage_offset") + return AsStridedViewInfo(base_index, size, stride, storage_offset) + elif f"{prefix}_slice_dim" in kwargs: + dim = get_arg(f"{prefix}_slice_dim") + start = get_arg(f"{prefix}_slice_start") + end = get_arg(f"{prefix}_slice_end") + return SliceViewInfo(base_index, dim, start, end) + else: + # This means that the argument is the base tensor + return NotView(base_index) + + args_view_info: dict[str, Any] = {} + for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): + if library_utils.is_tensorlist_like_type(arg_type): + length = get_arg(f"_{arg_name}_length") + if length is None: + # The whole list is None. + args_view_info[arg_name] = None + else: + args_view_info[arg_name] = [ + read_single_view(f"_{arg_name}_{i}") for i in range(length) + ] + + elif library_utils.is_tensor_like_type(arg_type): + args_view_info[arg_name] = read_single_view(f"_{arg_name}") + else: + raise RuntimeError(f"Unsupported type {arg_type}") + return args_view_info + + +# NOTE: [auto-functionalizing custom ops] +# Users may wish to torch.compile custom ops that mutate their inputs. +# torch.compile will automatically support this op without anyone needing +# to provide a functionalization kernel for it. Here's how. +# +# Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () +# op. First, when FakeTensor sees this op: +# - If the schema says it returns nothing, we can generate a trivial +# FakeTensor rule for it (that returns nothing). +# - Otherwise, the user needs to provide a FakeTensor impl (fake impl) +# +# Next, when Python FunctionalTensor sees the op, it will functionalize +# it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) +# HOP and replacing the mutated inputs with corresponding outputs of this HOP. +# This HOP effectively runs the functional version of the op when +# called: it clones inputs that will be mutated, runs the op, and +# then returns (output, Tensors with the new values) +# +# auto_functionalize_v2 is an improved version of auto_functionalize that better handle +# re-inplacing views. + + +class AutoFunctionalized(HigherOrderOperator): + """auto_functionalized(_mutable_op, **kwargs) + + This HOP runs a "functional" version of _mutable_op. + + Concretely, it looks at all the arguments that are mutable through + _mutable_op's operator schema, clones those kwargs, runs + `out = _mutable_op(**kwargs)` with the cloned values, and then returns the + operator output concatenated with the cloned values that were mutated. + + We have some restrictions on `_mutable_op`. + See `can_auto_functionalize` for the restrictions. We can likely lift + many of these if users request it. + + The reason why _mutable_op is prefixed with an + underscore is to prevent collisions with kwarg names in **kwargs. + """ + + def __init__(self) -> None: + super().__init__("auto_functionalized", cacheable=True) + + def __call__( + self, + /, + _mutable_op: OpOverload, + **kwargs: Any, + ) -> tuple[Any, tuple[Tensor, ...]]: + assert can_auto_functionalize(_mutable_op) + assert isinstance(kwargs, dict) + return super().__call__(_mutable_op, **kwargs) + + +auto_functionalized = AutoFunctionalized() +auto_functionalized.__module__ = "torch.ops.higher_order" + +auto_functionalized.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized.fallthrough(DispatchKey.AutogradCUDA) + + +_MutableOpType = Union[OpOverload, HigherOrderOperator] + + +class AutoFunctionalizedV2(HigherOrderOperator): + """auto_functionalized_v2(_mutable_op, **kwargs) + + This HOP runs a "functional" version of _mutable_op. + Unlike AutoFunctionalized, this version is improved to better handle + view tensors. This version is only used in non export mode. + """ + + def __init__(self) -> None: + super().__init__("auto_functionalized_v2", cacheable=True) + + def __call__( + self, + /, + _mutable_op: _MutableOpType, + **kwargs: Any, + ) -> tuple[Any, tuple[Tensor, ...]]: + _op_to_check: Optional[Union[OpOverload, HopInstance]] = None + if isinstance(_mutable_op, HigherOrderOperator): + _op_to_check = HopInstance( + _mutable_op, + SchemaHolder.from_tree_spec(kwargs.get("_op_schema", None)).schema, # type: ignore[arg-type] + ) + else: + _op_to_check = _mutable_op + + assert _op_to_check is not None + assert can_auto_functionalize(_op_to_check) + assert isinstance(kwargs, dict) + return super().__call__(_mutable_op, **kwargs) + + +auto_functionalized_v2 = AutoFunctionalizedV2() +auto_functionalized_v2.__module__ = "torch.ops.higher_order" + +auto_functionalized_v2.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized_v2.fallthrough(DispatchKey.AutogradCUDA) + + +def can_auto_functionalize( + op: Union[OperatorBase, HopInstance], +) -> bool: + if isinstance(op, HopInstance): + # HOPs that implement gen_schema and schema is not functional are auto_functionalizable. + if not _has_gen_schema(op._op): + return False + + else: + if not isinstance(op, OpOverload): + return False + + if torch._library.utils.is_builtin(op): + # We control the built-ins. These may (in rare cases) + # do input metadata mutation (which we have banned on custom ops) + return False + + schema = op._schema + if not schema.is_mutable: + return False + schema = op._schema + + for arg in schema.arguments: + if arg.alias_info is None: + continue + if not arg.alias_info.is_write: + continue + if torch._library.utils.is_tensor_like_type(arg.type): + continue + if torch._library.utils.is_tensorlist_like_type(arg.type): + continue + return False + + if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType): + # Skip schema returns -> None + return True + if isinstance(op, OpOverload): + # The returns of OpOverload must not alias anything + for ret in schema.returns: + if ret.alias_info is None and type(ret.type) is torch.TensorType: + continue + # Not yet supported: List[Tensor] return. + return False + if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Functionalize"): + return False + return True + + +def get_mutable_args_from_schema( + schema: torch.FunctionSchema, +) -> tuple[list[str], list[torch.Type]]: + """ + Returns the list of argument names that get mutated according to the + schema and their types. + """ + mutable_args_names = [ + arg.name + for arg in schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + + mutable_args_types = [ + arg.type + for arg in schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + return mutable_args_names, mutable_args_types # type: ignore[return-value] + + +def get_mutable_args(op: OpOverload) -> tuple[list[str], list[torch.Type]]: + return get_mutable_args_from_schema(op._schema) + + +def do_auto_functionalize( + mode: "torch._subclasses.functional_tensor.FunctionalTensorMode", + op: OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + """Functionalizes a call to op(*args, **kwargs) by emitting a call to + `outs = auto_functionalized(op, normalized_kwargs)` + and replacing the mutated (args, kwargs) with the corresponding outputs. + + The normalized_kwargs are just the (args, kwargs), but all in kwarg form. + This makes handling easier for the auto_functionalized HOP. + """ + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI(mode=mode) + + # All of the (args, kwargs), but all as kwargs. The names for the + # args come from the schema. This makes it easier for us to work with them. + normalized_kwargs = {} + schema = op._schema + for idx, arg in enumerate(schema.arguments): + # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema + if arg.name in kwargs: + normalized_kwargs[arg.name] = kwargs[arg.name] + elif idx < len(args): + # if its out of bounds we don't need to do anything + # as it means the the optional arg was passed with its default + # value + normalized_kwargs[arg.name] = args[idx] + else: + normalized_kwargs[arg.name] = arg.default_value + + unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] + if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: + warnings.warn( + "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " + "Please consider using a different name for this argument to avoid potential issues." + ) + with ctx.redispatch_to_next(): + unwrapped_outs = auto_functionalized( + op, **unwrapped_kwargs # type: ignore[arg-type] + ) + + # List of the name of args that get mutated (according to the schema) + mutable_args_names, _ = get_mutable_args(op) + + unwrapped_actual_out: Union[Any, tuple[Any]] = unwrapped_outs[ + : -len(mutable_args_names) + ] + unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :] + + if len(op._schema.returns) == 0: + assert unwrapped_actual_out[0] is None + unwrapped_actual_out = None + elif len(op._schema.returns) == 1: + assert len(unwrapped_actual_out) == 1 + unwrapped_actual_out = unwrapped_actual_out[0] + else: + assert len(unwrapped_actual_out) == len(op._schema.returns) + + for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out): + # Can be None if input was `Tensor(a!)?` + if unwrapped_out is None: + continue + + # We only handle Tensor or List[Tensor] here for now. + def sync_update(o, orig_arg): + ctx.replace(orig_arg, o) + ctx.commit_update(orig_arg) + ctx.sync(orig_arg) + + orig_arg = normalized_kwargs[name] + + if isinstance(unwrapped_out, torch.Tensor): + sync_update(unwrapped_out, orig_arg) + elif isinstance(unwrapped_out, list) and all( + isinstance(o, torch.Tensor) for o in unwrapped_out + ): + assert len(orig_arg) == len(unwrapped_out) + for orig_a, o in zip(orig_arg, unwrapped_out): + sync_update(o, orig_a) + else: + raise RuntimeError( + f"unsupported type for auto-functionalization: {unwrapped_out}" + ) + + return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] + + +def do_auto_functionalize_v2( + mode: "torch._subclasses.functional_tensor.FunctionalTensorMode", + op: Union[OpOverload, HopInstance], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI(mode=mode) + + # All of the (args, kwargs), but all as kwargs. The names for the + # args come from the schema. This makes it easier for us to work with them. + normalized_kwargs = {} + + schema = op._schema + op = op._op if isinstance(op, HopInstance) else op + assert isinstance(op, get_args(_MutableOpType)) + + def _functionalize_callable(arg: Any): + if callable(arg): + + def functional_fn(*args, **kwargs): + # We call torch.func.functionalize. This allows us to inline the epilogue graph. + # Inlining has the benefit of allowing easiser fusion inside subgraph. + # Though the epilogue graph contains copy_, it is OK becuase inductor can handle it + # and this is also how we have been supporting top-level graph input mutation. + return tuple( + pytree.tree_leaves(torch.func.functionalize(arg)(*args, **kwargs)) + ) + + return torch._higher_order_ops.base_hop.FunctionWithNoFreeVars( + functional_fn + ) + return arg + + args, kwargs = pytree.tree_map(_functionalize_callable, (args, kwargs)) + + for idx, arg in enumerate(schema.arguments): + # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema + if arg.name in kwargs: + normalized_kwargs[arg.name] = kwargs[arg.name] + elif idx < len(args): + # if its out of bounds we don't need to do anything + # as it means the the optional arg was passed with its default + # value + normalized_kwargs[arg.name] = args[idx] + else: + normalized_kwargs[arg.name] = arg.default_value + + # List of the name of args that get mutated (according to the schema) + mutable_args_names, mutable_args_types = get_mutable_args_from_schema(schema) + + # A list of all bases of mutable args without duplication + all_bases = [] + all_bases_addresses: list[int] = [] + + # Map arg_name to the index of its base in all_bases. + arg_to_base_index: dict[str, Any] = {} + + def update_dict(tensor, arg_name, index=None): + base = tensor if get_base(tensor) is None else get_base(tensor) + + def set_result(base_index): + if index is None: + arg_to_base_index[arg_name] = base_index + else: + arg_to_base_index[arg_name][index] = base_index + + if not all_bases_addresses.__contains__(base._cdata): + all_bases_addresses.append(base._cdata) + all_bases.append(base) + set_result(len(all_bases) - 1) + else: + set_result(all_bases_addresses.index(base._cdata)) + + for arg_name in mutable_args_names: + arg = normalized_kwargs[arg_name] + if arg is None: + continue + + if isinstance(arg, list): + arg_to_base_index[arg_name] = {} + for i, tensor in enumerate(arg): + if tensor is None: + arg_to_base_index[arg_name].append(None) + continue + + update_dict(tensor, arg_name, i) + + else: + update_dict(arg, arg_name) + + # add view_meta for each args into unwrapped_kwargs. + write_view_information_to_args( + mutable_args_names, + mutable_args_types, + normalized_kwargs, + arg_to_base_index, + ) + + # remove mutated args from the kwargs (its a function of _all_bases now) + for arg_name in mutable_args_names: + del normalized_kwargs[arg_name] # type: ignore[arg-type] + + unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] + if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: + warnings.warn( + "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " + "Please consider using a different name for this argument to avoid potential issues." + ) + all_basis_unwrapped = ctx.unwrap_tensors(all_bases) + + assert "_all_bases" not in unwrapped_kwargs, (op, unwrapped_kwargs) + auto_func_kwargs = dict(unwrapped_kwargs, _all_bases=all_basis_unwrapped) + if isinstance(op, HigherOrderOperator): + assert "_ops_schema" not in unwrapped_kwargs, (op, unwrapped_kwargs) + # We pass in the tree_spec of tree_flatten(SchemaHolder) to make it proxable + auto_func_kwargs.update( + {"_op_schema": pytree.tree_flatten(SchemaHolder(schema))[1]} + ) + + with ctx.redispatch_to_next(): + unwrapped_outs = auto_functionalized_v2( + op, **auto_func_kwargs # type: ignore[arg-type] + ) + + unwrapped_actual_out: Union[Any, tuple[Any]] = ( + unwrapped_outs if len(all_bases) == 0 else unwrapped_outs[: -len(all_bases)] + ) + + unwrapped_mutable_out = ( + [] if len(all_bases) == 0 else unwrapped_outs[-len(all_bases) :] + ) + + if isinstance(op, HigherOrderOperator): + assert ( + len(schema.returns) > 0 + ), f"hop is expected to return at least one output {schema}." + assert len(unwrapped_actual_out) == len(schema.returns) + else: + if len(schema.returns) == 0: + assert unwrapped_actual_out[0] is None + unwrapped_actual_out = None + elif len(schema.returns) == 1: + assert len(unwrapped_actual_out) == 1 + unwrapped_actual_out = unwrapped_actual_out[0] + else: + assert len(unwrapped_actual_out) == len(schema.returns) + + for orig_arg, unwrapped_out in zip(all_bases, unwrapped_mutable_out): + # Can be None if input was `Tensor(a!)?` + if unwrapped_out is None: + continue + + # We only handle Tensor or List[Tensor] here for now. + def sync_update(o, orig_arg): + ctx.replace(orig_arg, o) + ctx.commit_update(orig_arg) + ctx.sync(orig_arg) + + if isinstance(unwrapped_out, torch.Tensor): + sync_update(unwrapped_out, orig_arg) + elif isinstance(unwrapped_out, list) and all( + isinstance(o, torch.Tensor) for o in unwrapped_out + ): + assert len(orig_arg) == len(unwrapped_out) + for orig_a, o in zip(orig_arg, unwrapped_out): + sync_update(o, orig_a) + else: + raise RuntimeError( + f"unsupported type for auto-functionalization: {unwrapped_out}" + ) + + return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] + + +# auto_functionalize functions +@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_dense( + _mutable_op: OpOverload, + _only_clone_these_tensors: Optional[tuple[str, ...]] = None, + **kwargs: Any, +) -> tuple[Any, tuple[Tensor, ...]]: + new_kwargs = dict(**kwargs) + result = [] + + _mutable_args_names, _ = get_mutable_args(_mutable_op) + for name in _mutable_args_names: + if ( + _only_clone_these_tensors is not None + and name not in _only_clone_these_tensors + ): + new_kwargs[name] = kwargs[name] + else: + new_kwargs[name] = ( + [clone_preserve_strides(x) for x in kwargs[name]] + if kwargs[name] is not None and isinstance(kwargs[name], list) + else ( + clone_preserve_strides(kwargs[name]) + if kwargs[name] is not None + else None + ) + ) + result.append(new_kwargs[name]) + out = _mutable_op(**new_kwargs) + + if isinstance(out, tuple): + return (*out, *result) # type: ignore[return-value] + else: + return (out, *result) # type: ignore[return-value] + + +@auto_functionalized.py_impl(FakeTensorMode) +def auto_functionalized_fake( + mode, + _mutable_op: OpOverload, + **kwargs: Any, +) -> tuple[Any, tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_dense( + _mutable_op, _only_clone_these_tensors=None, **kwargs + ) + return result + + +@auto_functionalized.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_proxy( + mode, + _mutable_op: OpOverload, + **kwargs: Any, +) -> tuple[Any, tuple[Tensor, ...]]: + with disable_proxy_modes_tracing(): + out = auto_functionalized(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +@auto_functionalized.py_functionalize_impl +def auto_functionalized_func(ctx, _mutable_op, **kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + result = auto_functionalized(_mutable_op, **unwrapped_kwargs) + return ctx.wrap_tensors(result) + + +# auto_functionalized_v2 functions +@auto_functionalized_v2.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_v2_dense( + _mutable_op: _MutableOpType, + _only_clone_these_bases: Optional[tuple[int, ...]] = None, + **kwargs: Any, +) -> tuple[Any, tuple[Tensor, ...]]: + _all_bases: list[Tensor] = kwargs.pop("_all_bases", []) + if _only_clone_these_bases is None: + _only_clone_these_bases = tuple(range(len(_all_bases))) + + if isinstance(_mutable_op, OpOverload): + schema: torch._C.FunctionSchema = _mutable_op._schema + else: + schema = pytree.tree_unflatten([], kwargs.pop("_op_schema")).schema + + if isinstance(_mutable_op, OpOverload): + _callable_op: Union[HopInstance, OpOverload] = _mutable_op + else: + assert isinstance(schema, HopSchema) + _callable_op = HopInstance(_mutable_op, schema) + + op_kwargs_new, all_bases_new = _generate_new_op_kwargs_from_bases( + schema, + kwargs, + _all_bases, + _only_clone_these_bases, + ) + + out = call_op( + _callable_op, + tuple(), + op_kwargs_new, + ) + + if isinstance(out, tuple): + return (*out, *all_bases_new) # type: ignore[return-value] + else: + return (out, *all_bases_new) # type: ignore[return-value] + + +def _generate_new_op_kwargs_from_bases( + schema, kwargs, all_bases, _only_clone_these_bases +): + mutable_args_names, mutable_args_types = get_mutable_args_from_schema(schema) + args_view_info = read_view_information_from_args( + mutable_args_names, mutable_args_types, kwargs, all_bases + ) + + def maybe_copy(i, t): + if t is None: + return None + if i in _only_clone_these_bases: + return clone_preserve_strides(t) + else: + return t + + all_bases_new = [maybe_copy(i, t) for i, t in enumerate(all_bases)] + + # create new args + new_kwargs = dict(**kwargs) + + # re-generate all inputs from all_bases_new using args_view_info and add them to new_kwargs. + for arg_name in mutable_args_names: + if args_view_info[arg_name] is None: + new_kwargs[arg_name] = None + elif isinstance(args_view_info[arg_name], list): + new_kwargs[arg_name] = [] + for i, elem in enumerate(args_view_info[arg_name]): + if elem is None: + new_kwargs[arg_name].append(None) + else: + view_info = args_view_info[arg_name][i] + new_kwargs[arg_name].append( + view_info.regenerate_view(all_bases_new) + ) + else: + new_kwargs[arg_name] = args_view_info[arg_name].regenerate_view( + all_bases_new + ) + + return new_kwargs, all_bases_new + + +@auto_functionalized_v2.py_impl(FakeTensorMode) +def auto_functionalized_v2_fake( + mode, + _mutable_op: _MutableOpType, + **kwargs: dict[str, Any], +) -> tuple[Any, tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_v2_dense( + _mutable_op, _only_clone_these_bases=None, **kwargs + ) + return result + + +@auto_functionalized_v2.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_v2_proxy( + mode, + _mutable_op: _MutableOpType, + **kwargs: Any, +) -> tuple[Any, tuple[Tensor, ...]]: + if isinstance(_mutable_op, HigherOrderOperator): + # Note [materialize callable inputs as graph] + # Below code materializes the callable inputs to the hop as graph modules. + # kwargs may contain general callables, that are not proxable e.g. FunctionWithNoFreeVars + # this could happen when we auto_functionalize the backward of the hop, + # where backward fn is a callablle that wrapps forward graph module. + # This function materialize the callable args according to the schema of the hop. + + # We cannot materialize the callables in kwargs directly because the inputs to callable + # vary from hops to hop. To make the materialiation process generic to all hops, + # we trace a function that wraps the hop and let each hop itself figure out how to trace + # its callable inputs. Then we look at the schema of the traced hop node and replace the + # callable in original kwarg with the traced subgraphs. + # + # Specifically, we first trace a wrapped_fn that calls into the hop. Then we look for the + # hop node in the traced graph and graph module inputs to the hop. Finally, we replace the + # original kwarg's callable with the graph module. + all_bases = kwargs.get("_all_bases", []) + _only_clone_these_bases = kwargs.get("_only_clone_these_bases", None) + if _only_clone_these_bases is None: + _only_clone_these_bases = tuple(range(len(all_bases))) + + schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema # type: ignore[arg-type] + new_kwargs, _ = _generate_new_op_kwargs_from_bases( + schema, + {k: v for k, v in kwargs.items() if k not in ("_all_bases", "_op_schema")}, + all_bases, + _only_clone_these_bases, + ) + + _, materialized_kwargs = materialize_callable_in_args( + HopInstance(_mutable_op, schema), tuple(), new_kwargs + ) + + # Only replace the callabes in kwargs with the materialized subgraphs. + # The rest of the kwargs are kept unchanged. + for k, v in kwargs.items(): + if callable(v): + assert k in materialized_kwargs and isinstance( + materialized_kwargs[k], torch.fx.GraphModule + ) + kwargs[k] = materialized_kwargs[k] + + with disable_proxy_modes_tracing(): + out = auto_functionalized_v2(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + if isinstance(_mutable_op, HigherOrderOperator): + + def _maybe_register_subgraph(val: Any): + if isinstance(val, torch.fx.GraphModule): + _, graph_name = unique_graph_id( + mode, prefix="auto_functionalized_subgraph" + ) + mode.tracer.root.register_module(graph_name, val) + return val + return val + + proxy_kwargs = pytree.tree_map(_maybe_register_subgraph, proxy_kwargs) + + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized_v2, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +@auto_functionalized_v2.py_functionalize_impl +def auto_functionalized_v2_func(ctx, _mutable_op, **kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + result = auto_functionalized_v2(_mutable_op, **unwrapped_kwargs) + return ctx.wrap_tensors(result) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/base_hop.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/base_hop.py new file mode 100644 index 0000000000000000000000000000000000000000..9e74904fa6d9468c66affd04f08ab3b292181a42 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/base_hop.py @@ -0,0 +1,261 @@ +# mypy: allow-untyped-defs + +import abc + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._dispatch.python import suspend_functionalization +from torch._higher_order_ops.utils import ( + check_input_alias_and_mutation_return_outputs, + HopInstance, + materialize_as_graph, + reenter_make_fx, +) +from torch._ops import HigherOrderOperator +from torch._subclasses import FakeTensorMode +from torch._subclasses.functional_tensor import disable_functional_mode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +class BaseHOP(HigherOrderOperator, abc.ABC): + """ + This is the "Base" HOP implementation for a HOP that looks like: + + call_subgraph_hop(subgraph, *operands, **kwargs) + + That is: + 1) the HOP stays alive until Inductor + 2) the HOP's semantics are subgraph(*operands) + 3) kwargs may be some config options but aren't passed directly to the subgraph. + + To use this, please subclass this class and override methods as necessary: + ``` + class InvokeQuant(BaseHOP): + def __init__(self): + return super().__init__("invoke_quant") + + invoke_quant = InvokeQuant() + + def g(x): + return x.sin().cos() + + @torch.compile(backend="aot_eager") + def f(x): + return invoke_quant(g, x, scheme="nf4") + ``` + + NOTE: don't subclass BaseHOP out of tree! That is not allowed. All + usages must be in tree. + """ + + def __init__(self, hop_name) -> None: + super().__init__(hop_name) + + # Set up the registrations + # If you want to override any of these, override them in your subclass. + self.py_autograd_impl(self._call_Autograd) + self.py_functionalize_impl(self._call_Functionalize) + self.py_impl(ProxyTorchDispatchMode)(self._call_ProxyTorchDispatchMode) + self.py_impl(FakeTensorMode)(self._call_FakeTensorMode) + self.py_impl(DispatchKey.CompositeExplicitAutograd)( + self._call_CompositeExplicitAutograd + ) + + def __call__(self, subgraph, *operands, **kwargs): + if not isinstance(subgraph, (torch.fx.GraphModule, FunctionWithNoFreeVars)): + raise RuntimeError( + f"{self._name}: when calling this API without torch.compile, " + f"we require that the subgraph be a torch.fx.GraphModule (or " + f"a function we know doesn't have free variables)." + ) + return super().__call__(subgraph, *operands, **kwargs) + + def _call_Autograd(self, subgraph, *operands, **kwargs): + if isinstance(subgraph, torch.fx.GraphModule): + pass + + # We assume the subgraph doesn't mutate inputs and there is no aliasing. + # In the PT2 stack, this is Dynamo's responsibility to figure out. + return BaseHOPFunction.apply(self, subgraph, kwargs, *operands) + + def _call_CompositeExplicitAutograd(self, subgraph, *operands, **kwargs): + from torch.utils._python_dispatch import _get_current_dispatch_mode + + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return subgraph(*operands) + + def _call_ProxyTorchDispatchMode(self, proxy_mode, subgraph, *operands, **kwargs): + traced_graph = reenter_make_fx(subgraph)(*operands) + assert isinstance(proxy_mode.tracer, torch.fx.Tracer) + qualname = proxy_mode.tracer.get_fresh_qualname("subgraph") + proxy_mode.tracer.root.register_module(qualname, traced_graph) + + node_args = (traced_graph, *operands) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) # type: ignore[attr-defined] + proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, kwargs) # type: ignore[attr-defined] + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", self, proxy_args, proxy_kwargs + ) + + out = self(subgraph, *operands, **kwargs) + return track_tensor_tree( + out, out_proxy, constant=None, tracer=proxy_mode.tracer # type: ignore[arg-type] + ) + + def _call_FakeTensorMode(self, mode, subgraph, *operands, **kwargs): + # TODO: this should probably route through FakeTensorMode to reuse caching + with mode: + return subgraph(*operands) + + # NOTE [Support input mutation of hops] + # To support input mutation, hop's subgraph must be functionalized because many inductor passes are + # applied to subgraph recursively and only work on functional graph. However, we could inline an + # epilogue graph (i.e. the copy_) into the subgraph because this is how input mutation + # is implemented in the top-level graph when no hop is presented. All passes must have been and will be + # aware of the epilogue graph. + # + # Since we've supported input mutation for custom op with auto_functionalized, we share the infra for hops + # The plan is: + # 1. In hop's Functionalization key, it calls do_auto_functionalize_v2 if subgraph mutates input + # 2. In do_auto_functionalize_v2: + # a. we functionalize the callables in hop's argument. This is to make the subgraphs functional so we + # could recursively run passes on them. Also the epilogue graph is inlined at the end. + # b. we call auto_functionalized_v2 and pass in an additional schema in order to properly invoke + # the hop with normalized kwargs. + # 3. In inductor, we decompose the auto_functionalized hop by callilng into the dense implementation, which + # copies the mutated inputs to the hop if necessary and call the hop. + # After these steps, the rest of the inductor stack knows how to fuse the copy_ in subgraph with other ops. + def _call_Functionalize(self, ctx, subgraph, *operands, **kwargs): + from torch._higher_order_ops.auto_functionalize import ( + can_auto_functionalize, + do_auto_functionalize_v2, + ) + + # invoke_quant has non-proxable argument of type InvokeQuant that + # we cannot generate schema for. + if self is not torch.ops.higher_order.invoke_quant_packed: + hop_instance = HopInstance.create(self, subgraph, *operands, **kwargs) + if can_auto_functionalize(hop_instance): + return do_auto_functionalize_v2( + ctx.mode, hop_instance, (subgraph, *operands), kwargs + ) + + unwrapped_operands = ctx.unwrap_tensors(operands) + with ctx.redispatch_to_next(): + # We assume the subgraph doesn't mutate inputs and there is no aliasing. + # In the PT2 stack, this is Dynamo's responsibility to figure out. + functionalized_subgraph = FunctionWithNoFreeVars( + ctx.functionalize(subgraph) + ) + out = self(functionalized_subgraph, *unwrapped_operands, **kwargs) + return ctx.wrap_tensors(out) + + def gen_schema(self, subgraph, *operands, **kwargs): + from .schema import HopSchemaGenerator + + if not isinstance(subgraph, torch.fx.GraphModule): + subgraph = materialize_as_graph(subgraph, operands) + + fake_args = [ + ph.meta["example_value"] if "example_value" in ph.meta else ph.meta["val"] + for ph in subgraph.graph.find_nodes(op="placeholder") + ] + ( + inp_inp_alias, + inp_out_alias, + out_out_alias, + mutated_inp_idx, + output, + ) = check_input_alias_and_mutation_return_outputs(subgraph, fake_args) + + if not ( + len(inp_inp_alias) == 0 + and len(inp_out_alias) == 0 + and len(out_out_alias) == 0 + ): + # TODO: turn this into an error. + # test_foreach_map_backward_binary_foreach_map_addrecip_op fails the alias test. + import warnings + + warnings.warn( + "Aliasing is not suppported for HOP subgraph.\n" + f"{subgraph.print_readable(print_output=False)}\n" + f"Alias info: inp-inp alias: {inp_inp_alias}, inp-out alias: {inp_out_alias}, out-out alias{out_out_alias}" + f"This may lead to silent incorrectness." + ) + + schema_gen = HopSchemaGenerator(self) + schema_gen.add_arg("subgraph", subgraph) + for idx, arg in enumerate(operands): + schema_gen.add_arg(f"arg{idx}", arg, is_mutated=idx in mutated_inp_idx) + + for name, arg in kwargs.items(): + schema_gen.add_arg(name, arg, default_value=arg, kw_only=True) + + for out in output: + schema_gen.add_output(out) + + return schema_gen.gen_schema() + + +class BaseHOPFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, hop, subgraph, kwargs, *operands): + ctx.hop = hop + ctx.operands = operands + ctx.subgraph = subgraph + ctx.kwargs = kwargs + + with torch._C._AutoDispatchBelowAutograd(): + return hop(subgraph, *operands, **kwargs) + + @staticmethod + def backward(ctx, *grad_outputs): + subgraph = ctx.subgraph + operands = ctx.operands + kwargs = ctx.kwargs + + # TODO: Something special needs to happen with min cut partitioner + with suspend_functionalization(), disable_functional_mode(), torch.enable_grad(): + with disable_proxy_modes_tracing(): + from .invoke_subgraph import create_fw_bw_graph + from .utils import _from_fun + + fw_inputs = pytree.tree_map(_from_fun, operands) + ( + _, + joint_graph, + _, + ) = create_fw_bw_graph(subgraph, fw_inputs, grad_outputs) + + # The joint graph returns (*grad_inputs, *fwd_outputs). + # We only need the grad_inputs. + def bwd_fn(*args): + operands = args[: -len(grad_outputs)] + grad_outs = args[-len(grad_outputs) :] + result = joint_graph(*operands, *grad_outs) + grad_inputs = result[: -len(grad_outputs)] + return grad_inputs + + return ( + None, + None, + None, + *ctx.hop( + FunctionWithNoFreeVars(bwd_fn), *operands, *grad_outputs, **kwargs + ), + ) + + +class FunctionWithNoFreeVars: + def __init__(self, fn): + self.fn = fn + + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/cond.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/cond.py new file mode 100644 index 0000000000000000000000000000000000000000..68d189d26bf99e8a6387095cdac7bd0684d0af8a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/cond.py @@ -0,0 +1,732 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import contextlib +import logging +import warnings +from typing import Any, Callable, Optional, Union + +import torch +import torch._subclasses.functional_tensor +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._C._functorch import ( + _add_batch_dim, + get_unwrapped, + is_batchedtensor, + maybe_get_bdim, +) +from torch._functorch.utils import exposed_in +from torch._higher_order_ops.utils import ( + _maybe_run_with_interpreter, + _set_compilation_env, + materialize_as_graph, + reenter_make_fx, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, + unique_graph_id, + validate_subgraph_args_types, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + _temp_remove_pre_dispatch_torch_function_mode, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import _get_current_dispatch_mode + +from .utils import clone_outputs_aliasing_inputs + + +log = logging.getLogger(__name__) + +""" +We're going to define a `cond_op` operation. +In order to do this, we need implementations for each of the dispatch keys. +""" + + +class CondOp(HigherOrderOperator): + def __init__(self): + super().__init__("cond") + + def __call__(self, pred, true_fn, false_fn, operands): + validate_subgraph_args_types(operands) + return super().__call__(pred, true_fn, false_fn, operands) + + +cond_op = CondOp() + + +@exposed_in("torch") +def cond( + pred: Union[bool, int, float, torch.Tensor], + true_fn: Callable, + false_fn: Callable, + operands: Union[tuple, list] = (), +) -> Any: + r""" + Conditionally applies `true_fn` or `false_fn`. + + .. warning:: + `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + `cond` is structured control flow operator. That is, it is like a Python if-statement, + but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be + capturable using torch.compile and torch.export. + + Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following:: + + def cond(pred, true_branch, false_branch, operands): + if pred: + return true_branch(*operands) + else: + return false_branch(*operands) + + Args: + pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element, + indicating which branch function to apply. + + true_fn (Callable): A callable function (a -> b) that is within the + scope that is being traced. + + false_fn (Callable): A callable function (a -> b) that is within the + scope that is being traced. The true branch and false branch must + have consistent input and outputs, meaning the inputs have to be + the same, and the outputs have to be the same type and shape. Int + output is also allowed. We'll make the output dynamic by turning it + into a symint. + + operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the + true/false functions. It can be empty if true_fn/false_fn doesn't require input. Defaults to (). + + Example:: + + def true_fn(x: torch.Tensor): + return x.cos() + def false_fn(x: torch.Tensor): + return x.sin() + return cond(x.shape[0] > 4, true_fn, false_fn, (x,)) + + Restrictions: + - The conditional statement (aka `pred`) must meet one of the following constraints: + + - It's a `torch.Tensor` with only one element, and torch.bool dtype + + - It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10` + + - The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints: + + - The function signature must match with operands. + + - The function must return a tensor with the same metadata, e.g. shape, + dtype, etc. + + - The function cannot have in-place mutations on inputs or global variables. + (Note: in-place tensor operations such as `add_` for intermediate results + are allowed in a branch) + + """ + if torch.compiler.is_dynamo_compiling(): + return cond_op(pred, true_fn, false_fn, operands) + + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + + if isinstance(pred, (bool, int, float)): + # This is the non-strict export case. Strict export and torch.compile are + # handled above in dynamo. + if torch.compiler.is_compiling(): + warnings.warn( + "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." + " If you want torch.cond to preserve two branches, please make the predicate a boolean tensor or a SymBool.", + UserWarning, + ) + # This is the eager case. We can just run the true or false branch. + if pred: + return true_fn(*operands) + else: + return false_fn(*operands) + + def _validate_input(pred, true_fn, false_fn, operands): + if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)): + raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.") + + if isinstance(pred, torch.Tensor) and pred.numel() != 1: + raise RuntimeError( + f"Expected pred to be bool or single-element tensor, but got {pred}." + ) + + if not callable(true_fn) or not callable(false_fn): + raise RuntimeError("Expect both branches to be callable.") + + if not isinstance(operands, (tuple, list)) or pytree.tree_any( + lambda t: not isinstance(t, torch.Tensor), operands + ): + raise RuntimeError( + "Expect operands to be a tuple of possibly nested dict/list/tuple that only " + f"consists of tensor leaves, but got {operands}." + ) + + _validate_input(pred, true_fn, false_fn, operands) + + if not torch._dynamo.is_dynamo_supported(): + raise RuntimeError("torch.cond requires dynamo support.") + + # Dynamo is expecting a callable with "__code__" attribute. + # We cannot directly pass cond_op to it. So we wrap it in a dummy function. + def _cond_op_wrapper(*args, **kwargs): + return cond_op(*args, **kwargs) + + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( + pred, true_fn, false_fn, operands + ) + + +def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable: + """ + For a fn that accepts flat inputs and returns flat outputs: + fw_out = fn(*args), + this function returns: + grad_args = bw_fn(*args_and_grad_output) + with the following invariants: + 1. args + fw_out has an 1-1 correspondence to args_and_grad_output + 2. grad_args has an 1-1 corresponsence to args + 3. for tensor arg whose requires_grad is False, its corresponding grad in + grad_args will be a zero tensor with the same shape. + """ + + from torch._functorch.aot_autograd import AOTConfig, create_joint + from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + n_primals = len(args) + + bw_fn = create_joint( + prepare_fw_with_masks_all_requires_grad(fn), aot_config=dummy_aot_config + ) + + def flat_fn(*args_and_grad_outs): + primals = args_and_grad_outs[:n_primals] + tangents = args_and_grad_outs[n_primals:] + grad_args = bw_fn(primals, tangents)[1] + assert len(args) == len(grad_args) + # In order to keep HOPs functional where the backward graph, + # would have outputs that are aliasing inputs. + # For example in cases where the backward of the function is simply + # passing the upstream gradients through. + maybe_clone = clone_outputs_aliasing_inputs(args_and_grad_outs) + + return [ + ( + torch.zeros_like(arg) + if isinstance(arg, torch.Tensor) and grad is None + else maybe_clone(grad) + ) + for grad, arg in zip(grad_args, primals) + ] + + return flat_fn + + +def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): + assert isinstance( + operands, (list, tuple) + ), f"Cond operands must be a list or tuple of tensors and SymInts {operands}" + + true_graph = reenter_make_fx(true_fn)(*operands) + false_graph = reenter_make_fx(false_fn)(*operands) + + true_outs = [] + false_outs = [] + for node in true_graph.graph.nodes: + if node.op == "output": + true_outs.extend(node.args) + + for node in false_graph.graph.nodes: + if node.op == "output": + false_outs.extend(node.args) + + flat_true_outs = pytree.arg_tree_leaves(*true_outs) + flat_false_outs = pytree.arg_tree_leaves(*false_outs) + if len(flat_true_outs) != len(flat_false_outs): + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected to return same number of outputs but got:" + f"\n true branch returns {len(flat_true_outs)} item(s)" + f"\n false branch returns {len(flat_false_outs)} item(s)" + ) + + i, true_name = unique_graph_id(proxy_mode, prefix="true_graph") + + false_name = f"false_graph_{i}" + assert not hasattr(proxy_mode.tracer.root, false_name) + + proxy_mode.tracer.root.register_module(true_name, true_graph) + proxy_mode.tracer.root.register_module(false_name, false_graph) + + args = (pred, true_graph, false_graph, operands) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {} + ) + + out = func_overload(pred, true_graph, false_graph, operands) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def cond_op_dense(pred, true_fn, false_fn, operands): + assert all( + isinstance(o, (torch.Tensor, int)) for o in operands + ), f"Dense implementation operands must be a list of tensors and ints {operands}" + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + if pred: + return true_fn(*operands) + else: + return false_fn(*operands) + + +class CondAutogradOp(torch.autograd.Function): + @staticmethod + def forward( + ctx, + pred, + true_fn, + false_fn, + *operands, + ): + ctx._pred = pred + ctx._true_bw_fn = create_bw_fn( + true_fn, + operands, + ) + ctx._false_bw_fn = create_bw_fn( + false_fn, + operands, + ) + # We snapshot the dispatch keys in forward for materializing the + # the bw_graph in backward. + ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set() + ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + save_tensors_and_symints_for_backward(ctx, operands) + + with torch._C._AutoDispatchBelowAutograd(): + return cond_op(pred, true_fn, false_fn, operands) + + @staticmethod + def backward(ctx, *flat_grads): + operands = saved_tensors_and_symints(ctx) + args = operands + flat_grads + # TODO: we need to materialize the bw graphs because dynamo is unable to + # trace through the joint funcion when torch.compile torch.autograd.grad. + true_bw_gm = materialize_as_graph( + ctx._true_bw_fn, + args, + ctx._fw_include_key_set, + ctx._fw_exclude_key_set, + force_enable_grad=True, + ) + false_bw_gm = materialize_as_graph( + ctx._false_bw_fn, + args, + ctx._fw_include_key_set, + ctx._fw_exclude_key_set, + force_enable_grad=True, + ) + grads = cond_op( + ctx._pred, + true_bw_gm, + false_bw_gm, + args, + ) + return None, None, None, *grads + + +# Note: +# As long as one of the tensors in pred or operands requires grad, +# all the output would require grad with backward fn set to be the CondAutogradOp. +# This is consistent with autograd.Function's semantic. +@cond_op.py_autograd_impl +def cond_autograd(pred, true_fn, false_fn, operands): + return CondAutogradOp.apply( + pred, + true_fn, + false_fn, + *operands, + ) + + +@cond_op.py_impl(ProxyTorchDispatchMode) +def inner(mode, pred, true_fn, false_fn, operands): + return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands) + + +@cond_op.py_impl(FakeTensorMode) +def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands): + # Ignore here, because if you've gotten here but you're not manually + # tracing the inner graphs, that means that you intend to reuse the graph + # directly. Which means the old unbacked symbol bindings are appropriate. + # This strategy will not work if unbacked symbols can escape. + ignore_fresh_unbacked = contextlib.nullcontext() + if mode.shape_env: + ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols() + + with mode, ignore_fresh_unbacked: + flat_true_outs, true_out_spec = pytree.tree_flatten(true_fn(*operands)) + flat_false_outs, false_out_spec = pytree.tree_flatten(false_fn(*operands)) + if true_out_spec != false_out_spec: + raise RuntimeError( + "Unmatched output spec from torch.cond branches: " + f"true branch tree_spec {true_out_spec} vs false branch tree_spec {false_out_spec}." + ) + + merged_outs = [] + for true_out, false_out in zip(flat_true_outs, flat_false_outs): + merged_outs.append(_merge_output(true_out, false_out, mode)) + return pytree.tree_unflatten(merged_outs, true_out_spec) + + +def check_tensor_meta_match( + t1: torch.Tensor, t2: torch.Tensor, attr_names: tuple[str, ...], msg_prefix: str +) -> None: + def _get_attr_maybe_call(t: torch.Tensor, attr_name: str) -> Any: + attr = getattr(t, attr_name) + if callable(attr): + return attr() + return attr + + for attr_name in attr_names: + lattr = _get_attr_maybe_call(t1, attr_name) + rattr = _get_attr_maybe_call(t2, attr_name) + torch._check( + lattr == rattr, + lambda: f"{msg_prefix} expected same {attr_name} but got {lattr} and {rattr}.", + ) + + +def _merge_output( + a: Optional[Union[torch.Tensor, int]], + b: Optional[Union[torch.Tensor, int]], + mode: FakeTensorMode, +): + from torch.fx.experimental.symbolic_shapes import ( + has_free_unbacked_symbols, + SymIntEqByExpr, + ) + + if a is None or b is None: + assert a is None and b is None, (a, b) + return None + + def min_max(s0, s1): + def _bound(s0, lower_bound: bool): + if isinstance(s0, int): + return s0 + r = mode.shape_env.var_to_range.get( # type: ignore[union-attr] + s0.node.expr, + torch.utils._sympy.value_ranges.ValueRanges.unknown(), + ) + return r.lower if lower_bound else r.upper + + return min(_bound(s0, True), _bound(s1, True)), max( + _bound(s0, False), _bound(s1, False) + ) + + if type(a) is int and type(b) is int: + if a == b: + return a + assert mode.shape_env is not None + merged_out = mode.shape_env.create_unbacked_symint() + mode.shape_env.constrain_symbol_range(merged_out.node.expr, *min_max(a, b)) + return merged_out + + assert type(a) is FakeTensor and type(b) is FakeTensor, (a, type(a), b, type(b)) + + # Note: we don't check size, stride because + # they'll be merged with unbacked symints if they differ. + _meta_to_check = { + "dtype", + "device", + "layout", + "dim", + "is_quantized", + "is_conj", + "is_sparse", + "storage_offset", + } + check_tensor_meta_match( + a, + b, + tuple(_meta_to_check), + msg_prefix="When merging two branches' output in torch.cond, ", + ) + # NYI + assert not a.is_quantized and not b.is_quantized + assert not a.is_sparse and not b.is_sparse + assert not a.is_conj() and not b.is_conj() + + """ + Step 1: create unbacked symints for sizes that are different + along the same axis. For example: + a.size is [s0, 4, s0, 5, 4, 5] + b.size is [s1, 4, s2, 8, 4, 7] + merged_size will be [u0, 4, u1, u2, 4, u3], where + u0 has range [min(s0, s1), max(s0, s1)] + u1 has range [min(s0, s2), max(s0, s2)] + u2 has range [5, 8] + u3 has range [5, 7] + """ + merged_size: list[Union[int, torch.SymInt]] = [] + + def _has_unbacked_symbols(s: Union[int, torch.SymInt]) -> bool: + if isinstance(s, int): + return False + else: + return has_free_unbacked_symbols(s.node.expr) + + for s0, s1 in zip(a.size(), b.size()): + # If there are unbacked symbols leaked out of true_branch or false_branch + # we need to merge them with a new unbacked symbol and track in parent graph. + if ( + not _has_unbacked_symbols(s0) + and not _has_unbacked_symbols(s1) + and SymIntEqByExpr(s0) == SymIntEqByExpr(s1) + ): + merged_size.append(s0) + else: + assert mode.shape_env is not None + new_size = mode.shape_env.create_unbacked_symint() + mode.shape_env.constrain_symbol_range(new_size.node.expr, *min_max(s0, s1)) + merged_size.append(new_size) + + """ + This follows the logic in symbolic_shapes._compute_symbolic_stride + Step 2: Since tensor stride is an accumulative muliplication of the sizes, which is a permutated + (due to view ops) non-decending sequence. + + Case 1: No size is 1. In this case, strides have unique values. + For example, suppose we have a tenosr with: + size [3, 4, 3, 5, 4, 5], + stride (1200, 300, 1, 12, 3, 60), + merged_size [u0, u1, u2, u3, u4, u5]. + + We visit the strides in ascending order: 1, 3, 12, 60, 300, 1200. In each step, we check whether + the current stride is bounded or not and bound next stride by setting. + stride_expr[next_stride] = current_stride_expr * current_size_expr + 1st round: + current_stride is 1, current_size is 3, so next_stride is 1 * 3 = 3, + current_stride_expr is set to 1, current_size_expr is u2, so stride_expr[3] is therefore 1 * u2 = u2 + 2nd round: + current_stride is 3, current_size is 4, so next_stride is 3 * 4 = 12, + current_stride_expr is stride_expr[3] i.e. u2, current_size_expr is u4, so stride_expr[12] = u2 * u4 + ... + + Case 2: At least one dimension has size 1, which can produce duplicates in strides. + In this case, theorectically, we cannot uniquely determine the expr of strides because + the accessing stride_expr with same key in different order causes the final stride expression + to be different. + + Suppose we have: + size: (3, 1) + stride: (1, 1) + merged_size: (u0, u1) + + The stride expr could either be (u1, 1) or (1, u0) depending on whether we start with u1 or u0. + For this reason, we try to break tie by sorting via decending index so we always get (u1, 1). + + Note that backend might optimize the strides anyway so this is usually not a problem as long + as two branches matches. See relevant discussions in https://github.com/pytorch/pytorch/issues/142024. + + Case 3: Dim has 0 stride. 0 stride doesn't participate in the accumulative multiplication of + sizes. So they're always treated as constant even if their corresponding size is turned into unbacked symint. + + Suppose we have: + size: (3, 3) + stride: (0, 1) + merged_size: (u0, u1) + + The merged stride would be (0, 1) + """ + + def _bound_stride( + a_ex_size: torch.Size, + b_ex_size: torch.Size, + a_ex_stride: tuple[int, ...], + b_ex_stride: tuple[int, ...], + merged_size: list[Union[int, torch.SymInt]], + ) -> list[Union[int, torch.SymInt]]: + from torch._inductor.ir import get_stride_order + + a_sorted_stride_idx = get_stride_order(a_ex_stride, mode.shape_env) + b_sorted_stride_idx = get_stride_order(b_ex_stride, mode.shape_env) + + a_stride_li: list[Optional[tuple[Union[int, torch.SymInt], int]]] = [ + None + ] * len(a_ex_stride) + b_stride_li: list[Optional[tuple[Union[int, torch.SymInt], int]]] = [ + None + ] * len(b_ex_stride) + for i, idx in enumerate(a_sorted_stride_idx): + a_stride_li[idx] = (a_ex_stride[i], -i) + for i, idx in enumerate(b_sorted_stride_idx): + b_stride_li[idx] = (b_ex_stride[i], -i) + + for a_pair, b_pair in zip(a_stride_li, b_stride_li): + assert a_pair is not None and b_pair is not None + _, a_idx = a_pair + _, b_idx = b_pair + + if a_idx != b_idx: + raise RuntimeError( + f"The sorted order of strides of the two branches' output doesn't match." + f"this indicates the contiguousness of the two branches are different. " + f"True branch has stride {a_ex_stride} but false branch has stride {b_ex_stride}." + f"Consider using contiguous() to make the two branches have the same contiguousness." + ) + + def _maybe_expr(s: Union[int, torch.SymInt]): + if isinstance(s, int): + return s + return s.node.expr + + a_stride_expr: dict[Any, Union[int, torch.SymInt]] = {} + b_stride_expr: dict[Any, Union[int, torch.SymInt]] = {} + merged_strides: list[Union[int, torch.SymInt]] = [None] * len(a_ex_stride) # type: ignore[list-item] + for a_pair, b_pair in zip(a_stride_li, b_stride_li): + assert a_pair is not None and b_pair is not None + a_val, neg_i = a_pair + b_val, _ = b_pair + + i = -neg_i + if a_val == 0: + assert b_val == 0, (a_val, b_val) + merged_strides[i] = 0 + continue + + if _maybe_expr(a_val) in a_stride_expr: + a_expr = a_stride_expr[_maybe_expr(a_val)] + assert ( + b_stride_expr[_maybe_expr(b_val)] == a_expr + ), f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}" + merged_strides[i] = a_expr + else: + if a_val == 1: + assert b_val == 1 + a_stride_expr[_maybe_expr(a_val)] = 1 + b_stride_expr[_maybe_expr(b_val)] = 1 + merged_strides[i] = 1 + else: + # If we cannot find the expr of a_val in a_stride_expr, it means + # the strides is not a simple accumulative multiplication of sizes. + # In this case, we cannot determine the expr of strides from the new + # shapes so we error out and hint users to call contiguous(). + raise RuntimeError( + f"It seems one of cond's output stride is not a simple accumulative multiplication of sizes. " + f"This could be because cond returns a slice of a tensor, which is not dense in memory. " + f"True branch has size {a_ex_size}, stride {a_ex_stride} and false branch has size {b_ex_size} " + f"stride {b_ex_stride}. Hint: can call t.contiguous(). " + ) + nxt_merged_stride_expr = merged_strides[i] * merged_size[i] + a_stride_expr[_maybe_expr(a_val * a_ex_size[i])] = nxt_merged_stride_expr + b_stride_expr[_maybe_expr(b_val * b_ex_size[i])] = nxt_merged_stride_expr + return merged_strides + + merged_stride: list[Union[int, torch.SymInt]] = _bound_stride( + a.size(), b.size(), a.stride(), b.stride(), merged_size + ) + + with mode: + return torch.empty_strided( + merged_size, merged_stride, dtype=a.dtype, device=a.device + ) + + +@cond_op.py_functionalize_impl +def cond_func(ctx, pred, true_fn, false_fn, inputs): + from torch._higher_order_ops.utils import _check_alias_and_mutation + + unwrapped_inputs = ctx.unwrap_tensors(inputs) + unwrapped_pred = ctx.unwrap_tensors(pred) + with ctx.redispatch_to_next(): + functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn)) + functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn)) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + for branch, branch_name in [(true_fn, "cond_true"), (false_fn, "cond_false")]: + _check_alias_and_mutation( + branch, unwrapped_inputs, branch_name, pre_dispatch + ) + + cond_return = cond_op( + unwrapped_pred, functional_true, functional_false, unwrapped_inputs + ) + return ctx.wrap_tensors(cond_return) + + +@cond_op.py_impl(torch._C._functorch.TransformType.Vmap) +def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs): + assert isinstance( + inputs, (list, tuple) + ), "Cond inputs must be a list or tuple of tensors" + assert all( + isinstance(i, torch.Tensor) for i in inputs + ), "Cond inputs must be a list of tensors" + + pred_is_batched = isinstance(pred, torch.Tensor) and is_batchedtensor(pred) + pred_ = get_unwrapped(pred) if pred_is_batched else pred + + # unbatched tensors are not vmapped + tensors, in_dims = zip( + *[ + (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None) + for t in inputs + ] + ) + + if pred_is_batched: + # prepend "pred" and vmap everything + tensors = (pred_,) + tensors + in_dims = (0,) + in_dims + + def fn(p, *args): + t = true_fn(*args) + f = false_fn(*args) + return torch.where(p, t[0], f[0]) + + with interpreter.lower(): + result = torch.vmap(fn, in_dims=in_dims)(*tensors) + + else: + # predicate is known at this stage and it is a boolean expression or a + # tensor with one element. + true_fn = torch.vmap(true_fn, in_dims=in_dims) + false_fn = torch.vmap(false_fn, in_dims=in_dims) + + with interpreter.lower(): + result = cond_op(pred, true_fn, false_fn, tensors) + + if not isinstance(result, tuple): + result = (result,) + lvl = interpreter.level() + return tuple([_add_batch_dim(r, 0, lvl) for r in result]) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/effects.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/effects.py new file mode 100644 index 0000000000000000000000000000000000000000..c0eed6ef73abee73c71002d3c434685d27242609 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/effects.py @@ -0,0 +1,301 @@ +# mypy: allow-untyped-defs +from enum import Enum +from typing import Any, Optional, Union +from weakref import WeakKeyDictionary + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.torchbind import call_torchbind +from torch._library.fake_class_registry import FakeScriptObject +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +class _EffectType(Enum): + ORDERED = "Ordered" + + +OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload] + + +SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType]( + [ + (torch.ops.aten._print.default, _EffectType.ORDERED), + (call_torchbind, _EffectType.ORDERED), + ] +) + + +def _register_effectful_op(op: OpType, effect: _EffectType): + assert isinstance( + op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ) and not has_aliasing(op) + if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect: + raise RuntimeError( + f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, " + f"trying to register a different effect type {effect}." + ) + SIDE_EFFECTS[op] = effect + + +def _deregister_effectful_op(op: OpType): + if op not in SIDE_EFFECTS: + raise RuntimeError(f"Op {op} is not registered as effectful") + + del SIDE_EFFECTS[op] + + +class WithEffects(HigherOrderOperator): + """ + with_effects(token, op, args, kwargs) -> (new_token, op_results) + + This HOP helps ensure ordering between side effectful ops like prints or ops + using torchbind objects. This is needed to ensure a traced graph from + AOTAutograd is functional so that future optimization passes do not reorder + these operators. This is done through threading "effect tokens" through the + graph to enforce data dependence between side effectful ops. + + The tokens are basically dummy values (torch.tensor([])). We create a token + per "effect type", which are enumerated in the _EffectType enum. + """ + + def __init__(self) -> None: + super().__init__("with_effects") + + def __call__( + self, + token, + op: OpType, + *args: tuple[Any, ...], + **kwargs: dict[str, Any], + ) -> tuple[Any, ...]: + assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) + assert not has_aliasing(op), "Ops with aliasing is not supported" + assert has_effects(op, args, kwargs) + assert isinstance(kwargs, dict) + return super().__call__(token, op, *args, **kwargs) + + +with_effects = WithEffects() + + +def has_aliasing(op: OpType): + # NOT FOR PUBLIC USE + if isinstance(op, torch._ops.HigherOrderOperator): + return op not in SIDE_EFFECTS + + for arg in op._schema.arguments: + if arg.alias_info is not None: + return True + for arg in op._schema.returns: + if arg.alias_info is not None: + return True + return False + + +def has_effects(op, args, kwargs) -> bool: + # Skip over the profiler's RecordFunction as they should not show up in the graph + _skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction} + if op in _skip_ops: + return False + + return ( + isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) + and not has_aliasing(op) + and get_effect_key(op, args, kwargs) is not None + ) + + +def get_effect_key(op, args, kwargs) -> Optional[_EffectType]: + if op in SIDE_EFFECTS: + return SIDE_EFFECTS[op] + + for arg in args: + if isinstance(arg, (torch.ScriptObject, FakeScriptObject)): + # Add it to the table so that next time we see the same op we don't + # have to parse through the args again + SIDE_EFFECTS[op] = _EffectType.ORDERED + return _EffectType.ORDERED + + for arg in kwargs.values(): + if isinstance(arg, (torch.ScriptObject, FakeScriptObject)): + # Add it to the table so that next time we see the same op we don't + # have to parse through the args again + SIDE_EFFECTS[op] = _EffectType.ORDERED + return _EffectType.ORDERED + + return None + + +def new_token_tensor() -> torch.Tensor: + return torch.tensor([]) + + +@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd) +def with_effects_dense( + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: tuple[Any, ...], + **kwargs: dict[str, Any], +) -> tuple[torch.Tensor, ...]: + out = op(*args, **kwargs) + new_token = new_token_tensor() + # [NOTE: with_effects return type] + # Note that we should only do *out for tuple type, but not list type. + # This is to match the schema of the op. + # For tuple output, the length of schema output is the same as the length of out. + # For list output, the length of schema output is 1 (e.g. Tensor[]) regardless of the + # length of the list. + if isinstance(out, tuple): + return (new_token, *out) + return (new_token, out) + + +@with_effects.py_impl(FakeTensorMode) +def with_effects_fake( + mode, + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: tuple[Any, ...], + **kwargs: dict[str, Any], +) -> tuple[torch.Tensor, ...]: + with mode: + result = with_effects_dense(token, op, *args, **kwargs) + return result + + +@with_effects.py_impl(ProxyTorchDispatchMode) +def with_effects_proxy( + mode, + token: torch.Tensor, + op: torch._ops.OpOverload, + *args: tuple[Any, ...], + **kwargs: dict[str, Any], +) -> tuple[torch.Tensor, ...]: + with disable_proxy_modes_tracing(): + out = with_effects(token, op, *args, **kwargs) + + proxy_token = mode.tracer.unwrap_proxy(token) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + from torch.fx.node import has_side_effect + + # To avoid the being DCEed by graph.eliminate_dead_code if they. + # don't have output or their outputs are not used. + has_side_effect(op) + + out_proxy = mode.tracer.create_proxy( + "call_function", + with_effects, + (proxy_token, op, *proxy_args), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +with_effects.fallthrough(DispatchKey.AutogradCPU) +with_effects.fallthrough(DispatchKey.AutogradCUDA) + + +def _get_schema(op, args) -> torch.FunctionSchema: + if isinstance(op, torch._ops.OpOverload): + return op._schema + elif op == call_torchbind: + return getattr(args[0], args[1]).schema + else: + raise RuntimeError(f"Unable to get schema for op {op}") + + +def handle_effects( + allow_token_discovery: bool, + tokens: dict[_EffectType, torch.Tensor], + op: OpType, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + """ + Args: + allow_token_discovery: Whether or not we are discovering tokens. If this + is true, we will create a token for every side effect type seen that + does not have a token assigned yet. If this is false, the tokens + should've all been created ahead of time, so we will error if there is + no token mapping to every effect type. + + tokens: Map of effect type to tokens. This is to chain operators of the + same effects together so that they do not get reordered in later + optimization passes. + """ + + # Get a token. We can't do `tokens.get(op, torch.tensor([]))` because + # this will create an empty tensor during proxy mode tracing if the token + # doesn't exist. But the tokens should always exist during proxy mode tracing. + key = get_effect_key(op, args, kwargs) + assert key is not None + if key not in tokens: + assert ( + allow_token_discovery + ), f"Could not find a token for effect {key} which came from the function {op}" + proxy_tensor_mode = torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.PROXY + ) + if proxy_tensor_mode is not None: + # If we discovered a new token during tracing, we are in backward. + # Then we patch the graph, adding additional tangents_token as input to the joint graph. + tracer = proxy_tensor_mode.tracer + + from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + track_tensor_tree, + ) + + with disable_proxy_modes_tracing(): + token_tensor = new_token_tensor() + + token_proxy = proxy_tensor_mode.tracer.create_proxy( + "placeholder", "tangents_token", (), {}, name="tangents_token" + ) + track_tensor_tree(token_tensor, token_proxy, constant=None, tracer=tracer) + + tokens[key] = token_tensor + else: + tokens[key] = new_token_tensor() + + token = tokens[key] + + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + unwrapped_token = ctx.unwrap_tensors([token])[0] + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type] + with ctx.redispatch_to_next(): + (new_token, *unwrapped_outs) = with_effects( + unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs + ) + + schema = _get_schema(op, unwrapped_args) + if len(schema.returns) == 0: + assert unwrapped_outs[0] is None + unwrapped_outs = None # type: ignore[assignment] + elif len(schema.returns) == 1: + assert len(unwrapped_outs) == 1 + unwrapped_outs = unwrapped_outs[0] + else: + assert len(unwrapped_outs) == len(schema.returns) + + # Add the newly created token into the tokens map for a following call to + # use this token. + wrapped_token = ctx.wrap_tensors(new_token) + assert isinstance(wrapped_token, torch.Tensor) + tokens[key] = wrapped_token + + return ctx.wrap_tensors(unwrapped_outs) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/executorch_call_delegate.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/executorch_call_delegate.py new file mode 100644 index 0000000000000000000000000000000000000000..807e603bc93b96b4f82bac2889d9316bb0e67936 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/executorch_call_delegate.py @@ -0,0 +1,175 @@ +# mypy: allow-untyped-defs + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from typing import Any, cast + +import torch +import torch.utils._pytree as pytree +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + get_proxy_slot, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._pytree import tree_flatten + + +class ExecutorchCallDelegate(HigherOrderOperator): + def __init__(self): + super().__init__("executorch_call_delegate") + + def __call__(self, lowered_module, *args): + return super().__call__(lowered_module, *args) + + +executorch_call_delegate = ExecutorchCallDelegate() +executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) +executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot) +executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView) +executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU) + +LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule" + + +# pyre-ignore +def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args): + # pyre-ignore + def _unwrap_proxy(e): + if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): + return e + return get_proxy_slot( + cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy # type: ignore[attr-defined] + ) + + if not is_lowered_module(lowered_module): + raise ValueError( + "executorch_call_delegate()'s first argument must be a LoweredBackendModule" + ) + + with disable_proxy_modes_tracing(): + out = call_delegate_cpu(lowered_module, *args) + + get_lowered_module_name(proxy_mode.tracer.root, lowered_module) + + node_args = (lowered_module, *args) + proxy_args = pytree.tree_map(_unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="executorch_call_delegate" + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) +# pyre-ignore +def call_delegate_cpu(lowered_module, *args): + # FX creates this immutable_dict/list concept. Get rid of this. + map_types: dict[type, type] = { + torch.fx.immutable_collections.immutable_dict: dict, + torch.fx.immutable_collections.immutable_list: list, + } + new_args = pytree.tree_map_only( + tuple(map_types.keys()), + lambda a: map_types[type(a)](a), + args, + lambda a: isinstance(a, tuple(map_types.keys())), + ) + return lowered_module.original_module.module()(*new_args) + + +@executorch_call_delegate.py_autograd_impl +# pyre-ignore +def call_delegate_autograd(lowered_module, *args): + # TODO: support autograd + flat_operands, _ = tree_flatten([lowered_module, *args]) + requires_grad = any( + f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) + ) + + with torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) + ): + res = executorch_call_delegate(lowered_module, *args) + + if requires_grad: + # Create aliases of the output that has requires_grad=True. We need + # at least one of the inputs to err_fn to require grad so that the + # output will have a grad_fn. + + # pyre-ignore + def fake_requires_grad(var): + if var is not None: + var = var.detach() + if torch.is_floating_point(var) or torch.is_complex(var): + var.requires_grad = True + return var + + return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res) + + return res + + +@executorch_call_delegate.py_impl(ProxyTorchDispatchMode) +# pyre-ignore +def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args): + res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args) + return res + + +@executorch_call_delegate.py_impl(FakeTensorMode) +# pyre-ignore +def call_delegate_fake_tensor_mode(mode, lowered_module, *args): + with mode: + return call_delegate_cpu(lowered_module, *args) + + +@executorch_call_delegate.py_functionalize_impl +# pyre-ignore +def call_delegate_functionalize(ctx, lowered_module, *args): + unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) + with ctx.redispatch_to_next(): + res = executorch_call_delegate(lowered_module, *unwrapped_args) + return ctx.wrap_tensors(res) + + +# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre +def is_lowered_module(obj: Any) -> bool: + """ + This function is added to avoid using isinstance(obj, + LoweredBackendModule) as it will import LoweredBackendModule, which may + cause a circular import. + """ + return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE + + +def get_lowered_module_name( + root: torch.nn.Module, + # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. + lowered_module: LOWERED_BACKEND_MODULE_TYPE, # type: ignore[valid-type] +) -> str: + """ + Adds the given lowered_module into the given root module and returns the + name of the module added. + """ + # Find a qualifying name for the lowered submodule + qualname = None + i = 0 + while True: + qualname = f"lowered_module_{i}" + if not hasattr(root, qualname): + break + i += 1 + assert qualname is not None + + root.add_module(qualname, lowered_module) + return qualname diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/flat_apply.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/flat_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..044580026ac8e4704d58250e8bb9118205a39f12 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/flat_apply.py @@ -0,0 +1,125 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import Callable + +import torch +import torch.fx.node +import torch.utils._pytree as pytree +from torch._ops import HigherOrderOperator + + +def is_graphable(val) -> bool: + """Definition: a graphable type is a type that that is an acceptable input/output type to a FX node.""" + return isinstance(val, torch.fx.node.base_types) + + +def is_graphable_type(typ) -> bool: + """Return whether the given type is graphable""" + return issubclass(typ, torch.fx.node.base_types) + + +def to_graphable(stuff): + """Flattens stuff into a flat list of graphable types.""" + # We can consider preserving things like List[int] to improve + # perf and readability (right now that is all flattened out) + flat_args, spec = pytree.tree_flatten(stuff) + for arg in flat_args: + if not is_graphable(arg): + raise RuntimeError( + f"Expected all pytree.tree_leaves of (args, kwargs) to be graphable types, but found " + f"non-fx-graphable type {type(arg)}. If this type is meant to be constant, mark it as " + f"via pytree.register_constant; otherwise, register it as a pytree." + ) + return flat_args, spec + + +def from_graphable(flat_args, spec): + """The inverse of to_graphable.""" + stuff = pytree.tree_unflatten(flat_args, spec) + return stuff + + +def func_to_graphable(func): + """ + Pack and flatten a function type into graphable types. + This is useful for legalizing the function argument of `flat_apply`. + """ + return pytree.tree_flatten(_ConstantFunction(func)) + + +@dataclass(frozen=True) +class _ConstantFunction: + func: Callable + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +pytree.register_constant(_ConstantFunction) + +_op_types = ( + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + torch._ops.HigherOrderOperator, +) + + +class FlatApply(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("flat_apply") + + def __call__(self, func, in_spec, *flat_args, **_unused): + """ + Functions that take in non-graphable types cannot directly be put into FX graph. + + Given func(*args, **kwargs), if all of the non-graphable types are pytrees, + then we're able to store a call to flat_apply(func, in_spec, *flat_args) in the FX graph. + + The semantics of flat_apply(func, in_spec, *flat_args) are roughly equivalent to: + + >>> def flat_apply_impl(func, in_spec, *flat_args): + >>> args, kwargs = pytree.tree_unflatten(flat_args, in_spec) + >>> output = func(*args, **kwargs) + >>> return output + + flat_apply supports the following two cases: + - an input type is a container type (e.g. of tensors) registered as a pytree. + We'll tree_flatten the input type and store the spec. + - an input type is a constant type (i.e. torch.compile will specialize on it) + registered with pytree.register_constant. The constant type goes directly + into the spec. + + """ + assert isinstance(func, _op_types) or pytree._is_constant_holder(func) + assert len(_unused) == 0 + return impl(func, in_spec, *flat_args) + + +def impl(func, in_spec, *flat_args): + if not isinstance(func, _op_types): + # assume _ConstantFunction + func = pytree._retrieve_constant(func) + assert isinstance(func, _ConstantFunction) + + args, kwargs = from_graphable(flat_args, in_spec) + out = func(*args, **kwargs) + + # Right now, all outputs must either be graphable or lists/tuples of graphables. + # + # TODO: The following can be updated to support non-graphable outputs and pytrees. + # For non-graphable constant outputs: the assumption would be that they are constant + # (everytime the function runs those MUST be the same) + # For pytree outputs: + # I'm not sure if we need to return (flat_output, spec) or just (flat_output,): + # in the latter case the tracers need to carry out the output specs + # (they need to know how to reconstruct the object from just the flat_output). + def is_valid_output(x): + if isinstance(x, (tuple, list)): + return all(map(is_valid_output, x)) + return is_graphable(x) + + assert is_valid_output(out) + return out + + +flat_apply = FlatApply() diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/flex_attention.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f273c763c21e3f96a07a94eeff8205c9fced1bf8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/flex_attention.py @@ -0,0 +1,1268 @@ +import math +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch import Tensor +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_mutation, + _maybe_reenter_make_fx, + autograd_not_implemented, + has_user_subclass, + redirect_to_mode, + reenter_make_fx, + register_fake, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, + UnsupportedAliasMutationException, + validate_subgraph_args_types, +) +from torch._ops import HigherOrderOperator +from torch._subclasses import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx.experimental.proxy_tensor import ( + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.graph_module import GraphModule +from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode + + +# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import +def _construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def _permute_strides(out: torch.Tensor, query_strides: tuple[int, ...]) -> torch.Tensor: + """ + Create a new tensor with the same data and shape as the input, + but with strides permuted based on the input tensor's stride order. + + Args: + out (torch.Tensor): The output tensor of attention. + query_strides (List[int]): The stride order of the input query tensor + + Returns: + torch.Tensor: A new tensor with same shape and data as the input, + but with strides permuted based on the query tensor's stride order. + """ + from torch._inductor.ir import get_fill_order + + fill_order = get_fill_order(query_strides) + assert out.storage_offset() == 0, "Only support storage_offset == 0" + out_strides = _construct_strides(out.shape, fill_order) + new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides) + new_out.copy_(out) + return new_out + + +class FlexAttentionHOP(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("flex_attention", cacheable=True) + + def __call__( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), + ) -> tuple[torch.Tensor, torch.Tensor]: + validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers) + return super().__call__( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + +flex_attention = FlexAttentionHOP() + + +class FlexAttentionBackwardHOP(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("flex_attention_backward") + + def __call__( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] + ]: + validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers) + return super().__call__( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + +flex_attention_backward = FlexAttentionBackwardHOP() + + +def _math_attention_inner( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + + working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32 + + scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision) + + b = torch.arange(0, scores.size(0), device=scores.device) + h = torch.arange(0, scores.size(1), device=scores.device) + m = torch.arange(0, scores.size(2), device=scores.device) + n = torch.arange(0, scores.size(3), device=scores.device) + + captured_buffers_in_dim = (None,) * len(score_mod_other_buffers) + from torch.nn.attention.flex_attention import _vmap_for_bhqkv + + # first input is score + score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim) + + mask_mod = block_mask[-1] + mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers) + mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers) + + with TransformGetItemToIndex(): + scores = (scores * scale).to(working_precision) + post_mod_scores = torch.where( + mask_mod(b, h, m, n, *mask_mod_other_buffers), + score_mod(scores, b, h, m, n, *score_mod_other_buffers), + torch.tensor(-float("inf"), dtype=working_precision, device=scores.device), + ) + + return scores, post_mod_scores + + +def math_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + """Eager implementation + + This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions. + We then apply the vectorized score_mod function to the scores matrix. Each wrap of vmap applies one of the + batch, head, m, or n dimensions. We need to apply vmap 4 times to vectorized over all 4 dimensions. + + Args: + query: The query tensor + key: The key tensor + value: The value tensor + score_mod: The score_mod function + other_buffers: Other buffers that are passed to the score_mod function + """ + # broadcast query & key along head dim for GQA + G = query.size(1) // key.size(1) + value = torch.repeat_interleave(value, G, dim=1) + key = torch.repeat_interleave(key, G, dim=1) + + Bq, Bkv = query.size(0), key.size(0) + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + + key = key.expand((Bq, *key.size()[1:])) + value = value.expand((Bq, *value.size()[1:])) + + _, post_mod_scores = _math_attention_inner( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + # Set fully masked rows' sumexp to 0.0 + logsumexp = post_mod_scores.logsumexp(dim=-1) + masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1) + logsumexp = torch.where(masked_rows, -float("inf"), logsumexp) + + post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1) + + return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2) + + +@flex_attention.py_impl(DispatchKey.CompositeExplicitAutograd) +def sdpa_dense( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + out, lse = math_attention( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + out = _permute_strides(out, query.stride()) + return out, lse + + +def trace_flex_attention( + proxy_mode: ProxyTorchDispatchMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + """Traces the flex_attention operator with the given score_mod function and other_buffers. + + Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function + This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We + access this graph module in inductor to inline the score_mod function to the triton template. + """ + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + + example_out = flex_attention( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + example_vals = [query.new_zeros((), requires_grad=query.requires_grad)] + [ + query.new_zeros((), dtype=torch.int) for _ in range(4) + ] + mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)] + mask_mod = block_mask[-1] + with TransformGetItemToIndex(): + score_graph = reenter_make_fx(score_mod)( + *example_vals, *score_mod_other_buffers + ) + mask_graph = reenter_make_fx(mask_mod)( + *mask_example_vals, *mask_mod_other_buffers + ) + assert isinstance(proxy_mode.tracer, torch.fx.Tracer) + block_mask = block_mask[:-1] + (mask_graph,) + qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_score") + proxy_mode.tracer.root.register_module(qualname, score_graph) + mask_qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_mask") + proxy_mode.tracer.root.register_module(mask_qualname, mask_graph) + node_args = ( + query, + key, + value, + score_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", flex_attention, proxy_args, {} + ) + return track_tensor_tree( + example_out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +@flex_attention.py_impl(ProxyTorchDispatchMode) +def flex_attention_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + assert mode is not None, "Mode should always be enabled for python fallback key" + return trace_flex_attention( + mode, + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + +@flex_attention.py_functionalize_impl +def flex_attention_functionalize( + ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + """Defines the functionalization rules for the flex_attention operator. + + Write now we are unwrapping each tensor and then redispatching to the next, however we want to + guard against any mutations in the score_mod function, to the other_buffers since those + are free variables. + """ + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + + if has_user_subclass( + ( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ), + allowed_subclasses=(FakeTensor, FunctionalTensor), + ): + return NotImplemented + + query_unwrapped = ctx.unwrap_tensors(query) + key_unwrapped = ctx.unwrap_tensors(key) + value_unwrapped = ctx.unwrap_tensors(value) + block_mask_unwrapped = ctx.unwrap_tensors(block_mask) + score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers) + mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers) + + # Appease the mypy overlords + assert isinstance(query_unwrapped, torch.Tensor) + assert isinstance(key_unwrapped, torch.Tensor) + assert isinstance(value_unwrapped, torch.Tensor) + assert isinstance(block_mask_unwrapped, tuple) + assert isinstance(score_mod_other_buffers_unwrapped, tuple) + assert isinstance(mask_mod_other_buffers_unwrapped, tuple) + + example_vals = ( + [query_unwrapped.new_zeros(())] + + [query_unwrapped.new_zeros((), dtype=torch.int) for _ in range(4)] + + list(score_mod_other_buffers_unwrapped) + ) + with ctx.redispatch_to_next(): + functional_score_mod = ctx.functionalize(score_mod) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + with TransformGetItemToIndex(): + # TODO: So far only the input mutations are checked + # In the other HOPs, also aliases are checked which is + # omitted here + mutates = _has_potential_branch_input_mutation( + score_mod, example_vals, pre_dispatch + ) + # The only care about mutations of existing buffers since we can't replay these. + # However, we can just error if anything is detected + if mutates: + raise UnsupportedAliasMutationException("Mutations detected in score_mod") + + out = flex_attention( + query_unwrapped, + key_unwrapped, + value_unwrapped, + functional_score_mod, + block_mask_unwrapped, + scale, + kernel_options, + score_mod_other_buffers_unwrapped, + mask_mod_other_buffers_unwrapped, + ) + return ctx.wrap_tensors(out) # type: ignore[return-value, arg-type] + + +@register_fake(flex_attention) +def flex_attention_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[torch.Tensor, torch.Tensor]: + if has_user_subclass( + ( + query, + key, + value, + score_mod, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ), + allowed_subclasses=(FakeTensor,), + ): + return NotImplemented + + # TODO: Figure out a better way to handle this for NJT than using sum() + if query.is_nested: + out = torch.empty_like(query, memory_format=torch.contiguous_format) + logsumexp = query.sum(dim=-1) + return out, logsumexp + + v_head_dim = value.size(-1) + batch_size, num_heads, seq_len_q, _q_head_dim = query.shape + logsumexp = query.new_empty(batch_size, num_heads, seq_len_q, dtype=torch.float32) + out_shape = (batch_size, num_heads, seq_len_q, v_head_dim) + out = query.new_empty(out_shape) + out = _permute_strides(out, query.stride()) + return out, logsumexp + + +# Registers dispatches for SAC +redirect_to_mode(flex_attention, _CachingTorchDispatchMode) +redirect_to_mode(flex_attention, _CachedTorchDispatchMode) + + +# ---------------------------- Autograd Implementation ---------------------------- +def create_fw_bw_graph( + score_mod: Callable, + index_values: tuple[Tensor, Tensor, Tensor, Tensor, Tensor], + other_buffers: tuple[Tensor, ...], +) -> tuple[Callable, Callable]: + # See Note:[HOP create fw_bw graph] + + # All of these imports need to be here in order to avoid circular dependencies + from torch._dispatch.python import suspend_functionalization + from torch._functorch.aot_autograd import AOTConfig, create_joint + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + from torch._subclasses.functional_tensor import disable_functional_mode + from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + + def _from_fun( + t: Union[Tensor, torch.SymInt, int], + ) -> Union[Tensor, torch.SymInt, int]: + if isinstance(t, torch.Tensor): + return torch.empty_strided( + t.size(), + t.stride(), + device=t.device, + dtype=t.dtype, + requires_grad=t.requires_grad, + ) + return t + + # If someone runs this hop under the default compiler backend ("eager") + # Then this path will be run with the actual user inputs. We convert them + # to fake tensors in order to not perform any actual compute. + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(index_values) + if fake_mode is None: + fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + + with fake_mode: + unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values) + unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers) + + assert all( + isinstance(t, (FakeTensor, int, torch.SymInt)) + for t in unwrapped_score_mod_indexes + unwrapped_other_buffers + ) + + example_flat_out = pytree.tree_map( + _from_fun, + score_mod(*unwrapped_score_mod_indexes, *unwrapped_other_buffers), + ) + if not isinstance(example_flat_out, torch.Tensor): + raise RuntimeError( + "Expected output of score_mod to be a tensor." + f"Got type {type(example_flat_out)}." + ) + example_grad = _from_fun(example_flat_out) + + def joint_f( + score: Tensor, + b: Tensor, + h: Tensor, + m: Tensor, + n: Tensor, + example_grad: Tensor, + *other_buffers: tuple[Tensor, ...], + ) -> tuple[Tensor, ...]: + def fw_with_masks( + *args: tuple[Tensor, ...] + ) -> tuple[tuple[Tensor], tuple[bool]]: + fw_out = score_mod(*args) + out_requires_grad = fw_out.requires_grad + return ((fw_out,), (out_requires_grad,)) + + joint = create_joint(fw_with_masks, aot_config=dummy_aot_config) + args = [score, b, h, m, n] + list(other_buffers) + optional_grad = [example_grad] if example_grad.requires_grad else [] + _, grads = joint(args, optional_grad) + + return grads + + joint_graph = make_fx(joint_f)( + *unwrapped_score_mod_indexes, example_grad, *unwrapped_other_buffers + ) + return score_mod, joint_graph + + +class FlexAttentionAutogradOp(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + query: Tensor, + key: Tensor, + value: Tensor, + fw_graph: Callable, + joint_graph: Callable, + block_mask: tuple[Any, ...], + scale: float, + kernel_options: dict[str, Any], + mask_mod_other_buffers: tuple[Any, ...], + *score_mod_other_buffers: tuple[Any, ...], + ) -> tuple[torch.Tensor, torch.Tensor]: + any_buffer_requires_grad = any( + buffer.requires_grad + for buffer in mask_mod_other_buffers + if isinstance(buffer, torch.Tensor) + ) + assert ( + not any_buffer_requires_grad + ), "Captured buffers from mask mod that require grad are not supported." + ctx._fw_graph = fw_graph + ctx._joint_graph = joint_graph + ctx._mask_graph = block_mask[-1] + ctx.scale = scale + ctx.kernel_options = kernel_options + ctx._score_mod_other_buffers_len = len(score_mod_other_buffers) + with torch._C._AutoDispatchBelowAutograd(): + out, logsumexp = flex_attention( + query, + key, + value, + fw_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + save_tensors_and_symints_for_backward( + ctx, + ( + query, + key, + value, + out, + logsumexp, + *block_mask[:-1], + *score_mod_other_buffers, + *mask_mod_other_buffers, + ), + ) + return out, logsumexp + + @staticmethod + def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> tuple[Optional[Tensor], ...]: # type: ignore[override] + fw_args = saved_tensors_and_symints(ctx) + ( + query, + key, + value, + out, + logsumexp, + query_lengths, + kv_lengths, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, + *other_buffers, + ) = fw_args + fw_graph = ctx._fw_graph + joint_graph = ctx._joint_graph + mask_graph = ctx._mask_graph + scale = ctx.scale + kernel_options = ctx.kernel_options + score_mod_other_buffers = tuple( + other_buffers[: ctx._score_mod_other_buffers_len] + ) + mask_mod_other_buffers = tuple( + other_buffers[ctx._score_mod_other_buffers_len :] + ) + # We have asserted that mask_mod_other_buffers do not require grad, + # but score_mod_other_buffers can require grad. + none_grads = [None] * 6 + ( + grad_query, + grad_key, + grad_value, + grad_score_mod_captured, + ) = flex_attention_backward( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + ( + query_lengths, + kv_lengths, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, + mask_graph, + ), + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + return grad_query, grad_key, grad_value, *none_grads, *grad_score_mod_captured + + +# TODO: Rework DispatchKey.Autograd to py_autograd_impl +@flex_attention.py_impl(DispatchKey.Autograd) +def flex_attention_autograd( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + score_mod: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple[Tensor, ...] = (), + mask_mod_other_buffers: tuple[Tensor, ...] = (), +) -> tuple[torch.Tensor, torch.Tensor]: + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + + with TransformGetItemToIndex(): + input_requires_grad = any( + isinstance(t, torch.Tensor) and t.requires_grad + for t in (query, key, value, *score_mod_other_buffers) + ) + if torch.is_grad_enabled() and input_requires_grad: + example_vals = ( + query.new_zeros((), requires_grad=input_requires_grad), + query.new_zeros((), dtype=torch.int), + query.new_zeros((), dtype=torch.int), + query.new_zeros((), dtype=torch.int), + query.new_zeros((), dtype=torch.int), + ) + fw_graph, bw_graph = create_fw_bw_graph( + score_mod, example_vals, score_mod_other_buffers + ) + else: + fw_graph, bw_graph = score_mod, None + out, logsumexp = FlexAttentionAutogradOp.apply( + query, + key, + value, + fw_graph, + bw_graph, + block_mask, + scale, + kernel_options, + mask_mod_other_buffers, + *score_mod_other_buffers, + ) + return out, logsumexp + + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +@flex_attention_backward.py_impl(DispatchKey.CompositeExplicitAutograd) +def sdpa_dense_backward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Callable, # GraphModule type hint? + joint_graph: Callable, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple, + mask_mod_other_buffers: tuple, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] +]: + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + + Bq, Hq, seq_len_q, qk_head_dim = query.shape + Bkv, Hkv, seq_len_kv, v_head_dim = value.shape + + # Get outputs before calling repeat interleave and permute to input stride orders + actual_grad_query = query.new_empty((Bq, Hq, seq_len_q, qk_head_dim)) + actual_grad_query = _permute_strides(actual_grad_query, query.stride()) + + actual_grad_key = key.new_empty((Bq, Hkv, seq_len_kv, qk_head_dim)) + actual_grad_key = _permute_strides(actual_grad_key, key.stride()) + + actual_grad_value = value.new_empty((Bq, Hkv, seq_len_kv, v_head_dim)) + actual_grad_value = _permute_strides(actual_grad_value, value.stride()) + + def _maybe_new_buffer( + buffer: Union[torch.Tensor, torch.SymInt, int], + ) -> Optional[Union[torch.Tensor, torch.SymInt, int]]: + if isinstance(buffer, torch.Tensor): + return ( + torch.empty_like(buffer, memory_format=torch.contiguous_format) + if buffer.requires_grad + else None + ) + return buffer + + actual_grad_score_mod_captured = [ + _maybe_new_buffer(buffer) for buffer in score_mod_other_buffers + ] + + Bq, Bkv = query.size(0), key.size(0) + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + + key = key.expand((Bq, *key.size()[1:])) + value = value.expand((Bq, *value.size()[1:])) + + G = query.size(1) // key.size(1) + key = torch.repeat_interleave(key, G, dim=1) + value = torch.repeat_interleave(value, G, dim=1) + + # We're undoing the log -> log2 change of base in the forwards + logsumexp = logsumexp * math.log(2) + # The backwards formula for the log -> log2 change of base in the forwards + grad_logsumexp = grad_logsumexp / math.log(2) + scores, post_mod_scores = _math_attention_inner( + query, + key, + value, + fw_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + masked_out_rows = logsumexp == -float("inf") + softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1)) + softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores) + + grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out + + grad_softmax_scores = grad_out @ value.transpose(-2, -1) + + sum_scores = torch.sum(out * grad_out, -1, keepdim=True) + grad_score_mod = softmax_scores * ( + grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1) + ) + + b = torch.arange(0, scores.size(0), device=scores.device) + h = torch.arange(0, scores.size(1), device=scores.device) + m = torch.arange(0, scores.size(2), device=scores.device) + n = torch.arange(0, scores.size(3), device=scores.device) + + mask_graph = block_mask[-1] + # Gradient of the inline score_mod function, with respect to the scores + captured_buffers_in_dim = (None,) * len(score_mod_other_buffers) + out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers) + from torch.nn.attention.flex_attention import _vmap_for_bhqkv + + # inputs are [score, b, h, q_idx, kv_idx, gradOut, ...] + # score and gradOut are "fully" batched + joint_score_mod = _vmap_for_bhqkv( + joint_graph, + prefix=(0,), + suffix=(0,) + captured_buffers_in_dim, + out_dims=out_dims, + ) + with TransformGetItemToIndex(): + grad_scores, _, _, _, _, *grad_score_mod_captured = joint_score_mod( + scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers + ) + grad_scores = grad_scores * scale + grad_scores = grad_scores.to(query.dtype) + + mask_mod = _vmap_for_bhqkv( + mask_graph, prefix=(), suffix=(None,) * len(mask_mod_other_buffers) + ) + with TransformGetItemToIndex(): + mask_scores = mask_mod(b, h, m, n, *mask_mod_other_buffers) + grad_scores = torch.where( + mask_scores, grad_scores, torch.tensor(0, dtype=query.dtype) + ) + + grad_query = grad_scores @ key + grad_key = grad_scores.transpose(-2, -1) @ query + + # Reduce DK, DV along broadcasted heads. + grad_key = grad_key.view( + grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1) + ) + grad_value = grad_value.view( + grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1) + ) + + grad_key = torch.sum(grad_key, 2, keepdim=False) + grad_value = torch.sum(grad_value, 2, keepdim=False) + + # Fill to correctly strided outputs + actual_grad_query.copy_(grad_query) + actual_grad_key.copy_(grad_key) + actual_grad_value.copy_(grad_value) + + if Bq != Bkv: + assert ( + Bq > 1 and Bkv == 1 + ), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}" + + actual_grad_key = torch.sum(actual_grad_key, 0, keepdim=True) + actual_grad_value = torch.sum(actual_grad_value, 0, keepdim=True) + + score_mod_other_buffer_grads = [ + actual_grad.copy_(grad) if isinstance(actual_grad, torch.Tensor) else None + for actual_grad, grad in zip( + actual_grad_score_mod_captured, grad_score_mod_captured + ) + ] + + return ( + actual_grad_query, + actual_grad_key, + actual_grad_value, + tuple(score_mod_other_buffer_grads), + ) + + +def trace_flex_attention_backward( + proxy_mode: ProxyTorchDispatchMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] +]: + """We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs""" + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + + example_out = flex_attention_backward( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + requires_grad = any(pytree.tree_map(lambda x: x.requires_grad, (query, key))) + fw_example_vals = [query.new_zeros((), requires_grad=requires_grad)] + [ + query.new_zeros((), dtype=torch.int) for _ in range(4) + ] + bw_example_vals = fw_example_vals + [query.new_zeros(())] + mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)] + mask_graph = block_mask[-1] + with TransformGetItemToIndex(): + # There's no active make_fx during the compiled autograd graph's initial capture + fw_graph = _maybe_reenter_make_fx(fw_graph)( + *fw_example_vals, *score_mod_other_buffers + ) + joint_graph = _maybe_reenter_make_fx(joint_graph)( + *bw_example_vals, *score_mod_other_buffers + ) + mask_graph = _maybe_reenter_make_fx(mask_graph)( + *mask_example_vals, *mask_mod_other_buffers + ) + assert isinstance(proxy_mode.tracer, torch.fx.Tracer) + block_mask = block_mask[:-1] + (mask_graph,) + + qualname = proxy_mode.tracer.get_fresh_qualname("fw_graph") + proxy_mode.tracer.root.register_module(qualname, fw_graph) # type: ignore[arg-type] + qualname = proxy_mode.tracer.get_fresh_qualname("joint_graph") + proxy_mode.tracer.root.register_module(qualname, joint_graph) + qualname = proxy_mode.tracer.get_fresh_qualname("mask_graph") + proxy_mode.tracer.root.register_module(qualname, mask_graph) + + node_args = ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", + flex_attention_backward, + proxy_args, + {}, + name="flex_attention_backward", + ) + return track_tensor_tree( + example_out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +@flex_attention_backward.py_impl(ProxyTorchDispatchMode) +def flex_attention_backward_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] +]: + assert mode is not None, "Mode should always be enabled for python fallback key" + return trace_flex_attention_backward( + mode, + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + +@flex_attention_backward.py_functionalize_impl +def flex_attention_backward_functionalize( + ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] +]: + """Defines the functionalization rules for the flex_attention operator. + + Write now we are unwrapping each tensor and then redispatching to the next, + since we know that the forward score mod function is assured to be free of mutations + to the other_buffers, we skip that mutate check and go straight to redispatching. + """ + + if has_user_subclass( + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ), + allowed_subclasses=(FakeTensor, FunctionalTensor), + ): + return NotImplemented + query_unwrapped = ctx.unwrap_tensors(query) + key_unwrapped = ctx.unwrap_tensors(key) + value_unwrapped = ctx.unwrap_tensors(value) + out_unwrapped = ctx.unwrap_tensors(out) + logsumexp_unwrapped = ctx.unwrap_tensors(logsumexp) + grad_out_unwrapped = ctx.unwrap_tensors(grad_out) + grad_logsumexp_unwrapped = ctx.unwrap_tensors(grad_logsumexp) + block_mask_unwrapped = ctx.unwrap_tensors(block_mask) + score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers) + mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers) + + # Appease the mypy overlords + assert isinstance(query_unwrapped, torch.Tensor) + assert isinstance(key_unwrapped, torch.Tensor) + assert isinstance(value_unwrapped, torch.Tensor) + assert isinstance(out_unwrapped, torch.Tensor) + assert isinstance(logsumexp_unwrapped, torch.Tensor) + assert isinstance(grad_out_unwrapped, torch.Tensor) + assert isinstance(grad_logsumexp_unwrapped, torch.Tensor) + assert isinstance(block_mask_unwrapped, tuple) + assert isinstance(score_mod_other_buffers_unwrapped, tuple) + assert isinstance(mask_mod_other_buffers_unwrapped, tuple) + + with ctx.redispatch_to_next(): + functional_fw_graph = ctx.functionalize(fw_graph) + functional_joint_graph = ctx.functionalize(joint_graph) + + ( + grad_query, + grad_key, + grad_value, + grad_score_mod_captured, + ) = flex_attention_backward( + query_unwrapped, + key_unwrapped, + value_unwrapped, + out_unwrapped, + logsumexp_unwrapped, + grad_out_unwrapped, + grad_logsumexp_unwrapped, + functional_fw_graph, # type: ignore[arg-type] + functional_joint_graph, # type: ignore[arg-type] + block_mask_unwrapped, + scale, + kernel_options, + score_mod_other_buffers_unwrapped, + mask_mod_other_buffers_unwrapped, + ) + + return ctx.wrap_tensors((grad_query, grad_key, grad_value, grad_score_mod_captured)) # type: ignore[return-value,arg-type] + + +@register_fake(flex_attention_backward) +def flex_attention_backward_fake_tensor_mode( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + grad_out: torch.Tensor, + grad_logsumexp: torch.Tensor, + fw_graph: Union[Callable, GraphModule], + joint_graph: GraphModule, + block_mask: tuple, + scale: float, + kernel_options: dict[str, Any], + score_mod_other_buffers: tuple = (), + mask_mod_other_buffers: tuple = (), +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] +]: + if has_user_subclass( + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ), + allowed_subclasses=(FakeTensor,), + ): + return NotImplemented + Bq, _, _, qk_head_dim = query.shape + Bkv, Hkv, seq_len_kv, v_head_dim = value.shape + + grad_query = torch.empty_like(query) + # zeros_and_scatter creates a contiguous zeros tensor -> contiguous_format + grad_score_mod_captured = tuple( + [ + ( + torch.empty_like(buffer, memory_format=torch.contiguous_format) + if isinstance(buffer, torch.Tensor) and buffer.requires_grad + else None + ) + for buffer in score_mod_other_buffers + ] + ) + + broadcasted_grad_key = key.new_empty((Bq, Hkv, seq_len_kv, qk_head_dim)) + broadcasted_grad_key = _permute_strides(broadcasted_grad_key, key.stride()) + + broadcasted_grad_value = value.new_empty((Bq, Hkv, seq_len_kv, v_head_dim)) + broadcasted_grad_value = _permute_strides(broadcasted_grad_value, value.stride()) + + if Bq > 1 and Bkv == 1: + grad_key = torch.sum(broadcasted_grad_key, dim=0, keepdim=True) + grad_value = torch.sum(broadcasted_grad_value, dim=0, keepdim=True) + else: + grad_key = broadcasted_grad_key + grad_value = broadcasted_grad_value + + return grad_query, grad_key, grad_value, grad_score_mod_captured + + +flex_attention_backward.py_autograd_impl( + autograd_not_implemented(flex_attention_backward, deferred_error=True) +) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/foreach_map.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/foreach_map.py new file mode 100644 index 0000000000000000000000000000000000000000..b960a4fea5d10374266337fc9551c02b7666ec27 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/foreach_map.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from typing import Any, Callable + +from torch._higher_order_ops.base_hop import BaseHOP, FunctionWithNoFreeVars + + +class ForeachMap(BaseHOP): + def __init__(self): + super().__init__("foreach_map") + + def __call__(self, fn, *operands, **kwargs): # type: ignore[override] + fn = FunctionWithNoFreeVars(fn) + return super().__call__(fn, *operands, **kwargs) + + +_foreach_map = ForeachMap() + + +def foreach_map(op: Callable, *operands: Any, **kwargs: dict[str, Any]): + from torch._dynamo.polyfills import foreach_map_fn + + return _foreach_map(foreach_map_fn, op, *operands, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/hints_wrap.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/hints_wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbe8c5434a626b614e4e216c67c70687da1eac0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/hints_wrap.py @@ -0,0 +1,142 @@ +# mypy: allow-untyped-defs +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + autograd_not_implemented, + reenter_make_fx, + unique_graph_id, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree + + +# used for wrapping a function/op with context hints +class HintsWrapper(HigherOrderOperator): + def __init__(self): + super().__init__("hints_wrapper") + + def __call__(self, body_fn, args, kwargs, hints): + r""" + Call implementation of hints_wrapper + + Args: + body_fn (Callable): A callable function that is within the scope + that is being traced. + + args (Tuple of torch.Tensor/int/float/bool): A tuple of inputs to + body_fn. + + kwargs (dict): Keyword argument to the body_fn. + + hints (dict): A dict of context hints which could be passed to + backend compiler. + """ + if not isinstance(args, tuple): + raise RuntimeError(f"args must be a tuple, got {type(args)}") + + if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args): + raise RuntimeError( + "args must be a tuple of tensors, ints, floats, or bools, got " + f"{args}" + ) + + if not isinstance(kwargs, dict): + raise RuntimeError(f"kwargs must be a dict, got {type(kwargs)}") + + if len(kwargs) > 0: + raise RuntimeError( + f"kwargs except for hints are not supported, got {kwargs}" + ) + + if not isinstance(hints, dict): + raise RuntimeError(f"hints must be a dict, got {type(hints)}") + + for k, v in hints.items(): + if not isinstance(k, str): + raise RuntimeError(f"hints key must be a str, got {k}.") + + if not isinstance(v, (int, float, bool, str)): + raise RuntimeError( + "hints must be a dict containing int, float, bool or str " + f"value, got value {v} for key {k}." + ) + + return super().__call__(body_fn, args, kwargs, hints) + + +hints_wrapper = HintsWrapper() + + +@hints_wrapper.py_impl(DispatchKey.CompositeExplicitAutograd) +def hints_wrapper_dense(body_fn, args, kwargs, hints): + return body_fn(*args, **kwargs) + + +hints_wrapper.py_autograd_impl( + autograd_not_implemented(hints_wrapper, deferred_error=True) +) + + +@hints_wrapper.py_impl(FakeTensorMode) +def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints): + flat_args = pytree.tree_leaves(args) + with mode: + return body_func(*flat_args, **kwargs) + + +@hints_wrapper.py_functionalize_impl +def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints): + from torch._higher_order_ops.utils import _check_alias_and_mutation + + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + unwrapped_hints = ctx.unwrap_tensors(hints) + with ctx.redispatch_to_next(): + functional_body_fn = ctx.functionalize(body_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + _check_alias_and_mutation( + body_fn, unwrapped_args, "hints_wrapper", pre_dispatch + ) + + outputs = hints_wrapper( + functional_body_fn, + unwrapped_args, + unwrapped_kwargs, + unwrapped_hints, + ) + return ctx.wrap_tensors(outputs) + + +def trace_hints_wrapper(proxy_mode, hints_wrapper, body_fn, args, kwargs, hints): + flat_args = tuple(pytree.tree_leaves(args)) + body_graph = reenter_make_fx(body_fn)(*flat_args, **kwargs) + + _, body_graph_name = unique_graph_id(proxy_mode, prefix="hints_wrapper_body_graph") + proxy_mode.tracer.root.register_module(body_graph_name, body_graph) + + new_args: tuple = (body_graph, flat_args, {}) + # merge hints into kwargs + new_kwargs = {} + new_kwargs["hints"] = hints + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_args) + proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_kwargs) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", hints_wrapper, proxy_args, proxy_kwargs, name="hints_wrapper" + ) + + out = body_fn(*flat_args, **kwargs) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@hints_wrapper.py_impl(ProxyTorchDispatchMode) +def inner(proxy_mode, body_fn, args, kwargs, hints): + if proxy_mode.enable_tracing: + return trace_hints_wrapper( + proxy_mode, hints_wrapper, body_fn, args, kwargs, hints + ) + else: + return hints_wrapper(body_fn, args, kwargs, hints) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/invoke_subgraph.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/invoke_subgraph.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d7cddd36f8a87fa0e0b669197ea7ba28de6dfa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/invoke_subgraph.py @@ -0,0 +1,658 @@ +# mypy: allow-untyped-defs + + +import contextlib +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._dispatch.python import suspend_functionalization +from torch._higher_order_ops.utils import ( + _from_fun, + _maybe_reenter_make_fx, + _set_compilation_env, + clone_outputs_aliasing_inputs, + FunctionalizeCtxWrapper, + get_dummy_aot_autograd_config, + HopInstance, + prepare_fw_with_masks, + reenter_make_fx, + register_fake, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.functional_tensor import disable_functional_mode +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + _temp_remove_pre_dispatch_torch_function_mode, + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.graph_module import GraphModule +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts + + +invoke_subgraph_counter = 0 + + +# During the tracing of the joint graph, we construct this information. This is +# used to filter out grad_outs/tangents in the `backward` method of +# InvokeSubgraphAutogradOp. +@dataclass +class OutputMetadata: + num_fw_outs: Optional[int] = None + indexes_with_none: set[int] = field(default_factory=set) + indexes_with_no_grad: set[int] = field(default_factory=set) + + +class InvokeSubgraphHOP(HigherOrderOperator): + def __init__(self) -> None: + # Invoke subgraph does not have any state, it is just a wrapper over a + # subgraph, so we can safely cache the HOP. + super().__init__("invoke_subgraph", cacheable=True) + # This is used by the fake tensor cache key validator to extract the + # subgraph and iterate over the nodes to find if all nodes are fake + # tensor cacheable. + self.subgraph_indexes = [ + 0, + ] + + # identifier is setup by upper part of the stack. This helps us in + # identifying two invoke_subgraph calls have same subgraph. + def __call__( + self, + subgraph: Union[GraphModule, FunctionalizeCtxWrapper], + identifier: Optional[str], + *operands, + ): + assert identifier is None or isinstance( + identifier, str + ), "identifier must be a None or a string" + + assert all( + isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands + ), f"invoke_subgraph operands must be a list of tensors/ints/SymInts {operands}" + + return super().__call__(subgraph, identifier, *operands) + + def gen_schema(self, subgraph, identifier, *operands): + from torch._higher_order_ops.schema import HopSchemaGenerator + from torch._higher_order_ops.utils import ( + check_input_alias_and_mutation_return_outputs, + materialize_as_graph, + ) + + gm: torch.fx.GraphModule = ( + subgraph + if isinstance(subgraph, torch.fx.GraphModule) + else materialize_as_graph(subgraph, operands) + ) + + schema_gen = HopSchemaGenerator(self) + schema_gen.add_arg("subgraph", gm) + schema_gen.add_arg("identifier", identifier) + ( + _, + _, + _, + mutated_inputs, + outputs, + ) = check_input_alias_and_mutation_return_outputs(gm, operands) + for idx, arg in enumerate(operands): + schema_gen.add_arg(f"arg{idx}", arg, is_mutated=idx in mutated_inputs) + for out in outputs: + schema_gen.add_output(out) + + return schema_gen.gen_schema() + + +invoke_subgraph = InvokeSubgraphHOP() + + +def invoke_subgraph_placeholder(func, *args, **kwargs): + if torch.compiler.is_dynamo_compiling(): + # This is just a placeholder for Dynamo to replace with invoke_subgraph + raise RuntimeError("invoke_subgraph should not be called directly in Dynamo") + + if torch.compiler.is_compiling(): + # For non-strict export tracing, we still want to go through Dynamo + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + + def _invoke_subgraph_placeholder_wrapper(func, args): + return invoke_subgraph_placeholder(func, *args) + + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + + return torch.compile( + _invoke_subgraph_placeholder_wrapper, + backend=backend, + fullgraph=True, + )(func, args) + + return func(*args, **kwargs) + + +def mark_compile_region(fn=None): + """ + This wrapper instructs torch.compile to compile the wrapped region once and + reuse the compiled artifact, instead of the usual way of aggressively + inlining the function. + + Under the hood, it tells TorchDynamo to use InvokeSubgraph HOP for the + region. For PyTorch eager, this is a no-op. + """ + + def wrap(func): + def inner(*args, **kwargs): + # Get the innermost function to avoid nested compile regions + inner_func = func + while hasattr(inner_func, "__marked_compile_region_fn__"): + inner_func = inner_func.__marked_compile_region_fn__ + return invoke_subgraph_placeholder(inner_func, *args, **kwargs) + + inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined] + + return inner + + if fn: + return wrap(fn) + else: + return wrap + + +def get_invoke_subgraph_cache(): + cache = None + if tracing_ctx := torch._guards.TracingContext.try_get(): + cache = tracing_ctx.hop_dispatch_set_cache.get_cache(invoke_subgraph) + return cache + + +# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra +def trace_joint_graph(fn, fw_inputs, fw_outputs): + """ + Naively trace out a joint graph. This simplifies the reconstruction of joint + graph in the min-cut partitioner later on. + """ + from torch._functorch.aot_autograd import create_joint + + dummy_aot_config = get_dummy_aot_autograd_config() + + # This joint_fn is inserted as the backward graph as is. This simplifies the + # min-cut partitioner work later on. + # Input signature - (*primals, *tangents) + # Output signature - (*grads, *fw_outs) + # The output signature is deliberately kept grads first and fw_outs second. + # Having grads first makes the min-cut partitioner HOP graph stitching + # easier. + def joint_fn(*primals_and_tangents): + primals = primals_and_tangents[: len(fw_inputs)] + tangents = primals_and_tangents[len(fw_inputs) :] + + fw_outs, grads = create_joint( + prepare_fw_with_masks(fn), aot_config=dummy_aot_config + )(primals, tangents) + + maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents) + + # return signature is deliberately kept (*grads, *fw_outs). This + # simplifies partitioning work later on. + return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs))) + + primals = list(fw_inputs) + # This assumes that the tangent strides match fw_outputs strides. Check the + # InvokeSubgraphAutogradOp backward op for the contiguous call. + tangents = [_from_fun(out) for out in fw_outputs] + + joint_operands = primals + tangents + + return _maybe_reenter_make_fx(joint_fn)(*joint_operands) + + +# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra +def create_fw_bw_graph(subgraph, operands, grad_outputs=None): + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + # args are functional tensors, generate some example tensors + fw_inputs = pytree.tree_map(_from_fun, operands) + + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(fw_inputs) + context = ( + nullcontext() + if fake_mode is None or fake_mode.shape_env is None + else fake_mode.shape_env.ignore_fresh_unbacked_symbols() + ) + + with context: + fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) + + num_fw_outs = len(fw_outs) + + # Collect the indexes of none in the output to check that the grad + # is None at the corresponding index in the backward. This check is + # performed in the autograd.Function - InvokeSubgraphAutogradOp. + # Also collect the indexes of no_grad in the output to filter out + # the grad_outs in the `backward` method. + output_metadata = OutputMetadata() + + output_metadata.num_fw_outs = num_fw_outs + for idx, fw_out in enumerate(fw_outs): + if fw_out is None: + output_metadata.indexes_with_none.add(idx) + elif not fw_out.requires_grad: + output_metadata.indexes_with_no_grad.add(idx) + + if grad_outputs is None: + # Infer grad_outputs to be the same properties as the fw_outputs + # if they're not passed in + # Although fw_outs are equivalent to grad_outputs for tracing + # purposes, we have to carefully handle the None and fw_out that do + # not have require_grad. At those indexes, we will have None in the + # backward graph. + grad_outputs = fw_outs + grad_outputs = [grad for grad in grad_outputs if grad is not None] + grad_outputs = [grad for grad in grad_outputs if grad.requires_grad] + + # Force grad_out to be contiguous. This is because at runtime, + # grad_out could have different strides than fw_outs. So, we + # force the grad_outs to be contiguous for both tracing and + # runtime. + grad_outputs = [grad.contiguous() for grad in grad_outputs] + + if any( + not isinstance(out, torch.Tensor) + for out in grad_outputs + if out is not None + ): + raise RuntimeError( + "Expect outputs of invoke_subgraph to only contains tensors or None. " + f"Got types {[type(out) for out in grad_outputs]}." + ) + + # Trace the forward subgraph + fw_graph = _maybe_reenter_make_fx(subgraph)(*fw_inputs) + + # Trace the joint graph and assign it to the bwd graph + bw_graph = trace_joint_graph( + subgraph, + fw_inputs, + grad_outputs, + ) + return fw_graph, bw_graph, output_metadata + + +def get_output_metadata(subgraph, *operands): + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + # args are functional tensors, generate some example tensors + fw_inputs = pytree.tree_map(_from_fun, operands) + + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(fw_inputs) + context = ( + nullcontext() + if fake_mode is None or fake_mode.shape_env is None + else fake_mode.shape_env.ignore_fresh_unbacked_symbols() + ) + + with context: + fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) + + num_fw_outs = len(fw_outs) + + # Collect the indexes of none in the output to check that the grad + # is None at the corresponding index in the backward. This check is + # performed in the autograd.Function - InvokeSubgraphAutogradOp. + # Also collect the indexes of no_grad in the output to filter out + # the grad_outs in the `backward` method. + output_metadata = OutputMetadata() + + output_metadata.num_fw_outs = num_fw_outs + for idx, fw_out in enumerate(fw_outs): + if fw_out is None: + output_metadata.indexes_with_none.add(idx) + elif not fw_out.requires_grad: + output_metadata.indexes_with_no_grad.add(idx) + return output_metadata + + +def trace_joint_graph_as_bwd( + subgraph, num_primals, joint_operands, include_key_set, exclude_key_set +): + """ + Naively trace out a joint graph. This simplifies the reconstruction of joint + graph in the min-cut partitioner later on. + """ + from torch._functorch.aot_autograd import create_joint + + dummy_aot_config = get_dummy_aot_autograd_config() + + if isinstance(subgraph, torch.fx.GraphModule): + + def graph_with_interpreter(*args): + # Running graph with interpreter is needed for propagating the stack_trace + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(subgraph).run(*args) + + fn = graph_with_interpreter + else: + fn = subgraph + + # This joint_fn is inserted as the backward graph as is. This simplifies the + # min-cut partitioner work later on. + # Input signature - (*primals, *tangents) + # Output signature - (*grads, *fw_outs) + # The output signature is deliberately kept grads first and fw_outs second. + # Having grads first makes the min-cut partitioner HOP graph stitching + # easier. + def joint_fn(*primals_and_tangents): + primals = primals_and_tangents[:num_primals] + tangents = primals_and_tangents[num_primals:] + + fw_outs, grads = create_joint( + prepare_fw_with_masks(fn), aot_config=dummy_aot_config + )(primals, tangents) + + maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents) + + # return signature is deliberately kept (*grads, *fw_outs). This + # simplifies partitioning work later on. + return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs))) + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + joint_operands = [_from_fun(arg) for arg in joint_operands] + with contextlib.ExitStack() as stack: + stack.enter_context( + torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set), + ) + with torch.enable_grad(): + return _maybe_reenter_make_fx(joint_fn)(*joint_operands) + + +class InvokeSubgraphAutogradOp(torch.autograd.Function): + """ + Saves the subgraph, i.e. original callable, in the forward method. And then + traces out a joint graph in the backward. This delaying of tracing in + backward, also called as lazy backward, ensures that the assumptions about + the grad_out strides and tensor-subclass-ness are already accounted for. + """ + + @staticmethod + def forward( + ctx, + subgraph, + identifier, + output_metadata, + *operands, + ): + # We want to delay the backward graph construction until the backward. + # So in forward, we just run the fw callable as is. And save all the + # information necessary to construct the backward graph in the ctx. + ctx._subgraph = subgraph + ctx._identifier = identifier + ctx._output_metadata = output_metadata + # We snapshot the dispatch keys in forward for materializing the + # the bw_graph in backward. + ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set() + ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + + save_tensors_and_symints_for_backward(ctx, operands) + + with torch._C._AutoDispatchBelowAutograd(): + out = invoke_subgraph( + subgraph, + f"fw_{identifier}", + *operands, + ) + + # Check that None is at expected indexes. + for idx, o in enumerate(out): + if o is None: + assert idx in output_metadata.indexes_with_none + + return out + + @staticmethod + def backward( + ctx, + *grad_outs, + ): + from torch._dynamo.utils import dynamo_timed + + subgraph = ctx._subgraph + identifier = ctx._identifier + output_metadata = ctx._output_metadata + primals = saved_tensors_and_symints(ctx) + + # Filter out grads that are None or do not require_grad. This was + # the assumption we made during the tracing of joint_graph. + filtered_grad_outs = [] + for idx, o in enumerate(grad_outs): + if o is None: + assert idx in output_metadata.indexes_with_none + elif idx in output_metadata.indexes_with_no_grad: + # Deliberately skip over the grad_outs which we know should be + # None because the corresponding fwd_out does not require_grad. + pass + else: + filtered_grad_outs.append(o) + filtered_grad_outs = tuple(filtered_grad_outs) + + # Important note - Even though the forward graph can be same for + # different invoke_subgraphs, the backward graph can be different + # because the tangent strides can be different. So, here we cache on + # tangent_metadata in addition to identifier + from torch._guards import detect_fake_mode + from torch._subclasses._fake_tensor_utils import _CacheKeyState + from torch._subclasses.fake_tensor import extract_tensor_metadata + + fake_mode = detect_fake_mode(primals + filtered_grad_outs) + state = _CacheKeyState(fake_mode.shape_env) + + tangent_metadata: list[object] = [] + for tangent in filtered_grad_outs: + metadata = extract_tensor_metadata(tangent) + metadata._flatten_into(tangent_metadata, fake_mode, state) + tangent_metadata = tuple(tangent_metadata) + + # bw_graph is a joint graph with signature (*primals_and_tangents) and + # returns (*grads_and_fw_outs). To get the grads, we use the num_fw_outs + # to extract the grads. + primals_and_tangents = primals + filtered_grad_outs + + # Check if we have already traced the bwd subgraph. + bw_graph = None + suffix = None + invoke_subgraph_cache = get_invoke_subgraph_cache() + cache_hit = False + if invoke_subgraph_cache: + bw_graph, suffix = invoke_subgraph_cache.get_lazy_bwd_entry( + identifier, tangent_metadata + ) + cache_hit = bw_graph is not None + + if bw_graph is None: + assert suffix is None + with dynamo_timed( + "invoke_subgraph_trace_joint_graph", log_pt2_compile_event=True + ): + bw_graph = trace_joint_graph_as_bwd( + subgraph, + len(primals), + primals_and_tangents, + ctx._fw_include_key_set, + ctx._fw_exclude_key_set, + ) + + if invoke_subgraph_cache and not cache_hit: + suffix = invoke_subgraph_cache.add_lazy_bwd_entry( + identifier, tangent_metadata, bw_graph + ) + + grads = invoke_subgraph( + bw_graph, f"bw_{identifier}_{suffix}", *primals_and_tangents + )[: -output_metadata.num_fw_outs] + return None, None, None, *grads + + +@invoke_subgraph.py_autograd_impl +def _(subgraph, identifier, *operands): + # Check if we have already traced the subgraph. + invoke_subgraph_cache = get_invoke_subgraph_cache() + if invoke_subgraph_cache: + if saved_autograd_fn := invoke_subgraph_cache.get_autograd_key_entry( + identifier + ): + return saved_autograd_fn(*operands) + + output_metadata = get_output_metadata(subgraph, *operands) + + def autograd_fn_callable(*args): + return InvokeSubgraphAutogradOp.apply( + subgraph, identifier, output_metadata, *args + ) + + # Save the autograd_fn_callable in the dispatch set cache. + if invoke_subgraph_cache: + invoke_subgraph_cache.add_autograd_key_entry(identifier, autograd_fn_callable) + + return autograd_fn_callable(*operands) + + +@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd) +def _(subgraph, identifier, *operands): + from torch.utils._python_dispatch import _get_current_dispatch_mode + + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return subgraph(*operands) + + +@invoke_subgraph.py_functionalize_impl +def _(ctx, subgraph, identifier, *operands): + from torch._higher_order_ops.auto_functionalize import ( + can_auto_functionalize, + do_auto_functionalize_v2, + ) + + unwrapped_operands = ctx.unwrap_tensors(operands) + hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands) + if can_auto_functionalize(hop_instance): + # NOTE: [auto_functionalize x invoke_subgraph caching] + # We call auto_functionalized_v2 to support input mutation of invoke_subgraph. + # See NOTE [Support input mutation of hops] for the overall design. + # + # invoke_subgraph is special because of its identifier based caching machanism. + # In invoke_subgraph's functionalization key implementation, we create a new + # identifer because the subgraph is replaced by FunctionWithNoFreeVars in a + # functional + epilogue form. + assert isinstance(identifier, str), identifier + return do_auto_functionalize_v2( + ctx.mode, + hop_instance, + (subgraph, "auto_functionalized_" + identifier, *operands), + {}, + ) + + with ctx.redispatch_to_next(): + # NB: There is an assumption that subgraph does not mutate inputs and + # there is no aliasing. Its Dynamo responsibility to prevent formation + # of invoke_subgraph ops if input aliasing/mutation is detected. + functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph) + out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands) + return ctx.wrap_tensors(out) + + +# Register the hop fake fn. This will be called in the fake_tensor _dispatch_impl. +@register_fake(invoke_subgraph) +def _(subgraph, identifier, *operands): + from torch._dynamo.utils import dynamo_timed + + with dynamo_timed("invoke_subgraph_fake_tensor", log_pt2_compile_event=True): + return subgraph(*operands) + + +@invoke_subgraph.py_impl(ProxyTorchDispatchMode) +def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands): + # Check if we have already traced the subgraph. + graph = None + invoke_subgraph_cache = get_invoke_subgraph_cache() + if invoke_subgraph_cache: + graph = invoke_subgraph_cache.get_proxy_dispatch_entry(identifier) + + if graph is None: + from torch._dynamo.utils import dynamo_timed + + with dynamo_timed("invoke_subgraph_proxy_tensor", log_pt2_compile_event=True): + graph = reenter_make_fx(subgraph)(*operands) + + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(operands) + insert_deferred_runtime_asserts( + graph, + fake_mode.shape_env, + "invoke_subgraph_proxy_torch_dispatch_mode", + export=True, + ) + graph.recompile() + + assert isinstance(proxy_mode.tracer, torch.fx.Tracer) + if invoke_subgraph_cache: + invoke_subgraph_cache.add_proxy_dispatch_entry(identifier, graph) + + node_args = (graph, identifier, *operands) + + def _unwrap_proxy(arg): + if isinstance(arg, torch.fx.GraphModule): + # NOTE: [invoke_subgraph proxy_mode x auto_functionalize] + # Previously, we assumed that `invoke_subgraph` would always be traced with the same tracer. + # This allowed us to cache modules by their identifiers, assuming they were already registered. + # + # However, this assumption no longer holds when we auto-functionalize `invoke_subgraph`. + # auto_functionalize functionalizes the subgraph and wrap it with `FunctionWithNoFreeVars`. + # In the proxy mode implementation of `auto_functionalized_v2`, we need to materialize `FunctionWithNoFreeVars` + # input as a graph module. To do this, we re-trace the `invoke_subgraph` hop, which starts a new sub-tracer + # (see NOTE [materialize callable inputs as graph]). # When the new sub-tracer traces the `invoke_subgraph` + # with a previously cached identifier, the corresponding graph module might not + # exist as a submodule in the new tracer's root. Therefore, we register it as a submodule below. + # + # The alternative is to give a new identifer when we re-trace the invoke_subgraph but this will increase + # the compilatoin time, which defeats the purpose of caching. + registered_before = False + for ( + _, + submod, + ) in proxy_mode.tracer.root.named_modules(): # type: ignore[union-attr] + if arg is submod: + registered_before = True + + if not registered_before: + qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph") # type: ignore[union-attr] + proxy_mode.tracer.root.register_module(qualname, arg) # type: ignore[union-attr] + return proxy_mode.tracer.unwrap_proxy(arg) # type: ignore[union-attr] + + proxy_args = pytree.tree_map(_unwrap_proxy, node_args) # type: ignore[union-attr] + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", invoke_subgraph, proxy_args, {} + ) + + example_out = invoke_subgraph(graph, identifier, *operands) + return track_tensor_tree( + example_out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/map.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/map.py new file mode 100644 index 0000000000000000000000000000000000000000..75be06cf9504cdad59b6b86e4fe3c5f9afc5fac5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/map.py @@ -0,0 +1,291 @@ +# mypy: allow-untyped-defs +import functools +from typing import Callable, Union +from typing_extensions import TypeVarTuple + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._dispatch.python import suspend_functionalization +from torch._higher_order_ops.utils import _maybe_run_with_interpreter, reenter_make_fx +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.functional_tensor import disable_functional_mode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) + +from .utils import ( + _from_fun, + _stack_pytree, + _unstack_pytree, + clone_outputs_aliasing_inputs, + prepare_fw_with_masks, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, +) + + +class MapImpl(HigherOrderOperator): + def __init__(self): + super().__init__("map_impl") + + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +map_impl = MapImpl() + + +def create_fw_bw_graph(f, num_mapped_args, *args): + mapped_xs = args[:num_mapped_args] + pos_args = args[num_mapped_args:] + + # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs) + example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] + + example_pos_args = [ + _from_fun(arg) if isinstance(arg, torch.Tensor) else arg + for arg in pos_args + ] + example_flat_out = pytree.tree_map( + _from_fun, f(*example_xs, *example_pos_args) + ) + if any( + not isinstance(out, torch.Tensor) + for out in example_flat_out + if out is not None + ): + raise RuntimeError( + "Expect outputs of map only contains tensors or None. " + f"Got types {[type(out) for out in example_flat_out]}." + ) + example_grad = [_from_fun(out) for out in example_flat_out] + + fw_graph = make_fx(f)(*example_xs, *example_pos_args) + + from torch._functorch.aot_autograd import AOTConfig, create_joint + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + def joint_f(*example_args): + joint_mapped_args = example_args[:joint_num_mapped] + args = example_args[joint_num_mapped:] + + mapped_input = joint_mapped_args[:num_mapped_args] + mapped_grads = joint_mapped_args[num_mapped_args:] + + joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config) + _, grads = joint( + list(mapped_input) + list(args), + [ + grad + for grad in mapped_grads + if grad is not None and grad.requires_grad + ], + ) + + # In order to keep map functional for backward graph, + # we clone outputs that are aliasing inputs + maybe_clone = clone_outputs_aliasing_inputs(example_args) + + return pytree.tree_map(maybe_clone, grads) + + joint_num_mapped = len(example_grad) + len(example_xs) + joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) + return fw_graph, joint_graph + + +def map( + f: Callable[[pytree.PyTree, tuple[pytree.PyTree, ...]], pytree.PyTree], + xs: Union[pytree.PyTree, torch.Tensor], + *args: TypeVarTuple, +): + r""" + Perfoms a map of f with xs. Intuitively, you can think of the semantic being: + + out = [] + for idx in len(xs.size(0)): + xs_sliced = xs.select(0, idx) + out.append(f(xs_sliced, *args)) + torch.stack(out) + + .. warning:: + `torch._higher_order_ops.map` is a prototype feature in PyTorch. It currently + does not support autograd and you may run into miscompiles. + Read more about feature classification at: + https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + + Args: + f (Callable): a callable that takes an input x, that could either be a single Tensor + or a nested dict, list of tensors and some additional inputs + xs: the inputs that're to be mapped over. We'll iterate over the first dim of each x + and perform f on each slice. + + *args: additional arguments provided to each step of f. They could also be omitted and + map is able to automatically figure out the read dependency. + + Return: + the stacked output for each step of f + + Example: + + def f(xs): + return xs[0] + xs[1] + const1 + const2 + + xs = [torch.randn(2, 3), torch.randn(2, 3)] + const1 = torch.randn(2, 3) + const2 = torch.randn(2, 3) + # returns a tensor of shape [2, 2, 3] + torch._higher_order_ops.map(f, xs) + + """ + flat_xs, xs_spec = pytree.tree_flatten(xs) + flat_args, args_spec = pytree.tree_flatten(args) + if not all(isinstance(t, torch.Tensor) for t in flat_xs): + raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.") + + shapes = [xs.shape for xs in flat_xs] + leading_dim_size = shapes[0][0] + if leading_dim_size == 0: + raise RuntimeError("Leading dimensions of mapped xs cannot be 0.") + + if any(cur_shape[0] != leading_dim_size for cur_shape in shapes): + raise RuntimeError( + f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}." + ) + + def run_flattened_map(f, flat_xs, flat_args): + def wrapped_fn(*flat_args, f, xs_tree_spec, args_tree_spec, num_xs): + xs = pytree.tree_unflatten(flat_args[:num_xs], xs_tree_spec) + args = pytree.tree_unflatten(flat_args[num_xs:], args_tree_spec) + return f(xs, *args) + + inner_f = functools.partial( + wrapped_fn, + f=f, + xs_tree_spec=xs_spec, + args_tree_spec=args_spec, + num_xs=len(flat_xs), + ) + return map_impl(inner_f, flat_xs, flat_args) + + from torch._higher_order_ops.utils import _maybe_compile_and_run_fn + + return _maybe_compile_and_run_fn(run_flattened_map, f, flat_xs, flat_args) + + +class MapAutogradOp(torch.autograd.Function): + @staticmethod + def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): + save_tensors_and_symints_for_backward(ctx, flat_args) + ctx._joint_graph = joint_graph + ctx._num_mapped_args = num_mapped_args + with torch._C._AutoDispatchBelowAutograd(): + return ( + *map_impl( + fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:] + ), + ) + + @staticmethod + def backward(ctx, *flat_grads): + fw_args = saved_tensors_and_symints(ctx) + fw_mapped_args = fw_args[: ctx._num_mapped_args] + pos_args = fw_args[ctx._num_mapped_args :] + + grads = map_impl( + ctx._joint_graph, + fw_mapped_args + flat_grads, + pos_args, + ) + return None, None, None, *grads + + +def trace_map(proxy_mode, func_overload, f, xs, pos_args): + example_input = _unstack_pytree(xs)[0] + body_graph = f + + body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args) + + next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_") + + proxy_mode.tracer.root.register_module(next_name, body_graph) + + fake_outs = map_impl(body_graph, xs, pos_args) + + node_args = (body_graph, list(xs), list(pos_args)) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="map_impl" + ) + return track_tensor_tree( + fake_outs, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + +@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd) +def map_dense(f, xs, pos_args): + pytrees = [f(*inp, *pos_args) for inp in _unstack_pytree(xs)] + return _stack_pytree(pytrees) + + +@map_impl.py_autograd_impl +def map_autograd(f, xs, pos_args): + num_mapped_args = len(xs) + fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args) + flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args) + return flat_out + + +@map_impl.py_impl(ProxyTorchDispatchMode) +def map_proxy_torch_dispatch_mode(mode, f, xs, args): + return trace_map(mode, map_impl, f, xs, args) + + +@map_impl.py_impl(FakeTensorMode) +def map_fake_tensor_mode(mode, f, xs, args): + with mode: + return map_dense(f, xs, args) + + +@map_impl.py_functionalize_impl +def map_functionalize(ctx, f, xs, pos_args): + from torch._higher_order_ops.utils import _check_alias_and_mutation + + unwrapped_xs = ctx.unwrap_tensors(xs) + unwrapped_args = ctx.unwrap_tensors(pos_args) + wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f)) + + with ctx.redispatch_to_next(): + example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + _check_alias_and_mutation(f, example_inputs, "map", pre_dispatch) + map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args) + return ctx.wrap_tensors(map_return) + + +def _fake_map(f, x, *args): + from functorch.experimental.control_flow import _stack_pytree, _unstack_pytree + + x_pytrees = _unstack_pytree(x) + zs = [] + for xp in x_pytrees: + zs.append(f(xp, *args)) + return _stack_pytree(zs) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/out_dtype.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/out_dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..e8bb0da3728ad44894bbc88005986804f2606425 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/out_dtype.py @@ -0,0 +1,163 @@ +# mypy: allow-untyped-defs + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + maybe_handle_decomp, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +# TODO to figure out a more generic approach +ALLOWABLE_OPS = [ + torch.ops.aten.linear.default, + torch.ops.aten.mm.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.convolution.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul.Scalar, + torch.ops.aten.div.Tensor, + torch.ops.aten.div.Scalar, +] + + +class OutDtypeOperator(HigherOrderOperator): + """ + The out_dtype operator takes an existing ATen functional operator, an + `out_dtype` argument, and arguments to the original operator, and executes + the original operator and returns a Tensor with the `out_dtype` precision. + This operator does not mandate a compute precision so it allows the + representation to not be opinionated about the exact implementation. + + The general implementation for all operators will be the following: + 1. Promote inputs dtypes based on default PyTorch dtype promotion rules, + using the dtypes of all input Tensors/Scalars and the `out_dtype` + arugument. + 2. Execute the operator + 3. Cast the output to `out_dtype` + """ + + def __init__(self) -> None: + super().__init__("out_dtype") + + def __call__(self, op, output_dtype, *args): + if not isinstance(op, torch._ops.OpOverload): + raise ValueError("out_dtype's first argument must be an OpOverload") + if op._schema.is_mutable: + raise ValueError( + "out_dtype's first argument needs to be a functional operator" + ) + if not ( + len(op._schema.returns) == 1 + and isinstance(op._schema.returns[0].type, torch.TensorType) + ): + raise ValueError( + "out_dtype's can only apply to ops that return a single tensor" + f"Instead got {[r.type for r in op._schema.returns]}" + ) + + if op not in ALLOWABLE_OPS: + raise ValueError( + f"out_dtype only allows the following operators: {ALLOWABLE_OPS}." + ) + + res = super().__call__(op, output_dtype, *args) + + return res + + +out_dtype = OutDtypeOperator() + + +def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args): + # NB: Long-term we should put the decomposition logic into + # ProxyTorchDispatchMode so that people do not need to call maybe_handle_decomp + # in all HigherOrderOp proxy implementations. + r = maybe_handle_decomp(proxy_mode, func_overload, (op, output_dtype, *args), {}) + if r is not NotImplemented: + return r + + with disable_proxy_modes_tracing(): + # This is a simplified implementation of this operator just for tracing. + # Actual implementation may also first promote the arguments + out = op(*args).to(dtype=output_dtype) + + node_args = (op, output_dtype, *args) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="out_dtype" + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@out_dtype.py_impl(DispatchKey.CompositeExplicitAutograd) +def out_dtype_dense(op: torch._ops.OpOverload, output_dtype: torch.dtype, *args): + if is_int_mm(op, output_dtype, args): + return torch._int_mm(*args) + return out_dtype_fallback(op, output_dtype, *args) + + +def is_int_mm(op, output_dtype, args): + return ( + op == torch.ops.aten.mm.default + and output_dtype == torch.int32 + and len(args) == 2 + and args[0].dtype == torch.int8 + and args[1].dtype == torch.int8 + and args[0].is_cuda + and args[1].is_cuda + ) + + +def out_dtype_fallback(op, output_dtype, *args): + flat_inputs = pytree.arg_tree_leaves(*args) + [torch.ones(1, dtype=output_dtype)] + promote_dtype: torch.dtype = elementwise_dtypes( + *flat_inputs, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + )[0] + + casted_args = pytree.tree_map_only( + torch.Tensor, lambda arg: arg.to(dtype=promote_dtype), args + ) + res = op(*casted_args).to(dtype=output_dtype) + return res + + +out_dtype.py_autograd_impl(autograd_not_implemented(out_dtype, deferred_error=True)) + + +@out_dtype.py_impl(ProxyTorchDispatchMode) +def out_dtype_proxy( + mode: ProxyTorchDispatchMode, + op: torch._ops.OpOverload, + output_dtype: torch.dtype, + *args, +): + return trace_out_dtype(mode, out_dtype, op, output_dtype, *args) + + +@out_dtype.py_impl(FakeTensorMode) +def out_dtype_fake_tensor_mode( + mode: FakeTensorMode, + op: torch._ops.OpOverload, + output_dtype: torch.dtype, + *args, +): + with mode: + return out_dtype_dense(op, output_dtype, *args) + + +@out_dtype.py_functionalize_impl +def out_dtype_func(ctx, op, output_dtype, *args): + unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) + + with ctx.redispatch_to_next(): + res = out_dtype(op, output_dtype, *unwrapped_args) + return ctx.wrap_tensors(res) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/run_const_graph.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/run_const_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..b7452b45d5c85ac61f8dc3e677e639884b3a3579 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/run_const_graph.py @@ -0,0 +1,60 @@ +# mypy: allow-untyped-defs +import torch +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.utils import _pytree as pytree + + +class RunConstGraph(HigherOrderOperator): + def __init__(self): + super().__init__("run_const_graph") + + def __call__(self, graph, args): + return super().__call__(graph, args) + + +run_const_graph = RunConstGraph() + + +@run_const_graph.py_impl(ProxyTorchDispatchMode) +def run_const_graph_dispatch_mode(mode, graph, args): + const_gm, weights = graph, args + p_args = pytree.tree_map(mode.tracer.unwrap_proxy, (graph, args)) + assert isinstance(const_gm, torch.fx.GraphModule) + assert not hasattr(mode.tracer.root, "_const_graph") + mode.tracer.root.register_module("_const_graph", const_gm) + + proxy = mode.tracer.create_proxy("call_function", run_const_graph, p_args, {}) + + out = const_gm(*weights) + return track_tensor_tree(out, proxy, constant=None, tracer=mode.tracer) + + +@run_const_graph.py_functionalize_impl +def run_const_graph_functional(ctx, graph, args): + unwrapped_args = ctx.unwrap_tensors(args) + + with ctx.redispatch_to_next(): + out = run_const_graph(*unwrapped_args) + return ctx.wrap_tensors(out) + + +run_const_graph.py_autograd_impl( + autograd_not_implemented(run_const_graph, deferred_error=True) +) + + +@run_const_graph.py_impl(FakeTensorMode) +def run_const_graph_fake_tensor_mode(mode, graph, args): + assert isinstance(graph, torch.fx.GraphModule) + with mode: + return graph(*args) + + +@run_const_graph.py_impl(DispatchKey.CPU) +def run_const_graph_cpu(graph, args): + assert isinstance(graph, torch.fx.GraphModule) + return graph(*args) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/scan.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/scan.py new file mode 100644 index 0000000000000000000000000000000000000000..5f76f66410624b0a76966e3c7b6c22965121cee3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/scan.py @@ -0,0 +1,929 @@ +# mypy: allow-untyped-defs +import functools +import itertools +from collections.abc import Sequence +from typing import Any, Callable, Optional + +import torch +import torch._prims_common as utils +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.cond import create_bw_fn +from torch._higher_order_ops.utils import ( + _maybe_compile_and_run_fn, + check_meta_consistency, + first_slice_copy, + materialize_as_graph, + reenter_make_fx, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, + unique_graph_id, + validate_subgraph_args_types, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +aten = torch._ops.ops.aten + + +def wrap_combine_fn_flat( + *args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves +): + assert len(args) == ( + num_init_leaves + num_inp_leaves + ), f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}" + carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init) + xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs) + return combine_fn(carry, xs) + + +def _extract_carry_and_out(flat_out: list[Any], num_carry: int): + return split_into_chunks(flat_out, [num_carry, len(flat_out) - num_carry]) + + +# We also do a clone with contiguous_format. This is to be consistent with +# eager semantic of scan, which stacks the outputs. The result is contiguous +# as a result of the stack operation. +def stack_y(y: torch.Tensor, scan_length: int) -> torch.Tensor: + return ( + y.unsqueeze(0) + .repeat(*([scan_length] + [1] * y.ndim)) + .clone(memory_format=torch.contiguous_format) + ) + + +# NOTE: These functions can be reused in associative_scan and eventually moved to +# torch._higher_order_ops.utils +def get_tensor_mask(tensor_list: list[Any]) -> list[bool]: + # Returns a mask whether a list element is a tensor or not + return [True if isinstance(v, torch.Tensor) else False for v in tensor_list] + + +def mask_list( + mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None +) -> list[Any]: + # Masks elements on an `inp` list. + # If other is None, then the elements of the `inp` list where the mask is False are removed + # If other is not None, then the elements of the `inp` list where the mask is False are + # replaced with the elements of the `other` list + assert len(mask) == len( + inp + ), "The length of the mask needs to be identical to the length of the input" + if other is not None: + assert len(inp) == len( + other + ), "If an input and an other list is provided, they need to have the same length" + return [i if m else o for m, i, o in zip(mask, inp, other)] + else: + return [i for m, i in zip(mask, inp) if m] + + +def first_slice_copy_with_grad(li: list[Any]) -> list[Any]: + # First_slice_copy does not keep the original requires_grad flag, + # but we need it for materialize_as_graph + # in order to compute the correct gradients + # The reason why first_slice_copy doesn't keep requires_grad flag is + # because it's called in torch.autograd.Function.backward/forward. + slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li] + return slc + + +def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]: + it = iter(iterable) + assert sum(chunk_sizes) == len( + iterable + ), "the sum of all chunks needs to match the length of the iterable." + return [list(itertools.islice(it, size)) for size in chunk_sizes] + + +def call_operator(operator, *args): + return pytree.tree_leaves(operator(*args)) + + +def scan( + combine_fn: Callable[ + [pytree.PyTree, pytree.PyTree], tuple[pytree.PyTree, pytree.PyTree] + ], + init: pytree.PyTree, + xs: pytree.PyTree, + *, + dim: int = 0, + reverse: bool = False, +) -> tuple[pytree.PyTree, pytree.PyTree]: + r""" + Performs an inclusive scan with a combine function. + + .. warning:: + `torch.scan` is a prototype feature in PyTorch. It currently + does not support autograd and you may run into miscompiles. + Read more about feature classification at: + https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + Args: + combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> (Tensor, Tensor)``, + or if xs is a pytree ``(pytree, pytree) -> (pytree, pytree)``. + The first input to ``combine_fn`` is the previous or initial scan carry + and the second input element to ``combine_fn`` is a slice of the input along dim. + The first output element of ``combine_fn`` is the next scan carry + and the second output of ``combine_fn`` represents a slice of the output. + This function must be pure, i.e., no lifted arguments are supported at the moment + and may not have any side effects. + init (torch.Tensor or pytree with tensor leaves): The inital scan carry, a tensor, or nested pytree of tensors. + The ``init`` is expected to have the same pytree structure as the first output element (i.e. carry) + of ``combine_fn``. + xs (torch.Tensor or pytree with tensor leaves): The input tensor, or nested pytree of tensors. + + Kwargs: + dim (int): the dimension to scan over, default 0. + reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``. + + Returns: + final_carry (torch.Tensor or pytree with tensor leaves), + the final carry of the scan operation with same pytree structure as init. + out (torch.Tensor or pytree with tensor leaves), + each tensor leaf is a stacked output along first dim, where each slice is the output of a scan iteration. + + Restrictions: + - The combine_fn shouldn't have any aliasing between input-input, input-output, and output-output. E.g. return a view + or the same tensor as input is not supported. As a workaround, can clone the output to avoid aliasing. + + - The combine_fn shoudn't mutate any inputs. We'll remove the mutation restriction for inference soon. Please file an issue + if you input mutation support for training is needed. + + - The combine_fn's init carry should match the next_carry in pytree structure and in tensor metadata. + + Example:: + + def add(x: torch.Tensor, y: torch.Tensor): + next_carry = y = x + y + # clone the output to avoid output-output aliasing + return next_carry, y.clone() + + i0 = torch.zeros(1) + xs = torch.arange(5) + # returns torch.tensor([10.]), torch.tensor([[0], [1.], [3.], [6.], [10.]]) + last_carry, cumsum = scan(add, init=i0, xs=xs) + + + """ + # The reason we flatten init and xs before calling into dynamo is that + # we want to create a consistent input ordering for combine_fn + # and we also want to the input ordering matches the output ordering. + leaves_init, spec_init = pytree.tree_flatten(init) + leaves_xs_orig, spec_xs = pytree.tree_flatten(xs) + + # Shortcut if no xs is provided + if len(leaves_xs_orig) == 0: + return init, [] + + def _validate_input(cfn, lxs, linit, d, r): + # Basic arguments check + if not callable(cfn): + raise RuntimeError("Combine_fn must be a callable, but got {cfn}") + if not isinstance(d, int): + raise RuntimeError("Dim must be an int, but got " + str(type(d))) + if not isinstance(r, bool): + raise RuntimeError("Reverse must be a bool, but got " + str(type(r))) + + # Checks for init + if len(linit) == 0: + raise RuntimeError("scan() operator requires init leaves.") + for x in linit: + if not isinstance(x, torch.Tensor): + raise RuntimeError(f"All init leaves must be a Tensor but got {x}") + + # Checks for xs + for x in lxs: + if not isinstance(x, torch.Tensor): + raise RuntimeError(f"All xs leaves must be a Tensor but got {x}") + if any(x.ndim <= d for x in lxs): + raise RuntimeError( + "All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0" + ) + if any(x.shape[d] == 0 for x in lxs): + raise RuntimeError( + "All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0" + ) + + ndim = leaves_xs_orig[0].ndim + dim = utils.canonicalize_dim(ndim, dim) + + _validate_input(combine_fn, leaves_xs_orig, leaves_init, dim, reverse) + + # Move scan dim to 0 and always perform scan on dim 0 + leaves_xs = [] + for elem in leaves_xs_orig: + leaves_xs.append(torch.movedim(elem, dim, 0)) + + if reverse: + leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs] + + # TODO: Support _inductor lowering + # TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc. + + combine_fn = functools.partial( + wrap_combine_fn_flat, + combine_fn=combine_fn, + spec_init=spec_init, + spec_xs=spec_xs, + num_init_leaves=len(leaves_init), + num_inp_leaves=len(leaves_xs), + ) + + def run_flattened_scan(combine_fn, leaves_init, leaves_xs): + return scan_op(combine_fn, leaves_init, leaves_xs, additional_inputs=()) + + carry, out = _maybe_compile_and_run_fn( + run_flattened_scan, + combine_fn, + leaves_init, + leaves_xs, + ) + + if reverse: + out = pytree.tree_map(lambda elem: elem.flip([0]), out) + + return carry, out + + +class ScanOp(HigherOrderOperator): + def __init__(self): + super().__init__("scan") + + def __call__(self, combine_fn, init, xs, additional_inputs): + # There is currently an issue that the ScanOp is sometimes called with + # the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785 + # Once this issue is resolved, the assertion should only allow tuples + # and the tuple cast should be removed + assert isinstance( + additional_inputs, (tuple, list) + ), "additional_inputs must be a tuple." + additional_inputs = ( + tuple(additional_inputs) + if isinstance(additional_inputs, list) + else additional_inputs + ) + validate_subgraph_args_types(additional_inputs) + return super().__call__(combine_fn, init, xs, additional_inputs) + + +scan_op = ScanOp() + + +def generic_scan(operator, init, xs, dim=0, additional_inputs=()): + def _scan(init, xs): + """Perform scan on `elems` using `elems_init.""" + carry = init + if len(xs) == 0: + return carry, [] + + num_elems = xs[0].shape[dim] + ind = 0 + + # Compute dummy shapes for the pre-allocation + num_init_leaves = len(init) + dummy_carry, dummy_out = _extract_carry_and_out( + call_operator( + operator, + *carry, + *[first_slice_copy(elem, dim) for elem in xs], + *additional_inputs, + ), + num_init_leaves, + ) + + out_tensor_mask = get_tensor_mask(dummy_out) + dummy_out_masked = mask_list(out_tensor_mask, dummy_out) + + # Pre-alocate + # outs -> Output matrix + # idxs -> Index matrix for scatter_ + # out: (num_elems, M, N, ...) + # idx: (1, M, N) + outs = [ + torch.zeros( + [num_elems] + list(e.size()), + dtype=e.dtype, + device=e.device, + ) + for i, e in enumerate(dummy_out_masked) + ] + idxs = [ + torch.ones_like(e, dtype=torch.int64).unsqueeze(0) + for i, e in enumerate(dummy_out_masked) + ] + + def store_out_in_outs(out, ind): + # Store the intermediate out in the outs matrix + for o, x, idx in zip(outs, out, idxs): + # o: (num_elems, M, N ...) + # x: (M, N, ...) -> (1, M, N) + # ind * idx: (1, M, N,) with values to be ind + # essentially: o[ind][n][k] = x[0][n][k] + o.scatter_(0, ind * idx, x.unsqueeze(0)) + + for i in range(num_elems): + ind = i + carry, out = _extract_carry_and_out( + call_operator( + operator, + *carry, + *[elem.select(dim, ind) for elem in xs], + *additional_inputs, + ), + num_init_leaves, + ) + + # Store the inits in the outs matrix. + store_out_in_outs(mask_list(out_tensor_mask, out), ind) + + # Expand outs with None depending on the tensor mask of the output + outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask] + + return [*carry, *outs_expanded] + + scans = _scan(init, xs) + return scans + + +def trace_scan( + proxy_mode, + func_overload, + combine_fn: Callable, + init: list[torch.Tensor], + xs: list[torch.Tensor], + additional_inputs: tuple[torch.Tensor], +): + from torch._dynamo.utils import clone_input + + with disable_proxy_modes_tracing(): + sample_inits = [clone_input(x_init) for x_init in init] + sample_inputs = [first_slice_copy(x) for x in xs] + sample_additional_inputs = [ + clone_input(x) if isinstance(x, torch.Tensor) else x + for x in additional_inputs + ] + combine_graph = reenter_make_fx(combine_fn)( + *sample_inits, *sample_inputs, *sample_additional_inputs + ) + + outputs = None + for node in combine_graph.graph.nodes: + if node.op == "output": + assert outputs is None + assert len(node.args) == 1 + outputs = node.args[0] + + assert outputs is not None + + carry, output = _extract_carry_and_out(outputs, len(init)) + init_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ + i.clone() for i in init + ] + carry_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ + c.meta["val"] for c in carry + ] + check_meta_consistency( + init_fake_tensors, carry_fake_tensors, "init", "carry", include_contiguity=False + ) + + _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph") + + proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph) + + args = (combine_graph, init, xs, additional_inputs) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="scan" + ) + + with disable_proxy_modes_tracing(): + scan_length = xs[0].shape[0] + fake_carry, fake_outputs = _extract_carry_and_out( + [o.meta["val"] for o in outputs], len(init) + ) + out = ( + *fake_carry, + *(stack_y(t, scan_length) for t in fake_outputs), + ) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@scan_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def scan_op_dense(combine_fn, init, xs, additional_inputs): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return generic_scan(combine_fn, init, xs, additional_inputs=additional_inputs) + + +class ScanAutogradOp(torch.autograd.Function): + """ + Example :: + + def combine_fn(x: torch.Tensor, y: torch.Tensor): + next_carry = y = x * y + return next_carry, y + + The ``combine_fn_bw``, computing the gradients for x and y of ``combine_fn`` is computed as: + def combine_fn_bw(x: torch.Tensor, y: torch.Tensor, g_carry: torch.Tensor, g_y: torch.Tensor): + return g_y * y + g_carry * y, g_y * x + g_carry * x + + Note: In a real usecase of scan, there may be additional_inputs that participate in the + forward as well as in the backward of the scan operator. For the sake of readability those inputs + have been omitted in the following example, but are included in the subsequent detailed description below + + The forward output of scan is computed as: + carry, ys = scan(combine_fn, init, xs). + + This computation can be unpacked as + c_0, ys_0 = combine_fn(init, xs_0) + c_1, ys_1 = combine_fn(carry_0, xs_1) + c_2, ys_2 = combine_fn(carry_1, xs_2) + ... + c_T, ys_T = combine_fn(carry_(T-1), xs_T) + + We collect c_0, c_1, ..., c_T into a vector of carries that we save for the backward, + but we only output (c_T, ys), + where ys is the vector of all intermediate outputs [y_0, y_1, ..., y_T]. + + Given the carries and the ys, the gradients for xs and for init can be computed as follows: + We receive the upstream gradients in torch.autograd.Function, i.e., we get g_c_T and g_ys, + where g_ys is the vector of all intermediate gradients of the outputs [g_ys_0, g_ys_1, ..., g_ys_T] + + We then proceed to compute the gradients for the init (g_init) and the xs (g_xs) by running a + scan operation reverse over time. For example, + + g_c_(T-1), g_xs_T = combine_fn_bw(c_(T-1), xs_T, g_c_T, g_ys_T) + g_c_(T-2), g_xs_(T-1) = combine_fn_bw(c_(T-2), xs_(T-1), g_c_(T-1), g_ys_(T-1)) + g_c_(T-3), g_xs_(T-2) = combine_fn_bw(c_(T-3), xs_(T-2), g_c_(T-2), g_ys_(T-2)) + ... + g_init, g_xs_1 = combine_fn_bw(c_0, xs_1, g_c_0, g_ys_1) + 0 , g_xs_0 = combine_fn_bw(init, xs_0, g_init, g_ys_0), + + where combine_fn_bw takes the forward inputs of step t (i.e. c_(t-1), xs_t), + the gradients of the carry of step t (i.e. g_c_t) and + the upstream gradient of the output of step t (i.e. g_ys_T) + and returns the gradient of xs_t -> g_xs_t, as well as the gradient for the carry of step t-1 -> g_c_(t-1). + + Through this procedure we end up with the + gradients for the init -> g_init, + the gradients for the xs -> g_xs. + + + NOTE: [scan autograd implementation] + + The forward of scan can be computed as: + 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``: + To use a scan operation for the backward path as well, we need access to the carries from all steps. + Thus, the function ``combine_fn`` is wrapped such that it returns all carries and not only the last carry. + In particular, we define ``combine_fn_with_carry_checkpoint``: + def combine_fn_with_carry_checkpoint(x: torch.Tensor, y: torch.Tensor): + carry, y = combine_fn(x, y) + return carry, (carry, y) + + The scan operator will stack all outputs along the scan dimension. + Thus, by putting next_carry also into outputs of ``combine_fn_with_carry_checkpoint``, + the carries from all steps will be stacked and hence gives us chekpointed_carries + + 2.) Compute all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``: + c_T, (carries, ys) = scan_op(combine_fn_with_carry_checkpoint, init, xs, additional_inputs), + Where c_T (last carry) and ys (all outputs) are the original results of scan with the ``combine_fn``. + However, carries are checkpointed carries from all steps. + As a result of the forward, only the last carry c_T and the ys are returned, + while all carries are saved for the backward. + + The backward of scan can be computed as: + + 3.) Prepare the backward graph: + We prepare the backward graph to be used in the backward function. + We utilize ``create_bw_fn`` to generate the joint function, i.e., + ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands), where fw_operands = [init, xs_0, additional_inputs] + + The ctx._combine_fn_bw requires the primals (operands) + followed by the tangents (upstream gradients) from a single step + and produces the gradients of that step, i.e., + g_c_(T-1), g_xs_T, g_additional_input_T = ctx._combine_fn_bw(c_(T-1), xs_T, additional_inputs, g_c_T, g_ys_T). + + 4.) Create a wrapper of the ``combine_fn_bw``, i.e., ``combine_fn_bw_grad_accumulation``: + In the forward, there may be additional inputs that participate in every forward step. + The gradients for those additional inputs are also computed at every step and need to be accumulated over all steps, + which is taken care of in this wrapper. For example: + def combine_fn_bw_grad_accumulation(*args): + carried_g_additional_input = args[:num_additional_inputs] + inputs_bw_fn = args[num_additional_inputs:] + g_c_(t-1), g_xs_t, g_additional_input_t = ctx._combine_fn_bw(*inputs_bw_fn) + new_g_additional_inputs = carried_g_additional_input + g_additional_input_t + # The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator + # The ``g_xs_t`` is encoded as the output of the backward scan operator + return [*new_g_additional_inputs, *g_c_t, *g_xs_t] + + 5.) Perform the backward scan as + g_additional_inputs, g_init, g_xs = scan_op(combine_fn_bw_grad_accumulation, bw_init, bw_xs), where + bw_init consists of the initial gradient carry for the additional_inputs (initialized with 0s): + initial_g_additional_inputs, and the gradient of the last carry: g_c_T. Thus: + bwd_init = [*initial_g_additional_inputs, *g_c_T]. + + bw_xs consists of the combination of the upstream gradients g_ys, + the forward carries prepended with the fw_init, i.e., bw_carries = concat([fw_init, fw_carries[:-1]]) and + the fw_xs. In particular, + bwd_xs = [*g_ys, *bw_carries, *fw_xs]. + + Note: g_c_T and g_ys are provided through the torch.autograd.Function.backward's input + + As demonstrated in the Example above, this backward scan then yields the gradient for the init -> g_init + and the gradient for the xs -> g_xs + + NOTE: [scan partial grad handling] + If any element of init, of xs, of the outputs or of the additional_inputs does not require gradients, + i.e., requires_grad=False, there will be still gradients returned for those elements, + but those gradients will be a tensor filled with zeros of the same shape as the element itself. + + A special case are additional_inputs that are not tensors. Such inputs can occur for example with symbolic tracing, + where the shape symbol (SymInt) becomes an additional_input. + For such cases, we compute a ``additional_inputs_tensor_mask``, which is True for elements of additional_inputs + that are tensors and False otherwise. Gradients of additional_inputs are only accumulated if this mask is True, + otherwise, the value of initial_g_additional_inputs is passed, which is None for non-Tensor values. + """ + + @staticmethod + def forward( + ctx, + combine_fn, + num_leaves_init, + num_leaves_xs, + num_additional_inputs, + *operands, + ): + ctx._num_leaves_init = num_leaves_init + ctx._num_leaves_xs = num_leaves_xs + ctx._num_additional_inputs = num_additional_inputs + ctx._combine_fn = combine_fn + init, xs, additional_inputs = split_into_chunks( + operands, [num_leaves_init, num_leaves_xs, num_additional_inputs] + ) + additional_inputs_tensor_mask = get_tensor_mask(additional_inputs) + ctx._additional_inputs_tensor_mask = additional_inputs_tensor_mask + + # We snapshot the dispatch keys in forward for materializing the + # the bw_graph in backward. + ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set() + ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + + # 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint`` + # The wrapper of the forward graph returns carries from all iterations, + # not just from the last iteration. These are required in the backward path + def combine_fn_with_carry_checkpoint(*args): + carry, y = _extract_carry_and_out(combine_fn(*args), num_leaves_init) + return [ + *carry, + # We additionally checkpoint all the intemediate carry outputs for backward. + *[ + n_c.clone().detach() if isinstance(n_c, torch.Tensor) else n_c + for n_c in carry + ], + *y, + ] + + with torch._C._AutoDispatchBelowAutograd(): + # 2.) Compute the all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint`` + c_T, carries_ys = _extract_carry_and_out( + scan_op( + combine_fn_with_carry_checkpoint, + init, + xs, + additional_inputs, + ), + num_leaves_init, + ) + + # Collect the carries for each time step from the outs + # and save them for the backward path + carries = list(carries_ys[:num_leaves_init]) + ys = list(carries_ys[num_leaves_init:]) + save_tensors_and_symints_for_backward(ctx, list(operands) + carries + ys) + ctx._num_leaves_ys = len(ys) + + return (*c_T, *ys) + + @staticmethod + def backward(ctx, *flat_grads): + r""" + This function computes the gradients of the scan operation. + It does so by using a scan operator using all carries and the upstream gradients (see description above) + + Args: + flat_grads (torch.Tensor): The tensor of flattened upstream gradients. + """ + + # Collect the saved items from the forward + num_leaves_init = ctx._num_leaves_init + num_leaves_xs = ctx._num_leaves_xs + num_leaves_ys = ctx._num_leaves_ys + num_additional_inputs = ctx._num_additional_inputs + additional_inputs_tensor_mask = ctx._additional_inputs_tensor_mask + + def prepend_init_to_carries(init, carries): + # Prepare the carries for the backward path. + # This requires to concatenate the init and the carries + return [ + torch.cat([torch.unsqueeze(i, 0), c[:-1]], dim=0) + for i, c in zip(init, carries) + ] + + def initialize_g_additional_inputs( + additional_inputs, + ): + # The initial gradients for the additional_inputs are all zeros + g_additional_inputs = [ + torch.zeros_like(ai) if ai_tm else None + for ai_tm, ai in zip(additional_inputs_tensor_mask, additional_inputs) + ] + return g_additional_inputs + + # Retrieve the forward inputs and the forward outputs and dissect them + flat_args = saved_tensors_and_symints(ctx) + fw_init, fw_xs, additional_inputs, fw_carries, fw_ys = split_into_chunks( + flat_args, + [ + num_leaves_init, + num_leaves_xs, + num_additional_inputs, + num_leaves_init, + num_leaves_ys, + ], + ) + + # 3.) Prepare the backward graph + fw_operands = ( + *fw_init, + *[first_slice_copy(xs) for xs in fw_xs], + *additional_inputs, + ) + ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands) + + # 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs + def combine_fn_bw_grad_accumulation(*args): + # Dissect args and re-order them for the ``ctx._combine_fn_bw`` + # The content of ``combine_fn_bw_tangents`` is [*carries_g, *outs_g] + # The content of ``combine_fn_bw_primals`` is [*init, *xs, *additional_inputs] + ( + carried_g_additional_input, + combine_fn_bw_tangents, + combine_fn_bw_primals, + ) = split_into_chunks( + args, + [ + num_additional_inputs, + num_leaves_init + num_leaves_ys, + num_leaves_init + num_leaves_xs + num_additional_inputs, + ], + ) + combine_fn_bw_args = (*combine_fn_bw_primals, *combine_fn_bw_tangents) + + g_c_t, g_xs_t, g_additional_inputs_t = split_into_chunks( + ctx._combine_fn_bw(*combine_fn_bw_args), + [num_leaves_init, num_leaves_xs, num_additional_inputs], + ) + + new_g_additional_inputs = [ + # If the additional inputs are ints or SymInts, those values are taken as is and no gradients are added + carr_g + curr_g if add_inp_tm else carr_g + for add_inp_tm, carr_g, curr_g in zip( + additional_inputs_tensor_mask, + carried_g_additional_input, + g_additional_inputs_t, + ) + ] + + # The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator + # The ``g_xs_t`` is encoded as the output of the backward scan operator + return [*new_g_additional_inputs, *g_c_t, *g_xs_t] + + # Materialize the ``combine_fn_bw_grad_accumulation`` + def construct_args_single_step_bw(): + # This function constructs the arguments for a single step of the backward scan. + # In other words, it creates the arguments for ``combine_fn_bw_grad_accumulation`` + # The order of the arguments returned is identical to the order the backward scan + # operations provides + + # The following arguments are used for the backward part of the joint graph + # The first argument relates to the gradient accumulation of the additional inputs. + # Because only tensor elements of additional inputs can have requires_grad=True, + # the values for non-tensor elements of additional inputs are None + masked_additional_inputs = [ + a.clone() if add_inp_tm else None + for add_inp_tm, a in zip( + additional_inputs_tensor_mask, additional_inputs + ) + ] + + # The second argument relates to the gradients of the carries. + # Because the arguments are for a single step only, + # only the first slice of the carries is used. + sliced_carries = [first_slice_copy(c) for c in fw_carries] + + # The third argument relates to the gradients of the ys. + # Because the arguments are for a single step only, + # only the first slice of the ys is used. + sliced_ys = [first_slice_copy(o) for o in fw_ys] + + # The following arguments are used for the forward part of the joint graph + # The fourth argument relates to the init for the forward. + # I.e., fw_init + + # The fifth argument relates to the xs for the forward. + # Because the arguments are for a single step only, + # only the first slice of the xs is used. + # Note: It is important to preserve the requires_grad flag of xs + # and thus we use the wrapper function ``first_slice_copy_with_grad`` + fw_xs_slice = first_slice_copy_with_grad(fw_xs) + + # The last argument relates to the additional inputs for the forward. + # I.e., additional_inputs + + return ( + *masked_additional_inputs, + *sliced_carries, + *sliced_ys, + *fw_init, + *fw_xs_slice, + *additional_inputs, + ) + + args_single_step_bw = construct_args_single_step_bw() + + # TODO: we need to materialize the bw graphs because dynamo is unable to + # trace through the joint function when torch.compile torch.autograd.grad. + combine_fn_bw_grad_accumulation_gm = materialize_as_graph( + combine_fn_bw_grad_accumulation, + args_single_step_bw, + ctx._fw_include_key_set, + ctx._fw_exclude_key_set, + force_enable_grad=True, + ) + + # Decompose the flat_grads into g_c_T, g_ys + g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys]) + + # Initialize the g_additional_inputs with zero-tensors. + # This step is necessary because the gradients of the additional inputs are accumulated in the + # ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point + initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs) + + # Prepend the inits to the carries. + # This is needed, because when computing the gradients, the last carry is not needed + # but the first carry, the init, is required. + bw_carries = prepend_init_to_carries(fw_init, fw_carries) + + # Prepare the xs for the backward scan. + bwd_xs = [*g_ys, *bw_carries, *fw_xs] + + # The flipping of the ``bwd_xs`` is necessary because the scan_op in the backward is always performed in reverse + bwd_xs = [torch.flip(elem, [0]) for elem in bwd_xs] + + # Prepare the bwd_init + bwd_init = [*initial_g_additional_inputs, *g_c_T] + + # 5.) Perform the backwrad scan: + # The ``combine_fn_bw_wrapped`` receives the + # initial_g_additional_inputs and the last carry as the ``bwd_init`` and the + # gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs`` + gradients = scan_op( + combine_fn_bw_grad_accumulation_gm, + bwd_init, + bwd_xs, + additional_inputs, + ) + + # Unpack the computed gradients + g_additional_inputs, g_init, g_xs = split_into_chunks( + gradients, [num_additional_inputs, num_leaves_init, num_leaves_xs] + ) + + # The flipping back along the scan dimension is required to get the gradients in the right order for ``xs`` + g_xs = [torch.flip(elem, [0]) for elem in g_xs] + + return *[None] * 4, *g_init, *g_xs, *g_additional_inputs + + +@scan_op.py_autograd_impl +def scan_autograd(combine_fn, init, xs, additional_inputs): + num_leaves_init = len(init) + num_leaves_xs = len(xs) + num_additional_inputs = len(additional_inputs) + + flat_out = ScanAutogradOp.apply( + combine_fn, + num_leaves_init, + num_leaves_xs, + num_additional_inputs, + *(tuple(init) + tuple(xs) + additional_inputs), + ) + return *flat_out[:num_leaves_init], *flat_out[num_leaves_init:] + + +@scan_op.py_impl(ProxyTorchDispatchMode) +def scan_proxy_mode(mode, combine_fn, init, xs, additional_inputs): + return trace_scan(mode, scan_op, combine_fn, init, xs, additional_inputs) + + +@scan_op.py_impl(FakeTensorMode) +def scan_fake_tensor_mode(mode, combine_fn, init, xs, additional_inputs): + with mode: + scan_length = xs[0].shape[0] + carry, outputs = _extract_carry_and_out( + combine_fn( + *init, + *[first_slice_copy(inp) for inp in xs], + *additional_inputs, + ), + len(init), + ) + out = ( + *carry, + *(stack_y(t, scan_length) for t in outputs), + ) + return out + + +@scan_op.py_functionalize_impl +def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs): + from torch._higher_order_ops.utils import ( + _check_alias_and_mutation, + _maybe_run_with_interpreter, + ) + + unwrapped_xs = ctx.unwrap_tensors(xs) + unwrapped_init = ctx.unwrap_tensors(init) + unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) + + with ctx.redispatch_to_next(): + functional_combine_fn = ctx.functionalize( + _maybe_run_with_interpreter(combine_fn) + ) + sample_unwrapped_xs_sliced = [first_slice_copy(inp) for inp in unwrapped_xs] + sample_inputs = list( + itertools.chain( + unwrapped_init, + sample_unwrapped_xs_sliced, + unwrapped_additional_inputs, + ) + ) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + _check_alias_and_mutation(combine_fn, sample_inputs, "scan", pre_dispatch) + ret = scan_op( + functional_combine_fn, + unwrapped_init, + unwrapped_xs, + unwrapped_additional_inputs, + ) + return ctx.wrap_tensors(ret) + + +# dense implementation for scan. Used for testing only. +def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False): + carry_leaves, carry_spec = pytree.tree_flatten(init) + inp_leaves, inp_spec = pytree.tree_flatten(xs) + if xs is None or len(inp_leaves) == 0: + return init, [] + result_flat = [] + carry = carry_leaves + op = reversed if reverse else lambda x: x + + dummy_carry, dummy_out = combine_fn( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten( + [first_slice_copy(elem, dim) for elem in inp_leaves], + inp_spec, + ), + ) + dummy_out_leaves, dummy_out_spec = pytree.tree_flatten(dummy_out) + num_leaves = len(dummy_out_leaves) + + for ind in op(range(inp_leaves[0].size(dim))): + xs = [elem.select(dim, ind) for elem in inp_leaves] + + carry, y = combine_fn( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten(xs, inp_spec), + ) + carry, _ = pytree.tree_flatten(carry) + y, _ = pytree.tree_flatten(y) + result_flat.append(y) + + results = [ + torch.stack([e[leave_ind] for e in op(result_flat)]) + for leave_ind in range(num_leaves) + ] + return ( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten(results, dummy_out_spec), + ) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/schema.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..97aaf6632b71e885e6ed5b732de3fb2cd2807510 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/schema.py @@ -0,0 +1,306 @@ +import copy +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch.fx.node import Target + + +# Below is an implementation of generating FunctionSchema from example values. +# This is helpful for generating FunctionSchema for HigherOrderOperator, where +# we don't have a function to inspect and each call of the higher order operator +# would have different schema. +@dataclass(frozen=True) +class HopArgumentInfo: + # Could give a name to the operand by default it's empty string. + name: str + example_value: Any + # Provide an default_value + default_value: Any + # Whether this arugment gets mutated in the hop subgraph. + # For output, this should always be False + is_mutated: bool + kw_only: bool + + +class HopArgumentInfoGen: + @staticmethod + def from_example( + example_value: Any, + *, + name: str = "", + default_value: Optional[Any] = None, + is_mutated: bool = False, + kw_only: bool = False, + ) -> HopArgumentInfo: + if default_value is not None: + assert type(example_value) == type( + default_value + ), f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}" + + return HopArgumentInfo( + name=name, + example_value=example_value, + default_value=default_value, + is_mutated=is_mutated, + kw_only=kw_only, + ) + + +class CTypeGen: + convert_to_base_ty = { + int: torch._C.IntType.get(), + float: torch._C.FloatType.get(), + str: torch._C.StringType.get(), + bool: torch._C.BoolType.get(), + } + + # should return torch._C.JitType but that annotation is busted + @staticmethod + def from_example(obj: Any) -> Any: + import torch + + if isinstance(obj, torch.fx.GraphModule): + return torch._C.AnyType.get() + elif isinstance(obj, torch.SymInt): + return torch._C.SymIntType.get() + return torch._C._jit_try_infer_type(obj).type() + + +class CArgumentGen: + @staticmethod + def from_hop_argument_info( + arg_idx: int, arg_info: HopArgumentInfo, is_output: bool = False + ) -> Any: + typ = CTypeGen.from_example(arg_info.example_value) + if is_output: + return torch._C.Argument("", typ, None, None, False, None) + + alias_set = set({f"alias::a{arg_idx}"}) if arg_info.is_mutated else set() + alias_info = torch._C._AliasInfo(arg_info.is_mutated, alias_set, alias_set) # type: ignore[attr-defined] + return torch._C.Argument( + arg_info.name, + typ, + None, + arg_info.default_value, + arg_info.kw_only, + alias_info, + ) + + +class HopSchemaGenerator: + def __init__(self, hop: torch._ops.HigherOrderOperator): + self.arg_infos: list[HopArgumentInfo] = [] + self.example_outputs: list[Any] = [] + self.schema_tree_spec: Optional[pytree.TreeSpec] = None + self.hop = hop + + def add_arg( + self, + name: str, + example_value: Any, + default_value: Optional[Any] = None, + is_mutated: bool = False, + kw_only: bool = False, + ) -> None: + if callable(example_value): + assert isinstance( + example_value, (torch.fx.GraphModule, torch._ops.OperatorBase) + ), ( + "Expect callable to be a GraphModule or an. Please call materialize_as_graph first " + f"to turn callable arguments {example_value} into a GraphModule." + ) + _, flat_spec = pytree.tree_flatten(example_value) + if not flat_spec.is_leaf(): + raise RuntimeError( + f"example_value {example_value} is not a leaf node. " + "Please only add flattened inputs to the hop schema. " + "If you need some structure in the arguments, please" + "add_arg for flattened args one by one then " + "call add_schema_tree_spec to register the original pytree " + " spec of the args." + ) + + arg_info = HopArgumentInfoGen.from_example( + example_value=example_value, + name=name, + default_value=default_value, + is_mutated=is_mutated, + kw_only=kw_only, + ) + self.arg_infos.append(arg_info) + + def add_output(self, output: Any) -> None: + self.example_outputs.append(output) + + def add_schema_tree_spec(self, *args: Any, **kwargs: Any) -> None: + """schema tree spec is the tree spec from flattening all inputs to the hop with pytree.tree_flatten + Since torch.FunctionSchema only have proper mutation/alias support for flattened inputs, we need + to store the tree spec in order to reconstruct the inputs to the hop. + """ + self.schema_tree_spec = pytree.tree_flatten((args, kwargs))[1] + + def gen_schema(self) -> torch._C.FunctionSchema: + for i, arg_info in enumerate(self.arg_infos): + arg_spec = pytree.tree_flatten(arg_info.example_value)[1] + if not arg_spec.is_leaf() and self.schema_tree_spec is None: + raise RuntimeError( + f"example_value of arg_infos[{i}] is {arg_info.example_value}, which is not a leaf node. " + "Please call add_schema_tree_spec to add a schema tree spec first. " + "Or consider changing the hop's signature to only take flattened arguments." + ) + + return CFunctionSchemaGen.from_hop_argument_info( + str(self.hop), + self.arg_infos, + HopArgumentInfoGen.from_example(tuple(self.example_outputs), name="out"), + self.schema_tree_spec, + ) + + +class CFunctionSchemaGen: + """ + Note: [HigherOrderOperator schema generation] + Each invocation of a HigherOrderOperator will have a different schema. + For example, the schema of torch.cond varies depending on the true_fn and + false_fn. So we need a way to generate the schema for each invocation of a HOP. + + We want to enforce the following invariants for HOP's schema: + 1. Flattened inputs. There should be no pytree structure in it. + 2. Flattened outputs. Note even if the hop returns a single value, it should be wrapped as a tuple. + 3. No aliasing. This includes inp-inp aliasing, inp-out aliasing and out-out aliasing. + + By enforcing these invariants, we could make HOP's schema meets the requirement of schema parser + and makes hop easier to handle downstream. For example, suppose we have an invoke_quant_test HOP: + + class GraphModule(torch.nn.Module): + def forward(self, l_x_, l_y_): + subgraph_0 = self.subgraph_0 + invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); + + class subgraph_0(torch.nn.Module): + def forward(self, l_x_, l_y_): + add_ = l_x_.add_(1) + matmul = l_x_ @ l_y_ + sin = matmul.sin() + child = sin.cos() + child_1 = l_x_ + l_y_ + child_2 = l_x_ - l_y_ + child_3 = l_x_ @ l_y_ + return (child, child_1, child_2, child_3) + + By encoding the inputs of hop into a list of HopArgumentInfo and output as a single HopArgumentInfo, + we would get the following schema: + invoke_quant_test(Any arg0, Tensor(!) arg1, Tensor arg2, str scheme="\\"nf4\\"") -> (Tensor, Tensor, Tensor, Tensor) + """ + + @staticmethod + def from_hop_argument_info( + op_name: str, + inp_argument_info: list[HopArgumentInfo], + out_argument_info: HopArgumentInfo, + schema_tree_spec: Optional[pytree.TreeSpec], + ) -> Any: + args = [] + for i, arg_info in enumerate(inp_argument_info): + args.append(CArgumentGen.from_hop_argument_info(i, arg_info)) + + # NOTE: we want the output to always be a single argument with torch._C.TupleType. + assert isinstance( + out_argument_info.example_value, tuple + ), f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}" + assert ( + not out_argument_info.is_mutated + ), "out_argument_info.is_mutated should always be set to False." + rets = None + if len(out_argument_info.example_value) == 1: + rets = [CArgumentGen.from_hop_argument_info(0, out_argument_info, True)] + else: + rets = [ + CArgumentGen.from_hop_argument_info( + i, + HopArgumentInfoGen.from_example( + name=f"out{i}", + example_value=val, + default_value=None, + is_mutated=False, + ), + is_output=True, + ) + for i, val in enumerate(out_argument_info.example_value) + ] + + return HopSchema( + op_name, + "", + args, + rets, + False, + False, + schema_tree_spec, + ) + + +class HopSchema(torch._C.FunctionSchema): + def __init__( + self, + name: str, + overload_name: str, + arguments: list[torch._C.Argument], + returns: list[torch._C.Argument], + is_vararg: bool, + is_varret: bool, + schema_tree_spec: Optional[pytree.TreeSpec], + ): + self.tree_spec = schema_tree_spec + self.is_vararg = is_vararg + self.is_varret = is_varret + super().__init__( + name, + overload_name, + arguments, + returns, + self.is_vararg, + self.is_varret, + ) + + def __deepcopy__(self, memo: Any) -> "HopSchema": + # Need to additionally copy the tree_spec since + # it's not a member of torch._C.FunctionSchema + return HopSchema( + self.name, + self.overload_name, + self.arguments, + self.returns, + self.is_vararg, + self.is_varret, + copy.deepcopy(self.tree_spec), + ) + + +def find_hop_schema( + gm: torch.fx.GraphModule, target: Target +) -> list[torch._C.FunctionSchema]: + schemas = [] + for node in gm.graph.find_nodes(op="call_function", target=target): + + def _get_example_value(node: torch.fx.Node) -> Any: + if node.op == "get_attr": + assert isinstance(node.target, str) + return getattr(gm, node.target) + else: + return ( + node.meta["example_value"] + if "example_value" in node.meta + else node.meta["val"] + ) + + fake_args, fake_kwargs = pytree.tree_map_only( + torch.fx.Node, + _get_example_value, + (node.args, node.kwargs), + ) + schema = node.target.gen_schema(*fake_args, **fake_kwargs) + schemas.append(schema) + return schemas diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/strict_mode.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/strict_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..f9257c4455d3dc694a2e33183497c2ed658bb99a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/strict_mode.py @@ -0,0 +1,108 @@ +# mypy: allow-untyped-defs +import torch +import torch._subclasses.functional_tensor +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._functorch.utils import exposed_in +from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + _temp_remove_pre_dispatch_torch_function_mode, + disable_proxy_modes_tracing, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +@exposed_in("torch") +def strict_mode(callable, operands): + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_modes, + ) + + if torch.compiler.is_dynamo_compiling(): + return strict_mode_op(callable, operands) + + with _set_compilation_env(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + with _temp_remove_pre_dispatch_torch_function_mode() as predispatch_mode: + modes = [metadata_mode, predispatch_mode] + modes = [mode for mode in modes if mode is not None] + if modes: + backend = make_eager_backend_with_torch_function_modes(modes) + else: + backend = "eager" + with torch._dynamo.utils.disable_cache_limit(): + return torch.compile( + strict_mode_op, backend=backend, fullgraph=True + )(callable, operands) + + +class StrictMode(HigherOrderOperator): + def __init__(self): + super().__init__("strict_mode") + + def __call__(self, callable, operands): + return super().__call__(callable, operands) + + +strict_mode_op = StrictMode() + + +@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def strict_mode_op_dense(callable, operands): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return callable(*operands) + + +strict_mode_op.py_autograd_impl( + autograd_not_implemented(strict_mode_op, deferred_error=True) +) + + +@strict_mode_op.py_impl(ProxyTorchDispatchMode) +def inner(mode, callable, operands): + return trace_strict_mode(mode, strict_mode_op, callable, operands) + + +def trace_strict_mode(mode, strict_mode_op, callable, operands): + pre_dispatch = getattr(mode, "pre_dispatch", False) + + with disable_proxy_modes_tracing(): + graph = make_fx(callable, pre_dispatch=pre_dispatch)(*operands) + + graph_name = mode.tracer.get_fresh_qualname("strict_graph_") + mode.tracer.root.register_module(graph_name, graph) + + args = (graph, operands) + + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + + out_proxy = mode.tracer.create_proxy( + "call_function", strict_mode_op, proxy_args, {}, name="strict_mode" + ) + + out = graph(*operands) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + +@strict_mode_op.py_impl(FakeTensorMode) +def strict_mode_fake_tensor_mode(mode, callable, operands): + with mode: + true_outs = callable(*operands) + return true_outs + + +@strict_mode_op.py_functionalize_impl +def strict_mode_func(ctx, callable, inputs): + unwrapped_inputs = ctx.unwrap_tensors(inputs) + with ctx.redispatch_to_next(): + functional_callable = ctx.functionalize(callable) + + cond_return = strict_mode_op(functional_callable, unwrapped_inputs) + return ctx.wrap_tensors(cond_return) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/torchbind.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/torchbind.py new file mode 100644 index 0000000000000000000000000000000000000000..89859bed12516d619fa370142351157cb83b352e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/torchbind.py @@ -0,0 +1,164 @@ +# mypy: allow-untyped-defs +import logging +from contextlib import contextmanager + +import torch +from torch._C import DispatchKey # @manual +from torch._functorch._aot_autograd.utils import KNOWN_TYPES +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._library.fake_class_registry import ( + _is_script_object, + _ns_and_class_name, + FakeScriptObject, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.node import has_side_effect +from torch.utils import _pytree as pytree + + +log = logging.getLogger(__name__) + + +# The call_torchbind operator represents a method invocation on a torchbind +# object. The calling convention is: +# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs) +# We do not expect users to write this operator directly. Instead it will be +# emitted by Dynamo when tracing encounters a torchbind object. +class CallTorchBind(HigherOrderOperator): + def __init__(self): + super().__init__("call_torchbind") + + def __call__(self, obj, method, *args, **kwargs): + return super().__call__(obj, method, *args, **kwargs) + + @staticmethod + def schema(obj, method) -> torch.FunctionSchema: + """ + Returns the schema of ``CallTorchbind.__call__``. + """ + assert isinstance(obj, torch._inductor.ir.TorchBindObject) + val = obj.get_real_obj() + schema = val._get_method(method).schema + schema_str = str(schema) + new_schema_str = f"call_torchbind({str(schema.arguments[0].real_type)} {schema.arguments[0].name}," + first_comma_index = schema_str.find(",") + if first_comma_index == -1: + # If no comma is found, find the last closing parenthesis + first_comma_index = schema_str.rfind(") ->") + new_schema_str = new_schema_str + " str method" + schema_str[first_comma_index:] + new_schema = torch._C.parse_schema(new_schema_str) + return new_schema + + +call_torchbind = CallTorchBind() + +# Register this operator as side-effectful with FX. +# TODO: this is not really sufficient. While passes (hopefully) check +# Node.is_impure() and make good decisions, we also assume we can execute the +# graph as many times as we want without changing behavior, which is NOT true of +# ops that mutate torchbind object state. +has_side_effect(call_torchbind) + +_orig_scriptmethod_call = torch.ScriptMethod.__call__ + + +def torchbind_method_redispatch(self, *args, **kwargs): + if _is_script_object(self.raw_owner): + return call_torchbind(self.raw_owner, self.name, *args, **kwargs) + return _orig_scriptmethod_call(self, *args, **kwargs) + + +@contextmanager +def enable_torchbind_tracing(): + """Context manager that acts as a feature flag to enable torchbind tracing + behavior. Once torchbind tracing has been stabilized, we can remove this and + turn it always on. + """ + try: + KNOWN_TYPES.append(torch.ScriptObject) + torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign] + yield + finally: + assert ( + KNOWN_TYPES.pop() is torch.ScriptObject + ), "Someone else messed with KNOWN_TYPES during tracing, exploding." + torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign] + + +@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd) +def call_torchbind_impl(obj, method, *args, **kwargs): + if isinstance(obj, torch.ScriptObject): + return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs) + elif isinstance(obj, FakeScriptObject): + return getattr(obj.wrapped_obj, method)(*args, **kwargs) + else: + raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind") + + +@call_torchbind.py_impl(ProxyTorchDispatchMode) +def inner(mode, *args, **kwargs): + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + + out_proxy = mode.tracer.create_proxy( + "call_function", + call_torchbind, + proxy_args, + proxy_kwargs, + ) + out = call_torchbind(*args, **kwargs) + + obj, method, *_rest_args = args + if isinstance(obj, torch.ScriptObject): + ns, class_name = _ns_and_class_name( + obj._type().qualified_name() # type: ignore[attr-defined] + ) + log.warning( + "Tracing torchbind method %s.%s with real ScriptObject. This may" + " cause the original object being mutated. If this is not intended," + ' You can register a fake class with torch._library.register_fake_class("%s::%s").', + class_name, + method, + ns, + class_name, + ) + + ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + if "val" not in out_proxy.node.meta: + assert out is None or isinstance( + out, (int, float, bool) + ), "Currently, only these constant dtypes are supported to be returned from torchbind methods." + out_proxy.node.meta["val"] = out + return ret + + +# When tracing with fake script object, the call_torchbind op will return a fake tensor +# When tracing with real script object, the call_torchbind op may return a real tensor, +# we need to convert it to fake tensor mannually. Dynamic shape is surpported. +@call_torchbind.py_impl(FakeTensorMode) +def call_torchbind_fake(mode, *args, **kwargs): + with mode: + out = call_torchbind_impl(*args, **kwargs) + return pytree.tree_map_only( + torch.Tensor, + lambda x: mode.from_tensor(x, static_shapes=True) + if not isinstance(x, torch._subclasses.fake_tensor.FakeTensor) + else x, + out, + ) + + +call_torchbind.py_autograd_impl( + autograd_not_implemented(call_torchbind, deferred_error=True) +) + + +@call_torchbind.py_functionalize_impl +def call_torchbind_func(ctx, *args, **kwargs): + from torch._higher_order_ops.effects import handle_effects + + return handle_effects( + ctx.mode._allow_token_discovery, ctx.mode._tokens, call_torchbind, args, kwargs + ) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..ce416d47aeda41b3923220ebafb6a081121f5311 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py @@ -0,0 +1,2051 @@ +import collections +import copy +import dataclasses +import functools +import inspect +import itertools +import logging +import operator +import threading +from collections import defaultdict +from collections.abc import Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing_extensions import Never + +import sympy + +import torch.fx as fx +import torch.utils._pytree as pytree +from torch import SymInt, Tensor +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._prims_common import clone_preserve_strides +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.experimental.symbolic_shapes import guard_scalar +from torch.types import IntLikeType + + +if TYPE_CHECKING: + from triton._C.libtriton.ir import ( + module as TritonIRModule, + operation as TritonIROperation, + ) + + from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._dynamo.variables.constant import ConstantVariable + from torch._dynamo.variables.functions import TritonKernelVariable + from torch._subclasses.functional_tensor import BaseFunctionalizeAPI + from torch.fx.proxy import Proxy + from torch.utils._triton import has_triton + + TritonMetaParamsType = dict[str, int] + TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...] + TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]] + TritonGridType = Union[TritonGridTupleType, TritonGridCallableType] + + if has_triton(): + from triton.runtime.autotuner import Autotuner, Config as TritonConfig + from triton.runtime.jit import JITFunction + else: + + class Autotuner: # type: ignore[no-redef] + pass + + class JITFunction: # type: ignore[no-redef] + pass + + TritonKernelType = Union[Autotuner, JITFunction] + # mypy specifically complains that TritonAutotunerType is not a valid type if Autotuner is not inside of a Union. + TritonAutotunerType = Union[Autotuner] + +log = logging.getLogger("torch._dynamo") + +# e.g. for a host-side Triton TMA API call ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``, +# the metadata will look like ``("experimental", ([50, 60], [32, 15], 4))`` +TMAExperimentalMetadata = tuple[ + str, # type of TMA (should be "experimental") + tuple[ + list[IntLikeType], # dims + list[IntLikeType], # block_dims + IntLikeType, # element_size + ], +] + +# e.g. for host-side Triton TMA API call ``TensorDescriptor.from_tensor(ptr, [32, 64])`` +# the metadata will look like ``("stable", ([32, 64],))`` +TMAStableMetadata = tuple[ + str, # type of TMA ("experimental" or "stable") + tuple[list[IntLikeType],], # block_shape +] + + +def create_tma_experimental_metadata( + dims: list[IntLikeType], + block_dims: list[IntLikeType], + element_size: IntLikeType, +) -> TMAExperimentalMetadata: + return ("experimental", (dims, block_dims, element_size)) + + +def maybe_unpack_tma_experimental_metadata( + tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata] +) -> Optional[tuple[list[IntLikeType], list[IntLikeType], IntLikeType]]: + if not tma_meta or len(tma_meta) != 2: + return None + if tma_meta[0] == "experimental": + return tma_meta[1] # type: ignore[return-value] + return None + + +def create_tma_stable_metadata( + block_shape: list[IntLikeType], +) -> TMAStableMetadata: + return ("stable", (block_shape,)) + + +def maybe_unpack_tma_stable_metadata( + tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata] +) -> Optional[tuple[list[IntLikeType]]]: + if not tma_meta or len(tma_meta) != 2: + return None + if tma_meta[0] == "stable": + return tma_meta[1] # type: ignore[return-value] + return None + + +# TMADescriptorMetadata maps kernel parameter names to the metadata that allows +# reconstructing TMA descriptors from the underlying tensors (passed as kernel +# arguments in the fx graph, instead of the TMA descriptors). +# +# Since there are two TMA APIs (the old "experimental" API and the new "stable" API), +# each entry in the dict is a tuple that starts with a string, either "experimental" +# or "stable". The second entry in the tuple is another tuple, with data that depends +# on the API type (see TMAExperimentalMetadata and TMAStableMetadata above). +# +# These are stored as raw tuples (instead of classes) for ease of serialization. +TMADescriptorMetadata = dict[ + str, # kernel parameter name + Union[TMAExperimentalMetadata, TMAStableMetadata], +] + + +############################################################################### +# Kernel Side Table + + +# We cannot put Triton Kernels into the FX graph as the graph nodes +# do not support arbitrary functions. +# Use a side table. +# We use two dicts so that fetching both the kernel and id are O(1) +class KernelSideTable: + id_to_kernel: dict[int, "TritonKernelType"] = {} + kernel_to_id: dict["TritonKernelType", int] = {} + constant_args: dict[int, dict[str, Any]] = {} + lock = threading.Lock() + + # Returns index on the table + def add_kernel(self, kernel: "TritonKernelType") -> int: + with self.lock: + if kernel in self.kernel_to_id: + return self.kernel_to_id[kernel] + + idx = len(self.id_to_kernel) + self.id_to_kernel[idx] = kernel + self.kernel_to_id[kernel] = idx + return idx + + # Returns the triton kernel at the given index + def get_kernel(self, idx: int) -> "TritonKernelType": + # No need to lock here as fetching from dict is atomic + assert idx in self.id_to_kernel + return self.id_to_kernel[idx] + + # Not every constant arg can be added to the graph. Use this side table + # for constant args. + def add_constant_args(self, args: dict[str, Any]) -> int: + with self.lock: + idx = len(self.constant_args) + self.constant_args[idx] = args + return idx + + # Returns the constant args + def get_constant_args(self, idx: int) -> dict[str, Any]: + # No need to lock here as fetching from dict is atomic + assert idx in self.constant_args + return self.constant_args[idx] + + # Resets the table (only meant to be used in unit tests) + # This is only safe assuming single threaded execution + def reset_table(self) -> None: + self.id_to_kernel = {} + self.kernel_to_id = {} + self.constant_args = {} + + +kernel_side_table = KernelSideTable() + + +############################################################################### +# Mutation Tracker + + +@dataclasses.dataclass(frozen=True) +class Param: + idx: int + + +@dataclasses.dataclass(frozen=True) +class Intermediate: + idx: int + + def fake(self) -> bool: + return self.idx < 0 + + +@dataclasses.dataclass(frozen=True) +class Op: + name: str + fn_call_name: Optional[str] + args: list[Union[Param, Intermediate]] + ret: Intermediate = dataclasses.field(repr=False) + # used for scf.yield: see [Note: scf.yield fix-up] + sub_idx: Optional[int] = None + # used for tt.elementwise_inline_asm + # `is_pure = True` assumes the asm block has no side-effects + is_pure: bool = False + + def __post_init__(self) -> None: + if self.name == "tt.call": + assert self.fn_call_name is not None + else: + assert self.fn_call_name is None + + +def generate_ttir( + kernel: "TritonKernelType", + kwargs: dict[str, Any], + tma_descriptor_metadata: TMADescriptorMetadata, +) -> tuple["TritonIRModule", list[str]]: + """ + Uses Triton's internal code generation to create TTIR + """ + import sympy + import triton + import triton.runtime.jit + from triton.compiler.compiler import ASTSource + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + from torch._inductor.utils import ( + get_triton_attrs_descriptor_version, + triton_version_uses_attrs_dict, + TritonAttrsDescriptorVersion, + ) + from torch.utils._triton import has_triton_tensor_descriptor_host_tma + + triton_version = get_triton_attrs_descriptor_version() + + import torch._inductor.ir + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(kernel, Autotuner): + if len(kernel.configs) > 0: + # If we are autotuning, then it doesn't matter which version gets + # picked for tracing purposes, so lets pick the first one + kwargs = {**kwargs, **kernel.configs[0].kwargs} + kernel = kernel.fn + + assert isinstance(kernel, JITFunction) + + context = triton._C.libtriton.ir.context() + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options({}) + + # ignore backend-specific kwargs same way as in the native Triton code + # https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596 + # why this is important for user-defined Triton kernels on AMD: https://github.com/pytorch/pytorch/issues/140800 + for name in list(kwargs): + if name not in kernel.arg_names and name in options.__dict__: + kwargs.pop(name) + + if len(kwargs) != len(kernel.arg_names): + raise ValueError( + "Incorrect number of arguments passed to kernel: " + f"passed {list(kwargs.keys())}, expected {kernel.arg_names}." + ) + + # Replace all SymExprs with a regular value for TTIR generation + # Replace all FakeTensor/TensorBox with real tensors + # These replacements are needed for triton's type, key and config functions + ordered_args: dict[str, Any] = {} + for name in kernel.arg_names: + a = kwargs[name] + if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)): + ordered_args[name] = 2 + elif ( + stable_meta := maybe_unpack_tma_stable_metadata( + tma_descriptor_metadata.get(name, None) + ) + ) is not None: + from triton.tools.tensor_descriptor import TensorDescriptor + + block_shape = stable_meta[0] + with torch._C._DisableTorchDispatch(): + # need 16-byte aligned strides + elements_per_dim = max(1, 16 // a.dtype.itemsize) + base_tensor = torch.empty( + [elements_per_dim] * len(block_shape), dtype=a.dtype + ) + ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape) + elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)): + with torch._C._DisableTorchDispatch(): + ordered_args[name] = torch.empty(2, dtype=a.dtype) + else: + ordered_args[name] = a + + def is_stable_tensor_descriptor_arg(arg: Any) -> bool: + if has_triton_tensor_descriptor_host_tma(): + from triton.tools.tensor_descriptor import TensorDescriptor + + if isinstance(arg, TensorDescriptor): + return True + return False + + def is_tensor_like_arg(arg: Any) -> bool: + if isinstance(arg, Tensor) or is_stable_tensor_descriptor_arg(arg): + return True + return False + + # Note: one would expect that each input to the triton kernel maps to + # one input parameter in the TTIR. This is _not_ true for TMA descriptors: + # one TMA descriptor gets converted into: + # * one TMA descriptor input + # * N strides, for a rank-N tensor + # * N sizes, for a rank-N tensor + # To account for this, we inject some fake arg names as placeholders for + # the stride and size parameters. + def get_tensor_names(name: str, arg: Any) -> list[str]: + if isinstance(arg, Tensor): + return [name] + if is_stable_tensor_descriptor_arg(arg): + stable_meta = maybe_unpack_tma_stable_metadata( + tma_descriptor_metadata[name] + ) + assert stable_meta is not None + block_shape = stable_meta[0] + tensor_rank = len(block_shape) + names = [name] + names.extend(name + f" STRIDE PLACEHOLDER {i}" for i in range(tensor_rank)) + names.extend(name + f" SIZE PLACEHOLDER {i}" for i in range(tensor_rank)) + return names + return [] + + ordered_tensor_names = list( + itertools.chain.from_iterable( + get_tensor_names(name, arg) for name, arg in ordered_args.items() + ) + ) + + def _get_specialization(args): # type: ignore[no-untyped-def] + # Support multiple triton versions. + # This code basically copies JITFunction.run() logic to get the attrs to construct an ASTSource. + if triton_version == TritonAttrsDescriptorVersion.V1_COMPILER: + return kernel._get_config(*args) + elif triton_version in { + TritonAttrsDescriptorVersion.V2_BACKENDS, + TritonAttrsDescriptorVersion.V3_BACKENDS_TUPLE, + }: + from triton.backends.compiler import AttrsDescriptor # noqa: F401 + + target = triton.runtime.driver.active.get_current_target() + backend_ = triton.compiler.compiler.make_backend(target) + return backend_.get_attrs_descriptor(args, kernel.params) + else: + assert ( + get_triton_attrs_descriptor_version() + == TritonAttrsDescriptorVersion.V4_DICT + ) + # specialize_impl switched to create_specialize_impl in https://github.com/triton-lang/triton/pull/6099 + if hasattr(triton.runtime.jit, "create_specialize_impl"): + try: + # Latest versions of Triton take specialize_extra as an arg to create_specialize_impl + specialize_impl = triton.runtime.jit.create_specialize_impl( + specialize_extra=backend.get_arg_specialization + ) + except TypeError: # Unknown arg `specialize_extra` + # Older versions of Triton take specialize_extra as an arg to specialize_impl + specialize_impl = functools.partial( + triton.runtime.jit.create_specialize_impl(), + specialize_extra=backend.get_arg_specialization, + ) + else: + from triton.runtime.jit import specialize_impl as specialize_impl_orig + + specialize_impl = functools.partial( + specialize_impl_orig, + specialize_extra=backend.get_arg_specialization, + ) + + from triton._utils import find_paths_if, get_iterable_path + + # logic is copied from: binder = create_function_from_signature(self.signature, self.params, backend) + attrvals = [] + for arg, kp in zip(args, kernel.params): + if kp.is_constexpr: + attrvals.append(arg) + else: + spec = specialize_impl( + arg, + is_const=kp.is_const, + specialize_value=not kp.do_not_specialize, + align=not kp.do_not_specialize_on_alignment, + ) + attrvals.append(spec[1]) + + attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str)) + attrs = { + k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs + } + return attrs + + specialization = _get_specialization(ordered_args.values()) + constants = { + name: arg for name, arg in ordered_args.items() if not is_tensor_like_arg(arg) + } + + if (mangle_type := getattr(triton.runtime.jit, "mangle_type", None)) is not None: + + def get_signature_value(idx: int, arg: Any) -> str: + if kernel.params[idx].is_constexpr: + return "constexpr" + return mangle_type(arg) + + else: + + def get_signature_value(idx: int, arg: Any) -> str: + return kernel._type_of(kernel.key_of(arg)) + + if triton_version_uses_attrs_dict(): + # In newer versions of Triton, the signature includes constexpr args + signature = { + name: get_signature_value(i, arg) + for i, (name, arg) in enumerate(ordered_args.items()) + } + else: + # In older versions of Triton, the signature does not include constexpr args + signature = { + name: get_signature_value(i, arg) + for i, (name, arg) in enumerate(ordered_args.items()) + if i not in kernel.constexprs + } + + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + src = ASTSource(kernel, signature, constants, specialization) + + # Triton changes ASTSource.make_ir to take 3/4 arguments. Handle + # backward compatibility here. + make_ir_sig_params = len(inspect.signature(src.make_ir).parameters) + get_codegen_implementation_sig_params = len( + inspect.signature(backend.get_codegen_implementation).parameters + ) + if make_ir_sig_params == 2: + ttir_module = src.make_ir(options, context) + elif make_ir_sig_params == 3: + codegen_fns = backend.get_codegen_implementation() + ttir_module = src.make_ir(options, codegen_fns, context) + else: + codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] + codegen_fns = backend.get_codegen_implementation(*codegen_args) + module_map = backend.get_module_map() + ttir_module = src.make_ir(options, codegen_fns, module_map, context) + if not ttir_module.verify(): + raise RuntimeError("Verification for TTIR module has failed") + + return ttir_module, ordered_tensor_names + + +def ttir_to_functions( + ttir_module: "TritonIRModule", +) -> dict[str, dict[Intermediate, list[Op]]]: + """ + Walk the `ttir_module` bottom up to mine the `functions` from + the structured MLIR entities representing the Triton kernel + (mlir::Operation, mlir::Block, mlir::Region). + """ + functions: dict[str, dict[Intermediate, list[Op]]] = {} + + # block id --> op result (Intermediate) --> one or more ops + op_stack: dict[int, dict[Intermediate, list[Op]]] = defaultdict( + lambda: defaultdict(list) + ) + region_id_to_block_ids: dict[int, list[int]] = defaultdict(list) + block_id_to_block_arg_ids: dict[int, list[int]] = {} + replacements: dict[int, Union[Intermediate, Param]] = {} + reindex_map: dict[int, int] = {} + next_fake_intermediate = 0 + + def reindex(idx: int) -> int: + if idx not in reindex_map: + reindex_map[idx] = len(reindex_map) + return reindex_map[idx] + + def mlir_to_functions(op: "TritonIROperation") -> None: + name: str = op.get_name() + if name == "builtin.module": + # this wraps all tt.func ops + return + + operand_ids: list[int] = [ + reindex(op.get_operand(i).id()) for i in range(op.get_num_operands()) + ] + result_ids: list[int] = [ + reindex(op.get_result(i).id()) for i in range(op.get_num_results()) + ] + + child_block_ids: list[int] = [] + for i in [op.get_region(i).id() for i in range(op.get_num_regions())]: + # as the walk is bottom-up, the region_id_to_block_ids[i] + # must be populated by the time we process the enclosing op + child_block_ids.extend(region_id_to_block_ids[i]) + + parent_block_id = -1 + parent_block = op.get_block() + if parent_block is not None: + parent_block_id = parent_block.id() + if parent_block_id not in block_id_to_block_arg_ids: + block_id_to_block_arg_ids[parent_block_id] = [] + for i in range(parent_block.get_num_arguments()): + block_id_to_block_arg_ids[parent_block_id].append( + reindex(parent_block.get_argument(i).id()), + ) + # the region info is collected via ops' parent blocks to be + # used later when the region's encloding op is traversed + parent_region = parent_block.get_parent() + if parent_region is not None: + region_id_to_block_ids[parent_region.id()].append(parent_block_id) + + nonlocal next_fake_intermediate + + if name == "tt.func": + # for function ops: gather and inline + # the ops from all child blocks + fn_ops = defaultdict(list) + for child_block_id in child_block_ids: + for result, block_fn_ops in op_stack.pop(child_block_id).items(): + for block_fn_op in block_fn_ops: + fn_ops[result].append(block_fn_op) + + # replace the corresponding Intermediates in the + # child op args with the function args (Params) + for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]): + replacements[idx] = Param(i) + + for fn_op_list in fn_ops.values(): + for fn_op in fn_op_list: + for i in range(len(fn_op.args)): + arg = fn_op.args[i] + seen = set() # to break cycles + # there can be transitive replacements, but likely + # no cycles (we keep the `seen` set just in case) + while ( + isinstance(arg, Intermediate) + and arg.idx in replacements + and arg.idx not in seen + ): + seen.add(arg.idx) + arg = fn_op.args[i] = replacements[arg.idx] + + # next function capture starts + # with empty replacements + replacements.clear() + + fn_name = op.get_str_attr("sym_name") + functions[fn_name] = fn_ops + elif child_block_ids: + if name in {"scf.if", "scf.for", "scf.while", "tt.reduce", "tt.scan"}: + # for blocked ops: inline the enclosed ops into + # the parent block + rewire the last op in each + # child block to return the block result + return_ops = [] + for block_id in child_block_ids: + if name == "scf.for": + # example: + # %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (i32) ... + # block args: 2 (%iv, %arg) + # op operands: 4 (%lb, %ub, %step, %init) + # `%arg` is mapping to `%init` + for i, idx in enumerate(block_id_to_block_arg_ids[block_id]): + if i == 0: + next_fake_intermediate -= 1 + replacements[idx] = Intermediate(next_fake_intermediate) + else: + replacements[idx] = Intermediate(operand_ids[i + 2]) + elif name == "scf.while": + # example: + # %3:3 = scf.while (%arg2 = %1, %arg3 = %2, %arg4 = %c0_i32_8) ... + # block args: 3 (%arg2, %arg3, %arg4) + # op operands: 3 (%1, %2, %c0_i32_8) + # `%arg2` is mapping to `%1`, `%arg3` is mapping to `%2`, ... + for i, idx in enumerate(block_id_to_block_arg_ids[block_id]): + replacements[idx] = Intermediate(operand_ids[i]) + elif name == "scf.if": + # the scf block args are ignored by the pass. but, as they + # may be used as operands of the ops inside the block + # (and nested blocks inlined in the current block by now), + # they are replaced by new fake Intermediates to avoid "this + # operand is not returned by any other op in the fn" error + # in the downstream analysis + for idx in block_id_to_block_arg_ids[block_id]: + next_fake_intermediate -= 1 + replacements[idx] = Intermediate(next_fake_intermediate) + else: + assert name in ("tt.reduce", "tt.scan") + # wire the block arguments to the op arguments + num_operands = len(operand_ids) + block_arg_ids = block_id_to_block_arg_ids[block_id] + assert len(block_arg_ids) == 2 * num_operands, ( + f"{name} is expected to have twice as " + "many block arguments as op arguments: " + f"{operand_ids=}, {block_arg_ids=}." + ) + for i, idx in enumerate(block_arg_ids): + # for a tt.reduce/tt.scan op with N arguments, the block + # arguments comprise N reduced values followed by + # N current values corresponding to the N op args + replacements[idx] = Intermediate( + operand_ids[i % num_operands] + ) + + if block_id in op_stack: + block_ops = op_stack.pop(block_id) + if not block_ops: + continue + last_ret, last_ops = block_ops.popitem() + if all( + op.name + in ("scf.yield", "tt.reduce.return", "tt.scan.return") + for op in last_ops + ): + # if last_ops are all return ops, treat them separately + return_ops.extend(last_ops) + else: + # otherwise, return last_ops to the block + block_ops[last_ret] = last_ops + for op_result, child_ops in block_ops.items(): + op_stack[parent_block_id][op_result].extend(child_ops) + + scf_results = [Intermediate(idx) for idx in result_ids] + + if return_ops and all( + (op.name == "scf.yield" and len(result_ids) == len(op.args)) + for op in return_ops + ): + # [Note: scf.yield fix-up] + # + # TL;DR: if our scf.yield takes N args, then we'll create N scf.yield ops to handle each of the + # args. + # + # **Context**: + # During mutation analysis, the analysis pass will identify mutating ops (e.g. tt.store) + # and then DFS upwards towards the parameters of the function. Specifically, the analysis pass + # looks at the mutated arg in tt.store; then looks for its source ops; and then recurses on the + # arguments to each of the source ops. + # + # In the case of scf.if/scf.for, we may have multiple return ops, each passed as an arg + # to scf.yield: + # + # %18:2 = scf.if %... -> (!tt.ptr, !tt.ptr) { + # ... + # scf.yield %1, %2 + # } else { + # scf.yield %3, %4 + # } + # + # And for each of the returns of the scf.if, we'd naively assign the source op of each of the + # return values to be the scf.yields. But the scf.yields take _all_ the returns as arguments. + # Therefore, if _any_ of the return values of the scf.if are mutated, then the analysis pass + # would mark _all_ of the yield args as mutated. + # + # **Solution**: + # For the purposes of this analysis pass, we create N yield ops - one for each + # return-val/yield-arg. In the example above, we'll have two scf.yield's for each branch of the + # scf.if. + + for return_op in return_ops: + for i, (scf_result, yield_arg) in enumerate( + zip(scf_results, return_op.args) + ): + sub_yield_op = Op( + return_op.name, + return_op.fn_call_name, + [yield_arg], + return_op.ret, + sub_idx=i, + ) + op_stack[parent_block_id][scf_result].append(sub_yield_op) + + else: + for scf_result in scf_results: + for return_op in return_ops: + op_stack[parent_block_id][scf_result].append(return_op) + else: + raise RuntimeError( + f"Unknown blocked function: {name}. Can't capture the TTIR." + ) + else: + callee = None + if name == "tt.call": + callee = op.get_flat_symbol_ref_attr("callee") + args: list[Union[Param, Intermediate]] = [ + Intermediate(operand) for operand in operand_ids + ] + block_ops = op_stack[parent_block_id] + + is_pure = False + # Handle the case for tt.elementwise_inline_asm to set `is_pure` for mutation analysis + if name == "tt.elementwise_inline_asm": + is_pure = op.get_bool_attr("pure") + + if result_ids: + for result_id in result_ids: + res = Intermediate(result_id) + block_ops[res].append(Op(name, callee, args, res, is_pure=is_pure)) + else: + next_fake_intermediate -= 1 + fake_res = Intermediate(next_fake_intermediate) + block_ops[fake_res].append( + Op(name, callee, args, fake_res, is_pure=is_pure) + ) + + ttir_module.walk(mlir_to_functions) + + return functions + + +class MemoizeWithCycleCheck: + fn: Callable[..., Any] + cache: dict[tuple[Any], Any] + + def __init__(self, fn: Callable[..., Any]) -> None: + self.fn = fn + self.reset() + + def __call__( + self, + functions: dict[str, dict[Intermediate, list[Op]]], + fn_name: str, + *args: Any, + ) -> list[bool]: + key: tuple[Any, ...] = (fn_name, *args) + if key not in self.cache: + self.cache[key] = None + self.cache[key] = self.fn(functions, fn_name, *args) + if self.cache[key] is None: + raise RuntimeError("Recursion is not supported") + return self.cache[key] + + def reset(self) -> None: + self.cache = {} + + +@MemoizeWithCycleCheck +def get_tma_stores( + functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str +) -> set[Union[Intermediate, Param]]: + """ + Identifies all intermediates and parameters that are written to by a + `tt.experimental_descriptor_store`. It tracks only the specific values + written to via experimental_descriptor_store and the input values to + `tt.reinterpret_tensor_descriptor` used to construct the direct inputs + to tt.experimental_descriptor_store - not any recursive values + used to construct those values. + + For example: for + tt.reinterpret_tensor_descriptor(Intermediate(idx=0), ...) + Intermediate(idx=1) = tt.experimental_descriptor_store(Intermediate(idx=0), ...) + this function will return [Intermediate(idx=0), Intermediate(idx=1)], + + However + Intermediate(idx=4) = arith.addptr(Intermediate(idx=2), Intermediate(idx=3)) + Intermediate(idx=5) = tt.experimental_descriptor_store(Intermediate(idx=4), ...) + tt.experimental_descriptor_store(Intermediate(idx=5), ...) + this function will mark only idx=4 and idx=5 (but not idx=2 or idx=3) + + If an intermediate/parameter is passed into a function and is written to + via experimental_descriptor_store within that function, the argument to the + function will also be marked. + """ + + result: set[Union[Intermediate, Param]] = set() + + ops = functions[fn_name] + for op_list in ops.values(): + for op in op_list: + if op.name == "tt.call": + assert op.fn_call_name in functions + tma_stores = get_tma_stores(functions, op.fn_call_name) + for i, inp in enumerate(op.args): + if Param(idx=i) in tma_stores: + result.add(inp) + elif op.name == "tt.experimental_descriptor_store": + assert len(op.args) >= 1 + result.add(op.args[0]) + + for val in list(result): + if val in ops: + if not isinstance(val, Intermediate): + continue + for op in ops[val]: + if op.name == "tt.reinterpret_tensor_descriptor": + assert len(op.args) >= 1 + result.add(op.args[0]) + + return result + + +@MemoizeWithCycleCheck +def analyze_kernel_mutations( + functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str, num_args: int +) -> list[bool]: + """ + Analyzes the graph to detect all sinks from a predefined list of sinks + by using triton's MemWrite trait list. NOTE: What if triton exposed this? + From each sink, it traverses the CFG backwards to identify all the input + pointers that are mutated. + """ + # Name of mutation op to mutated parameter indices + # List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td + # All the OPs that have MemWrite trait. + # What if Triton exposed this? + MUTATION_OPS = { + "tt.store": [0], + "tt.atomic_cas": [0], + "tt.atomic_rmw": [0], + "tt.experimental_descriptor_store": [0], + "tt.experimental_tensormap_create": [0], + "tt.descriptor_store": [0], + } + # Ops that we want to bail out on + UNKNOWN_OPS = {"tt.elementwise_inline_asm"} + + stack: list[Union[Param, Intermediate]] = [] + visited = set() + ops = functions[fn_name] + tma_stores = get_tma_stores(functions, fn_name) + + for op_list in ops.values(): + for op in op_list: + # If we encounter an operation with effects that cannot be reliably analyzed + # (e.g. `tt.elementwise_inline_asm`), we assume it does not mutate any input parameters. + if op.name in UNKNOWN_OPS: + if op.name == "tt.elementwise_inline_asm" and op.is_pure: + log.warning( + "TTIR mutation analysis: Skipping pure tt.elementwise_inline_asm op (is_pure=True)" + ) + continue + raise RuntimeError( + f"ttir analysis hit an op we do not know how to analyze: {op.name}" + ) + + if op.name == "tt.experimental_tensormap_create": + # Note: this is how we implement experimental_descriptor_store mutation analysis. + # for on-device TMA. + # experimental_tensormap_store(a, b, ...) stores b to the location specified + # by descriptor in the memory of a. + # To track this, we first find all the intermediates/params to which we store via + # experimental_tensormap_store (get_tma_stores, called above). Then, during this + # analysis we wait to find the corresponding experimental_tensormap_create (if it + # exists), at which point we will mark the global_ptr as mutated (as done below). + assert len(op.args) >= 2 + if op.args[0] in tma_stores: + stack.append(op.args[1]) + + if op.name == "tt.call": + assert op.fn_call_name in functions + mutations = analyze_kernel_mutations( + functions, op.fn_call_name, len(op.args) + ) + stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated) + else: + stack.extend(op.args[idx] for idx in MUTATION_OPS.get(op.name, [])) + + # The following is an iterative DFS algorithm + mutated = [False] * num_args + while stack: + arg = stack.pop() + if arg in visited: + continue + + visited.add(arg) + + if isinstance(arg, Param): + if arg.idx >= num_args: + # This is an argument defined in the kernel, not passed in + continue + mutated[arg.idx] = True + elif isinstance(arg, Intermediate) and not arg.fake(): + for op in ops[arg]: + # Skip arguments to load + if op.name != "tt.load": + stack.extend(op.args) + return mutated + + +def identify_mutated_tensors( + kernel: "TritonKernelType", + kwargs: dict[str, Any], + tma_descriptor_metadata: TMADescriptorMetadata, +) -> list[str]: + """ + Given a triton kernel and the arguments for this kernel, this function + 1) Retrieves the TTIR converted version of the kernel from Triton's API. + 2) Parses the TTIR and creates a control flow graph + 3) Analyzes the graph to detect all input tensor mutations + """ + + ttir_module = None + functions = None + try: + ttir_module, ordered_tensor_names = generate_ttir( + kernel, kwargs, tma_descriptor_metadata + ) + + # extract functions from TTIR using MLIR bindings exposed by Triton code + functions = ttir_to_functions(ttir_module) + + assert functions is not None + kernel_name = next(iter(functions.keys())) + # Triton codegen modifies the name + assert kernel.fn.__name__ in kernel_name + # Reset the cache between top level invocations + # The cache for analyze kernel mutations is mainly used for cycle + # detection, so each top level invocation needs a clean cache + analyze_kernel_mutations.reset() + get_tma_stores.reset() + mutations = analyze_kernel_mutations( + functions, kernel_name, len(ordered_tensor_names) + ) + + return [ + ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated + ] + except Exception: + log.warning( + "Encountered an exception in identify_mutated_tensors, assuming every input is mutated", + exc_info=True, + ) + if ttir_module is not None: + log.debug("TTIR:\n%s", str(ttir_module)) + if functions is not None: + log.debug("functions:") + for name, fn in functions.items(): + log.debug("===\t%s\t===", name) + for ret, ops in fn.items(): + log.debug("%s\t=>\t%s", ret, ops) + return [key for key, value in kwargs.items() if isinstance(value, Tensor)] + + +############################################################################### +# Triton Kernel Wrappers + + +# Used for wrapping a Triton Kernel +class TritonKernelWrapperMutation(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("triton_kernel_wrapper_mutation", cacheable=True) + + def __call__( + self, + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], + ) -> Any: + return super().__call__( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, + kwargs=kwargs, + ) + + +triton_kernel_wrapper_mutation = TritonKernelWrapperMutation() + + +# Used for wrapping a Triton Kernel in a functional manner +class TritonKernelWrapperFunctional(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("triton_kernel_wrapper_functional", cacheable=True) + + def __call__( + self, + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], + tensors_to_clone: list[str], + ) -> dict[str, Any]: + return super().__call__( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, + kwargs=kwargs, + tensors_to_clone=tensors_to_clone, + ) + + +triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() + + +@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd) +def triton_kernel_wrapper_mutation_dense( + *, + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], +) -> None: + from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code + + kernel = kernel_side_table.get_kernel(kernel_idx) + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + + if len(grid) == 1: + grid_fn = grid[0] + else: + fn_name, code = user_defined_kernel_grid_fn_code( + kernel.fn.__name__, kernel.configs, grid + ) + namespace: dict[str, Any] = {} + exec(code, namespace) + grid_fn = namespace[fn_name] + + if tma_descriptor_metadata: + # as we need to launch the kernel here, we "unwrap" the + # tma_descriptor_metadata, create the TMA descriptors + # from it, and replace the tensors in the kwargs by the + # correspoinding TMA descriptors before launching + kwargs = kwargs.copy() + for k, v in tma_descriptor_metadata.items(): + tensor = kwargs[k] + if (exp_meta := maybe_unpack_tma_experimental_metadata(v)) is not None: + from triton.tools.experimental_descriptor import ( # noqa: F401 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + + dims, block_dims, element_size = exp_meta + create_tma_descriptor = ( + create_1d_tma_descriptor + if len(dims) == 1 + else create_2d_tma_descriptor + ) + kwargs[k] = create_tma_descriptor( + tensor.data_ptr(), + *dims, + *block_dims, + element_size, + ) + else: + stable_meta = maybe_unpack_tma_stable_metadata(v) + assert stable_meta is not None + from triton.tools.tensor_descriptor import TensorDescriptor + + block_shape = stable_meta[0] + kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape) + + # move as many positional arguments from dicts to args as we + # can to circumvent the bug with the kwargs and pre_/post_hook: + # https://github.com/triton-lang/triton/issues/5082 + # TODO: remove this when the Triton issue above is fixed + args = [] + # copy kwargs and constant_args here to + # avoid mutating the original inputs + kwargs = kwargs.copy() + constant_args = constant_args.copy() + for name in kernel.arg_names: + if name in kwargs: + args.append(kwargs.pop(name)) + elif name in constant_args: + args.append(constant_args.pop(name)) + else: + break + + kernel[grid_fn](*args, **kwargs, **constant_args) + + +@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode) +def triton_kernel_wrapper_mutation_fake_tensor_mode( + mode: FakeTensorMode, + *, + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], +) -> None: + with mode: + return None + + +@triton_kernel_wrapper_mutation.py_impl(DispatchKey.Meta) +def _( + *, + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], +) -> None: + return None + + +def trace_triton_kernel_wrapper( + proxy_mode: ProxyTorchDispatchMode, + func_overload: Callable[..., Any], + node_args: dict[str, Any], +) -> Optional[dict[str, Any]]: + with disable_proxy_modes_tracing(): + out = func_overload(**node_args) + + proxy_args = pytree.tree_map( + proxy_mode.tracer.unwrap_proxy, node_args # type: ignore[union-attr] + ) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", + func_overload, + (), + proxy_args, + name=func_overload.__name__ + "_proxy", + ) + + ret = track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + return ret + + +@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode) +def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, + *, + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], +) -> None: + trace_triton_kernel_wrapper( + mode, + triton_kernel_wrapper_mutation, + { + "kernel_idx": kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": grid, + "tma_descriptor_metadata": tma_descriptor_metadata, + "kwargs": kwargs, + }, + ) + + return None + + +def get_mutated_tensors( + kernel_idx: int, + constant_args_idx: int, + kwargs: dict[str, Any], + tma_descriptor_metadata: TMADescriptorMetadata, +) -> list[str]: + kernel = kernel_side_table.get_kernel(kernel_idx) + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + return identify_mutated_tensors( + kernel, {**kwargs, **constant_args}, tma_descriptor_metadata + ) + + +@triton_kernel_wrapper_mutation.py_functionalize_impl +def triton_kernel_wrapper_mutation_functionalize( + ctx: "BaseFunctionalizeAPI", + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], +) -> None: + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type] + # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each + # other, and one gets mutated in kernel, and later another gets mutated, + # they are no longer equal. Fix this by graph breaking on this condition + # earlier in dynamo. + tensors_to_clone = get_mutated_tensors( + kernel_idx, constant_args_idx, unwrapped_kwargs, tma_descriptor_metadata + ) + with ctx.redispatch_to_next(): + unwrapped_outputs = triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, + kwargs=unwrapped_kwargs, + tensors_to_clone=tensors_to_clone, + ) + + assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys())) + for key, output_arg in unwrapped_outputs.items(): + if not isinstance(output_arg, Tensor): + continue + input_arg = kwargs[key] + assert isinstance(input_arg, Tensor) + + ctx.replace(input_arg, output_arg) + # indicate that above replace is hidden from autograd + ctx.mark_mutation_hidden_from_autograd(input_arg) + ctx.commit_update(input_arg) + ctx.sync(input_arg) + return None + + +@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd) +def triton_kernel_wrapper_functional_dense( + *, + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], + tensors_to_clone: list[str], +) -> dict[str, Any]: + # TODO(oulgen): For performance reasons, we want to ensure that these + # `clone_preserve_strides` calls are never executed at runtime + # (inductor should always optimize them away). + # Requires https://github.com/pytorch/pytorch/issues/109240 + kwargs = { + key: (clone_preserve_strides(val) if key in tensors_to_clone else val) + for key, val in kwargs.items() + } + triton_kernel_wrapper_mutation( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, + kwargs=kwargs, + ) + return {key: val for key, val in kwargs.items() if key in tensors_to_clone} + + +@triton_kernel_wrapper_functional.py_impl(FakeTensorMode) +def triton_kernel_wrapper_functional_fake_tensor_mode( + mode: FakeTensorMode, + *, + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], + tensors_to_clone: list[str], +) -> dict[str, Any]: + # TODO(oulgen): For performance reasons, we want to ensure that these + # `clone_preserve_strides` calls are never executed at runtime + # (inductor should always optimize them away). + # Requires https://github.com/pytorch/pytorch/issues/109240 + with mode: + return { + key: clone_preserve_strides(val) + for key, val in kwargs.items() + if key in tensors_to_clone + } + + +@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode) +def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, + *, + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], + tensors_to_clone: list[str], +) -> dict[str, Any]: + ret = trace_triton_kernel_wrapper( + mode, + triton_kernel_wrapper_functional, + { + "kernel_idx": kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": grid, + "tma_descriptor_metadata": tma_descriptor_metadata, + "kwargs": kwargs, + "tensors_to_clone": tensors_to_clone, + }, + ) + assert ret is not None + return ret + + +@triton_kernel_wrapper_functional.py_functionalize_impl +def triton_kernel_wrapper_functional_functionalize( + ctx: "BaseFunctionalizeAPI", + kernel_idx: int, + constant_args_idx: int, + grid: list["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: dict[str, Any], + tensors_to_clone: list[str], +) -> dict[str, Any]: + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type] + with ctx.redispatch_to_next(): + outputs = triton_kernel_wrapper_functional( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, + kwargs=unwrapped_kwargs, + tensors_to_clone=tensors_to_clone, + ) + return ctx.wrap_tensors(outputs) # type: ignore[return-value,arg-type] + + +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU) + +triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) +triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU) + + +############################################################################### +# The "TritonHOPifier": a class that transforms a call to a triton kernel into +# a call to the triton_kernel_wrapper_mutation HOP. + + +class TritonHOPifier: + """Orchestrator for converting a user-defined triton kernel into a call + to the triton_kernel_wrapper_mutation HOP. + + It has two main use cases. + + 1. When Dynamo sees a triton kernel, it wraps it into a TritonKernelVariable + and uses the TritonHOPifier to convert calls to the TritonKernelVariable + into a call to the HOP. + + 2. In order to capture a user-defined triton kernel while performing + tracing (via make_fx or non-strict export), a user must annotate their + triton kernel with the `wrap_triton` decorator. The decorator uses + TritonHOPifier to convert calls to the triton kernel into a call + to the HOP (which can then be traced). + + Because Dynamo has its own calling conventions for e.g. invoking a user-defined function + TritonHOPifier is an abstract class that can be overridden by its subclasses. + """ + + def raise_unsupported(self, msg: str) -> Never: + raise NotImplementedError("abstract method") + + def is_callable(self, maybe_callable: Any) -> bool: + raise NotImplementedError("abstract method") + + def get_value(self, val: Any) -> Any: + raise NotImplementedError("abstract method") + + def call_grid( # type: ignore[no-untyped-def] + self, + grid, + meta, + tx, + ) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]: + raise NotImplementedError("abstract method") + + def wrap_user_defined_obj( + self, + user_obj: Any, + tx: Optional["InstructionTranslator"], + variable: Optional[ + Union["TritonKernelVariable", "TraceableTritonKernelWrapper"] + ], + name: str, + ) -> Any: + raise NotImplementedError("abstract method") + + def call_user_defined_fn( + self, + user_fn: Callable[..., Any], + args: list, + kwargs: dict, + tx: Optional["InstructionTranslator"], + variable: Optional[ + Union["TritonKernelVariable", "TraceableTritonKernelWrapper"] + ], + ) -> Any: + raise NotImplementedError("abstract method") + + def maybe_unpack_configs( + self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"] + ) -> list["TritonConfig"]: + raise NotImplementedError("abstract method") + + def maybe_unpack_heuristic_result(self, result: Any) -> Any: + raise NotImplementedError("abstract method") + + @staticmethod + def do_prune_configs( # type: ignore[no-untyped-def] + autotuner: "TritonAutotunerType", + early_config_prune: Optional[Callable], + perf_model: Optional[Callable], + top_k: float, + configs: list, + named_args: dict, + kwargs: dict, + ) -> list["TritonConfig"]: + # Reimplement autotuner.prune_configs(...) here + # see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950 + # We do this to avoid calling prune_configs, which in turn calls early_config_prune and perf_model + # These are both user-defined functions which can contain side effects, so we want to sandbox them in Dynamo + + if early_config_prune: + configs = early_config_prune(configs, named_args, **kwargs) + + if perf_model: + # we assert top_k is a float before calling this + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(configs) * top_k) + elif not isinstance(top_k, int): + """ + Slice index must be an integer, SupportsIndex or None + """ + raise TypeError( + "Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int" + ) + if len(configs) > top_k: + est_timing = [ + ( + config, + float( + perf_model(**named_args, **kwargs, **config.all_kwargs()) + ), + ) + for config in configs + ] + configs = [ + config[0] + for config in sorted(est_timing, key=operator.itemgetter(1))[:top_k] + ] + return configs + + def call_HOP( # type: ignore[no-untyped-def] + self, + variable, + grids, + combined_args: dict[str, Any], + tx, + ) -> Optional["ConstantVariable"]: + raise NotImplementedError("abstract method") + + def check_grid( # type: ignore[no-untyped-def] + self, grid + ) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]: + raise NotImplementedError("abstract method") + + def init_variable( + self, + variable: Union["TraceableTritonKernelWrapper", "TritonKernelVariable"], + kernel: "TritonKernelType", + kernel_idx: Optional[int], + grid: Optional["TritonGridType"], + ) -> None: + from triton.runtime.autotuner import Autotuner + + assert kernel is not None + + variable.kernel = kernel + variable.kernel_idx = kernel_side_table.add_kernel(kernel) + + assert kernel_idx is None or variable.kernel_idx == kernel_idx + + variable.grid = grid + + if isinstance(kernel, Autotuner): + import torch + import torch._dynamo + + # We only support configs, keys, and restore_value arguments + # of triton.autotune. Make sure other arguments are defaulted. + defaults = inspect.signature(Autotuner.__init__).parameters + # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep. + # The call to get_first_attr is to maintain backward-compatibility. + + def defaults_ok( + attr: str, alternates: tuple[str, ...], values: tuple[Any, ...] + ) -> bool: + if attr not in defaults: + return True + value = torch._dynamo.utils.get_first_attr(kernel, attr, *alternates) + if value == defaults[attr].default: + return True + return value in values + + if ( + not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args + and ( + not defaults_ok("num_warmups", ("warmup",), (25, None)) + or not defaults_ok("num_reps", ("rep",), (100, None)) + or not defaults_ok("use_cuda_graph", (), (False,)) + ) + ): + self.raise_unsupported( + "Only configs, keys, restore_value, and reset_to_zero are supported for triton.autotune" + ) + if ( + not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args + and ( + # pre_hook requires running arbitrary code at runtime, which we cannot handle at this time + # https://github.com/pytorch/pytorch/issues/139059 + # we can't support pre_hook or post_hook in user defined triton kernels at the moment, + # as they require the ability to execute code at runtime (AOTI can't support this) + ( + hasattr(kernel, "user_defined_pre_hook") + and kernel.user_defined_pre_hook is not False + ) + or ( + hasattr(kernel, "user_defined_post_hook") + and kernel.user_defined_post_hook is not False + ) + or ( + # Check Config passed to autotuner in configs + any(cfg.pre_hook is not None for cfg in kernel.configs) + ) + ) + ): + self.raise_unsupported( + "pre_hook and post_hook are not supported in triton.Autotune or triton.Config" + ) + + def call_getitem( + self, + variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"], + args: Sequence[Any], + ) -> Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]: + # __getitem__ should only be called if we don't already have a grid + # Only grid needs to be passed + if variable.grid is not None or len(args) != 1: + self.raise_unsupported( + "Triton kernels should be called with only a single grid" + ) + + return type(variable)( + kernel=variable.kernel, + kernel_idx=variable.kernel_idx, + grid=args[0], + ) + + def call_run( + self, + variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"], + args: Sequence[Any], + kwargs: dict[str, Any], + tx: Optional["InstructionTranslator"], + ) -> Optional["ConstantVariable"]: + if "grid" not in kwargs: + self.raise_unsupported("Triton kernel requires to be called with a grid") + grid = kwargs.pop("grid") + kwargs.pop("warmup", None) + # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args) + return self.call_triton_kernel( + type(variable)( + kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid + ), + args, + kwargs, + tx, + ) + + def call_triton_kernel( + self, + variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"], + args: Sequence[Any], + kwargs: dict[str, Any], + tx: Optional["InstructionTranslator"], + ) -> Optional["ConstantVariable"]: + from triton import JITFunction + from triton.runtime.autotuner import autotune, Autotuner, Config, Heuristics + + # Check if num_ctas is in kwargs + if "num_ctas" in kwargs: + self.raise_unsupported( + "Passing num_ctas directly to the Triton kernel is not supported. " + "Please use a Config in @triton.autotune instead." + ) + + # Make sure the kernel has a grid + if variable.grid is None: + self.raise_unsupported("Triton kernels should always be called with a grid") + + # raise an exception if there are multiple @triton.autotune decorators + iter_kernel = variable.kernel + autotuner_count = 0 + while not isinstance(iter_kernel, JITFunction): + if isinstance(iter_kernel, Autotuner): + autotuner_count += 1 + if autotuner_count > 1: + self.raise_unsupported( + "Passing multiple @triton.autotune decorators is not supported. " + "Please use a single @triton.autotune decorator instead." + ) + iter_kernel = iter_kernel.fn + + # Process the @triton.heuristics decorator: + # - We know there is only 1 autotuner decorator here + # - We can apply the heuristic to all triton.Configs in the order that the decorators appear + # This way, when the config is selected, the heuristics have already been applied. + # - Decorators that appear *before* the autotuner are already processed correctly + if isinstance(variable.kernel, Autotuner) and isinstance( + variable.kernel.fn, Heuristics + ): + # unwrap the heuristics decorator, we don't need it anymore + # variable.kernel ==> Autotuner + # variable.kernel.fn ==> Heuristics + # ... + # There can be arbitrarily many heuristics wrappers here! + # ... + # variable.kernel.fn ==> JITFunction + + # Copy the configs, we are going to be modifying them + new_configs = copy.deepcopy(variable.kernel.configs) + + named_args = dict(zip(variable.kernel.arg_names, args)) + + # Iterate through all of the heuristics wrappers that come after the autotune wrapper + iter_kernel = variable.kernel.fn + while isinstance(iter_kernel, Heuristics): + # For each config, apply the heuristic fn(s) + for config_idx in range(len(new_configs)): + for kwarg_key, heuristic_fn in iter_kernel.values.items(): + # Run heuristics on the combined configs + kwargs + heuristic_result = self.call_user_defined_fn( + heuristic_fn, + [ + { + **named_args, + **kwargs, + **new_configs[config_idx].__dict__["kwargs"], + }, + ], + {}, + tx, + variable, + ) + + # Update the kwargs in each config + # maybe_unpack_heuristic_result raises unsupported if the value is non-constant + new_configs[config_idx].__dict__["kwargs"][ + kwarg_key + ] = self.maybe_unpack_heuristic_result(heuristic_result) + + iter_kernel = iter_kernel.fn + assert isinstance(iter_kernel, JITFunction) + prune_configs_by = { + "perf_model": variable.kernel.perf_model, + "early_config_prune": variable.kernel.early_config_prune, + "configs_top_k": variable.kernel.configs_top_k, + } + new_kernel = autotune( + configs=new_configs, key=[], prune_configs_by=prune_configs_by + )(iter_kernel) + # create a new variable to contain the new (wrapped) kernel; + # skip kernel_idx to get a new record in the kernel side table + new_var = type(variable)(new_kernel, None, variable.grid) + return self.call_triton_kernel(new_var, args, kwargs, tx) + + SPECIAL_CONFIG_NAMES = { + "num_warps", + "num_stages", + "num_ctas", + "num_consumer_groups", + "num_buffers_warp_spec", + } + + # move special config names to configs out of kwargs + special_kwargs = {} + for name in SPECIAL_CONFIG_NAMES: + if name in kwargs: + # remove special kwargs from `kwargs` + val = kwargs.pop(name) + special_kwargs[name] = self.get_value(val) + + if special_kwargs: + if isinstance(variable.kernel, Autotuner): + # if there is Autotuner already, set + # special kwargs to each of its configs + new_configs = copy.deepcopy(variable.kernel.configs) + for config in new_configs: + config.__dict__.update(special_kwargs) + prune_configs_by = { + "perf_model": variable.kernel.perf_model, + "early_config_prune": variable.kernel.early_config_prune, + "configs_top_k": variable.kernel.configs_top_k, + } + + new_kernel = autotune( + configs=new_configs, key=[], prune_configs_by=prune_configs_by + )(variable.kernel.fn) + else: + # if there is no Autotuner, wrap the kernel into a + # new one with a single config with special kwargs + new_config = Config(kwargs={}, **special_kwargs) + + new_kernel = autotune(configs=[new_config], key=[])(variable.kernel) + + # create a new variable to contain the new (wrapped) kernel; + # skip kernel_idx to get a new record in the kernel side table + new_var = type(variable)(new_kernel, None, variable.grid) + return self.call_triton_kernel(new_var, args, kwargs, tx) + + if isinstance(variable.kernel, Autotuner): + special_param_names = [] + for name in SPECIAL_CONFIG_NAMES: + if name in variable.kernel.fn.arg_names: + special_param_names.append(name) + + if special_param_names: + # If the Triton kernel has SPECIAL_CONFIG_NAMES in parameters, those should + # be passed from the kernel configs: the behavior of Triton runtime is that + # those values get folded into the kernel arguments iff there are parameters + # with the same name. Normally the values of those parameters are defined + # outside the `kwargs` part of the autotuning configs. Here we move them to + # the `kwargs` part (if they're absent there) to facilitate passing them as + # arguments to the kernel downstream. + updated = False + new_configs = copy.deepcopy(variable.kernel.configs) + for config in new_configs: + for name in special_param_names: + if name not in config.__dict__["kwargs"]: + assert ( + name in config.__dict__ + ), f"{name} must be in autotuning configs to be used as a kernel parameter" + config.__dict__["kwargs"][name] = config.__dict__[name] + updated = True + + if updated: + prune_configs_by = { + "perf_model": variable.kernel.perf_model, + "early_config_prune": variable.kernel.early_config_prune, + "configs_top_k": variable.kernel.configs_top_k, + } + + new_kernel = autotune( + configs=new_configs, prune_configs_by=prune_configs_by, key=[] + )(variable.kernel.fn) + new_var = type(variable)(new_kernel, None, variable.grid) + return self.call_triton_kernel(new_var, args, kwargs, tx) + + # These are the default values in upstream Triton + # see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950 + default_perf_model = None + default_early_config_prune = None + + # run prune_configs_by + if isinstance(variable.kernel, Autotuner) and ( + variable.kernel.perf_model != default_perf_model + or variable.kernel.early_config_prune != default_early_config_prune + ): + # Prune the configs + named_args = dict(zip(variable.kernel.arg_names, args)) + + # The source information is important here so the guards are installed correctly + + wrapped_early_configs_prune = self.wrap_user_defined_obj( + variable.kernel.early_config_prune, + tx, + variable, + "early_config_prune", + ) + + wrapped_perf_model = self.wrap_user_defined_obj( + variable.kernel.perf_model, tx, variable, "perf_model" + ) + + wrapped_configs_top_k = self.wrap_user_defined_obj( + variable.kernel.configs_top_k, tx, variable, "configs_top_k" + ) + + wrapped_configs = self.wrap_user_defined_obj( + variable.kernel.configs, tx, variable, "configs" + ) + + pruned_configs = self.call_user_defined_fn( + self.do_prune_configs, + [ + variable, + wrapped_early_configs_prune, + wrapped_perf_model, + wrapped_configs_top_k, + wrapped_configs, + named_args, + kwargs, + ], + {}, + tx, + variable, + ) + + pruned_configs = self.maybe_unpack_configs(pruned_configs, tx) + + # after pruning the configs, create a new autotuner object with + # these configs and recurse. + new_kernel = autotune(configs=pruned_configs, key=[])(variable.kernel.fn) + # create a new variable to contain the new (wrapped) kernel; + # skip kernel_idx to get a new record in the kernel side table + new_var = type(variable)(new_kernel, None, variable.grid) + return self.call_triton_kernel(new_var, args, kwargs, tx) + + # Both for grid's meta as well as for the kernel, we need combined + # args and kwargs combined and normalized + combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs} + + # precompute the grid for the kernel + configs = ( + [config.kwargs for config in variable.kernel.configs] + if isinstance(variable.kernel, Autotuner) + else [{}] + ) + grids = [] + for config_args in configs: + # If the grid is a function, then lets execute it and convert it to + # a list + grid = variable.grid + assert grid is not None + if self.is_callable(grid): + # Populate the special "meta" argument to call the grid function + meta = {**combined_args_raw, **config_args} + grid = self.call_grid(grid, meta, tx) # type: ignore[arg-type] + grids.append(self.check_grid(grid)) + + for i in range(len(grids)): + if not isinstance(grids[i], tuple): + self.raise_unsupported("Only tuple grids are supported") + # inductor expects all grids to be 3-tuple so lets make it + if len(grids[i]) == 1: + grids[i] = (grids[i][0], 1, 1) + elif len(grids[i]) == 2: + grids[i] = (grids[i][0], grids[i][1], 1) + elif len(grids[i]) > 3: + self.raise_unsupported("Grid can have at most rank 3") + + assert len(grids) != 0 + if isinstance(variable.kernel, JITFunction): + constexprs = variable.kernel.constexprs + else: + # If we are looking at an @triton.autotune decorator, the nested function should be a JITFunction + # This is because we don't support @triton.heuristics or nested @triton.autotune decorators yet + assert isinstance(variable.kernel, Autotuner) + constexprs = variable.kernel.fn.constexprs + + for idx, arg_name in enumerate(variable.kernel.arg_names): + if idx in constexprs: + if arg_name in combined_args_raw: + # [Note: Specialize tl.constexpr args in user-defined triton kernels] + # This arg is marked as tl.constexpr. That means that triton will recompile every time + # this value changes. + # https://github.com/pytorch/pytorch/issues/136504 + # One option is to correctly pass the symints in so that the symbolic expressions are defined + # when the triton code is being executed. + # But since triton will have to recompile either way, we instead just specialize on the value. + # + # Depending on the type of `variable` we might expect different types for the symbolic args: + # either SymNodeVariables (for TritonKernelVariables) or SymInts (TracingTritonKernelWrapper) + combined_args_raw[arg_name] = variable.specialize_symbolic( + combined_args_raw[arg_name] + ) + return self.call_HOP(variable, grids, combined_args_raw, tx) + + +############################################################################### +# Helpers for wrap_triton API that makes a user-defined triton kernel traceable into +# a graph via make_fx or non-strict export (coming soon) + + +class TracingTritonHOPifier(TritonHOPifier): + def raise_unsupported(self, msg: str) -> Never: + raise RuntimeError(msg) + + def is_callable(self, maybe_callable: Any) -> bool: + return callable(maybe_callable) + + def get_value(self, val: Any) -> Any: + return val + + def call_grid( + self, + grid: "TritonGridCallableType", + meta: "TritonMetaParamsType", + tx: None, + ) -> tuple[Union[int, sympy.Expr, SymInt], ...]: + assert tx is None + assert isinstance(meta, dict) + assert callable(grid) + return grid(meta) + + def wrap_user_defined_obj( + self, + user_obj: Any, + tx: Optional["InstructionTranslator"], + variable: Optional[ + Union["TritonKernelVariable", "TraceableTritonKernelWrapper"] + ], + name: str, + ) -> Any: + assert tx is None + return user_obj + + def call_user_defined_fn( + self, + user_fn: Callable[..., Any], + args: list, + kwargs: dict, + tx: Optional["InstructionTranslator"], + variable: Optional[ + Union["TritonKernelVariable", "TraceableTritonKernelWrapper"] + ], + ) -> Any: + assert isinstance(args, list) + assert isinstance(kwargs, dict) + assert callable(user_fn) + return user_fn(*args, **kwargs) + + def maybe_unpack_configs( + self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"] + ) -> list["TritonConfig"]: + assert isinstance(configs, list) + return configs + + def maybe_unpack_heuristic_result(self, result: Any) -> Any: + return result + + def check_grid( + self, + grid: "TritonGridType", + ) -> tuple[Union[int, sympy.Expr, SymInt], ...]: + if not isinstance(grid, collections.abc.Sequence): + raise RuntimeError( + "wrap_triton can only handle grids that resolve to Sequence[int]." + ) + # normalize to tuple + return tuple(grid) + + def store_non_graphable_args( + self, + combined_args: dict[str, Any], + ) -> tuple[dict, int]: + """ + Some args cannot be stored in the FX graph. + Put them in the side table. + """ + + def is_graphable(val: Any) -> bool: + return isinstance(val, (fx.node.base_types, fx.Node)) + + non_graphable_args = { + k: v for k, v in combined_args.items() if not is_graphable(v) + } + graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)} + + constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args) + + return graphable_args, constant_args_idx + + def call_HOP( + self, + variable: "TraceableTritonKernelWrapper", + grids: list["TritonGridTupleType"], + combined_args: dict[str, Any], + tx: None, + ) -> None: + assert tx is None + assert isinstance(variable, TraceableTritonKernelWrapper) + + graphable_args, constant_args_idx = self.store_non_graphable_args(combined_args) + + assert isinstance(variable.kernel_idx, int) + return triton_kernel_wrapper_mutation( + kernel_idx=variable.kernel_idx, + constant_args_idx=constant_args_idx, + grid=grids, # type: ignore[arg-type] + # TMA descriptor capturing not yet + # supported in non-dynamo tracing + tma_descriptor_metadata={}, + kwargs=graphable_args, + ) + + +tracing_triton_hopifier_singleton = TracingTritonHOPifier() + + +class TraceableTritonKernelWrapper: + kernel: "TritonKernelType" + kernel_idx: Optional[int] + grid: Optional["TritonGridType"] + + def __init__( + self, + kernel: "TritonKernelType", + kernel_idx: Optional[int], + grid: Optional["TritonGridType"], + ) -> None: + self.kernel = None + self.grid = None + tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) + assert self.kernel is not None + + def __getitem__(self, *args: Sequence[Any]) -> "TraceableTritonKernelWrapper": + return tracing_triton_hopifier_singleton.call_getitem(self, args) # type: ignore[return-value] + + def run(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any: + from torch._library.triton import is_wrap_triton_enabled + + if is_wrap_triton_enabled(): + return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None) + else: + assert self.kernel is not None + return self.kernel.run(*args, **kwargs) + + def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any: + from torch._library.triton import is_wrap_triton_enabled + + if is_wrap_triton_enabled(): + return tracing_triton_hopifier_singleton.call_triton_kernel( + self, args, kwargs, None + ) + else: + assert self.kernel is not None + return self.kernel[self.grid](*args, **kwargs) + + def specialize_symbolic(self, arg: Sequence[Any]) -> Any: + import torch + + # See [Note: Specialize tl.constexpr args in user-defined triton kernels] + if isinstance(arg, (torch.SymInt, torch.SymBool, torch.SymFloat)): + return guard_scalar(arg) + return arg diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/utils.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b103c1f6b5b65c5b956cac53d9a8d6ffb9b80b8a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/utils.py @@ -0,0 +1,1134 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +from contextlib import contextmanager, ExitStack, nullcontext +from dataclasses import dataclass +from typing import Any, Callable, Optional, overload, TypeVar, Union + +import torch +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch._dispatch.python import suspend_functionalization +from torch._guards import detect_fake_mode +from torch._higher_order_ops.schema import HopSchema +from torch._ops import HigherOrderOperator, OperatorBase, OpOverload +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import ( + disable_functional_mode, + FunctionalTensor, +) +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + disable_proxy_modes_tracing, + make_fx, +) +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.multiprocessing.reductions import StorageWeakRef + + +@dataclass +class UnsupportedAliasMutationException(RuntimeError): + reason: str + + +def autograd_not_implemented_inner( + operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any +) -> Any: + """If autograd is enabled and any of the arguments require grad this will either + raise an error or return a DelayedError depending on the value of delayed. + + Args: + operator: The Operator to call with the *args and **kwargs with + op_name: The name of the Operator + delayed_error: If True, return a DelayedError instead of raising an error + args: The flattened operands to the Operator + kwargs: The keyword arguments to the Operator + + Raises: + RuntimeError: If autograd is enabled and any of the arguments to the Operator + """ + with torch._C._AutoDispatchBelowAutograd(): + result = operator(*args, **kwargs) + flat_operands = pytree.arg_tree_leaves(*args) + if torch.is_grad_enabled() and any( + f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) + ): + if delayed_error: + err_fn = torch._C._functions.DelayedError( + f"Autograd not implemented for {str(operator)}", + 1, + ) + + def fake_requires_grad(tensor): + if torch.is_floating_point(tensor) or torch.is_complex(tensor): + tensor = tensor.detach() + tensor.requires_grad = True + return tensor + + return pytree.tree_map_only( + torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result + ) + else: + raise RuntimeError(f"Autograd not implemented for {str(operator)}") + return result + + +def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable: + def inner(*args, **kwargs): + return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs) + + return inner + + +def _maybe_run_with_interpreter(fn): + maybe_interpreted_fn = fn + if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta(): + # Running graph with interpreter is needed for propagating the stack_trace + def graph_with_interpreter(*args): + with fx_traceback.preserve_node_meta(): + return torch.fx.Interpreter(fn).run(*args) + + maybe_interpreted_fn = graph_with_interpreter + return maybe_interpreted_fn + + +def _maybe_compile_and_run_fn(fn, *args): + if not torch.compiler.is_dynamo_compiling(): + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(fn, backend=backend, fullgraph=True)(*args) + else: + return fn(*args) + + +def reenter_make_fx(fn): + from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER + + @functools.wraps(fn) + def wrapped(*args): + assert ( + _CURRENT_MAKE_FX_TRACER is not None + ), "Cannot reenter make_fx when we're not under a make_fx tracing session" + return _CURRENT_MAKE_FX_TRACER.trace_subgraph( + _maybe_run_with_interpreter(fn), *args + ) + + return wrapped + + +def _maybe_reenter_make_fx(fn): + from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER + + if _CURRENT_MAKE_FX_TRACER is not None: + return reenter_make_fx(fn) + else: + + def _maybe_make_fx_with_fake_mode(fn): + @functools.wraps(fn) + def wrapped(*args): + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(args) + if fake_mode is None: + # we creaeta a fake_mode here to make sure we could + # trace the graph with data-dependent calls e.g. .item() + return make_fx(fn, tracing_mode="fake")(*args) + # Tracing with real if all inputs have been fakfied + return make_fx(fn)(*args) + + return wrapped + + return _maybe_make_fx_with_fake_mode(fn) + + +def check_meta_consistency( + lhs_list: list[Union[torch.Tensor, torch.SymInt, int]], + rhs_list: list[Union[torch.Tensor, torch.SymInt, int]], + lhs_name: str, + rhs_name: str, + include_contiguity: bool = True, +) -> None: + def diff_meta_pairs( + lhs_list: list[Union[torch.Tensor, torch.SymInt, int]], + rhs_list: list[Union[torch.Tensor, torch.SymInt, int]], + ) -> list[str]: + def diff_meta( + lhs: Union[torch.Tensor, torch.SymInt, int], + rhs: Union[torch.Tensor, torch.SymInt, int], + ) -> str: + if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): + return ", ".join( + diff_tensor_meta( + _extract_tensor_metadata( + lhs, include_contiguity=include_contiguity + ), + _extract_tensor_metadata( + rhs, include_contiguity=include_contiguity + ), + check_grad=False, + ) + ) + else: + + def _both_int_types(lhs, rhs): + return isinstance(lhs, (int, torch.SymInt)) and isinstance( + rhs, (int, torch.SymInt) + ) + + def _both_tensor(lhs, rhs): + return isinstance(lhs, torch.Tensor) and isinstance( + rhs, torch.Tensor + ) + + if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs): + return f"type: {lhs} vs {rhs}" + + return "" + + # Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata + def diff_device( + lhs: Union[torch.Tensor, torch.SymInt, int], + rhs: Union[torch.Tensor, torch.SymInt, int], + ) -> str: + if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): + if ( + rhs.device.type == lhs.device.type + and rhs.device.index == lhs.device.index + ): + return "" + else: + return "device" + return "" + + if len(lhs_list) != len(rhs_list): + raise torch._dynamo.exc.UncapturedHigherOrderOpError( + f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}" + ) + all_diffs = [] + for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)): + if diff := diff_meta(lhs, rhs): + all_diffs.append( + f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}" + ) + if diff := diff_device(lhs, rhs): + all_diffs.append( + f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}" + ) + return all_diffs + + if all_diffs := diff_meta_pairs(lhs_list, rhs_list): + diff_str = "\n".join(all_diffs) + raise torch._dynamo.exc.UncapturedHigherOrderOpError( + f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}" + ) + + +@contextmanager +def _set_compilation_env(): + _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag + _old_allow_empty_graphs = torch._dynamo.config.allow_empty_graphs + # The issue is tracked in https://github.com/pytorch/pytorch/issues/144360: when dynamo finds + # the top-level frame produces no graph, the default behavior is to fallback to eager. + # Then when it encounters an inner function, it will try to trace that function again, which is unnecessary. + # For while_loop, during inspecting the inner call, we trace into the python dispathcer + # logic, which is not tracable as of today. So the proper fix can be either 1. allow dispatch + # logic to be dynamo tracable or 2. fixing https://github.com/pytorch/pytorch/issues/144360. + # but it exposes some bugs in existing tests so we have to have a temporary flag to control + # the behavior, which allows dynamo to store an empty graph for a frame without falling back to eager + try: + # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo + # once we are confident fx tracing works with dynamo. + torch.fx._symbolic_trace._is_fx_tracing_flag = False + torch._dynamo.config.allow_empty_graphs = True + yield + finally: + torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing + torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs + + +# The invariant here is that we always trace the branch with fake tensor +def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch): + fake_mode = detect_fake_mode(inputs) + tracing_mode = "real" + if fake_mode is None: + fake_mode = nullcontext() + tracing_mode = "fake" + + # Note: we need to turn off proxy tensor mode to avoid tracing infra + # code that happens in make_fx e.g. we now call as_strided when wrapping tensor + # as fake tensor. + with fake_mode, disable_proxy_modes_tracing(): + gm = make_fx( + fn, + tracing_mode=tracing_mode, + pre_dispatch=pre_dispatch, + _error_on_data_dependent_ops=False, + )(*inputs) + if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: + insert_deferred_runtime_asserts( + gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True + ) + return gm + + +def potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False): + try: + gm = _maybe_fake_tracing(gm, inputs, pre_dispatch) + except UnsupportedAliasMutationException: + # this can happen when nested cond_op is + # functionalized + return True + except Exception as e: + raise e + + example_inputs = [ + ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder") + ] + ( + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + inp_mutation, + ) = check_input_alias_and_mutation(gm, example_inputs) + return (inp_inp_alias_map, inp_out_alias_map, out_out_alias_map), inp_mutation + + +def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations): + if any(len(a) > 0 for a in aliases): + # TODO: Investigate here further which node is exactly aliasing + raise RuntimeError( + f"{name} where aliases appear. " + + f"In particular, these inputs \ + {set(el for el_map in aliases if len(el_map.keys()) > 0 for el in el_map.keys())} " # noqa: C401 + + "get aliased. Please ensure that this doesn't happen." + ) + if len(input_mutations): + # TODO: Investigate here further which node is exactly mutating the inputs + raise RuntimeError( + f"{name} where the inputs are mutated. " + + f"In particular, these nodes are mutating the inputs \ + {set(el for el in input_mutations)}." # noqa: C401 + + "Please ensure that this doesn't happen." + ) + + +def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False): + ( + _, + _, + _, + ), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) + + return len(inp_mutation) > 0 + + +def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False): + ( + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + ), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) + return ( + any( + ( + len(inp_inp_alias_map) > 0, + len(inp_out_alias_map) > 0, + len(out_out_alias_map) > 0, + ) + ), + len(inp_mutation) > 0, + ) + + +def _collect_fake_inputs(inputs): + from torch._subclasses.fake_tensor import FakeTensor + + # Get the example values of the inputs. + inputs_fake: list[Union[FakeTensor, torch.Tensor, int]] = [] + for inp in inputs: + if isinstance(inp, (torch.fx.proxy.Proxy, torch.fx.node.Node)): + inp = inp.node if isinstance(inp, torch.fx.proxy.Proxy) else inp + if hasattr(inp, "meta"): + val = inp.meta["example_value"] + if isinstance(val, torch.Tensor): + if torch._C._functorch.is_batchedtensor( + val + ) or torch._C._functorch.is_functionaltensor(val): + # This case is for batched or functional tensors + # Unwrap the tensors + while torch._C._functorch.is_batchedtensor( + val + ) or torch._C._functorch.is_functionaltensor(val): + val = torch._C._functorch.get_unwrapped(val) + assert isinstance(val, FakeTensor) + inputs_fake.append(val) + else: + # This is the standard case of a TensorVariable + assert isinstance(val, FakeTensor) + inputs_fake.append(val) + else: + # This case is for SymInts and other non-Tensor elements + assert not isinstance(val, torch.Tensor) + inputs_fake.append(val) + else: + # This case is for ints + assert isinstance(inp, int) + inputs_fake.append(inp) + + return inputs_fake + + +def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch): + aliases, inp_mutation = has_potential_input_alias_or_mutation( + graph_module, inputs_fake, pre_dispatch=pre_dispatch + ) + if aliases: + raise RuntimeError( + f"{name} might be aliasing the input or the output!" + ) # noqa: F541 + if inp_mutation: + raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541 + + +def unique_graph_id(proxy_mode, prefix): + """Returns a unique name and id for a graph to be added to a proxy_mode tracer""" + # There are probably better ways - I know that create_arg has some self incrementing name + # magic to it, but since we explicitly have to get the name for register_module, + # I was not sure how to do that. This kinda simulates it. + return unique_graph_name_with_root(proxy_mode.tracer.root, prefix) + + +def unique_graph_name_with_root( + root: torch.fx.GraphModule, prefix: str +) -> tuple[int, str]: + next_name = None + i = 0 + while not next_name: + candidate = f"{prefix}_{i}" + if hasattr(root, candidate): + i += 1 + else: + next_name = candidate + return i, next_name + + +def _from_fun(t): + from torch._functorch.aot_autograd import from_fun + + if isinstance(t, torch.Tensor): + if t.dtype != torch.bool: + return torch.empty_strided( + t.size(), + t.stride(), + dtype=t.dtype, + requires_grad=t.requires_grad, + device=t.device, + ) + else: + # clone of a functional tensor produces a functional tensor + # but we want to avoid it so we clone a non-functional version + maybe_unfunc_t = t + if isinstance(t, FunctionalTensor): + torch._sync(t) + maybe_unfunc_t = from_fun(t) + elif torch._is_functional_tensor(t): + # need to handle both types of functionalization here: + # these are the tensors that came from the user, + # which could be either FunctionalTensorWrapper or FunctionalTensor + torch._sync(t) + maybe_unfunc_t = torch._from_functional_tensor(t) + return maybe_unfunc_t.clone() + return t + + +def clone_outputs_aliasing_inputs(args): + input_storage = { + StorageWeakRef(arg._typed_storage()) + for arg in args + if isinstance(arg, torch.Tensor) + } + + def maybe_clone(t): + if ( + isinstance(t, torch.Tensor) + and StorageWeakRef(t._typed_storage()) in input_storage + ): + return t.clone() + return t + + return maybe_clone + + +def prepare_fw_with_masks(fn): + def fw_with_masks(*args): + fw_out = fn(*args) + return fw_out, [ + True if isinstance(ret, torch.Tensor) and ret.requires_grad else False + for ret in fw_out + ] + + return fw_with_masks + + +def prepare_fw_with_masks_all_requires_grad(fn): + def fw_with_masks(*args): + fw_out = fn(*args) + # Note [force all outputs to be require grad] + # Instead of using the original fn, we set the output of original + # fn to all require grad. This is consistent with the behavior + # of autograd.Function, where if any one of the inputs requires grad + # all output will be require grad. This also makes the downstream + # require_gradness reasoning much easier. + if pytree.tree_any_only(torch.Tensor, lambda t: t.requires_grad, args): + fw_out = pytree.tree_map_only( + torch.Tensor, lambda x: x.requires_grad_(True), fw_out + ) + return fw_out, pytree.tree_map_only( + torch.Tensor, lambda x: x.requires_grad, fw_out + ) + + return fw_with_masks + + +# This function replaces None gradients with all-zero gradients. +# `None` gradients are problematic for CUDA graphs. Those gradients are +# replaced with an all-zero tensor for better optimization +def unmask_none_gradients(grads, operands): + allowed_types = (torch.Tensor, int, torch.SymInt) + assert all( + isinstance(o, allowed_types) for o in operands + ), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}" + + unmasked_grads = [] + for g, o in zip(grads, operands): + if g is not None: + unmasked_grads.append(g) + else: + # In case the operand is an int or a torch.SymInt, return None + # This can happen for lifted_arguments. E.g., the shapes of a dynamic tensor are lifted and passed + # as additional arguments + unmasked_grads.append( + torch.zeros_like(o) if isinstance(o, torch.Tensor) else None + ) + + return unmasked_grads + + +def _maybe_fake_prop_ignore_unbacked(fn, args): + with ExitStack() as ctx_stack: + if (fake_mode := detect_fake_mode(args)) is not None: + ctx_stack.enter_context(fake_mode) + if fake_mode.shape_env is not None: + ctx_stack.enter_context( + fake_mode.shape_env.ignore_fresh_unbacked_symbols() + ) + return fn(*args) + + +def redirect_to_mode(hop: OperatorBase, mode): + """Utility for redispatching HOP to underlying mode + + Args: + hop: The HOP to redispatch + mode: The mode to redispatch to + + Returns: + A decorated function that implements the HOP for the given mode + """ + + @hop.py_impl(mode) + def impl(mode, *args, **kwargs): + return mode.__torch_dispatch__(hop, [], args, kwargs) + + return impl + + +# TODO: The parameter use_output_and_grad_bw is required because some operations +# that utilize this function, such as the while_loop, may require (grad, fwd_outputs) +def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs): + from torch._functorch.aot_autograd import AOTConfig, create_joint + + # Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys + # between Autograd and Python key. Currently, we only suspend functionalization but more can be + # added when required. Will encounter two problems if we don't suspend functionalization: + # + # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper, + # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching. + # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to + # fetch the proxy for the inputs and fail to capture any operations on them. + # + # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further + # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer + # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore, + # when creating the output node, it fails to associate the wrapped tensor with its proxy. + # Instead, it will create _tensor_constant as output. + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + example_grad = [_from_fun(out) for out in fw_outputs] + num_grads = len(example_grad) + fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs) + + def joint_fn(*joint_operands_grads): + if use_output_and_grad_bw: + grads = joint_operands_grads[0] + inputs = joint_operands_grads[1][-1:] + else: + grads = joint_operands_grads[:num_grads] + inputs = joint_operands_grads[num_grads:] + + joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config) + _, grads = joint( + list(inputs), + [grad for grad in grads if grad is not None and grad.requires_grad], + ) + + # Unmask None gradients to all-zero gradients + unmasked_grads = unmask_none_gradients(grads, inputs) + + # In order to keep map functional for backward graph, + # we clone outputs that are aliasing inputs + maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads) + + return pytree.tree_map(maybe_clone, unmasked_grads) + + if use_output_and_grad_bw: + example_xs_out = list(fw_inputs) + list(fw_outputs) + joint_graph = _maybe_reenter_make_fx(joint_fn)( + (list(example_grad), list(example_xs_out)) + ) + else: + example_xs_out = list(fw_inputs) + joint_graph = _maybe_reenter_make_fx(joint_fn)( + *(list(example_grad) + list(example_xs_out)) + ) + + return fw_graph, joint_graph + + +def _unstack_pytree(xs): + flat_xs, inspec = pytree.tree_flatten(xs) + if not all(isinstance(xs, torch.Tensor) for xs in flat_xs): + raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") + + if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): + raise RuntimeError( + f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" + ) + + a = zip(*flat_xs) + + pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a] + return pytrees + + +def _stack_pytree(pytrees): + flat_out = [] + out_spec = None + for pt in pytrees: + flat_pt, out_spec = pytree.tree_flatten(pt) + flat_out.append(flat_pt) + assert out_spec is not None + b = zip(*flat_out) + stacked_out = [] + for leaves in b: + if all(isinstance(leaf, torch.Tensor) for leaf in leaves): + stacked_out.append(torch.stack(leaves)) + elif all(leaf is None for leaf in leaves): + # Backward graph can return None output when forward inputs doesn't require grad. + # When we eagerly execute backward graph, we need to call _stack_pytree on its output, + # therefore we need to deal with None output. + stacked_out.append(None) # type: ignore[arg-type] + else: + raise RuntimeError(f"Cannot stack {leaves}.") + return pytree.tree_unflatten(stacked_out, out_spec) + + +# We cannot call save_for_backward for symints. This helper function +# can be used to save symints as direct attributes of ctx in autograd.Function. +# +# For example, if args = (x, y, s0, z, s1), +# save_tensors_and_symints_for_backward will partition the args into two lists, and a bookkeeping list pos: +# partitioned_args[0] = (x, y, z) +# partitioned_args[1] = (s0, s1) +# pos = (0, 0, 1, 0, 1) +# pos list keeps track of which partition the args +# is partitioned into in order to recover it in saved_tensors_and_symints. +# +# In saved_tensors_and_symints, we can recover the original args by: +# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]]. +# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists. +def save_tensors_and_symints_for_backward(ctx, args): + assert all( + isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args + ), args + partitioned_args: list[Any] = [[], []] + pos = [] + for arg in args: + idx = 0 if isinstance(arg, torch.Tensor) else 1 + partitioned_args[idx].append(arg) + pos.append(idx) + + assert not hasattr(ctx, "sym_int_args"), "ctx already has sym_int_args attribute." + assert not hasattr(ctx, "pos"), "ctx already has pos attribute." + ctx.save_for_backward(*partitioned_args[0]) + ctx.sym_int_args = partitioned_args[1] + ctx.pos = pos + + +def saved_tensors_and_symints(ctx): + args = [] + t_idx = 0 + s_idx = 0 + saved_tensors = ctx.saved_tensors + for p in ctx.pos: + if p == 0: + args.append(saved_tensors[t_idx]) + t_idx += 1 + else: + args.append(ctx.sym_int_args[s_idx]) + s_idx += 1 + assert t_idx + s_idx == len(ctx.pos) + return tuple(args) + + +def get_dummy_aot_autograd_config(): + from torch._functorch.aot_autograd import AOTConfig + + return AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + +# Slices off the first element of a given dimension +def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor: + return torch.select_copy(t, dim, 0) + + +# Reports the difference between meta of two tensors in a string +def diff_tensor_meta( + meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True +) -> list[str]: + from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode + + pair_diffs = [] + for meta_name in TensorMetadata._fields: + if not check_grad and meta_name == "requires_grad": + continue + val1 = getattr(meta1, meta_name) + val2 = getattr(meta2, meta_name) + try: + if val1 != val2: + pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'") + except GuardOnDataDependentSymNode as _: + pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'") + continue + return pair_diffs + + +# Note [lifted arg types in hop] +# For dynamoed hops, we automatically lift the free symbols in tensors as arguments. +# This has implications for the types of lifted args for different dispatch keys: +# 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint +# lifted args because it's on the path of torch.compile(dynamic=True). +# 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need +# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner +# hops may receive int inputs from the shape of outer tensor inputs. +# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs. +def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]): + allowed_types = (torch.Tensor, int, torch.SymInt) + assert all( + isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args + ), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}" + + +# TODO: Return a more detailed information as to which node +# causes a mutation or an alias. This may requires a per operator tensor version checking +def check_input_alias_and_mutation( + gm: torch.fx.GraphModule, + fake_args: list[FakeTensor], +) -> tuple[dict[int, int], dict[int, int], dict[int, int], list[int]]: + ( + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + mutated_inputs, + ) = check_input_alias_and_mutation_return_outputs(gm, fake_args)[:-1] + return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs + + +def check_input_alias_and_mutation_return_outputs( + gm: torch.fx.GraphModule, + fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]], +) -> tuple[ + dict[int, int], + dict[int, int], + dict[int, int], + list[int], + Union[tuple[Any, ...], list[Any]], +]: + # This function can be called under autograd, functional, proxy and fake tensor mode. + # We need to return either a fake tensor or a real tensor depending on the mode. + # to detect the input mutation/aliasing. + with disable_proxy_modes_tracing(), disable_functional_mode(), suspend_functionalization(): + + def _from_functional_tensor(t: torch.Tensor) -> torch.Tensor: + if isinstance(t, FunctionalTensor) or torch._is_functional_tensor(t): + return torch.empty_strided( + t.size(), + t.stride(), + dtype=t.dtype, + requires_grad=t.requires_grad, + device=t.device, + ) + return t + + fake_args = pytree.tree_map_only( + torch.Tensor, _from_functional_tensor, fake_args + ) + # We want to disable active functional, proxy and fake modes if any. + # to create a encapsulated environment for fake tensor prop + with torch.utils._python_dispatch._disable_current_modes(): + """This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias + in the graph module gm. It checks whether input tensor versions have + changed after run gm once to detect mutation and checks tensor storage + to detect alias. + """ + + def _tensor_version(t) -> Optional[int]: + if isinstance(t, torch.Tensor): + if not isinstance(t, FakeTensor): + raise RuntimeError("Only fake tensor is allowed") + return t._version + return None + + def _tensor_storage(t) -> StorageWeakRef: + return StorageWeakRef(t._typed_storage()) + + def _get_shape_env( + fake_args, + ) -> Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]: + # detect_fake_mode requires there could be only one active fake mode. This + # restricts the usage of this function because the global TracingContext + # has a persistent fake mode but fake tensors can be created + # outside of the tracing context (e.g. in testing). + # Instead, we just look at fake_args fake tensor mode + if len(fake_args) == 0: + return torch.fx.experimental.symbolic_shapes.ShapeEnv() + + for arg in fake_args: + if isinstance(arg, FakeTensor): + return arg.fake_mode.shape_env + return None + + # Clone the fake args to avoid mutating the original fake args + with ExitStack() as ctx_stack: + # We need to re-use prev_fake_mode's shape env to resolve + # the runtime assertions for unbacked symbols. + new_fake_mode = torch._subclasses.FakeTensorMode( + shape_env=_get_shape_env(fake_args), + allow_non_fake_inputs=False, + ) + # We need to temporarily turn inference_mode off because + # under inference mode, tensor version counter is not tracked. + no_inference_mode_ctx = torch.inference_mode(False) + ctx_stack.enter_context(new_fake_mode) + ctx_stack.enter_context(no_inference_mode_ctx) + if new_fake_mode.shape_env is not None: + ctx_stack.enter_context( + new_fake_mode.shape_env.ignore_fresh_unbacked_symbols() + ) + + # create new fake tensors in new fake mode to avoid mutating original tensors + cloned = [ + torch.empty_strided( + arg.size(), + arg.stride(), + dtype=arg.dtype, + device=arg.device, + requires_grad=arg.requires_grad, + layout=arg.layout, + ) + if isinstance(arg, torch.Tensor) + else arg + for arg in fake_args + ] + before = [_tensor_version(arg) for arg in cloned] + outputs = gm(*cloned) + outputs = [outputs] if not isinstance(outputs, (list, tuple)) else outputs + after = [_tensor_version(arg) for arg in cloned] + mutated_inputs = [ + i for i, (v1, v2) in enumerate(zip(before, after)) if v1 != v2 + ] + # We need to analyze the original fake_args to detect + # inp-inp alias. + inp_storage_map = { + _tensor_storage(inp): i + for i, inp in enumerate(fake_args) + if isinstance(inp, torch.Tensor) + } + inp_inp_alias_map = { + i: inp_storage_map[_tensor_storage(inp)] + for i, inp in enumerate(fake_args) + if isinstance(inp, torch.Tensor) + and inp_storage_map[_tensor_storage(inp)] != i + } + out_storage_map = { + _tensor_storage(out): i + for i, out in enumerate(outputs) + if isinstance(out, torch.Tensor) + } + out_out_alias_map = { + i: out_storage_map[_tensor_storage(out)] + for i, out in enumerate(outputs) + if isinstance(out, torch.Tensor) + and out_storage_map[_tensor_storage(out)] != i + } + inp_out_alias_map = { + i: out_storage_map[_tensor_storage(inp)] + for i, inp in enumerate(cloned) + if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map + } + return ( + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + mutated_inputs, + outputs, + ) + + +registered_hop_fake_fns: dict[torch._ops.OpOverload, Callable] = {} + + +F = TypeVar("F", bound=Callable) + + +@overload +def register_fake(hop, fn: None = None) -> Callable[[F], F]: + ... + + +@overload +def register_fake(hop, fn: F) -> F: + ... + + +def register_fake(hop, fn=None): + """ + Register a fake function for a HOP. This is conceptually equivalent of the + register_fake utility for the custom ops. The registered function is called + inside the fake_tensor _dispatch_impl. + """ + assert hop not in registered_hop_fake_fns + + def register(func: F) -> F: + from torch._subclasses.fake_tensor import FakeTensorMode + + redirect_to_mode(hop, FakeTensorMode) + + registered_hop_fake_fns[hop] = func + return func + + if fn is None: + return register + return register(fn) + + +class FunctionalizeCtxWrapper: + """ + This is a dummy wrapper to facilitate fake tensor caching. + + For AOT Dispatcher metadata collection pass, HOPs go from functionalization + key to fake tensor key. The functionalization key wraps the subgraphs in a + function, which changes from call to call even though the subgraph might + still be same. + + To enable fake tensor caching, we just wrap the ctx and subgraph in this + class and then use the subgraph as the hash. + """ + + # Prevents PYTORCH_TEST_WITH_DYNAMO=1 test failures + @torch._disable_dynamo + def __init__(self, ctx, subgraph): + self.ctx = ctx + self.subgraph = subgraph + + def __hash__(self): + return id(self.subgraph) + + def __repr__(self): + return f"FunctionalizeCtxWrapper on subgraph {self.subgraph})" + + def __call__(self, *args, **kwargs): + if isinstance(self.subgraph, torch.fx.GraphModule): + # Running graph with interpreter is needed for propagating the stack_trace + with fx_traceback.preserve_node_meta(): + return self.ctx.functionalize(torch.fx.Interpreter(self.subgraph).run)( + *args, **kwargs + ) + return self.ctx.functionalize(self.subgraph)(*args, **kwargs) + + +# A wrapper over HigherOrderOperator that also carries its schema +class HopInstance: + def __init__(self, op: HigherOrderOperator, schema: HopSchema): + assert isinstance(op, HigherOrderOperator), op + self._op = op + # Using "_" to be consistent with how we access _schema of OpOverload + self._schema = schema + + def __call__(self, *args, **kwargs): + return self._op(*args, **kwargs) + + @staticmethod + def create(hop: HigherOrderOperator, *args, **kwargs): + return HopInstance(hop, hop.gen_schema(*args, **kwargs)) + + +# This call_op can be used to call a HopInstance with +# flat args and kwargs. We need to make use of the hop's schema's tree_spec +# to unflatten the args and kwargs before calling the hop. +def call_op(op: Union[OpOverload, HopInstance], args, kwargs): + if isinstance(op, OpOverload): + return op(*args, **kwargs) + + assert isinstance(op, HopInstance), op + schema = op._schema + bound_args = list(args) + bound_kwargs = {} + for arg in schema.arguments[len(bound_args) :]: + assert arg.name in kwargs, (arg.name, kwargs) + val = kwargs[arg.name] + if not arg.kwarg_only: + bound_args.append(val) + else: + bound_kwargs[arg.name] = val + + if schema.tree_spec is not None: + assert len(bound_args) == len(schema.arguments) and len(bound_kwargs) == 0 + args, kwargs = pytree.tree_unflatten(bound_args, schema.tree_spec) + return op(*args, **kwargs) + else: + assert len(bound_args) + len(bound_kwargs) == len(schema.arguments) + return op(*bound_args, **bound_kwargs) + + +def materialize_as_graph( + fn: Callable, + args: tuple[Any], + include_key_set: Optional[torch._C.DispatchKeySet] = None, + exclude_key_set: Optional[torch._C.DispatchKeySet] = None, + force_enable_grad=False, +) -> torch.fx.GraphModule: + if include_key_set is None: + include_key_set = torch._C._dispatch_tls_local_include_set() + if exclude_key_set is None: + exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + + @torch._dynamo.disable(recursive=True, reason=None) + def _materialize_as_graph_inner(): + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + unfunc_t = [_from_fun(arg) for arg in args] + with contextlib.ExitStack() as stack: + stack.enter_context( + torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set), + ) + if force_enable_grad: + stack.enter_context(torch.enable_grad()) + return _maybe_reenter_make_fx(fn)(*unfunc_t) + + gm = _materialize_as_graph_inner() + assert gm is not None + return gm + + +def materialize_callable_in_args(op: HopInstance, args, kwargs): + schema = op._schema + hop = op._op + flat_args, flat_spec = pytree.tree_flatten((args, kwargs)) + + def wrapped_fn(*flat_args): + return call_op(op, args, kwargs) + + # We need to trace the higher order op in order to materilaize the callable inputs that + # are a callable (e.g. after functionalization key) + gm = reenter_make_fx(wrapped_fn)(*flat_args) + hop_node = gm.graph.find_nodes(op="call_function", target=hop)[0] + arg_proxies = pytree.tree_leaves((hop_node.args, hop_node.kwargs)) + assert isinstance(schema, torch._C.FunctionSchema) and len(arg_proxies) == len( + schema.arguments + ) + + # call_op preserves ordering of proxies via schema + materialized_args = [] + for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)): + if ( + isinstance(proxy, torch.fx.Node) + and proxy.op == "get_attr" + and isinstance(getattr(gm, proxy.target), torch.fx.GraphModule) # type: ignore[arg-type] + ): + assert callable(flat_args[i]), (schema, args, kwargs) + materialized_args.append(getattr(gm, proxy.target)) # type: ignore[arg-type] + else: + materialized_args.append(flat_args[i]) + + return pytree.tree_unflatten(materialized_args, flat_spec) + + +def has_user_subclass(args, allowed_subclasses): + """Check if any tensor arguments are user subclasses. + + This is used to determine if tensor subclasses should get a chance to run + their own implementation first before falling back to the default implementation. + + Args: + args: Arguments to check (will be flattened with pytree) + allowed_subclasses: Tuple of allowed subclass types + + Returns: + True if user tensor subclasses are found, False otherwise + """ + flat_args, _ = pytree.tree_flatten(args) + + val = any( + isinstance(a, torch.Tensor) + and type(a) is not torch.Tensor + and not isinstance(a, allowed_subclasses) + for a in flat_args + ) + return val + + +def _has_gen_schema(op: HigherOrderOperator): + # There is an InvokeQuant argument we cannot gen_schema. + if op is torch.ops.higher_order.invoke_quant_packed: + return False + method = "gen_schema" + return hasattr(type(op), method) and getattr(type(op), method) is not getattr( + HigherOrderOperator, method + ) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/while_loop.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/while_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf2d08ddade729b74ed8ec9b331f1f9021f6b65 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/while_loop.py @@ -0,0 +1,420 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Callable, Union + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _maybe_run_with_interpreter, + _set_compilation_env, + autograd_not_implemented, + check_meta_consistency, + reenter_make_fx, + validate_subgraph_args_types, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + ProxyTorchDispatchMode, + track_tensor_tree, +) + + +class WhileLoopOp(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("while_loop") + + def __call__( + self, + cond_fn: Callable, + body_fn: Callable, + carried_inputs: tuple[Union[torch.Tensor, int, float, bool]], + additional_inputs: tuple[Union[torch.Tensor, torch.SymInt, int], ...], + /, + ): + if not isinstance(carried_inputs, (tuple, list)): + raise RuntimeError( + f"carried_inputs must be a tuple or list, got {type(carried_inputs)}" + ) + if not isinstance(additional_inputs, (tuple, list)): + raise RuntimeError( + f"additional_inputs must be a tuple or list, got {type(additional_inputs)}" + ) + + validate_subgraph_args_types(carried_inputs) + validate_subgraph_args_types(additional_inputs) + return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs) + + +while_loop_op = WhileLoopOp() + + +def while_loop(cond_fn, body_fn, carried_inputs): + r""" + Run body_fn(*carried_inputs) while cond_fn(*carried_inputs) returns a True scalar tensor. Returns the output of body_fn or + initial carried_inputs. + + .. warning:: + `torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and + doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + `while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export. + + `while_loop` is equivalent to the following: + + def while_loop(cond_fn, body_fn, carried_inputs): + val = carried_inputs + while cond_fn(*val): + val = body_fn(*val) + return val + + Args: + cond_fn (Callable): A callable function that returns a boolean Scalar tensor or a python boolean. + + body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors or ints + + carried_inputs (Tuple of possibly nested dict/list/tuple of tensors or ints): A tuple of inputs to cond_fn and body_fn. + It's also the initial value of states that are carried across iterations. Note that when pass an integer as carry, + the corresponding return of while_loop will be another int with unknown values because we don't know how many + iterations while_loop will run. + + Example 1: + + def cond_fn(iter, x): + return iter.sum() < 10 + + def body_fn(iter, x): + return iter + 1, x.sin() + + while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4))) + + Example 2: + + def cond_fn(int_iter, x): + return 2 * int_iter < x.shape[0] + + def body_fn(int_iter, x): + return int_iter + 1, x + int_iter + + while_loop(cond,_fn, body_fn, (0, torch.randn(3, 4))) + + Restrictions: + + - body_fn must return tensors or int with the same metadata (e.g.shape, dtype) as inputs. + + - body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required. + + - body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn. + + - body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required. + + .. warning:: + Temporal Limitations: + + - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. + + """ + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + + # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. + # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. + additional_inputs: tuple = () + + # The reason we flatten the output before calling into dynamo is that + # we want to create a consistent input ordering for cond_fn and body_fn. + # and we also want to the input ordering matches the output ordering. + # Also see NOTE: [why we cannot use "automatic" for while_loop] + # Construct flat cond_fn and flat_body_fn, which takes flattened inputs + flat_inputs, in_spec = pytree.tree_flatten((carried_inputs, additional_inputs)) + + def flat_cond_fn(*flat_args): + carried, additional = pytree.tree_unflatten(flat_args, in_spec) + return cond_fn(*carried, *additional) + + def flat_body_fn(*flat_args): + carried, additional = pytree.tree_unflatten(flat_args, in_spec) + return body_fn(*carried, *additional) + + if torch.compiler.is_dynamo_compiling(): + return while_loop_op(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple()) + + def _validate_input(cond_fn, body_fn, carried_inputs): + from torch._higher_order_ops.utils import validate_subgraph_args_types + + if not callable(cond_fn) or not callable(body_fn): + raise RuntimeError("Expect cond_fn and body_fn to be callable.") + + validate_subgraph_args_types(flat_inputs) + + if not pytree.tree_all( + lambda t: isinstance(t, (torch.Tensor, torch.SymInt, int)), carried_inputs + ): + raise RuntimeError( + "Expect carried_inputs to be a tuple of possibly nested dict/list/tuple that only" + f"consists of tensor or int leaves, but got {carried_inputs}." + ) + + _validate_input(cond_fn, body_fn, carried_inputs) + + # Dynamo is expecting a callable with "__code__" attribute. + # We cannot directly pass cond_op to it. So we wrap it in a dummy function. + def _while_loop_op_wrapper(*args, **kwargs): + return while_loop_op(*args, **kwargs) + + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile( + _while_loop_op_wrapper, backend=backend, fullgraph=True + )(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple()) + + +@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs): + carried_vals = carried_inputs + + def _validate_cond_output(pred): + if ( + isinstance(pred, torch.Tensor) + and pred.size() == torch.Size([]) + and pred.dtype == torch.bool + ) or isinstance(pred, bool): + return + else: + raise RuntimeError( + f"cond_fn must return a boolean scalar tensor or a boolean but got {pred}" + ) + + if not isinstance(carried_inputs, (tuple, list)): + raise RuntimeError( + f"carried_inputs must be a tuple or list but got {type(carried_inputs)}" + ) + + while pred := cond_fn(*carried_vals, *additional_inputs): + _validate_cond_output(pred) + out = body_fn(*carried_vals, *additional_inputs) + assert isinstance( + out, tuple + ), f"body_fn should return a tuple but got {type(out)}" + assert len(out) == len( + carried_inputs + ), "body_fn should return the same number of elements as carried_inputs" + carried_vals = out + return carried_vals + + +while_loop_op.py_autograd_impl( + autograd_not_implemented(while_loop_op, deferred_error=True) +) + + +def _find_or_create_fake_mode() -> FakeTensorMode: + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + fake_mode = torch._guards.detect_fake_mode() + if fake_mode is None: + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + + return fake_mode + + +def _create_unbacked_symint( + fake_mode: FakeTensorMode, ignore_fresh_unbacked_symbols: bool +) -> torch.SymInt: + assert ( + fake_mode is not None and fake_mode.shape_env is not None + ), "Must provide a fake_mode with shape_env." + ctx = ( + contextlib.nullcontext() + if not ignore_fresh_unbacked_symbols + else fake_mode.shape_env.ignore_fresh_unbacked_symbols() + ) + with ctx: + return fake_mode.shape_env.create_unbacked_symint() + + +@while_loop_op.py_impl(ProxyTorchDispatchMode) +def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs): + def _trace_while_loop( + proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs + ): + # NOTE [unspecialize int carry with unbacked symints] + # When we support int carry, we'll also need to support int output of body_fn because. + # previous iteration's output is next iteration's input and they must match. + # For carries, when we start tracing while_loop, they can be + # - constants e.g. (0, [1, 3]) + # - backed symints (x.shape[0], [x.shape[1] + x.stride[1], x.shape[2]]) + # - unbacked symints e.g. (u0, [u0 + u1, u2]) + # We choose the most conservative design: in all cases, we create new unbacked symints to trace the + # subgraph. It's possible to do some analysis on initial carry and the output of first + # iteration to determine a better range for the output unbacked symbol e.g. when input is an unbacked + # symint >= 0 before the while_loop but in general this is difficult because we don't know + # the number of iterations. Users would have to re-constrain the unbacked symint in subgraph if needed. + # + # For output of fake cond_fn, it could be constant bool or SymBool (e.g. return x.shape[0] < 4, + # where x.shape[0] can be either static of dynamic). In the case of constant bool, we should do a + # specialization (NYI). + + # For output of fake body_fn, it could be all three types though from user's point of view, + # they're all integers e.g. + + # init_carry = (0, s0, u1, t) + # def body_fn(u0, s0, u1, t): + # ... + # return (t.shape[0], t.shape[1], t.shape[2], y + 1) + # + # It may seem that a constant output isn't possible: users shouldn't write a while_loop + # that always return 0. But it could be that a shape is not set as dynamic properly (e.g. + # automatic dynamic hasn't been triggered). + # + # For this reason, we treat int, symint outputs in the same way: + # - they can match against any of int, symint carry + # - we unspecialize them with new unbacked symints in fake while_loop + # Similarly, we could do some analysis to refine the output ranges but it's eaiser to start with + # fresh unbacked symints. One suprising case can be: an input unbacked symint is constrained by + # users to be >= 0 (either before while_loop or inside body_fn) and it increments by 1 in each + # iteration. Ideally, we should know that the final output is >= 0 but we didn't constrain the + # unbacked symint output of subgraph as of today because this requires a smart range analysis. + fake_mode: FakeTensorMode = _find_or_create_fake_mode() + unspecialized_carried_inputs = pytree.tree_map_only( + (int, torch.SymInt), + # For temporarily created unbacked symints, we don't need to bind them to any proxy + lambda _: _create_unbacked_symint( + fake_mode, ignore_fresh_unbacked_symbols=True + ), + carried_inputs, + ) + + cond_graph = reenter_make_fx(cond_fn)( + *unspecialized_carried_inputs, *additional_inputs + ) + body_graph = reenter_make_fx(body_fn)( + *unspecialized_carried_inputs, *additional_inputs + ) + + next_name = None + i = 0 + while not next_name: + candidate = f"while_loop_cond_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + cond_graph_name = next_name + body_graph_name = f"while_loop_body_graph_{i}" + assert not hasattr(proxy_mode.tracer.root, body_graph_name) + + proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph) + proxy_mode.tracer.root.register_module(body_graph_name, body_graph) + + args = (cond_graph, body_graph, carried_inputs, additional_inputs) + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", while_loop_op, proxy_args, {}, name="while_loop" + ) + + out = while_loop_op( + cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs + ) + return track_tensor_tree( + out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) + + return _trace_while_loop( + mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs + ) + + +@while_loop_op.py_impl(FakeTensorMode) +def while_loop_fake_tensor_mode( + mode, cond_fn, body_fn, carried_inputs, additional_inputs +): + with mode: + # NOTE: [Handling unback symints in subgraph of while_loop] + # The idea is that the scope of unbacked symints are limited to the subgraph. + # + # We're implementing the fake tensor mode of while_loop operator. + # and we run body_fn once to get an fake output. + # Let's first consider the case that unbacked symints are tensor shapes: + # + # Case 1: + # if the unbacked symints is local to the subgraph e.g. + # def body_fn(it, x): + # nz = x.nonzero() + # return it+1. nz.sum() + # we can just ignore the newly created unbacked symints because it has + # no effect on the output of while_loop and it's tracked when we tracing. + # the subgraph. + # + # Case 2: + # if the unbacked symints are shape of output of while_loop e.g. + # def body_fn(it, x): + # nz = x.nonzero() + # return it+1, nz + # This will fail the shape check because in each iteration, the carried_input's shape + # must match the output shape as nz.shape contains newly allocated unbacked symint, this + # won't match the carried_input's shape. + # + # Case 3: + # if the unbacked symints are shape of carried_inputs e.g. + # nz = a.nonzero() + # body_fn(it, nz): + # return it+1. nz.sin() + 1, + # There's no new unbacked symints allocated in subgraph, so we're safe. + with mode.shape_env.ignore_fresh_unbacked_symbols(): + # body_fn return output with the same pytree and tensor meta data as carried_inputs + # so we could just return the output after one iteration. + body_outs = body_fn(*carried_inputs, *additional_inputs) + check_meta_consistency( + carried_inputs, + body_outs, + "carried_inputs", + "body_output", + include_contiguity=False, + ) + # See NOTE [unspecialize int carry with unbacked symints] + return pytree.tree_map_only( + (int, torch.SymInt), + # For while_loop's unbacked symint output, we want them to be bound + # to the proxy of while_loop's output. + lambda _: _create_unbacked_symint( + mode, ignore_fresh_unbacked_symbols=False + ), + body_outs, + ) + + +@while_loop_op.py_functionalize_impl +def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs): + from torch._higher_order_ops.utils import _check_alias_and_mutation + + unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs) + unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) + unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs + with ctx.redispatch_to_next(): + functional_cond_fn = ctx.functionalize(_maybe_run_with_interpreter(cond_fn)) + functional_body_fn = ctx.functionalize(_maybe_run_with_interpreter(body_fn)) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + for fn, fn_name in [ + (cond_fn, "cond_fn"), + (body_fn, "body_fn"), + ]: + _check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch) + ret = while_loop_op( + functional_cond_fn, + functional_body_fn, + unwrapped_carried_inputs, + unwrapped_additional_inputs, + ) + return ctx.wrap_tensors(ret) diff --git a/phivenv/Lib/site-packages/torch/_higher_order_ops/wrap.py b/phivenv/Lib/site-packages/torch/_higher_order_ops/wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..90f96d53ee3d57ac135439cb3d161a8cd9b6d3b8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_higher_order_ops/wrap.py @@ -0,0 +1,286 @@ +# mypy: allow-untyped-defs +import inspect +import itertools +import logging +from typing import Optional + +from torch._logging import warning_once +from torch._ops import HigherOrderOperator +from torch.types import _dtype + + +log = logging.getLogger(__name__) + +uid = itertools.count(1) + + +# Used for testing the HigherOrderOperator mechanism +class Wrap(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("wrap") + + def __call__(self, func, *args, **kwargs): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo import disable + + @disable + def wrapper(): + result = func(*args, **kwargs) + return result + + return wrapper() + + +wrap = Wrap() + + +class WrapWithSetGradEnabled(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("wrap_with_set_grad_enabled") + + def __call__(self, enable_grad, wrapped_func, *args, **kwargs): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo import disable + + @disable + def wrapper(): + with torch.set_grad_enabled(enable_grad): + return wrapped_func(*args, **kwargs) + + return wrapper() + + +wrap_with_set_grad_enabled = WrapWithSetGradEnabled() + + +class WrapWithAutocast(HigherOrderOperator): + def __init__(self): + super().__init__("wrap_with_autocast") + + def __call__( + self, + device_type: str, + dtype: Optional[_dtype], + enabled: bool, + cache_enabled: Optional[bool], + wrapped_func, + *args, + **kwargs, + ): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo import disable + + @disable + def wrapper(): + with torch.autocast(device_type, dtype, enabled, cache_enabled): + return wrapped_func(*args, **kwargs) + + return wrapper() + + +wrap_with_autocast = WrapWithAutocast() + + +# This HOP allows you to bypass dynamo tracing of the wrapper function while +# still tracing the inner function. +# Takes two callables: The first, `wrapper_fn`, accepts `inner_fn` and returns a +# callable with the same signature. The second is the `inner_fn` itself. Any +# extra *args and **kwargs are forwarded to `wrapper_fn(inner_fn)` when it is +# executed. +class DynamoBypassingWrapper(HigherOrderOperator): + def __init__(self): + super().__init__("dynamo_bypassing_wrapper") + + def __call__( + self, + wrapper_fn_or_key, + inner_fn, + *args, + **kwargs, + ): + # Dynamo already traces the body of HigherOrderOp beforehand when it + # so no need to trace into it. + import torch._dynamo # noqa: F401 + from torch._dynamo import disable + + is_compiling = isinstance(wrapper_fn_or_key, str) + if is_compiling: + assert isinstance(inner_fn, torch.fx.GraphModule) + wrapper_fn = inner_fn.meta[wrapper_fn_or_key] + else: + wrapper_fn = wrapper_fn_or_key + + @disable + def wrapper(): + return wrapper_fn(inner_fn)(*args, **kwargs) + + return wrapper() + + +dynamo_bypassing_wrapper = DynamoBypassingWrapper() + + +class WrapActivationCheckpoint(HigherOrderOperator): + """ + This operator is used to wrap torch.utils.checkpoint. This avoids + TorchDynamo to look into saved tensor hooks and directly passes the control + to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of + AOT tracing torch.utils.checkpoint code, we have a backward graph with + recomputed forward nodes. + + However, we might deprecate this operator soon. The difficulty arises in the + functionalization of rng ops. Today, there are two different + functionalization of rng ops - one at AOT autograd and other at Inductor. + And they are difficult to map to each other. The rng states also complicate + pattern matching in Inductor. Due to the ease of implementation, we are + currently inclined towards functionalization at Inductor level, which means + that duplication/recomputation is done as a compiler pass in the + partitioners. See TagActivationCheckpoint for more information. + """ + + def __init__(self) -> None: + super().__init__("wrap_activation_checkpoint", cacheable=False) + + def __call__(self, function, *args, **kwargs): + # use_reentrant is set to False because this op is going to be traced. + # And we ensure that AOT Autograd traces through the non reentrant + # version of checkpointing. + import torch.fx.traceback as fx_traceback + from torch.fx import Interpreter + + kwargs["use_reentrant"] = False + kwargs["preserve_rng_state"] = False + # Using interpreter allows preservation of metadata through torch.compile stack. + with fx_traceback.preserve_node_meta(): + from torch.utils.checkpoint import checkpoint + + return checkpoint(Interpreter(function).run, *args, **kwargs) + + +wrap_activation_checkpoint = WrapActivationCheckpoint() + + +class TagActivationCheckpoint(HigherOrderOperator): + """ + This operator is supposed to be used only with torch.compile stack. This + accepts a Fx graph module which needs to be checkpointed. This operator adds + "recomputable" tag to the nodes of the Fx graph that should be recomputed. + + The goal is to: + 1. Avoid using Dynamo to trace through saved tensor hooks. + 2. For selective checkpointing case, let AOTAutograd trace through + saved tensor hooks but has special logic with TorchDispatchMode to override + the usual saved_tensor_hooks fn logic in order to tag the nodes. + 3. Rely on the partitioners to actually duplicate the nodes. + This sits well in the torch.compile stack, because by the time graph + reaches partitioner, inductor has already run its functionalization of rng + ops (by setting fixed seed for each random op, see `replace_random_passes`). + Therefore, the duplication of nodes, by design, respects the rng states in + the forward and recomputed forward in backward. + """ + + def __init__(self) -> None: + super().__init__("tag_activation_checkpoint", cacheable=False) + + @staticmethod + def divide_kwargs(kwargs): + """ + checkpoint fn can have mixed kwargs between checkpointed fn and + checkpoint fn itself. For example + >> def gn(x, y, z=None): + >> a = torch.matmul(x, y) + >> if z is not None: + >> return torch.matmul(a, z) + >> return a + >> def fn(x, y, z): + >> return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z)) + In the above case, z belongs to checkpointed function gn, but + use_reentrant belongs to the checkpoint function. This function splits + the kwargs into checkpoint_kwargs and gmod_kwargs (or + checkpointed_fn_kwargs). + We do sorting to ensure same graph from run to run for better + debuggability. It is not required for correctness. + """ + from torch.utils.checkpoint import checkpoint + + ckpt_signature = inspect.signature(checkpoint) + checkpoint_keys = set() + for name in ckpt_signature.parameters: + if name in ("function", "args", "kwargs"): + continue + checkpoint_keys.add(name) + + # `preserve_rng_state` is not a regular kwarg + checkpoint_keys.add("preserve_rng_state") + + checkpoint_kwargs = { + name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys + } + gmod_kwargs = { + name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys + } + return checkpoint_kwargs, gmod_kwargs + + def tag_nodes(self, gmod, is_sac): + from torch.utils.checkpoint import CheckpointPolicy + + unique_graph_id = next(uid) + for node in gmod.graph.nodes: + if node.op in ("call_function", "call_method", "call_module"): + node.meta["ac_graph_id"] = unique_graph_id + if is_sac: + # For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode. + node.meta["recompute"] = None + else: + # Under vanilla activation checkpointing, all nodes should be recomputed. + node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE + return gmod + + def __call__(self, gmod, *args, **kwargs): + import torch.fx.traceback as fx_traceback + from torch.fx import Interpreter + + if "_checkpoint_context_fn" in gmod.meta: + warning_once( + log, + """ +Detected that context_fn is passed to torch.utils.checkpoint under torch.compile. +Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_). +""", + ) + # use_reentrant is set to False because this op is going to be traced. + # And we ensure that AOT Autograd traces through the non reentrant + # version of checkpointing. + kwargs["use_reentrant"] = False + # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through + # `torch.random.fork_rng` op (which is not supported yet under CUDA). + # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state + # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor + # instead of in AOTAutograd). + kwargs["preserve_rng_state"] = False + kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"] + # We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag + # for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py. + gmod = self.tag_nodes(gmod, is_sac=True) + # Using interpreter allows preservation of metadata through torch.compile stack. + with fx_traceback.preserve_node_meta(): + from torch.utils.checkpoint import checkpoint + + return checkpoint(Interpreter(gmod).run, *args, **kwargs) + else: + gmod = self.tag_nodes(gmod, is_sac=False) + # Using interpreter allows preservation of metadata through torch.compile stack. + # TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here + # as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile. + # (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test) + with fx_traceback.preserve_node_meta(): + return Interpreter(gmod).run(*args) + + +tag_activation_checkpoint = TagActivationCheckpoint() diff --git a/phivenv/Lib/site-packages/torch/_inductor/__autotune_main__.py b/phivenv/Lib/site-packages/torch/_inductor/__autotune_main__.py new file mode 100644 index 0000000000000000000000000000000000000000..f80be9bfe48db04edb31a093691c615391af2ed3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/__autotune_main__.py @@ -0,0 +1,33 @@ +import argparse +import logging +import os + +from torch._inductor.autotune_process import TuningProcess +from torch._inductor.compile_worker.utils import _async_compile_initializer + + +log = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--parent", type=int) + parser.add_argument("--read-fd", type=int) + parser.add_argument("--write-fd", type=int) + args = parser.parse_args() + read_pipe = os.fdopen(args.read_fd, "rb") + write_pipe = os.fdopen(args.write_fd, "wb") + + try: + # Ensures the subprocess exits if the parent crashes: + _async_compile_initializer(args.parent) + TuningProcess.process_main(read_pipe, write_pipe) + except Exception: + log.exception("Uncaught exception in autotune subprocess") + finally: + read_pipe.close() + write_pipe.close() + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/torch/_inductor/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a051f2115e98d9d55a13e3db2abcd93b8fdee34f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/__init__.py @@ -0,0 +1,415 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import io +import logging +import os +from typing import Any, IO, Literal, Optional, TYPE_CHECKING, Union + +import torch._inductor.config +import torch.fx + +from .standalone_compile import CompiledArtifact # noqa: TC001 + + +if TYPE_CHECKING: + from torch._inductor.utils import InputType + from torch.export import ExportedProgram + from torch.export.pt2_archive._package_weights import Weights + from torch.types import FileLike + +__all__ = [ + "compile", + "list_mode_options", + "list_options", + "cudagraph_mark_step_begin", + "standalone_compile", +] + + +log = logging.getLogger(__name__) + + +def compile( + gm: torch.fx.GraphModule, + example_inputs: list[InputType], + options: Optional[dict[str, Any]] = None, +): + """ + Compile a given FX graph with TorchInductor. This allows compiling + FX graphs captured without using TorchDynamo. + + Args: + gm: The FX graph to compile. + example_inputs: List of tensor inputs. + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Callable with same behavior as gm but faster. + """ + from .compile_fx import compile_fx + + return compile_fx(gm, example_inputs, config_patches=options) + + +def aoti_compile_and_package( + exported_program: ExportedProgram, + _deprecated_unused_args=None, + _deprecated_unused_kwargs=None, + *, + package_path: Optional[FileLike] = None, + inductor_configs: Optional[dict[str, Any]] = None, +) -> str: + """ + Compiles the exported program with AOTInductor, and packages it into a .pt2 + artifact specified by the input package_path. To load the package, you can + call ``torch._inductor.aoti_load_package(package_path)``. + + An example usage is as follows: + + .. code-block:: python + + ep = torch.export.export(M(), ...) + aoti_file = torch._inductor.aoti_compile_and_package( + ep, package_path="my_package.pt2" + ) + compiled_model = torch._inductor.aoti_load_package("my_package.pt2") + + To compile and save multiple models into a single ``.pt2`` artifact, you can do + the following: + + .. code-block:: python + + ep1 = torch.export.export(M1(), ...) + aoti_file1 = torch._inductor.aot_compile( + ep1, ..., options={"aot_inductor.package": True} + ) + ep2 = torch.export.export(M2(), ...) + aoti_file2 = torch._inductor.aot_compile( + ep2, ..., options={"aot_inductor.package": True} + ) + + from torch._inductor.package import package_aoti, load_package + + package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2}) + + compiled_model1 = load_package("my_package.pt2", "model1") + compiled_model2 = load_package("my_package.pt2", "model2") + + Args: + exported_program: An exported program created through a call from torch.export + package_path: Optional specified path to the generated .pt2 artifact. + inductor_configs: Optional dictionary of configs to control inductor. + + Returns: + Path to the generated artifact + """ + from torch.export import ExportedProgram + + from .debug import aot_inductor_minifier_wrapper + + if not isinstance(exported_program, ExportedProgram): + raise ValueError("Only ExportedProgram is supported") + + if exported_program.example_inputs is None: + raise RuntimeError( + "exported_program.example_inputs is required to be set in order " + "for AOTInductor compilation." + ) + + if _deprecated_unused_args is not None or _deprecated_unused_kwargs is not None: + log.warning( + "You no longer need to specify args/kwargs to aoti_compile_and_package " + "as we can get this information from exported_program.example_inputs." + ) + + assert ( + package_path is None + or ( + isinstance(package_path, (io.IOBase, IO)) + and package_path.writable() + and package_path.seekable() + ) + or ( + isinstance(package_path, (str, os.PathLike)) + and os.fspath(package_path).endswith(".pt2") + ) + ), ( + f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}" + ) + + inductor_configs = inductor_configs or {} + inductor_configs["aot_inductor.package"] = True + + if inductor_configs.get("aot_inductor.output_path"): + raise RuntimeError( + "Please pass in a package path to aot_inductor_compile() instead " + "of setting the aot_inductor.output_path config." + ) + + # a wrapper around aoti_compile_and_package_inner. + return aot_inductor_minifier_wrapper( + _aoti_compile_and_package_inner, + exported_program, + package_path=package_path, + inductor_configs=inductor_configs, + ) + + +def _aoti_compile_and_package_inner( + gm: torch.nn.Module, + # flat_example_inputs: List[Any], + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + *, + load_and_run: bool = False, + check_accuracy: Optional[str] = None, + package_path: Optional[Union[str, io.BytesIO]] = None, + inductor_configs: Optional[dict[str, Any]] = None, +): + """ + See docstring for aoti_compile_and_package. + + If `load_and_run` is True, this function will load the compiled model and run it. + This is for the minifier to check the correctness of the compiled model. + + If `check_accuracy` is set, this function will check the accuracy of the compiled + model against gm. kwargs must be None if check_accuracy is set. + "strict_accuracy" means "we will minify any time we see anything that + diverges", whereas "accuracy" is more conservative, and will only minify if there + is a meaningful fp64 divergence + """ + + if check_accuracy: + assert kwargs is None or len(kwargs) == 0, ( + "when checking for accuracy, the inputs must have been flattened and kwargs is None" + ) + + from .package import package_aoti + + assert isinstance(gm, torch.fx.GraphModule) + + kwargs = kwargs or {} + + aoti_files = aot_compile(gm, args, kwargs, options=inductor_configs) + assert isinstance(aoti_files, list) + + if package_path is None: + path = [ + os.path.splitext(file)[0] + for file in aoti_files + if isinstance(file, str) and os.path.splitext(file)[1] == ".so" + ] + if len(path) == 0: + path = [ + os.path.splitext(file)[0] + for file in aoti_files + if isinstance(file, str) and os.path.splitext(file)[1] == ".cpp" + ] + package_path = path[0] + ".pt2" + + res = package_aoti(package_path, aoti_files) + assert res == package_path + + if load_and_run or check_accuracy: + compiled_model = aoti_load_package(package_path) + if check_accuracy: + from torch._dynamo.debug_utils import AccuracyError, same_two_models + + # This might look inverted but it's not. strict_accuracy means "we will + # minify any time we see anything that diverges", whereas accuracy is more + # conservative, and will only minify if there is a meaningful fp64 + # divergence + not_strict_accuracy = check_accuracy == "accuracy" + if not same_two_models( + gm, + compiled_model, + args, + only_fwd=True, + require_fp64=not_strict_accuracy, + ignore_non_fp=not_strict_accuracy, + ): + raise AccuracyError("Bad accuracy detected") + else: + compiled_model(*args, **kwargs) + + return package_path + + +def aoti_load_package( + path: FileLike, run_single_threaded: bool = False, device_index: int = -1 +) -> Any: # type: ignore[type-arg] + """ + Loads the model from the PT2 package. + + If multiple models were packaged into the PT2, this will load the default + model. To load a specific model, you can directly call the load API + + .. code-block:: python + + from torch._inductor.package import load_package + + compiled_model1 = load_package("my_package.pt2", "model1") + compiled_model2 = load_package("my_package.pt2", "model2") + + Args: + path: Path to the .pt2 package + run_single_threaded (bool): Whether the model should be run without + thread synchronization logic. This is useful to avoid conflicts with + CUDAGraphs. + device_index (int): The index of the device to which the PT2 package is + to be loaded. By default, `device_index=-1` is used, which corresponds + to the device `cuda` when using CUDA. Passing `device_index=1` would + load the package to `cuda:1`, for example. + """ + from torch._inductor.package import load_package + + return load_package( + path, run_single_threaded=run_single_threaded, device_index=device_index + ) + + +def aot_compile( + gm: torch.fx.GraphModule, + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + *, + options: Optional[dict[str, Any]] = None, +) -> Union[str, list[Union[str, Weights]]]: + """ + Ahead-of-time compile a given FX graph with TorchInductor into a shared library. + + Args: + gm: The FX graph to compile. + args: Example arguments + kwargs: Example keyword arguments + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Path to the generated shared library, or a list of files generated by + AOTI if aot_inductor.package=True. + TODO: make it return a list by default + """ + from .compile_fx import _aoti_flatten_inputs, compile_fx_aot + + flat_example_inputs, options = _aoti_flatten_inputs( + gm, args, kwargs, options=options + ) + from torch._export.utils import _compiling_state_context + + with _compiling_state_context(): + return compile_fx_aot( + gm, + flat_example_inputs, # type: ignore[arg-type] + config_patches=options, + ) + + +def list_mode_options( + mode: Optional[str] = None, dynamic: Optional[bool] = None +) -> dict[str, Any]: + r"""Returns a dictionary describing the optimizations that each of the available + modes passed to `torch.compile()` performs. + + Args: + mode (str, optional): The mode to return the optimizations for. + If None, returns optimizations for all modes + dynamic (bool, optional): Whether dynamic shape is enabled. + + Example:: + >>> torch._inductor.list_mode_options() + """ + + mode_options: dict[str, dict[str, bool]] = { + "default": {}, + # enable cudagraphs + "reduce-overhead": { + "triton.cudagraphs": True, + }, + # enable max-autotune + "max-autotune-no-cudagraphs": { + "max_autotune": True, + "coordinate_descent_tuning": True, + }, + # enable max-autotune + # enable cudagraphs + "max-autotune": { + "max_autotune": True, + "triton.cudagraphs": True, + "coordinate_descent_tuning": True, + }, + } + try: + return mode_options[mode] if mode else mode_options + except KeyError as e: + raise RuntimeError( + f"Unrecognized mode={mode}, should be one of: {', '.join(mode_options.keys())}" + ) from e + + +def list_options() -> list[str]: + r"""Returns a dictionary describing the optimizations and debug configurations + that are available to `torch.compile()`. + + The options are documented in `torch._inductor.config`. + + Example:: + + >>> torch._inductor.list_options() + """ + + from torch._inductor import config + + current_config: dict[str, Any] = config.get_config_copy() + + return list(current_config.keys()) + + +def cudagraph_mark_step_begin(): + "Indicates that a new iteration of inference or training is about to begin." + from .cudagraph_trees import mark_step_begin + + mark_step_begin() + + +def standalone_compile( + gm: torch.fx.GraphModule, + example_inputs: list[InputType], + *, + dynamic_shapes: Literal[ + "from_example_inputs", "from_tracing_context", "from_graph" + ] = "from_graph", + options: Optional[dict[str, Any]] = None, +) -> CompiledArtifact: + """ + Precompilation API for inductor. + + .. code-block:: python + + compiled_artifact = torch._inductor.standalone_compile(gm, args) + compiled_artifact.save(path=path, format="binary") + + # Later on a new process + loaded = torch._inductor.CompiledArtifact.load(path=path, format="binary") + compiled_out = loaded(*args) + + Args: + gm: Graph Module + example_inputs: Inputs for the graph module + dynamic_shapes: If "from_graph" (default), we will use the dynamic + shapes in the passed-in graph module. + If "from_tracing_context", we use the dynamic shape info in the + ambient tracing context. + If "from_example_inputs", we will specialize the graph on the + example_inputs. + options: Inductor compilation options + + Returns: + CompiledArtifact that can be saved to disk or invoked directly. + """ + from .standalone_compile import standalone_compile + + options = options if options else {} + return standalone_compile( + gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/analyze_preserves_zero_mask.py b/phivenv/Lib/site-packages/torch/_inductor/analyze_preserves_zero_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..8b326265ca83c2d8824c24784d049a23a6511364 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/analyze_preserves_zero_mask.py @@ -0,0 +1,165 @@ +import dataclasses +import itertools +from typing import Any, Optional, TYPE_CHECKING + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.dtype_propagation import DtypePropagationOpsHandler +from torch._inductor.index_propagation import SymPyOps, TypedExpr + +from .ops_handler import DefaultHandler +from .virtualized import StoreMode, V + + +if TYPE_CHECKING: + from torch._inductor.scheduler import SchedulerNode + + +def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol: + return sympy.Symbol(f"unknown_{count}") + + +class PreservesZeros(SymPyOps, DefaultHandler): + """ + For prologue kernels where the loads are masked, does the final store of this kernel preserve + the zeros. + """ + + def __init__(self) -> None: + self.count = itertools.count(0) + self.store_preserves_zeros: Optional[bool] = None + self.dtype_prop = DtypePropagationOpsHandler() + + def load(self, name: str, index: sympy.Expr) -> TypedExpr: + # In prologue fusion, all loads get broadcasted + dtype = self.dtype_prop.load(name, index) + return TypedExpr( + sympy.Float(0) if dtype.is_floating_point else sympy.Integer(0), dtype + ) + + def store( + self, name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None + ) -> None: + assert isinstance(self, PreservesZeros) + # should only have a single store in prologue + assert self.store_preserves_zeros is None + self.store_preserves_zeros = value.is_constant() and value.expr == 0 + + def indirect_indexing(self, *args: Any, **kwargs: Any) -> sympy.Expr: + return construct_symbol(next(self.count), torch.int32) + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + from torch._inductor.codegen.common import OpDecompositions + + if hasattr(OpDecompositions, name): + return getattr(OpDecompositions, name)(*args, **kwargs).value + + dtype = getattr(self.dtype_prop, name)(*args, **kwargs) + return TypedExpr(construct_symbol(next(self.count), dtype), dtype) + + +def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool: + """ + Does this prologue preserve zero masks + """ + preserves_zeros = PreservesZeros() + with V.set_ops_handler(preserves_zeros): + prologue._body(*prologue.get_ranges()) + + store_preserves_zeros = preserves_zeros.store_preserves_zeros + assert isinstance(store_preserves_zeros, bool) + + return store_preserves_zeros + + +@dataclasses.dataclass +class DTypeContainer: + dtype: torch.dtype + is_scalar: bool = False + + +class RecordLowPrecisionOps(DefaultHandler): + def __init__(self, disallow_fp32_ops: bool = False) -> None: + self.disallow_fp32_ops = disallow_fp32_ops + self.low_precision_numeric_op = False + self.dtype_prop = DtypePropagationOpsHandler() + self.non_numeric_ops = ( + "to_dtype", + "constant", + "where", + ) + + def load(self, name: str, index: sympy.Expr) -> DTypeContainer: + return DTypeContainer(self.dtype_prop.load(name, index)) + + @staticmethod + def store( + name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None + ) -> None: + pass + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + pass + + @staticmethod + def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr: + return sympy.S.Zero + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs) + out = DTypeContainer(out_dtype, is_scalar=(name == "constant")) + if name == "constant": + return DTypeContainer(torch.float, is_scalar=True) + + uses_low_prec = any( + isinstance(dtype_cont, DTypeContainer) + and dtype_cont.dtype is not None + and low_prec_float(dtype_cont.dtype) + for dtype_cont in itertools.chain((out,), args, kwargs.values()) + ) + + if uses_low_prec and name not in self.non_numeric_ops: + self.low_precision_numeric_op = True + + if ( + self.disallow_fp32_ops + and out.dtype in (torch.float32, torch.float64) + and not out.is_scalar + ): + self.low_precision_numeric_op = True + + return out + + +def low_prec_float(dtype: torch.dtype) -> bool: + return dtype.is_floating_point and dtype.itemsize < 4 + + +def can_codegen_without_upcasts( + prologue: "SchedulerNode", + disallow_fp32_ops: bool = False, +) -> bool: + """ + Can this prologue be run without `upcast_to_fp32` while preserving numerics. + + This is only true if the node only contains dtype conversions, indexing, and other non-arithmetic operators. + + If disallow_fp32_ops is True, then we also disallow ops that are explicitly computed in fp32 or fp64. + """ + if prologue.get_operation_names() <= V.graph.low_precision_codegen_ops: + return True + + low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops) + + # Need to turn off upcasting to do analysis of whether we can turn it off + with ( + config.patch("triton.codegen_upcast_to_fp32", False), + V.set_ops_handler(low_prec_analysis), + ): + prologue._body(*prologue.get_ranges()) + + return not low_prec_analysis.low_precision_numeric_op diff --git a/phivenv/Lib/site-packages/torch/_inductor/aoti_eager.py b/phivenv/Lib/site-packages/torch/_inductor/aoti_eager.py new file mode 100644 index 0000000000000000000000000000000000000000..0b60fb1f14159922e05d2c888c88887fd6161b5d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/aoti_eager.py @@ -0,0 +1,298 @@ +import json +import logging +import os +from pathlib import Path +from typing import Any, Callable, Optional +from unittest import mock + +import torch +import torch._export +from torch._inductor.utils import is_cpu_device + +from .runtime.runtime_utils import cache_dir + + +log = logging.getLogger(__name__) + + +def aoti_eager_cache_dir(namespace: str, device: str) -> Path: + return Path(cache_dir()) / "aoti_eager" / namespace / device + + +def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any: + # Avoid circular import + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT + from torch.utils._filelock import FileLock + + op_conf_lock_file = f"{op_func_name_with_overload}.lock" + lock_dir = get_lock_dir() + return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT) + + +def load_aoti_eager_cache( + ns: str, op_func_name_with_overload: str, device_type: str +) -> list[Optional[dict[str, Any]]]: + device_kernel_cache = aoti_eager_cache_dir(ns, device_type) + op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" + if not op_conf.exists(): + return [] + + try: + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf) as f: + json_data = json.load(f) + for item in json_data: + # Get absolution path for kernel library + kernel_lib_abs_path = device_kernel_cache / item["kernel_path"] + item["kernel_path"] = kernel_lib_abs_path.as_posix() + + # Check if the kernel library exists + if not kernel_lib_abs_path.exists(): + return [] + + for metadata in item["meta_info"]: + if metadata.get("is_dynamic"): + raise NotImplementedError( + "Only support static shape for now" + ) + if ( + "device_type" in metadata + and metadata["device_type"] == "cpu" + ): + metadata["device_index"] = -1 + for dtype_key in ["dtype", "dtype_value"]: + if dtype_key in metadata: + metadata[dtype_key] = getattr( + torch, metadata[dtype_key].split(".")[-1] + ) + if "layout_value" in metadata: + metadata["layout_value"] = getattr( + torch, metadata["layout_value"].split(".")[-1] + ) + if "memory_format_value" in metadata: + metadata["memory_format_value"] = getattr( + torch, metadata["memory_format_value"].split(".")[-1] + ) + + return json_data + except Exception as e: + err_msg = f"Failed to load aoti eager cache: {e}" + log.exception(err_msg) + return [] + + +def supported_builtin_dtype_torch_dtype() -> dict[type, torch.dtype]: + return {int: torch.int32, float: torch.float, bool: torch.bool} + + +def supported_scalar_types() -> tuple[type, ...]: + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + return tuple(type_to_torch_dtype.keys()) + + +def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> dict[str, Any]: + metadata: dict[str, Any] = {} + metadata["is_dynamic"] = dynamic + + assert isinstance(input, torch.Tensor) + metadata["device_type"] = f"{input.device.type}" + if is_cpu_device([input]): + metadata["device_index"] = -1 + else: + metadata["device_index"] = input.device.index + metadata["dtype"] = f"{input.dtype}" + metadata["sizes"] = list(input.size()) + metadata["strides"] = list(input.stride()) + metadata["requires_grad"] = input.requires_grad + metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr() + return metadata + + +def extract_tensor_list_metadata( + dynamic: bool, + input: list[torch.Tensor], +) -> dict[str, Any]: + metadata_list = [] + for item in input: + assert isinstance(item, torch.Tensor) + metadata_list.append(extract_tensor_metadata(dynamic, item)) + + metadata: dict[str, Any] = {} + metadata["tensor_list"] = metadata_list + return metadata + + +def extract_scalar_metadata(device_type: str, input: Any) -> dict[str, Any]: + assert isinstance(input, supported_scalar_types()) + metadata: dict[str, Any] = {} + metadata["is_dynamic"] = False + # Scalar tensor + metadata["device_type"] = device_type + metadata["device_index"] = -1 if device_type == "cpu" else 0 + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" + metadata["scalar_value"] = input + return metadata + + +def extract_string_metadata(input: str) -> dict[str, Any]: + assert isinstance(input, str) + metadata: dict[str, Any] = {} + metadata["string_value"] = input + return metadata + + +def extract_dtype_metadata(input: torch.dtype) -> dict[str, Any]: + assert isinstance(input, torch.dtype) + metadata: dict[str, Any] = {} + metadata["dtype_value"] = f"{input}" + return metadata + + +def extract_device_metadata(input: torch.device) -> dict[str, Any]: + assert isinstance(input, torch.device) + metadata: dict[str, Any] = {} + metadata["device_type_value"] = f"{input.type}" + metadata["device_index_value"] = input.index + return metadata + + +def extract_layout_metadata(input: torch.layout) -> dict[str, Any]: + assert isinstance(input, torch.layout) + metadata: dict[str, Any] = {} + metadata["layout_value"] = f"{input}" + return metadata + + +def aoti_compile_with_persistent_cache( + ns: str, + op_func_name_with_overload: str, + device_type: str, + dynamic: bool, + f: Callable[..., Any], + args: tuple[Any], + kwargs: dict[str, Any], + *, + dynamic_shapes: Optional[dict[str, Any]] = None, + options: Optional[dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, +) -> str: + """ + Compile the given function with persistent cache for AOTI eager mode. + """ + assert not dynamic, "Only support static shape for now" + flattened_inputs = list(args) + list(kwargs.values()) + if not all( + isinstance( + input, + ( + supported_scalar_types(), + torch.Tensor, + list, + str, + torch.dtype, + torch.device, + torch.layout, + ), + ) + for input in flattened_inputs + ): + err_msg = f"Unsupported input types: {flattened_inputs}" + log.exception(err_msg) + raise NotImplementedError(err_msg) + + for input in flattened_inputs: + if isinstance(input, list) and not all( + isinstance(item, torch.Tensor) for item in input + ): + err_msg = f"_impl_with_aoti_compile encounters unsupported input types: {flattened_inputs}" + log.exception(err_msg) + raise NotImplementedError(err_msg) + + persistent_cache = aoti_eager_cache_dir(ns, device_type) + if not persistent_cache.exists(): + persistent_cache.mkdir(parents=True) + + persistent_cache_lib = persistent_cache / "lib" + if not persistent_cache_lib.exists(): + persistent_cache_lib.mkdir() + + with mock.patch.dict( + os.environ, + {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, + ): + try: + kernel_lib_path = torch._export.aot_compile( + f, + args, + kwargs, + dynamic_shapes=dynamic_shapes, + remove_runtime_assertions=remove_runtime_assertions, + disable_constraint_solver=disable_constraint_solver, + # Some operations may have non-Tensor parameters like int, float, bool. These + # non-Tensor parameters will not be the input of the graph. Therefore, we do + # need to keep the same signature. + same_signature=False, + ) + assert isinstance(kernel_lib_path, str) + + kernel_metadata_items = [] + + for idx, input in enumerate(flattened_inputs): + if isinstance(input, torch.Tensor): + metadata = extract_tensor_metadata(dynamic, input) + elif isinstance(input, list): + assert all(isinstance(item, torch.Tensor) for item in input) + metadata = extract_tensor_list_metadata(dynamic, input) + elif isinstance(input, supported_scalar_types()): + metadata = extract_scalar_metadata(device_type, input) + elif isinstance(input, str): + metadata = extract_string_metadata(input) + elif isinstance(input, torch.dtype): + metadata = extract_dtype_metadata(input) + elif isinstance(input, torch.device): + metadata = extract_device_metadata(input) + elif isinstance(input, torch.layout): + metadata = extract_layout_metadata(input) + else: + raise NotImplementedError(f"Unsupported input type: {type(input)}") + + metadata["arg_order"] = idx + kernel_metadata_items.append(metadata) + + kernel_meta_info: dict[str, Any] = {} + kernel_meta_info["meta_info"] = kernel_metadata_items + kernel_meta_info["kernel_path"] = ( + Path(kernel_lib_path).relative_to(persistent_cache).as_posix() + ) + + json_data = [] + update_json = True + op_conf = persistent_cache / f"{op_func_name_with_overload}.json" + mode = "r" if op_conf.exists() else "w" + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf, mode) as op_conf_file: + try: + json_data = json.load(op_conf_file) + except Exception: + json_data = [] + + assert isinstance(json_data, list) + for item in json_data: + assert isinstance(item, dict) + # Same kernel meta info already exists in the json file + if item["meta_info"] == kernel_metadata_items: + update_json = False + break + + if update_json: + json_data.append(kernel_meta_info) + with open(op_conf, "w") as op_conf_file: + json.dump(json_data, op_conf_file, indent=4) + + return kernel_lib_path + except Exception as e: + err_msg = f"Failed to compile {op_func_name_with_overload}: {e}" + log.exception(err_msg) + return "" diff --git a/phivenv/Lib/site-packages/torch/_inductor/async_compile.py b/phivenv/Lib/site-packages/torch/_inductor/async_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..ef30ad037c23bc6609a154c525471722e9c75048 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/async_compile.py @@ -0,0 +1,541 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import atexit +import functools +import json +import logging +import multiprocessing +import os +import sys +from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures.process import BrokenProcessPool +from functools import partial +from time import time, time_ns +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +from torch._dynamo.device_interface import get_registered_device_interfaces +from torch._dynamo.utils import ( + counters, + dynamo_timed, + get_metrics_context, + set_feature_use, +) +from torch._inductor import config +from torch._inductor.codecache import ( + _load_triton_kernel_from_source, + code_hash, + CodeCacheFuture, + CppCodeCache, + CppPythonBindingsCodeCache, + CUDACodeCache, + HalideCodeCache, + LambdaFuture, + ROCmCodeCache, + StaticAutotunerFuture, + torch_key, +) +from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool +from torch._inductor.compile_worker.tracked_process_pool import ( + TrackedProcessPoolExecutor, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.runtime.compile_tasks import ( + _set_triton_ptxas_path, + _worker_compile_triton, +) +from torch._inductor.utils import clear_on_fresh_cache +from torch._inductor.virtualized import V +from torch.hub import _Faketqdm, tqdm +from torch.utils._ordered_set import OrderedSet +from torch.utils._triton import has_triton_package + + +if TYPE_CHECKING: + from torch._inductor.runtime.hints import HalideMeta + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + +# timing metrics for time spent in the compilation +_cumulative_compile_time = 0.0 +_t0: Optional[float] = None + +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + +log = logging.getLogger(__name__) + +_triton_kernel_metrics: Optional[dict[str, dict[str, Any]]] = None + + +def pre_fork_setup(): + """ + Setup that must be done prior to forking with a process pool. + """ + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + + # Computing the triton key can be slow. If we call it before fork, + # it will be cached for the forked subprocesses. + try: + from triton.compiler.compiler import triton_key + + triton_key() + except ImportError: + # Triton might not be installed or might be an old version. + pass + + +def caching_device_properties(): + for _, device_interface in get_registered_device_interfaces(): + if device_interface.is_available(): + device_interface.Worker.get_device_properties() + + +def _compile_start() -> None: + global _t0, _triton_kernel_metrics + if _t0 is None: + _t0 = time() + if _triton_kernel_metrics is None: + _triton_kernel_metrics = {} + + +def _compile_end() -> None: + global _cumulative_compile_time, _t0, _triton_kernel_metrics + if _t0 is not None: + t1 = time() + _cumulative_compile_time += t1 - _t0 + _t0 = None + # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) + if _triton_kernel_metrics: + # Log triton kernel info + sorted_info = dict(sorted(_triton_kernel_metrics.items())) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "triton_kernel_info", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(sorted_info), + ) + _triton_kernel_metrics = None + + +def _add_triton_kernel_info(kernel_name: str, info: dict[str, Any]): + global _triton_kernel_metrics + # Must be called between _compile_start and _compile_end + if _triton_kernel_metrics is not None: + _triton_kernel_metrics[kernel_name] = info + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + +# Used to keep track of all process pools invoked so far. +_pool_set = OrderedSet[AnyPool]() + + +def shutdown_compile_workers() -> None: + """Shut down all outstanding compile-worker pools.""" + for pool in _pool_set: + pool.shutdown() + after_fork() + + +def after_fork(): + """Reset pools to initial state without shutting them down""" + _pool_set.clear() + AsyncCompile.process_pool.cache_clear() + + +try: + os.register_at_fork(after_in_child=after_fork) +except AttributeError: + pass # register_at_fork does not exists on windows + + +def get_compile_threads() -> int: + """ + Temporary for internal rollout. Assign config.compile_threads lazily and return it. + TODO: remove after rollout. + """ + if config.compile_threads is None: + config.compile_threads = config.decide_compile_threads() + return config.compile_threads + + +@clear_on_fresh_cache +class CompiledTritonKernels: + """ + In memory cache for storing compiled triton kernels. + + Each triton kernel is keyed by the hash of its source code. Each value stored + in the cache is a return value of AsyncCompile.triton(). + + Currently, the cache stores Future objects, but it should be generalizable for any kernels. + """ + + _cache: dict[str, CodeCacheFuture] = {} + + @staticmethod + def key(kernel_src: str): + """ + Generates a cache key given a triton kernel's full source code. + This source includes the inductor meta, compilation metadata, the kernel itself, etc. + `kernel_src` should be the exact string passed to async_compile.triton()'s first argument. + """ + # Hashes the kernel source with torch_key into a single hash key + return code_hash(kernel_src, extra=torch_key()) + + @staticmethod + def save(kernel_src: str, future: CodeCacheFuture): + """ + Saves a compiled triton kernel to the cache. + TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton, + but the real type we want to return here is actually an abstract triton kernel. + + TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc. + so it could be less strict. + """ + key = CompiledTritonKernels.key(kernel_src) + CompiledTritonKernels._cache[key] = future + + @staticmethod + def get(kernel_src: str) -> Optional[CodeCacheFuture]: + key = CompiledTritonKernels.key(kernel_src) + return CompiledTritonKernels._cache.get(key, None) + + @staticmethod + def cache_clear(): + CompiledTritonKernels._cache = {} + + @staticmethod + def remove_future(kernel_src: str) -> None: + key = CompiledTritonKernels.key(kernel_src) + + # Delete the LambdaFuture if there is one + if key in CompiledTritonKernels._cache: + del CompiledTritonKernels._cache[key] + + +class AsyncCompile: + def __init__(self) -> None: + pass + + @staticmethod + @functools.lru_cache(1) + def pool() -> ThreadPoolExecutor: + assert get_compile_threads() > 1 + return ThreadPoolExecutor(get_compile_threads()) + + @staticmethod + def _get_ready(): + """No-op function to help mark when the subprocess pool is ready.""" + return "ready" + + @staticmethod + @functools.lru_cache(1) + def process_pool() -> AnyPool: + assert get_compile_threads() > 1 + log.info( + "Creating '%s' pool with %d workers", + config.worker_start_method, + get_compile_threads(), + ) + + pool: AnyPool + if config.worker_start_method == "subprocess": + # Wrapper around ProcessPoolExecutor forks in a new process we control + pool = SubprocPool(get_compile_threads()) + else: + if config.worker_start_method == "spawn": + # Avoid creating pools in the spawned subprocs themselves: + os.environ["TORCH_WARM_POOL"] = "0" + pre_fork_setup() + ctx = multiprocessing.get_context(config.worker_start_method) + pool = TrackedProcessPoolExecutor( + get_compile_threads(), + mp_context=ctx, + initializer=partial(_async_compile_initializer, os.getpid()), + ) + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + + # Set an attribute we can check to see if the pool is ready. + pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr] + _pool_set.add(pool) + return pool + + @classmethod + def warm_pool(cls) -> None: + if get_compile_threads() <= 1: + return + _compile_start() + # Pool is initialized on first access + cls.process_pool() + _compile_end() + + @classmethod + def submit(cls, task: Callable[..., Any]) -> Any: + if get_compile_threads() <= 1: + return task() + return cls.pool().submit(task) + + def use_process_pool(self): + return ( + get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr] + ) + + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + """ + Async_compile.triton is more complicated than the other backends because + we're trying to optimize compile time as much as possible for this hot callsite. + + First of all, the function is cached by CompiledTritonKernels; if there's a kernel + already compiled, we grab it directly from the cache and return. + + Otherwise, if we have multiple compile threads, we kick off triton compilations on each + worker process by giving it a kernel and source code to compile. The worker initializes + a CachingAutotuner, runs triton compilation, and pickles the kernel back to us. + We use TritonCompileResult to represent the objects being pickled back to us by each + worker. + + Some maybe not obvious things that are pickled back to us: + - Most of the time, we can avoid sending back CachingAutotuner.fn and other metadata + and do not have to pay the cost of loading the triton kernel on the parent. But certain + cases, like coordesc tuning and dynamic_scale_rblock, require us to reload the function + in the parent lazily when we require it. + - The AutotuneCache, if enabled, is constructed on each worker per triton config + and pickled by to us via `CachingAutotuner.save_cache_hook`. + """ + load_kernel = functools.partial( + _load_triton_kernel_from_source, kernel_name, source_code + ) + + def reload_kernel_in_parent(): + # Benchmark how often this happens + with dynamo_timed("reload_kernel_in_parent"): + return load_kernel() + + counters["inductor"]["async_compile_cache_miss"] += 1 + + kernel_code_log.info("Triton Kernel:\n%s", source_code) + _compile_start() + + if os.environ.get("TRITON_INTERPRET", "0") == "1": + return getattr( + torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name + ) + + is_parallel = self.use_process_pool() + set_feature_use("parallel_compile_post_warmup", is_parallel) + + compile_id = torch._guards.CompileContext.current_compile_id() + is_backward = getattr(V.graph, "is_backward", False) + + if (future := CompiledTritonKernels.get(source_code)) is not None: + counters["inductor"]["async_compile_cache_hit"] += 1 + # Set reload_kernel_from_src properly based on source_code + if isinstance(future, StaticAutotunerFuture): + # Remove the future now that we've cache hit + CompiledTritonKernels.remove_future(source_code) + future.reload_kernel_from_src = reload_kernel_in_parent + if is_parallel: + return future + else: + return future.result() + + # Cache miss + if is_parallel: + # We want to support changing these env vars after (and while) the + # process pool is running, so pass them to the subprocess to reset. + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} + extra_config = { + "use_static_cuda_launcher": torch._inductor.config.use_static_cuda_launcher + } + + task = self.process_pool().submit( + _worker_compile_triton, + load_kernel, + extra_env, + extra_config, + ) + + def get_result() -> CachingAutotuner: + kernel, elapsed_us = task.result() + # Now that we've compiled, we should clear the future + # so it can't be used again + kernel.set_compile_info(compile_id, is_backward) + CompiledTritonKernels.remove_future(source_code) + + kernel.precompile( + warm_cache_only=False, + reload_kernel=reload_kernel_in_parent, + static_triton_bundle_key=CompiledTritonKernels.key(source_code), + ) + info = kernel.autotune_cache_info or {} + info["compile_time_us"] = elapsed_us + _add_triton_kernel_info(kernel_name, info) + get_metrics_context().add_top_n( + "triton_kernel_compile_times_us", kernel_name, elapsed_us + ) + return kernel + + future = LambdaFuture(get_result, future=task) + CompiledTritonKernels.save(source_code, future) + return future + else: + with dynamo_timed( + "async_compile.precompile", + log_pt2_compile_event=True, + dynamo_compile_column_us="triton_compile_time_us", + log_waitcounter=True, + waitcounter_name_override="compile_triton", + ): + start_ns = time_ns() + _set_triton_ptxas_path() + kernel = load_kernel() + kernel.set_compile_info(compile_id, is_backward) + kernel.precompile( + warm_cache_only=False, + static_triton_bundle_key=CompiledTritonKernels.key(source_code), + ) + elapsed_us = (time_ns() - start_ns) // 1000 + get_metrics_context().add_top_n( + "triton_kernel_compile_times_us", kernel_name, elapsed_us + ) + info = kernel.autotune_cache_info or {} + info["compile_time_us"] = elapsed_us + _add_triton_kernel_info(kernel_name, info) + return kernel + + def multi_kernel(self, *args, **kwargs) -> Any: + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + # no need to call this in parallel since the sub-kernels are already parallel tasks + return MultiKernelCall(*args, **kwargs) + + def cpp(self, source_code: str): + kernel_code_log.info("CPP Kernel:\n%s", source_code) + if get_compile_threads() <= 1: + return CppCodeCache.load(source_code).kernel + else: + get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) + return LambdaFuture(lambda: get_result().kernel) + + def cpp_pybinding(self, argtypes: list[str], source_code: str): + kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) + if get_compile_threads() <= 1: + return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) + else: + get_result = CppPythonBindingsCodeCache.load_pybinding_async( + argtypes, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def cuda(self, source_code, dst_file_ext, aot_compile=False): + kernel_code_log.info("CUDA Kernel:\n%s", source_code) + + def task(): + if aot_compile: + # We rely on JITInductor to compile the CUDA code, + # so that we can load it into AOTInductor. + output_path, *_ = CUDACodeCache.compile(source_code, "o") + CUDACodeCache.aot_kernels_o.append(output_path) + return CUDACodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def rocm( + self, + source_code, + dst_file_ext, + aot_compile=False, + ): + kernel_code_log.info("ROCm Kernel:\n%s", source_code) + + def task(): + if aot_compile: + output_path, *_ = ROCmCodeCache.compile(source_code, dst_file_ext="o") + ROCmCodeCache.aot_kernels_o.append(output_path) + if config.rocm.generate_test_runner: + _ = ROCmCodeCache.compile(source_code, dst_file_ext="exe") + return ROCmCodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def halide(self, meta: HalideMeta, source_code: str): + kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) + if get_compile_threads() <= 1: + return HalideCodeCache.generate_halide(meta, source_code) + else: + get_result = HalideCodeCache.generate_halide_async( + meta, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def wait(self, scope: dict[str, Any]) -> None: + if get_compile_threads() > 1: + with dynamo_timed( + "async_compile.wait", + log_pt2_compile_event=True, + dynamo_compile_column_us="triton_compile_time_us", + log_waitcounter=True, + waitcounter_name_override="compile_triton", + ): + self._wait_futures(scope) + + _compile_end() + + def _wait_futures(self, scope: dict[str, Any]) -> None: + kernels = { + key: value + for key, value in scope.items() + if isinstance(value, (Future, CodeCacheFuture)) + } + pbar = tqdm( + total=len(kernels), + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) + for key, result in kernels.items(): + if config.verbose_progress and not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(key) + try: + kernel = result.result() + scope[key] = kernel + except BrokenProcessPool as e: + raise RuntimeError( + "A compilation subprocess exited unexpectedly. This " + "is likely due to a crash. To facilitate debugging, " + "you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 " + "to cause compilation to occur in the main process." + ) from e + pbar.update(1) + + +if ( + os.environ.get("TORCH_TNT_IN_USE", "0") == "1" + or os.environ.get("TORCH_WARM_POOL", "1") != "1" + # The subprocess pool is only used for the Triton backend + or not has_triton_package() + # Skip for fbcode. We have internal reports of usages inside multiprocessing + # pools that lead a multiplicative number of compile subprocesses. + or config.is_fbcode() +): + pass +else: + AsyncCompile.warm_pool() + +# On exit give the workers a chance to clean themselves up. Without this the +# resource_tracker can complain about leaked semaphores coming from the +# ProcessPoolExecutor: +# UserWarning: resource_tracker: There appear to be 5 leaked semaphore objects +# to clean up at shutdown +atexit.register(shutdown_compile_workers) diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a5b64260dc471471cecdb1c4ab8841168b8bd49 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c7551ed5c81d663a4da38fba9d1e76200dfb3ed Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62a78728b89d43a15b62e6358d203ac553fec690 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dce9ded1eb0d9d49d16bc920d04dbf2c509ffc6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21b904a434a751c753b2fd5af3331bf6429656e3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5eb3f981b0300cb6a321252944cc7d1755dad2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py @@ -0,0 +1,296 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 52.6245059967041: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 312.0: + return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)] + else: + if context.get_value('k') <= 40.0: + return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)] + else: + return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)] + else: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)] + else: + if context.get_value('k') <= 68.0: + return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)] + else: + return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)] + else: + if context.get_value('k') <= 35.0: + if context.get_value('k') <= 18.0: + if context.get_value('m*n') <= 19505152.0: + return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)] + else: + return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)] + else: + if context.get_value('n') <= 68.0: + return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)] + else: + return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)] + else: + if context.get_value('m*n') <= 309760.0: + return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)] + else: + if context.get_value('n') <= 72.0: + return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)] + else: + return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)] + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 815360.0: + if context.get_value('k') <= 1184.0: + return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)] + else: + return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)] + else: + if context.get_value('arith_intensity') <= 187.23922729492188: + if context.get_value('mat1_stride_0') <= 198.0: + return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)] + else: + return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)] + else: + return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)] + else: + return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)] diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py new file mode 100644 index 0000000000000000000000000000000000000000..18a0cc4aaefd5bcbce4c36a40461b36628e7073b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py @@ -0,0 +1,321 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 29.89772129058838: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 432.0: + if context.get_value('arith_intensity') <= 7.8700292110443115: + return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)] + else: + return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)] + else: + if context.get_value('k') <= 40.0: + return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)] + else: + return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)] + else: + if context.get_value('mat1_stride_0') <= 40.0: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)] + else: + return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)] + else: + if context.get_value('mat1_stride_0') <= 68.0: + return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)] + else: + return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)] + else: + if context.get_value('k') <= 18.0: + if context.get_value('m*k') <= 528.0: + return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)] + else: + if context.get_value('n') <= 80.0: + return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)] + else: + return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)] + else: + if context.get_value('k') <= 36.0: + if context.get_value('n') <= 68.0: + return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)] + else: + return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)] + else: + if context.get_value('mat2_stride_0') <= 384.0: + return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)] + else: + return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)] + else: + if context.get_value('arith_intensity') <= 56.995582580566406: + if context.get_value('n') <= 68.0: + if context.get_value('k*n') <= 4448.0: + if context.get_value('m*n') <= 29626368.0: + return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)] + else: + return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)] + else: + if context.get_value('k') <= 348.0: + return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)] + else: + return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)] + else: + if context.get_value('m') <= 3264.0: + return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)] + else: + if context.get_value('k') <= 62.5: + return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)] + else: + return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)] + else: + if context.get_value('m*n') <= 1097728.0: + return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)] + else: + if context.get_value('m*n') <= 3244032.0: + return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)] + else: + if context.get_value('n') <= 136.0: + return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)] + else: + return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)] diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..d24410f14552936c26d1a1301b439d17272b97c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py @@ -0,0 +1,150 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]: + if str(context.get_value('1LEQmLEQ16')) != 'True': + if context.get_value('m') <= 32.5: + if context.get_value('n') <= 6976.0: + if context.get_value('n') <= 3520.0: + if context.get_value('m*n') <= 37632.0: + return None + else: + return [(1.000, 13)] + else: + if context.get_value('m*k') <= 452352.0: + return [(0.590, 13), (0.256, 8), (0.103, 7), (0.051, 11)] + else: + return [(0.778, 8), (0.222, 13)] + else: + if context.get_value('k*n') <= 102776832.0: + if context.get_value('n') <= 14656.0: + return [(1.000, 11)] + else: + return [(0.889, 11), (0.111, 13)] + else: + return [(1.000, 11)] + else: + if context.get_value('m*n') <= 446464.0: + if context.get_value('m*n') <= 223424.0: + if context.get_value('mat1_stride_0') <= 3968.0: + return None + else: + return None + else: + if context.get_value('m*n') <= 346112.0: + return [(0.960, 16), (0.040, 7)] + else: + return [(0.750, 16), (0.136, 14), (0.114, 7)] + else: + if str(context.get_value('33LEQmLEQ64')) != 'True': + if context.get_value('n') <= 6976.0: + return [(1.000, 14)] + else: + return [(0.753, 2), (0.222, 1), (0.015, 7), (0.007, 16), (0.004, 12)] + else: + if context.get_value('n') <= 13888.0: + return [(0.710, 14), (0.275, 21), (0.014, 12)] + else: + return [(0.374, 19), (0.339, 20), (0.106, 21), (0.101, 16), (0.066, 17), (0.009, 14), (0.004, 18)] + else: + if context.get_value('n') <= 3520.0: + if context.get_value('arith_intensity') <= 3.994754433631897: + if str(context.get_value('mat2_dtype')) != 'torch.uint8': + if context.get_value('m*k') <= 18944.0: + return [(0.577, 5), (0.423, 6)] + else: + return [(0.988, 5), (0.012, 6)] + else: + if context.get_value('arith_intensity') <= 2.9899919033050537: + return None + else: + return None + else: + if context.get_value('arith_intensity') <= 7.956453561782837: + if context.get_value('k*n') <= 9244032.0: + return [(0.822, 5), (0.178, 6)] + else: + return [(0.977, 5), (0.023, 0)] + else: + if context.get_value('m*k') <= 978944.0: + return [(1.000, 5)] + else: + return [(0.971, 5), (0.029, 0)] + else: + if context.get_value('n') <= 13632.0: + if context.get_value('n') <= 6976.0: + return [(1.000, 6)] + else: + if context.get_value('k') <= 3968.0: + return [(0.617, 3), (0.111, 5), (0.099, 7), (0.086, 9), (0.062, 6), (0.025, 8)] + else: + return [(0.779, 8), (0.119, 5), (0.053, 7), (0.035, 6), (0.013, 3)] + else: + if context.get_value('k*n') <= 39518208.0: + return [(0.385, 4), (0.327, 3), (0.192, 6), (0.038, 7), (0.038, 10), (0.019, 5)] + else: + if context.get_value('n') <= 20800.0: + return [(0.821, 6), (0.121, 7), (0.029, 4), (0.014, 5), (0.007, 3), (0.007, 8)] + else: + return [(0.530, 7), (0.386, 6), (0.046, 8), (0.021, 3), (0.015, 4), (0.002, 5)] diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py new file mode 100644 index 0000000000000000000000000000000000000000..704a4e14efa6d675c77f1d52e00f8d3cc4697088 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py @@ -0,0 +1,149 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]: + if context.get_value('arith_intensity') <= 15.988086223602295: + if context.get_value('n') <= 25280.0: + if context.get_value('n') <= 1344.0: + if context.get_value('mat1_stride_0') <= 7808.0: + return [(0.581, 7), (0.419, 6)] + else: + if context.get_value('m*n') <= 7680.0: + return [(0.875, 0), (0.125, 6)] + else: + return [(0.833, 0), (0.167, 7)] + else: + if context.get_value('n') <= 8512.0: + if str(context.get_value('mat2_dtype')) != 'torch.int8': + return [(0.763, 6), (0.237, 7)] + else: + return [(0.725, 7), (0.275, 6)] + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)] + else: + return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)] + else: + if context.get_value('n') <= 42254.0: + if context.get_value('n') <= 33856.0: + if context.get_value('k*n') <= 68157440.0: + return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)] + else: + return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)] + else: + return [(0.659, 5), (0.341, 6)] + else: + if context.get_value('k*n') <= 326052992.0: + if context.get_value('n') <= 55232.0: + return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)] + else: + return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)] + else: + if context.get_value('n') <= 57024.0: + return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)] + else: + return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)] + else: + if context.get_value('m*n') <= 543936.0: + if str(context.get_value('17LEQmLEQ32')) != 'True': + if context.get_value('m*n') <= 262272.0: + if context.get_value('n') <= 1592.5: + return [(0.860, 0), (0.140, 9)] + else: + return None + else: + if context.get_value('m*k') <= 1294336.0: + return [(0.833, 17), (0.150, 18), (0.017, 15)] + else: + return [(0.917, 17), (0.083, 8)] + else: + if context.get_value('n') <= 12416.0: + if context.get_value('m*n') <= 43008.0: + return None + else: + return [(0.853, 14), (0.147, 9)] + else: + return [(0.625, 12), (0.375, 14)] + else: + if context.get_value('m') <= 32.5: + if context.get_value('mat2_stride_1') <= 6656.0: + if context.get_value('n') <= 69184.0: + return [(0.611, 12), (0.361, 14), (0.028, 13)] + else: + return [(1.000, 12)] + else: + if context.get_value('mat2_stride_1') <= 20864.0: + return [(1.000, 12)] + else: + return [(0.958, 12), (0.042, 9)] + else: + if context.get_value('m*n') <= 1085440.0: + if context.get_value('n') <= 9152.0: + return [(1.000, 18)] + else: + return [(0.780, 18), (0.160, 16), (0.060, 20)] + else: + if context.get_value('m') <= 67.0: + return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)] + else: + return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)] diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..3a1d390ff498b1bc18df73e0d8eb240a615d0788 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py @@ -0,0 +1,109 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/ +from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicRegression, +) + + +class PadMMA100(LearnedHeuristicRegression): + + def __init__(self) -> None: + pass + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_feedback(self, context: AHContext, choice: Choice) -> float: + context.context_dict[CHOICE_COL] = choice + return self.predict(context) + + def get_confidence_threshold(self) -> float: + return 1.7025303314066 + + def get_name(self) -> str: + return 'pad_mm' + + def predict(self, context: AHContext) -> float: + if str(context.get_value('choice')) != 'pad': + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 4171264.0: + if context.get_value('m*k') <= 3999308.0: + return 1.8751469764071178 + else: + if str(context.get_value('n_multiple_32')) != 'True': + return 0.9117231355626345 + else: + return 1.1607689608873861 + else: + if str(context.get_value('n_multiple_2')) != 'True': + if str(context.get_value('using_tf32')) != 'True': + return 0.7430382200435992 + else: + return 0.8531269794448678 + else: + if str(context.get_value('k_multiple_2')) != 'True': + return 0.7577181972719917 + else: + return 0.8977349440424219 + else: + if context.get_value('m*n') <= 1299712.0: + return 1.1669723418995592 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + if context.get_value('m*n') <= 55884158.0: + return 1.0262769936909601 + else: + return 1.0022677428470845 + else: + if context.get_value('m') <= 18478.0: + return 1.1127066261894312 + else: + return 1.0337740659894263 + else: + if str(context.get_value('mat1_dtype')) != 'torch.float32': + if str(context.get_value('n_multiple_2')) != 'False': + if str(context.get_value('k_multiple_2')) != 'True': + if context.get_value('mat1_stride_0') <= 561.0: + return 1.2900382135142956 + else: + return 1.5761737616057887 + else: + if context.get_value('num_dims_needs_padding') <= 1.5: + return 1.0472263310239422 + else: + return 1.1727673465762514 + else: + if context.get_value('k') <= 28238.5: + if context.get_value('k/(m*n)') <= 0.00026227018679492176: + return 1.6770542505397175 + else: + return 1.3974785435105923 + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return 1.3952699800111992 + else: + return 1.5759286511628336 + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 14119424.0: + return 0.8875772670422478 + else: + if str(context.get_value('mat2_innermost_needs_padding')) != 'True': + return 1.1467728924377265 + else: + return 1.215842963532998 + else: + if context.get_value('arith_intensity') <= 396.8774871826172: + return 0.89940161869551 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + return 0.9964328169353532 + else: + return 0.9493479238294826 diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0150fc46f276946c1cf325125e2d89ba0f802f88 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffb72dd286d2da5a7b76f9c85ae27b3995adfa8b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa3e625c3846d497d9c1dfdf708595aa2727f71f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c30a205e3c086821578788a80a17548f0466e00a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99a35e61f3057da10532ede4cd881bdc1a08bbb8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df2e62cf5c382ce18d904fbc3ed6759b23ee8704 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/autoheuristic.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/autoheuristic.py new file mode 100644 index 0000000000000000000000000000000000000000..430a56063be8837079c4fab71e055b2cabcf1210 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/autoheuristic.py @@ -0,0 +1,315 @@ +import json +import os +from functools import partial +from typing import Any, Callable, Optional + +import torch +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + AHOperation, + Choice, + CHOICE_COL, + Feedback, + FEEDBACK_COL, + get_metadata_str_from_log, +) +from torch._inductor.autoheuristic.learned_heuristic_controller import ( + LearnedHeuristicController, +) +from torch._inductor.ir import ChoiceCaller +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import get_gpu_shared_memory + + +class LocalFeedback: + """ + To be able to collect data for a choice, a function providing feedback given a choice has to be provided. + LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice + (see pad_mm.py, where the autotuning happens locally, for an example). + """ + + def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None: + self.feedback_fn = feedback_fn + + def __call__(self, choice: Choice) -> Feedback: + return self.feedback_fn(choice) + + +class InconsistentMetadata(Exception): + """ + Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does + not match the metadata it would store if the file didn't exist. + """ + + +class AutoHeuristic: + """ + AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and + generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train + a heuristic (see torchgen/autoheuristic/). + """ + + collected_feedback: dict[Choice, Feedback] + + def __init__( + self, + fallback: Callable[[], Choice], + choices: list[Choice], + feedback: Optional[LocalFeedback], + context: AHContext, + name: str, + augment_context: Optional[list[AHOperation]] = None, + precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, + ) -> None: + """ + Initializes an instance of the AutoHeuristic class. + + Args: + fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or + AutoHeuristic is in data collection mode. + choices: A list of possible choices the heuristic can make. + feedback: An instance of LocalFeedback that provides feedback for a given choice. + context: Context to store with each choice and feedback. + name: A string that identifies the heuristic. + augment_context: An optional list of AHOperation instances that augment the context. + precondition: A callable that returns a boolean indicating whether AutoHeuristic should run. + """ + self.fallback = fallback + self.choices = choices + self.feedback = feedback + self.context = context + self.name = name + self.collected_feedback = {} + self.augment_context = augment_context + self.metadata = AHMetadata( + get_gpu_shared_memory(), + torch.cuda.get_device_capability(), + self.choices, + self.name, + ) + self.precondition = precondition + + if not self.satisfies_precondition(): + return + + if torch._inductor.config.autoheuristic_log_path == "DEFAULT": + self.log_path = self.get_default_log_path() + else: + self.log_path = torch._inductor.config.autoheuristic_log_path + + if torch._inductor.config.collect_autoheuristic(self.name): + if self.feedback is not None: + for choice in self.choices: + feedback_val = self.feedback(choice) + self.save_data(choice, feedback_val) + + def satisfies_precondition(self) -> bool: + return self.precondition is None or self.precondition( + self.metadata, self.context + ) + + def get_choice(self) -> Choice: + """ + Returns the chosen option based on the value of autoheuristic_use. + If self.name is one of the comma separated strings in autoheuristic_use, + it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option. + """ + + if not self.satisfies_precondition(): + return self.fallback() + + if torch._inductor.config.use_autoheuristic(self.name): + if self.augment_context is not None: + self.context.apply_operations(self.augment_context) + controller = LearnedHeuristicController( + self.metadata, + self.context, + ) + decision = controller.get_decision() + if decision not in self.choices: + # TODO(AlnisM): We might want to allow this in the future + return self.fallback() + if decision is not None: + return decision + return self.fallback() + + def get_top_k_choices( + self, top_k: int, always_included: Optional[list[str]] = None + ) -> Optional[list[Choice]]: + if not self.satisfies_precondition(): + return None + if torch._inductor.config.use_autoheuristic(self.name): + if self.augment_context is not None: + self.context.apply_operations(self.augment_context) + controller = LearnedHeuristicController( + self.metadata, + self.context, + ) + choices = controller.get_decisions_ranked(top_k) + if choices is None: + return None + if always_included is not None: + for choice in always_included: + if choice not in choices: + choices.append(choice) + return choices + return None + + def get_collected_feedback(self, choice: Choice) -> Any: + return self.collected_feedback.get(choice, None) + + @staticmethod + def get_device_identifier() -> str: + # a heuristic might work well for one GPU, but not for another + # we store the collected data per GPU model and learn a heuristic per GPU model + + # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names + device_name = torch.cuda.get_device_name().replace(" ", "_") + return device_name + + def get_default_log_path(self) -> str: + device_name = self.get_device_identifier() + path = f"{cache_dir()}/autoheuristic/{device_name}/" + os.makedirs(path, exist_ok=True) + path += f"{self.name}.txt" + return path + + def serialize_metadata(self) -> str: + metadata_dict = self.metadata.to_dict() + ( + num_features, + cat_features, + ) = self.context.get_numerical_and_categorical_features() + metadata_dict["numerical_features"] = num_features + metadata_dict["categorical_features"] = cat_features + return json.dumps(metadata_dict) + + def save_data(self, choice: Choice, feedback_val: Feedback) -> None: + self.collected_feedback[choice] = feedback_val + log_path = self.log_path + + lines = [] + log_exists = os.path.exists(log_path) + if log_exists: + # if log already exists, make sure it is consistent + metadata = self.serialize_metadata() + existing_metadata = get_metadata_str_from_log(self.log_path) + if existing_metadata != metadata: + raise InconsistentMetadata( + "Given metadata does not match existing metadata" + ) + else: + lines.append(self.serialize_metadata()) + feature_header = self.context.get_feature_names_csv() + header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL + lines.append(header) + + line = "" + feature_values = self.context.get_feature_values_csv() + line += feature_values + "," + choice + "," + str(feedback_val) + lines.append(line) + + with open(log_path, "a") as f: + f.write("\n".join(lines) + "\n") + + +class AutoHeuristicSelectAlgorithm(AutoHeuristic): + """ + AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic + when one wants to use AutoHeuristic for kernel choice selection. + """ + + def __init__( + self, + fallback: Callable[[], Optional[ChoiceCaller]], + choices: list[ChoiceCaller], + input_nodes: list[Any], + context: AHContext, + name: str, + augment_context: Optional[list[AHOperation]] = None, + precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, + ) -> None: + """ + The arguments choices, input_nodes and name have to match the ones used in the call to + autotune_select_algorithm(), e.g. if the following call is made + autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes + have to be used here. + """ + self.input_nodes = input_nodes + self.choicestr2choice: dict[str, ChoiceCaller] = {} + for choice in choices: + self.choicestr2choice[choice.autoheuristic_id()] = choice + choices_str = list(self.choicestr2choice.keys()) + + def fallback_str() -> str: + fallback_choice = fallback() + if fallback_choice is None: + # TODO: Find a nicer way to handle this + return "unsure" + return fallback_choice.autoheuristic_id() + + super().__init__( + fallback_str, + choices_str, + None, + context, + name, + augment_context, + precondition, + ) + + if ( + torch._inductor.config.collect_autoheuristic(self.name) + and self.satisfies_precondition() + ): + self.register_global_feedback(input_nodes, choices) + + def register_global_feedback( + self, input_nodes: list[Any], choices: list[ChoiceCaller] + ) -> None: + """ + Registers a callback in select_algorithm, which is called with the timing of each choice. + """ + + from torch._inductor.select_algorithm import ( + add_feedback_saver, + create_inputs_key, + create_precompile_key, + ) + + def store_global_feedback( + ah_inputs_key: str, + ah_precompile_key: str, + timings: dict[ChoiceCaller, float], + name: str, + input_nodes: list[Any], + choices: list[ChoiceCaller], + ) -> None: + current_inputs_key = create_inputs_key(input_nodes) + if current_inputs_key != ah_inputs_key: + return + current_precompile_key = create_precompile_key( + name, current_inputs_key, choices + ) + if current_precompile_key != ah_precompile_key: + return + for choice, time in timings.items(): + self.save_data(choice.autoheuristic_id(), time) + + inputs_key = create_inputs_key(input_nodes) + precompile_key = create_precompile_key(self.name, inputs_key, choices) + feedback_saver = partial(store_global_feedback, inputs_key, precompile_key) + add_feedback_saver(feedback_saver) + + def get_choice_caller(self) -> Optional[ChoiceCaller]: + choice = self.get_choice() + return self.choicestr2choice.get(choice, None) + + def get_top_k_choices_caller( + self, top_k: int, always_included: Optional[list[str]] = None + ) -> Optional[list[ChoiceCaller]]: + choices = self.get_top_k_choices(top_k, always_included) + if choices is None: + return None + return [self.choicestr2choice[choice] for choice in choices] diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..01bb1dd4d6f0f5b7b87fe6c34555063f20cc20e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/autoheuristic_utils.py @@ -0,0 +1,339 @@ +import functools +from typing import Any, Callable + +import torch + + +Feedback = float +Choice = str +Value = Any + +CHOICE_COL = "choice" +FEEDBACK_COL = "feedback" + + +class AHFeature: + """ + The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is + categorical (i.e., not a continuous variable) to learn a machine learning model. + """ + + def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None: + self.name = name + self.value = value + self.is_categorical = is_categorical + + +class AHOperation: + """ + AHOperation can be used to augment the data collected by AutoHeuristic. + One might for example store features like m, k, n, but also want to use + features like m*n, or k*n, to learn a heuristic. Instead of storing features + that can be created from the collected data, one can use AHOperation to + create new features from the collected data. + """ + + def __init__( + self, name: str, func: Callable[[Any], Value], is_categorical: bool = False + ) -> None: + self.name = name + self.func = func + self.is_categorical = is_categorical + + def apply_operation(self, data: Any) -> None: + data[self.name] = self.func(data) + + +class AHContext: + """ + This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will + store the context and the collected feedback. The context could be something like the shape of a tensor, i.e., + information that will help to learn a heuristic. + """ + + features: list[AHFeature] + context_dict: dict[str, Value] + + def __init__(self) -> None: + self.features = [] + self.context_dict = {} + + def add_feature( + self, name: str, value: Value, is_categorical: bool = False + ) -> None: + self.features.append(AHFeature(name, value, is_categorical=is_categorical)) + self.context_dict[name] = value + + def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]: + numerical_features = [] + categorical_features = [] + for feature in self.features: + if feature.is_categorical: + categorical_features.append(feature.name) + else: + numerical_features.append(feature.name) + + return numerical_features, categorical_features + + def get_feature_names_csv(self) -> str: + return ",".join(feature.name for feature in self.features) + + def get_feature_values_csv(self) -> str: + return ",".join(str(feature.value) for feature in self.features) + + def get_value(self, name: str) -> Value: + return self.context_dict[name] + + def apply_operations(self, operations: list[AHOperation]) -> None: + for op in operations: + op.apply_operation(self.context_dict) + + +class AHMetadata: + def __init__( + self, + shared_memory: Any, + device_capa: tuple[int, int], + choices: list[Choice], + name: str, + ) -> None: + # use amount of shared_memory and device_capability to identify GPU + # TODO(AlnisM): there might be a better way to do this + self.shared_memory = shared_memory + self.device_capa = device_capa + self.choices = choices + self.name = name + + def to_dict(self) -> dict[str, Value]: + return { + "shared_memory": self.shared_memory, + "device_capa": self.device_capa, + "name": self.name, + } + + +def get_metadata_str_from_log(log_path: str) -> str: + with open(log_path, newline="") as file: + json_string = file.readline().strip() + return json_string + + +def check_minsize(context: AHContext, minsize: int) -> bool: + return ( + context.get_value("m") >= minsize + and context.get_value("k") >= minsize + and context.get_value("n") >= minsize + ) + + +def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool: + if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0): + # A100 precondition + return check_minsize(context, 512) + elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0): + # H100 precondition + return check_minsize(context, 768) + return True + + +def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool: + m = context.get_value("m") + k = context.get_value("k") + n = context.get_value("n") + if m > 128 or k < 1024 or n < 1024: + return False + mat1_iscontig = context.get_value("mat1_iscontig") + mat2_iscontig = context.get_value("mat2_iscontig") + return mat1_iscontig and not mat2_iscontig + + +def get_mult_dims_ops() -> list[AHOperation]: + m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"]) + m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"]) + k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"]) + return [m_times_k_op, m_times_n_op, k_times_n_op] + + +def get_arith_intensity(data: Any) -> float: + m = data["m"] + k = data["k"] + n = data["n"] + if m == 0 or k == 0 or n == 0: + return 0.0 + return m * k * n / (m * k + k * n + m * n) + + +def pad_mm_operations() -> list[AHOperation]: + mult_dims_ops = get_mult_dims_ops() + k_div_m_times_n_op = AHOperation( + "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"]) + ) + + def bfloat_perf_hit(data: Any) -> bool: + m = data["m"] + k = data["k"] + n = data["n"] + is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16" + return k > (m * 1024) and k > (n * 1024) and is_bfloat + + bfloat_perf_hit_op = AHOperation( + "bfloat_perf_hit", bfloat_perf_hit, is_categorical=True + ) + + arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) + dims_need_padding_ops = get_dims_need_padding_ops() + dims_multiple_ops = get_dims_multiple_ops() + is_contig_ops = get_is_contig_ops() + + ah_operations = mult_dims_ops + [ + k_div_m_times_n_op, + bfloat_perf_hit_op, + arith_intensity_op, + ] + ah_operations.extend(dims_need_padding_ops) + ah_operations.extend(dims_multiple_ops) + ah_operations.extend(is_contig_ops) + return ah_operations + + +def between_op(data: Any, dim: str, lower: int, upper: int) -> bool: + return data[dim] >= lower and data[dim] <= upper + + +def between_ops() -> list[AHOperation]: + dims = ["m", "k", "n"] + limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)] + ah_operations = [] + for dim in dims: + for lower, upper in limits: + between_op_fn = functools.partial( + between_op, dim=dim, lower=lower, upper=upper + ) + # using 'LEQ' instead of '<=' because '<=' cannot be exported to dot + between_op_name = f"{lower}LEQ{dim}LEQ{upper}" + ah_operations.append( + AHOperation(between_op_name, between_op_fn, is_categorical=True) + ) + return ah_operations + + +def pow2_op(data: Any, dim: str, exponent: int) -> bool: + return data[dim] == 2**exponent + + +def mm_operations() -> list[AHOperation]: + mult_dims_ops = get_mult_dims_ops() + arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) + return mult_dims_ops + [arith_intensity_op] + + +def mixed_mm_operations() -> list[AHOperation]: + return mm_operations() + between_ops() + + +def is_multiple(data: Any, dim: str, mult: int) -> bool: + return data[dim] % mult == 0 + + +def get_dims_multiple_ops() -> list[AHOperation]: + multiples = [2, 4, 8, 16, 32] + dims = ["m", "k", "n"] + dims_multiple_ops = [] + for dim in dims: + for mult in multiples: + is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult) + dims_multiple_op = AHOperation( + f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True + ) + dims_multiple_ops.append(dims_multiple_op) + return dims_multiple_ops + + +def get_dims_need_padding_ops() -> list[AHOperation]: + def mat1_innermost_needs_padding_fn(data: Any) -> bool: + mat1_stride_0 = data["mat1_stride_0"] + mat1_stride_1 = data["mat1_stride_1"] + m_padded_length = data["m_padded_length"] + k_padded_length = data["k_padded_length"] + mat1_innermost_needs_padding = False + if mat1_stride_0 == 1 and m_padded_length != 0: + mat1_innermost_needs_padding = True + if mat1_stride_1 == 1 and k_padded_length != 0: + mat1_innermost_needs_padding = True + return mat1_innermost_needs_padding + + mat1_innermost_op = AHOperation( + "mat1_innermost_needs_padding", + mat1_innermost_needs_padding_fn, + is_categorical=True, + ) + + def mat2_innermost_needs_padding_fn(data: Any) -> bool: + mat2_stride_0 = data["mat2_stride_0"] + mat2_stride_1 = data["mat2_stride_1"] + k_padded_length = data["k_padded_length"] + n_padded_length = data["n_padded_length"] + mat2_innermost_needs_padding = False + if mat2_stride_0 == 1 and k_padded_length != 0: + mat2_innermost_needs_padding = True + if mat2_stride_1 == 1 and n_padded_length != 0: + mat2_innermost_needs_padding = True + return mat2_innermost_needs_padding + + mat2_innermost_op = AHOperation( + "mat2_innermost_needs_padding", + mat2_innermost_needs_padding_fn, + is_categorical=True, + ) + + def num_dims_needs_padding_fn(data: Any) -> int: + m_padded_length = data["m_padded_length"] + k_padded_length = data["k_padded_length"] + n_padded_length = data["n_padded_length"] + num_dims_needs_padding = 0 + if m_padded_length != 0: + num_dims_needs_padding += 1 + if k_padded_length != 0: + num_dims_needs_padding += 1 + if n_padded_length != 0: + num_dims_needs_padding += 1 + return num_dims_needs_padding + + num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn) + return [mat1_innermost_op, mat2_innermost_op, num_dims_op] + + +def get_is_contig_ops() -> list[AHOperation]: + def mat1_is_contig_fn(data: Any) -> bool: + stride_0 = data["mat1_stride_0"] + stride_1 = data["mat1_stride_1"] + k = data["k"] + return stride_0 == k and stride_1 == 1 + + mat1_is_contig_op = AHOperation( + "mat1_iscontig", mat1_is_contig_fn, is_categorical=True + ) + + def mat2_is_contig_fn(data: Any) -> bool: + stride_0 = data["mat2_stride_0"] + stride_1 = data["mat2_stride_1"] + n = data["n"] + return stride_0 == n and stride_1 == 1 + + mat2_is_contig_op = AHOperation( + "mat2_iscontig", mat2_is_contig_fn, is_categorical=True + ) + + return [mat1_is_contig_op, mat2_is_contig_op] + + +def context_add_strides(context: AHContext, name: str, stride: tuple[int, ...]) -> None: + for i, s in enumerate(stride): + context.add_feature(f"{name}_stride_{i}", s) + + +def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None: + using_tf32 = "not_float_32" + if dtype == torch.float32: + using_tf32 = torch.backends.cuda.matmul.allow_tf32 + context.add_feature("using_tf32", using_tf32, is_categorical=True) diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..8000a000cebe1dd591351be35836b82f73d1a336 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/learned_heuristic_controller.py @@ -0,0 +1,119 @@ +import importlib +import inspect +import pkgutil +from collections import defaultdict +from typing import Any, Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeuristic + + +def find_and_instantiate_subclasses( + package_name: str, base_class: Any +) -> list[LearnedHeuristic]: + instances = [] + + package = importlib.import_module(package_name) + for _, module_name, _ in pkgutil.walk_packages( + package.__path__, package.__name__ + "." + ): + try: + module_basename = module_name.split(".")[-1] + if not module_basename.startswith("_"): + # learned heuristics start with an underscore + continue + module = importlib.import_module(module_name) + + # look for classes that are subclasses of base_class + for _name, obj in inspect.getmembers(module): + if ( + inspect.isclass(obj) + and issubclass(obj, base_class) + and obj != base_class + ): + instance = obj() + instances.append(instance) + except Exception as e: + print(f"Error processing module {module_name}: {e}") + + return instances + + +class LearnedHeuristicController: + """ + Class that finds and instantiates all learned heuristics. It also provides + a way to get the decision of a learned heuristic. + """ + + existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list) + """ + A dictionary that stores all the learned heuristics for each optimization. + The key is the optimization name, and the value is a list of LearnedHeuristic objects. + """ + + heuristics_initialized: bool = False + """ + A flag that indicates whether the learned heuristics have been initialized. + Set to true when the get_decision() function is called for the first time. + """ + + def __init__( + self, + metadata: AHMetadata, + context: AHContext, + ) -> None: + self.metadata = metadata + self.context = context + + def get_heuristics(self, name: str) -> list[LearnedHeuristic]: + """ + Returns a list of learned heuristics for the given optimization name. + """ + + if not LearnedHeuristicController.heuristics_initialized: + # learned heuristics are generated into the following package + learned_heuristics_package = "torch._inductor.autoheuristic.artifacts" + + # learned heuristics have to be of type LearnedHeuristic + base_class = LearnedHeuristic + found_heuristics = find_and_instantiate_subclasses( + learned_heuristics_package, base_class + ) + + for learned_heuristic in found_heuristics: + opt_name = learned_heuristic.get_name() + LearnedHeuristicController.existing_heuristics[opt_name].append( + learned_heuristic + ) + LearnedHeuristicController.heuristics_initialized = True + + return LearnedHeuristicController.existing_heuristics[name] + + def get_decision(self) -> Optional[Choice]: + """ + Returns the decision made by the learned heuristic or None if no heuristic was found or the heuristic is unsure + which choice to make. + """ + + heuristics = self.get_heuristics(self.metadata.name) + for heuristic in heuristics: + if heuristic.check_precondition(self.metadata, self.context): + return heuristic.get_decision(self.context, self.metadata.choices) + return None + + def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]: + heuristics = self.get_heuristics(self.metadata.name) + for heuristic in heuristics: + if heuristic.check_precondition(self.metadata, self.context): + choices = heuristic.get_decisions_ranked(self.context) + if choices is None: + return None + avail_choices = [ + choice for choice in choices if choice in self.metadata.choices + ] + return avail_choices[:top_k] + return None diff --git a/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf4c95cb53d004580983d94ac7e0210ff644737 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autoheuristic/learnedheuristic_interface.py @@ -0,0 +1,95 @@ +import operator +from typing import Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) + + +class LearnedHeuristic: + """ + LearnedHeuristic is a base class for all learned heuristics. + """ + + def __init__(self) -> None: + pass + + def check_precondition( + self, + metadata: AHMetadata, + context: AHContext, + ) -> bool: + return True + + def get_decision( + self, context: AHContext, choices: list[Choice] + ) -> Optional[Choice]: + return None + + def get_confidence_threshold(self) -> float: + return 1.0 + + def get_name(self) -> str: + return "" + + def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: + return None + + +class LearnedHeuristicRegression(LearnedHeuristic): + def __init__(self) -> None: + super().__init__() + + def get_feedback(self, context: AHContext, choice: Choice) -> float: + return 1.0 + + def get_decision( + self, context: AHContext, choices: list[Choice] + ) -> Optional[Choice]: + choice2feedback = {} + for choice in choices: + predicted_feedback = self.get_feedback(context, choice) + choice2feedback[choice] = predicted_feedback + sorted_choices_feedback = sorted( + choice2feedback.items(), key=operator.itemgetter(1) + ) + highest_feedback = sorted_choices_feedback[-1][1] + second_highest_feedback = sorted_choices_feedback[-2][1] + if highest_feedback / second_highest_feedback > self.get_confidence_threshold(): + return sorted_choices_feedback[-1][0] + # We are not sure which choice is the best one + return None + + +class LearnedHeuristicDecision(LearnedHeuristic): + def __init__(self) -> None: + super().__init__() + + def get_choice(self, idx: int) -> Optional[str]: + return None + + def get_decision( + self, context: AHContext, choices: list[Choice] + ) -> Optional[Choice]: + best_choices = self.get_best_choices(context) + if not best_choices: + return None + (best_choice_proba, best_choice_idx) = best_choices[0] + if best_choice_proba <= self.get_confidence_threshold(): + return None + return self.get_choice(best_choice_idx) + + def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: + feedback_idx_list = self.get_best_choices(context) + if feedback_idx_list is None: + return None + choices = [ + self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list + ] + choices = [choice for choice in choices if choice is not None] + return choices + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + return [] diff --git a/phivenv/Lib/site-packages/torch/_inductor/autotune_process.py b/phivenv/Lib/site-packages/torch/_inductor/autotune_process.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc97678b2acde8ecc086faeb03912b40190a238 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/autotune_process.py @@ -0,0 +1,890 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import atexit +import ctypes +import dataclasses +import functools +import logging +import os +import pickle +import queue +import selectors +import subprocess +import sys +import time +import warnings +from collections.abc import Iterable, Sequence +from concurrent.futures import ThreadPoolExecutor +from ctypes import byref, c_size_t, c_void_p, CDLL +from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.testing import rand_strided +from torch._inductor import ir +from torch._inductor.codecache import ( + CppCodeCache, + CUDACodeCache, + DLLWrapper, + get_hash, + PyCodeCache, +) +from torch._inductor.utils import get_gpu_type, get_ld_library_path, is_gpu +from torch._logging import getArtifactLogger +from torch.utils._ordered_set import OrderedSet + + +if TYPE_CHECKING: + from types import ModuleType + + from torch._inductor.select_algorithm import TritonTemplateCaller + +from . import config +from .runtime.benchmarking import benchmarker +from .virtualized import V + + +CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" + +autotuning_log = getArtifactLogger(__name__, "autotuning") + + +class NonzeroWorkspaceNotSupportedError(Exception): + pass + + +class TuningProcess: + """ + Class to launch and interact with a benchmarking subprocess. + """ + + @staticmethod + def process_main(read_pipe: IO[bytes], write_pipe: IO[bytes]) -> None: + """ + Entry point for the child process. + """ + autotuning_log.debug( + "Started autotune subprocess %s. Visible devices: %s", + os.getpid(), + os.environ.get(CUDA_VISIBLE_DEVICES), + ) + + def workloop(): + while True: + job = TuningProcess.recv(read_pipe) + if job is None: + # None is a sentinel for the child to shut down + break + try: + result = job() + except Exception as e: + result = e + TuningProcess.send(result, write_pipe) + + try: + workloop() + except EOFError: + # The parent closed the pipe + pass + + @staticmethod + def send(obj: Any, write_pipe: IO[bytes]) -> None: + pickle.dump(obj, write_pipe) + write_pipe.flush() + + @staticmethod + def recv(read_pipe: IO[bytes]) -> Any: + return pickle.load(read_pipe) + + def __init__(self, device: Optional[int]): + self.device = device + self.start() + + def start(self): + """ + Start the benchmarking subprocess. + """ + entry = os.path.join(os.path.dirname(__file__), "__autotune_main__.py") + + subproc_read_fd, write_fd = os.pipe() + read_fd, subproc_write_fd = os.pipe() + self.write_pipe = os.fdopen(write_fd, "wb") + self.read_pipe = os.fdopen(read_fd, "rb") + + self.selector = selectors.DefaultSelector() + self.selector.register(self.read_pipe, selectors.EVENT_READ) + + cmd = [ + sys.executable, + entry, + f"--parent={os.getpid()}", + f"--read-fd={str(subproc_read_fd)}", + f"--write-fd={str(subproc_write_fd)}", + ] + extra_env = { + # We need to set the PYTHONPATH so the subprocess can find torch. + "PYTHONPATH": os.environ.get( + "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) + ), + # We shouldn't be using the Triton async compile subprocess pool, + # but as a precaution set the env var that disables its creation. + "TORCH_WARM_POOL": "0", + # Some internal usages need a modified LD_LIBRARY_PATH. + "LD_LIBRARY_PATH": get_ld_library_path(), + # This will cause the subprocs to profile using the profiler. + "TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING": "1" + if config.profile_bandwidth_with_do_bench_using_profiling + else "0", + } + if self.device is not None: + extra_env[CUDA_VISIBLE_DEVICES] = str(self.device) + self.process = subprocess.Popen( + cmd, + env={**os.environ, **extra_env}, + pass_fds=(subproc_read_fd, subproc_write_fd), + ) + os.close(subproc_read_fd) + os.close(subproc_write_fd) + + self.running = True + + def alive(self) -> bool: + """ + True if the subprocess is still running. + """ + return self.running and self.process.poll() is None + + def put(self, req: Any) -> None: + """ + Push a work item to the child process. + """ + if not self.alive(): + self.start() + TuningProcess.send(req, self.write_pipe) + + def get(self, timeout: float = 120.0) -> Any: + """ + Get a response from the child process. Raises TimeoutError on timeout; + raises EOFError if the subprocess crashes. + """ + try: + if not self.selector.select(timeout): + raise TimeoutError(f"Timeout in autotune subprocess {self.process.pid}") + result = TuningProcess.recv(self.read_pipe) + except TimeoutError: + self.kill() + raise + except EOFError: + # The subprocess crashed + self.close() + raise + except Exception: + autotuning_log.exception( + "Unexpected exception in autotune subprocess %s", self.process.pid + ) + self.kill() + raise + + if isinstance(result, Exception): + raise result + return result + + def shutdown(self, wait: bool = True) -> None: + """ + Signal the child process to shut down gracefully. + """ + if self.alive(): + TuningProcess.send(None, self.write_pipe) + if wait: + self.wait() + + def wait(self) -> None: + """ + Wait for the child process to exit. + """ + if self.alive(): + self.process.wait() + self.close() + + def close(self) -> None: + """ + Close resources. + """ + self.selector.close() + self.read_pipe.close() + self.write_pipe.close() + self.running = False + + def kill(self) -> None: + """ + Send a SIGKILL to the child process. + """ + if self.alive(): + autotuning_log.error( + "Sending SIGKILL to autotune subprocess %d", + self.process.pid, + ) + self.process.kill() + self.close() + + +class TuningProcessPool: + """ + Maintains a pool of TuningProcesses to benchmark kernels in parallel + across devices. By default, we create one TuningProcess per device and + set the sub-process environment to make only that device visible. + """ + + def __init__(self) -> None: + """ + Start the child processes. + """ + devices = self.get_device_list() + autotuning_log.debug("Sub-process autotune device list: %s", devices) + + # Launch the child processes. + self.processes = [TuningProcess(device=device) for device in devices] + + self.process_queue: queue.Queue[TuningProcess] = queue.Queue() + for p in self.processes: + self.process_queue.put(p) + + # Use a thread pool to manage distributing work to the subprocesses. + # Threads block on an available process, so it makes sense to match + # the number of threads with the number of devices. + self.executor = ThreadPoolExecutor(max_workers=len(devices)) + + @staticmethod + def get_device_list() -> Sequence[Optional[int]]: + """ + Gather the list of devices to be used in the pool. + """ + if not config.autotune_multi_device: + # Don't use multiple devices + return [None] + + gpu_type = get_gpu_type() + device_interface = get_interface_for_device(gpu_type) + count = device_interface.device_count() + + # If the user specified the visible devices in the env, use those. + if CUDA_VISIBLE_DEVICES in os.environ: + devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")] + assert len(devices) <= count + return devices + + return list(range(count)) + + def shutdown(self) -> None: + """ + Signal all child processes to exit. + """ + self.executor.shutdown() + + for p in self.processes: + p.shutdown(wait=False) + for p in self.processes: + p.wait() + + def target(self, choice: TritonTemplateCaller) -> float: + """ + Entry point for the thread-pool helper threads: Wait for an open TuningProcess, + remove it from the queue, execute the benchmark in that subprocess, and return + the TuningProcess to the queue. + """ + assert choice.bmreq is not None + + process = self.process_queue.get() + process.put(choice.bmreq.benchmark) + try: + return process.get( + config.max_autotune_subproc_result_timeout_seconds, + ) + except TimeoutError: + warnings.warn( + f"Timed out benchmarking choice '{choice}'. It will be ignored. " + "Please debug the root cause in case the choice can bring perf gains." + ) + # Set to INF so this choice will be ignored + return float("inf") + except Exception: + warnings.warn( + f"Failed to benchmark choice '{choice}'. It will be ignored. " + "Please debug the root cause in case the choice can bring perf gains." + ) + # Set to INF so this choice will be ignored + return float("inf") + finally: + self.process_queue.put(process) + + def benchmark( + self, + choices: list[TritonTemplateCaller], + ) -> dict[TritonTemplateCaller, float]: + """ + Benchmark each choice in a separate process. + """ + + # Use a ThreadExecutorPool to spread the work across the subprocesses and + # to grab subprocesses as soon as they're free. + results = dict(zip(choices, self.executor.map(self.target, choices))) + + return results + + +LayoutOrBuffer = Union[ir.Layout, ir.Buffer] + + +@dataclasses.dataclass +class TensorMeta: + device: torch.device + dtype: torch.dtype + sizes: torch._prims_common.ShapeType + strides: torch._prims_common.StrideType + offset: int + name: Optional[str] = None + + @classmethod + def from_irnodes( + cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]] + ) -> Union[TensorMeta, list[TensorMeta]]: + if isinstance(irnodes, Sequence): + result: list[Any] = [cls.from_irnodes(x) for x in irnodes] + assert all(isinstance(x, TensorMeta) for x in result) + return result + + node = irnodes + if isinstance(node, ir.Layout): + node = ir.Buffer(name="fake", layout=node) + + dtype = node.get_dtype() + assert dtype is not None + device = node.get_device() + assert device is not None + + return TensorMeta( + device=device, + dtype=dtype, + sizes=V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + strides=V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + offset=V.graph.sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + name=node.get_name(), + ) + + def to_tensor(self) -> torch.Tensor: + return rand_strided( + self.sizes, + self.strides, + device=self.device, + dtype=self.dtype, + extra_size=self.offset, + ) + + +@dataclasses.dataclass +class BenchmarkRequest: + """ + Only handle triton template benchmark for now. The extern kernel benchmark + can be done inside the same process since they usually don't cause crash. + + Important: Instances of this class and subclasses have to be serializable + across process boundaries. Do not put CUDA Tensors in here! + """ + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + ) -> None: + # the kernel name defined in the module + self.kernel_name = kernel_name + + if isinstance(input_tensor_meta, TensorMeta): + input_tensor_meta = [input_tensor_meta] + self.input_tensor_meta = input_tensor_meta + + if isinstance(output_tensor_meta, (tuple, list)): + if len(output_tensor_meta) > 1: + # Each output with same meta for Grouped GEMM + assert all( + getattr(output_tensor_meta[0], attr) == getattr(x, attr) + for x in output_tensor_meta + for attr in ["device", "dtype", "sizes", "strides", "offset"] + ) + output_tensor_meta = output_tensor_meta[0] + self.output_tensor_meta = output_tensor_meta + + self.extra_args = extra_args + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + raise NotImplementedError + + def cleanup_run_fn(self) -> None: + pass + + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> float: + raise NotImplementedError + + def benchmark( + self, + *input_tensors: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> float: + debug = autotuning_log.isEnabledFor(logging.DEBUG) + if debug: + start_ts = time.time() + + # create args and out tensor + if out is None: + assert len(input_tensors) == 0 + input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta) + out = self.output_tensor_meta.to_tensor() + + if debug: + create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + try: + fn = self.make_run_fn(*input_tensors, out=out) + except NonzeroWorkspaceNotSupportedError: + # Skipping all ops with nonzero workspace requirements + autotuning_log.info("Skipping op due to nonzero workspace requirement") + return float("inf") + + if debug: + load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + + res = self.do_bench(fn, *input_tensors, out) + + if debug: + bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + autotuning_log.debug( + "InChildProcess %s: load %f, create tensor %f, bench %f", + str(self), + load_elapse, # type: ignore[possibly-undefined] + create_tensor_elapse, # type: ignore[possibly-undefined] + bench_elapse, + ) + self.cleanup_run_fn() + return res + + +class _TestBenchmarkRequest(BenchmarkRequest): + """ + Supports unit testing. Defined in this file instead of the test file so the + TuningProcess sub-process can unpickle these objects. + """ + + def __init__( + self, + result: float = 0.0, + device: Optional[int] = None, + sleep: Optional[float] = None, + exc: Optional[Exception] = None, + crash: bool = False, + ): + self.result = result + self.device = device + self.sleep = sleep + self.exc = exc + self.crash = crash + + def benchmark( + self, *input_tensors: torch.Tensor, out: Optional[torch.Tensor] = None + ) -> float: + if self.device is not None: + assert os.environ.get(CUDA_VISIBLE_DEVICES, None) == str(self.device) + if self.sleep: + time.sleep(self.sleep) + if self.exc: + raise self.exc + if self.crash: + sys.exit(1) + return self.result + + +class GPUDeviceBenchmarkMixin: + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> float: + device_idx_set = OrderedSet( + tensor.device.index + for tensor in [*input_tensors, out] + if isinstance(tensor, torch.Tensor) + and is_gpu(tensor.device.type) + and tensor.device.index is not None + ) + assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" + device_type = next( + ( + tensor.device.type + for tensor in input_tensors + if is_gpu(tensor.device.type) + ), + "cuda", + ) + device_interface = get_interface_for_device(device_type) + if len(device_idx_set) == 1: + device_idx = next(iter(device_idx_set)) + else: + device_idx = device_interface.current_device() + with device_interface.device(device_idx): # type: ignore[attr-defined] + res = benchmarker.benchmark_gpu(fn) + device_interface.synchronize() # shake out any CUDA errors + + return res + + +class CPUDeviceBenchmarkMixin: + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> float: + return benchmarker.benchmark_cpu(fn) + + +class TritonBenchmarkRequest(BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + module_path: str, # the path of the module defining the triton kernel + module_cache_key: str, + num_stages: int, + num_warps: int, + num_consumer_groups: int = 0, + num_buffers_warp_spec: int = 0, + matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. + waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit + kpack: int = 0, # ROCm specific gemm parameter + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.module_path = module_path + self.module_cache_key = module_cache_key + self.num_stages = num_stages + self.num_warps = num_warps + self.num_consumer_groups = num_consumer_groups + self.num_buffers_warp_spec = num_buffers_warp_spec + self.matrix_instr_nonkdim = matrix_instr_nonkdim + self.waves_per_eu = waves_per_eu + self.kpack = kpack + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + autotuning_log.debug( + "benchmark module key: %s, path: %s", + self.module_cache_key, + self.module_path, + ) + + run_method = getattr(mod, self.kernel_name).run + extra_args = list(self.extra_args) + run_method.__self__.with_bandwidth_info = False + + # Newer version of triton add warmup argument to JITFunction.run. + # This code handles backward-compatibility. + warmup_arg = {} + import inspect + + if "warmup" in inspect.signature(run_method).parameters: + warmup_arg["warmup"] = False + + if out.device.type == "cpu": + stream = 0 + else: + device_type = out.device.type + device_interface = get_interface_for_device(device_type) + stream = device_interface.get_raw_stream( + self.output_tensor_meta.device.index + ) + + if isinstance( + getattr(mod, self.kernel_name), + torch._inductor.runtime.triton_heuristics.DebugAutotuner, + ): + return functools.partial( + run_method, + *input_tensors, + out, + *extra_args, + **warmup_arg, + stream=stream, + ) + else: + return functools.partial( + run_method, + *input_tensors, + out, + *extra_args, + **warmup_arg, + stream=stream, + benchmark_run=True, + ) + + def precompile(self): + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + getattr(mod, self.kernel_name).precompile() + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" + + +class TritonGPUBenchmarkRequest(GPUDeviceBenchmarkMixin, TritonBenchmarkRequest): + pass + + +class TritonCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, TritonBenchmarkRequest): + pass + + +class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): + """ + A class to handle CUDA (CUTLASS) benchmark requests. This class is for + managing the lifecycle of a CUDA kernel benchmark, including compiling + the source code, managing workspace memory, and executing the kernel. + + Important: Instances of this class have to be serializable across + process boundaries. Do not put CUDA Tensors in here! + """ + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") + + def precompile(self): + """ + Precompile the CUDA source code to populate the CUDACodeCache. + This may happen in a separate thread pool. + """ + autotuning_log.debug("Precompiling %s", self) + CUDACodeCache.compile(self.source_code, "so") + autotuning_log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + """ + Create a function to run the CUDA kernel with the given input and output tensors. + """ + + self.ensure_dll_loaded() + self.update_workspace_size() + args = [c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [out]] + autotuning_log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=out.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + ret = functools.partial( + run_method, + *args, + *self.extra_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + # sanity check to make sure we cleanup run fn properly + try: + ret() + except RuntimeError as e: + err_msg = str(e) + + def raise_runtime_error(): + raise RuntimeError(err_msg) + + self.cleanup_run_fn() + return raise_runtime_error + + return ret + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len( + {meta.name for meta in self.input_tensor_meta} # noqa: set_linter + ) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + run_method( + *args, # input ptrs and output ptrs + *self.extra_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.cuda.synchronize() # shake out any CUDA errors + self.workspace_size = c_workspace_size.value + autotuning_log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = CUDACodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.DLL = None + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + + +class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.hash_key = get_hash(source_code) + self.DLL: Optional[Union[CDLL, ModuleType]] = None + + def precompile(self): + # Prepopulate CppCodeCache + # may happen in separate Threadpool + autotuning_log.debug("Precompiling %s", self) + CppCodeCache.load(self.source_code, device_type="cpu") + autotuning_log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf + self.DLL = CppCodeCache.load(self.source_code, device_type="cpu") + args = [tensor.data_ptr() for tensor in list(input_tensors) + [out]] + autotuning_log.debug( + "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.DLL, + args, + self.extra_args, + ) + run_method = getattr(self.DLL, self.kernel_name) + # Assume only size with type ctypes.c_ulonglong in extra_args + assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args) + run_method.argtypes = [ctypes.c_ulonglong] * ( + len(args) + len(list(self.extra_args)) + ) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + """ + Check close attr due to it crash on Windows. + """ + if hasattr(self.DLL, "close"): + self.DLL.close() + + def __str__(self) -> str: + return f"{self.kernel_name=}" + + +@functools.cache +def get_tuning_process_pool() -> TuningProcessPool: + pool = TuningProcessPool() + atexit.register(pool.shutdown) + return pool + + +def benchmark_in_sub_process( + choices: list[TritonTemplateCaller], +) -> dict[TritonTemplateCaller, float]: + """ + Do benchmarking in a subprocess and return the perf number (latency). + """ + return get_tuning_process_pool().benchmark(choices) diff --git a/phivenv/Lib/site-packages/torch/_inductor/bounds.py b/phivenv/Lib/site-packages/torch/_inductor/bounds.py new file mode 100644 index 0000000000000000000000000000000000000000..af88a99e924e7fe020d3eea9c28cf1cb8bdd0bcf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/bounds.py @@ -0,0 +1,259 @@ +import logging +import operator +from functools import partial +from typing import Any, Callable, Optional, Union + +import sympy +from sympy import Expr + +import torch +from torch.utils._sympy.value_ranges import ( + bound_sympy, + SymPyValueRangeAnalysis, + ValueRanges, +) + +from ..utils._sympy.functions import PowByNatural +from ..utils._sympy.numbers import int_oo +from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock +from .ops_handler import DefaultHandler, ReductionType, StoreMode +from .utils import cache_on_self, dominated_nodes +from .virtualized import V + + +log = logging.getLogger(__name__) + + +class BoundVars: + """ + Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() + It exposes the ranges of the nodes in the `bounds` variable + + Note. A current limitation of this analysis is that it just works on a per-loop basis. + We should be able to propagate the bounds between across the whole graph. This may benefit + the case a bounded variable is returned by a kernel and fed into another. + """ + + def __init__(self, loop_body: LoopBody) -> None: + def upper_bound(v: Union[Expr, int]) -> int: + return bound_sympy(v).upper if isinstance(v, Expr) else v + + self.loop_body = loop_body + self.replacement_vals = { + k: ValueRanges[Expr](0, upper_bound(v) - 1) + for k, v in loop_body.var_ranges.items() + } + # avoid computing these values, pessimistically assume that they are unbounded + self.unbounded_vars = dominated_nodes( + node + for node in self.loop_body.get_nodes() + if node.target in ["load", "reduction", operator.getitem] + or "masked_subblock" in node.target + ) + # To access this variable call `get_bounds()` + self._bounds: dict[torch.fx.Node, ValueRanges[Expr]] = {} + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"loop_body={self.loop_body},\n " + f"replacement_vals={self.replacement_vals}, \n" + f"unbounded_vars={self.unbounded_vars}, \n" + f"_bounds={self._bounds})" + ) + + @cache_on_self + def get_bounds(self) -> dict[torch.fx.Node, ValueRanges[Expr]]: + submodules = self.swap_submodules(self.loop_body.submodules) + + # Initialize the environment with the unbounded variables + for node in self.unbounded_vars: + # we need to evaluate masked_subblock to recurse, and we need to set indirect values + if not isinstance(node.target, str) or ( + "masked_subblock" not in node.target + and "set_indirect" not in node.target + ): + self._bounds[node] = ValueRanges[Expr].unknown() + + with V.set_ops_handler(ValueRangeAnalysis()): + interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) + interpreter.run(V.get_ops_handler(), initial_env=self._bounds) + return self._bounds + + def swap_submodules( + self, submodules: dict[str, Callable[..., Any]] + ) -> dict[str, Callable[..., ValueRanges[Expr]]]: + result: dict[str, Callable[..., ValueRanges[Expr]]] = {} + for key in submodules.keys(): + if key == "get_index": + result[key] = self.get_index + elif "masked_subblock" in key: + subblock = self.loop_body.subblocks[key] + # The result within the lambda will reference to the final + # set of modules at the end of the for-loop as it stores a reference to it + + # bind subblock in a function because python lambdas close over by reference + # moving the lambda out of make_fn would close over the reference to subblock, + # so all lambdas would have the same subblock reference that is the final + # subblock in the loop + def make_fn( + subblock: LoopBodyBlock, + ) -> Callable[[Any, Any], ValueRanges[Expr]]: + return lambda mask, value: self.masked_subblock( + subblock, self._bounds, mask, value, result + ) + + result[key] = make_fn(subblock) + elif "set_indirect" in key: + idx = int(key[len("set_indirect") :]) + var = self.loop_body.indirect_vars[idx] + indirect = partial(self.set_indirect, var) + result[key] = indirect + else: + assert "scan" in key + result[key] = submodules[key] + + return result + + def masked_subblock( + self, + subblock: LoopBodyBlock, + env: dict[torch.fx.Node, ValueRanges[Expr]], + mask: Any, + value: Any, + submodules: dict[str, Callable[..., Any]], + ) -> ValueRanges[Expr]: + interp = InterpreterShim(subblock.graph, submodules) + interp.run(V.get_ops_handler(), initial_env=env) + output = [node for node in subblock.graph.nodes if node.target == "output"] + assert len(output) == 1 + # dont bother unioning with value since the load from buffer will be + # pessimistically assumed to be inf anyway + return interp.env[output[0]] + + def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]: + assert isinstance(new, ValueRanges) + self.replacement_vals[old] = new + return new + + def get_index(self, name: str) -> ValueRanges[Expr]: + expr = self.loop_body.indexing_exprs[name] + bound = self.replacement_vals.get(expr) + if bound is None: + bound = bound_sympy(expr, self.replacement_vals) + # The following assertion is true at the time of this writing + # We don't assert is as to not execute bound_sympy when bound is not None + # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) + self.replacement_vals[name] = bound + return bound + + +class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler): + def __init__(self) -> None: + self.name = "ValueRangeAnalysis" + boolean_operators = ( + "xor", + "logical_and", + "logical_or", + "logical_not", + ) + for op in boolean_operators: + setattr(self, op, self.bool_handler) + + @staticmethod + def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]: + # just assuming bools can have both values + return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type] + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + # many ops are unlikely to show up in optimizable indexing compute, + # so we dont have full coverage + return ValueRanges.unknown() + + def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]: + return ValueRanges.unknown() + + def store( + self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None + ) -> None: + return + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Any, + ) -> ValueRanges[Any]: + return ValueRanges.unknown() + + @classmethod + def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]: + assert isinstance(index, ValueRanges) + return cls.to_dtype(index, dtype) + + @staticmethod + def to_dtype( + x: Any, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> ValueRanges[Any]: + x = ValueRanges.wrap(x) + + if dtype == torch.bool: + if x.is_singleton(): + return ValueRanges.wrap(x.lower != 0) + elif x.is_bool: + return x + elif 0 not in x: + return ValueRanges.wrap(sympy.true) + else: + return ValueRanges(sympy.false, sympy.true) + + def cast(x: Any, dtype: torch.dtype) -> sympy.Expr: + # dtype is int or float + if dtype.is_floating_point: + return sympy.Float(x) + else: + if x in (int_oo, -int_oo): + return x + try: + return sympy.Integer(x) + except TypeError: + # inf cannot be cast to Integer + return x + + if x.is_bool: + if x.is_singleton(): + val = 1 if x.lower else 0 + return ValueRanges.wrap(cast(val, dtype)) + else: + return ValueRanges(cast(0, dtype), cast(1, dtype)) + else: + # int to float or float to int + return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) + + @staticmethod + def square(x: Any) -> ValueRanges[Any]: + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) + + @staticmethod + def neg(x: Any) -> ValueRanges[Any]: + return ValueRanges.decreasing_map(x, operator.neg) + + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds + @classmethod + def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]: + x = cls.truediv(a, b) + if x == ValueRanges.unknown(): + return x + + return cls.trunc(x) + + @classmethod + def sub(cls, a: Any, b: Any) -> ValueRanges[Any]: + return cls.add(a, cls.neg(b)) diff --git a/phivenv/Lib/site-packages/torch/_inductor/choices.py b/phivenv/Lib/site-packages/torch/_inductor/choices.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f4e71880c7c8c9e482a55972e81c45fa369ede --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/choices.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import typing +from typing import Any, Optional, TYPE_CHECKING + +import sympy + +import torch + +from . import config +from .codecache import write_text +from .metrics import get_metric_table, is_metric_table_enabled +from .runtime.hints import DeviceProperties, ReductionHint +from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse +from .template_heuristics import ( + BaseConfigHeuristic, + CPUConfigHeuristic, + CUDAConfigHeuristic, + ROCmConfigHeuristic, + XPUConfigHeuristic, +) +from .virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Generator + from functools import partial + + from triton import Config as TritonConfig + + from torch.utils._ordered_set import OrderedSet + + from .codegen.simd_kernel_features import SIMDKernelFeatures + from .codegen.triton import TritonKernel + + +class Sortable(typing.Protocol): + """Anything that can be used as a list.sort() key (int/tuple/etc)""" + + def __lt__(self, other: typing.Self) -> bool: ... + + +class InductorChoices: + """ + This class contains a collection of default heuristics that effect performance of our generated + code. We try to not put correctness requirements in this file. + + You can override the choices made here by doing: + + class MyHeuristics(InductorChoices): + ... + + torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()) + """ + + def get_config_heuristics( + self, device_type: Optional[str] = "cuda" + ) -> BaseConfigHeuristic: + if device_type == "cuda": + if torch.version.hip is None: + return CUDAConfigHeuristic() + else: + return ROCmConfigHeuristic() + elif device_type == "xpu": + return XPUConfigHeuristic() + elif device_type == "cpu": + return CPUConfigHeuristic() + else: + return BaseConfigHeuristic() + + # GEMM configs + def get_base_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + if config.max_autotune_gemm_search_space != "EXHAUSTIVE": + return mm_heuristics.get_mm_configs() + else: + return mm_heuristics.get_exhaustive_mm_configs() + + def get_extra_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_extra_mm_configs() + + def get_int8_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_int8_mm_configs() + + def get_mixed_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_mixed_mm_configs() + + def get_persistent_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_persistent_mm_configs() + + def get_scaled_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_scaled_mm_configs() + + def get_scaled_persistent_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_scaled_persistent_mm_configs() + + def get_mm_plus_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_mm_plus_mm_configs() + + # Conv configs + def get_conv_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + conv_heuristics = self.get_config_heuristics(device_type) + return conv_heuristics.get_conv_configs() + + # Flex attention configs + def get_flex_attention_fwd_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_attn_fwd_configs(head_dim, dtype) + + def get_flex_attention_bwd_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_attn_bwd_configs(head_dim, dtype) + + def get_flex_decode_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_decode_configs(head_dim, dtype) + + def triton_kernel_kwargs( + self, + kernel_cls: type[TritonKernel], + features: SIMDKernelFeatures, + groups: list[sympy.Expr], + kernel_kwargs: dict[str, Any], + ) -> dict[str, Any]: + """Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations""" + return kernel_kwargs + + @staticmethod + def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool: + """Heuristic to decide if a cooperative reduction should be used.""" + if config.triton.force_cooperative_reductions: + return True + if ( + not config.triton.cooperative_reductions + or V.graph.get_current_device_or_throw().type == "cpu" + ): + return False + + xhint = V.graph.sizevars.size_hint(features.numel, fallback=2) + if xhint <= 8: + threshold = 32768 * xhint + elif xhint <= 16: + threshold = 2097152 + else: + return False + # TODO(jansel): should this default on for dynamic shapes? + return V.graph.sizevars.statically_known_geq( + features.reduction_numel, threshold + ) + + @staticmethod + def should_use_persistent_reduction( + features: SIMDKernelFeatures, cooperative_reduction: bool + ) -> bool: + """ + Heuristic to decide if a persistent reduction should be used. + """ + if not config.triton.persistent_reductions: + return False + threshold = { + ReductionHint.INNER: 1024, + }.get(features.get_reduction_hint(), 64) + + if cooperative_reduction: + # The RSPLIT of cooperative reductions means each thread block is operating on fewer elements + try: + threshold *= 32 // min(V.graph.sizevars.size_hint(features.numel), 32) + except ValueError: + pass # unbacked symint + + # If multi_kernel is enabled, we do more aggressive persistent reduction. + # This may result in some persistent reductions slower than the + # corresponding non-persistent reductions. MultiKernel will do benchmarking + # to pick the faster one. + if config.triton.multi_kernel: + threshold *= 16 + return V.graph.sizevars.statically_known_leq( + features.reduction_numel, threshold + ) # type: ignore[arg-types] + + @staticmethod + def want_no_x_dim(features: SIMDKernelFeatures) -> bool: + """ + Heuristic to decide if we should drop the X dimension from a persistent reduction kernel. + So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1. + Strangely this is faster than a [1, RBLOCK] block in some cases. + """ + return ( + features.get_reduction_hint() == ReductionHint.INNER + and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256) + ) + + @staticmethod + def reduction_split_factor( + device: torch.device, + reduction_numel_hint: int, + numel_hint: int, + inner_reduction: bool, + ) -> int: + """Heuristic to decide the RSPLIT used for split reductions. + When a reduction has a small number of outputs there is not enough parallelism, + so we will do the reduction in two phases.""" + props = DeviceProperties.create(device) + num_sm = props.multi_processor_count + min_elements_per_thread = 32 + max_elements_per_thread = 512 + threads_per_sm = 2048 + min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm + max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm + num_warps = 8 + num_threads = 32 * num_warps + + if inner_reduction: + # do heuristics that's close to eager mode for split inner reduction + # we leak reduction autotune configs here, and will need to refactor to avoid this later + if numel_hint >= 2 * num_sm: # don't split if there are enough outputs + return 1 + if reduction_numel_hint <= 8192: + return 1 + if reduction_numel_hint * numel_hint <= min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (2 * num_threads) + blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint + tmp_split_size = ( + reduction_numel_hint + num_threads * blocks_per_output - 1 + ) // (num_threads * blocks_per_output) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(closest - tmp_split_size) < 30: + # prefer even splits, but never smalle than min_elements_per_thread + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + return (reduction_numel_hint + split_size * num_threads - 1) // ( + split_size * num_threads + ) + else: + # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128 + # extend to even smaller number of outputs + rvals_per_thread = 4 # comes from heuristics, refactor to not leak here + xvals_per_block = 128 + xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block + if reduction_numel_hint * numel_hint < min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (num_threads) + target_blocks = (target_blocks + xblocks - 1) // xblocks + tmp_split_size = ( + reduction_numel_hint + rvals_per_thread * target_blocks - 1 + ) // (rvals_per_thread * target_blocks) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(tmp_split_size - closest) < 20: + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + + return (reduction_numel_hint + rvals_per_thread * split_size - 1) // ( + rvals_per_thread * split_size + ) + + @staticmethod + def can_fuse( + scheduler: Scheduler, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + shared_data_score: int, + ) -> bool: + """ + Heuristics to prevent fusion applied to both horizontal and vertical fusions. Heuristics here should not + be needed for correctness and tweaking them may yield additional performance. + + See also some related heuristics that can be changed via config: + - config.triton.tiling_prevents_pointwise_fusion + - config.triton.tiling_prevents_reduction_fusion + - config.aggressive_fusion (will cause this function to be called more times) + """ + if shared_data_score == 0 and ( + not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() + ): + if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"): + common_buf_names: OrderedSet[str] = ( + node1.read_writes.buffer_names() & node2.read_writes.buffer_names() + ) + if len(common_buf_names) > 0: + get_metric_table("fusion_failure_due_to_indexing_mismatch").add_row( + lambda: { + "pre_grad_graph_id": V.graph.graph_id, + "post_grad_graph_id": V.graph.post_grad_graph_id, + "node1_name": node1.get_name(), + "node2_name": node2.get_name(), + "node1_debug_str": write_text(node1.debug_str()), + "node2_debug_str": write_text(node2.debug_str()), + "common_buffer_names": list(common_buf_names), # type: ignore[dict-item] + "failure_reason": scheduler.decide_fusion_fail_reason( + node1, node2, common_buf_names + ), + } + ) + + WhyNoFuse(node1, node2)("no shared data due to indexing mismatch") + return False + WhyNoFuse(node1, node2)("no shared data") + return False # heuristic not needed for correctness + + if ( + not node1.is_foreach() + and not node2.is_foreach() + and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size + ): + WhyNoFuse(node1, node2)("exceeds max fusion") + return False # heuristic not needed for correctness + + if scheduler.can_fusion_increase_peak_memory(node1, node2): + WhyNoFuse(node1, node2)("Fusion will increase peak memory") + return False + + return True + + @staticmethod + def can_fuse_vertical( + scheduler: Scheduler, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + shared_data_score: int, + ) -> bool: + """Hook for heuristics to prevent vertical (producer/consumer) fusions""" + return True + + @staticmethod + def can_fuse_horizontal( + scheduler: Scheduler, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + shared_data_score: int, + ) -> bool: + """Hook for heuristics to prevent horizontal (consumer/consumer) fusions""" + if shared_data_score < config.score_fusion_memory_threshold: + WhyNoFuse(node1, node2)("score_fusion_memory_threshold") + return False + if scheduler.are_long_distant_nodes(node1, node2): + WhyNoFuse(node1, node2)( + "Nodes are too far away. Fusing them may increase peak memory." + ) + return False + return True + + @staticmethod + def score_fusion( + scheduler: Scheduler, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + ) -> Sortable: + """ + Assign a score (higher comes first) to the fusion of node1 and node2. + When different fusions conflict with each other, this is the way we + decide what order to run them in. + + Our current score is based on: + - The type of fusion (template/reduction/etc) + - Estimate of the saved memory operations + - Fusions closer together in original graph order + """ + memory_score = scheduler.score_fusion_memory(node1, node2) + proximity_score = -max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + + # prologue fusion always last + if node2.is_template(): + template_score = 0 + else: + template_score = 1 + ( + (node1.is_template() == config.epilogue_fusion_first) + and memory_score > 0 + ) + + return ( + template_score, + node1.is_reduction() == node2.is_reduction() and memory_score > 0, + memory_score, + proximity_score, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codecache.py b/phivenv/Lib/site-packages/torch/_inductor/codecache.py new file mode 100644 index 0000000000000000000000000000000000000000..a23ebdde49197a804457509c10bd286e3f2ec9d3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codecache.py @@ -0,0 +1,4042 @@ +from __future__ import annotations + +import base64 +import copyreg +import dataclasses +import functools +import hashlib +import importlib +import importlib.resources +import io +import itertools +import json +import logging +import os +import pickle +import pkgutil +import re +import shlex +import shutil +import struct +import subprocess +import sys +import tempfile +import textwrap +import threading +import warnings +from bisect import bisect_right +from copy import copy +from ctypes import c_void_p, CDLL, cdll +from datetime import timedelta +from functools import lru_cache, partial +from pathlib import Path +from time import time, time_ns +from types import ModuleType +from typing import ( + Any, + Callable, + cast, + Generic, + NoReturn, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import override, Self + +import torch +import torch.distributed as dist +from torch import SymInt, Tensor +from torch._dynamo.exc import SkipFrame +from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed +from torch._inductor import config, exc, metrics +from torch._inductor.codegen.common import ( + custom_backend_passes, + init_backend_registration, +) +from torch._inductor.codegen.cuda import cuda_env +from torch._inductor.codegen.rocm.compile_command import ( + rocm_compile_command, + rocm_compiler, +) +from torch._inductor.compile_worker.utils import in_toplevel_process +from torch._inductor.cpp_builder import ( + _LINKER_SCRIPT, + _set_gpu_runtime_env, + _TORCH_PATH, + _transform_cuda_paths, + convert_cubin_to_obj, + CppBuilder, + CppOptions, + CppTorchDeviceOptions, + get_compiler_version_info, + get_ld_and_objcopy, + get_name_and_dir_from_output_file_path, + normalize_path_separator, +) +from torch._inductor.cpu_vec_isa import pick_vec_isa +from torch._inductor.custom_graph_pass import ( + CustomGraphModulePass, + CustomGraphPass, + CustomGraphPassType, +) +from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param +from torch._inductor.runtime.compile_tasks import _reload_python_module +from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir +from torch._inductor.utils import ( + ALIGN_BYTES, + clear_on_fresh_cache, + is_linux, + is_windows, +) +from torch._logging import trace_structured +from torch._subclasses.fake_tensor import ( + extract_tensor_metadata, + FakeTensor, + TensorMetadata, +) +from torch._utils_internal import log_cache_bypass +from torch.compiler import config as cconfig +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) +from torch.export.pt2_archive._package_weights import TensorProperties, Weights +from torch.export.pt2_archive.constants import CUSTOM_OBJ_FILENAME_PREFIX +from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv +from torch.utils._ordered_set import OrderedSet + +from .output_code import CompiledFxGraph +from .remote_cache import create_cache +from .runtime import autotune_cache +from .runtime.autotune_cache import AutotuneCacheBundler +from .triton_bundler import TritonBundler +from .virtualized import V + + +if config.is_fbcode(): + from triton.fb.build import build_paths + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: + pass + + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: + pass + + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: + pass + + def use_global_cache() -> bool: + return False + + +T = TypeVar("T") + +if TYPE_CHECKING: + from collections.abc import Generator, KeysView, Sequence + from concurrent.futures import Future + + from .compile_fx import _CompileFxKwargs + from .cpp_builder import BuildOptionsBase + from .graph import GraphLowering + from .ir import ChoiceCaller + from .output_code import CompiledFxGraphConstants, OutputCode + from .remote_cache import JsonDataTy, RemoteCache + from .runtime.hints import HalideInputSpec, HalideMeta + from .runtime.triton_heuristics import CachingAutotuner + from .utils import InputType + + +_IS_WINDOWS = sys.platform == "win32" +LOCK_TIMEOUT = 600 + +output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") +log = logging.getLogger(__name__) + + +def use_re_build() -> bool: + """ + Use for CUTLASS compilation only right now. + """ + if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()): + from triton.fb.re_build_helper import should_build_locally + + return not should_build_locally() + return False + + +def get_cpp_wrapper_cubin_path_name() -> str: + return "cubin_path" if torch.version.hip is None else "hsaco_path" + + +def get_kernel_bin_format(device: str) -> str: + if device == "cuda": + return "cubin" if torch.version.hip is None else "hsaco" + elif device == "xpu": + return "spv" + else: + return "" + + +@functools.cache +def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]: + return ( + Path(os.path.join(global_cache_dir, CacheBase.get_system()["hash"])) + if global_cache_dir is not None + else None + ) + + +class CacheBase: + @staticmethod + @functools.cache + def get_system() -> dict[str, Any]: + try: + from triton.compiler.compiler import triton_key + + # Use triton_key instead of triton.__version__ as the version + # is not updated with each code change + triton_version = triton_key() + except ModuleNotFoundError: + triton_version = None + + try: + system: dict[str, Any] = { + "device": {"name": None}, + "version": { + "triton": triton_version, + }, + } + device_properties = torch.cuda.get_device_properties( + torch.cuda.current_device() + ) + if torch.version.cuda is not None: + system["device"]["name"] = device_properties.name + system["version"]["cuda"] = torch.version.cuda + else: + system["device"]["name"] = device_properties.gcnArchName + system["version"]["hip"] = torch.version.hip + except (AssertionError, RuntimeError): + # If cuda is not installed, none of the above config is relevant. + system = {} + + system["hash"] = hashlib.sha256( + json.dumps(system, sort_keys=True).encode("utf-8") + ).hexdigest() + + return system + + @staticmethod + @clear_on_fresh_cache + @functools.cache + def get_local_cache_path() -> Path: + return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) + + @staticmethod + def get_global_cache_path() -> Optional[Path]: + return get_global_cache_path_impl(config.global_cache_dir) + + def __init__(self) -> None: + self.system = CacheBase.get_system() + + def get_local_cache(self) -> dict[str, Any]: + local_cache_path = self.get_local_cache_path() + if not local_cache_path.is_file(): + return {} + with open(local_cache_path) as local_cache_fp: + local_cache = json.load(local_cache_fp) + return local_cache["cache"] + + def update_local_cache(self, local_cache: dict[str, Any]) -> None: + local_cache_path = self.get_local_cache_path() + write_atomic( + str(local_cache_path), + json.dumps({"system": self.system, "cache": local_cache}, indent=4), + make_dirs=True, + ) + + +class LocalCache(CacheBase): + def lookup(self, *keys: str) -> Optional[dict[str, Any]]: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys: + if key in cache: + sub_cache = cache[key] + else: + return None + + return sub_cache + + def set_value(self, *keys: str, value: Any) -> None: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys[0:-1]: + sub_cache.setdefault(key, {}) + sub_cache = sub_cache[key] + sub_cache[keys[-1]] = value + + self.update_local_cache(cache) + + +class PersistentCache(CacheBase): + @functools.cache # noqa: B019 + def get_global_cache(self) -> dict[str, Any]: + global_cache_path = self.get_global_cache_path() + if global_cache_path is None or not global_cache_path.is_file(): + return {} + with open(global_cache_path) as global_cache_fp: + global_cache = json.load(global_cache_fp) + return global_cache["cache"] + + def lookup( + self, + choices: list[ChoiceCaller], + op: str, + inputs: str, + benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]], + ) -> dict[ChoiceCaller, float]: + """ + Check to see if we have benchmarked the given choice callers. For each + choice caller: + + 1. Check global_cache[op][inputs][choice][precision], return benchmark if cached. + 2. Check local_cache[op][inputs][choice][precision], return benchmark if cached. + 3. If benchmark is not None: + a. `max_autotune_gemm=True`: benchmark the choice, update + local_cache[op][inputs][choice], and return the benchmark. + b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing. + """ + precision = torch.get_float32_matmul_precision() + + log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision) + log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision) + log_errors = partial( + log_global_cache_errors, self.system, op, inputs, precision + ) + timings = {} + + def check_cache(cache: dict[str, Any], callback: Any = None) -> bool: + """Check if `cache` contains data for all the choices""" + hit = True + for choice in choices: + choice_hash = choice.hash_key() + if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}): + # cache hit + timings[choice] = cache[op][inputs][precision][choice_hash] + else: + # cache miss + hit = False + break + if callback: + callback(cached=hit) + return hit + + if config.max_autotune or config.max_autotune_gemm: + local_cache = self.get_local_cache() if config.autotune_local_cache else {} + # check local cache first since it is data specific to the current machine + if ( + not check_cache(local_cache) + and not ( + use_global_cache() + and check_cache(self.get_global_cache(), callback=log_stats) + ) + and benchmark is not None + ): + try: + # re-benchmark everything to try to get consistent numbers from the same machine + timings = benchmark(choices) + assert all(choice in timings for choice in choices) + local_cache.setdefault(op, {}) + local_cache[op].setdefault(inputs, {}).setdefault(precision, {}) + for choice, timing in timings.items(): + local_cache[op][inputs][precision][choice.hash_key()] = timing + except RuntimeError as e: + # catch and log autotuning failures + log_errors(e) + raise e + + self.update_local_cache(local_cache) + + timings_to_log = { + choice.hash_key(): timings[choice] for choice in choices + } + log_vals(timings_to_log) + elif use_global_cache(): + # only check global cache, not local one + check_cache(self.get_global_cache(), callback=log_stats) + # may have a partial cache hit, where not everything is benchmarked + + return timings + + +def get_lock_dir() -> str: + lock_dir = os.path.join(cache_dir(), "locks") + if not os.path.exists(lock_dir): + os.makedirs(lock_dir, exist_ok=True) + return lock_dir + + +def sha256_hash(data: bytes) -> str: + # [:51] to strip off the "Q====" suffix common to every hash value. + return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() + + +def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str: + hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") + if extra: + extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8") + hashing_str = hashing_str + b"||" + extra_b + return "c" + sha256_hash(hashing_str) + + +def get_path( + basename: str, extension: str, specified_dir: str = "" +) -> tuple[str, str, str]: + if specified_dir: + if os.path.isabs(specified_dir): + subdir = specified_dir + else: + subdir = os.path.join(cache_dir(), specified_dir) + else: + subdir = os.path.join(cache_dir(), basename[1:3]) + path = os.path.join(subdir, f"{basename}.{extension}") + return basename, subdir, path + + +def get_hash( + content: Union[str, bytes], extra: str = "", hash_type: str = "code" +) -> str: + if hash_type in {"amdgcn", "code", "ptx", "spv"}: + return code_hash(content, extra) + if hash_type in {"cubin", "hsaco", "spv"}: + return code_hash(repr(content)) + raise AssertionError(f"Unknown hash type {hash_type}") + + +def write( + content: Union[str, bytes], + extension: str, + extra: str = "", + hash_type: str = "code", + specified_dir: str = "", + key: Optional[str] = None, +) -> tuple[str, str]: + if key is None: + # use striped content to compute hash so we don't end up with different + # hashes just because the content begins/ends with different number of + # spaces. + key = get_hash(content.strip(), extra, hash_type) + basename, _subdir, path = get_path(key, extension, specified_dir) + if not os.path.exists(path): + write_atomic(path, content, make_dirs=True) + return basename, path + + +def write_text(text: str) -> str: + """ + Write the `text` to a file and return the path computed based on the hash. + """ + return write(text, "txt")[1] + + +def write_atomic( + path_: str, + content: Union[str, bytes], + make_dirs: bool = False, + encode_utf_8: bool = False, +) -> None: + # Write into temporary file first to avoid conflicts between threads + # Avoid using a named temporary file, as those have restricted permissions + assert isinstance(content, (str, bytes)), ( + "Only strings and byte arrays can be saved in the cache" + ) + path = Path(path_) + if make_dirs: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" + write_mode = "w" if isinstance(content, str) else "wb" + with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: + f.write(content) + try: + tmp_path.rename(target=path) + except FileExistsError: + if not _IS_WINDOWS: + raise + # On Windows file exist is expected: https://docs.python.org/3/library/pathlib.html#pathlib.Path.rename + # Below two lines code is equal to `tmp_path.rename(path)` on non-Windows OS. + # 1. Copy tmp_file to Target(Dst) file. + shutil.copy2(src=tmp_path, dst=path) + # 2. Delete tmp_file. + os.remove(tmp_path) + + +@dataclasses.dataclass +class TensorMetadataAndValues: + """ + TensorMetadata plus the elements as a list of raw values. + Used for hashing inlined constants. + """ + + tensor_metadata: TensorMetadata + values: list[Any] + + +def _ident(x: T) -> T: + return x + + +def extract_tensor_metadata_for_cache_key(t: Tensor) -> TensorMetadata: + """ + Extracts the tensor metadata and removes fields of the TensorMetadata + that are not needed for caching + """ + meta = extract_tensor_metadata(t) + if not hasattr(t, "_is_inductor_static"): + meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None) + + return meta + + +class FxGraphCachePickler(pickle.Pickler): + """ + Custom pickler to customize the pickling of some objects (Tensors), only for the + purpose of computing a hash for keying into the FxGraphCache. Tensors contain + objects that don't pickle and/or vary between runs, and we want to capture the + data that allow us to compute a stable, but safe hash. + """ + + def __init__( + self, + gm: torch.fx.GraphModule, + has_user_defined_triton_kernels: bool = False, + ) -> None: + """ + Create an FX graph pickler. If include_non_inlined=True, then pickling will + include the _values_ for all Tensors. (Note that any tensors are constants + attached as attributes to the GraphModule). Otherwise, pickling will include + only the metadata for these tensors. + """ + self._stream = io.BytesIO() + super().__init__(self._stream) + + self.dispatch_table = copyreg.dispatch_table.copy() + self.dispatch_table.update( + { + FakeTensor: functools.partial(self._reduce_fake_tensor), + torch.Tensor: functools.partial(self._reduce_tensor), + torch.nn.parameter.Parameter: functools.partial(self._reduce_tensor), + torch.SymInt: functools.partial(self._reduce_symint), + torch.fx.experimental._backward_state.BackwardState: functools.partial( + self._reduce_unsupported + ), + } + ) + if has_user_defined_triton_kernels: + # Need to use runtime type as GraphModule generates a singleton in __new__ function + self.dispatch_table[gm.__class__] = functools.partial( + self._reduce_graph_module + ) + + # Run with pickler.fast so it doesn't intern strings, making the hash result more predictable + # TODO: pickler.fast is technically deprecated. Will this work on new python versions? + self.fast = True + + def _reduce_fake_tensor( + self, t: Tensor + ) -> tuple[Callable[[T], T], tuple[TensorMetadata]]: + """ + Custom reducer to pickle FakeTensors. + """ + metadata = extract_tensor_metadata_for_cache_key(t) + return (_ident, (metadata,)) + + def _reduce_tensor( + self, t: Tensor + ) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]: + """ + Custom reducer to pickle Tensors. If we see tensors, we know they're constants + stored as attributes on the GraphModule. + """ + from .graph import GraphLowering + + if t.is_mkldnn: + # TODO: These tensors don't currently pickle, so we can't cache a compiled + # graph containing them. Just fail now. If mkldnn tensors get pickling + # support, we can remove this. + raise BypassFxGraphCache("mkldnn tensors unpickleable") + + metadata = extract_tensor_metadata_for_cache_key(t) + + # If this is a non-inlined frozen parameter, we consider the metadata only. + if is_frozen_param(t) and not GraphLowering.can_inline_constant(t): + return (_ident, (metadata,)) + + # Very large tensors will be expensive to copy to cpu and hash. Let's at least + # report any slowness. + start = time() + values = t.tolist() + elapsed = time() - start + if elapsed > 1.0: + warnings.warn( + f"FX graph cache copying of a large constant took {elapsed:.1}s. " + "Please file an issue." + ) + + return (_ident, (TensorMetadataAndValues(metadata, values),)) + + def _reduce_symint(self, s: SymInt) -> tuple[Callable[[T], T], tuple[str]]: + """ + Custom reducer to pickle SymInts. + """ + # For hashing purposes, we only care about the name of the symbol and not the + # backed value. We evaluate guards stored with a cached graph to ensure a cached + # entity with SymInt args is safe to reuse. + return (_ident, (str(s),)) + + def _reduce_unsupported(self, s: Any) -> NoReturn: + """ + Custom reducer to handle any objects that we don't support and therefore + raise to bypass caching. + """ + raise BypassFxGraphCache("Reduce unsupported") + + def _reduce_graph_module( + self, gm: torch.fx.GraphModule + ) -> tuple[Any, tuple[dict[str, Any], str]]: + """ + Custom reducer for graph module to handle irrelevant data for user + defined triton kernels + Essentially what we are doing here is a huge hack where user defined + triton kernel contain a dynamo time side table and the arguments to the + call_function are indices into this side table. These arguments are not + for hashing purposes since we included the source code into the cache + key and the numbers are prone to give false negatives due to ordering. + """ + fn, (data, imports) = gm.__reduce__() + code = data["_code"] + code = re.sub(r"kernel_idx = \d+", "", code) + code = re.sub(r"constant_args_idx = \d+", "", code) + data["_code"] = code + return fn, (data, imports) + + def dumps(self, obj: Any) -> bytes: + """ + Pickle an object and return a byte string. + """ + try: + self.dump(obj) + return self._stream.getvalue() + except (TypeError, AttributeError) as e: + # Some configs options may not pickle. + log.warning("Failed to pickle cache key", exc_info=True) + raise BypassFxGraphCache("Failed to pickle cache key") from e + finally: + # Reset our stream for the next dump. + self._stream.seek(0) + self._stream.truncate(0) + + def get_hash(self, obj: Any) -> str: + """ + Serialize an object and return a hash of the bytes. + """ + serialized_data = self.dumps(obj) + return sha256_hash(serialized_data) + + def debug_lines(self, inp: FxGraphHashDetails) -> list[str]: + """ + Get a printable string describing in more detail all the attributes + comprising an object. Useful for debugging when one graph hashes + to a different value than another. + """ + + def get_str(obj: Any) -> str: + if isinstance(obj, torch.Tensor): + return str(extract_tensor_metadata_for_cache_key(obj)) + elif isinstance(obj, bytes): + return "" + elif type(obj) in self.dispatch_table: + # Run the reducer on the object + return str(self.dispatch_table[type(obj)](obj)[1]) + else: + return str(obj) + + lines = [] + for attr, obj in vars(inp).items(): + if isinstance(obj, list): + for ii in range(len(obj)): + h = self.get_hash(obj[ii]) + lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}") + elif isinstance(obj, dict): + for k, v in obj.items(): + h = self.get_hash(v) + lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}") + else: + h = self.get_hash(obj) + lines.append(f"[{h}] {attr}: {get_str(obj)}") + return lines + + +def build_code_hash( + roots: list[str] | None, prefix: str, hasher: hashlib._Hash +) -> None: + for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): + spec = lib.module_finder.find_spec(lib.name, None) + assert spec is not None + module = spec.origin + assert module is not None + with open(module, "rb") as f: + hasher.update(spec.name.encode("utf-8")) + hasher.update(f.read()) + if lib.ispkg: + # need to also hash submodules + build_code_hash(spec.submodule_search_locations, f"{spec.name}.", hasher) + + +def torch_key_cache(func: Callable[[], bytes]) -> Callable[[], bytes]: + """ + This function is a reimplementation of functools.lru_cache with a + set function that allows prepopulating the cache. + """ + # Use list for reference semantics + _cache: list[bytes] = [] + + def wrapper() -> bytes: + if len(_cache) == 0: + _cache.append(func()) + return _cache[0] + + def set_val(val: bytes) -> None: + assert len(_cache) == 0 + _cache.append(val) + + def clear() -> None: + _cache.clear() + + wrapper.set = set_val # type: ignore[attr-defined] + wrapper.clear = clear # type: ignore[attr-defined] + return wrapper + + +@torch_key_cache +def torch_key() -> bytes: + """ + Compute a key that contains relevant information about torch source files + """ + with dynamo_timed("inductor_codecache_torch_key", log_pt2_compile_event=False): + if not config.is_fbcode(): + + def get_code_hash(root: str) -> bytes: + # This function isn't meant to be used outside of torch_key, just a + # helper for clarity. Instead, use torch_key() directly when you need + # a hash representing the state of the source code. + extra_files = ( + "codegen/aoti_runtime/interface.cpp", + "script.ld", + ) + inductor_root = os.path.dirname(__file__) + extra_files = [os.path.join(inductor_root, x) for x in extra_files] + hasher = hashlib.sha256() + hasher.update(torch.__version__.encode("utf-8")) + build_code_hash([root], "", hasher) + for path in extra_files: + if os.path.exists(path): + with open(path, "rb") as f: + hasher.update(f.read()) + return hasher.digest() + + return get_code_hash(_TORCH_PATH) + + from libfb.py import parutil + + return parutil.get_file_contents("torch/src_hash.txt").rstrip().encode("ascii") + + +def get_inductor_root() -> str: + return os.path.dirname(__file__) + + +@dataclasses.dataclass +class OrderedSetHolder: + """ + See FxGraphHashDetails. Holds a sorted list to support stable hashing + of set kwargs. + """ + + items: list[Any] + + +class BypassFxGraphCache(Exception): + """ + Exception to indicate that the FxGraphCache should be bypassed. + """ + + +class FxGraphHashDetails: + """ + Object to capture all the details for a compiled FX graph relevant to computing + a safe and stable cache key. + """ + + # Excluded kwargs param that are not stable between runs + EXCLUDED_KWARGS = ["graph_id"] + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, + inputs_to_check: Sequence[int], + ) -> None: + self.gm = gm + self.example_inputs = example_inputs + self.cache_key_tag = cconfig.cache_key_tag + + # Order kwargs so hashing is stable to changes in kwarg order. Although + # it's technically a _CompileFxKwargs we don't actually need it typed as + # such since we're just using it to generate a hash. + self.fx_kwargs: dict[str, object] = {} + for k, v in sorted(fx_kwargs.items()): + if k not in self.EXCLUDED_KWARGS: + if type(v) in (set, OrderedSet): # noqa: set_linter + # Special case to handle set params. Python sets can't be + # ordered, so sort the elements and store them in a proxy. + self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) # type: ignore[call-overload] + else: + self.fx_kwargs[k] = v + + from torch._higher_order_ops.triton_kernel_wrap import ( + kernel_side_table, + triton_kernel_wrapper_functional, + triton_kernel_wrapper_mutation, + ) + from torch._inductor.codegen.wrapper import ( + user_defined_triton_kernel_transitive_closure_source_code, + ) + + # Node meta will not be part of gm's reduce function, so lets remember + # the kernel source code separately + self.user_defined_triton_source: list[Any] = [] + if gm is not None: + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in itertools.chain( + module.graph.find_nodes( + op="call_function", target=triton_kernel_wrapper_functional + ), + module.graph.find_nodes( + op="call_function", target=triton_kernel_wrapper_mutation + ), + ): + from triton.runtime.autotuner import Autotuner + + kernel = kernel_side_table.get_kernel(node.kwargs["kernel_idx"]) + configs = None + if isinstance(kernel, Autotuner): + if kernel.configs: + configs = str( + sorted( + sorted(str(kv) for kv in c.all_kwargs().items()) + for c in kernel.configs + ) + ) + kernel = kernel.fn + + kernel_source = ( + user_defined_triton_kernel_transitive_closure_source_code( + kernel + ) + ) + constant_args = kernel_side_table.get_constant_args( + node.kwargs["constant_args_idx"] + ) + self.user_defined_triton_source.append( + (kernel_source, constant_args, configs) + ) + + # Alignment checks + self.inputs_to_check = inputs_to_check + + no_tensor_inputs = not any(isinstance(x, torch.Tensor) for x in example_inputs) + # This device index is usually already encoded by the device of the inputs + # but fx graphs don't necessarily have tensor inputs. If there aren't any, + # we need to guard on the device index in case we allocate cuda tensors + if no_tensor_inputs and torch.accelerator.is_available(): + self.default_cuda_device_index = torch.accelerator.current_device_index() + + # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. + self.deterministic_algorithms_settings = ( + torch.are_deterministic_algorithms_enabled(), + torch.is_deterministic_algorithms_warn_only_enabled(), + torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined] + ) + + # Global settings affecting matmul codegen. + self.cuda_matmul_settings = ( + torch.backends.cuda.matmul.allow_tf32, + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction, + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction, + ) + + # Also hash on various system info (including the triton compiler version). + self.torch_version = torch_key() + self.system_info = CacheBase.get_system() + self.inductor_config = config.save_config_portable(ignore_private_configs=False) + # Custom post grad passes should provide an ID to hash. + self.post_grad_custom_pre_pass = self._get_custom_pass_detail( + config.post_grad_custom_pre_pass + ) + self.post_grad_custom_post_pass = self._get_custom_pass_detail( + config.post_grad_custom_post_pass + ) + self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe( + config._pre_fusion_custom_pass + ) + self._fuse_ddp_communication_passes = self._get_custom_pass_detail_unsafe( + config._fuse_ddp_communication_passes + ) + + # Register indcutor backends and custom passes and get their UUIDs. + init_backend_registration() + self.custom_backend_passes = tuple( + map(self._get_custom_pass_detail, custom_backend_passes.values()) + ) + + # This is mainly added to handle these two inductor configs, which are (unfortunately) + # sometimes cache safe: + # - _pre_fusion_custom_pass + # - _fuse_ddp_communication_passes + # Their types can be found in `torch/_inductor/config.py`, but: + # - if they are string names, we can cache them safely (one is by default) + # - if any of them are set to custom callables, we will need to cache miss + # Future work is for someone to find any places where these functions are used + # and force them to be of type CustomGraphPass, so we can guarantee serialization. + def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Optional[Any]: + if not custom_pass: + return None + if isinstance(custom_pass, list): + return [self._get_custom_pass_detail_unsafe(x) for x in custom_pass] + if isinstance(custom_pass, str): + return custom_pass + if isinstance(custom_pass, CustomGraphPass): + return custom_pass.uuid() + if callable(custom_pass): + # Returning None is safe here because we raise an explicit bypass error + # later if we detect these passes are set to callables + return None + raise AssertionError(f"unknown config type: {str(type(custom_pass))}") + + def _get_custom_pass_detail( + self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass] + ) -> Optional[Any]: + if not custom_pass: + return None + assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass)) + return custom_pass.uuid() + + +def compiled_fx_graph_hash( + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, + inputs_to_check: Sequence[int], +) -> tuple[str, list[str]]: + """ + Generate a unique hash of the FX graph for caching. + """ + details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check) + has_user_defined_triton_kernels = len(details.user_defined_triton_source) != 0 + pickler = FxGraphCachePickler(gm, has_user_defined_triton_kernels) + + # The prefix distinguishes among the other kinds of objects we + # cache in this module. + key = "f" + pickler.get_hash(details) + debug_lines = pickler.debug_lines(details) + debug_str = "\n".join(debug_lines) + log.debug(f"FX graph cache hash details for key {key}:\n{debug_str}") # noqa: G004 + return key, debug_lines + + +def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int: + """ + Ephemerally increases the NCCL timeout when compiling for a distributed job + Returns amount of seconds increased + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return 0 + + increased_timeout_sec = int(time_saved_ns // 1e9) # convert to seconds + + if config.is_fbcode(): + fudge_factor = torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:ephemeral_timeout_fudge_factor_percentage" + ) + log.info( + "Ephemeral NCCL timeout increase fudge factor %d and original increase value %d", + fudge_factor, + increased_timeout_sec, + ) + increased_timeout_sec += int(increased_timeout_sec * fudge_factor / 100) + + log.info("Increasing NCCL timeout by %d", increased_timeout_sec) + dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs( + timedelta(seconds=increased_timeout_sec) + ) + return increased_timeout_sec + + +class GuardedCache(Generic[T]): + """ + Mixin for caches that have guards associated with their entries. + """ + + @classmethod + def _get_tmp_dir_for_key(cls: type[GuardedCache[T]], _key: str) -> str: + raise NotImplementedError("Implement _get_tmp_dir_for_key on parent class") + + @classmethod + def iterate_over_candidates( + cls: type[GuardedCache[T]], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + key: str, + ) -> Generator[tuple[T, bytes], None, None]: + if local: + subdir = cls._get_tmp_dir_for_key(key) + if os.path.exists(subdir): + for path in sorted(os.listdir(subdir)): + try: + with open(os.path.join(subdir, path), "rb") as f: + content = f.read() + yield pickle.loads(content), content + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", + exc_info=True, + ) + + if remote_cache: + try: + if (cache_data := remote_cache.get(key)) is not None: + assert isinstance(cache_data, dict) + data = cache_data["data"] + assert isinstance(data, (str, bytes)) + content = base64.b64decode(data) + yield pickle.loads(content), content + except Exception: + log.warning( + "%s unable to load compiled graph", cls.__name__, exc_info=True + ) + + @classmethod + def find_guarded_entry( + cls: type[GuardedCache[T]], + key: str, + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool], + hints: list[int], + ) -> tuple[Optional[T], Optional[bytes], dict[str, str]]: + """ + Find the first cache entry in iterate_over_candidates that passes `evaluate_guards`. + + Args: + key: The cache key to look up + local: Whether to check the local cache + remote_cache: The remote cache to check, if any + evaluate_guards: Function that evaluates whether a guard passes the check, + given a list of hint values and the guard expression. + hints: List of symint hints paired with evaluate_guards + + Returns: + A tuple of (graph, pickled_content) if found, or (None, None) if not found + """ + graph = None + pickled_content = None + result_status = "full_miss" + sample_guards_expr = None + + # Iterate over any entries in the subdir for this key and evaluate + # guards to determine whether there's a hit. + + for candidate, content in cls.iterate_over_candidates(local, remote_cache, key): + assert hasattr(candidate, "guards_expr") + if not candidate.guards_expr: # type: ignore[attr-defined] + # No guards to evaluate, so this is a hit. + graph = candidate + pickled_content = content + result_status = "hit" + break + + # Evaluate the guard expression in the current context. + # If there's not a cache hit, we don't want the evaluation to + # affect the current env, e.g., cause the creation of new guards, + # so we evaluate with the hints instead of the symbols. + hit = bool(evaluate_guards(candidate.guards_expr, hints)) # type: ignore[attr-defined] + if hit: + graph = candidate + pickled_content = content + result_status = "hit" + sample_guards_expr = candidate.guards_expr + break + else: + # At least one guard missed, log this + result_status = "guard_miss" + sample_guards_expr = candidate.guards_expr + + info = {"cache_status_detailed": result_status} + if sample_guards_expr is not None: + info["cache_status_guard_expr"] = sample_guards_expr + return graph, pickled_content, info + + @classmethod + def _filter_backed_symints( + cls: type[GuardedCache[T]], inputs: Sequence[InputType] + ) -> list[torch.SymInt]: + """ + Get the backed SymInt objects from the input list. Note that we can never + have guards that depend on unbacked symint. + """ + return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)] + + @classmethod + def _get_shape_env(cls: type[GuardedCache[T]]) -> Optional[ShapeEnv]: + """ + Helper to get the shape env from the tracing context. + """ + ctx = torch._guards.TracingContext.try_get() + if not ctx: + return None + return ctx.fake_mode.shape_env + + +@CacheArtifactFactory.register +class InductorCacheArtifact(CacheArtifact): + @override + def populate_cache(self) -> None: + FxGraphCache._write_to_local_cache(self.key, self.content) + + @override + @staticmethod + def type() -> str: + return "inductor" + + +class FxGraphCache(GuardedCache[CompiledFxGraph]): + """ + Supports caching and reusing compiled Fx graphs. + + The overall strategy is as follows: + - This cache stores entries on disk. When saving an entry, we can't + serialize callables (that could be C++, Triton, etc.), so we serialize + their own disk cache location. We then recreate the compiled artifact + after fetching from disk. + - For indexing the cache, we gather the fields relevant to identifying an + FxGraph (the graph module, graph inputs, system settings etc.) into an + FxGraphCacheDetails object, pickle it, and compute a hash for the key. + See FxGraphCachePickler. + - Among the metadata we store, we also include a guards expression that's + appropriate for validating any symbols for Tensor arguments that have + symbolic bounds. On cache lookup then, we evaluate those guards in the + current context to validate that a cached entry can be served. + - A given graph could have multiple compiled versions, corresponding to + different sets of guards. Therefore, we store cache entries in the form: + // + - On lookup, we compute the key from the graph details, iterate over all + leaf files in the corresponding subdirectory, deserialize the entry, and + evaluate its guards expression. If the evaluation succeeds, we have a + cache hit. If it fails, we compile the graph and store a new entry. + - Finally, on a cache hit, we need to make sure any guards that would + have been created during compilation are added to the current context. + """ + + # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs + # in an in-memory cache after loading from disk. + @staticmethod + def _get_tmp_dir() -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cache_dir(), "fxgraph") + + @classmethod + def _get_tmp_dir_for_key(cls: type[FxGraphCache], key: str) -> str: + """ + Return the disk location for a given cache key. + """ + return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) + + @staticmethod + def cache_hit_post_compile( + graph: CompiledFxGraph, + cache_info: dict[str, Any], + constants: CompiledFxGraphConstants, + ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]: + """ + Cache specific post compile steps that need to run if we find a graph in the cache + This includes putting bundled triton artifacts in the right place, + reloading the PyCodeCache artifact, etc. + + These don't always happen (i.e. on a cache miss, so they are in a separate function from + CompiledFxGraph.post_compile) + """ + if bundle := graph._triton_bundle: + triton_bundler_meta = TritonBundler.read_and_emit(bundle) + if (meta := triton_bundler_meta) is not None: + cache_info["triton_bundler_meta"] = str(meta) + CompileEventLogger.try_add_pt2_compile( + "inductor_compile", cached_kernel_names=meta.cached_kernel_names + ) + CompileEventLogger.try_add_pt2_compile( + "AOTAutogradCache.inductor_load", + cached_kernel_names=meta.cached_kernel_names, + ) + if len(meta.cached_kernel_names) > 0: + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, "num_triton_bundles" + ) + + try: + artifact_path = graph.after_deserialization(constants) + + from .graph import GraphLowering + + # This is used by tests to check the output for specific details. + if GraphLowering.save_output_code is not None: + GraphLowering.save_output_code(graph.source_code) + + except OSError: + # Not expected, but in case the PyCodeCache entry is removed from + # underneath us, treat it as a cache miss and recompile. + return None, cache_info + + inductor_meta = autotune_cache.inductor_meta_from_config() + code = graph.source_code + AutotuneCacheBundler.begin_compile(inductor_meta, code=code) + + # Increment the cached metrics/counters by the amounts recorded when the FX + # graph was compiled for this cache entry. Pretending these counters + # were incremented normally is useful for testing with the cache enabled. + metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas) + counters["inductor"] += graph.counter_deltas + + output_code_log.debug("Output code: \n%s", code) + output_code_log.debug("Output code written to: %s", artifact_path) + # On cache hit, use artifact path as filename + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_runnable", + "encoding": "string", + }, + payload_fn=lambda: graph.runnable_graph_str, + ) + trace_structured( + "inductor_post_grad_graph", + payload_fn=lambda: graph.inductor_post_grad_graph_str, + ) + trace_structured( + "inductor_output_code", + lambda: {"filename": artifact_path}, + payload_fn=lambda: code, + ) + return graph, cache_info + + @staticmethod + def _lookup_graph( + key: str, + example_inputs: Sequence[InputType], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + constants: CompiledFxGraphConstants, + evaluate_guards: Optional[ + Callable[[str, Union[list[int], list[torch.SymInt]]], bool] + ] = None, + ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]: + """ + Lookup a compiled graph in the cache by key. On a hit, return the + deserialized CompiledFxGraph object. On a miss, return None. + `constants` tracks a list of constants, or a way to obtain the list of constants + associated with a given cache entry + `evaluate_guards` allows AOTAutogradCache and other callers to customize + what constitutes a guard success. Normally, a guard hit happens if + `shape_env.evaluate_guards_expression` returns True. + """ + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + + symints = FxGraphCache._filter_backed_symints(example_inputs) + hints = [hint_int(s) for s in symints] + + # If this config is turned on, everything is a guard hit and we check nothing + if config.unsafe_skip_cache_dynamic_shape_guards: + # This also makes it so we don't add anything to the dynamic + # shape environment + evaluate_guards = lambda x, y: True # noqa: E731 + + if evaluate_guards is None: + evaluate_guards = shape_env.evaluate_guards_expression + + cache_info: dict[str, Any] = dict() + + # Use the find_graph_for_key method to find a graph for the given key + graph, pickled_content, guard_info = FxGraphCache.find_guarded_entry( + key, local, remote_cache, evaluate_guards, hints + ) + cache_info.update(guard_info) + if graph is None: + return None, cache_info + + if pickled_content is not None: + CacheArtifactManager.record_artifact( + InductorCacheArtifact.type(), key, pickled_content + ) + + # Now re-evaluate with the symints to add any guards to the current env. + if graph.guards_expr: + check = bool(evaluate_guards(graph.guards_expr, symints)) + assert check is True + log.debug( + "fx graph cache key %s post-load guards: %s", key, shape_env.guards + ) + + return FxGraphCache.cache_hit_post_compile(graph, cache_info, constants) + + @staticmethod + def _write_to_local_cache(key: str, content: bytes) -> None: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized CompiledFxGraph to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + write_atomic(path, content, make_dirs=True) + + @staticmethod + def _save_graph( + key: str, + compiled_graph: OutputCode, + example_inputs: Sequence[InputType], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + ) -> None: + """ + Store a serialized CompiledFxGraph on disk. + """ + from .compile_fx import CompiledFxGraph + + assert isinstance(compiled_graph, CompiledFxGraph), ( + f"serialization for {type(compiled_graph)} NYI" + ) + + # Before serializing, compute the guard expression that will be used to + # ensure that a CompiledFxGraph is valid when loaded from the cache. It's + # sufficient to consider only the SymInt args to the fx graph since the + # Tensor shapes are already captured in the hash for the cache key. Any + # Tensor arg with a symbolic shape will have a SymInt arg for the graph. + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + symints = FxGraphCache._filter_backed_symints(example_inputs) + guards = shape_env.get_pruned_guards(symints) + compiled_graph.guards_expr = shape_env.produce_guards_expression( + placeholders=symints, guards=guards + ) + disk_compiled_graph = copy(compiled_graph) + disk_compiled_graph.prepare_for_serialization() + + try: + content = pickle.dumps(disk_compiled_graph) + except Exception: + log.warning( + "fx graph cache unable to serialize compiled graph", exc_info=True + ) + counters["inductor"]["fxgraph_cache_pickle_error"] += 1 + return + + try: + CacheArtifactManager.record_artifact( + InductorCacheArtifact.type(), key, content + ) + if local: + FxGraphCache._write_to_local_cache(key, content) + + if remote_cache: + time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6) + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + "time_taken_ms": time_taken_ms, + } + remote_cache.put(key, cache_data) + except Exception: + log.warning("fx graph unable to write to cache", exc_info=True) + counters["inductor"]["fxgraph_cache_write_error"] += 1 + + @staticmethod + def _check_for_hop(gm: torch.fx.GraphModule) -> None: + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if ( + isinstance(node.target, torch._ops.HigherOrderOperator) + and not node.target.cacheable() + ): + raise BypassFxGraphCache( + f"Can't cache HigherOrderOperator: {node.target.name()}" + ) + if node.op == "getattr" and isinstance( + getattr(gm, node.target), torch._C.ScriptObject + ): + raise BypassFxGraphCache("Can't cache torchbind objects") + + @staticmethod + def _check_can_cache(gm: torch.fx.GraphModule) -> None: + """ + Check some conditions that would preclude caching and raise BypassFxGraphCache + to bypass in case caching is not possible. + """ + # Post grad custom passes must implement the CustomGraphPass or we don't + # know how to include them in the cache key calculation. + for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass): + if p and (not isinstance(p, CustomGraphPass) or not p.uuid()): + raise BypassFxGraphCache("Unsupported post grad custom pass") + # We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes + # and ensure they are not passing us raw callables + if config._pre_fusion_custom_pass is not None: + if not isinstance(config._pre_fusion_custom_pass, CustomGraphPass): + raise BypassFxGraphCache("Unsupported _pre_fusion_custom_pass") + for p in config._fuse_ddp_communication_passes: + if callable(p) and not isinstance(p, CustomGraphPass): + raise BypassFxGraphCache("Unsupported _fuse_ddp_communication_pass") + + # Freezing can embed constants that wouldn't be static across runs. + if has_frozen_params(gm) and not torch._utils_internal.justknobs_check( + "pytorch/inductor:allow_freezing_with_caching" + ): + raise BypassFxGraphCache("Skipping graph with frozen constants") + + if config.aot_inductor.use_runtime_constant_folding: + raise BypassFxGraphCache( + "Runtime constant folding can introduce constants that aren't " + "static across runs" + ) + + from torch._inductor.compiler_bisector import CompilerBisector + + if CompilerBisector.bisection_enabled: + log.debug("dont cache graph when bisect enabled") + raise BypassFxGraphCache + + # The treatment of guards in the caching implementation requires that + # we have a shape env. + if FxGraphCache._get_shape_env() is None: + log.debug("fx graph cache no shape env") + raise BypassFxGraphCache("No shape env") + + # We skip caching if there are any HOPs or torchbind objects. + FxGraphCache._check_for_hop(gm) + + @staticmethod + def prepare_key( + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, + inputs_to_check: Sequence[int], + remote: bool, + ) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]: + """ + Checks that the inductor input is cacheable, then computes + and returns the cache key for the input. + Returns (key_info, cache_info) where: + - key_info is (hash_key, debug_lines), and + - cache_info will contain debug info in the event of BypassFxGraphCache. + + NB: It is possible to have this function return a union instead. But + I personally believe it is more annoying/difficult to read in that format. + """ + try: + FxGraphCache._check_can_cache(gm) + key, debug_lines = compiled_fx_graph_hash( + gm, example_inputs, fx_kwargs, inputs_to_check + ) + except BypassFxGraphCache as e: + counters["inductor"]["fxgraph_cache_bypass"] += 1 + log.info("Bypassing FX Graph Cache because '%s'", e) + if remote: + log_cache_bypass("bypass_fx_graph", str(e)) + cache_info = { + "cache_state": "bypass", + "cache_bypass_reason": str(e), + "cache_event_time": time_ns(), + } + return None, cache_info + # If key exists, then cache_info will come from load_with_key + return (key, debug_lines), {} + + @staticmethod + def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: + """ + Attempts to load the remote cache, returns None on error. + """ + cache_id = "fx-graph-v1" + return create_cache( + cache_id, + config.is_fbcode(), + "FbRemoteFxGraphCache", + "RemoteFxGraphCache", + ) + + @staticmethod + def load_with_key( + key: str, + debug_lines: list[str], + example_inputs: Sequence[InputType], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + is_backward: bool, + constants: CompiledFxGraphConstants, + evaluate_guards: Optional[ + Callable[[str, Union[list[int], list[torch.SymInt]]], bool] + ] = None, + ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]: + """ + Lookup the graph with the given key, and return results and metadata. + Doesn't do any logging on its own, because AOTAutograd handles a cache miss + differently from FXGraphCache. + """ + compiled_graph, cache_info = FxGraphCache._lookup_graph( + key, example_inputs, local, remote_cache, constants, evaluate_guards + ) + cache_info = { + **cache_info, + "key": key, + "components": debug_lines, + "cache_event_time": time_ns(), + } + if compiled_graph is not None: + log.info("fx graph cache hit for key %s", key) + counters["inductor"]["fxgraph_cache_hit"] += 1 + cache_info["cache_state"] = "hit" + if remote_cache: + # Count remote cache hit stats + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "inductor_fx_remote_cache_hit_count", + ) + CompileEventLogger.try_( + CompileEventLogger.add_to_set_toplevel, + "inductor_fx_remote_cache_hit_keys", + key, + ) + + if (time_saved_ns := compiled_graph._time_taken_ns) is not None: + cache_info["time_saved_ns"] = time_saved_ns + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "distributed_ephemeral_timeout_us", + time_saved_ns // 1000, + ) + if ( + ephemeral_increase + := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase + else: + if remote_cache: + # Count remote cache miss stats + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "inductor_fx_remote_cache_miss_count", + ) + CompileEventLogger.try_( + CompileEventLogger.add_to_set_toplevel, + "inductor_fx_remote_cache_miss_keys", + key, + ) + log.info("fx graph cache miss for key %s", key) + counters["inductor"]["fxgraph_cache_miss"] += 1 + cache_info["cache_state"] = "miss" + + return compiled_graph, cache_info + + @staticmethod + def clear() -> None: + """ + Clear out the on-disk cache. + """ + try: + shutil.rmtree(FxGraphCache._get_tmp_dir()) + except FileNotFoundError: + pass + + +@functools.cache +def split_aot_inductor_output_path(path: str) -> tuple[str, str]: + """Returns the path where the AOT Inductor compiled kernels are stored.""" + if path.endswith(".so"): + return os.path.split(path) + elif path.endswith(".pt2"): + return os.path.split(path) + else: + return path, "" + + +@clear_on_fresh_cache +class CudaKernelParamCache: + cache: dict[str, dict[str, Any]] = {} + cache_clear = staticmethod(cache.clear) + + @classmethod + def set( + cls, + key: str, + params: dict[str, Optional[str]], + cubin: str, + bin_type: str, + asm: Optional[str] = None, + asm_type: Optional[str] = None, + ) -> None: + basename = None + if config.aot_inductor.package_cpp_only: + assert config.triton.unique_kernel_names, ( + "package_cpp_only requires triton kernel names to be unique" + ) + assert params["mangled_name"], "Missing kernel name" + basename = params["mangled_name"] + + _, bin_path = write( + cubin, + bin_type, + hash_type=bin_type, + specified_dir=split_aot_inductor_output_path( + config.aot_inductor.output_path + )[0], + key=basename, + ) + # Retrieve the basename again in case it is a generated hashcode + basename, _ = get_name_and_dir_from_output_file_path(bin_path) + + if config.aot_inductor.emit_multi_arch_kernel: + bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv"} + assert bin_type in bin_type_to_ext.keys(), ( + "multi_arch_kernel_binary only supported in CUDA/XPU" + ) + base_path, _ = os.path.splitext(bin_path) + bin_path = base_path + bin_type_to_ext[bin_type] + + asm_path: str = "" + if ( + config.aot_inductor.emit_multi_arch_kernel + or config.aot_inductor.package_cpp_only + ): + assert asm, "Missing kernel assembly code" + assert asm_type, "Missing kernel assembly type" + _, asm_path = write( + asm, + asm_type, + hash_type=asm_type, + specified_dir=split_aot_inductor_output_path( + config.aot_inductor.output_path + )[0], + # make sure asm file has the same basename + key=basename, + ) + + params[get_cpp_wrapper_cubin_path_name()] = bin_path + params["asm"] = asm_path + cls.cache[key] = params + + @classmethod + def get(cls, key: str) -> Optional[dict[str, Any]]: + return cls.cache.get(key, None) + + @classmethod + def get_keys(cls) -> KeysView[str]: + return cls.cache.keys() + + +class AotCodeCompiler: + """ + Compile AOT Inductor generated code. + """ + + @classmethod + def compile( + cls, + graph: GraphLowering, + wrapper_code: str, + kernel_code: str, + serialized_extern_kernel_nodes: Optional[str], + *, + device_type: str, + additional_files: list[str], + ) -> Union[list[Union[str, Weights]], str]: + """ + Returns the .so path, or returns a list of files that were generated if + config.aot_inductor.package=True. + """ + generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment] + + if sys.platform == "win32": + raise RuntimeError("AotCodeCompiler not yet supported for inductor") + + _set_gpu_runtime_env() # cpp_extension consults the env + + picked_vec_isa = pick_vec_isa() + vec_isa_cmd_gen = CppBuilder( + name="o", + sources="i", + BuildOption=CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + ), + ) + # write function will calc source_code hash, the same source code with different + # ISA level should be generate different hash. + # So we need get a command_line which contains isa related parameter as a part of hash key. + # And then pass the command_line to below write function as extra parameter to + # guarantee the source code hash contains ISA difference. + cpp_command = repr(vec_isa_cmd_gen.get_command_line()) + + # Meta internal AOTInductor CPU + use_relative_path = ( + config.is_fbcode() and device_type == "cpu" and graph.aot_mode + ) + + ( + specified_output_path, + specified_artifact_name, + ) = split_aot_inductor_output_path(config.aot_inductor.output_path) + + # TODO (benjaminglass1): the CMake packaging path doesn't support linking files + # built with different flags. Until that's implemented, append the kernel code + # to the wrapper and build everything at max optimization. + if config.aot_inductor.package_cpp_only: + wrapper_code = "\n".join((wrapper_code, kernel_code)) + kernel_code = "" + + wrapper_key, wrapper_path = write( + wrapper_code, + "wrapper.cpp", + extra=cpp_command, + specified_dir=specified_output_path, + key=config.aot_inductor.model_name_for_generated_files, + ) + kernel_code = ( + f"// Triton kernels are embedded as comments in {wrapper_path}\n" + + kernel_code + ) + _, kernel_path = write( + kernel_code, + "kernel.cpp", + extra=cpp_command, + specified_dir=specified_output_path, + key=config.aot_inductor.model_name_for_generated_files, + ) + + # Log the AOTInductor wrapper and kernel code, if needed. + with tempfile.NamedTemporaryFile("w+") as t: + t.writelines((wrapper_code, "\n", kernel_code, "\n")) + t.flush() + V.debug.output_code(t.name, extension="cpp") + + if config.aot_inductor.package: + generated_files.append(wrapper_path) + if not config.aot_inductor.package_cpp_only: + generated_files.append(kernel_path) + + output_code_log.info("Wrapper code written to: %s", wrapper_path) + output_code_log.info("Kernel code written to: %s", kernel_path) + trace_structured( + "graph_dump", + lambda: { + "name": "inductor_aot_wrapper_code", + "type": "cpp", + "filename": wrapper_path, + }, + payload_fn=lambda: wrapper_code, + ) + trace_structured( + "graph_dump", + lambda: { + "name": "inductor_aot_kernel_code", + "type": "cpp", + "filename": kernel_path, + }, + payload_fn=lambda: kernel_code, + ) + + # We use a file lock below to protect FS operations. The lock file + # is scoped to the 'key', so make sure the consts_s is protected + # by the same lock: + wrapper_path_operator = Path(wrapper_path) + kernel_path_operator = Path(kernel_path) + specified_sub_dir = wrapper_path_operator.parent / wrapper_key + if not specified_sub_dir.exists(): + specified_sub_dir.mkdir(exist_ok=True) + cmake_path = str(Path(specified_sub_dir) / "CMakeLists.txt") + + def _compile_consts(consts: bytes, platform: str) -> str: + if platform == "linux": + if graph.mutated_buffers & OrderedSet(graph.constants.keys()): + # .data section is between .text and .bss. When the size of .data is large, + # during the linking, the relocation of .text against .bss may overflow. + # Rename it to .ldata so that it won't be in between the .text and .bss section + if len(consts) > 2_000_000_000: + raise ValueError( + "Models with buffer mutation included doesn't support constants greater than 2GB!" + ) + section_attr = '.ldata, "aw"' + else: + section_attr = '.lrodata, "a"' + symbol_prefix = "" + elif platform == "darwin": + section_attr = "__DATA,__data" + symbol_prefix = "_" + else: + raise RuntimeError(f"Unsupported platform: {platform}") + + is_large_consts = len(consts) > 1024 + consts_asm = f"\t.section\t{section_attr}\n" + consts_asm += f"\t.balign {ALIGN_BYTES}\n" + consts_asm += f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n" + consts_asm += f"{symbol_prefix}_binary_constants_bin_start:\n" + if not is_large_consts: + for c in consts: + consts_asm += f"\t.byte {c}\n" + # Add one element even if constants are empty + # Otherwise assembler will not put them in data section + if not consts: + consts_asm += "\t.space 1\n" + else: + consts_asm += "\t.quad 0x1234567899abcdef\n" + consts_asm += f"\t.space {len(consts) - 8}\n" + consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" + consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n" + _, consts_s = write( + consts_asm, + "S", + specified_dir=str(specified_sub_dir), + ) + consts_s = Path(consts_s) + object_build_options = CppTorchDeviceOptions( + # Intel compiler failed to compile this manually constructed assembly file. + # it is ok to use gcc to compile the .S to a .o and linked with Intel compiler . + device_type=device_type if device_type != "xpu" else "cpu", + aot_mode=graph.aot_mode, + compile_only=True, + use_relative_path=use_relative_path, + ) + object_builder = CppBuilder( + name=str(consts_s.stem), + sources=str(consts_s), + output_dir=str(consts_s.parent), + BuildOption=object_build_options, + ) + consts_o = object_builder.get_target_file_path() + object_builder.build() + + if is_large_consts: + with open(consts_o, "r+b") as f: + f.seek(0) + hdr = f.read(1024) + # Search for magic number and write the actual data over it + start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") + assert start_idx != -1 + f.seek(start_idx) + pos = 0 + while pos < len(consts): + rc = f.write(consts[pos:]) + pos += rc + + # Remove the .S file to save space + os.remove(consts_s) + + return consts_o + + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock( + os.path.join(lock_dir, wrapper_key + ".lock"), timeout=LOCK_TIMEOUT + ) + with lock: + if serialized_extern_kernel_nodes: + extern_kernel_nodes_json = str( + wrapper_path_operator.with_suffix(".json") + ) + with open(extern_kernel_nodes_json, "w") as f: + f.write(serialized_extern_kernel_nodes) + + if config.aot_inductor.package: + generated_files.append(extern_kernel_nodes_json) + + metadata = config.aot_inductor.metadata + metadata["AOTI_DEVICE_KEY"] = device_type + + # Save user provided metadata + meta_json = str( + wrapper_path_operator.with_name( + f"{wrapper_path_operator.stem}_metadata.json" + ) + ) + for k, v in config.aot_inductor.metadata.items(): + assert isinstance(k, str) and isinstance(v, (str)), ( + "Metadata must only contain strings" + ) + + with open(meta_json, "w") as f: + f.write(json.dumps(config.aot_inductor.metadata)) + + kernel_meta_json = str( + kernel_path_operator.with_name( + f"{kernel_path_operator.stem}_metadata.json" + ) + ) + shutil.copy(meta_json, kernel_meta_json) + + if config.aot_inductor.package: + generated_files.append(meta_json) + if not config.aot_inductor.package_cpp_only: + generated_files.append(kernel_meta_json) + + output_so = ( + config.aot_inductor.output_path + if specified_artifact_name + else str(wrapper_path_operator.with_suffix(".so")) + ) + all_cuda = all( + graph.get_original_value_of_constant(name).is_cuda + for name in graph.constants.keys() + if name not in graph.folded_constants + ) + + def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: + def _pad_to_alignment(raw_bytes: bytes) -> bytes: + padded_bytes = raw_bytes.ljust( + (len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES, + b"\x00", + ) + return padded_bytes + + # This serializes the tensor's untyped_storage to bytes by accessing + # the raw data of the underlying structure. + import ctypes + + if t.numel() == 0: + return b"" + + if t.is_mkldnn: + data_ptr = torch.ops.mkldnn.data_ptr(t) + nbytes = torch.ops.mkldnn._nbytes(t) + else: + t_cpu = t.untyped_storage().cpu() + data_ptr = t_cpu.data_ptr() + nbytes = t_cpu.nbytes() + + raw_array = ctypes.cast( + data_ptr, + ctypes.POINTER(ctypes.c_ubyte * nbytes), + ) + raw_bytes = bytes(raw_array.contents) + return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) + + if config.aot_inductor.package_constants_in_so: + serialized_weights = b"".join( + _to_bytes(graph.get_original_value_of_constant(name), all_cuda) + for name in graph.constants.keys() + if name not in graph.folded_constants + ) + else: + serialized_weights = b"" + + if config.aot_inductor.package_constants_on_disk: + # We need to return a storage key here because the original value tensor might be a clone + weights_dict = Weights( + { + graph.allocated_constant_name[name]: ( + graph.get_original_value_of_constant(name), + TensorProperties(graph.constants[name]), + ) + for name in graph.constants.keys() + if name not in graph.folded_constants + } + ) + generated_files.append(weights_dict) + + consts_size = len(serialized_weights) + + # TODO: Fix mmap weights with cuda + use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000 + if config.aot_inductor.force_mmap_weights: + use_mmap_weights = True + + compile_command: dict[str, Any] = { + "aot_mode": graph.aot_mode, + "device_type": device_type, + "use_mmap_weights": use_mmap_weights, + "use_relative_path": use_relative_path, + "vec_isa": picked_vec_isa, + } + # If we're packaging via CMake, we build the whole code at max optimization. + wrapper_build_options = CppTorchDeviceOptions( + compile_only=True, + min_optimize=not config.aot_inductor.package_cpp_only, + **compile_command, + ) + kernel_build_options = CppTorchDeviceOptions( + compile_only=True, + **compile_command, + ) + + # potentially, precompile the AOT header for this device + if config.aot_inductor.precompile_headers and not _IS_WINDOWS: + header_file = _get_cpp_wrapper_header( + device_type, aot_mode=graph.aot_mode + ) + wrapper_build_options.precompiled_header = _precompile_header( + header_file, + cpp_command, + min_optimize=not config.aot_inductor.package_cpp_only, + **compile_command, + ) + if cpp_prefix := _get_cpp_prefix_header(device_type): + kernel_build_options.precompiled_header = _precompile_header( + cpp_prefix, + cpp_command, + **compile_command, + ) + + wrapper_builder = CppBuilder( + name=str(wrapper_path_operator.stem), + sources=wrapper_path, + output_dir=str(wrapper_path_operator.parent), + BuildOption=wrapper_build_options, + ) + wrapper_compile_cmd = wrapper_builder.get_command_line() + wrapper_o = wrapper_builder.get_target_file_path() + + kernel_builder = CppBuilder( + name=str(kernel_path_operator.stem), + sources=kernel_path, + output_dir=str(wrapper_path_operator.parent), + BuildOption=kernel_build_options, + ) + kernel_compile_cmd = kernel_builder.get_command_line() + kernel_o = kernel_builder.get_target_file_path() + + log.debug("aot wrapper compilation command: %s", wrapper_compile_cmd) + log.debug("aot kernel compilation command: %s", kernel_compile_cmd) + if config.aot_inductor.package_cpp_only: + # Not doing the actual compilation here + compile_flags = str( + wrapper_path_operator.with_name( + f"{wrapper_path_operator.stem}_compile_flags.json" + ) + ) + wrapper_build_options.save_flags_to_json(compile_flags) + generated_files.append(compile_flags) + wrapper_builder.save_compile_cmd_to_cmake(cmake_path, device_type) + wrapper_builder.save_src_to_cmake(cmake_path, wrapper_path) + generated_files.append(cmake_path) + else: + try: + wrapper_builder.build() + except (exc.CppCompileError, SkipFrame) as e: + if " is too big to optimize" in str(e): + raise RuntimeError( + "Please use torch._inductor.config.aot_inductor.compile_wrapper_opt_level = 'O0' flag." + ) from e + raise e + kernel_builder.build() + + if not use_mmap_weights: + aot_constants = serialized_weights + magic_number = 0 + else: + magic_number = cast( + int, torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item() + ) + aot_constants = struct.pack("qq", consts_size + 8, magic_number) + + consts_o = _compile_consts(aot_constants, sys.platform) + custom_obj_idx = 0 + # Note that custom_objs_config.json file is different from the model_constants_config.json file produced + # in package_sigmoid(). The keys in custom_objs_config.json directly correspond to the arg name in extern + # nodes json. The key in model_constants_config.json produced by package_sigmoid is the attribute name in the + # user model code. + + qual_name_to_id = {} # Map from constant name to its name in constants folder + for custom_obj_idx, (name, constant) in enumerate( + graph.torchbind_constants.items() + ): + if isinstance( + constant, torch._library.fake_class_registry.FakeScriptObject + ): + constant = constant.real_obj + assert isinstance(constant, torch._C.ScriptObject) + custom_obj_name = f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}" + + log.debug("saving script object %s as %s", name, custom_obj_name) + + qual_name_to_id[name] = custom_obj_name + custom_obj_bytes = torch._C._pickle_save(constant) + custom_obj_path = os.path.join( + wrapper_path_operator.parent, custom_obj_name + ) + + write_atomic(custom_obj_path, custom_obj_bytes, True) + generated_files.append(custom_obj_path) + + if qual_name_to_id: + constants_config_json = os.path.join( + wrapper_path_operator.parent, "custom_objs_config.json" + ) + with open(constants_config_json, "w") as f: + f.write(json.dumps(qual_name_to_id)) + generated_files.append(constants_config_json) + + gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = ( + ROCmCodeCache() if torch.version.hip else CUDACodeCache() + ) + gpu_kernels_o = gpu_codecache.aot_kernels_o.copy() + # clear the list of aot kernels after each linking + gpu_codecache.aot_kernels_o.clear() + + if gpu_kernels_o: + assert not config.aot_inductor.emit_multi_arch_kernel, ( + "TODO: add emit_multi_arch_kernel support for cutlass kernels" + ) + + cubins_o = [] + asm_files = [] + ld, objcopy = get_ld_and_objcopy(use_relative_path) + for kernel_name, value in CudaKernelParamCache.cache.items(): + if asm_file := value["asm"]: + asm_files.append(asm_file) + + cubin_file = value[get_cpp_wrapper_cubin_path_name()] + if config.aot_inductor.emit_multi_arch_kernel and device_type == "cuda": + current_arch = _nvcc_arch_as_compile_option() + cmd = ( + f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " + # Triton only allows generating PTX version as same as the current arch + f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " + # Include SASS for the current specific arch + f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " + ) + subprocess.run( + cmd.split(), capture_output=True, text=True, check=True + ) + + if config.aot_inductor.embed_kernel_binary: + # Embed cubin files into model.so using objcopy + cubins_o.append( + convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) + ) + + output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) + so_build_options = CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + use_relative_path=use_relative_path, + ) + + obj_srcs = [wrapper_o, kernel_o, consts_o, *gpu_kernels_o, *cubins_o] + so_builder = CppBuilder( + name=output_name, + sources=obj_srcs, + output_dir=output_dir, + BuildOption=so_build_options, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + log.debug("aot linkage command: %s", link_cmd) + + # Append cmds to the end of codegen-ed wrapper file + with open(wrapper_path, "a") as f: + f.write("\n") + f.write(f"// Compile cmd\n// {wrapper_compile_cmd}\n") + f.write(f"// Link cmd\n// {link_cmd}\n") + + with open(kernel_path, "a") as f: + f.write("\n") + f.write(f"// Compile cmd\n// {kernel_compile_cmd}\n") + f.write(f"// Link cmd\n// {link_cmd}\n") + + if config.aot_inductor.package_cpp_only: + linker_flags = str( + wrapper_path_operator.with_name( + f"{wrapper_path_operator.stem}_linker_flags.json" + ) + ) + so_build_options.save_flags_to_json(linker_flags) + generated_files.append(linker_flags) + generated_files.append(_LINKER_SCRIPT) + + # If we only want to package the cpp, then we need to save the + # weights separately into a bin, and we also need to prevent compiling the so + if use_mmap_weights: + weight_file = str( + wrapper_path_operator.with_name( + f"{wrapper_path_operator.stem}_serialized_weights.bin" + ) + ) + with open(weight_file, "wb") as f_weights: + f_weights.write(serialized_weights) + f_weights.write(struct.pack("q", magic_number)) + + generated_files.append(weight_file) + else: + # TODO: unify to always use mmap_weights + generated_files.append(consts_o) + so_builder.save_src_to_cmake(cmake_path, consts_o) + + if config.aot_inductor.emit_multi_arch_kernel: + so_builder.save_kernel_asm_to_cmake(cmake_path, asm_files) + generated_files.extend(asm_files) + else: + obj_srcs = [*gpu_kernels_o, *cubins_o] + generated_files.extend(obj_srcs) + for obj in obj_srcs: + so_builder.save_src_to_cmake(cmake_path, obj) + + so_builder.save_link_cmd_to_cmake(cmake_path) + else: + so_builder.build() + for o_file in obj_srcs: + if o_file in gpu_kernels_o: + continue + # Remove these as they are not needed anymore + os.remove(o_file) + + if use_mmap_weights: + import resource + + page_size_ = resource.getpagesize() + page_size = max(16384, page_size_) + + with open(output_so, "a+b") as f_so: + so_size = f_so.tell() + # Page align the weights + f_so.write(b" " * (page_size - so_size % page_size)) + f_so.write(serialized_weights) + f_so.write(struct.pack("q", magic_number)) + + if config.aot_inductor.package: + generated_files.append(output_so) + + if config.aot_inductor.package: + # We want to return the directory that contains all the AOTI + # generated files, not just the so + # return os.path.split(output_so)[0] + return generated_files + + return output_so + + +_libgomp: Optional[CDLL] = None + + +def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, None]: + # This function will be called from generated cpp wrapper code in the JIT mode. + # Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them. + def convert_arg(arg: Any) -> Any: + if str(type(arg)) == "": + # No easy way to do isinstance check on PyCapsule + return torch._C._aoti.alloc_tensor_by_stealing_from_void_ptr(arg) + elif isinstance(arg, (list, tuple)): + return type(arg)(convert_arg(a) for a in arg) + else: + return arg + + converted_args = [convert_arg(arg) for arg in args] + + assert op.startswith("torch.ops."), ( + op + " can not be called through custom_op_wrapper" + ) + func = None + for i, s in enumerate(op.split(".")): + if i == 0: + func = importlib.import_module(s) + func = getattr(func, s) + + assert callable(func), op + " can not be loaded through custom_op_wrapper" + + # convert any kwarg-only arguments to kwargs + kwargs = dict() + for func_arg, conv_arg in zip(func._schema.arguments, converted_args): + if func_arg.kwarg_only: + kwargs[func_arg.name] = conv_arg + if kwargs: + del converted_args[-len(kwargs) :] + + result = func(*converted_args, **kwargs) + if result is None: + return None + + if isinstance(result, (list, tuple)): + # unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only + result = [torch.tensor([]) if r is None else r for r in result] + for i, r in enumerate(result): + assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" + return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] + + assert isinstance(result, torch.Tensor), op + " returns a non-tensor" + return torch._C._aoti.unsafe_alloc_void_ptr_from_tensor(result) + + +# Precompiled headers are persistent past program runtime, but associated with one +# specific compiler version and set of flags. We explicitly use default_cache_dir here +# because these headers need to be global, rather than ignored by fresh_cache. +_HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers") +_HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks") + + +@functools.cache +def _precompile_header( + header: str, + hashable_cmd_line: str, + **compile_command: Any, +) -> str: + assert not _IS_WINDOWS, ( + "CppBuilder does not currently support precompiling on Windows!" + ) + + # Get the preprocessed output from the header file to be precompiled. This allows + # us to properly invalidate the file cache when any header dependency changes. This + # is thread-safe, as each thread will get its own temporary directory. + # + # N.B. we can't use NamedTemporaryFile here because Windows errors out on attempts + # to read from a file with an open write handle. + with tempfile.TemporaryDirectory() as preprocessing_dir: + preprocessing_header = Path(preprocessing_dir) / "header.hpp" + preprocessing_header.write_text(f"#include <{header}>\n") + preprocessor = CppBuilder( + name=str(preprocessing_header)[:-4], # strip off the .hpp extension + sources=str(preprocessing_header), + BuildOption=CppTorchDeviceOptions(**compile_command, preprocessing=True), + ) + preprocessor.build() + + def _get_file_checksum(filename: str) -> str: + """Reading the whole preprocessed header in for hashing is very expensive, + but calling a fast hashing utility in a subprocess is cheap.""" + # If Windows support needs to be added here, use certutil -hashfile. + cmd_output = subprocess.run( + ("openssl", "sha512", filename), capture_output=True, text=True + ) + return cmd_output.stdout.split()[-1] + + preprocessor_hash = _get_file_checksum(preprocessor.get_target_file_path()) + + header_build_option = CppTorchDeviceOptions(**compile_command, precompiling=True) + header_hash, header_full_path = write( + content=f"#include <{header}>\n", + extension="h", + extra=( + hashable_cmd_line + + preprocessor_hash + + get_compiler_version_info(header_build_option.get_compiler()) + ), + specified_dir=_HEADER_DIR, + ) + cpp_builder = CppBuilder( + name=header_full_path, + sources=header_full_path, + BuildOption=header_build_option, + ) + # _worker_compile_cpp will automatically ignore any compilation whose result already + # exists, so this is always safe. + os.makedirs(_HEADER_LOCK_DIR, exist_ok=True) + _worker_compile_cpp( + os.path.join(_HEADER_LOCK_DIR, f"{header_hash}.lock"), + (cpp_builder,), + ) + + return header_full_path + + +def _get_cpp_prefix_header(device: str) -> Optional[str]: + if device.startswith("cpu"): + return "torch/csrc/inductor/cpp_prefix.h" + return None + + +def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str: + """Given a device type (and optionally whether we're in AOT Inductor mode), returns + the path to the cpp_wrapper header file to be precompiled.""" + base_device = device.split(":")[0] + is_array_ref = config.aot_inductor.allow_stack_allocation and base_device == "cpu" + return ( + "torch/csrc/inductor/" + f"{'aoti_include' if aot_mode else 'cpp_wrapper'}/" + f"{'array_ref' if is_array_ref else base_device}.h" + ) + + +@clear_on_fresh_cache +class CppCodeCache: + """Compiles and caches C++ libraries. Users of this class supply the source code to + be compiled, while compilation flags are set by CppBuilder.""" + + cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags: dict[str, Any] = {} + + @staticmethod + def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]: + return cdll.LoadLibrary(path) + + @classmethod + def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]: + try: + result = cls._load_library_inner(path, key) + result.key = key # type: ignore[union-attr] + return result + except (ImportError, OSError) as e: + if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"): + # hacky workaround for fbcode/buck + global _libgomp + _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1") + result = cls._load_library_inner(path, key) + result.key = key # type: ignore[union-attr] + return result + if "failed to map segment from shared object" in str(e): + raise OSError( + f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder " + "is mounted with noexec (e.g., by default Docker mounts tmp file systems " + f"as noexec). Please remount {tempfile.gettempdir()} with exec enabled, or set another " + "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable." + ) from e + raise + + @classmethod + def _get_uncompiled_header(cls, device: str) -> str | None: + """ + Given a device type, returns the path to a CPP header file to be precompiled. + """ + return None + + @classmethod + def load_async( + cls, + main_code: str, + device_type: str = "cpu", + submit_fn: Any = None, + extra_flags: Sequence[str] = (), + optimized_code: Optional[str] = None, + ) -> Any: + """Compile and load a C++ library. Returns a callable that returns the loaded + library.""" + compile_command = { + **cls.cpp_compile_command_flags, + "device_type": device_type, + "extra_flags": extra_flags, + "use_relative_path": config.is_fbcode(), + "vec_isa": pick_vec_isa(), + } + + _set_gpu_runtime_env() # cpp_extension consults the env + + # Note the distinction between the two booleans. We do minimal optimization if + # the optimized_code argument is present at all, since that's how the user of + # this function opts in, but we do compilation and linking in one step if the + # optimized_code argument is empty (as a micro-optimization). + main_build_option = CppTorchDeviceOptions( + compile_only=bool(optimized_code), + min_optimize=optimized_code is not None, + **compile_command, + ) + optimized_build_option = CppTorchDeviceOptions( + compile_only=True, **compile_command + ) + + def get_hashable_command_line(build_option: BuildOptionsBase) -> str: + """Writing the code to file will calculate a hash, which we need to vary if + the command line flags change. This implements a mostly-generic way of + validating that.""" + return CppBuilder( + name="o", sources="i", BuildOption=build_option + ).get_command_line() + + main_cmd_line = get_hashable_command_line(main_build_option) + optimized_cmd_line = get_hashable_command_line(optimized_build_option) + + key, main_path = write( + main_code, "main.cpp", extra=f"{optimized_code} {main_cmd_line}" + ) + + # Don't bother writing if the argument is empty. + if optimized_code: + _, optimized_path = write( + optimized_code, "optimized.cpp", extra=optimized_cmd_line + ) + else: + # Unused, but makes type checkers happy. + optimized_path = os.devnull + + if key not in cls.cache: + from torch.utils._filelock import FileLock + + lock_path = os.path.join(get_lock_dir(), key + ".lock") + future: Optional[Future[Any]] = None + lib = None + + # if requested, pre-compile any headers + if config.cpp_cache_precompile_headers and not _IS_WINDOWS: + if header := cls._get_uncompiled_header(device_type): + main_build_option.precompiled_header = _precompile_header( + header, + main_cmd_line, + min_optimize=optimized_code is not None, + **compile_command, + ) + + # Currently, the optimized_code field is only used for cpp kernel code, + # so go ahead and precompile the relevant header here. Revisit this + # decision if that ever changes. + if optimized_code and (header := _get_cpp_prefix_header(device_type)): + optimized_build_option.precompiled_header = _precompile_header( + header, + optimized_cmd_line, + **compile_command, + ) + + main_name, output_dir = get_name_and_dir_from_output_file_path(main_path) + main_builder = CppBuilder( + name=main_name, + sources=main_path, + BuildOption=main_build_option, + output_dir=output_dir, + ) + + if optimized_code: + optimized_name, _ = get_name_and_dir_from_output_file_path( + optimized_path + ) + optimized_builder = CppBuilder( + name=optimized_name, + sources=optimized_path, + BuildOption=optimized_build_option, + output_dir=output_dir, + ) + + linker = CppBuilder( + name=main_name, + sources=[ + main_builder.get_target_file_path(), + optimized_builder.get_target_file_path(), + ], + BuildOption=CppTorchDeviceOptions(**compile_command), + output_dir=output_dir, + ) + + worker_fn = functools.partial( + _worker_compile_cpp, + lock_path, + (main_builder, optimized_builder, linker), + ) + binary_path = normalize_path_separator(linker.get_target_file_path()) + else: + worker_fn = functools.partial( + _worker_compile_cpp, lock_path, (main_builder,) + ) + binary_path = normalize_path_separator( + main_builder.get_target_file_path() + ) + + def load_fn() -> Any: + nonlocal lib + if lib is None: + if future is not None: + future.result() + result = worker_fn() + assert result is None + lib = cls._load_library(binary_path, key) + assert lib is not None + return lib + + if submit_fn is not None: + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + if not os.path.exists(binary_path): + future = submit_fn(worker_fn) + + cls.cache[key] = load_fn + + return cls.cache[key] + + @classmethod + def load(cls, *args: Any, **kwargs: Any) -> Any: + return cls.load_async(*args, **kwargs)() + + +def _worker_compile_cpp( + lock_path: str, + cpp_builders: Sequence[CppBuilder], +) -> None: + from torch.utils._filelock import FileLock + + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + for builder in cpp_builders: + if not os.path.exists(builder.get_target_file_path()): + builder.build() + + +# Customized Python binding for cpp kernels +@clear_on_fresh_cache +class CppPythonBindingsCodeCache(CppCodeCache): + cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + # kernels have no dependency on libtorch + "include_pytorch": False, + "shared": True, + } + entry_function = "kernel" + call_entry_function = "kernel({}); Py_RETURN_NONE;" + extra_parse_arg = "" + suffix_template = textwrap.dedent( + """ + // Python bindings to call {entry_func}(): + #define PY_SSIZE_T_CLEAN + #include + #include + #include + + #ifndef _MSC_VER + #if __cplusplus < 202002L + // C++20 (earlier) code + // https://en.cppreference.com/w/cpp/language/attributes/likely + #define likely(x) __builtin_expect(!!(x), 1) + #define unlikely(x) __builtin_expect(!!(x), 0) + #endif + #else + #define likely(x) (x) + #define unlikely(x) (x) + #endif + + // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. + // We manually link it below to workaround issues with fbcode build. + static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj); + + template static inline T parse_arg(PyObject* args, size_t n) {{ + static_assert(std::is_pointer_v, "arg type must be pointer or long"); + return static_cast(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n))); + }} + template <> inline int64_t parse_arg(PyObject* args, size_t n) {{ + auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == -1 && PyErr_Occurred())) + throw std::runtime_error("expected int arg"); + return result; + }} + template <> inline uintptr_t parse_arg(PyObject* args, size_t n) {{ + auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == reinterpret_cast(-1) && PyErr_Occurred())) + throw std::runtime_error("expected int arg"); + return reinterpret_cast(result); + }} + + {extra_parse_arg} + + static PyObject* {entry_func}_py(PyObject* self, PyObject* args) {{ + try {{ + if(unlikely(!PyTuple_CheckExact(args))) + throw std::runtime_error("tuple args required"); + if(unlikely(PyTuple_GET_SIZE(args) != {arg_len})) + throw std::runtime_error("requires {arg_len} args"); + {call_entry_func} + }} catch(std::exception const& e) {{ + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + }} catch(...) {{ + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return nullptr; + }} + }} + + static PyMethodDef py_methods[] = {{ + {{"{entry_func}", {entry_func}_py, METH_VARARGS, ""}}, + {{NULL, NULL, 0, NULL}}}}; + + static struct PyModuleDef py_module = + {{PyModuleDef_HEAD_INIT, "{entry_func}", NULL, -1, py_methods}}; + + PyMODINIT_FUNC PyInit_{entry_func}(void) {{ + const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"); + if(!str_addr) {{ + PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set"); + return nullptr; + }} + std::istringstream iss(str_addr); + uintptr_t addr = 0; + iss >> addr; + _torchinductor_pyobject_tensor_data_ptr = + reinterpret_cast(addr); + PyObject* module = PyModule_Create(&py_module); + if (module == NULL) {{ + return NULL; + }} + #ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED); + #endif + return module; + }} + """ + ) + + @classmethod + def _load_library_inner(cls, path: str, key: str) -> ModuleType: + os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str( + torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined] + ) + module_name = f"{key}.{cls.entry_function}" + try: + return sys.modules[module_name] + except KeyError: + pass + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + @classmethod + def _get_uncompiled_header(cls, device: str) -> str | None: + return _get_cpp_prefix_header(device) + + @classmethod + def load_pybinding_async( + cls, + argtypes: Sequence[str], + main_code: str, + device_type: str = "cpu", + num_outputs: int = -1, + submit_fn: Any = None, + extra_flags: Sequence[str] = (), + kernel_code: Optional[str] = None, + ) -> Any: + """ + Wrap a C++ function in fast Python bindings. + + Args: + argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"] + main_code: C++ source code containing ENTRY_FUNCTION(). Will be built at + -O3 if kernel_code is None (to maximize performance in any kernels that + are present), or -O1 otherwise (to minimize compile time). + kernel_code: If present, C++ source code that will be built at -O3 and + linked to main_code. + + Returns: + A python version of ENTRY_FUNCTION() + """ + parseargs = ", ".join( + f"parse_arg<{argtype.replace('const ', '')}>(args, {n})" + for n, argtype in enumerate(argtypes) + ) + suffix = cls.suffix_template.format( + arg_len=len(argtypes), + call_entry_func=cls.call_entry_function.format(parseargs), + entry_func=cls.entry_function, + extra_parse_arg=cls.extra_parse_arg.format(array_len=num_outputs), + ) + get_result = cls.load_async( + main_code + suffix, + device_type, + submit_fn=submit_fn, + extra_flags=extra_flags, + optimized_code=kernel_code, + ) + result = None + + def future() -> Any: + nonlocal result + if result is None: + result = get_result() + assert isinstance(result, ModuleType) + return getattr(result, cls.entry_function) + + return future + + @classmethod + def load_pybinding(cls, *args: Any, **kwargs: Any) -> Any: + return cls.load_pybinding_async(*args, **kwargs)() + + +@clear_on_fresh_cache +class CppWrapperCodeCache(CppPythonBindingsCodeCache): + cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + "include_pytorch": True, + "shared": True, + } + entry_function = "inductor_entry_cpp" + call_entry_function = "return inductor_entry_cpp({});" + extra_parse_arg = textwrap.dedent( + """ + #include + + static inline std::vector unpack_tensor_handle_list(PyObject* pyvec) {{ + std::vector result; + size_t result_len = PyList_GET_SIZE(pyvec); + result.reserve(result_len); + for (size_t i = 0; i < result_len; i++) {{ + // AtenTensorHandle is essentially a pointer + void* elem = PyCapsule_GetPointer(PyList_GET_ITEM(pyvec, i), NULL); + result.push_back(reinterpret_cast(elem)); + }} + return result; + }} + + static inline PyObject* pack_tensor_handle_list(const std::array& arr) {{ + PyObject* result = PyList_New({array_len}); + for (size_t i = 0; i < {array_len}; i++) {{ + PyObject *elem = + arr[i] == nullptr + ? Py_None + // Store AtenTensorHandle as PyCapsulate + : PyCapsule_New(reinterpret_cast(arr[i]), NULL, NULL); + PyList_SET_ITEM(result, i, elem); + }} + return result; + }} + + template <> inline std::vector parse_arg>(PyObject* args, size_t n) {{ + return unpack_tensor_handle_list(PyTuple_GET_ITEM(args, n)); + }} + + PyObject* inductor_entry_cpp(std::vector&& input_handles) {{ + // For outputs, we only allocate an array to hold returned tensor handles, + // not the actual output tensor storage. + std::array output_handles{{}}; + try {{ + inductor_entry_impl(input_handles.data(), output_handles.data()); + if (PyErr_Occurred()) {{ + return nullptr; + }} + return pack_tensor_handle_list(output_handles); + }} catch(std::exception const& e) {{ + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + }} catch(...) {{ + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return nullptr; + }} + }} + """ + ) + + @classmethod + def _get_uncompiled_header(cls, device: str) -> str | None: + return _get_cpp_wrapper_header(device) + + +@clear_on_fresh_cache +class HalideCodeCache(CppPythonBindingsCodeCache): + cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} + cache_clear = staticmethod(cache.clear) + _standalone_runtime_path: Optional[str] = None + prefix = textwrap.dedent( + """ + #include "{halideruntime_h}" + #include "{headerfile}" + #include + #include + + namespace c10 {{ + inline long div_floor_integer(long a, long b) {{ + if ((a<0) != (b<0)) {{ + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + }} + return a / b; + }} + }} + """ + ) + glue_template_cpp = prefix + textwrap.dedent( + """ + void kernel({argdefs}) {{ + {buffers} + int err = halide_kernel({buffer_names}); + if(err != 0) throw std::runtime_error("halide_kernel failed"); + }} + """ + ) + glue_template_cuda = prefix + textwrap.dedent( + """ + #include + static const halide_device_interface_t* cuda_interface = halide_cuda_device_interface(); + + void kernel({argdefs}, uintptr_t stream) {{ + {buffers} + int err = halide_kernel(reinterpret_cast(stream), {buffer_names}); + if(err != 0) throw std::runtime_error("halide_kernel failed"); + }} + """ + ) + standalone_runtime_cuda_init = textwrap.dedent( + """ + #include "{}" + #include + + static int acquire_context(void* user_context, + void** cuda_context_out, + bool create) {{ + return cuCtxGetCurrent(reinterpret_cast(cuda_context_out)); + }} + + static int release_context(void* user_context) {{ + return 0; + }} + + static int get_stream(void* user_context, + void* cuda_context, + void** stream_out) {{ + *stream_out = user_context; + return 0; + }} + + static int register_halide_hooks() {{ + halide_set_cuda_acquire_context(&acquire_context); + halide_set_cuda_release_context(&release_context); + halide_set_cuda_get_stream(&get_stream); + return 0; + }} + + int inductor_register_halide_hooks_result = register_halide_hooks(); + """ + ) + + @classmethod + def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]: + assert arg.shape is not None + assert arg.stride is not None and len(arg.shape) == len(arg.stride) + assert arg.offset is not None + data_ptr = f"{arg.alias_of or arg.name} + {arg.offset}" + if cuda: + device = f"reinterpret_cast({data_ptr})" + device_interface = "cuda_interface" + host = "nullptr" + flags = "halide_buffer_flag_device_dirty" + else: + device = "0" + device_interface = "nullptr" + host = f"reinterpret_cast({data_ptr})" + flags = "halide_buffer_flag_host_dirty" + + dims = [] + for size, stride in zip(arg.shape, arg.stride): + dims.append(f"halide_dimension_t(0, {size}, {stride})") + + return [ + f"halide_buffer_t {name};", + f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};", + f"{name}.device = {device};", + f"{name}.device_interface = {device_interface};", + f"{name}.host = {host};", + f"{name}.flags = {flags};", + f"{name}.type = {arg.halide_type()};", + f"{name}.dimensions = {len(dims)};", + f"{name}.dim = {name}_dims;", + f"{name}.padding = nullptr;", + ] + + @classmethod + def _codegen_glue(cls, meta: HalideMeta, headerfile: object) -> str: + is_cuda = meta.is_cuda() + assert is_cuda is ("user_context" in meta.target) + assert "no_runtime" in meta.target + buffers = [] + buffer_names = [] + for i, arg in enumerate(meta.argtypes): + if arg.is_buffer(): + buffer_names.append(f"&hl_buf_{i}") + buffers.extend(cls._codegen_buffer(f"hl_buf_{i}", arg, is_cuda)) + else: + assert "*" not in arg.ctype + buffer_names.append(arg.name) + buffers = "\n".join([f" {line}" for line in buffers]).lstrip() + + glue_template = cls.glue_template_cuda if is_cuda else cls.glue_template_cpp + glue_code = glue_template.format( + halideruntime_h=cls.find_header( + "HalideRuntimeCuda.h" if is_cuda else "HalideRuntime.h" + ), + headerfile=headerfile, + argdefs=", ".join( + f"{a.bindings_type()} {a.name}" + for a in meta.argtypes + if a.alias_of is None + ), + buffers=buffers, + buffer_names=", ".join(buffer_names), + ) + return glue_code + + @classmethod + @functools.cache + def config_hash(cls) -> str: + command_gen = CppBuilder( + name="O", + sources="I", + BuildOption=CppOptions(), + ) + command_line = command_gen.get_command_line() + return sha256_hash( + "\n".join( + [ + cls.glue_template_cpp, + cls.glue_template_cuda, + cls.standalone_runtime_cuda_init, + command_line, + ] + ).encode("utf-8") + ) + + @staticmethod + def _search_for_file(suffix: str, errmsg: str) -> str: + spec = importlib.machinery.PathFinder.find_spec("halide") + if spec is None or not spec.submodule_search_locations: + raise RuntimeError("halide python bindings not installed") + try: + search = spec.submodule_search_locations[0] + for file in os.listdir(search): + if file.endswith(".so"): + try: + out = subprocess.check_output( + ["ldd", os.path.join(search, file)] + ) + except subprocess.SubprocessError: + continue + m = re.search(r"(/.*)/libHalide.so", out.decode("utf-8")) + if m: + path = os.path.join(os.path.abspath(m.group(1)), suffix) + if os.path.exists(path): + return os.path.abspath(path) + except Exception as e: + raise RuntimeError(errmsg) from e + raise RuntimeError(errmsg) + + @staticmethod + @functools.cache + def find_libautoschedule(name: str) -> str: + sofile = f"libautoschedule_{name.lower()}.so" + if "HALIDE_LIB" in os.environ: + path = os.path.join(os.environ["HALIDE_LIB"], sofile) + if os.path.exists(path): + return path + errmsg = ( + f"Can't find {sofile}, set env HALIDE_LIB to the directory containing it" + ) + return HalideCodeCache._search_for_file(sofile, errmsg) + + @staticmethod + @functools.cache + def find_header(name: str) -> str: + if "HALIDE_INCLUDE" in os.environ: + path = os.path.join(os.environ["HALIDE_INCLUDE"], name) + if os.path.exists(path): + return path + if "HALIDE_LIB" in os.environ: + path = os.path.abspath( + os.path.join(os.environ["HALIDE_LIB"], f"../include/{name}") + ) + if os.path.exists(path): + return path + errmsg = ( + f"Can't find {name}, set env HALIDE_INCLUDE to the directory containing it" + ) + return HalideCodeCache._search_for_file(f"../include/{name}", errmsg) + + @classmethod + def generate_halide_async( + cls, meta: HalideMeta, source_code: str, submit_fn: Any = None + ) -> Callable[[], Any]: + dirpath = Path( + get_path( + code_hash( + source_code, + extra=repr((cls.config_hash(), meta)), + ), + "halide", + )[2] + ) + os.makedirs(dirpath, exist_ok=True) + wait_for_compile = None + genfile = str(dirpath / "generate_kernel.py") + libfile = str(dirpath / "halide_kernel.a") + headerfile = str(dirpath / "halide_kernel.h") + donefile = str(dirpath / "done") + lockfile = str(dirpath / "lock") + need_compile = not os.path.exists(donefile) + jobs: list[Any] = [] + if need_compile: + write_atomic(genfile, source_code) + cmd = [ + sys.executable, + genfile, + "-g", + "kernel", + "-o", + f"{dirpath}", + "-f", + "halide_kernel", + "-e", + "static_library,h,schedule", + ] + if meta.scheduler: + cmd.extend(["-p", cls.find_libautoschedule(meta.scheduler)]) + cmd.extend(meta.args()) + jobs.append(functools.partial(subprocess.check_call, cmd)) + + binding_types = [ + arg.bindings_type() for arg in meta.argtypes if arg.alias_of is None + ] + if meta.is_cuda(): + binding_types.append("uintptr_t") # stream + bindings_future = cls.load_pybinding_async( + binding_types, + cls._codegen_glue(meta, headerfile), + extra_flags=(libfile, cls.build_standalone_runtime()), + submit_fn=jobs.append if need_compile else None, + device_type="cuda" if meta.is_cuda() else "cpu", + ) + + if need_compile: + jobs.append(functools.partial(touch, donefile)) + task = functools.partial(_worker_task_halide, lockfile, jobs) + if submit_fn: + wait_for_compile = submit_fn(task).result + else: + task() + + def load() -> Callable[[], Any]: + if wait_for_compile: + wait_for_compile() + return bindings_future() + + return load + + @classmethod + def generate_halide(cls, *args: Any, **kwargs: Any) -> Callable[[], Any]: + return cls.generate_halide_async(*args, **kwargs)() + + @classmethod + def build_standalone_runtime(cls) -> str: + if cls._standalone_runtime_path and os.path.exists( + cls._standalone_runtime_path + ): + return cls._standalone_runtime_path + device_type = "cuda" if torch.cuda.is_available() else "cpu" + libname = "libStandaloneHalideRuntime.so" + target = "host-cuda" if device_type == "cuda" else "host" + if cls._standalone_runtime_path: + assert not os.path.exists(cls._standalone_runtime_path) + # We hit this case in unittests when we run with fresh_cache() + # Generating a fresh runtime over and over causes errors because we initialize + # cuda hundreds of times in the same process and run out of file descriptors. + # Workaround by jail breaking the current fresh_cache(). + base = default_cache_dir() + else: + base = cache_dir() + dirpath = Path(base) / f"halide-runtime-{target}-{cls.config_hash()}" + os.makedirs(dirpath, exist_ok=True) + done_file = str(dirpath / "done") + lock_file = str(dirpath / "lock") + hook_file = str(dirpath / "hooks.cpp") + a_file = str(dirpath / "standalone_halide_runtime.a") + so_file = str(dirpath / libname) + if not os.path.exists(done_file): + import halide as hl # type: ignore[import-untyped,import-not-found] + + from torch.utils._filelock import FileLock + + with FileLock(lock_file, LOCK_TIMEOUT): + if not os.path.exists(done_file): + with open(hook_file, "w") as f: + if device_type == "cuda": + f.write( + cls.standalone_runtime_cuda_init.format( + cls.find_header("HalideRuntimeCuda.h") + ) + ) + hl.compile_standalone_runtime(a_file, hl.Target(target)) + + name, output_dir = get_name_and_dir_from_output_file_path(so_file) + halide_cmd_gen = CppBuilder( + name=name, + sources=[hook_file, a_file], + output_dir=output_dir, + BuildOption=CppTorchDeviceOptions( + device_type=device_type, + ), + ) + + subprocess.check_call( + shlex.split(halide_cmd_gen.get_command_line()) + ) + touch(done_file) + assert os.path.exists(so_file) + cls._standalone_runtime_path = so_file + return so_file + + @classmethod + def _get_uncompiled_header(cls, device: str) -> str | None: + """Header precompiling is currently disabled for halide.""" + return None + + +def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None: + from torch.utils._filelock import FileLock + + try: + with FileLock(lockfile, LOCK_TIMEOUT): + for job in jobs: + job() + except subprocess.SubprocessError as e: + if os.environ.get("HALIDE_REPRO") == "1": + cmd: list[Any] + python, script, *cmd = getattr(e, "cmd", ("", "", "")) + if os.path.basename(python).startswith("python"): + code = open(script).read() + main = " hl.main()" + assert code.count(main) == 1 + + class Out: + def __repr__(self) -> str: + return "out" + + ci = cmd.index("-o") + assert isinstance(ci, int) + cmd[ci + 1] = Out() + repl = textwrap.indent( + textwrap.dedent( + f"""\ + import sys, tempfile + with tempfile.TemporaryDirectory() as out: + sys.argv = {["repro.py", *cmd]!r} + hl.main() + """ + ), + " ", + ) + code = code.replace(main, repl) + with open("repro.py", "w") as fd: + fd.write(code.lstrip()) + raise RuntimeError(f"wrote repro.py: {e}") from e + raise + + +def touch(filename: str) -> None: + open(filename, "a").close() + + +@clear_on_fresh_cache +class PyCodeCache: + # Track the loaded modules so we can remove the on-disk artifacts when + # clearing the cache. Note also that we may load the same path more + # than once, but attach different attributes, i.e., due to different + # constant values. + modules: list[ModuleType] = [] + + # Modules loaded without extra attributes are stored here, those do not + # need to be re-loaded. + modules_no_attr: dict[str, ModuleType] = {} + + linemaps: dict[str, list[tuple[Any, ...]]] = {} + + @classmethod + def write(cls, source_code: str, extra: str = "") -> tuple[str, str]: + return write(source_code, "py", extra=extra) + + @classmethod + def load(cls, source_code: str, extra: str = "") -> ModuleType: + key, path = write(source_code, "py", extra=extra) + return cls.load_by_key_path(key, path) + + @classmethod + def load_by_key_path( + cls, + key: str, + path: str, + linemap: Optional[list[tuple[int, str]]] = None, + attrs: Optional[dict[str, Any]] = None, + ) -> ModuleType: + if linemap is None: + linemap = [] + + # we only cache when attrs is None + if attrs is None and path in cls.modules_no_attr: + return cls.modules_no_attr[path] + + in_toplevel = in_toplevel_process() + mod = _reload_python_module(key, path, set_sys_modules=in_toplevel) + + # unzip into separate lines/nodes lists + if in_toplevel: + cls.linemaps[path] = list(zip(*linemap)) + + if attrs is not None: + for k, v in attrs.items(): + setattr(mod, k, v) + + if in_toplevel: + # we only cache when attrs is None + if attrs is None: + cls.modules_no_attr[path] = mod + + cls.modules.append(mod) + return mod + + @classmethod + def cache_clear(cls, purge: bool = False) -> None: + """ + Clear the in-memory module cache. If purge=True, also delete all the + corresponding on-disk source files. + """ + if purge: + for mod in cls.modules: + try: + assert mod.__file__ + os.remove(mod.__file__) + except FileNotFoundError: + pass + cls.modules.clear() + cls.modules_no_attr.clear() + + @classmethod + @functools.cache + def stack_frames_for_code( + cls, path: str, lineno: int + ) -> Optional[list[dict[str, Any]]]: + if path not in cls.linemaps: + return None + if len(cls.linemaps[path]) == 0: + return None + # [(starting_line, ), ...] + lines, nodes = cls.linemaps[path] + p = bisect_right(lines, lineno) + if p == 0: + return None + entry = nodes[p - 1] + if not entry: + return None + + def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]: + # ideally fx stores stack traces as data rather than a string + # but this is not along a performance critical path + regex = r'File "(.+)", line (\d+), in (.+)\n' + matches = re.findall(regex, stack_trace) + return [ + {"filename": f, "line": int(l), "name": n} + for f, l, n in reversed(matches) + ] + + return parse_stack_trace(entry) + + +def _load_triton_kernel_from_source( + kernel_name: str, source_code: str +) -> CachingAutotuner: + return getattr(PyCodeCache.load(source_code), kernel_name) + + +def _cuda_compiler() -> Optional[str]: + if cuda_env.nvcc_exist(config.cuda.cuda_cxx): + return config.cuda.cuda_cxx + if config.is_fbcode(): + return os.path.join(build_paths.sdk_home, "bin", "nvcc") + if cuda_env.nvcc_exist(os.getenv("CUDACXX")): + return os.getenv("CUDACXX", "") + if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")): + return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc")) + return "nvcc" + + +def _cutlass_path() -> str: + if config.is_fbcode(): + from libfb.py import parutil + + return parutil.get_dir_path("cutlass-3-headers") + else: + return config.cuda.cutlass_dir + + +def _cutlass_paths() -> list[str]: + return [ + "include", + "tools/library/include", + "tools/library/src", + "tools/util/include", + ] + + +def _clone_cutlass_paths(build_root: str) -> list[str]: + paths = _cutlass_paths() + cutlass_root = _cutlass_path() + for path in _cutlass_paths(): + old_path = os.path.join(cutlass_root, path) + new_path = os.path.join(build_root, path) + shutil.copytree(old_path, new_path, dirs_exist_ok=True) + return paths + + +def _cutlass_include_paths() -> list[str]: + cutlass_path = _cutlass_path() + return [ + # Use realpath to get canonical absolute paths, in order not to mess up cache keys + os.path.realpath(os.path.join(cutlass_path, path)) + for path in _cutlass_paths() + ] + + +@torch_key_cache +def cutlass_key() -> bytes: + """ + Compute a key representing the state of the CUTLASS library. + + Note: OSS and fbcode will have different keys. + """ + if config.is_fbcode(): + with importlib.resources.path("cutlass", "src_hash.txt") as resource_path: + with open(resource_path) as resource_file: + return resource_file.read().encode() + + combined_hash = hashlib.sha256() + build_code_hash([config.cuda.cutlass_dir], "", combined_hash) + return combined_hash.digest() + + +def _cuda_lib_options() -> list[str]: + """ + Util function for CUTLASS backend to find the correct CUDA libraries. + """ + _set_gpu_runtime_env() # cpp_extension consults the env + from torch.utils import cpp_extension + + lpaths = cpp_extension.library_paths(device_type="cuda") + if use_re_build(): + lpaths += [ + build_paths.sdk_lib, + os.path.join(build_paths.sdk_lib, "stubs"), + ] + extra_ldflags: list[str] = [] + if is_linux(): + _transform_cuda_paths(lpaths) + for path in lpaths: + if "torch/lib" in path: + # don't want to depend on pytorch + continue + # -rpath ensures the DLL can find its dependencies when loaded, even + # if the library path is non-standard. + extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"]) + extra_ldflags.append("-lcuda") + extra_ldflags.append("-lcudart") + else: + raise NotImplementedError( + "Unsupported env, failed to find cuda libs! Currently only Linux is supported." + ) + return extra_ldflags + + +def _nvcc_host_compiler_options() -> list[str]: + return [ + "-fPIC", + "-fno-strict-aliasing", + "-fvisibility=hidden", + "-Wconversion", + ] + + +def _nvcc_arch_as_compile_option() -> str: + arch = cuda_env.get_cuda_arch() + if arch == "90": + # Required by cutlass compilation. + return "90a" + if arch == "100": + return "100a" + return arch + + +def _nvcc_compiler_options() -> list[str]: + arch = _nvcc_arch_as_compile_option() + code = [f"sm_{arch}", f"compute_{arch}"] + if config.cuda.enable_cuda_lto: + code += [f"lto_{arch}"] + options = [ + "-t=0", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + "-w", + f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", + config.cuda.compile_opt_level, + "-std=c++17", + "--expt-relaxed-constexpr", + "-DNDEBUG", + ] + if config.is_fbcode(): + options.extend(["-ccbin", os.path.dirname(build_paths.gcc)]) + if config.cuda.enable_debug_info: + options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) + if config.cuda.enable_ptxas_info: + options.extend( + [ + "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) + "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels + "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels + "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) + "--source-in-ptx", + ] + ) # Annotate the ptx file with source information + if config.cuda.use_fast_math: + options.extend( + [ + "--use_fast_math", + "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", + ] + ) + return options + + +def cuda_compile_command( + src_files: list[str], + dst_file: str, + dst_file_ext: str, + extra_args: Optional[list[str]] = None, +) -> str: + if extra_args is None: + extra_args = [] + if use_re_build(): + build_path = os.path.dirname(dst_file) + include_paths = _clone_cutlass_paths(build_path) + src_files = [os.path.basename(src_file) for src_file in src_files] + dst_file = os.path.basename(dst_file) + else: + include_paths = _cutlass_include_paths() + cuda_lib_options = _cuda_lib_options() + nvcc_host_compiler_options = _nvcc_host_compiler_options() + nvcc_compiler_options = _nvcc_compiler_options() + options = ( + nvcc_compiler_options + + extra_args + + [ + f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}" + for opt in nvcc_host_compiler_options + ] + + ["-I" + path for path in include_paths] + + cuda_lib_options + ) + src_file = " ".join(src_files) + res = "" + if dst_file_ext == "o": + res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}" + elif dst_file_ext == "so": + options.append("-shared") + res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + elif dst_file_ext == "exe": + res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + log.debug("CUDA command: %s", res) + return res + + +class DLLWrapper: + """A wrapper for a dynamic library.""" + + def __init__( + self, + lib_path: str, + ) -> None: + self.lib_path = lib_path + self.is_open = False + self.DLL = cdll.LoadLibrary(lib_path) + self.is_open = True + + def close(self) -> None: + if self.is_open: + self._dlclose() + self.is_open = False + + def _dlclose(self) -> None: + f_dlclose = None + + if is_linux(): + syms = CDLL(None) + if not hasattr(syms, "dlclose"): + # Apline Linux + syms = CDLL("libc.so") + + if hasattr(syms, "dlclose"): + f_dlclose = syms.dlclose + elif is_windows(): + import ctypes + + kernel32 = ctypes.CDLL("kernel32", use_last_error=True) + + f_dlclose = kernel32.FreeLibrary + else: + raise NotImplementedError("Unsupported env, failed to do dlclose!") + + if f_dlclose is not None: + if is_linux(): + f_dlclose.argtypes = [c_void_p] + f_dlclose(self.DLL._handle) + elif is_windows(): + import ctypes + from ctypes import wintypes + + f_dlclose.argtypes = [wintypes.HMODULE] + f_dlclose(self.DLL._handle) + else: + log.warning( + "dll unloading function was not found, library may not be unloaded properly!" + ) + + def __getattr__(self, name: str) -> Callable[..., None]: + if not self.is_open: + raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}") + + method = getattr(self.DLL, name) + + def _wrapped_func(*args: Any) -> None: + err = method(*args) + if err: + raise RuntimeError(f"Error in function: {method.__name__}") + + return _wrapped_func + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + def __del__(self) -> None: + self.close() + + +@lru_cache +def binary_error_path(output_path: str) -> str: + """ + standard format for the error path + """ + return output_path + ".error" + + +@clear_on_fresh_cache +class CUDACodeCache: + """ + A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS. + This class handles writing source code to files, compiling them into shared objects, and caching + the results to avoid redundant compilations. It also manages error handling and logging for the + compilation process. + """ + + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + error_json: Optional[str] = None + + cache: dict[str, CacheEntry] = {} + aot_kernels_o: list[str] = [] + _SOURCE_CODE_SUFFIX = "cu" + + @staticmethod + def cache_clear() -> None: + CUDACodeCache.cache.clear() + CUDACodeCache.aot_kernels_o.clear() + + @staticmethod + @lru_cache(maxsize=4) + def get_kernel_binary_remote_cache( + caching_enabled: bool, caching_available: bool + ) -> Optional[Any]: + """ + Get or create the class instance of the CUTLASSKernelBinaryRemoteCache. + + Args: + caching_enabled: Whether binary remote caching is enabled + caching_available: Whether we're in fbcode environment + + Returns: + CUTLASSKernelBinaryRemoteCache: The class instance of the kernel binary remote cache + """ + if not caching_enabled: + log.debug("CUTLASSKernelBinaryRemoteCache not requested, skipping") + return None + if not caching_available: + return None + + try: + from torch._inductor.fb.kernel_binary_remote_cache import ( + CUTLASSKernelBinaryRemoteCache, + ) + + return CUTLASSKernelBinaryRemoteCache() + except ImportError: + log.debug( + "CUTLASSKernelBinaryRemoteCache not available, remote caching disabled" + ) + return None + + @classmethod + def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + if config.cuda.cutlass_hash_with_compile_cmd: + cuda_command = repr( + cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + extra = cuda_command + else: + extra = repr( + [ + # nvcc and cuda hash + _cuda_compiler(), + # cutlass flags and gcc hash + _nvcc_compiler_options(), + # flags + _nvcc_host_compiler_options(), + # cutlass key + cutlass_key(), + # hack to deal with AOTI .o compilation + ] + + [dst_file_ext] + if dst_file_ext == "o" + else [] + ) + key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=extra) + return key, input_path + + @classmethod + def compile( + cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None + ) -> tuple[str, str, str]: + """ + Compiles CUDA source_code into a file with dst_file_ext extension. + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + error_path = binary_error_path(output_path) + binary_remote_cache = cls.get_kernel_binary_remote_cache( + caching_enabled=config.cuda.use_binary_remote_cache + and not config.force_disable_caches, + caching_available=config.is_fbcode(), + ) + if binary_remote_cache is not None: + # The remote cache implementation will only download if the file does + # not already exist locally + binary_remote_cache.get(output_path, error_path) + + if os.path.exists(error_path): + with open(error_path, encoding="utf-8") as fh: + error_json = fh.read() + cmd_parts, error_output = json.loads(error_json) + if ( + binary_remote_cache is not None + and config.cuda.upload_to_binary_remote_cache + ): + # This ensures that a local error is uploaded to the remote cache, + # as we make no assumptions about the remote cache having the same + # information as the local cache + binary_remote_cache.put( + error_path, config.cuda.binary_remote_cache_force_write + ) + cls.cache[key] = CUDACodeCache.CacheEntry( + input_path, output_path, error_json + ) + raise exc.CUDACompileError(cmd_parts, error_output) + if not os.path.exists(output_path): + cmd = cuda_compile_command( + [input_path], output_path, dst_file_ext, extra_args + ) + with open(input_path, "a") as f: + f.write("\n") + f.write(f"// CUDA Compile cmd\n// {cmd}\n") + start_time = time() + log.debug("CUDA Compilation: %s", cmd) + cmd_parts = cmd.split(" ") + try: + if use_re_build(): + from triton.fb.re_build_helper import run_build_command + + run_build_command( + cmd_parts, + os.path.dirname(input_path), + os.path.basename(output_path), + ) + else: + subprocess.check_output( + cmd_parts, stderr=subprocess.STDOUT, env=os.environ + ) + except subprocess.CalledProcessError as error: + cls._record_cuda_compile_error( + error.output.decode("utf-8"), + key, + cmd_parts, + input_path, + output_path, + binary_remote_cache, + ) + raise exc.CUDACompileError(cmd_parts, error.output) from error + except Exception as error: + if "COMPILE FAILED WITH" in str(error): + cls._record_cuda_compile_error( + str(error), + key, + cmd_parts, + input_path, + output_path, + binary_remote_cache, + ) + raise exc.CUDACompileError(cmd_parts, str(error)) from error + raise error + end_time = time() + log_duration_msg = f"CUDA Compilation took {end_time - start_time} seconds. Compile command: {cmd}" + log.info(log_duration_msg) + + else: + log.debug( + "CUDA Compilation skipped: %s since output already exists", + input_path, + ) + # Upload to remote cache if enabled + if ( + binary_remote_cache is not None + and config.cuda.upload_to_binary_remote_cache + ): + # will log on errors, but not fail out + binary_remote_cache.put( + output_path, config.cuda.binary_remote_cache_force_write + ) + cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path, None) + cache_entry: CUDACodeCache.CacheEntry = cls.cache[key] + if cache_entry.error_json is not None: + # Restore cached Exception and raise it as if we had compiled + cmd_parts, error_output = json.loads(cache_entry.error_json) + raise exc.CUDACompileError(cmd_parts, error_output.encode("utf-8")) + return (cls.cache[key].output_path, key, input_path) + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) + + @classmethod + def _record_cuda_compile_error( + cls, + error_str: str, + key: str, + cmd_parts: list[str], + input_path: str, + output_path: str, + # Any here, as the import and type will only work in fbcode + # TODO: Make the typing hint strong here + binary_remote_cache: Any = None, + ) -> None: + error_json = json.dumps([cmd_parts, error_str]) + cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path, error_json) + error_path = binary_error_path(output_path) + with open(error_path, "w", encoding="utf-8") as fh: + fh.write(error_json) + + # Upload to remote cache directly from memory if enabled + if ( + binary_remote_cache is not None + and config.cuda.upload_to_binary_remote_cache + ): + binary_remote_cache.put( + error_path, config.cuda.binary_remote_cache_force_write + ) + + +@clear_on_fresh_cache +class ROCmCodeCache: + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + + cache: dict[str, CacheEntry] = {} + aot_kernels_o: list[str] = [] + _SOURCE_CODE_SUFFIX = "cpp" + _logged_compiler_version = False + + @staticmethod + def cache_clear() -> None: + ROCmCodeCache.cache.clear() + ROCmCodeCache.aot_kernels_o.clear() + + @classmethod + def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + cuda_command = repr( + rocm_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + key, input_path = write( + source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command + ) + return key, input_path + + @classmethod + def compile( + cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None + ) -> tuple[str, str, str]: + """ + Compiles source_code into a file with dst_file_ext extension, + using the compile command specific for the ROCm platform. + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + if not cls._logged_compiler_version: + cls._logged_compiler_version = True + log.debug(get_compiler_version_info(str(rocm_compiler()))) + + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if not os.path.exists(output_path): + cmd = rocm_compile_command( + [input_path], output_path, dst_file_ext, extra_args + ) + start_time = time() + cmd_parts = cmd.split(" ") + try: + output = subprocess.check_output( + cmd_parts, + stderr=subprocess.STDOUT, + text=True, + env=os.environ, + ) + log.debug("Compilation output: %s", output) + except subprocess.CalledProcessError as error: + raise exc.CUDACompileError(cmd_parts, error.output) from error + end_time = time() + log_duration_msg = f"Compilation took {end_time - start_time} seconds. Compile command: {cmd}" + log.info(log_duration_msg) + else: + log.debug( + "Skip compiling %s: output %s already exists", + input_path, + output_path, + ) + cls.cache[key] = ROCmCodeCache.CacheEntry(input_path, output_path) + + return (cls.cache[key].output_path, key, input_path) + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) + + +class CodeCacheFuture: + def result(self) -> Callable[..., Any]: + raise NotImplementedError + + +class LambdaFuture(CodeCacheFuture): + def __init__( + self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None + ) -> None: + self.result_fn = result_fn + self.future = future + + def result(self) -> Callable[..., Any]: + return self.result_fn() + + +class StaticAutotunerFuture(CodeCacheFuture): + """ + A statically launchable CachingAutotuner, loaded from TritonBundler + """ + + def __init__(self, static_autotuner: CachingAutotuner) -> None: + # Pickled version of CachingAutotuner + self.static_autotuner = static_autotuner + # This needs to be set in AsyncCompile.triton, in case + # we need to reload the CachingAutotuner from its source code + # We don't store the source code on the CachingAutotuner itself + # since it can be very large. + self.reload_kernel_from_src: Optional[Callable[[], Any]] = None + + def result(self) -> CachingAutotuner: + assert self.reload_kernel_from_src is not None + with dynamo_timed("StaticAutotunerFuture.warm_precompile"): + self.static_autotuner.recheck_autotune_cache( + reload_kernel_from_src=self.reload_kernel_from_src + ) + self.static_autotuner.precompile( # type: ignore[union-attr] + warm_cache_only=False, + reload_kernel=self.reload_kernel_from_src, + static_triton_bundle_key=None, # no need to save again + ) + return self.static_autotuner diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d6c4d1002608f3931965220e3aa3f23715a01d10 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py @@ -0,0 +1,31 @@ +import re + +import torch +from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE + + +# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like: +# "... +# from ..codecache import CudaKernelParamCache +# ..." +# In such cases, we do not need to hipify_torch the original class/file name in codegen/codecache + + +def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str: + if torch.version.hip is None and not force_hipify: + return source_codes + + def c2_repl(m: re.Match[str]) -> object: + return PYTORCH_MAP[m.group(0)] + + # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch, + # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching + # keyword at the beginning of code line. However, this can happen in codegen, + # which will cause the pattern to not match. + + # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example + # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA" + RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)") + + source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes) # type: ignore[arg-type] + return source_codes diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/block_analysis.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/block_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b418980633c7fbbf368847abca6a00e49cfeda --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/block_analysis.py @@ -0,0 +1,175 @@ +import collections +import functools +import textwrap +from typing import Optional + +import sympy +from sympy import Expr, Symbol + +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + +from ..utils import sympy_dot, sympy_subs +from ..virtualized import V + + +class BlockPatternMatcher: + """ + Matches block indexing expressions. + """ + + @classmethod + def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: + """ + Given a sympy expression, return the subexpression comprised only of terms + involving the specified symbol. + + For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`, + this returns `x * 5 + x ** 2`. + """ + expr = cls._preprocess(expr) + return sympy.S.Zero + sum( + term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols + ) + + @staticmethod + def get_slice_numels(dims: list[Expr]) -> list[Expr]: + """ + Compute the cumulative size of each dimension's slice. + This proceeds from the last dim up to the second. + """ + numels = collections.deque([sympy.S.One]) + for dim in dims[:0:-1]: + numel = dim * numels[0] + numels.appendleft(numel) + return [*numels] + + @staticmethod + def _preprocess(expr: Expr) -> Expr: + # Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y. + return expr.expand(identity=True) + + @classmethod + def match_mod_div_block_expr( + cls, + index: Expr, + index_var: Symbol, + numel: Expr, + num_dims: int, + ) -> Optional[tuple[list[Expr], list[Expr], list[Expr]]]: + """ + Matches modular indexing expressions, converting them to implied block dimensions and strides. + See triton.py for more information. + """ + index = cls._preprocess(index) + + # Pattern match to find the strides and offset. + wild = functools.partial(sympy.Wild, exclude=[index_var]) + dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)] + strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)] + + # The first dimension's index is computed by division. + # The remaining are computed by modulo. + slice_numels = cls.get_slice_numels(dims[:num_dims]) + block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [ + ModularIndexing(index_var, numel, dim) + for dim, numel in zip(dims[1:], slice_numels[1:]) + ] + + # Calculate a linear index from block indices. + match_expr = sympy_dot(strides, block_index_exprs) + + # Heuristic: if the number of dimensions is high, check that the minimum requirements + # are met before attempting an expensive full match. see triton.py:match_mod_div_block + # for more details. In short, here we check that each subexpression in sympy.Add contains + # only FloorDiv or ModularIndexing expressions. + if num_dims >= 5: + stride, denom, other = sympy.symbols("stride denominator other", cls=wild) + mod_div_pattern = stride * ModularIndexing(index_var, denom, other) + floor_div_pattern = stride * FloorDiv(index_var, denom) + first_dim_floor_div_matched = False + match_failed = False + for arg in sympy.Add.make_args(index): + if arg.match(floor_div_pattern): + # There should only be a single FloorDiv(index, denom) expression + # corresponding to the first dimension + if first_dim_floor_div_matched: + match_failed = True + break + first_dim_floor_div_matched = True + elif arg.match(mod_div_pattern): + continue + else: + match_failed = True + break + + if match_failed: + return None + + # Pattern match. + match = index.match(match_expr) + if match is None: + return None + + # Provide default values for unmatched dims and strides. + for dim in dims[1:]: + if dim not in match: + match[dim] = sympy.S.One + for stride in strides[1:]: + if stride not in match: + match[stride] = sympy.S.Zero + + sizevars = V.graph.sizevars + + def get_match(expr: Expr) -> Expr: + return sizevars.lookup_precomputed_size(match[expr]) + + # Replace wildcards with matched expressions. + dims = [dims[0]] + [get_match(dim) for dim in dims[1:]] + strides = [get_match(stride) for stride in strides] + slice_numels = cls.get_slice_numels(dims) + block_index_exprs = [sympy_subs(expr, match) for expr in block_index_exprs] + + # The leading dimension is not directly matched in our expression. + # We solve for it by dividing the range tree numel by the product of + # all other dimensions. We quit if they are not known to be divisible. + assert dims[0] not in match, "Expected not to match the leading dimension!" + if not sizevars.statically_known_multiple_of(numel, slice_numels[0]): + return None + dims[0] = numel / slice_numels[0] + + # Sanity check that we can recover the index from the matched subexpressions. + matched_index = sympy_dot(strides, block_index_exprs) + assert sizevars.statically_known_equals( + # New precomputed replacements may be generated when the `get_match` function + # above is called, but the `index` that is being matched has not been updated. + # So remove them when checking for equivalence e.g. if ps0=3*s0 and + # index=3*s0*expr, matched_index=ps0*expr, then index == matched_index + sizevars.remove_precomputed_replacements(matched_index), + sizevars.remove_precomputed_replacements(index), + ), textwrap.dedent( + f""" + Invalid match! + Index: {index} + Matched expression: {matched_index} + """ + ) + + return dims, strides, block_index_exprs + + @classmethod + def match_affine_block_expr( + cls, + index: Expr, + index_var: Symbol, + ) -> Optional[Expr]: + """ + Matches simple expressions of the form stride * index, returning the + stride. + """ + index = cls._preprocess(index) + stride = sympy.Wild("stride", exclude=[index_var]) + m = index.match(index_var * stride) + if m is None: + return None + + return m[stride] diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/common.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/common.py new file mode 100644 index 0000000000000000000000000000000000000000..15b956ba0939d421b0f9c93e358de0b6ed5d5a7d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/common.py @@ -0,0 +1,2691 @@ +from __future__ import annotations + +import atexit +import contextlib +import dataclasses +import enum +import functools +import itertools +import logging +import math +import operator +import os +import re +import tempfile +from abc import ABC, abstractmethod +from enum import auto, Enum +from itertools import chain +from typing import ( + Any, + Callable, + cast, + ClassVar, + Generic, + NamedTuple, + Optional, + TYPE_CHECKING, + Union, +) +from typing_extensions import Self, TypeVar + +import sympy + +import torch +import torch.fx +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from .. import config, metrics +from ..dtype_propagation import DtypePropagationOpsHandler +from ..ops_handler import BasicMathOpsMixin, DefaultHandler +from ..utils import ( + boolean_ops, + DeferredLineBase, + generate_assert, + get_current_backend, + IndentedBuffer, + ir_dataclass, + ScopedDict, + sympy_dot, + sympy_index_symbol, + sympy_subs, + triton_type, + unique, +) +from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V + + +if TYPE_CHECKING: + from collections.abc import Iterator, MutableMapping, Sequence + + from torch.fx import GraphModule + + from ..custom_graph_pass import CustomGraphModulePass + from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode + from ..loop_body import LoopBody + from ..scheduler import BaseScheduling, Scheduler, SchedulerNode + from .wrapper import PythonWrapperCodegen + + _T = TypeVar("_T") + SchedulingConstructor = Callable[[Optional[Scheduler]], BaseScheduling] + WrapperConstructor = type[PythonWrapperCodegen] + SymbolLike = Union[str, sympy.Symbol] + + # OpVarT should really be Union[CSEVariable, str], however this + # causes typing errors in subclasses (defined in other files). + OpVarT = str + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +log = logging.getLogger(__name__) + + +def data_type_logger(msg: str) -> None: + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Data type propagation: %s", msg) + + +@dataclasses.dataclass +class FileBackedGraphModule: + """ + Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these + map back to a GraphModule instead of Python source. + """ + + gm: GraphModule + compiled_fn: Callable[..., Any] + + def __post_init__(self) -> None: + # Write the code to a file for compatibility with debugging utilities. + # The file is deleted upon program termination. + self.tempfile = tempfile.NamedTemporaryFile( + mode="w+", suffix=".py", delete=False + ) + atexit.register(os.remove, self.tempfile.name) + with self.tempfile as f: + f.write(self.value) + + @property + def __file__(self) -> str: + return self.tempfile.name + + def call(self, args: list[Any]) -> Any: + return self.compiled_fn(*args) + + @property + def value(self) -> str: + return self.gm.code + + +class WorkspaceZeroMode(enum.Enum): + UNINITIALIZED = 0 + ZERO_ON_CALL = 1 # kernel may leave workspace dirty + ZERO_PER_GRAPH = 2 # must be re-zeroed by kernel + + @staticmethod + def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode: + if a == b or b == WorkspaceZeroMode.UNINITIALIZED: + return a + if a == WorkspaceZeroMode.UNINITIALIZED: + return b + raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})") + + @staticmethod + def from_bool(zero_fill: bool) -> WorkspaceZeroMode: + if zero_fill: + return WorkspaceZeroMode.ZERO_ON_CALL + return WorkspaceZeroMode.UNINITIALIZED + + +class CodegenSymbol(ABC): + """ + An IR object possibly corresponding to a variable in the wrapper code. + """ + + @abstractmethod + def get_name(self) -> str: + pass + + @abstractmethod + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + pass + + +@ir_dataclass(frozen=True) +class WorkspaceArg(CodegenSymbol): + """A temporary buffer used for a single kernel, then discarded. + + Not registered as a traditional buffer since there are no users, + so it would be dead code eliminated. + + Args: + nbytes: The size of the buffer in bytes. + zero_fill: Whether the buffer should be initialized to zero. + + """ + + count: sympy.Expr + zero_mode: WorkspaceZeroMode + device: torch.device + outer_name: str + inner_name: str = "ws_ptr" + dtype: torch.dtype = torch.uint8 + + @staticmethod + def unique_name(prefix: str = "workspace_") -> str: + return f"{prefix}{next(V.graph.workspace_id)}" + + @staticmethod + def can_join(a: WorkspaceArg, b: WorkspaceArg) -> bool: + return ( + a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device + ) + + @staticmethod + def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: + return WorkspaceArg( + count=a.count + b.count, + zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode), + dtype=a.dtype, + device=a.device, + inner_name=a.inner_name, + outer_name=a.outer_name, + ) + + @staticmethod + def maximum(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: + assert ( + a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name + ) + return WorkspaceArg( + count=sympy.Max(a.count, b.count), + zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode), + dtype=a.dtype, + device=a.device, + inner_name=a.inner_name, + outer_name=a.outer_name, + ) + + # These methods let WorkspaceArg pretend it is a buffer to reuse allocation code + def get_device(self) -> torch.device: + return self.device + + get_device_or_error = get_device + + def get_dtype(self) -> torch.dtype: + return self.dtype + + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + return self.get_layout().get_example() + + def get_layout(self) -> FixedLayout: + from ..ir import FixedLayout + + return FixedLayout( + device=self.device, + dtype=self.dtype, + size=[self.count], + stride=[1], + ) + + @property + def layout(self) -> FixedLayout: + return self.get_layout() + + get_output_spec = get_layout + maybe_get_output_spec = get_layout + maybe_get_layout = get_layout + + def get_offset(self) -> sympy.Expr: + return sympy.S.Zero + + def get_size(self) -> list[sympy.Expr]: + return [self.count] + + def get_stride(self) -> list[sympy.Expr]: + return [sympy.S.One] + + def get_name(self) -> str: + return self.outer_name + + def get_inputs_that_alias_output(self) -> list[str]: + return [] + + +class TritonScratchWorkspace: + def __init__(self, size: int, generate_dtype_str: Callable[..., str]): + self.size = size + self._generate_dtype_str = generate_dtype_str + + def generate_dtype_str(self) -> str: + return self._generate_dtype_str() + + +@dataclasses.dataclass +class TensorArg: + name: str + buffer: str + dtype: torch.dtype + offset: sympy.Expr = sympy.S.Zero # c++ only + alias_of: Optional[str] = None # halide only + + +@dataclasses.dataclass +class SizeArg: + name: str + expr: sympy.Expr + + @property + def alias_of(self) -> Optional[str]: + return None + + +@dataclasses.dataclass +class ConstexprArg: + name: str + + +@dataclasses.dataclass +class TMADescriptorArg: + name: str + api_type: str # "experimental" or "stable" + block_shape: Optional[list[sympy.Expr]] # only needed for "stable" + dtype: Optional[torch.dtype] # only needed for "stable" + + +@dataclasses.dataclass +class DeviceCodegen: + scheduling: SchedulingConstructor + wrapper_codegen: WrapperConstructor + cpp_wrapper_codegen: Optional[WrapperConstructor] = None + + +KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg] + +device_codegens: dict[str, DeviceCodegen] = {} + + +class DeviceOpOverrides: + def import_get_raw_stream_as(self, name: str) -> str: + raise NotImplementedError + + def set_device(self, device_idx: int) -> str: + raise NotImplementedError + + def synchronize(self) -> str: + raise NotImplementedError + + def device_guard(self, device_idx: int) -> str: + raise NotImplementedError + + def cpp_device_guard(self) -> str: + raise NotImplementedError + + def cpp_aoti_device_guard(self) -> str: + raise NotImplementedError + + def cpp_stream_guard(self) -> str: + raise NotImplementedError + + def cpp_aoti_stream_guard(self) -> str: + raise NotImplementedError + + def cpp_getStreamFromExternal(self) -> str: + raise NotImplementedError + + def kernel_header(self) -> str: + raise NotImplementedError + + def kernel_driver(self) -> str: + raise NotImplementedError + + def cpp_stream_type(self) -> str: + raise NotImplementedError + + def aoti_get_stream(self) -> str: + raise NotImplementedError + + def cpp_kernel_type(self) -> str: + raise NotImplementedError + + def cpp_device_ptr(self) -> str: + raise NotImplementedError + + def tma_descriptor_helpers(self) -> str: + raise NotImplementedError + + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: + # optionally return (scratch definition, arg name) + raise NotImplementedError + + +device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} +custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {} + + +# The code generated by Inductor consists of two main parts: kernel code and wrapper code. +# For any new backend looking to integrate with Inductor, customization of these two main +# parts are necessary to generate its specific code. +# +# Kernel code generation is determined by different Scheduling. Consequently, a new +# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, +# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. +# +# For the Wrapper, Inductor provides a PythonWrapperCodegen class to generate the Python wrapper code +# that bridges kernels. This allows out-of-tree backends to inherit from PythonWrapperCodegen, +# and override specific member functions to create backend-specific Python wrapper code. +# +# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part +# of the logic for either Scheduling or PythonWrapperCodegen. So the Scheduling and PythonWrapperCodegen interfaces +# provide flexibility to the backend. A backend can choose to implement these classes from scratch, +# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, +# register_backend_for_device, to equip a new backend at runtime. +# +# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. +# This backend can be used as a reference: +# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 +def register_backend_for_device( + device: str, + device_scheduling: SchedulingConstructor, + device_wrapper_codegen: WrapperConstructor, + device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, + device_custom_pass: Optional[CustomGraphModulePass] = None, +) -> None: + device_codegens[device] = DeviceCodegen( + device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen + ) + custom_backend_passes[device] = device_custom_pass + + +class BackendFeature(Enum): + FOREACH = auto() + BUCKETIZE = auto() + INPLACE_BUFFERS = auto() + MASKED_SCATTER_WITH_INDEX = auto() + SCAN = auto() + SORT = auto() + TUPLE_REDUCTION = auto() + PREFER_STORE_LOOP_ORDER = auto() + TRITON_TEMPLATES = auto() + REDUCE_TO_SINGLE_ELEMENT = auto() + + +def get_backend_features( + device: Union[torch.device, str, None], +) -> OrderedSet[BackendFeature]: + if device is None: + return OrderedSet() + init_backend_registration() + if isinstance(device, torch.device): + device_type = device.type + else: + assert isinstance(device, str), type(device) + device_type = device + device = torch.device(device_type) + scheduling_ctor = get_scheduling_for_device(device_type) + assert scheduling_ctor + scheduling = scheduling_ctor(None) + return scheduling.get_backend_features(device) + + +def has_backend_feature( + device: Union[torch.device, str, None], feature: BackendFeature +) -> bool: + """See also V.graph.has_feature""" + assert isinstance(feature, BackendFeature) + return feature in get_backend_features(device) + + +def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]: + return device_codegens[device].scheduling if device in device_codegens else None + + +def get_wrapper_codegen_for_device( + device: str, cpp_wrapper: bool = False +) -> Optional[WrapperConstructor]: + if device in device_codegens: + wrapper_codegen_obj: DeviceCodegen = device_codegens[device] + return ( + wrapper_codegen_obj.cpp_wrapper_codegen + if cpp_wrapper + else wrapper_codegen_obj.wrapper_codegen + ) + return None + + +def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]: + return custom_backend_passes[device] if device in custom_backend_passes else None + + +@functools.cache +def init_backend_registration() -> None: + from .cpp import CppScheduling + from .cpp_wrapper_cpu import CppWrapperCpu + from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef + from .cpp_wrapper_gpu import CppWrapperGpu + from .cpp_wrapper_mps import CppWrapperMps + from .cuda_combined_scheduling import CUDACombinedScheduling + from .halide import HalideScheduling + from .mps import MetalScheduling + from .triton import TritonScheduling + from .wrapper import PythonWrapperCodegen + + if get_scheduling_for_device("cpu") is None: + cpu_backends = { + "cpp": CppScheduling, + "halide": HalideScheduling, + "triton": TritonScheduling, + } + register_backend_for_device( + "cpu", + lambda scheduling: cpu_backends[config.cpu_backend](scheduling), + PythonWrapperCodegen, + CppWrapperCpuArrayRef + if config.aot_inductor.allow_stack_allocation + else CppWrapperCpu, + ) + + if get_scheduling_for_device("cuda") is None: + # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation + cuda_backends = { + "triton": CUDACombinedScheduling, + "halide": HalideScheduling, + } + register_backend_for_device( + "cuda", + lambda scheduling: cuda_backends[config.cuda_backend](scheduling), + PythonWrapperCodegen, + CppWrapperGpu, + ) + + if get_scheduling_for_device("xpu") is None: + register_backend_for_device( + "xpu", + TritonScheduling, + PythonWrapperCodegen, + CppWrapperGpu, + ) + + if get_scheduling_for_device("mps") is None: + register_backend_for_device( + "mps", + MetalScheduling, + PythonWrapperCodegen, + CppWrapperMps, + ) + + private_backend = torch._C._get_privateuse1_backend_name() + if ( + private_backend != "privateuseone" + and get_scheduling_for_device(private_backend) is None + ): + from torch.utils.backend_registration import _get_custom_mod_func + + try: + device_scheduling = _get_custom_mod_func("Scheduling") + wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen") + cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen") + if device_scheduling and wrapper_codegen and cpp_wrapper_codegen: + register_backend_for_device( + private_backend, + device_scheduling, + wrapper_codegen, + cpp_wrapper_codegen, + ) + except RuntimeError: + pass + + +def index_prevent_reordering( + index: Sequence[sympy.Expr], + index_vars: Sequence[sympy.Expr], + sizes: Sequence[sympy.Expr], +) -> list[sympy.Expr]: + from ..ir import FlexibleLayout + + # added contiguous index prevents reordering + return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] + + +def register_device_op_overrides( + device: str, device_op_overrides: DeviceOpOverrides +) -> None: + device_op_overrides_dict[device] = device_op_overrides + + +def get_device_op_overrides(device: str) -> DeviceOpOverrides: + assert isinstance(device, str), type(device) + + if not device_op_overrides_dict: + from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401 + from .cuda import device_op_overrides # noqa: F401 + from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 + + return device_op_overrides_dict[device] + + +DTYPE_TO_COMPUTATION_DTYPE: dict[torch.dtype, torch.dtype] = { + torch.bfloat16: torch.float, + torch.float16: torch.float, + **{ + dtype: dtype + for dtype in [ + torch.bool, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ] + }, +} + + +def deduce_output_dtype_by_name( + op_name: str, + *args: Any, + **kwargs: Any, +) -> Optional[torch.dtype]: + """ + Given op name and a list of input dtypes, deduce the output dtype + """ + if op_name in boolean_ops(): + return torch.bool + elif op_name in ( + "to_dtype", + "index_expr", + ): + return kwargs["dtype"] if "dtype" in kwargs else args[-1] + elif op_name in ( + "rand", + "randn", + ): + return torch.float + elif op_name in ( + "get_index", + "randint64", + "load_seed", + ): + return torch.int64 + elif op_name == "reduction": + return kwargs["dtype"] if "dtype" in kwargs else args[1] + elif op_name == "constant": + return kwargs["dtype"] if "dtype" in kwargs else args[-1] + elif op_name in ( + "load", + "store", + "store_reduction", + ): + buf_name = args[1] + return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + elif op_name == "to_dtype_bitcast": + return kwargs["dtype"] if "dtype" in kwargs else args[-2] + return None + + +def check_dtype( + buffer: IndentedBuffer, var: CSEVariableType, dtype: torch.dtype +) -> None: + backend = get_current_backend() + if config.test_configs.runtime_triton_dtype_assert and backend == "triton": + buffer.writeline(f"tl.static_assert({var}.dtype == {triton_type(dtype)})") + elif config.test_configs.static_cpp_dtype_assert and backend == "cpp": + from .cpp_utils import CppCSEVariable, DTYPE_TO_CPP + + assert isinstance(var, CppCSEVariable), type(var) + if dtype == torch.bool: + if var.is_vec: + is_same_dt = f"IsVecMaskType::value" + else: + # operator&(bool, bool) returns int and it can be used as boolean in C++ + is_same_dt = f"std::is_same_v || std::is_same_v" + else: + c_var_type = f"decltype({var})" + if var.is_vec: + c_var_type = f"typename {c_var_type}::value_type" + is_same_dt = f"std::is_same_v<{c_var_type}, {DTYPE_TO_CPP[dtype]}>" + + buffer.writeline(f"static_assert({is_same_dt});") + + +class DataTypePropagation: + def __init__(self, body: LoopBody) -> None: + self.body = body + self.graphs: dict[Union[Callable[..., Any], str], Any] = { + "root": body.root_block.graph + } + for k, v in body.subblocks.items(): + self.graphs[k] = v.graph + + def deduce_node_dtype_by_inputs(self, node: torch.fx.Node) -> Optional[torch.dtype]: + inputs = node.all_input_nodes + input_nodes = [ + n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" + ] + if len(input_nodes) == 0: + return None + + all_input_nodes_propagated = all( + OptimizationContext.key in n.meta + and n.meta[OptimizationContext.key].dtype is not None + for n in input_nodes + ) + if not all_input_nodes_propagated: + return None + + return functools.reduce( + torch.promote_types, + [n.meta[OptimizationContext.key].dtype for n in input_nodes], + ) + + def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node) -> torch.dtype: + sub_graph = self.graphs[node.target] + dtype = self.propagate_graph(sub_graph) + assert dtype + return dtype + + def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]: + if node.op == "placeholder": + return None + + if node.target == "output" and len(node.args) != 1: + # we can infer output node if it only have 1 arg + return None + + if node.target == operator.getitem: + node_arg = node.args[0] + assert isinstance(node_arg, torch.fx.Node), type(node_arg) + return self.deduce_node_dtype(node_arg) + + assert isinstance(node.target, str), type(node.target) + + if node.target.startswith("masked_subblock"): + return self.deduce_node_dtype_by_subgraph(node) + + if ( + output_dtype := deduce_output_dtype_by_name( + node.target, + *node.args, + **node.kwargs, + ) + ) is not None: + return output_dtype + + return self.deduce_node_dtype_by_inputs(node) + + def propagate_graph(self, graph: torch.fx.Graph) -> Optional[torch.dtype]: + assert graph.nodes + graph_dtype: Optional[torch.dtype] = None + # For masked_subblock, we use output's dtype to represent + # the dtype of this subgraph. For other cases, graph_dtype + # might be None + for node in graph.nodes: + if OptimizationContext.key in node.meta: + opt_ctx = node.meta[OptimizationContext.key] + else: + opt_ctx = OptimizationContext() + + opt_ctx.dtype = self.deduce_node_dtype(node) + node.meta[OptimizationContext.key] = opt_ctx + if node.target == "output": + graph_dtype = opt_ctx.dtype + return graph_dtype + + def propagate(self) -> Optional[torch.dtype]: + return self.propagate_graph(self.graphs["root"]) + + @classmethod + def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]: + return cls(body).propagate() + + @classmethod + def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]: + from ..loop_body import LoopBody + from ..scheduler import SchedulerNode + + assert isinstance(node, SchedulerNode), type(node) + assert isinstance(node._body, LoopBody), type(node._body) + return DataTypePropagation.propagate_loopbody(node._body) + + +class PythonPrinter(_PythonPrinter): + def doprint( + self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True + ) -> str: + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + + +class OpDecompositions: + """ + Decomposes inductor ops + """ + + @staticmethod + def identity(value: OpVarT) -> OpVarT: + # used to trigger cse + return value + + @staticmethod + def reciprocal(x: OpVarT) -> OpVarT: + return ops.truediv(ops.constant(1, torch.int32), x) + + @staticmethod + def square(x: OpVarT) -> OpVarT: + return ops.mul(x, x) + + @staticmethod + def erfc(x: OpVarT) -> OpVarT: + return ops.sub(ops.constant(1, torch.float32), ops.erf(x)) + + @staticmethod + def erfcx(x: OpVarT) -> OpVarT: + return ops.mul(ops.exp(ops.square(x)), ops.erfc(x)) + + @staticmethod + def expm1(x: OpVarT) -> OpVarT: + return ops.sub(ops.exp(x), ops.constant(1, torch.float32)) + + @staticmethod + def log10(x: OpVarT) -> OpVarT: + return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32)) + + @staticmethod + def log2(x: OpVarT) -> OpVarT: + return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32)) + + @staticmethod + def exp2(x: OpVarT) -> OpVarT: + return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32))) + + @staticmethod + def log1p(x: OpVarT) -> OpVarT: + return ops.log(ops.add(x, ops.constant(1, torch.int32))) + + @staticmethod + def sigmoid(x: OpVarT) -> OpVarT: + one = ops.constant(1, torch.int32) + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) + + @staticmethod + def relu(x: OpVarT) -> OpVarT: + return ops.maximum(x, ops.constant(0, torch.int32)) + + @staticmethod + def fma(x: OpVarT, y: OpVarT, z: OpVarT) -> OpVarT: + # for backends that don't override this (halide) + return ops.add(ops.mul(x, y), z) + + @staticmethod + def floor_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def trunc_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def remainder(a: OpVarT, b: OpVarT) -> OpVarT: + r = ops.mod(a, b) + cond = ops.and_( + ops.ne(r, ops.constant(0, torch.int32)), + ops.ne(ops.signbit(r), ops.signbit(b)), + ) + return ops.where(cond, ops.add(r, b), r) + + @staticmethod + def round_to_int(a: OpVarT, dtype: torch.dtype) -> OpVarT: + return ops.to_dtype(ops.round(a), dtype) + + +_RE_PAREN_NOT_NEEDED = re.compile(r"[a-z0-9_.]+|\([^)]*\)|", flags=re.IGNORECASE) + + +def _all_in_parens(string: str) -> bool: + if string[0] != "(" or len(string) < 2: + return False + count = 1 + for i, char in enumerate(string[1:]): + if char == "(": + count += 1 + elif char == ")": + count -= 1 + if count == 0 and i != len(string) - 2: + return False + assert count == 0 + return True + + +class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): + @staticmethod + def paren(string: OpVarT) -> OpVarT: + if ( + isinstance(string, CSEVariable) + or _RE_PAREN_NOT_NEEDED.fullmatch(string) + or _all_in_parens(string) + ): + # don't put extra parens for strings that are already wrapped in parens + return string + return f"({string})" + + @staticmethod + def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT: + return repr(value) + + @staticmethod + def bitwise_not(x: OpVarT) -> OpVarT: + return f"~{OpOverrides.paren(x)}" + + @staticmethod + def logical_not(a: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(a)} == 0" + + @staticmethod + def bitwise_and(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_or(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_xor(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_left_shift(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}" + + @staticmethod + def bitwise_right_shift(x: OpVarT, y: OpVarT) -> OpVarT: + return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}" + + @staticmethod + def int_truediv(a: OpVarT, b: OpVarT) -> OpVarT: + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + + @staticmethod + def load_seed(name: str, offset: OpVarT) -> OpVarT: + return ops.load(name, sympy.Integer(offset)) + + def indirect_indexing( + self, + var: OpVarT, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg: bool = True, + ) -> sympy.Symbol: + return sympy_index_symbol(str(var)) + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + raise NotImplementedError( + f"{type(self).__name__}: check_bounds should be handled by CSEProxy" + ) + + def load(self, name: str, index: sympy.Expr) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: load should be handled by CSEProxy" + ) + + def store( + self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None + ) -> None: + raise NotImplementedError( + f"{type(self).__name__}: store should be handled by CSEProxy" + ) + + def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None: + raise NotImplementedError( + f"{type(self).__name__}: store_reduction should be handled by CSEProxy" + ) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[OpVarT, tuple[OpVarT, ...]], + ) -> Union[OpVarT, tuple[OpVarT, ...]]: + raise NotImplementedError( + f"{type(self).__name__}: reduction should be handled by CSEProxy" + ) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[OpVarT, ...], tuple[OpVarT, ...]], + tuple[OpVarT, ...], + ], + values: tuple[OpVarT, ...], + ) -> tuple[OpVarT, ...]: + raise NotImplementedError( + f"{type(self).__name__}: scan should be handled by CSEProxy" + ) + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[OpVarT, ...], + stable: bool, + descending: bool, + ) -> tuple[OpVarT, ...]: + raise NotImplementedError( + f"{type(self).__name__}: sort should be handled by CSEProxy" + ) + + def bucketize( + self, + values: OpVarT, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: OpVarT, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[OpVarT] = None, + ) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: bucketize should be handled by CSEProxy" + ) + + def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: halide_clamp only implemented for Halide backend" + ) + + def inline_asm_elementwise( + self, + *inputs: OpVarT, + asm: str, + constraints: Optional[str] = None, + dtype: torch.dtype = torch.float32, + is_pure: bool = True, + pack: int = 1, + ) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" + ) + + def output(self, *args: OpVarT) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear at codegen time" + ) + + def placeholder(self, index: int) -> OpVarT: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear at codegen time" + ) + + @staticmethod + def _unimplemented(name: str) -> Callable[..., OpVarT]: + def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT: + raise NotImplementedError( + f"{type(self).__name__} does not implement ops.{name}" + ) + + unimplemented.__name__ = name + unimplemented.is_unimplemented = True # type: ignore[attr-defined] + return unimplemented + + @classmethod + def _is_unimplemented(cls, name: str) -> bool: + fn = getattr(cls, name, None) + default_fn = getattr(OpsHandler, name, None) + return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False) + + @classmethod + def _initialize_pointwise_overrides(cls, target: str) -> None: + assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target + + for funcname, data in pointwise_overrides_data.items(): + impl = getattr(data, target) + if impl is None: + if cls._is_unimplemented(funcname): + setattr(cls, funcname, cls._unimplemented(funcname)) + else: + assert funcname not in cls.__dict__, ( + f"multiple definitions of {funcname} on {cls.__name__}" + ) + impl.__name__ = funcname + setattr(cls, funcname, staticmethod(impl)) + + +@dataclasses.dataclass +class OverridesData: + name: str + cpp: Callable[..., str] + # None when not impl in libdevice/triton + triton: Optional[Callable[..., str]] = None + # None when not impl in aten/.../vec + cppvec: Optional[Callable[..., str]] = None + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + halide: Optional[Callable[..., str]] = None + mps: Optional[Callable[..., str]] = None + + +# NB: if you add a new special function, don't forget to update +# torch._inductor.ops_handler too +pointwise_overrides_data: dict[str, OverridesData] = dict( + airy_ai=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"airy_ai_forward({x})", + name="special_airy_ai", + ), + bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j0_forward({x})", + triton=lambda x: f"libdevice.j0({x})", + name="special_bessel_j0", + ), + bessel_j1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j1_forward({x})", + triton=lambda x: f"libdevice.j1({x})", + name="special_bessel_j1", + ), + bessel_y0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y0_forward({x})", + triton=lambda x: f"libdevice.y0({x})", + name="special_bessel_y0", + ), + bessel_y1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y1_forward({x})", + triton=lambda x: f"libdevice.y1({x})", + name="special_bessel_y1", + ), + digamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_digamma({x})", + cppvec=lambda x: f"{x}.digamma()", + name="digamma", + ), + # no cpp nor triton implementation for entr, it is defined as decomposition + # erf, erfc + erfcx=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_erfcx({x})", + triton=lambda x: f"libdevice.erfcx({x})", + name="special_erfcx", + ), + fma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})", + cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})", + triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})", + name="fma", + ), + # erfinv, exp2, expit, gammaln + igamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="igamma", + ), + igammac=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="igammac", + ), + gammainc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="special_gammainc", + ), + gammaincc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="special_gammaincc", + ), + i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + cppvec=lambda x: f"{x}.i0()", + name="i0", + ), + i0e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0e({x})", + cppvec=lambda x: f"{x}.i0e()", + name="special_i0e", + ), + i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_i1", + ), + i1e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1e({x})", + name="special_i1e", + ), + log_ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_log_ndtr({x})", + name="special_log_ndtr", + ), + # logit + modified_bessel_i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i0_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + name="special_modified_bessel_i0", + ), + modified_bessel_i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i1_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_modified_bessel_i1", + ), + modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k0_forward({x})", + name="special_modified_bessel_k0", + ), + modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k1_forward({x})", + name="special_modified_bessel_k1", + ), + # multigamma + ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtr({x})", + name="special_ndtr", + ), + ndtri=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtri({x})", + name="special_ndtri", + ), + polygamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, + y: f"{x} == 0 ? calc_digamma({y}) : ({x} == 1 ? trigamma({y}) : calc_polygamma({y}, {x}))", + name="polygamma", + ), + # psi - alias to digamma + # round + scaled_modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})", + name="special_scaled_modified_bessel_k0", + ), + scaled_modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})", + name="special_scaled_modified_bessel_k1", + ), + # sinc + spherical_bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"spherical_bessel_j0_forward({x})", + name="special_spherical_bessel_j0", + ), + zeta=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"zeta({x}, {y})", + name="special_zeta", + ), + chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})", + name="special_chebyshev_polynomial_t", + ), + chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})", + name="special_chebyshev_polynomial_u", + ), + chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})", + name="special_chebyshev_polynomial_v", + ), + chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})", + name="special_chebyshev_polynomial_w", + ), + legendre_polynomial_p=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})", + name="special_legendre_polynomial_p", + ), + shifted_chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_t", + ), + shifted_chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_u", + ), + shifted_chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_v", + ), + shifted_chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_w", + ), + hermite_polynomial_h=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})", + name="special_hermite_polynomial_h", + ), + hermite_polynomial_he=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})", + name="special_hermite_polynomial_he", + ), + laguerre_polynomial_l=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})", + name="special_laguerre_polynomial_l", + ), +) + + +def is_buffer_removed(name: str) -> bool: + return any( + name in x + for x in ( + V.graph.removed_buffers, + V.kernel.removed_buffers, + V.graph.inplaced_to_remove, + V.kernel.inplaced_to_remove, + ) + ) + + +class DeferredLine(DeferredLineBase): + """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" + + def __init__(self, name: str, line: str): + super().__init__(line) + self.name = name + assert not isinstance(line, DeferredLineBase) + + def __call__(self) -> Optional[str]: + if not is_buffer_removed(self.name): + return self.line + return None + + def _new_line(self, line: str) -> DeferredLine: + return DeferredLine(self.name, line) + + +class BracesBuffer(IndentedBuffer): + def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]: + @contextlib.contextmanager + def ctx() -> Iterator[None]: + for _ in range(offset): + self.writeline("{") + self._indent += 1 + for _ in range(-offset): + self._indent -= 1 + self.writeline("}") + yield + for _ in range(-offset): + self.writeline("{") + self._indent += 1 + for _ in range(offset): + self._indent -= 1 + self.writeline("}") + + return ctx() + + +class InplacedBuffer(NamedTuple): + inner_name: str + other_names: list[str] + + +@dataclasses.dataclass +class ArgName: + name: str + # is_constexpr=True is used to attach a " : tl.constexpr" into the argument list + is_constexpr: bool = False + + def full_name(self) -> str: + return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}" + + +class RemovedArg: + def __str__(self) -> str: + return "REMOVED" + + +REMOVED = RemovedArg() + + +class KernelArgs: + @staticmethod + def _lookup( + prefix: str, + odict: Union[dict[_T, Union[str, RemovedArg]], dict[_T, str]], + name: _T, + ) -> str: + result: Union[str, RemovedArg] = odict.get(name, REMOVED) + if isinstance(result, RemovedArg): + odict[name] = new_result = f"{prefix}{len(odict)}" + return new_result + return result + + def __init__(self) -> None: + self.input_buffers: dict[str, str] = {} + self.output_buffers: dict[str, Union[str, RemovedArg]] = {} + self.inplace_buffers: dict[str, Union[InplacedBuffer, RemovedArg]] = {} + self.sizevars: dict[sympy.Expr, str] = {} + self.workspace_args: list[WorkspaceArg] = [] + + def __repr__(self) -> str: + return "KernelArgs({})".format( + ", ".join( + map( + repr, + [ + self.input_buffers, + self.output_buffers, + self.inplace_buffers, + self.sizevars, + ], + ) + ) + ) + + @staticmethod + def _buffer_is_marked_removed(name: Any) -> bool: + # this function is needed by MTIA + return isinstance(name, RemovedArg) + + def input(self, name: str) -> str: + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.output_buffers: + return cast(str, self.output_buffers[name]) + if name in self.inplace_buffers: + return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name + if name.startswith("seed"): + return self._lookup("seed", self.input_buffers, name) + return self._lookup("in_ptr", self.input_buffers, name) + + def output(self, name: str) -> str: + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.inplace_buffers: + return cast(InplacedBuffer, self.inplace_buffers[name]).inner_name + return self._lookup("out_ptr", self.output_buffers, name) + + def make_inplace(self, input_name: str, output_name: str) -> None: + if input_name in V.graph.unaligned_buffers: + V.graph.unaligned_buffers.add(output_name) + assert output_name not in self.inplace_buffers, output_name + if input_name in self.inplace_buffers: + buf = self.inplace_buffers[input_name] + assert not isinstance(buf, RemovedArg) + buf.other_names.append(output_name) + self.inplace_buffers[output_name] = buf + else: + alive_buffers = [ + val + for val in self.inplace_buffers.values() + if not isinstance(val, RemovedArg) + ] + removed_buffers = [ + val + for val in self.inplace_buffers.values() + if isinstance(val, RemovedArg) + ] + inplace_buffer_idx = len(unique(alive_buffers)) + len(removed_buffers) + buf = InplacedBuffer( + f"in_out_ptr{inplace_buffer_idx}", + [input_name, output_name], + ) + self.inplace_buffers[input_name] = buf + self.inplace_buffers[output_name] = buf + + def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]: + """ + Allocate or extend a workspace buffer of nbytes bytes. + + This function manages the allocation of a workspace buffer. It either creates + a new WorkspaceArg or extends an existing one. + + Note: + - Calling this function will in-place mutate the args by adding or updating + a WorkspaceArg. + - The codegen for generating the Python argdefs and call_defs will check + this field and allocate the buffer accordingly. + - A new argument "ws_ptr" will be present in the generated code. + + Args: + nbytes (sympy.Expr): The number of bytes to allocate. + zero_fill (bool): Whether to initialize the buffer to zero. + + Returns: + Tuple[str, int]: A tuple containing: + - "ws_ptr": A string identifier for the workspace pointer. + - offset: An integer representing the byte offset in the workspace. + """ + arg = WorkspaceArg( + count=nbytes, + zero_mode=WorkspaceZeroMode.from_bool(zero_fill), + device=V.graph.get_current_device_or_throw(), + outer_name=WorkspaceArg.unique_name(), + ) + for i, existing_arg in enumerate(self.workspace_args): + if WorkspaceArg.can_join(existing_arg, arg): + offset = existing_arg.count + self.workspace_args[i] = WorkspaceArg.join(existing_arg, arg) + return existing_arg.inner_name, offset + assert ( + existing_arg.inner_name != arg.inner_name + and existing_arg.outer_name != arg.outer_name + ), existing_arg + self.workspace_args.append(arg) + return arg.inner_name, 0 + + def semaphores(self, min_size: sympy.Expr) -> str: + """ + Lazily allocate a graph-wide semaphores buffer with at least min_size. This is a single buffer shared by + all kernels and zero initialized once at graph start. Each kernel must leave the buffer zeroed on exit. + + Warning: multiple calls to this function will return the same buffer. + + Args: + min_size: the number of int32 semaphores required + + Returns: + name of the semaphores buffer + """ + current_device = V.graph.get_current_device_or_throw() + arg = WorkspaceArg( + count=min_size, + zero_mode=WorkspaceZeroMode.ZERO_PER_GRAPH, + dtype=torch.uint32, + inner_name="sem_ptr", + outer_name=f"semaphores_{current_device.type}_{current_device.index}", + device=current_device, + ) + for existing_arg in self.workspace_args: + if existing_arg.inner_name == arg.inner_name: + assert arg == existing_arg, (arg, existing_arg) + self.workspace_args.append(arg) + return arg.inner_name + + def seed_offset(self, name: str, value: int) -> str: + assert isinstance(value, int), (type(value), value) + # here we are lifting a constant integer into an arg to the kernel to try to get additional cache hits + value = sympy.Integer(value) + if value in self.sizevars: + return self.sizevars[value] + if name in self.sizevars.values(): + name = ( + f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" + ) + self.sizevars[value] = name + return name + + def size(self, name: sympy.Symbol) -> str: + assert isinstance(name, sympy.Symbol), (type(name), name) + if name.name == "seed": + self.sizevars[name] = "seed" # don't manage the name of seeds + return "seed" + return self._lookup("ks", self.sizevars, name) + + def call_names(self) -> Iterator[str]: + return chain( + self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() + ) + + def arg_name(self, name: str) -> Optional[str]: + """ + Returns inner name of a given outer name. + """ + inplaced = self.inplace_buffers.get(name, None) + if inplaced is not None and not isinstance(inplaced, RemovedArg): + return inplaced.inner_name + output_name = self.output_buffers.get(name, None) + if output_name is not None and not isinstance(output_name, RemovedArg): + return output_name + return self.input_buffers.get(name, None) + + def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str: + return buf + + def wrap_size_arg(self, size: SymbolLike) -> str: + return str(size) + + def cpp_argdefs( + self, dtype_to_cpp_type: Optional[dict[torch.dtype, str]] = None + ) -> tuple[list[str], list[str], list[str]]: + from .cpp_utils import INDEX_TYPE + + if dtype_to_cpp_type is None: + from .cpp_utils import DTYPE_TO_CPP + + dtype_to_cpp_type = DTYPE_TO_CPP + + call_args = [] + arg_defs = [] + arg_types = [] + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + outer = inplaced.other_names[-1] + inner = inplaced.inner_name + dtype = V.graph.get_dtype(outer) + cpp_dtype = dtype_to_cpp_type[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.input_buffers.items(): + if outer in self.inplace_buffers: + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = dtype_to_cpp_type[dtype] + arg_defs.append(f"const {cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"const {cpp_dtype}*") + for outer, maybe_inner in self.output_buffers.items(): + if outer in self.inplace_buffers or isinstance(maybe_inner, RemovedArg): + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = dtype_to_cpp_type[dtype] + arg_defs.append(f"{cpp_dtype}* {maybe_inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.sizevars.items(): + arg_defs.append(f"const {INDEX_TYPE} {inner}") + call_args.append(self.wrap_size_arg(outer)) + arg_types.append(f"const {INDEX_TYPE}") + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + assert not self.workspace_args, "Workspace not supported on CPU " + return arg_defs, call_args, arg_types + + def python_argdefs( + self, + ) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]: + arg_defs: list[ArgName] = [] + call_args: list[str] = [] + arg_types: list[Any] = [] + precompile_args: list[KernelArgType] = [] + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + arg_defs.append(ArgName(inplaced.inner_name)) + call_args.append(inplaced.other_names[-1]) + arg_types.append(V.graph.get_dtype(inplaced.other_names[-1])) + precompile_args.append( + TensorArg( + name=inplaced.inner_name, + buffer=inplaced.other_names[-1], + dtype=V.graph.get_dtype(inplaced.other_names[-1]), + ) + ) + for outer, inner in chain( + self.input_buffers.items(), self.output_buffers.items() + ): + if outer in self.inplace_buffers or isinstance(inner, RemovedArg): + continue + arg_defs.append(ArgName(inner)) + call_args.append(outer) + arg_types.append(V.graph.get_dtype(outer)) + precompile_args.append( + TensorArg( + name=inner, + buffer=outer, + dtype=V.graph.get_dtype(outer), + ) + ) + for outer, inner in self.sizevars.items(): + arg_defs.append(ArgName(inner)) + call_args.append(outer) + arg_types.append(type(outer)) + precompile_args.append(SizeArg(inner, outer)) + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + for arg in self.workspace_args: + arg_defs.append(ArgName(arg.inner_name)) + call_args.append(arg.outer_name) + precompile_args.append(arg) + arg_types.append(arg.dtype) + return arg_defs, call_args, precompile_args, arg_types + + def aliases(self) -> Iterator[tuple[str, str]]: + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + for other in inplaced.other_names: + if ( + other in V.graph.inplaced_to_remove + or other in V.kernel.inplaced_to_remove + ): + continue + if other in self.input_buffers: + yield self.input_buffers[other], inplaced.inner_name + if other in self.output_buffers: + yield cast(str, self.output_buffers[other]), inplaced.inner_name + + def is_removed(self, name: str) -> bool: + return isinstance( + self.output_buffers.get(name, REMOVED), RemovedArg + ) and isinstance(self.inplace_buffers.get(name, REMOVED), RemovedArg) + + # Includes inplace buffers, excludes removed buffers. Essentially, + # after you do a call into this kernel, which buffers actually contain + # updated data? Modeled off of python_argdefs. + def live_output_buffers(self) -> OrderedSet[str]: + live_outs: OrderedSet[str] = OrderedSet() + for inplaced in unique(self.inplace_buffers.values()): + if isinstance(inplaced, RemovedArg): + continue + live_outs.add(inplaced.other_names[-1]) + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or isinstance(inner, RemovedArg): + continue + live_outs.add(outer) + return live_outs + + +class CSEVariable: + """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. + To do so, the backends can simply overload `Kernel.create_cse_var` + The "CSEVariable.update_on_args" method gives you a hook for annotations + See example of TritonCSEVariable in triton.py + """ + + def __init__( + self, + name: str, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + assert isinstance(bounds, ValueRanges), type(bounds) + self.name = name + self.bounds = bounds + self.use_count = 1 # track how many times this expression is used + self.dtype = dtype + + def __str__(self) -> str: + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + return isinstance(other, CSEVariable) and other.name == self.name + + def update_on_args(self, name: str, args: Any, kwargs: Any) -> None: + pass + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.name!r})" + + +AugmentedKeyT = TypeVar("AugmentedKeyT", default=str) +CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable) + +if TYPE_CHECKING: + ReductionCacheKey = tuple[ + torch.dtype, + ReductionType, + Union[CSEVariable, tuple[CSEVariable, ...]], + ] + + +class CSE(Generic[CSEVariableType, AugmentedKeyT]): + """Common subexpression elimination""" + + def __init__( + self, + prefix: str = "", + suffix: str = "", + name_prefix: str = "tmp", + iter_buffers: Optional[itertools.count[int]] = None, + store_cache: Optional[MutableMapping[str, CSEVariableType]] = None, + reduction_cache: Optional[ + MutableMapping[ReductionCacheKey, CSEVariableType] + ] = None, + varname_map: Optional[dict[str, CSEVariableType]] = None, + ): + self.prefix = prefix + self.suffix = suffix + self._cache: MutableMapping[AugmentedKeyT, CSEVariableType] = {} + self.name_prefix = name_prefix + self.store_cache: MutableMapping[str, CSEVariableType] = store_cache or {} + self.reduction_cache: MutableMapping[ReductionCacheKey, CSEVariableType] = ( + reduction_cache or {} + ) + self.iter_buffer_ids: itertools.count[int] = iter_buffers or itertools.count() + self.invalidated_stores: OrderedSet[str] = OrderedSet() + self.varname_map: dict[str, CSEVariableType] = varname_map or {} + + def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None: + for name, tmp in [*self.store_cache.items()]: + if tmp not in keep_vars: + del self.store_cache[name] + self.invalidated_stores.add(name) + if keep_vars: + self._cache = {k: v for k, v in self._cache.items() if v in keep_vars} + else: + self._cache = {} + + def clone(self) -> Self: + return type(self)( + prefix=self.prefix, + suffix=self.suffix, + name_prefix=self.name_prefix, + iter_buffers=self.iter_buffer_ids, + store_cache=self.store_cache, + varname_map=self.varname_map, + reduction_cache=self.reduction_cache, + ) + + def scoped_copy(self) -> Self: + """Return a copy of using ScopedDict so changes to *_cache aren't visible in self""" + new_cse = self.clone() + new_cse._cache = ScopedDict(self._cache) + new_cse.reduction_cache = ScopedDict(self.reduction_cache) + new_cse.store_cache = ScopedDict(self.store_cache) + return new_cse + + def augment_key(self, cache_key: str) -> AugmentedKeyT: + "Override this method to augment cache key with backend specifics" + return cast(AugmentedKeyT, cache_key) + + def put(self, cache_key: str, val: CSEVariableType) -> None: + self._cache[self.augment_key(cache_key)] = val + + def contains(self, cache_key: str) -> bool: + return self.augment_key(cache_key) in self._cache + + def try_get(self, cache_key: str) -> Optional[CSEVariableType]: + return self._cache.get(self.augment_key(cache_key), None) + + def get(self, cache_key: str) -> CSEVariableType: + return self._cache[self.augment_key(cache_key)] + + def generate( + self, + buffer: IndentedBuffer, + expr: Union[str, CSEVariable, OpsValue, IndentedBuffer, DeferredLineBase], + *, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + write: bool = True, + assignment: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> CSEVariableType: + if isinstance(expr, OpsValue): + expr = expr.value + + assert write or assignment + if isinstance(expr, CSEVariable): + # If the expressions were always created with all the information, we could + # assert expr.bounds == bounds, but sometimes the expression is created + # with the loose ValueRanges.unknown(), so we need to tighten the bounds + expr.bounds = expr.bounds.tighten(bounds) + expr.use_count += 1 + return cast(CSEVariableType, expr) + elif isinstance(expr, IndentedBuffer): + cache_key = expr.getvalue() + elif isinstance(expr, DeferredLineBase): + cache_key = expr.line + else: + assert isinstance(expr, str) + cache_key = expr + var = self.try_get(cache_key) + if not var: + var = self.newvar(bounds, dtype) + self.put(cache_key, var) + if write: + if V.kernel.current_node: + V.kernel.current_node.codegen_originating_info( + buffer, only_once=True + ) + if isinstance(expr, IndentedBuffer): + if assignment: + buffer.writeline(f"{self.prefix}{var} =") + buffer.splice(expr) + buffer.writeline(self.suffix) + elif isinstance(expr, DeferredLineBase): + assert assignment + buffer.writeline( + expr._new_line(f"{self.prefix}{var} = {expr.line}{self.suffix}") + ) + else: + if assignment: + line = f"{self.prefix}{var} = {expr}{self.suffix}" + else: + line = f"{expr}{self.suffix}" + buffer.writeline(line) + + # cpp backend cannot determine is_vec at this point + if ( + assignment + and ( + config.test_configs.runtime_triton_dtype_assert + or config.test_configs.static_cpp_dtype_assert + ) + and dtype is not None + and get_current_backend() != "cpp" + ): + check_dtype(buffer, var, dtype) + + else: + var.bounds = var.bounds.tighten(bounds) + var.use_count += 1 + + return var + + def newvar( + self, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + dtype: Optional[torch.dtype] = None, + ) -> CSEVariableType: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name, bounds, dtype) + self.varname_map[var_name] = var + return var + + def namedvar( + self, + name: str, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + dtype: Optional[torch.dtype] = None, + ) -> CSEVariableType: + torch._check_value( + name not in self.varname_map, lambda: f"duplicate name: {name}" + ) + var = V.kernel.create_cse_var(name, bounds, dtype) + self.varname_map[name] = var + return var + + +class CodeGen: + def __init__(self) -> None: + super().__init__() + self.exit_stack = contextlib.ExitStack() + + def __enter__(self) -> Self: + self.exit_stack.__enter__() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + +class Kernel(CodeGen, Generic[CSEVariableType]): + newvar_prefix: str = "" + suffix: str = "" + overrides: Optional[Callable[[], OpsHandler[Any]]] = None + + def __init__( + self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True + ) -> None: + super().__init__() + if increase_kernel_count: + metrics.generated_kernel_count += 1 + self.args = args or KernelArgs() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + + self.num_load = 0 + self.num_reduction = 0 + + self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix) + self.must_keep_buffers: OrderedSet[str] = OrderedSet() + self.store_buffer_names: OrderedSet[str] = OrderedSet() + self._load_mask: Optional[str] = None + self._load_other: Union[None, int, float] = None + # OrderedSet in set_current_node + self.current_node: Optional[SchedulerNode] = None + self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None + + self.removed_buffers: OrderedSet[str] = OrderedSet() + self.inplaced_to_remove: OrderedSet[str] = OrderedSet() + + # key: the buffer to write + # value: the buffer to read and whose memory can be reused for + # the buffer specified by key + self.inplace_update_buffers: dict[str, str] = {} + # Set minimum number of elements processed per thread. + self.min_elem_per_thread = 1 + self.kernel_name: Optional[str] = None + + @contextlib.contextmanager + def set_current_node(self, node: SchedulerNode) -> Iterator[None]: + prior = self.current_node + self.current_node = node + self.node_to_bounds = node._body.bounds().get_bounds() + try: + yield + finally: + self.current_node = prior + + @contextlib.contextmanager + def swap_buffers( + self, + lb: IndentedBuffer, + cb: Optional[IndentedBuffer] = None, + sb: Optional[IndentedBuffer] = None, + ) -> Iterator[None]: + if cb is None: + cb = lb + if disallow_stores := sb is None: + sb = IndentedBuffer() + loads = self.loads + compute = self.compute + stores = self.stores + cse = self.cse + self.loads = lb + self.compute = cb + self.stores = sb + self.cse = cse.scoped_copy() + try: + yield + finally: + self.loads = loads + self.compute = compute + self.stores = stores + self.cse = cse + if disallow_stores: + assert not sb, "unexpected store inside swap_buffers" + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + raise NotImplementedError + + def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable: + """A load the depends on an index we have read""" + prior = self.loads + try: + # put the load in the compute section as it might have deps + self.loads = self.compute + return self.load(name, index) + finally: + self.loads = prior + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + raise NotImplementedError + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + raise NotImplementedError + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + raise NotImplementedError + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...] + ], + values: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + raise NotImplementedError + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> tuple[CSEVariable, ...]: + raise NotImplementedError + + def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]: + raise NotImplementedError + + def bucketize( + self, + values: CSEVariable, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + raise NotImplementedError + + @property + def assert_function(self) -> str: + raise NotImplementedError + + def indirect_assert( + self, + var: Union[CSEVariable, str], + lower: Optional[str], + upper: Optional[str], + mask: Optional[Union[CSEVariable, str]] = None, + ) -> str: + if isinstance(var, CSEVariable): + var = str(var) + assert isinstance(var, str), type(var) + assert lower is None or isinstance(lower, str) + assert upper is None or isinstance(upper, str) + if lower and upper: + # The conditions need to be in parens because of Python's operator precedence. + # It'd be less error-prone to use and/or/not, which is supported by triton + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower} <= {var} < {upper}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = cond + else: + assert upper + cond = f"{var} < {upper}" + cond_print = cond + + if mask: + cond = f"({cond}) | ~({mask})" + + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + raise NotImplementedError + + def index_to_str(self, index: sympy.Expr) -> str: + raise NotImplementedError + + def __enter__(self) -> Self: + super().__enter__() + assert self.overrides + self.exit_stack.enter_context( + V.set_ops_handler(CSEProxy(self, self.overrides())) + ) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.remove_kernel_local_buffers() + super().__exit__(exc_type, exc_val, exc_tb) + + def remove_kernel_local_buffers(self) -> None: + """ + Any buffers that are both created and have a last use in the + same kernel can be removed. + + Note that V.graph.scheduler can be None when codegening triton template + kernels. + """ + scheduler = V.graph.scheduler + if not scheduler: + return + fused_node_names = OrderedSet( + scheduler.name_to_buf[buf].defining_op_name() + for buf in self.store_buffer_names + if buf in scheduler.name_to_buf + ) + names_to_remove: OrderedSet[str] = OrderedSet() + for name in self.store_buffer_names: + if ( + name not in self.must_keep_buffers + and name not in self.args.input_buffers + and scheduler.can_buffer_be_removed_through_fusion( + name, fused_node_names + ) + ): + names_to_remove.add(name) + + for name in names_to_remove: + if name in self.args.inplace_buffers: + buf = self.args.inplace_buffers[name] + if isinstance(buf, RemovedArg): + continue + remove = all(n in names_to_remove for n in buf.other_names) + if remove: + self.remove_inplace_buffer(name) + self.inplaced_to_remove.add(name) + else: + self.remove_buffer(name) + + def remove_buffer(self, name: str) -> None: + # Assign a special value instead of deleting the entry + # because we still rely on output_buffers's length to + # generate unique arg name. + log.debug("remove_buffer(%r)", name) + self.args.output_buffers[name] = REMOVED + self.removed_buffers.add(name) + + def remove_inplace_buffer(self, name: str) -> None: + log.debug("removing_inplace_buffer(%r)", name) + self.args.inplace_buffers[name] = REMOVED + self.removed_buffers.add(name) + + def rename_indexing( + self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr] + ) -> sympy.Expr: + # adds the necessary kernel args for index expressions + # and renames variables in index expressions to kernel arg names + if isinstance(index, (list, tuple)): + return [self.rename_indexing(x) for x in index] + index = V.graph.sizevars.simplify(index) + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) + replacements = { + x: self.args.size(x) + for x in sorted_symbols + if symbol_is_type( + x, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + ), + ) + } + return sympy_subs(index, replacements) + + def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable: + return CSEVariable(*args, **kwargs) + + def arg_name(self, node: IRNode) -> Optional[str]: + """ + Returns arg name of a given input or output node. + """ + if node is None: + return None + return self.args.arg_name(node.get_name()) + + +@dataclasses.dataclass +class OptimizationContext: + key: ClassVar[str] = "opt_ctx" + + dtype: Optional[torch.dtype] = None + ops_name: str = "" + + +@functools.cache +def jinja2_env() -> Any: + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class KernelTemplate: + """ + Base class for defining kernel templates. + + Children classes: TritonTemplate, CUDATemplate + """ + + @staticmethod + def indent_except_first( + source: str, num_indents: int, indents_spacing: int = 4 + ) -> str: + lines = source.splitlines(True) + if len(lines) > 1: + lines[1:] = [ + (" " * indents_spacing * num_indents) + line for line in lines[1:] + ] + return "".join(lines) + + @staticmethod + def _template_from_string(source: str) -> Any: + env = jinja2_env() + if env is None: + return None + env.filters["indent_except_first"] = KernelTemplate.indent_except_first + from jinja2 import TemplateSyntaxError + + try: + return env.from_string(source) + except TemplateSyntaxError as e: + + class DetailedTemplateSyntaxError(TemplateSyntaxError): + def __init__(self, original_error: TemplateSyntaxError) -> None: + super().__init__( + original_error.message, + original_error.lineno, + original_error.name, + original_error.filename, + ) + self.original_error = original_error + + def __str__(self) -> str: + error_info = f"Error in template at line {self.lineno}\n" + error_info += f"Error message: {self.message}\n" + if hasattr(self.original_error, "source"): + lines = self.original_error.source.split("\n") + error_info += "Context:\n" + start = max(0, self.lineno - 2) + end = min(len(lines), self.lineno + 2) + for i in range(start, end): + if i == self.lineno - 1: + error_info += f"{i + 1}: --> {lines[i]}\n" + if hasattr(self.original_error, "column"): + error_info += ( + " " + + " " * (self.original_error.column - 1) + + "^\n" + ) + else: + error_info += f"{i + 1}: {lines[i]}\n" + return error_info + + raise DetailedTemplateSyntaxError(e) from e + + @staticmethod + def _fake_get_dtype( + fake_outs: Union[list[Buffer], Buffer], + ) -> Callable[[str], torch.dtype]: + _get_dtype_real = V.graph.get_dtype + if isinstance(fake_outs, (list, tuple)): + lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs} + else: + lookup = {fake_outs.get_name(): fake_outs.get_dtype()} + + def get_dtype(name: str) -> torch.dtype: + result = lookup.get(name) + if result is not None: + return result + return _get_dtype_real(name) + + return get_dtype + + def __init__(self, name: str) -> None: + self.name = name + + def maybe_append_choice( + self, choices: list[Any], **kwargs: Any + ) -> Optional[NotImplementedError]: + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + Returns None if success, otherwise returns the error. + + choices: A list of ChoiceCallers. + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + + try: + choices.append(self.generate(**kwargs)) + return None + except NotImplementedError as e: + log.info( + "Cannot Append Choice: %s. KernelTemplate type is %s", + e, + type(self), + stack_info=log.getEffectiveLevel() < logging.INFO, + ) + return e + + def generate(self, **kwargs: Any) -> ChoiceCaller: + """ + Generates a ChoiceCaller instance from the given arguments. + """ + + raise NotImplementedError + + +class CSEProxy(DefaultHandler): + name = "CSEProxy" + + def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): + super().__init__() + from ..bounds import ValueRangeAnalysis + + self.vr_analysis = ValueRangeAnalysis() + self.kernel = kernel + self.parent_handler = parent_handler + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + bounds = self._bound_variable(name, *args, **kwargs) + + value = getattr(self.parent_handler, name)(*args, **kwargs) + dtype_handler = DtypePropagationOpsHandler() + + backend = get_current_backend() + + output_dtype = None + if name == "masked" and backend == "triton": + output_dtype = value.dtype + elif name == "masked" and backend == "cpp": + output_dtype = V.interpreter.current_node.meta.get( + OptimizationContext.key, None + ).dtype + elif backend in ("triton", "cpp", "mps"): + dtype_op = getattr(dtype_handler, name) + output_dtype = dtype_op(*args, **kwargs) + + if backend in ("triton", "cpp"): + # maybe there are some exceptions on mps? + assert output_dtype is not None + + output_idx = 0 + + def do_cse(v: str) -> CSEVariable: + # we tree_map over the output, so we need to fetch corresponding dtype + nonlocal output_idx + var_dtype: Optional[torch.dtype] = ( + output_dtype[output_idx] + if isinstance(output_dtype, (list, tuple)) + else output_dtype + ) + output_idx += 1 + + # some cpp op implementations don't set the dtype + if backend == "cpp" and isinstance(v, CSEVariable) and v.dtype is None: + v.dtype = var_dtype + + csevar = V.kernel.cse.generate( + V.kernel.compute, + v, + bounds=bounds, + dtype=output_dtype, + ) + + csevar.update_on_args(name, args, kwargs) + + if ( + config.test_configs.runtime_triton_dtype_assert + or config.test_configs.static_cpp_dtype_assert + ): + assert var_dtype is not None + check_dtype(V.kernel.compute, csevar, var_dtype) + return csevar + + return pytree.tree_map(do_cse, value) + + def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]: + """ + If the variable comes from an FX node, we forward the bound we have already computed + Else, if the variable when codegen'ing another op, we try to compute its bounds + """ + from ..bounds import ValueRangeAnalysis + from ..select_algorithm import TritonTemplateKernel + from .cuda.cuda_kernel import CUDATemplateKernel + + if isinstance(V.kernel, TritonTemplateKernel): + return ValueRanges.unknown() + + if isinstance(V.kernel, CUDATemplateKernel): + return ValueRanges.unknown() + + fx_node = V.interpreter.current_node + if fx_node.target == name and self.kernel.node_to_bounds is not None: + assert isinstance(self.kernel.node_to_bounds, dict), type( + self.kernel.node_to_bounds + ) + return self.kernel.node_to_bounds.get(fx_node, ValueRanges.unknown()) + elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): + # These create lots of inner strings. We would need to compute the bounds at the ops + # We will also likely not get much from computing VRs on these nodes + if any(s in fx_node.target for s in ("set_indirect", "reduction", "scan")): + return ValueRanges.unknown() + + # We assume that the inputs come from `ops.` and are not strings. If you want to generate + # intermediary strings, wrap them in CSE variables with properly initialised bounds. + + # If there is no FX bound but we know how to compute one we do so + assert not kwargs + + def arg_to_bound(x: Any) -> Any: + if isinstance(x, CSEVariable): + return x.bounds + elif isinstance(x, sympy.Expr): + return bound_sympy(x) + else: + return x + + arg_bounds = list(map(arg_to_bound, args)) + return getattr(self.vr_analysis, name)(*arg_bounds) + return ValueRanges.unknown() + + def indirect_indexing( + self, + var: CSEVariable, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg: bool = True, + ) -> sympy.Symbol: + if isinstance(size, int): + size = sympy.Integer(size) + assert isinstance(size, sympy.Expr), (type(size), size) + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: + if wrap_neg: + stm = ops.add(var, ops.index_expr(size, torch.long)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: + lt = ops.lt(var, 0) + stm = ops.where(lt, stm, var) + else: + stm = var + + # Propagate bounds as we know how to compute them properly + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg_bounds = var.bounds & ValueRanges(-int_oo, -1) + new_bounds = ValueRanges( + neg_bounds.lower + size, neg_bounds.upper + size + ) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: + pos = var.bounds & ValueRanges(0, int_oo) + new_bounds = new_bounds | pos + + var = self.kernel.cse.generate(self.kernel.compute, stm, bounds=new_bounds) + + sympy_var = self.parent_handler.indirect_indexing(var, size, check) + if generate_assert(check): + assert_lower = not (var.bounds.lower >= 0) + # value ranges cannot x < s when x and s are symbols + assert_upper = not isinstance(size, sympy.Number) or not ( + var.bounds.upper < size + ) + self.kernel.check_bounds(sympy_var, size, assert_lower, assert_upper) + return sympy_var + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + return self.kernel.check_bounds(expr, size, lower, upper) + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + if name in self.kernel.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_is_type(index, SymT.TMP): + return self.kernel.indirect_load(name, index) + store_cache = self.kernel.cse.store_cache + if name in store_cache: + return store_cache[name] + out = self.kernel.load(name, index) + # count load that is not in the store_cache, and also not in the + # cse cache. + if out.use_count == 1: + self.kernel.num_load += 1 + return out + + def _update_store_cache(self, name: str, value: CSEVariable) -> None: + self.kernel.cse.store_cache[name] = value + if self.kernel.current_node and name in V.graph.name_to_buffer: + buf = self.kernel.current_node.get_output(name) + for other_name in buf.get_mutations(): + self.kernel.cse.store_cache[other_name] = value + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.kernel.store_buffer_names.add(name) + if mode is None: + self._update_store_cache(name, value) + if name not in V.graph.removed_buffers: + self.kernel.store(name, index, value, mode=mode) + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + self.kernel.store_buffer_names.add(name) + self._update_store_cache(name, value) + + if name not in V.graph.removed_buffers: + return self.kernel.store_reduction(name, index, value) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + self.kernel.num_reduction += 1 + return self.kernel.reduction(dtype, src_dtype, reduction_type, value) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], + tuple[CSEVariable, ...], + ], + values: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + return self.kernel.scan(dtypes, combine_fn, values) + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> tuple[CSEVariable, ...]: + return self.kernel.sort(dtypes, values, stable, descending) + + def bucketize( + self, + values: CSEVariable, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + [Note: Inductor bucketize op] + + Inputs: + ------- + values: the values to be bucketized. + boundaries: a tuple containing + (a) the name of the boundaries tensor (which must be sorted, unless + the sorting tensor is present), + (b) the length of the tensor in the last dimension (i.e. the length of + one set of boundaries), + (c) the number of elements in the underlying storage (i.e. the length + of the flattened tensor, ignoring striding), and + (d) the stride of the tensor in the last dimension. + boundary_indices: indices into a flattened version of the boundaries + tensor, of the same size and shape as "values". Each index points to + the first element in the set of boundaries to be used for the + corresponding value. + indexing_dtype: the dtype to use when indexing into the boundaries + tensor. This must be int64 or int32. This additionally specifies the + dtype of the return value. + right: see "Details" below. + sorter: an optional tuple containing + (a) the name of an optional sorting tensor, used to access unsorted + boundaries without reordering the boundaries tensor, and + (b) the stride of the tensor in the last dimension. + The values in the sorting tensor are used as indices into the *last* + dimension of the boundaries tensor, with all other indices matching. + The size of the sorting and boundaries tensors must be equivalent. + sorter_indices: must be present if the sorting array is present; see + "boundary_indices" for the equivalent definition for the boundaries + tensor. + + Output: + ------- + The buckets each value belongs in, within a given set of boundaries. 0 + indicates a position before the first boundary, and len(boundaries_set) + represents a position after the last boundary. + + Details: + -------- + Given a value and a set of boundaries, calculate the bucket that each + value belongs to. This works differently in 1-D and N-D cases. + + for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [0, 4, 4, 8], right=True + return = [[ 0, 1, 1, 1], [1, 3, 3, 4]]. + + for values [[-1, 0, 1, 2], [3, 4, 5, 9]], boundaries [[0, 4], [4, 8]], right=True + return = [[ 0, 1, 1, 1], [0, 1, 1, 2]] + + Note that in the N-D boundaries case, the shape of "values" and + "boundaries" must match in every dimension _except_ the last. + + When right == False, bucket i refers to range (boundaries[i], boundaries[i+1]]. + When right == True, bucket i refers to range [boundaries[i], boundaries[i+1]). + + Boundaries must be non-decreasing, or a sorter must be provided which + would re-index offsets in a non-decreasing order (e.g. the second output + of torch.sort(offsets)). Otherwise, the result is undefined. + """ + return self.kernel.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..a0532680a8def011eea6335d12e8cbf559e23f22 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp.py @@ -0,0 +1,5566 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import itertools +import math +import operator +import re +import sys +import warnings +from collections.abc import Sequence +from enum import Enum +from typing import Any, Callable, cast, Optional, Union + +import sympy + +import torch +import torch.fx +from torch._inductor import dependencies +from torch._prims_common import is_float_dtype, is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT + +from ..._dynamo.utils import counters +from .. import config, cpp_builder, cpu_vec_isa, ir, metrics +from ..loop_body import LoopBody +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + ExternKernelSchedulerNode, + ForeachKernelSchedulerNode, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from ..utils import ( + cache_on_self, + get_bounds_index_expr, + get_fused_kernel_name, + has_free_symbols, + is_multi_outputs_template, + is_welford_reduction, + parallel_num_threads, + Placeholder, + set_kernel_post_grad_provenance_tracing, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_product, + sympy_subs, +) +from ..virtualized import NullKernelHandler, ops, OpsValue, V +from .common import ( + BackendFeature, + BracesBuffer, + CSE, + CSEVariable, + DataTypePropagation, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + IndentedBuffer, + Kernel, + KernelArgs, + OpOverrides, + OptimizationContext, +) +from .cpp_utils import ( + _get_dtype_from_loopbodies, + _get_loop_body, + cexpr, + cexpr_index, + codegen_rand, + CppCSEVariable, + DTYPE_TO_CPP, + get_promote_dtype, + INDEX_TYPE, + LocalBufferContext, + may_unify_binary_op_mask_type, + promote_args, + template_fusion_with_epilogues_supported, + unify_mask_base_type, + value_to_cpp, +) + + +_IS_WINDOWS = sys.platform == "win32" + + +@functools.cache +def get_export_declaration(): + return "__declspec(dllexport)" if _IS_WINDOWS else "" + + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + +NATIVE_OMP_RTYPES = OrderedSet(["+", "*", "^", "||", "min", "max"]) +RTYPE_TO_CPP = { + "sum": "+", + "prod": "*", + "xor_sum": "^", + "min": "min", + "max": "max", + "argmin": "argmin", + "argmax": "argmax", + "any": "||", + "welford_reduce": "welford", + "welford_combine": "welford", +} +VECTORIZABLE_RTYPES = OrderedSet( + [ + "max", + "min", + "sum", + "prod", + "xor_sum", + "welford_reduce", + "welford_combine", + "argmin", + "argmax", + "any", + ] +) + +PYTHON_TO_CPP = { + "Tensor": "at::Tensor", + "int": "long", + "float": "double", + "bool": "bool", + "str": "std::string", + "ScalarType": "c10::ScalarType", + "MemoryFormat": "at::MemoryFormat", + "Layout": "at::Layout", + "Device": "at::Device", + "number": "at::Scalar", +} + +CONTAINER_PYTHON_TO_CPP = { + "List": "std::vector", + "Optional": "std::optional", +} + +DTYPE_LOWP_FP = [ + torch.bfloat16, + torch.float16, +] + +VECTORIZABLE_DTYPES: list[torch.dtype] = [ + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.bool, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + torch.float8_e4m3fn, + torch.float8_e5m2, +] + +MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, +] + + +def reduction_init(reduction_type, dtype): + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, the initial + # constant for reduction must be promoted as well + dtype = torch.float32 + if reduction_type in ("xor_sum", "sum", "any"): + return 0 + if reduction_type == "prod": + return 1 + if reduction_type in ("max", "argmax", "min", "argmin"): + cdtype = DTYPE_TO_CPP[dtype] + if dtype == torch.bool and reduction_type in ("argmin", "argmax"): + cdtype = DTYPE_TO_CPP[torch.float] + min_var = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + max_var = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + init_var = min_var if reduction_type in ("max", "argmax") else max_var + return ( + init_var + if reduction_type in ("max", "min") + else f"IndexValue<{cdtype}>{{0, {init_var}}}" + ) + if is_welford_reduction(reduction_type): + return f"Welford<{DTYPE_TO_CPP[dtype]}>()" + raise AssertionError(reduction_type) + + +def reduction_acc_type(reduction_type, dtype): + scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]] + if is_welford_reduction(reduction_type): + return f"Welford<{scalar_type}>" + if reduction_type in ("argmin", "argmax"): + if dtype == torch.bool: + scalar_type = DTYPE_TO_CPP[torch.float] + return f"IndexValue<{scalar_type}>" + return scalar_type + + +def reduction_combine( + reduction_type, + var, + next_value, + index: Optional[sympy.Symbol] = None, + src_dtype=None, +): + is_bool = src_dtype == torch.bool + if reduction_type == "sum": + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + if reduction_type == "prod": + return f"{var} * {next_value}" + if reduction_type == "xor_sum": + return f"{var} ^ {next_value}" + if reduction_type == "any": + return f"{var} || {next_value}" + if reduction_type in ("min", "max"): + return f"{reduction_type}_propagate_nan({var}, {next_value})" + if reduction_type == "welford_reduce": + return f"welford_combine({var}, {next_value})" + if reduction_type == "welford_combine": + if isinstance(next_value, tuple): + mean, m2, weight = next_value + else: + mean, m2, weight = reduction_project(reduction_type, next_value) + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + if reduction_type in ("argmin", "argmax"): + if ( + hasattr(next_value, "dtype") + and next_value.dtype == torch.bool + and not next_value.is_vec + ): + if index is not None: + return f"{reduction_type}_combine({var}, static_cast({next_value}), {index})" + else: + return ( + f"{reduction_type}_combine({var}, static_cast({next_value}))" + ) + if index is not None: + return f"{reduction_type}_combine({var}, {next_value}, {index})" + else: + return f"{reduction_type}_combine({var}, {next_value})" + raise AssertionError(reduction_type) + + +def reduction_project(reduction_type, acc): + if is_welford_reduction(reduction_type): + return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight" + elif reduction_type in ("argmin", "argmax"): + return f"{acc}.index" + return acc + + +def move_code_under_inner_loop( + code: IndentedBuffer, + iter_var: sympy.Expr, + new_iter_var: str, + loop_start: sympy.Expr, + loop_end: sympy.Expr, +) -> BracesBuffer: + r""" + f(iter_var) is transformed to f(new_iter_var) under the inner loop + \/ + for (new_iter_var = loop_start; new_iter_var < loop_end; new_iter_var++) { + f(new_iter_var) + } + Please be careful while using this function, + as the variable defined in f(iter_var) will be invalid outside the for loop. + For example: + auto tmp0 = in_ptr[x0]; -> + for (new_x0 = start; new_x0 < end; new_x0++){ + auto tmp0 = in_ptr[new_x0]; + } + The tmp0 is invalid outside the loop. + """ + transformed_code = BracesBuffer() + with contextlib.ExitStack() as stack: + transformed_code.writeline( + f"for ({INDEX_TYPE} {new_iter_var} = {cexpr_index(loop_start)};" + + f"{new_iter_var} < {cexpr_index(loop_end)}; {new_iter_var}++)" + ) + stack.enter_context(transformed_code.indent()) + for _, line in enumerate(code._lines): + assert isinstance( + line, + ( + str, + DeferredLine, + ), + ) + deferred_name = None + if isinstance(line, DeferredLine): + deferred_name = line.name + line = line.line + new_line = re.sub(r"\b" + f"{iter_var}" + r"\b", f"{new_iter_var}", line) + if deferred_name: + new_line = DeferredLine(deferred_name, new_line) # type: ignore[assignment] + transformed_code.writeline(new_line) + return transformed_code + + +def reduction_prefix_array( + acc_var: Union[str, CSEVariable], + acc_type: str, + reduction_type: str, + dtype: torch.dtype, + len: Union[str, int], + init_fn, +): + """ + MSVC don't support dynamic array(VLA). So we use std::unique_ptr here. + Ref: https://stackoverflow.com/questions/56555406/creating-dynamic-sized-array-using-msvc-c-compiler + MSVC is the only one compiler without VLA. support. Since MSVC can't get good performance here. + We just use unique_ptr make it works on MSVC. + For other compilers, we continue to use VLA to get best performance. + """ + code_buffer = IndentedBuffer() + acc_decl = ( + f"auto {acc_var}_arr = std::make_unique<{acc_type}[]>({len});" + if cpp_builder.is_msvc_cl() + else f"{acc_type} {acc_var}_arr[{len}];" + ) + code_buffer.writeline(f"{acc_decl}") + code_buffer.writelines( + [ + f"for (int i = 0; i < {len}; i++)", + "{", + f" {acc_var}_arr[i] = {init_fn(reduction_type, dtype)};", + "}", + ], + ) + return code_buffer + + +def replace_acc_name(buffer: IndentedBuffer, name: str, new_name: str): + for i, line in enumerate(buffer._lines): + assert isinstance( + line, + ( + str, + DeferredLine, + ), + ) + if isinstance(line, DeferredLine): + line.line = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line.line) + else: + buffer._lines[i] = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line) + + +@functools.lru_cache +def stride_at(index: sympy.Expr, var: sympy.Symbol): + if not index.has(var): + # see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu + # which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation. + # in this case, there is no dependencies between index and var. + return sympy.S.Zero + replacement = {var: var + 1} + new_index = sympy_subs(index, replacement) # type: ignore[arg-type] + return sympy.simplify(new_index - index) + + +@functools.lru_cache +def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int): + """ + Simplifies the index expression within the range of a vectorized loop. + Given a vectorized loop variable `var` in the range of a loop with `vec_length`, + this function transforms the `index` into an equivalent form. It handles + simplifications for cases where `var` can be expressed as `vec_length * a + b`, + where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences + of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations. + + NOTE: + The simplified index expression is intended for analysis purposes only, not + for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables + which are not dependent on the loop variable `var` in the vectorized range. Check + https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details. + + Examples: + 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then + `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable + when `div` is divisible by 16. + 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free + variable when `mod` is divisible by 16. + """ + + div_freevar_id = 0 + mod_freevar_id = 0 + + def visit_indexing_div(divisor): + nonlocal div_freevar_id + result = FloorDiv(var, divisor) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_div_c{div_freevar_id}") + div_freevar_id += 1 + return result + + def visit_modular_indexing(divisor, modulus): + nonlocal mod_freevar_id + result = ModularIndexing(var, divisor, modulus) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length: + result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + return result + + original_index = index + + div = sympy.Wild("divisor", integer=True) + if index.has(FloorDiv): + index = index.replace(FloorDiv(var, div), visit_indexing_div) + + mod = sympy.Wild("modulus", integer=True) + if index.has(ModularIndexing): + index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) + + index = sympy.simplify(index) + if index != original_index: + return simplify_index_in_vec_range(index, var, vec_length) + + return index + + +@functools.lru_cache +def stride_at_vec_range( + index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None +): + if vec_length: + index = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index, var) + + +@dataclasses.dataclass +class ParallelDepth: + """ + A class representing parallel depth. + Includes the starting depth of parallelism and the depth of parallelism. + """ + + parallel_depth: int + start_depth: int + + +class OuterLoopFusedSchedulerNode(FusedSchedulerNode): + @classmethod + def fuse( # type: ignore[override] + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode, outer_loop_fusion_depth + ): + assert node1.scheduler is node2.scheduler + assert all( + type(node) + in ( + OuterLoopFusedSchedulerNode, + SchedulerNode, + FusedSchedulerNode, + ) + for node in (node1, node2) + ) + if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return cls( + node1.scheduler, + ( + list(node1.get_outer_nodes()) + if type(node1) is OuterLoopFusedSchedulerNode + else [ + node1, + ] + ) + + ( + list(node2.get_outer_nodes()) + if type(node2) is OuterLoopFusedSchedulerNode + else [ + node2, + ] + ), + outer_loop_fusion_depth, + ) + else: + return cls(node1.scheduler, [node1, node2], outer_loop_fusion_depth) # type: ignore[list-item] + + def __init__( + self, + scheduler: "Scheduler", + outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]], + outer_loop_fusion_depth, + ): + self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = ( + outer_fused_nodes + ) + self.outer_loop_fusion_depth = outer_loop_fusion_depth + flatten_snodes = [] + for _node in self.outer_fused_nodes: + assert isinstance(_node, (SchedulerNode, FusedSchedulerNode)) + flatten_snodes.extend(list(_node.get_nodes())) + super().__init__(scheduler, flatten_snodes) # type: ignore[arg-type] + + def get_outer_nodes(self): + return self.outer_fused_nodes + + def check_outer_fusion_loop_level_attr( + self, cpp_kernel_proxy_list, outer_loop_fusion_depth + ): + # This function ensures that the same tiling split is applied at each loop level within the outer loop fusion depth. + # In the fusion stage, we only examine nodes with same vars and reduce. + # However, for nodes with same vars and reduce, the loops may still have different tile splits. + # For example (test_expr_vec_non_contiguous in test_cpu_repro.py): + # * buf0 tiling along the 2nd loop level, buf1 tiling along the 3rd loop level. + # If the check failed, we should fall back to standard loop codegen. + def _inner( + left_loop_nest: LoopNest, + right_loop_nest: LoopNest, + loop_fusion_depth: int, + current_checking_depth: int, + ) -> bool: + assert left_loop_nest.loops + assert right_loop_nest.loops + left_loop_level = left_loop_nest.loops[current_checking_depth] + right_loop_level = right_loop_nest.loops[current_checking_depth] + # Check if same loop level attr + outer_loops_attr_compare_list = [ + "var", + "size", + "offset", + "steps", + ] + if not ( + all( + getattr(left_loop_level, attr_compare) + == getattr(right_loop_level, attr_compare) + for attr_compare in outer_loops_attr_compare_list + ) + ): + return False + + assert loop_fusion_depth >= 1 + if (loop_fusion_depth := loop_fusion_depth - 1) > 0: + # Check next loop level attr + current_checking_depth = current_checking_depth + 1 + assert current_checking_depth < len(left_loop_nest.loops) + assert current_checking_depth < len(right_loop_nest.loops) + if not _inner( + left_loop_nest, + right_loop_nest, + loop_fusion_depth, + current_checking_depth, + ): + return False + + return True + + for idx in range(len(cpp_kernel_proxy_list) - 1): + left_loop_nest = cpp_kernel_proxy_list[idx].loop_nest + right_loop_nest = cpp_kernel_proxy_list[idx + 1].loop_nest + if not _inner( + left_loop_nest, + right_loop_nest, + outer_loop_fusion_depth, + 0, + ): + return False + + for cpp_kernel_proxy in cpp_kernel_proxy_list: + outer_ranges = functools.reduce( + operator.mul, + cpp_kernel_proxy.ranges[:outer_loop_fusion_depth], + ) + # When the range of the first inner loop is much larger than the range of + # all outer loops, do not fuse outer loop and fallback to standard loop codegen, + # so that the inner loops with larger range have a chance to be parallelized. + # We set a conservative threshold here: + # First inner loop range / all outer loops range > 300. + if ( + len(cpp_kernel_proxy.ranges) > outer_loop_fusion_depth + and isinstance(outer_ranges, sympy.Integer) + and isinstance( + cpp_kernel_proxy.ranges[outer_loop_fusion_depth], + sympy.Integer, + ) + and outer_ranges * 300 + < cpp_kernel_proxy.ranges[outer_loop_fusion_depth] + ): + return False + + return True + + def merge_outer_fusion_kernels( + self, + cpp_kernel_proxy_list, + ): + kernel_group = cpp_kernel_proxy_list[0].kernel_group + outer_loop_fused_kernel = OuterLoopFusedKernel(kernel_group) + outer_loop_fused_kernel.inner = [ + proxy.loop_nest.from_loop_level(self.outer_loop_fusion_depth) + for proxy in cpp_kernel_proxy_list + ] + outer_fused_proxy = cpp_kernel_proxy_list[0] + outer_fused_proxy.loop_nest.kernel = outer_loop_fused_kernel + outer_fused_proxy.loop_nest.loops = outer_fused_proxy.loop_nest.loops[ + : self.outer_loop_fusion_depth + ] + return outer_fused_proxy + + +class RecordOptimizationContext: + def __init__(self, func_name: str = ""): + self.func_name = func_name + self.current_node: Optional[torch.fx.Node] = None + self.opt_ctx: Optional[OptimizationContext] = None + + def __enter__(self): + assert V.interpreter + assert V.interpreter.current_node + + self.current_node = V.interpreter.current_node + assert self.current_node is not None + if OptimizationContext.key in self.current_node.meta: + self.opt_ctx = self.current_node.meta[OptimizationContext.key] + else: + self.opt_ctx = OptimizationContext() + assert self.opt_ctx is not None + self.opt_ctx.ops_name = self.func_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.current_node + assert self.opt_ctx + self.current_node.meta[OptimizationContext.key] = self.opt_ctx + + def get_opt_ctx(self): + return self.opt_ctx + + def get_fx_node(self): + assert self.current_node + return self.current_node + + +def decltype_promoted(*args): + assert not any(isinstance(arg, CppCSEVariable) and arg.is_vec for arg in args), ( + "Promotion of vector types is not supported" + ) + + if (dt := get_promote_dtype(args)) is not None: + return DTYPE_TO_CPP[dt] + else: + return f"decltype({args[0]})" + + +class CppOverrides(OpOverrides): + """Map element-wise ops to C++""" + + @staticmethod + def add(a, b): + return f"{decltype_promoted(a, b)}({a} + {b})" + + @staticmethod + def sub(a, b): + return f"{decltype_promoted(a, b)}({a} - {b})" + + @staticmethod + def mul(a, b): + return f"{decltype_promoted(a, b)}({a} * {b})" + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_types=True): + assert isinstance(x, CppCSEVariable) + if src_dtype is None: + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in DTYPE_LOWP_FP and src_dtype == torch.float: + """ + https://github.com/pytorch/pytorch/issues/115260 + For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is + in low-precision floating point data type. When the output of node1 also serves as the output of the + kernel, the result of nodes would be different from the case when output of node1 is not the output + of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on + storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type + to the cse cache. + + Example (pseudo code): + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = load(buf) + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + Without cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + With cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = node1_output # hit cse cache + """ + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def to_dtype_bitcast(x, dtype, src_dtype): + assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" + return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})" + + @staticmethod + def abs(x): + return f"std::abs({x})" + + @staticmethod + def sin(x): + return f"std::sin({x})" + + @staticmethod + def cos(x): + return f"std::cos({x})" + + @staticmethod + def neg(x): + return f"decltype({x})(-{x})" + + @staticmethod + def exp(x): + # return f"Sleef_expf_u10({x})" + return f"std::exp({x})" + + @staticmethod + def exp2(x): + return f"std::exp2({x})" + + @staticmethod + def expm1(x): + return f"std::expm1({x})" + + @staticmethod + def erf(x): + return f"std::erf({x})" + + @staticmethod + def erfc(x): + return f"std::erfc({x})" + + @staticmethod + def erfinv(x): + return f"calc_erfinv({x})" + + @staticmethod + def sqrt(x): + return f"std::sqrt({x})" + + @staticmethod + def rsqrt(x): + return f"1 / std::sqrt({x})" + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::log1p({x})" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def tan(x): + return f"std::tan({x})" + + @staticmethod + def tanh(x): + return f"std::tanh({x})" + + @staticmethod + def signbit(x): + """ + On windows std::signbit only support float type. + Ref: https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/signbit?view=msvc-170 + """ + return ( + f"std::signbit(static_cast({x}))" + if _IS_WINDOWS + else f"std::signbit({x})" + ) + + @staticmethod + def pow(a, b): + return f"std::pow({a}, {b})" + + @staticmethod + def log(x): + return f"std::log({x})" + + @staticmethod + def round(x): + return f"std::nearbyint({x})" + + @staticmethod + def floor(x): + return f"std::floor({x})" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + quot = f"{a} / {b}" + rem = f"{a} % {b}" + return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" + + @staticmethod + def ceil(x): + return f"std::ceil({x})" + + @staticmethod + def trunc(x): + return f"std::trunc({x})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def fmod(a, b): + return f"std::fmod({a}, {b})" + + @staticmethod + def isinf(x): + return f"std::isinf({x})" + + @staticmethod + def isnan(x): + return f"std::isnan({x})" + + @staticmethod + def lgamma(x): + return f"std::lgamma({x})" + + @staticmethod + def acos(x): + return f"std::acos({x})" + + @staticmethod + def acosh(x): + return f"std::acosh({x})" + + @staticmethod + def cosh(x): + return f"std::cosh({x})" + + @staticmethod + def sinh(x): + return f"std::sinh({x})" + + @staticmethod + def asin(x): + return f"std::asin({x})" + + @staticmethod + def asinh(x): + return f"std::asinh({x})" + + @staticmethod + def atan2(x, y): + return f"std::atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"std::atan({x})" + + @staticmethod + def atanh(x): + return f"std::atanh({x})" + + @staticmethod + def copysign(x, y): + return f"std::copysign({x}, {y})" + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(V.kernel.cse.try_get(cache_key) is not None for cache_key in cache_keys): + return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys) + + code = BracesBuffer() + exponent = V.kernel.cse.newvar(dtype=torch.int32) + mantissa = V.kernel.cse.newvar(dtype=x.dtype) + code.writeline(f"int32_t {exponent};") + code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.put(cache_key, cse_var) + return mantissa, exponent + + @staticmethod + def hypot(x, y): + return f"std::hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"std::log10({x})" + + @staticmethod + def log2(x): + return f"std::log2({x})" + + @staticmethod + def nextafter(x, y): + return f"std::nextafter({x}, {y})" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::max({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"min_propagate_nan({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"max_propagate_nan({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"{a} ? {b} : {c}" + + @staticmethod + def mod(a, b): + return f"mod({a}, {b})" + + @staticmethod + def constant(val, dtype): + return value_to_cpp(val, DTYPE_TO_CPP[dtype]) + + @staticmethod + def index_expr(expr, dtype): + idx_str = cexpr(V.kernel.rename_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return ops.to_dtype(var, dtype) + + @staticmethod + def masked(mask, body, other): + code = BracesBuffer() + + # Write masked operation into a lambda + body_var = V.kernel.cse.newvar() + code.writeline(f"auto {body_var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + # Use the lambda's return type as the type of other + other_code = value_to_cpp(other, f"decltype({body_var}())") + return f"{mask} ? {body_var}() : {other_code}" + + @staticmethod + def logical_and(a, b): + return f"{a} && {b}" + + @staticmethod + def logical_not(a): + return f"!{a}" + + @staticmethod + def logical_or(a, b): + return f"{a} || {b}" + + @staticmethod + def logical_xor(a, b): + return f"{a} != {b}" + + @staticmethod + def bitwise_and(a, b): + return f"decltype({a})({a} & {b})" + + @staticmethod + def bitwise_not(a): + return f"decltype({a})(~{a})" + + @staticmethod + def bitwise_or(a, b): + return f"decltype({a})({a} | {b})" + + @staticmethod + def bitwise_xor(a, b): + return f"decltype({a})({a} ^ {b})" + + @staticmethod + def bitwise_left_shift(a, b): + code = BracesBuffer() + code.writeline("[&]()") + with code.indent(): + scalar_t = DTYPE_TO_CPP[a.dtype] + code.writeline( + f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT;" + ) + code.writeline( + f"if ((static_cast>({b}) < 0) || ({b} >= max_shift))" + ) + with code.indent(): + code.writeline(f"return decltype({a})(0);") + code.writeline( + f"return decltype({a})(static_cast>({a}) << {b});" + ) + code.writeline("()") + return code + + @staticmethod + def bitwise_right_shift(a, b): + code = BracesBuffer() + code.writeline("[&]()") + with code.indent(): + scalar_t = DTYPE_TO_CPP[a.dtype] + code.writeline( + f"constexpr decltype({b}) max_shift = sizeof({scalar_t}) * CHAR_BIT - std::is_signed_v<{scalar_t}>;" + ) + code.writeline( + f"if ((static_cast>({b}) < 0) || ({b} >= max_shift))" + ) + with code.indent(): + code.writeline(f"return decltype({a})({a} >> max_shift);") + code.writeline(f"return decltype({a})({a} >> {b});") + code.writeline("()") + return code + + @staticmethod + def rand(seed: sympy.Expr, offset: sympy.Expr): + return f"normalized_rand_cpu({seed}, {offset})" + + @staticmethod + def randn(seed: sympy.Expr, offset: sympy.Expr): + return f"randn_cpu({seed}, {offset})" + + @staticmethod + def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high): + return f"randint64_cpu({seed}, {offset}, {low}, {high})" + + @staticmethod + def sigmoid(x): + return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))" + + @staticmethod + def sign(x): + code = BracesBuffer() + scalar_zero = f"decltype({x})(0)" + scalar_one = f"decltype({x})(1)" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};") + code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};") + code.writeline("return left - right;") + code.writeline("()") + return code + + +CppOverrides._initialize_pointwise_overrides("cpp") + + +class CppVecOverrides(CppOverrides): + """Map element-wise ops to aten vectorization C++""" + + def __new__(cls, *args, **kargs): + self = super().__new__(cls) + + def wrap(func): + # `CppVecKernel` generates both scalar ops and vector ops according to + # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` + # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in + # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to + # `CppOverrides` when all inputs are scalars. + # + # Notes on ops handled separately in their own functions: + # `ops.masked`: + # needs recursive handling of masked body. + # `ops.index_expr`: + # needs to further analyze the dependency of the index expression on + # the tiling itervar. + def wrapper(*args, **kwargs): + scalars = [ + arg + for arg in args + if isinstance(arg, (int, sympy.Expr)) + or (isinstance(arg, CppCSEVariable) and not arg.is_vec) + ] + vectors = [ + arg + for arg in args + if isinstance(arg, CppCSEVariable) and arg.is_vec + ] + new_args = list(args) + if scalars and vectors: + new_args = [] + for arg in args: + if isinstance(arg, (int, sympy.Expr)): + if isinstance(arg, sympy.Expr) and not arg.is_number: + arg = ops.index_expr(arg, torch.int64) + else: + arg = ops.constant(arg, torch.int64) + arg = arg.value if isinstance(arg, OpsValue) else arg + new_args.append(arg) + + # DType Promotion + if vectors: + # We have saw several data type mismatch issues related with index_expr in + # the lowering phase of torch.int8. torch.int32, torch.int64. + # 1. int32 and int64 in test_torchinductor.py::test_max_pool2d_with_indices_backward3_cpu + # 2. int8 and int32 in test_torchinductor.py::test_max_pool2d5_cpu + # 3. int32 and fp32 in test_torchinductor_dynamic_shapes.py::test_avg_pool2d8_dynamic_shapes_cpu + if len(new_args) == 2: + new_args = promote_args(new_args) + elif func == CppVecOverrides.where: + new_args[1:] = promote_args(new_args[1:]) + + # Broadcast scalar args to vector + if scalars and vectors: + assert isinstance(V.kernel, CppVecKernel) + new_args = [ + ( + V.kernel.broadcast(new_arg) + if ( + isinstance(new_arg, CppCSEVariable) + and not new_arg.is_vec + and func + not in [ + CppVecOverrides.rand, + CppVecOverrides.randn, + CppVecOverrides.randint64, + ] + ) + else new_arg + ) + for new_arg in new_args + ] + + if vectors: + return func(*new_args, **kwargs) + else: + # fallback to scalar ops + scalar_ops = super(CppVecOverrides, self) + scalar_func = getattr(scalar_ops, func.__name__) + assert scalar_func is not None + return scalar_func(*args, **kwargs) + + return wrapper + + for name, method in vars(CppVecOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in [ + "masked", + "index_expr", + ]: + setattr(self, name, wrap(method.__func__)) + + return self + + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def truediv(a, b): + return f"{a} / {b}" + + @staticmethod + def abs(x): + return f"{x}.abs()" + + @staticmethod + def sin(x): + return f"{x}.sin()" + + @staticmethod + def cos(x): + return f"{x}.cos()" + + @staticmethod + def exp(x): + return f"{x}.exp()" + + @staticmethod + def exp2(x): + return f"{x}.exp2()" + + @staticmethod + def expm1(x): + # decompose for a better performance + vec_one = f"decltype({x})(1)" + return f"{x}.exp() - {vec_one}" + + @staticmethod + def erf(x): + return f"{x}.erf()" + + @staticmethod + def erfc(x): + return f"{x}.erfc()" + + @staticmethod + def erfinv(x): + return f"{x}.erfinv()" + + @staticmethod + def sqrt(x): + return f"{x}.sqrt()" + + @staticmethod + def eq(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} == {y})" + + @staticmethod + def ne(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + if x.dtype == torch.bool: + assert y.dtype == torch.bool + x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y)) + return f"{x_cast} != {y_cast}" + else: + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" + + @staticmethod + def lt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} < {y})" + + @staticmethod + def gt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} > {y})" + + @staticmethod + def le(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} <= {y})" + + @staticmethod + def ge(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} >= {y})" + + @staticmethod + def and_(x, y): + return f"{x} & {y}" + + @staticmethod + def rsqrt(x): + return f"{x}.rsqrt()" + + @staticmethod + def pow(a, b): + return f"{a}.pow({b})" + + @staticmethod + def log(x): + return f"{x}.log()" + + @staticmethod + def round(x): + return f"{x}.round()" + + @staticmethod + def floor(x): + return f"{x}.floor()" + + @staticmethod + def ceil(x): + return f"{x}.ceil()" + + @staticmethod + def trunc(x): + return f"{x}.trunc()" + + @staticmethod + def fmod(a, b): + return f"{a}.fmod({b})" + + @staticmethod + def lgamma(x): + return f"{x}.lgamma()" + + @staticmethod + def logical_and(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"~{a}" + + @staticmethod + def logical_or(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} ^ {b}" + + @staticmethod + def bitwise_and(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + a, b = may_unify_binary_op_mask_type(a, b) + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def load_seed(name, offset): + assert isinstance(V.kernel, CppVecKernel) + return f"{V.kernel.load(name, offset)}" + + @staticmethod + def rand(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = ( + f"result[offset_idx] = normalized_rand_cpu({seed}, offset[offset_idx]);" + ) + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randn(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randn_cpu({seed}, offset[offset_idx]);" + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randint64(seed, offset, low, high): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randint64_cpu({seed}, offset[offset_idx], {low}, {high});" + return codegen_rand(offset, code, rand_function, torch.int64) + + @staticmethod + def remainder(a, b): + assert a.dtype == b.dtype, ( + "remainder vec implementation expect the same inputs' dtype." + ) + return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}" + + @staticmethod + def tan(a): + return f"{a}.tan()" + + @staticmethod + def tanh(a): + if config.cpp.use_decompose_tanh: + vec_one = f"decltype({a})(1)" + vec_two = f"decltype({a})(2)" + vec_minus_two = f"decltype({a})(-2)" + return ( + f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}" + ) + else: + return f"{a}.tanh()" + + @staticmethod + def reciprocal(a): + return f"{a}.reciprocal()" + + @staticmethod + def atan(x): + return f"{x}.atan()" + + @staticmethod + def acos(x): + return f"{x}.acos()" + + @staticmethod + def asin(x): + return f"{x}.asin()" + + @staticmethod + def cosh(x): + return f"{x}.cosh()" + + @staticmethod + def sinh(x): + return f"{x}.sinh()" + + @staticmethod + def log10(x): + return f"{x}.log10()" + + @staticmethod + def log2(x): + return f"{x}.log2()" + + @staticmethod + def nextafter(x, y): + return f"{x}.nextafter({y})" + + @staticmethod + def copysign(a, b): + return f"{a}.copysign({b})" + + @staticmethod + def atan2(a, b): + return f"{a}.atan2({b})" + + @staticmethod + def hypot(a, b): + return f"{a}.hypot({b})" + + @staticmethod + def atanh(x): + # For real x, atanh(x) = 1/2 * log((1+x)/(1-x)) + vec_one = f"decltype({x})(1)" + vec_one_half = f"decltype({x})(0.5)" + return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()" + + @staticmethod + def asinh(x): + return f"{x}.asinh()" + + @staticmethod + def acosh(x): + return f"{x}.acosh()" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"at::vec::clamp_min({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + # TODO: this seems to be dead + @staticmethod + def sigmoid(x): + return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" + + @staticmethod + def neg(x): + return f"{x}.neg()" + + @staticmethod + def floordiv(a, b): + if is_float_dtype(a.dtype): + assert a.dtype == b.dtype, ( + "div_floor_floating_vec implementation expect the same inputs' dtype." + ) + return f"div_floor_floating_vec({a}, {b})" + else: + assert all(is_integer_dtype(item.dtype) for item in [a, b]) + # a and b are integer type + _t = f"decltype({a})" + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + quot = f"{a} / {b}" + has_rem = f"({a} % {b} != {_t}(0))" + is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))" + return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + _t = f"decltype({b})" + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + return f"{a} / {b}" + + @staticmethod + def minimum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} & {b_cast}" + else: + return f"at::vec::minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} | {b_cast}" + else: + return f"at::vec::maximum({a}, {b})" + + @staticmethod + def square(a): + return f"{a} * {a}" + + @staticmethod + def where(a, b, c): + assert isinstance(V.kernel, CppVecKernel) + if b.dtype == torch.bool: + assert c.dtype == torch.bool + blendv_a, blendv_b, blendv_c = unify_mask_base_type( + V.kernel.compute, (a, b, c) + ) + return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" + else: + return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})" + + @staticmethod + def sign(x): + code = BracesBuffer() + vec_zero = f"decltype({x})(0)" + vec_one = f"decltype({x})(1)" + blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" + blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {blendv_l};") + code.writeline(f"auto right = {blendv_r};") + code.writeline("return left - right;") + code.writeline("()") + return code + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True): + assert dtype in [ + torch.bool, + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + torch.float8_e4m3fn, + torch.float8_e5m2, + ], f"{__name__} does not support {dtype}" + assert isinstance(x, CppCSEVariable) + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in DTYPE_LOWP_FP and src_dtype == torch.float: + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"{x}.log1p()" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def masked(mask, body, other): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + var = V.kernel.cse.newvar() + with V.kernel.masked(mask) as new_mask: + code.writeline(f"auto {var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + dtype = result.dtype + body_code = f"{var}()" + + def maskify_or_vecify(code): + return ( + f"{V.kernel._get_mask_type()}::from({code})" + if dtype == torch.bool + else f"{V.kernel._get_vec_type(dtype)}({code})" + ) + + if result.is_vec: + body_code_vec = body_code + else: + body_code_vec = maskify_or_vecify(body_code) + other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype]) + # loading bool as VecMask + other_code_vec = maskify_or_vecify(other_code) + assert isinstance(new_mask, CppCSEVariable), new_mask + if new_mask.is_vec: + code = BracesBuffer() + code.writeline("[&]") + with V.kernel.swap_buffers(code), code.indent(): + code.writeline(f"if ({new_mask}.all_zero())") + with code.indent(): + code.writeline(f"return {other_code_vec};") + code.writeline("else") + with code.indent(): + # Create cse variable to reuse kernel.overrides.where + body_vec_var = V.kernel.cse.generate( + V.kernel.compute, + body_code_vec, + ) + other_vec_var = V.kernel.cse.generate( + V.kernel.compute, + other_code_vec, + ) + assert isinstance(body_vec_var, CppCSEVariable), body_vec_var + assert isinstance(other_vec_var, CppCSEVariable), other_vec_var + body_vec_var.dtype = dtype + other_vec_var.dtype = dtype + overrides: type[Union[CppOverrides, CppVecOverrides]] = ( + V.kernel.overrides + ) # type: ignore[has-type] + code.writeline( + f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};" + ) + code.writeline("()") + csevar = V.kernel.cse.generate( + V.kernel.compute, + code, + ) + elif result.is_vec: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}" + ) + else: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code} : {other_code}" + ) + # `result` is explicitly added to the args for correct propagation + # of relevant itervars and vectorization status. + csevar.update_on_args("masked", (mask, body, other, result), {}) + return csevar + + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppVecKernel) + index = V.kernel.rename_indexing(expr) + tiling_var = V.kernel.itervars[V.kernel.tiling_idx] + stride = V.kernel._try_get_const_stride(index, tiling_var) + if stride == 0: + return CppOverrides.index_expr(expr, dtype) + elif stride is not None: + idx = V.kernel.cse.generate( + V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr) + ) + value = ops.to_dtype(idx, dtype) + if isinstance(value, OpsValue): + value = value.value + csevar = V.kernel.arange(value, stride) + else: + csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment] + None, index, dtype, V.kernel.compute + ) + csevar.update_on_args("index_expr", (expr, dtype), {}) + return csevar + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(V.kernel.cse.try_get(cache_key) is not None for cache_key in cache_keys): + return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys) + + cdtype = DTYPE_TO_CPP[x.dtype] + size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor + code = BracesBuffer() + exponent = V.kernel.cse.newvar(dtype=torch.int32) + mantissa = V.kernel.cse.newvar(dtype=x.dtype) + exponent.update_on_args("frexp", (x,), kwargs={}) + mantissa.update_on_args("frexp", (x,), kwargs={}) + n_vec = V.kernel._get_num_vectors(x.dtype) + mantissa_t = ( + f"at::vec::Vectorized<{cdtype}>" + if n_vec == 1 + else f"at::vec::VectorizedN<{cdtype}, {n_vec}>" + ) + code.writeline( + f"at::vec::Vectorized {exponent};" + if n_vec == 1 + else f"at::vec::VectorizedN {exponent};" + ) + code.writeline(f"{mantissa_t} {mantissa};") + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;" + ) + code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});") + code.writeline( + f"__at_align__ std::array tmpbuf_exponent;" + ) + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;" + ) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline( + "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);" + ) + code.writeline( + f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + if n_vec == 1 + else f"{exponent} = at::vec::VectorizedN::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + ) + code.writeline( + f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});" + ) + code.writeline("();") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.put(cache_key, cse_var) + return mantissa, exponent + + @classmethod + def _scalarize(cls, scalar_func): + def inner(*args, **kwargs): + assert not kwargs + kernel = V.kernel + assert isinstance(kernel, CppVecKernel) + code = BracesBuffer() + code.writeline("[&]()") + vec_dtype = args[0].dtype + n_vec = kernel._get_num_vectors(vec_dtype) + size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor + scalar_args = [] + cdtype = DTYPE_TO_CPP[vec_dtype] + output_mask = scalar_func.__name__ in ( + "isinf", + "isnan", + "signbit", + ) + octype = "bool" if output_mask else cdtype + octype = ( + DTYPE_TO_CPP[args[-2]] + if (scalar_func.__name__ == "to_dtype_bitcast") + else octype + ) + with code.indent(): + for argidx, arg in enumerate(args): + if isinstance(arg, CppCSEVariable): + assert arg.is_vec + assert arg.dtype == vec_dtype + code.writeline( + f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};" + ) + code.writeline( + f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});" + ) + scalar_args.append(f"tmpbuf{argidx}[i]") + else: + scalar_args.append(arg) + code.writeline( + f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;" + ) + res = scalar_func(*scalar_args) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline(f"tmpbuf_out[i] = {res};") + if output_mask: + assert not kernel.tail_size + load_args = "tmpbuf_out.data()" + load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" + else: + load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" + if n_vec == 1: + load_fn = f"at::vec::Vectorized<{octype}>::loadu" + else: + load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" + code.writeline(f"return {load_fn}({load_args});") + code.writeline("()") + return code + + return inner + + @classmethod + def _initialize_scalarize(cls): + vec_vars = vars(CppVecOverrides) + for name, method in vars(CppOverrides).items(): + if isinstance(method, staticmethod) and name not in vec_vars: + func = cls._scalarize(method.__func__) + func.__name__ = name + setattr(cls, name, staticmethod(func)) + + +CppVecOverrides._initialize_pointwise_overrides("cppvec") +CppVecOverrides._initialize_scalarize() + + +class CppTile2DOverrides(CppVecOverrides): + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppTile2DKernel) + expr = V.kernel.transform_indexing(expr) + return CppVecOverrides.index_expr(expr, dtype) + + +class CppKernel(Kernel): + overrides = CppOverrides # type: ignore[assignment] + sexpr = cexpr + newvar_prefix = "auto " + suffix = ";" + + def __init__(self, args, num_threads): + super().__init__(args) + # Indicate when this kernel is active, for example + # {x0, {24, 26}} -> this kernel is active when x0 >= 24 and x0 < 26 + self.active_ranges: dict[sympy.Expr, tuple[sympy.Expr, ...]] = {} + # Indicate this kernel will be moved under the inner for-loop + # See move_code_under_inner_loop + self.inner_itervars: list[sympy.Symbol] = [] + self.call_ranges: Optional[tuple[sympy.Expr, ...]] = None + self.ranges: list[sympy.Expr] = [] + self.itervars: list[sympy.Symbol] = [] + self.reduction_depth = None + self.reduction_prefix = IndentedBuffer() + # We need this because when we run "reduction" nodes here, we lack + # "loop" information to decide whether we need a scalar init or an array init + # in the reduction prefix. Meanwhile, we have other information like + # reduction types and dtype to generate the reduction prefix. We record the information + # with a callable lambda function, and when we have enough information to finalize + # the reduction prefix, we can invoke the functions here with additional information. + self.reduction_prefix_generators: list[Callable] = [] # type: ignore[type-arg] + self.reduction_suffix = IndentedBuffer() + self.parallel_reduction_prefix = IndentedBuffer() + self.parallel_reduction_suffix = IndentedBuffer() + self.local_reduction_init = IndentedBuffer() + self.local_reduction_stores = IndentedBuffer() + self.is_reduction = False + self.non_parallel_reduction_prefix = IndentedBuffer() + self.non_parallel_reduction_suffix = IndentedBuffer() + self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") + self.welford_helper_cse = CSE( + self.newvar_prefix, self.suffix, name_prefix="welford_helper" + ) + self.preloads = IndentedBuffer() + self.poststores = IndentedBuffer() + self.num_threads = num_threads # num_threads the kernel specialized for + self.reduction_omp_dec: dict[tuple[str, str], str] = {} + self.reduction_var_names: list[str] = [] + + def _gen_parallel_reduction_buffers( + self, + acc, + acc_type, + reduction_type, + dtype, + reduction_combine_fn=reduction_combine, + reduction_init_fn=reduction_init, + ): + if config.cpp.dynamic_threads and not self.parallel_reduction_prefix: + self.parallel_reduction_prefix.writeline( + "int max_threads = omp_get_max_threads();" + ) + acc_local = f"{acc}_local" + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + acc_local_in_array = f"{acc}_arr[tid]" + self.local_reduction_init.writeline( + f"{acc_type} {acc_local} = {reduction_init_fn(reduction_type, dtype)};" + ) + self.parallel_reduction_prefix.splice( + reduction_prefix_array( + acc, + acc_type, + reduction_type, + dtype, + num_threads, + reduction_init_fn, + ) + ) + self.local_reduction_stores.writeline(f"{acc_local_in_array} = {acc_local};") + self.parallel_reduction_suffix.writelines( + [ + f"for (int tid = 0; tid < {num_threads}; tid++)", + "{", + f" {acc} = {reduction_combine_fn(reduction_type, acc, acc_local_in_array, src_dtype=dtype)};", + "}", + ], + ) + + def update_stores_with_parallel_reduction(self): + for var_name in self.reduction_var_names: + replace_acc_name(self.stores, var_name, f"{var_name}_local") + + def gen_body(self, code: Optional[BracesBuffer] = None): + assert code is None + code = BracesBuffer() + with contextlib.ExitStack() as stack: + if hasattr(self, "codegen_inner_loops"): + code.splice(self.preloads) + self.codegen_inner_loops(code) + stack.enter_context(code.indent()) + code.splice(self.loads) + code.splice(self.compute) + code.splice(self.stores) + if hasattr(self, "codegen_inner_loops"): + code.splice(self.poststores) + + if self.inner_itervars: + for idx in self.inner_itervars: + start, end = self.active_ranges[idx] + code = move_code_under_inner_loop(code, idx, f"{idx}_tail", start, end) + return code + + @contextlib.contextmanager + def masked(self, mask): + """Context manager to add an additional mask to loads and stores.""" + prior = self._load_mask + if prior: + mask = ops.and_(mask, prior) + if isinstance(mask, OpsValue): + mask = mask.value + assert isinstance(mask, CppCSEVariable) + # see NOTE [dtype of CppCSEVariable] + # mask's dtype should be bool + mask.dtype = torch.bool + + self._load_mask = mask + try: + yield mask + finally: + self._load_mask = prior + + def scale_index_with_offset( + self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0 + ): + var = self.itervars[itervar_idx] + replacement = {var: var * scale + offset} + new_index = sympy_subs(index, replacement) + return new_index + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in cpp code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel. + """ + return cexpr(self.rename_indexing(index)) + + def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + """ + Check if an index has free symbol CppCSEVariable that depends on `itervar`. + """ + return any( + self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined] + for s in index.free_symbols + if s.name in self.cse.varname_map # type: ignore[attr-defined] + and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined] + ) + + def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + return itervar in index.free_symbols or self.index_indirect_depends_on( + index, itervar + ) + + def var_ranges(self): + return dict(zip(self.itervars, self.ranges)) + + def check_bounds( + self, + expr: sympy.Expr, + size: sympy.Expr, + lower: bool, + upper: bool, + ): + if not (lower or upper): + return + + indirect = free_symbol_is_type(expr, SymT.TMP) + if indirect: + # indexing in compute + csevar = ops.index_expr(expr, torch.int64).value + buffer = V.kernel.compute + else: + # indexing in loads + prior_compute = V.kernel.compute + try: + V.kernel.compute = self.loads + csevar = ops.index_expr(expr, torch.int64).value + finally: + V.kernel.compute = prior_compute + buffer = self.loads + + size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None + + line = self.indirect_assert( + csevar, "0" if lower else None, size_str, self._load_mask + ) + self.cse.generate(buffer, line, assignment=False) + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + line = f"{var}[{cexpr_index(index)}]" + csevar = self.cse.generate(self.loads, line, dtype=V.graph.get_dtype(name)) + csevar.update_on_args("load", (self, name, index), {}) + return csevar + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + index = self.rename_indexing(index) + if mode is None: + line = f"{var}[{cexpr_index(index)}] = {value};" + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + line = f"{var}[{cexpr_index(index)}] += {value};" + else: + dtype = V.graph.get_dtype(name) + # mirroring static_cast(...) in load: + value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})" + line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});" + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + + def _gen_reduction_prefix( + self, + acc: Union[CSEVariable, str], + acc_type: str, + rtype: str, + dtype: torch.dtype, + init_fn, + ): + # Generate reduction prefix + # If size is None, we will define and initialize a single reduction variable + # => float tmp_acc0 = 0; + # Otherwise, we will define and initialize a reduction array + # => float tmp_acc0_arr[size]; + # => for (int i = 0; i < size; i++) tmp_acc0_arr[i] = 0; + def inner(size: Optional[int] = None): + if size is None: + return f"{acc_type} {acc} = {init_fn(rtype, dtype)};" + else: + return reduction_prefix_array( + acc, + acc_type, + rtype, + dtype, + size, + init_fn, + ) + + return inner + + def finalize_reduction_prefix(self, size: Optional[int] = None): + for gen_fn in self.reduction_prefix_generators: + self.reduction_prefix.splice(gen_fn(size)) + + def reduction(self, dtype, src_dtype, reduction_type, value): + argmax_or_argmin = reduction_type in ("argmax", "argmin") + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + self.reduction_var_names.append(f"{acc}") + self.is_reduction = True + init_dtype = src_dtype if argmax_or_argmin else dtype + acc_type = reduction_acc_type(reduction_type, init_dtype) + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + acc, acc_type, reduction_type, init_dtype, reduction_init + ) + ) + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value, index)};" + ) + + self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype) + result = reduction_project(reduction_type, acc) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + self.reduction_suffix.writeline( + DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};") + ) + + def set_ranges(self, lengths, reduction_lengths): + if self.call_ranges: + assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), ( + f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + ) + assert self.reduction_depth == len(lengths) + else: + self.call_ranges = tuple(lengths) + tuple(reduction_lengths) + self.ranges = [self.rename_indexing(x) for x in self.call_ranges] + self.itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(self.ranges)) + ] + self.reduction_depth = len(lengths) + return ( + self.itervars[: self.reduction_depth], + self.itervars[self.reduction_depth :], + ) + + def size_hint(self): + assert self.call_ranges is not None + return V.graph.sizevars.size_hint( + sympy_product(self.call_ranges), fallback=8192 + ) + + def codegen_loops_impl(self, loop_nest, code, worksharing): + assert isinstance(self, CppKernelProxy) + threads = parallel_num_threads() + assert self.call_ranges is not None + if isinstance(loop_nest.kernel, OuterLoopFusedKernel): + par_depth = loop_nest.kernel.decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + else: + par_depth = self.decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + + is_reduction_loop = ( + loop_nest.loops is not None + and loop_nest.loops[par_depth.start_depth].is_reduction + ) + with contextlib.ExitStack() as stack: + if par_depth.parallel_depth: + if is_reduction_loop: + # need to close the worksharing scope to define reduction vars outside it + worksharing.close() + else: + worksharing.parallel(threads) + loop_nest.mark_parallel(par_depth) + elif threads > 1: + if worksharing.single(): + stack.enter_context(code.indent()) + + def gen_kernel(_loop_nest: LoopNest): + def is_parallel_reduction(): + assert _loop_nest.loops + root = _loop_nest.loops[par_depth.start_depth] + return root.is_reduction and root.parallel + + kernel = _loop_nest.get_kernel() + if isinstance(kernel, OuterLoopFusedKernel): + for _loop_nest in kernel.inner: + gen_loop_nest(_loop_nest) + else: + assert isinstance(kernel, CppKernelProxy) + if _loop_nest.loops is not None and is_parallel_reduction(): + kernel.update_stores_with_parallel_reduction() + with contextlib.ExitStack() as stack: + stack.enter_context(code.indent()) + kernel.gen_body(code) + + def get_reduction_prefix_suffix(kernel, parallel=False, is_suffix=False): + if is_suffix: + suffix = kernel.reduction_suffix + if parallel: + suffix = kernel.parallel_reduction_suffix + suffix + else: + suffix = kernel.non_parallel_reduction_suffix + suffix + return suffix + else: + prefix = kernel.reduction_prefix + if parallel: + prefix = prefix + kernel.parallel_reduction_prefix + else: + prefix = prefix + kernel.non_parallel_reduction_prefix + return prefix + + def gen_loop_with_reduction( + _loop_nest: LoopNest, depth: int = 0, in_reduction=False + ): + kernel = _loop_nest.get_kernel() + assert _loop_nest.loops + loop = _loop_nest.loops[depth] + with contextlib.ExitStack() as stack_outer: + if loop.is_reduction and not in_reduction: + reduction_prefix = get_reduction_prefix_suffix( + kernel, loop.parallel, is_suffix=False + ) + if reduction_prefix: + stack_outer.enter_context(code.indent()) + code.splice(reduction_prefix) + if is_reduction_loop and loop.parallel: + worksharing.parallel(threads) + if kernel.local_reduction_init: + assert kernel.local_reduction_stores + code.splice(kernel.local_reduction_init) + + gen_loop_at(_loop_nest, depth) + + if is_reduction_loop and loop.parallel: + if kernel.local_reduction_stores: + code.splice(kernel.local_reduction_stores) + worksharing.close() + if loop.is_reduction and not in_reduction: + code.splice( + get_reduction_prefix_suffix( + kernel, loop.parallel, is_suffix=True + ) + ) + + def gen_loop_at(_loop_nest: LoopNest, depth: int = 0): + with contextlib.ExitStack() as stack: + assert _loop_nest.loops + loop = _loop_nest.loops[depth] + loop_lines = loop.lines() + if loop_lines is None: + return + code.writelines(loop_lines) + stack.enter_context(code.indent()) + gen_loop_nest(_loop_nest, depth + 1, loop.is_reduction) + + def gen_loop_nest( + _loop_nest: LoopNest, + depth: int = 0, + in_reduction: bool = False, + ): + if _loop_nest.loops is None or depth == len(_loop_nest.loops): # type: ignore[arg-type] + gen_kernel(_loop_nest) + else: + gen_loop_with_reduction(_loop_nest, depth, in_reduction) + + stack.enter_context(code.indent()) + + if ( + isinstance(loop_nest.kernel, OuterLoopFusedKernel) + and isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + # Allocate local buffer + local_buffers = V.local_buffer_context.local_buffers + for local_buffer in local_buffers.values(): + # For dynamic size, rename s to ks + local_buf_size = sympy_product( + [ + self.rename_indexing(size_val) + for size_val in local_buffer.get_layout().size + ] + ) + local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype] + allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})" + local_buffer_name = local_buffer.get_name() + code.splice( + f"std::unique_ptr<{local_buf_dtype} []> buf_{local_buffer_name} = {allocate};" + ) + code.splice( + f"{local_buf_dtype}* {local_buffer_name} = buf_{local_buffer_name}.get();" + ) + gen_loop_nest(loop_nest) + + def codegen_loops(self, code, worksharing): + loop_nest = LoopNest.build(self) + self.codegen_loops_impl(loop_nest, code, worksharing) + + @property + def assert_function(self) -> str: + if V.graph.aot_mode: + return "AOTI_TORCH_CHECK" + else: + return "TORCH_CHECK" + + def decide_parallel_depth(self, max_parallel_depth, threads): + assert self.call_ranges is not None + ranges = self.call_ranges[ + max_parallel_depth.start_depth : ( + max_parallel_depth.start_depth + max_parallel_depth.parallel_depth + ) + ] + seq = self.size_hint() + par = 1 + depth = 0 + for expr in ranges: + hint = V.graph.sizevars.size_hint(expr, fallback=8192) + if par >= 2 * threads or par == threads: + break + if seq // threads < config.cpp.min_chunk_size: + # not enough work + break + depth += 1 + par *= hint + seq /= hint + # if we assume thread number is dynamic, make sure we + # have at least one parallel scope and let OMP runtime + # to manage the serial vs. parallel. + if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0: + depth = 1 + return ParallelDepth( + parallel_depth=depth, start_depth=max_parallel_depth.start_depth + ) + + @contextlib.contextmanager + def write_to_suffix(self): + prior = (self.loads, self.compute, self.stores, self.cse) + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.cse = self.cse.clone() + yield + self.reduction_suffix.splice(self.loads) + self.reduction_suffix.splice(self.compute) + self.reduction_suffix.splice(self.stores) + (self.loads, self.compute, self.stores, self.cse) = prior + + def create_cse_var(self, *args, **kwargs): + return CppCSEVariable(*args, **kwargs) + + def get_to_dtype_expr(self, src, dtype, src_dtype): + return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({src})" + + def cache_dtype_convert(self, dst, dst_dtype, src, src_dtype): + expr = self.get_to_dtype_expr(src, dst_dtype, src_dtype) + self.cse.put(expr, dst) + + def codegen_conditions( + self, + code: BracesBuffer, + prefix: Optional[str] = None, + var: Optional[sympy.Symbol] = None, + ): + if prefix is None: + prefix = "" + if not self.active_ranges: + return True + conditions = [] + + def gen(start, end, var): + if start == end: + return False + var_id = None + for i, _var in enumerate(self.itervars): + if var == _var: + var_id = i + break + if ( + type(self) == CppKernel + and var_id + and start == 0 + and end == self.ranges[var_id] + ): + end = 1 + conditions.append(f"{var} >= {cexpr_index(start)}") + conditions.append(f"{var} < {cexpr_index(end)}") + return True + + if var is not None: + assert var in self.active_ranges + start, end = self.active_ranges[var] + if not gen(start, end, var): + return False + else: + for _var, _range in self.active_ranges.items(): + start, end = _range + if not gen(start, end, _var): + return False + joined_conditions = " && ".join(conditions) + if joined_conditions: + code.writeline(f"if({prefix}({joined_conditions}))") + return True + else: + return False + + +class CppVecKernel(CppKernel): + overrides = CppVecOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_idx, + tail_size=None, + ): + super().__init__(args, num_threads) + self.vec_isa = cpu_vec_isa.pick_vec_isa() + assert self.vec_isa + assert tiling_factor > 0, "Expect pass in Non-Zero tiling_factor explicitly" + self.tiling_factor = tiling_factor + self.tiling_idx = tiling_idx + self.tail_size = tail_size + self.num_elems = tail_size if tail_size else tiling_factor + + def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol): + if self.index_indirect_depends_on(index, itervar): + return None + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + return None + stride = stride_at_vec_range(index, itervar, self.tiling_factor) + return stride if stride.is_number else None + + def _get_num_vectors(self, dtype: torch.dtype) -> int: + num_vectors = math.ceil( + self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + ) + assert num_vectors >= 1 + return num_vectors + + def _get_raw_num_vectors(self, dtype: torch.dtype) -> float: + # This utility function is used to check if the vector lanes has been + # fully utilized. For example, uint8 will only use 1/4 of the vector lanes. + return self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + + def _get_vec_type(self, dtype: torch.dtype) -> str: + num_vectors = self._get_num_vectors(dtype) + if num_vectors == 1: + return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" + else: + return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_type(self, dtype: torch.dtype = torch.float) -> str: + if dtype == torch.bool: + return "" + num_vectors = self._get_num_vectors(dtype) + return f"at::vec::VecMask<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_cast(self, mask: CppCSEVariable, dtype: torch.dtype) -> str: + assert mask.dtype == torch.bool, repr(mask) + num_vectors = self._get_num_vectors(dtype) + return f"{mask}.template cast<{DTYPE_TO_CPP[dtype]},{num_vectors}>()" + + def _get_vec_load_line( + self, + var: str, + index: sympy.Expr, + dtype: torch.dtype, + load_mask: Optional[CppCSEVariable] = None, + ): + """ + Get a load line str that loads a vector from `var` at `index` of type `dtype`. + If `load_mask` is not None, we do a masked load accordingly. + Notes on the `dtype`: + 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`. + It means we load half of the vector lanes for 16-bit data types and quarter of the + vector lanes for 8-bit data types. + 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors. + """ + cpp_type = DTYPE_TO_CPP[dtype] + num_vectors = self._get_num_vectors(dtype) + load_mask_str = None + if load_mask: + if not load_mask.is_vec: + # TODO: avoid hard-code torch.float + load_mask_str = f"{self._get_mask_type(torch.float)}::from({load_mask})" + else: + load_mask_str = f"{self._get_mask_cast(load_mask, torch.float)}" + loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var + if dtype == torch.bool: + # TODO: should we consider load mask here? + line = f"{self._get_mask_type()}::from({loadbuf})" + else: + line = ( + f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})" + if load_mask_str + else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})" + ) + return line + + def _load_or_store_non_contiguous( + self, + var: Optional[str], + index: sympy.Expr, + dtype: torch.dtype, + buffer: Optional[IndentedBuffer] = None, + store_value: Optional[Union[str, CppCSEVariable]] = None, + accu_store: bool = False, + ) -> Optional[CppCSEVariable]: + """ + Load or store a vector in a non-contiguous way. The vector is initialized from an array that is + filled in an inner loop over the tiling factor. + :param var: buffer to load from or store to, i.e. `var[transformed(index)]`. If None, we load the index + as index expression, i.e. `transformed(index)`. + :param index: index into the `var` or the index expression by its own if `var` is None. + The `index` could contain indirect indexing or the tiling itervar. When used in + the inner loop, the index is transformed as follows: + 1. the index is linearized along the tiling dim. + 2. the indirect indexing vector variables are transformed into arrays over the tiling dim. + :param dtype: data type of `var` or `index` if `var` is None. + :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. + :param store_value: the value to store. If None, we load the vector. + :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided + :return: a CppCSEVariable that represents the loaded vector or None if it is a store. + """ + assert not store_value or var is not None, "store var must be provided" + if accu_store: + assert store_value + if buffer is None: + buffer = self.loads + + def get_result_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.num_elems * (4 // dtype.itemsize) + else: + return self.num_elems + + def get_tiling_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.tiling_factor * (4 // dtype.itemsize) + else: + return self.tiling_factor + + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: + assert vec_var.is_vec + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + vec_dtype = vec_var.dtype + assert vec_dtype is not None + if vec_dtype == torch.bool: + vec_dtype = torch.float + result_size = get_result_size(vec_dtype) + tiling_size = get_tiling_size(vec_dtype) + code.writeline( + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;" + ) + line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});" + code.writeline(line) + code.writeline("return tmpbuf;") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + return csevar + + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + result_size = get_result_size(dtype) + tiling_size = get_tiling_size(dtype) + result_declare = ( + f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;" + ) + code.writeline(result_declare) + if store_value: + code.writeline( + f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});" + ) + itervar_inner = sympy_index_symbol( + f"{self.itervars[self.tiling_idx]}_inner" + ) + replacements = {} + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + array_var = vec_to_array(indirect_var) + replacements[indirect_var] = f"{array_var}[{itervar_inner}]" + index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=itervar_inner + ) + load_mask = None + if self._load_mask is not None: + assert not store_value, "unexpected store with load mask" + assert isinstance(self._load_mask, CppCSEVariable), self._load_mask + if self._load_mask.is_vec: + load_mask = f"{self._load_mask}.is_masked({itervar_inner})" + else: + load_mask = f"{self._load_mask} != 0" + if cpp_builder.is_gcc(): + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") + else: + code.writeline(f"#pragma unroll {self.tiling_factor}") + code.writeline( + f"for (long {itervar_inner} = 0; " + + f"{itervar_inner} < {cexpr_index(self.num_elems)}; " + + f"{itervar_inner}++)" + ) + with code.indent(), contextlib.ExitStack() as stack: + index_c = cexpr_index(index) + for indirect_var in replacements: + index_c = re.sub( + r"\b" + f"{indirect_var}" + r"\b", + replacements[indirect_var], + index_c, + ) + rhs = f"{var}[{index_c}]" if var is not None else f"{index_c}" + if load_mask: + code.writeline(f"if ({load_mask})") + stack.enter_context(code.indent()) + if store_value: + conjunction = "+=" if accu_store else "=" + code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];") + else: + code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") + if not store_value: + load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type] + code.writeline(f"return {load_line};") + code.writeline("()") + if store_value: + code.writeline(";") + buffer.splice(code) + return None + else: + csevar = self.cse.generate(buffer, code, dtype=dtype) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + tiling_var = self.itervars[self.tiling_idx] + stride = self._try_get_const_stride(index, tiling_var) + if stride == 0: + # load scalar and lazily broadcast it on demand + return super().load(name, index) + elif stride == 1: + # load contiguously + line = self._get_vec_load_line(var, index, dtype, self._load_mask) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line, dtype=dtype) # type: ignore[assignment] + else: + csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment] + assert isinstance(csevar, CppCSEVariable) + csevar.update_on_args("load", (self, name, index), {}) + csevar.is_vec = True + return csevar + + def _get_store_line( + self, + value: Union[str, CppCSEVariable], + var: str, + index: sympy.Expr, + dtype: torch.dtype, + accu_store: bool = False, + ): + """ + Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles + both contiguous and non-contiguous store cases. + :param value: Vectorized type templaterized on `dtype`. + :param var: buffer to store into. + :index: index into the `var`. + """ + # when value's type is str (e.g., welford reduction), caller should make sure + # it is a vector + assert isinstance(value, str) or ( + isinstance(value, CppCSEVariable) and value.is_vec + ), value + tiling_var = self.itervars[self.tiling_idx] + var_expr = f"{var} + {cexpr_index(index)}" + stride = self._try_get_const_stride(index, tiling_var) + code = IndentedBuffer() + if stride == 1: + if accu_store: + load = ( + f"{self._get_vec_type(dtype)}::loadu({var_expr})" + if dtype == torch.float and self.tail_size is None + else f"{self._get_vec_type(dtype)}::loadu({var_expr}, {cexpr_index(self.num_elems)})" + ) + value = f"({value} + {load})" + if dtype == torch.float and self.tail_size is None: + code.writeline(f"{value}.store({var_expr});") + else: + code.writeline( + f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});" + ) + else: + self._load_or_store_non_contiguous( + var, index, dtype, buffer=code, store_value=value, accu_store=accu_store + ) + return code + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + var = self.args.output(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + if mode is None: + code = self._get_store_line(value, var, index, dtype) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + code = self._get_store_line( + f"{value}", + var, + index, + dtype, + accu_store=True, + ) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + else: + n_src = self._get_num_vectors(dtype) + n_idx = self._get_num_vectors(torch.int64) + cdtype = DTYPE_TO_CPP[dtype] + index = ops.index_expr(index, torch.int64).value + assert isinstance(index, CppCSEVariable) and index.is_vec + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" + self.stores.writeline(DeferredLine(name, line)) + else: + raise NotImplementedError(f"store mode={mode}") + + def reduction(self, dtype, src_dtype, reduction_type, value): + # Note: For argmax and argmin on bool type, we always convert bool to float. + # Fix issue: https://github.com/pytorch/pytorch/issues/143568 + assert reduction_type in VECTORIZABLE_RTYPES + argmax_or_argmin = reduction_type in ("argmax", "argmin") + horizontal_reduction = self.tiling_idx >= self.reduction_depth + init_dtype = src_dtype if argmax_or_argmin else dtype + assert isinstance(value, CppCSEVariable), value + + if not value.is_vec: + value = self.broadcast(value) + + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + vec_ns = "at::vec" + vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" + acc_type = reduction_acc_type(reduction_type, init_dtype) + acc_type_vec = self.reduction_acc_type_vec(reduction_type, init_dtype) + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + assert isinstance(acc, CppCSEVariable) + acc_vec = f"{acc}_vec" + masked_acc_vec = f"masked_{acc_vec}" + self.reduction_var_names += [f"{acc}", acc_vec, masked_acc_vec] + self.is_reduction = True + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + acc, acc_type, reduction_type, init_dtype, reduction_init + ) + ) + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + acc_vec, + acc_type_vec, + reduction_type, + init_dtype, + self.reduction_init_vec, + ) + ) + if reduction_type == "welford_reduce": + # use masked acc_vec for tail vec kernel + self.reduction_prefix_generators.append( + self._gen_reduction_prefix( + masked_acc_vec, + acc_type_vec, + reduction_type, + dtype, + self.reduction_init_vec, + ) + ) + + # use welford_helper for vec kernel + assert self.reduction_depth is not None + reduction_size = functools.reduce( + operator.mul, self.ranges[self.reduction_depth :] + ) + welford_helper_val = self.welford_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + masked_welford_helper_val = f"masked_{welford_helper_val}" + welford_helper_vec_range = ( + ( + FloorDiv(reduction_size, self.ranges[self.tiling_idx]) + * FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) + if self.tiling_idx >= self.reduction_depth + else reduction_size + ) + if FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) + else sympy.Integer(0) + ) + masked_welford_helper_vec_range = ( + ( + FloorDiv(reduction_size, self.ranges[self.tiling_idx]) + if self.tiling_idx >= self.reduction_depth + else reduction_size + ) + if self.ranges[self.tiling_idx] % self.tiling_factor + else sympy.Integer(0) + ) + self._use_welford_helper( + acc_vec, welford_helper_val, welford_helper_vec_range, dtype + ) + self._use_welford_helper( + masked_acc_vec, + masked_welford_helper_val, + masked_welford_helper_vec_range, + dtype, + ) + + # use masked acc_vec for tail vec kernel + acc_vec_ = masked_acc_vec if self.tail_size else acc_vec + welford_helper_val_ = ( + masked_welford_helper_val if self.tail_size else welford_helper_val + ) + self.stores.writeline( + f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, welford_helper_val_)};" + ) + else: + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + kwargs = { + "next_value": value, + "index": index, + "horizontal_reduction": horizontal_reduction, + "src_dtype": src_dtype, + } + self.stores.writeline( + f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, **kwargs)};" + ) + self._gen_parallel_reduction_buffers( + acc_vec, + acc_type_vec, + reduction_type, + init_dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + self._gen_parallel_reduction_buffers( + acc, + acc_type, + reduction_type, + init_dtype, + reduction_combine_fn=reduction_combine, + reduction_init_fn=reduction_init, + ) + if reduction_type == "welford_reduce": + # use masked acc_vec for tail vec kernel + self._gen_parallel_reduction_buffers( + masked_acc_vec, + acc_type_vec, + reduction_type, + dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + tmpvar: Union[str, CSEVariable] + is_bool = dtype == torch.bool + if horizontal_reduction: + # Horizontal reduction + if is_welford_reduction(reduction_type): + assert self._get_num_vectors(dtype) in [ + 1, + 2, + ], "Welford reduction does not support VectorizedN (N>2)" + next_value = f"welford_vec_reduce_all({acc_vec})" + masked_next_value = f"welford_vec_reduce_all({masked_acc_vec})" + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, masked_next_value)};" + ) + elif argmax_or_argmin: + next_value = f"{reduction_type}_vec_reduce_all({acc_vec})" + elif is_bool: + if reduction_type in ( + "any", + "sum", + "max", + ): + next_value = f"!{acc_vec}.all_zero()" + else: + assert reduction_type == "min" + next_value = f"{acc_vec}.all_masked()" + else: + reduce_all_body = ( + "{ return " + + self.reduction_combine_vec(reduction_type, "x", "y") + + "; }" + ) + is_bool = dtype == torch.bool + # we are using at::vec::VecMask for bool + vec_dtype = torch.float if is_bool else dtype + vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" + vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" + next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" + + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};" + ) + tmpvar = acc + else: + tmpvar = acc_vec + if is_welford_reduction(reduction_type): + masked_tmpvar = f"masked_{tmpvar}" + self.reduction_suffix.writeline( + f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};" + ) + + result = reduction_project(reduction_type, tmpvar) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + out_dtype = V.graph.get_dtype(name) + dtype = ( + (out_dtype if out_dtype == torch.double else torch.float) + if out_dtype.is_floating_point + else torch.int64 + ) + out_num_vectors = V.kernel._get_num_vectors(out_dtype) + src_num_vectors = V.kernel._get_num_vectors(dtype) + code = IndentedBuffer() + if self.tiling_idx >= self.reduction_depth: + # Horizontal reduction + code.writeline( + f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});" + ) + else: + # Vertical reduction + if out_dtype != dtype: + converted_value = ( + f"{DTYPE_TO_CPP[out_dtype].replace('::', '_')}_{value}" + ) + if out_dtype == torch.bool: + convert = f"{value}.template cast()" + else: + if src_num_vectors == out_num_vectors == 1: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" + ) + else: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}," + f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})" + ) + code.writeline(f"auto {converted_value} = {convert};") + value = converted_value + code.splice(self._get_store_line(value, var, index, out_dtype)) + self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x))) + + def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable: + assert not scalar_var.is_vec + if scalar_var.dtype == torch.bool: + vec_var = self.cse.generate( + self.compute, f"{self._get_mask_type()}::from({scalar_var.name})" + ) + else: + assert scalar_var.dtype is not None + vec_var = self.cse.generate( + self.compute, + f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})", + ) + assert isinstance(vec_var, CppCSEVariable) + vec_var.dtype = scalar_var.dtype + vec_var.dependent_itervars = scalar_var.dependent_itervars + vec_var.is_vec = True + return vec_var + + def arange(self, index: CppCSEVariable, stride: sympy.Symbol) -> CppCSEVariable: + assert not index.is_vec + assert index.dtype is not None + csevar = self.cse.generate( + self.compute, + f"{self._get_vec_type(index.dtype)}::arange({index}, {stride})", + ) + assert isinstance(csevar, CppCSEVariable) + csevar.dtype = index.dtype + csevar.is_vec = True + return csevar + + def reduction_init_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>()" + + if reduction_type in ("argmin", "argmax"): + cdtype = DTYPE_TO_CPP[scalar_type] + acc_type = self.reduction_acc_type_vec(reduction_type, dtype) + if reduction_type == "argmin": + val = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + else: + val = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + return f"{acc_type}({val})" + + if reduction_type == "any": + return f"{self._get_mask_type()}::from(0)" + + scalar_init = reduction_init(reduction_type, dtype) + vec_init = f"{vec_type}({scalar_init})" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "sum") + return f"{self._get_mask_type()}::from({scalar_init})" + return vec_init + + def reduction_acc_type_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>" + if reduction_type in ("argmin", "argmax"): + n_src = self._get_num_vectors(scalar_type) + n_idx = self._get_num_vectors(torch.int64) + if dtype == torch.bool: + return f"IndexValueVec<{DTYPE_TO_CPP[torch.float]}, {n_src}, {n_idx}>" + return f"IndexValueVec<{DTYPE_TO_CPP[scalar_type]}, {n_src}, {n_idx}>" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "any", "sum") + return f"{self._get_mask_type()}" + return vec_type + + def _welford_helper_init( + self, welford_helper_val, welford_helper_vec_range, dtype, num_threads=None + ): + vec_num_range_thread = ( + CeilDiv(welford_helper_vec_range, num_threads) + if num_threads + else welford_helper_vec_range + ) + vec_num_range_thread_expr = cexpr_index(vec_num_range_thread) + chunk_size = 4096 + num_chunks = CeilDiv(vec_num_range_thread, chunk_size) + welford_helper_init_line = ( + f"WelfordHelper<{self._get_vec_type(dtype)}, {chunk_size}> {welford_helper_val}" + f"(" + f"{vec_num_range_thread_expr}" + f");" + ) + if isinstance(num_chunks, sympy.Integer) and num_chunks <= 1: + # When the number of chunks <= 1, there is no need to use cascade summation to improve + # reduction accuracy. We can initialize a static WelfordHelper to improve performance. + return f"static {welford_helper_init_line}" + else: + return welford_helper_init_line + + def _use_welford_helper( + self, acc_vec, welford_helper_val, welford_helper_vec_range, dtype + ): + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + self.non_parallel_reduction_prefix.writeline( + self._welford_helper_init( + welford_helper_val, welford_helper_vec_range, dtype + ) + ) + self.local_reduction_init.writeline( + self._welford_helper_init( + welford_helper_val, welford_helper_vec_range, dtype, num_threads + ) + ) + self.non_parallel_reduction_suffix.writeline( + f"{acc_vec} = welford_combine({acc_vec}, &{welford_helper_val});" + ) + self.local_reduction_stores.writeline( + f"{acc_vec}_local = welford_combine({acc_vec}_local, &{welford_helper_val});" + ) + + def reduction_combine_vec( + self, + reduction_type, + var, + next_value, + welford_helper_val=None, + index: Optional[sympy.Symbol] = None, + horizontal_reduction: Optional[bool] = None, + src_dtype: Optional[torch.dtype] = torch.float32, + ): + is_bool = src_dtype == torch.bool + if reduction_type == "max": + if self.tail_size: + return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} | {next_value}" + if is_bool + else f"at::vec::maximum({var}, {next_value})" + ) + elif reduction_type == "min": + if self.tail_size: + return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} & {next_value}" + if is_bool + else f"at::vec::minimum({var}, {next_value})" + ) + elif reduction_type == "sum": + if self.tail_size: + return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + elif reduction_type == "prod": + if self.tail_size: + return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} * {next_value}" + elif reduction_type == "xor_sum": + if self.tail_size: + return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} ^ {next_value}" + elif reduction_type == "welford_reduce": + if welford_helper_val: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{welford_helper_val})" + else: + return ( + f"welford_combine({var}, {next_value}, &{welford_helper_val})" + ) + else: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {next_value})" + elif reduction_type == "welford_combine": + if isinstance(next_value, tuple): + # When reading a value from Inductor IR we have a tuple of variable names + mean, m2, weight = next_value + else: + # When combining intermediate accumulators we have a Welford struct + mean, m2, weight = reduction_project(reduction_type, next_value) + if self.tail_size: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + elif reduction_type in ("argmin", "argmax"): + assert src_dtype is not None + cdtype = DTYPE_TO_CPP[src_dtype] + if src_dtype == torch.bool: + cdtype = DTYPE_TO_CPP[torch.float] + n_src = self._get_num_vectors(src_dtype) + n_idx = self._get_num_vectors(torch.int64) + t_extra = "" + arg_extra = "" + if index is not None: + assert horizontal_reduction is not None + t_extra = f", {str(horizontal_reduction).lower()}" + arg_extra = f", {index}" + if self.tail_size: + return ( + f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>" + f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})" + ) + else: + return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})" + elif reduction_type == "any": + if isinstance(next_value, CppCSEVariable): + assert next_value.dtype == torch.bool + (next_value,) = unify_mask_base_type(V.kernel.compute, (next_value,)) + return f"{var} | {next_value}" + else: + raise NotImplementedError + + def indirect_assert(self, var, lower, upper, mask=None): + assert isinstance(var, CppCSEVariable) + assert var.dtype is not None + if not var.is_vec: + if isinstance(mask, CppCSEVariable) and mask.is_vec: + mask = f"({mask}).all_masked()" + return super().indirect_assert(var, lower, upper, mask) + lower_scalar = lower + upper_scalar = upper + if lower: + lower = f"{self._get_vec_type(var.dtype)}({lower})" + if upper: + upper = f"{self._get_vec_type(var.dtype)}({upper})" + if lower and upper: + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower_scalar} <= {var} < {upper_scalar}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = f"{lower_scalar} <= {var}" + else: + assert upper + cond = f"{var} < {upper}" + cond_print = f"{var} < {upper_scalar}" + cond = f"{self._get_mask_type(var.dtype)}({cond})" + if mask: + if not mask.is_vec: + mask = f"{self._get_mask_type(var.dtype)}({mask})" + # We need not check when the mask is False + cond = f"({cond}) | ~({mask})" + if self.tail_size: + cond = ( + f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)" + f", ({cond}), {cexpr_index(self.tail_size)})" + ) + cond = f"({cond}).all_masked()" + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def get_to_dtype_expr(self, src, dtype, src_dtype): + assert isinstance(src, CppCSEVariable) + if not src.is_vec: + return super().get_to_dtype_expr(src, dtype, src_dtype) + src_cpp_type = DTYPE_TO_CPP[src_dtype] + src_num_vectors = self._get_num_vectors(src_dtype) + dst_cpp_type = DTYPE_TO_CPP[dtype] + dst_num_vectors = self._get_num_vectors(dtype) + expr = f"({src})" + if src_dtype != torch.bool and dtype == torch.bool: + expr = f"{self._get_mask_type(src_dtype)}::from<{src_cpp_type},{src_num_vectors}>({src})" + elif src_dtype == torch.bool and dtype != torch.bool: + expr = f"{src}.to<{dst_cpp_type},{dst_num_vectors}>()" + elif src_dtype != dtype: + if src_num_vectors == dst_num_vectors == 1: + expr = f"at::vec::convert<{dst_cpp_type}>({src})" + else: + expr = f"at::vec::convert<{dst_cpp_type},{dst_num_vectors},{src_cpp_type},{src_num_vectors}>({src})" + return expr + + +class CppTile2DKernel(CppVecKernel): + """ + A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on + the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data + tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the + tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization + logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load + and store are generated into kernel.preloads and kernel.poststores buffers. + + The loop structure looks like below: + for ... + for i_outer ... + for ... + for inner_most ... + // generated by CppTile2DKernel + float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads + float tmp1[16*16]; // into kernel.preloads + for i_inner ... { // the kernel inner loop + vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores + } + at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores + for inner_most ... (tail) + // generated by CppVecKernel + ... + for i_outer ... (tail) + for ... + for ... + // generated by CppKernel + ... + """ + + overrides = CppTile2DOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_indices, + inner_tail_size=None, + outer_tail_size=None, + ): + super().__init__( + args, + num_threads, + tiling_factor, + tiling_indices[1], + inner_tail_size, + ) + self.tiling_indices = tiling_indices + self.inner_tail_size = inner_tail_size + self.outer_tail_size = outer_tail_size + self.inner_num_elems = inner_tail_size if inner_tail_size else tiling_factor + self.outer_num_elems = outer_tail_size if outer_tail_size else tiling_factor + self.inner_is_tiling_idx = True + + def inner_itervar(self): + return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner") + + def need_vec_transpose(self, index): + outer_var = self.itervars[self.outer_idx] + inner_var = self.itervars[self.tiling_idx] + outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor) + inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor) + return ( + self._load_mask is None # TODO: support transposition with mask + and outer_stride == 1 + and index.has(inner_var) + and not inner_stride.has(inner_var) + and not inner_stride.has(outer_var) + ) + + def gen_transposed_tile_load_store( + self, name, var, index, is_store, store_mode=None + ): + # transposed tile load/store outside the kernel inner loop + dtype = V.graph.get_dtype(name) + factor = self.tiling_factor + src = f"{var} + {cexpr_index(index)}" + dst = "__place_holder__" + ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" + ld_dst = f"{cexpr_index(self.num_elems)}" + if is_store: + src, dst = dst, src + ld_src, ld_dst = ld_dst, ld_src + + need_define = True + if self.inner_is_tiling_idx ^ is_store: + M, N = self.inner_num_elems, self.outer_num_elems + else: + M, N = ( + self.outer_num_elems, + self.inner_num_elems, + ) + atomic_add = "true" if (is_store and (store_mode == "atomic_add")) else "false" + if (isinstance(M, sympy.Expr) and not M.is_number) or ( + isinstance(N, sympy.Expr) and not N.is_number + ): + load_or_store = ( + f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{atomic_add}>" + f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});" + ) + else: + load_or_store = ( + f"transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)},{atomic_add}>" + f"({src}, {ld_src}, {dst}, {ld_dst});" + ) + if is_store: + tile_var = self.cse.newvar() + elif not self.cse.contains(load_or_store): + tile_var = self.cse.generate(self.preloads, load_or_store, write=False) + else: + need_define = False + tile_var = self.cse.get(load_or_store) + + if need_define: + cpp_dtype = DTYPE_TO_CPP[dtype] + # tiling_factor might be smaller than the alignment of cpp_dtype, such as + # with a vector that only holds 4 elements due to NEON 128-bit vectors and + # cpp_dtype being a 64-bit integer. + alignas = f"alignas(std::max(std::size_t({factor}), alignof({cpp_dtype})))" + define_line = f"{alignas} {cpp_dtype} {tile_var}[{factor}*{factor}];" + self.preloads.writeline(define_line) + + load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) + if is_store: + self.poststores.writeline(DeferredLine(name, load_or_store)) + else: + self.preloads.writeline(load_or_store) + + return tile_var + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + + inner = self.inner_itervar() + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=False + ) + # vector load inside the kernel inner loop + loadbuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + dtype = V.graph.get_dtype(name) + line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line, dtype=dtype) + csevar.update_on_args("load", (self, name, index), {}) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + else: + new_index = self.transform_indexing(index) + return super().load(name, new_index) + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + + var = self.args.output(name) + + inner = self.inner_itervar() + index = self.rename_indexing(index) + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=True, store_mode=mode + ) + # vector store inside the kernel inner loop + storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [ + torch.uint8, + torch.int8, + ]: + line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});" + else: + line = f"{value}.store({storebuf});" + self.stores.writeline(DeferredLine(name, line)) + else: + new_index = self.transform_indexing(index) + super().store(name, new_index, value, mode) + + def codegen_inner_loops(self, code): + inner = self.inner_itervar() + if self.inner_is_tiling_idx: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)" + ) + else: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)" + ) + + def set_ranges(self, group, reduction_group): + vars = super().set_ranges(group, reduction_group) + # do vertical reduction as the tail loop + self.outer_idx, self.tiling_idx = ( + self.tiling_indices + if self.tiling_indices[1] < self.reduction_depth + else reversed(self.tiling_indices) + ) + if self.tiling_idx == self.tiling_indices[0]: + self.tail_size = self.outer_tail_size + self.num_elems = self.outer_num_elems + self.inner_is_tiling_idx = False + else: + self.tail_size = self.inner_tail_size + self.num_elems = self.inner_num_elems + self.inner_is_tiling_idx = True + return vars + + def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: + return self.scale_index_with_offset( + index, + itervar_idx=self.outer_idx, + offset=self.inner_itervar(), + ) + + +def get_loop_body_lowp_fp(_body: LoopBody) -> tuple[Optional[torch.dtype], bool]: + """ + Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes + and if all the nodes can codegen with this data type without converting to float. + Otherwise returns None and True. + """ + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + + _lowp_fp_type: Optional[torch.dtype] = None + _use_fp32 = False + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.op == "placeholder" or _node.target in ( + "get_index", + "index_expr", + ): + continue + + # Fast path if all operations can support bf16/fp16 without converting to fp32 + if _node.target not in [ + "load", + "store", + "abs", + "neg", + "output", + ]: + _use_fp32 = True + + if hasattr(_node, "meta") and _node.meta: + assert OptimizationContext.key in _node.meta + opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] + if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: + _use_fp32 = True + elif _lowp_fp_type is not None: + if _lowp_fp_type != opt_ctx.dtype: + warnings.warn("bf16 and fp16 are mixed in the scheduler node.") + else: + _lowp_fp_type = opt_ctx.dtype + else: + _use_fp32 = True + + return _lowp_fp_type, _use_fp32 + + +class TilingSelect: + """ + Implement the heuristic to select the tiling factors and tiling indices. + In the future, we can implement advanced heuristic in a subclass. + """ + + def __init__(self): + super().__init__() + + def select_tiling( + self, + fn_list, + var_sizes_list, + ) -> tuple[list[int], list[int]]: + # TODO(jgong5): support alternative tiling factors and data types + loop_bodies = _get_loop_body(fn_list) + all_dtypes = _get_dtype_from_loopbodies(loop_bodies) + assert all_dtypes + if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes): + return [], [] + dtype = torch.float + _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0] + if _lowp_fp_dtype and all( + (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype) + for loop_body in loop_bodies[1:] + ): + dtype = _lowp_fp_dtype + + tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + tiling_indices = self._select_tiling_indices( + fn_list, var_sizes_list, tiling_factor + ) + + if tiling_indices: + group, reduction_group = max( + var_sizes_list, key=lambda sizes: len(sizes[1]) + ) + call_ranges = tuple(group) + tuple(reduction_group) + + if config.cpp.enable_tiling_heuristics: + + def _try_get_stride( + index, + itervars, + tiling_factor, + tiling_indices, + ): + itervar = itervars[tiling_indices[0]] + stride = stride_at_vec_range(index, itervar, tiling_factor) + return stride if stride.is_number else None + + def _update_negative_op_count( + node_name, non_contig_indexing_op_counter + ): + if node_name not in non_contig_indexing_op_counter: + non_contig_indexing_op_counter[node_name] = 1 + else: + non_contig_indexing_op_counter[node_name] += 1 + + def _is_valid_indices( + itervars, + tiling_indices, + ): + return ( + len(tiling_indices) == 1 + and len(itervars) > 0 + and ( + tiling_indices[0] + if tiling_indices[0] >= 0 + else tiling_indices[0] + len(itervars) + ) + < len(itervars) + ) + + itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(call_ranges)) + ] + reduction_depth = len(group) + vars, reduction_vars = ( + itervars[:reduction_depth], + itervars[reduction_depth:], + ) + op_counter: dict[str, int] = {} + # ops may cause overhead with vectorization, like non-contiguous + # index_expr, load, store + non_contig_indexing_op_counter: dict[str, int] = {} + for _body in loop_bodies: + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.target in ["index_expr", "load", "store"]: + # get the index and replace prefix from z to x + arg_idx = 1 if _node.target == "index_expr" else 2 + index = sub_block.body.indexing_from_args( + (vars, reduction_vars) + )[_node.args[arg_idx].args[0]] + if _is_valid_indices(itervars, tiling_indices): + stride = _try_get_stride( + index, itervars, tiling_factor, tiling_indices + ) + if ( + stride is None + if _node.target == "index_expr" + else stride not in [0, 1] + ): + _update_negative_op_count( + _node.target, non_contig_indexing_op_counter + ) + if isinstance(_node.target, str) and not ( + _node.target.startswith("masked_subblock") + or _node.target + in ["ops", "output", "constant", "get_index"] + ): + if _node.target not in op_counter: + op_counter[_node.target] = 1 + else: + op_counter[_node.target] += 1 + + op_num = sum(op_counter.values()) + non_contig_indexing_op_num = sum( + non_contig_indexing_op_counter.values() + ) + ratio_threshold = 0.12 + quantity_threshold = 35 + if non_contig_indexing_op_num >= quantity_threshold or ( + op_num > 0 + and non_contig_indexing_op_num / op_num >= ratio_threshold + ): + # Too many non-contiguous load/store/index_expr which hurts the + # vectorization performance. Disable vectorization when exceeding + # the thresholds. + return [], [] + + if ( + not reduction_group + and group + and len(tiling_indices) == 1 + and not has_free_symbols( + [ + group[tiling_indices[0]], + ] + ) + and group[tiling_indices[0]] < tiling_factor / 4 + and op_num < 10 + ): + # We found that when the number of elements in the inner loop range is + # relatively small(< tiling_factor / 4) and the number of operations is + # not large(< 10), vectorization is not efficient. + # And found that `#pragma GCC ivdep` has better performance than + # `#pragma omp simd simdlen(8)` for these cases. + return [], [] + + if dtype in DTYPE_LOWP_FP: + # For lower precision data type, if the call_range is not long enough, + # use tiling_factor // 2 for better performance + factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + for tiling_indice in tiling_indices: + if tiling_indice < 0: + tiling_indice = tiling_indice + len(call_ranges) + if tiling_indice < 0 or tiling_indice >= len(call_ranges): + continue + if has_free_symbols(call_ranges): + call_range = V.graph.sizevars.size_hint( + call_ranges[tiling_indice], fallback=0 + ) + if call_range < factor_lowp: + V.graph.sizevars.guard_lt(call_range, factor_lowp) # type: ignore[arg-type] + tiling_factor = factor_lowp // 2 + break + elif call_ranges[tiling_indice] < factor_lowp: + tiling_factor = factor_lowp // 2 + break + + if len(tiling_indices) == 1: + return [tiling_factor], tiling_indices + if len(tiling_indices) == 2: + return [tiling_factor, tiling_factor], tiling_indices + return [], [] + + def _select_tiling_indices( + self, + fn_list, + var_sizes_list, + tiling_factor, + ): + all_index = [] + for fn, var_sizes in zip(fn_list, var_sizes_list): + rw = dependencies.extract_read_writes(fn, *var_sizes) + all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] + contig_vars = OrderedSet[int]() + contig_vars_list = [] + non_contig_stride_const = OrderedSet[int]() + non_contig_stride_other = OrderedSet[int]() + for index in all_index: + for var in index.free_symbols: + if not re.search(r"^d\d+$", var.name): + continue + stride = stride_at_vec_range(index, var, tiling_factor) + if stride == 0: + continue + elif stride == 1: + contig_vars.add(int(var.name[1:])) + contig_vars_list.append(int(var.name[1:])) + elif all(symbol_is_type(s, SymT.SIZE) for s in stride.free_symbols): + non_contig_stride_const.add(int(var.name[1:])) + else: + non_contig_stride_other.add(int(var.name[1:])) + contig_only = contig_vars - non_contig_stride_const - non_contig_stride_other + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + num_itervars = len(group) + len(reduction_group) + if len(contig_vars) == 0: + # no contiguous vars + return [num_itervars - 1] + if contig_only: + return sorted(contig_only)[-1:] + contig_and_const_stride = ( + contig_vars & non_contig_stride_const + ) - non_contig_stride_other + contig_vars_sorted = sorted(contig_vars) + if ( + len(contig_vars_sorted) == 2 + and contig_vars_sorted[-1] in contig_and_const_stride + and contig_vars_sorted[-1] == num_itervars - 1 + ): + return contig_vars_sorted + return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:] + + +class CppKernelProxy(CppKernel): + # Subclass CppKernel, CppVecKernel, etc., to customize code generation. + # Override CppOverrides or CppVecOverrides to emit custom ops. + # Earlier, this meant copying codegen_functions() to use your subclasses. + # Now, use kernel_cls and vec_kernel_cls class attributes instead. + # This lets CppKernelProxy subclasses inject custom behavior cleanly. + # No need to duplicate codegen_functions() just to swap kernel classes. + kernel_cls: type[CppKernel] = CppKernel + vec_kernel_cls: type[CppVecKernel] = CppVecKernel + tile2d_kernel_cls: type[CppTile2DKernel] = CppTile2DKernel + + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.kernel_group = kernel_group + self.loop_nest = None + self.call_ranges = None + self.picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + self.kernels: list[CppKernel] = [] + + def data_type_propagation(self, nodes): + for _node in nodes: + assert isinstance(_node, SchedulerNode) + DataTypePropagation.propagate_scheduler_node(_node) + + # Check if all the nodes of a given fx graph can support BF16/FP16 + def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): + if not isinstance(scheduler_node._body, LoopBody): + return True + # Propagate the dtype to check if all the fx node is bf16/fp16 + DataTypePropagation.propagate_scheduler_node(scheduler_node) + return ( + get_loop_body_lowp_fp(scheduler_node._body)[0] is not None + and not get_loop_body_lowp_fp(scheduler_node._body)[1] + ) + + def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody): + def add_to_dtype(sub_graph: torch.fx.Graph): + def get_input_dtype(node: torch.fx.Node) -> Optional[torch.dtype]: + """Get input dtype for nodes that may consumes lowp fp dt""" + if node.target == "store": + return V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + elif node.target == "to_dtype_bitcast": + return node.args[-1] # type: ignore[return-value] + elif node.target == "to_dtype": + if len(node.args) > 3: + return node.args[3] # type: ignore[return-value] + else: + return node.kwargs.get("src_dtype", None) # type: ignore[return-value] + else: + return None + + def get_output_dtype(node: torch.fx.Node) -> Optional[torch.dtype]: + """Get output dtype for nodes that may produce lowp fp dt""" + if node.target == "load": + assert len(node.args) == 3 + return V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + elif node.target in ["to_dtype", "constant", "index_expr"]: + return node.args[-1] # type: ignore[return-value] + elif node.target == "to_dtype_bitcast": + return node.args[2] # type: ignore[return-value] + else: + return None + + def is_lowp_fp_source(node: torch.fx.Node, dt: torch.dtype): + """Check if the given node produces output with expected low precision floating point data type.""" + assert dt in DTYPE_LOWP_FP + return get_output_dtype(node) == dt + + def is_lowp_fp_sink(node: torch.fx.Node, dt: torch.dtype): + """Check if the given node accept input with expected low precision floating point data type.""" + assert dt in DTYPE_LOWP_FP + if input_dtype := get_input_dtype(node): + return input_dtype == dt + elif node.target == "to_dtype": + # The `src_dtype` of a `to_dtype` node might miss, in which case the node accept any input dtype. + return True + else: + return False + + def is_lowp_fp_source_no_promote(node: torch.fx.Node, dt: torch.dtype): + """Check if the node is a lowp fp sources which are all directly fed to ops that accepts lowp fp input + thus no need to promote to float + """ + return is_lowp_fp_source(node, dt) and all( + is_lowp_fp_sink(user, dt) for user in node.users + ) + + sub_graph_nodes = list(sub_graph.nodes) + to_lowp_fp_legalized_nodes = [] + for _node in sub_graph_nodes: + if ( + _node.target in ["load", "index_expr"] + and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP + ): + # No need to promote to float if all users are ops that accepts lowp fp input + if all(is_lowp_fp_sink(user, dt) for user in _node.users): + continue + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + _node.replace_all_uses_with( + to_type_node, lambda n: n is not to_type_node + ) + metrics.cpp_to_dtype_count += 1 + elif ( + _node.target == "store" + and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP + ): + ops, name, _, value_var, _ = _node.args + if is_lowp_fp_source_no_promote(value_var, dt): + continue + dtype = V.graph.get_dtype(name) + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, dtype) + ) + _node.replace_input_with(value_var, to_type_node) + metrics.cpp_to_dtype_count += 1 + elif _node.target == "reduction": + ( + ops, + dtype, + src_dtype, + reduction_type, + value, + ) = _node.args + if src_dtype in DTYPE_LOWP_FP: + # Since we always convert the load/store value to float if the tensor is bfloat16/float16. + # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update + # the bfloat16/float16 reduction by + # 1) updating the src_dtype to float + # and 2) updating the dtype to float if it is bfloat16/float16. + assert dtype in [ + torch.float, + torch.bfloat16, + torch.float16, + torch.int64, + ] + _node.args = ( + ops, + torch.float if dtype in DTYPE_LOWP_FP else dtype, + torch.float, + reduction_type, + value, + ) + elif _node.target == "constant" and _node.args[-1] in DTYPE_LOWP_FP: + # No need to promote to float if all users are ops that accepts lowp fp input + (ops, value, dt) = _node.args + if all(is_lowp_fp_sink(user, dt) for user in _node.users): # type: ignore[arg-type] + continue + _node.args = (ops, value, torch.float) + elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP: + # No need to promote to float if all users are ops that accepts lowp fp input + (ops, x, dt) = _node.args + if all(is_lowp_fp_sink(user, dt) for user in _node.users): # type: ignore[arg-type] + continue + # The legalization always loads the BF16/FP16 tensor as FP32 for computation + # and converts back to BF16/FP16 after the computation. + # Hence, there should be no computation w/ BF16/FP16. + # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32. + # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step): + # 1) Eliminate the redundant to_dtype node if we have a pattern as follows: + # graph(): + # %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float)) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16)) + # Regarding the first to_dtype, it is redundant because + # the second to_type also converts to the torch.bfloat16/torch.float16. + # Hence, we remove the first to_type. + to_lowp_fp_legalized_nodes.append(_node) + _node.args = (ops, x, torch.float) + elif _node.target == "to_dtype_bitcast": + (ops, value_var, dtype, src_dtype) = _node.args + + # to_dtype_bitcast act as a lowp fp sink: + # c10::bit_cast requires the source and target have the same bitwidth. Because the input tensor's + # dtype could be promoted, e.g. from float16 to float, we have to cast the tensor to its original + # source dtype before invoking bit_cast. + if src_dtype in DTYPE_LOWP_FP: + # No need to promote to float if it is a user of a lowp fp sources + # which are all directly fed to ops that accepts lowp fp input + if not is_lowp_fp_source_no_promote(value_var, src_dtype): + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, src_dtype) + ) + _node.replace_input_with(value_var, to_type_node) + metrics.cpp_to_dtype_count += 1 + + # to_dtype_bitcast act as a lowp fp source: + # We also need to convert the bit-casted tensor back to float to make sure we keep using higher + # precision values for the rest of the computation. + if dtype in DTYPE_LOWP_FP: + # No need to promote to float if all users are ops that accepts lowp fp input + if not ( + all(is_lowp_fp_sink(user, dtype) for user in _node.users) + ): + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + _node.replace_all_uses_with( + to_type_node, lambda n: n is not to_type_node + ) + metrics.cpp_to_dtype_count += 1 + else: + pass + + def eliminate_to_dtype(sub_graph: torch.fx.Graph): + def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph): + # Eliminate the redundant to_dtype node. Let's consider a pattern as follows: + # graph(): + # %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {}) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {}) + # Regarding the first to_dtype, it is redundant because the second to_type also converts to the + # torch.float. Hence, we remove the first to_type + def _used_by_to(to_node: torch.fx.Node): + return all(usr.target == "to_dtype" for usr in to_node.users) + + all_to_nodes = [ + node for node in sub_graph.nodes if node.target == "to_dtype" + ] + all_to_nodes_and_users = [ + {node: node.users} for node in all_to_nodes if _used_by_to(node) + ] + for node_users in all_to_nodes_and_users: + for node, users in node_users.items(): + if node in sub_graph.nodes and ( + all(usr.args[-1] == node.args[-1] for usr in users) + or ( + node in to_lowp_fp_legalized_nodes + and all( + usr.args[-1] in DTYPE_LOWP_FP for usr in users + ) + ) + ): + val_node = node.all_input_nodes[-1] + node.replace_all_uses_with(val_node) + sub_graph.erase_node(node) + + # For debug mode, the graph of LoopBody will attach a new GraphModule as + # owning_module for debugging while the release mode will not. The lint will + # check whether the graph has owning_module to decide if it needs to check + # call_module. LoopBody might contain get_index as a module call. But it + # is just a function. Hence, it cannot pass the lint check for debug mode. + # We bypass the check if the owning_module is None. Eventually, we should call + # get_index via call_function but not call_module. + if sub_graph.owning_module is None: + sub_graph.lint() + + _eliminate_duplicate_to_node(sub_graph) + + eliminate_to_dtype(sub_graph) + + sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values()) + for sub_block in sub_blocks: + add_to_dtype(sub_block.graph) + + def legalize_lowp_fp_dtype(self, nodes): + if all( + isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node) + for _node in nodes + ): + # Mark the load node to load bf16/fp16 + for _node in nodes: + sub_blocks = [_node._body.root_block] + list( + _node._body.subblocks.values() + ) + for sub_block in sub_blocks: + for fx_node in sub_block.graph.nodes: + if fx_node.target in ["load", "store"]: + assert fx_node.meta + assert OptimizationContext.key in fx_node.meta + opt_ctx: OptimizationContext = fx_node.meta[ + OptimizationContext.key + ] + assert opt_ctx.dtype in DTYPE_LOWP_FP + + # Bypass the legalization as the kernel can run with bf16/fp16 directly + return + + for _node in nodes: + assert isinstance(_node, SchedulerNode) + assert isinstance(_node._body, LoopBody) + body: LoopBody = _node._body + if not body.is_memory_copy(): + self.legalize_lowp_fp_dtype_loopbody(body) + + def codegen_functions(self, fn_list, var_sizes_list): + assert len(fn_list) == len(var_sizes_list) + kernel_group = self.kernel_group + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + + self.set_ranges(group, reduction_group) + + def codegen_kernel(cls, *args): + with kernel_group.new_kernel(cls, *args) as kernel: + # Ugly hack to maintain the metrics kernel count since + # we only count in CppKernelProxy, not those contained in it + metrics.generated_kernel_count -= 1 + + run(kernel) + return kernel + + def run(kernel): + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + in_suffix = False + for fn, var_sizes in zip(fn_list, var_sizes_list): + if var_sizes in [ + (group, reduction_group), + (tuple(itertools.chain(group, reduction_group)), ()), + ]: + assert not in_suffix + fn(vars, reduction_vars) + else: + in_suffix = True + assert var_sizes == ( + group, + (), + ), f"unexpected group: {var_sizes} != {group}, {reduction_group}" + # we can fuse in some extra pointwise into the suffix + with kernel.write_to_suffix(): + fn(vars, ()) + + scalar_kernel = codegen_kernel(self.kernel_cls) + V.graph.removed_buffers |= scalar_kernel.removed_buffers + V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove + self.loop_nest = LoopNest.build(scalar_kernel) + + if not self.picked_vec_isa or not self.itervars: + self.kernels = [scalar_kernel] + self.aggregate_reduction_buffers(False, None) + self.loop_nest.set_kernel(self) + return + + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + # But the generated scalar kernel has updated these global contexts. Hence, the other kernels + # should not do this again to avoid context conflict. By now, we only control the + # config.inplace_buffers. In the future, we could maintain more contexts. + with torch._inductor.config.patch(inplace_buffers=False): + tiling_select = TilingSelect() + tiling_factors, tiling_indices = tiling_select.select_tiling( + fn_list, var_sizes_list + ) + assert len(tiling_factors) == len(tiling_indices) + # This should be removed after full support for vectorization is implemented. + could_masked_vec = True + all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list)) + if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes): + # can be removed after masked vectorizable dtype are same with vectorizable dtype + could_masked_vec = False + + _inner_loop_reduction_outer_not = False + _outer_loop = None + if tiling_indices: + inner_loop_reduction = False + outer_loop_level = tiling_indices[0] + inner_loop_level = outer_loop_level + 1 + if len(self.loop_nest.loops) > inner_loop_level: + inner_loop_reduction = self.loop_nest.loops[ + inner_loop_level + ].is_reduction + outer_loop_reduction = self.loop_nest.loops[ + outer_loop_level + ].is_reduction + _inner_loop_reduction_outer_not = ( + inner_loop_reduction and not outer_loop_reduction + ) + + if len(tiling_indices) == 1: + metrics.generated_cpp_vec_kernel_count += 1 + loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0]) + vec_kernel = codegen_kernel( + self.vec_kernel_cls, tiling_factors[0], tiling_indices[0] + ) + tail_size = loop.size - loop.tiled_size + vec_kernel.active_ranges = {loop.var: (0, loop.tiled_size)} + if config.cpp.enable_loop_tail_vec and could_masked_vec: + tail_kernel = codegen_kernel( + self.vec_kernel_cls, + tiling_factors[0], + tiling_indices[0], + tail_size, + ) + else: + tail_kernel = scalar_kernel + scalar_kernel.inner_itervars = [loop.var] + tail_kernel.active_ranges = {loop.var: (loop.tiled_size, loop.size)} + self.kernels = [vec_kernel, tail_kernel] + _outer_loop = loop + elif len(tiling_indices) == 2: + assert ( + tiling_indices[1] == len(self.itervars) - 1 + and tiling_factors[0] == tiling_factors[1] + ) + + metrics.generated_cpp_vec_kernel_count += 2 + outer_loop = self.loop_nest.tile( + tiling_indices[0], factor=tiling_factors[0] + ) + outer_ranges = { + "main": (0, outer_loop.tiled_size), + "tail": (outer_loop.tiled_size, outer_loop.size), + } + outer_tail_size = outer_loop.size - outer_loop.tiled_size + inner_loop = self.loop_nest.tile( + tiling_indices[1], factor=tiling_factors[0] + ) + inner_ranges = { + "main": (0, inner_loop.tiled_size), + "tail": (inner_loop.tiled_size, inner_loop.size), + } + inner_tail_size = inner_loop.size - inner_loop.tiled_size + tile2d_kernel = codegen_kernel( + self.tile2d_kernel_cls, + tiling_factors[0], + tiling_indices, + ) + tile2d_kernel.active_ranges = { + outer_loop.var: outer_ranges["main"], + inner_loop.var: inner_ranges["main"], + } + tail_kernel = [] + if config.cpp.enable_loop_tail_vec and could_masked_vec: + for outer_r, inner_r in ( + ("main", "tail"), + ("tail", "main"), + ("tail", "tail"), + ): + _inner_tail_size = ( + inner_tail_size if inner_r == "tail" else None + ) + _outer_tail_size = ( + outer_tail_size if outer_r == "tail" else None + ) + kernel = codegen_kernel( + self.tile2d_kernel_cls, + tiling_factors[0], + tiling_indices, + _inner_tail_size, + _outer_tail_size, + ) + kernel.active_ranges = { + outer_loop.var: outer_ranges[outer_r], + inner_loop.var: inner_ranges[inner_r], + } + tail_kernel.append(kernel) + else: + vec_kernel = codegen_kernel( + self.vec_kernel_cls, tiling_factors[0], tiling_indices[0] + ) + vec_kernel.active_ranges = { + outer_loop.var: outer_ranges["main"], + inner_loop.var: inner_ranges["tail"], + } + vec_kernel.inner_itervars = [inner_loop.var] + tail_kernel.append(vec_kernel) + scalar_kernel.active_ranges = { + outer_loop.var: outer_ranges["tail"], + inner_loop.var: (0, inner_loop.size), + } + scalar_kernel.inner_itervars = [inner_loop.var, outer_loop.var] + tail_kernel.append(scalar_kernel) + self.kernels = [tile2d_kernel] + tail_kernel + _outer_loop = outer_loop + else: + self.kernels = [scalar_kernel] + self.aggregate_reduction_buffers( + _inner_loop_reduction_outer_not, _outer_loop + ) + self.loop_nest.set_kernel(self) + + def codegen_loop_bodies(self, loop_bodies, var_sizes_list): + for body in loop_bodies: + self.legalize_lowp_fp_dtype_loopbody(body) + DataTypePropagation.propagate_loopbody(body) + self.codegen_functions(loop_bodies, var_sizes_list) + + def codegen_nodes(self, nodes: list[SchedulerNode]): + # Legalize BF16 node by adding to_dtype explicitly + self.legalize_lowp_fp_dtype(nodes) + self.data_type_propagation(nodes) + assert len(nodes) >= 1 + + def fn(node, *index_vars): + node.decide_inplace_update() + node.mark_run() + if isinstance(V.kernel, NullKernelHandler): + return node._body(*index_vars) + else: + return node.codegen(index_vars) + + fn_list = [functools.partial(fn, node) for node in nodes] + + if ( + isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + + def wrap_fn(fn): + wrapped_fn = V.local_buffer_context.localize_function( + fn, + ) + wrapped_fn.original_fn = fn + return wrapped_fn + + fn_list = [wrap_fn(fn) for fn in fn_list] + + var_sizes_list = [node.group[1] for node in nodes] + self.codegen_functions(fn_list, var_sizes_list) + + def codegen_loops(self, code, worksharing): + self.codegen_loops_impl(self.loop_nest, code, worksharing) + + def update_stores_with_parallel_reduction(self): + for kernel in self.kernels: + kernel.update_stores_with_parallel_reduction() + + def gen_body(self, code: Optional[BracesBuffer] = None): + assert code is not None + if_prefix = "C10_LIKELY" + for kernel in self.kernels: + with contextlib.ExitStack() as stack: + if kernel.codegen_conditions(code, if_prefix): + if_prefix = "C10_UNLIKELY" + stack.enter_context(code.indent()) + code.splice(kernel.gen_body()) + + def aggregate_reduction_buffers( + self, inner_loop_reduction_outer_not: bool, outer_loop: Optional["LoopLevel"] + ): + # CppKernel/CppVecKernel/CppTile2dKernel have reduction buffers themselves. + # Here, we decide how to aggregate them together and place new reduction buffers + # under CppKernelProxy. + def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"): + assert len(self.kernels) >= 2 + main_loop_kernel = self.kernels[0] + tail_loop_kernel = self.kernels[-1] + assert isinstance(main_loop_kernel, self.vec_kernel_cls) + + # Prefix + if type(tail_loop_kernel) == self.kernel_cls: + # if tail loop kernel is a scalar kernel, we need to extend tmp_acc -> tmp_acc_arr[] to + # hold the temporary inner loop acc result for outer tail loop + tail_loop_kernel.finalize_reduction_prefix( + main_loop_kernel.tiling_factor + ) + main_loop_kernel.finalize_reduction_prefix() + self.reduction_prefix.splice( + tail_loop_kernel.reduction_prefix + + main_loop_kernel.reduction_prefix + ) + else: + main_loop_kernel.finalize_reduction_prefix() + self.reduction_prefix.splice(main_loop_kernel.reduction_prefix) + + # Suffix + suffix_buf = BracesBuffer() + with contextlib.ExitStack() as stack: + if main_loop_kernel.codegen_conditions( + suffix_buf, "C10_LIKELY", outer_loop.var + ): + stack.enter_context(suffix_buf.indent()) + suffix_buf.splice(main_loop_kernel.reduction_suffix) + with contextlib.ExitStack() as stack: + if tail_loop_kernel.codegen_conditions( + suffix_buf, "C10_UNLIKELY", outer_loop.var + ): + stack.enter_context(suffix_buf.indent()) + if type(tail_loop_kernel) == self.kernel_cls: + reduction_vars = tail_loop_kernel.reduction_var_names + for name in reduction_vars: + new_name = f"{name}_arr[{outer_loop.var}_tail - {cexpr_index(outer_loop.tiled_size)}]" + replace_acc_name(tail_loop_kernel.stores, name, new_name) + replace_acc_name( + tail_loop_kernel.reduction_suffix, name, new_name + ) + suffix_buf.splice( + move_code_under_inner_loop( + tail_loop_kernel.reduction_suffix, + outer_loop.var, + f"{outer_loop.var}_tail", + outer_loop.tiled_size, + outer_loop.size, + ) + ) + else: + suffix_buf.splice(tail_loop_kernel.reduction_suffix) + self.reduction_suffix = suffix_buf + + main_kernel = self.kernels[0] + if inner_loop_reduction_outer_not: + assert outer_loop + aggregate_reduction_prefix_suffix(outer_loop) + else: + main_kernel.finalize_reduction_prefix() + self.reduction_prefix.splice(main_kernel.reduction_prefix) + self.reduction_suffix.splice(main_kernel.reduction_suffix) + self.parallel_reduction_prefix.splice(main_kernel.parallel_reduction_prefix) + self.parallel_reduction_suffix.splice(main_kernel.parallel_reduction_suffix) + self.local_reduction_init.splice(main_kernel.local_reduction_init) + self.local_reduction_stores.splice(main_kernel.local_reduction_stores) + self.non_parallel_reduction_prefix.splice( + main_kernel.non_parallel_reduction_prefix + ) + self.non_parallel_reduction_suffix.splice( + main_kernel.non_parallel_reduction_suffix + ) + + +class OuterLoopFusedKernel(CppKernel): + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.inner: list[LoopNest] = [] + + def decide_parallel_depth(self, max_parallel_depth, threads): + kernels_parallel_depth = [] + nested_kernels: list[CppKernel] = [ + loop_nest.get_kernel() for loop_nest in self.inner + ] + # TODO(leslie-fang-intel): only enable parallel within all outer loop levels. + for kernel in nested_kernels: + # For any ScalarKernel, VecKernel, or Tile2DKernel, + # they should all have the same call_ranges + call_ranges = kernel.call_ranges + assert call_ranges is not None + kernels_parallel_depth.append( + kernel.decide_parallel_depth( + ParallelDepth( + parallel_depth=( + len(call_ranges) - max_parallel_depth.start_depth + ), + start_depth=max_parallel_depth.start_depth, + ), + threads, + ).parallel_depth + ) + return ParallelDepth( + parallel_depth=min( + max_parallel_depth.parallel_depth, max(kernels_parallel_depth) + ), + start_depth=max_parallel_depth.start_depth, + ) + + +class ReasonFusedNodes(Enum): + SAME_VARS_REDUCE = "same_vars_reduce" + COMPATIBLE_REDUCTION = "compatible_reduction" + COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction" + + +class CppScheduling(BaseScheduling): + # Subclass CppKernelProxy to customize codegen without copying codegen_node(). + # Use kernel_proxy_cls to inject custom proxies in CppScheduling subclasses. + # Avoid duplicating codegen_node() just to swap in a custom kernel proxy class. + kernel_proxy_cls: type[CppKernelProxy] = CppKernelProxy + # ctypes limits the number of args to 1024, refer to: + # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 + # We set a conservative threshold here. + MAX_FUSED_KERNEL_ARGS_NUM = 500 + backend_features = OrderedSet( + [ + BackendFeature.INPLACE_BUFFERS, + BackendFeature.REDUCE_TO_SINGLE_ELEMENT, + ] + ) + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + return cls.backend_features + + def __init__(self, scheduler): + super().__init__(scheduler) + if scheduler: + self.reset_kernel_group() + self._ready_to_flush = False + + def _set_flush_status(self, status: bool): + self._ready_to_flush = status + + def group_fn(self, sizes): + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + + def reset_kernel_group(self): + self.kernel_group = KernelGroup() + + def fuse(self, node1, node2): + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + elif node1.is_template(): + assert not node2.is_template() + return FusedSchedulerNode.fuse(node1, node2) + else: + if ( + self._why_fuse_nodes(node1, node2) + == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + ): + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + assert reduce1 == () and reduce2 == (), (reduce1, reduce2) + + def get_indexing_ranges_exprs(node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0, node.snodes + var_ranges = None + indexing_exprs = OrderedSet[Any]() + for snode in node.snodes: + v, exprs = get_indexing_ranges_exprs(snode) + if var_ranges is None: + var_ranges = v + assert var_ranges == v, (var_ranges, v, node.snodes) + indexing_exprs.update(exprs) + return var_ranges, list(indexing_exprs) + else: + assert isinstance(node, SchedulerNode) + comp_buffer = node.node + assert isinstance(comp_buffer, ir.ComputedBuffer) + _, body, _ = comp_buffer.get_default_sizes_body() + return body.var_ranges, list(body.indexing_exprs.values()) + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + assert isinstance(node_to_recomp, SchedulerNode) + + ref_node = node2 if len(vars1) < len(vars2) else node1 + + ref_indexing_constraints = get_indexing_ranges_exprs(ref_node) + + node_to_recomp.recompute_size_and_body( + extra_indexing_constraints=ref_indexing_constraints + ) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + + if vars1 == vars2: + return FusedSchedulerNode.fuse(node1, node2) + + # recompute ref_node if its ranges are also changed + node_to_recomp_indexing_constraints = get_indexing_ranges_exprs( + node_to_recomp + ) + if isinstance(ref_node, SchedulerNode): + ref_node.recompute_size_and_body( + extra_indexing_constraints=node_to_recomp_indexing_constraints + ) + else: + assert isinstance(ref_node, FusedSchedulerNode) + for snode in ref_node.snodes: + assert isinstance(snode, SchedulerNode) + snode.recompute_size_and_body( + extra_indexing_constraints=node_to_recomp_indexing_constraints + ) + ref_node = FusedSchedulerNode(ref_node.scheduler, ref_node.snodes) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + assert vars1 == vars2, (vars1, vars2) + return FusedSchedulerNode.fuse(node1, node2) + elif self.can_fuse_vertical_outer_loop(node1, node2): + return OuterLoopFusedSchedulerNode.fuse( + node1, node2, self._get_outer_loop_fusion_depth(node1, node2) + ) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]: + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + if vars1 == vars2 and reduce1 == reduce2: + return ReasonFusedNodes.SAME_VARS_REDUCE + if reduce1 == () and vars1 == vars2 + reduce2: + return ReasonFusedNodes.COMPATIBLE_REDUCTION + if self._can_fuse_nodes_with_compatible_ranges(node1, node2): + return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + # TODO(jansel): allow fusion pointwise (vars1, ()) suffix? + return None + + def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): + # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges + # e.g. (s0, s1, s2) and (s0 * s1 * s2) + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + c1 = reduce1 == () and reduce2 == () + c2 = math.prod(vars1) == math.prod(vars2) + c3 = len(vars1) == 1 or len(vars2) == 1 + if not (c1 and c2 and c3): + return False + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + ref_node = node2 if len(vars1) < len(vars2) else node1 + + # We can not recompute sizes and body for nodes other than SchedulerNode + # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode + if isinstance(node_to_recomp, FusedSchedulerNode): + return False + + # It may happen that node1 and node2 compatible number of elements + # but different original ranges, for example: + # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2} + # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details + # TODO: we can fix if it allows us to CSE at least one of the variables + + assert isinstance(node_to_recomp, SchedulerNode) + if isinstance(node_to_recomp.node, ir.TemplateBuffer): + return False + assert isinstance(node_to_recomp.node, ir.ComputedBuffer) + # node.data.get_size() is a cheaper version of node.get_read_writes().var_ranges + # but without variable name + ranges2 = node_to_recomp.node.data.get_size() + ranges1 = None + if isinstance(ref_node, FusedSchedulerNode): + ranges_set = OrderedSet[tuple[Any, ...]]() + for snode in ref_node.snodes: + if isinstance(snode.node, ir.TemplateBuffer): + break + assert isinstance(snode.node, ir.ComputedBuffer) + ranges_set.add(tuple(snode.node.data.get_size())) + + if len(ranges_set) != 1: + return False + + ranges1 = list(next(iter(ranges_set))) + else: + assert isinstance(ref_node, SchedulerNode) + assert isinstance(ref_node.node, ir.ComputedBuffer) + ranges1 = ref_node.node.data.get_size() # type: ignore[assignment] + + if ranges1 != ranges2: + return False + + return True + + def _can_fuse_horizontal_impl(self, node1, node2): + assert isinstance(node1, (FusedSchedulerNode, SchedulerNode)) + assert isinstance(node2, (FusedSchedulerNode, SchedulerNode)) + if any( + isinstance(node, OuterLoopFusedSchedulerNode) for node in (node1, node2) + ): + return False + return self._why_fuse_nodes(node1, node2) is not None + + def can_fuse_horizontal(self, node1, node2): + if node1.is_template() or node2.is_template(): + return False + if ( + len(node1.get_nodes()) + len(node2.get_nodes()) + > config.cpp.max_horizontal_fusion_size + ): + return False + + return self._can_fuse_horizontal_impl(node1, node2) + + def can_fuse_multi_outputs_template( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if template_buf := node1.get_template_node(): + return ( + isinstance(template_buf.layout, ir.MultiOutputLayout) + and isinstance(node2.node, ir.MultiOutput) + and len(node2.node.inputs) == 1 + and node2.node.inputs[0].get_name() == template_buf.name + ) + return False + + def _get_outer_loop_fusion_depth(self, node1, node2): + DISABLE_OUTER_LOOP_FUSION = 0 + if not all( + type(node) + in (OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode) + for node in (node1, node2) + ): + return DISABLE_OUTER_LOOP_FUSION + + _node1 = ( + node1.get_outer_nodes()[-1] + if isinstance(node1, OuterLoopFusedSchedulerNode) + else node1 + ) + assert isinstance(_node1, (FusedSchedulerNode, SchedulerNode)) + _node2 = ( + node2.get_outer_nodes()[0] + if isinstance(node2, OuterLoopFusedSchedulerNode) + else node2 + ) + assert isinstance(_node2, (FusedSchedulerNode, SchedulerNode)) + + _, (vars1, reduce1) = _node1.group + _, (vars2, reduce2) = _node2.group + if vars1 == () and vars2 == () and reduce1 != () and reduce2 != (): + # Reduction only + return DISABLE_OUTER_LOOP_FUSION + if all(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return ( + node1.outer_loop_fusion_depth + if node1.outer_loop_fusion_depth == node2.outer_loop_fusion_depth + else DISABLE_OUTER_LOOP_FUSION + ) + outer_loop_fusion_depth = min(len(vars1), len(vars2)) + if ( + outer_loop_fusion_depth >= 1 + and vars1[:outer_loop_fusion_depth] == vars2[:outer_loop_fusion_depth] + ): + if any( + type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2) + ): + _compare_node = ( + node1 if type(node1) is OuterLoopFusedSchedulerNode else node2 + ) + if _compare_node.outer_loop_fusion_depth == outer_loop_fusion_depth: + # Same outer loop fusion depth as prev nodes in OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + else: + return DISABLE_OUTER_LOOP_FUSION + else: + # First 2 nodes to generate OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + return DISABLE_OUTER_LOOP_FUSION + + def can_fuse_vertical_outer_loop(self, node1, node2): + return ( + not node1.is_template() + and not node2.is_template() + and node1.get_operation_names() & node2.ancestors + and not ( + self._can_fuse_horizontal_impl(node1, node2) + and not node1.is_reduction() + ) + and self._get_outer_loop_fusion_depth(node1, node2) >= 1 + ) + + def get_fusion_pair_priority(self, node1, node2): + if self.can_fuse_vertical_outer_loop(node1, node2): + # Outer loop fusion with lower priority + return 1 + else: + return 0 + + def can_fuse_vertical(self, node1, node2): + if node2.is_template(): + # TODO(jgong5): support pre-op fusion with template + return False + if node1.is_template(): + template_fusion_supported, _ = template_fusion_with_epilogues_supported( + node1, [node2] + ) + return not node2.is_reduction() and template_fusion_supported + return ( + self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() + ) or self.can_fuse_vertical_outer_loop(node1, node2) + + def try_loop_split(self, nodes: list[SchedulerNode]): + """ + Apply loop split optimization. + When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop + to avoid non-contiguous loads, subject to the following conditions: + 1. No reduction and no mudular index for all nodes. + 2. The indexing_exprs of all nodes contain only one (or more, but all the same) division, + where the divisor is an integer and not too small (the divisor > 8), the dividend is + one of the iter_vars, and this var, i.e. the dimension that needs to be split, is + contiguous in all other indexing_exprs. + + For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: + {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, + we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to + {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to + {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. + """ + + # No reduction and no mudular + if any( + len(node.group[1][1]) != 0 + or any( + expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() + ) + for node in nodes + ): + return nodes + + split_var = None + split_number = None + num_div = 0 + div_expr_ = None + match_div = False + matched_node = None + + for node in nodes: + assert isinstance(node.node, ir.ComputedBuffer) + _, original_body, _ = node.node.get_default_sizes_body() + for name, expr in original_body.indexing_exprs.items(): + if not isinstance(expr, sympy.Expr): + continue + for div_expr in expr.find(FloorDiv): + if ( + any(div_expr.has(var) for var in original_body.iter_vars) + and div_expr != div_expr_ + ): + div_expr_ = div_expr + num_div += 1 + if num_div > 1: + return nodes + if ( + isinstance(div_expr.args[1], sympy.core.numbers.Integer) + and div_expr.args[0] in original_body.iter_vars + and name is not None + and all( + stride_at_vec_range(expr_, div_expr.args[0]) in (0, 1) + for name_, expr_ in original_body.indexing_exprs.items() + if name_ != name + ) + and div_expr.args[1] > 8 + ): + split_var = div_expr.args[0] + split_number = div_expr.args[1] + match_div = True + matched_node = node + + # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. + if not match_div: + return nodes + + extra_indexing_constraints = None + + def loop_split(sizes, body, vars): + index_size, reduce_size = sizes + index_vars, reduce_vars = vars + split_idx = index_vars.index(split_var) + new_index_size = index_size.copy() + new_index_size[split_idx] = index_size[split_idx] // split_number + new_index_size.insert(split_idx + 1, split_number) + (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( + new_index_size, reduce_size, prefix="y" + ) + iter_vars = new_index_vars.copy() + divisor_var = iter_vars.pop(split_idx + 1) + iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var + body = ir.LoopBody( + body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars + ) + nonlocal extra_indexing_constraints + if not extra_indexing_constraints: + extra_indexing_constraints = ( + body.var_ranges, + list(body.indexing_exprs.values()), + ) + return ( + (new_index_size, reduce_size), + body, + (new_index_vars, reduce_vars), + ) + + # Here decide the final loop order + for node in nodes: + if node == matched_node: + node.recompute_size_and_body(recompute_sizes_body_func=loop_split) + for node in nodes: + if node != matched_node: + node.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=loop_split, + ) + + return nodes + + def codegen_outer_loop_node( + self, + node: OuterLoopFusedSchedulerNode, + ): + """ + Generate the code for the outer loop fused scheduler node. + 1. Codegen with fused outer loop: depends on the analysis of + the outer loop fused scheduler node, with or without the local buffer. + 2. If failed, fallback to standard codegen. + """ + kernel_group = self.kernel_group + generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count + cpp_kernel_proxy_list: list[self.kernel_proxy_cls] = [] # type: ignore[name-defined] + nodes_list: list[list[SchedulerNode]] = [] + assert isinstance(node, OuterLoopFusedSchedulerNode) + + def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode): + """ + Codegen code with fused outer loop and local Buffer. + """ + assert isinstance(node, OuterLoopFusedSchedulerNode) + cpp_kernel_proxy_list.clear() + nodes_list.clear() + + def get_call_ranges(node: BaseSchedulerNode): + assert isinstance(node, (SchedulerNode, FusedSchedulerNode)) + nodes: list[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + return call_ranges + + local_buffers: list[ir.Buffer] = [] + # Map local buffer name to a list of global buffers + local_to_global_buffers: dict[str, list[ir.Buffer]] = {} + if all( + len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1 + for _node in node.get_outer_nodes() + ): + # Ref to the typical case of local buffer in + # https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # noqa: B950 + # where the buffer is with size of last dim and contiguous. + # Only support this typical case at first. + visited_scheduler_nodes: OrderedSet[str] = OrderedSet() + for scheduler_node in node.get_nodes(): + # all users inside same OuterLoopFusedSchedulerNode + assert isinstance(scheduler_node, SchedulerNode) + visited_scheduler_nodes.add(scheduler_node.get_name()) + if ( + scheduler_node.is_reduction() + or len(scheduler_node.get_outputs()) != 1 + ): + continue + + scheduler_buffer = scheduler_node.get_outputs()[0] + if all( + user.node in node.get_nodes() for user in scheduler_buffer.users + ): + global_buffer = scheduler_buffer.node + assert isinstance(global_buffer, ir.ComputedBuffer) + global_buffer_layout = global_buffer.get_layout() + size_offset = node.outer_loop_fusion_depth - len( + get_call_ranges(scheduler_node) + ) + + def is_all_write_read_contiguous(): + contiguous_index_expr = 0 + stride = 1 + for var, range in reversed( + scheduler_node._body.var_ranges.items() + ): + contiguous_index_expr += stride * var + stride *= range + write_index_expr = scheduler_node._body.get_write_expr( + scheduler_buffer.get_name() + ) + + def is_contiguous_index(x): + return x == contiguous_index_expr + + return is_contiguous_index(write_index_expr) and all( + isinstance(user.node, SchedulerNode) + and is_contiguous_index( + user.node._body.get_read_expr( + scheduler_buffer.get_name() + ), + ) + for user in scheduler_buffer.users + ) + + if not ( + global_buffer_layout.is_contiguous() + and is_all_write_read_contiguous() + ): + continue + # Local Buffer is a view of global buffer + local_buffer_layout = ir.FixedLayout( + global_buffer_layout.device, + global_buffer_layout.dtype, + global_buffer_layout.size[size_offset:], + global_buffer_layout.stride[size_offset:], + ) + + def try_share_local_buffer(local_buffer_layout, local_buffers): + for local_buf in local_buffers: + if local_buffer_layout == local_buf.layout and all( + all( + user.node.get_name() in visited_scheduler_nodes + for user in V.graph.scheduler.name_to_buf[ + global_buffer.name + ].users + ) + for global_buffer in local_to_global_buffers[ + local_buf.name + ] + if global_buffer.name is not None + ): + return local_buf + return None + + local_buf_prefix = "local_buffer_data" + # Share existing local buffer + local_buffer_used = try_share_local_buffer( + local_buffer_layout, local_buffers + ) + if not local_buffer_used: + # Create new local buffer + local_buffer_used = ir.Buffer( + name=f"{local_buf_prefix}_{len(local_buffers)}", + layout=local_buffer_layout, + ) + local_buffers.append(local_buffer_used) + local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index] + local_to_global_buffers[local_buffer_used.name].append( + global_buffer, + ) + + with LocalBufferContext(kernel_group.args) as scope: + if len(local_buffers) > 0: + for local_buffer in local_buffers: + assert local_buffer.name is not None + scope.add_local_buffer( + local_buffer, local_to_global_buffers[local_buffer.name] + ) + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) + cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type] + cpp_kernel_proxy_list.append(cpp_kernel_proxy) + nodes_list.append(_node.get_nodes()) # type: ignore[arg-type] + + if not node.check_outer_fusion_loop_level_attr( + cpp_kernel_proxy_list, node.outer_loop_fusion_depth + ): + for removed_buffer in scope.removed_buffers: + # Restore the removed buffers by this context before + # fallback to codegen without using Local Buffer + V.graph.removed_buffers.remove(removed_buffer) + return False + metrics.cpp_outer_loop_fused_inner_counts.append( + metrics.CppOuterLoopFusedCount( + len(cpp_kernel_proxy_list), + local_buffer_number=len(scope.local_buffers), + ) + ) + outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( + cpp_kernel_proxy_list, + ) + kernel_group.finalize_kernel( + outer_fusion_cpp_kernel_proxy, + [*itertools.chain.from_iterable(nodes_list)], + ) + + return True + + if not try_outer_loop_fusion_with_local_buf(node): + # Reset generated_cpp_vec_kernel_count to codegen again + metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count + cpp_kernel_proxy_list.clear() + nodes_list.clear() + # Similar as comment in + # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272 + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + with torch._inductor.config.patch(inplace_buffers=False): + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + _nodes: list[SchedulerNode] = _node.get_nodes() # type: ignore[assignment] + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) + cpp_kernel_proxy.codegen_nodes(_nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes) + + def codegen_node( + self, + node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode], + ): + """ + Turn an set of pre-fused nodes into a C++ kernel. + """ + kernel_group = self.kernel_group + + if isinstance(node, OuterLoopFusedSchedulerNode): + self.codegen_outer_loop_node(node) + else: + nodes: list[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + nodes = self.try_loop_split(nodes) + cpp_kernel_proxy = self.kernel_proxy_cls(kernel_group) + cpp_kernel_proxy.codegen_nodes(nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) + + args_num = self._get_scheduled_num_args() + if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: + self._set_flush_status(True) + + def is_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ir.CppTemplateBuffer + ) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CPP template, possibly with fused epilogues + """ + assert not prologue_nodes + + # remove MultiOutput from epilogue_nodes + epilogue_nodes = [ + epilogue_node + for epilogue_node in epilogue_nodes + if isinstance(epilogue_node, (SchedulerNode, FusedSchedulerNode)) + ] + # The counter cpp_templated_kernel_counter is used for verifying if a + # a templated kernel was successfully compiled in a UT + counters["inductor"]["cpp_templated_kernel_counter"] += 1 + counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cpp_template(template_node), ( + "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_, rnumel) = template_node.group + assert rnumel == () + ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) + epilogue_ir_nodes: list[Optional[ir.Operation]] = [ + n.node for n in epilogue_nodes + ] + assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), ( + "Epilogue nodes must all be instances of ir.ComputedBuffer" + ) + + def template_buffer_has_other_users( + template_buffer, outputs_by_name, epilogue_nodes + ): + if not epilogue_nodes: + return False + + assert template_buffer.get_name() in outputs_by_name + users = outputs_by_name[template_buffer.get_name()].users + return not all( + isinstance(user.node, BaseSchedulerNode) + and user.node.node in epilogue_nodes + for user in users + ) + + flag_template_buffer_has_other_users = template_buffer_has_other_users( + ctb, template_node.outputs_by_name, epilogue_ir_nodes + ) + kernel, render = ctb.make_kernel_render( + ctb, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_ir_nodes, + ) + with kernel: + if not is_multi_outputs_template(template_node.node): + template_node.mark_run() # type: ignore[attr-defined] + for node in epilogue_nodes: + node.mark_run() # type: ignore[attr-defined] + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) + + if is_multi_outputs_template(template_node.node): + # For multi outputs template, allocate buffers for each output after the epilogue + # codegen to which determines if the buffer has been removed. + assert len(template_node.outputs) == 1, ( + "Multi outputs template should be with 1 output template buffer of MultiOutputLayout" + ) + for user in template_node.outputs[0].users: + assert isinstance(user.node, ExternKernelSchedulerNode), ( + "Multi outputs template should be with ExternKernelSchedulerNode" + ) + assert isinstance(user.node.node, ir.MultiOutput), ( + "Multi outputs template has multi users with MultiOutput" + ) + user.node.mark_run() + + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() + + def _get_scheduled_num_args(self): + return self.kernel_group.get_num_args() + + def ready_to_flush(self): + return self._ready_to_flush + + def codegen_sync(self): + pass + + def define_kernel(self, src_code, nodes, kernel_args=None): + wrapper = V.graph.wrapper_code + fused_name = ( + get_fused_kernel_name(nodes, config.cpp.descriptive_names) + if config.cpp.descriptive_names + else "" + ) + kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) + # below add provenance tracing info for cpu CppKernel types + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing(nodes, kernel_name) + + kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name) + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "//") + + # Get the lines in the source code representing the function definition, + # excluding the the first line including cpp_prefix.h. + first_char = src_code.rfind('extern "C"') + last_char = src_code.find(")", first_char) + if _IS_WINDOWS: + # get_export_declaration introduced one more ')' in Windows + last_char = src_code.find(")", last_char + 1) + kernel_definition = f"{src_code[first_char : last_char + 1]};\n" + + compile_wrapper = IndentedBuffer() + args = self.kernel_group.args if kernel_args is None else kernel_args + _, _, arg_types = args.cpp_argdefs() + if not V.graph.cpp_wrapper: + compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") + compile_wrapper.splice(src_code, strip=True) + if not V.graph.cpp_wrapper: + compile_wrapper.writeline("''')") + wrapper.define_kernel( + kernel_name, + compile_wrapper.getvalue(), + gpu=False, + cpp_definition=kernel_definition, + ) + return kernel_name + + def flush(self): + src_code = self.kernel_group.codegen_group() + if src_code: + kernel_name = self.define_kernel( + src_code, self.kernel_group.scheduled_nodes + ) + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + self.reset_kernel_group() + self._set_flush_status(False) + + +class KernelGroup: + def __init__(self): + super().__init__() + self.args = KernelArgs() + self.loops_code = BracesBuffer() + self.ws = WorkSharing(self.loops_code) + self.stack = contextlib.ExitStack() + self.stack.enter_context(self.ws) + self.scheduled_nodes = [] + + def new_kernel(self, cls, *args): + return cls(self.args, parallel_num_threads(), *args) + + def finalize_kernel(self, new_kernel, nodes): + self.scheduled_nodes += nodes + code = self.loops_code + ws = self.ws + new_kernel.codegen_loops(code, ws) + + def get_num_args(self): + arg_defs, _call_args, _arg_types = self.args.cpp_argdefs() + args_num = len(arg_defs) + return args_num + + def codegen_group(self, name=None) -> str: + self.stack.close() + if not self.scheduled_nodes: + return "" + code = BracesBuffer() + # 1. Include header files + # TODO: support kernel profile on other platforms + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + code.writelines(["#include "]) + code.writeline("#include ") + + # 2. Function definition + kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name + kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name + arg_defs, _, _ = self.args.cpp_argdefs() + arg_defs = ",\n".ljust(25).join(arg_defs) + func_export_decl = get_export_declaration() + code.writeline( + f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})' + ) + + # 3. Function body + with code.indent(): + if enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + code.writelines( + [ + f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' + ] + ) + for old, new in self.args.aliases(): + code.writeline(f"auto {old} = {new};") + code.splice(self.loops_code) + return code.getvalue() + + def call_kernel(self, wrapper, kernel_name): + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call( + kernel_name, call_args, triton=False, arg_types=arg_types + ) + + +class WorkSharing: + def __init__(self, code): + self.code = code + self.in_parallel = False + self.num_threads = None + self.stack = contextlib.ExitStack() + + def parallel(self, threads): + if self.in_parallel and threads != self.num_threads: + # wrong number of threads + self.close() + if not self.in_parallel: + self.num_threads = threads + self.in_parallel = True + if config.cpp.dynamic_threads: + self.code.writeline("#pragma omp parallel") + else: + self.code.writeline(f"#pragma omp parallel num_threads({threads})") + self.stack.enter_context(self.code.indent()) + self.code.writeline( + "int tid = omp_get_thread_num();", + ) + + def single(self): + if self.in_parallel: + self.code.writeline("#pragma omp single") + return self.in_parallel + + def close(self): + self.stack.close() + self.in_parallel = False + + def __enter__(self): + self.stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stack.__exit__(exc_type, exc_val, exc_tb) + + +@dataclasses.dataclass +class LoopLevel: + var: Optional[sympy.Expr] = None + size: Optional[sympy.Expr] = None + offset: sympy.Expr = sympy.S.Zero + # Note [tiled_size] + # We may do loop-tiling at this loop level. + # When var is in [offset, tiled_size), we will perform the vectorization kernel. + # When var is in [tiled_size, size), we will perform the scalar or masked vectorization kernel. + # for (var = offset; var < size; var += steps) { + # if (var >= offset && var < tiled_size) vec_loop_body(); + # if (var >= tiled_size && var < size) scalar_or_maskvec_loop_body(); + # } + tiled_size: sympy.Expr = sympy.S.Zero + steps: sympy.Expr = sympy.S.One + parallel: int = 0 + simd_omp: bool = False + simd_vec: bool = False + collapsed: bool = False + is_reduction: bool = False + + def __post_init__(self): + # Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check + # vectorization ISA is a time-consuming and one-shot operation. It leads + # to taking a longer time to import `codegen.cpp` package because the + # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while + # the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the + # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation + # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to + # `__post_init__` + picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 + + def tile(self, factor): + sympy_factor = sympy.Integer(factor) + loop = LoopLevel(self.var, self.size) + loop.steps = sympy_factor + loop.simd_vec = True + loop.tiled_size = FloorDiv(loop.size, sympy_factor) * sympy_factor + loop.parallel = self.parallel + loop.collapsed = False + loop.is_reduction = self.is_reduction + return loop + + def lines(self): + offset_expr = cexpr_index(self.offset) + size_expr = cexpr_index(self.size) + if config.cpp.no_redundant_loops and offset_expr == size_expr: + return None + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) + if self.parallel: + # TODO(jansel): look into chunk size and other schedules + line1 = "#pragma omp for" + if self.parallel > 1: + line1 += f" collapse({self.parallel})" + if self.simd_omp: + line1 = line1.replace(" for ", f" for {simd}") + elif self.simd_vec: + line1 = "" + elif self.simd_omp: + line1 = f"#pragma omp {simd}" + elif not self.is_reduction and cpp_builder.is_gcc(): + line1 = "#pragma GCC ivdep" + else: + line1 = "" + offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" + size_str = f"{self.var}<{size_expr}" + if self.steps.is_number: + steps_str = f"{self.var}+={cexpr_index(self.steps)}" + else: + # If the step size is 0, change it to 1 because a step size of 0 + # will cause floating point exception (core dump) during parallelization. + steps_str = ( + f"{self.var}+=({cexpr_index(self.steps)} == 0 ? " + f"1 : {cexpr_index(self.steps)})" + ) + line2 = f"for({offset_str}; {size_str}; {steps_str})" + if self.collapsed or not line1: + return [line2] + return [line1, line2] + + +@dataclasses.dataclass +class LoopNest: + """ + A loop-nest-like structure. It is built with the `build` method + as a loop nest and then will perform loop-tiling at some depth. + + A typical case is for vectorization, where we typically do loop-tiling + at the innermost loop level. A more complicated case is when we do + 2D tiling at both the innermost and outer levels. + """ + + loops: Optional[list[LoopLevel]] = None + kernel: Optional[CppKernel] = None + + @staticmethod + def build(kernel: CppKernel): + """Build a LoopNest with the given `kernel` as the leaf""" + itervars = kernel.itervars + ranges = kernel.ranges + reduction_depth = kernel.reduction_depth + assert reduction_depth is not None + + loops: Optional[list[LoopLevel]] = None + for loop_idx, (var, size) in enumerate(zip(itervars, ranges)): + loop = LoopLevel(var, size) + if not loops: + loops = [loop] + else: + loops.append(loop) + if loop_idx >= reduction_depth: + loop.is_reduction = kernel.is_reduction + + loop_nest = LoopNest(loops) + return loop_nest + + def __bool__(self): + return bool(self.loops) + + @cache_on_self + def max_parallel_depth(self): + """ + Maximal allowed depth for parallelism: All reduction or non-reduction levels. + When the range of the first inner loop beyond the maximum parallel depth is much + larger than the range of all outer loops within the maximum parallel depth, + change the starting depth of parallelism to the first inner loop and recalculate + the maximum parallel depth. + """ + if self.loops is None: + return ParallelDepth(parallel_depth=0, start_depth=0) + + start_depth = 0 + max_depth = 0 + is_reduction = self.loops[0].is_reduction + num_steps = sympy.Integer(1) + for loop in self.loops: + if loop.is_reduction != is_reduction: + break + num_steps = num_steps * FloorDiv(loop.size, loop.steps) + max_depth += 1 + + def get_simd_vec_depth(loops): + # Return the first loop level which is simd_vec + for i, loop in enumerate(loops): + if loop.simd_vec: + return i + return None + + simd_vec_depth = get_simd_vec_depth(self.loops) + + # When the number of steps of the first inner loop is much larger than the number of steps of + # all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`. + if ( + max_depth < len(self.loops) + and isinstance(num_steps, sympy.Integer) + and isinstance(self.loops[max_depth].size, sympy.Integer) + and num_steps * 300 + < FloorDiv(self.loops[max_depth].size, self.loops[max_depth].steps) + and not ( + # Disable parallel reduction under the vec loop + simd_vec_depth is not None + and max_depth > simd_vec_depth + and self.loops[max_depth].is_reduction + ) + ): + start_depth = max_depth + max_depth = 0 + is_reduction = self.loops[start_depth].is_reduction + for i in range(start_depth, len(self.loops)): + if self.loops[i].is_reduction != is_reduction: + break + max_depth += 1 + return ParallelDepth(parallel_depth=max_depth, start_depth=start_depth) + + def mark_parallel(self, par_depth): + assert par_depth.parallel_depth <= self.max_parallel_depth().parallel_depth, ( + "Parallel depth cannot exceed the maximal allowed parallel depth" + ) + assert self.loops is not None + assert len(self.loops) >= par_depth.parallel_depth + loop = self.loops[par_depth.start_depth] + loop.parallel = par_depth.parallel_depth + if loop.is_reduction: + metrics.parallel_reduction_count += 1 + for i in range(par_depth.start_depth + 1, par_depth.parallel_depth): + self.loops[i].collapsed = True + + def tile(self, depth, factor): + """ + Do loop-tiling at the `depth` level with `factor`. + for (x0 = 0; x0 < x0_end; x0++) + -> + for (x0 = 0; x0 < x0_end; x0 += factor) + See details in Note [tiled_size]. + """ + assert self.loops + self.loops[depth] = self.loops[depth].tile(factor) + return self.loops[depth] + + def get_kernel(self) -> CppKernel: + assert self.kernel + return self.kernel + + def set_kernel(self, kernel): + self.kernel = kernel + + def from_loop_level(self, level: int): + assert self.loops + assert len(self.loops) >= level + loops = None if level == len(self.loops) else self.loops[level:] + return LoopNest(loops, self.kernel) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_bmm_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_bmm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..d634b3bcd01f4731bfb726c944202cb0e2ae6f20 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_bmm_template.py @@ -0,0 +1,262 @@ +# mypy: allow-untyped-defs +import contextlib +import itertools +from typing import Any, Callable, Optional +from unittest.mock import patch + +import sympy + +from .. import ir +from ..select_algorithm import PartialRender +from ..virtualized import V +from .common import ArgName +from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE +from .cpp_micro_gemm import LayoutType +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import DTYPE_TO_CPP, GemmBlocking + + +# We pass all sizevars present in BY to the GEMM templates so variables are not renamed in the BMM definition +GEMM_SINGLE_THREAD_MM_STUB = r""" +{{kernel.def_kernel( + inputs={"X": X, "W": W}, + outputs={"Y": Y_2d}, + aliases=aliases, + function_name=kernel_name+"_single_thread_mm", + extra_sizevars=BY_sizevars + [b_index], + placeholder="")}}""" + +GEMM_THREADED_MM_STUB = r""" +{{kernel.def_kernel( + inputs={"X": X, "W": W}, + outputs={"Y": Y_2d}, + aliases=aliases, + function_name=kernel_name+"_threaded_mm", + extra_sizevars=BY_sizevars + [b_index], + placeholder="")}}""" + +BMM_TEMPLATE = r""" +{{ template.codegen_microkernel_def() }} +{{ template.codegen_single_thread_gemm() }} +{{ template.codegen_multi_thread_gemm() }} + +extern "C" +{{kernel.def_kernel(inputs={"X": BX, "W": BW}, outputs={"Y": BY}, aliases=aliases)}} +{ + const int64_t B = {{kernel.size(BY_2d, 0)}}; + {%- if num_threads > 1 %} + constexpr int64_t num_threads = {{num_threads}}; + int64_t B_single_thread_block = (B / num_threads) * num_threads; + + #pragma omp parallel for num_threads({{num_threads}}) + {%- else %} + int64_t B_single_thread_block = B; + {%- endif %} + for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) { + {{template.get_gemm_function_call( + kernel, + kernel_name+"_single_thread_mm", + "", + b_index="b_start", + )}} + } + for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) { + {{template.get_gemm_function_call( + kernel, + kernel_name+"_threaded_mm", + "", + b_index="b_start", + )}} + } +} +""" + + +class CppBmmTemplate(CppGemmTemplate): + def __init__( + self, + input_nodes, + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta=1, + alpha=1, + has_bias=False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + should_block_weights: bool = False, + name="bmm", + ): + """ + In order to simplify the implementation and increase code reuse, the BMM template implements + two versions of the GEMM kernel: a single-threaded version and a multi-threaded version. + GEMM kernels are called in a loop over the batch dimension, with single-threaded GEMM calls + for all but the last (B % num_threads), which are handled by the multi-threaded GEMM kernel. + + We use an extra sizevar `b_index` to index the batch dimension, which we pass into the GEMM + template as a sympy.Symbol. This allows us to slice the 3D batch tensors in the GEMM template + without any changes to the GEMM template itself. + """ + super().__init__( + input_nodes, + layout, + num_threads, + register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + should_block_weights=should_block_weights, + name=name, + ) + self.b_index = sympy.Symbol("s_b_index", integer=True, nonnegative=True) + + @staticmethod + def get_padded_size(n, block_n, k, should_block_weight): + if should_block_weight: + # Tensor is constant or not contiguous, so we will pad and block + new_size, padded_n = CppGemmTemplate.get_padded_size( + n, block_n, k, should_block_weight + ) + # Add the new batch dimension + new_size.insert(0, -1) + return new_size, padded_n + else: + new_size = [-1, k, n] + return new_size, n + + @staticmethod + def check_if_block_weight(W, micro_gemm): + assert isinstance(W, ir.IRNode) + _, n = W.get_size()[-2:] + result = ( + not W.get_layout().is_contiguous() + or W.get_name() in V.graph.constants + or ( + n % micro_gemm.register_blocking.block_n != 0 + and micro_gemm.get_b_layout != LayoutType.NORMAL + ) + ) + return result + + def get_gemm_function_call( + self, + kernel: CppTemplateKernel, + function_name: str, + placeholder: str, + b_index: str, + ) -> str: + """ + Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition, + generate a function call for the GEMM kernel. + Args: + placeholder: The string to replace the function call with + b_index: The index for slicing the 3D batch tensors + """ + + def hook(): + arg_defs, call_args, _, _ = kernel.args.python_argdefs() + for i, buf in enumerate(call_args): + if buf == self.b_index: + arg_defs[i] = ArgName(b_index) + call = f"{function_name}({', '.join(x.full_name() for x in arg_defs)});" + return call + + assert placeholder not in kernel.render_hooks + kernel.render_hooks[placeholder] = hook + return placeholder + + def get_default_reindexers(self, epilogue_nodes): + def reindexer(args): + # if epilogue nodes exist, they have 3D ranges but args are 2D, so add 0 index + return [self.b_index] + args + + return [reindexer] * len(epilogue_nodes) + + def get_options( + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> dict[str, Any]: + options = super().get_options( + kernel=kernel, + template_buffer_node=template_buffer_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + + BX, BW, BY = options["X"], options["W"], options["Y"] + options["BX"], options["BW"], options["BY"] = BX, BW, BY + options["BY_2d"] = options["Y_2d"] + for kword in ["X", "W", "GemmOut", "Y_2d"]: + options[kword] = kernel.select(options[kword], 0, self.b_index) + for kword in ["X", "W", "Y_2d"]: + options[kword + "_dtype"] = DTYPE_TO_CPP[options[kword].dtype] + options["b_index"] = self.b_index + options["BY_sizevars"] = [ + s + for sym in itertools.chain(BY.get_size(), BY.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ] + options["kernel_name"] = kernel.kernel_name + + return options + + def render( # type: ignore[override, return] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + options = self.get_options( + kernel=kernel, + template_buffer_node=template_buffer_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + self.render_options = options + + with contextlib.ExitStack() as stack: + for buf in options["fake_buffers"]: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + result = self._template_from_string(BMM_TEMPLATE).render(**options) + + # Finalize the function definitions for the gemm routines + sub_mm_hooks = { + name: hook + for name, hook in kernel.render_hooks.items() + if "FOR_BMM" in name + } + result = PartialRender(result, sub_mm_hooks).finalize_all() + for name in sub_mm_hooks: + del kernel.render_hooks[name] + del kernel.args.sizevars[options["b_index"]] + return result + + def codegen_single_thread_gemm(self): + stub = self._template_from_string(GEMM_SINGLE_THREAD_MM_STUB).render( + self.render_options + ) + return stub + self._template_from_string(GEMM_TEMPLATE).render( + {**self.render_options, "num_threads": 1} + ) + + def codegen_multi_thread_gemm(self): + stub = self._template_from_string(GEMM_THREADED_MM_STUB).render( + self.render_options + ) + return stub + self._template_from_string(GEMM_TEMPLATE).render( + self.render_options + ) + + def codegen_gemm_stub_def(self): + return "" diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py new file mode 100644 index 0000000000000000000000000000000000000000..9d660c0e46da47e404ed26b8510d47351136201a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_flex_attention_template.py @@ -0,0 +1,1081 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import re +from typing import Optional +from unittest.mock import patch + +import sympy + +import torch +import torch.utils + +from ...utils._ordered_set import OrderedSet +from .. import ir +from ..ir import TensorBox +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import parallel_num_threads +from ..virtualized import V +from .cpp_template import CppTemplate +from .cpp_utils import GemmBlocking + + +log = logging.getLogger(__name__) + +# TODO: reuse cpp codegen to generate below pointwise/reduction kernels +SOFTMAX_FUSIONS = r""" +// 1) out = exp(a - val) +// 2) val = sum(out) +template +inline void {{kernel_name}}_exp_reduce_sum_fusion_kernel( + T1* a, + const int& size, + T2* out, + T1& val) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + T1 tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + at::native::_store(out + i, tmp2); + } + tmp_sum = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return x + y; + }, + vec_tmp_sum); + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 - val; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + out[i] = tmp2; + } + val = tmp_sum; +} + +// 1) out = a * scale +// 2) max = max(out) +template +inline void {{kernel_name}}_mul_reduce_max_fusion_kernel( + const scalar_t* a, + const scalar_t& scale, + const int& size, + scalar_t* out, + scalar_t& max) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + scalar_t tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + at::native::_store(out + i, tmp1); + } + for (long i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + max = std::max( + tmp_max, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + vec_tmp_max)); +} + +template +static inline scalar_t* {{kernel_name}}_conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { + TORCH_CHECK(ptr2 == nullptr); + return ptr; +} + +template , int> = 0> +static inline scalar_t* {{kernel_name}}_conditional_data_ptr(float* ptr, scalar_t* ptr2) { + return ptr2; +} + +template +inline void {{kernel_name}}_fill_stub(scalar_t* data, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + Vec data_vec = Vec(val); + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + data_vec.store(data + d); + } + #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) + # pragma unroll + #endif + for (; d < size; d++) { + data[d] = val; + } +} + +// out = a * scale +template +inline void {{kernel_name}}_mul_scale_kernel( + scalar_t* a, + scalar_t scale, + int64_t size) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + for (int64_t i = 0; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + at::native::_store(a + i, tmp1); + } + for (int64_t i = vec_size * (size / vec_size); i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + a[i] = tmp1; + } +} + +""" + +BRGEMM_PACK_FUNCTIONS = r""" +template +inline void {{kernel_name}}_copy_value_with_pad( + const scalar_t* value_ptr, + scalar_t* dst_ptr, + int64_t rows, + int64_t cols, + int64_t prows, + int64_t pcols, + int64_t ldi) { + auto vec_size = at::vec::Vectorized::size(); + int64_t i = 0; + for (; i < rows; i++) { + int64_t j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(dst_ptr + i * pcols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(dst_ptr + i * pcols + j, cols - j); + } + + // col padding + auto psize = pcols - cols; + if (psize > 0) { + auto zero_vec = at::vec::Vectorized(0); + int64_t pj = 0; + for (; pj < psize - (psize % vec_size); pj += vec_size) { + zero_vec.store(dst_ptr + i * pcols + cols + pj); + } + if (pj < psize) { + zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj); + } + } + } + // row padding + for (; i < prows; i++) { + auto zero_vec = at::vec::Vectorized(0); + int64_t j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + zero_vec.store(dst_ptr + i * pcols + j); + } + if (j < pcols) { + zero_vec.store(dst_ptr + i * pcols + j, pcols - j); + } + + } +} +""" + +MICRO_GEMM_TEMPLATE = r""" +GEMM_DEFINE +""" + +ALLOCATE_BUFFER = r""" + int64_t {{buffer_name}}_dtype_itemsize = c10::is_reduced_floating_point_v<{{buffer_dtype}}> ? 2 : 4; + auto& {{buffer_name}}_allocator = *at::getCPUAllocator(); + auto {{buffer_name}}_work_data = {{buffer_name}}_allocator.allocate({{buffer_size}}*{{buffer_name}}_dtype_itemsize); + void* {{buffer_name}}_data_ptr = {{buffer_name}}_work_data.get(); + {{buffer_dtype}}* {{buffer_name}} = ({{buffer_dtype}}*){{buffer_name}}_data_ptr; +""" + +FLEX_ATTENTION_TEMPLATE = r""" +{{template.header().getvalue()}} +#include +#include +#include +{{template.codegen_micro_gemm(kernel.kernel_name)}} +{{template.codegen_softmax_fusion(kernel.kernel_name)}} +{{template.codegen_brgemm_pack_function(kernel.kernel_name)}} +{%- set kernel_args = {"query": query, "key": key, "value": value, + "kv_num_blocks": kv_num_blocks, "kv_indices": kv_indices, + "full_kv_num_blocks": full_kv_num_blocks, "full_kv_indices": full_kv_indices } %} +{%- set kernel_args = template.update_kernel_args(kernel_args) %} + +extern "C" +{{kernel.def_kernel(inputs=kernel_args, outputs={"output": output}, extra_sizevars=template.extra_sizevars)}} +{ + {{ kernel.maybe_codegen_profile() }} + int64_t qBlockSize = {{qBlockSize}}; + int64_t kvBlockSize = {{kvBlockSize}}; + int64_t num_thread = {{num_thread}}; + + // dtypes of kernel and internal buffers + using scalar_t = {{kernel.dtype(query)}}; + constexpr bool is_reduced_type = c10::is_reduced_floating_point_v; + using accum_t = at::opmath_type<{{kernel.dtype(query)}}>; + using Vec = at::vec::Vectorized; + accum_t scaling_factor = {{scale}}; + int64_t batchSize = {{kernel.size(query, 0)}}; + int64_t qSize = {{kernel.size(query, 1)}}; + int64_t num_head = {{kernel.size(query, 2)}}; + int64_t headSize = {{kernel.size(query, 3)}}; + int64_t batchSize_k = {{kernel.size(key, 0)}}; + int64_t num_head_k = {{kernel.size(key, 2)}}; + int64_t headSize_v = {{kernel.size(value, 3)}}; + bool is_broadcast_bs_kv = batchSize != batchSize_k; + bool is_broadcast_head_kv = num_head != num_head_k; + int64_t gqa_shards = num_head / num_head_k; + int64_t bs_shards = batchSize / batchSize_k; + + int64_t batchSize_kvi = {{kernel.size(kv_indices, 0)}}; + int64_t num_head_kvi = {{kernel.size(kv_indices, 1)}}; + int64_t block_num_kvi = {{kernel.size(kv_indices, 3)}}; + bool is_broadcast_bs_kvi = batchSize != batchSize_kvi; + bool is_broadcast_head_kvi = num_head != num_head_kvi; + int64_t gqa_shards_kvi = num_head / num_head_kvi; + int64_t bs_shards_kvi = batchSize / batchSize_kvi; + + int64_t kviStrideB = {{kernel.stride(kv_indices, 0)}}; + int64_t kviStrideH = {{kernel.stride(kv_indices, 1)}}; + int64_t kviStrideQ = {{kernel.stride(kv_indices, 2)}}; + + int64_t num_kviStrideB = {{kernel.stride(kv_num_blocks, 0)}}; + int64_t num_kviStrideH = {{kernel.stride(kv_num_blocks, 1)}}; + +{%- if has_full_kv_block %} + int64_t full_kviStrideB = {{kernel.stride(full_kv_indices, 0)}}; + int64_t full_kviStrideH = {{kernel.stride(full_kv_indices, 1)}}; + int64_t full_kviStrideQ = {{kernel.stride(full_kv_indices, 2)}}; + + int64_t full_num_kviStrideB = {{kernel.stride(full_kv_num_blocks, 0)}}; + int64_t full_num_kviStrideH = {{kernel.stride(full_kv_num_blocks, 1)}}; + auto full_kv_indices_data = full_kv_indices; + auto full_kv_num_blocks_data = full_kv_num_blocks; +{%- endif %} + + auto kv_num_blocks_data = kv_num_blocks; + auto kv_indices_data = kv_indices; + + // Strides + int64_t qStrideB = {{kernel.stride(query, 0)}}; + int64_t qStrideM = {{kernel.stride(query, 1)}}; + int64_t qStrideH = {{kernel.stride(query, 2)}}; + int64_t kStrideB = {{kernel.stride(key, 0)}}; + int64_t kStrideN = {{kernel.stride(key, 1)}}; + int64_t kStrideH = {{kernel.stride(key, 2)}}; + int64_t vStrideB = {{kernel.stride(value, 0)}}; + int64_t vStrideN = {{kernel.stride(value, 1)}}; + int64_t vStrideH = {{kernel.stride(value, 2)}}; + int64_t oStrideB = {{kernel.stride(output, 0)}}; + int64_t oStrideM = {{kernel.stride(output, 2)}}; + int64_t oStrideH = {{kernel.stride(output, 1)}}; + + int64_t kvSize = {{kernel.size(key, 1)}}; + + int64_t qSplitSize = qBlockSize; + int64_t kvSplitSize = kvBlockSize; + + + qSplitSize = qSplitSize > qSize ? qSize : qSplitSize; + kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize; + int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize; + int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; + + bool need_pack = false; + // Whether pack is needed for BFloat16/Half + if (is_reduced_type) { + // check platform ability + need_pack = std::is_same_v ? at::native::cpublas::could_pack(at::kBFloat16) + : at::native::cpublas::could_pack(at::kHalf); + } + if (need_pack) { + // When the number of gemm is greater than the number of pack, + // the pack overhead can be overlapped. + int64_t thresh_size = 64; + need_pack = kvSize >= thresh_size && qSize >= thresh_size; + if (need_pack) { + double pack_size = batchSize * num_head * kvSize * headSize; + double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread; + double gemm_size_per_thread = qs_per_thread * qSplitSize * kvSize * headSize; + need_pack = gemm_size_per_thread / pack_size >= 4; + } + } + // Pad is needed for packing when K is not even + bool headSize_even = headSize % 2 == 0; + int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize; + int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize; + int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail; + int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail; + + // Allocate per thread temp buf (accumulate type) + int64_t _size_per_thread = + /* qk */ qSplitSize * kvSplitSize + + /* qk_max */ qSplitSize + + /* qk_sum */ qSplitSize + + /* dst */ qSplitSize * headSize_v; + + // Inputs/outputs buffers + const scalar_t* q_data = query; + const scalar_t* k_data = key; + const scalar_t* v_data = value; + scalar_t* out_data = output; + + // Buffers to store accum results, padding query and transpose/packing key/value + {{template.codegen_allocate_buffer("buf_data", "accum_t", "num_thread*_size_per_thread")}} + {{template.codegen_allocate_buffer("buf_reduced_data", "scalar_t", "num_thread*qSplitSize*ekvSplitSize")}} + {{template.codegen_allocate_buffer("key_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*eheadSize*kvSize")}} + {{template.codegen_allocate_buffer("value_reorder_ptr", "scalar_t", "batchSize_k*num_head_k*kv_padding_size*headSize_v")}} + {{template.codegen_allocate_buffer("transpose_buffer_ptr", "scalar_t", "num_thread*kvSplitSize*headSize")}} + {{template.codegen_allocate_buffer("query_padding_ptr", "scalar_t", "num_thread*qSplitSize*eheadSize")}} + if (need_pack) { + // Pack K, V + at::parallel_for(0, batchSize_k * num_head_k * kvSlice, 1, [&](int64_t begin, int64_t end) { + int ompIdx = at::get_thread_num(); + int64_t i = 0, j = 0, l = 0, n = 0; + scalar_t* transpose_ptr = need_pack? transpose_buffer_ptr + ompIdx * kvSplitSize * headSize : nullptr; + at::native::data_index_init(begin, i, batchSize_k, j, num_head_k, l, kvSlice); + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + n = l * kvSplitSize; + int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n); + auto k_addr = + k_data + i * kStrideB + j * kStrideH + n * kStrideN; + auto v_addr = + v_data + i * vStrideB + j * vStrideH + n * vStrideN; + // transpose [cur_kvSplitSize, headSize] -> [headSize, cur_kvSplitSize] + at::native::utils::transpose( + cur_kvSplitSize, + headSize, + /* src_ptr */ + reinterpret_cast(k_addr), + /* ld_src */ kStrideN, + /* dst */ reinterpret_cast(transpose_ptr), + /* ld_dst */ cur_kvSplitSize); + + // Pack [headSize, cur_kvSplitSize] + at::vec::pack_vnni2( + /* src */ reinterpret_cast(transpose_ptr), + /* dst */ reinterpret_cast(key_reorder_ptr + i * num_head_k * eheadSize * kvSize + + j * eheadSize * kvSize + n * eheadSize), + /* ld_src */ cur_kvSplitSize, + /* K */ headSize, + /* N */ cur_kvSplitSize); + + // Pack [cur_kvSplitSize, headSize_v] + at::vec::pack_vnni2( + /* src */ reinterpret_cast(v_addr), + /* dst */ reinterpret_cast(value_reorder_ptr + + i * num_head_k * kv_padding_size * headSize_v + + j * kv_padding_size * headSize_v + n * headSize_v), + /* ld_src */ vStrideN, + /* K */ cur_kvSplitSize, + /* N */ headSize_v); + // Move to the next query + at::native::data_index_step(i, batchSize_k, j, num_head_k, l, kvSlice); + } + }); + } + // Attention loop below + at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, k = 0; + at::native::data_index_init(begin, i, batchSize, j, num_head, k, qSlice); + int ompIdx = at::get_thread_num(); + accum_t* buf_ptr = buf_data + ompIdx * _size_per_thread; + accum_t* qk_data = buf_ptr; + accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize; + accum_t* qk_sum_data = qk_max_data + qSplitSize; + accum_t* dst_data = qk_sum_data + qSplitSize; + scalar_t *qk_reduced_data = + is_reduced_type + ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize + : nullptr; + scalar_t* query_t_padding_ptr = (!headSize_even && need_pack) + ? query_padding_ptr + ompIdx * qSplitSize * eheadSize + : nullptr; + + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { + auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i; + auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j; + auto kv_logical_num_data = kv_num_blocks_data + i_kvi * num_kviStrideB + + j_kvi * num_kviStrideH + k; + int kv_indice_num = *kv_logical_num_data; + std::vector kv_indice_list(kv_indice_num); + for(int kv_i = 0; kv_i < kv_indice_num; kv_i++){ + auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB + + j_kvi * kviStrideH + k*kviStrideQ + kv_i; + kv_indice_list[kv_i] = *kv_logical_data; + } + bool is_skip_kv = kv_indice_num > 0 ? false : true; +{%- if has_full_kv_block %} + auto full_kv_logical_num_data = full_kv_num_blocks_data + i_kvi * num_kviStrideB + + j_kvi * num_kviStrideH + k; + int full_kv_indice_num = *full_kv_logical_num_data; + std::vector full_kv_indice_list(full_kv_indice_num); + for(int kv_i = 0; kv_i < full_kv_indice_num; kv_i++){ + auto full_kv_logical_data = full_kv_indices_data + i_kvi * full_kviStrideB + + j_kvi * full_kviStrideH + k*full_kviStrideQ + kv_i; + full_kv_indice_list[kv_i] = *full_kv_logical_data; + } + is_skip_kv = kv_indice_num + full_kv_indice_num > 0 ? false : true; +{%- endif %} + int64_t m = k * qSplitSize; + int64_t cur_qSplitSize = std::min(qSplitSize, qSize - m); + if (!is_skip_kv){ + // Initialize max and sum + {{kernel.kernel_name}}_fill_stub(qk_max_data, + -std::numeric_limits::infinity(), cur_qSplitSize); + {{kernel.kernel_name}}_fill_stub(qk_sum_data, + static_cast(0), cur_qSplitSize); + + if (!headSize_even && need_pack) { + // Pad query if headSize is not even + {{kernel.kernel_name}}_copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + cur_qSplitSize, + headSize, + cur_qSplitSize, + eheadSize, + qStrideM + ); + } + } + +{%- if has_full_kv_block %} + for (int64_t n_idx = 0; n_idx < kv_indice_num + full_kv_indice_num ; n_idx += 1) { + auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize; +{%- else %} + for (int64_t n_idx = 0; n_idx < kv_indice_num ; n_idx += 1) { + auto n = kv_indice_list[n_idx]*kvSplitSize; +{%- endif %} + + auto cur_n = n/kvSplitSize; + int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n); + int64_t cur_ekvSplitSize = (need_pack && cur_kvSplitSize % 2 != 0) ? cur_kvSplitSize + 1 : cur_kvSplitSize; + + // Calculate scale * q @ k.T + auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i; + auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j; + + if (!need_pack) { + auto k_addr = + k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN; + + {{kernel.kernel_name}}_kernel_micro_gemm_transpose_b(false)>( + q_data + i * qStrideB + j * qStrideH + + m * qStrideM, + k_addr, + qk_data, + cur_qSplitSize, + cur_kvSplitSize, + headSize, + qStrideM, + kStrideN, + cur_kvSplitSize); + + } else { + at::native::cpublas::brgemm( + cur_qSplitSize, + cur_kvSplitSize, + eheadSize, + headSize_even ? qStrideM : eheadSize, + cur_kvSplitSize, + cur_kvSplitSize, + false, + !headSize_even + ? query_t_padding_ptr + : q_data + i * qStrideB + j * qStrideH + m * qStrideM, + key_reorder_ptr + i_kv * num_head_k * eheadSize * kvSize + + j_kv * eheadSize * kvSize + n * eheadSize, + qk_data, + need_pack); + } + + {{kernel.kernel_name}}_mul_scale_kernel(qk_data, scaling_factor, cur_qSplitSize*cur_kvSplitSize); + +{%- if score_mod and mask_mod %} + // TODO: reduce the number of calls of q_idx and kv_idx initialization + std::vector q_idx(cur_qSplitSize); + for (int64_t i = 0; i < cur_qSplitSize; ++i) { + q_idx[i] = m + i; + } + + std::vector kv_idx(cur_kvSplitSize); + for (int64_t i = 0; i < cur_kvSplitSize; ++i) { + kv_idx[i] = n + i; + } + + std::vector b_idx = {i}; + std::vector h_idx = {j}; + + accum_t* in_ptr0 = qk_data; + + auto in_ptr1 = b_idx.data(); + auto in_ptr2 = h_idx.data(); + auto in_ptr3 = q_idx.data(); + auto in_ptr4 = kv_idx.data(); + + // apply score mod function + { + {{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }} + accum_t* out_ptr{{score_buf_idx}} = in_ptr0; + {{ template.modification(score_mod, score_buf_name, score_buf_idx)|indent(12, false) }} + } + + if ((std::find(kv_indice_list.begin(), kv_indice_list.end(), cur_n) != kv_indice_list.end()) ){ + // Apply block mask, fill unused with -inf + { + {{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }} + accum_t* out_ptr{{mask_buf_idx}} = in_ptr0; + {{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }} + } + } + +{%- endif %} + // Update coefficients with Softmax + accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0; + for (int64_t row = 0; row < cur_qSplitSize; ++row) { + // apply scaling factor and max per row in fusion + {{kernel.kernel_name}}_mul_reduce_max_fusion_kernel( + qk_data + row * cur_kvSplitSize, + static_cast(1), + cur_kvSplitSize, + qk_data + row * cur_kvSplitSize, + tmp_max); + tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max; + if (tmp_max == -std::numeric_limits::infinity()) { + // to avoid `nan = exp2f(-inf - (-inf))` + {{kernel.kernel_name}}_fill_stub( + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize, + static_cast(0), cur_kvSplitSize); + } else { + tmp_sum = tmp_max; + // qk <- exp(qk - max) and sum per row + {{kernel.kernel_name}}_exp_reduce_sum_fusion_kernel( + qk_data + row * cur_kvSplitSize, cur_kvSplitSize, + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data) + row * cur_ekvSplitSize, + tmp_sum); + // exp_tmp <- exp(max[row] - max) + exp_tmp = std::exp(qk_max_data[row] - tmp_max); + // sum[row] <- sum + exp_tmp * sum[row] + qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row]; + // max[row] <- max + qk_max_data[row] = tmp_max; + // dst <- dst * exp_tmp + if (n_idx > 0) { + at::vec::map( + [exp_tmp](Vec x) { return x * Vec(exp_tmp); }, + dst_data + row * headSize_v, + dst_data + row * headSize_v, + headSize_v); + } + } + if (need_pack && cur_kvSplitSize % 2 != 0) { + // Pad: [qSplitSize, cur_kvSplitSize] -> [qSplitSize, cur_kvSplitSize + 1] + *(qk_reduced_data + row * (1 + cur_kvSplitSize) + cur_kvSplitSize) = scalar_t(0); + } + } + // Calculate Softmax(q @ k.T) @ v + if (!need_pack) { + auto v_addr = + v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN; + // Fallback Half brgemm is slower than micro gemm + if (!std::is_same_v) { + at::native::cpublas::brgemm( + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + vStrideN, + headSize_v, + n_idx > 0, + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), + v_addr, + dst_data, + need_pack); + } else { + if (n_idx > 0) { + {{kernel.kernel_name}}_kernel_micro_gemm(true)>( + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), + v_addr, + dst_data, + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + vStrideN, + headSize_v); + } else { + {{kernel.kernel_name}}_kernel_micro_gemm(false)>( + {{kernel.kernel_name}}_conditional_data_ptr(qk_data, qk_reduced_data), + v_addr, + dst_data, + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + vStrideN, + headSize_v); + } + } + } else { + int64_t psize = n / kvSplitSize * ekvSplitSize; + at::native::cpublas::brgemm( + cur_qSplitSize, + headSize_v, + cur_ekvSplitSize, + cur_ekvSplitSize, + headSize_v, + headSize_v, + n_idx > 0, + qk_reduced_data, + value_reorder_ptr + + i_kv * num_head_k * kv_padding_size * headSize_v + + j_kv * kv_padding_size * headSize_v + psize * headSize_v, + dst_data, + need_pack); + } + } + + // dst <- dst / sum[row] + // reorder MHA output with strides + for (int64_t row = 0; row < cur_qSplitSize; ++row) { + // Row sums for full masked out rows are 0, we set them to 1 + // in order to avoid NaNs in the output and instead set fully + // masked out rows to 0 + qk_max_data[row] = qk_max_data[row] == -std::numeric_limits::infinity() ? 0 : qk_max_data[row]; + qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row]; + accum_t sum_reciprocal = 1 / qk_sum_data[row]; + at::vec::map( + [sum_reciprocal, is_skip_kv](Vec x) { return is_skip_kv ? Vec(0.0) : x * Vec(sum_reciprocal); }, + out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM, + dst_data + row * headSize_v, + headSize_v); + } + + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, k, qSlice); + } + + at::native::cpublas::brgemm_release(need_pack); + + }); +} +""" + + +class CppFlexAttentionTemplate(CppTemplate): + def __init__( + self, + input_nodes, + layout: ir.Layout, + scale, + score_mod, + mask_mod, + kv_block_size, + q_block_size, + has_other_buffer, + no_full_kv_block, + fake_buffers, + len_score_other, + len_mask_other, + kernel_input_name_to_buffer, + block_vars, + ) -> None: + assert layout.dtype in [torch.float, torch.bfloat16, torch.float16] + super().__init__("flex_attention", input_nodes, layout, parallel_num_threads()) + self.scale = scale + self.score_mod = score_mod + self.mask_mod = mask_mod + self.score_buf_name = ( + V.graph.register_buffer(self.score_mod) if self.score_mod else None + ) + self.mask_buf_name = ( + V.graph.register_buffer(self.mask_mod) if self.mask_mod else None + ) + + def get_idx(buf_name): + match = re.search(r"\d+", buf_name) + assert match, f"incorrect score buf name: {buf_name}" + return match.group() + + self.score_buf_idx = ( + get_idx(self.score_buf_name) if self.score_buf_name else None + ) + self.mask_buf_idx = get_idx(self.mask_buf_name) if self.mask_buf_name else None + self.kv_block_size = kv_block_size + self.q_block_size = q_block_size + self.has_other_buffer = has_other_buffer + self.no_full_kv_block = no_full_kv_block + self.other_buffer_input_offset = 2 + if self.no_full_kv_block: + self.other_buffer_input_offset = 0 + self.fake_buffers = fake_buffers + self.len_score_other = len_score_other + self.len_mask_other = len_mask_other + self.kernel_input_name_to_buffer = kernel_input_name_to_buffer + self.block_vars = block_vars + self.extra_sizevars = list( + OrderedSet( + val + for val in self.kernel_input_name_to_buffer.values() + if isinstance(val, sympy.Symbol) + ) + ) + self.other_buf_start_idx = 5 + self.score_mod_other_buffers = ( + self.input_nodes[ + self.other_buf_start_idx + + self.other_buffer_input_offset : self.other_buf_start_idx + + self.other_buffer_input_offset + + self.len_score_other + ] + if self.has_other_buffer + else None + ) + self.mask_mod_other_buffers = ( + self.input_nodes[ + self.other_buf_start_idx + + self.other_buffer_input_offset + + self.len_score_other : + ] + if self.has_other_buffer + else None + ) + self.other_ptr_data = {} # type: ignore[var-annotated] + + def update_kernel_args(self, kernel_args): + kernel_args.update( + { + key: value + for key, value in self.kernel_input_name_to_buffer.items() + if not isinstance(value, sympy.Symbol) + } + ) + return kernel_args + + def generate_other_buffer(self, buf_list, start_offset, len_attr, kernel_args): + kernel_input_name_to_buffer_name = { + key: value if isinstance(value, sympy.Symbol) else value.get_name() + for key, value in self.kernel_input_name_to_buffer.items() + } + + def get_arg(name): + return kernel_input_name_to_buffer_name.get(name) + + def get_arg_name(name): + if isinstance(get_arg(name), sympy.Symbol): + return kernel_args.sizevars.get(get_arg(name)) + return kernel_args.input_buffers.get(get_arg(name)) + + if not self.has_other_buffer: + return "" + + if start_offset == -1: + start_offset = getattr(self, len_attr) + + length = getattr(self, len_attr) + for i in range(length): + pointer = f"in_ptr{self.other_buf_start_idx + start_offset + i}" + buffer_key = f"{buf_list}_{i}" + if pointer not in self.other_ptr_data: + self.other_ptr_data[pointer] = ( + get_arg_name(buffer_key), + get_arg(buffer_key), + ) + + return "\n".join( + f"auto {ptr} = {name};" for ptr, (name, _) in self.other_ptr_data.items() + ) + + def modification(self, subgraph_buffer, output_name, output_idx): + assert isinstance(subgraph_buffer, ir.ComputedBuffer) + subgraph_buffer_data = subgraph_buffer.data + from ..loop_body import LoopBody + from ..utils import sympy_index_symbol_with_prefix, SymT + from ..virtualized import V + from .cpp import CppKernelProxy, KernelGroup + + kernel_group = KernelGroup() + kernel_input_args = { + "score": "in_ptr0", + "b": "in_ptr1", + "h": "in_ptr2", + "q_idx": "in_ptr3", + "kv_idx": "in_ptr4", + } + if self.has_other_buffer: + kernel_input_args.update( + {arg: ptr for ptr, (_, arg) in self.other_ptr_data.items()} + ) + + kernel_output_args = {output_name: f"out_ptr{output_idx}"} + + args = kernel_group.args + for name, inp in kernel_input_args.items(): + args.input_buffers[name] = inp + + for name, inp in kernel_output_args.items(): + args.output_buffers[name] = inp + + for name in self.extra_sizevars: + args.sizevars[name] = f"k{name}" + + kernel_group.args = args + + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + var_sizes = tuple(subgraph_buffer.get_size()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes) + } + + dst_layout = subgraph_buffer.get_layout() + output_index = dst_layout.make_indexer()([*var_ranges.keys()]) + + def fn(*args): + V.ops.store( + output_name, + output_index, + subgraph_buffer_data.make_loader()(args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys())), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + + from ..loop_body import MemoryUsageType + + assert all( + mem.buffer_name in kernel_group.args.input_buffers + for mem in body.memory_usage[MemoryUsageType.LOAD] + ), ( + "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers" + ) + + bodies.append(body) + var_sizes_list.append((var_sizes, ())) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + output_code = kernel_group.loops_code.getvalue() + + var_q_symbol, var_kv_symbol = self.block_vars + # See [Note] Handle the case where the split sizes are not statically known. + # We don't know the value of qBlockSize and rkvBlockSize during compilation time + # thus we've represented them by symbols. + # We change the symbol strings back to "cur_qSplitSize" and "cur_kvSplitSize" + # in the generated code thus they'll be filled with the real value during runtime. + if var_q_symbol in kernel_group.args.sizevars: + output_code = output_code.replace( + kernel_group.args.sizevars[var_q_symbol], "cur_qSplitSize" + ) + if var_kv_symbol in kernel_group.args.sizevars: + output_code = output_code.replace( + kernel_group.args.sizevars[var_kv_symbol], "cur_kvSplitSize" + ) + + return output_code + + @staticmethod + def add_choices( + choices, + input_nodes, + layout, + scale, + score_mod, + mask_mod, + kv_block_size, + q_block_size, + has_other_buffer, + no_full_kv_block, + fake_buffers, + len_score_other, + len_mask_other, + kernel_input_name_to_buffer, + block_vars, + ): + def preprocessor(input_nodes, layout): + return input_nodes, layout + + def postprocessor(output): + return output + + template = DataProcessorTemplateWrapper( + CppFlexAttentionTemplate, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + scale=scale, + score_mod=score_mod, + mask_mod=mask_mod, + kv_block_size=kv_block_size, + q_block_size=q_block_size, + has_other_buffer=has_other_buffer, + no_full_kv_block=no_full_kv_block, + fake_buffers=fake_buffers, + len_score_other=len_score_other, + len_mask_other=len_mask_other, + kernel_input_name_to_buffer=kernel_input_name_to_buffer, + block_vars=block_vars, + ) + template.maybe_append_choice(choices) + return template + + def apply_score_mod(self, score, b, h, q_idx, kv_idx): + return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item() + + def render( # type: ignore[override,return] + self, + kernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + if epilogue_nodes is not None and epilogue_nodes != []: + raise NotImplementedError( + "Unsupported for `epilogue_nodes` in CppFlexAttentionTemplate." + ) + # Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + # -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + # Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + # -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + # Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + # -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + + query = kernel.permute(self.input_nodes[0], [0, 2, 1, 3]) + key = kernel.permute(self.input_nodes[1], [0, 2, 1, 3]) + value = kernel.permute(self.input_nodes[2], [0, 2, 1, 3]) + self.accumulate_dtype = torch.float + self.input_dtype = query.layout.dtype + + num_threads = parallel_num_threads() + buf_out = TensorBox.create(self.output_node) + if template_buffer_node is not None: + buf_out = template_buffer_node + options = dict( + query=query, + key=key, + value=value, + kv_num_blocks=self.input_nodes[3], + kv_indices=self.input_nodes[4], + full_kv_num_blocks=self.input_nodes[5] + if not self.no_full_kv_block + else None, + full_kv_indices=self.input_nodes[6] if not self.no_full_kv_block else None, + score_mod_other_buffers=self.score_mod_other_buffers, + mask_mod_other_buffers=self.mask_mod_other_buffers, + scale=self.scale, + has_full_kv_block=not self.no_full_kv_block, + accumulate_dtype=self.accumulate_dtype, + query_dtype=self.input_dtype, + kvBlockSize=self.kv_block_size, + qBlockSize=self.q_block_size, + template=self, + output=buf_out, + kernel=kernel, + num_thread=num_threads, + score_mod=self.score_mod, + mask_mod=self.mask_mod, + score_buf_name=self.score_buf_name, + mask_buf_name=self.mask_buf_name, + score_buf_idx=self.score_buf_idx, + mask_buf_idx=self.mask_buf_idx, + ) + with contextlib.ExitStack() as stack: + for buf in self.fake_buffers: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options) + + def codegen_softmax_fusion(self, kernel_name: str): + # TODO: use inductor IR to rewrite those fusions + return self._template_from_string(SOFTMAX_FUSIONS).render( + dict(kernel_name=kernel_name) + ) + + def codegen_brgemm_pack_function(self, kernel_name: str): + # TODO: make them general for common bmm templates + return self._template_from_string(BRGEMM_PACK_FUNCTIONS).render( + dict(kernel_name=kernel_name) + ) + + def codegen_allocate_buffer(self, buffer_name: str, buffer_dtype, buffer_size): + return self._template_from_string(ALLOCATE_BUFFER).render( + dict( + buffer_name=buffer_name, + buffer_dtype=buffer_dtype, + buffer_size=buffer_size, + ) + ) + + def micro_gemm_define(self, kernel_name: str): + from torch._inductor.codegen.cpp_gemm_template import ( + CppTemplateKernel, + parallel_num_threads, + ) + from torch._inductor.codegen.cpp_micro_gemm import CppMicroGemmFP32Vec + from torch._inductor.virtualized import V + + micro_gemm_trans = CppMicroGemmFP32Vec( + kernel_name + "_kernel_micro_gemm_transpose_b", + self.input_dtype, + self.input_dtype, + self.accumulate_dtype, + self.accumulate_dtype, + GemmBlocking(1, 16, 1), + 1, + True, + True, + ) + + micro_gemm = CppMicroGemmFP32Vec( + kernel_name + "_kernel_micro_gemm", + self.input_dtype, + self.input_dtype, + self.accumulate_dtype, + self.accumulate_dtype, + GemmBlocking(1, 16, 1), + 1, + True, + False, + ) + + with V.set_graph_handler(V.graph): + kernel = CppTemplateKernel("cpp_micro_gemm", parallel_num_threads()) + code_trans = micro_gemm_trans.codegen_define(kernel) + code = micro_gemm.codegen_define(kernel) + return code + code_trans + + def codegen_micro_gemm(self, kernel_name: str): + micro_gemm = self.micro_gemm_define(kernel_name) + GEMM_SOURCE_CODE = MICRO_GEMM_TEMPLATE.replace("GEMM_DEFINE", micro_gemm) + return self._template_from_string(GEMM_SOURCE_CODE).render() diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b794b757fb4673e6d57a0135142d365152a49a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py @@ -0,0 +1,1777 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import math +from functools import lru_cache +from typing import Any, Callable, cast, Optional, TypeVar, Union +from unittest.mock import patch + +import torch +import torch.utils +from torch.utils._ordered_set import OrderedSet + +from ..._dynamo.utils import counters +from .. import config, ir, lowering as L +from ..kernel.mm_common import mm_args +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import ( + has_free_symbols, + is_same_mkldnn_tensor, + is_same_tensor, + parallel_num_threads, +) +from ..virtualized import ops, V +from .cpp import get_export_declaration +from .cpp_micro_gemm import ( + CppMicroBrgemm, + CppMicroGemm, + CppMicroGemmAMX, + CppMicroGemmFP32Vec, + create_micro_gemm, + is_int8_woq_gemm_small_m_dim_corner_case, + LayoutType, +) +from .cpp_template import CppTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import ( + create_epilogue_with_attr, + DTYPE_TO_CPP, + GemmBlocking, + get_gemm_template_output_and_compute_dtype, +) + + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK = r""" + constexpr int64_t num_threads = {{num_threads}}; + constexpr int64_t N = {{N}}; + constexpr int64_t K = {{K}}; + constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}}; + constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}}; + constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}}; + constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; + constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; +{%- if is_dynamic_M %} + const int64_t M = {{kernel.size(GemmOut, 0)}}; + const int64_t Mr_blocks = (M + Mr - 1) / Mr; +{%- else %} + constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; + constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; +{%- endif %} +""" + +GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED = r""" +{%- if is_dynamic_M %} + {%- if num_threads > 1 %} + int64_t Mt_blocks, Nt_blocks, Kt_blocks; + mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); + {%- else %} + const auto Mt_blocks = Mr_blocks; + const auto Nt_blocks = Nr_blocks; + const auto Kt_blocks = Kr_blocks; + {%- endif %} + int64_t Mc_blocks, Nc_blocks, Kc_blocks; + uint32_t L1_cache_size = {{L1_cache_size}}; + uint32_t L2_cache_size = {{L2_cache_size}}; + mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>( + num_threads, + M, + N, + K, + Mr, + Nr, + Kr, + Mt_blocks, + Nt_blocks, + Kt_blocks, + Mc_blocks, + Nc_blocks, + Kc_blocks, + L1_cache_size, + L2_cache_size + ); + const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- else %} + constexpr int64_t Mt_blocks = {{template.thread_blocking(num_threads).block_m}}; + constexpr int64_t Nt_blocks = {{template.thread_blocking(num_threads).block_n}}; + constexpr int64_t Kt_blocks = {{template.thread_blocking(num_threads).block_k}}; + constexpr int64_t Mc_blocks = {{template.cache_blocking(num_threads).block_m}}; + constexpr int64_t Nc_blocks = {{template.cache_blocking(num_threads).block_n}}; + constexpr int64_t Kc_blocks = {{template.cache_blocking(num_threads).block_k}}; + constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- endif %} +{%- if is_woq_int4 %} + int64_t group_size = *q_group_size; +{%- endif %} + + // make sure all partitions are assigned + {{kernel.assert_function}}( + Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks, + "Not all partitions are assigned." + ); +""" + +GEMM_TEMPLATE_MULTI_THREADS_PARAMS = r""" +const int tid = omp_get_thread_num(); +const int64_t k_group_id = tid / num_Kt_blocks; +const int64_t k_slice_id = tid % num_Kt_blocks; +const int64_t n_group_id = k_group_id / num_Nt_blocks; +const int64_t n_slice_id = k_group_id % num_Nt_blocks; +const int64_t k_block_start = k_slice_id * Kt_blocks; +const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); +const int64_t n_block_start = n_slice_id * Nt_blocks; +const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); +const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); +const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); +const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; +""" + +GEMM_TEMPLATE_SINGLE_THREAD_PARAMS = r""" +constexpr int tid = 0; +constexpr int64_t k_group_id = 0; +constexpr int64_t k_slice_id = 0; +constexpr int64_t n_group_id = 0; +constexpr int64_t n_slice_id = 0; +constexpr int64_t m_block_start = 0; +constexpr int64_t n_block_start = 0; +constexpr int64_t n_block_end = Nr_blocks; +constexpr int64_t k_block_start = 0; +constexpr int64_t k_block_end = Kr_blocks; +{%- if is_dynamic_M %} +const int64_t num_Mc_blocks_per_thread = num_Mc_blocks; +const int64_t m_block_end = Mr_blocks; +{%- else %} +constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; +constexpr int64_t m_block_end = Mr_blocks; +{%- endif %} +""" + +GEMM_TEMPLATE_M_LOOP_PARAMS = r""" +const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; +const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; +const int64_t m_start = mc * Mr; +const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); +const int64_t m_size = m_end - m_start; +""" + +GEMM_TEMPLATE_N_LOOP_PARAMS = r""" +const int64_t n_start = nc * Nr; +const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); +const int64_t n_size = n_end - n_start; +// NB: assume we pad N, nc_block_end won't exceed padded N here. +const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); +""" + +GEMM_TEMPLATE_MICROKERNEL_DEF = r""" +{{template.header().getvalue()}} + +{{micro_gemm.codegen_define(kernel)}} +""" + +GEMM_TEMPLATE_STUB_DEF = r""" +{%- if x_scale is not none %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %} +{%- elif is_woq_int4 %} + {%- set kernel_args = {"X": X, "W": W, "q_group_size": q_group_size, "qscale_and_zeros": qscale_and_zeros} %} +{%- else %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp} %} +{%- endif %} + +extern "C" {{export_declaration}} +{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}} +""" + +GEMM_TEMPLATE = r""" +{{ template.codegen_gemm_stub_def() }} +{ + {{ kernel.maybe_codegen_profile() }} + {{ template.codegen_blocks( + num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W + ) }} + +{%- if maybe_k_slicing %} + std::unique_ptr[]> local_buf_ptrs; + if (num_Kt_blocks > 1) { + local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]); + } +{%- endif %} + +{%- if num_threads > 1 %} + #pragma omp parallel num_threads({{num_threads}}) + { + {{ template.codegen_multi_threads_params()|indent(8, false) }} +{%- else %} + { + {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }} +{%- endif %} + {{ micro_gemm.codegen_init(kernel) }} +{%- if use_local_acc %} + {%- set acc_buf_name = "local_acc_buf" %} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} +{%- endif %} + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { + {{ template.codegen_m_loop_params()|indent(12, false) }} + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + {{ template.codegen_n_loop_params()|indent(16, false) }} +{%- if use_local_acc %} + {%- set acc = kernel.local_buffers[acc_buf_name] %} + {{ kernel.reinit_buffer_if_null(acc_buf_name) }} +{%- else %} + {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- endif %} + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + int64_t k_start = kc * Kr; + int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); +{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} + for (int64_t nci = nc; nci < nc_block_end; nci++) { +{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} +{%- if template.should_block_weights and not is_woq_int4 %} +{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %} +{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} +{%- else %} + {%- if is_woq_int4 %} + {%- set tile_W = kernel.slice_nd(W, [("nci * Nr", "(nci + 1) * Nr"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %} + {%- set tile_qparam = kernel.slice_nd( + qscale_and_zeros, [("k_start // group_size", "k_end // group_size"), ("nci * Nr", "(nci + 1) * Nr"), ()]) %} + {%- else %} + {%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %} + {%- set tile_qparam = None %} + {%- endif %} +{%- endif %} + if (kc == k_block_start) { + {{ micro_gemm.codegen_call(kernel, + tile_X, + tile_W, + acc_slice, + accum=False, + qscale_and_zeros=tile_qparam)|indent(28, false) + }} + } else { + {{ micro_gemm.codegen_call(kernel, + tile_X, + tile_W, + acc_slice, + accum=True, + qscale_and_zeros=tile_qparam)|indent(28, false) + }} + } + } + } +{%- if maybe_k_slicing %} + if (num_Kt_blocks > 1) { + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset( + {{ kernel.release_buffer(acc_buf_name) }}); + } else +{%- endif %} + { +{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %} + {{ kernel.store_output( + tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- if maybe_k_slicing %} + if (num_Kt_blocks > 1) { + #pragma omp barrier + for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { + // We slice M-dim and each thread in the k-slicing group works on a slice + const int64_t m_start_unsliced = mc * Mr; + const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); + const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced; + const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks; + const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced); + const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced); + const int64_t m_size = m_end - m_start; + const int64_t m_offset = m_start - m_start_unsliced; + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + const int64_t n_start = nc * Nr; + const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); + const int64_t n_size = n_end - n_start; + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get(); + for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) { + auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get(); + for (int64_t m = m_offset; m < m_offset + m_size; m++) { + #pragma omp simd + for (int64_t n = 0; n < n_size; n++) { + {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n]; + } + } + } + {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %} + {{ kernel.store_output( + tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- endif %} + {{ micro_gemm.codegen_finalize(kernel) }} + } +} +""" + +SMALL_M_GEMM_TEMPLATE = r""" +{{ template.codegen_gemm_stub_def() }} +{ + {{ kernel.maybe_codegen_profile() }} + {{ template.codegen_blocks( + num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W + ) }} + # pragma omp parallel + { + #pragma omp for nowait + for (int64_t nr_block_id = 0; nr_block_id < Nr_blocks; nr_block_id++) { + // Handle one output M * Nr block in each thread + int64_t n_start = nr_block_id * Nr; + int64_t n_end = (nr_block_id + 1) * Nr; +{%- if use_local_acc %} + {%- set acc_buf_name = "local_acc_buf" %} + {{ kernel.define_stack_allocated_buffer(acc_buf_name, ["M", "Nr"], acc_buf_dtype) }} + {%- set acc = kernel.local_buffers[acc_buf_name] %} +{%- else %} + {%- set acc = kernel.slice_nd(GemmOut, [(0, "M"), ("n_start", "n_end")]) %} +{%- endif %} + for (int64_t kr_block_id = 0; kr_block_id < Kr_blocks; kr_block_id++) { + // this loop is not parallelized + int64_t k_start = kr_block_id * Kr; + int64_t k_end = std::min((kr_block_id + 1) * Kr, K); +{%- set tile_X = kernel.slice_nd(X, [(0, "M"), ("k_start", "k_end")]) %} +{%- set tile_W_3d = kernel.slice_nd(W, [("nr_block_id", "nr_block_id + 1"), ("k_start", "k_end"), ()]) %} +{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} + if C10_UNLIKELY(kr_block_id == 0) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=False, prefetch=True)|indent(20, false) }} + } else if C10_UNLIKELY(k_end == K) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=False)|indent(20, false) }} + } else { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc, accum=True, prefetch=True)|indent(20, false) }} + } + } +{%- set tile_Y = kernel.slice_nd(Y_2d, [("0", "M"), ("n_start", "n_end")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "M"), ("0", "n_end - n_start")]) %} + {{ kernel.store_output( + tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("0", "n_start"), reindexers=reindexers + )|indent(20, false) }} + } + } +} +""" + + +def _is_int8_gemm(inputs): + return ( + isinstance(inputs[0], ir.IRNode) + and inputs[0].get_dtype() in [torch.uint8, torch.int8] + ) or ( + isinstance(inputs[0], torch.Tensor) + and inputs[0].dtype in [torch.uint8, torch.int8] + ) + + +def get_padded_n(n, block_n): + return (n + block_n - 1) // block_n * block_n + + +_T = TypeVar("_T", ir.IRNode, torch.Tensor) + + +def transpose_w(W: _T, trans_w: bool) -> _T: + """ + Transpose W based on the trans_w flag. + """ + if isinstance(W, ir.IRNode): + if trans_w: + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + W = L.permute(W, [1, 0]) + else: + if trans_w: + assert isinstance(W, torch.Tensor) + W = W.transpose(0, 1) + return W + + +def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]: + """ + Expand Bias to the same size of X. + """ + if B is not None: + if isinstance(B, ir.IRNode): + if not isinstance(B, ir.TensorBox): + B = ir.TensorBox(B) + assert hasattr(X, "get_size") + B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) + else: + assert isinstance(B, torch.Tensor) + assert isinstance(X, torch.Tensor) + B = B.expand(X.shape[0], B.shape[-1]) + return B + + +def prune_tensors(input_nodes: list[ir.IRNode], new_input_nodes: list[ir.IRNode]): + """ + Prune unused tensors from `V.graph` since the GEMM Template use new packed weight. + """ + + def share_storage(base_tensor: torch.Tensor, comp_tensor: torch.Tensor): + return base_tensor.is_mkldnn == comp_tensor.is_mkldnn and ( + is_same_tensor(base_tensor, comp_tensor) + or is_same_mkldnn_tensor(base_tensor, comp_tensor) + ) + + def get_candidates(input_nodes, new_input_nodes): + # Only Constant Buffer like weight and bias might be changed in GEMM Template. + # The Inductor IR Node may changed, but still share the storage. For example: + # bias in bfloat16 case which only do the expand + return [ + node + for node in input_nodes + if ( + node not in new_input_nodes + and isinstance(node, (ir.TensorBox, ir.StorageBox)) + and node.get_name() in V.graph.constants + and not any( + ( + isinstance(new_node, (ir.TensorBox, ir.StorageBox)) + and new_node.get_name() in V.graph.constants + and share_storage( + V.graph.constants[node.get_name()], + V.graph.constants[new_node.get_name()], + ) + ) + for new_node in new_input_nodes + ) + ) + ] + + for candidate_node in get_candidates(input_nodes, new_input_nodes): + # By using the new packed weight for the GEMM template, we can prune the + # old weight if it has no other users. This saves memory but makes the FX graph + # non-retraceable. To support retracing, we can add a repack node to the + # FX graph. For example: + # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template + candidate_tensor_users = 0 + candidate_tensor = V.graph.constants[candidate_node.get_name()] + for node in reversed(V.graph.graph.nodes): + # Case may happen when the candidate tensor is used by more than 1 get_attr node + # https://github.com/pytorch/pytorch/issues/134998 + if node.op == "get_attr" and hasattr( + V.graph.module, node.target + ): # candidate tensor might already be deleted + comp_tensor = getattr(V.graph.module, node.target) + if isinstance(comp_tensor, torch.Tensor) and share_storage( + candidate_tensor, comp_tensor + ): + candidate_tensor_users += 1 + + for node in reversed(V.graph.graph.nodes): + # The get_attr node has only 1 user fx node + # The candidate tensor has been used by only 1 get_attr node + if ( + node.op == "get_attr" + and node.target == candidate_node.get_name() + and len(node.users) == 1 + and candidate_tensor_users == 1 + ): + del V.graph.constants[node.target] + delattr(V.graph.module, node.target) + delattr(V.graph.graph.owning_module, node.target) + counters["inductor"]["select_algorithm_weight_prune"] += 1 + + +def gen_2d_view_of_epilogue_buf( + Y: ir.Buffer, + template_buffer: ir.Buffer, + epilogue_nodes: list[ir.IRNode], + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]], + default_reindexers: list[Optional[Callable[[list[Any]], list[Any]]]], +) -> tuple[ + Union[ir.Buffer, ir.ReinterpretView], + list[Optional[Callable[[list[Any]], list[Any]]]], +]: + """ + The dimension and the indexing could be different between the GEMM output, i.e. `template_buffer`, which is + 2D with MxN) and the output from the template after epilogues, i.e. `Y`. In the GEMM template code, + we are not aware of the dimension and the indexing of the epilogues and always work on 2D tiles according to + the indexing of the GEMM output. + In this function, we return a 2D buffer (`Y_2d`) according to GEMM output (reinterpreted from `Y` if needed) and + build a reindexer that converts the indexing of `Y` into `Y_2d`. + """ + Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y + if ( + Y.get_size() == template_buffer.get_size() + and Y.get_stride() == template_buffer.get_stride() + ): + reindexers.extend(default_reindexers) + Y_2d = Y + else: + + def get_reindexer(epilogue_node, default_reindexer=None): + # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example: + # template_buffer: + # size (324, 512), stride (512, 1) + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + stride_order = list( + ir.get_stride_order( + V.graph.sizevars.size_hints(epilogue_node.get_stride()) + ) + ) + fill_order = ir.stride_order2fill_order(stride_order) + reversed_fill_order = list(reversed(fill_order)) + size_with_stride_ordered_decreasingly = [ + epilogue_node.get_size()[i] for i in reversed_fill_order + ] + reshape_reindex = ir.View.dynamic_reshape_indexer( + size_with_stride_ordered_decreasingly, + template_buffer.get_size(), + ) + if default_reindexer: + reshape_reindex = ir.fuse_reindexing(reshape_reindex, default_reindexer) + + # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example: + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + # epilogue_node: + # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) + from_stride_ordered_decreasingly_to_epilogue_node_order = [ + (len(stride_order) - 1) - stride_order[i] + for i in range(len(stride_order)) + ] + stride_reindex = ir.same_reorder( + from_stride_ordered_decreasingly_to_epilogue_node_order + ) + + reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) # type: ignore[var-annotated] + return reindexer + + if default_reindexers is None: + default_reindexers = [None] * len(epilogue_nodes) + new_reindexers = [ + get_reindexer(epilogue_node, default_reindexer) + for epilogue_node, default_reindexer in zip( + epilogue_nodes, default_reindexers + ) + ] + reindexers.extend(new_reindexers) + if isinstance(Y, ir.BaseView): + storage = ir.StorageBox(Y.unwrap_view()) + else: + assert isinstance(Y, ir.Buffer) + storage = ir.StorageBox(Y) + Y_2d = ir.ReinterpretView(data=storage, layout=template_buffer.get_layout()) + return Y_2d, reindexers + + +class CppGemmTemplate(CppTemplate): + """ + GEMM Template for Inductor CPP Backend. + """ + + def __init__( + self, + input_nodes, + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta=1, + alpha=1, + has_bias=False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + should_block_weights: bool = True, + name="packed_gemm", + ) -> None: + assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8] + super().__init__( + name, + input_nodes, + layout, + num_threads, + epilogue_creator=epilogue_creator, + ) + self.beta = beta + self.alpha = alpha + self.has_bias = has_bias + self.register_blocking = register_blocking + m, n = layout.size[-2:] + k = input_nodes[0].get_size()[-1] + self.m, self.n, self.k = m, n, k + self.padded_n = get_padded_n(n, self.register_blocking.block_n) + self.is_dynamic_M = has_free_symbols((m,)) + self.should_block_weights = should_block_weights + self.thread_blocking = self.make_thread_blocking_cache() + self.cache_blocking = self.make_cache_blocking_cache() + + def make_thread_blocking_cache(self): + cache = lru_cache()(self._thread_blocking) + + def thread_blocking(num_threads: int) -> GemmBlocking: + return cache(num_threads) + + return thread_blocking + + def _thread_blocking(self, num_threads: int) -> GemmBlocking: + """ + NOTE [Thread blocking in Cpp GEMM] + We use simple heuristics to decide the thread blocking: + 1. Make sure all threads are occupied as much as possible. + 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse. + 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing. + TODO(jgong5): allow tuning various blocking options + """ + + def get_factors(number): + factors = [] + for i in range(int(number**0.5), 0, -1): + if number % i == 0: + factors.append(number // i) + factors.append(i) + return factors + + def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks): + thread_block_k = math.ceil(k_blocks / k_factor) + thread_block_n = math.ceil(n_blocks / n_factor) + thread_block_m = math.ceil(m_blocks / m_factor) + return GemmBlocking(thread_block_m, thread_block_n, thread_block_k) + + assert not self.is_dynamic_M, ( + "Unable to determine thread blocking for dynamic M." + ) + register_blocking = self.register_blocking + m_blocks = math.ceil(self.m / register_blocking.block_m) + n_blocks = math.ceil(self.n / register_blocking.block_n) + k_blocks = math.ceil(self.k / register_blocking.block_k) + factors = get_factors(num_threads) + assert len(factors) > 0 + + if config.cpp.gemm_thread_factors is not None: + factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")] + assert len(factors) == 3 + assert math.prod(factors) == self.num_threads + return get_blocking( + factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks + ) + + # we favor square-sized thread blocks for good data reuse + def get_better_blocking(blocking, best_blocking): + if best_blocking is None: + best_blocking = blocking + else: + block_m_size = blocking.block_m * register_blocking.block_m + block_n_size = blocking.block_n * register_blocking.block_n + best_block_m_size = best_blocking.block_m * register_blocking.block_m + best_block_n_size = best_blocking.block_n * register_blocking.block_n + if blocking.block_k > best_blocking.block_k: + best_blocking = blocking + elif ( + blocking.block_k == best_blocking.block_k + and block_m_size + block_n_size + < best_block_m_size + best_block_n_size + ): + best_blocking = blocking + return best_blocking + + best_blocking = None + # check if we can have a thread-blocking to occupy all threads without k-slicing + for n_factor in factors: + m_factor = num_threads // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for k_factor in factors: + if k_blocks >= k_factor and ( + config.cpp.gemm_max_k_slices == 0 + or k_factor <= config.cpp.gemm_max_k_slices + ): + n_factors = get_factors(num_threads // k_factor) + for n_factor in n_factors: + m_factor = (num_threads // k_factor) // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, + n_factor, + k_factor, + m_blocks, + n_blocks, + k_blocks, + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for n_factor in factors: + m_factor = num_threads // n_factor + if n_blocks >= n_factor or m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + assert best_blocking is not None + return best_blocking + + def make_cache_blocking_cache(self): + cache = lru_cache()(self._cache_blocking) + + def cache_blocking(num_threads: int) -> GemmBlocking: + return cache(num_threads) + + return cache_blocking + + def _cache_blocking(self, num_threads: int) -> GemmBlocking: + def get_cache_blocking(register_blocking, thread_blocking): + Mr = register_blocking.block_m + Nr = register_blocking.block_n + Kr = register_blocking.block_k + + Mt_blocks = thread_blocking.block_m + Nt_blocks = thread_blocking.block_n + Kt_blocks = thread_blocking.block_k + + if config.cpp.gemm_cache_blocking is not None: + blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")] + assert len(blockings) == 3 + Mc_blocks, Nc_blocks, Kc_blocks = blockings + return ( + min(Mc_blocks, Mt_blocks), + min(Nc_blocks, Nt_blocks), + min(Kc_blocks, Kt_blocks), + ) + + # The ratios below are empirically determined to decide + # the effective sizes of L1 and L2. + # TODO: tune the factor here + L1_limit_factor = 0.8 + L2_limit_factor = 0.5 + + L1_cache_size = ( + torch._C._cpu._L1d_cache_size() + ) # per core cache size in Bytes + assert L1_cache_size > 0, ( + f"Expect L1_cache_size > 0 but got {L1_cache_size}" + ) + L1 = L1_cache_size * L1_limit_factor + + L2_cache_size = ( + torch._C._cpu._L2_cache_size() + ) # per core cache size in Bytes + assert L2_cache_size > 0, ( + f"Expect L2_cache_size > 0 but got {L2_cache_size}" + ) + L2 = L2_cache_size * L2_limit_factor + + def get_num_byte(dtype): + return torch.tensor([], dtype=dtype).element_size() + + dtype_A = self.input_nodes[0].get_dtype() + dtype_B = self.input_nodes[1].get_dtype() + num_byte_A = get_num_byte(dtype_A) + num_byte_B = get_num_byte(dtype_B) + if dtype_A is torch.bfloat16 and dtype_B is torch.int8 and Kr != 1: + # We will cache dequantized weights (BF16) in L1D for AMX micro-kernel. + # In this case, the choice of the micro-kernel being used can't be decoupled from + # the cache blocking. + # TODO: Decouple the choice of micro-kernel from cache blocking + num_byte_B *= num_byte_A + + # NOTE [CPP GEMM Cache Blocking Algorithm] + # Our overall strategy is to + # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc. + # Here, B is Kc x Nr where Nr is a single register block. We use L1 size to + # decide Kc. We want to make Mc large enough to better reuse B. + # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A + # along N, where we have two sub-strategies (see notes below) to decide Mc and Nc. + + # Step 1: Decide Kc assuming B block is L1-reside. + size_cache_B = Kr * Kt_blocks * Nr * num_byte_B + + Kc_blocks = Kt_blocks + if size_cache_B > L1: + Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B)) + + if ( + config.cpp.use_small_dequant_buffer + and dtype_A is torch.bfloat16 + and dtype_B is torch.uint8 + and Mt_blocks == 1 + ): + # Make a small dequant_B buffer for woq int4 [q_group_size, Nr] + # Since when Mt_blocks == 1, L1-reside B block can't be reused by A. + if Kc_blocks * Kr >= self.q_group_size(): + Kc_blocks = self.q_group_size() // Kr + + # Step 2: Decide Mc assuming A block is L2-reside. + min_Mc_ratio = 2 # TODO(jgong5): something to tune? + min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr) + assert min_Mc_blocks >= 1 + Kt_bytes = Kt_blocks * Kr * num_byte_A + if min_Mc_blocks * Mr * Kt_bytes < L2: + # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt + # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks) + # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside + # in L1. + Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes))) + Nc_blocks = 1 + else: + # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse + # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2. + Mc_blocks = Mt_blocks + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32 + Kc_bytes = Kc_blocks * Kr * num_byte_A + if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2: + # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2, + # assuming Mc == Nc for good data reuse. + M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8 + if M_max < Mc_blocks * Mr: + Mc_blocks = math.floor(M_max / Mr) + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + + return Mc_blocks, Nc_blocks, Kc_blocks + + assert not self.is_dynamic_M, ( + "Unable to determine cache blocking for dynamic M." + ) + register_blocking = self.register_blocking + thread_blocking = self.thread_blocking(num_threads) + + return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking)) + + def log_blockings(self): + log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004 + if self.is_dynamic_M: + # thread and cache blockings are determined at runtime for dynamic shapes + return + log.debug( + f"Cache blocking: {self.cache_blocking(self.num_threads)}" # noqa: G004 + ) + thread_blocking = self.thread_blocking(self.num_threads) + log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004 + + def get_occupancy(): + m_blocks = math.ceil(self.m / self.register_blocking.block_m) + n_blocks = math.ceil(self.n / self.register_blocking.block_n) + k_blocks = math.ceil(self.k / self.register_blocking.block_k) + m = math.ceil(m_blocks / thread_blocking.block_m) + n = math.ceil(n_blocks / thread_blocking.block_n) + k = math.ceil(k_blocks / thread_blocking.block_k) + return (m, n, k) + + log.debug( + f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004 + ) + + def maybe_k_slicing(self): + if self.num_threads == 1: + return False + if self.is_dynamic_M: + # TODO(jgong5): perhaps use size hint to decide? + return True + register_blocking = self.register_blocking + k_blocks = math.ceil(self.k / register_blocking.block_k) + thread_blocking = self.thread_blocking(self.num_threads) + return k_blocks > thread_blocking.block_k + + @classmethod + def add_choices( + cls, + choices, + layout, + input_nodes, + beta=1, + alpha=1, + has_bias=False, + trans_w=False, + input_indices=None, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + act_mapping: Optional[dict[int, ir.IRNode]] = None, + ): + """ + Add choices for the GEMM template. + """ + # Fast path to save the epilogue calculation when x_scale/x_zp/w_scale are constant + use_int8_fast_compensation_path = _is_int8_gemm(input_nodes) and all( + ( + isinstance(input_nodes[idx], ir.TensorBox) + and isinstance(input_nodes[idx].data.data, ir.ConstantBuffer) + ) + for idx in [1, 2, 4] + ) + + if input_indices is None: + input_indices = list(range(len(input_nodes))) + only_one_input = ( + input_nodes[0] == input_nodes[1] if len(input_nodes) > 1 else False + ) + + def reorder_and_filter(inputs, layout_or_out): + if has_bias: + assert len(input_indices) >= 3 + # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp] + inp_idx = input_indices[0] + x_idx = input_indices[1] + w_idx = input_indices[2] + return [ + inputs[x_idx], + inputs[w_idx], + inputs[inp_idx], + *[inputs[idx] for idx in input_indices[3:]], + ], layout_or_out + elif len(inputs) >= len(input_indices): + assert len(input_indices) >= 2 + return [inputs[idx] for idx in input_indices], layout_or_out + else: + # For when input is used for x and w, i.e. X@X.T or similar + # Assumes the first input is the only input + assert len(inputs) == 1 + return [inputs[0]] * len(input_indices), layout_or_out + + new_inputs, new_layout = reorder_and_filter(input_nodes, layout) + is_mkldnn_wgt = ( + new_inputs[1].get_name() in V.graph.constants + and V.graph.constants[new_inputs[1].get_name()].is_mkldnn + ) + if is_mkldnn_wgt: + # It shouldn't happen as viewing an mkldnn tensor, we can extend the + # implementation if it does. + assert not isinstance(new_inputs[1], ir.BaseView) + # Note that the layout of MKLDNN Tensor is with the wrong stride + view_size = new_inputs[1].layout.size + view_stride = new_inputs[1].layout.stride + view_offset = new_inputs[1].layout.offset + + def maybe_to_dense(inputs, layout_or_out): + new_inputs = list(inputs) + if isinstance(inputs[1], torch.Tensor): + W = inputs[1] + new_inputs[1] = W.to_dense() if W.is_mkldnn else W + return new_inputs, layout_or_out + + def normalize_shapes(inputs, layout_or_out): + new_inputs = list(inputs) + if not is_mkldnn_wgt and isinstance(new_inputs[1], torch.Tensor): + if has_free_symbols(view_size): + # If batch size B is dynamic, we need to set the batch size and possibly stride + assert not has_free_symbols(view_size[1:]) + view_size[:] = V.graph.sizevars.size_hints(view_size) + view_stride[:] = V.graph.sizevars.size_hints(view_stride) + # With the assumptation that W is the storage of unwrap view + # thus view it back here + new_inputs[1] = new_inputs[1].as_strided( + view_size, view_stride, view_offset + ) + + if not trans_w: + return new_inputs, layout_or_out + X = new_inputs[0] + W = new_inputs[1] + B = new_inputs[2] if has_bias else None + W = transpose_w(W, trans_w) + B = expand_bias(B, X) # type:ignore[arg-type] + new_inputs[1] = W + if B is not None: + new_inputs[2] = B + return new_inputs, layout_or_out + + # TODO(jgong5): decide proper number of threads per problem size + num_threads = parallel_num_threads() + new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout)) + m, n, k, *_ = mm_args( + new_inputs[0], + new_inputs[1], + mat2_transposed=cls.is_woq_int4(), + use_4x2_dim=cls.is_woq_int4(), + ) + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + new_inputs[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=new_inputs[0].get_dtype(), + input2_dtype=new_inputs[1].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=alpha, + num_threads=num_threads, + use_ref=not cls.is_woq_int4(), + q_group_size=cls.q_group_size(), + ) + assert micro_gemm is not None + pre_block_weights = cls.check_if_block_weight(new_inputs[1], micro_gemm) + micro_gemm.use_local_vnni_blocking(not pre_block_weights) + + def preprocessor(inputs, layout): + new_inputs, new_layout = normalize_shapes( + *maybe_to_dense(*reorder_and_filter(inputs, layout)) + ) + if only_one_input and isinstance(new_inputs[0], torch.Tensor): + return new_inputs[1:], new_layout + return cls.prep_weight( + new_inputs, + new_layout, + micro_gemm, + pre_block_weights, + use_int8_fast_compensation_path, + ) + + def postprocessor(output): + if isinstance(output, ir.TensorBox): + # prepack the weight as input to the template buffer + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + + W_node = new_input_nodes[1] + if W_node.get_name() not in V.graph.constants: + return output + W = V.graph.constants[W_node.get_name()] + new_input_nodes[1] = W + new_input_nodes, new_layout = normalize_shapes( + *maybe_to_dense(new_input_nodes, layout) + ) + new_input_nodes, _ = cls.prep_weight( + new_input_nodes, + new_layout, + micro_gemm, + pre_block_weights, + use_int8_fast_compensation_path, + skip_int8_compensation=True, + ) + W_packed = new_input_nodes[1] + W_packed_constant = V.graph.add_tensor_constant(W_packed) + new_input_nodes[1] = W_packed_constant + + # Prune unused tensors + prune_tensors(input_nodes, new_input_nodes) + + template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( + W_packed_constant + ) + return output + + template = DataProcessorTemplateWrapper( + cls, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + should_block_weights=pre_block_weights, + name=micro_gemm.__class__.__name__, + ) + template.maybe_append_choice(choices) + return template + + @staticmethod + def get_padded_size(n, block_n, k, should_block_weight): + padded_n = get_padded_n(n, block_n) + # We assume that all GEMM weight tensors should be blocked and padded + new_size = [padded_n // block_n, k, block_n] + return new_size, padded_n + + @classmethod + def prep_weight( + cls, + inputs, + layout: ir.Layout, + micro_gemm: CppMicroGemm, + should_block_weight: bool, + use_int8_fast_compensation_path: bool = False, + skip_int8_compensation: bool = False, + ): + """ + NOTE Weight prep consists of 2 separate steps: + 1. Blocking the weight tensor into a 3D shape: [n//block_n, k, block_n] + This is always done if the weight tensor is constant, i.e. for all GEMM and some BMM. + For BMM, we also block non-contiguous weight tensors, since they would be reshaped anyway. + This assumes that blocked, contiguous weights will be more efficient for the GEMM kernel, + and is worth the overhead of reshape and blocking. + + This blocking includes additional padding, when n is not a multiple of block_n. + This padding allows a more efficient microkernel implementation. For BMM, this is only done + if reshape would happen anyway, i.e. if the weight tensor is constant, is not contiguous, + or is using AMX VNNI layout. + 2. Packing the weight tensor into a VNNI-friendly shape. For constant input, + this is done at the same time as the weight blocking. + + At compile time, the constant weight tensors are blocked and packed. For non-constant tensors (e.g. BMM) + which will be blocked (non-contiguous or VNNI-layout tensors), the weight tensor is blocked and packed at runtime. + + CppBmmTemplate overrides the methods get_padded_size, and block_weight in order to accommodate + an additional dimension for the batch size and to determine if the weight tensor should be blocked. + """ + W = inputs[1] + new_inputs = list(inputs) + if cls.is_woq_int4(): + assert ( + len(W.get_size()) == 2 + if isinstance(W, ir.IRNode) + else len(W.shape) == 2 + ) + n, k = W.get_size() if isinstance(W, ir.IRNode) else W.shape + else: + k, n = W.get_size()[-2:] if isinstance(W, ir.IRNode) else W.shape[-2:] + _, block_n, _ = micro_gemm.register_blocking + new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight) + padding = padded_n - n + + if should_block_weight and not cls.is_woq_int4(): + blocked_w = cls.block_weight(W, new_size, padding) + new_inputs[1] = cls.pack_vnni_weight(blocked_w, micro_gemm, new_size) + elif should_block_weight: + assert cls.is_woq_int4() + new_inputs[1] = cls.block_weight(W, new_size, padding) + elif isinstance(W, ir.IRNode): + # Require W layout to be fixed & contiguous, happens inplace. + ir.ExternKernel.require_contiguous(W) + + if not skip_int8_compensation and _is_int8_gemm(new_inputs): + BCompensate = None + x_w_scale = None + + def _get_compensation_node(W, use_int8_fast_compensation_path): + BCompensate = V.graph.add_tensor_constant( + V.graph.constants[W.get_name() + "_BMatrixCompens"], + W.get_name() + "_BMatrixCompens", + ) + x_w_scale = None + if use_int8_fast_compensation_path: + x_w_scale = V.graph.add_tensor_constant( + V.graph.constants[W.get_name() + "_x_w_compens"], + W.get_name() + "_x_w_compens", + ) + return BCompensate, x_w_scale + + if use_int8_fast_compensation_path: + # new_inputs has been reordered: [x, w, optional[bias], x_scale, x_zp, w_scale, w_zp] + x_scale = new_inputs[-4] + x_zp = new_inputs[-3] + w_scale = new_inputs[-2] + if isinstance(W, ir.IRNode): + BCompensate, x_w_scale = _get_compensation_node( + W, use_int8_fast_compensation_path + ) + else: + # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate + BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment] + assert all( + isinstance(item, torch.Tensor) + for item in (x_scale, x_zp, w_scale) + ) + BCompensate = BCompensate * x_scale * w_scale * x_zp + x_w_scale = x_scale * w_scale + new_inputs.append(BCompensate) + new_inputs.append(x_w_scale) + else: + if isinstance(W, ir.IRNode): + BCompensate, _ = _get_compensation_node( + W, use_int8_fast_compensation_path + ) + else: + # Use the original W, not the blocked_w in new_inputs[1] to calculate BCompensate + BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment] + new_inputs.append(BCompensate) + return new_inputs, layout + + @staticmethod + def check_if_block_weight(W, micro_gemm): + return True + + @classmethod + def block_weight(cls, W, new_size, padding): + # These are separated into two methods to allow subclasses to override them separately + if isinstance(W, ir.IRNode): + if W.get_name() in V.graph.constants: + # Create a new buffer, representing the constant blocked tensor + blocked_w = ir.Buffer( + name=W.get_name(), # Borrow the registered buffer name + layout=ir.FixedLayout( + W.get_device_or_error(), + W.get_dtype(), + new_size, + ir.FlexibleLayout.contiguous_strides(new_size), + 0, + ), + ) + else: + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + permute_dims = list(range(len(new_size))) + permute_dims[-2], permute_dims[-3] = permute_dims[-3], permute_dims[-2] + permute_size = list(new_size) + permute_size[-2], permute_size[-3] = permute_size[-3], permute_size[-2] + blocked_w = L.constant_pad_nd(W, (0, padding)) + blocked_w = L.permute( + L.view(blocked_w, permute_size), + permute_dims, + ) + else: + assert isinstance(W, torch.Tensor) + # Pad the weight tensor and reshape it into a 3D blocked shape + blocked_size = list(new_size) + blocked_size[-2], blocked_size[-3] = blocked_size[-3], blocked_size[-2] + blocked_w = ( + torch.nn.functional.pad(W, (0, padding)) # type: ignore[assignment] + .reshape(*blocked_size) + .transpose(-3, -2) + .contiguous() + ) + return blocked_w + + @classmethod + def pack_vnni_weight(cls, W, micro_gemm, new_size): + # WOQ INT4 weights are reordered in microkernel so do not pack them here + should_pack = ( + micro_gemm.get_b_layout() != LayoutType.NORMAL + and not micro_gemm.is_woq_int4() + ) + + # These are separated into two methods to allow subclasses to override them separately + if isinstance(W, ir.IRNode): + if isinstance(W, ir.Buffer) and W.get_name() in V.graph.constants: + return W + k = new_size[-2] + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + if should_pack: + permute_dims = list(range(len(new_size) + 1)) + permute_dims[-1], permute_dims[-2] = permute_dims[-2], permute_dims[-1] + vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 + vnni_view_size = list(new_size) + vnni_view_size[-2] = k // vnni_size + vnni_view_size.insert(-1, vnni_size) + W = L.view( + L.permute(L.view(W, vnni_view_size), permute_dims), + new_size, + ) + W = ir.ExternKernel.realize_input(W) + W = ir.ExternKernel.require_contiguous(W) + return W + else: + k = new_size[-2] + # Apply VNNI packing to the weight tensor + if should_pack: + # TODO: Move VNNI weight packing for non-constant tensors into the template, + # to improve cache locality and avoid full-tensor copy. + layout_str = ( + "VNNI4" + if micro_gemm.get_b_layout() == LayoutType.VNNI4 + else "VNNI2" + ) + assert micro_gemm.get_b_layout() in [ + LayoutType.VNNI2, + LayoutType.VNNI4, + ], f"We only support {layout_str} for now" + vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 + assert k % vnni_size == 0, ( + f"k should be divisible by vnni_size for {layout_str} layout" + ) + vnni_view_size = list(new_size) + vnni_view_size[-2] = k // vnni_size + vnni_view_size.insert(-1, vnni_size) + W = W.view(vnni_view_size).transpose(-1, -2).contiguous().view(new_size) + # normalize stride to be "contiguous_strides" per size + # this avoids the problems in L.view during template codegen + new_stride = [1] + for sz in reversed(W.shape[1:]): + new_stride.insert(0, new_stride[0] * sz) + W = W.as_strided(W.shape, new_stride) + return W + + def get_default_reindexers(self, epilogue_nodes): + return [None] * len(epilogue_nodes) + + def get_options( + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + ) -> dict[str, Any]: + assert len(self.input_nodes) >= 2 + + int8_gemm = self.input_nodes[0].get_dtype() in [torch.uint8, torch.int8] + x_scale = None + x_zp = None + w_scale = None + w_zp = None + inp = None + q_group_size_node = None + qscale_and_zeros = None + if int8_gemm: + X, W = self.input_nodes[0], self.input_nodes[1] + bias_idx = 2 if self.has_bias else 1 + inp = self.input_nodes[bias_idx] if self.has_bias else None + x_scale = self.input_nodes[bias_idx + 1] + x_zp = self.input_nodes[bias_idx + 2] + w_scale = self.input_nodes[bias_idx + 3] + w_zp = self.input_nodes[bias_idx + 4] + Y = self.output_node + elif self.is_woq_int4(): + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + q_group_size_node = self.input_nodes[2] + qscale_and_zeros = self.input_nodes[3] + else: + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + inp = self.input_nodes[2] if self.has_bias else None + + template_buffer_has_other_users = None + + if template_buffer_node is not None: + # Use the updated prepacked weight buffer + W = template_buffer_node.inputs[1] + Y = template_buffer_node + + assert flag_template_buffer_has_other_users is not None + template_buffer_has_other_users = flag_template_buffer_has_other_users + + template_buffer = Y + gemm_output_buffer = template_buffer + + epilogues: list[ir.IRNode] = [] + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = [] + epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = [] + fake_buffers: list[ir.Buffer] = [] + Y_aliases: OrderedSet[str] = OrderedSet() + + use_local_acc = ( + self.layout.dtype != torch.float + or template_buffer_has_other_users + or int8_gemm + or self.padded_n != self.n + or self.maybe_k_slicing() + or (epilogue_nodes and epilogue_nodes[-1].get_dtype() != self.layout.dtype) + ) + + # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template, + # but we'd better move it here to align with fp. + if inp is not None and self.beta != 0 and not int8_gemm: + # add an epilogue for bias add + def _bias_add_epilogue(buf): + return create_epilogue_with_attr( + buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype + ) + + epilogue_creators.append(_bias_add_epilogue) + + if self.epilogue_creator is not None: + epilogue_creators.append(self.epilogue_creator) + + # When the GEMM output buffer is localized but it has users other than the epilogue nodes, + # we need to copy the value in the GEMM output local buffer to a global buffer. + def need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + # The GEMM output buffer is a global buffer, thus copy is not needed. + if not use_local_acc: + return False + + # The possible value of template_buffer_has_other_users is (None, False, True) + # It is None when generating the gemm template during autotune and it will have value during scheduler codegen. + # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases: + # 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune) + # 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the + # GEMM output buffer in local buffer only (no users outside of the epilogues will use its value). + if not template_buffer_has_other_users: + return False + + # When bias is not None or self.epilogue_creator is not None, + # there will be epilogue_creators after the GEMM. + # The GEMM output buffer is localized while + # the output buffer of the epilogue_creators is a global buffer. + if epilogue_creators: + return False + + return True + + if need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + + def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer): + dtype = self.layout.dtype + input_loader = input_buffer.make_loader() + + def copy_inner(index): + input = input_loader(index) + result = ops.to_dtype(input, dtype) + return result + + return ir.Pointwise( + device=input_buffer.get_device_or_error(), + dtype=self.layout.dtype, + inner_fn=copy_inner, + ranges=input_buffer.get_size(), + ) + + epilogue_creators.append(copy_from_local_to_global_buffer_epilogue) + + # NOTE [How CPP GEMM template epilogues are organized] + # gemm_output_buffer + # --> zero or more in-template epilogues (created by `epilogue_creators`) --> + # template_buffer + # --> zero or more out-of-template epilogues (`epilogue_nodes`) --> + # Y + if epilogue_creators: + assert isinstance(template_buffer, ir.IRNode) + gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + gemm_output_buffer = ir.Buffer( + name=gemm_output_name, layout=template_buffer.layout + ) + current_input_buffer = gemm_output_buffer + for i, creator in enumerate(epilogue_creators): + if i == len(epilogue_creators) - 1: + buffer_name = template_buffer.get_name() + else: + buffer_name = f"{gemm_output_name}_epilogue_{i}" + epilogues.append( + ir.ComputedBuffer( + name=buffer_name, + layout=template_buffer.layout, + data=creator(current_input_buffer), + ) + ) + fake_buffers.append(current_input_buffer) + Y_aliases.add(current_input_buffer.get_name()) + reindexers.append(None) + if i < len(epilogue_creators) - 1: + current_input_buffer = ir.Buffer( + name=buffer_name, layout=template_buffer.layout + ) + + assert isinstance(Y, (ir.Buffer, ir.ReinterpretView)) + Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y + + if epilogue_nodes: + if not template_buffer_has_other_users: + assert isinstance(template_buffer, ir.IRNode) + Y_aliases.add(template_buffer.get_name()) + epilogues.extend(epilogue_nodes) + assert Y.get_numel() == epilogues[-1].get_numel() + Y = cast(ir.Buffer, epilogues[-1]) + assert isinstance(template_buffer, ir.Buffer) + Y_2d, reindexers = gen_2d_view_of_epilogue_buf( + Y, + template_buffer, + epilogue_nodes, + reindexers, + default_reindexers=self.get_default_reindexers(epilogue_nodes), + ) + + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + X.get_dtype() + ) + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + input_dtype=X.get_dtype(), + input2_dtype=W.get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=self.alpha, + num_threads=self.num_threads, + use_ref=not self.is_woq_int4(), + q_group_size=self.q_group_size(), + ) + assert micro_gemm is not None + micro_gemm.use_local_vnni_blocking(not self.should_block_weights) + assert self.register_blocking == micro_gemm.register_blocking + self.log_blockings() + if isinstance(micro_gemm, CppMicroGemmAMX): + counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + if isinstance(micro_gemm, CppMicroBrgemm): + counters["inductor"]["cpp_micro_brgemm_counter"] += 1 + + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + + options = dict( + X=X, + W=W, + inp=inp, + Y=Y, + N=self.n, + K=self.k, + PADDED_N=self.padded_n, + GemmOut=gemm_output_buffer, + aliases={alias: Y.get_name() for alias in Y_aliases}, + beta=self.beta, + alpha=self.alpha, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, + template=self, + kernel=kernel, + export_declaration=get_export_declaration(), + epilogue_nodes=epilogues, + reindexers=reindexers, + Y_2d=Y_2d, + use_local_acc=use_local_acc, + maybe_k_slicing=self.maybe_k_slicing(), + x_scale=x_scale, + x_zp=x_zp, + w_scale=w_scale, + w_zp=w_zp, + acc_buf_dtype=torch.int32 if int8_gemm else torch.float, + DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + config=config, + fake_buffers=fake_buffers, + is_woq_int4=self.is_woq_int4(), + q_group_size=q_group_size_node, + qscale_and_zeros=qscale_and_zeros, + ) + return options + + def is_int8_woq_gemm_small_m_dim( + self, + X: ir.ReinterpretView, + W: ir.ReinterpretView, + N, + K, + micro_gemm, + ): + """Use SMALL_M_GEMM_TEMPLATE""" + return ( + isinstance(micro_gemm, CppMicroGemmFP32Vec) + and is_int8_woq_gemm_small_m_dim_corner_case( + micro_gemm, X.get_size()[0], N, K + ) + and X.get_dtype() is torch.bfloat16 + and W.get_dtype() is torch.int8 + ) + + def render( # type: ignore[override, return] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + options = self.get_options( + kernel=kernel, + template_buffer_node=template_buffer_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + ) + self.render_options = options + + with contextlib.ExitStack() as stack: + for buf in options["fake_buffers"]: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + if not options["is_dynamic_M"] and self.is_int8_woq_gemm_small_m_dim( + options["X"], + options["W"], + options["N"], + options["K"], + options["micro_gemm"], + ): + template_str = SMALL_M_GEMM_TEMPLATE + else: + template_str = GEMM_TEMPLATE + return self._template_from_string(template_str).render(**options) + + def codegen_blocks( + self, + num_threads, + N, + K, + micro_gemm, + is_dynamic_M, + kernel, + GemmOut, + config, + L1_cache_size, + L2_cache_size, + X, + W, + ): + options = dict( + num_threads=num_threads, + N=N, + K=K, + micro_gemm=micro_gemm, + is_dynamic_M=is_dynamic_M, + kernel=kernel, + GemmOut=GemmOut, + config=config, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + template=self, + X=X, + W=W, + is_woq_int4=self.is_woq_int4(), + ) + template_str = GEMM_TEMPLATE_INIT_BLOCKING_BASIC_BLOCK + if not ( + not is_dynamic_M + and self.is_int8_woq_gemm_small_m_dim(X, W, N, K, micro_gemm) + ): + template_str += GEMM_TEMPLATE_INIT_BLOCKING_EXTENDED + return self._template_from_string(template_str).render(options) + + def codegen_microkernel_def(self): + return self._template_from_string(GEMM_TEMPLATE_MICROKERNEL_DEF).render( + self.render_options + ) + + def codegen_gemm_stub_def(self): + microkernel = self.codegen_microkernel_def() + return microkernel + self._template_from_string(GEMM_TEMPLATE_STUB_DEF).render( + self.render_options + ) + + def codegen_multi_threads_params(self): + return self._template_from_string(GEMM_TEMPLATE_MULTI_THREADS_PARAMS).render() + + def codegen_single_thread_params(self, is_dynamic_M): + options = dict( + is_dynamic_M=is_dynamic_M, + ) + return self._template_from_string(GEMM_TEMPLATE_SINGLE_THREAD_PARAMS).render( + options + ) + + def codegen_m_loop_params(self): + return self._template_from_string(GEMM_TEMPLATE_M_LOOP_PARAMS).render() + + def codegen_n_loop_params(self): + return self._template_from_string(GEMM_TEMPLATE_N_LOOP_PARAMS).render() + + @classmethod + def is_woq_int4(cls): + return False + + @classmethod + def q_group_size(cls): + return None + + +class CppWoqInt4GemmTemplateMeta(type): + def __getitem__(cls, q_group_size): + class CppWoqInt4GemmTemplateInstance(CppGemmTemplate): + def __init__( + self, + *args, + **kwargs, + ) -> None: + super().__init__( + *args, + **kwargs, + ) + + @classmethod + def is_woq_int4(cls): + return True + + @classmethod + def q_group_size(cls): + return q_group_size + + @staticmethod + def check_if_block_weight(W, micro_gemm): + # For WOQ INT4, weight is already packed + # However, for AMX microkernel, we want to change the blocking of weight + from .cpp_micro_gemm import CppMicroGemmWoQInt4Amx + + return isinstance(micro_gemm, CppMicroGemmWoQInt4Amx) + + @classmethod + def block_weight(cls, W, new_size, padding): + # This method is called only if AMX microkernels are used. + # In this case, we unpack and repack weight so that block_n=32 + # the format of packed weight is described here: + # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 + if isinstance(W, ir.IRNode): + # in this case, we do nothing + ir.ExternKernel.require_contiguous(W) + blocked_w = W + else: + # in this case, we unpack and repack weight + assert isinstance(W, torch.Tensor) + assert W.dim() == 2 + N = W.size(0) + K = W.size(-1) * 2 + G = cls.q_group_size() + # x and qscales_and_zeros are in bfloat16 instead of float to use the optimized kernel + # so that the unpacking process is faster + x = torch.eye(K).bfloat16() + # Here we use scale=1 and qzero=8 because we want to unpack weight + # without dequantizing it. The qzero here is 8 instead of 0 because + # int4 values are converted to [-7, 8] in the _weight_int4pack_mm_for_cpu kernel: + # https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L95 + qscales_and_zeros = ( + torch.tensor([1.0, 8.0]) + .bfloat16() + .expand(K // G, N, 2) + .contiguous() + ) + # shape: [K, N] + unpacked_w = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + W, + G, + qscales_and_zeros, + ).to(torch.uint8) + block_n = 32 + # shape: [N // block_n, K, block_n] + w_blocked = ( + unpacked_w.view(K, N // block_n, block_n) + .permute(1, 0, 2) + .contiguous() + ) + # pack 2 int4 -> 1 int8 + # block_n: [a0, a1, ..., a15, b0, b1, ..., b15] + # -> [(a0 & 0xf) | (b0 << 4), (a1 & 0xf) | (b1 << 4), ...] + # shape: [N // block_n, K, 2, block_n // 2] + w_blocked = w_blocked.view(N // block_n, K, 2, block_n // 2) + # shape: [N // block_n, K, block_n // 2] + w_blocked_packed = (w_blocked[:, :, 0, :] & 0xF) | ( + w_blocked[:, :, 1, :] << 4 + ) + # shape: [N, K // 2] + blocked_w = w_blocked_packed.view(N, K // 2) + + return blocked_w + + return CppWoqInt4GemmTemplateInstance + + +class CppWoqInt4GemmTemplate(metaclass=CppWoqInt4GemmTemplateMeta): + pass diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc86cb3b334c158bd62e475854415ac9d8b5378 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_grouped_gemm_template.py @@ -0,0 +1,500 @@ +import contextlib +import logging +from typing import Any, Callable, cast, Optional, TypeVar +from unittest.mock import patch + +import torch +import torch.utils +from torch.utils._ordered_set import OrderedSet + +from ..._dynamo.utils import counters +from .. import config, ir +from ..kernel.mm_common import mm_args +from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper +from ..utils import parallel_num_threads +from ..virtualized import V +from .cpp import get_export_declaration +from .cpp_gemm_template import ( + CppGemmTemplate, + expand_bias, + gen_2d_view_of_epilogue_buf, + prune_tensors, + transpose_w, +) +from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import ( + create_epilogue_with_attr, + DTYPE_TO_CPP, + GemmBlocking, + get_gemm_template_output_and_compute_dtype, +) + + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} +{{micro_gemm.codegen_define(kernel)}} + +extern "C" {{export_declaration}} +{{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}} +{ + {{kernel.maybe_codegen_profile()}} + {{ template.codegen_blocks( + num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0] + ) }} +{%- if num_threads > 1 %} + #pragma omp parallel num_threads({{num_threads}}) + { + {{ template.codegen_multi_threads_params()|indent(8, false) }} +{%- else %} + { + {{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }} +{%- endif %} + {{ micro_gemm.codegen_init(kernel) }} +{%- set acc_buf_name_list=[] %} +{%- set acc_buf_name_prefix = "local_acc_buf_" %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} + {%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %} +{%- endfor %} + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { + {{ template.codegen_m_loop_params()|indent(12, false) }} + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + {{ template.codegen_n_loop_params()|indent(16, false) }} +{%- set acc_list=[] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %} + {{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }} +{%- endfor %} + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + int64_t k_start = kc * Kr; + int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); +{%- set tile_X_list=[] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %} +{%- endfor %} + for (int64_t nci = nc; nci < nc_block_end; nci++) { +{%- set tile_W_3d_list=[] %} +{%- set tile_W_list=[] %} +{%- set acc_slice_list=[] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set acc_slice_list = acc_slice_list.append( + kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) + ) %} + {%- set tile_W_3d_list = tile_W_3d_list.append( + kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()]) + ) %} +{%- endfor %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set tile_W_list = tile_W_list.append( + kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n]) + ) %} +{%- endfor %} + if (kc == k_block_start) { + {%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {{ micro_gemm.codegen_call( + kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False + )|indent(28, false) }} + {%- endfor %} + } else { + {%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {{ micro_gemm.codegen_call( + kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True + )|indent(28, false) }} + {%- endfor %} + } + } + } + { +{%- set tile_acc_list = [] %} +{%- set tile_Y_list = [] %} +{%- for gemm_idx in range(0, gemm_grouped_num, 1) %} + {%- set tile_acc_list = tile_acc_list.append( + kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")]) + ) %} + {%- set tile_Y_list = tile_Y_list.append( + kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")]) + ) %} +{%- endfor %} + {{ kernel.store_outputs( + tile_Y_list, + tile_acc_list, + GemmOuts, + epilogue_nodes, + offsets=("m_start", "n_start"), + reindexers=reindexers, + multi_output_buffers=multi_output_buffers + )|indent(20, false) + }} + } + } + } + {{ micro_gemm.codegen_finalize(kernel) }} + } +} +""" + + +def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> list[ir.IRNode]: + act_deduplicated = [] + act_deduplicated_name: OrderedSet[str] = OrderedSet() + for act_idx in range(len(act_mapping.values())): + act = act_mapping[act_idx] + if act.get_name() not in act_deduplicated_name: + act_deduplicated.append(act) + act_deduplicated_name.add(act.get_name()) + return act_deduplicated + + +class CppGroupedGemmTemplate(CppGemmTemplate): + def __init__( + self, + input_nodes: list[ir.IRNode], + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta: int = 1, + alpha: int = 1, + has_bias: bool = False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + act_mapping: Optional[dict[int, ir.IRNode]] = None, + gemm_grouped_num: int = 1, + ) -> None: + """ + Template for Group of GEMMs: + * Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc) + for their A, B, and C matrices. + * Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues. + * In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues. + This behavior can be extended in the future if needed. + """ + super().__init__( + input_nodes, + layout, + num_threads, + register_blocking, + beta, + alpha, + has_bias, + epilogue_creator, + ) + self.act_mapping = act_mapping + self.gemm_grouped_num = gemm_grouped_num + self.output_node: list[ir.Buffer] = [ + ir.Buffer(name="buf_out" + str(idx), layout=layout) + for idx in range(gemm_grouped_num) + ] + + @classmethod + def add_choices( + cls, + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[ir.IRNode], + beta: int = 1, + alpha: int = 1, + has_bias: tuple[bool, ...] = (False, False), + trans_w: bool = False, + input_indices: Optional[list[int]] = None, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf + ) -> DataProcessorTemplateWrapper: + # Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ... + gemm_grouped_num = len(has_bias) + assert act_mapping + act_deduplicated = get_deduplicated_act(act_mapping) + wgt_start_idx = len(act_deduplicated) + bias_start_idx = wgt_start_idx + gemm_grouped_num + input_indices = list(range(len(input_nodes))) + + _T = TypeVar("_T", ir.IRNode, torch.Tensor) + _U = TypeVar("_U", ir.Layout, torch.Tensor) + + def reorder_and_filter( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + assert input_indices is not None, "input_indices must be set" + return [inputs[idx] for idx in input_indices], layout_or_out + + new_inputs, new_layout = reorder_and_filter(input_nodes, layout) + + def maybe_to_dense( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + new_inputs = list(inputs) + for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num): + if isinstance(inputs[idx], torch.Tensor): + W = inputs[idx] + assert isinstance(W, torch.Tensor), "W must be a torch.Tensor" + new_inputs[idx] = W.to_dense() if W.is_mkldnn else W + return new_inputs, layout_or_out + + def normalize_shapes( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + new_inputs: list[_T] = list(inputs) + if not trans_w: + return new_inputs, layout_or_out + X = new_inputs[0] + for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num): + new_input = new_inputs[wgt_idx] + new_inputs[wgt_idx] = transpose_w(new_input, trans_w) + for bias_idx in range(bias_start_idx, len(new_inputs)): + new_bias = expand_bias(new_inputs[bias_idx], X) + assert new_bias is not None + new_inputs[bias_idx] = new_bias + return new_inputs, layout_or_out + + num_threads = parallel_num_threads() + new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout)) + m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx]) + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + new_inputs[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=new_inputs[0].get_dtype(), + input2_dtype=new_inputs[wgt_start_idx].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=alpha, + num_threads=num_threads, + ) + assert micro_gemm is not None + _, block_n, _ = micro_gemm.register_blocking + new_size, padded_n = cls.get_padded_size( + n, block_n, k, should_block_weight=True + ) + padding = padded_n - n + + def pack_weight( + inputs: list[_T], + layout_or_out: _U, + ) -> tuple[list[_T], _U]: + new_W_list = [] + new_inputs = list(inputs) + W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] + for W in W_list: + blocked_w = cls.block_weight(W, new_size, padding) + new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size)) + new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list + return new_inputs, layout_or_out + + def preprocessor( + inputs: list[_T], + layout: _U, + ) -> tuple[list[_T], _U]: + return pack_weight( + *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) + ) + + def postprocessor(output: _T) -> _T: + if isinstance(output, ir.TensorBox): + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + W_nodes = new_input_nodes[ + wgt_start_idx : wgt_start_idx + gemm_grouped_num + ] + W_tensor = [] + for W_node in W_nodes: + assert W_node.get_name() in V.graph.constants + W_tensor.append(V.graph.constants[W_node.get_name()]) + new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = ( + W_tensor # type: ignore[assignment] + ) + new_input_nodes, _ = pack_weight( + *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) + ) + # Prune unused tensors + prune_tensors(input_nodes, new_input_nodes) + for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num): + W_packed = new_input_nodes[idx] + assert isinstance(W_packed, torch.Tensor) + W_packed_constant = V.graph.add_tensor_constant(W_packed) + template_buffer.inputs[idx] = ( + ir.InputsKernel.unwrap_storage_for_input(W_packed_constant) + ) + return output + + template = DataProcessorTemplateWrapper( + CppGroupedGemmTemplate, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + act_mapping=act_mapping, + gemm_grouped_num=gemm_grouped_num, + ) + template.maybe_append_choice(choices) + return template + + def render( # type: ignore[override,return,no-untyped-def] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + **kwargs, + ) -> str: + assert self.act_mapping + act_deduplicated = get_deduplicated_act(self.act_mapping) + wgt_start_idx = len(act_deduplicated) + bias_start_idx = wgt_start_idx + self.gemm_grouped_num + X_list = list(self.act_mapping.values()) + W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num] + inp_list = [] + cur_idx = bias_start_idx + for inp_idx in range(self.gemm_grouped_num): + inp = None + if self.has_bias[inp_idx]: + inp = self.input_nodes[cur_idx] + cur_idx += 1 + inp_list.append(inp) + + Y_list = self.output_node + multi_output_buffers = None + if template_buffer_node is not None: + W_list = template_buffer_node.inputs[ + wgt_start_idx : wgt_start_idx + self.gemm_grouped_num + ] + assert isinstance(template_buffer_node.outputs, list) + Y_list = template_buffer_node.outputs + counters["inductor"]["cpp_grouped_gemm_template"] += 1 + multi_output_buffers = template_buffer_node.outputs + + template_buffer = Y_list[0] + fake_buffers: list[ir.Buffer] = [] + Y_2d_list = Y_list + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + X_list[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + input_dtype=X_list[0].get_dtype(), + input2_dtype=W_list[0].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=self.alpha, + num_threads=self.num_threads, + ) + assert micro_gemm is not None + assert self.register_blocking == micro_gemm.register_blocking + self.log_blockings() + if isinstance(micro_gemm, CppMicroGemmAMX): + counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + + epilogues: list[ir.IRNode] = [] + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = [] + gemm_output_buffers: list[ir.Buffer] = [] + for out_buf_idx in range(self.gemm_grouped_num): + gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str( + out_buf_idx + ) + gemm_output_buffers.append( + ir.Buffer(name=gemm_output_name, layout=template_buffer.layout) + ) + + assert not self.epilogue_creator, ( + "epilogue_creator is not supported yet in Grouped GEMM Template" + ) + + kernel_args: dict[str, Optional[ir.IRNode]] = {} + for x_idx in range(wgt_start_idx): + kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx] + for w_idx in range(self.gemm_grouped_num): + kernel_args["W" + str(w_idx)] = W_list[w_idx] + for inp_idx in range(self.gemm_grouped_num): + kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx] + + def _bias_add_epilogue(buf: ir.IRNode, inp: ir.IRNode) -> ir.Pointwise: + return create_epilogue_with_attr( + buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype + ) + + for gemm_idx, inp in enumerate(inp_list): + if inp: + buffer_name = Y_list[gemm_idx].get_name() + epilogues.append( + ir.ComputedBuffer( + name=buffer_name, + layout=template_buffer.layout, + data=_bias_add_epilogue(gemm_output_buffers[gemm_idx], inp), + ) + ) + reindexers.append(None) + + if epilogue_nodes: + epilogues.extend(epilogue_nodes) + for epilogue_node in epilogue_nodes: + Y = cast(ir.Buffer, epilogue_node) + _, reindexers = gen_2d_view_of_epilogue_buf( + Y, + template_buffer, + [ + epilogue_node, + ], + reindexers, + default_reindexers=[ + None, + ], + ) + + options = dict( + N=self.n, + K=self.k, + PADDED_N=self.padded_n, + aliases={}, + beta=self.beta, + alpha=self.alpha, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, + template=self, + kernel=kernel, + export_declaration=get_export_declaration(), + acc_buf_dtype=torch.float, + DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + config=config, + epilogue_nodes=epilogues, + GemmOuts=gemm_output_buffers, + reindexers=reindexers, + kernel_args=kernel_args, + X_list=X_list, + W_list=W_list, + gemm_grouped_num=self.gemm_grouped_num, + Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)}, + Y_2d_list=Y_2d_list, + multi_output_buffers=multi_output_buffers, + ) + with contextlib.ExitStack() as stack: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers)) + ) + return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a2b8d17df306642af5bff184139390720af09b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_micro_gemm.py @@ -0,0 +1,2011 @@ +# mypy: allow-untyped-defs +import dataclasses +import operator +import sys +from enum import Enum +from typing import Callable, Optional + +import torch + +from .. import cpp_builder, ir +from ..cpu_vec_isa import ( + pick_vec_isa, + VecAMX, + VecAVX2, + VecAVX512, + VecISA, + VecNEON, + VecSVE256, +) +from ..utils import IndentedBuffer, parallel_num_threads +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp + + +class LayoutType(Enum): + NORMAL = 0 + VNNI2 = 1 + VNNI4 = 2 + + +_IS_WINDOWS = sys.platform == "win32" + + +def get_restrict_keyword() -> str: + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/cpp/extension-restrict?view=msvc-170 + return "__restrict" + else: + return "__restrict__" + + +class CppMicroGemm: + """ + A class that codegens a kernel that computes small-sized matrix multiplication. + + A micro GEMM kernel is responsible for register blocking, instruction selection, + and other CPU architecture-specific optimizations. + + The subclasses need to override `codegen_define` to define the kernel function + that is called by the code generated by `codegen_call`. + """ + + # TODO(jgong5): support constant shapes and lds as template args. + DECLARE_KERNEL = r""" +template +inline void {{kernel_name}}( +{%- if kernel_extra_args_declare %} + {{kernel_extra_args_declare}} +{%- endif %} + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) +""" + + def __init__( + self, + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + ) -> None: + self.name = name + self.input_dtype = input_dtype + assert input2_dtype is not None + self.input2_dtype = input2_dtype + self.output_dtype = output_dtype + self.compute_dtype = compute_dtype + self.register_blocking = register_blocking + self.alpha = alpha + self.pack_vnni_B_locally = False + + def get_common_options(self): + if self.input_dtype in [torch.uint8, torch.int8]: + assert self.compute_dtype == torch.int32 + assert self.output_dtype == torch.int32 + assert self.input2_dtype == torch.int8 + return { + "torch": torch, + "kernel_name": self.name, + "input_dtype": self.input_dtype, + "input2_dtype": self.input2_dtype, + "output_dtype": self.output_dtype, + "compute_dtype": self.compute_dtype, + "input_t": DTYPE_TO_CPP[self.input_dtype], + "input2_t": DTYPE_TO_CPP[self.input2_dtype], + "output_t": DTYPE_TO_CPP[self.output_dtype], + "compute_t": DTYPE_TO_CPP[self.compute_dtype], + "alpha": self.alpha, + "kernel_extra_args_declare": self.get_kernel_extra_args_declare(), + "int8_gemm": self.input_dtype in [torch.uint8, torch.int8], + "vnni_size": 4 if self.input_dtype in [torch.uint8, torch.int8] else 2, + "restrict_keyword": get_restrict_keyword(), + "pack_vnni_B_locally": self.pack_vnni_B_locally, + "template": self, + "is_woq_int4": self.is_woq_int4(), + } + + def get_kernel_declaration(self): + options = self.get_common_options() + return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) + + def get_kernel_extra_args_declare(self) -> str: + return "" + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + return [] + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + raise NotImplementedError + + def codegen_call( + self, + kernel: CppTemplateKernel, + A: ir.Buffer, + B: ir.Buffer, + C: ir.Buffer, + accum: bool, + prefetch: bool = False, + **kwargs_for_extra_args, + ) -> str: + """ + Generate the code for calling the templated kernel that computes + `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise. + """ + A_ptr = f"&({kernel.index(A, [0, 0])})" + B_ptr = f"&({kernel.index(B, [0, 0])})" + C_ptr = f"&({kernel.index(C, [0, 0])})" + M = kernel.size(C, 0) + N = kernel.size(C, 1) + K = kernel.size(A, 1) + lda = kernel.stride(A, 0) + ldb = kernel.stride(B, 0) + ldc = kernel.stride(C, 0) + res = IndentedBuffer() + res.writeline( + f"{self.name}<{value_to_cpp(accum, 'bool')}, {value_to_cpp(prefetch, 'bool')}>(" + ) + with res.indent(): + kwargs_for_extra_args.update({"kernel": kernel}) + extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args) + for arg in extra_args: + res.writeline(arg) + res.writeline(f"{A_ptr},") + res.writeline(f"{B_ptr},") + res.writeline(f"{C_ptr},") + res.writeline(f"{M},") + res.writeline(f"{N},") + res.writeline(f"{K},") + res.writeline(f"{lda},") + res.writeline(f"{ldb},") + res.writeline(f"{ldc}") + res.writeline(");") + return res.getvalue() + + def use_local_vnni_blocking(self, should_block_weight: bool): + self.pack_vnni_B_locally = should_block_weight + + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "" + + def get_b_layout(self) -> LayoutType: + return LayoutType.NORMAL + + ALLOCATE_WEIGHT_BUFFER = r""" + {%- if is_msvc_compiler %} + // MSVC doesn't support stack-allocated dynamic-sized arrays, so using heap memory here. + std::unique_ptr<{{buffer_dtype}}[]> heap_deq_b_buf_ptr(new {{buffer_dtype}}[{{buffer_size}}]); + {{buffer_dtype}}* {{buffer_name}} = heap_deq_b_buf_ptr.get(); + {%- else %} + // It's safe to use a stack-allocated array since the blocking strategy would + // require us to allocate an array that's smaller than the size of L1D cache, + // and the default per thread max stack size on Linux is quite higher, + // so we need not worry about stack overflow. + alignas(4096) {{buffer_dtype}} {{buffer_name}}[{{buffer_size}}]; + {%- endif %} +""" + + def codegen_allocate_weight_buffer( + self, buffer_name: str, buffer_dtype: str, *size_args + ) -> str: + buffer_size = " * ".join(map(str, size_args)) + return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render( + dict( + buffer_name=buffer_name, + buffer_dtype=buffer_dtype, + buffer_size=buffer_size, + is_msvc_compiler=cpp_builder.is_msvc_cl(), + ) + ) + + def is_woq_int4(self): + return False + + +@dataclasses.dataclass +class CppMicroGemmConfig: + input_dtype: torch.dtype + input2_dtype: torch.dtype + output_dtype: torch.dtype + compute_dtype: torch.dtype + vec_isa_cls: type[VecISA] + register_blocking: GemmBlocking + extra_check: Optional[Callable[..., bool]] = None + + +micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {} + + +def register_micro_gemm(*configs): + def inner(cls): + assert cls not in micro_gemm_configs, ( + f"Duplicate micro_gemm registration for {cls}" + ) + assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" + micro_gemm_configs[cls] = list(configs) + return cls + + return inner + + +def generate_gemm_config( + vec_isa_cls, + register_blockings, + input_dtype=torch.float, + input2_dtype=None, + output_dtype=None, + compute_dtype=None, + extra_check=None, +): + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = output_dtype + if input2_dtype is None: + input2_dtype = input_dtype + return [ + CppMicroGemmConfig( + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + vec_isa_cls, + GemmBlocking(*blocking), + extra_check, + ) + for blocking in register_blockings + ] + + +class CppMicroGemmRef(CppMicroGemm): + """ + A reference implementation of the CppMicroGemm class with naive C++ code. + It is used for correctness debugging. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + {{compute_t}} result = accum ? C[m * ldc + n] : 0; + for (int64_t k = 0; k < K; ++k) { + result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; + } + C[m * ldc + n] = result; + } + } +} +""" + + def __init__( + self, name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha + ) -> None: + super().__init__( + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + GemmBlocking(1, 1, 1), + alpha, + ) + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + **self.get_common_options(), + } + return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) + + +def is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k): + return ( + k % config.register_blocking.block_k == 0 + and n % config.register_blocking.block_n == 0 + and m < 16 + ) + + +# extra check for small M dimension for int8 WoQ case +def check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs): + return is_int8_woq_gemm_small_m_dim_corner_case(config, m, n, k) and not kwargs.get( + "dynamic_M", False + ) + + +# For int8 WoQ GEMM with small M, we use different blockings that shouldn't be used otherwise +def do_not_use_with_small_m_for_int8_woq(config, m, n, k, alpha, num_threads, **kwargs): + return not check_int8_woq_small_m_dim(config, m, n, k, alpha, num_threads, **kwargs) + + +@register_micro_gemm( + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.half, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX512, + [(8, 48, 1), (8, 32, 1), (16, 16, 1)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=do_not_use_with_small_m_for_int8_woq, + ), + *generate_gemm_config( + VecAVX512, + [ + (4, 32, 64), + (8, 32, 64), + ], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_int8_woq_small_m_dim, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.half, + output_dtype=torch.float, + ), + *generate_gemm_config( + VecAVX2, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=do_not_use_with_small_m_for_int8_woq, + ), + *generate_gemm_config( + VecAVX2, + [ + (2, 16, 64), + (4, 16, 64), + ], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_int8_woq_small_m_dim, + ), + *generate_gemm_config( + VecNEON, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + input2_dtype=torch.float, + output_dtype=torch.float, + compute_dtype=torch.float, + ), + *generate_gemm_config( + VecSVE256, + [(4, 24, 1), (4, 16, 1), (8, 8, 1)], + input_dtype=torch.float, + input2_dtype=torch.float, + output_dtype=torch.float, + compute_dtype=torch.float, + ), +) +class CppMicroGemmFP32Vec(CppMicroGemm): + """ + This class generates the code for micro gemm using fp32 vec instructions for compute. + It supports input types of torch.float, torch.bfloat16, and torch.half with fp32 output. + The output of the microkernel is in FP32, but it would be converted to BF16/FP16 in the template, + if the desired output is BF16/FP16. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + using Vectorized = at::vec::Vectorized<{{compute_t}}>; + constexpr auto VLEN = Vectorized::size(); + {{kernel.assert_function}}({{block_n}} % VLEN == 0, "block_n dimension must be multiple of Vector size"); + {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + // TODO(jgong5): loop unroll for M and N + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + for (int64_t n = 0; n < N; n += {{block_n}}) { + int64_t block_n = std::min(N - n, {{block_n}}); + if (block_m == {{block_m}} && block_n == {{block_n}}) { +{%- if not trans_b %} + {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum, prefetch>( +{%- else %} + {{kernel_name}}_transpose_b_kernel<{{block_m}}, {{block_n}}, accum, prefetch>( +{%- endif %} + A + m * lda, +{%- if not trans_b %} + B + n, +{%- else %} + B + n * ldb, +{%- endif %} + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); +{%- if tail_n %} + } else if (block_n == {{block_n}}){ +{%- else %} + } else { +{%- endif %} + switch (block_m) { +{%- for b in range(block_m - 1, 0, -1) %} + case {{b}}: + {%- if not trans_b %} + {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- else %} + {{kernel_name}}_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- endif %} + A + m * lda, + {%- if not trans_b %} + B + n, + {%- else %} + B + n * ldb, + {%- endif %} + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); + break; +{%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); + } + +{%- if tail_n %} + } else { + switch (block_m) { + {%- for b in range(block_m, 0, -1) %} + case {{b}}: + {%- if not trans_b %} + {{kernel_name}}_ntail_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- else %} + {{kernel_name}}_ntail_transpose_b_kernel<{{b}}, {{block_n}}, accum, prefetch>( + {%- endif %} + A + m * lda, + {%- if not trans_b %} + B + n, + {%- else %} + B + n * ldb, + {%- endif %} + C + m * ldc + n, + block_n, + K, + lda, + ldb, + ldc + ); + break; + {%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); + } + } +{%- else %} + } +{%- endif %} + } + } +} +""" + + TEMPLATE_KERNEL = r""" + +template +{%- if not trans_b %} + {%- if tail_n %} +inline void {{kernel_name}}_ntail_kernel( + {%- else %} +inline void {{kernel_name}}_kernel( + {%- endif %} +{%- else %} + {%- if tail_n %} +inline void {{kernel_name}}_ntail_transpose_b_kernel( + {%- else %} +inline void {{kernel_name}}_transpose_b_kernel( + {%- endif %} +{%- endif %} + const {{input_t}}* {{restrict_keyword}} A, + const {{input2_t}}* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, +{%- if tail_n %} + int64_t N, +{%- endif %} + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) { + using Vectorized = at::vec::Vectorized<{{compute_t}}>; +{%- if input2_dtype in [torch.bfloat16, torch.float16] %} + using VectorizedIn = at::vec::Vectorized<{{input_t}}>; +{%- endif %} + +{%- if not trans_b %} + constexpr auto VLEN = Vectorized::size(); + constexpr auto ROWS = BLOCK_M; + constexpr auto COLS = BLOCK_N / VLEN; + + Vectorized va; + at::vec::VectorizedN<{{compute_t}}, COLS> vb; + at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; + + {%- if tail_n %} + int64_t rCOLS = (N + VLEN - 1) / VLEN; + int ntail = N % VLEN; + {%- endif %} + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + {%- if tail_n %} + int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN; + if (col < rCOLS) { + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN, load_size); + } + {%- else %} + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); + {%- endif %} + } else { + vc[i] = Vectorized(0.0f); + } + }; + c10::ForcedUnroll{}(loadc); + + auto compute = [&, COLS](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + {%- if tail_n %} + int load_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN; + {%- endif %} + if constexpr (col == 0) { + {%- if alpha != 1 %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k]) * {{alpha}}); + {%- else %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + k])); + {%- endif %} + } + + if constexpr (row == 0) { + {%- if tail_n %} + if (col < rCOLS) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, load_size); + vb[col] = at::vec::convert<{{compute_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + // Convert VLEN int8 elements to int32, and then fp32 + auto b32 = at::vec::convert_to_int32(B + k * ldb + col * VLEN, load_size); + vb[col] = at::vec::convert(b32); + {%- else %} + vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN, load_size); + {%- endif %} + } else { + vb[col] = Vectorized(0.0f); + } + + {%- else %} + + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + k * ldb + col * VLEN, VLEN); + vb[col] = at::vec::convert<{{compute_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + // Convert VLEN int8 elements to int32, and then fp32 + auto b32 = at::vec::convert_to_int32(B + k * ldb + col * VLEN); + if constexpr (prefetch) { + _mm_prefetch(B + (k + {{block_k}}) * ldb + col * VLEN, _MM_HINT_T0); + } + vb[col] = at::vec::convert(b32); + {%- else %} + vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); + {%- endif %} + {%- endif %} + + } + + constexpr int idx = row * COLS + col; + {%- if tail_n %} + if (col < rCOLS) { + vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); + } + {%- else %} + vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); + {%- endif %} + }; + + for (int k = 0; k < K; ++k) { + c10::ForcedUnroll{}(compute, k); + } + + // store to C + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + {%- if tail_n %} + int store_size = (col == rCOLS - 1 && ntail != 0) ? ntail : VLEN; + if (col < rCOLS) { + vc[i].store(C + row * ldc + col * VLEN, store_size); + } + {%- else %} + vc[i].store(C + row * ldc + col * VLEN); + {%- endif %} + }; + c10::ForcedUnroll{}(storec); + +{%- else %} + // Use 2 implementations for the transposed B: + // First implementation: + // Transpose first and then perform outer product calculation in sub-blocks, + // which introduces an additional transpose overhead of [K, N] compared to the non-transpose version. + // Second implementation: + // Directly perform inner product calculation in sub-blocks, + // which introduces an additional vector reduction of [M, N] compared to the non-tranpose version. + // Therefore, when M * N / (K * N) is large, the first implementation has better performance. + {%- if tail_n %} + if (K % Vectorized::size() == 0 && N % Vectorized::size() == 0 && 24 * BLOCK_M > K) { + {%- else %} + if (K % Vectorized::size() == 0 && 24 * BLOCK_M > K) { + {%- endif %} + // First implementation: + constexpr auto VLEN = Vectorized::size(); + constexpr auto ROWS = BLOCK_M; + constexpr auto COLS = BLOCK_N / VLEN; + int _K = K / VLEN; + Vectorized va; + at::vec::VectorizedN<{{compute_t}}, VLEN> vb; + at::vec::VectorizedN<{{compute_t}}, ROWS*COLS> vc; + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); + } else { + vc[i] = Vectorized(0.0f); + } + }; + c10::ForcedUnroll{}(loadc); + auto unroll_loadB = [&](auto i, const {{input2_t}}* {{restrict_keyword}} src_ptr) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(src_ptr + i * ldb, VLEN); + vb[i] = at::vec::convert<{{compute_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + auto b32 = at::vec::convert_to_int32(src_ptr + i * ldb, VLEN); + vb[i] = at::vec::convert(b32); + {%- else %} + vb[i] = Vectorized::loadu(src_ptr + i * ldb, VLEN); + {%- endif %} + }; + auto compute_trans = [&, COLS](auto i, int k) { + constexpr int row = i % ROWS; + constexpr int col = i / ROWS; + constexpr int e_col = col * VLEN; + int idk = k * VLEN; + if constexpr (row == 0) { + c10::ForcedUnroll{}(unroll_loadB, B + e_col * ldb + idk); + at::vec::transpose_block(vb); + } + constexpr int idx = row * COLS + col; + {{kernel.unroll_pragma(16)}} + for (int j = 0; j < VLEN; j++) { + {%- if alpha != 1 %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j]) * {{alpha}}); + {%- else %} + va = Vectorized(static_cast<{{compute_t}}>(A[row * lda + idk + j])); + {%- endif %} + vc[idx] = at::vec::fmadd(va, vb[j], vc[idx]); + } + }; + for (int k = 0; k < _K; ++k) { + c10::ForcedUnroll{}(compute_trans, k); + } + // store to C + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i].store(C + row * ldc + col * VLEN); + }; + c10::ForcedUnroll{}(storec); + } else { + // Second implementation + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + constexpr auto VLEN = VectorizedIn::size(); + {%- else %} + constexpr auto VLEN = Vectorized::size(); + {%- endif %} + int _K = (K + VLEN - 1) / VLEN; + // sub-block size of BLOCK_N and BLOCK_M + constexpr int sM = {{sub_block_m}}; + constexpr int sN = {{sub_block_n}}; + {%- if tail_n %} + int bN = (N + sN - 1) / sN; + {%- else %} + constexpr int bN = (BLOCK_N + sN - 1) / sN; + {%- endif %} + constexpr int bM = (BLOCK_M + sM - 1) / sM; + + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + at::vec::VectorizedN<{{compute_t}}, 2> va; + at::vec::VectorizedN<{{compute_t}}, 2 * sN> vb; + {%- else %} + at::vec::Vectorized<{{compute_t}}> va; + at::vec::VectorizedN<{{compute_t}}, sN> vb; + {%- endif %} + at::vec::VectorizedN<{{compute_t}}, sN * sM> vmid; + + {%- if tail_n %} + int ntail = N % sN; + {%- else %} + constexpr int ntail = BLOCK_N % sN; + {%- endif %} + constexpr int mtail = BLOCK_M % sM; + int ktail = K % VLEN; + + auto compute_trans = [&](int m, int n, int k) { + {%- if tail_n %} + int e_n = (n == bN - 1 && ntail != 0) ? (N - n * sN) : sN; + {%- else %} + int e_n = (n == bN - 1 && ntail != 0) ? (BLOCK_N - n * sN) : sN; + {%- endif %} + int e_m = (m == bM - 1 && mtail != 0) ? (BLOCK_M - m * sM) : sM; + int e_k = (k == _K - 1 && ktail != 0) ? (K - k * VLEN) : VLEN; + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto b = VectorizedIn::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k); + std::tie(vb[2 * i], vb[2 * i + 1]) = at::vec::convert_to_float<{{input_t}}>(b); + {%- elif input2_dtype == torch.int8 %} + auto b32 = at::vec::convert_to_int32(B + (sN * n + i) * ldb + k * VLEN, e_k); + vb[i] = at::vec::convert(b32); + {%- else %} + vb[i] = Vectorized::loadu(B + (sN * n + i) * ldb + k * VLEN, e_k); + {%- endif %} + } + + {{kernel.unroll_pragma(sub_block_m)}} + for (int s = 0; s < e_m; s++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + auto a = VectorizedIn::loadu(A + (sM * m + s) * lda + k * VLEN, e_k); + std::tie(va[0], va[1]) = at::vec::convert_to_float<{{input_t}}>(a); + {%- elif input2_dtype == torch.int8 %} + auto a32 = at::vec::convert_to_int32(A + (sM * m + s) * lda + k * VLEN, e_k); + va = at::vec::convert(a32); + {%- else %} + va = Vectorized::loadu(A + (sM * m + s) * lda + k * VLEN, e_k); + {%- endif %} + + {%- if alpha != 1 %} + va = va * Vectorized({{alpha}}); + {%- endif %} + if (k == 0) { + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], Vectorized(0.0f)); + vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]); + {%- else %} + vmid[sN * s + i] = at::vec::fmadd(va, vb[i], Vectorized(0.0f)); + {%- endif %} + } + } else { + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + {%- if input2_dtype in [torch.bfloat16, torch.float16] %} + vmid[sN * s + i] = at::vec::fmadd(va[0], vb[2 * i], vmid[sN * s + i]); + vmid[sN * s + i] = at::vec::fmadd(va[1], vb[2 * i + 1], vmid[sN * s + i]); + {%- else %} + vmid[sN * s + i] = at::vec::fmadd(va, vb[i], vmid[sN * s + i]); + {%- endif %} + } + } + } + + // store to C + if (k == _K - 1) { + {{kernel.unroll_pragma(sub_block_m)}} + for (int s = 0; s < e_m; s++) { + {{kernel.unroll_pragma(sub_block_n)}} + for (int i = 0; i < e_n; i++) { + auto v = at::vec::vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, vmid[sN * s + i]); + if constexpr (accum) { + auto c = *(C + (sM * m + s) * ldc + sN * n + i); + *(C + (sM * m + s) * ldc + sN * n + i) = c + v; + } else { + *(C + (sM * m + s) * ldc + sN * n + i) = v; + } + } + } + } + }; + + for (int n = 0; n < bN; ++n) { + for (int m = 0; m < bM; ++m) { + for (int k = 0; k < _K; ++k) { + compute_trans(m, n, k); + } + } + } + } +{%- endif %} +} +""" + + # set trans_b to generate gemm that supports transposed B matrix + # set tail_n to support the tail of N + # TODO add trans_b support for other micro gemms + # and move setting of trans_b to the init of CppMicroGemm + def __init__( + self, + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + tail_n=False, + trans_b=False, + ) -> None: + super().__init__( + name, + input_dtype, + input2_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha, + ) + self.tail_n = tail_n + # trans_b is only supported on platforms that + # support avx512 or avx2 since transpose_block is + # only implemented on these platforms + if trans_b: + vec_isa = pick_vec_isa() + assert issubclass(vec_isa.__class__, VecAVX512) or issubclass( + vec_isa.__class__, VecAVX2 + ) + self.trans_b = trans_b + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + "trans_b": False, + "tail_n": False, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + if self.trans_b: + # TODO supports tuning of sub_block_m/sub_block_n + # to get better performance for specific shapes + sub_block_m = min(1, self.register_blocking.block_m) + sub_block_n = min(4, self.register_blocking.block_n) + # update options to generate kernel with trans_b and sub-block size + options.update( + { + "trans_b": self.trans_b, + "sub_block_m": sub_block_m, + "sub_block_n": sub_block_n, + } + ) + result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + # update options to generate the kernel for the tail of N + if self.tail_n: + options.update( + { + "tail_n": self.tail_n, + } + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + +# extra check for CppMicroGemmAMX +def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): + vnni_size = 4 if config.input_dtype in [torch.uint8, torch.int8] else 2 + return k % vnni_size == 0 and alpha == 1 + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [(32, 32, 64), (48, 16, 64)], + input_dtype=torch.int8, + input2_dtype=torch.int8, + output_dtype=torch.int32, + compute_dtype=torch.int32, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.bfloat16, + input2_dtype=torch.int8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.bfloat16, + output_dtype=torch.float, + extra_check=check_amx_extra, + ), + *generate_gemm_config( + VecAMX, + [(32, 32, 64), (48, 16, 64)], + input_dtype=torch.uint8, + input2_dtype=torch.int8, + output_dtype=torch.int32, + compute_dtype=torch.int32, + extra_check=check_amx_extra, + ), +) +class CppMicroGemmAMX(CppMicroGemm): + """ + This class generates the code for micro gemm using Advanced Matrix extension (AMX) + instructions available in 4th generation Intel Xeon for compute. + It supports input types of torch.bfloat16 with fp32 output. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); +{%- if pack_vnni_B_locally %} + {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K", block_n)}} +{%- endif %} +{%- if use_cached_dequantized_B %} + // Create a stack-allocated buffer for tiles of B. + // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements. + // we cache K * {{block_n}} elements of dequantized B + {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}} + const auto buf_size = K * {{block_n}}; + auto load_dequantized_B = [&](int base_idx) { + // Load a tile of B & cache it in L1D. + {{input2_t}}* base_addr = const_cast<{{input2_t}}*>(B) + base_idx; + for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) { + {%- for vec_idx in range(0, block_n, 32) %} + {%- if (block_n - vec_idx) >= 32 %} + auto b_int8_idx_{{vec_idx}} = at::vec::Vectorized::loadu( + base_addr + idx_q + {{vec_idx}} , + static_cast(32) + ); + auto b_bf16_idx_{{vec_idx}} = at::vec::convert<{{input_t}}>(b_int8_idx_{{vec_idx}}); + b_bf16_idx_{{vec_idx}}.store(dequantized_B_buf + idx_dq + {{vec_idx}}); + {%- else %} + auto b_int8_tail = at::vec::Vectorized::loadu( + base_addr + idx_q + {{block_n - (block_n % 32)}}, + static_cast({{block_n % 32}}) + ); + auto b_bf16_tail = at::vec::convert<{{input_t}}>(b_int8_tail); + b_bf16_tail.store( + dequantized_B_buf + idx_dq + {{block_n - (block_n % 32)}}, + static_cast({{block_n % 32}}) + ); + {%- endif %} + {%- endfor %} + } + }; +{%- endif %} +// The ldb would not be block_n if N != block_n +{%- if use_cached_dequantized_B or pack_vnni_B_locally %} + const int64_t updated_ldb = {{block_n}}; +{%- else %} + const int64_t updated_ldb = ldb; +{%- endif %} + // TODO(jgong5): loop unroll for M and N + for (int64_t n = 0; n < N; n += {{block_n}}) { +{%- if pack_vnni_B_locally %} + // Pack non-constant weights into VNNI interleaved format in packed_B_buf + at::vec::pack_vnni2(B + n, packed_B_buf, ldb, K, {{block_n}}); +{%- elif use_cached_dequantized_B %} + // Dequantize K * block_n int8 B elements into BF16 + load_dequantized_B(n); +{%- endif %} + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + int64_t m_tail = m; +{%- for num_rows in range(block_m, 0, -16) %} + {%- if num_rows != block_m %} + else + {%- endif %} + if (block_m >= {{num_rows}}) { + {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + amx_state, + A + m * lda, +{%- if use_cached_dequantized_B %} + dequantized_B_buf, +{%- elif pack_vnni_B_locally %} + packed_B_buf, +{%- else %} + B + n, +{%- endif %} + C + m * ldc + n, + K, + lda, + updated_ldb, + ldc, + 16 + ); + block_m -= {{num_rows}}; + m_tail += {{num_rows}}; + } +{%- endfor %} + if (block_m > 0) { + {{kernel_name}}_amx_kernel_16_{{num_columns}}( + amx_state, + A + m_tail * lda, +{%- if use_cached_dequantized_B %} + dequantized_B_buf, +{%- elif pack_vnni_B_locally %} + packed_B_buf, +{%- else %} + B + n, +{%- endif %} + C + m_tail * ldc + n, + K, + lda, + updated_ldb, + ldc, + block_m + ); + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" + +template +inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + AMXState& amx_state, + const {{input_t}}* {{restrict_keyword}} A, +{%- if use_cached_dequantized_B %} + const {{input_t}}* {{restrict_keyword}} B, +{%- else %} + const {{input2_t}}* {{restrict_keyword}} B, +{%- endif %} + {{output_t}}* {{restrict_keyword}} C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + uint8_t tilecfg_rows +) { + // TODO(jgong5): add prefetch hint for A, B, C + auto loadconfig = [](const amx_tilecfg& cfg) { + _tile_loadconfig(&cfg); + }; + const auto last_k_offset = K / {{block_k}} * {{block_k}}; + const auto tail_k_size = K - last_k_offset; + if C10_LIKELY (last_k_offset > 0) { + amx_state.configure(tilecfg_rows, 64, {{num_rows}} / 16, {{num_columns}}, loadconfig); + } else { + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + } + auto load_c = [&]() { +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_loadd({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} +{%- endfor %} + }; + auto zero_c = [&]() { +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_zero({{tile_idx}}); + {%- endfor %} +{%- endfor %} + }; + + if constexpr (accum) { + load_c(); + } else { + zero_c(); + } + + auto compute = [&](int k) { +{%- set tile_offset_a = num_rows // 16 * num_columns %} +{%- set tile_offset_b = tile_offset_a + num_rows // 16 %} +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx_a = tile_offset_a + tile_row %} + {%- set tile_idx_b = tile_offset_b + tile_col %} + {%- set tile_idx_c = tile_row * num_columns + tile_col %} + {%- if tile_col == 0 %} + _tile_stream_loadd({{tile_idx_a}}, A + {{tile_row * 16}} * lda + k, lda * sizeof({{input_t}})); + {%- endif %} + {%- if tile_row == 0 %} + _tile_loadd({{tile_idx_b}}, B + k * ldb + {{tile_col * 16 * vnni_size}}, ldb * {{vnni_size}} * sizeof({{input_t}})); + {%- endif %} + {%- if int8_gemm %} + {%- if input_dtype == torch.int8 %} + _tile_dpbssd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- else %} + _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- endif %} + {%- else %} + _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); + {%- endif %} + {%- endfor %} +{%- endfor %} + }; + + {{kernel.unroll_pragma(4)}} + for (int k = 0; k < last_k_offset; k += {{block_k}}) { + compute(k); + } + + auto store_c = [&]() { + // store to C +{%- for tile_row in range(num_rows // 16) %} + {%- for tile_col in range(num_columns) %} + {%- set tile_idx = tile_row * num_columns + tile_col %} + _tile_stored({{tile_idx}}, C + {{tile_row * 16}} * ldc + {{tile_col * 16}}, ldc * sizeof({{output_t}})); + {%- endfor %} +{%- endfor %} + }; + + // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead + if C10_UNLIKELY (tail_k_size > 0) { + if C10_LIKELY (last_k_offset > 0) { + store_c(); + amx_state.configure(tilecfg_rows, tail_k_size * sizeof({{input_t}}), {{num_rows}} / 16, {{num_columns}}, loadconfig); + load_c(); + } + compute(last_k_offset); + } + + store_c(); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + block_m, block_n, block_k = self.register_blocking + assert block_m % 16 == 0, "Only support block_m % 16 == 0 for AMX" + assert block_n % 16 == 0, "Only support block_n % 16 == 0 for AMX" + if self.input_dtype in [torch.uint8, torch.int8]: + assert block_k == 64, "Only support block_k = 64 for AMX INT8" + else: + assert block_k == 32, "Only support block_k = 32 for AMX Bfloat16/Float16" + num_columns = block_n // 16 + options = { + "declare_kernel": self.get_kernel_declaration(), + "use_cached_dequantized_B": ( + self.input_dtype == torch.bfloat16 + and self.input2_dtype in [torch.int8, torch.uint8] + ), + "kernel": kernel, + "block_m": block_m, + "block_n": block_n, + "block_k": block_k, + "num_columns": num_columns, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + result = "" + for num_rows in range(block_m, 0, -16): + amx_kernel_options = {**options, "num_rows": num_rows} + result += KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + amx_kernel_options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + def codegen_init( + self, + kernel: CppTemplateKernel, + ) -> str: + return "AMXState amx_state;" + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "amx_state.release([]() { _tile_release(); });" + + def get_kernel_extra_args_declare(self) -> str: + return "AMXState& amx_state," + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + return ["amx_state,"] + + def get_b_layout(self): + if self.input_dtype in [torch.uint8, torch.int8]: + return LayoutType.VNNI4 + else: + return LayoutType.VNNI2 + + +# extra check for CppMicroBrgemm +def check_brgemm_extra(config, m, n, k, alpha, num_threads, **kwargs): + assert config.input_dtype == torch.half and config.output_dtype == torch.float + vnni_size = 2 + # use brgemm for Half when amx_fp16 is supported + return torch.cpu._is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1 + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], + input_dtype=torch.half, + output_dtype=torch.float, + extra_check=check_brgemm_extra, + ), +) +class CppMicroBrgemm(CppMicroGemm): + """ + This class generates the code for micro gemm using oneDNN brgemm. + It supports input types of torch.half. + """ + + TEMPLATE_ENTRY = r""" +#include +{{declare_kernel}} { +{%- if pack_vnni_B_locally %} + {{template.codegen_allocate_weight_buffer("packed_B_buf", input2_t, "K * N")}} + at::vec::pack_vnni2(B, packed_B_buf, ldb, K, N); +{%- endif %} + at::native::cpublas::brgemm( + M, N, K, + {%- if pack_vnni_B_locally %} + lda, N, ldc, + {%- else %} + lda, ldb, ldc, + {%- endif %} + accum, + A, + {%- if pack_vnni_B_locally %} + packed_B_buf, + {%- else %} + B, + {%- endif %} + C); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + "restrict_keyword": get_restrict_keyword(), + **self.get_common_options(), + } + result = "" + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + def codegen_finalize( + self, + kernel: CppTemplateKernel, + ) -> str: + return "at::native::cpublas::brgemm_release();" + + def get_b_layout(self): + assert self.input_dtype == torch.half and torch.cpu._is_amx_fp16_supported() + return LayoutType.VNNI2 + + +def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs): + if alpha != 1: + return False + q_group_size = kwargs.get("q_group_size", None) + assert q_group_size is not None + if ( + q_group_size < 32 + or k % q_group_size != 0 + or config.register_blocking.block_k > q_group_size + ): + return False + return k % config.register_blocking.block_k == 0 and n % 64 == 0 + + +@register_micro_gemm( + # TODO: support float/half input + *generate_gemm_config( + VecAVX512, + [(4, 64, 32), (4, 64, 64), (4, 64, 128)], + input_dtype=torch.bfloat16, + input2_dtype=torch.uint8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_woq_int4_extra, + ), +) +class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): + """ + This class generates the code for WoQ int4 micro gemm using AVX512 intrinsics. + It is based on the corresponding ATen kernel. + Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2] + Shape of packed ScalesAndZeros = [K // group_size, N, 2] + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + auto group_size = q_group_size; + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + for (int64_t n = 0; n < N; n += {{block_n}}) { + if (block_m == {{block_m}}) { + {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>( + A + m * lda, + reinterpret_cast(B) + n * ldb, + C + m * ldc + n, + K, + lda, + /* ldb */ {{block_n}} / 2, + ldc, + group_size, + ScaleAndZeros + n * 2, + lds, + k_start + ); + } else { + switch (block_m) { + {%- for b in range(block_m - 1, 0, -1) %} + case {{b}}: + {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( + A + m * lda, + reinterpret_cast(B) + n * ldb, + C + m * ldc + n, + K, + lda, + /* ldb */ {{block_n}} / 2, + ldc, + group_size, + ScaleAndZeros + n * 2, + lds, + k_start + ); + break; + {%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); + } + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" +inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) { + return (k_start + index) % group_size == 0; +} + +inline __m128i {{kernel_name}}_convert_int4_to_int8(const uint8_t* data) { + __m128i tmp = _mm_loadu_si64((const __m128i*)data); + __m128i bytes = _mm_cvtepu8_epi16(tmp); + const __m128i lowMask = _mm_set1_epi8(0xF); + __m128i high = _mm_andnot_si128(lowMask, bytes); + __m128i low = _mm_and_si128(lowMask, bytes); + high = _mm_slli_epi16(high, 4); + bytes = _mm_or_si128(low, high); + return bytes; +} + +template +inline void {{kernel_name}}_kernel( + const {{input_t}}* {{restrict_keyword}} A, + const uint8_t* {{restrict_keyword}} B, + {{output_t}}* {{restrict_keyword}} C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t q_group_size, + const at::BFloat16* {{restrict_keyword}} ScaleAndZeros, + int64_t lds, // leading dimension of ScaleAndZeros + int64_t k_start) { + constexpr int BLOCK_K = {{block_k}}; + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + const int PREFETCH_SIZE_K = 16 * 4; + const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K; + + // number of blocks on K + const int KB = K / BLOCK_K; + + __m512 va; + __m512 vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 scale[COLS]; + __m512 zero[COLS]; + + // Lookup table to de-quantize int4 values to bf16. + // Values are dequantized as truly int4 [-8, 7] range; + // + // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + static const __m512 lut = _mm512_set_ps( + 7.0f, 6.0f, 5.0f, 4.0f, + 3.0f, 2.0f, 1.0f, 0.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + -5.0f, -6.0f, -7.0f, -8.0f); + + // index for transpose + static const __m512i idx1 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, + 14, 12, 10, 8, 6, 4, 2, 0); + static const __m512i idx2 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, + 15, 13, 11, 9, 7, 5, 3, 1); + + // load scale and zero point + auto load_scale_and_zeros = [&](int i, int _kb) { + // load 2x bfloat16 vector + __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i)); + if (_kb + PREFETCH_SIZE_KB < KB) { + _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); + } + + // convert to 2x f32 vector + __m512 a, b; + at::vec::cvtbf16_fp32(t, a, b); + + // transpose scale_and_zero from {16, 2} to {2, 16} + // inputs: + // a: {s0, z0, s1, z1, ..., s7, z7} + // b: {s8, z8, s9, z9, ..., s15, z15} + // output: + // scale: {s0, s1, s2, ..., s15} + // zero: {z0, z1, z2, ..., z15} + scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b); + zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b); + }; + + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16); + } else { + vc[i] = _mm512_setzero_ps(); + } + }; + c10::ForcedUnroll{}(loadc); + + auto compute = [&, COLS](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + float aa = static_cast(A[row * lda + k]); + if (k + PREFETCH_SIZE_K < K) { + _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } + va = _mm512_set1_ps(aa); + } + + if constexpr (row == 0) { + if constexpr (COLS == 4) { + // when BLOCK_N = 64, handle each row at a time + // to reduce de-quantize overhead. + if constexpr (col == 0) { + __m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb)); + if (k + PREFETCH_SIZE_K < K) { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0); + } + + __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4)); + vb[0] = _mm512_permutexvar_ps(b32, lut); + vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]); + vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); + vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]); + + b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1)); + vb[1] = _mm512_permutexvar_ps(b32, lut); + vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]); + vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut); + vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]); + } + } else { + __m128i b8 = {{kernel_name}}_convert_int4_to_int8(B + k * ldb + col * 8); + __m512i b32 = _mm512_cvtepu8_epi32(b8); + vb[col] = _mm512_permutexvar_ps(b32, lut); + vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]); + } + } + + constexpr int idx = row * COLS + col; + vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]); + }; + + for (int k = 0, kb = 0; k < K; ++k) { + if ({{kernel_name}}_is_block_start(k, k_start, q_group_size)) { + c10::ForcedUnroll{}(load_scale_and_zeros, kb++); + } + c10::ForcedUnroll{}(compute, k); + } + + //store to C + auto storec = [&, COLS](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(C + row * ldc + col * 16, vc[i]); + }; + c10::ForcedUnroll{}(storec); +} +""" + + def get_kernel_extra_args_declare(self) -> str: + return ( + "const int64_t q_group_size,\n" + " const at::BFloat16* __restrict__ ScaleAndZeros,\n" + " const int64_t lds,\n" + " int64_t k_start," + ) + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + assert "kernel" in kwargs + assert "qscale_and_zeros" in kwargs + kernel = kwargs["kernel"] + qscale_and_zeros = kwargs["qscale_and_zeros"] + return [ + "group_size,", + f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),", + "N * 2,", # lds + "k_start,", + ] + + def is_woq_int4(self): + return True + + +@register_micro_gemm( + *generate_gemm_config( + VecAMX, + [ # (block_m, block_n, block_k) + (16, 32, 32), + (32, 32, 32), + ], + input_dtype=torch.bfloat16, + input2_dtype=torch.uint8, + output_dtype=torch.float, + compute_dtype=torch.float, + extra_check=check_amx_extra, + ), +) +class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): + """ + This class generates the code for WoQ int4 micro gemm using AMX intrinsics, + which are available on 4th and newer generations of Intel Xeon. + Shape of packed weight = [N // 32, K, 16], viewed as [N, K // 2] + Shape of packed ScalesAndZeros = [K // group_size, N, 2] + Reuse TEMPLATE_KERNEL of CppMicroGemmAMX. + """ + + TEMPLATE_ENTRY = r""" +inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) { + return (k_start + index) % group_size == 0; +} + +{{declare_kernel}} { + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); + {{kernel.assert_function}}({{block_n}} == 32, "block_n must be 32 for WOQ int4"); + + // Create a stack-allocated buffer for tiles of B. + // Except maybe for the tail-case, an AMX tile of B has 16x32 BF16 elements. + // we cache K * {{block_n}} elements of dequantized B + {{template.codegen_allocate_weight_buffer("dequantized_B_buf", input_t, "K", block_n)}} + + constexpr int BLOCK_K = {{block_k}}; + constexpr int64_t BLOCK_N = {{block_n}}; + constexpr int COLS = BLOCK_N / 16; + const int PREFETCH_SIZE_K = 16 * 4; + const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K; + const int KB = K / BLOCK_K; + + __m512i b32[COLS * 2]; + __m512 vb[COLS * 2]; + __m512 scale[COLS]; + __m512 zero[COLS]; + + // Lookup table to de-quantize int4 values to bf16. + // Values are dequantized as truly int4 [-8, 7] range; + // + // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + static const __m512 lut = _mm512_set_ps( + 7.0f, 6.0f, 5.0f, 4.0f, + 3.0f, 2.0f, 1.0f, 0.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + -5.0f, -6.0f, -7.0f, -8.0f); + + // index for transpose + static const __m512i idx1 = _mm512_set_epi32( + 30, 28, 26, 24, 22, 20, 18, 16, + 14, 12, 10, 8, 6, 4, 2, 0); + static const __m512i idx2 = _mm512_set_epi32( + 31, 29, 27, 25, 23, 21, 19, 17, + 15, 13, 11, 9, 7, 5, 3, 1); + + // Indices for VNNI layout conversion + __m512i idx_low = _mm512_set_epi32( + 0x17, + 0x07, + 0x16, + 0x06, + 0x15, + 0x05, + 0x14, + 0x04, + 0x13, + 0x03, + 0x12, + 0x02, + 0x11, + 0x01, + 0x10, + 0x00); + __m512i idx_high = _mm512_set_epi32( + 0x1f, + 0x0f, + 0x1e, + 0x0e, + 0x1d, + 0x0d, + 0x1c, + 0x0c, + 0x1b, + 0x0b, + 0x1a, + 0x0a, + 0x19, + 0x09, + 0x18, + 0x08); + + // load scale and zero point + auto load_scale_and_zeros = [&](int i, int _kb) { + // load 2x bfloat16 vector + __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i)); + if (_kb + PREFETCH_SIZE_KB < KB) { + _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); + } + + // convert to 2x f32 vector + __m512 a, b; + at::vec::cvtbf16_fp32(t, a, b); + + // transpose scale_and_zero from {16, 2} to {2, 16} + // inputs: + // a: {s0, z0, s1, z1, ..., s7, z7} + // b: {s8, z8, s9, z9, ..., s15, z15} + // output: + // scale: {s0, s1, s2, ..., s15} + // zero: {z0, z1, z2, ..., z15} + scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b); + zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b); + }; + + // Dequantize a B block of 2 * block_n into bf16 + // So, it handles k and k+1 at the same time + auto dequantize_B = [&](int n) { + constexpr int64_t ldb_int4 = BLOCK_N / 2; // 16 + for (int k = 0, kb = 0; k < K; k += 2) { + // Since block_k must be 32 for AMX microkernels, k_start may not be + // a multiple of q_group_size. In that case, we need to load scales + // and zero points immediately when k == 0 here + if ({{kernel_name}}_is_block_start(k, k_start, q_group_size) || k == 0) { + c10::ForcedUnroll{}(load_scale_and_zeros, kb++); + } + + // load 256 bits = 64 elements in int4 + if (k + PREFETCH_SIZE_K < K) { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0); + } + + __m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4)); + b32[0] = _mm512_cvtepu8_epi32(b4); + b32[1] = _mm512_srli_epi32(b32[0], 4); + vb[0] = _mm512_permutexvar_ps(b32[0] , lut); + vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]); + vb[1] = _mm512_permutexvar_ps(b32[1], lut); + vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]); + + b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4)); + b32[0 + COLS] = _mm512_cvtepu8_epi32(b4); + b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4); + vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut); + vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]); + vb[1 + COLS] = _mm512_permutexvar_ps(b32[1 + COLS], lut); + vb[1 + COLS] = _mm512_fmadd_ps(vb[1 + COLS], scale[1], zero[1]); + + for (int i = 0; i < COLS; i++) { + // convert to VNNI + auto low = _mm512_permutex2var_ps(vb[i], idx_low, vb[i + COLS]); + auto high = _mm512_permutex2var_ps(vb[i], idx_high, vb[i + COLS]); + // convert lower 16 float32 values to bfloat16 + auto v0_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(low)); + // convert higher 16 float32 values to bfloat16 + auto v1_bf16 = reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(high)); + // combine the lower 16 and higher 16 bfloat16 values + auto v = _mm512_castsi256_si512(v0_bf16); + v = _mm512_inserti64x4(v, v1_bf16, 1); + // store the VNNI format bfloat16 values + {{input_t}}* addr = dequantized_B_buf + k * 32 + (i % 2) * 32; + _mm512_storeu_si512(addr, v); + } + } + }; + + for (int64_t n = 0; n < N; n += {{block_n}}) { + // Dequantize K * block_n int8 B elements into BF16 + dequantize_B(n); + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + int64_t m_tail = m; + {%- for num_rows in range(block_m, 0, -16) %} + {%- if num_rows != block_m %} + else + {%- endif %} + if (block_m >= {{num_rows}}) { + {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}( + amx_state, + A + m * lda, + dequantized_B_buf + n * K, + C + m * ldc + n, + K, + lda, + {{block_n}}, + ldc, + 16 + ); + block_m -= {{num_rows}}; + m_tail += {{num_rows}}; + } + {%- endfor %} + if (block_m > 0) { + {{kernel_name}}_amx_kernel_16_{{num_columns}}( + amx_state, + A + m_tail * lda, + dequantized_B_buf + n * K, + C + m_tail * ldc + n, + K, + lda, + {{block_n}}, + ldc, + block_m + ); + } + } // for m + } // for n +} +""" + + def get_kernel_extra_args_declare(self) -> str: + return ( + "AMXState& amx_state,\n" + " const int64_t q_group_size,\n" + " const c10::BFloat16* __restrict__ ScaleAndZeros,\n" + " const int64_t lds,\n" + " int64_t k_start," + ) + + def get_kernel_extra_args(self, **kwargs) -> list[str]: + assert "kernel" in kwargs + assert "qscale_and_zeros" in kwargs + kernel = kwargs["kernel"] + qscale_and_zeros = kwargs["qscale_and_zeros"] + return [ + "amx_state,", + "group_size,", + f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),", + "N * 2,", # lds + "k_start,", + ] + + def is_woq_int4(self): + return True + + +def create_micro_gemm( + name, + m, + n, + k, + input_dtype, + input2_dtype, + output_dtype=None, + compute_dtype=None, + alpha=1, + num_threads=-1, + use_ref=True, + q_group_size=None, +) -> Optional[CppMicroGemm]: + """ + Based on the provided info, try to find the config of the micro-kernel that would + deliver the best performance in terms of lower latency for this case. + """ + + def create_from_config(cls, config: CppMicroGemmConfig): + return cls( + name, + config.input_dtype, + config.input2_dtype, + config.output_dtype, + config.compute_dtype, + config.register_blocking, + alpha, + ) + + def skip_amx_kernel_for_woq(config, dynamic_M, micro_gemm_cls): + # For WoQ GEMM, AMX micro-kernel may not perform well if m is small. + # Exception: for dynamic shapes, we consider using the AMX micro-kernel. + if ( + dynamic_M + or input_dtype != torch.bfloat16 + or input2_dtype not in [torch.int8, torch.uint8] + ): + return False + # For WOQ INT8, use AMX for m >= block_m + # For WOQ INT4, use AMX for m >= 5 + block_m, *_ = config.register_blocking + is_woq_int4 = micro_gemm_cls == CppMicroGemmWoQInt4Amx + m_threshold = 5 if is_woq_int4 else block_m + return m < m_threshold + + assert isinstance(n, int) or n.is_number, n + assert isinstance(k, int) or k.is_number, k + from ..utils import has_free_symbols + + dynamic_M = has_free_symbols((m,)) + m = V.graph.sizevars.size_hint(m, fallback=1) if dynamic_M else m + assert isinstance(m, int) or m.is_number, m + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = output_dtype + if num_threads < 0: + num_threads = parallel_num_threads() + vec_isa = pick_vec_isa() + matched_configs = [] + for cls, configs in micro_gemm_configs.items(): + for config in configs: + if not issubclass(vec_isa.__class__, config.vec_isa_cls): + continue + if ( + config.input_dtype == input_dtype + and config.compute_dtype == compute_dtype + and config.input2_dtype == input2_dtype + and config.output_dtype == output_dtype + # The output_dtype here is the output dtype of the micro-kernel. + # In some cases, the actual output dtype of the op for which the micro-kernel + # is being created would be same as that of the activation, but the micro-kernels + # compute output in Float/int32, which is converted in the GEMM template. This is + # subject to change in the future. + ): + if config.extra_check is not None and not config.extra_check( + config, + m, + n, + k, + alpha, + num_threads, + dynamic_M=dynamic_M, + q_group_size=q_group_size, + ): + continue + block_m, block_n, block_k = config.register_blocking + if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq( + config, dynamic_M, cls + ): + continue + # Criteria on the ranking of configurations + # 1. ISA: AMX > VEC + # 2. Dividable by block sizes (block_m, block_n, block_k) + # 3. Number of mxn blocks is large enough to occupy all the threads + # 4. Register blocks are larger + isa_score = 0 + if config.vec_isa_cls == VecAMX: + isa_score += 1 + dividable_score = 0 + if m % block_m == 0: + dividable_score += 1 + if n % block_n == 0: + dividable_score += 1 + if k % block_k == 0: + dividable_score += 1 + occupancy_score = 0 + n_blocks = (n + block_n - 1) // block_n + total_mxn_blocks = n_blocks * ((m + block_m - 1) // block_m) + if n_blocks >= num_threads: + occupancy_score += 1 + if total_mxn_blocks >= num_threads: + occupancy_score += 1 + register_bytes = ( + block_m * block_n * config.compute_dtype.itemsize + + (block_m * block_k + block_k * block_n) + * config.input_dtype.itemsize + ) + matched_configs.append( + ( + (isa_score, dividable_score, occupancy_score, register_bytes), + cls, + config, + ) + ) + if len(matched_configs) == 0: + if use_ref: + return CppMicroGemmRef( + name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha + ) + else: + return None + # TODO(jgong5): allow autotuning on choices of configs + return create_from_config(*max(matched_configs, key=operator.itemgetter(0))[1:]) diff --git a/phivenv/Lib/site-packages/torch/_inductor/comm_analysis.py b/phivenv/Lib/site-packages/torch/_inductor/comm_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..a307209ce9307a8d49c141499d01c41e9ef88b08 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/comm_analysis.py @@ -0,0 +1,264 @@ +import functools +import math +from enum import IntEnum + +import sympy + +import torch + +from . import ir +from .utils import get_dtype_size, sympy_product +from .virtualized import V + + +class NCCL_COLL(IntEnum): + ALL_REDUCE = 0 + ALL_GATHER = 1 + REDUCE_SCATTER = 2 + + +class NVIDIA_GPU_TYPE(IntEnum): + VOLTA = 0 + AMPERE = 1 + HOPPER = 2 + + +@functools.lru_cache +def get_gpu_type() -> NVIDIA_GPU_TYPE: + gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" + if "V100" in gpu_info: + return NVIDIA_GPU_TYPE.VOLTA + elif "A100" in gpu_info: + return NVIDIA_GPU_TYPE.AMPERE + elif "H100" in gpu_info: + return NVIDIA_GPU_TYPE.HOPPER + else: + # for other gpu types, assume Ampere + return NVIDIA_GPU_TYPE.AMPERE + + +def get_collective_type(node: ir.IRNode) -> NCCL_COLL: + if not isinstance(node, ir._CollectiveKernel): + raise ValueError(f"node is not a collective kernel: {node}") + + kernel_name = node.python_kernel_name + assert kernel_name is not None + if "all_reduce" in kernel_name: + return NCCL_COLL.ALL_REDUCE + elif "all_gather" in kernel_name: + return NCCL_COLL.ALL_GATHER + elif "reduce_scatter" in kernel_name: + return NCCL_COLL.REDUCE_SCATTER + else: + raise ValueError(f"Unsupported collective kernel: {kernel_name}") + + +def get_collective_input_size_bytes(node: ir.IRNode) -> int: + sz_bytes = 0 + for inp in node.inputs: # type: ignore[attr-defined] + numel = sympy_product(inp.layout.size) + if isinstance(numel, sympy.Integer): + # For ease of testing + numel = int(numel) + else: + numel = V.graph.sizevars.size_hint(numel, fallback=0) + sz_bytes += numel * get_dtype_size(inp.layout.dtype) + return sz_bytes + + +def get_collective_group_size(node: ir.IRNode) -> int: + if type(node) == ir._CollectiveKernel: + from torch.distributed.distributed_c10d import _get_group_size_by_name + + return _get_group_size_by_name(node.constant_args[-1]) + else: + raise TypeError(f"Unsupported collective type: {node}") + + +#################################################################################################################### +# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +#################################################################################################################### + + +class NCCL_HW(IntEnum): + NVLINK = 0 + PCI = 1 + NET = 2 + + +class NCCL_ALGO(IntEnum): + TREE = 0 + RING = 1 + + +class NCCL_PROTO(IntEnum): + # The ordering and enum values here matches original in + # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28 + # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990 + LL = 0 # Low-latency + # LL128 = 1 # Low-latency 128-byte + # SIMPLE = 2 + + +# Latencies in us +# len(NCCL_ALGO) x len(NCCL_PROTO) +# NOTE: use array instead of tensor to prevent incompatibility with fake mode +baseLat = [ + # Tree + [ + 6.8, # LL + ], + # Ring + [ + 6.6, # LL + ], +] + +# Latencies in us +# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO) +hwLat = [ + # NVLINK + [ + [0.6], # Tree (LL) + [0.6], # Ring (LL) + ], + # PCI + [ + [1.0], # Tree (LL) + [1.0], # Ring (LL) + ], + # NET + [ + [5.0], # Tree (LL) + [2.7], # Ring (LL) + ], +] + + +# LL128 max BW per channel +llMaxBws = [ + # Volta-N1/Intel-N2/Intel-N4 + [ + 39.0, + 39.0, + 20.4, + ], + # Ampere-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], + # Hopper-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], +] + + +def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: + """ + Returns estimated NCCL collective runtime in nanoseconds (ns). + + The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. + We aim to estimate the runtime as accurately as possible. + + Assumptions: + - only ring algorithm (NCCL_ALGO_RING) is used + - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used + - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + - collective is one of: allreduce, reducescatter, allgather + """ + tensor_storage_size_bytes = get_collective_input_size_bytes(node) + # Convert bytes to GB + tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 + + # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. + # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + num_gpus_per_node = 8 + group_size = get_collective_group_size(node) + nNodes = math.ceil(group_size / num_gpus_per_node) + nRanks = group_size # this is total # of gpus globally that participate in this collective op + + if nRanks <= 1: + return 0 + + # Assumes ring algorithm + nccl_algo = NCCL_ALGO.RING + nccl_proto = NCCL_PROTO.LL + coll = get_collective_type(node) + + # =============== bandwidth computation =============== + # First compute bandwidth in GB/s; then at the end, convert it to GB/ns + + bwIntra = torch._inductor.config.intra_node_bw + bwInter = torch._inductor.config.inter_node_bw + + compCapIndex = get_gpu_type() + index2 = nNodes - 1 if nNodes <= 2 else 2 + # LL: for single node, we look at GPU type; for multi-node, we look at CPU type + index1 = compCapIndex if nNodes == 1 else 0 + llMaxBw = llMaxBws[index1][index2] + + # NOTE: each step of ring algorithm is synchronized, + # and is bottlenecked by the slowest link which is the inter-node interconnect. + # hence when nNodes >= 2, bw is inter-node bandwidth. + # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc + # have this as `if nNodes <= 2` which seems wrong. Corrected it here. + bw = bwIntra if nNodes == 1 else bwInter + nChannels = 2 # Assume # channels is 2 + busBw = nChannels * bw + + # Various model refinements + busBw = min( + llMaxBw, + busBw + * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0), + ) + + if coll == NCCL_COLL.ALL_REDUCE: + nsteps = 2 * (nRanks - 1) + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nsteps = nRanks - 1 + + # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time) + ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined] + bandwidth = busBw * ratio + # Convert GB/s to GB/ns + bandwidth_GB_per_ns = bandwidth / 1e9 + + # =============== latency computation =============== + intraHw = NCCL_HW.NVLINK + + if coll == NCCL_COLL.ALL_REDUCE: + if nNodes > 1: + nInterSteps = 2 * nNodes + else: + nInterSteps = 0 + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nInterSteps = nNodes - 1 + + # First compute latency in us; then at the end, convert it to ns + latency = baseLat[nccl_algo][nccl_proto] + intraLat = hwLat[intraHw][nccl_algo][nccl_proto] + interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto] + + # Inter-node rings still have to launch nsteps * net overhead. + netOverhead = 0.0 + if nNodes > 1: + netOverhead = 1.0 # getNetOverhead(comm); + intraLat = max(intraLat, netOverhead) + latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined] + # Convert us to ns + latency_ns = latency * 1e3 + + # =============== final result =============== + transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns + return transport_ns + latency_ns + + +################################################################################################################ +# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +################################################################################################################ diff --git a/phivenv/Lib/site-packages/torch/_inductor/comm_lowering.py b/phivenv/Lib/site-packages/torch/_inductor/comm_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..c2fd0b568914bafc75e0dfaf8f6f6e77686a32bf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/comm_lowering.py @@ -0,0 +1,360 @@ +# mypy: allow-untyped-defs +import logging +from typing import cast + +import torch +import torch.utils._pytree as pytree +from torch._inductor.utils import is_symbolic +from torch.utils._ordered_set import OrderedSet + +from . import config, ir +from .virtualized import V + + +log = logging.getLogger(__name__) + + +# NOTE [lowering-time collective optimization] +# +# In collective communication libraries such as NCCL, every rank maintains +# communication buffers that are remotely accessible by some peers. Depending +# on the underlying transport, remote accessibility may be established via +# mechanisms such as ib_reg_mr, CUDA P2P, or CUDA multicast. Typically, these +# buffers are private to the communication library by default, and +# communication ops copy user data in and out of these buffers. +# +# To prevent these copies, an optimization commonly known as "user buffer +# registration" can be employed. This allows direct establishment of remote +# accessibility on user buffers, eliminating the need for copying. However, +# this optimization introduces stringent usage requirements, which are +# typically hard to satisfy without being intrusive to the user code: +# +# - Establishing remote accessibility is expensive and often done ahead of +# time. In such implementations, all ranks must agree on the set of allocations +# used for every collective op. Failing to meet this requirement can +# lead to runtime errors or even silent correctness issues. +# - Even if the collective communication library supports gracefully falling +# back to "unregistered" implementations, the fallback mechanism would nullify +# the optimization. +# - Some communication mechanisms impose stricter requirements than others. For +# example, CUDA's multicast + multi-mem instructions require all ranks to agree +# not only on the allocations used for every collective but also on the offsets +# within these allocations. +# +# To support all different mechanisms with optimal results, we aim to satisfy +# the strictest requirement for this family of optimizations - we ensures that +# every collective op invocation is guaranteed to operate on the same +# allocation, at the same offset, in every iteration. +# +# For eligible collective ops, we identify communication buffers at lowering +# time and optionally choose to lower the op to a different kernel +# (ommunication libraries like NCCL handle both registered and non-registered +# buffers transparently within the same op, though some may require different +# ops for different cases). Later, the codegen will perform "persistent +# allocation" to satisfy the aforementioned constraints, and optionally, +# perform buffer planning to optimize overall memory usage. +def can_realize_as_comm_buffer( + x: ir.TensorBox, comm_buffer_type: ir.CommBufferType +) -> bool: + """ + Check if an input can be realized as a comm buffer of the specified + `comm_buffer_type`. + """ + data = _get_data(x) + + if isinstance(data, ir.Loops): + return True + + layout = data.get_output_spec() + if isinstance(layout, ir.CommBufferLayout): + return True + + if isinstance(layout, ir.FlexibleLayout) and not is_symbolic(data.get_numel()): + return True + + return False + + +def realize_as_comm_buffer( + x: ir.TensorBox, comm_buffer_type: ir.CommBufferType, group_name: str +) -> None: + """ + Realize an input as a comm buffer of the specified `comm_buffer_type`. + + Specifically, this realizes the underlying buffer if it's still unrealized + and changes the layout of the buffer to `ir.CommBufferLayout`. + """ + x.realize() + buffer = _get_data(x) + assert isinstance(buffer, ir.Buffer) + + layout = buffer.get_output_spec() + if isinstance(layout, ir.CommBufferLayout): + return + + if not isinstance(layout, ir.FlexibleLayout): + raise AssertionError( + "A buffer can only be realized as a comm buffer if it " + f"has `FlexibleLayout` (got {layout})." + ) + + if is_symbolic(buffer.get_numel()): + raise AssertionError( + "A buffer with symbolic shape cannot be converted to " + f"a comm buffer (got {layout})." + ) + + buffer.layout = ir.CommBufferLayout( + layout=layout, + comm_buffer_type=comm_buffer_type, + group_name=group_name, + ) + + +def _get_data(x: ir.TensorBox) -> ir.IRNode: + if isinstance(x.data, ir.BaseView): + # TensorBox -> *View -> StorageBox -> IRNode + return x.data.unwrap_view().data + elif isinstance(x.data, ir.StorageBox): + # TensorBox -> StorageBox -> IRNode + return cast(ir.Buffer, x.data.data) + else: + raise AssertionError( + "Expect the data attr of a `TensorBox` to be either " + f"an `ir.BaseView` or `ir.StorageBox` (got {x.data})." + ) + + +_bufs_to_skip_wait = OrderedSet[tuple[int, str]]() + + +def mark_as_skip_wait(x: ir.IRNode) -> None: + """ + If a non-blocking collective is lowered as a blocking collective, the wait + node in the original graph becomes useless and we can skip the lowering it. + """ + _bufs_to_skip_wait.add((id(V.graph), x.get_name())) + + +def should_skip_wait(x: ir.IRNode) -> bool: + return (id(V.graph), x.get_name()) in _bufs_to_skip_wait + + +def _should_lower_as_one_shot_all_reduce( + inp: ir.TensorBox, reduce_op: str, group_name: str +): + from torch.distributed._symmetric_memory import is_symm_mem_enabled_for_group + + inp_size = inp.get_numel() * inp.get_dtype().itemsize + return ( + config._collective.auto_select + and is_symm_mem_enabled_for_group(group_name) + and can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM) + and reduce_op in ("sum",) + and inp_size <= config._collective.one_shot_all_reduce_threshold_bytes + ) + + +def _one_shot_all_reduce(inp: ir.TensorBox, reduce_op, group_name): + realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM, group_name) + return pytree.tree_map( + ir.TensorBox.create, + ir.FallbackKernel.create( + torch.ops.symm_mem.one_shot_all_reduce.default, + inp, + reduce_op, + group_name, + ), + ) + + +def register_comm_lowerings(): + try: + torch.ops._c10d_functional.all_reduce + except AttributeError: + log.info( + "Inductor support for distributed collectives depends on building " + "torch.distributed" + ) + return + + from .lowering import ( + add_layout_constraint, + clone, + constrain_to_fx_strides, + copy_, + register_lowering, + ) + + def register_comm_lowering(fn): + add_layout_constraint(fn, constrain_to_fx_strides) + return register_lowering(fn) + + c10d = torch.ops._c10d_functional + + @register_comm_lowering(c10d.all_reduce) # type: ignore[misc] + def _all_reduce(inp: ir.TensorBox, reduce_op: str, group_name: str) -> ir.TensorBox: + if _should_lower_as_one_shot_all_reduce(inp, reduce_op, group_name): + return _one_shot_all_reduce(inp, reduce_op, group_name) + + # Lower as c10d.all_reduce_ + inp = clone(inp) + if config.reorder_for_compute_comm_overlap: + # The horizontal fusion of this clone often severely delays the + # scheduling of the all_reduce_ node. Horizontally fusing this + # clone can almost never out-perform scheduling the all_reduce_ + # earlier. Also in most cases, this clone is eliminated via + # in-place reuse. Therefore, we tell the scheduler to not fuse it. + inp.realize() + V.graph.no_fuse_buffer_names.add(inp.get_name()) + inp = ir.ExternKernel.require_contiguous(inp) + ir._CollectiveKernel.create_inplace( + c10d.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_comm_lowering(c10d.all_reduce_) # type: ignore[misc] + def _all_reduce_( + inp: ir.TensorBox, reduce_op: str, group_name: str + ) -> ir.TensorBox: + if _should_lower_as_one_shot_all_reduce(inp, reduce_op, group_name): + ret = copy_( + inp, + _one_shot_all_reduce(inp, reduce_op, group_name), + ) + mark_as_skip_wait(ret) + return inp + + # Lower as c10d.all_reduce_ + inp = ir.ExternKernel.require_contiguous(inp) + ir._CollectiveKernel.create_inplace( + c10d.all_reduce_.default, inp, reduce_op, group_name + ) + return inp + + @register_comm_lowering(c10d.all_reduce_coalesced) + def _all_reduce_coalesced(inputs, reduce_op, group_name): + inputs = [clone(inp) for inp in inputs] + ir._CollectiveKernel.create_inplace( + c10d.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_comm_lowering(c10d.all_reduce_coalesced_) + def _all_reduce_coalesced_(inputs, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + c10d.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_comm_lowering(c10d.all_gather_into_tensor) + def _all_gather_into_tensor(inp, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + c10d.all_gather_into_tensor.default, + inp, + group_size, + group_name, + ) + ) + + @register_comm_lowering(c10d.all_gather_into_tensor_coalesced) + def _all_gather_into_tensor_coalesced(inputs, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + c10d.all_gather_into_tensor_coalesced.default, + inputs, + group_size, + group_name, + ), + ) + + @register_comm_lowering(c10d.all_gather_into_tensor_out) + def _all_gather_into_tensor_out(inp, group_size, group_name, *, out): + ir._CollectiveKernel.create_inplace( + c10d.all_gather_into_tensor_out.default, + inp, + group_size, + group_name, + out=out, + ) + return out + + @register_comm_lowering(c10d.reduce_scatter_tensor) + def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + c10d.reduce_scatter_tensor.default, + inp, + reduce_op, + group_size, + group_name, + ) + ) + + @register_comm_lowering(c10d.reduce_scatter_tensor_coalesced) + def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + c10d.reduce_scatter_tensor_coalesced.default, + inputs, + reduce_op, + group_size, + group_name, + ), + ) + + @register_comm_lowering(c10d.all_to_all_single) + def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + c10d.all_to_all_single.default, + inp, + output_split_sizes, + input_split_sizes, + group_name, + ) + ) + + @register_comm_lowering(c10d.broadcast) + def _broadcast(inp, src, group_name): + inp = clone(inp) + ir._CollectiveKernel.create_inplace( + c10d.broadcast_.default, inp, src, group_name + ) + return inp + + @register_comm_lowering(c10d.broadcast_) + def _broadcast_(inp, src, group_name): + ir._CollectiveKernel.create_inplace( + c10d.broadcast_.default, inp, src, group_name + ) + return inp + + @register_comm_lowering(torch.ops._dtensor.shard_dim_alltoall) + def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + torch.ops._dtensor.shard_dim_alltoall.default, + inp, + gather_dim, + shard_dim, + group_name, + ) + ) + + @register_comm_lowering(c10d.wait_tensor) + def _wait_tensor(inp): + if should_skip_wait(inp): + return inp + + ir._WaitKernel.create_wait(c10d.wait_tensor.default, inp) + return inp diff --git a/phivenv/Lib/site-packages/torch/_inductor/comms.py b/phivenv/Lib/site-packages/torch/_inductor/comms.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e139b2ae5defc70906804ab153535bddf7fec8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/comms.py @@ -0,0 +1,1063 @@ +# mypy: allow-untyped-defs +# pyre-strict +from __future__ import annotations + +import heapq +import importlib +import logging +import operator +import sys +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, TYPE_CHECKING + +import torch +from torch._logging import trace_structured +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._ordered_set import OrderedSet + +from . import config, ir +from .dependencies import WeakDep +from .memory import estimate_peak_memory, FreeableInputBuffer, get_freeable_input_buf +from .utils import ( + contains_collective, + contains_wait, + find_recursive_deps_of_node, + find_recursive_users_of_node, + is_collective, + is_fallback_op, + is_wait, +) +from .virtualized import V + + +log = logging.getLogger(__name__) +overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") + +if TYPE_CHECKING: + from torch._inductor.scheduler import BaseSchedulerNode + + +def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Greedily schedules waits as late as possible. + """ + return _schedule_for_comm( + snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False + ) + + +def raise_comms(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Greedily schedules comms as early as possible. + """ + return _schedule_for_comm( + snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False + ) + + +def reorder_compute_for_overlap( + snodes: list[BaseSchedulerNode], +) -> list[BaseSchedulerNode]: + """ + This achieves the following overall scheduling procedure: + Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes + that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. + Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. + Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. + We prioritize compute nodes that are needed sooner. + Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. + Step 4: We schedule comm N + 1. + Repeat this for subsequent comm nodes. + """ + return _schedule_for_comm( + snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True + ) + + +def reorder_communication_preserving_peak_memory( + snodes: list[BaseSchedulerNode], +) -> list[BaseSchedulerNode]: + """ + Reorders communication ops relative to computation ops to improve communication-compute overlapping and hide comm + latency. Stops moving a particular op if it reaches a point that would have increased the peak memory footprint. + + Currently, follows these heuristics (subject to change or tune): + - never reorders collectives relative to one another, for SPMD safety + - has an option for per-collective prefetch limit, but does not enable it by default + - limits the total number of reorder steps to some factor of the graph size to prevent worst-case quadratic + performance + + Prerequisite: sink_comms_and_waits - ensure comm and wait nodes are scheduled as late as possible, respecting data + dependencies. That allows reorder_communication_preserving_peak_memory to take a best case peak-memory snapshot, + and then monotonically improve latency by moving collectives backward in time. + + Peak memory impact is computed in an iterative fashion. First, memory use at each timestep is computed, and global + peak memory is computed as a max over timesteps. Then, when swapping any two adjacent nodes, only the curr-memory + for the earlier of the nodes after the swap is affected. This enables checking step by step whether a swap is + peak-memory-safe, and bailing out if not. Example: + + 0 n0 C0 + 1 n1 C0 + Allocs(n1) - Frees(n1) + 2 n2 C0 + Allocs(n1) - Frees(n1) + Allocs(n2) - Frees(n2) + + 0 n0 C0 + 1 n2 C0 + Allocs(n2) - Frees(n2) <-- After moving n2 to Time 1, only time1 memory changes + 2 n1 C0 + Allocs(n2) - Frees(n2) + Allocs(n1) - Frees(n1) + + """ + reordered_snodes, node_stats = ( + _reorder_communication_preserving_peak_memory_internal(snodes) + ) + improvement = {snode: node_stats[snode].improvement for snode in node_stats} + total_improvement = sum([improvement[snode] for snode in improvement]) + total_moves = sum([node_stats[snode].moves for snode in node_stats]) + + reorder_log_str = ( + f"reorder_communication_preserving_peak_memory improved overlap by {total_improvement} ns" + f" after {total_moves} reorders.\n" + ) + headers = [ + "Collective node", + "initial exposed", + "final exposed", + "improvement", + "limiting factor", + "moves", + ] + rows = [ + [ + node_summary(snode), + node_reorder_info.initial_exposed, + node_reorder_info.final_exposed, + node_reorder_info.improvement, + node_reorder_info.limiting_factor, + node_reorder_info.moves, + ] + for snode, node_reorder_info in node_stats.items() + ] + if importlib.util.find_spec("tabulate"): + from tabulate import tabulate + + reorder_log_str += tabulate( + rows, + headers=headers, + ) + else: + reorder_log_str += ( + "Please `pip install tabulate` to nicely render overlap stats.\n" + ) + reorder_log_str += str(headers) + "\n" + reorder_log_str += "\n".join(map(str, rows)) + overlap_log.info(reorder_log_str) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "reorder_communication_preserving_peak_memory", + "encoding": "string", + }, + payload_fn=lambda: reorder_log_str, + ) + + return reordered_snodes + + +@dataclass +class ReorderInfo: + """ + Debug info describing how an individual snode was reordered + """ + + initial_exposed: float = -1 + final_exposed: float = -1 + limiting_factor: str = "None" + moves: int = 0 + + @property + def improvement(self): + return self.initial_exposed - self.final_exposed + + +def _reorder_communication_preserving_peak_memory_internal( + snodes: list[BaseSchedulerNode], +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: + """ + Internal testing helper that also returns debug info. + Returns: + - reordered snodes list + - dict {snode: ReorderInfo} + """ + # heuristic to avoid degenerating to quadratic time + MOVE_LIMIT = len(snodes) * 100 + total_moves = 0 + # TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it + PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes) + if config.reorder_prefetch_limit is not None: + PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( + snodes, graph_inputs + ) + peak_memory, curr_memory = estimate_peak_memory( + snodes, name_to_freeable_input_buf, graph_outputs + ) + runtimes = {snode: estimate_op_runtime(snode) for snode in snodes} + + # debug stats + stats: dict[BaseSchedulerNode, ReorderInfo] = {} + + def exposed_communication_time(collective_snode, remaining_snodes): + # assumes a linear schedule and computes the overlap of the collective with the remaining nodes + comm_time = estimate_op_runtime(collective_snode) + compute_time = 0.0 + for snode in remaining_snodes: + if contains_collective(snode): + continue + if contains_wait(snode): + # TODO - if the wait is for a collective that started before this collective or on another stream, + # we can ignore it. Otherwise, it's the end of the road for overlap opportunities + break + + compute_time += runtimes[snode] + return max(0, comm_time - compute_time) + + for i, snode in enumerate(snodes): + if contains_collective(snode): + reorder_info = stats[snode] = ReorderInfo() + reorder_info.initial_exposed = reorder_info.final_exposed = ( + exposed_communication_time(snode, snodes[i + 1 :]) + ) + if total_moves >= MOVE_LIMIT: + reorder_info.limiting_factor = "move limit" + continue + for j in range(i - 1, -1, -1): + prev_snode = snodes[j] + if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT): + reorder_info.limiting_factor = "prefetch limit" + break + if contains_collective(prev_snode): + reorder_info.limiting_factor = "collective ordering" + break + dep_names = OrderedSet([s.name for s in snode.unmet_dependencies]) + if any( + o.get_name() in dep_names for o in prev_snode.get_outputs() + ) and not contains_wait(prev_snode): + reorder_info.limiting_factor = "data dependency" + break + if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]: + reorder_info.limiting_factor = "peak memory" + break + if reorder_info.final_exposed > runtimes[snode]: + reorder_info.limiting_factor = "sufficient overlapping" + break + reorder_info.moves += 1 + total_moves += 1 + tmp = snodes[j] + snodes[j] = snodes[j + 1] + snodes[j + 1] = tmp + # swapping nodes j and j+1 affects curr memory at j only + j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j] + j_alloc = curr_memory[j] - curr_memory[j - 1] + curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc + reorder_info.final_exposed = exposed_communication_time( + snode, snodes[j + 1 :] + ) + + return snodes, stats + + +def _schedule_for_comm( + snodes: list[BaseSchedulerNode], + raise_comms: bool, + sink_waits: bool, + reorder_for_overlap: bool, +) -> list[BaseSchedulerNode]: + """ + Schedule `snodes` for various comm optimization objectives. + + Args: + snodes: the nodes to be scheduled. + raise_comms: whether to greedily schedule collectives as early as possible + sink_wait: whether to greedily schedule waits as late as possible + reorder_compute_for_overlap: whether to reorder compute nodes to + optimize for compute/communication overlapping. + + Returns: + The new schedule order. + + Some notes on the synergy between different options: + - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`. + - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized. + """ + # We assign each node a tuple of scores (score_0, score_1, score_2), + # decreasing in importance, with a lower value indicating a higher ranking: + # + # - score_0: the lowest comm_idx among the comm nodes that the node blocks. + # If a node doesn't block any comm nodes, its score_0 is set to + # sys.maxsize. This score ensures that comm nodes get scheduled as early as + # possible. + # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures + # that wait nodes are deferred as late as possible. + # - score_2: the index of the node in the original topological order. This + # score provides stability in case of ties. + # + # When only raise_comms is True, only score_0 and score_2 are considered. + # When only sink_waits is True, only score_1 and score_2 are considered. + # When neither is True, the original order is yielded. + buf_name_to_snode = {} + name_to_fused_node = {} + scores_0, scores_1, scores_2 = {}, {}, {} + for idx, snode in enumerate(snodes): + for buf_name in snode.get_buffer_names(): + buf_name_to_snode[buf_name] = snode + + for op_name in snode.get_operation_names(): + name_to_fused_node[op_name] = snode + name_to_fused_node[snode.get_name()] = snode + + node_name = snode.get_name() + scores_0[node_name] = sys.maxsize + scores_1[node_name] = 0 + scores_2[node_name] = idx + + comm_idx = 0 + for snode in snodes: + if raise_comms and contains_collective(snode): + scores_0[snode.get_name()] = comm_idx + for ancestor in snode.ancestors: + anc_fused_name = name_to_fused_node[ancestor].get_name() + scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx) + comm_idx += 1 + elif sink_waits and contains_wait(snode): + scores_1[snode.get_name()] = 1 + + class Runnable: + def __init__(self, snode) -> None: + self.snode = snode + name = next(iter(snode.get_operation_names())) + fused_name = name_to_fused_node[name].get_name() + self.score = ( + scores_0[fused_name], + scores_1[fused_name], + scores_2[fused_name], + ) + + def __lt__(self, other): + return self.score < other.score + + unmet_deps: dict[BaseSchedulerNode, OrderedSet[str]] = { + snode: OrderedSet(dep.name for dep in snode.unmet_dependencies) + for snode in snodes + } + + ready: list[Runnable] = [] + buffer_users: dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet) + snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes} + + for snode, deps in unmet_deps.items(): + if len(deps) == 0: + heapq.heappush(ready, Runnable(snode)) + for dep in deps: + buffer_users[dep].add(snode) + + scheduled = [] + + def schedule(snode): + """ + Schedules `snode` and put all unblocked nodes onto the ready queue. + """ + scheduled.append(snode) + for buf_name in snode.get_buffer_names(): + for snode in buffer_users[buf_name]: + unmet_deps[snode].remove(buf_name) + if len(unmet_deps[snode]) == 0: + heapq.heappush(ready, Runnable(snode)) + + def get_overlapping_candidate(): + """ + Return the next node in the ready queue that's neither a collective or + a wait. + """ + candidates = [ + x + for x in ready + if not contains_collective(x.snode) and not contains_wait(x.snode) + ] + if len(candidates) == 0: + return None + return min(candidates, key=lambda x: x.score) + + def schedule_collective_for_overlap(snode): + """ + Schedules collective node `snode`, along with one or more compute nodes + to overlap with it. The strategy is described in the comment of + `reorder_compute_for_overlap`. + """ + assert contains_collective(snode) + schedule(snode) + + collective_cost = snode_to_cost[snode] + while ( + collective_cost > 0 + and (candidate := get_overlapping_candidate()) is not None + ): + ready.remove(candidate) + schedule(candidate.snode) + collective_cost -= snode_to_cost[candidate.snode] + heapq.heapify(ready) + + while len(ready): + snode = heapq.heappop(ready).snode + if reorder_for_overlap and contains_collective(snode): + schedule_collective_for_overlap(snode) + else: + schedule(snode) + + for snode, deps in unmet_deps.items(): + assert len(deps) == 0, ( + f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}" + ) + return scheduled + + +def decide_global_ordering_of_comms( + nodes: list[BaseSchedulerNode], name_to_buf, name_to_fused_node +) -> list[BaseSchedulerNode]: + """ + Decide global ordering of comms, by just enforcing the ordering that's in the input graph + (might not be the same ordering as the eager mode program). + TODO: Come up with a better approach + """ + if not torch.distributed.is_available(): + return nodes + + comm_nodes = [n for n in nodes if contains_collective(n)] + + for i in range(1, len(comm_nodes)): + # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm + mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) + for buf in comm_nodes[i - 1].get_buffer_names(): + comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) + + return nodes + + +def estimate_op_runtime(snode: BaseSchedulerNode) -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + if config.estimate_op_runtime == "default": + runtime = snode.get_estimated_runtime() + else: + assert callable(config.estimate_op_runtime) + runtime = config.estimate_op_runtime(snode) + return runtime + + +def node_summary(snode): + snodes = snode.get_nodes() + if len(snodes) == 1: + detail = "" + if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): + detail = f" ({snode.node.python_kernel_name})" + layouts = [child.node.get_output_spec() for child in snode.get_nodes()] + out_tensor_info = ",".join( + [ + f" (size={layout.size}, stride={layout.stride})" + if isinstance(layout, ir.Layout) + else "" + for layout in layouts + ] + ) + try: + node_name = snode.node.maybe_get_name() + except AttributeError: + # TODO: node_summary was written without FusedSchedulerNode in mind, generally needs to be hardened + node_name = "" + return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name} ({snode.get_estimated_runtime():.0f} ns)" + + # Flatten the summaries for Fused/Foreach/Grouped nodes + summaries = [] + for child_snode in snodes: + summaries.append(node_summary(child_snode)) + return f"{snode.__class__.__name__}: {', '.join(summaries)}" + + +def visualize_overlap(order): + # TODO - this function probably doesn't do a very good job estimating the runtime because it doesn't carefully model + # streams and overlap. For now its mostly useful as a debug visualization. + + total_est_runtime: float = 0.0 + cur_comm_node = None + + def step_log(step, msg): + overlap_log.debug(f"{step:>6}: {msg}") # noqa: G004 + + for step, snode in enumerate(order): + if cur_comm_node is None: + if contains_collective(snode): + total_est_runtime += estimate_op_runtime(snode) + cur_comm_node = snode.node + elif is_wait(snode.node): + # raise AssertionError( + # "Wait is not expected when there is no collective running" + # ) + pass + else: # exposed compute op + total_est_runtime += estimate_op_runtime(snode) + step_log(step, f"{node_summary(snode)}") + else: # cur_comm_node is not None + if contains_collective(snode): + total_est_runtime += estimate_op_runtime(snode) + cur_comm_node = snode.node + step_log(step, f"{node_summary(snode)}") # noqa: G004 + elif is_wait(snode.node): # end of this comm op + step_log(step, f"{node_summary(snode)}") + cur_comm_node = None + else: # overlapped compute op + step_log(step, f"| {node_summary(snode)}") + overlap_log.debug( + f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 + ) + + +def reorder_compute_and_comm_for_overlap( + snodes: list[BaseSchedulerNode], +) -> list[BaseSchedulerNode]: + order = snodes + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + for p in config.reorder_for_compute_comm_overlap_passes: + if isinstance(p, str) and p in globals(): + p = globals()[p] # it is a builtin pass + assert callable(p), ( + f"Invalid reorder_compute_and_comm_for_overlap pass: {p} is not callable" + ) + peak_memory, _ = estimate_peak_memory( + snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs + ) + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap before reordering pass {p}, {peak_memory=} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug("", exc_info=e) + t0 = time.time() + order = p(order) # type: ignore[operator] + t = time.time() - t0 + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap after reordering pass {p} (ran in {t} sec)====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug("", exc_info=e) + peak_memory, _ = estimate_peak_memory( + snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs + ) + print(f"final {peak_memory=}") + return order + + +def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph): + """ + This FX graph pass replaces uses of FSDP2 unsharded params with their corresponding + graph intermediates that were fsdp.copy_ into the unsharded params in the original graph. + + NOTE: Can only apply this pass to any of the FSDP2 unsharded params that have this pattern + (or repetition of): `resize_(full) -> copy_ -> resize_(0)`. Because of this, for partial-graph case + where `resize_(full) -> copy_` is in one graph and `resize_(0)` is in another graph, we can't + remove these resize and copy ops and thus we will have worse performance there. + + In other words, "do we try to remove all the resize_(full) -> copy_ -> resize_(0) nodes for this unsharded param" + is actually a per-unsharded-param decision, since for each unsharded param, we look at its resize sequence pattern + (in `check_resize_pattern()`) to determine if its set of resize and copy nodes can be removed. + """ + node_list = list(graph.nodes) + + # Find all graph inputs and their resize counts + graph_input_to_resized_to_full_node_idxes = defaultdict(list) + graph_input_to_resized_to_0_node_idxes = defaultdict(list) + for idx, node in enumerate(node_list): + if ( + node.op == "call_function" + and node.target == torch.ops.inductor.resize_storage_bytes_.default + ): + assert node.args[0].op == "placeholder", f"""\ +Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]} +""" + graph_input = node.args[0] + new_size = node.args[1] + if new_size > 0: + graph_input_to_resized_to_full_node_idxes[graph_input].append(idx) + else: + graph_input_to_resized_to_0_node_idxes[graph_input].append(idx) + + def check_resize_pattern(graph_input): + # Check the number of resize-to-full and resize-to-0 nodes are equal, + # and that for each (resize-to-full, resize-to-0) pair, the resize-to-full node + # always happens before the resize-to-0 node. + # This is the precondition for being able to remove all the resize and copy nodes + # for this specific unsharded param. + resized_to_full_idxes = graph_input_to_resized_to_full_node_idxes.get( + graph_input, [] + ) + resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, []) + + if not len(resized_to_full_idxes) == len(resized_to_0_idxes): + log.warning( + f""" +Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}: +{len(resized_to_full_idxes)} vs. {len(resized_to_0_idxes)}. +Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass. +""" # noqa: G004 + ) + return False + + # Check the sequence: (resize_to_full -> resize_to_0)+ + for resize_to_full_idx, resize_to_0_idx in zip( + resized_to_full_idxes, resized_to_0_idxes + ): + if resize_to_full_idx >= resize_to_0_idx: + log.warning( + f""" +For graph input {graph_input}: resize-to-full node {node_list[resize_to_full_idx]} at index {resize_to_full_idx} +happens after resize-to-0 node {node_list[resize_to_0_idx]} at index {resize_to_0_idx}. +Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that unsharded param. +""" # noqa: G004 + ) + return False + return True + + # Find all eligible unsharded params and their corresponding graph intermediates. + unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list) + for idx, node in enumerate(node_list): + if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default: + fsdp_copy_node = node + unsharded_param = node.args[0] + assert unsharded_param.op == "placeholder", f""" +Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true! +Offending node: {unsharded_param}. Graph: {graph} +""" + if check_resize_pattern(unsharded_param): + unsharded_param_to_fsdp_copy_node_idxes[unsharded_param].append(idx) + + def is_allowed_mutation(node): + return ( + node.target == torch.ops.fsdp.copy_.default + or node.target == torch.ops.inductor.resize_storage_bytes_.default + ) + + def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params): + # Check whether the node is mutating any of the unsharded params or their aliases. + mutated_arg_idxes = ( + [ + i + for i, x in enumerate(node.target._schema.arguments) + if x.alias_info is not None and x.alias_info.is_write + ] + if isinstance(node.target, torch._ops.OpOverload) + else [] + ) + mutated_node_arg_storages = OrderedSet( + [ + StorageWeakRef(node.args[i].meta["val"].untyped_storage()) + for i in mutated_arg_idxes + ] + ) + storages_of_unsharded_params = OrderedSet( + [ + StorageWeakRef(unsharded_param.meta["val"].untyped_storage()) + for unsharded_param in unsharded_params + ] + ) + return len(mutated_node_arg_storages & storages_of_unsharded_params) > 0 + + # Check no user mutation on any unsharded_param + for node in node_list: + if ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and node.target._schema.is_mutable + and not is_allowed_mutation(node) + ): + assert not is_node_mutating_unsharded_param_or_its_alias( + node, unsharded_param_to_fsdp_copy_node_idxes.keys() + ), f"""\ +User mutation on FSDP2 unsharded param is not allowed when Traceable FSDP2 is used. Violating node: {node} +""" + + # For each `fsdp.copy_(unsharded_param, Y)`, replace downstream usage of `unsharded_param` with `Y`. + # + # NOTE: Because of "layer reuse" use case, there could be multiple `fsdp.copy_` to the same `unsharded_param` graph input. + # e.g. + # ``` + # fsdp_copy_1 = fsdp.copy_(unsharded_param_1, Y1) + # ... (use of unsharded_param_1) -> Subgraph 1 + # fsdp_copy_2 = fsdp.copy_(unsharded_param_1, Y2) + # ... (use of unsharded_param_1) -> Subgraph 2 + # fsdp_copy_3 = fsdp.copy_(unsharded_param_1, Y3) + # ... (use of unsharded_param_1) -> Subgraph 3 + # ``` + # We must do the replacement only within each subgraph. + for ( + unsharded_param, + fsdp_copy_node_idxes, + ) in unsharded_param_to_fsdp_copy_node_idxes.items(): + for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes): + fsdp_copy_node = node_list[fsdp_copy_node_idx] + assert fsdp_copy_node.args[0] is unsharded_param + _, replacement = fsdp_copy_node.args + # subgraph_start_idx is exclusive + subgraph_start_idx = fsdp_copy_node_idx + 1 + # subgraph_end_idx is exclusive (also intentionally don't replace args in return op) + subgraph_end_idx = ( + fsdp_copy_node_idxes[i + 1] + if i < len(fsdp_copy_node_idxes) - 1 + else len(node_list) - 1 + ) + subgraph_nodes = node_list[subgraph_start_idx:subgraph_end_idx] + assert not any( + is_node_mutating_unsharded_param_or_its_alias(node, [unsharded_param]) + for node in subgraph_nodes + ), f"""\ +Assumed no ops mutating unsharded param {unsharded_param} in subgraph {subgraph_nodes}, but it's not true! +Graph: {graph} +""" + for node in subgraph_nodes: + if ( + node.op == "call_function" + and unsharded_param in node.args + and node.target != torch.ops.inductor.resize_storage_bytes_.default + ): # TODO(yf225): implement replacement in kwargs + new_args = tuple( + replacement if arg is unsharded_param else arg + for arg in node.args + ) + node.args = new_args + + # Delete `fsdp.copy_(unsharded_param, Y)` nodes + for ( + unsharded_param, + fsdp_copy_node_idxes, + ) in unsharded_param_to_fsdp_copy_node_idxes.items(): + for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes): + fsdp_copy_node = node_list[fsdp_copy_node_idx] + graph.erase_node(fsdp_copy_node) + + # Delete `resize_(unsharded_param, ...)` nodes + for node in node_list: + if ( + node.op == "call_function" + and node.target == torch.ops.inductor.resize_storage_bytes_.default + and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes + ): + graph.erase_node(node) + + +def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: + try: + import torch.distributed.fsdp._fully_shard._fsdp_collectives + + assert torch.distributed.is_available() + # Assert existence of these ops + assert ( + torch.ops._c10d_functional.all_gather_into_tensor + and torch.ops._c10d_functional.all_gather_into_tensor_out + ) + except (ImportError, AttributeError, AssertionError): + return + + from .pattern_matcher import ( + CallFunction, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + ) + + """ + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); + getitem = all_gather_copy_in[0]; + (getitem_1 = all_gather_copy_in[1];) # optional + + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...); + + -> + + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); + getitem = all_gather_copy_in[0]; + getitem_1 = all_gather_copy_in[1]; + + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1); + """ + + def remove_unused_getitem(g): + # Remove `getitem_X = all_gather_copy_in[1]` which is never used. + node_list = list(g.nodes) + for n in node_list: + if ( + n.target == operator.getitem + and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default + and n.args[1] == 1 + ): + g.erase_node(n) + + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunction( + torch.ops._c10d_functional.all_gather_into_tensor.default, + CallFunction( + operator.getitem, + CallFunction( + torch.ops.fsdp.all_gather_copy_in.default, + KeywordArg("all_gather_inputs"), + KeywordArg("inp_split_sizes"), + KeywordArg("all_gather_input_numel"), + KeywordArg("world_size"), + KeywordArg("rank"), + KeywordArg("dtype"), + KeywordArg("device"), + KeywordArg("group_name_inner"), + KeywordArg("allocate_memory_from_process_group"), + ), + KeywordArg("item_idx"), + ), + KeywordArg("group_size"), + KeywordArg("group_name"), + ), + pass_dict=graph_pass, + extra_check=lambda match: match.kwargs["item_idx"] == 0, + ) + def reinplace_all_gather(match: Match, *args, **kwargs): + def repl( + *args, + ): + copy_in_args = args[:-2] + group_size = args[-2] + group_name = args[-1] + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default( + *copy_in_args + ) + getitem = all_gather_copy_in[0] + getitem_1 = all_gather_copy_in[1] + all_gather_into_tensor = ( + torch.ops._c10d_functional.all_gather_into_tensor_out.default( + getitem, group_size, group_name, out=getitem_1 + ) + ) + return all_gather_into_tensor + + match.replace_by_example( + repl, + [ + kwargs["all_gather_inputs"], + kwargs["inp_split_sizes"], + kwargs["all_gather_input_numel"], + kwargs["world_size"], + kwargs["rank"], + kwargs["dtype"], + kwargs["device"], + kwargs["group_name_inner"], + kwargs["allocate_memory_from_process_group"], + kwargs["group_size"], + kwargs["group_name"], + ], + ) + + remove_unused_getitem(graph) + graph_pass.apply(graph) # type: ignore[arg-type] + + +def get_op_idx(snode): + assert not isinstance( + snode, + ( + torch._inductor.scheduler.FusedSchedulerNode, + torch._inductor.scheduler.GroupedSchedulerNode, + ), + ) + return int(snode.get_name()[2:]) + + +def enforce_comm_ordering_for_fsdp( + snodes: list[torch._inductor.scheduler.BaseSchedulerNode], + name_to_buf: dict[str, torch._inductor.scheduler.SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], +) -> list[torch._inductor.scheduler.BaseSchedulerNode]: + from . import scheduler + + new_order: list[BaseSchedulerNode] = [] + scheduled = OrderedSet[Any]() + ag_exists = False + rs_exists = False + ag_grouped_node_to_wait_grouped_node = {} + rs_grouped_node_to_wait_grouped_node = {} + snode_name_to_final_snode = {} + + def _create_group_node(snodes_to_group): + group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group) + for snode in snodes_to_group: + snode_name_to_final_snode[snode.get_name()] = group_node + snode_name_to_final_snode[group_node.get_name()] = group_node + return group_node + + # Create grouped nodes for specific sets of ops + for snode in snodes: + # Case 1: Handle AllGather + if is_collective( + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default + ) and any( + is_fallback_op( + name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default + ) + for x in snode.ancestors + ): + ag_exists = True + ag_snode = snode + ag_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() + + # Find the "cast + copy_in + getitem + all_gather" code block + find_recursive_deps_of_node( + ag_snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + ) + + # Find the "all_gather + all_gather_wait_tensor + copy_out" code block + allowed_ops = OrderedSet( + [ + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + torch.ops._c10d_functional.wait_tensor.default, + torch.ops.fsdp.split_with_sizes_copy.default, + ] + ) + find_recursive_users_of_node( + ag_snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + criteria_cb=lambda x: not ( + isinstance(x, scheduler.NopKernelSchedulerNode) + or ( + isinstance(x, scheduler.ExternKernelSchedulerNode) + and x.node.op_overload in allowed_ops # type: ignore[union-attr] + ) + ), + ) + + # sort nodes by original operation order + ag_related_snodes = sorted( + ag_related_snode_set, key=lambda x: get_op_idx(x) + ) + + # In the "reuse layer" case, some ops in the 2nd all-gather code block could also + # depend on ops in the 1st all-gather code block, and we don't want to group them together. + end_idx_of_current_ag_block = len(ag_related_snodes) + copy_out_count = 0 + for i in range(len(ag_related_snodes)): + cur_snode = ag_related_snodes[i] + if is_fallback_op( + cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default + ): + copy_out_count += 1 + if copy_out_count > 1: + end_idx_of_current_ag_block = i + break + + ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block] + + # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode + wait_node_idx = None + for i in range(len(ag_related_snodes) - 1): + if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel): + wait_node_idx = i + 1 + break + assert wait_node_idx is not None + ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx]) + + # Group "all_gather_wait_tensor + copy_out" into one GroupedSchedulerNode + ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:]) + + ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node + + # Case 2: Handle ReduceScatter + elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default): + rs_exists = True + rs_snode = snode + + # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block + rs_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() + find_recursive_users_of_node( + rs_snode, + rs_related_snode_set, + name_to_buf, + name_to_fused_node, + ) + + # sort nodes by original operation order + rs_related_snodes = sorted( + rs_related_snode_set, key=lambda x: get_op_idx(x) + ) + + # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode + wait_node_idx = None + for i in range(len(rs_related_snodes) - 1): + if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel): + wait_node_idx = i + 1 + break + assert wait_node_idx is not None + rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx]) + + # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode + rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:]) + + rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node + + assert len(snode_name_to_final_snode) > 0 + if ag_exists: + assert len(ag_grouped_node_to_wait_grouped_node) > 0 + if rs_exists: + assert len(rs_grouped_node_to_wait_grouped_node) > 0 + + # Build the new node schedule, taking GroupedSchedulerNode into account + for snode in snodes: + if snode.get_name() in snode_name_to_final_snode: + snode = snode_name_to_final_snode[snode.get_name()] + if snode in scheduled: + continue + new_order.append(snode) + scheduled.add(snode) + + # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run + # before next AllGather's "copy_in then AG" group node + prev_ag_wait = None + for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items(): + if prev_ag_wait is not None: + mutating_buf = next(iter(ag_group_node.get_buffer_names())) + for o in prev_ag_wait.get_outputs(): + ag_group_node.add_fake_dep( + WeakDep(o.get_name(), mutating_buf=mutating_buf) + ) + prev_ag_wait = wait_group_node + + # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run + # before next ReduceScatter's "copy_in then RS" group node + prev_rs_wait = None + for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items(): + if prev_rs_wait is not None: + mutating_buf = next(iter(rs_group_node.get_buffer_names())) + for o in prev_rs_wait.get_outputs(): + rs_group_node.add_fake_dep( + WeakDep(o.get_name(), mutating_buf=mutating_buf) + ) + prev_rs_wait = wait_group_node + + return new_order # type: ignore[return-value] diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_fx.py b/phivenv/Lib/site-packages/torch/_inductor/compile_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..98f4c1bd322031ed0c68194dc671df8cb4913dc5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/compile_fx.py @@ -0,0 +1,2614 @@ +from __future__ import annotations + +import contextlib +import enum +import functools +import io +import itertools +import json +import logging +import os +import sys +import time +import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from contextlib import AbstractContextManager +from inspect import currentframe +from itertools import count +from operator import attrgetter +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack +from unittest import mock + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch.fx +import torch.utils._pytree as pytree +from functorch.compile import min_cut_rematerialization_partition +from torch import fx +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo import ( + compiled_autograd, + config as dynamo_config, + logging as dynamo_logging, + utils as dynamo_utils, +) +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.repro.after_aot import wrap_compiler_debug +from torch._dynamo.utils import ( + chromium_event_timed, + CompileEventLogger, + counters, + detect_fake_mode, + dynamo_timed, + flatten_graph_inputs, + get_metrics_context, + lazy_format_graph_code, + set_feature_use, +) +from torch._functorch import config as functorch_config +from torch._functorch._aot_autograd.subclass_parametrization import ( + unwrap_tensor_subclass_parameters, +) +from torch._functorch.aot_autograd import ( + aot_export_module, + make_boxed_func, + SerializableAOTDispatchCompiler, +) +from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + format_default_skip_message, + log_cudagraph_skip_and_bump_counter, + PlaceholderInfo, +) +from torch._inductor.debug import save_args_for_compile_fx_inner +from torch._inductor.output_code import ( + CompiledAOTI, + CompiledFxGraph, + CompiledFxGraphConstantsWithGm, + get_expanded_dims, + index_expanded_dims, + OutputCode, +) +from torch._inductor.runtime.cache_dir_utils import cache_dir +from torch._inductor.utils import ( + BoxedBool, + count_tangents, + fresh_cache, + get_all_devices, + InputType, + is_gpu, + should_assume_input_aligned, + should_use_remote_fx_graph_cache, + tensor_is_aligned, +) +from torch._library.fake_class_registry import FakeScriptObject +from torch._logging import trace_structured +from torch._utils_internal import compile_time_strobelight_meta +from torch.fx import GraphModule +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.monitor import _WaitCounter +from torch.utils._ordered_set import OrderedSet + +from .._dynamo.backends.common import aot_autograd +from .._dynamo.exc import ShortenTraceback, SkipFrame +from ..fx._lazy_graph_module import _use_lazy_graph_module +from ..fx.graph import _PyTreeCodeGen +from ..utils._triton import has_triton +from . import config, metrics +from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration +from .debug import DebugContext +from .decomposition import select_decomp_table +from .exc import InductorError +from .fx_passes.joint_graph import joint_graph_passes +from .fx_passes.post_grad import post_grad_passes, view_to_reshape +from .fx_passes.pre_grad import pre_grad_passes +from .graph import GraphLowering +from .ir import get_device_type, IRNode +from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401 +from .triton_bundler import TritonBundler +from .utils import ( + align_inputs_from_check_idxs, + clone_preserve_strides, + copy_misaligned_inputs, + get_cloned_parameter_buffer_name, + get_first_incompatible_cudagraph_node, + maybe_get_suppress_shape_guards_ctx, + output_node, + remove_unaligned_input_idxs, + shape_env_from_inputs, +) +from .virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Generator, Sequence + + from torch._inductor.output_code import _StrideExprStr + from torch._ops import OpOverload + from torch.export.pt2_archive._package_weights import Weights + + from .ir import ExternKernelNode + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +if TYPE_CHECKING or not config.is_fbcode(): + # no-op decorator + def time_and_log(attr: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + return dynamo_utils.identity + + def log_optimus_to_scuba(*args: object, **kwargs: object) -> None: + pass + +else: + from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log + +if TYPE_CHECKING: + from torch._functorch._aot_autograd.schemas import ( + FQN, + GraphInputName, + GraphSignature, + ) + + +class FxCompileMode(enum.Enum): + NORMAL = 0 + # For testing - use the serde FxCompile scheme to debug serialization and + # deserialization of GraphMoule and CompiledFxGraph. + SERIALIZE = 1 + # Compile using a subprocess instead of in-process. + SUBPROCESS = 2 + + +# Return compile mode and use_async flag +def _fx_compile_mode_default() -> tuple[FxCompileMode, bool]: + name = "TORCHINDUCTOR_FX_COMPILE_MODE" + value = os.environ.get(name) + if value is None: + return FxCompileMode.NORMAL, False + + use_async = False + if value.lower().startswith("async+"): + use_async = True + value = value[6:] + + try: + value = value.upper() + return FxCompileMode[value], use_async + except KeyError: + import logging + + log = logging.getLogger(__name__) + log.error( + "Invalid value of %s for %s. Expected one of %s. Using default.", + value, + name, + ", ".join(sorted(repr(x) for x in FxCompileMode.__members__.keys())), + ) + # Remove from the environment so subprocesses don't ALSO complain. + os.environ.pop(name) + return FxCompileMode.NORMAL, False + + +fx_compile_mode, fx_compile_async = _fx_compile_mode_default() + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +pre_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "pre_grad_graphs") +post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs") +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) +inductor_metrics_log = torch._logging.getArtifactLogger(__name__, "inductor_metrics") + + +def get_static_input_idxs(num_fixed: int) -> list[int]: + # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes + # of cudagraphs. Rather than copying these into cudagraph-owned memory + # like we do for normal inputs on each run, we will re-record a cudagraph if these + # parameter locations change. + context = torch._guards.TracingContext.try_get() + fixed = list(range(num_fixed)) + if not context or not context.fw_metadata: + return fixed + + return context.fw_metadata.static_input_indices + + +def record_original_output_strides(gm: GraphModule) -> None: + output_node = gm.graph.find_nodes(op="output")[0] + output_strides = [] + + if not isinstance(output_node.args[0], torch.fx.Node): + output_node_args = output_node.args[0] + else: + output_node_args = output_node.args + + for output in output_node_args: + if ( + isinstance(output, torch.fx.Node) + and (val := output.meta.get("val")) is not None + and isinstance(val, torch.Tensor) + ): + output_strides.append(val.stride()) + else: + output_strides.append(None) + output_node.meta["original_output_strides"] = output_strides + + +def _recursive_record_original_output_strides(gm: GraphModule) -> None: + # invoke_subgraph HOP requires output strides to be respected + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_subgraph + ): + subgraph = getattr(gm, node.args[0].target) + _recursive_record_original_output_strides(subgraph) + + record_original_output_strides(gm) + + +def _recursive_record_user_visible_output_idxs(gm: GraphModule) -> None: + # invoke_subgraph HOP requires output strides to be respected + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_subgraph + ): + subgraph = getattr(gm, node.args[0].target) + + for node in subgraph.graph.find_nodes(op="output"): + node.meta["user_visible_output_idxs"] = [ + idx + for idx in range(len(node.args[0])) + if isinstance(node.args[0][idx], torch.fx.Node) + ] + _recursive_record_user_visible_output_idxs(subgraph) + + +@functools.lru_cache(None) +def _step_logger() -> Callable[..., None]: + return dynamo_logging.get_step_logger(log) + + +@functools.cache +def _warn_tf32_disabled() -> None: + if ( + torch.cuda.is_available() + and not torch.backends.cuda.matmul.allow_tf32 + and torch.cuda.get_device_capability() >= (8, 0) + ): + warnings.warn( + "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. " + "Consider setting `torch.set_float32_matmul_precision('high')` for better performance." + ) + + +def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None: + """ + In aot_export_module (make_fx), we create get_attr nodes with name prefix + "_tensor_constant" and "_torchbind_obj". See Tracer.create_arg() in + torch/fx/_symbolic_trace.py + + However, this might result in name collision if the original mod already + has a different buffer with the same name. + + We resolve this potential name collision here by changing the target name + with a new number post fix. + """ + + existing_keys = OrderedSet( + [name for name, val in mod.named_parameters(remove_duplicate=False)] + ) + existing_keys.update( + OrderedSet([name for name, val in mod.named_buffers(remove_duplicate=False)]) + ) + + def find_smallest_i(graph: fx.Graph, prefix: str) -> int: + i = 0 + for node in graph.nodes: + if node.op == "get_attr" and node.target.startswith(prefix): + if len(node.target) > len(prefix): + post_fix = node.target.split(prefix)[-1] + if post_fix.isdigit(): + i = max(i, int(post_fix)) + for key in existing_keys: + if key.startswith(prefix): + if len(key) > len(prefix): + post_fix = key.split(prefix)[-1] + if post_fix.isdigit(): + i = max(i, int(post_fix)) + return i + 1 + + for node in gm.graph.nodes: + if node.op == "get_attr": + target_name = node.target + if not target_name.startswith( + "_tensor_constant" + ) and not target_name.startswith("_torchbind_obj"): + continue + + if not hasattr(mod, target_name): + continue + gm_target = attrgetter(target_name)(gm) + model_target = attrgetter(target_name)(mod) + if ( + torch.equal(gm_target, model_target) + and gm_target.dtype == model_target.dtype + ): + continue + + prefix = ( + "_tensor_constant" + if target_name.startswith("_tensor_constant") + else "_torchbind_obj" + ) + new_id = find_smallest_i(gm.graph, prefix) + new_target_name = f"{prefix}{new_id}" + node.target = new_target_name + setattr(gm, new_target_name, gm_target) + existing_keys.add(new_target_name) + + +def _unlift_graph( + mod: GraphModule, gm: GraphModule, graph_signature: GraphSignature +) -> GraphModule: + from torch.export.unflatten import _assign_attr, _AttrKind + + _resolve_name_collision(mod, gm) + + state_dict: dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {} + for name, param in mod.named_parameters(remove_duplicate=False): + state_dict[name] = param + _assign_attr( + param, + gm, + name, + attr_kind=_AttrKind.PARAMETER, + ) + for name, buffer in mod.named_buffers(remove_duplicate=False): + state_dict[name] = buffer + _assign_attr( + buffer, + gm, + name, + attr_kind=_AttrKind.BUFFER, + ) + + placeholder_nodes = gm.graph.find_nodes(op="placeholder") + lifted_inputs: list[Optional[FQN]] = [] + + # In AOTI, module parameters and buffers are not lifted as graph inputs. + # As a result, mutation to buffers has side effect which makes their initial + # values different from Eager. So we clone them here as a copy. + # We are not cloning for parameters, although it will be needed if we want to + # support training. + for node in placeholder_nodes: + node_name = node.name + if node_name in graph_signature.inputs_to_parameters: + parameter_name = graph_signature.inputs_to_parameters[node_name] + lifted_inputs.append(parameter_name) + elif node_name in graph_signature.inputs_to_buffers: + buffer_name = graph_signature.inputs_to_buffers[node_name] + lifted_inputs.append(buffer_name) + gm.meta[get_cloned_parameter_buffer_name(buffer_name)] = ( + clone_preserve_strides(state_dict[buffer_name]) + ) + else: + assert node_name in graph_signature.user_inputs + lifted_inputs.append(None) + + from torch.export._unlift import _unlift + + outputs = list(gm.graph.nodes)[-1].args[0] + mutated_outputs = [] + buffer_mutations = graph_signature.buffers_to_mutate + user_input_mutations = graph_signature.user_inputs_to_mutate + output_tokens = graph_signature.output_tokens + for idx, out in enumerate(outputs): + value: Optional[Union[FQN, GraphInputName]] = None + + if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): + if out.name in buffer_mutations: + value = buffer_mutations[out.name] + elif out.name in user_input_mutations: + value = user_input_mutations[out.name] + + mutated_outputs.append(value) + + unlifted_gm = _unlift( + gm, + lifted_inputs, + mutated_outputs, + pytree.LeafSpec(), + None, + state_dict, + {}, + ) + return unlifted_gm + + +def _get_subgraph_names( + gm: GraphModule, skip_invoke_subgraph: bool = False +) -> Generator[str, None, None]: + all_subgraph_names: OrderedSet[str] = OrderedSet( + x.target for x in gm.graph.find_nodes(op="get_attr") + ) + fx_subgraph_names: OrderedSet[str] = OrderedSet() + for child_name, child_module in gm.named_children(): + # Sometimes an owning_module can have unused children. Skip them + # by checking them from get_attr node targets. + if child_name in all_subgraph_names and isinstance( + child_module, torch.fx.GraphModule + ): + fx_subgraph_names.add(child_name) + + if skip_invoke_subgraph: + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_subgraph + ): + fx_subgraph_names.discard(node.args[0].target) + + yield from fx_subgraph_names + + +def _recursive_pre_grad_passes( + gm: GraphModule, + example_inputs: Sequence[InputType], +) -> GraphModule: + with dynamo_timed( + "_recursive_pre_grad_passes", + log_pt2_compile_event=True, + dynamo_compile_column_us="pre_grad_pass_time_us", + ): + add_passes = config.add_pre_grad_passes + remove_passes = config.remove_pre_grad_passes + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + # as we don't have recursive example inputs, passing empty set here + new_subgraph = _recursive_pre_grad_passes(subgraph, ()) + setattr(gm, subgraph_name, new_subgraph) + return pre_grad_passes(gm, example_inputs, add_passes, remove_passes) + + +def _recursive_joint_graph_passes( + gm: GraphModule, skip_invoke_subgraph: bool = False +) -> None: + with dynamo_timed( + "_recursive_joint_graph_passes", + log_pt2_compile_event=True, + dynamo_compile_column_us="joint_graph_pass_time_us", + ): + # invoke_subgraph already runs the _recursive_joint_graph_passes. In + # AOTAutograd, `run_joint_graph_passes_on_hops` partitions the + # invoke_subgraph HOP before calling the partitioner on the outer graph. + # AOTAutograd has access to partition_fn, which internally calls the + # `_recursive_joint_graph_passes` for the subgraph. So, skip recursing + # skip_invoke_subgraph. + for subgraph_name in _get_subgraph_names(gm, skip_invoke_subgraph): + subgraph = getattr(gm, subgraph_name) + _recursive_joint_graph_passes(subgraph, skip_invoke_subgraph) + joint_graph_passes(gm) + + +def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) -> None: + with dynamo_timed( + "_recursive_post_grad_passes", + log_pt2_compile_event=True, + dynamo_compile_column_us="post_grad_pass_time_us", + ): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + _recursive_post_grad_passes(subgraph, is_inference) + post_grad_passes(gm, is_inference) + + +def split_const_gm( + gm: GraphModule, + skip_constructor: bool = True, + lifted_constant_names: Optional[list[str]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> tuple[GraphModule, dict[str, int]]: + """ + This function takes an GraphModule input "gm". + The gm will be split into 2 components, + 1) const_gm, which consists the subgraph of gm that can be constant folded. + 2) gm (being inplace modified,) which returns the graph after constant folding. + + If an additional "lifted_constants" argument is passed in, we will assume the gm has + been lifted and run the transformation accordingly. + + When a "skip_folding_node_fn" callback is passed, we will skip constant folding on + the nodes for which the callback returns True. + + const_output_index is a mapping of corresponding node name from gm to the + output index of const_gm. + Returns (const_gm, const_output_index) + """ + from torch._inductor.constant_folding import ( + CONST_MODULE_TAG, + META_TAG, + MODULE_TAG, + replace_node_with_constant, + run_and_get_constant_graph, + ) + + const_gm = run_and_get_constant_graph( + gm, skip_constructor, lifted_constant_names, skip_folding_node_fn + ) + const_result = const_gm() if lifted_constant_names is None else None + + const_outputs = { + x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0]) + } + + to_erase_node = [] + to_replace_node = [] + const_output_index = {} + for node in gm.graph.nodes: + if node.name in const_outputs: + to_replace_node.append(node) + elif node.meta[META_TAG] == CONST_MODULE_TAG and node.op != "placeholder": + to_erase_node.append(node) + + for node in to_replace_node: + new_const_name = "_FOLDED_CONST_" + node.name + replace_node_with_constant( + gm, + node, + ( + const_result[const_outputs[node.name]] # type:ignore[index] + if lifted_constant_names is None + else None + ), + new_const_name, + ) + const_output_index[new_const_name] = const_outputs[node.name] + for node in to_erase_node[::-1]: + if node.users: + for n in node.users: + assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty." + else: + gm.graph.erase_node(node) + gm.recompile() + + return const_gm, const_output_index + + +def is_tf32_warning_applicable(gm: GraphModule) -> bool: + aten = torch.ops.aten + tf32_ops = OrderedSet( + [ + aten.mm.default, + aten.addmm.default, + aten.bmm.default, + aten.baddbmm.default, + ] + ) + for target in tf32_ops: + for node in gm.graph.find_nodes(op="call_function", target=target): + if ( + isinstance(node.meta.get("val", None), torch.Tensor) + and node.meta["val"].dtype == torch.float32 + and node.meta["val"].device.type == "cuda" + ): + return True + return False + + +def maybe_disable_comprehensive_padding( + example_inputs: Sequence[InputType], +) -> AbstractContextManager[None, None]: + """ + For CPU backend, enable comprehensive padding causes some unit tests + fail due to changing number of generated kernels. Skip for now. + """ + has_gpu = any( + is_gpu(t.device.type) for t in example_inputs if isinstance(t, torch.Tensor) + ) + + if config.disable_padding_cpu and config.comprehensive_padding and not has_gpu: + perf_hint_log.info("Skip comprehensive padding on CPU") + return config.patch(comprehensive_padding=False) + elif config.aot_inductor.use_runtime_constant_folding: + perf_hint_log.info( + "Skip comprehensive padding for use_runtime_constant_folding" + ) + return config.patch(comprehensive_padding=False) + else: + return contextlib.nullcontext() + + +def maybe_disable_graph_partition( + cpp_wrapper: bool, aot_mode: bool +) -> AbstractContextManager[None, None]: + """ + graph partition does not support cpp_wrapper and aot_mode yet. + """ + if cpp_wrapper or aot_mode: + return config.patch(graph_partition=False) + else: + return contextlib.nullcontext() + + +def fake_tensor_prop( + gm: GraphModule, + example_inputs: Sequence[InputType], + force_allow_non_fake_inputs: bool = False, +) -> torch._subclasses.FakeTensorMode: + """ + If we can not detect fake mode from the context of inputs, create one. + + The created fake mode will be returned. + """ + # Ensure that decomps that support symbolic shapes are used + with enable_python_dispatcher(): + fake_mode = detect_fake_mode(example_inputs) + if not fake_mode: + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) + else: + ctx = ( + contextlib.nullcontext() + if not force_allow_non_fake_inputs + else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) + ) + with ctx: # type: ignore[attr-defined] + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( + *example_inputs + ) + + return fake_mode + + +# pass config dict back to user +def get_patched_config_dict( + config_patches: Optional[Union[str, dict[str, Any]]] = None, +) -> dict[str, Any]: + with config.patch(config_patches): + return config.get_config_copy() + + +@contextlib.contextmanager +def with_fresh_cache_if_config() -> Generator[None, None, None]: + if config.force_disable_caches: + # Don't delete the cache dir because it has to survive beyond the + # compile_fx call. Let's put the temp dirs under the default cache + # dir so they're easier to locate. + with fresh_cache(dir=cache_dir(), delete=False): + yield + else: + yield + + +class _CompileFxKwargs(TypedDict, total=False): + cudagraphs: Optional[BoxedBool] + static_input_idxs: Sequence[int] + is_backward: bool + graph_id: Optional[int] + cpp_wrapper: bool + aot_mode: bool + is_inference: bool + layout_opt: Optional[bool] + extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] + boxed_forward_device_index: Optional[BoxedDeviceIndex] + + +class _CompileFxCallable(Protocol): + def __call__( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargs], + ) -> OutputCode: ... + + +def compile_fx_inner( + gm: GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargs], +) -> OutputCode: + kwargs.setdefault("cudagraphs", None) + kwargs.setdefault("static_input_idxs", ()) + kwargs.setdefault("is_backward", False) + kwargs.setdefault("graph_id", None) + kwargs.setdefault("cpp_wrapper", False) + kwargs.setdefault("is_inference", False) + kwargs.setdefault("boxed_forward_device_index", None) + kwargs.setdefault("layout_opt", None) + kwargs.setdefault("extern_node_serializer", None) + + # Need with_fresh_cache_if_config for compile_fx_inner even if we already have one for + # compile_fx. The reason is the compilation for backward graph may happen after + # compile_fx return and we may want to use the _LazyGraphModule for compiling + # the backward graph as well. + with contextlib.ExitStack() as stack: + stack.enter_context(torch.utils._python_dispatch._disable_current_modes()) + stack.enter_context(_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)) + stack.enter_context( + dynamo_utils.dynamo_timed( + "compile_fx_inner", + phase_name="inductor_compile", + log_pt2_compile_event=True, + log_waitcounter=True, + waitcounter_name_override="compile_inductor", + dynamo_compile_column_us="inductor_cumulative_compile_time_us", + ) + ) + stack.enter_context(with_fresh_cache_if_config()) + stack.enter_context(DebugContext()) + CompileEventLogger.pt2_compile( + "inductor_compile", + is_backward=kwargs["is_backward"], + ) + return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( + gm, + example_inputs, + **kwargs, + ) + + +@time_and_log(attr="compilation time (in seconds)") +def _compile_fx_inner( + gm: GraphModule, + example_inputs: Sequence[InputType], + **graph_kwargs: Unpack[_CompileFxKwargs], +) -> OutputCode: + """ + Inductor API that compiles a single graph. + + If you change the argument list for this function, make sure you + also update the call to save_args_for_compile_fx_inner below accordingly. + """ + aot_mode: bool = V.aot_compilation + + # Clean up Compiled Triton Kernels per inductor compile, as the future objects + # may not be valid for use after they are run/autotuned + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + + if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: + # trigger the real recompilation for _LazyGraphModule before returning + # the forward method. + from torch._dynamo.utils import CompileEventLogLevel + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(gm) + compile_id = torch._guards.CompileContext.current_compile_id() + CompileEventLogger.log_instant_event( + "backward no-op", + metadata={"compile_id": compile_id}, + log_level=CompileEventLogLevel.PT2_COMPILE, + ) + + return make_boxed_func(gm.forward) + + static_input_idxs: Sequence[int] = graph_kwargs.setdefault("static_input_idxs", ()) + static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs) + inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) + + assert isinstance(next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)), ( + f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" + ) + + if graph_kwargs.get("cudagraphs") is None: + graph_kwargs["cudagraphs"] = BoxedBool(config.triton.cudagraphs) + if config.save_args: + save_args_for_compile_fx_inner( + gm, + example_inputs, + **graph_kwargs, + ) + + start = time.time() + + fx_graph_remote_cache = should_use_remote_fx_graph_cache() + + # Check if the registered backend(s) support caching. + init_backend_registration() + backends_support_caching = all( + backend.supports_caching + for backend in ( + get_wrapper_codegen_for_device(device.type, config.cpp_wrapper) + for device in get_all_devices(gm) + ) + if backend is not None + ) + + with dynamo_timed( + "fx_codegen_and_compile", log_pt2_compile_event=True, log_waitcounter=True + ): + use_cache = ( + not config.force_disable_caches + and (config.fx_graph_cache or fx_graph_remote_cache) + and not aot_mode + and backends_support_caching + ) + local = config.fx_graph_cache + remote = fx_graph_remote_cache + set_feature_use("fx_cache", use_cache) + + log.debug( + "FX cache status: use_cache=%s, local=%s, remote=%s, aot_mode=%s, force_disable_caches=%s", + use_cache, + local, + remote, + aot_mode, + config.force_disable_caches, + ) + + # TODO: This is a hack purely to get some info to extract_tensor_metadata_for_cache_key, + # figure out how to not have to modify example inputs + for i, input in enumerate(example_inputs): + if ( + isinstance(input, torch.Tensor) + and is_gpu(input.device.type) + and i in static_input_idxs + ): + input._is_inductor_static = True # type: ignore[attr-defined] + + mb_compiled_graph: Optional[OutputCode] = None + key_info = None + cache_info = None + remote_cache = None + constants = CompiledFxGraphConstantsWithGm(gm) + # TODO: this time will be slightly inconsistent with the one computed + # in prepare_key/load_with_key, dump those settings of "cache_event_time" + start_time = time.time_ns() + + if use_cache: + (key_info, cache_info) = FxGraphCache.prepare_key( + gm, example_inputs, graph_kwargs, inputs_to_check, remote + ) + + # Attempt a cache lookup + if key_info is not None: + key, debug_lines = key_info + log.debug("FX cache key generated: %s", key) + if remote: + remote_cache = FxGraphCache.get_remote_cache() + log.debug("Using remote FX cache") + mb_compiled_graph, cache_info = FxGraphCache.load_with_key( + key, + debug_lines, + example_inputs, + local, + remote_cache, + is_backward=graph_kwargs.get("is_backward", False), + constants=constants, + ) + else: + log.debug("Failed to generate FX cache key") + + # CACHE BYPASS: Compile the graph, don't save it to the cache + # (this can happen either because cache was disabled, or we + # determined the input is uncacheable) + if cache_info is None or cache_info["cache_state"] == "bypass": + assert mb_compiled_graph is None + log.debug( + "FX cache bypass reason: %s", + ( + cache_info.get("cache_bypass_reason", "unknown") + if cache_info is not None + else "FX cache disabled or key generation failed" + ), + ) + mb_compiled_graph = fx_codegen_and_compile( + gm, example_inputs, inputs_to_check, **graph_kwargs + ) + + # CACHE MISS: Compile the graph and save to cache + elif cache_info["cache_state"] == "miss": + assert mb_compiled_graph is None + assert key_info is not None + log.debug("FX cache miss, compiling and saving to cache") + TritonBundler.begin_compile() + try: + mb_compiled_graph = fx_codegen_and_compile( + gm, example_inputs, inputs_to_check, **graph_kwargs + ) + assert mb_compiled_graph is not None + mb_compiled_graph._time_taken_ns = time.time_ns() - start_time + cache_key, debug_lines = key_info + mb_compiled_graph._fx_graph_cache_key = cache_key + mb_compiled_graph._fx_graph_cache_debug_lines = debug_lines + ( + triton_bundle, + triton_bundler_meta, + ) = TritonBundler.collect() + mb_compiled_graph.set_triton_bundle(triton_bundle) + except (ShortenTraceback, SkipFrame): + raise + except Exception as e: + raise InductorError(e, currentframe()).with_traceback( + e.__traceback__ + ) from None + finally: + TritonBundler.end_compile() + if triton_bundler_meta is not None: + cache_info["triton_bundler_meta"] = str(triton_bundler_meta) + cache_info["time_taken_ns"] = mb_compiled_graph._time_taken_ns + log.debug("Saving compiled graph to FX cache with key: %s", cache_key) + FxGraphCache._save_graph( + cache_key, + mb_compiled_graph, + example_inputs, + local, + remote_cache, + ) + + # CACHE HIT: not much to really do, just make sure the cache key + # is recorded on the graph + else: + assert cache_info["cache_state"] == "hit" + assert mb_compiled_graph is not None + assert key_info is not None + (cache_key, debug_lines) = key_info + log.debug("FX cache hit with key: %s", cache_key) + mb_compiled_graph._fx_graph_cache_key = cache_key + mb_compiled_graph._fx_graph_cache_debug_lines = debug_lines + + assert mb_compiled_graph is not None + compiled_graph = mb_compiled_graph + + # Logging and observability: we log a single chromium event + # and a tlparse log for every cache action. + # In the event of a bypass, we also logged to the remote table earlier + # with log_cache_bypass. + cache_state = ( + cache_info["cache_state"] if cache_info is not None else "disabled" + ) + # Here for grepping: + # fx_graph_cache_hit + # fx_graph_cache_miss + # fx_graph_cache_bypass + # fx_graph_cache_disabled + CompileEventLogger.instant( + f"fx_graph_cache_{cache_state}", + metadata=cache_info or {}, + time_ns=start_time, + ) + # Add event data about cache hits/miss + # TODO: add remote cache get/put timings here too + CompileEventLogger.pt2_compile( + "inductor_compile", + cache_state=cache_state, + cache_event_time=start_time, + key=cache_info.get("key") if cache_info else None, + components=cache_info.get("components") if cache_info else None, + cache_bypass_reason=( + cache_info.get("cache_bypass_reason") + if cache_info + else "cache not enabled" + ), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) + + # Don't clog up the main tlparse output with disabled cache + if cache_info is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"fx_graph_cache_{cache_state}", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + compiled_graph.post_compile(example_inputs, constants, graph_kwargs) + + log.debug("FX codegen and compilation took %.3fs", time.time() - start) + + # Dump provenance artifacts for debugging trace + provenance_info = V.debug.log_inductor_triton_kernel_to_post_grad_node_info() + # provenance_info might be None if config.trace.enabled is not set + if provenance_info: + ( + debug_info, + node_mappings, + ) = provenance_info + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_generated_kernel_to_post_grad_nodes", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(debug_info), + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_node_mappings", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(node_mappings), + ) + + # This message is for printing overview information of inductor mm counts, shapes,etc after lowering + if log.isEnabledFor(logging.INFO): + mm_table_data = [] + for key, value in counters["aten_mm_info"].items(): + parts = key.split("_") + if len(parts) < 3: + # Unexpected format, show as-is + mm_table_data.append([key, "-", "?", "?", "?", value]) + continue + + # Determine if this is a batched operation by checking the operation name + name = "_".join(parts[:-4]) if len(parts) >= 4 else "_".join(parts[:-3]) + is_batched = name.endswith(("bmm", "baddbmm")) + + if is_batched and len(parts) >= 4: + # Batched operation: last 4 parts are batch, m, n, k + batch, m, n, k = parts[-4:] + name = "_".join(parts[:-4]) + mm_table_data.append([name, batch, m, n, k, value]) + else: + # Non-batched operation: last 3 parts are m, n, k + m, n, k = parts[-3:] + name = "_".join(parts[:-3]) + mm_table_data.append([name, "-", m, n, k, value]) + + log.info("Overview info of inductor aten mms: ") + log.info( + "{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format( # noqa: G001 + "Name", "B", "M", "N", "K", "Count" + ) + ) + log.info("-" * 130) + for row in mm_table_data: + log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001 + log.info("-" * 130) + + # Not strictly necessary, but good to clean up straggling futures + # that are unused to reclaim memory. + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + + _step_logger()( + logging.INFO, + "torchinductor done compiling " + f"{'BACKWARDS' if graph_kwargs['is_backward'] else 'FORWARDS'} " + f"graph {graph_kwargs['graph_id']}", + ) + return compiled_graph + + +class _FxCompileStat: + # Count of successful compiles of this type + codegen_and_compile: int = 0 + + def __repr__(self) -> str: + return f"codegen_and_compile: {self.codegen_and_compile}" + + +class FxCompile(ABC): + """ + An FxCompile represents a mechanism that can turn a GraphModule into an + OutputCode. + """ + + # Some stats for logging/debugging + _compile_stats: dict[type[FxCompile], _FxCompileStat] = defaultdict(_FxCompileStat) + + # TODO: We should probably eventually add some kind of async version of this + # so we can kick off a compile and then go do other things - but we'll need + # to know what kind of API we want for that first. + @abstractmethod + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: ... + + @classmethod + def _reset_stats(cls) -> None: + cls._compile_stats.clear() + + +class _InProcessFxCompile(FxCompile): + @override + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: + """ + Generates the OutputCode from the GraphModule and example_inputs. + """ + # Sorry about the mess, we need graph_kwargs to continue to be able + # to propagate it further on + # TODO: _CompileFxKwargs actually has stronger types than in the + # signature, need to tighten it up + + assert "cudagraphs" in graph_kwargs and graph_kwargs["cudagraphs"] is not None + cudagraphs: BoxedBool = graph_kwargs["cudagraphs"] + static_input_idxs: Sequence[int] = graph_kwargs.get("static_input_idxs", ()) + is_backward: bool = graph_kwargs.get("is_backward", False) + graph_id: Optional[int] = graph_kwargs.get("graph_id", None) + cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False) + aot_mode: bool = V.aot_compilation + is_inference: bool = graph_kwargs.get("is_inference", False) + extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = ( + graph_kwargs.get("extern_node_serializer", None) + ) + + with ( + _WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(), + dynamo_utils.preserve_rng_state(), + ): + if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None: + import time + + log.warning( + "Sleeping for %s since sleep_sec_TESTING_ONLY is set", sleep_sec + ) + time.sleep(sleep_sec) + + if is_tf32_warning_applicable(gm): + _warn_tf32_disabled() + + inductor_counters = counters["inductor"].copy() + + # lift the maximum depth of the Python interpreter stack + # to adapt large/deep models + sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000)) + + _step_logger()( + logging.INFO, + "torchinductor compiling " + f"{'BACKWARDS' if is_backward else 'FORWARDS'} " + f"graph {graph_id}", + ) + + fd = io.StringIO() + torch._dynamo.repro.after_aot.save_graph_repro( + fd, gm, example_inputs, "inductor", save_dir=None + ) + runnable_graph_str = fd.getvalue() + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_runnable", + "encoding": "string", + }, + payload_fn=lambda: runnable_graph_str, + ) + + V.debug.fx_graph(gm, example_inputs) + # TODO: Should we actually dump this? It should be redundant with the aot + # structured logs... + # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False)) + + shape_env = shape_env_from_inputs(example_inputs) + + # Convert view to reshape in the graph. This is necessary primarily for + # layout optimization. Do it unconditionally for uniformity. + # + # It's needed because when we do layout optimization, an contiguous tensor + # in eager mode may becomes a channels last tensor. A view op previously + # can be applied to the contiguous tensor may not be able to be applied + # on the channels tensor any more. An error like + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + # will be printed. + # + # Replace view op to reshape op in this case. + # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this. + # + # Also this has to be done before FakeTensorProp below to avoid the failed + # .view() call. + view_to_reshape(gm) + + with dynamo_timed( + "additional_fake_tensor_prop", log_pt2_compile_event=True + ): + # It is safe to run FakeTensorProp under no_grad because by the time + # we're in inductor, we assume that AOTAutograd has already "taken care" + # of autograd, so there should be no more autograd-related API's in the + # graph. + with torch.no_grad(): + fake_mode = fake_tensor_prop(gm, example_inputs) + + _recursive_record_original_output_strides(gm) + + # pattern matcher passes might not preserve striding information + # on node.meta["val"]. if in the future we rely on these being + # correct we will need to fix. + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_post_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + with V.set_fake_mode(fake_mode): + # has some issues with memory in training + cuda_context = get_cuda_device_context(gm) + with cuda_context: + _recursive_post_grad_passes(gm, is_inference=is_inference) + V.debug.fx_graph_transformed(gm, example_inputs) + post_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "AFTER POST GRAD", + gm, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + # We're printing the graph to be used as a cache key - so a + # printer which is a little less readable but faster is + # appropriate. + inductor_post_grad_graph_str = gm.print_readable( + print_output=False, + include_stride=True, + include_device=True, + fast_sympy_print=True, + ) + # "after_post_grad_graph" is used in inductor provenance + # tracking highlighter front-end. + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "after_post_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: inductor_post_grad_graph_str, + ) + if config.trace.enabled: + provenance_tracking_json = ( + torch.fx.traceback.get_graph_provenance_json(gm.graph) + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_post_to_pre_grad_nodes", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(provenance_tracking_json), + ) + torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( + provenance_tracking_json + ) + + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + # TODO: Remove this when 3.9 is no longer supported + if sys.version_info < (3, 10): + num_graph_breaks = sum(counters["graph_break"].values()) + else: + num_graph_breaks = counters["graph_break"].total() + CompileEventLogger.compilation_metric( + overwrite=True, num_graph_breaks=num_graph_breaks + ) + if config.is_fbcode(): + try: + log_optimus_to_scuba( + extra_logging={ + "pt2_configs": str(get_patched_config_dict()) + } + ) + except Exception: + # TODO(T216453900): need to work around for now to support vllm + # See details in vllm/compilation/pass_manager.py. + log.warning("failed to log pt2_configs") + + with ( + V.set_fake_mode(fake_mode), + maybe_disable_comprehensive_padding(example_inputs), + maybe_disable_graph_partition(cpp_wrapper, aot_mode), + ): + const_output_index = None + const_graph = None + const_wrapper_code = None + const_kernel_code = None + + if aot_mode and config.aot_inductor.use_runtime_constant_folding: + # torchbind objects have name that starts with _torchbind_obj + # See caffe2/torch/fx/_symbolic_trace.py?lines=406 + const_gm, const_output_index = split_const_gm( + gm, + skip_folding_node_fn=lambda node: node.op == "get_attr" + and isinstance(node.target, str) + and ( + node.target.startswith("_torchbind_obj") + or isinstance(node.meta.get("val", None), FakeScriptObject) + ), + ) + + const_graph = GraphLowering( + const_gm, + example_inputs=[], + shape_env=shape_env, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + is_backward=is_backward, + is_const_graph=True, + ) + with V.set_graph_handler(const_graph): + assert cpp_wrapper, "AOT mode only supports C++ wrapper" + const_graph.run() + const_wrapper_code, const_kernel_code = ( + const_graph.codegen_with_cpp_wrapper() + ) + + graph = GraphLowering( + gm, + # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning. + # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass, + # we currently use fake tensors and defake them later. + example_inputs=example_inputs, + shape_env=shape_env, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + is_backward=is_backward, + const_output_index=const_output_index, + const_wrapper_code=( + const_wrapper_code.value if const_wrapper_code else None + ), + const_kernel_code=( + const_kernel_code.value if const_kernel_code else None + ), + const_module=const_graph, + inputs_to_check=inputs_to_check, + ) + metrics_helper = metrics.CachedMetricsHelper() + + # We are going to start code generating runtime asserts, so make sure + # you don't start adding new ones in the lowering process + graph.freeze_runtime_asserts() + with V.set_graph_handler(graph): + graph.run(*example_inputs) + output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = [] + if graph.graph_outputs is not None: + # We'll put the output strides in the compiled graph so we + # can later return them to the caller via TracingContext + p = SymExprPrinter() + for out in graph.graph_outputs: + if ( + isinstance(out, IRNode) + and out.has_tensor_output() + and len(free_unbacked_symbols(out.get_stride())) == 0 + ): + # Convert to string for eval on the load path + output_strides.append( + tuple(p.doprint(s) for s in out.get_layout().stride) + ) + else: + output_strides.append(None) + + _check_triton_bf16_support(graph) + + # TODO: The switching between AOT mode and not here is a bit + # messy, but it's localized to the block of code below so I'm + # not going to touch it for now + + compiled_fn: Any + compiled_fn_runner = None + with dynamo_timed( + "GraphLowering.compile_to_fn", log_pt2_compile_event=True + ): + if graph.aot_mode: + from .codecache import AotCodeCompiler + + assert graph.cpp_wrapper, ( + "AOT mode only supports C++ wrapper" + ) + wrapper_code, kernel_code = graph.codegen_with_cpp_wrapper() + output_code_log.debug( + "Output wrapper code: \n%s", wrapper_code.value + ) + if kernel_code.value: + output_code_log.debug( + "Output kernel code:\n%s", kernel_code.value + ) + + serialized_extern_kernel_nodes = None + if graph.extern_kernel_nodes: + serialized_extern_kernel_nodes = ( + graph.extern_node_serializer( + graph.extern_kernel_nodes + ) + ) + output_code_log.debug( + "Serialized Extern Kernel Nodes: \n%s", + serialized_extern_kernel_nodes, + ) + + with dynamo_timed( + "AotCodeCompiler.compile", log_pt2_compile_event=True + ): + # Directly return the file path with the compiled code + compiled_fn = AotCodeCompiler.compile( + graph, + wrapper_code.value, + kernel_code.value, + serialized_extern_kernel_nodes, + device_type=graph.device_type, + additional_files=[ + *dict.fromkeys( + graph.wrapper_code.additional_files + + ( + const_graph.wrapper_code.additional_files + if const_graph + else [] + ) + ) + ], + ) + else: + compiled_module = graph.compile_to_module() + compiled_fn = compiled_module.call + compiled_fn_runner = getattr( + compiled_module, "runner", None + ) + + if inductor_metrics_log.isEnabledFor(logging.INFO): + num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() + metrics.num_bytes_accessed += num_bytes + metrics.node_runtimes += node_runtimes + metrics.nodes_num_elem += nodes_num_elem + inductor_metrics_log.info( + "Graph Metrics:\n%s", + { + "num_bytes_accessed": num_bytes, + "nodes_num_elem": nodes_num_elem, + "node_runtimes": node_runtimes, + }, + ) + + if ( + cudagraphs + and config.triton.cudagraph_skip_dynamic_graphs + and not V.graph.disable_cudagraphs_reason + and torch._inductor.utils.any_is_symbolic(*example_inputs) + ): + stack_trace = None + for node in gm.graph.nodes: + meta_val = node.meta.get("val", None) + if ( + node.op == "placeholder" + or not isinstance(meta_val, torch.Tensor) + or not torch._inductor.utils.any_is_symbolic(meta_val) + ): + continue + + if stack_trace := node.meta.get("stack_trace", None): + break + disable = "graph with symbolic shapes inputs and config.triton.cudagraph_skip_dynamic_graphs=True." + if stack_trace: + disable = f"{disable} Found from {stack_trace}\n" + else: + disable = f"{disable}\n" + V.graph.disable_cudagraphs_reason = disable + + if cudagraphs and not V.graph.disable_cudagraphs_reason: + maybe_incompat_node = get_first_incompatible_cudagraph_node(gm) + if maybe_incompat_node: + disable = f"disabling cudagraphs due to incompatible op {maybe_incompat_node.target}" + if stack_trace := maybe_incompat_node.meta.get( + "stack_trace", None + ): + disable = f"{disable} Found from {stack_trace}\n" + V.graph.disable_cudagraphs_reason = disable + + if V.aot_compilation: + assert isinstance(compiled_fn, (str, list)) + return CompiledAOTI(compiled_fn) + + # TODO: Hoist this above V.aot_compilation + if cudagraphs and not V.graph.disable_cudagraphs_reason: + from torch._inductor.cudagraph_utils import ( + check_lowering_disable_cudagraph, + ) + + V.graph.disable_cudagraphs_reason = ( + check_lowering_disable_cudagraph( + V.graph.device_node_mapping + ) + ) + + self._compile_stats[type(self)].codegen_and_compile += 1 + + return CompiledFxGraph( + compiled_fn, + graph, + gm, + output_strides, + V.graph.disable_cudagraphs_reason, + metrics_helper.get_deltas(), + counters["inductor"] - inductor_counters, + cudagraphs, + example_inputs, + static_input_idxs, + graph_kwargs, + inputs_to_check, + runnable_graph_str, + inductor_post_grad_graph_str, + compiled_fn_runner, + ) + + +def fx_codegen_and_compile( + gm: GraphModule, + example_inputs: Sequence[InputType], + # This is derivable from the other inputs to this function, but we pass it + # in explicitly because it's nontrivial to compute + inputs_to_check: Sequence[int], + **graph_kwargs: Unpack[_CompileFxKwargs], +) -> OutputCode: + scheme: FxCompile + + if fx_compile_mode == FxCompileMode.NORMAL: + scheme = _InProcessFxCompile() + elif fx_compile_mode == FxCompileMode.SERIALIZE: + from .compile_fx_ext import _DebugSerdeFxCompile + + scheme = _DebugSerdeFxCompile() + elif fx_compile_mode == FxCompileMode.SUBPROCESS: + from .compile_fx_subproc import _SubprocessFxCompile + + scheme = _SubprocessFxCompile() + + if fx_compile_async: + from .compile_fx_async import _AsyncFxCompile + from .compile_fx_ext import _OutOfProcessFxCompile + + assert isinstance(scheme, _OutOfProcessFxCompile), ( + "async is only valid with an out-of-process compile mode" + ) + scheme = _AsyncFxCompile(scheme) + + return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) + + +def get_input_idxs_to_check( + inputs: Sequence[InputType], + static_input_idxs: Sequence[int], +) -> Sequence[int]: + """ + This function runs at compile time, and generates a list of indices for which we + might need to do a copy to preserve alignment requirements. + """ + ids_to_check = [] + + for i, input in enumerate(inputs): + if not isinstance(input, torch.Tensor): + # non-tensors don't need alignment + continue + if not is_gpu(input.device.type): + # right now we only care for gpu tensors + continue + with maybe_get_suppress_shape_guards_ctx(): + # suppress guards so that tensor_is_aligned and should_assume_input_aligned + # do not add guards on input's storage offset + if i in static_input_idxs and tensor_is_aligned(input): + continue + if not should_assume_input_aligned(input): + continue + + # if we get here, then + # (a) our triton code assumes that the input is aligned + # (b) we can't be sure ahead of time that the input will actually be aligned. + # therefore, at runtime, we'll need to check that the input is aligned + # (and if not, clone it to make it aligned.) + ids_to_check.append(i) + + return ids_to_check + + +def cudagraphify( + model: Callable[..., Any], + static_input_idxs: Sequence[int] = (), + *, + device_index: int, + stack_traces: list[Optional[str]], + is_backward: bool, + is_inference: bool, + constants: tuple[torch.Tensor, ...] = (), + placeholders: Sequence[PlaceholderInfo] = (), + mutated_input_idxs: tuple[int, ...] = (), +) -> Callable[..., Any]: + from torch._inductor.cudagraph_trees import ( + cudagraphify_impl as new_cudagraphify_impl, + ) + + cudagraphify_fn: Callable[..., Any] + if config.triton.cudagraph_trees: + cudagraphify_fn = functools.partial( + new_cudagraphify_impl, + device_index=device_index, + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=constants, + placeholders=placeholders, + mutated_input_idxs=mutated_input_idxs, + compile_id=torch._guards.CompileContext.current_compile_id(), + ) + else: + cudagraphify_fn = cudagraphify_impl + + compiled_fn = None + + def run(new_inputs: Sequence[InputType]) -> Any: + nonlocal compiled_fn + if compiled_fn is None: + with dynamo_utils.preserve_rng_state(): + compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) # type: ignore[arg-type] + return compiled_fn(new_inputs) # type: ignore[arg-type] + + return run + + +def static_input(x: torch.Tensor) -> torch.Tensor: + """ + Copy and input while preserving strides + """ + return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device) + + +def index_expanded_dims_and_copy_( + dst: torch.Tensor, + src: torch.Tensor, + expanded_dims: list[int], +) -> None: + "Index into expanded dimensions of both dst and src then copy_" + dst = index_expanded_dims(dst, expanded_dims) + src = index_expanded_dims(src, expanded_dims) + dst.copy_(src) + + +def cudagraphify_impl( + model: Callable[..., Any], + inputs: list[torch.Tensor], + static_input_idxs: Sequence[int] = (), +) -> Callable[[list[InputType]], Any]: + """ + Assumes inputs[static_input_idxs[i]] are always the same memory address + """ + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] + static_input_idxs: OrderedSet[int] = OrderedSet( + remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] + ) + copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type] + + assert isinstance(inputs, list) + + inps_expanded_dims = [ + get_expanded_dims(x) if idx not in static_input_idxs else [] + for idx, x in enumerate(inputs) + ] + + # allocate static tensor inputs + static_inputs = [ + ( + x + if not isinstance(x, torch.Tensor) + else static_input(x) + if idx not in static_input_idxs + else x.detach() + ) + for idx, x in enumerate(inputs) + ] + + # copy over input values for fresh allocations + for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)): + if isinstance(x, torch.Tensor) and idx not in static_input_idxs: + index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims) + + # warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + # copy static_inputs because it will be cleared in model + with torch.cuda.stream(stream): + model(list(static_inputs)) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # record + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"): + static_outputs = model(list(static_inputs)) + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + if config.size_asserts: + + def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]: + assert len(static_inputs) == len(new_inputs) + for idx, (dst, src, expanded_dims) in enumerate( + zip(static_inputs, new_inputs, inps_expanded_dims) + ): + if not isinstance(dst, torch.Tensor): + continue + assert isinstance(src, torch.Tensor) + if idx in static_input_idxs: + assert dst.data_ptr() == src.data_ptr() + else: + # TODO - could make one single op of multiple slices + # and avoid dispatch. + # Could also pre-index the `dst` tensors + index_expanded_dims_and_copy_(dst, src, expanded_dims) + new_inputs.clear() + graph.replay() + return static_outputs + + else: + copy_indices = [ + idx for idx in range(len(static_inputs)) if idx not in static_input_idxs + ] + + def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]: + for idx in copy_indices: + expanded_dims = inps_expanded_dims[idx] + src = new_inputs[idx] + assert isinstance(src, torch.Tensor) + index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims) + new_inputs.clear() + graph.replay() + return static_outputs + + return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet()) + + +def compile_fx_aot( + model_: GraphModule, + example_inputs_: list[InputType], + inner_compile: _CompileFxCallable = compile_fx_inner, + config_patches: Optional[dict[str, str]] = None, +) -> Union[list[Union[str, Weights]], str]: + assert isinstance(model_, GraphModule), model_ + + # [See NOTE] Unwrapping subclasses AOT + unwrap_tensor_subclass_parameters(model_) + + config_patches: dict[str, Any] = ( + {"cpp_wrapper": True} + if config_patches is None + else {**config_patches, "cpp_wrapper": True} + ) + + output_path = config_patches.get( + "aot_inductor.output_path", config.aot_inductor.output_path + ) + + if output_path: + assert not output_path.endswith(".pt2"), ( + "The output path for aot_compile should not have an extension with .pt2 " + "this is for specifying the output path for the .so in AOTInductor. " + "If you would like to package the AOTInductor generated files " + "into a pt2, please call `torch._inductor.aoti_compile_and_package`." + ) + else: + config_patches = { + **config_patches, + "aot_inductor.output_path": code_hash(model_.code), + } + + extern_node_serializer = config_patches.pop("extern_node_serializer", None) + saved_compile_id = model_.meta.get("dynamo_compile_id", None) + saved_compile_context = torch._guards.CompileContext(saved_compile_id) + with ( + V.set_aot_compilation(True), + torch._guards.compile_context(saved_compile_context), + chromium_event_timed( + "compile_fx_aot", + log_pt2_compile_event=True, + reset_event_log_on_exit=True, + ), + get_metrics_context(), + ): + compiled_artifacts = compile_fx( + model_, + example_inputs_, + inner_compile=functools.partial( + inner_compile, + extern_node_serializer=extern_node_serializer, + ), + config_patches=config_patches, + ) + + assert isinstance(compiled_artifacts, CompiledAOTI) + + return compiled_artifacts.filename + + +_graph_counter = count(0) + + +def fw_compiler_freezing( + aot_autograd_model: GraphModule, + aot_example_inputs: Sequence[InputType], + dynamo_model: GraphModule, + num_example_inputs: int, + inner_compile: Callable[..., Any], + cudagraphs: BoxedBool, + graph_id: int, + forward_device: BoxedDeviceIndex, +) -> Callable[[list[object]], Sequence[torch.Tensor]]: + from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze + + # partition_fn won't be called + _recursive_joint_graph_passes(aot_autograd_model) + + layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True) + if layout_opt: + # make sure meta['val'] is properly setup + fake_tensor_prop(aot_autograd_model, aot_example_inputs, True) + convert_conv_weights_to_channels_last(aot_autograd_model) + + opt_model, preserved_arg_indices = freeze( + dynamo_model, + aot_autograd_model, + aot_example_inputs, # type: ignore[arg-type] + ) + + aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices] + + fake_mode = detect_fake_mode(aot_example_inputs) + + # for freezing, all graph outputs should be user visible + *_, model_outputs_node = opt_model.graph.nodes + model_outputs = model_outputs_node.args[0] + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node) + ] + + static_input_idxs = [] + # constant params will be real tensors, not fake + tracing_context = torch._guards.TracingContext.try_get() + unwrapped_args_offsets = [0] + max_offset_idx = 0 + if tracing_context is not None: + assert tracing_context.params_flat_unwrap_subclasses is not None + params_flat_unwrap = tracing_context.params_flat_unwrap_subclasses + max_offset_idx = max(0, len(params_flat_unwrap) - 1) + preserved_indices_params_flat = OrderedSet[int]() + unwrapped_idxs = tracing_context.params_unwrapped_to_flat_index + assert unwrapped_idxs is not None + current_offset = 0 + if len(params_flat_unwrap) > 0: + unwrapped_args_offsets = [] + + for i in range(len(params_flat_unwrap)): + if i not in preserved_arg_indices: + params_flat_unwrap[i] = None + if i > 0 and unwrapped_idxs[i] == unwrapped_idxs[i - 1]: + current_offset += 1 + else: + preserved_indices_params_flat.add(unwrapped_idxs[i]) + unwrapped_args_offsets.append(current_offset) + + # Deallocate wrapped params, if all subelements were deallocated + assert tracing_context.params_flat is not None + for i in range(len(tracing_context.params_flat)): + if i not in preserved_indices_params_flat: + tracing_context.params_flat[i] = None + + if tracing_context.fw_metadata: + static_input_idxs = tracing_context.fw_metadata.static_input_indices + + with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): + optimized_function = inner_compile( + opt_model, + aot_example_inputs, + static_input_idxs=static_input_idxs, + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=True, + boxed_forward_device_index=forward_device, + layout_opt=layout_opt, + ) + + # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper + # that drops constant-ified params + if V.aot_compilation: + return optimized_function + + def wrapper(args: list[object]) -> Sequence[torch.Tensor]: + args_new = [ + args[i - unwrapped_args_offsets[min(i, max_offset_idx)]] + for i in preserved_arg_indices + ] + args.clear() + return optimized_function(args_new) + + wrapper._boxed_call = True # type: ignore[attr-defined] + + return wrapper + + +def get_cpp_wrapper_config() -> dict[str, object]: + if config.triton.cudagraphs: + log_cudagraph_skip_and_bump_counter( + format_default_skip_message("cpp wrapper enabled") + ) + + return { + # Set autotune_at_compile_time to True as default if the option is not explicitly set + "triton.autotune_at_compile_time": ( + config.triton.autotune_at_compile_time + if config.triton.autotune_at_compile_time is not None + else has_triton() + ), + "triton.autotune_cublasLt": False, + "triton.cudagraphs": False, # TODO: to be removed + "triton.store_cubin": True, + } + + +def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]: + """ + Returns a cuda device context manager if there is a single device in the graph + """ + if not torch.cuda.is_available(): + return contextlib.nullcontext() + + cuda_devices: OrderedSet[torch.device] = OrderedSet( + device for device in get_all_devices(gm) if device.type == "cuda" + ) + + return ( + torch.cuda.device(next(iter(cuda_devices))) # type: ignore[return-value] + if len(cuda_devices) == 1 + else contextlib.nullcontext() + ) + + +def compile_fx( + model_: GraphModule, + example_inputs_: Sequence[InputType], + inner_compile: Callable[..., OutputCode] = compile_fx_inner, + config_patches: Optional[dict[str, Any]] = None, + decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None, + ignore_shape_env: bool = False, +) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str], Weights]: + """ + Main entry point for compiling given FX graph. Despite the fact that this + lives in :mod:`torch._inductor`, this function is responsible for calling + into AOT Autograd (and we will eventually get a callback to + ``inner_compile`` to perform actual compilation. In other words, this + function orchestrates end-to-end compilation for the inductor backend when + you use :func:`torch.compile`. + + NB: This function TAKES OWNERSHIP of the input ``model_`` and can potentially + mutate it! Make a copy if you need to preserve the original GraphModule. + """ + + # Some arguments trigger a recursive call to compile_fx. Handle these + # short circuits first, before anything else + + if config_patches: + with config.patch(config_patches): + return compile_fx( + model_, + example_inputs_, + # need extra layer of patching as backwards is compiled out of scope + inner_compile=config.patch(config_patches)(inner_compile), + decompositions=decompositions, + ignore_shape_env=ignore_shape_env, + ) + + # TODO: This probably shouldn't be a recursive call + if config.cpp_wrapper: + with ( + config.patch( + { + "cpp_wrapper": False, # reset to break recursive call to compile_fx + **get_cpp_wrapper_config(), + } + ), + V.set_real_inputs(example_inputs_), + ): + inputs_: Sequence[InputType] = example_inputs_ + + if isinstance(model_, GraphModule): + fake_inputs = [ + node.meta.get("val") + for node in model_.graph.nodes + if node.op == "placeholder" + ] + # Replace non-tensor (constant) inputs with Nones, since these are not being + # used anyways by the graph + fake_inputs = [ + inp if isinstance(inp, torch.Tensor) else None + for inp in fake_inputs + ] + + if any(v is not None for v in fake_inputs): + # Validate devices before switching to fake tensors. + for idx, fi, i in zip(count(), fake_inputs, inputs_): + if fi is not None: + assert isinstance(i, torch.Tensor) + if fi.device != i.device: + raise ValueError( + f"Device mismatch between fake input and example input at position #{idx}: " + f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " + "make sure torch.export() and torch.aot_compile() run on the same device." + ) + inputs_ = fake_inputs # type: ignore[assignment] + from torch._export.non_strict_utils import _fakify_script_objects + + fake_mode = detect_fake_mode(inputs_) + with _fakify_script_objects(model_, inputs_, {}, fake_mode) as ( + patched_mod, + fake_args, + _, + _, + _, + ): + return compile_fx( + patched_mod, + fake_args, + inner_compile=functools.partial(inner_compile, cpp_wrapper=True), + decompositions=decompositions, + ignore_shape_env=ignore_shape_env, + ) + + recursive_compile_fx = functools.partial( + compile_fx, + inner_compile=inner_compile, + decompositions=decompositions, + ignore_shape_env=ignore_shape_env, + ) + + if not graph_returns_tuple(model_): + return make_graph_return_tuple( + model_, + example_inputs_, + recursive_compile_fx, + ) + + if isinstance(model_, GraphModule) and isinstance( + model_.graph._codegen, _PyTreeCodeGen + ): + # this graph is the result of dynamo.export() + return handle_dynamo_export_graph( + model_, + example_inputs_, + recursive_compile_fx, + ) + + # Do the actual work + + with ( + _use_lazy_graph_module(dynamo_config.use_lazy_graph_module), + enable_python_dispatcher(), + torch.fx.traceback.preserve_node_meta(config.trace.enabled), + ): + # Pre-grad passes cannot be run if we weren't given a GraphModule. + # Dynamo will always produce a GraphModule, but this handles cases + # where a user directly passes a plain Module with the intention of + # having AOTAutograd trace it. + # TODO: Get rid of this? + if isinstance(model_, GraphModule): + # "before_pre_grad_graph" is used in inductor provenance + # tracking highlighter front-end. + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_pre_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: model_.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + f"\n\n # graph id: {id(model_.graph)}", + ) + pre_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "BEFORE PRE GRAD", + model_, + include_stride=True, + include_device=True, + colored=True, + ), + ) + torch._inductor.debug._pre_grad_graph_id = id(model_.graph) + + model_ = _recursive_pre_grad_passes(model_, example_inputs_) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "after_pre_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: model_.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + f"\n\n # graph id: {id(model_.graph)}", + ) + + # TODO: Move this before recursive pre-grad passes + # NB: This short circuit never occurs for Dynamo produced graphs + # (which are pre-flattened) + if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_): + return flatten_graph_inputs( + model_, + example_inputs_, + recursive_compile_fx, + ) + + assert not config._raise_error_for_testing + + num_example_inputs = len(example_inputs_) + + # Although cudagraphs may have been enabled via config, various + # conditions (which are tested within the bowels of Inductor) may + # force cudagraphs to be disabled. This mutable box lets us retrieve + # the final determination if cudagraphs actually can be used or not. + cudagraphs = BoxedBool(config.triton.cudagraphs) + + # See [Backward Generation Handling] + forward_device = BoxedDeviceIndex(None) + + # TODO: The modern style is to use CompileId from TracingContext to + # identify Inductor compilation. However, this CompileId cannot + # uniquely identify multiple Inductor compilations that arise from + # DDPOptimizer + graph_id = next(_graph_counter) + + decompositions = ( + decompositions if decompositions is not None else select_decomp_table() + ) + + def fw_compiler_base( + gm: GraphModule, + example_inputs: Sequence[InputType], + is_inference: bool, + ) -> OutputCode: + with dynamo_utils.dynamo_timed("compile_fx..fw_compiler_base"): + if is_inference: + # partition_fn won't be called + _recursive_joint_graph_passes(gm) + + fixed = torch._inductor.utils.num_fw_fixed_arguments( + num_example_inputs, len(example_inputs) + ) + + model_outputs_node = output_node(gm) + if config.keep_output_stride: + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + num_model_outputs = len(model_outputs) + + context = torch._guards.TracingContext.try_get() + # See Note [User Outputs in the inductor graph] + if context is not None and context.fw_metadata and not is_inference: + original_output_start_index = ( + context.fw_metadata.num_mutated_inp_runtime_indices + ) + else: + original_output_start_index = 0 + + if isinstance(model_, GraphModule): + *_, orig_model_outputs_node = model_.graph.nodes + assert orig_model_outputs_node.op == "output" + orig_model_outputs, _ = pytree.tree_flatten( + orig_model_outputs_node.args + ) + num_orig_model_outputs = len(orig_model_outputs) + else: + num_orig_model_outputs = num_model_outputs + + assert num_orig_model_outputs <= num_model_outputs + + # Note [User Outputs in the inductor graph] + # We makes the following assumption + # For inference + # len(orig_model_outputs) == len(model_outputs) + # For training + # len(orig_model_outputs) <= len(model_outputs) + # During training, most of the time the model_outputs starts with + # original module's outputs followed by saved activations. + # But this can be not true if the model have inplace updated tensors. + # AOTAutograd will make those tensors being returned before the original + # module's output. + # To make things safe, we'll use original_output_start_index field + # set by AOTAutograd to decide where the original module outputs start. + orig_output_end_idx = ( + original_output_start_index + num_orig_model_outputs + ) + # Sanity check: we are about to splice out the "user" outputs from the full set + # of "graph" outputs. Make sure we're within bounds. + assert orig_output_end_idx <= num_model_outputs + + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx + for idx in range( + original_output_start_index, orig_output_end_idx + ) + if isinstance(model_outputs[idx], torch.fx.Node) + ] + else: + model_outputs_node.meta["user_visible_output_idxs"] = [] + + # We also mark the invoke_subgraph outputs as user_visible to + # force the outputs of invoke_subgraph subgraph to follow the + # original strides + _recursive_record_user_visible_output_idxs(gm) + + return inner_compile( + gm, + example_inputs, + static_input_idxs=get_static_input_idxs(fixed), + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=is_inference, + boxed_forward_device_index=forward_device, + ) + + fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = ( + functools.partial(fw_compiler_base, is_inference=False) + ) + fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler) + + if config.freezing and not torch.is_grad_enabled(): + inference_compiler: Callable[..., Any] = functools.partial( + fw_compiler_freezing, + dynamo_model=model_, + num_example_inputs=num_example_inputs, + inner_compile=inner_compile, + cudagraphs=cudagraphs, + graph_id=graph_id, + forward_device=forward_device, + ) + else: + inference_compiler = functools.partial(fw_compiler_base, is_inference=True) + inference_compiler = SerializableAOTDispatchCompiler( + OutputCode, inference_compiler + ) + + def partition_fn( + gm: GraphModule, + joint_inputs: Sequence[object], + **kwargs: object, + ) -> tuple[GraphModule, GraphModule]: + cuda_context = get_cuda_device_context(gm) + with cuda_context: + # We can skip the invoke_subgraph because the + # entire_partition_fn is called recursively for invoke_subgraph + # in partitioning. + _recursive_joint_graph_passes(gm, skip_invoke_subgraph=True) + + static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment] + "static_lifetime_input_indices", None + ) + + with dynamo_utils.dynamo_timed( + "min_cut_rematerialization_partition", log_pt2_compile_event=True + ): + return min_cut_rematerialization_partition( + gm, + joint_inputs, + compiler="inductor", + static_lifetime_input_indices=static_lifetime_input_indices, + **kwargs, + ) + + @compile_time_strobelight_meta(phase_name="backward") + def bw_compiler( + gm: GraphModule, example_inputs: Sequence[InputType] + ) -> OutputCode: + from torch._dynamo.convert_frame import compile_lock + + with ( + dynamo_utils.dynamo_timed("compile_fx..bw_compiler"), + compile_lock, + ): + model_outputs_node = output_node(gm) + if config.bw_outputs_user_visible: + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx + for idx, n in enumerate(model_outputs) + if isinstance(n, torch.fx.Node) + ] + else: + model_outputs_node.meta["user_visible_output_idxs"] = [] + + fixed = count_tangents(gm) + with ( + config.patch(get_cpp_wrapper_config()) + if config.cpp_wrapper + else contextlib.nullcontext() + ): + return inner_compile( + gm, + example_inputs, + static_input_idxs=list(range(fixed)), + cudagraphs=cudagraphs, + is_backward=True, + graph_id=graph_id, + boxed_forward_device_index=forward_device, + ) + + bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler) + + fake_mode = detect_fake_mode( + example_inputs_ + ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + tracing_context = ( + torch._guards.TracingContext.try_get() + or torch._guards.TracingContext(fake_mode) + ) + + if V.aot_compilation: + with functorch_config.patch(unlift_effect_tokens=True): + gm, graph_signature = aot_export_module( + model_, + example_inputs_, + trace_joint=False, + decompositions=decompositions, + ) + + from torch._export.utils import _detect_fake_mode_from_gm + + fake_mode = _detect_fake_mode_from_gm(gm) + # aot_export_module doesn't account for constant tensor attributes + # so we end up having tensors that don't have fake vals attached. + # This can happen when upstream export is non-strict where we + # preserve the original module params/buffers. Once AOTI switches + # to ep.run_decompositions() flow to lower to post-autograd opset + # this will go away. + for node in gm.graph.nodes: + if node.op == "get_attr" and "val" not in node.meta: + target = attrgetter(node.target)(gm) + if isinstance(target, torch.Tensor): + node.meta["val"] = fake_mode.from_tensor( + target, static_shapes=True + ) + elif isinstance(target, torch.ScriptObject): + node.meta["val"] = ( + torch._library.fake_class_registry.maybe_to_fake_obj( + fake_mode, target + ) + ) + elif isinstance(target, FakeScriptObject): + node.meta["val"] = target + + unlifted_gm = _unlift_graph(model_, gm, graph_signature) + if "dynamo_flat_name_to_original_fqn" in model_.meta: + unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[ + "dynamo_flat_name_to_original_fqn" + ] + + if "dynamo_compile_id" in model_.meta: + unlifted_gm.meta["dynamo_compile_id"] = model_.meta["dynamo_compile_id"] + + # Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515) + # In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into + # _sfdp_init() to register patterns. + # When fallback_random is set to True, the sdpa patterns will be traced during runtime. + # If amp is turned on, the traced FP32 patterns will have prims.convert_element_type which + # will be the same as the generated FP16 patterns. + disable_amp = torch._C._is_any_autocast_enabled() + context = ( + torch._C._DisableAutocast if disable_amp else contextlib.nullcontext + ) + with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context(): + return inference_compiler(unlifted_gm, example_inputs_) + + with ( + V.set_fake_mode(fake_mode), + torch._guards.tracing(tracing_context), + compiled_autograd._disable(), + functorch_config.patch(unlift_effect_tokens=True), + ): + try: + return aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + decompositions=decompositions, + partition_fn=partition_fn, + keep_inference_input_mutations=True, + cudagraphs=cudagraphs, + boxed_forward_device_index=forward_device, + ignore_shape_env=ignore_shape_env, + )(model_, example_inputs_) + except ShortenTraceback as e: + # We will also shorten the traceback inside dynamo. + # This is only useful if inductor is called directly with an FX graph. + raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 + + +def graph_returns_tuple(gm: GraphModule) -> bool: + """True if a FX graph returns a tuple""" + if not isinstance(gm, GraphModule): + return True # can't check this, assume true + (rv,) = output_node(gm).args + if isinstance(rv, (list, tuple)): + return True + if ( + isinstance(rv, torch.fx.node.Node) + and hasattr(rv.target, "_schema") + and len(rv.target._schema.returns) > 1 + and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns) + ): + # for graphs whose result is one node with multiple outputs + return True + return False + + +def make_graph_return_tuple( + gm: GraphModule, + inputs: Sequence[InputType], + compile_gm: Callable[..., Any], +) -> Callable[..., Any]: + """ + Mutate gm so it returns a tuple. This is only needed for graphs + not created by torchdynamo that return non-tuples. + """ + node = output_node(gm) + (rv,) = node.args + rv, spec = pytree.tree_flatten(rv) + with gm.graph.inserting_before(node): + gm.graph.output(rv) + gm.graph.erase_node(node) + assert graph_returns_tuple(gm) + + compiled_fn = compile_gm(gm, inputs) + + @functools.wraps(compiled_fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) + + return wrapper + + +def handle_dynamo_export_graph( + gm: GraphModule, + inputs: Sequence[InputType], + compile_gm: Callable[..., Any], +) -> Callable[..., Any]: + """ + `torch._dynamo.export` embeds pytrees in the FX graph codegen object, + convert that to a normal FX graph so inductor can compile it. + """ + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs)) + + @functools.wraps(compiled_fn) # type: ignore[misc] + def wrapper(*args: Any) -> Any: + return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args))) + + return wrapper + + +def _check_triton_bf16_support(graph: GraphLowering) -> None: + def warn_and_skip(device: Optional[torch.device]) -> Never: + from torch._dynamo.exc import SkipFrame + + assert device is not None + + device_interface = get_interface_for_device(device.type) + device_props = device_interface.get_device_properties(device) + warnings.warn( + f"{device_props.name} does not support bfloat16 compilation natively, skipping" + ) + raise SkipFrame("BF16 is not supported") + + for node in itertools.chain(graph.graph_inputs.values(), graph.graph_outputs): + if not isinstance(node, IRNode): + continue + device_type = get_device_type(node) + if ( + not device_type + or not is_gpu(device_type) + or node.get_dtype() != torch.bfloat16 + ): + continue + # Print warning and skip frame if attempting to compile for bfloat16 + # on device without hardware support for dtype + device_interface = get_interface_for_device(device_type) + if device_interface.is_bf16_supported(including_emulation=False): + return + warn_and_skip(node.get_device()) + + +def _aoti_flatten_inputs( + gm: torch.fx.GraphModule, + args: Union[list[Any], tuple[Any, ...]], + kwargs: Optional[dict[str, Any]] = None, + *, + options: Optional[dict[str, Any]] = None, +) -> tuple[list[Any], dict[str, Any]]: + """ + Flatten the inputs to the graph module and return the flat inputs and options. + Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options. + """ + from .compile_fx import graph_returns_tuple + + assert graph_returns_tuple(gm), ( + "Graph output must be a tuple(). This is so that we can avoid " + "pytree processing of the outputs. Please change the module to " + "have tuple outputs." + ) + + # We will serialize the pytree info into the .so as constant strings + in_spec = None + out_spec = None + if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen): + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + if codegen.pytree_info.in_spec is not None: + in_spec = codegen.pytree_info.in_spec + if codegen.pytree_info.out_spec is not None: + out_spec = codegen.pytree_info.out_spec + + else: + if hasattr(gm, "_in_spec"): + in_spec = gm._in_spec + if hasattr(gm, "_out_spec"): + out_spec = gm._out_spec + + serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else "" + serialized_out_spec = ( + pytree.treespec_dumps(out_spec) if out_spec is not None else "" + ) + + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, kwargs or {}) + ) + + if any(isinstance(x[1], torch.ScriptObject) for x in flat_args_with_path): + from torch._dynamo.exc import UserError, UserErrorType + + raise UserError( + UserErrorType.INVALID_INPUT, + "TorchBind objects found in inputs. TorchBind object inputs are not supported in AOTInductor. " + "TorchBind objects can only be attributes.", + ) + + # Replace non-tensor (constant) inputs with Nones, since these are not being + # used anyways by the graph + flat_example_inputs = [ + x[1] if isinstance(x[1], torch.Tensor) else None for x in flat_args_with_path + ] + + if in_spec is not None and received_spec != in_spec: + raise ValueError( # noqa: B904 + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + options = ( + { + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + if options is None + else { + **options, + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + ) + return flat_example_inputs, options diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_fx_async.py b/phivenv/Lib/site-packages/torch/_inductor/compile_fx_async.py new file mode 100644 index 0000000000000000000000000000000000000000..10640b26a84396fa532ca534290616ed49155ca3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/compile_fx_async.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Optional, TYPE_CHECKING +from typing_extensions import final, override + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._inductor.output_code import CompiledFxGraphConstants, OutputCode + +from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile +from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401 + + +if TYPE_CHECKING: + from collections.abc import Sequence + from concurrent.futures import Future + + from torch._inductor.utils import InputType + from torch.fx import GraphModule + + from .compile_fx_ext import _OutOfProcessFxCompile, _WireProtocolPickledOutput + + +@dataclass +class _PostCompileData: + example_inputs: Sequence[InputType] + constants: CompiledFxGraphConstants + graph_kwargs: _CompileFxKwargs + + +# _AsyncOutputCode handles the actual management of waiting for an +# out-of-process compile to finish and then switching over to it. +@final +class _AsyncOutputCode(OutputCode): + _eager_forward: Optional[Callable[..., Any]] + _output_code: Optional[OutputCode] + _future: Optional[Future[_WireProtocolPickledOutput]] + _callback: Callable[[_WireProtocolPickledOutput], OutputCode] + _post_compile_data: Optional[_PostCompileData] = None + _boxed_call: bool # Copied from the forward/output_code + + def __init__( + self, + # eager_forward is run until the future is finished. + eager_forward: Callable[..., Any], + # this responds with the result of the out-of-process compile when it's + # ready. + future: Future[_WireProtocolPickledOutput], + # this callback gets called to turn the _WireProtocolPickledOutput into an OutputCode + callback: Callable[[_WireProtocolPickledOutput], OutputCode], + ) -> None: + self._eager_forward = eager_forward + self._boxed_call = getattr(eager_forward, "_boxed_call", False) + self._output_code = None + + self._future = future + self._callback = callback + + @override + def __call__(self, *args: Any) -> Any: + if self._future is not None and self._future.done(): + args = self._switch_to_compiled_forward(args) + + if eager_forward := self._eager_forward: + _AsyncFxCompile._stat_eager_runs += 1 + return eager_forward(*args) + + else: + _AsyncFxCompile._stat_compiled_runs += 1 + assert self._output_code is not None + return self._output_code.__call__(*args) + + # Takes and returns the args (converted to the "right" boxed mode) + def _switch_to_compiled_forward(self, args: tuple[Any, ...]) -> tuple[Any, ...]: + assert self._future is not None + + # TODO: If the future ended in an exception do we want to continue + # running eager or hit the exception now? + f, self._future = self._future, None + output_code = self._callback(f.result()) + + if pcd := self._post_compile_data: + self._post_compile_data = None + + output_code.post_compile( + pcd.example_inputs, pcd.constants, pcd.graph_kwargs + ) + + self._output_code = output_code + self._eager_forward = None + boxed_call = getattr(output_code, "_boxed_call", False) + + if self._boxed_call != boxed_call: + if self._boxed_call: + # Was boxed, now unboxed + args = args[0] if len(args) > 0 else () + else: + # Was unboxed, now boxed + args = (args,) + + self._boxed_call = boxed_call + return args + + @override + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + if self._eager_forward is not None: + self._post_compile_data = _PostCompileData( + example_inputs, constants, graph_kwargs + ) + else: + assert self._output_code is not None + self._output_code.post_compile(example_inputs, constants, graph_kwargs) + + +# Given an FxCompile for an out-of-process compile _AsyncFxCompile will run +# eager until the compiled artifact is ready then it will automatically switch +# over to using the compiled version. +@final +class _AsyncFxCompile(FxCompile): + _compile: _OutOfProcessFxCompile + + # Some debugging stats: + # Number of times we started a background compile. + _stat_bg_started: int = 0 + # Number of times we finished a background compile. + _stat_bg_finished: int = 0 + # Number of times we ran "eager" + _stat_eager_runs: int = 0 + # Number of times we ran our compiled (out-of-process) artifact + _stat_compiled_runs: int = 0 + + def __init__(self, compile: _OutOfProcessFxCompile) -> None: + self._compile = compile + + @classmethod + def _reset_stats(cls) -> None: + cls._stat_bg_started = 0 + cls._stat_bg_finished = 0 + cls._stat_eager_runs = 0 + cls._stat_compiled_runs = 0 + + @override + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: + eager_output_code = _InProcessFxCompile().codegen_and_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + + # This is similar to _SerializedFxCompile.codegen_and_compile() but + # handles the async routing. + + serialized = self._compile.serialize_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + if not serialized: + # We can't serialize - just return the eager OutputCode + return eager_output_code + + inputs, constants = serialized + + _AsyncFxCompile._stat_bg_started += 1 + f = self._compile._send_to_child_async(inputs) + + # This is called by _switch_to_compiled_forward() when f has a result... + def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: + _AsyncFxCompile._stat_bg_finished += 1 + output = pickled_output.deserialize(constants) + self._compile._postprocess(output) + return output.graph + + return _AsyncOutputCode(eager_output_code, f, callback) diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_fx_ext.py b/phivenv/Lib/site-packages/torch/_inductor/compile_fx_ext.py new file mode 100644 index 0000000000000000000000000000000000000000..6de252b225777afe59db0fdfcc44452c9cff99c8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/compile_fx_ext.py @@ -0,0 +1,681 @@ +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import logging +import os +import queue +import sys +import warnings +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING, Union +from typing_extensions import final, override, Self, TypeGuard + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch.fx +from torch._inductor.codecache import BypassFxGraphCache, FxGraphCache +from torch._inductor.metrics import CachedMetricsDeltas, CachedMetricsHelper +from torch._inductor.output_code import ( + CompiledFxGraph, + CompiledFxGraphConstants, + CompiledFxGraphConstantsWithGm, + OutputCode, +) +from torch._subclasses import FakeTensorMode +from torch.utils._ordered_set import OrderedSet + +from . import config +from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile, log +from .debug import DebugContext +from .graph import GraphLowering +from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401 +from .virtualized import V + + +if TYPE_CHECKING: + import types + from collections.abc import Generator, Mapping, Sequence + from concurrent.futures import Future + + from torch._inductor.utils import InputType + from torch.fx import GraphModule + + +@dataclass +class _VirtualizedSerializer: + """ + This handles the data for serializing Virtualized. + """ + + # The values here get serialized. We don't grab everything because some of + # the fields can't be serialized. + aot_compilation: Any = None + choices: Any = None + local_buffer_context: Any = None + ops: Any = None + kernel: Any = None + current_node: Any = None + + @classmethod + def serialize(cls) -> _VirtualizedSerializer: + """ + Turn the current state of torch._inductor.virtualized.V into a + serializable structure. + """ + kwargs = {} + for f in dataclasses.fields(cls): + kwargs[f.name] = getattr(V, f.name) + return _VirtualizedSerializer(**kwargs) + + def patch(self) -> _VirtualizedSerializerContextManager: + """ + Returns a context manager which patches the saved values into the + current environment. While patched, any value not listed above will be + poisoned so that reads will raise an error. + """ + return _VirtualizedSerializerContextManager(self) + + +class _VirtualizedSerializerContextManager(contextlib.ExitStack): + """ + Helper for _VirtualizedSerializer.patch() + """ + + def __init__(self, virtualized: _VirtualizedSerializer) -> None: + super().__init__() + self.virtualized = virtualized + + @override + def __enter__(self) -> Self: + super().__enter__() + + for set_name in dir(V): + if not set_name.startswith("set_"): + continue + name = set_name[4:] + name = name.removesuffix("_handler") + set_handler = getattr(V, set_name) + if hasattr(self.virtualized, name): + value = getattr(self.virtualized, name) + else: + # poison any values that we don't serialize so that any + # unset accesses are caught. + value = torch._inductor.virtualized._PoisonedVirtual + self.enter_context(set_handler(value)) + + return self + + +def _is_fallback_handler(op: object) -> bool: + try: + return op._is_fallback_handler # type: ignore[attr-defined] + except AttributeError: + return False + + +class _LoweringSerializer: + """ + This handles the data for serializing lowering.lowering + """ + + # A full implementation would make sure that all lowerings are copied over + # (or at least detected and raise a bypass when a non-standard lowering is + # used). For now we just handle tests by looking for lowerings that were + # overridden with a forced fallback. + fallbacks: OrderedSet[str] + + def __init__(self) -> None: + from . import lowering + + self.fallbacks = OrderedSet( + str(k) for k, v in lowering.lowerings.items() if _is_fallback_handler(v) + ) + + def patch(self) -> _LoweringSerializerContextManager: + return _LoweringSerializerContextManager(self) + + +class _LoweringSerializerContextManager(contextlib.ExitStack): + """ + Helper for _LoweringSerializer.patch() + """ + + def __init__(self, lowering: _LoweringSerializer) -> None: + super().__init__() + self.lowering = lowering + + @override + def __enter__(self) -> Self: + super().__enter__() + + from . import lowering + + for k, v in lowering.lowerings.items(): + name = str(k) + if name in self.lowering.fallbacks: + if not _is_fallback_handler(v): + self.enter_context(lowering.force_fallback(k)) # type: ignore[arg-type] + + return self + + +@dataclass +class _FakeTensorModeSerializer: + allow_non_fake_inputs: bool + + def __init__(self, fake_mode: FakeTensorMode) -> None: + self.allow_non_fake_inputs = fake_mode.allow_non_fake_inputs + self.shape_env = fake_mode.shape_env + + @contextlib.contextmanager + def patch(self, fake_mode: FakeTensorMode) -> Generator[None, None, None]: + saved_allow_non_fake_inputs = fake_mode.allow_non_fake_inputs + fake_mode.allow_non_fake_inputs = self.allow_non_fake_inputs + + yield + + fake_mode.allow_non_fake_inputs = saved_allow_non_fake_inputs + + +@dataclass +class _WireProtocolInput: + """ + For _SerializedFxCompile - encapsulates all the data being transferred + (sent) from the parent to the child. + """ + + gm: torch.fx.GraphModule + example_inputs: Sequence[InputType] + inputs_to_check: Sequence[int] + graph_kwargs: _CompileFxKwargs + tracing_context: Optional[torch._guards.TracingContext] + config: dict[str, object] + virtualized: _VirtualizedSerializer + deterministic_guard_for_testing: Optional[ # type: ignore[name-defined] # mypy bug + torch.testing._internal.common_utils.DeterministicGuard + ] + logger_state: _LoggerState + lowering: _LoweringSerializer + fake_tensor_mode: _FakeTensorModeSerializer + + def serialize(self) -> _WireProtocolPickledInput: + """ + Turns this object into a _WireProtocolPickledInput which can be + directly transferred across a stream. + """ + from torch.fx._graph_pickler import GraphPickler + + return _WireProtocolPickledInput(GraphPickler.dumps(self)) + + +def _current_fake_mode() -> FakeTensorMode: + fake_mode = None + if context := torch._guards.TracingContext.try_get(): + fake_mode = context.fake_mode + if fake_mode is not None: + return fake_mode + + shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv() + return FakeTensorMode(shape_env=shape_env) + + +@dataclass +class _WireProtocolPickledInput: + value: bytes + + def deserialize(self) -> _WireProtocolInput: + """ + Turn this streamable object back into a _WireProtocolInput. + """ + from torch.fx._graph_pickler import GraphPickler + + fake_mode = _current_fake_mode() + result = GraphPickler.loads(self.value, fake_mode) + assert isinstance(result, _WireProtocolInput) + return result + + +@dataclass +class _WireProtocolOutput: + """ + For _SerializedFxCompile - encapsulates all the data being transferred + (returned) back from the child to the parent. + """ + + graph: OutputCode + metrics: CachedMetricsDeltas + logs: list[logging.LogRecord] + warning_replay: Optional[list[warnings.WarningMessage]] + shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv] + + def serialize(self) -> _WireProtocolPickledOutput: + """ + Turns this object into a _WireProtocolPickledOutput which can be + directly transferred across a stream. + """ + from torch.fx._graph_pickler import GraphPickler + + if isinstance(self.graph, CompiledFxGraph): + self.graph.prepare_for_serialization() + return _WireProtocolPickledOutput(GraphPickler.dumps(self)) + + +@dataclass +class _WireProtocolPickledOutput: + value: bytes + + def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput: + """ + Turn this streamable object back into a _WireProtocolOutput. + """ + from torch.fx._graph_pickler import GraphPickler + + fake_mode = _current_fake_mode() + result = GraphPickler.loads(self.value, fake_mode) + assert isinstance(result, _WireProtocolOutput) + if isinstance(result.graph, CompiledFxGraph): + result.graph.after_deserialization(constants) + return result + + +class _LoggerState: + """ + This class is for tracking logging that happens during an out-of-process + compile so we can "replay" those messages when the compile is done. Used as + a context manager which returns the captured logs (object). + """ + + loggers: dict[str, int] + # The actual log capturing mechanism - this should be None when we're not + # actively capturing logs. + captured_logs: Optional[_CapturedLogs] = None + + def __init__(self) -> None: + # Mapping from logger name to level. + self.loggers = {} + + def filter( + logger: Union[logging.Logger, logging.PlaceHolder], + ) -> TypeGuard[logging.Logger]: + if not isinstance(logger, logging.Logger): + # Assume that Placeholders propagate + return False + # We only want to track torch._inductor logging + if not logger.name.startswith("torch._inductor"): + return False + # If this logger propagates then assume we'll track its parent + if logger.propagate: + return False + return True + + root = logging.getLogger("torch._inductor") + if sys.version_info < (3, 12): + # logging.getChildren() doesn't exist until 3.12 + logging._acquireLock() # type: ignore[attr-defined] + try: + for logger in root.manager.loggerDict.values(): + if filter(logger): + self.loggers[logger.name] = logger.level + finally: + logging._releaseLock() # type: ignore[attr-defined] + else: + q = [root] + while q: + logger = q.pop() + if filter(logger): + self.loggers[logger.name] = logger.level + q.extend(logger.getChildren()) + + def __enter__(self) -> _CapturedLogs: + assert self.captured_logs is None + self.captured_logs = _CapturedLogs(self) + self.captured_logs.apply() + return self.captured_logs + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> None: + assert self.captured_logs is not None + self.captured_logs.remove() + + +class _CapturedLogs: + """ + Helper for _LoggerState - this class actually attaches to the logger in + the child process and grabs the log messages themselves. + """ + + state: _LoggerState + queue: queue.Queue[logging.LogRecord] + handlers: Optional[dict[str, logging.Handler]] + + def __init__(self, state: _LoggerState) -> None: + self.state = state + # A queue of the log entries + # TODO: For memory purposes should we log to a file and then respond with that? + self.queue = queue.Queue(-1) + # Mapping from name to handler (only valid when applied) + self.handlers = None + + def finish(self) -> list[logging.LogRecord]: + assert self.handlers is None + logs = [] + try: + while True: + logs.append(self.queue.get_nowait()) + except queue.Empty: + pass + return logs + + def remove(self) -> None: + assert self.handlers is not None + handlers, self.handlers = self.handlers, None + for name, handler in handlers.items(): + logger = logging.getLogger(name) + logger.removeHandler(handler) + + def apply(self) -> None: + from logging.handlers import QueueHandler + + assert self.handlers is None + self.handlers = {} + for name, level in self.state.loggers.items(): + logger = logging.getLogger(name) + handler = QueueHandler(self.queue) + self.handlers[name] = handler + logger.addHandler(handler) + if level != logging.NOTSET: + logger.setLevel(level) + + +class _SerializedFxCompile(FxCompile): + """ + This is used to represent an FxCompile which occurs across a serialized + boundary. + """ + + @override + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: + # If this code changes it's likely _AsyncFxCompile.codegen_and_compile() + # will also need to match. + + serialized = self.serialize_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + if not serialized: + return _InProcessFxCompile().codegen_and_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + + inputs, constants = serialized + output = self._send_to_child(inputs).deserialize(constants) + + self._postprocess(output) + self._compile_stats[type(self)].codegen_and_compile += 1 + + # TODO: Do we need to figure out what changed in TracingContext in the + # child and plumb that back up to the parent? + + return output.graph + + def serialize_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> Optional[tuple[_WireProtocolPickledInput, CompiledFxGraphConstantsWithGm]]: + """ + Prepare a _WireProtocolInput to compile. If None is returned then it + wasn't possible to serialize and we should fallback to in-process. + """ + try: + # _check_for_hop raises BypassFxGraphCache when it detects something + # we can't cache (or serialize) + FxGraphCache._check_for_hop(gm) + except BypassFxGraphCache as e: + log.debug("Skipping %s compile: %s", type(self), e) + return None + + context = torch._guards.TracingContext.try_get() + constants = CompiledFxGraphConstantsWithGm(gm) + logger_state = _LoggerState() + lowering = _LoweringSerializer() + + # If we're running tests then grab the DeterministicGuard (don't want to + # import this if it isn't already imported because it has side-effects) + deterministic_guard_for_testing: Optional[ # type: ignore[name-defined] # mypy bug + torch.testing._internal.common_utils.DeterministicGuard + ] = None + try: + deterministic_guard_for_testing = ( + torch.testing._internal.common_utils.DeterministicGuard._current_state() # type: ignore[attr-defined] # mypy bug + ) + except AttributeError: + pass + + fake_mode = _current_fake_mode() + fake_tensor_mode = _FakeTensorModeSerializer(fake_mode) + + try: + input = _WireProtocolInput( + gm, + example_inputs, + inputs_to_check, + graph_kwargs, + context, + config.save_config_portable(), + _VirtualizedSerializer.serialize(), + deterministic_guard_for_testing, + logger_state, + lowering, + fake_tensor_mode, + ).serialize() + return (input, constants) + except (AttributeError, BypassFxGraphCache): + # For example: AttributeError: Can't pickle local object + # 'make_opaque_unary_fn..OpaqueUnaryFn' + + # TODO: scuba record about not being able to do this? + log.warning("Unable to pickle input graph or example inputs", exc_info=True) + + return None + + @abstractmethod + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + # The implementation of this should transfer `input` to the child, call + # `_run_in_child(input)` and transfer the result back. + ... + + def _postprocess(self, output: _WireProtocolOutput) -> None: + pass + + @classmethod + def _run_in_child( + cls, + pickled_input: _WireProtocolPickledInput, + extra_env: Optional[Mapping[str, str]] = None, + ) -> _WireProtocolPickledOutput: + metrics = CachedMetricsHelper() + + with contextlib.ExitStack() as stack: + if extra_env is not None: + import unittest + + stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env)) + + # Save warnings to "replay" in the parent + warning_replay = stack.enter_context(warnings.catch_warnings(record=True)) + + # TODO: Should we split the input into multiple sections where each + # section sets up state for the previous section? (i.e. a Config section + # which we decode and apply, followed by a FakeTensorMode section which + # we decode and apply, etc) + input = pickled_input.deserialize() + + stack.enter_context(input.virtualized.patch()) + stack.enter_context(input.lowering.patch()) + stack.enter_context(config.patch(input.config)) + captured_logs = stack.enter_context(input.logger_state) + if input.deterministic_guard_for_testing: + stack.enter_context(input.deterministic_guard_for_testing) + stack.enter_context(torch._guards.tracing(input.tracing_context)) + stack.enter_context(DebugContext()) + + fake_mode = _current_fake_mode() + stack.enter_context(input.fake_tensor_mode.patch(fake_mode)) + + output_graph = _InProcessFxCompile().codegen_and_compile( + input.gm, + input.example_inputs, + input.inputs_to_check, + input.graph_kwargs, + ) + + logs = captured_logs.finish() + + return _WireProtocolOutput( + output_graph, + metrics.get_deltas(), + logs, + warning_replay, + fake_mode.shape_env, + ).serialize() + + +# This is a debugging/testing implementation of FxCompile which serializes the +# input and output but still runs the FxCompile in-process. +@final +class _DebugSerdeFxCompile(_SerializedFxCompile): + @override + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + # For debugging just serde the input and output but don't run in a + # subprocess. + return self._run_in_child(pickled_input) + + +class _OutOfProcessFxCompile(_SerializedFxCompile): + """ + Represents an FxCompile which is run outside the current process (in + either a subprocess or possibly even a separate machine). + """ + + @override + @final + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + f = self._send_to_child_async(pickled_input) + + # For debugging: If we want to print status updates... + # last = time.time() + # while not f.done(): + # print("tick...") + # time.sleep(0.125) + # now = time.time() + # if now - last > 1: + # last = now + + return f.result() + + @abstractmethod + def _send_to_child_async( + self, pickled_input: _WireProtocolPickledInput + ) -> Future[_WireProtocolPickledOutput]: ... + + def _postprocess(self, output: _WireProtocolOutput) -> None: + # Since our metrics were gathered in a subprocess make sure to add them + # here. + CachedMetricsHelper.apply_deltas(output.metrics) + + # This is used by tests to check the output for specific details. For + # remote things (subproc and RE) we need to do the `save_output_code` + # here since it didn't happen earlier in-process. In the future if this + # doesn't have "source_code" (it's a CompiledAOTI, for example) and we + # need it we'll have to grab it and serialize it separately from the + # child. + if GraphLowering.save_output_code is not None: + GraphLowering.save_output_code(output.graph.source_code) # type: ignore[attr-defined] + + # And forward our collected logs. The cache is cleared when the outer + # function exits. + @functools.cache + def getLogger(name: str) -> logging.Logger: + return logging.getLogger(name) + + if output.warning_replay: + for w in output.warning_replay: + warnings.warn_explicit( + message=w.message, + category=w.category, + filename=w.filename, + lineno=w.lineno, + source=w.source, + ) + + for record in output.logs: + logger = getLogger(record.name) + logger.handle(record) + + +# For debugging - create a _FxCompile which writes the serialized data to a file +# and then exits. +# +# TODO: make this a FxCompileMode value? +# +# The "child runner" should look something like this: +# +# import torch +# from torch._inductor import compile_fx +# idx = 0 +# with open(f"/tmp/pytorch_compile_fx_tmp_input_{idx}.bin", "rb") as f: +# input = compile_fx._WireProtocolPickledInput(f.read()) +# result = compile_fx._SubprocessFxCompile._run_in_child(input) +# with open(f"/tmp/pytorch_compile_fx_tmp_output_{idx}.bin", "wb") as f: +# f.write(result.value) +# +@final +class _DebugFileFxCompile(_SerializedFxCompile): + file_index = 0 + + @override + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + idx = _DebugFileFxCompile.file_index + _DebugFileFxCompile.file_index += 1 + + name = f"/tmp/aorenste/pytorch_compile_fx_tmp_input_{idx}.bin" + with open(name, "wb") as f: + f.write(pickled_input.value) + print(f"Wrote to {name}") + + if False: + name = f"/tmp/aorenste/pytorch_compile_fx_tmp_actual_{idx}.bin" + actual = self._run_in_child(pickled_input) + with open(name, "wb") as f: + f.write(actual.value) + return actual + elif False: + name = f"/tmp/aorenste/pytorch_compile_fx_tmp_output_{idx}.bin" + with open(name, "rb") as f: + result = _WireProtocolPickledOutput(f.read()) + print(f"Read from {name}") + return result + else: + os._exit(-1) diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_fx_subproc.py b/phivenv/Lib/site-packages/torch/_inductor/compile_fx_subproc.py new file mode 100644 index 0000000000000000000000000000000000000000..540ed89b433c86cea31f2e527452ffe71d4b1f2e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/compile_fx_subproc.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import atexit +import functools +import os +from typing import Optional, TYPE_CHECKING +from typing_extensions import final, override + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch.fx +from torch._inductor.compile_worker.subproc_pool import ( + AnyPool, + SubprocKind, + SubprocPool, +) +from torch._inductor.utils import clear_caches + +from .compile_fx_ext import ( + _OutOfProcessFxCompile, + _WireProtocolPickledInput, + _WireProtocolPickledOutput, +) +from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401 + + +if TYPE_CHECKING: + from collections.abc import Mapping + from concurrent.futures import Future + + +@final +class _SubprocessFxCompile(_OutOfProcessFxCompile): + @override + def _send_to_child_async( + self, input: _WireProtocolPickledInput + ) -> Future[_WireProtocolPickledOutput]: + # TODO: Do we need to copy across some kind of logging IDs? (ChromiumEventLogger) + + pool = self.process_pool() + + # TODO: This is probably the wrong thing to do long-term - but for now + # let's share the cache so we can identify tests broken by this later. + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} + + return pool.submit( + _SubprocessFxCompile._run_in_child_subprocess, input, extra_env + ) + + @staticmethod + @functools.cache + def process_pool() -> AnyPool: + pool = SubprocPool( + # TODO: Consider raising this limit if we start using async w/ + # subprocess and want to compile multiple graphs in parallel. + 1, + kind=SubprocKind.SPAWN, + ) + + atexit.register(pool.shutdown) + + return pool + + @classmethod + def _run_in_child_subprocess( + cls, + pickled_input: _WireProtocolPickledInput, + extra_env: Optional[Mapping[str, str]], + ) -> _WireProtocolPickledOutput: + # TODO: In subprocess mode we need to clear the inductor caches. + # The problem: + # 1. We compile in worker A which fills stuff in tmpdir + # 2. parent clears inductor caches which deletes tmpdirs and tells + # cpp_prefix_path() to clear its LRU cache + # 3. We compile a second time in subproc A - but since we never told + # cpp_prefix_path() in worker A to clear its LRU it thinks the + # tmpdir still exists and fails to compile. + # + # TODO: We probably should be using a separate tmpdir in the worker + # anyway... but we should probably still respect clear_caches() + # in the parent... maybe? + # + # TODO: We could be less aggressive by keeping a clock which gets + # incremented when we clear the cache, send the clock to the worker and + # only clear caches if the clock changed since last time. + # + clear_caches() + torch._inductor.metrics.reset() + + # TODO: turn off config.fx_graph_async_compile + + result = cls._run_in_child(pickled_input, extra_env) + return result diff --git a/phivenv/Lib/site-packages/torch/_inductor/compiler_bisector.py b/phivenv/Lib/site-packages/torch/_inductor/compiler_bisector.py new file mode 100644 index 0000000000000000000000000000000000000000..6a89abf80d7d55f66d514888b27544d6860ce924 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/compiler_bisector.py @@ -0,0 +1,631 @@ +import atexit +import collections +import dataclasses +import functools +import os +import shutil +import sys +import tempfile +from dataclasses import dataclass, field +from typing import Callable, Optional + +from torch._inductor.runtime.cache_dir_utils import cache_dir + + +# Set the subdirectory name +SUBDIR_NAME = "bisect" + + +@dataclass +class Subsystem: + name: str + + +@dataclass +class BisectSubsystem(Subsystem): + pass + + +@dataclass +class BinarySubsystem(Subsystem): + pass + + +@dataclass +class ConfigChange(BinarySubsystem): + name: str = field(init=False) + config_name: str + config_field: str + config_value: object + + def __post_init__(self) -> None: + self.name = f"{self.config_name}_{self.config_field}" + + +# Dictionary of backend -> subsystems +BACKENDS: dict[str, list[Subsystem]] = { + # run dynamo without aot_autograd + "eager": [], + # run dynamo with aot_autograd, but no partitioner or decomps + "aot_eager": [], + # run dynamo with aot autograd, decompositions and partitioner + "aot_eager_decomp_partition": [ + ConfigChange("aot_eager_decomp_partition", "cse", False), + BisectSubsystem( + "decomposition" + ), # number of decompositions we apply in tracing + ], # TODO - add cse ? + # applies CrossRefFakeMode on invocation + "aot_eager_decomp_partition_crossref": [], + "inductor": [ + BisectSubsystem("joint_graph_passes"), # passes applied on joint graph + BisectSubsystem( + "post_grad_passes" + ), # passes applied individually on forward, and backward in inductor + ConfigChange("inductor", "fallback_random", True), + ConfigChange("inductor", "emulate_precision_casts", True), + ConfigChange("inductor", "layout_optimization", False), + ConfigChange("inductor", "comprehensive_padding", False), + BisectSubsystem("lowerings"), # lowering aten operators to inductor + ], # TODO - add more - fusions ? +} + +subsystem_call_counter: dict[str, int] = collections.Counter() +call_counter_debug_info: dict[int, str] = {} + + +def reset_counters() -> None: + subsystem_call_counter.clear() + call_counter_debug_info.clear() + + +@functools.cache +def get_env_val(env_str: str) -> Optional[str]: + return os.environ.get(env_str, None) + + +@dataclasses.dataclass +class BisectionResult: + """ + backend: torch.compile backend responsible for failure + subsystem: optional, registered component identified for failure + bisect_number: optional, number of times the subsystem needed to be applied to trigger failure + debug_info: associated info of the triggering bisect application of subsystem + """ + + backend: str + subsystem: Optional[str] = None + bisect_number: Optional[int] = None + debug_info: Optional[str] = None + + +class CompilerBisector: + """ + This class iteratively runs torch.compile backends (eager, aot_eager, inductor) to find the + first backend that can repro an issue. + + Once it discovers the offending backend it will iteratively disable subsystems within the backend. + For subsystems which are applied repeatedly, such as the number of post grad passes or number + of lowering of nodes to inductor ir, it will bisect to find the offending application. + + The idiomatic way to run it is with `do_bisect`. You can also use it by setting the env flags + `TORCH_BISECT_BACKEND`, `TORCH_BISECT_SUBSYSTEM` and `TORCH_BISECT_MAX`. + + It also supports a CLI interface, although this is less well tested. + + You must run python compiler_bisector.py [start | good | bad | end] + """ + + bisection_enabled: bool = False + + in_process_cache: Optional[str] = None + + @classmethod + def get_dir(cls) -> str: + return f"{cache_dir() if not cls.in_process_cache else cls.in_process_cache}/{SUBDIR_NAME}" + + @classmethod + def write_lines_to_file(cls, file_path: str, lines: list[str]) -> None: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as file: + file.writelines(lines) + + @classmethod + def read_lines_from_file(cls, file_path: str) -> list[str]: + if os.path.exists(file_path): + with open(file_path) as file: + return file.readlines() + return [] + + @classmethod + def update_run_state( + cls, backend_name: str, subsystem: Subsystem, run_state: str + ) -> None: + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem.name}_run_state.txt" + ) + if isinstance(subsystem, ConfigChange): + assert run_state == "test_disable" + cls.set_config_values( + backend_name, + subsystem.name, + {subsystem.config_field: subsystem.config_value}, + ) + + cls.write_lines_to_file(file_path, [run_state]) + + @classmethod + def set_config_values( + cls, backend: str, subsystem: str, config_data: dict[str, object] + ) -> None: + file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt") + lines = [f"{k}={v}\n" for k, v in config_data.items()] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def update_bisect_status(cls, backend_name: str, subsystem_name: str) -> None: + assert isinstance(subsystem_name, str) + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = [f"backend={backend_name}\n", f"subsystem={subsystem_name}\n"] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def update_bisect_range( + cls, backend_name: str, subsystem_name: str, low: int, high: int + ) -> None: + assert isinstance(subsystem_name, str) + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt" + ) + lines = [f"low={low}\n", f"high={high}\n"] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def get_backend(cls) -> Optional[str]: + """ + Returns the active backend, if any + """ + if val := get_env_val("TORCH_BISECT_BACKEND"): + return val + + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = cls.read_lines_from_file(file_path) + for line in lines: + if line.startswith("backend="): + return line.strip().split("=")[1] + return None + + @classmethod + def get_subsystem(cls) -> Optional[str]: + """ + Returns the active subsystem, if any + """ + + if val := get_env_val("TORCH_BISECT_SUBSYSTEM"): + return val + + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = cls.read_lines_from_file(file_path) + for line in lines: + if line.startswith("subsystem="): + out = line.strip().split("=")[1] + return out if out else None + return None + + @classmethod + def get_subsystem_object(cls, backend_name: str, subsystem_name: str) -> Subsystem: + return next(obj for obj in BACKENDS[backend_name] if obj.name == subsystem_name) + + @classmethod + def get_run_state(cls, backend_name: str, subsystem_name: str) -> Optional[str]: + """ + Returns the current stage of bisecting, if Any + """ + + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_run_state.txt" + ) + lines = cls.read_lines_from_file(file_path) + if lines: + out = lines[0].strip() + assert out in ("test_disable", "find_max_bounds", "bisect") + return out + return None + + @classmethod + def get_bisect_range( + cls, backend_name: str, subsystem_name: str + ) -> tuple[int, int]: + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt" + ) + lines = cls.read_lines_from_file(file_path) + low = None + high = None + for line in reversed(lines): + if line.startswith("low="): + low = int(line.strip().split("=")[1]) + elif line.startswith("high="): + high = int(line.strip().split("=")[1]) + + if low is not None and high is not None: + break + + if low is None or high is None: + raise RuntimeError( + f"Trying to get bisect range when it is not set: subsystem {subsystem_name}" + ) + + return low, high + + @classmethod + def update_config_change(cls, backend: str, subsystem: ConfigChange) -> None: + file_path = os.path.join(cls.get_dir(), backend, f"{subsystem.name}_config.txt") + lines = [ + f"config_name={subsystem.config_name}\n", + f"config_field={subsystem.config_field}\n", + f"config_value={subsystem.config_value}\n", + ] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def get_config_change(cls, config_name: str) -> Optional[dict[str, object]]: + backend = cls.get_backend() + subsystem = cls.get_subsystem() + + if not backend or not subsystem: + return None + + file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt") + + if not os.path.exists(file_path): + return None + + lines = cls.read_lines_from_file(file_path) + config_data = {} + for line in lines: + key, value = line.strip().split("=", 1) + config_data[key] = eval(value) + + return config_data + + @classmethod + def delete_bisect_status(cls) -> None: + # in process_cache we have created if it exists, just the subdirectory of non created dir + dir_name = cls.in_process_cache if cls.in_process_cache else cls.get_dir() + if os.path.exists(dir_name): + shutil.rmtree(dir_name) + print("Bisection status deleted.") + else: + print("No bisection status found.") + + @classmethod + def get_system_counter(cls, name: str, increment: bool = True) -> int: + global subsystem_call_counter + curr = subsystem_call_counter[name] + if increment: + subsystem_call_counter[name] += 1 + return curr + + @classmethod + def disable_subsystem( + cls, + backend: str, + subsystem: str, + debug_info: Optional[Callable[[], str]] = None, + ) -> bool: + if not cls.bisection_enabled: + return False + + if cls.get_backend() != backend: + return False + + if cls.get_subsystem() != subsystem: + return False + + if val := get_env_val("TORCH_BISECT_MAX"): + counter = cls.get_system_counter(subsystem, increment=True) + return counter > int(val) + + run_state = cls.get_run_state(backend, subsystem) + if run_state == "test_disable": + # First run, disable completely + return True + elif run_state == "find_max_bounds": + # Second run, update bisection range and return True to enable the subsystem + cls.update_bisect_range( + backend, + subsystem, + 0, + cls.get_system_counter(subsystem, increment=True), + ) + return False + else: + assert run_state == "bisect" + # If the environment variable is not set, use the bisection range midpoint + low, high = cls.get_bisect_range(backend, subsystem) + # if high - low <= 2: + midpoint = (low + high) // 2 + call_counter = cls.get_system_counter(subsystem) + + if ( + call_counter >= low + and call_counter <= high + and (low - high) <= 2 + and debug_info is not None + ): + call_counter_debug_info[call_counter] = debug_info() + + return call_counter > midpoint + + @classmethod + def advance_subsystem( + cls, curr_backend: str, curr_subsystem: Subsystem + ) -> Optional[Subsystem]: + """ + Tries to move to the next subsystem within the current system. + """ + print(f"Disabling {curr_subsystem.name} did not fix the issue.") + + current_subsystems = BACKENDS[curr_backend] + current_subsystem_index = next( + i + for i, subsystem in enumerate(current_subsystems) + if subsystem.name == curr_subsystem.name + ) + + if current_subsystem_index < len(current_subsystems) - 1: + next_subsystem = current_subsystems[current_subsystem_index + 1] + cls.update_bisect_status(curr_backend, next_subsystem.name) + cls.update_run_state(curr_backend, next_subsystem, "test_disable") + print( + f"Moving to the next subsystem: {curr_backend} - {next_subsystem.name}" + ) + return next_subsystem + else: + print( + f"All subsystems in {curr_backend} have been checked. The issue is not in this system." + ) + return None + + @classmethod + def advance_backend(cls, curr_backend: str) -> Optional[str]: + """ + Tries Move to the next backend. + """ + current_system_index = list(BACKENDS.keys()).index(curr_backend) + + if current_system_index < len(BACKENDS) - 1: + curr_backend = list(BACKENDS.keys())[current_system_index + 1] + cls.update_bisect_status(curr_backend, "") + print(f"Moving to the next system: {curr_backend}") + return curr_backend + else: + return None + + @classmethod + def process_subsystem( + cls, + curr_backend: str, + curr_subsystem: Subsystem, + fn: Callable[[], bool], + cli_interface: bool = True, + ) -> bool: + """ + Process the current subsystem. Returns True if the issue is found, False otherwise. + """ + assert isinstance(curr_subsystem, Subsystem) + while True: + run_state = cls.get_run_state(curr_backend, curr_subsystem.name) + reset_counters() + if run_state == "test_disable": + if not fn(): + next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem) + if not next_subsystem: + return False + curr_subsystem = next_subsystem + else: + if isinstance(curr_subsystem, ConfigChange): + print( + f"Setting config {curr_subsystem.config_name} field {curr_subsystem.config_field} " + f"to {curr_subsystem.config_value} fixed the issue" + ) + else: + print(f"Disabling {curr_subsystem.name} fixed the issue.") + if isinstance(curr_subsystem, BinarySubsystem): + return True + print("Starting bisect by getting upper bound.") + cls.update_run_state( + curr_backend, curr_subsystem, "find_max_bounds" + ) + elif run_state == "find_max_bounds": + if fn(): + raise RuntimeError( + f"Function succeeded with 'find_max_bounds' status for {curr_backend} - {curr_subsystem.name}." + ) + else: + _, high = cls.get_bisect_range(curr_backend, curr_subsystem.name) + print(f"Upper bound of {high} found for {curr_backend}.") + cls.update_run_state(curr_backend, curr_subsystem, "bisect") + elif run_state == "bisect": + low, high = cls.get_bisect_range(curr_backend, curr_subsystem.name) + midpoint = (low + high) // 2 + print( + f"Bisecting {curr_backend} - {curr_subsystem.name} (Range: [{low}, {high}], Midpoint: {midpoint})" + ) + if fn(): + cls.update_bisect_range( + curr_backend, curr_subsystem.name, midpoint + 1, high + ) + else: + cls.update_bisect_range( + curr_backend, curr_subsystem.name, low, midpoint + ) + low, high = cls.get_bisect_range(curr_backend, curr_subsystem.name) + if low == high: + print( + f"Binary search completed for {curr_backend} - {curr_subsystem.name}. The bisect number is {low}. " + f"Debug info: {call_counter_debug_info.get(low, 'not found')}" + ) + return True + else: + raise RuntimeError(f"Unexpected run_state {run_state}") + + if cli_interface: + sys.exit(0) + + @classmethod + def initialize_system(cls) -> None: + curr_backend = next(iter(BACKENDS.keys())) + curr_subsystem = "" + cls.update_bisect_status(curr_backend, curr_subsystem) + print(f"Starting bisection process with system: {curr_backend}") + + @classmethod + def do_bisect( + cls, fn: Callable[[], bool], cli_interface: bool = False + ) -> Optional[BisectionResult]: + """ + Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure. + """ + + if not cli_interface: + bisection_enabled_orig = cls.bisection_enabled + cls.delete_bisect_status() + cls.bisection_enabled = True + cls.in_process_cache = tempfile.mkdtemp() + + def cleanup() -> None: + cls.bisection_enabled = bisection_enabled_orig + cls.delete_bisect_status() + cls.in_process_cache = None + + cleanup_handler = atexit.register(cleanup) + + class DisableBisect: + def __del__(self) -> None: + cleanup() + atexit.unregister(cleanup_handler) + + _cleanup = DisableBisect() + + curr_backend = cls.get_backend() + curr_subsystem_name = cls.get_subsystem() + + if not curr_backend: + cls.initialize_system() + curr_backend = cls.get_backend() + assert curr_backend is not None + curr_subsystem_name = cls.get_subsystem() + + curr_subsystem = ( + cls.get_subsystem_object(curr_backend, curr_subsystem_name) + if curr_subsystem_name is not None + else None + ) + while True: + assert curr_backend is not None + reset_counters() + if curr_subsystem: + result = cls.process_subsystem( + curr_backend, curr_subsystem, fn, cli_interface=cli_interface + ) + if result: + curr_subsystem = cls.get_subsystem_object( + curr_backend, + cls.get_subsystem(), # type: ignore[arg-type] + ) + + if isinstance(curr_subsystem, BinarySubsystem): + return BisectionResult( + curr_backend, + curr_subsystem.name, + 0, + curr_subsystem.name, + ) + + low, _ = cls.get_bisect_range(curr_backend, curr_subsystem.name) + return BisectionResult( + curr_backend, + curr_subsystem.name, + low, + call_counter_debug_info.get(low, None), + ) + + next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem) + if not next_subsystem: + print( + f"The issue is in the {curr_backend} system, but could not identify subsystem." + ) + assert curr_backend is not None + return BisectionResult(curr_backend) + + curr_subsystem = next_subsystem + else: + if fn(): + next_backend = cls.advance_backend(curr_backend) + if not next_backend: + print("All systems have been checked.") + return None + + curr_backend = next_backend + else: + current_subsystems = BACKENDS[curr_backend] + if current_subsystems: + curr_subsystem = current_subsystems[0] + cls.update_bisect_status(curr_backend, curr_subsystem.name) + cls.update_run_state( + curr_backend, curr_subsystem, "test_disable" + ) + print( + f"The issue is in the {curr_backend} system. Moving to the first subsystem: {curr_subsystem}" + ) + else: + print(f"The issue is in the {curr_backend} system.") + return BisectionResult(curr_backend) + + if cli_interface: + sys.exit(0) + + +def command_line_usage() -> None: + if len(sys.argv) < 2: + print("Usage: python bisect_update.py ") + sys.exit(1) + + bisection_manager = CompilerBisector() + command = sys.argv[1] + + if command == "end": + bisection_manager.delete_bisect_status() + sys.exit(0) + + if command == "start": + bisection_manager.delete_bisect_status() + bisection_manager.initialize_system() + sys.exit(0) + + if command not in ["good", "bad"]: + print("Invalid command. Must be 'good', 'bad', 'start', or 'end'.") + sys.exit(1) + + def test_function() -> bool: + return command == "good" + + if not bisection_manager.get_backend(): + raise ValueError("Must call start prior to good or bad") + + bisection_manager.do_bisect(test_function, cli_interface=True) + + +def get_is_bisection_enabled() -> bool: + return ( + CompilerBisector.get_subsystem() is not None + or CompilerBisector.get_backend() is not None + ) + + +CompilerBisector.bisection_enabled = get_is_bisection_enabled() + +if __name__ == "__main__": + command_line_usage() diff --git a/phivenv/Lib/site-packages/torch/_inductor/config.py b/phivenv/Lib/site-packages/torch/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..04e5124805a7db9373e14089b5b3d78cd537a615 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/config.py @@ -0,0 +1,1770 @@ +import os +import sys +from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union + +import torch +import torch._inductor.custom_graph_pass +from torch._environment import is_fbcode +from torch.utils._config_module import Config, get_tristate_env, install_config_module + + +inplace_padding = os.environ.get("TORCHINDUCTOR_INPLACE_PADDING", "1") == "1" +can_inplace_pad_graph_input = False # ease testing + + +def fx_graph_remote_cache_default() -> Optional[bool]: + return get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") + + +def vec_isa_ok_default() -> Optional[bool]: + if os.environ.get("TORCHINDUCTOR_VEC_ISA_OK") == "1": + return True + if os.environ.get("TORCHINDUCTOR_VEC_ISA_OK") == "0": + return False + return None + + +def autotune_remote_cache_default() -> Optional[bool]: + return get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") + + +def bundled_autotune_remote_cache_default() -> Optional[bool]: + return get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE") + + +def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]: + return get_tristate_env( + "TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE", + True if not is_fbcode() else None, + ) + + +def static_cuda_launcher_default() -> bool: + STATIC_CUDA_LAUNCHER_VERSION = 1 + + if "TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER" in os.environ: + return os.environ.get("TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER") == "1" + elif is_fbcode(): + version = torch._utils_internal.justknobs_getval_int( + "pytorch/inductor:static_cuda_launcher_version" + ) + return version <= STATIC_CUDA_LAUNCHER_VERSION + else: + # Default true in OSS + return True + + +def prologue_fusion_enabled() -> bool: + ENABLE_PROLOGUE_FUSION_VERSION = 0 + + if "TORCHINDUCTOR_PROLOGUE_FUSION" in os.environ: + return os.environ.get("TORCHINDUCTOR_PROLOGUE_FUSION") == "1" + elif is_fbcode(): + jk_name = "pytorch/inductor:prologue_fusion_version" + version = torch._utils_internal.justknobs_getval_int(jk_name) + return version <= ENABLE_PROLOGUE_FUSION_VERSION + else: + return True + + +# Enable auto_functionalized_v2 (enabled by default) +enable_auto_functionalized_v2 = ( + os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "1") == "1" +) + +# add some debug printouts +debug = False + +# Whether to disable a progress bar for autotuning +disable_progress = True + +# Whether to enable printing the source code for each future +verbose_progress = False + +# precompilation timeout +precompilation_timeout_seconds: int = 60 * 60 + +# use fx aot graph codegen cache +fx_graph_cache: bool = Config( + justknob="pytorch/remote_cache:enable_local_fx_graph_cache", + env_name_force="TORCHINDUCTOR_FX_GRAPH_CACHE", + default=True, +) + +# use remote fx aot graph codegen cache +# False: Disables the cache +# True: Enables the cache +# None: Not set -- Off for OSS, JustKnobs based for internal +fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default() + +# should we bundle triton caching into fx graph cache +bundle_triton_into_fx_graph_cache: Optional[bool] = ( + bundle_triton_into_fx_graph_cache_default() +) + +non_blocking_remote_cache_write: bool = Config( + justknob="pytorch/remote_cache:enable_non_blocking_remote_cache_write_v2", + env_name_force="TORCHINDUCTOR_NON_BLOCKING_REMOTE_CACHE_WRITE", + default=True, +) + +# Enable autotune local cache. +# +# See bundled_autotune_remote_cache for the effect this flag has on the bundled +# remote cache. +autotune_local_cache: bool = True + +# Enable autotune remote cache. +# +# Enables/disables the autotune remote cache regardless of the state of +# autotune_local_cache. If both local and remote are enabled then on write both +# are written and on read local is checked first and only on a cache miss is +# remote read. +# +# False: Disables the cache +# True: Enables the cache +# None: Not set -- Off for OSS, JustKnobs based for internal +autotune_remote_cache: Optional[bool] = autotune_remote_cache_default() + +# Enable bundled autotune cache. +# +# Enables/disables the bundled autotune cache regardless of the state of +# autotune_remote_cache. However it does depend on the local cache for local +# state management - as a result if the local cache is disabled this will also +# disable the bundled autotune cache. +# +# False: Disables the cache +# True: Enables the cache (requires autotune_local_cache) +# None: Not set -- Off for OSS, JustKnobs based for internal +bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default() + +# Force disabled all inductor level caching -- This will override any other caching flag +force_disable_caches: bool = Config( + justknob="pytorch/remote_cache:force_disable_caches", + env_name_force="TORCHINDUCTOR_FORCE_DISABLE_CACHES", + default=False, +) + +# Unsafe way to skip dynamic shape guards to get faster cache load +unsafe_skip_cache_dynamic_shape_guards: bool = False + +# Unsafe way to mark non torch functions as safe to cache +# dictionary is from function name -> cache key +# Any function name in the dictionary will be allowed to be cacheable +# by AOTAutogradCache and FxGraphCache. +# changing the cache key value will change the resulting +# FXGraphCache key. +# Example usage: +# torch._inductor.config.unsafe_marked_cacheable_functions = { +# 'torch.ops.my_function' : torch.__version__ +# } +# The above example causes the custom op torch.ops.my_function to be cacheable, +# and for cache keys to be keyed by the current torch version +unsafe_marked_cacheable_functions: dict[str, str] = {} + +# sleep in inductor for testing +sleep_sec_TESTING_ONLY: Optional[int] = None + +# The default layout constraint for user-defined triton kernels. +# See "The default layout constraint for custom operators" for options. +triton_kernel_default_layout_constraint: Literal[ + "needs_fixed_stride_order", "flexible_layout" +] = "needs_fixed_stride_order" + +# use cpp wrapper instead of python wrapper +# incompatible with disable_cpp_codegen +cpp_wrapper: bool = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" + +# controls whether to compile entry and kernel separately for cpp_wrapper mode. +# turn on this option to compile entry and kernel separately and minimize compile time of the entry part. +# see https://github.com/pytorch/pytorch/pull/148773 +# Note: compiling entry and kernel separately may have a non-negligible impact on the performance. +# see https://github.com/pytorch/pytorch/issues/156037 +cpp_wrapper_build_separate: bool = ( + os.environ.get("TORCHINDUCTOR_CPP_WRAPPER_BUILD_SEPARATE", "0") == "1" +) + +# Controls automatic precompiling of common include files for codecache.CppCodeCache +# (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is +# controlled by a separate flag. +cpp_cache_precompile_headers: bool = not is_fbcode() + +online_softmax = os.environ.get("TORCHINDUCTOR_ONLINE_SOFTMAX", "1") == "1" + +# dead code elimination +dce = False + +# assume weight tensors are fixed size +static_weight_shapes = True + +# put correctness assertions in generated code +size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1" +nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1" +scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1" + +# Disable by default in fbcode +alignment_asserts = ( + os.environ.get("TORCHINDUCTOR_ALIGNMENT_ASSERTS", "0" if is_fbcode() else "1") + == "1" +) + +# enable loop reordering based on input orders +pick_loop_orders = True + +# reuse a kernel input as the output +inplace_buffers = True + +# reuse a buffer for an unrelated purpose +allow_buffer_reuse = True + +# Enable pooled allocations for non-output tensors +memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1" + +# Enable to allow using ftz variant of exponenet instruction in triton codegen. +use_fast_math = os.environ.get("TORCHINDUCTOR_USE_FAST_MATH") == "1" + +# Enable bfloat16 atomic adds (fbcode only until upstreamed to triton) +bfloat16_atomic_adds_enabled = True + +# How to organize memory under memory_planning=True: +# - "none": do not try to pool storage, just reuse +# - "intermediates": all non-outputs share storage, outputs each get unique storage +# - "outputs": two pools, one for intermediates (freed on return) and one for outputs +# - "combined": a single pool for both intermediates and outputs +memory_pool: Literal["none", "intermediates", "outputs", "combined"] = os.environ.get( + "TORCHINDUCTOR_MEMORY_POOL", "intermediates" +) # type: ignore[assignment] + +# codegen benchmark harness +benchmark_harness = True + +# fuse pointwise into templates epilogues +epilogue_fusion = True + +# fuse pointwise into template prologues +prologue_fusion = prologue_fusion_enabled() + +# do epilogue fusions before other fusions +epilogue_fusion_first = False + +# enable pattern match+replace optimizations +pattern_matcher = True + +# set to True to enable the back-to-back GEMM pass +b2b_gemm_pass = False + +# register custom graph optimization pass hook. so far, pre/post passes are +# only applied before/after pattern_matcher in post_grad_passes. +# +# Implement CustomGraphPass to allow Inductor to graph compiled artifacts +# to which your custom passes have been applied: +post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None +post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None + +# Registers a custom joint graph pass. +joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None +joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None + +# Registers a custom pregrad pass. Note that the pre-grad IR is 1. +# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should +# use post-grad passes. +pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Registers a custom pass to be run right before fusion in Inductor scheduler. +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +_pre_fusion_custom_pass: Optional[ + Callable[ + [list["torch._inductor.scheduler.BaseSchedulerNode"]], + list["torch._inductor.scheduler.BaseSchedulerNode"], + ] +] = None + +# Registers a custom pass to be run right after fusion in Inductor scheduler. +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +_post_fusion_custom_pass: Optional[ + Callable[ + [list["torch._inductor.scheduler.BaseSchedulerNode"]], + list["torch._inductor.scheduler.BaseSchedulerNode"], + ] +] = None + +# Deprecated +split_cat_fx_passes = True + +# Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. +efficient_conv_bn_eval_fx_passes = False + +# Enable predispatch aten IR for export +is_predispatch = False + +# Deprecated +group_fusion = False + +# Deprecated +batch_fusion = True + +# Pre grad fusion and options in order, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions. +# batch fusion options: +# batch_linear +# batch_linear_lhs +# batch_layernorm +# batch_tanh +# batch_relu +# batch_sigmoid + +# split cat fusion options: +# normalization_pass +# remove_split_with_size_one_pass +# merge_getitem_cat_pass +# merge_stack_tahn_unbind +# merge_splits_pass +# mutate_cat_pass +# split_cat_pass +pre_grad_fusion_options: dict[str, dict[str, Any]] = {} + +# Post grad fusion and options, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions. +post_grad_fusion_options: dict[str, dict[str, Any]] = {} + +# enable reordering pass for improving memory locality +reorder_for_locality = True + +# Scale down Rn_BLOCK for better occupancy +dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1" + +# this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32 +# but the mul gets fused with other pointwise ops instead. +force_fuse_int_mm_with_mul = False + +# DEPRECATED. This setting is ignored. +use_mixed_mm = True + +# enable runtime numeric check for pre/post grad fx passes +# floating point provides limited accuracy (about 7 decimal digits for single precision +# floating point numbers,about 16 decimal digits for double precision floating point numbers) +# according to PyTorch documentation. +# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations +fx_passes_numeric_check: dict[str, Any] = { + "pre_grad": False, + "precision": 1e-4, + "num_iterations": 1, + "requires_optimizer": True, +} + +# DEPRECATED. This setting is ignored. +mixed_mm_choice: Literal["default", "triton", "aten", "heuristic"] = "heuristic" + +# enable reordering pass for increasing overlap between compute and communication +reorder_for_compute_comm_overlap = False + +# passes (in execution order) for increasing overlap between compute and communication +# for built-in passes, use string name; for user-defined passes, pass in the function handle +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +reorder_for_compute_comm_overlap_passes: list[ + Union[ + str, + Callable[ + [list["torch._inductor.scheduler.BaseSchedulerNode"]], + list["torch._inductor.scheduler.BaseSchedulerNode"], + ], + ] +] = [ + "reorder_compute_for_overlap", + "sink_waits", + "raise_comms", +] + +# Maximum number of positions to advance a given collective, unlimited by default +reorder_prefetch_limit: Optional[int] = None + +# enable operator reordering for peak memory optimization +reorder_for_peak_memory = True + +# runtime estimation function for ops +# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle +estimate_op_runtime = "default" + +# unit: GB/s, uni-directional P2P bandwidth per card +# default value is NVLink +intra_node_bw = 300 + +# unit: GB/s, uni-directional P2P bandwidth per node +# default value is InfiniBand +inter_node_bw = 25 + +# use Inductor's experimental benchmarker (runtime/benchmarking.py) +# to benchmark kernels during autotuning, otherwise fall back to +# Triton's `do_bench`. the experimental benchmarker may produce +# results that are not consistent with `do_bench`'s results +use_experimental_benchmarker: bool = Config( + default=True, + env_name_force="TORCHINDUCTOR_USE_EXPERIMENTAL_BENCHMARKER", + justknob="pytorch/inductor:use_experimental_benchmarker", +) + +# enable slow autotuning passes to select algorithms +max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" + +# enable slow autotuning passes to select pointwise/reductions algorithms +max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1" + +# enable slow autotuning passes to select gemm algorithms +max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" + +# disable decomposek autotune choice for gemm +disable_decompose_k = os.environ.get("TORCHINDUCTOR_DISABLE_DECOMPOSE_K") == "1" + +# Modifies the number of autotuning choices displayed, set to None for all +autotune_num_choices_displayed: Optional[int] = 10 + +# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph +graph_partition = False + +# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations +# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations +# for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure +# that triton does not use TF32 wherever cublas would not use TF32 +force_same_precision = ( + True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" +) + +# Specify candidate backends for gemm autotune. +# Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CPP. +# ATen: default Pytorch ATen kernels. +# Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs). +# CUTLASS: Cutlass templates and kernels (NVidia GPUs only). +# CK: Composable Kernel templates and kernels (AMD Instinct GPUs only). +# CPP: CPP templates and kernels for CPU. +max_autotune_gemm_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" +).upper() + + +# As above, specify candidate backends for conv autotune. +# NB: in some cases for 1x1 convs we emit as matmul, +# which will use the backends of `max_autotune_gemm_backends` +max_autotune_conv_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_CONV_BACKENDS", "ATEN,TRITON" +).upper() + + +# Specify the size of the search space for GEMM autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_gemm_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT" +).upper() # type: ignore[assignment] + +# Specify the size of the search space for flex attention autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" +).upper() # type: ignore[assignment] + +# DEPRECATED. This setting is ignored. +autotune_fallback_to_aten = False + +# the value used as a fallback for the unbacked SymInts +# that can appear in the input shapes (e.g., in autotuning) +unbacked_symint_fallback = 8192 + +# enable searching global and local cache regardless of `max_autotune` +search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1" + +save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" + +# We will disable creating subprocess for autotuning if this is False +autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1" + +# The following three timeouts are applicable if autotune_in_subproc is True: + +# Max time that a valid benchmark result may take during autotuning +max_autotune_subproc_result_timeout_seconds = 60.0 +# DEPRECATED. This setting is ignored. +max_autotune_subproc_graceful_timeout_seconds = 0.0 +# DEPRECATED. This setting is ignored. +max_autotune_subproc_terminate_timeout_seconds = 0.0 + +# If autotuning in subprocess, whether to use multiple devices +autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" + +coordinate_descent_tuning = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" +) +coordinate_descent_check_all_directions = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1" +) +coordinate_descent_search_radius = int( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1") +) + +# AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and +# generate the learned heuristic to code which is shipped with the compiler +# Specify a list of comma separated optimizations to collect data for +autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "") +# Specify a list of comma separated optimizations to use learned heuristics for +autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm") + + +def run_autoheuristic(name: str) -> bool: + return collect_autoheuristic(name) or use_autoheuristic(name) + + +def collect_autoheuristic(name: str) -> bool: + return name in torch._inductor.config.autoheuristic_collect.split(",") + + +def use_autoheuristic(name: str) -> bool: + return name in torch._inductor.config.autoheuristic_use.split(",") + + +# If set to "DEFAULT", this will use the default log path specified in autoheuristic.py. +# If set to another path, autoheuristic will instead log results to the given path. +autoheuristic_log_path = os.environ.get( + "TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH", "DEFAULT" +) + +# Disabled by default on ROCm, opt-in if model utilises NHWC convolutions +layout_opt_default = "1" if not torch.version.hip else "0" +layout_optimization = ( + os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1" +) + +force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1" + + +# Whether to keep the output strides the same as eager after layout optimization. +keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1" + +# Enabling this will let compiler print warning messages if a generated triton +# kernel has inputs with mixed layouts. This is helpful for perf debugging +# since kernel with mixed layout inputs may run much slower then one whose inputs +# have uniform layouts. +warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1" + +# control store vs recompute heuristic +# For fanouts, rematerialization can lead to exponential blowup. So, have +# smaller threshold +realize_reads_threshold = 4 +realize_opcount_threshold = 30 + +# Threshold to prevent excessive accumulation of ops in one buffer during lowering +realize_acc_reads_threshold = 8 + +# fallback to eager for random/dropout, this is slow but useful for debugging +fallback_random = False + +# automatically create fallbacks when encountering an unhandled op +implicit_fallbacks = True +assume_unaligned_fallback_output = ( + os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1" +) + +# fuse even in cases without common reads +aggressive_fusion = False + +# For each fused kernel in the wrapper, comment with the nodes that get fused. +# Useful for debugging fusion. +debug_fusion: bool = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" +benchmark_fusion: bool = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" +enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "") +loop_ordering_after_fusion: bool = ( + os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1" +) + +# If fusing two nodes only save less then score_fusion_memory_threshold memory, +# we should not bother fusing the nodes. +# +# This is especially helpful to resolve https://github.com/pytorch/pytorch/issues/133242 +# Previously we fuse two nodes because of common read of a scalar tensor. +# If we skip it, the loop ordering after fusion mechanism kicks in and can +# brings more savings. +# +# For the cases loop ordering after fusion does not help, we don't lose much. +score_fusion_memory_threshold = 10 + +# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel +benchmark_epilogue_fusion = ( + os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1" +) + +# Take how many of the top triton kernels to benchmark epilogue +max_epilogue_benchmarked_choices = 1 + +# how many nodes to allow into a single fusion +max_fusion_size = 64 + +# how many nodes to attempt pairwise fusion with in a buffer group +max_fusion_buffer_group_pairwise_attempts = 64 + +# max number of inputs to generate cat as a pointwise op with masked loads +max_pointwise_cat_inputs = 8 + +# force concat to be generated as a pointwise op with masked loads +force_pointwise_cat = False + +# replace small reductions with pointwise, disable with `= 1` +unroll_reductions_threshold = 8 + +# Add extra comments to output code (causes compile cache misses) +comment_origin = False + +# Convert 1x1 convs into matmuls +conv_1x1_as_mm = False + +# For reductions with a small output size (usually 1, e.g. x.sum()) there is not enough +# parallelism to saturate the GPU. We have two ways of handling this, either `split_reductions` +# or `triton.cooperative_reductions` which are mutually exclusive. +# split_reductions: uses multiple kernels to gain more parallelism +# triton.cooperative_reductions: uses cross thread-block synchronization to gain more parallelism +# enabling both of these will implicitly disable split_reductions +split_reductions = True + +# When we do split reduction, this number control the minimum value for +# num_split. Too small num_split make the split reduction less efficient. +# It's a much bigger problem when we compile a dynamic shape kernel with +# non-representative inputs. +min_num_split = int(os.environ.get("TORCHINDUCTOR_MIN_NUM_SPLIT", 0)) + +benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1" + +# Enable constant and index_expr folding +constant_and_index_propagation = True + +# we always add constants into graph.constants without +# performing any constant-inlining optimization +always_keep_tensor_constants = False + +# assert that indirect indexing does not read / write out of bounds +assert_indirect_indexing = True + +# compute CSE bounds on variables that do not appear in the FX graph +compute_all_bounds = False + +# enable the combo kernel that combines data-independent kernels (additional +# to foreach kernels) into a single one (Experimental) +combo_kernels = False +# benchmark combo kernels and only allow ones with perf gains +benchmark_combo_kernel = False +# combo_kernel autotuning options: 0 - disable, 1 - enable except for foreach, +# 2 - enable for all +combo_kernels_autotune = 1 +# Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable +# for all except for foreach, 2 - enable for all +combo_kernel_allow_mixed_sizes = 1 +# Enable dynamic shapes for foreach kernels +combo_kernel_foreach_dynamic_shapes = False + +# constant folding on the joint graph +joint_graph_constant_folding = True + +# Enable indirect_indexing asserts for decompositions and lowerings +debug_index_asserts = False + +# Mode to emulate PyTorch eager numerics when doing lower precision compute +# (fp16, bf16). PyTorch eager computes bf16/fp16 by upcasting inputs to fp32 +# and downcasting after. When two low precision operators are fused together, +# Inductor will elide the downcast-upcast pairs (effectively a precision +# truncation) that would occur between these two operators. Typically, +# Inductor's behavior should be closer to fp64 ref numerics. However, with +# this knob you can ensure the downcast-upcast are preserved so that you can +# emulate the eager numerics. +emulate_precision_casts = ( + os.environ.get("TORCHINDUCTOR_EMULATE_PRECISION_CASTS", "0") == "1" +) + +# warnings intended for PyTorch developers, disable for point releases +is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__ +developer_warnings = is_fbcode() or is_nightly_or_source + +# This pattern matches a special usage of scatter +# 1. It's applied to a constant tensor +# 2. The index tensor has size 1 in the scatter dimension +# Such pattern generates a sparse matrix when the const tensor is all-zero. +# We can lower this pattern to a pointwise kernel for more fusion opportunities +# and saving memory footprint. +optimize_scatter_upon_const_tensor = ( + os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1" +) + +# options in caffe2/torch/_inductor/fx_passes/pre_grad.py +add_pre_grad_passes: Optional[str] = None +remove_pre_grad_passes: Optional[str] = None + + +# The multiprocessing start method to use for inductor workers in the codecache. +def decide_worker_start_method() -> str: + if "TORCHINDUCTOR_WORKER_START" in os.environ: + start_method = os.environ["TORCHINDUCTOR_WORKER_START"] + else: + start_method = "subprocess" + assert start_method in ( + "subprocess", + "fork", + "spawn", + ), f"Invalid start method: {start_method}" + return start_method + + +worker_start_method: str = decide_worker_start_method() + +# Whether to log from subprocess workers that are launched. +worker_suppress_logging: bool = Config( + justknob="pytorch/compiler:worker_suppress_logging", + env_name_force="TORCHINDUCTOR_WORKER_SUPPRESS_LOGGING", + default=True, +) + +# Flags to turn on all_reduce fusion. These 2 flags should be automatically turned +# on by DDP and should not be set by the users. +_fuse_ddp_communication = False +_fuse_ddp_bucket_size = 25 + +# Flag to control which fusion passes to apply. Functions in the list will +# be applied in order. There are two different different fusion passes +# --"fuse_ddp_with_concat_op" and "fuse_ddp_with_coalesced_op". The default +# one is "fuse_ddp_with_concat_op". Users can also change this to a customized +# fusion function. +# +# The fusion currently does not support multiple DDP with different PG or +# data type. This feature will be added in the future PRs. +# +# "schedule_comm_wait" is used to delay the wait ops to maximize comm/comp +# overlapping. At this moment, this pass performs better than +# reorder_for_compute_comm_overlap_passes but we will add the logic of +# "schedule_comm_wait" in the future and remove the one here. +_fuse_ddp_communication_passes: list[Union[Callable[..., None], str]] = [ + "fuse_ddp_with_concat_op", + "schedule_comm_wait", +] + +_micro_pipeline_tp: bool = False + + +class _collective: + auto_select: bool = False + one_shot_all_reduce_threshold_bytes: int = 128 * 1024 + + +def parallel_compile_enabled_internally() -> bool: + """ + TODO: Remove when parallel compiled is fully enabled internally. For rollout, use a + knob to enable / disable. The justknob should not be performed at import, however. + So for fbcode, we assign compile_threads to 'None' below and initialize lazily in + async_compile.py. + """ + ENABLE_PARALLEL_COMPILE_VERSION = 1 + + jk_name = "pytorch/inductor:enable_parallel_compile_version" + version = torch._utils_internal.justknobs_getval_int(jk_name) + return ENABLE_PARALLEL_COMPILE_VERSION >= version + + +def decide_compile_threads() -> int: + """ + Here are the precedence to decide compile_threads + 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by + setting this to 1 to make pdb happy. + 2. Set to 1 if it's win32 platform + 3. decide by the number of CPU cores + """ + import logging + + # Defined locally so install_config_module doesn't try to parse + # as a config option. + log = logging.getLogger(__name__) + + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + compile_threads = int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + log.info("compile_threads set to %d via env", compile_threads) + elif sys.platform == "win32": + compile_threads = 1 + log.info("compile_threads set to 1 for win32") + elif is_fbcode() and not parallel_compile_enabled_internally(): + compile_threads = 1 + log.info("compile_threads set to 1 in fbcode") + else: + cpu_count = ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count() + ) + assert cpu_count + compile_threads = min(32, cpu_count) + log.info("compile_threads set to %d", compile_threads) + + return compile_threads + + +# TODO: Set directly after internal rollout. +compile_threads: Optional[int] = None if is_fbcode() else decide_compile_threads() + +# Whether or not to enable statically launching CUDA kernels +# compiled by triton (instead of using triton's own launcher) +use_static_cuda_launcher: bool = static_cuda_launcher_default() + +# Attempt to statically launch user defined triton kernels +# Requires use_static_cuda_launcher +static_launch_user_defined_triton_kernels: bool = Config( + justknob="pytorch/inductor:static_launch_user_defined_triton_kernels", + env_name_force="TORCHINDUCTOR_STATIC_LAUNCH_USER_DEFINED_TRITON_KERNELS", + default=False, +) + +# Raise error if we bypass the launcher +strict_static_cuda_launcher: bool = ( + os.environ.get("TORCHINDUCTOR_STRICT_STATIC_CUDA_LAUNCHER", "0") == "1" +) + +# gemm autotuning global cache dir +global_cache_dir: Optional[str] +if is_fbcode(): + try: + from libfb.py import parutil + + if __package__: + global_cache_dir = parutil.get_dir_path( + os.path.join(__package__.replace(".", os.sep), "fb/cache") + ) + else: + global_cache_dir = parutil.get_dir_path("fb/cache") + except (ValueError, ImportError): + global_cache_dir = None + +else: + global_cache_dir = None + +# If kernel is fused, the name is generated from the origin node op names +# for larger kernels limit this +kernel_name_max_ops = 10 + +# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs +shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1" + +# Control if we will do padding for pointwise/reductions +comprehensive_padding = ( + os.environ.get("TORCHINDUCTOR_COMPREHENSIVE_PADDING", "1") == "1" +) +pad_channels_last = False + +# Disable comprehensive padding on the CPU +disable_padding_cpu = True + +# The width of comprehensive padding, in bytes. +# CUDA max memory transaction size is 128 bytes for a warp. +padding_alignment_bytes = 128 + +# Threshold on the minimum stride that will be padded. +# +# Don't align a too small stride since that causes too much memory increase. +# Pad too small stride may also cause perf loss. We may result in many tiny data blocks +# with gaps in between. That causes less coalesced GPU memory access! +# +# Initially we pick 320 as the threshold since for alignment=16, +# that results in at most 5% memory cost. +# +# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. +# Let's say an inner reduction has a row size 513. Inductor will generate +# persistent reduction code. +# If we do padding, the strides are not contiguous any more. Inductor +# uses a much smaller threshold for persistent reduction in this case and +# generates potentially worse non-persistent reduction code. +# +# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. +# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) +padding_stride_threshold = 1024 + +# Enable padding outputs, even if they would not be padded in eager mode. +# By default, we use the same strides as eager mode. +pad_outputs = False + +# Whether to treat output of the backward graph as user visible. +# For user visible outputs, inductor will make sure the stride matches with eager. +bw_outputs_user_visible = True + +# Whether to always use shape padding if it is enabled and possible +force_shape_pad: bool = False + +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + +# Mark the wrapper call in PyTorch profiler +profiler_mark_wrapper_call = False + +# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for +# every intermediate for which we can correlate it with an intermediate +# from the original FX graph +generate_intermediate_hooks = False + +# Populate traceback field on IRNode; good for debugging why origin_node is +# not populated, or finding out where an IRNode was constructed +debug_ir_traceback = False + +# used for debugging to make sure config is properly set +_raise_error_for_testing = False + +_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "") +profile_bandwidth = _profile_var != "" +profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var +# Specify a file where we print out the profiling results. +# None means we do not dump results to a file. +profile_bandwidth_output: Optional[str] = os.environ.get( + "TORCHINDUCTOR_PROFILE_OUTPUT", None +) +# Switch to do_bench_using_profiling to exclude the CPU overheads +profile_bandwidth_with_do_bench_using_profiling = ( + os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1" +) + + +# TODO: remove later +# incompatible with cpp_wrapper +disable_cpp_codegen = False + + +# Freezing will attempt to inline weights as constants in optimization +# and run constant folding and other optimizations on them. After freezing, weights +# can no longer be updated. +freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1" + +# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead +# of potentially keeping multiple copies of weights. +freezing_discard_parameters: bool = False + +# decompose some memory bound matmul/bmm to mul +decompose_mem_bound_mm: bool = False + +# assume_aligned_inputs means that we assume that inputs will be aligned; we generate +# code using this assumption, and clone tensors before use if they aren't aligned. +# In the common case, most inputs will be aligned. +assume_aligned_inputs: bool = False + +# For the user-written Triton kernels compiled with the model, ignore the unsupported +# arguments passed to the @triton.autotune in the user's code; this is unsafe, as +# ignoring the unsupported args may lead to unexpected autotuning behavior: don't +# set unless you know what you're doing. +unsafe_ignore_unsupported_triton_autotune_args: bool = False + +# When True, we will check in scheduler.py _codegen that there are no "loops" +# in the call stack; that is to say, the same frame multiple times. This +# ensures that a cProfile trace to this frame will be a straight line without +# any cycles. Incompatible with cpp_wrapper. +check_stack_no_cycles_TESTING_ONLY: bool = False + +# When True, complex_memory_overlap always reports True +always_complex_memory_overlap_TESTING_ONLY: bool = False + +# enable linear binary folding +enable_linear_binary_folding = ( + os.environ.get("TORCHINDUCTOR_ENABLE_LINEAR_BINARY_FOLDING", "0") == "1" +) + + +# Adds NVTX annotations around training phases +annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1" + +# Enable caching codegen of triton templates. +enable_caching_generated_triton_templates: bool = False + + +# config specific to codegen/cpp.py +class cpp: + # set to torch.get_num_threads() + threads = -1 + + # Do not generate loops when the condition doesn't hold, like: + # for(long i0=4096; i0<4096; i0+=1) + no_redundant_loops = ( + os.environ.get("TORCHINDUCTOR_CPP_NO_REDUNDANT_LOOPS", "1") == "1" + ) + + # Assume number of threads is dynamic, don't specialize thread number. + # Kernels don't recompile on thread number changes with this flag on. + # For single-threaded workload, turning it on would incur a slight + # performance degradation. + dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1" + + simdlen: Optional[int] = None + min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096")) + + cxx: tuple[Literal[None], str] = ( + None, # download gcc12 from conda-forge if conda is installed + os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"), + ) # type: ignore[assignment] + + # Allow kernel performance profiling via PyTorch profiler + enable_kernel_profile = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1" + ) + + # enable weight prepacking to get a better performance; may lead to large memory footprint + weight_prepack = os.environ.get("TORCHINDUCTOR_CPP_WEIGHT_PREPACK", "1") == "1" + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + inject_log1p_bug_TESTING_ONLY: Optional[str] = None + + # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise, + # force usage as specified, without testing. Default None. + vec_isa_ok: Optional[bool] = get_tristate_env("TORCHINDUCTOR_VEC_ISA_OK") + + # similar to config.triton.descriptive_names + descriptive_names: Literal["torch", "original_aten", "inductor_node"] = ( + "original_aten" + ) + + # how many nodes to allow into a single horizontal fusion + max_horizontal_fusion_size = int( + os.environ.get("TORCHINDUCTOR_CPP_MAX_HORIZONTAL_FUSION_SIZE", "16") + ) + + # Make scatter_reduce fallback when reduce is sum to avoid performance regression + # using atomic_add. + fallback_scatter_reduce_sum = ( + os.environ.get("TORCHINDUCTOR_CPP_FALLBACK_SCATTER_REDUCE_SUM", "1") == "1" + ) + + # Use funsafe-math-optimizations when compiling + enable_unsafe_math_opt_flag = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_UNSAFE_MATH_OPT_FLAG", "0") == "1" + ) + + # Use ffp-contract when compiling + # Options: "off" (default), "on", "fast" + # Per https://godbolt.org/z/bf4bvfc9r , clang/gcc has different behavior for "fast" + enable_floating_point_contract_flag = os.environ.get( + "TORCHINDUCTOR_CPP_ENABLE_FLOATING_POINT_CONTRACT_FLAG", "off" + ) + + # Disable the tiling select heuristic + enable_tiling_heuristics = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1" + ) + + # Enable the Grouped GEMM Fusion + enable_grouped_gemm_template = False + + # Maximal allowed number of slices on K-dim for a GEMM kernel. This controls + # the maximal parallelism of K-slicing. Since K-slicing requires extra thread + # synchronization and buffers, the maximal number of slices is limited to + # mitigate the sync overhead and memory usage. + # When set to 0, the number of slices is unlimited. + gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1")) + + # For perf tuning and debugging purpose, configure the pre-defined cache blocking for + # MxNxK dims respectively. The blockings are separated by comma and the unit is + # the number of register blocks. + # For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively. + gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None) + + # For perf tuning and debugging purpose, configure the pre-defined thread blocking factors for + # MxNxK dims respectively. The factors are separated by comma and their product + # should be the same as the total number of threads. + # For example, if the total number of threads is 56, "7,4,2" means the work is + # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM. + gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None) + + # Whether to enable masked vectorization for the tail_loop. + enable_loop_tail_vec = True + + # Whether to enable concat linear for cpu device + # Currently concat linear on CPU not always have benefit, depends on linear'shape or + # computing resource. We set this default to False to avoid regressions. User and + # enable this feature by their need. + enable_concat_linear = False + + # Whether to use decomposed tanh for cpu device + # Disable by default due to https://github.com/pytorch/pytorch/issues/148241 + use_decompose_tanh = ( + os.environ.get("TORCHINDUCTOR_CPP_USE_DECOMPOSE_TANH", "0") == "1" + ) + + # Use a small dequant buffer for wgt of woq int4 size as: [q_group_size, Nr] + use_small_dequant_buffer = False + + +# config specific to codegen/triton.py +class triton: + # Use cudagraphs on output code + cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1" + + # Use cudagraph trees for memory pooling if `cudagraphs` is True + cudagraph_trees = True + + # Should we skip cudagraphing graphs with dynamic shape inputs + # If False, we will re-record a graph for each unique set of shape inputs + cudagraph_skip_dynamic_graphs = False + + # Specify dynamic shapes to capture cudagraphs and skip cudagraph for other shapes. + # Default to None, which means we capture cudagraphs for all shapes. + cudagraph_capture_sizes: Optional[tuple[Union[int, tuple[int, ...]]]] = None + + # assertions not on the fast path, steady state + slow_path_cudagraph_asserts = True + + # TODO - need to debug why this prevents cleanup + cudagraph_trees_history_recording = False + + # Enable cudagraph support for mutated inputs from prior cudagraph pool + cudagraph_support_input_mutation = False if is_fbcode() else True + + # Maximal number of allowed cudagraph re-record for a function and + # a cudagraph node due to static input tensor address changes or + # cudagraph managed tensor data pointer changed. + # i.e., allow num_recording <= cudagraph_unexpected_rerecord_limit + # note: we are conservative here and choose a large limit. + cudagraph_unexpected_rerecord_limit = 128 + + # Warn loudly when the number of cudagraphs due to dynamic shape + # exceeds this limit + cudagraph_dynamic_shape_warn_limit: Optional[int] = 50 + + # synchronize after cudagraph invocation + force_cudagraph_sync = False + + # always run cudagraphs in the eager warmup stage + # instead of recording and executing cudagraphs + force_cudagraphs_warmup = False + + # assertions on the fast path + fast_path_cudagraph_asserts = False + + # skip warmup for cudagraph trees + skip_cudagraph_warmup = False + + # Synchronize before and after every compiled graph. + debug_sync_graph = False + + # Synchronize after every kernel launch, to help pinpoint bugs + debug_sync_kernel = False + + # Always load full blocks (rather than broadcasting inside the block) + dense_indexing = False + + # TODO - enable by default + coalesce_tiling_analysis: bool = ( + os.environ.get( + "TORCHINDUCTOR_COALESCE_TILING_ANALYSIS", "1" if not is_fbcode() else "0" + ) + == "1" + ) + + # limit tiling dimensions + # - max_tiles=1 disables tiling + # - max_tiles=2 + # - max_tiles=3 is experimental and may have bugs + # higher values are unsupported + + # We use a max of 3 if coalesce_tiling_analysis is True, and 2 otherwise. + # Note - coalesce_tiling_analysis does not yet apply to dynamic shapes. + max_tiles: Optional[int] = None + + # Prefer higher dimensional tilings. This simplifies indexing expressions, making + # it easier to identify block pointers. + prefer_nd_tiling: bool = False + + # use triton.autotune for pointwise ops with complex layouts + # this should only be disabled for debugging/testing + autotune_pointwise = True + + # max autotune gemm with cublasLt + autotune_cublasLt = True + + # Tune the generated Triton kernels at compile time instead of first time they run + # Setting to None means uninitialized + autotune_at_compile_time: Optional[bool] = None + + # We use random tensors for autotune by default. Setting this as true will let us + # use inputs from sample inputs to autotune user defined triton kernels. + # Side effect for this option is increased memory footprint during first pass compilation. + autotune_with_sample_inputs: bool = False + + # Allows tiling reductions into multiple dimensions. + # For best results, this should be used with prefer_nd_tiling. + tile_reductions: bool = False + + # should we stop a fusion to allow better tiling? + tiling_prevents_pointwise_fusion = True + tiling_prevents_reduction_fusion = True + + # should we give different names to kernels + # Note: This is orthogonal to descriptive_names - this is deciding whether + # our triton kernel names should all be `triton_` (to maximize caching) or + # whether they should be unique. + unique_kernel_names = ( + os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES", "1") == "1" + ) + + # similar to the option above, but this is specific to user defined kernels, + # while unique_kernel_name is for kernels generated by inductor. + # We have this option because sometimes we reuse user's kernel code with different + # configs which would result in the same name. + # Note: This MODIFIES the user's kernel function name within inductor phase. + unique_user_kernel_names = ( + os.environ.get("TORCHINDUCTOR_UNIQUE_USER_KERNEL_NAMES", "0") == "1" + ) + + # should we put op names in kernel names + # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.) + # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions) + # "inductor_node": Maps to the node name in the FX graph passed to Inductor + descriptive_names: Literal["torch", "original_aten", "inductor_node"] = ( + "original_aten" + ) + + # use alternate codegen for smaller reductions + persistent_reductions = ( + os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1" + ) + + # For small output size reductions uses cross thread-block synchronization to gain more parallelism + cooperative_reductions = ( + os.environ.get("TORCHINDUCTOR_COOPERATIVE_REDUCTIONS", "0") == "1" + ) + + # used for debugging cooperative reduction codegen, always generate cooperative_reductions + force_cooperative_reductions = False + + # 0: disable + # 1/True: enable, use tuning to pick between different subkernels + # 2: enable, force using persistent reduction (for debugging) + # 3: enable, force using non-persistent reduction (for debugging) + multi_kernel: Literal[0, 1, 2, 3] = int( + os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0") + ) # type: ignore[assignment] + + # hint to Triton when arguments are divisible by 16 + divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1" + + # Minimum R0_BLOCK to be used for a TritonSplitScanKernel + # NOTE: This also indirectly controls the size of workspace buffer required + min_split_scan_rblock = 256 + + # Store the generated cubin files for cpp wrapper code to load + store_cubin = False + + # the max number of spills we allow for the configs we benchmark. + # Setting this to 0 means we skip a config if it spills even a single + # register. + # Setting it to a larger value allows a config spilling a small amount + # of registers being benchmarked. + # + # NOTE: triton will always report >0 register spills for kernels using sin/cos. + # (check this issue https://github.com/triton-lang/triton/issues/1756 ) + # So far we see a fixed 8 spilled registers for kernels using sin/cos. + # Raise the threshold to 16 to be safe. + # We should revisit this once we understand more of the source of register spills. + spill_threshold: int = 16 + + # Generate code containing the newer tl.make_block_ptr() API for loads/store + use_block_ptr = False + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + + # Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental) + codegen_upcast_to_fp32 = True + + # Whether persistent matmul kernels should be enabled this flag only has effect when on h100 + # with a version of triton new enough to support TMA + enable_persistent_tma_matmul = ( + os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1" + ) + # Skip L1 cache for buffers that are used only once. Disabled by default + skip_l1_cache = os.environ.get("TORCHINDUCTOR_SKIP_L1", "0") == "1" + + # During autotuning, if one of the kernels/configs fails for some reason, + # Inductor will usually skip it (and assign its latency to inf). + # For testing it's helpful to be able to assert that none of the configs fail. + # Note: it may also need to be used with config.compile_threads = 1 + disallow_failing_autotune_kernels_TESTING_ONLY = False + + +class aot_inductor: + """ + Settings for Ahead-Of-Time Inductor Compilation + """ + + # AOTInductor output path + # If an absolute path is specified, the generated lib files will be stored under the directory; + # If a relative path is specified, it will be used as a subdirectory under the default caching path; + # If not specified, a temp directory will be created under the default caching path. + # If the specified path contains something like "model.so", the sub-string will be used + # to name the generated library. + output_path = "" + + debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + + # Annotate generated main wrapper function, i.e. AOTInductorModel::run_impl, + # to use which cpp compiler optimization level, default to O1 + compile_wrapper_opt_level = os.environ.get( + "AOT_INDUCTOR_COMPILE_WRAPPER_OPT_LEVEL", "O1" + ) + + # option for debug printing/saving for intermediate tensor values for aot inductor + # 0: disable debug dumping + # 1: enable saving intermediate tensor values + # 2: enable printing intermediate tensor values + # 3: enable printing kernel names only (useful for pinpointing troublesome kernels) + debug_intermediate_value_printer: Literal["0", "1", "2", "3"] = os.environ.get( + "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0" + ) # type: ignore[assignment] + + # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2 + filtered_kernel_names = os.environ.get( + "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None + ) + + # Serialized tree spec for flattening inputs + # TODO: Move this into metadata + serialized_in_spec = "" + + # Serialized tree spec for flattening outputs + # TODO: Move this into metadata + serialized_out_spec = "" + + # flag to decide whether to create a submodule for constant graph. + use_runtime_constant_folding: bool = False + + # flag to force weight to be appended to the shared library and mapped by the runtime + # rather than embedded into the data section. Needed to support 1B+ parameter models + force_mmap_weights: bool = False + + package: bool = False + package_cpp_only: bool = False + + # Dictionary of metadata users might want to save to pass to the runtime. + # TODO: Move this somewhere else, since it's no longer really a config + metadata: dict[str, str] = {} + + # fbcode only. Whether to raise error if C++ codegen is too big to optimize + raise_error_on_ignored_optimization: bool = ( + os.environ.get("AOTINDUCTOR_RAISE_ERROR_ON_IGNORED_OPTIMIZATION", "1") == "1" + ) + + # dump an aoti minifier if program errors + dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1" + + # Compiler compilation debug info + # 1: Dumps the original graph out to repro.py if compilation fails + # 2: Dumps a minifier_launcher.py if aoti fails. + # 3: Always dumps a minifier_launcher.py. Good for segfaults. + # 4: Dumps a minifier_launcher.py if the accuracy fails. + repro_level: int = int(os.environ.get("AOTINDUCTOR_REPRO_LEVEL", 2)) + + # Dictionary of presets that can be passed in + presets: dict[str, Any] = {} + + # Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests + # should be run with this flag both on and off to make sure we have coverage. + allow_stack_allocation: bool = False + + # Enables an alternate DSO interface (the "minimal ArrayRef interface") intended + # to maximize performance for use cases that it can accommodate at the expense of + # generality. In brief: + # - inputs and outputs are ArrayRefTensor (note that strides are required, but the + # tensor must be contiguous) + # - constant handling is unchanged because it is not a per-inference-iteration bottleneck + # + # When the DSO is generated in this mode, the usual interface will also be supported, + # but performance for that interface may be degraded. + use_minimal_arrayref_interface: bool = False + + # Experimental. Flag to control whether to include weight in .so + package_constants_in_so: bool = True + + # Experimental. Flag to control whether to package weight separately on disk + package_constants_on_disk: bool = False + + # Experimental. Controls automatic precompiling of common AOTI include files. + precompile_headers: bool = not is_fbcode() + + # Embed generated kernel binary files into model.so + embed_kernel_binary: bool = False + + # Generate kernel files that support multiple archs + # For CUDA, this means generating fatbin files for kernels, and the fatbin files + # contains PTX and SASS for the current architecture. + emit_multi_arch_kernel: bool = False + + # If not None, the generated files with use this name in file stem. + # If None, we will use a hash to name files. + model_name_for_generated_files: Optional[str] = None + + # Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict + custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {} + # custom op libs that have implemented C shim wrappers + custom_op_libs: Optional[list[str]] = None + + +class cuda: + """Settings for cuda backend, today this consists of cutlass""" + + # CUDA arch to use for CUDA template kernel compilation. + # e.g. "70", "75", "80", "90", etc. + # When arch is None, Inductor uses torch.cuda.get_device_capability(0). + arch: Optional[str] = None + + # CUDA version to use for CUDA template kernel compilation. + # e.g. "11.4", "12.1", etc. + # When version is None, Inductor uses torch.version.cuda. + version: Optional[str] = None + + # Optimization level for the host compiler. + compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1" + + # Whether to enable device LTO (link-time-optimization). + enable_cuda_lto = False + + # Whether to keep intermediate files dring compilation. + enable_ptxas_info = False + + # Whether to enable debug info, e.g. line number, cutlass debug info. + enable_debug_info = False + + # Whether to use fast math. + use_fast_math = False + + # Path to the CUTLASS repo root directory. + # The default path only works under PyTorch local development environment. + cutlass_dir = os.environ.get( + "TORCHINDUCTOR_CUTLASS_DIR", + os.path.abspath( + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/") + ), + ) + + # Configures the maximum number of CUTLASS configs to profile in max_autotune. + # By default it's None, so that all CUTLASS configs are tuned. + # This is mainly used to reduce test time in CI. + cutlass_max_profiling_configs: Optional[int] = None + + # The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune. + cutlass_max_profiling_swizzle_options: list[int] = [1, 2, 4, 8] + + # Whether to use CUTLASS EVT for epilogue fusion + cutlass_epilogue_fusion_enabled = ( + os.environ.get("CUTLASS_EPILOGUE_FUSION", "0") == "1" + ) + + # Whether to only use TMA-compatible kernels in CUTLASS + cutlass_tma_only = False + + # Path to CUDA NVCC. + # NVCC search order: + # 1) cuda_cxx set in this config + # 2) CUDACXX environment variable + # 3) CUDA_HOME environment variable + # 4) default system search PATH. + cuda_cxx: Optional[str] = None + + # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops. + cutlass_backend_min_gemm_size: int = 1 + + # enable generation of inline standalone runner in CUDA CPP generated code + # which allows to compile the generated code into a standalone executable. + generate_test_runner: bool = ( + os.environ.get("INDUCTOR_CUDA_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1" + ) + + # Keep only Cutlass op configs which contain this regular expression pattern + # Set this to "warpspecialized_cooperative_epi_tma" to enable only SM90 TMA Cutlass Kernels for large GEMMs + cutlass_op_allowlist_regex: Optional[str] = os.environ.get( + "TORCHINDUCTOR_CUTLASS_ALLOWLIST" + ) + + # Note: Names of Cutlass ops names can be obtained by calling + # op.configuration_name() on a Cutlass op instance, for example those + # returned from cutlass_utils.gen_ops() or the op argument passed to + # CUTLASSGemmTemplate.render(...) + + # Filter Cutlass configs which contain this regular expression pattern + # Set this to "pingpong" to avoid numerical issues + # caused by the op ordering of the "pingpong" memory access + # pattern used by some Cutlass Kernels. + cutlass_op_denylist_regex: Optional[str] = os.environ.get( + "TORCHINDUCTOR_CUTLASS_DENYLIST" + ) + + # Non-negative integer which determines how many kernels are instantiated. + # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations. + # increasing first digit reduces schedule / mixed type pruning, + # increasing second digit generates more cluster sizes, + # increasing third digit generates more MMA multipliers, + # increasing fourth digit generates more instruction shapes. + cutlass_instantiation_level: str = os.environ.get( + "TORCHINDUCTOR_CUTLASS_INSTANTIATION_LEVEL", "0" + ) + + # Experimental. Only for H100 for now. Flag to control whether to use presets. + # Format looks like: "0,1,3" for using presets 0, 1, and 3. Presets can be + # controlled by some cutlass instantiation level flags (e.g. 0, 1111, 2222, ...) + cutlass_presets: Optional[str] = os.environ.get("TORCHINDUCTOR_CUTLASS_PRESETS") + + # use compile command to create kernel .cu and .so name + cutlass_hash_with_compile_cmd: bool = ( + os.environ.get("TORCHINDUCTOR_CUTLASS_HASH_WITH_COMPILE_CMD", "0") == "1" + ) + + # Experimental. Prescreen top x configs before tuning on swizzle. + cutlass_prescreening: bool = ( + os.environ.get("TORCHINDUCTOR_CUTLASS_PRESCREENING", "1") == "1" + ) + + # Specify which operations should use CUTLASS backend + # Comma-separated list like "mm,addmm,bmm", "all" for all operations, and "" for none. + # Acceptable operations: mm, int_mm, addmm, sparse_semi_structured_mm, bmm, scaled_mm + cutlass_enabled_ops: str = os.environ.get( + "TORCHINDUCTOR_CUTLASS_ENABLED_OPS", "all" + ) + + # Whether to consult the binary remote cache + use_binary_remote_cache: bool = True + + # Whether to upload compiled kernels to remote cache + upload_to_binary_remote_cache: bool = False + + # Whether to force upload if the key already exists + # Use this to overwrite and handle cache pollution + binary_remote_cache_force_write: bool = False + + +class rocm: + # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. + # If empty, the `native` arch is used + arch: list[str] = [] + + # Enable the CK backend for CDNA2 and CDNA3 only (for now) + # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors + ck_supported_arch: list[str] = ["gfx90a", "gfx942"] + + # Optimization level, use to balance compilation speed and runtime performance. + # The type will not necessarily be comprehensive and won't be enforced at runtime. + compile_opt_level: Literal[ + "-O0", "-O1", "-O2", "-O3", "-Os", "-Oz", "-Omin", "-Ofast", "-Omax" + ] = "-O2" + + # Flag to keep debug information in compiled objects + is_debug = False + + # Flag to keep intermediate files (assembly listings, preprocessed sources, etc.) + save_temps = False + + # Flag to add `-ffast-math`` to compile flags + use_fast_math = True + + # Flag to add `-fgpu-flush-denormals-to-zero` to compile flags + flush_denormals = True + + # Flag to print register and LDS usage during compilation + print_kernel_resource_usage = False + + # Path to ROCm installation, if None, use env variable ROCM_HOME. + # In fbcode see triton/fb/TARGETS for how ROCM_HOME gets set. + rocm_home: Optional[str] = None + + # Path to Composable Kernel library. + # Install with `pip install git+https://github.com/rocm/composable_kernel@develop`. + ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR") + + # generate standalone executables for instances generated with the CK backend + generate_test_runner: bool = ( + os.environ.get("INDUCTOR_CK_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1" + ) + + # Deprecated, use CK and/or CK-tile specific settings + n_max_profiling_configs: Optional[int] = None + + # Number of op instance choices to trade off between runtime perf and compilation time + # For CK Kernels + ck_max_profiling_configs: Optional[int] = None + + # Number of op instance choices to trade off between runtime perf and compilation time + # For CK-Tile Kernels + ck_tile_max_profiling_configs: Optional[int] = None + + # Flag to use a short list of CK instances which perform well across a variety of shapes. + # Currently RCR and F16 only + use_preselected_instances: bool = False + + # List to determine kBatch parameters to sweep over. By default, we calculate one in splitK + # scenarios, and run on kBatch=1 in non-splitK scenarios + kBatch_sweep: Optional[list[int]] = None + + # The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this + split_k_threshold: int = 16 + + +# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) +cpu_backend: Literal["cpp", "triton", "halide"] = "cpp" + +# Backend to use for CUDA codegen either "triton" or "halide" (experimental) +cuda_backend: Literal["triton", "halide"] = "triton" + + +class halide: + # Base halide target to use for CPU devices + cpu_target = "host" + + # Base halide target to use for CUDA devices + gpu_target = "host-cuda" + + # Halide autoscheduler to use, choices are: + # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only) + scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = ( + "Anderson2021" + ) + scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = ( + "Adams2019" + ) + + # Controls `no_asserts` flag passed to Halide target (warning: can false positive) + asserts = False + + # Controls `debug` flag passed to Halide target + debug = False + + # Enable (or fallback on) scan kernels such as cumsum + # Halide autoschedulers struggle with these kernels + scan_kernels = False + + +# create a directory containing lots of debug information +class trace: + # master switch for all debugging flags below + enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + # save real tensors + save_real_tensors = os.environ.get("TORCH_COMPILE_DEBUG_SAVE_REAL", "0") == "1" + + # Save debug information to a temporary directory + # If not specified, a temp directory will be created by system + debug_dir: Optional[str] = None + + # Save python logger call >=logging.DEBUG + debug_log = False + + # Save python logger call >=logging.INFO + info_log = False + + # Save input FX graph (post decomps, pre optimization) + fx_graph = True + + # Save FX graph after transformations + fx_graph_transformed = True + + # Save TorchInductor IR before fusion pass + ir_pre_fusion = True + + # Save TorchInductor IR after fusion pass + ir_post_fusion = True + + # Copy generated code to trace dir + output_code = True + + # SVG figure showing post-fusion graph + graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1" + + # SVG figure showing fx with fusion + draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1" + + # We draw our fx graphs with the "record" shape attribute by default. + # Sometimes, when the graph is very complex, we may hit dot errors like below: + # "flat edge between adjacent nodes one of which has a record shape - + # replace records with HTML-like labels" + # and thus fail to generate a graph. So, let's give the user an option + # to specify the shape attribute for the dot graph. For example, passing + # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like labels + # to workaround the above failure. + dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None) + + # If not None, this is the URL that saves the SVG files of the input/output + # graph of each pass that changed the graph + # The nodes that are being transformed in each pass will be colored in yellow + # URL only supports local directory for now + log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None) + + # Store cProfile (see snakeviz to view) + compile_profile = False + + # Upload the .tar.gz file + # Needs to be overridden based on specific environment needs + upload_tar: Optional[Callable[[str], None]] = None + + log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1" + + # Save mapping info from inductor generated triton kernel to post_grad fx nodes + log_inductor_triton_kernel_to_post_grad_node_info: bool = True + + +_save_config_ignore: list[str] = [ + # workaround: "Can't pickle " + "trace.upload_tar", + "joint_custom_pre_pass", + "joint_custom_post_pass", + "pre_grad_custom_pass", + "aot_inductor.repro_level", + "aot_inductor.dump_aoti_minifier", + "post_grad_custom_pre_pass", + "post_grad_custom_post_pass", + "_fuse_ddp_communication_passes", + "_pre_fusion_custom_pass", +] + +_cache_config_ignore_prefix: list[str] = [ + # trace functions are not relevant to config caching + "trace", + # uses absolute path + "cuda.cutlass_dir", + # not relevant + "worker_start_method", + "compile_threads", + # see CustomGraphPass; these are handled specially + "post_grad_custom_post_pass", + "post_grad_custom_pre_pass", + "_fuse_ddp_communication_passes", + "_pre_fusion_custom_pass", + # tests assume that changes here don't invalidate cache + "always_complex_memory_overlap_TESTING_ONLY", + # cache related options are not relevant to cache results + "fx_graph_cache", + "fx_graph_remote_cache", + "autotune_local_cache", + "autotune_remote_cache", +] + +# External callable for matmul tuning candidates +external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = [] + + +class test_configs: + force_extern_kernel_in_multi_template: bool = False + + max_mm_configs: Optional[int] = None + + runtime_triton_dtype_assert = False + static_cpp_dtype_assert = False + + # regex to control the set of considered autotuning + # choices (aka configs) by name and / or description + autotune_choice_name_regex: Optional[str] = None + autotune_choice_desc_regex: Optional[str] = None + + graphsafe_rng_func_ignores_fallback_random = False + + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/phivenv/Lib/site-packages/torch/_inductor/constant_folding.py b/phivenv/Lib/site-packages/torch/_inductor/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..fb1c04896e8267a800debd9afa1968b20aa60b55 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/constant_folding.py @@ -0,0 +1,415 @@ +import collections +from typing import Any, Callable, Optional + +import torch +import torch.utils._pytree as pytree +from torch._inductor.freezing_utils import maybe_set_is_frozen_param +from torch.utils._ordered_set import OrderedSet + + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + +_dont_constant_fold: list[torch.fx.node.Target] = [] + + +def add_dont_constant_fold(op: torch.fx.node.Target) -> None: + global _dont_constant_fold + _dont_constant_fold.append(op) + + +def clear_dont_constant_fold() -> None: + global _dont_constant_fold + _dont_constant_fold.clear() + + +def replace_node_with_constant( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + constant: Optional[torch.Tensor] = None, + name: Optional[str] = None, +) -> None: + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 # type: ignore[assignment] + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 # type: ignore[assignment, operator] + + gm._frozen_param_count = i + 1 # type: ignore[assignment, operator] + + with g.inserting_before(node): + if constant is not None: + new_input_node = g.create_node("get_attr", qualname, (), {}) + else: + # this is the case for lifted constants + new_input_node = g.create_node("placeholder", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + new_input_node.name = node.name + + if constant is not None: + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + # mark any constants created during freezing + maybe_set_is_frozen_param(constant) + + +def is_const_source( + node: torch.fx.Node, lifted_constant_names: Optional[list[str]] +) -> bool: + return node.op == "get_attr" or node.name in (lifted_constant_names or ()) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm: torch.fx.GraphModule, + skip_constructors: bool = False, + lifted_constant_names: Optional[list[str]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + ) -> None: + super().__init__(gm) + self.node_replacements: dict[torch.fx.Node, Any] = {} + self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + self.lifted_constant_names = lifted_constant_names + self.deferred_value = object() + self.skip_folding_node_fn = skip_folding_node_fn + + def _support_dynamic_shape(self) -> bool: + # ConstantFolder not support dynamic shape now + return False + + def _deduce_value(self, node: torch.fx.Node) -> Any: + if self.lifted_constant_names is None: + return super().run_node(node) + # if lifted_constant_names is passed in, no concrete value is available + # so we just check if all inputs have values + if self.skip_folding_node_fn is not None and self.skip_folding_node_fn(node): + return self.unknown_value + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + for inp in flattened_node_inps: + if ( + isinstance(inp, torch.fx.Node) + and inp.name not in (self.lifted_constant_names or ()) + and self.env[inp] != self.deferred_value + ): + return self.unknown_value + return self.deferred_value + + def is_impure(self, node: torch.fx.node.Node) -> bool: + def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: + return ( + node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value] + and isinstance(node.args[0], torch.fx.Node) + and "val" in node.args[0].meta + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ) + + if ( + is_woq_int8_pattern(node) + or ( + node.target == torch.ops.aten.permute.default + and len(node.users) == 1 + and is_woq_int8_pattern(next(iter(node.users))) + ) + ) and is_const_source( + node.args[0], # type: ignore[arg-type] + self.lifted_constant_names, + ): + # Case 1: int8_weight -> dq -> bf16_weight + # Case 2: int8_weight -> permute -> dq -> bf16_weight + return True + + quant_registered = ( + getattr(torch.ops.quantized_decomposed, "dequantize_per_channel", None) + is not None + ) + if quant_registered and node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.convert_element_type.no_fuse, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + + if node.target in _dont_constant_fold: + return True + return False + + def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]: + last_non_output_use = collections.defaultdict(list) + seen_uses = OrderedSet[torch.fx.Node]() + output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr] + + for node in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] + if node.target == "output": + continue + + def add_use(inp: torch.fx.Node) -> None: + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node: torch.fx.Node) -> Any: + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg: torch.fx.Node) -> None: + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) == type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target == aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and not is_const_source(node, self.lifted_constant_names) + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + if node.op == "call_function" and isinstance( + node.target, torch._ops.HigherOrderOperator + ): + return self.unknown_value + + out = self._deduce_value(node) + + if isinstance(out, torch._C.ScriptObject): + return out + + if out == self.unknown_value: + return self.unknown_value + + if not is_const_source(node, self.lifted_constant_names) and ( + isinstance(out, torch.Tensor) or out == self.deferred_value + ): + if out != self.deferred_value and out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self) -> Any: # type: ignore[override] + env: dict[torch.fx.Node, Any] = {} + self.insert_placerholder_values(env) + return super().run(initial_env=env) + + def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None: + for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] + env[n] = self.unknown_value # type: ignore[assignment] + if self.lifted_constant_names is None: + return + for n in self.module.graph.nodes: # type: ignore[union-attr] + if n.name in (self.lifted_constant_names or ()): + env[n] = self.deferred_value + + +def constant_fold( + gm: torch.fx.GraphModule, + constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.find_nodes(op="get_attr"): + if len(node.users) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def constant_graph_tag( + gm: torch.fx.GraphModule, + skip_constructors: bool = True, + lifted_constant_names: Optional[list[str]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder( + gm, + skip_constructors=skip_constructors, + lifted_constant_names=lifted_constant_names, + skip_folding_node_fn=skip_folding_node_fn, + ) + cf.run() + + for node in gm.graph.nodes: + if skip_folding_node_fn is not None and skip_folding_node_fn(node): + node.meta[META_TAG] = MODULE_TAG + continue + if ( + is_const_source(node, lifted_constant_names) + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph( + gm: torch.fx.GraphModule, + skip_constructors: bool = True, + lifted_constant_names: Optional[list[str]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag( + gm, skip_constructors, lifted_constant_names, skip_folding_node_fn + ) + + def untag(node: torch.fx.Node) -> bool: + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + return used_to_fold + + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.nodes: + if node.op == "get_attr" or (node.name in (lifted_constant_names or ())): + untag(node) + + new_graph = torch.fx.Graph() + + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/phivenv/Lib/site-packages/torch/_inductor/cpp_builder.py b/phivenv/Lib/site-packages/torch/_inductor/cpp_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..85f79e5d99e2d674366f86a1eecc4ffdac7ec1eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/cpp_builder.py @@ -0,0 +1,1861 @@ +# This CPP builder is designed to support both Windows and Linux OS. +# The design document please check this RFC: https://github.com/pytorch/pytorch/issues/124245 + +import copy +import errno +import functools +import json +import logging +import os +import platform +import re +import shlex +import shutil +import subprocess +import sys +import sysconfig +import tempfile +import textwrap +import warnings +from collections.abc import Sequence +from ctypes import cdll +from ctypes.util import find_library +from pathlib import Path +from typing import Any, Optional, Union + +import torch +from torch._dynamo.utils import dynamo_timed +from torch._inductor import config, exc +from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.torch_version import TorchVersion + + +if config.is_fbcode(): + from triton.fb.build import _run_build_command, build_paths + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def use_global_cache() -> bool: # type: ignore[misc] + return False + + +# Windows need setup a temp dir to store .obj files. +_BUILD_TEMP_DIR = "CxxBuild" +_HERE = os.path.abspath(__file__) +_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) +_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld") + +# initialize variables for compilation +_IS_LINUX = sys.platform.startswith("linux") +_IS_MACOS = sys.platform.startswith("darwin") +_IS_WINDOWS = sys.platform == "win32" + +SUBPROCESS_DECODE_ARGS = ("utf-8",) if _IS_WINDOWS else () + +log = logging.getLogger(__name__) + + +# =============================== toolchain =============================== +@functools.lru_cache(1) +def cpp_compiler_search(search: str) -> str: + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT + + for cxx in search: + try: + if cxx is None: + # gxx package is only available for Linux + # according to https://anaconda.org/conda-forge/gxx/ + if sys.platform != "linux": + continue + # Do not install GXX by default + if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"): + continue + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock( + os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT + ) + with lock: + cxx = install_gcc_via_conda() + subprocess.check_output([cxx, "--version"]) + return cxx + except (subprocess.SubprocessError, FileNotFoundError, ImportError): + continue + raise exc.InvalidCxxCompiler + + +def install_gcc_via_conda() -> str: + """On older systems, this is a quick way to get a modern compiler""" + prefix = os.path.join(cache_dir(), "gcc") + cxx_path = os.path.join(prefix, "bin", "g++") + if not os.path.exists(cxx_path): + log.info("Downloading GCC via conda") + conda = os.environ.get("CONDA_EXE", "conda") + if conda is None: + conda = shutil.which("conda") + if conda is not None: + subprocess.check_call( + [ + conda, + "create", + f"--prefix={prefix}", + "--channel=conda-forge", + "--quiet", + "-y", + "python=3.8", + "gxx", + ], + stdout=subprocess.PIPE, + ) + return cxx_path + + +@functools.cache +def check_compiler_exist_windows(compiler: str) -> None: + """ + Check if compiler is ready, in case end user not activate MSVC environment. + """ + try: + subprocess.check_output([compiler, "/help"], stderr=subprocess.STDOUT) + except FileNotFoundError as exc: + raise RuntimeError(f"Compiler: {compiler} is not found.") from exc + except subprocess.SubprocessError: + # Expected that some compiler(clang, clang++) is exist, but they not support `/help` args. + pass + + +def get_cpp_compiler() -> str: + if _IS_WINDOWS: + compiler = os.environ.get("CXX", "cl") + check_compiler_exist_windows(compiler) + else: + if config.is_fbcode(): + return build_paths.cc + if isinstance(config.cpp.cxx, (list, tuple)): + search = tuple(config.cpp.cxx) + else: + search = (config.cpp.cxx,) + compiler = cpp_compiler_search(search) + return compiler + + +def get_ld_and_objcopy(use_relative_path: bool) -> tuple[str, str]: + if _IS_WINDOWS: + raise RuntimeError("Windows is not supported yet.") + else: + if config.is_fbcode(): + ld = build_paths.ld + objcopy = ( + build_paths.objcopy_fallback + if use_relative_path + else build_paths.objcopy + ) + else: + ld = "ld" + objcopy = "objcopy" + return ld, objcopy + + +def convert_cubin_to_obj( + cubin_file: str, + kernel_name: str, + ld: str, + objcopy: str, +) -> str: + obj_file = cubin_file + ".o" + # Convert .cubin to .o + cmd = f"{ld} -r -b binary -z noexecstack -o {obj_file} {cubin_file}" + subprocess.run(cmd.split(), capture_output=True, text=True, check=True) + # Rename .data to .rodata + cmd = f"{objcopy} --rename-section .data=.rodata,alloc,load,readonly,data,contents {obj_file}" + subprocess.run(cmd.split(), capture_output=True, text=True, check=True) + # By default objcopy will create *_start, *_size, *_end symbols using the full path + # Rename to use the unique kernel name + file_name = re.sub(r"[\W]", "_", cubin_file) + cmd = ( + objcopy + + f" --redefine-sym _binary_{file_name}_start=__{kernel_name}_start " + + f"--redefine-sym _binary_{file_name}_size=__{kernel_name}_size " + + f"--redefine-sym _binary_{file_name}_end=__{kernel_name}_end " + + obj_file + ) + subprocess.run(cmd.split(), capture_output=True, text=True, check=True) + return obj_file + + +@functools.cache +def _is_apple_clang(cpp_compiler: str) -> bool: + version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") + return "Apple" in version_string.splitlines()[0] + + +@functools.cache +def _is_clang(cpp_compiler: str) -> bool: + # Mac OS apple clang maybe named as gcc, need check compiler info. + if sys.platform == "darwin": + return _is_apple_clang(cpp_compiler) + elif _IS_WINDOWS: + # clang suite have many compilers, and only clang-cl is supported. + if re.search(r"((clang$)|(clang\+\+$))", cpp_compiler): + raise RuntimeError( + "Please use clang-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + return bool(re.search(r"(clang-cl)", cpp_compiler)) + return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) + + +@functools.cache +def _is_gcc(cpp_compiler: str) -> bool: + # Since "clang++" ends with "g++", the regex match below would validate on it. + if _is_clang(cpp_compiler): + return False + return bool(re.search(r"(gcc|g\+\+|gnu-c\+\+)", cpp_compiler)) + + +@functools.cache +def _is_msvc_cl(cpp_compiler: str) -> bool: + if not _IS_WINDOWS: + return False + + try: + output_msg = ( + subprocess.check_output([cpp_compiler, "/help"], stderr=subprocess.STDOUT) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + return "Microsoft" in output_msg.splitlines()[0] + except FileNotFoundError: + return False + + return False + + +@functools.cache +def _is_intel_compiler(cpp_compiler: str) -> bool: + def _check_minimal_version(compiler_version: TorchVersion) -> None: + """ + On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor. + """ + min_version = "2024.2.1" if _IS_WINDOWS else "0.0.0" + if compiler_version < TorchVersion(min_version): + raise RuntimeError( + f"Intel Compiler error: less than minimal version {min_version}." + ) + + try: + output_msg = ( + subprocess.check_output( + [cpp_compiler, "--version"], stderr=subprocess.DEVNULL + ) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + is_intel_compiler = "Intel" in output_msg.splitlines()[0] + if is_intel_compiler: + if _IS_WINDOWS: + if re.search(r"((icx$)|(icx-cc$))", cpp_compiler): + raise RuntimeError( + "Please use icx-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + + # Version check + icx_ver_search = re.search(r"(\d+[.]\d+[.]\d+[.]\d+)", output_msg) + if icx_ver_search is not None: + icx_ver = icx_ver_search.group(1) + _check_minimal_version(TorchVersion(icx_ver)) + + return is_intel_compiler + except FileNotFoundError: + return False + except subprocess.SubprocessError: + # --version args not support. + return False + + return False + + +@functools.cache +def is_gcc() -> bool: + return _is_gcc(get_cpp_compiler()) + + +@functools.cache +def is_clang() -> bool: + return _is_clang(get_cpp_compiler()) + + +@functools.cache +def is_intel_compiler() -> bool: + return _is_intel_compiler(get_cpp_compiler()) + + +@functools.cache +def is_apple_clang() -> bool: + return _is_apple_clang(get_cpp_compiler()) + + +@functools.cache +def is_msvc_cl() -> bool: + return _is_msvc_cl(get_cpp_compiler()) + + +@functools.cache +def get_compiler_version_info(compiler: str) -> str: + env = os.environ.copy() + env["LC_ALL"] = "C" # Don't localize output + try: + version_string = subprocess.check_output( + [compiler, "-v"], stderr=subprocess.STDOUT, env=env + ).decode(*SUBPROCESS_DECODE_ARGS) + except Exception: + try: + version_string = subprocess.check_output( + [compiler, "--version"], stderr=subprocess.STDOUT, env=env + ).decode(*SUBPROCESS_DECODE_ARGS) + except Exception: + return "" + # Multiple lines to one line string. + version_string = version_string.replace("\r", "_") + version_string = version_string.replace("\n", "_") + return version_string + + +# =============================== cpp builder =============================== +def _append_list(dest_list: list[str], src_list: list[str]) -> None: + dest_list.extend(copy.deepcopy(item) for item in src_list) + + +def _remove_duplication_in_list(orig_list: list[str]) -> list[str]: + new_list: list[str] = [] + for item in orig_list: + if item not in new_list: + new_list.append(item) + return new_list + + +def _create_if_dir_not_exist(path_dir: str) -> None: + if not os.path.exists(path_dir): + try: + Path(path_dir).mkdir(parents=True, exist_ok=True) + except OSError as exc: # Guard against race condition + if exc.errno != errno.EEXIST: + raise RuntimeError( # noqa: TRY200 (Use `raise from`) + f"Fail to create path {path_dir}" + ) + + +def _remove_dir(path_dir: str) -> None: + if os.path.exists(path_dir): + for root, dirs, files in os.walk(path_dir, topdown=False): + for name in files: + file_path = os.path.join(root, name) + os.remove(file_path) + for name in dirs: + dir_path = os.path.join(root, name) + os.rmdir(dir_path) + os.rmdir(path_dir) + + +def _run_compile_cmd(cmd_line: str, cwd: str) -> None: + cmd = shlex.split(cmd_line) + try: + subprocess.run( + cmd, cwd=cwd, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as e: + output = e.stdout.decode("utf-8") + openmp_problem = "'omp.h' file not found" in output or "libomp" in output + if openmp_problem and sys.platform == "darwin": + instruction = ( + "\n\nOpenMP support not found. Please try one of the following solutions:\n" + "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " + "that has builtin OpenMP support;\n" + "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" + "(3) install libomp via brew: `brew install libomp`;\n" + "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" + " with `include/omp.h` under it." + ) + output += instruction + raise exc.CppCompileError(cmd, output) from e + + +def run_compile_cmd(cmd_line: str, cwd: str) -> None: + with dynamo_timed("compile_file"): + _run_compile_cmd(cmd_line, cwd) + + +def normalize_path_separator(orig_path: str) -> str: + if _IS_WINDOWS: + return orig_path.replace(os.sep, "/") + return orig_path + + +class BuildOptionsBase: + """ + This is the Base class for store cxx build options, as a template. + Actually, to build a cxx shared library. We just need to select a compiler + and maintains the suitable args. + """ + + def __init__( + self, + compiler: str = "", + definitions: Optional[list[str]] = None, + include_dirs: Optional[list[str]] = None, + cflags: Optional[list[str]] = None, + ldflags: Optional[list[str]] = None, + libraries_dirs: Optional[list[str]] = None, + libraries: Optional[list[str]] = None, + passthrough_args: Optional[list[str]] = None, + aot_mode: bool = False, + use_relative_path: bool = False, + compile_only: bool = False, + precompiling: bool = False, + preprocessing: bool = False, + ) -> None: + self._compiler = compiler + self._definitions: list[str] = definitions or [] + self._include_dirs: list[str] = include_dirs or [] + self._cflags: list[str] = cflags or [] + self._ldflags: list[str] = ldflags or [] + self._libraries_dirs: list[str] = libraries_dirs or [] + self._libraries: list[str] = libraries or [] + # Some args are hard to abstract to OS compatible, passthrough directly. + self._passthrough_args: list[str] = passthrough_args or [] + + # Optionally, the path to a precompiled header which should be included on the + # build command line. + self.precompiled_header: Optional[str] = None + + self._aot_mode: bool = aot_mode + self._use_relative_path: bool = use_relative_path + self._compile_only: bool = compile_only + self._precompiling: bool = precompiling + self._preprocessing: bool = preprocessing + + def _process_compile_only_options(self) -> None: + if self._compile_only: + self._libraries_dirs = [] + self._libraries = [] + + def _remove_duplicate_options(self) -> None: + self._definitions = _remove_duplication_in_list(self._definitions) + self._include_dirs = _remove_duplication_in_list(self._include_dirs) + self._cflags = _remove_duplication_in_list(self._cflags) + self._ldflags = _remove_duplication_in_list(self._ldflags) + self._libraries_dirs = _remove_duplication_in_list(self._libraries_dirs) + self._libraries = _remove_duplication_in_list(self._libraries) + self._passthrough_args = _remove_duplication_in_list(self._passthrough_args) + + def _finalize_options(self) -> None: + self._process_compile_only_options() + self._remove_duplicate_options() + + def get_compiler(self) -> str: + return self._compiler + + def get_definitions(self) -> list[str]: + return self._definitions + + def get_include_dirs(self) -> list[str]: + return self._include_dirs + + def get_cflags(self) -> list[str]: + return self._cflags + + def get_ldflags(self) -> list[str]: + return self._ldflags + + def get_libraries_dirs(self) -> list[str]: + return self._libraries_dirs + + def get_libraries(self) -> list[str]: + return self._libraries + + def get_passthrough_args(self) -> list[str]: + return self._passthrough_args + + def get_aot_mode(self) -> bool: + return self._aot_mode + + def get_use_relative_path(self) -> bool: + return self._use_relative_path + + def get_compile_only(self) -> bool: + return self._compile_only + + def get_precompiling(self) -> bool: + return self._precompiling + + def get_preprocessing(self) -> bool: + return self._preprocessing + + def save_flags_to_json(self, file: str) -> None: + attrs = { + "compiler": self.get_compiler(), + "definitions": self.get_definitions(), + "include_dirs": self.get_include_dirs(), + "cflags": self.get_cflags(), + "ldflags": self.get_ldflags(), + "libraries_dirs": self.get_libraries_dirs(), + "libraries": self.get_libraries(), + "passthrough_args": self.get_passthrough_args(), + "aot_mode": self.get_aot_mode(), + "use_relative_path": self.get_use_relative_path(), + "compile_only": self.get_compile_only(), + } + + with open(file, "w") as f: + json.dump(attrs, f) + + +def _get_warning_all_cflag(warning_all: bool = True) -> list[str]: + if not _IS_WINDOWS: + return ["Wall"] if warning_all else [] + else: + return [] + + +def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]: + if _IS_WINDOWS: + """ + On Windows, only c++20 can support `std::enable_if_t`. + Ref: https://learn.microsoft.com/en-us/cpp/overview/cpp-conformance-improvements-2019?view=msvc-170#checking-for-abstract-class-types # noqa: B950 + Note: + Only setup c++20 for Windows inductor. I tried to upgrade all project to c++20, but it is failed: + https://github.com/pytorch/pytorch/pull/131504 + """ + std_num = "c++20" + return [f"std:{std_num}"] + else: + return [f"std={std_num}"] + + +def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]: + if _IS_WINDOWS: + cflags = [ + "wd4819", + "wd4251", + "wd4244", + "wd4267", + "wd4275", + "wd4018", + "wd4190", + "wd4624", + "wd4067", + "wd4068", + "EHsc", + ] + else: + cflags = ["Wno-unused-variable", "Wno-unknown-pragmas"] + if _is_clang(cpp_compiler): + ignored_optimization_argument = ( + "Werror=ignored-optimization-argument" + if config.aot_inductor.raise_error_on_ignored_optimization + else "Wno-ignored-optimization-argument" + ) + cflags.append(ignored_optimization_argument) + return cflags + + +def _get_ffast_math_flags() -> list[str]: + # ffast-math is equivalent to these flags as in + # https://github.com/gcc-mirror/gcc/blob/4700ad1c78ccd7767f846802fca148b2ea9a1852/gcc/opts.cc#L3458-L3468 + # however gcc<13 sets the FTZ/DAZ flags for runtime on x86 even if we have + # -ffast-math -fno-unsafe-math-optimizations because the flags for runtime + # are added by linking in crtfastmath.o. This is done by the spec file which + # only does globbing for -ffast-math. + flags = [ + "fno-trapping-math", + "funsafe-math-optimizations", + "ffinite-math-only", + "fno-signed-zeros", + "fno-math-errno", + ] + + if is_gcc(): + flags.append("fexcess-precision=fast") + + return flags + + +def _get_optimization_cflags( + cpp_compiler: str, min_optimize: bool = False +) -> list[str]: + if _IS_WINDOWS: + return ["O1" if min_optimize else "O2"] + else: + wrapper_opt_level = config.aot_inductor.compile_wrapper_opt_level + cflags = ( + ["O0", "g"] + if config.aot_inductor.debug_compile + else [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"] + ) + cflags += _get_ffast_math_flags() + cflags.append("fno-finite-math-only") + if not config.cpp.enable_unsafe_math_opt_flag: + cflags.append("fno-unsafe-math-optimizations") + cflags.append(f"ffp-contract={config.cpp.enable_floating_point_contract_flag}") + + if sys.platform != "darwin": + # on macos, unknown argument: '-fno-tree-loop-vectorize' + if _is_gcc(cpp_compiler): + cflags.append("fno-tree-loop-vectorize") + # https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1 + # `-march=native` is unrecognized option on M1 + if not config.is_fbcode(): + if platform.machine() == "ppc64le": + cflags.append("mcpu=native") + else: + cflags.append("march=native") + + return cflags + + +def _get_shared_cflag(do_link: bool) -> list[str]: + if _IS_WINDOWS: + """ + MSVC `/MD` using python `ucrtbase.dll` lib as runtime. + https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170 + """ + return ["DLL", "MD"] + if not do_link: + return ["fPIC"] + if platform.system() == "Darwin" and "clang" in get_cpp_compiler(): + # This causes undefined symbols to behave the same as linux + return ["shared", "fPIC", "undefined dynamic_lookup"] + return ["shared", "fPIC"] + + +def get_cpp_options( + cpp_compiler: str, + do_link: bool, + warning_all: bool = True, + extra_flags: Sequence[str] = (), + min_optimize: bool = False, +) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + definitions: list[str] = [] + include_dirs: list[str] = [] + cflags: list[str] = [] + ldflags: list[str] = [] + libraries_dirs: list[str] = [] + libraries: list[str] = [] + passthrough_args: list[str] = [] + + cflags = ( + _get_shared_cflag(do_link) + + _get_optimization_cflags(cpp_compiler, min_optimize) + + _get_warning_all_cflag(warning_all) + + _get_cpp_std_cflag() + + _get_os_related_cpp_cflags(cpp_compiler) + ) + + passthrough_args.append(" ".join(extra_flags)) + + return ( + definitions, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthrough_args, + ) + + +class CppOptions(BuildOptionsBase): + """ + This class is inherited from BuildOptionsBase, and as cxx build options. + This option need contains basic cxx build option, which contains: + 1. OS related args. + 2. Toolchains related args. + 3. Cxx standard related args. + Note: + 1. This Options is good for assist modules build, such as x86_isa_help. + """ + + def __init__( + self, + compile_only: bool = False, + warning_all: bool = True, + extra_flags: Sequence[str] = (), + use_relative_path: bool = False, + compiler: str = "", + min_optimize: bool = False, + precompiling: bool = False, + preprocessing: bool = False, + ) -> None: + super().__init__( + compile_only=compile_only, + use_relative_path=use_relative_path, + precompiling=precompiling, + preprocessing=preprocessing, + ) + self._compiler = compiler if compiler else get_cpp_compiler() + + ( + definitions, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthrough_args, + ) = get_cpp_options( + cpp_compiler=self._compiler, + do_link=not (compile_only or precompiling or preprocessing), + extra_flags=extra_flags, + warning_all=warning_all, + min_optimize=min_optimize, + ) + + _append_list(self._definitions, definitions) + _append_list(self._include_dirs, include_dirs) + _append_list(self._cflags, cflags) + _append_list(self._ldflags, ldflags) + _append_list(self._libraries_dirs, libraries_dirs) + _append_list(self._libraries, libraries) + _append_list(self._passthrough_args, passthrough_args) + self._finalize_options() + + +def _get_glibcxx_abi_build_flags() -> list[str]: + if not _IS_WINDOWS: + return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] + else: + return [] + + +def _get_torch_cpp_wrapper_definition() -> list[str]: + return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"] + + +def _use_custom_generated_macros() -> list[str]: + return [" C10_USING_CUSTOM_GENERATED_MACROS"] + + +def _use_fb_internal_macros() -> list[str]: + if not _IS_WINDOWS: + if config.is_fbcode(): + fb_internal_macros = [ + "C10_USE_GLOG", + "C10_USE_MINIMAL_GLOG", + "C10_DISABLE_TENSORIMPL_EXTENSIBILITY", + ] + return fb_internal_macros + else: + return [] + else: + return [] + + +def _setup_standard_sys_libs( + cpp_compiler: str, + aot_mode: bool, + use_relative_path: bool, +) -> tuple[list[str], list[str], list[str]]: + cflags: list[str] = [] + include_dirs: list[str] = [] + passthrough_args: list[str] = [] + if _IS_WINDOWS: + return cflags, include_dirs, passthrough_args + + if config.is_fbcode(): + # TODO(T203137008) Can we unify these flags with triton_cc_command? + cflags.append("nostdinc") + # Note that the order of include paths do matter, as a result + # we need to have several branches interleaved here + include_dirs.append(build_paths.sleef_include) + include_dirs.append(build_paths.openmp_include) + include_dirs.append(build_paths.python_include) + include_dirs.append(build_paths.cc_include) + include_dirs.append(build_paths.libgcc_include) + include_dirs.append(build_paths.libgcc_arch_include) + include_dirs.append(build_paths.libgcc_backward_include) + include_dirs.append(build_paths.glibc_include) + include_dirs.append(build_paths.linux_kernel_include) + include_dirs.append("include") + + if aot_mode and not use_relative_path: + linker_script = _LINKER_SCRIPT + else: + linker_script = os.path.basename(_LINKER_SCRIPT) + + if _is_clang(cpp_compiler): + passthrough_args.append(" --rtlib=compiler-rt") + passthrough_args.append(" -fuse-ld=lld") + passthrough_args.append(f" -Wl,--script={linker_script}") + passthrough_args.append(" -B" + build_paths.glibc_lib) + passthrough_args.append(" -L" + build_paths.glibc_lib) + + return cflags, include_dirs, passthrough_args + + +def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[list[str], list[str]]: + macros: list[str] = [] + build_flags: list[str] = [] + if vec_isa != invalid_vec_isa: + # Add Windows support later. + macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro()) + + build_flags = [vec_isa.build_arch_flags()] + + if config.is_fbcode(): + cap = str(vec_isa).upper() + macros = [ + f"CPU_CAPABILITY={cap}", + f"CPU_CAPABILITY_{cap}", + f"HAVE_{cap}_CPU_DEFINITION", + ] + + return macros, build_flags + + +def _get_torch_related_args( + include_pytorch: bool, aot_mode: bool +) -> tuple[list[str], list[str], list[str]]: + from torch.utils.cpp_extension import include_paths, TORCH_LIB_PATH + + include_dirs = include_paths() + libraries_dirs = [TORCH_LIB_PATH] + libraries = [] + if sys.platform != "darwin" and not config.is_fbcode(): + libraries = ["torch", "torch_cpu"] + if not aot_mode: + libraries.append("torch_python") + + if _IS_WINDOWS: + libraries.append("sleef") + + return include_dirs, libraries_dirs, libraries + + +def _get_python_include_dirs() -> list[str]: + include_dir = Path(sysconfig.get_path("include")) + # On Darwin Python executable from a framework can return + # non-existing /Library/Python/... include path, in which case + # one should use Headers folder from the framework + if not include_dir.exists() and platform.system() == "Darwin": + std_lib = Path(sysconfig.get_path("stdlib")) + include_dir = (std_lib.parent.parent / "Headers").absolute() + if not (include_dir / "Python.h").exists(): + warnings.warn(f"Can't find Python.h in {str(include_dir)}") + return [str(include_dir)] + + +def _get_python_related_args() -> tuple[list[str], list[str]]: + python_include_dirs = _get_python_include_dirs() + python_include_path = sysconfig.get_path( + "include", scheme="nt" if _IS_WINDOWS else "posix_prefix" + ) + if python_include_path is not None: + python_include_dirs.append(python_include_path) + + if _IS_WINDOWS: + python_lib_path = [ + str( + ( + Path(sysconfig.get_path("include", scheme="nt")).parent / "libs" + ).absolute() + ) + ] + else: + python_lib_path = [sysconfig.get_config_var("LIBDIR")] + + if config.is_fbcode(): + python_include_dirs.append(build_paths.python_include) + + return python_include_dirs, python_lib_path + + +@functools.cache +def is_conda_llvm_openmp_installed() -> bool: + try: + command = "conda list llvm-openmp --json" + output = subprocess.check_output(command.split()).decode("utf8") + return len(json.loads(output)) > 0 + except (subprocess.SubprocessError, FileNotFoundError): + return False + + +@functools.cache +def homebrew_libomp() -> tuple[bool, str]: + try: + # check if `brew` is installed + if shutil.which("brew") is None: + return False, "" + # get the location of `libomp` if it is installed + # this is the location that `libomp` **would** be installed + # see https://github.com/Homebrew/brew/issues/10261#issuecomment-756563567 for details + libomp_path = ( + subprocess.check_output(["brew", "--prefix", "libomp"]) + .decode("utf8") + .strip() + ) + # check if `libomp` is installed + omp_available = os.path.exists(libomp_path) + return omp_available, libomp_path + except subprocess.SubprocessError: + return False, "" + + +@functools.cache +def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: + try: + output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode( + "utf8" + ) + omp_path = os.path.join(output.rstrip(), omp_name) + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + cdll.LoadLibrary(omp_path) + except subprocess.SubprocessError: + pass + + +@functools.cache +def perload_icx_libomp_win(cpp_compiler: str) -> None: + def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: + try: + output = subprocess.check_output( + [cpp_compiler, f"-print-file-name={lib_name}"], + stderr=subprocess.DEVNULL, + ).decode(*SUBPROCESS_DECODE_ARGS) + omp_path = output.rstrip() + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + cdll.LoadLibrary(omp_path) + return True + except subprocess.SubprocessError: + pass + return False + + """ + Intel Compiler implemented more math libraries than clang, for performance proposal. + We need preload them like openmp library. + """ + preload_list = [ + "libiomp5md.dll", # openmp + "svml_dispmd.dll", # svml library + "libmmd.dll", # libm + ] + + for lib_name in preload_list: + _load_icx_built_in_lib_by_name(cpp_compiler, lib_name) + + +def _get_openmp_args( + cpp_compiler: str, +) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str]]: + cflags: list[str] = [] + ldflags: list[str] = [] + include_dir_paths: list[str] = [] + lib_dir_paths: list[str] = [] + libs: list[str] = [] + passthrough_args: list[str] = [] + if _IS_MACOS: + # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` + cflags.append("Xclang") + cflags.append("fopenmp") + + # only Apple builtin compilers (Apple Clang++) require openmp + omp_available = not _is_apple_clang(cpp_compiler) + + # check the `OMP_PREFIX` environment first + omp_prefix = os.getenv("OMP_PREFIX") + if omp_prefix is not None: + header_path = os.path.join(omp_prefix, "include", "omp.h") + valid_env = os.path.exists(header_path) + if valid_env: + include_dir_paths.append(os.path.join(omp_prefix, "include")) + lib_dir_paths.append(os.path.join(omp_prefix, "lib")) + else: + warnings.warn("environment variable `OMP_PREFIX` is invalid.") + omp_available = omp_available or valid_env + + if not omp_available: + libs.append("omp") + + # prefer to use openmp from `conda install llvm-openmp` + conda_prefix = os.getenv("CONDA_PREFIX") + if not omp_available and conda_prefix is not None: + omp_available = is_conda_llvm_openmp_installed() + if omp_available: + conda_lib_path = os.path.join(conda_prefix, "lib") + include_dir_paths.append(os.path.join(conda_prefix, "include")) + lib_dir_paths.append(conda_lib_path) + # Prefer Intel OpenMP on x86 machine + if os.uname().machine == "x86_64" and os.path.exists( + os.path.join(conda_lib_path, "libiomp5.dylib") + ): + libs.append("iomp5") + + # next, try to use openmp from `brew install libomp` + if not omp_available: + omp_available, libomp_path = homebrew_libomp() + if omp_available: + include_dir_paths.append(os.path.join(libomp_path, "include")) + lib_dir_paths.append(os.path.join(libomp_path, "lib")) + + # if openmp is still not available, we let the compiler to have a try, + # and raise error together with instructions at compilation error later + elif _IS_WINDOWS: + """ + On Windows, `clang` and `icx` have their specific openmp implenmention. + And the openmp lib is in compiler's some sub-directory. + For dynamic library(DLL) load, the Windows native APIs are `LoadLibraryA` and `LoadLibraryExA`, and their search + dependencies have some rules: + https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryexa#searching-for-dlls-and-dependencies + In some case, the rules may not include compiler's sub-directories. + So, it can't search and load compiler's openmp library correctly. + And then, the whole application would be broken. + + To avoid the openmp load failed, we can automatic locate the openmp binary and preload it. + 1. For clang, the function is `perload_clang_libomp_win`. + 2. For icx, the function is `perload_icx_libomp_win`. + """ + if _is_clang(cpp_compiler): + cflags.append("openmp") + libs.append("libomp") + perload_clang_libomp_win(cpp_compiler, "libomp.dll") + elif _is_intel_compiler(cpp_compiler): + cflags.append("Qiopenmp") + libs.append("libiomp5md") + perload_icx_libomp_win(cpp_compiler) + else: + # /openmp, /openmp:llvm + # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ + # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 + cflags.append("openmp") + cflags.append("openmp:experimental") # MSVC CL + else: + if config.is_fbcode(): + include_dir_paths.append(build_paths.openmp_include) + + openmp_lib = build_paths.openmp_lib_so + fb_openmp_extra_flags = f"-Wp,-fopenmp {openmp_lib}" + passthrough_args.append(fb_openmp_extra_flags) + + libs.append("omp") + else: + if _is_clang(cpp_compiler): + # TODO: fix issue, can't find omp.h + cflags.append("fopenmp") + libs.append("gomp") + elif _is_intel_compiler(cpp_compiler): + cflags.append("fiopenmp") + else: + cflags.append("fopenmp") + libs.append("gomp") + + return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args + + +def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]: + macros = [] + if use_mmap_weights: + macros.append(" USE_MMAP_SELF") + return macros + + +def get_cpp_torch_options( + cpp_compiler: str, + vec_isa: VecISA, + include_pytorch: bool, + aot_mode: bool, + use_relative_path: bool, + use_mmap_weights: bool, +) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + definitions: list[str] = [] + include_dirs: list[str] = [] + cflags: list[str] = [] + ldflags: list[str] = [] + libraries_dirs: list[str] = [] + libraries: list[str] = [] + passthrough_args: list[str] = [] + + torch_cpp_wrapper_definitions = _get_torch_cpp_wrapper_definition() + use_custom_generated_macros_definitions = _use_custom_generated_macros() + + ( + sys_libs_cflags, + sys_libs_include_dirs, + sys_libs_passthrough_args, + ) = _setup_standard_sys_libs(cpp_compiler, aot_mode, use_relative_path) + + isa_macros, isa_ps_args_build_flags = _get_build_args_of_chosen_isa(vec_isa) + + ( + torch_include_dirs, + torch_libraries_dirs, + torch_libraries, + ) = _get_torch_related_args(include_pytorch=include_pytorch, aot_mode=aot_mode) + + python_include_dirs, python_libraries_dirs = _get_python_related_args() + + ( + omp_cflags, + omp_ldflags, + omp_include_dir_paths, + omp_lib_dir_paths, + omp_lib, + omp_passthrough_args, + ) = _get_openmp_args(cpp_compiler) + + cxx_abi_passthrough_args = _get_glibcxx_abi_build_flags() + fb_macro_passthrough_args = _use_fb_internal_macros() + + mmap_self_macros = get_mmap_self_macro(use_mmap_weights) + + definitions = ( + torch_cpp_wrapper_definitions + + use_custom_generated_macros_definitions + + isa_macros + + fb_macro_passthrough_args + + mmap_self_macros + ) + include_dirs = ( + sys_libs_include_dirs + + python_include_dirs + + torch_include_dirs + + omp_include_dir_paths + ) + cflags = sys_libs_cflags + omp_cflags + ldflags = omp_ldflags + libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths + libraries = torch_libraries + omp_lib + passthrough_args = ( + sys_libs_passthrough_args + + isa_ps_args_build_flags + + cxx_abi_passthrough_args + + omp_passthrough_args + ) + + return ( + definitions, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthrough_args, + ) + + +class CppTorchOptions(CppOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options. And then it will maintains torch related build + args. + 1. Torch include_directories, libraries, libraries_directories. + 2. Python include_directories, libraries, libraries_directories. + 3. OpenMP related. + 4. Torch MACROs. + 5. MISC + """ + + def __init__( + self, + vec_isa: VecISA = invalid_vec_isa, + include_pytorch: bool = False, + warning_all: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_relative_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + compiler: str = "", + min_optimize: bool = False, + precompiling: bool = False, + preprocessing: bool = False, + ) -> None: + super().__init__( + compile_only=compile_only, + warning_all=warning_all, + extra_flags=extra_flags, + use_relative_path=use_relative_path, + compiler=compiler, + min_optimize=min_optimize, + precompiling=precompiling, + preprocessing=preprocessing, + ) + + self._aot_mode = aot_mode + + ( + torch_definitions, + torch_include_dirs, + torch_cflags, + torch_ldflags, + torch_libraries_dirs, + torch_libraries, + torch_passthrough_args, + ) = get_cpp_torch_options( + cpp_compiler=self._compiler, + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + use_relative_path=use_relative_path, + use_mmap_weights=use_mmap_weights, + ) + + _append_list(self._definitions, torch_definitions) + _append_list(self._include_dirs, torch_include_dirs) + _append_list(self._cflags, torch_cflags) + _append_list(self._ldflags, torch_ldflags) + _append_list(self._libraries_dirs, torch_libraries_dirs) + _append_list(self._libraries, torch_libraries) + _append_list(self._passthrough_args, torch_passthrough_args) + self._finalize_options() + + +def _set_gpu_runtime_env() -> None: + if ( + config.is_fbcode() + and torch.version.hip is None + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = build_paths.sdk_home + + +@functools.lru_cache(8) +def _find_libcudart_static(path: str) -> Optional[Path]: + lib_dirs = list(Path(path).rglob("libcudart_static.a")) + if lib_dirs: + return lib_dirs[0].resolve().parent + log_msg = f'"libcudart_static.a" not found under {path}' + log.info(log_msg) + return None + + +def _transform_cuda_paths(lpaths: list[str]) -> None: + # This handles two cases: + # 1. Cases where libs are in (e.g.) lib/cuda-12 and lib/cuda-12/stubs + # 2. Linux machines may have CUDA installed under either lib64/ or lib/ + for i, path in enumerate(lpaths): + if "CUDA_HOME" in os.environ and path.startswith(os.environ["CUDA_HOME"]): + lib_dir: Optional[Path] = _find_libcudart_static(path) + if lib_dir is None: + continue + lpaths[i] = str(lib_dir) + stub_dir = lib_dir / "stubs" + if stub_dir.exists(): + lpaths.append(str(stub_dir)) + + +def get_cpp_torch_device_options( + device_type: str, + aot_mode: bool = False, + compile_only: bool = False, +) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + definitions: list[str] = [] + include_dirs: list[str] = [] + cflags: list[str] = [] + ldflags: list[str] = [] + libraries_dirs: list[str] = [] + libraries: list[str] = [] + passthrough_args: list[str] = [] + if ( + config.is_fbcode() + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = build_paths.sdk_home + + _set_gpu_runtime_env() + from torch.utils import cpp_extension + + include_dirs = cpp_extension.include_paths(device_type) + libraries_dirs = cpp_extension.library_paths(device_type) + if device_type == "cuda": + definitions.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") + + if torch.version.hip is not None: + if config.is_fbcode(): + libraries += ["amdhip64"] + else: + libraries += ["c10_hip", "torch_hip"] + definitions.append(" __HIP_PLATFORM_AMD__") + else: + if config.is_fbcode(): + libraries += ["cuda"] + else: + libraries += ["c10_cuda", "cuda", "torch_cuda"] + _transform_cuda_paths(libraries_dirs) + + if device_type == "xpu": + definitions.append(" USE_XPU") + # Suppress multi-line comment warnings in sycl headers + cflags += ["Wno-comment"] + libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"] + if not find_library("ze_loader"): + raise OSError( + "Intel GPU driver is not properly installed, please follow the instruction " + "in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support." + ) + + if device_type == "mps": + definitions.append(" USE_MPS") + + if config.is_fbcode(): + include_dirs.append(build_paths.sdk_include) + + if aot_mode and device_type == "cuda": + if torch.version.hip is None: + if not compile_only: + # Only add link args, when compile_only is false. + passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"] + + if config.aot_inductor.custom_op_libs: + libraries += config.aot_inductor.custom_op_libs + + return ( + definitions, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthrough_args, + ) + + +class CppTorchDeviceOptions(CppTorchOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options and torch common build options. And then it will + maintains cuda/xpu device related build args. + """ + + def __init__( + self, + vec_isa: VecISA = invalid_vec_isa, + include_pytorch: bool = False, + device_type: str = "cuda", + aot_mode: bool = False, + compile_only: bool = False, + use_relative_path: bool = False, + use_mmap_weights: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + min_optimize: bool = False, + precompiling: bool = False, + preprocessing: bool = False, + ) -> None: + super().__init__( + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_relative_path=use_relative_path, + use_mmap_weights=use_mmap_weights, + extra_flags=extra_flags, + min_optimize=min_optimize, + precompiling=precompiling, + preprocessing=preprocessing, + ) + + device_definitions: list[str] = [] + device_include_dirs: list[str] = [] + device_cflags: list[str] = [] + device_ldflags: list[str] = [] + device_libraries_dirs: list[str] = [] + device_libraries: list[str] = [] + device_passthrough_args: list[str] = [] + + ( + device_definitions, + device_include_dirs, + device_cflags, + device_ldflags, + device_libraries_dirs, + device_libraries, + device_passthrough_args, + ) = get_cpp_torch_device_options( + device_type=device_type, aot_mode=aot_mode, compile_only=compile_only + ) + _append_list(self._definitions, device_definitions) + _append_list(self._include_dirs, device_include_dirs) + _append_list(self._cflags, device_cflags) + _append_list(self._ldflags, device_ldflags) + _append_list(self._libraries_dirs, device_libraries_dirs) + _append_list(self._libraries, device_libraries) + _append_list(self._passthrough_args, device_passthrough_args) + self._finalize_options() + + def _finalize_options(self) -> None: + super()._finalize_options() + if config.is_fbcode(): + # Re-order library search paths in case there are lib conflicts + # that also live in the FBCode python lib dir. + _, python_lib_dirs = _get_python_related_args() + assert len(python_lib_dirs) == 1, f"Python lib dirs: {python_lib_dirs}" + if python_lib_dirs[0] in self._libraries_dirs: + self._libraries_dirs.remove(python_lib_dirs[0]) + self._libraries_dirs.append(python_lib_dirs[0]) + + +def get_name_and_dir_from_output_file_path( + file_path: str, +) -> tuple[str, str]: + """ + This function help prepare parameters to new cpp_builder. + Example: + input_code: /tmp/tmpof1n5g7t/5c/c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc.cpp + name, dir = get_name_and_dir_from_output_file_path(input_code) + Run result: + name = c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc + dir = /tmp/tmpof1n5g7t/5c/ + + put 'name' and 'dir' to CppBuilder's 'name' and 'output_dir'. + CppBuilder --> get_target_file_path will format output path according OS: + Linux: /tmp/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.so + Windows: [Windows temp path]/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.dll + """ + name_and_ext = os.path.basename(file_path) + name, _ext = os.path.splitext(name_and_ext) + dir = os.path.dirname(file_path) + + return name, dir + + +class CppBuilder: + """ + CppBuilder is a cpp jit builder, and it supports both Windows, Linux and MacOS. + Args: + name: + 1. Build target name, the final target file will append extension type automatically. + 2. Due to the CppBuilder is supports multiple OS, it will maintains ext for OS difference. + sources: + Source code file list to be built. + BuildOption: + Build options to the builder. + output_dir: + 1. The output_dir the target file will output to. + 2. The default value is empty string, and then the use current dir as output dir. + 3. Final target file: output_dir/name.ext + """ + + @staticmethod + def __get_python_module_flags() -> tuple[str, str]: + extension = ".pyd" if _IS_WINDOWS else ".so" + output_flags = "/Fe" if _IS_WINDOWS else "-o" + return extension, output_flags + + @staticmethod + def __get_object_flags() -> tuple[str, str]: + extension = ".obj" if _IS_WINDOWS else ".o" + output_flags = "/c /Fo" if _IS_WINDOWS else "-c -o" # codespell:ignore + return extension, output_flags + + @staticmethod + def __get_precompiled_header_flags() -> tuple[str, str]: + extension = ".pch" if _IS_WINDOWS or not is_gcc() else ".gch" + output_flags = "/Fp" if _IS_WINDOWS else "-o" + return extension, output_flags + + @staticmethod + def __get_preprocessor_output_flags() -> tuple[str, str]: + extension = ".i" + output_flags = "/EP /P" if _IS_WINDOWS else "-E -P -o" + return extension, output_flags + + def __init__( + self, + name: str, + sources: Union[str, list[str]], + BuildOption: BuildOptionsBase, + output_dir: str = "", + ) -> None: + self._compiler = "" + self._cflags_args = "" + self._definitions_args = "" + self._include_dirs_args = "" + self._ldflags_args = "" + self._libraries_dirs_args = "" + self._libraries_args = "" + self._passthrough_parameters_args = "" + + # When relative path is used, we need to maintain the source dir list. + self._orig_source_paths = [] + self._output_dir = "" + self._target_file = "" + + self._use_relative_path: bool = False + self._aot_mode: bool = False + + self._name = name + + # Code start here, initial self internal variables firstly. + self._build_option = BuildOption + self._compiler = BuildOption.get_compiler() + self._use_relative_path = BuildOption.get_use_relative_path() + self._aot_mode = BuildOption.get_aot_mode() + + self._output_dir = output_dir + + self._compile_only = BuildOption.get_compile_only() + self._precompiling = BuildOption.get_precompiling() + self._preprocessing = BuildOption.get_preprocessing() + # Only one of these options (if any) should be true at any given time. + assert sum((self._compile_only, self._precompiling, self._preprocessing)) <= 1 + self._do_link = not ( + self._compile_only or self._precompiling or self._preprocessing + ) + + # MSVC produces two files when precompiling: the actual .pch file, as well as an + # object file which must be linked into the final library. This class assumes + # only one output file of note, so for now we'll error out here. + assert not _IS_WINDOWS or not self._precompiling, ( + "Cannot currently precompile headers on Windows!" + ) + + if self._compile_only: + file_ext, output_flags = self.__get_object_flags() + elif self._precompiling: + file_ext, output_flags = self.__get_precompiled_header_flags() + elif self._preprocessing: + file_ext, output_flags = self.__get_preprocessor_output_flags() + else: + file_ext, output_flags = self.__get_python_module_flags() + self._target_file = os.path.join(self._output_dir, f"{self._name}{file_ext}") + + relative_target_file = ( + os.path.basename(self._target_file) + if self._use_relative_path + else self._target_file + ) + if _IS_WINDOWS: + if self._preprocessing: + # The target file name is automatically determined by MSVC. + self._output = output_flags + else: + self._output = f"{output_flags}{relative_target_file}" + else: + self._output = f"{output_flags} {relative_target_file}" + + if isinstance(sources, str): + sources = [sources] + + if config.is_fbcode() and (not self._aot_mode or self._use_relative_path): + # Will create another temp directory for building, so do NOT use the + # absolute path. + self._orig_source_paths = list(sources) + sources = [os.path.basename(i) for i in sources] + + if self._precompiling: + assert len(sources) == 1 + # See above; we can currently assume this is not on MSVC. + self._sources_args = f"-x c++-header {sources[0]}" + else: + self._sources_args = " ".join(sources) + + for cflag in BuildOption.get_cflags(): + if _IS_WINDOWS: + self._cflags_args += f"/{cflag} " + else: + self._cflags_args += f"-{cflag} " + + for definition in BuildOption.get_definitions(): + if _IS_WINDOWS: + self._definitions_args += f"/D {definition} " + else: + self._definitions_args += f"-D {definition} " + + if precompiled_header := BuildOption.precompiled_header: + if _IS_WINDOWS: + log.warning( + "Precompiled header support for MSVC is currently unavailable; ignoring %s", + precompiled_header, + ) + else: + self._include_dirs_args = f"-include {precompiled_header} " + + for inc_dir in BuildOption.get_include_dirs(): + if _IS_WINDOWS: + self._include_dirs_args += f'/I "{inc_dir}" ' + else: + self._include_dirs_args += f"-I{shlex.quote(inc_dir)} " + + for ldflag in BuildOption.get_ldflags(): + if _IS_WINDOWS: + self._ldflags_args += f"/{ldflag} " + else: + self._ldflags_args += f"-{ldflag} " + + for lib_dir in BuildOption.get_libraries_dirs(): + if _IS_WINDOWS: + self._libraries_dirs_args += f'/LIBPATH:"{lib_dir}" ' + else: + self._libraries_dirs_args += f"-L{lib_dir} " + + for lib in BuildOption.get_libraries(): + if _IS_WINDOWS: + self._libraries_args += f'"{lib}.lib" ' + else: + self._libraries_args += f"-l{lib} " + + for passthrough_arg in BuildOption.get_passthrough_args(): + self._passthrough_parameters_args += f"{passthrough_arg} " + + def get_command_line(self) -> str: + def format_build_command( + compiler: str, + sources: str, + include_dirs_args: str, + definitions_args: str, + cflags_args: str, + ldflags_args: str, + libraries_args: str, + libraries_dirs_args: str, + passthrough_args: str, + output: str, + ) -> str: + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/build/walkthrough-compile-a-c-program-on-the-command-line?view=msvc-1704 + # https://stackoverflow.com/a/31566153 + cmd = ( + f"{compiler} {include_dirs_args} {definitions_args} {cflags_args} " + f"{sources} {passthrough_args} {output}" + ) + if self._do_link: + cmd += f" /LD /link {libraries_dirs_args} {libraries_args} {ldflags_args}" + cmd = normalize_path_separator(cmd) + else: + cmd = ( + f"{compiler} {sources} {definitions_args} {cflags_args} " + f"{include_dirs_args} {passthrough_args} {output}" + ) + if self._do_link: + cmd += f" {ldflags_args} {libraries_args} {libraries_dirs_args}" + return cmd + + command_line = format_build_command( + compiler=self._compiler, + sources=self._sources_args, + include_dirs_args=self._include_dirs_args, + definitions_args=self._definitions_args, + cflags_args=self._cflags_args, + ldflags_args=self._ldflags_args, + libraries_args=self._libraries_args, + libraries_dirs_args=self._libraries_dirs_args, + passthrough_args=self._passthrough_parameters_args, + output=self._output, + ) + return command_line + + def get_target_file_path(self) -> str: + return normalize_path_separator(self._target_file) + + def build_fbcode_re( + self, + ) -> None: + with dynamo_timed("compile_file"): + command = self.get_command_line().split() + try: + output_path = self._target_file + # When we build remotely, we need to make sure to carefully copy any files + # that are required during the compilation process into our build directly. + # This is where all of the ATen/c10/Torch includes come from. + torch_includes_path = os.path.join(_TORCH_PATH, "include") + with tempfile.TemporaryDirectory() as tmp_dir: + # Copy everything to tmp compilation folder + shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld")) + for src in self._orig_source_paths: + shutil.copy(src, os.path.join(tmp_dir, os.path.basename(src))) + dest_include_path = os.path.join(tmp_dir, "include") + shutil.copytree(torch_includes_path, dest_include_path) + # Run the build + tmp_output_path = _run_build_command( + command, tmp_dir, os.path.basename(output_path) + ) + # Copy output from the build + if os.path.exists(output_path): + os.remove(output_path) + shutil.copy(tmp_output_path, output_path) + if output_path.endswith(".o"): + os.chmod(output_path, 0o644) + elif output_path.endswith(".so"): + os.chmod(output_path, 0o755) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8") + raise exc.CppCompileError(command, output) from e + + def build(self) -> None: + """ + It is must need a temporary directory to store object files in Windows. + After build completed, delete the temporary directory to save disk space. + """ + if self._use_relative_path: + # remote build uses relative path + return self.build_fbcode_re() + _create_if_dir_not_exist(self._output_dir) + _build_tmp_dir = os.path.join( + self._output_dir, f"{self._name}_{_BUILD_TEMP_DIR}" + ) + _create_if_dir_not_exist(_build_tmp_dir) + + build_cmd = self.get_command_line() + run_compile_cmd(build_cmd, cwd=_build_tmp_dir) + _remove_dir(_build_tmp_dir) + + def save_compile_cmd_to_cmake( + self, + cmake_path: str, + device_type: str, + ) -> None: + """ + Save global cmake settings here, e.g. compiler options. + If targeting CUDA, also emit a custom function to embed CUDA kernels. + """ + + definitions = " ".join(self._build_option.get_definitions()) + contents = textwrap.dedent( + f""" + cmake_minimum_required(VERSION 3.27 FATAL_ERROR) + project(aoti_model LANGUAGES CXX) + set(CMAKE_CXX_STANDARD 17) + + # May need to point CMAKE_PREFIX_PATH to the right torch location + find_package(Torch REQUIRED) + + # Set a shared library target + add_library(aoti_model SHARED) + + # Add macro definitions + target_compile_definitions(aoti_model PRIVATE {definitions}) + + # Add compile flags + target_compile_options(aoti_model PRIVATE {self._cflags_args}) + # Backend specific flags + target_compile_options(aoti_model PRIVATE {self._passthrough_parameters_args} -c) + + """ + ) + if device_type == "cuda" and torch.version.hip is None: + from torch._inductor.codecache import _nvcc_arch_as_compile_option + + current_arch = _nvcc_arch_as_compile_option() + contents += textwrap.dedent( + f""" + enable_language(CUDA) + find_package(CUDAToolkit REQUIRED) + + find_program(OBJCOPY_EXECUTABLE objcopy) + if(NOT OBJCOPY_EXECUTABLE) + message(FATAL_ERROR "objcopy not found. Cannot embed fatbin as object file") + endif() + + set(KERNEL_TARGETS "") + set(KERNEL_OBJECT_FILES "") + # Function to embed a single kernel + function(embed_gpu_kernel KERNEL_NAME PTX_FILE) + set(FATBIN_BASENAME ${{KERNEL_NAME}}.fatbin) + set(FATBIN_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{FATBIN_BASENAME}}) + set(OBJECT_BASENAME ${{KERNEL_NAME}}.fatbin.o) + set(OBJECT_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{OBJECT_BASENAME}}) + + # --- Define UNIQUE C symbol names --- + set(SYMBOL_START __${{KERNEL_NAME}}_start) + set(SYMBOL_END __${{KERNEL_NAME}}_end) + set(SYMBOL_SIZE __${{KERNEL_NAME}}_size) + string(REGEX REPLACE "[^a-zA-Z0-9]" "_" MANGLED_BASENAME ${{FATBIN_FILE}}) + set(OBJCOPY_START_SYM _binary_${{MANGLED_BASENAME}}_start) + set(OBJCOPY_END_SYM _binary_${{MANGLED_BASENAME}}_end) + set(OBJCOPY_SIZE_SYM _binary_${{MANGLED_BASENAME}}_size) + + # --- PTX to FATBIN Command & Target --- + add_custom_command( + OUTPUT ${{FATBIN_FILE}} + COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}} + -gencode arch=compute_80,code=compute_80 + -gencode arch=compute_{current_arch},code=sm_{current_arch} + DEPENDS ${{PTX_FILE}} + ) + + # --- FATBIN to Object File (.o) Command --- + add_custom_command( + OUTPUT ${{OBJECT_FILE}} + COMMAND ${{CMAKE_LINKER}} -r -b binary -z noexecstack -o ${{OBJECT_FILE}} ${{FATBIN_FILE}} + COMMAND ${{OBJCOPY_EXECUTABLE}} --rename-section .data=.rodata,alloc,load,readonly,data,contents + ${{OBJECT_FILE}} + COMMAND ${{OBJCOPY_EXECUTABLE}} + --redefine-sym ${{OBJCOPY_START_SYM}}=${{SYMBOL_START}} + --redefine-sym ${{OBJCOPY_END_SYM}}=${{SYMBOL_END}} + --redefine-sym ${{OBJCOPY_SIZE_SYM}}=${{SYMBOL_SIZE}} + ${{OBJECT_FILE}} + DEPENDS ${{FATBIN_FILE}} + ) + add_custom_target(build_kernel_object_${{KERNEL_NAME}} DEPENDS ${{OBJECT_FILE}}) + + # --- Add to a list for linking later --- + set(KERNEL_TARGETS ${{KERNEL_TARGETS}} build_kernel_object_${{KERNEL_NAME}} PARENT_SCOPE) + set(KERNEL_OBJECT_FILES ${{KERNEL_OBJECT_FILES}} ${{OBJECT_FILE}} PARENT_SCOPE) + endfunction() + + """ + ) + + with open(cmake_path, "w") as f: + f.write(contents) + + def save_src_to_cmake(self, cmake_path: str, src_path: str) -> None: + # Remove the directory part of file_path + src_path = "${CMAKE_CURRENT_SOURCE_DIR}/" + Path(src_path).name + with open(cmake_path, "a") as f: + f.write(f"target_sources(aoti_model PRIVATE {src_path})\n") + + def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> None: + # TODO: make this work beyond CUDA + with open(cmake_path, "a") as f: + for asm_file in asm_files: + kernel_name = Path(asm_file).name.split(".")[0] + asm_file = f"${{CMAKE_CURRENT_SOURCE_DIR}}/{Path(asm_file).name}" + contents = textwrap.dedent( + f""" + embed_gpu_kernel({kernel_name} {asm_file}) + """ + ) + f.write(contents) + f.write("add_dependencies(aoti_model ${KERNEL_TARGETS})\n") + f.write( + "target_link_libraries(aoti_model PRIVATE ${KERNEL_OBJECT_FILES})\n" + ) + + def save_link_cmd_to_cmake(self, cmake_path: str) -> None: + lflags = " ".join(self._build_option.get_ldflags()) + libs = " ".join(self._build_option.get_libraries()) + contents = textwrap.dedent( + f""" + # Add linker flags + target_link_options(aoti_model PRIVATE {lflags}) + + # Add libraries + target_link_libraries(aoti_model PRIVATE {libs}) + """ + ) + + assert os.path.exists(cmake_path), ( + f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist" + ) + with open(cmake_path, "a") as f: + f.write(contents) diff --git a/phivenv/Lib/site-packages/torch/_inductor/cpu_vec_isa.py b/phivenv/Lib/site-packages/torch/_inductor/cpu_vec_isa.py new file mode 100644 index 0000000000000000000000000000000000000000..8b63f4f867cc50b0286ea5da22dbd545955e2799 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/cpu_vec_isa.py @@ -0,0 +1,447 @@ +# mypy: allow-untyped-defs +import dataclasses +import functools +import os +import platform +import re +import subprocess +import sys +import warnings +from typing import Any, Callable, Union + +import torch +from torch._inductor import config + + +_IS_WINDOWS = sys.platform == "win32" + + +def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str: + # ISA dry compile will cost about 1 sec time each startup time. + # Please check the issue: https://github.com/pytorch/pytorch/issues/100378 + # Actually, dry compile is checking compile capability for ISA. + # We just record the compiler version, isa options and pytorch version info, + # and generated them to output binary hash path. + # It would optimize and skip compile existing binary. + from torch._inductor.cpp_builder import get_compiler_version_info, get_cpp_compiler + + compiler_info = get_compiler_version_info(get_cpp_compiler()) + torch_version = torch.__version__ + fingerprint = f"{compiler_info}={isa_flags}={torch_version}" + return fingerprint + + +class VecISA: + _bit_width: int + _macro: list[str] + _arch_flags: str + _dtype_nelements: dict[torch.dtype, int] + + # Note [Checking for Vectorized Support in Inductor] + # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions + # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions + # like exp, pow, sin, cos and etc. + # But PyTorch and TorchInductor might use different compilers to build code. If + # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so + # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass + # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest + # gcc/g++ compiler by default while it could support the AVX512 compilation. + # Therefore, there would be a conflict sleef version between PyTorch and + # TorchInductor. Hence, we dry-compile the following code to check whether current + # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM + # also needs the logic + # In fbcode however, we are using the same compiler for pytorch and for inductor codegen, + # making the runtime check unnecessary. + _avx_code = """ +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE) +#include +#include +#endif + +alignas(64) float in_out_ptr0[16] = {0.0}; + +extern "C" void __avx_chk_kernel() { + auto tmp0 = at::vec::Vectorized(1); + auto tmp1 = tmp0.exp(); + tmp1.store(in_out_ptr0); +} +""" # noqa: B950 + + _avx_py_load = """ +import torch +from ctypes import cdll +cdll.LoadLibrary("__lib_path__") +""" + + def bit_width(self) -> int: + return self._bit_width + + def nelements(self, dtype: torch.dtype = torch.float) -> int: + return self._dtype_nelements[dtype] + + def build_macro(self) -> list[str]: + return self._macro + + def build_arch_flags(self) -> str: + return self._arch_flags + + def __hash__(self) -> int: + return hash(str(self)) + + def check_build(self, code: str) -> bool: + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write + from torch._inductor.cpp_builder import ( + CppBuilder, + CppTorchOptions, + normalize_path_separator, + ) + + key, input_path = write( + code, + "cpp", + extra=_get_isa_dry_compile_fingerprint(self._arch_flags), + ) + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_dir = os.path.dirname(input_path) + buid_options = CppTorchOptions(vec_isa=self, warning_all=False) + x86_isa_help_builder = CppBuilder( + key, + [input_path], + buid_options, + output_dir, + ) + try: + # Check if the output file exist, and compile when not. + output_path = normalize_path_separator( + x86_isa_help_builder.get_target_file_path() + ) + if not os.path.isfile(output_path): + x86_isa_help_builder.build() + + # Check build result + subprocess.check_call( + [ + sys.executable, + "-c", + VecISA._avx_py_load.replace("__lib_path__", output_path), + ], + cwd=output_dir, + stderr=subprocess.DEVNULL, + env={ + **os.environ, + "PYTHONPATH": os.environ.get( + "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) + ), + }, + ) + except Exception: + return False + + return True + + def __bool__(self) -> bool: + return self.__bool__impl(config.cpp.vec_isa_ok) + + @functools.cache # noqa: B019 + def __bool__impl(self, vec_isa_ok) -> bool: + if vec_isa_ok is not None: + return vec_isa_ok + + if config.is_fbcode(): + return True + + return self.check_build(VecISA._avx_code) + + +@dataclasses.dataclass +class VecNEON(VecISA): + _bit_width = 128 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h + _macro = ["CPU_CAPABILITY_NEON", "AT_BUILD_ARM_VEC256_WITH_SLEEF"] + _arch_flags = "" # Unused + _dtype_nelements = {torch.float: 4, torch.bfloat16: 8, torch.float16: 8} + + def __str__(self) -> str: + if config.is_fbcode(): + return "neon" + return "asimd" # detects the presence of advanced SIMD on armv8-a kernels + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +@dataclasses.dataclass +class VecSVE256(VecISA): + # this function can be repurposed for SVE with variable vec length + _bit_width = 256 + _macro = [ + "CPU_CAPABILITY_SVE", + "CPU_CAPABILITY_SVE256", + "AT_BUILD_ARM_VEC256_WITH_SLEEF", + "__ARM_FEATURE_BF16", + ] + _arch_flags = "-march=armv8-a+sve+bf16 -msve-vector-bits=256" + + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + if config.is_fbcode(): + return "neon" + return "asimd" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +@dataclasses.dataclass +class VecAVX512(VecISA): + _bit_width = 512 + _macro = ["CPU_CAPABILITY_AVX512"] + _arch_flags = ( + "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + if not _IS_WINDOWS + else "/arch:AVX512" + ) # TODO: use cflags + _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} + + def __str__(self) -> str: + return "avx512" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +@dataclasses.dataclass +class VecAMX(VecAVX512): + _arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8" + + def __str__(self) -> str: + return super().__str__() + " amx_tile" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + _amx_code = """ +#include +#include + +struct amx_tilecfg { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[16]; + uint8_t rows[16]; +}; + +extern "C" void __amx_chk_kernel() { + amx_tilecfg cfg = {0}; + _tile_loadconfig(&cfg); + _tile_zero(0); + _tile_dpbf16ps(0, 1, 2); + _tile_dpbusd(0, 1, 2); +} +""" + + @functools.cache # noqa: B019 + def __bool__(self) -> bool: + if super().__bool__(): + if config.is_fbcode(): + return False + if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx(): + return True + return False + + +@dataclasses.dataclass +class VecAVX2(VecISA): + _bit_width = 256 + _macro = ["CPU_CAPABILITY_AVX2"] + _arch_flags = ( + "-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2" + ) # TODO: use cflags + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "avx2" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +@dataclasses.dataclass +class VecZVECTOR(VecISA): + _bit_width = 256 + _macro = [ + "CPU_CAPABILITY_ZVECTOR", + "CPU_CAPABILITY=ZVECTOR", + "HAVE_ZVECTOR_CPU_DEFINITION", + ] + _arch_flags = "-mvx -mzvector" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "zvector" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +@dataclasses.dataclass +class VecVSX(VecISA): + _bit_width = 256 # VSX simd supports 128 bit_width, but aten is emulating it as 256 + _macro = ["CPU_CAPABILITY_VSX"] + _arch_flags = "-mvsx" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "vsx" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +class InvalidVecISA(VecISA): + _bit_width = 0 + _macro = [""] + _arch_flags = "" + _dtype_nelements = {} + + def __str__(self) -> str: + return "INVALID_VEC_ISA" + + def __bool__(self) -> bool: # type: ignore[override] + return False + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +def x86_isa_checker() -> list[str]: + supported_isa: list[str] = [] + + def _check_and_append_supported_isa( + dest: list[str], isa_supported: bool, isa_name: str + ) -> None: + if isa_supported: + dest.append(isa_name) + + Arch = platform.machine() + """ + Arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + if Arch != "x86_64" and Arch != "AMD64": + return supported_isa + + avx2 = torch.cpu._is_avx2_supported() + avx512 = torch.cpu._is_avx512_supported() + amx_tile = torch.cpu._is_amx_tile_supported() + + _check_and_append_supported_isa(supported_isa, avx2, "avx2") + _check_and_append_supported_isa(supported_isa, avx512, "avx512") + _check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile") + + return supported_isa + + +invalid_vec_isa = InvalidVecISA() +supported_vec_isa_list = [ + VecAMX(), + VecAVX512(), + VecAVX2(), + VecNEON(), + VecSVE256(), +] + + +def get_isa_from_cpu_capability( + capability: Union[str, None], + vec_isa_list: list[VecISA], + invalid_vec_isa: InvalidVecISA, +): + # AMX setting is not supported in eager + # VecAMX will be prioritized for selection when setting ATEN_CPU_CAPABILITY to avx512 + # TODO add sve256 support + capability_to_isa_str = { + "default": "INVALID_VEC_ISA", + "zvector": "zvector", + "vsx": "vsx", + "avx2": "avx2", + "avx512": "avx512", + } + if capability in capability_to_isa_str.keys(): + isa_str = capability_to_isa_str[capability] + if isa_str == "INVALID_VEC_ISA": + return invalid_vec_isa + for vec_isa in vec_isa_list: + if isa_str in str(vec_isa): + return vec_isa + + if capability: + warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}") + + return vec_isa_list[0] + + +# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content +# might have too much redundant content that is useless for ISA check. Hence, +# we only cache some key isa information. +@functools.cache +def valid_vec_isa_list() -> list[VecISA]: + isa_list: list[VecISA] = [] + if sys.platform == "darwin" and platform.processor() == "arm": + isa_list.append(VecNEON()) + + if sys.platform not in ["linux", "win32"]: + return isa_list + + arch = platform.machine() + if arch == "s390x": + with open("/proc/cpuinfo") as _cpu_info: + while True: + line = _cpu_info.readline() + if not line: + break + # process line + featuresmatch = re.match(r"^features\s*:\s*(.*)$", line) + if featuresmatch: + for group in featuresmatch.groups(): + if re.search(r"[\^ ]+vxe[\$ ]+", group): + isa_list.append(VecZVECTOR()) + break + elif arch == "ppc64le": + isa_list.append(VecVSX()) + elif arch == "aarch64": + if torch.backends.cpu.get_cpu_capability() == "SVE256": + isa_list.append(VecSVE256()) + else: + isa_list.append(VecNEON()) + + elif arch in ["x86_64", "AMD64"]: + """ + arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + _cpu_supported_x86_isa = x86_isa_checker() + isa_list.extend( + isa + for isa in supported_vec_isa_list + if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa + ) + + return isa_list + + +def pick_vec_isa() -> VecISA: + if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]): + return VecAVX2() + + _valid_vec_isa_list: list[VecISA] = valid_vec_isa_list() + if not _valid_vec_isa_list: + return invalid_vec_isa + + # If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY + # to control CPU vec ISA + if config.cpp.simdlen is None: + return get_isa_from_cpu_capability( + os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa + ) + + for isa in _valid_vec_isa_list: + if config.cpp.simdlen == isa.bit_width(): + return isa + + return invalid_vec_isa diff --git a/phivenv/Lib/site-packages/torch/_inductor/cudagraph_trees.py b/phivenv/Lib/site-packages/torch/_inductor/cudagraph_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..8624008d2947498edc84b088b6fdb71b26f4abe1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/cudagraph_trees.py @@ -0,0 +1,2575 @@ +""" +CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, +which share the same memory pool. Sharing a memory pool is an extremely +important optimization when chaining multiple CUDA graphs together, as it +prevents you from needing to copy intermediate tensors from one graph to the +next, and reduces overall memory usage by allowing dead memory from the first +pool to be reused in the second. + +The standard graph/make_graph_callables support sharing memory pool, but +with a lot of caveats. CUDA graph trees remove these restrictions: + +* Previously, if you recorded graphs A, B, you had to replay A, B in that + order. With CUDA graph trees, after replaying A, you can change your + mind and record/replay a different graph B'; we will support efficient + execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In + other words: we support arbitrary trees of CUDA graph operations, not just + sequences (this is why this feature is called CUDA graph trees.) + +* Previously, if you executed graph A, some non-CUDA graph code, and then + graph B, after executing graph B, it was not safe to retain any references + to intermediates produced by A. With CUDA graph trees, we track if any +outputs of graph A are still live by the time graph B is run, and make + sure graph B doesn't clobber there memory when reusing the CUDA graphs + pool. You'll get a separate recording of B depending on what tensors + stay live or dead. + +CUDA graph trees are flexible enough to be used in Dynamo across graph breaks, +which is their primary use case. + +The ability to switch from replay to record is fairly nontrivial: remember that +when you replay a CUDA graph, you only replay CUDA operations; no CPU side state +is updated. In particular, the CPU-side book-keeping for the allocator is not +reconstructed. However, to record a new child CUDA graph, we must restore this +book-keeping. This is what checkpoint pool state is used for. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import gc +import itertools +import operator +import sys +import threading +import traceback +import warnings +import weakref +from collections import defaultdict +from contextlib import AbstractContextManager +from enum import auto, Enum +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union + +import torch.fx +from torch import Tensor +from torch._dynamo.callback import CallbackTrigger +from torch._dynamo.mutation_guard import GenerationTracker +from torch._dynamo.utils import counters, dynamo_timed, preserve_rng_state +from torch._inductor.compile_fx import ( + align_inputs_from_check_idxs, + copy_misaligned_inputs, + get_expanded_dims, + get_input_idxs_to_check, + index_expanded_dims, + remove_unaligned_input_idxs, + static_input, +) +from torch._inductor.cudagraph_utils import ( + check_for_mutation, + CheckInvariantStatus, + FunctionID, + log_cudagraph_skip_and_bump_counter, + log_data_ptr_mismatch, + maybe_warning_due_to_dynamic_shape, + ModelType, + OutputType, + PlaceholderInfo, + WrappedFunction, +) +from torch.multiprocessing.reductions import StorageWeakRef +from torch.storage import UntypedStorage +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils.weak import TensorWeakRef + + +if TYPE_CHECKING: + from collections.abc import Generator, Iterator, Sequence + + from torch._guards import CompileId + from torch._inductor.utils import InputType + from torch.types import _bool + +StorageWeakRefPointer = int +StorageDataPtr = int +NBytes = int +S = TypeVar("S", bound="StorageWeakRefWrapper") + + +if torch.backends.cuda.is_built(): + from torch._C import ( + _cuda_CUDAAllocator_AllocatorState as AllocatorState, + _set_cached_tensors_enabled as _set_cached_tensors_enabled, + ) +else: + + class AllocatorState: # type: ignore[no-redef] + pass + + def _set_cached_tensors_enabled(enabled: _bool) -> None: + pass + + +log = torch._logging.getArtifactLogger(__name__, "cudagraphs") + + +from . import config + + +@dataclasses.dataclass(frozen=True) +class GraphID: + "Unique counter of a cuda graph recording" + + id: int + + +def clear_cublass_cache() -> None: + """ + Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for + doing warmup within a CUDAGraph private pool because we do not want persistent allocations from + one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors + from the previous generation are freed. This frees them the memory pool, but not elsewhere. + A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated + in the next run. The memory would be in use in two places. + + To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required + it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the + program. There is no overhead to this on replay since cudagraphs removes allocation overhead. + """ + torch._C._cuda_clearCublasWorkspaces() + + +@contextlib.contextmanager +def clear_cublas_manager() -> Generator[None, None, None]: + "Context manager around clearing cublas caches that will clear on enter and exit" + clear_cublass_cache() + try: + yield + finally: + clear_cublass_cache() + + +@contextlib.contextmanager +def disable_conv_cache_emptying() -> Generator[None, None, None]: + prev = torch._C._cuda_get_conv_benchmark_empty_cache() + torch._C._cudnn_set_conv_benchmark_empty_cache(False) + try: + yield + finally: + torch._C._cudnn_set_conv_benchmark_empty_cache(prev) + + +@contextlib.contextmanager +def enable_history_recording() -> Generator[None, None, None]: + "Turns on history recording in the CUDA Caching Allocator" + enabled = torch._C._cuda_isHistoryEnabled() + try: + if not enabled: + torch.cuda.memory._record_memory_history() + yield + finally: + if not enabled: + torch.cuda.memory._record_memory_history(None) + + +def get_history_recording() -> AbstractContextManager[None]: + # TODO - remove, prevents cleanup + if not config.triton.cudagraph_trees_history_recording: + return contextlib.nullcontext() + return enable_history_recording() + + +class TreeManagerContainer: + """ + Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator, + the tree and its corresponding memory pool should be kept alive as long as any outstanding + graph or tensor which is an output of a graph remains alive. + + There is a single tree manager container per device. + + The lifecycle of a tree_manager is: + - Is constructed, no graph, no fns, no tensors + - Tree manager is fetched, resulting in tree manager being allocated + - We generate a bunch of functions, calling add_strong_reference + - These functions die, calling finalize_reference + - When all the functions die, we finalize_tree_manager. + + TODO: in the future, we would like to do the following once storage weak refs land + - We look for all the live storages and add references to THOSE + - We count as storages die + - All the storages are dead, we deallocate the tree manager + """ + + def __init__(self, device_index: int) -> None: + # This class keeps a strong reference to tree_manager, + # but upon all other strong references to the tree_manager will reset it to None. + # We need a strong reference so that we can still access its attributes upon cleanup. + self.tree_manager: Optional[CUDAGraphTreeManager] = None + + # Number of outstanding references to the current tree manager + self.live_cudagraphify_fns = 0 + + self.device_index = device_index + + # Following two objects are only set in the case that Tensor outputs outlive + # the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from + # deallocation. + self.live_storages_count = 0 + self.graph: Optional[torch.cuda.CUDAGraph] = None + + self.lock = threading.Lock() + + def _finalize_tensor(self) -> None: + with self.lock: + self.live_storages_count -= 1 + if self.live_storages_count == 0: + self.graph = None + + # manager was used again after existing cleanup, + # we shouldn't set it to None + if self.live_cudagraphify_fns == 0: + self.tree_manager = None + + def finalize_cudagraphify_fn(self) -> None: + with self.lock: + self.live_cudagraphify_fns -= 1 + if self.live_cudagraphify_fns == 0: + self._finalize_tree_manager() + + def _finalize_tree_manager(self) -> None: + assert self.lock.locked() + self.tree_manager = None + + # TODO - when issue #91395 is landed, we can set a weakref on + # storages and trigger a deallocation when all outputs of the + # cudagraph are dead. + + # live_storages = list( + # tree_manager.live_cudagraph_pool_storages_in_curr_execution() + # ) + + # # Maintain reference to graph to keep tensors alive + # assert len(tree_manager.roots) > 0, "expected at least one use" + # root = next(tree_manager.get_roots()) + # self.graph = root.graph + # seen_storages = set() + # for stor in live_storages: + # if stor in seen_storages: + # continue + # seen_storages.add(stor) + # self.live_storages_count += 1 + # . weakref.finalize(stor, self._finalize_tensor) + + def add_strong_reference(self, fn: Callable[..., Any]) -> None: + with self.lock: + self.live_cudagraphify_fns += 1 + + weakref.finalize(fn, self.finalize_cudagraphify_fn) + + def get_tree_manager(self) -> CUDAGraphTreeManager: + with self.lock: + if self.tree_manager is None: + self.tree_manager = CUDAGraphTreeManager(self.device_index) + return self.tree_manager + + +local = threading.local() + +# one tree manager per device +local.tree_manager_containers = {} +local.tree_manager_locks = defaultdict(threading.Lock) + + +# only incremented by user call of mark_step_begin +class MarkStepBox: + mark_step_counter = 0 + + +# We need to register this as an object that will be copied over as TLS when new +# threads are created in autograd +torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers) +torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks) + + +def mark_step_begin() -> None: + "Indicates that a new iteration of inference or training is about to begin." + + # iterate down to distinguish from GenerationTracking counter + MarkStepBox.mark_step_counter -= 1 + + +def reset_cudagraph_trees() -> None: + "Clear all cudagraph trees" + # see shutdown below for why this is necessary + container_dict = get_obj(local, "tree_manager_containers") + locks_dict = get_obj(local, "tree_manager_locks") + for device, lock in locks_dict.items(): + with lock: + container = container_dict.get(device) + if not container or not container.tree_manager: + continue + + container.tree_manager.shutdown() + + _set_cached_tensors_enabled(False) + container_dict.clear() + + MarkStepBox.mark_step_counter = 0 + + +def get_obj(local: Any, attr_name: str) -> Any: + if hasattr(local, attr_name): + return getattr(local, attr_name) + else: + assert torch._C._is_key_in_tls(attr_name) + return torch._C._get_obj_in_tls(attr_name) + + +def get_container(device_index: int) -> TreeManagerContainer: + container_dict = get_obj(local, "tree_manager_containers") + lock = get_obj(local, "tree_manager_locks")[device_index] + + with lock: + if device_index not in container_dict: + container_dict[device_index] = TreeManagerContainer(device_index) + + return container_dict[device_index] + + +def get_manager( + device_index: int, create_if_none_exists: bool = True +) -> Optional[CUDAGraphTreeManager]: + if create_if_none_exists: + return get_container(device_index).get_tree_manager() + return get_container(device_index).tree_manager + + +def is_cudagraph_capture_sizes(int_key: Union[int, tuple[int, ...]]) -> bool: + """ + Returns true if all dynamic shapes should be captured or the dynamic shape + int_key should be captured. + """ + return ( + config.triton.cudagraph_capture_sizes is None + or int_key in config.triton.cudagraph_capture_sizes + ) + + +def cudagraphify_impl( + model: ModelType, + inputs: list[InputType], + static_input_idxs: Sequence[int], + *args: Any, + **kwargs: Any, +) -> ModelType: + fn_cache: dict[tuple[int, ...], Callable[..., Any]] = {} + + # Detect int inputs: we need to index on these + int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)] + get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None + + has_warn = False + + del inputs + + def deferred_cudagraphify(inputs: list[InputType]) -> OutputType: + nonlocal has_warn + + int_key = get_ints(inputs) + + if not is_cudagraph_capture_sizes(int_key): + return model(inputs) + + fn = fn_cache.get(int_key) + if fn is not None: + return fn(inputs) + + if int_key is None: + log.info("recording cudagraph tree for graph without symints") + else: + log.info("recording cudagraph tree for symint key %s", int_key) + + if not has_warn: + has_warn = maybe_warning_due_to_dynamic_shape(fn_cache, int_key) + + # first get indices we need to check to align, then update our static inputs, + # and finally copy + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) + new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) + copy_misaligned_inputs(inputs, check_input_idxs) + + fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs) + # cudagraph will already clones input locally, no need to copy back + mutated_input_idxs: OrderedSet[int] = OrderedSet() + fn = align_inputs_from_check_idxs( + fn, inputs_to_check=check_input_idxs, mutated_input_idxs=mutated_input_idxs + ) + fn_cache[int_key] = fn + + return out + + return deferred_cudagraphify + + +@contextlib.contextmanager +def dynamo_timed_cudagraph( + name: str, + compile_id: Optional[CompileId], + mode: Optional[CompilationMode], +) -> Generator[Any, None, None]: + """ + Makes usages of dynamo_timed in this file less verbose. NOTE: This CM sums + all durations into a single column in the dynamo_compile table. Use only if + you consider the timed region to be part of the runtime overhead associated + with the compiler. + """ + with dynamo_timed( + name, + log_pt2_compile_event=True, + compile_id=compile_id, + is_backward=mode == CompilationMode.BACKWARD, + dynamo_compile_column_us="runtime_cudagraphify_time_us", + ): + yield + + +def cudagraphify( + model: ModelType, + inputs: list[InputType], + static_input_idxs: Sequence[int] = (), + *, + device_index: int, + is_backward: bool, + is_inference: bool, + stack_traces: Optional[StackTraces] = None, + constants: tuple[torch.Tensor, ...] = (), + placeholders: tuple[PlaceholderInfo, ...] = (), + mutated_input_idxs: tuple[int, ...] = (), + compile_id: Optional[CompileId] = None, +) -> tuple[ModelType, OutputType]: + assert not (is_backward and is_inference) + mode = ( + CompilationMode.BACKWARD + if is_backward + else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD) + ) + + with dynamo_timed_cudagraph("cudagraphify.get_container", compile_id, mode): + manager = get_container(device_index).get_tree_manager() + + return manager.add_function( + model, + inputs, + static_input_idxs, + stack_traces, + mode, + constants, + placeholders, + mutated_input_idxs, + compile_id, + ) + + +class StorageWeakRefWrapper: + """ + Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked. + """ + + __slots__ = ["ref", "_data_ptr", "extra_ref_check"] + + storage_ref: Optional[StorageWeakRef] + + def __init__( + self, + inp: Union[Tensor, UntypedStorage], + extra_ref_check: Optional[Callable[[], bool]] = None, + ) -> None: + """ + extra_ref_check is an additional check we need to run to check if the + weak ref has expired. in checking storage use count we assume extra_ref_check + will hold an additional reference to the storage. + """ + if isinstance(inp, Tensor): + stor = inp.untyped_storage() + else: + assert isinstance(inp, UntypedStorage) + stor = inp + self.ref = StorageWeakRef(stor) + self._data_ptr = stor.data_ptr() + self.extra_ref_check = extra_ref_check + + @classmethod + def from_weakref_and_data_ptr( + cls: type[StorageWeakRefWrapper], + cdata: Any, + data_ptr: int, + extra_ref_check: Optional[Callable[[], bool]] = None, + ) -> StorageWeakRefWrapper: + instance = cls.__new__(cls) + instance._data_ptr = data_ptr + instance.ref = StorageWeakRef.from_weakref(cdata) + instance.extra_ref_check = extra_ref_check + return instance + + def __call__(self) -> Optional[StorageWeakRefPointer]: + if self.expired(): + return None + + return self.ref.cdata + + def swap_weakref(self, cdata: Any) -> None: + self.ref.__del__() + self.ref.cdata = cdata + + def data_ptr(self) -> int: + "NB: returns the data ptr even if the storage has expired" + return self._data_ptr + + def remove_extra_reference(self) -> None: + self.extra_ref_check = None + + def expired(self) -> bool: + if self.extra_ref_check is not None and not self.extra_ref_check(): + return False + + # if extra_ref_check is not None we expect an additional reference + stor_count = torch._C._storage_Use_Count(self.ref.cdata) + return (stor_count - (self.extra_ref_check is not None)) == 0 + + def __repr__(self) -> str: + if self.ref is None or self.ref.expired(): + return f"StorageWeakRefWrapper to {self.data_ptr()}; dead" + else: + return f"StorageWeakRefWrapper to {self.data_ptr()}; alive" + + +def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool: + return maybe_deref(weak_ref) is not None + + +def maybe_deref( + weak_ref: Optional[StorageWeakRefWrapper], +) -> Optional[tuple[StorageWeakRefPointer, int]]: + if weak_ref is None: + return None + r = weak_ref() + if r is None: + return None + # NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr() + return r, weak_ref.data_ptr() + + +@contextlib.contextmanager +def _use_cuda_memory_pool_manager( + device: int, mem_pool: tuple[int, int], stream: torch.cuda.Stream +) -> Generator[None, None, None]: + """ + Context manager to use cuda graph pool for new allocations. If you use this manager + all cudagraph tensors in use should be reflected in the allocator or they will be overwritten. + existing_graph should already have been used in a capture, and the mem_pool must already exist, + because this manager will not preserve a reference to the pool which keeps it alive. + """ + torch.cuda.synchronize() + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream), torch.device(device): + # Begin allocate to mem pool for all memory allocation on the current thread. + # This is thread safe since a thread can only warmup or record 1 cudagraph + # at the same time. + torch._C._cuda_beginAllocateCurrentThreadToPool(device, mem_pool) + try: + yield + finally: + torch._C._cuda_endAllocateToPool(device, mem_pool) + torch._C._cuda_releasePool(device, mem_pool) + + torch.cuda.current_stream().wait_stream(stream) + + +def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]: + if not isinstance(t, torch.Tensor): + assert t is None + return None + return StorageWeakRefWrapper(t) + + +# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root +# at graph output offset +PathOutputIndex = tuple[int, int] + +# For each node in the path, for each output, is the output alive +PathLiveness = list[list[bool]] + +StackTraces = list[Optional[str]] + + +class CUDAWarmupNode: + """ + Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes + apis to get the live storages in the current chain of warmup. + + A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have + CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable + memory addresses. + + CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes. + - Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the + first instance of warmup, these are not finalized yet. + - All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup. + - CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler. + + NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and + `self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility. + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]], + cuda_graphs_pool: tuple[int, int], + existing_cuda_graph: Optional[torch.cuda.CUDAGraph], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + already_warm: bool, + id: GraphID, + ) -> None: + self.wrapped_function = wrapped_function + self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent + self.cuda_graphs_pool = cuda_graphs_pool + self.outputs_weakrefs: list[Optional[StorageWeakRefWrapper]] = [] + self.tensor_weakrefs: list[Optional[TensorWeakRef]] = [] + self.existing_cuda_graph = existing_cuda_graph + self.has_run = False + self.device_index = device_index + self.stack_traces = stack_traces + self.stream = stream + self.already_warm = already_warm + self.id = id + + def run(self, new_inputs: Any) -> OutputType: + assert not self.has_run, "Wrapped function should never be run twice" + + # See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created + # storages in path_live_weakrefs. + existing_path_data_ptrs = OrderedSet( + [t.data_ptr() for t in self.path_live_weakrefs() if t()] + ) + + def get_non_cudagraph_inps() -> list[weakref.ReferenceType[UntypedStorage]]: + non_cudagraph_inps = [ + weakref.ref(t.untyped_storage()) + for t in itertools.chain(new_inputs, self.wrapped_function.constants) + if isinstance(t, torch.Tensor) + and t.untyped_storage().data_ptr() not in existing_path_data_ptrs + ] + return non_cudagraph_inps + + non_cudagraph_inps_storages = get_non_cudagraph_inps() + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, refs) + + with ( + torch.cuda.device(self.device_index), + disable_conv_cache_emptying(), + clear_cublas_manager(), + _use_cuda_memory_pool_manager( + self.device_index, self.cuda_graphs_pool, self.stream + ), + get_history_recording(), + ): + out = self.wrapped_function.model(new_inputs) + + # We need to know which outputs are allocated within the cudagraph pool + # so that we can deallocate them at the beginning of the next cudagraph step, + # and set their access to error. + # We use a weakref to the inputs storage, in case a block which was previously + # allocated to the general caching allocator pool gets reallocated to a private pool. + + non_cudagraph_inps_storage_ptrs = OrderedSet[Any]() + for storage in non_cudagraph_inps_storages: + s = storage() + if s is not None: + non_cudagraph_inps_storage_ptrs.add(s._cdata) + + assert len(new_inputs) == 0 + + # sdpa returns cpu tensors when not recording cuda graph + def add_ref(o: Any) -> bool: + return ( + isinstance(o, torch.Tensor) + and o.is_cuda + and o.untyped_storage()._cdata not in non_cudagraph_inps_storage_ptrs + and o.untyped_storage().data_ptr() != 0 + ) + + self.outputs_weakrefs.extend( + [map_to_ref(o) if add_ref(o) else None for o in out] + ) + self.tensor_weakrefs.extend( + [TensorWeakRef(o) if add_ref(o) else None for o in out] + ) + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + out_refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs) + + return out + + @property + def _path_from_root( + self, + ) -> Generator[Union[CUDAGraphNode, CUDAWarmupNode], None, None]: + nodes = [] + node: Union[CUDAGraphNode, CUDAWarmupNode] = self + while node: + nodes.append(node) + node = node.parent # type: ignore[assignment] + + yield from reversed(nodes) + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + "Returns all live storages weakrefs that created by nodes in this path" + for node in self._path_from_root: + for output in node.outputs_weakrefs: + if is_live(output): + yield output # type: ignore[misc] + + def all_outputs_are_dead(self) -> bool: + return not list(self.path_live_weakrefs()) + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: + for storage_weak_ref in self.path_live_weakrefs(): + if t.untyped_storage().data_ptr() == storage_weak_ref.data_ptr(): + return True + return False + + +# Aliases for List that say what the indices denote +InputList = list # input indexes +OutputList = list # output indexes +LevelList = list # levels (distance from root of tree) + + +class OutputAliasInfo: + pass + + +class _UnaliasedStorage(OutputAliasInfo): + "Singleton to mark that the graph output constructs a new alias or is None" + + +UnaliasedStorage = _UnaliasedStorage() + + +class AliasesPriorGraphOutput(OutputAliasInfo): + "Marks that the graph output aliases an output of a prior graph" + + __slots__ = ["index"] + + index: PathOutputIndex + + def __init__(self, index: PathOutputIndex) -> None: + assert isinstance(index, tuple) + self.index = index + + +class AliasesNewOutput(OutputAliasInfo): + "Marks that the graph output aliases an index in the new, returned outputs" + + __slots__ = ["index"] + + index: int + + def __init__(self, index: int) -> None: + assert isinstance(index, int) + self.index = index + + +class CUDAGraphNode: + """ + A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool + and are structured into a tree, where there is a single recording that can precede it (parent) and multiple + subsequent recordings that may follow (children). A node will have no parent if it is the first recording + in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which + would force a dependency. + + On first recording, all of the live tensors in the current CUDA Graph Node path will be + reflected in the corresponding private pool. On subsequent executions, the caching allocator + is unaffected when the graph is replayed. + + In order to support recording a subsequent cuda graph recording after execution of this graph, + we checkpoint the state of the memory pool so that it may later be resumed. + + WrappedFunction should have already been warmed up prior to invocation. + + See [setCheckpointPoolState] for further explanation, as well as + https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + id: GraphID, + parent: Optional[CUDAGraphNode], + inputs: list[InputType], + cuda_graphs_pool: tuple[int, int], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + mode: Optional[CompilationMode], + compile_id: Optional[CompileId], + ) -> None: + assert isinstance(inputs, (list, tuple)) + + self.wrapped_function = wrapped_function + self.id = id + self.device = device_index + self.stack_traces = stack_traces + self.stream = stream + + # Enable re-record a cudagraph when static tensor address changed. + # if not we should error when it changed. + self.rerecord_if_static_inputs_change = ( + torch._dynamo.config.inline_inbuilt_nn_modules + or torch._inductor.config.triton.cudagraph_support_input_mutation + ) + + # if this is a root parent will be None. use weakref to prevent reference cycle + self._parent = weakref.ref(parent) if parent is not None else None + # reference to the shared memory pool for the entire cuda graphs tree + self.cuda_graphs_pool = cuda_graphs_pool + + # A single wrapped function may be recorded multiple times if memory patterns or + # invariants change from one execution to the next + self.children: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list) + + # StorageWeakRef maintains whether the Storage C++ object remains allocated, + # not whether the corresponding memory has been deallocated. In order + # to use them to track memory deallocations we must maintain a single StorageWeakRef + # for all Storages that reference that memory (even if we are constructing Storages + # that do not have a deallocator function). We maintain one single storage_cache + # as we execute any tree path. When we retrieve a storage from the cache we + # check that it is still alive, and we hash based on observed recording data ptr + # and storage cdata. + + # we preserve a single reference to executed outputs that is then referenced + # in children to avoid children having to chase parent pointers in the hot path + # DO NOT reassign output_weakrefs, only call `clear()` + # Path is a series of nodes from root to the current node + self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = [] + self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [ + node.outputs_weakrefs for node in self._path_from_root + ] + self.path_stacktraces: LevelList[Optional[StackTraces]] = [ + node.stack_traces for node in self._path_from_root + ] + self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = [] + + # tensors which are outputs of previous graphs in the tree + self.cudagraph_managed_idxs: list[int] = [ + idx + for idx, t in enumerate(inputs) + if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) + ] + + # (depth, offset) of live tensors which are alias of previous graph outputs + self.live_cudagraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [ + ( + self._is_alias_of_live_recorded_tensor(t) + if isinstance(t, torch.Tensor) + else None + ) + for t in inputs + ] + + # when replay, preserve the liveness of an input if it AliasesPriorGraphOutput + # and also aliases an output of the current CUDAGraphNode + self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs) + + self.static_input_idxs: list[int] = list( + OrderedSet(wrapped_function.static_input_idxs) + | OrderedSet(self.cudagraph_managed_idxs) + ) + + self.non_static_input_idx: LevelList[int] = [ + i for i in range(len(inputs)) if i not in self.static_input_idxs + ] + + counters["inductor"]["cudagraph_recorded_non_static_inputs"] += len( + self.non_static_input_idx + ) + + self.non_managed_static_input_idxs: LevelList[int] = [ + i + for i in wrapped_function.static_input_idxs + if i not in self.cudagraph_managed_idxs + ] + + def maybe_get_static_data_ptr( + idx: int, + inputs: list[InputType], + static_input_idxs: list[int], + ) -> Optional[int]: + inp = inputs[idx] + if isinstance(inp, torch.Tensor) and idx in static_input_idxs: + return inp.data_ptr() + return None + + self.static_input_data_ptrs: InputList[Optional[int]] = [ + maybe_get_static_data_ptr(i, inputs, self.static_input_idxs) + for i in range(len(inputs)) + ] + + # When we checkpoint, and free generations, we will be manually freeing the outputs + # of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for + # their liveness (they are static), so we need to compute which outputs are aliases of + # parameters. Some static inputs are saved tensors from the forward that die in the backward. + # Their locations are static but lifetimes are not. We only include the persistent static + # data ptrs below because the non persistent data ptrs may be outputs of this record and + # fresh allocations. + + # precompute expanded dims to avoid computing in the hot path + self.expanded_dims: list[list[int]] = [ + get_expanded_dims(x) + if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs + else [] + for idx, x in enumerate(inputs) + ] + + # For each node in path, which outputs were observed to be live + # before invoking graph recording, and after graph recording + self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = [] + self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = [] + + # List of Tuples of (depth, output_index) that index into node at depth + # number of nodes from root and output_index of outputs. Will index into + # path_weakrefs. + self.expected_dead_indices_before_graph: list[PathOutputIndex] = [] + self.expected_dead_indices_after_graph: list[PathOutputIndex] = [] + + # all live indices after graph recording + self.live_indices_after_graph: list[PathOutputIndex] = [] + + if self.parent is not None: + previous_liveness = self.parent.recorded_liveness_after_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + different_indices = self._get_different_indices( + previous_liveness, curr_liveness + ) + + self.recorded_liveness_before_graph = curr_liveness + self.expected_dead_indices_before_graph = different_indices + + rng_states = [inp for inp in inputs if isinstance(inp, torch.Generator)] + recording_inputs = self._allocate_and_copy_recording_inputs(inputs) + # recording inputs will copy over memory, so we can free non recording inputs + inputs.clear() + del inputs + + # graph used for recording model invocation + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + + # TODO: register_generator_state should potentially take explicit device + with torch.cuda.device(self.device): + for rng_state in rng_states: + self.graph.register_generator_state(rng_state) + + # we allocate non-static inputs within the same memory pool as the CUDAGraph + # which we will record the model with. For memory efficiency, it is important + # to reclaim the input memory when the inputs are no longer live. To accomplish this, + # we reconstruct tensors at the correct data pointers of our inputs which are + # non owning and do not prevent deallocation. On subsequent executions, input values + # will be copied over to these tensors. + self.reconstructed_inputs: list[InputType] = [ + self._reconstruct_from_tensor_metadata(self._tensor_metadata(x)) + if isinstance(x, torch.Tensor) + else x + for x in recording_inputs + ] + + # DO THE RECORDING!!! + # We record the CUDA graph in the constructor of CUDAGraphNode, which + # gives you what the CPU side compute of the function would do. We + # don't throw the recording outputs away: their memory is + # correctly accounted for in the CUDAGraphs caching allocator. This + # means on the very FIRST run of the CUDA graph node, we can directly + # do more recording, because we have a valid caching allocator state. + # NB: This relies on run() being called immediately after the + # constructor, otherwise this optimization would not be valid. + + # initialized below in _record + + self.checkpointed_caching_state: Optional[AllocatorState] = None + + # Output Storage Alias information, can be: + # - A new, unaliased storage, or the output is None + # - An alias of an output of a prior graph + # - An alias of an output already created in the reconstructed outputs + # This is None if the output in question is an int + self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = [] + + # is the output Storage unaliased in subsequent outputs, of all subsequent paths + # if it is, we cached the output tensor and adjust storage liveness tracking to also + # check if the output tensor does not have an additional python reference. + # If a descendent node discovers it has an alias of a prior output, then the output + # will no longer be cached in the ancestor. + # The large majority of tensors are unaliased, and preserving aliased output tensors would add + # significant additional complexity with marginal gains + # The cached tensor outputs are added on the first execution, and cleared whenever we need + # to do subsequent recording + self.unaliased_in_all_paths: OutputList[bool] = [] + self.cached_tensor_outputs: OutputList[Optional[Tensor]] = [] + + # if an output aliases a static, persistent input then the corresponding Tensor will + # be set here. These are different than cached tensors, because they are tensors that + # are aliases of parameters that are always live. + self.static_output_tensors: OutputList[Optional[Tensor]] = [] + + # Cleared after recording + with dynamo_timed_cudagraph("CUDAGraphNode.record", compile_id, mode): + self.recording_outputs: Optional[OutputType] = self._record( + wrapped_function.model, recording_inputs + ) + self.outputs_metadata: OutputList[Union[dict[str, Any], int, None]] = [] + + # As with inputs, we do not want to keep the outputs permanently alive because that would prevent + # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata + # needed to reconstruct instead. + assert self.recording_outputs is not None + for out in self.recording_outputs: + if isinstance(out, torch.Tensor): + self.outputs_metadata.append( + self._tensor_metadata(out, ignore_storage_offset=False) + ) + else: + assert isinstance(out, (int, type(None))), type(out) + self.outputs_metadata.append(out) + + self.graph.replay() + + def _copy_inputs_and_remove_from_src( + self, dsts: list[InputType], srcs: list[InputType] + ) -> None: + dst_tensors = [] + src_tensors = [] + for idx in self.non_static_input_idx: + if not isinstance(srcs[idx], torch.Tensor): + continue + expanded_dims = self.expanded_dims[idx] + dst_tensors.append(index_expanded_dims(dsts[idx], expanded_dims)) # type: ignore[arg-type] + src_tensors.append(index_expanded_dims(srcs[idx], expanded_dims)) # type: ignore[arg-type] + srcs[idx] = None # type: ignore[call-overload] + # Fails on empty lists + if dst_tensors: + torch._foreach_copy_(dst_tensors, src_tensors) + + def check_static_inputs_are_stable(self, new_inputs: list[InputType]) -> None: + # avoid checking managed tensor static points since we already checked those in check_invariants + if ( + not self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + new_inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + ) + ): + # this should error + error_msg = log_data_ptr_mismatch( + self.wrapped_function.placeholders, + new_inputs, + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + CheckInvariantStatus.StaticInputIdxMismatch, + ) + torch._check(False, lambda: error_msg) + + def run_first_inputs(self, new_inputs: list[InputType]) -> OutputType: + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_before_invocation() + + # graph is already invoked in the __init__ + # inputs are copied over in _allocate_recording_inputs and subsequently cleared + assert len(new_inputs) == 0 + outputs = self.recording_outputs + self.recording_outputs = None + assert outputs is not None + return outputs + + def run(self, new_inputs: list[InputType]) -> OutputType: + self.check_static_inputs_are_stable(new_inputs) + + self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs) + + self.run_graph() + + outputs = self.reconstruct_outputs() + new_inputs.clear() + + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_after_invocation() + + if config.triton.force_cudagraph_sync: + torch.cuda.synchronize() + + # Reset this to run the check in the future + self.static_inputs_stable = False + + return outputs + + def reconstruct_outputs(self) -> OutputType: + "Reconstruct output tensors according to their saved metadata and alias information" + + # Cached tensors will not yet be set on the first execution + # They are also cleared in checkpointing, so if we checkpoint this node + # and then execute it again we will need to repopulate cached tensors + if not self.cached_tensor_outputs: + self._initialize_cached_tensors() + + outputs: OutputType = [] + + for i, (storage_info, metadata) in enumerate( + zip(self.output_storage_alias, self.outputs_metadata) + ): + if not isinstance(metadata, dict): # tensor metadata + assert isinstance(metadata, (int, type(None))) + outputs.append(metadata) + continue + + cached_t = self.cached_tensor_outputs[i] + if cached_t is not None: + # this output represents a fresh allocated tensor. + # We return the same TensorImpl from run to run to avoid overhead. + # autograd.Function will reset the Autograd meta of output tensors + # as part of aot_autograd, but _backward_hooks are stored on tensors separately, + # so we need to manually reset hooks. + if cached_t._backward_hooks is not None: + cached_t._backward_hooks = None + + # No need to update weakrefs, already correctly initialized + outputs.append(cached_t) + continue + + static_t = self.static_output_tensors[i] + if static_t is not None: + assert self.outputs_weakrefs[i] is None + outputs.append(static_t) + continue + + storage = self.prepare_alias_info_for_tensor_construction( + storage_info, metadata + ) + + if isinstance(storage, UntypedStorage) or storage is None: + out = self._reconstruct_from_tensor_metadata(metadata, storage) + else: + assert isinstance(storage, int) + out = self._reconstruct_from_tensor_metadata( + metadata, cast(torch.Tensor, outputs[storage]).untyped_storage() + ) + + outputs.append(out) + w = self.outputs_weakrefs[i] + assert w is not None + w.swap_weakref(out.untyped_storage()._weak_ref()) + + return outputs + + def prepare_alias_info_for_tensor_construction( + self, + out_alias_info: Optional[OutputAliasInfo], + metadata: Union[dict[str, Any], int, None], + ) -> Union[UntypedStorage, None, int]: + if ( + isinstance(metadata, (int, type(None))) + or out_alias_info is UnaliasedStorage + ): + return None + + if isinstance(out_alias_info, AliasesPriorGraphOutput): + depth, existing_output_index = out_alias_info.index + ref = self.path_weakrefs[depth][existing_output_index] + assert ref is not None + return torch.UntypedStorage._new_with_weak_ptr(ref()) + + assert isinstance(out_alias_info, AliasesNewOutput) + return out_alias_info.index + + def prepare_storages_for_construction( + self, + ) -> list[Union[UntypedStorage, None, int]]: + output_storages = [] + for output_storage_alias, metadata in zip( + self.output_storage_alias, self.outputs_metadata + ): + output_storages.append( + self.prepare_alias_info_for_tensor_construction( + output_storage_alias, metadata + ) + ) + + return output_storages + + def run_graph(self) -> None: + assert self.graph is not None + self.graph.replay() + + def all_outputs_are_dead(self) -> bool: + "All outputs of the path from this node to its root are dead" + for depth, output_index in self.live_indices_after_graph: + if is_live(self.path_weakrefs[depth][output_index]): + return False + return True + + def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType: + "Record the model" + + def static_input_iter() -> Generator[torch.Tensor, None, None]: + for i in self.wrapped_function.static_input_idxs: + _inp = inputs[i] + if isinstance( + _inp, torch.Tensor + ) and not self._is_cuda_graph_recorded_tensor(_inp): + yield _inp + + # see: output_is_alias_of_persistent_static_inputs above + static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper] = { + inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp) + for inp in itertools.chain( + static_input_iter(), self.wrapped_function.constants + ) + } + + if config.triton.slow_path_cudagraph_asserts: + # need to use parent live weakrefs because live_indices isn't set yet + memory = ( + [] if self.parent is None else list(self.parent.path_live_weakrefs()) + ) + memory += [ + StorageWeakRefWrapper(elem) + for i, elem in enumerate(inputs) + if isinstance(elem, torch.Tensor) + and i not in self.wrapped_function.static_input_idxs + and elem.untyped_storage().data_ptr() != 0 + ] + check_memory_pool(self.device, self.cuda_graphs_pool, memory) + + with ( + preserve_rng_state(), + torch.cuda.device(self.device), + clear_cublas_manager(), + torch.cuda.graph( + self.graph, + stream=self.stream, + pool=self.cuda_graphs_pool, + capture_error_mode="thread_local", + ), + get_history_recording(), + ): + static_outputs = model(inputs) + + # running model should reclaim memory + assert len(inputs) == 0 + + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs) + + return static_outputs + + def _add_first_outputs( + self, + outputs: OutputType, + static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper], + ) -> None: + "Add the outputs from the first invocation of the node and set up metadata" + + # getting liveness before we have added the outputs to path, so the length + # of the two lists is equal + prev_liveness = self.recorded_liveness_before_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + delta = self._get_different_indices(prev_liveness, curr_liveness) + self.expected_dead_indices_after_graph = delta + + assert len(self.outputs_weakrefs) == 0 + # index from data pointer to index in outputs + output_new_storages_index: dict[StorageDataPtr, int] = {} + + self.unaliased_in_all_paths = [False for _ in range(len(outputs))] + self.static_output_tensors = [None for _ in range(len(outputs))] + + for i, o in enumerate(outputs): + if o is None or not isinstance(o, torch.Tensor): + self.output_storage_alias.append(UnaliasedStorage) + continue + + ( + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" + ), + ), + ) + + ref = static_input_persistent_storage_ptrs.get( + o.untyped_storage().data_ptr(), None + ) + # also treat empty storages as static outputs because we do not need to manage their lifetime + # and they should not participate in checkpointing + is_empty_storage = o.untyped_storage().data_ptr() == 0 + if (ref and ref() is not None) or is_empty_storage: + self.output_storage_alias.append(None) + self.static_output_tensors[i] = o + continue + + path_ref = self._is_alias_of_live_recorded_tensor(o) + if path_ref is not None: + self._mark_prior_graph_output_as_aliased(path_ref) + + for idx, inp_path_ref in enumerate( + self.live_cudagraph_managed_path_refs + ): + if path_ref == inp_path_ref: + self.preserved_aliased_inputs[idx] = True + self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) + continue + + if o.untyped_storage().data_ptr() in output_new_storages_index: + index = output_new_storages_index[o.untyped_storage().data_ptr()] + self.unaliased_in_all_paths[index] = False + self.output_storage_alias.append(AliasesNewOutput(index)) + continue + + output_new_storages_index[o.untyped_storage().data_ptr()] = i + self.output_storage_alias.append(UnaliasedStorage) + self.unaliased_in_all_paths[i] = True + + if self.stack_traces is None: + self.stack_traces = [None for _ in range(len(outputs))] + else: + assert len(self.stack_traces) == len(outputs), ( + "Wrong number of stack traces passed in" + ) + + assert not self.outputs_weakrefs + for out, static_output_tensor in zip(outputs, self.static_output_tensors): + if not isinstance(out, torch.Tensor) or static_output_tensor is not None: + self.outputs_weakrefs.append(None) + self.tensor_weakrefs.append(None) + else: + self.outputs_weakrefs.append(StorageWeakRefWrapper(out)) + self.tensor_weakrefs.append(TensorWeakRef(out)) + + self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs) + self.checkpointed_caching_state = torch._C._cuda_getCheckpointState( + self.device, self.cuda_graphs_pool + ) + + # now, get liveness with outputs added + for depth in range(len(self.path_weakrefs)): + for output_index in range(len(self.path_weakrefs[depth])): + if is_live(self.path_weakrefs[depth][output_index]): + self.live_indices_after_graph.append((depth, output_index)) + + self.debug_check_invariants_after_invocation() + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs()) + ) + + def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex) -> None: + "Remove a graph output from the unaliased, cached tensors in an ancestor node" + depth, output_index = index + node = list(self._path_from_root)[depth] + node.unaliased_in_all_paths[output_index] = False + x = self.path_weakrefs[depth][output_index] + assert x is not None + x.remove_extra_reference() + + def _initialize_cached_tensors(self) -> None: + # we should not be clearing output_weakrefs, and they should be set in the first + # record run + assert len(self.outputs_weakrefs) == len(self.outputs_metadata) + + for i, (storage_info, metadata, make_cached) in enumerate( + zip( + self.output_storage_alias, + self.outputs_metadata, + self.unaliased_in_all_paths, + ) + ): + if not make_cached: + self.cached_tensor_outputs.append(None) + continue + + assert storage_info is UnaliasedStorage + assert isinstance(metadata, dict) + s = self.create_storage(metadata) + out = self._reconstruct_from_tensor_metadata(metadata, storage=s) # type: ignore[arg-type] + + # XXX: let autograd know that there will be an additional reference to the tensor + # that can be ignored when deciding whether to do gradient buffer inplacing. + # Otherwise, inplacing could differ between tracing and subsequent execution. + # For some models we tested this led to inputs no longer being in cudagraph pools, + # leading to spurious re-recordings. + # It also tells AMP cache that even though the tensor impls cannot be cached + # in dtype conversions. + + torch._C._add_cached_tensor(out) + + self_ref = weakref.ref(self) + + # one reference in our array, and calling sys.getrefcount bumps the refcount by one + def check_refcount(i: int) -> bool: + self_loc = self_ref() + if self_loc is None: + return False + return self_loc.get_output_refcount(i) == 2 + + check = functools.partial(check_refcount, i=i) + + self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check) + self.cached_tensor_outputs.append(out) + + def get_output_refcount(self, index: int) -> int: + return sys.getrefcount(self.cached_tensor_outputs[index]) + + @property + def parent(self) -> Optional[CUDAGraphNode]: + "unwraps the weakref to _parent" + return self._parent() if self._parent is not None else None + + @property + def _path_to_root(self) -> Generator[CUDAGraphNode, None, None]: + "Returns all nodes in the path starting at self and ending at root" + node = self + while node: + yield node + node = node.parent # type: ignore[assignment] + + @property + def _path_from_root(self) -> Generator[CUDAGraphNode, None, None]: + "Returns all nodes in the path starting at the root and ending at self" + nodes = reversed(list(self._path_to_root)) + yield from nodes + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: + "Is this tensor an output of a node in this path" + for output_refs in self.path_weakrefs: + for storage_weak_ref in output_refs: + if storage_weak_ref is None: + continue + # don't need to check liveness of storage since the cuda graph managed + # memory is never released. + data_ptr = storage_weak_ref.data_ptr() + if t.untyped_storage().data_ptr() == data_ptr: + return True + + return False + + def _is_alias_of_live_recorded_tensor( + self, t: torch.Tensor + ) -> Optional[PathOutputIndex]: + for depth, output_refs in enumerate(self.path_weakrefs): + for output_index, storage_ref in enumerate(output_refs): + if (storage_and_ptr := maybe_deref(storage_ref)) is not None: + _storage, ptr = storage_and_ptr + if ptr == t.untyped_storage().data_ptr(): + return (depth, output_index) + + return None + + @staticmethod + def _check_liveness( + indices: list[PathOutputIndex], + output_refs: list[list[Optional[StorageWeakRefWrapper]]], + ) -> bool: + "Check that all of the indices specified are dead references" + for depth, output_index in indices: + w = output_refs[depth][output_index] + assert w is not None + if w() is not None: + return False + return True + + def add_child(self, function_id: FunctionID, node: CUDAGraphNode) -> None: + "Adds node as a a child of self" + self.children[function_id].append(node) + + @staticmethod + def _get_different_indices( + prev: list[list[bool]], curr: list[list[bool]] + ) -> list[PathOutputIndex]: + "Find indices where the two lists differ." + dead_indices = [] + assert len(prev) <= len(curr) + for i, (outputs1, outputs2) in enumerate(zip(prev, curr)): + assert len(outputs1) == len(outputs2) + for j, (output1, output2) in enumerate(zip(outputs1, outputs2)): + if output1 != output2: + dead_indices.append((i, j)) + + return dead_indices + + @staticmethod + def _get_liveness( + weakrefs: list[list[Optional[StorageWeakRefWrapper]]], + ) -> list[list[bool]]: + "Maps weakrefs to true if the reference is alive and false otherwise" + if len(weakrefs) == 0: + return [] + + return [pytree.tree_map(is_live, outputs) for outputs in weakrefs] + + def debug_assert_invariants( + self, expected_liveness: list[list[bool]], newly_dead: list[PathOutputIndex] + ) -> None: + if not config.triton.fast_path_cudagraph_asserts: + return + + for i, node in enumerate(self._path_from_root): + assert self.path_weakrefs[i] is node.outputs_weakrefs + + nodes = list(self._path_from_root) + + live_blocks = get_block_addrs(self.cuda_graphs_pool) + + live_storage_data_ptrs = OrderedSet[Any]() + live_storage_weak_ptrs = OrderedSet[Any]() + + for depth, outputs_liveness in enumerate(expected_liveness): + for output_idx, output_liveness in enumerate(outputs_liveness): + # tensor can die early, but it can't be alive when it should be dead + w = self.path_weakrefs[depth][output_idx] + if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None: + assert output_liveness + stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr + assert (stor_data_ptr in live_storage_data_ptrs) == ( + stor_weak_ptr in live_storage_weak_ptrs + ) + live_storage_data_ptrs.add(stor_data_ptr) + live_storage_weak_ptrs.add(stor_weak_ptr) + + is_persistent_alias = ( + nodes[depth].static_output_tensors[output_idx] is not None + ) + + if is_persistent_alias: + assert stor_data_ptr not in live_blocks + + for depth, output_index in newly_dead: + assert not is_live(self.path_weakrefs[depth][output_index]) + + def debug_check_invariants_before_invocation(self) -> None: + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph + ) + + def debug_check_invariants_after_invocation(self) -> None: + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph + ) + + def data_ptrs_dead_since_invocation(self) -> list[int]: + """ + Since this node was invoked, return data ptrs of all tensor outputs that have died + in the current executing tree path. + """ + curr_liveness = self._get_liveness(self.path_weakrefs) + _get_different_indices = self._get_different_indices( + self.recorded_liveness_after_graph, curr_liveness + ) + + path = list(self._path_from_root) + ptrs_to_deallocate = [] + for depth, output_index in _get_different_indices: + ptrs_to_deallocate.append( + path[depth].outputs_metadata[output_index]["data_ptr"] # type: ignore[index] + ) + + return ptrs_to_deallocate + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + for i, j in self.live_indices_after_graph: + out = self.path_weakrefs[i][j] + if out is not None and is_live(out): + yield out + + def remove_node_cached_tensors(self) -> None: + for t in self.cached_tensor_outputs: + if t is not None: + torch._C._remove_cached_tensor(t) + self.cached_tensor_outputs.clear() + + for i, unaliased in enumerate(self.unaliased_in_all_paths): + if unaliased: + n = self.outputs_weakrefs[i] + assert n is not None + n.remove_extra_reference() + + def remove_path_cached_tensors(self) -> None: + for node in self._path_from_root: + node.remove_node_cached_tensors() + + def clear_path_state(self) -> None: + "Clear the path state in this current executing node" + # this doesn't actually do anything right now, leaving it as placeholder + + @staticmethod + def _tensor_metadata( + x: torch.Tensor, ignore_storage_offset: bool = True + ) -> dict[str, Any]: + assert isinstance(x, torch.Tensor) + # We ignore the storage offset for inputs, but not for outputs + # TODO: - should we make the storage resizable ? + return { + "nbytes": x.untyped_storage().nbytes(), + "data_ptr": x.untyped_storage().data_ptr(), + "size": x.shape, + "stride": x.stride(), + "dtype": x.dtype, + "device": x.device, + "storage_offset": x.storage_offset() if not ignore_storage_offset else 0, + } + + def _reconstruct_from_tensor_metadata( + self, metadata: dict[str, Any], storage: Optional[UntypedStorage] = None + ) -> Tensor: + s = self.create_storage(metadata) if storage is None else storage + return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type] + + def create_storage(self, metadata: dict[str, Any]) -> torch.types.Storage: + return torch._C._construct_storage_from_data_pointer( + metadata["data_ptr"], metadata["device"], metadata["nbytes"] + ) + + def _allocate_and_copy_recording_inputs( + self, inputs: list[InputType] + ) -> list[InputType]: + """ + Allocate inputs for non static, non cudagraph managed tensors in the memory pool + and copy over the tensor values. + """ + + torch.cuda.synchronize() + self.stream.wait_stream(torch.cuda.current_stream()) + recording_inputs: list[InputType] = [] + + with ( + warnings.catch_warnings(record=True), + torch.cuda.device(self.device), + _use_cuda_memory_pool_manager( + self.device, + mem_pool=self.cuda_graphs_pool, + stream=self.stream, + ), + ): + for i, inp in enumerate(inputs): + if not isinstance(inp, torch.Tensor): + assert isinstance(inp, (int, torch.Generator)) + recording_inputs.append(inp) + elif i not in self.static_input_idxs: + # static_input does an allocation! + recording_inputs.append(static_input(inp)) + else: + recording_inputs.append(inp) + + self._copy_inputs_and_remove_from_src(recording_inputs, inputs) + + return recording_inputs + + def check_invariants( + self, inputs: list[InputType] + ) -> tuple[CheckInvariantStatus, Callable[..., str]]: + """ + Checks if this node can be run. The same pattern of tensor liveness, static inputs, + and tensors managed in the cudagraph private pool must remain stable. + """ + + _logger = functools.partial( + log_data_ptr_mismatch, + self.wrapped_function.placeholders, + inputs, + self.static_input_data_ptrs, + ) + + # previously managed data pointers remain stable + # this is on the hot path so moved to C++. equivalent to: + # return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs)) + if not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.cudagraph_managed_idxs, + ): + status = CheckInvariantStatus.CudagraphManagedIdxMismatch + _logger = functools.partial( + _logger, + self.cudagraph_managed_idxs, + status, + ) + return status, _logger + + if not self._check_liveness( + self.expected_dead_indices_before_graph, self.path_weakrefs + ): + status = CheckInvariantStatus.ExpectedDeadIndicesBeforeGraphMismatch + return status, lambda: f"{status}" + + # static input data pointers should remain stable + # if we are inlining builtin nn modules we re-record in this case + # if we are not inlining builtin nn modules, we check this in check_static_inputs_are_stable + # and error if they are not stable + if ( + self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.static_input_idxs, + ) + ): + status = CheckInvariantStatus.StaticInputIdxMismatch + _logger = functools.partial( + _logger, + self.static_input_idxs, + status, + ) + return status, _logger + + # the cudagraph managed tensors which died upon recording must also die upon + # this invocation. it is too late to check after we've replayed the graph, + # because we would have already written over their memory. + for idx in self.cudagraph_managed_idxs: + if not self.preserved_aliased_inputs[idx]: + inputs[idx] = None # type: ignore[call-overload] + + torch._check( + self._check_liveness( + self.expected_dead_indices_after_graph, self.path_weakrefs + ), + lambda: "TODO: graph recording observed an input tensor deallocate during graph " + " recording that did not occur during replay. Please file an issue.", + ) + return CheckInvariantStatus.SUCCESS, lambda: f"{CheckInvariantStatus.SUCCESS}" + + def num_descendants(self) -> int: + "Total number of descendents of this node" + num_desc = 0 + for children in self.children.values(): + for child in children: + num_desc += 1 + num_desc += child.num_descendants() + return num_desc + + +def get_cudagraph_segments(pool_id: tuple[int, int]) -> Any: + segments = torch.cuda.memory_snapshot() + return [segment for segment in segments if segment["segment_pool_id"] == pool_id] + + +def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> list[int]: + blocks = [] + + for segment in get_cudagraph_segments(pool_id): + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated" or not live_only: + blocks.append(addr) + + addr += block["size"] + + return blocks + + +def format_tb(frames: list[Any]) -> str: + formatted_traceback = [ + traceback.FrameSummary(entry["filename"], entry["line"], entry["name"]) + for entry in frames + ] + + return "".join(traceback.format_list(formatted_traceback)) + + +def check_memory_pool( + device: int, + pool_id: tuple[int, int], + live_storages_ptrs: list[StorageWeakRefWrapper], +) -> None: + assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) # noqa: C419 + unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} # noqa: set_linter + + # check if there is a divergence first, then do the expensive snapshot call after + # we know it will error + if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages): + return + + # at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead, + # but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages + gc.collect() + torch.cuda.synchronize() + + segments = get_cudagraph_segments(pool_id) + + allocated_not_in_live_storages = {} + + for segment in segments: + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated": + if addr not in unique_storages: + allocated_not_in_live_storages[addr] = block + else: + unique_storages.remove(addr) + + addr += block["size"] + + torch._check( + len(unique_storages) == 0, + lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}", + ) + + if len(allocated_not_in_live_storages) != 0: + formatted = [] + for dp, block in allocated_not_in_live_storages.items(): + trace = format_tb(block.get("frames", [])) + formatted.append(f"Data Pointer: {dp}, history: \n{trace}") + formatted_s = "\n".join(formatted) + msg = ( + f"These live storage data ptrs are in the cudagraph pool but not " + f"accounted for as an output of cudagraph trees: \n\n{formatted_s}" + ) + raise RuntimeError(msg) + + +class ExecutionState(Enum): + """ + Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated + in the cuda graph pool. Otherwise will reflect the state of the most recently executed node. + """ + + NONE = auto() + WARMUP = auto() + RECORDING = auto() + EXECUTION = auto() + + +class CompilationMode(Enum): + FORWARD = auto() + BACKWARD = auto() + INFERENCE = auto() + + +class CUDAGraphTreeManager: + """ + Groups individual recordings or executions of cuda graphs into a tree of recordings, + and checks required invariants, and manages warmups of graphs. + + When graphs are recorded in the same tree, it enforces subsequent execution + to follow the same order and have the same output tensor livespans. To remove + unnecessary coupling of cuda graphs (and additional imposed invariants), + the tree manager will end a currently recording tree whenever it is valid - when + the memory pool no longer has any live allocations. + + We ignore outputs from a previous generation that correspond to prior model outputs. + Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo. + # TODO: make generation increment configurable, warn on overwrite. + + We run graph warmups in the cudagraph memory pool and return the result on the first invocation + of a function. For many models it is important to reclaim activations as you run the backward. + If we were to warm up the model and keep an extra copy of the inputs around to subsequently + use for recording, we would incur a memory penalty. Additionally, if we are part way through training + your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this + warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors + to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph + replay. + """ + + def __init__(self, device_index: int) -> None: + # roots are functions which have no dependencies on an other node. I.e., + # when they are first invoked, none of their inputs are outputs are outputs + # of another node, nor are there any live outputs of another node whose + # liveness would create a dependency. + self.roots: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list) + + # mapping from function id to wrapped function + self.ids_to_funcs: dict[FunctionID, WrappedFunction] = {} + + self.ids_to_stack_traces: dict[FunctionID, Optional[StackTraces]] = {} + + self.warmed_up_functions: OrderedSet[FunctionID] = OrderedSet() + # if we fail to increment generation, and are stuck warming up, + # only warn on each function once + self.warned_functions: OrderedSet[FunctionID] = OrderedSet() + torch._C._set_cached_tensors_enabled(True) + + # warn only once if a function mutates inputs + self.warned_mutation: OrderedSet[FunctionID] = OrderedSet() + + # NB: cuda caching allocator will remember the stream a segment is allocated to + # and only allocate that segment to the same stream. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be reused; separate recordings would have use the same memory pool, but not + # the same memory. + + with torch.cuda.device(device_index): + torch.cuda.synchronize() + self.stream = torch.cuda.Stream() + self.stream.wait_stream(torch.cuda.current_stream()) + + # Keeps Memory Pool Alive + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle() + + with ( + warnings.catch_warnings(record=True), + torch.cuda.graph( + self.graph, + pool=self.cuda_graphs_thread_pool, + stream=self.stream, + capture_error_mode="thread_local", + ), + ): + pass + + self.graph_counter = itertools.count(0) + self.func_counter = itertools.count(0) + + # mapping from graph_id to (function id to mutation type hint) since we are + # specializing on a particular combination of Parent Node -> Function ID. + self.non_cudagraph_managed_mutation_hint: dict[ + Optional[GraphID], dict[FunctionID, bool] + ] = defaultdict(dict) + self.warmup_node_counter = itertools.count(start=-1, step=-1) + + # mapping from graph_id to (function id to re-record count). We fall back to + # eager function if a function is re-recorded frequently on a node. + self.num_rerecord: dict[Optional[GraphID], dict[FunctionID, int]] = defaultdict( + lambda: defaultdict(lambda: 0) + ) + + # whether we the current node is in a state of warmup, recording, execution. If + # there is no current node the state will be ExecutionState.None. + self.path_state = ExecutionState.NONE + self.device_index = device_index + + # the most recently invoked cudagraph wrapping of a function. Will be None + # when there is no output from a previous recording or execution whose memory + # we need to respect in the cuda caching allocation. If you incremented generation, + # this will also be none, as ignore those allocations. + self.current_node: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = None + + # current generation of cudagraph invocations. when torch.compile is run + # we increment the current generation. are willing to ignore live outputs + # of a previous generation in checking liveness. + self.current_gen: int = -1 + + # number of instances we are in execution and failed to match to an + # existing child + self.debug_fail_counter = 0 + # number of instances we had to checkpoint the function + self.debug_checkpointing_counter = 0 + + self.id_to_mode: dict[FunctionID, CompilationMode] = {} + self.id_to_compile_id: dict[FunctionID, Optional[CompileId]] = {} + + # Note: [Backward Generation Handling] + # We generally perform a sequence of forward executions followed by backward executions. + # If multiple torch.compile wrapped forwards are executed with their backwards pending, + # we should not disregard the outputs from a prior torch.compile since the entire training + # loop hasn't completed. Occasionally, a backward pass corresponding to a forward pass may + # not be executed, so we cannot wait for all pending forward pass backward completions, so + # we cannot wait for all backwards to have been invoked. Instead we wait for a single backward + # invocation. Triggering a backward pass typically doesn't lead to another torch.compile + # invocation, making it less likely for the generation to increase between multiple + # backward calls. The following use case is covered by this approach: + # mod1 = torch.compile(...) + # mod2 = torch.compile(...) + # mod2(mod1(x)).sum().backward() + + self.running_forwards_with_pending_backwards = False + self.mode: Optional[CompilationMode] = None + + self.disable_invalidate_aliases = ( + False + if not torch._environment.is_fbcode() + else torch._utils_internal.justknobs_check( + "pytorch/inductor:disable_cudagraph_alias_invalidation" + ) + ) + + def run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType: + assert self.graph is not None, "Running CUDAGraph after shutdown" + self.mode = self.id_to_mode[function_id] + self.compile_id = self.id_to_compile_id[function_id] + out = self._run(new_inputs, function_id) + + # The forwards are only pending following invocation, not before + if self.mode == CompilationMode.FORWARD: + self.running_forwards_with_pending_backwards = True + elif self.mode == CompilationMode.BACKWARD: + self.running_forwards_with_pending_backwards = False + + return out + + def set_to_running_backward(self) -> None: + self.running_forwards_with_pending_backwards = False + self.mode = CompilationMode.BACKWARD + + def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]: + return ( + self.current_node._is_cuda_graph_recorded_tensor + if isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)) + else lambda _: False + ) + + def new_warmup_node_id(self) -> GraphID: + return GraphID(next(self.warmup_node_counter)) + + def _update_non_cudagraph_managed_mutation( + self, function_id: FunctionID, inputs: list[InputType] + ) -> None: + node_id = self._get_node_id() + if maybe_mutation_str := check_for_mutation( + self.ids_to_funcs[function_id], + inputs, + self._get_cuda_graph_recorded_tensor_checker(), + ): + self.non_cudagraph_managed_mutation_hint[node_id][function_id] = True + # warn once per function_id + if function_id in self.warned_mutation: + return + self.warned_mutation.add(function_id) + log_cudagraph_skip_and_bump_counter(maybe_mutation_str) + else: + self.non_cudagraph_managed_mutation_hint[node_id][function_id] = False + + def _get_node_id(self) -> Optional[GraphID]: + if self.current_node is None: + return None + elif isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)): + return self.current_node.id + else: + raise RuntimeError(f"Unknown node type {type(self.current_node)}") + + def exceed_rerecord_limit( + self, node_id: Optional[GraphID], function_id: FunctionID + ) -> bool: + if torch._dynamo.config.inline_inbuilt_nn_modules: + return False + + return ( + self.num_rerecord[node_id][function_id] + > torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit + ) + + def _run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType: + # we will try to end the current execution lazily, since + # we dont want to do unnecessary checking of the existing outputs + # on the hot path, but both recording and warmup only happen once + # so we check up front + if self.in_recording: + self.try_end_curr_recording(function_id) + + if self.in_warmup: + self.try_end_curr_warmup(function_id) + + node_id = self._get_node_id() + if function_id not in self.non_cudagraph_managed_mutation_hint[node_id]: + self._update_non_cudagraph_managed_mutation(function_id, new_inputs) + + # Early exit if the function mutates inputs which are neither parameters/buffers nor + # cudagraph recorded tensors. This check should happen after `try_end_curr_recording` + # and `try_end_curr_warmup` which may change self.current_node. + if self.non_cudagraph_managed_mutation_hint[node_id][ + function_id + ] or self.exceed_rerecord_limit(node_id, function_id): + return self.ids_to_funcs[function_id].model(new_inputs) + + # warming up a function and subsequentally recording may use different memory addresses + # because both depend on the state of the caching allocator. if we warm up graph A, + # then warm up graph B and make more allocations, the subsequent recording of A will not + # necessarily use the same addresses as in the warm up. Thus any warm up of a node can only + # be followed by warm up runs. + if ( + ( + not ( + function_id in self.warmed_up_functions + or config.triton.skip_cudagraph_warmup + ) + ) + or self.in_warmup + or config.triton.force_cudagraphs_warmup + ): + # If we are in the middle of executing cuda graphs, then we need to checkpoint memory state. + # Both Recording and Warmup will be reflected in the allocator and dont need changes + if self.path_state == ExecutionState.EXECUTION: + self.apply_checkpoint_execution_state_in_allocator() + + return self.run_eager(new_inputs, function_id) + + assert not isinstance(self.current_node, CUDAWarmupNode) + child_nodes = ( + self.roots if self.current_node is None else self.current_node.children + ) + + if not self.in_recording: + unexpected_rerecord, unexpected_rerecord_reason = False, lambda: "" + for child in child_nodes[function_id]: + # here we are checking memory consistency between recording and execution, + # as well as things like stability of tensor locations, etc + # and other + status, status_logger = child.check_invariants(new_inputs) + if status == CheckInvariantStatus.SUCCESS: + return self.execute_node(child, new_inputs) + + if ( + status == CheckInvariantStatus.StaticInputIdxMismatch + or status == CheckInvariantStatus.CudagraphManagedIdxMismatch + ): + unexpected_rerecord = True + unexpected_rerecord_reason = status_logger + + # now that we know the new function can't be run as a child of the + # current node, if it is a root, try to end the current execution. + # as noted above, we want to do this lazily to avoid having to + # check all existing outputs + if self.current_node is not None and function_id in self.roots: + self.try_end_curr_execution() + + # run again to hit the root matching case which must succeed + if self.current_node is None: + return self.run(new_inputs, function_id) + + if len(self.ids_to_funcs[function_id].mutated_input_idxs) > 0: + self._update_non_cudagraph_managed_mutation(function_id, new_inputs) + if self.non_cudagraph_managed_mutation_hint[self._get_node_id()][ + function_id + ]: + return self.ids_to_funcs[function_id].model(new_inputs) + + # nb: run before checkpointing because checkpointing is slow, and we will + # be using the eager caching allocator pool which does not require live + # accounting of tensors in cudagraph allocator + if unexpected_rerecord: + curr_node_id = self._get_node_id() + self.num_rerecord[curr_node_id][function_id] += 1 + if self.exceed_rerecord_limit(curr_node_id, function_id): + _id = curr_node_id.id if curr_node_id else None + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraph due to function {function_id.id} exceeding max " + f"re-recording limit " + f"(={torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit}) " + f"on cudagraph node {_id} due to {unexpected_rerecord_reason()}." + ) + return self.ids_to_funcs[function_id].model(new_inputs) + + # at this point, we necessarily will do a new recording + self.debug_fail_counter += 1 + + self.try_end_curr_execution() + if self.current_node is not None: + self.apply_checkpoint_execution_state_in_allocator() + + # now, we are in a recording state ! + return self.record_function(new_inputs, function_id) + + def shutdown(self) -> None: + """ + Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn + might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown + to avoid a reference cycle. + """ + nodes = [] + for roots in self.roots.values(): + nodes.extend(roots) + + while nodes: + node = nodes.pop() + for children in node.children.values(): + nodes.extend(children) + node.remove_node_cached_tensors() + node.graph = None + + self.graph = None + self.roots = None # type: ignore[assignment] + self.current_node = None + + def record_function( + self, new_inputs: list[InputType], function_id: FunctionID + ) -> OutputType: + assert not isinstance(self.current_node, CUDAWarmupNode) + with torch._dynamo.callback_handler.install_callbacks( + CallbackTrigger.CUDAGRAPH_RECORDING, str(self.compile_id) + ): + graph_id = self.new_graph_id() + log.debug( + "Recording function %d of graph recording id %d", + function_id.id, + graph_id.id, + ) + torch.cuda.synchronize() + node = CUDAGraphNode( + self.ids_to_funcs[function_id], + graph_id, + self.current_node, + new_inputs, + self.cuda_graphs_thread_pool, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + self.mode, + self.compile_id, + ) + if self.current_node is None: + self.roots[function_id].append(node) + else: + self.current_node.add_child(function_id, node) + self.current_node = node + self.path_state = ExecutionState.RECORDING + self.update_generation() + torch.cuda.synchronize() + return node.run_first_inputs(new_inputs) + + def execute_node( + self, node: CUDAGraphNode, new_inputs: list[InputType] + ) -> OutputType: + self.current_node = node + self.path_state = ExecutionState.EXECUTION + self.update_generation() + return node.run(new_inputs) + + def run_eager( + self, new_inputs: list[InputType], function_id: FunctionID + ) -> OutputType: + # this is only stored on current node, because when we start a new path, + # we will deallocate it + already_warm = function_id in self.warmed_up_functions + if not already_warm: + log.debug("Running warmup of function %d", function_id.id) + else: + log.debug( + "Running eager of function %d because ancestor needed to warm up", + function_id.id, + ) + self.warmed_up_functions.add(function_id) + node = CUDAWarmupNode( + self.ids_to_funcs[function_id], + self.current_node, + self.cuda_graphs_thread_pool, + self.graph, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + already_warm, + self.new_warmup_node_id(), + ) + self.current_node = node + self.path_state = ExecutionState.WARMUP + self.update_generation() + return node.run(new_inputs) + + def new_graph_id(self) -> GraphID: + return GraphID(next(self.graph_counter)) + + def new_func_id(self) -> FunctionID: + return FunctionID(next(self.func_counter)) + + def add_function( + self, + model: ModelType, + inputs: list[InputType], + static_input_idxs: Sequence[int], + stack_traces: Optional[StackTraces], + mode: CompilationMode, + constants: tuple[torch.Tensor, ...], + placeholders: tuple[PlaceholderInfo, ...], + mutated_input_idxs: tuple[int, ...], + compile_id: Optional[CompileId], + ) -> tuple[ + ModelType, + OutputType, + ]: + id = self.new_func_id() + self.ids_to_stack_traces[id] = stack_traces + self.ids_to_funcs[id] = WrappedFunction( + model, + list(static_input_idxs), + id, + tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda), + placeholders, + mutated_input_idxs, + ) + self.id_to_mode[id] = mode + self.id_to_compile_id[id] = compile_id + fn = functools.partial(self.run, function_id=id) + + # container needs to set clean up when fn dies + get_container(self.device_index).add_strong_reference(fn) + return fn, fn(inputs) + + @property + def in_recording(self) -> bool: + return self.path_state == ExecutionState.RECORDING + + @property + def in_warmup(self) -> bool: + return self.path_state == ExecutionState.WARMUP + + def get_roots(self) -> Iterator[CUDAGraphNode]: + for nodes in self.roots.values(): + yield from nodes + + @property + def current_node(self) -> Optional[Union[CUDAGraphNode, CUDAWarmupNode]]: + return self._current_node + + @current_node.setter + def current_node( + self, value: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] + ) -> None: + self._current_node = value + if value is None: + self.path_state = ExecutionState.NONE + + def update_generation(self) -> None: + self.current_gen = self.get_curr_generation() + + @staticmethod + def get_curr_generation() -> int: + if MarkStepBox.mark_step_counter != 0: + return MarkStepBox.mark_step_counter + + return GenerationTracker.generation + + @staticmethod + def user_invoked_mark_step() -> bool: + return MarkStepBox.mark_step_counter != 0 + + def can_start_new_generation(self) -> bool: + if not self.in_new_torch_compile_invocation(): + return False + + if self.user_invoked_mark_step(): + return True + + return not self.running_forwards_with_pending_backwards + + def in_new_torch_compile_invocation(self) -> bool: + return self.current_gen != self.get_curr_generation() + + def try_end_curr_recording(self, function_id: FunctionID) -> None: + """ + Check if the current recording can be terminated, either because all outputs of the + previously recorded node are dead or because it was executed in a different + generation. Will set current_node to None and in_recording to False if successful. + """ + assert self.in_recording + assert self.current_node is not None + + # multiple invocations, allow overwriting the previous generation + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def try_end_curr_execution(self) -> None: + """ + Check if the current executing node can be terminated, either because all outputs of the + previously executed node are dead or because it was executed in a different generation. + Will set current_node to None if successful. + """ + + assert not self.in_recording + if self.current_node is None: + return + + if self.can_start_new_generation(): + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + + def try_end_curr_warmup(self, function_id: FunctionID) -> None: + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.current_node = None + return + + assert self.current_node is not None + if self.current_node.all_outputs_are_dead(): + self.current_node = None + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> None: + "Warn if we in a potential loop where we are unable to hit fast path" + if ( + function_id in self.warned_functions + or not self.in_new_torch_compile_invocation() + ): + return + + assert self.current_node is not None + existing_nodes = [ + node + for node in self.current_node._path_from_root + if node.wrapped_function.id == function_id + ] + + if len(existing_nodes) <= 1: + return + + # repeated same pattern + parents = OrderedSet( + [ + n.parent.wrapped_function.id + for n in itertools.chain(existing_nodes, (self.current_node,)) + if n.parent is not None + ] + ) + if len(parents) == len(existing_nodes): + return + + self.warned_functions.add(function_id) + warnings.warn( + "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. " + "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() " + "before each model invocation" + ) + + @staticmethod + def format_dealloc_msg(stack_trace: Optional[str]) -> str: + stack_trace = ( + stack_trace.strip() if stack_trace else "[Could not find stack trace]" + ) + return ( + "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " + f"Stack trace: {stack_trace}. " + "To prevent overwriting, clone the tensor outside of torch.compile() " + "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." + ) + + def dealloc_current_path_weakrefs(self) -> None: + assert self.current_node is not None + # TODO: we could also allow the these weak refs to continue to be allocated, + # but that adds some complications. + + stor_stack_trace: dict[int, Optional[str]] = {} + for node in self.current_node._path_from_root: + assert node.stack_traces is not None + assert len(node.tensor_weakrefs) == len(node.stack_traces) + for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces): + ten = None if t is None else t() + if ten is None: + continue + + torch._C._set_storage_access_error_msg( + ten, self.format_dealloc_msg(stack_trace) + ) + + # we would to enable the following assertion, but an internal model failed with a command + # that does not repro. len(node.outputs_weakrefs) == len(node.stack_traces) + # so, pessimistically assume that they might differ by doing the debug info + # loop separately from the dealloc loop + if self.disable_invalidate_aliases: + continue + + for storage_ref, stack_trace in zip( + node.outputs_weakrefs, node.stack_traces + ): + if not storage_ref: + continue + + stor_stack_trace[storage_ref.data_ptr()] = stack_trace + + deleted = OrderedSet[Any]() + for storage_ref in self.current_node.path_live_weakrefs(): + _storage_deref = storage_ref() + if _storage_deref and storage_ref.data_ptr() not in deleted: + deleted.add(storage_ref.data_ptr()) + + msg = self.format_dealloc_msg( + stor_stack_trace.get(storage_ref.data_ptr()) + ) + torch._C._free_And_Remove_DeleterFn(_storage_deref) + + if self.disable_invalidate_aliases: + continue + + torch._C._set_storage_data_ptr_access_error_msg(_storage_deref, msg) + + def clear_current_path_state_and_set_to_none(self) -> None: + assert isinstance(self.current_node, CUDAGraphNode) + self.current_node.clear_path_state() + self.current_node = None + + def apply_checkpoint_execution_state_in_allocator(self) -> None: + """ + Checkpoint the current execution state in the caching allocator so that + additional cudagraph recordings can be made respecting existent live storages. + """ + assert isinstance(self.current_node, CUDAGraphNode) + self.debug_checkpointing_counter += 1 + log.debug( + "Checkpointing cuda caching allocator state. Number of checkpoints %d", + self.debug_checkpointing_counter, + ) + + state = self.current_node.checkpointed_caching_state + device = self.current_node.device + assert state is not None and device is not None + + # currently we deallocate on instead of allowing stale recordings + stale_storages: list[int] = [] + + # remove cached tensors, otherwise they would prevent memory from being + # reclaimed in subsequent recordings + self.current_node.remove_path_cached_tensors() + live_storages_wrappers = list(self.current_node.path_live_weakrefs()) + + # path_live_weakrefs guarantees that t() will not be None + live_storages_weak_refs: list[int] = [t() for t in live_storages_wrappers] # type: ignore[misc] + ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation() + torch._C._cuda_setCheckpointPoolState( + device, state, stale_storages, live_storages_weak_refs + ) + + # NB: deduplicate aliased outputs + for ptr in OrderedSet(ptrs_to_deallocate): + torch._C._cuda_cudaCachingAllocator_raw_delete(ptr) + + # Now the live blocks should be exactly equal to the live storages in private pool + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers + ) + for wrapper in live_storages_wrappers: + storage_ptr = wrapper() + assert storage_ptr is not None + assert torch._C._has_Standard_Deleter(storage_ptr) + assert wrapper.data_ptr() not in ptrs_to_deallocate + + def live_cudagraph_pool_storages_in_curr_execution( + self, + ) -> list[StorageWeakRefPointer]: + if self.current_node is None: + return [] + # explicitly ignoring previous recorded outputs from past path + # path_live_weakrefs() guarantees that t() will not be None + return [t() for t in self.current_node.path_live_weakrefs()] # type: ignore[misc] diff --git a/phivenv/Lib/site-packages/torch/_inductor/cudagraph_utils.py b/phivenv/Lib/site-packages/torch/_inductor/cudagraph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1659125614e56125a7971d213311172deb1c5c6c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/cudagraph_utils.py @@ -0,0 +1,414 @@ +# mypy: disallow-untyped-defs +from __future__ import annotations + +import dataclasses +from enum import Enum +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +from torch._dynamo.utils import counters, get_metrics_context +from torch._inductor.utils import GraphPartitionMap, InputType +from torch.utils._ordered_set import OrderedSet + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +OutputType = list[Optional[Union[int, torch.Tensor]]] +ModelType = Callable[[list[InputType]], OutputType] + + +@dataclasses.dataclass(frozen=True) +class FunctionID: + "Unique counter of a function wrapped in cudagraphify_impl" + + id: int + + +@dataclasses.dataclass(frozen=True) +class PlaceholderInfo: + """ + A serializable version of torch.fx.Node that contains information + pertinent to placeholder stack traces. We use these in logging and error messages + related to cudagraphs, and will cache these results. + """ + + name: str + stack_trace: Optional[str] + # This field is recursive, but never cyclic (since a node never uses itself) + users: list[PlaceholderInfo] + mutating_use_stack_trace: Optional[str] + + +@dataclasses.dataclass(frozen=True) +class WrappedFunction: + """ + Represents a function that you want to record for CUDA graph replay, + with a little more metadata so we can identify if we have an applicable + CUDA graph in our CUDA graph tree for it. + """ + + model: Callable[..., Any] + static_input_idxs: Sequence[int] + id: FunctionID + constants: tuple[torch.Tensor, ...] + placeholders: Sequence[PlaceholderInfo] + mutated_input_idxs: Sequence[int] + + +def get_mutating_use_stack_trace_from_node( + placeholder_node: torch.fx.Node, +) -> Optional[str]: + # reinplaced uses might have a single, non-copy_ use + if len(placeholder_node.users) == 1: + return next(iter(placeholder_node.users)).meta.get("stack_trace", None) + + for use in placeholder_node.users: + if use.target == torch.ops.aten.copy_.default: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + + return None + + +def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]: + return placeholder_info.mutating_use_stack_trace + + +def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo: + name = placeholder_node.name + stack_trace = placeholder_node.meta.get("stack_trace", None) + users = [] + mutating_use_stack_trace = None + # Only recurse to users once, since we only care about user's stack traces + if placeholder_node.op == "placeholder": + users = [to_placeholder_info(i) for i in placeholder_node.users] + mutating_use_stack_trace = get_mutating_use_stack_trace_from_node( + placeholder_node + ) + + return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace) + + +def get_placeholder_info(graph: torch.fx.Graph) -> list[PlaceholderInfo]: + return [ + to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder" + ] + + +def format_default_skip_message(reason: str) -> str: + return f"skipping cudagraphs due to {reason}" + + +def get_mutation_stack_trace( + placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int] +) -> str: + stack_trace: Optional[str] = "" + + for idx in mutation_indices: + placeholder = placeholders[idx] + if stack_trace := get_mutating_use_stack_trace(placeholder): + break + + msg = format_default_skip_message( + f"mutated inputs ({len(mutation_indices)} instances)" + ) + if stack_trace: + return f"{msg}. Found from : \n {stack_trace}" + + return msg + + +def check_for_mutation( + func: WrappedFunction, + inputs: list[InputType], + is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool], +) -> Optional[str]: + # doesn't work for non-trees because the warmup run would apply mutation twice + if torch._inductor.config.triton.cudagraph_trees: + # checking if mutation is only on parameters/static inputs + mutation_indices: Sequence[int] = [ + idx + for idx in func.mutated_input_idxs + if not ( + idx in func.static_input_idxs + or is_cuda_graph_recorded_tensor(inputs[idx]) # type: ignore[arg-type] + ) + ] + else: + mutation_indices = func.mutated_input_idxs + + static_inputs_log.debug( + "check mutation static input indices: %s", func.static_input_idxs + ) + static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices) + + return ( + get_mutation_stack_trace(func.placeholders, mutation_indices) + if mutation_indices + else None + ) + + +def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]: + for use in node.users: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + return None + + +def check_multiple_devices_or_any_cpu_nodes( + device_node_mapping: dict[torch.device, torch.fx.Node], +) -> Optional[str]: + # meta tensors are supported since there is no compute + device_node_mapping.pop(torch.device("meta"), None) + + if torch._inductor.config.graph_partition: + # graph partition supports splitting on cpu op. So we can ignore cpu nodes. + device_node_mapping.pop(torch.device("cpu"), None) + + if cpu_node := device_node_mapping.get(torch.device("cpu")): + msg = f"cpu device ({cpu_node.name})" + if stack_trace := _get_use_stack_trace(cpu_node): + return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}") + + return format_default_skip_message(msg) + + if ( + len(device_node_mapping) == 1 + and next(iter(device_node_mapping.keys())).type == "cuda" + ): + return None + + keys_repr = (repr(key) for key in device_node_mapping.keys()) + return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}") + + +def check_lowering_disable_cudagraph( + device_node_mapping: dict[torch.device, torch.fx.Node], +) -> Optional[str]: + return check_multiple_devices_or_any_cpu_nodes(device_node_mapping) + + +def log_cudagraph_skip_and_bump_counter(msg: str) -> None: + perf_hint_log.warning(msg) + counters["inductor"]["cudagraph_skips"] += 1 + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.set("cudagraph_skip_reason", msg, overwrite=True) + + +@dataclasses.dataclass +class BoxedDeviceIndex: + value: Optional[int] + + def set(self, device_idx: Optional[int]) -> None: + assert device_idx is None or isinstance(device_idx, int) + self.value = device_idx + + +def check_for_mutation_ignore_cuda_graph_managed_tensor( + gm: torch.fx.GraphModule, + mutated_inputs: OrderedSet[str], + mutated_input_idxs: OrderedSet[int], + static_input_idxs: Sequence[int], +) -> Optional[str]: + default_msg = format_default_skip_message("mutated inputs") + + # doesn't work for non-trees because the warmup run would apply mutation twice + if torch._inductor.config.triton.cudagraph_trees: + unique_idxs = OrderedSet(static_input_idxs) + # checking if mutation is only on parameters/static inputs + mutation_indices = [idx for idx in mutated_input_idxs if idx not in unique_idxs] + has_mutation = len(mutation_indices) != 0 + if not has_mutation: + return None + placeholders = get_placeholder_info(gm.graph) + return get_mutation_stack_trace(placeholders, mutation_indices) + + else: + has_mutation = len(mutated_inputs) != 0 + return None if not has_mutation else default_msg + + +def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]: + """ + Gets the first non-empty stack trace of a placeholder or its users. + """ + if placeholder.stack_trace: + return placeholder.stack_trace + + for user in placeholder.users: + if user.stack_trace: + return user.stack_trace + + return None + + +class CheckInvariantStatus(Enum): + # Check invariant succeeded + SUCCESS = 1 + + # Previously managed data pointers are not stable + CudagraphManagedIdxMismatch = 2 + + # Static tensor input addresses are not stable + StaticInputIdxMismatch = 3 + + # Expected dead indices before graph are live + ExpectedDeadIndicesBeforeGraphMismatch = 4 + + def __str__(self) -> str: + if self.name == "CudagraphManagedIdxMismatch": + return "cudagraph managed tensor data pointer changed" + elif self.name == "StaticInputIdxMismatch": + return "static input data pointer changed" + elif self.name == "ExpectedDeadIndicesBeforeGraphMismatch": + return "expected dead indices before graph are live" + else: + return f"{self.name}: {self.value}" + + +def log_data_ptr_mismatch( + placeholders: Sequence[PlaceholderInfo], + inputs: list[InputType], + recorded_data_ptr: Sequence[Optional[int]], + target_idxs: Sequence[int], + mismatch: CheckInvariantStatus, +) -> str: + """ + Logs the mismatch between input data pointers and recorded data pointers. + This checks only idxs in target_idxs. + """ + assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(placeholders), ( + "length mismatch between inputs, recorded_data_ptr, and placeholders" + ) + + t_tensors = [inputs[i] for i in target_idxs] + t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs] + error_msg = f"{mismatch}.\n" + for i, (tensor, data_ptr) in enumerate(zip(t_tensors, t_data_ptrs)): + assert isinstance(tensor, torch.Tensor) + index = target_idxs[i] + if tensor.data_ptr() != data_ptr: + placeholder = placeholders[index] + error_msg = ( + f"{error_msg}input name: {placeholder.name}. " + f"data pointer changed from {data_ptr} to {tensor.data_ptr()}. " + f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n" + ) + return error_msg + + +def maybe_warning_due_to_dynamic_shape( + fn_cache: dict[tuple[int, ...], Callable[..., Any]], + new_int_key: Any, +) -> bool: + num_cudagraphs = len(fn_cache.keys()) + 1 + + def warn_msg() -> str: + return ( + "CUDAGraph supports dynamic shapes by recording a new graph for each " + "distinct input size. Recording too many CUDAGraphs may lead to " + f"extra overhead. We have observed {num_cudagraphs} distinct sizes. " + "Please consider the following options for better performance: " + "a) padding inputs to a few fixed number of shapes; or b) set " + "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. " + "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None " + "to silence this warning." + ) + + if ( + torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit + and num_cudagraphs + > torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit + ): + perf_hint_log.warning(warn_msg()) + return True + + return False + + +@dataclasses.dataclass(frozen=True) +class CudagraphCachedInfo: + """ + Info needed to realign inputs + """ + + placeholders: Sequence[PlaceholderInfo] + stack_traces: list[Optional[str]] + cudagraph_fail_reasons: list[str] + + +@dataclasses.dataclass(frozen=True) +class CudagraphMetadata: + """ + Metadata for recording a CUDA graph. + """ + + placeholders: Sequence[PlaceholderInfo] + static_input_idxs: OrderedSet[int] + mutated_input_idxs: OrderedSet[int] + stack_traces: list[Optional[str]] + constants: dict[str, torch.Tensor] + + +def get_partition_cudagraph_metadata( + partition_map: GraphPartitionMap, + metadata: CudagraphMetadata, +) -> CudagraphMetadata: + """ + Convert the cudagraph metadata at the graph level to the graph partition level, + given the graph partition info (i.e., mapping from partition input/output index + to graph input/output index). + """ + + partition_placeholders = [] + partition_static_input_idxs: OrderedSet[int] = OrderedSet() + partition_mutated_input_idxs: OrderedSet[int] = OrderedSet() + for partition_input_idx, graph_input_idx in enumerate( + partition_map.input_index_mapping + ): + if graph_input_idx in metadata.static_input_idxs: + partition_static_input_idxs.add(partition_input_idx) + + if graph_input_idx in metadata.mutated_input_idxs: + partition_mutated_input_idxs.add(partition_input_idx) + + if graph_input_idx is not None: + placeholder = metadata.placeholders[graph_input_idx] + else: + # create a dummy placeholder info since this partition input is not a graph input + placeholder = PlaceholderInfo( + name=f"partition_{partition_map.id}_placeholder_{partition_input_idx}", + stack_trace=None, + users=[], + mutating_use_stack_trace=None, + ) + partition_placeholders.append(placeholder) + + partition_stack_traces = [] + for graph_output_idx in partition_map.output_index_mapping: + if graph_output_idx is not None: + partition_stack_traces.append(metadata.stack_traces[graph_output_idx]) + else: + partition_stack_traces.append(None) + + partition_constants = { + name: metadata.constants[name] for name in partition_map.constant_names + } + + return CudagraphMetadata( + partition_placeholders, + partition_static_input_idxs, + partition_mutated_input_idxs, + partition_stack_traces, + partition_constants, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/custom_graph_pass.py b/phivenv/Lib/site-packages/torch/_inductor/custom_graph_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..fabe988d17fb287b4ee1e0fefd582f71b71030cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/custom_graph_pass.py @@ -0,0 +1,104 @@ +import hashlib +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Callable, Optional, Union +from typing_extensions import TypeAlias + +import torch.fx.graph + + +class CustomGraphPass(ABC): + """ + Implement this interface for custom Graph passes: + + 1) The __call__() method contains the implementation of the custom pass. + + 2) The uuid() method enables inductor to cache compiled graphs when your custom + passes are applied. This method can return any identifier as long as it uniquely + identifies your implementation (and can be pickled). The caching logic includes this + identifier in its key calculation, i.e., any new value will effectively invalidate + existing entries. We expect custom passes would typically depend purely on the + textual representation of the implementation. In that case, we recommend using the + 'get_hash_for_files' helper below to compute a unique hash from the contents of a + static list of source files, i.e., the source(s) containing the custom pass + implementation. That approach ensures that any change to the implementation will + mean a new uuid. + + ** IMPORTANT ** If your custom pass's behavior depends on some external state, then + you'll need to implement something more complicated (or disable caching). + + EXAMPLE: + + class MyCustomGraphPass(CustomGraphPass): + def __call__(self, graph: torch.fx.graph.Graph) -> None: + # my custom graph optimization pass + # ... + + def uuid(self) -> Optional[Any]: + return get_hash_for_files((__file__,)) + + """ + + @abstractmethod + def __call__(self, graph: torch.fx.graph.Graph) -> None: + """ + Implementation of the custom pass. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom pass implementation. Return None + to skip inductor code caching entirely. + """ + + +class CustomGraphModulePass(ABC): + """ + Implement this interface for custom Graph passes: + + 1) The __call__() method contains the implementation of the custom pass. + + 2) The uuid() method enables inductor to cache compiled graphs when your custom + passes are applied. This method can return any identifier as long as it uniquely + identifies your implementation (and can be pickled). The caching logic includes this + identifier in its key calculation, i.e., any new value will effectively invalidate + existing entries. We expect custom passes would typically depend purely on the + textual representation of the implementation. In that case, we recommend using the + 'get_hash_for_files' helper below to compute a unique hash from the contents of a + static list of source files, i.e., the source(s) containing the custom pass + implementation. That approach ensures that any change to the implementation will + mean a new uuid. + """ + + @abstractmethod + def __call__(self, gm: torch.fx.GraphModule) -> None: + """ + Implementation of the custom pass. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom pass implementation. Return None + to skip inductor code caching entirely. + """ + + +CustomGraphPassType: TypeAlias = Optional[ + Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]] +] + + +@lru_cache(1) +def get_hash_for_files(paths: tuple[str], extra: str = "") -> bytes: + """ + Helper to compute a unique string by hashing the contents of a list of files. + """ + hasher = hashlib.sha256() + hasher.update(extra.encode("utf-8")) + for path in paths: + with open(path, "rb") as f: + hasher.update(path.encode("utf-8")) + hasher.update(f.read()) + return hasher.digest() diff --git a/phivenv/Lib/site-packages/torch/_inductor/debug.py b/phivenv/Lib/site-packages/torch/_inductor/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f77bb8a8927f667c221abfd19355ecb282ce24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/debug.py @@ -0,0 +1,961 @@ +import collections +import contextlib +import copy +import dataclasses +import functools +import io +import itertools +import json +import logging +import os +import os.path +import pickle +import pstats +import shutil +import traceback +from collections.abc import Iterator +from typing import Any, Callable, IO, Optional, Union +from unittest.mock import patch + +import torch +from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled +from torch import fx as fx +from torch._dynamo.repro.after_aot import save_graph_repro +from torch._dynamo.utils import get_debug_dir +from torch._logging import getArtifactLogger +from torch.fx.graph_module import GraphModule +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.fx.passes.tools_common import legalize_graph +from torch.types import FileLike +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map + +from . import config, ir # noqa: F811, this is needed +from .scheduler import ( + BaseSchedulerNode, + FusedSchedulerNode, + NopKernelSchedulerNode, + OutputNode, + SchedulerNode, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + +ir_pre_fusion_log = getArtifactLogger(__name__, "ir_pre_fusion") +ir_post_fusion_log = getArtifactLogger(__name__, "ir_post_fusion") +SchedulerNodeList = list[Any] +BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) +GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] + + +@functools.cache +def has_dot() -> bool: + return shutil.which("dot") is not None + + +def draw_buffers( + nodes: list[BaseSchedulerNode], + print_graph: bool = False, + fname: Optional[str] = None, +) -> None: + """ + Draw a graph in fname.svg. + """ + if not has_dot(): + log.warning("draw_buffers() requires `graphviz` package") + return + + if fname is None: + fname = get_graph_being_compiled() + + graph = create_fx_from_snodes(nodes) + + for node in graph.nodes: + if "fusion_meta" not in node.meta: + continue + group = node.meta["fusion_meta"].group + if isinstance(group, tuple): + if isinstance(group[1], int): + group = (group[1],) + else: + group = group[1] + + # gather meta data + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type] + node.meta["tensor_meta"] = metadata + + if print_graph: + print(graph) + + gm = GraphModule({}, graph) + legalize_graph(gm) + gm.graph.lint() + draw_graph( + gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape + ) + + +def create_fx_from_snodes(snodes: list[BaseSchedulerNode]) -> fx.Graph: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + + def get_fake_func(name: str) -> Callable[..., int]: + def func1(*args: Any) -> int: + return 0 + + func1.__name__ = name + return func1 + + FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) + + buf_to_fx_node = {} + node_to_fx_node = {} + graph = torch.fx.Graph() + first_node = None + + outputs = [] + group: Any = None + # create call_function node for each Buffer and Kernel + for snode in snodes: + if snode.is_extern(): + node_type = "extern" + group = node_type + elif snode.is_template(): + node_type = "template" + group = node_type + elif isinstance(snode, NopKernelSchedulerNode): + node_type = "nop" + group = node_type + elif isinstance(snode, SchedulerNode): + node_type = "compute" + group = snode.group + elif isinstance(snode, FusedSchedulerNode): + node_type = "fused" + group = snode.group + else: + raise RuntimeError("Unknown node type") + + fused_name = torch._inductor.utils.get_fused_kernel_name( + snode.get_nodes(), "original_aten" + ) + func_name = f"{node_type}: {fused_name}" + node_func = get_fake_func(func_name) + kwargs = {} + if hasattr(snode, "get_device"): + kwargs = {"device": snode.get_device()} + fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type] + + def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: + if isinstance(snode, FusedSchedulerNode): + return any(in_output(x) for x in snode.snodes) + return any( + isinstance(user.node, OutputNode) + for buf in snode.get_outputs() + for user in buf.users + ) + + if in_output(snode): + outputs.append(fx_node) + name = snode.get_name() + fx_node.name = name + + fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) + + node_to_fx_node[name] = fx_node + for buf in snode.get_outputs(): + buf_to_fx_node[buf.get_name()] = fx_node + + if first_node is None: + first_node = fx_node + + # create edges between nodes + for snode in snodes: + name = snode.get_name() + deps = snode.read_writes.reads + + fx_node = node_to_fx_node[name] + new_args = [] + for dep in deps: + if dep.name in buf_to_fx_node: + dep_node = buf_to_fx_node[dep.name] + else: + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) + buf_to_fx_node[dep.name] = dep_node + if dep_node == fx_node: # to avoid cycles + continue + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + return graph + + +def update_orig_fx_node_name_to_buf_name( + nodes: Optional[SchedulerNodeList], + node_name_to_buf_name: dict[str, str], + parent_buf_name: Optional[str] = None, + n_origins: int = 0, +) -> None: + if nodes is None: + return + for node in nodes: + # for FusedSchedulerNode, traverse recursively into get_nodes() + buf_name = node.get_name() + children_nodes = node.get_nodes() + if children_nodes is not None and len(children_nodes) > 1: + update_orig_fx_node_name_to_buf_name( + children_nodes, + node_name_to_buf_name, + buf_name if parent_buf_name is None else parent_buf_name, + ) + continue + else: + assert len(children_nodes) == 1 and children_nodes[0] == node + + ir_node = node.node + if ir_node is None or ir_node.origins is None: + continue + for origin in ir_node.origins: + node_name = origin.name + # when buf1 and buf2 both have origin=node1 + # we draw node1 according to buf1 + if node_name not in node_name_to_buf_name: + node_name_to_buf_name[node_name] = ( + buf_name if parent_buf_name is None else parent_buf_name + ) + + +def get_node_name_to_buf_meta( + node_name_to_buf_name: dict[str, str], +) -> dict[str, BufMeta]: + buf_name_to_n_node = {} + for node_name, buf_name in node_name_to_buf_name.items(): + if buf_name not in buf_name_to_n_node: + buf_name_to_n_node[buf_name] = OrderedSet([node_name]) + else: + buf_name_to_n_node[buf_name].add(node_name) + + node_name_to_buf_meta = {} + for node_name, buf_name in node_name_to_buf_name.items(): + n_node = len(buf_name_to_n_node[buf_name]) + node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node) + return node_name_to_buf_meta + + +def annotate_orig_fx_with_snodes( + gm: torch.fx.GraphModule, + snodes: SchedulerNodeList, +) -> None: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + node_name_to_buf_name: dict[str, str] = {} + update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name) + if node_name_to_buf_name is None: + return + node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name) + for node in gm.graph.nodes: + if node.name in node_name_to_buf_meta: + node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name) + + +@contextlib.contextmanager +def enable_aot_logging() -> Iterator[None]: + compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + import torch._functorch.aot_autograd + + log = logging.getLogger(torch._functorch.aot_autograd.__name__) + + stack = contextlib.ExitStack() + if not compile_debug: + try: + yield + finally: + stack.close() + return + + # Enable all graphs to be logged to a file by setting the flags to True + # and the log level of the file logger to DEBUG + stack.enter_context(patch("functorch.compile.config.debug_partitioner", True)) + + path = os.path.join(get_debug_dir(), "torchinductor") + os.makedirs(path, exist_ok=True) + + fh = logging.FileHandler( + os.path.join( + path, + f"aot_{get_aot_graph_name()}_debug.log", + ) + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(fh) + try: + yield + finally: + log.removeHandler(fh) + stack.close() + + +# Used for provenance tracking +# They are not stored in DebugContext because they are not set in +# _inductor_triton_kernel_to_post_grad_node_info's Debug Context +_inductor_post_to_pre_grad_nodes: dict[str, Any] = {} +_pre_grad_graph_id: Optional[int] = None + + +class DebugContext: + _counter = itertools.count() + + # Used for provenance tracking + _inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {} + + @staticmethod + def create_debug_dir(folder_name: str) -> Optional[str]: + debug_dir = config.trace.debug_dir or get_debug_dir() + for n in DebugContext._counter: + dirname = os.path.join( + debug_dir, + "torchinductor", + f"{folder_name}.{n}", + ) + if not os.path.exists(dirname): + os.makedirs(dirname) + return dirname + return None + + def __init__(self) -> None: + self._prof = None + self._path = None + self._stack = contextlib.ExitStack() + + def copy(self, new_path: str) -> None: + if not self._path: + return + assert new_path.endswith(".debug"), new_path + from filelock import FileLock + + try: + with FileLock(f"{new_path}.lock"): + if os.path.exists(new_path): + shutil.rmtree(new_path) + shutil.copytree(self._path, new_path) + except OSError: + log.warning( + "Failed to copy debug files from %s to %s", self._path, new_path + ) + + def fopen( + self, + filename: str, + write_mode: str = "w", + *args: Any, + **kwargs: Any, + ) -> IO[Any]: + assert self._path + return open(os.path.join(self._path, filename), write_mode, *args, **kwargs) + + @contextlib.contextmanager + def fopen_context( + self, + filename: str, + write_mode: str = "w", + *args: Any, + **kwargs: Any, + ) -> Iterator[IO[Any]]: + assert self._path + with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f: + yield f + + def filename(self, suffix: str) -> str: + assert self._path + return os.path.join(self._path, suffix) + + def upload_tar(self) -> None: + if config.trace.upload_tar is not None: + import tarfile + + assert self._path + tar_file = os.path.join( + self._path, f"{os.path.basename(self._path)}.tar.gz" + ) + with tarfile.open(tar_file, "w:gz") as tar: + tar.add(self._path, arcname=os.path.basename(self._path)) + config.trace.upload_tar(tar_file) + + def __enter__(self) -> None: + if config.debug: + log = logging.getLogger("torch._dynamo") + prev_level = log.level + log.setLevel(logging.DEBUG) + + def reset_log_level(level: Any) -> None: + log.setLevel(level) + + self._stack.callback(reset_log_level, prev_level) + + self._stack.enter_context(V.set_debug_handler(self)) + + if not config.trace.enabled: + return + + self._path = self.create_debug_dir(get_aot_graph_name()) # type: ignore[assignment] + + if config.trace.debug_log: + self._setup_log_capture("debug.log", logging.DEBUG) + if config.trace.info_log: + self._setup_log_capture("info.log", logging.INFO) + + def _setup_log_capture( + self, + filename: str, + level: int, + ) -> None: + log = logging.getLogger("torch._inductor") + fd = self._stack.enter_context(self.fopen(filename)) + ch = logging.StreamHandler(fd) + ch.setLevel(level) + ch.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(ch) + log.setLevel(min(log.level, level)) + self._stack.callback(log.removeHandler, ch) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + if self._prof: + self._prof.disable() + self._save_profile_data() + + if self._path: + self.upload_tar() + log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path) + self._stack.close() + + def _save_profile_data(self) -> None: + assert self._prof + self._prof.dump_stats(self.filename("compile.prof")) + with self.fopen("compile.stats") as fd: + stats = pstats.Stats(self._prof, stream=fd) + stats.strip_dirs() + stats.sort_stats("cumtime") + stats.print_stats(100) + stats.sort_stats("tottime") + stats.print_stats(100) + + def __getattr__(self, name: str) -> Optional[Callable[..., None]]: + if config.trace.enabled and getattr(config.trace, name): + try: + return getattr(DebugFormatter(self), name) + except Exception: + log.warning("Ignoring exception in debug code", exc_info=True) + return None + else: + + def ignored(*args: Any, **kwargs: Any) -> None: + pass + + return ignored + + +class DebugFormatter: + def __init__(self, handler: DebugContext) -> None: + self.fopen = handler.fopen + self.fopen_context = handler.fopen_context + self.filename = handler.filename + self.handler = handler + + def fx_graph( + self, + gm: torch.fx.GraphModule, + inputs: list[torch.Tensor], + ) -> None: + with self.fopen("fx_graph_runnable.py") as fd: + save_dir = None + if torch._inductor.config.trace.save_real_tensors: + inputs = torch._subclasses.fake_utils.try_convert_fake_to_real(inputs) + save_dir = os.path.dirname(fd.name) + + # dont try to use stable hash torchinductor compilation if saving real tensors + # and avoid recursively trying to save real tensors inside of the inductor compilation + # regardless + stable_hash = torch._inductor.config.trace.save_real_tensors + with torch._inductor.config.patch( + {"trace.enabled": False, "trace.save_real_tensors": False} + ): + save_graph_repro( + fd, + gm, + inputs, + "inductor", + save_dir=save_dir, + stable_hash=stable_hash, + ) + + with self.fopen("fx_graph_readable.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def fx_graph_transformed( + self, + gm: torch.fx.GraphModule, + inputs: list[torch.Tensor], + ) -> None: + with self.fopen("fx_graph_transformed.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None: + with self.fopen("ir_pre_fusion.txt") as fd: + fd.write(self._write_ir(nodes)) + + def ir_post_fusion(self, nodes: SchedulerNodeList) -> None: + with self.fopen("ir_post_fusion.txt") as fd: + fd.write(self._write_ir(nodes)) + + @staticmethod + def _write_ir(nodes: SchedulerNodeList) -> str: + buf = io.StringIO() + for node in nodes: + buf.write(node.debug_str()) + buf.write("\n\n\n") + return buf.getvalue() + + def graph_diagram(self, nodes: SchedulerNodeList) -> None: + draw_buffers(nodes, fname=self.filename("graph_diagram.svg")) + + def draw_orig_fx_graph( + self, + gm: torch.fx.GraphModule, + nodes: SchedulerNodeList, + ) -> None: + annotate_orig_fx_with_snodes(gm, nodes) + draw_graph( + gm, + fname=self.filename("orig_fx_graph_diagram.svg"), + clear_meta=False, + prog=GRAPHVIZ_COMMAND_SCALABLE, + parse_stack_trace=True, + dot_graph_shape=config.trace.dot_graph_shape, + ) + + def output_code(self, filename: str, extension: str = "py") -> None: + shutil.copy(filename, self.filename(f"output_code.{extension}")) + + def log_inductor_triton_kernel_to_post_grad_node_info( + self, filename: str = "inductor_generated_kernel_to_post_grad_nodes.json" + ) -> tuple[dict[str, list[str]], dict[str, Any]]: + debug_info = {} + with self.fopen(filename, "w") as fd: + log.info("Writing provenance tracing debugging info to %s", fd.name) + debug_info = DebugContext._inductor_triton_kernel_to_post_grad_node_info + json.dump(debug_info, fd) + node_mapping = {} + if _pre_grad_graph_id: + with self.fopen( + "inductor_provenance_tracking_node_mappings.json", "w" + ) as fd: + node_mapping = create_node_mapping( + _pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info + ) + json.dump(node_mapping, fd) + return debug_info, node_mapping + + def log_autotuning_results( + self, + name: str, + input_nodes: list[ir.IRNode], + timings: dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821 + elapse: float, + precompile_elapse: float, + prescreening_elapse: Optional[float], + ) -> None: + from .ir import FixedLayout + + def build_node_info(node: ir.IRNode) -> dict[str, str]: + if hasattr(node, "name"): + node_name = node.name + else: + node_name = "" + node_info = { + "name": node_name, + "type": type(node).__name__, + } + try: + layout = node.get_output_spec() + if isinstance(layout, FixedLayout): + offset = 0 + try: + offset = int(layout.offset) + except Exception: + try: + offset = V.graph.sizevars.size_hint( + layout.offset, fallback=0 + ) + except Exception: + pass + static_layout = FixedLayout( + layout.device, + dtype=layout.dtype, + size=[*V.graph.sizevars.size_hints(layout.size)], + stride=[*V.graph.sizevars.size_hints(layout.stride)], + offset=offset, + ) + node_info["layout"] = str(static_layout) + else: + node_info["layout"] = str(layout) + except Exception: + pass + try: + node_info["dtype"] = str(node.get_dtype()) + except Exception: + pass + try: + node_info["device"] = str(node.get_device()) + except Exception: + pass + try: + node_info["stride"] = str( + V.graph.sizevars.size_hints(node.get_stride()) + ) + except Exception: + pass + try: + node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size())) # type: ignore[arg-type] + except Exception: + pass + try: + node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel())) + except Exception: + pass + if hasattr(node, "data") and isinstance(node.data, ir.IRNode): + node_info["data"] = build_node_info(node.data) + return node_info + + general_properties = { + "op_name": name, + "cuda_device_name": torch.cuda.get_device_name(), + "cuda_device_count": torch.cuda.device_count(), + "input_nodes": [build_node_info(node) for node in input_nodes], + "autotuning_time": elapse, + "precompile_time": precompile_elapse, + "prescreening_time": prescreening_elapse, + } + with self.fopen_context( + "autotuning_result_json_list.txt", "at", encoding="utf-8" + ) as fd: + for caller, time in timings.items(): + info_dict = dict(caller.info_dict()) + info_dict.update(general_properties) + info_dict["benchmark_result"] = time + json.dump(info_dict, fd) + fd.write("\n") + + +def log_ir_pre_fusion(nodes: SchedulerNodeList) -> None: + if ir_pre_fusion_log.isEnabledFor(logging.INFO): + ir_pre_fusion_log.info("BEFORE FUSION\n%s", DebugFormatter._write_ir(nodes)) + + V.debug.ir_pre_fusion(nodes) + + +def log_ir_post_fusion(nodes: SchedulerNodeList) -> None: + if ir_post_fusion_log.isEnabledFor(logging.INFO): + ir_post_fusion_log.info("AFTER FUSION\n%s", DebugFormatter._write_ir(nodes)) + + V.debug.ir_post_fusion(nodes) + + +@dataclasses.dataclass +class TensorMetadataHolder: + tensor_metadata: TensorMetadata + device: torch.device + + +save_args_cnt = itertools.count() + + +def create_node_mapping( + pre_grad_graph_id: int, + post_to_pre_grad_nodes_json: dict[str, Any], + triton_kernel_to_post_grad_json: dict[str, Any], +) -> dict[str, dict[str, Any]]: + """Create bidirectional mappings between: + + - pre_grad graph nodes and post_grad graph code nodes, and vice versa + - triton kernel name and post_grad graph code nodes, and vice versa + """ + + # return a dummy dict if there's any error + empty_return: dict[str, dict[str, Any]] = { + "preToPost": {}, + "postToPre": {}, + "cppCodeToPost": {}, + "postToCppCode": {}, + } + + log.info("Creating node mappings for provenance tracking") + + if not isinstance(post_to_pre_grad_nodes_json, dict): + log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") + return empty_return + + if not isinstance(triton_kernel_to_post_grad_json, dict): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" + ) + return empty_return + + if not isinstance(pre_grad_graph_id, int): + log.error("Provenance tacking error: pre_grad_graph_id is not an int") + return empty_return + + pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet) + post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet) + + post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet) + + try: + for outer_key, node_array in triton_kernel_to_post_grad_json.items(): + if not isinstance(node_array, list): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list" + ) + return empty_return + for curr_node in node_array: + post_to_cpp_code[curr_node].add(outer_key) + + def check_format(node: dict[str, Any]) -> bool: + if not isinstance(node, dict): + log.error( + "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json is not a dict" + ) + return False + if "graph_id" not in node or "name" not in node or "from_node" not in node: + log.error( + "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json has wrong format" + ) + return False + return True + + for outer_key, node_array in post_to_pre_grad_nodes_json.items(): + if not isinstance(node_array, list): + log.error( + "Provenance tacking error: post_to_pre_grad_nodes_json value is not a list" + ) + return empty_return + for node in node_array: + if not check_format(node): + return empty_return + # Check the current node first + if node.get("graph_id") == pre_grad_graph_id: + pre_to_post[node["name"]].add(outer_key) + post_to_pre[outer_key].add(node["name"]) + + # Check nested from_node array recursively, add node with the right graph_id to the map + stack = [(n, outer_key) for n in node.get("from_node", [])] + while stack: + current_node, parent_key = stack.pop() + if not check_format(current_node): + return empty_return + if current_node.get("graph_id") == pre_grad_graph_id: + pre_to_post[current_node["name"]].add(parent_key) + post_to_pre[parent_key].add(current_node["name"]) + stack.extend( + (n, parent_key) for n in current_node.get("from_node", []) + ) + + def convert_sets_to_lists(d: dict[str, Any]) -> None: + for key in d: + d[key] = list(d[key]) + d = dict(d) + + # convert to list because set is not JSON serializable + convert_sets_to_lists(pre_to_post) + convert_sets_to_lists(post_to_pre) + convert_sets_to_lists(post_to_cpp_code) + return { + "preToPost": pre_to_post, + "postToPre": post_to_pre, + "cppCodeToPost": triton_kernel_to_post_grad_json, + "postToCppCode": post_to_cpp_code, + } + except Exception as e: + # Since this is just logging code, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + log.error("Unexpected error in create_node_mapping: %s", e) + log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json) + log.error( + "triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json + ) + log.error("pre_grad_graph_id: %s", pre_grad_graph_id) + log.error(traceback.format_exc()) + return empty_return + + +def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: + """ + This function is used to save arguments for a compile_fx_inner function call + to the file system. Later on one can replay the compile_fx_inner call + with the saved arguments using load_args_and_run_compile_fx_inner. + """ + + folder = "/tmp/inductor_saved_args" + if not os.path.exists(folder): + os.mkdir(folder) + + def handle_tensor(x: Any) -> Any: + """ + Pickle FakeTensor will result in error: + AttributeError: Can't pickle local object 'WeakValueDictionary.__init__..remove' + + Convert all Tensor to metadata. This may also makes pickle faster. + """ + if isinstance(x, torch.Tensor): + return TensorMetadataHolder(_extract_tensor_metadata(x), x.device) + else: + return x + + args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs)) + + fn_name = "compile_fx_inner" + path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl" + with open(path, "wb") as f: + pickle.dump((args_to_save, kwargs_to_save), f) + + if log.isEnabledFor(logging.DEBUG): + message = f""" +Arguments for a compile_fx_inner call is saved to {path}. To replay the call, +run the following: + +from torch._inductor.debug import load_args_and_run_compile_fx_inner +load_args_and_run_compile_fx_inner({path!r}) + """ + # call print rather than log.debug. log.debug will print message + # prefix for each line which makes the code snippet harder to be + # copied. + # Not a big deal since the code is already been guarded by checking + # the log level. + print(message) + + +def load_args_and_run_compile_fx_inner(path: str) -> Any: + from torch._inductor.compile_fx import compile_fx_inner + + with open(path, "rb") as f: + args, kwargs = pickle.load(f) + + def handle_tensor(x: Any) -> Any: + if isinstance(x, TensorMetadataHolder): + return torch._dynamo.testing.rand_strided( + x.tensor_metadata.shape, + x.tensor_metadata.stride, + x.tensor_metadata.dtype, + x.device, + ) + else: + return x + + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + with fake_mode, config.patch("save_args", False): + args, kwargs = tree_map(handle_tensor, (args, kwargs)) + return compile_fx_inner(*args, **kwargs) + + +def aot_inductor_minifier_wrapper( + func: Callable[..., str], + exported_program: torch.export.ExportedProgram, + *, + inductor_configs: dict[str, Any], + package_path: Optional[FileLike] = None, +) -> str: + from torch._dynamo.debug_utils import AccuracyError + from torch._dynamo.repro.aoti import dump_to_minify + from torch._inductor import config + from torch._inductor.compile_fx import _aoti_flatten_inputs + + use_minifier = config.aot_inductor.dump_aoti_minifier + + gm = exported_program.module() + assert isinstance(gm, torch.fx.GraphModule) + + args, kwargs = exported_program.example_inputs + + try: + if use_minifier and config.aot_inductor.repro_level == 3: + # Always dump the original module in case we have segfaults + dump_to_minify( + exported_program, + "aot_inductor", + options=inductor_configs, + ) + if use_minifier and config.aot_inductor.repro_level == 4: + # Check for accuracy + # We will first flatten the inputs before compiling and checking for accuracy. + # This is ok because we will flatten the inputs in the minifier anyway. + gm_copy = copy.deepcopy(gm) + example_inputs_copy = copy.deepcopy(exported_program.example_inputs) + config_copy = copy.deepcopy(inductor_configs) + flat_example_inputs, config_copy = _aoti_flatten_inputs( + gm_copy, + example_inputs_copy[0], + example_inputs_copy[1], + options=config_copy, + ) + tuple_inputs = tuple(flat_example_inputs) + flattened_ep = torch.export.export(gm_copy, tuple_inputs, strict=False) + func( + flattened_ep.module(), + tuple_inputs, + inductor_configs=config_copy, + package_path=package_path, + load_and_run=True, + check_accuracy="accuracy", + ) + + return func( + gm, + args, + kwargs, + inductor_configs=inductor_configs, + package_path=package_path, + load_and_run=use_minifier, + ) + except AccuracyError as e: + dump_to_minify( + exported_program, + "aot_inductor_accuracy", + command="minify", + options=inductor_configs, + ) + log.warning("Accuracy failed") + raise e + except Exception as e: + if use_minifier: + command = "minify" + + if config.aot_inductor.repro_level == 1: + command = "run" + + dump_to_minify( + exported_program, + "aot_inductor", + command=command, + options=inductor_configs, + ) + raise e diff --git a/phivenv/Lib/site-packages/torch/_inductor/decomposition.py b/phivenv/Lib/site-packages/torch/_inductor/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..1c512649ed8dd1e1718bae8dc90517e624128776 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/decomposition.py @@ -0,0 +1,1152 @@ +# mypy: allow-untyped-decorators +import functools +import logging +import math +import operator +import sys +import typing +from typing import Any, Callable, Optional, TypeVar, Union +from typing_extensions import ParamSpec, TypeAlias + +import torch +import torch._decomp as decomp +import torch._prims_common as utils +import torch.ao.quantization.fx._decomposed +from torch._decomp import ( + core_aten_decompositions, + get_decompositions, + remove_decompositions, +) +from torch._decomp.decompositions import ( + _grid_sampler_2d as decomp_grid_sampler_2d, + _index_add, + embedding_dense_backward as decomp_embedding_dense_backward, + pw_cast_for_opmath, + pw_cast_for_opmath_non_tensor_args, +) +from torch._decomp.decompositions_for_rng import extra_random_decomps +from torch._dynamo.utils import counters +from torch._environment import is_fbcode +from torch._higher_order_ops.out_dtype import out_dtype +from torch._inductor.utils import pad_listlike +from torch._prims_common import ( + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + type_to_dtype, +) +from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_size_oblivious, + statically_known_true, +) + +from . import config, inductor_prims +from .utils import ( + is_gpu, + needs_fallback_due_to_atomic_add_limitations, + use_scatter_fallback, +) + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +_GenericOperator: TypeAlias = Union[ + torch._ops.OperatorBase, torch._ops.OpOverloadPacket +] + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims +quantized = torch.ops.quantized +_quantized = torch.ops._quantized +quantized_decomposed = torch.ops.quantized_decomposed + +inductor_decompositions = get_decompositions( + [ + aten._adaptive_avg_pool2d_backward, + aten.index_select, + aten.addmv, + aten.arange, + aten.bitwise_and_, + aten.bitwise_or_, + aten.clamp_min_, + aten.dist, + aten.elu, + aten.empty_like, + aten.flip, + aten.gelu, + aten.hardtanh, + aten.lcm, + aten.leaky_relu, + aten.linalg_vector_norm, + aten._log_softmax, + aten.max_pool2d_with_indices_backward, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, + aten._batch_norm_with_update, + aten._batch_norm_with_update_functional, + aten._batch_norm_no_update, + aten.batch_norm_backward, + aten.native_batch_norm, + aten.native_group_norm, + aten.native_layer_norm, + aten.nll_loss2d_backward, + aten.permute_copy, + aten.rrelu_with_noise_backward, + aten._softmax, + aten.sin_, + aten.sqrt_, + out_dtype, + aten._to_copy, + aten.tril_indices, + aten.triu_indices, + aten.unbind_copy.int, + aten.upsample_bilinear2d.vec, + quantized.linear_dynamic_fp16_unpacked_weight, + _quantized.wrapped_quantized_linear, + ] +) +decompositions = {**core_aten_decompositions(), **inductor_decompositions} + +# Remove unwanted decompositions included via the core ATen decompositions from +# the Inductor decomp table. +decomps_to_exclude: list[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]] = [ + aten._unsafe_index, + aten._unsafe_masked_index, + aten._unsafe_masked_index_put_accumulate, + aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py + aten._softmax_backward_data, + aten.clamp_max, + aten.clamp_min, + aten.embedding_dense_backward, # we fall back on xpu + aten.index_add, # we conditionally call this decomp + aten.glu, # inductor lowers this directly + aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass + aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass + aten.split.Tensor, # inductor lowers this directly + aten.squeeze, # inductor lowers this directly + aten.sum, # inductor lowers this directly + aten.unbind, # inductor lowers this directly + aten.baddbmm, # upcasts to fp32, perf issue +] + +remove_decompositions(decompositions, decomps_to_exclude) + + +def register_decomposition( + ops: Union[_GenericOperator, list[_GenericOperator]], +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + for op in ops if isinstance(ops, list) else [ops]: + if op in decompositions: + log.warning("duplicate decomp: %s", ops) + return decomp.register_decomposition(ops, decompositions) + + +@register_decomposition([aten.embedding_dense_backward]) +def _embedding_dense_backward( + grad_output: torch.Tensor, + indices: torch.Tensor, + num_weights: int, + padding_idx: int, + scale_grad_by_freq: bool, +) -> torch.Tensor: + # TODO: check if XE4 still need this fallback + # check torch.xpu.get_device_properties(grad_output.device).architecture + if grad_output.is_xpu: + return NotImplemented + # We can write a util function to update decomp table if we have more ops to fallback. + return decomp_embedding_dense_backward( + grad_output, indices, num_weights, padding_idx, scale_grad_by_freq + ) + + +# TODO: for now, inductor doesn't handle asserts +# because the condition is symbol -> tensor in the graph. +@register_decomposition([aten._assert_async.msg]) +def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: + return + + +# Following `assert_async_msg_decomp` and implement as non-op. +@register_decomposition([aten._functional_assert_async.msg]) +def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: + return + + +@register_decomposition([aten.sym_constrain_range_for_size.default]) +def sym_constrain_range_for_size( + symbol: torch.SymInt, + *, + min: Optional[torch.types.Number] = None, + max: Optional[torch.types.Number] = None, +) -> None: + return + + +@register_decomposition([aten.clamp]) +@pw_cast_for_opmath_non_tensor_args +def clamp( + x: torch.Tensor, + min: Optional[torch.types.Number] = None, + max: Optional[torch.types.Number] = None, +) -> torch.Tensor: + if min is not None: + x = x.clamp_min(min) + if max is not None: + x = x.clamp_max(max) + return x + + +@register_decomposition([aten.full]) +def full( + size: list[Union[int, torch.SymInt]], + fill_value: torch.types.Number, + **kwargs: Any, +) -> torch.Tensor: + dtype = kwargs.get("dtype") + if dtype is None: + kwargs["dtype"] = type_to_dtype(type(fill_value)) + return torch.full(size, fill_value, **kwargs) + return NotImplemented + + +@register_decomposition([aten.index_add]) +def index_add( + x: torch.Tensor, + dim: int, + index: torch.Tensor, + tensor: torch.Tensor, + *, + alpha: torch.types.Number = 1, +) -> torch.Tensor: + # If we are not in fbcode and dtype is bfloat16 + # fallback to index_add kernel + # see https://github.com/pytorch/pytorch/issues/137425 for details + if not is_fbcode() and x.dtype == torch.bfloat16: + return NotImplemented + else: + return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha) + + +# Not really sure how to put this into the main library. PrimTorch wants +# empty_permuted to go to the prim, and typically users don't really want +# to decompose to empty_strided (but inductor is OK with it, because we are +# cool with strides and everything goes to empty_strided) +@register_decomposition([aten.empty_permuted.default]) +def empty_permuted( + size: list[Union[int, torch.SymInt]], + physical_layout: list[int], + **kwargs: Any, +) -> torch.Tensor: + perm = [0] * len(size) + for p, l in enumerate(physical_layout): + perm[l] = p + return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm) + + +@register_decomposition([aten.convolution_backward]) +def convolution_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes: list[int], + stride: Union[int, list[int]], + padding: Union[int, list[int]], + dilation: Union[int, list[int]], + transposed: bool, + output_padding: list[int], + groups: int, + output_mask: list[bool], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not output_mask[2] or not is_gpu(grad_output.device.type): + return NotImplemented + grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) + grad_inp, grad_weight, _ = aten.convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + [output_mask[0], output_mask[1], False], + ) + return (grad_inp, grad_weight, grad_bias) + + +@register_decomposition([aten.round.decimals]) +def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor: + ten_pow_decimals = 10.0**decimals + return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) + + +@register_decomposition([aten.bmm]) +@pw_cast_for_opmath +def bmm( + self: torch.Tensor, + batch2: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + # TODO: Re-enable for mps once our reductions are performant enough + # (https://github.com/pytorch/pytorch/issues/150121) + if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]: + if statically_known_true(self.shape[1] == 1) or statically_known_true( + batch2.shape[2] == 1 + ): + out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2) + return out + if self.device.type == "cpu": + if statically_known_true(self.size(1) == 1) and statically_known_true( + batch2.size(-1) == 1 + ): + counters["inductor"]["decompose_bmm"] += 1 + return torch.sum( + self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True + ).unsqueeze(1) + return NotImplemented + + +@register_decomposition([aten.addmm]) +@pw_cast_for_opmath +def addmm( + self: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + beta: torch.types.Number = 1, + alpha: torch.types.Number = 1, +) -> torch.Tensor: + if self.device.type == "cpu": + if statically_known_true(mat1.size(0) == 1) and statically_known_true( + mat2.size(-1) == 1 + ): + counters["inductor"]["decompose_addmm"] += 1 + out = torch.sum( + mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return alpha * out + beta * self + if ( + statically_known_true(mat1.size(0) == 1) + and guard_or_false(mat2.size(0) <= 16) + and guard_or_false(mat2.size(1) <= 16) + ): + counters["inductor"]["decompose_addmm"] += 1 + out = (mat1.T * mat2).sum(dim=0, keepdim=True) + return alpha * out + beta * self + return NotImplemented + + +@register_decomposition([aten.mm]) +@pw_cast_for_opmath +def mm( + self: torch.Tensor, + input2: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. + # todo: Look into why and fix it (hopefully) + + # TODO: Re-enable for mps once our reductions are performant enough + # (https://github.com/pytorch/pytorch/issues/150121) + if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]: + if statically_known_true(self.shape[0] == 1) or statically_known_true( + input2.shape[1] == 1 + ): + return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) + if self.device.type == "cpu": + if ( + statically_known_true(self.size(-1) == 1) + and statically_known_true(self.size(0) > 0) + and statically_known_true(input2.size(0) == 1) + and (self.dtype == input2.dtype) + and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32) + ): + counters["inductor"]["decompose_mm"] += 1 + return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) + if statically_known_true(self.size(0) == 1) and statically_known_true( + input2.size(-1) == 1 + ): + counters["inductor"]["decompose_mm"] += 1 + return torch.sum( + self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return NotImplemented + + +# This pass does two things: +# - Eliminate cat when there is only one tensor input +# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we +# don't remove ALL empty tensors, only the naughty ones) +@register_decomposition([aten.cat.default]) +def cat( + tensors: list[torch.Tensor], + dim: int = 0, +) -> torch.Tensor: + def non_empty_tensor(x: torch.Tensor) -> bool: + # For better or worse, this is a valid cat: + # + # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)]) + # + # We'd like to eliminate naughtiness like this for downstream passes + # like split_cat. The easiest way is to just drop such inputs + # (guarding that they are non-zero). + # + # Is it permissible for this filtering to be size-oblivious? A case + # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0 + # happened to be zero, we would have liked to have filtered it out. + # But actually, the ONLY way this could have passed is if u0 == 0, + # so by the time we get here we have already installed a deferred + # runtime assert forcing u0 to be zero. So if this hasn't happened, + # we know that the unbacked SymInt has appropriate size and there are + # no problems. + if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0): + return False + + if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0): + return False + + return True + + filtered_tensors = list(filter(non_empty_tensor, tensors)) + + if len(filtered_tensors) == 1: + # check dtype promotion + promoted_dtype = elementwise_dtypes( + *tensors, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + )[1] + filtered_t = filtered_tensors[0] + return ( + filtered_t.clone() + if promoted_dtype == filtered_t.dtype + else filtered_t.to(dtype=promoted_dtype) + ) + elif 1 < len(filtered_tensors) < len(tensors): + # on the first call, when we remove empty tensors, we redispatch recursively + return aten.cat.default(filtered_tensors, dim) + + # optimization, avoid concat for single, repeated input + if len(filtered_tensors) > 1 and all( + t is filtered_tensors[0] for t in filtered_tensors + ): + inp = filtered_tensors[0] + shape = list(inp.shape) + dim = dim + len(inp.shape) if dim < 0 else dim + shape.insert(dim, len(filtered_tensors)) + return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone() + + # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed) + return NotImplemented + + +@register_decomposition([aten.angle]) +def angle(x: torch.Tensor) -> torch.Tensor: + if x.is_complex(): + return torch.where( + torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real) + ) + + # when x is real number + # if x >= 0, return 0 + # if x < 0, return pi + # if x is nan, return nan + _, dtype = elementwise_dtypes( + x, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device) + ret = torch.where(x < 0, pi, 0.0) + return torch.where(torch.isnan(x), float("nan"), ret) + + +@register_decomposition([aten.add]) +def add( + x: torch.Tensor, + y: torch.Tensor, + *, + alpha: Optional[torch.types.Number] = None, +) -> torch.Tensor: + # Require both x and y to be complex tensors. + x_is_complex_tensor = torch.is_tensor(x) and x.is_complex() + y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() + if not x_is_complex_tensor or not y_is_complex_tensor: + return NotImplemented + z = y + if alpha is not None: + z = alpha * y + complex_type = torch.promote_types(x.dtype, y.dtype) + + # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem + # when broadcasting the add. + def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor: + """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]""" + # Get the current shape of the tensor + *initial_dims, last_dim = tensor.shape + + # Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)` + # doubles the last dimension for complex numbers. + if last_dim % 2 != 0: + raise AssertionError( + "The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]" + ) + + # Reshape the tensor + new_shape = (*initial_dims, last_dim // 2, 2) + reshaped_tensor = tensor.view(new_shape) + return reshaped_tensor + + # Manually resolve complex tensors, as .is_conj() is unreliable after cloning during compilation. + x = x + 0 + z = z + 0 + + x_reshaped = reshape_tensor_complex(x.view(x.real.dtype)) + z_reshaped = reshape_tensor_complex(z.view(y.real.dtype)) + result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type) + return result + + +@register_decomposition([aten.conj_physical]) +def conj_physical(self: torch.Tensor) -> torch.Tensor: + if self.is_complex(): + return NotImplemented + return self + + +@register_decomposition([aten.lift, aten.detach_]) +def lift(self: torch.Tensor) -> torch.Tensor: + return self + + +@register_decomposition([aten.fmin, prims.fmin]) +def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.where(torch.isnan(other) | (other > self), self, other) + + +@register_decomposition([aten.fmax, prims.fmax]) +def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.where(torch.isnan(other) | (other < self), self, other) + + +@register_decomposition(aten.amax) +def amax( + self: torch.Tensor, + dim: Optional[int] = None, + keepdim: bool = False, +) -> torch.Tensor: + if self.dtype == torch.bool: + return torch.any(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition(aten.amin) +def amin( + self: torch.Tensor, + dim: Optional[int] = None, + keepdim: bool = False, +) -> torch.Tensor: + if self.dtype == torch.bool: + return torch.all(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition([aten.narrow_copy]) +def narrow_copy( + self: torch.Tensor, + dim: int, + start: int, + length: int, +) -> torch.Tensor: + return torch.narrow(self, dim, start, length).clone() + + +@register_decomposition([aten.view_copy.default]) +def view_copy_default( + self: torch.Tensor, + size: list[Union[int, torch.SymInt]], +) -> torch.Tensor: + return aten.view(self, size).clone() + + +@register_decomposition([aten.view_copy.dtype]) +def view_copy_dtype( + self: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + return self.to(dtype).clone() + + +def get_like_layout( + tensor: torch.Tensor, + memory_format: Optional[torch.memory_format] = None, +) -> torch.memory_format: + # TODO: _to_copy tensor to stride permutation + if memory_format is torch.preserve_format or memory_format is None: + return utils.suggest_memory_format(tensor) + else: + return memory_format + + +@register_decomposition(aten.rand_like) +def rand_like( + self: torch.Tensor, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return torch.rand( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randn_like) +def randn_like( + self: torch.Tensor, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return torch.randn( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.full_like) +def full_like( + self: torch.Tensor, + fill_value: Union[int, float], + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> torch.Tensor: + return torch.full( + [*self.size()], + fill_value, + dtype=dtype or self.dtype, + layout=layout or self.layout, + device=device or self.device, + requires_grad=requires_grad, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint_like.default) +def randint_like( + self: torch.Tensor, + high: int, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low( + 0, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint_like.low_dtype) +def randint_like_low( + self: torch.Tensor, + low: int, + high: int, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low( + low, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint.default) +def randint( + high: int, + size: list[Union[int, torch.SymInt]], + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low(0, high, size, **kwargs) + + +@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default) +def linear_dynamic_fp16_unpacked_weight( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight) + return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight( + input, packed_weight, bias, weight.size()[0] + ) + + +@register_decomposition(_quantized.wrapped_quantized_linear.default) +def wrapped_quantized_linear( + input: torch.Tensor, + input_scale: torch.Tensor, + input_zero_point: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + bias: torch.Tensor, + out_scale: torch.Tensor, + out_zero_point: torch.Tensor, + out_channel: int, +) -> torch.Tensor: + packed_weight = torch.ops._quantized._wrapped_linear_prepack( + weight, weight_scale, weight_zero_point, bias + ) + return torch.ops._quantized._wrapped_quantized_linear_prepacked( + input, + input_scale, + input_zero_point, + packed_weight, + out_scale, + out_zero_point, + out_channel, + ) + + +@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack) +def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor: + def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor: + x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3)) + if sys.byteorder == "little": + return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None] + else: + return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None] + + scales = bitcast_u8_to_f32(packed[..., -8:-4]) + offsets = bitcast_u8_to_f32(packed[..., -4:]) + return packed[..., :-8].to(torch.float32) * scales + offsets + + +@register_decomposition([aten.grid_sampler_2d]) +@pw_cast_for_opmath +def grid_sampler_2d( + a: torch.Tensor, + grid: torch.Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> torch.Tensor: + # We do not expand the grid (_expand_grid=False) on cpu for performance reasons + # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x + # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) + # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. + # Thus we apply this hack to not expand the grid for this case. + _expand_grid = not ( + a.device == torch.device("cpu") + and interpolation_mode == 0 + and a.is_contiguous(memory_format=torch.contiguous_format) + ) + + output = decomp_grid_sampler_2d( + a, + grid=grid, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + _expand_grid=_expand_grid, + ) + return output + + +@register_decomposition(aten._foreach_addcmul.Scalar) +def _foreach_addcmul_scalar( + self: list[torch.Tensor], + left_tensors: list[torch.Tensor], + right_tensors: list[torch.Tensor], + scalar: float = 1, +) -> list[torch.Tensor]: + return aten._foreach_add.List( + self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_addcdiv.Scalar) +def _foreach_addcdiv_scalar( + self: list[torch.Tensor], + left_tensors: list[torch.Tensor], + right_tensors: list[torch.Tensor], + scalar: float = 1, +) -> list[torch.Tensor]: + return aten._foreach_add.List( + self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_lerp.Scalar) +def _foreach_lerp_scalar( + start_tensors: list[torch.Tensor], + end_tensors: list[torch.Tensor], + weight: torch.types.Number, +) -> list[torch.Tensor]: + return aten._foreach_add.List( + start_tensors, + aten._foreach_mul.Scalar( + aten._foreach_sub.List(end_tensors, start_tensors), weight + ), + ) + + +@register_decomposition(aten._foreach_lerp.ScalarList) +def _foreach_lerp_scalarlist( + start_tensors: list[torch.Tensor], + end_tensors: list[torch.Tensor], + scalars: list[torch.types.Number], +) -> list[torch.Tensor]: + return aten._foreach_add.List( + start_tensors, + aten._foreach_mul.ScalarList( + aten._foreach_sub.List(end_tensors, start_tensors), scalars + ), + ) + + +@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd) +@register_decomposition(aten.miopen_batch_norm) +def miopen_batch_norm( + input: torch.Tensor, + weight: torch.Tensor, + bias: typing.Optional[torch.Tensor], + running_mean: typing.Optional[torch.Tensor], + running_var: typing.Optional[torch.Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + a, b, c = aten.native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + ) + + if training: + return (a, b, c) + return ( + a, + weight.new_zeros((0,)), + weight.new_zeros((0,)), + ) + + +@functools.cache +def fast_random_decomps() -> dict[Any, Callable[..., Any]]: + return {**decompositions, **extra_random_decomps} + + +# TODO(aakhundov): replace this (and the above) Any by more +# specific type and fix all the cascading mypy errors +def select_decomp_table() -> dict[Any, Callable[..., Any]]: + """decomps can change based on config""" + if config.fallback_random: + return decompositions + return fast_random_decomps() + + +@register_decomposition(aten.masked_scatter) +def masked_scatter( + self: torch.Tensor, + mask: torch.Tensor, + source: torch.Tensor, +) -> torch.Tensor: + from .codegen.common import BackendFeature, has_backend_feature + + if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX): + # This two-step algorithm is the same as eager CUDA, for eager CPU we + # use a 1-shot serial iteration. + self, mask = aten.broadcast_tensors([self, mask]) + source_idx = mask.reshape(-1).cumsum(0) - 1 + self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source)) + result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0) + return torch.where(mask_flat, result, self_flat).view(self.shape) + return NotImplemented + + +@register_decomposition(quantized_decomposed.choose_qparams.tensor) +def choose_qparams_tensor( + input: torch.Tensor, + quant_min: int, + quant_max: int, + eps: float, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + min_val, max_val = torch.aminmax(input) + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.max(scale, torch.Tensor([eps])) + zero_point = quant_min - torch.round(min_val / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + return scale.to(torch.float64), zero_point.to(torch.int64) + + +@register_decomposition(aten.put) +def put( + self: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + accumulate: bool = False, +) -> torch.Tensor: + flattened = self.flatten() + flattened = torch.index_put( + flattened, [index], source.reshape(index.shape), accumulate + ) + return flattened.reshape(self.shape) + + +@register_decomposition(aten.put_) +def put_( + self: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + accumulate: bool = False, +) -> torch.Tensor: + out = aten.put(self, index, source, accumulate=accumulate) + return self.copy_(out) + + +@register_decomposition(aten._softmax_backward_data.default) +@pw_cast_for_opmath +def _softmax_backward_data( + grad_output: torch.Tensor, + output: torch.Tensor, + dim: int, + input_dtype: torch.dtype, +) -> torch.Tensor: + new_grad_output = grad_output * output + sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True) + # grad_input = new_grad_output - output * sum_new_grad + grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + if grad_output.dtype != input_dtype: + grad_input = grad_input.to(input_dtype) + return grad_input.contiguous() + + +@register_decomposition(aten.index_reduce) +def index_reduce( + self: torch.Tensor, + dim: int, + index: torch.Tensor, + src: torch.Tensor, + reduction_type: str, + *, + include_self: bool = True, +) -> torch.Tensor: + if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations( + self.dtype + ): + true_division = self.dtype.is_floating_point or self.dtype.is_complex + ones = torch.ones_like(src) + if include_self: + out = self + counts = torch.ones_like(self).index_add(dim, index, ones) + else: + out = self.index_fill(dim, index, 0) + counts = torch.zeros_like(self).index_add(dim, index, ones) + counts = counts.masked_fill(counts < 1, 1) + out = out.index_add(dim, index, src) + return out / counts if true_division else out // counts + + if use_scatter_fallback( + aten.scatter_reduce_.two, + reduction_type, + self.dtype, + src.dtype, + src.device.type, + True, + ): + return NotImplemented + + repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel() + index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim]) + perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim)) + scatter_index = ( + index.to(torch.int64) + .repeat_interleave(repeats) + .reshape(index_shape) + .permute(perm) + ) + return self.scatter_reduce( + dim, + scatter_index, + src, + reduction_type, + include_self=include_self, + ) + + +def _max_pool_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]], + padding: Union[int, list[int]], + dilation: Union[int, list[int]], + ceil_mode: bool, + dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + if dilation == 1: + dilation = [1] * dim + + if padding == 0: + padding = [0] * dim + + if not stride: + stride = kernel_size + + kernel_size = pad_listlike(kernel_size, dim) + dilation = pad_listlike(dilation, dim) + padding = pad_listlike(padding, dim) + stride = pad_listlike(stride, dim) + + window_size = functools.reduce(operator.mul, kernel_size) + # We fallback when using non-default dilation or when the window size is too large + if ( + torch._inductor.lowering.should_fallback_max_pool_with_indices( + kernel_size, n_dim=dim + ) + or window_size > torch.iinfo(torch.int8).max + ): + return NotImplemented + + vals, offsets = prims._low_memory_max_pool_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + indices = prims._low_memory_max_pool_offsets_to_indices( + offsets, + kernel_size, + x.shape[-dim:], + stride, + padding, + dilation, + ) + return vals, indices + + +@register_decomposition(aten.max_pool2d_with_indices) +def max_pool2d_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]] = None, + padding: Union[int, list[int]] = 0, + dilation: Union[int, list[int]] = 1, + ceil_mode: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, dim=2 + ) + + +@register_decomposition(aten.max_pool3d_with_indices) +def max_pool3d_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]] = None, + padding: Union[int, list[int]] = 0, + dilation: Union[int, list[int]] = 1, + ceil_mode: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, dim=3 + ) + + +@register_decomposition(aten.adaptive_max_pool2d) +def adaptive_max_pool2d( + x: torch.Tensor, output_size: list[int] +) -> tuple[torch.Tensor, torch.Tensor]: + *batch, h_in, w_in = x.shape + h_out, w_out = output_size + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return x.new_empty(o_size), x.new_empty(o_size, dtype=torch.int64) + + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return aten.max_pool2d_with_indices(x, kernel_size) + + return NotImplemented + + +@register_decomposition(aten.searchsorted.Scalar) +def searchsorted_scalar( + sorted_sequence: torch.Tensor, + self: torch.types.Number, + *, + out_int32: bool = False, + right: bool = False, + side: Optional[str] = None, + sorter: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return aten.searchsorted( + sorted_sequence, + torch.tensor([self], device=sorted_sequence.device), + out_int32=out_int32, + right=right, + side=side, + sorter=sorter, + )[0] + + +@register_decomposition(aten.rrelu_with_noise_functional) +def rrelu_with_noise_functional( + self: torch.Tensor, + noise: torch.Tensor, + lower: float = 0.125, + upper: float = 0.3333333333333333, + training: bool = False, + generator: Optional[torch.Generator] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if training: + not_positive = self <= 0 + r = aten.uniform(self, lower, upper, generator=generator) + output = torch.where(not_positive, self * r, self) + noise_out = torch.where(not_positive, r, 1) + return output, noise_out + else: + negative_slope = (lower + upper) / 2 + return aten.leaky_relu(self, negative_slope), torch.Tensor() diff --git a/phivenv/Lib/site-packages/torch/_inductor/dependencies.py b/phivenv/Lib/site-packages/torch/_inductor/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..853982280b4ae7766ae32d246b2eeec04a73c221 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/dependencies.py @@ -0,0 +1,822 @@ +import abc +import dataclasses +import itertools +import logging +import re +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, TypeVar, Union +from typing_extensions import Self +from unittest.mock import patch + +import sympy + +import torch +from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols +from torch.utils._ordered_set import OrderedSet + +from ..utils._sympy.symbol import make_symbol, SymT +from .codegen.common import index_prevent_reordering +from .ops_handler import DefaultHandler +from .utils import ( + get_dtype_size, + reduction_num_outputs, + sympy_index_symbol, + sympy_str, + sympy_subs, + VarRanges, +) +from .virtualized import ReductionType, V + + +T = TypeVar("T") + +log = logging.getLogger(__name__) +is_indirect = re.compile(r"indirect|tmp").search + + +class Dep(abc.ABC): + name: str + index: sympy.Expr + + @abc.abstractmethod + def rename(self, renames: dict[str, str]) -> Self: + pass + + @abc.abstractmethod + def get_numel(self) -> sympy.Expr: + pass + + @abc.abstractmethod + def numbytes_hint(self) -> int: + pass + + @abc.abstractmethod + def has_unbacked_symbols(self) -> bool: + pass + + @abc.abstractmethod + def is_contiguous(self) -> bool: + pass + + def normalize_with_stride_order(self, prefix: str = "t") -> Self: + return self + + +@dataclasses.dataclass(frozen=True) +class MemoryDep(Dep): + name: str + index: sympy.Expr + var_names: tuple[sympy.Symbol, ...] + size: tuple[sympy.Expr, ...] + mode: Optional[str] = None + + def __repr__(self) -> str: + maybe_mode = "" + if self.mode is not None: + maybe_mode = f", {self.mode}" + return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}{maybe_mode})" + + @property + def num_vars(self) -> int: + return len(self.var_names) + + def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[list[int]]: + """ + Can return None if not able to decide loop orders. + """ + assert self.num_vars == other.num_vars + + # ignore broadcast for now since broadcast causes extra 0 strides + # which makes it hard to decide the correct loop orders. + if self.num_vars != len(self.index.free_symbols): + return None + if other.num_vars != len(other.index.free_symbols): + return None + + # bail out if any size is 0 or 1 + # For size == 0, it's an empty tensor, any strides for that dimension + # are equivalent. Skip for simplicity and it may not matter that much. + # + # For size == 1, it cause cause tie for strides of different dimensions. + # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder + # we can dependencies.index_vars_squeeze which should already sqeeuze + # the size == 1 dimensions. + if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)): + return None + + # Extract strides for both expression + self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names) + + # Even if the shape contains no 0/1, some complex index expression may + # still have duplicate stride values. Here is an example: + # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129 + # We don't reorder the loop for these cases for now, but in theory + # we could improve the algorithm to detect the correct loop orders. + if len(OrderedSet(self_strides)) != len(self_strides) or len( + OrderedSet(other_strides) + ) != len(other_strides): + log.debug( + "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s", + self, + other, + self_strides, + other_strides, + ) + return None + + # May happen if self and other are as follows + # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None) + # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None) + if OrderedSet(self_strides) != OrderedSet(other_strides): + return None + + stride_to_index = {s: i for i, s in enumerate(self_strides)} + order = [stride_to_index[s] for s in other_strides] + + assert OrderedSet(order) == OrderedSet(range(0, self.num_vars)) + return order + + def get_offset(self) -> sympy.Expr: + """ + Return the offset by setting every variable to be 0. + """ + return sympy_subs(self.index, dict.fromkeys(self.var_names, 0)) + + def normalize(self) -> "MemoryDep": + """ + Normalize by merging loops. The different to normalize_with_stride_order is, + this method does not reorder loops while normalize_with_stride_order reorder + loops based on stride order. + """ + return MemoryDep( + self.name, + *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type] + self.mode, + ) + + def normalize_with_stride_order(self, prefix: str = "t") -> "MemoryDep": + r""" + Used to decide if two MemoryDep does not equal due to different loop orders. + More specifically, when dep1 and dep2 are not equal, we can normalize + both and check if they are equal after that. If yes, then the mismatch is + caused by different loop orders. + """ + # import here to avoid circular import + from torch._inductor import ir + + strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + + # pick a loop order with stride ordered decreasingly + order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + stride_reorder = ir.same_reorder(order) + sizes = self.size + var_names = self.var_names + + new_reordered_sizes = stride_reorder(sizes) + new_reordered_var_names = stride_reorder(var_names) + + new_simplified_sizes, reindex, _prune = V.graph.sizevars._simplify_loops( + new_reordered_var_names, + new_reordered_sizes, + index_prevent_reordering( + [self.index], new_reordered_var_names, new_reordered_sizes + ), + ) + + # now let's create new symbols with the passed in prefix + var_ranges, add_var = var_builder(prefix) + replacement = dict( + zip( + new_reordered_var_names, + reindex([add_var(x) for x in new_simplified_sizes]), + ) + ) + new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR + + out = MemoryDep( + self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values()) + ) # type: ignore[arg-type] + return out + + @property + def ranges(self) -> dict[sympy.Symbol, sympy.Expr]: + """{c0: 128, c1: 512, ...}""" + return dict(zip(self.var_names, self.size)) + + def simplify_with_ranges(self) -> "MemoryDep": + return MemoryDep( + name=self.name, + index=V.graph.sizevars.simplify_with_ranges(self.index, self.ranges), + var_names=self.var_names, + size=self.size, + mode=self.mode, + ) + + def get_numel(self) -> sympy.Expr: + if self.is_indirect(): + numel = V.graph.get_numel(self.name) + else: + vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols) + numel = sympy.S.One + for var, size in zip(self.var_names, self.size): + if var in vars: + numel = numel * size + return numel # type: ignore[return-value] + + def rename(self, renames: dict[str, str]) -> "MemoryDep": + if self.name in renames: + return MemoryDep( + renames[self.name], + self.index, + var_names=self.var_names, + size=self.size, + mode=self.mode, + ) + return self + + def numbytes_hint(self) -> int: + try: + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + except NotImplementedError: # NoneLayout + return 0 + + def has_unbacked_symbols(self) -> bool: + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + if isinstance(self.index, sympy.Integer): + return True + return isinstance(self.index, sympy.Symbol) and self.index in self.var_names + + def stride1_for_last_dim(self, result_for_complex_expression: bool = True) -> bool: + """ + Whether the stride for the last dimension is 1. + """ + # python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16 + # will exercise thru this corner case. + if len(self.var_names) == 0: + return True + + terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index] + + last_sym = self.var_names[-1] + for term in terms: + if term == last_sym: + return True + + # Having a >1 stride for the last dimension is bad for perf + # return False. + if ( + isinstance(term, sympy.Mul) + and len(term.args) == 2 + and term.args[1] == last_sym + and isinstance(term.args[0], (int, sympy.Integer)) + and term.args[0] > 1 + ): + return False + + return result_for_complex_expression + + def is_scalar(self) -> bool: + if isinstance(self.index, sympy.Symbol): + return self.index not in self.var_names and not self.is_indirect() + return isinstance(self.index, (int, sympy.Integer)) + + def is_indirect(self) -> bool: + return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined] + + +@dataclasses.dataclass(frozen=True) +class StarDep(Dep): + name: str + mode: Optional[str] = None + + # depends on the entire buffer + @property + def index(self) -> sympy.Expr: + raise NotImplementedError("StarDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return V.graph.get_numel(self.name) # type: ignore[return-value] + + def rename(self, renames: dict[str, str]) -> "StarDep": + if self.name in renames: + return StarDep(renames[self.name], self.mode) + return self + + def numbytes_hint(self) -> int: + try: + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + except NotImplementedError: + return 0 # NoneLayout, MultiOutputLayout, etc + + def has_unbacked_symbols(self) -> bool: + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + return False + + def is_scalar(self) -> bool: + return False + + def is_indirect(self) -> bool: + return False + + +# Used for tracking mutation ordering +# if A reads a buffer and B mutates it +# B must be ordered after A +# +# This is useful for a variety of reasons. +# For example, if A's read is never actually used, we can eliminate it. +# Another case is if A's buffer ends up being fused away, we never need to +# materialize that buffer +@dataclasses.dataclass(frozen=True) +class WeakDep(Dep): + # Fake dependency on unused buffer + name: str + # Buffer that is doing the mutation + mutating_buf: str + + @property + def index(self) -> sympy.Expr: + raise NotImplementedError("WeakDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return sympy.S.One + + def rename(self, renames: dict[str, str]) -> "WeakDep": + if self.name in renames: + return WeakDep(renames[self.name], self.mutating_buf) + return self + + def numbytes_hint(self) -> int: + return 1 # Purely inserted for ordering, not an actual dep + + def has_unbacked_symbols(self) -> bool: + return False + + def is_contiguous(self) -> bool: + return False + + +@dataclasses.dataclass(frozen=True) +class IndexExprDep: + index: sympy.Expr # type: ignore[assignment] + var_names: tuple[sympy.Symbol, ...] + size: tuple[sympy.Expr, ...] + + +@dataclasses.dataclass +class ReadWrites: + reads: OrderedSet[Dep] + writes: OrderedSet[Dep] + index_exprs: OrderedSet[IndexExprDep] + range_vars: Optional[list[sympy.Expr]] = None + var_ranges: Optional[VarRanges] = None + + def rename(self, renames: dict[str, str]) -> "ReadWrites": + return ReadWrites( + OrderedSet(dep.rename(renames) for dep in self.reads), + OrderedSet(dep.rename(renames) for dep in self.writes), + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def with_read(self, dep: Union[Dep, OrderedSet[Dep]]) -> "ReadWrites": + assert isinstance(dep, (WeakDep, StarDep, OrderedSet)) + if not isinstance(dep, OrderedSet): + dep = OrderedSet([dep]) + return ReadWrites( + OrderedSet.union(self.reads, dep), + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def merge(self, other: "ReadWrites") -> "ReadWrites": + reads = OrderedSet.union(self.reads, other.reads) + writes = OrderedSet.union(self.writes, other.writes) + index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs) + return ReadWrites(reads - writes, writes, index_exprs) + + @staticmethod + def merge_list(read_writes: list["ReadWrites"]) -> "ReadWrites": + all_writes = OrderedSet.union(*[rw.writes for rw in read_writes]) + all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes + all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes]) + return ReadWrites(all_reads, all_writes, all_index_exprs) + + def remove_reads(self, rem_reads: OrderedSet[Dep]) -> "ReadWrites": + return ReadWrites( + self.reads - rem_reads, + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def reads_and_writes(self) -> Iterable[Dep]: + return itertools.chain(self.reads, self.writes) + + def buffer_names(self, ignore_integer_index: bool = True) -> OrderedSet[str]: + """ + Integer index is used for load_seed. + """ + names: OrderedSet[str] = OrderedSet() + for dep in self.reads_and_writes(): + if not isinstance(dep, MemoryDep): + continue + if not ignore_integer_index or not isinstance( + dep.index, (int, sympy.Integer) + ): + names.add(dep.name) + return names + + +class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: + super().__init__() + self._reads: OrderedSet[Dep] = OrderedSet() + self._writes: OrderedSet[MemoryDep] = OrderedSet() + self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet() + self._var_ranges: VarRanges = var_ranges + self._should_normalize: bool = normalize + + @staticmethod + def drop_unused_symbols( + index: Union[int, sympy.Expr], + var_names: list[sympy.Expr], + sizes: list[sympy.Expr], + ) -> None: + """ + Reduction has last (reduced) dim in its sizes, but + downstream users won't. Normalize this away. + """ + if not isinstance(index, sympy.Expr): + # index can be an int + return + free_symbols = index.free_symbols + while var_names and var_names[-1] not in free_symbols: + var_names.pop() + sizes.pop() + + @classmethod + def _normalize( + cls, index: sympy.Expr, var_ranges: VarRanges + ) -> tuple[sympy.Expr, tuple[sympy.Symbol, ...], tuple[sympy.Expr, ...]]: + # Try to further simplify the indexes even if simplify_loops didn't + # convert it to the simplest form because of the interference from + # different indexing formulas. + index_vars = [*var_ranges.keys()] + sizes = tuple(var_ranges.values()) # type: ignore[assignment] + new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops( + index_vars, + sizes, + index_prevent_reordering([index], index_vars, sizes), + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + new_vars, add_var = var_builder(canonicalization_prefix()) + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + index = sympy_subs(sympy.expand(index), replacement) + + new_vars = [*new_vars.keys()] + new_sizes = [*new_sizes] + cls.drop_unused_symbols(index, new_vars, new_sizes) + return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] + + def canonicalize( + self, index: sympy.Expr + ) -> tuple[sympy.Expr, tuple[sympy.Symbol, ...], tuple[sympy.Expr, ...]]: + if not self._should_normalize: + sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] + var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1] + sizes = [v for v in sizes if v != 1] + + self.drop_unused_symbols(index, var_names, sizes) + + return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type] + var_ranges = { + k: V.graph.sizevars.simplify(v) + for k, v in self._var_ranges.items() + # TODO(jansel): explore this further normalization + # if k in free_symbols + } + return self._normalize(index, var_ranges) + + def load(self, name: str, index: sympy.Expr) -> str: + self._reads.add(MemoryDep(name, *self.canonicalize(index))) + return f"load({name}, {sympy_str(index)})" + + def load_seed(self, name: str, index: int) -> str: + assert isinstance(index, int) + return self.load(name, sympy.Integer(index)) + + def store( + self, name: str, index: sympy.Expr, value: str, mode: Optional[str] = None + ) -> str: + self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode)) + return f"store({name}, {sympy_str(index)}, {value}, {mode})" + + def store_reduction(self, name: str, index: sympy.Expr, value: str) -> str: + return self.store(name, index, f"store_reduction({value})") + + def index_expr(self, index: sympy.Expr, dtype: Optional[torch.dtype]) -> str: + self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) + return f"index_expr({sympy_str(index)}, {dtype})" + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> None: + """Records the names of the buffers that bucketize will read from.""" + self._reads.add(StarDep(boundaries[0])) + if sorter is not None: + self._reads.add(StarDep(sorter[0])) + + +class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: + parent_handler = _RecordLoadStoreInner( + var_ranges=var_ranges, normalize=normalize + ) + super().__init__(parent_handler=parent_handler) + + +# TODO: check call sites +def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]: + cnt = itertools.count() + var_ranges: VarRanges = {} + + def add_var(length: sympy.Expr) -> sympy.Symbol: + v = sympy_index_symbol(f"{prefix}{next(cnt)}") + var_ranges[v] = length + return v + + return var_ranges, add_var + + +def index_vars_no_squeeze( + *argsizes: Sequence[sympy.Expr], prefix: str +) -> tuple[list[list[sympy.Symbol]], VarRanges]: + var_ranges, add_var = var_builder(prefix) + args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes] + return args, var_ranges + + +def index_vars_squeeze( + *argsizes: Sequence[sympy.Expr], prefix: str = "d" +) -> tuple[list[list[sympy.Expr]], VarRanges]: + from .ir import SqueezeView + + var_ranges, add_var = var_builder(prefix) + args: list[list[sympy.Expr]] = [] + new_sizes: list[list[sympy.Expr]] = [] + for size in argsizes: + new_size, reindex = SqueezeView.squeezer(size) + new_sizes.append(new_size) + args.append(reindex(list(map(add_var, new_size)))) + return args, var_ranges + + +def extract_read_writes( + fn: Callable[..., Any], + *argsizes: Sequence[sympy.Expr], + normalize: bool = False, + prefix: str = "d", + hidden_args: Sequence[list[sympy.Expr]] = (), +) -> ReadWrites: + args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) + + from .loop_body import LoopBody + + if isinstance(fn, LoopBody): + inner = extract_loop_body_with_args( + fn, [*args, *hidden_args], var_ranges, normalize + ) + else: + # Slow path tracing the function + rw = RecordLoadStore(var_ranges, normalize=normalize) + with V.set_ops_handler(rw): + fn(*args, *hidden_args) + inner = rw.parent_handler + + if normalize: + range_vars = [] # Number of vars could differ due to normalization + else: + range_vars = [*itertools.chain.from_iterable(args)] + + return ReadWrites( + OrderedSet(inner._reads), + OrderedSet(inner._writes), + inner._index_exprs, + range_vars, + var_ranges, + ) + + +def extract_loop_body_with_args( + fn: Any, + args: list[list[sympy.Expr]], + var_ranges: VarRanges, + normalize: bool = False, +) -> _RecordLoadStoreInner: + from .loop_body import MemoryUsageType + + # Fast path to avoid tracing when we already have a LoopBody + inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize) + name_to_index = fn.indexing_from_args(args) + if fn.indirect_vars: + # mimic the `tmpX` naming tracing gives us + repl = {v: make_symbol(SymT.TMP, i) for i, v in enumerate(fn.indirect_vars)} + name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} # type: ignore[arg-type] + for entry in fn.memory_usage[MemoryUsageType.LOAD]: + inner.load(entry.buffer_name, name_to_index[entry.index_name]) # type: ignore[arg-type] + for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: + inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type] + for entry in fn.memory_usage[MemoryUsageType.STORE]: + inner.store( + entry.buffer_name, + name_to_index[entry.index_name], + None, # type: ignore[arg-type] + entry.mode, + ) + for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: + inner.store_reduction( + entry.buffer_name, + name_to_index[entry.index_name], + None, # type: ignore[arg-type] + ) + for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: + inner.index_expr(name_to_index[entry.index_name], None) + for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: + # All that matters is that we record the buffer name, so place it in the + # "boundaries" name position to ensure that it's recorded. + inner.bucketize( + None, + (entry.buffer_name, None, None, None), + None, + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + ) + # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped + return inner + + +def extract_input_node_reduction_ranges( + input_node: "torch._inductor.ir.IRNode", +) -> tuple[Optional[list[sympy.Expr]], Optional[list[sympy.Expr]]]: + """ + Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same. + It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes. + In this case, reduction_sizes of the Reduction nodes need to be the same. + Otherwise returns (None, None). + """ + + from .ir import ComputedBuffer, ExternKernel, Loops + + size: Optional[list[sympy.Expr]] + reduction_size: Optional[list[sympy.Expr]] + + if isinstance(input_node.get_defining_op(), ComputedBuffer): + # Input node has already been realized. Return its size and reduction_size. + size = [*input_node.get_size()] + reduction_size = [*input_node.get_reduction_size()] + if len(reduction_size) > 0: + return (size, reduction_size) + else: + return (None, None) + + if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined] + # Other IRNodes do not have reduction_ranges. + return (None, None) + + # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes? + # The current method still uses reduction ranges from the dependent realized node, which is not ideal. + # Is there a way to check whether there are permutations in between? + reads = input_node.get_reads() + reduction_size: Optional[list[sympy.Expr]] = None + size: Optional[list[sympy.Expr]] = None + while reduction_size is None and len(reads) > 0: + seen: OrderedSet[str] = OrderedSet() + new_reads: list[Dep] = [] + for read in reads: + if not isinstance(read, MemoryDep): + continue + if read.name in seen: + continue + seen.add(read.name) + buffer = V.graph.try_get_buffer(read.name) + if buffer is None: + continue + op = buffer.get_defining_op() + if op is None or isinstance(op, ExternKernel): + continue + + if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0: + if reduction_size is None: + reduction_size = [*op.get_reduction_size()] + size = [*op.get_size()] + elif reduction_size != [*op.get_reduction_size()] or size != [ + *op.get_size() + ]: + return (None, None) + else: + new_reads.extend(op.get_reads()) + if reads == new_reads: + return (size, reduction_size) + else: + reads = OrderedSet(new_reads) + return (size, reduction_size) + + +def canonicalization_prefix() -> str: + return "c" + + +# ops handler which computes all the free symbols for an IR +class FreeSymbolsOpsHandler(DefaultHandler): + symbols: OrderedSet[sympy.Symbol] + + def __init__(self, unbacked_only: bool = True) -> None: + self.symbols = OrderedSet() + self.get_symbols = free_unbacked_symbols if unbacked_only else free_symbols + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + for a in itertools.chain(args, kwargs.values()): + if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): + self.symbols |= self.get_symbols(a) + + def indirect_indexing( + self, + index_var: Any, + size: Union[int, sympy.Expr], + check: bool = True, + wrap_neg: bool = True, + ) -> sympy.Symbol: + assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean)) + self.symbols |= self.get_symbols(size) + return sympy_index_symbol(f"({str(index_var)})") + + def frexp(self, x: Any) -> tuple[None, ...]: + return (None,) * 2 + + def scan( + self, dtypes: Any, combine_fn: Any, values: Sequence[Any] + ) -> tuple[None, ...]: + return (None,) * len(values) + + def sort( + self, dtypes: Any, values: Sequence[Any], stable: Any, descending: Any + ) -> tuple[None, ...]: + return (None,) * len(values) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[None, tuple[None, ...]], + ) -> Union[None, tuple[None, ...]]: + num_values = reduction_num_outputs(reduction_type) + return (None,) * num_values if num_values > 1 else None + + def masked(self, mask: Any, body: Callable[..., Any], other: Any) -> None: + assert callable(body), "masked body must always be callable." + # The body can make additional calls, for e.g. ops.indirect_indexing + body() + + +def extract_free_symbols( + fn: Callable[..., Any], + index: Sequence[sympy.Expr], + rindex: Optional[Sequence[sympy.Expr]] = None, + unbacked_only: bool = True, +) -> OrderedSet[sympy.Symbol]: + from .ir import FlexibleLayout + + args = [index, rindex] if rindex is not None else [index] + handler = FreeSymbolsOpsHandler(unbacked_only) + # NB: I cargo culted the allow_indexing patch here, I don't understand why + # people do this all over + with ( + V.set_ops_handler(handler), + patch.object(FlexibleLayout, "allow_indexing", True), + ): + fn(*args) + return handler.symbols diff --git a/phivenv/Lib/site-packages/torch/_inductor/dtype_propagation.py b/phivenv/Lib/site-packages/torch/_inductor/dtype_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffa38849688ee8f6e786b1695874533d495c7da --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/dtype_propagation.py @@ -0,0 +1,380 @@ +# mypy: allow-untyped-defs +import functools +from collections.abc import Sequence +from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union + +import sympy + +import torch +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype +from torch.utils._ordered_set import OrderedSet + +from .ops_handler import OP_NAMES, OpsHandler +from .utils import upcast_compute_type +from .virtualized import OpsValue, V + + +T = TypeVar("T") + + +class DTypeVar(Protocol): + @property + def dtype(self) -> torch.dtype: ... + + +DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue] + + +# Inputs need to be cacheable (e.g., not a CSEVar) in order for the cache to be effective +# So first decompose CSEVars -> tuple before calling this + + +@functools.cache +def get_promoted_dtype( + *args: Sequence[tuple[torch.dtype, bool]], + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND] = None, +): + def construct_input(inp): + if inp[1]: + return torch.empty([], dtype=inp[0]) + else: + return torch.empty([1], dtype=inp[0]) + + inps = [construct_input(arg) for arg in args] + _, dtype = torch._prims_common.elementwise_dtypes( + *inps, + type_promotion_kind=( + type_promotion_kind + if type_promotion_kind + else ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ) + return dtype + + +def promote_types( + args: Sequence[DTypeArg], + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND] = None, +): + dtype_prop_candidates = [] + + for arg in args: + assert not isinstance(arg, str) + if isinstance(arg, OpsValue): + arg = arg.value + assert isinstance(arg, torch._prims_common.Number) or hasattr(arg, "dtype") + + if isinstance(arg, torch._prims_common.Number): + dtype_prop_candidates.append((type_to_dtype(type(arg)), True)) + continue + + dtype_prop_candidates.append((arg.dtype, getattr(arg, "is_scalar", False))) + + dtype = get_promoted_dtype( + *dtype_prop_candidates, + type_promotion_kind=type_promotion_kind, + ) + + return dtype + + +class DtypePropagationOpsHandler: + """ + Propagate dtype from args to output + """ + + # Singleton DtypePropagationOpsHandler, because we meta program over a number of op rules. + # Those are only defined after other inductor state has run. + + _instance: Optional["DtypePropagationOpsHandler"] = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + for op, rule in torch._inductor.utils.op_dtype_propagation_rules.items(): + fn = ( + functools.partial(self.return_dtype, dtype=rule.override_return_dtype) + if rule.override_return_dtype + else functools.partial( + self.op_dtype_rule, type_promotion_kind=rule.type_promotion_kind + ) + ) + setattr(self, op, fn) + + # Set pointwise operation rules + for op in torch._inductor.codegen.common.pointwise_overrides_data.values(): + if not hasattr(self, op.name): + setattr( + self, + op.name, + functools.partial( + self.op_dtype_rule, type_promotion_kind=op.type_promotion_kind + ), + ) + + # Set boolean operation rules + for op in torch._inductor.utils.boolean_ops(): + if not hasattr(self, op): + setattr( + self, op, functools.partial(self.return_dtype, dtype=torch.bool) + ) + + unimplemented_ops = OP_NAMES - OrderedSet(dir(self)) + torch._check( + len(unimplemented_ops) == 0, + lambda: f"Unimplemented dtype rule for ops: {unimplemented_ops}", + ) + + # metaprogrammed in __init__ + + @staticmethod + def op_dtype_rule( + *args: DTypeArg, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND + ) -> torch.dtype: + return promote_types(args, type_promotion_kind=type_promotion_kind) + + @staticmethod + def return_dtype(*args: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + # op rules + + @staticmethod + def constant(value: torch.types.Number, dtype: torch.dtype) -> torch.dtype: + return upcast_compute_type(dtype) + + @staticmethod + def load_seed(name: str, offset: int) -> torch.dtype: + return upcast_compute_type(V.graph.get_dtype(name)) + + @staticmethod + def randint64(seed: int, offset: int, low: int, high: int) -> torch.dtype: + return torch.int64 + + @staticmethod + def masked( + mask: DTypeArg, body: Callable[[], DTypeArg], other: DTypeArg + ) -> torch.dtype: + from .loop_body import LoopBodyBlock + + assert isinstance(body, LoopBodyBlock), "body must be a LoopBodyBlock" + # TODO - we avoid calling this in codegen, needs work for non codegen use cases + loads = body.graph.find_nodes(op="call_method", target="load") + if len(loads) <= 1: + return promote_types([other]) + + return upcast_compute_type(V.graph.get_dtype(loads[-1].args[1])) + + @staticmethod + def where(a: DTypeArg, b: DTypeArg, c: DTypeArg) -> torch.dtype: + return promote_types([b, c]) + + @staticmethod + def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> torch.dtype: + # TODO - TODO - rationalize index_expr. The dtype is not always used and we are inconsistent about int32 or int64 + # in lowerings. cpp just uses the dtype + if dtype not in (torch.int32, torch.int64) or not hasattr( + V.kernel, "index_dtype" + ): + return upcast_compute_type(dtype) + + return V.kernel.get_index_dtype_as_torch_dtype() + + @staticmethod + def to_dtype( + x: DTypeArg, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ) -> torch.dtype: + return upcast_compute_type(dtype) if use_compute_types else dtype + + @staticmethod + def to_dtype_bitcast( + x: DTypeArg, dtype: torch.dtype, src_dtype: torch.dtype + ) -> torch.dtype: + return upcast_compute_type(dtype) + + @staticmethod + def gelu(x: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def mul(a: DTypeArg, b: DTypeArg) -> torch.dtype: + return promote_types([a, b]) + + @staticmethod + def truediv(a: DTypeArg, b: DTypeArg) -> torch.dtype: + return promote_types([a, b]) + + @staticmethod + def pow(a: DTypeArg, b: DTypeArg) -> torch.dtype: + return promote_types([a, b]) + + @staticmethod + def mod(a: DTypeArg, b: DTypeArg) -> torch.dtype: + return promote_types([a, b]) + + @staticmethod + def indirect_indexing( + x: DTypeArg, size: int, check: bool = True, wrap_neg: bool = True + ) -> torch.dtype: + return torch.int64 + + @staticmethod + def randn(seed: int, offset: int) -> torch.dtype: + return torch.float + + @staticmethod + def rand(seed: int, offset: int) -> torch.dtype: + return torch.float + + @staticmethod + def store_reduction(name: str, index, value: DTypeArg) -> None: + return None + + @staticmethod + def reduction( + dtype: torch.dtype, src_dtype: torch.dtype, reduction_type: str, value: DTypeArg + ) -> torch.dtype: + return dtype + + @staticmethod + def store(name: str, index, value: DTypeArg, mode: Optional[str] = None) -> None: + return None + + @staticmethod + def load(name: str, index) -> torch.dtype: + return upcast_compute_type(V.graph.get_dtype(name)) + + @staticmethod + def floor(x: DTypeArg) -> torch.dtype: + return promote_types( + [x], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + @staticmethod + def ceil_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + @staticmethod + def int_truediv(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types( + [x, y], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + @staticmethod + def scan( + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[[tuple[T, ...], tuple[T, ...]], tuple[T, ...]], + values: tuple[T, ...], + ) -> tuple[torch.dtype, ...]: + return dtypes + + @staticmethod + def fmod(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x, y]) + + @staticmethod + def round_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + @staticmethod + def identity(x: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def frexp(x: DTypeArg) -> tuple[torch.dtype, torch.dtype]: + # TODO - need to handle multiple outputs + return (promote_types([x]), torch.int32) + + @staticmethod + def sort( + dtypes: tuple[torch.dtype, ...], + values: tuple[T, ...], + stable: bool, + descending: bool, + ) -> tuple[torch.dtype, ...]: + return dtypes + + @staticmethod + def trunc(x: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def bucketize( + values: DTypeArg, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: DTypeArg, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> torch.dtype: + return indexing_dtype + + @staticmethod + def rshift(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def round(x: DTypeArg) -> torch.dtype: + return promote_types( + [x], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + @staticmethod + def trunc_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + @staticmethod + def floor_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + @staticmethod + def truncdiv(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x, y]) + + @staticmethod + def floordiv(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x, y]) + + @staticmethod + def halide_clamp(value, size, check): + # TODO - way of registering dtype for op in backend + return torch.int32 + + @staticmethod + def inline_asm_elementwise( + *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1 + ): + return dtype + + @staticmethod + def lshift(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def check_bounds( + expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + return None + + def output(self, *args: DTypeArg) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear here" + ) + + def placeholder(self, index: int) -> torch.dtype: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear here" + ) + + +if TYPE_CHECKING: + + class _typecheck_DtypePropagation(DtypePropagationOpsHandler, OpsHandler[Any]): + pass # mypy will error if we got any of the signatures wrong diff --git a/phivenv/Lib/site-packages/torch/_inductor/exc.py b/phivenv/Lib/site-packages/torch/_inductor/exc.py new file mode 100644 index 0000000000000000000000000000000000000000..181b47dc57f8deca3f23af8c5f8ceeedf5d44b56 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/exc.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import os +import tempfile +import textwrap +from functools import lru_cache +from typing import Any, Optional, TYPE_CHECKING + +from torch._dynamo.exc import BackendCompilerFailed, ShortenTraceback + + +if TYPE_CHECKING: + import types + + from torch.cuda import _CudaDeviceProperties + +if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1": + + @lru_cache(None) + def _record_missing_op(target: Any) -> None: + with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd: + fd.write(str(target) + "\n") + +else: + + def _record_missing_op(target: Any) -> None: # type: ignore[misc] + pass + + +class OperatorIssue(RuntimeError): + @staticmethod + def operator_str(target: Any, args: list[Any], kwargs: dict[str, Any]) -> str: + lines = [f"target: {target}"] + [ + f"args[{i}]: {arg}" for i, arg in enumerate(args) + ] + if kwargs: + lines.append(f"kwargs: {kwargs}") + return textwrap.indent("\n".join(lines), " ") + + +class MissingOperatorWithoutDecomp(OperatorIssue): + def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None: + _record_missing_op(target) + super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}") + + +class MissingOperatorWithDecomp(OperatorIssue): + def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None: + _record_missing_op(target) + super().__init__( + f"missing decomposition\n{self.operator_str(target, args, kwargs)}" + + textwrap.dedent( + f""" + + There is a decomposition available for {target} in + torch._decomp.get_decompositions(). Please add this operator to the + `decompositions` list in torch._inductor.decomposition + """ + ) + ) + + +class LoweringException(OperatorIssue): + def __init__( + self, exc: Exception, target: Any, args: list[Any], kwargs: dict[str, Any] + ) -> None: + super().__init__( + f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}" + ) + + +class SubgraphLoweringException(RuntimeError): + pass + + +class InvalidCxxCompiler(RuntimeError): + def __init__(self) -> None: + from . import config + + super().__init__( + f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}" + ) + + +class CppWrapperCodegenError(RuntimeError): + def __init__(self, msg: str) -> None: + super().__init__(f"C++ wrapper codegen error: {msg}") + + +class CppCompileError(RuntimeError): + def __init__(self, cmd: list[str], output: str) -> None: + if isinstance(output, bytes): + output = output.decode("utf-8") + + super().__init__( + textwrap.dedent( + """ + C++ compile error + + Command: + {cmd} + + Output: + {output} + """ + ) + .strip() + .format(cmd=" ".join(cmd), output=output) + ) + + +class CUDACompileError(CppCompileError): + pass + + +class TritonMissing(ShortenTraceback): + def __init__(self, first_useful_frame: Optional[types.FrameType]) -> None: + super().__init__( + "Cannot find a working triton installation. " + "Either the package is not installed or it is too old. " + "More information on installing Triton can be found at: https://github.com/triton-lang/triton", + first_useful_frame=first_useful_frame, + ) + + +class GPUTooOldForTriton(ShortenTraceback): + def __init__( + self, + device_props: _CudaDeviceProperties, + first_useful_frame: Optional[types.FrameType], + ) -> None: + super().__init__( + f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, " + "which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, " + f"but your device is of CUDA capability {device_props.major}.{device_props.minor}", + first_useful_frame=first_useful_frame, + ) + + +class InductorError(BackendCompilerFailed): + backend_name = "inductor" + + def __init__( + self, + inner_exception: Exception, + first_useful_frame: Optional[types.FrameType], + ) -> None: + self.inner_exception = inner_exception + ShortenTraceback.__init__( + self, + f"{type(inner_exception).__name__}: {inner_exception}", + first_useful_frame=first_useful_frame, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/extern_node_serializer.py b/phivenv/Lib/site-packages/torch/_inductor/extern_node_serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb05f59544309b5c84ac8094fd4ff5909bc1da2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/extern_node_serializer.py @@ -0,0 +1,24 @@ +import json + +from torch._export.serde.schema import ExternKernelNode, ExternKernelNodes, Node +from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder +from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode + + +def serialize_extern_kernel_node( + extern_kernel_node: inductor_ExternKernelNode, +) -> ExternKernelNode: + assert isinstance(extern_kernel_node.node, Node) + return ExternKernelNode( + name=extern_kernel_node.name, + node=extern_kernel_node.node, + ) + + +def extern_node_json_serializer( + extern_kernel_nodes: list[inductor_ExternKernelNode], +) -> str: + serialized_nodes = ExternKernelNodes( + nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes] + ) + return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder) diff --git a/phivenv/Lib/site-packages/torch/_inductor/freezing.py b/phivenv/Lib/site-packages/torch/_inductor/freezing.py new file mode 100644 index 0000000000000000000000000000000000000000..cae6938699f4b9384f5f8df5d485497e51ad0b14 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/freezing.py @@ -0,0 +1,288 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import itertools +import logging +import weakref +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code +from torch._functorch.aot_autograd import MutationType +from torch._functorch.compile_utils import fx_graph_cse +from torch._inductor.constant_folding import constant_fold, replace_node_with_constant +from torch._inductor.freezing_utils import enter_freezing, record_has_frozen_params +from torch._inductor.fx_passes.freezing_patterns import freezing_passes +from torch._inductor.fx_passes.post_grad import view_to_reshape + +from . import config + + +aten = torch.ops.aten +prims = torch.ops.prims + +log = logging.getLogger(__name__) + + +def replace_params_with_constants( + gm: torch.fx.GraphModule, + flat_params: list[Any], + fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta, +) -> list[int]: + """ + Replaces the parameters of a PyTorch GraphModule with constants wherever possible. + Returns a list of indices representing the input parameters that were not converted to constants. + """ + params = gm.graph.find_nodes(op="placeholder") + fake_inp_nodes = params[: len(params)] + preserved_arg_indices = [] + aliased_input_args = [ + out_info.base_idx + for out_info in fw_metadata.output_info + if out_info.base_idx is not None + ] + + # TODO (tmanlaibaatar) figure out why this is different + # from mutated_inp_runtime_indices + mutated_inps = [ + i + for i, m in enumerate(fw_metadata.input_info) + if m.mutation_type + in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) + ] + + static_indices_new = [] + static_indices_offset = 0 + for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)): + if i in mutated_inps or i in aliased_input_args: + preserved_arg_indices.append(i) + if i in fw_metadata.static_input_indices: + new_static_index = i - static_indices_offset + static_indices_new.append(new_static_index) + else: + replace_node_with_constant(gm, node, real_input) + static_indices_offset += 1 + # add on non param inputs + preserved_arg_indices.extend(range(len(flat_params), len(params))) + # is this necessary ? + fw_metadata.static_input_indices = static_indices_new + gm.recompile() + return preserved_arg_indices + + +def freeze( + dynamo_gm: torch.fx.GraphModule, + aot_autograd_gm: torch.fx.GraphModule, + example_inputs: list[torch._subclasses.FakeTensor], +) -> tuple[torch.fx.GraphModule, list[int]]: + """ + Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation + and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency. + + Assumes that this function is run in dynamo tracing post aot_autograd. + + Args: + dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule. + aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen. + example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process. + + Returns: + Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices + of the inputs that were preserved (not turned into constants). + """ + with enter_freezing(): + return _freeze(dynamo_gm, aot_autograd_gm, example_inputs) + + +def _freeze( + dynamo_gm: torch.fx.GraphModule, + aot_autograd_gm: torch.fx.GraphModule, + example_inputs: list[torch._subclasses.FakeTensor], +) -> tuple[torch.fx.GraphModule, list[int]]: + # We have convert conv's weight to channels last which may meet error for .view + # when doing fake_tensor_prop. So we need to convert view to reshape first. + # See the details in fx_codegen_and_compile of compile_fx.py. + view_to_reshape(aot_autograd_gm) + + if tracing_context := torch._guards.TracingContext.try_get(): + fw_metadata = tracing_context.fw_metadata + assert tracing_context.params_flat_unwrap_subclasses is not None + params_flat = tracing_context.params_flat_unwrap_subclasses + assert fw_metadata is not None and params_flat is not None + + preserved_arg_indices = replace_params_with_constants( + aot_autograd_gm, params_flat, fw_metadata + ) + else: + inputs = aot_autograd_gm.graph.find_nodes(op="placeholder") + preserved_arg_indices = list(range(len(inputs))) + + # TODO - further restrict cse ? right now needed to dedup aliasing ops + cse_graph = fx_graph_cse(aot_autograd_gm.graph) + aot_autograd_gm.graph = cse_graph + aot_autograd_gm.recompile() + + aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices] + freezing_passes(aot_autograd_gm, aot_example_inputs) + + constant_fold(aot_autograd_gm) + # invalidate nn Modules + if config.freezing_discard_parameters: + invalidate_eager_modules() + discard_traced_gm_params(dynamo_gm) + + log.debug( + "%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm, colored=True) + ) + + record_has_frozen_params(aot_autograd_gm) + return aot_autograd_gm, preserved_arg_indices + + +class ErasedTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, name, owning_mod): + return super().__new__(cls, elem.to(device="meta")) + + def __init__(self, elem, name: Optional[str], mod) -> None: + self.erased_name = name + self.owning_mod_ref = weakref.ref(mod) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + erased_tensors = [ + e + for e in pytree.arg_tree_leaves(*args, **kwargs) + if isinstance(e, ErasedTensor) + ] + assert len(erased_tensors) > 0 + e = erased_tensors[0] + + raise RuntimeError( + f"Trying to run Pytorch Eager Module after Dynamo Freezing. " + "The original parameters have been discarded for memory efficiency. " + f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}" + ) + + +def invalidate_eager_modules(): + with torch.utils._python_dispatch._disable_current_modes(): + for ( + mod + ) in torch._guards.TracingContext.get().module_context.nn_modules.values(): + if not isinstance(mod, torch.nn.Module): + continue + + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), + mod.named_buffers(recurse=False), + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True + setattr(mod, attr_name, e_t) + + +def discard_traced_gm_params(mod: torch.fx.GraphModule): + with torch.utils._python_dispatch._disable_current_modes(): + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), mod.named_buffers(recurse=False) + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True + setattr(mod, attr_name, e_t) + + +def enforce_output_layout(gm: torch.fx.GraphModule): + """ + Make sure the output node's layout does not change due to compiler optimizations + by adding aten.as_strided nodes with the expected strides. + + Only used for inference so we can assume all graph outputs are model outputs. + """ + *_, output_node = gm.graph.nodes + out_list = output_node.args[0] + with gm.graph.inserting_before(output_node): + for n in out_list: + if not isinstance( + n.meta["val"], torch.Tensor + ) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]): + continue + + # add a node to enforce eager layout + ft = n.meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n, ft.stride()) + ) + + # can not call + # n.replace_all_uses_with(new_node) + # since it will replace the usage of n in new_node itself. + output_node.replace_input_with(n, new_node) + + gm.graph.lint() + gm.recompile() + + +def enforce_as_strided_input_layout(gm: torch.fx.GraphModule): + """ + Make sure the as_strided node's input's layout does not change due to compiler + optimizations, because the as_strided strides info depends on input tensor stride info. + """ + + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + ] + strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops] + for n in strided_nodes: + with gm.graph.inserting_before(n): + # add a node to enforce eager layout + ft = n.args[0].meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n.args[0], ft.stride()) + ) + n.replace_input_with(n.args[0], new_node) + + gm.graph.lint() + gm.recompile() + + +def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule): + """ + Convert 4d convolution weight tensor to channels last format. + + This pass is performed before freezing so the added nodes can be constant + folded by freezing. + """ + with dynamo_timed("convert_conv_weights_to_channels_last"): + convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default] + for conv in convs: + weight_node = conv.args[1] + if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[ + "val" + ].is_contiguous(memory_format=torch.channels_last): + # not a 4d tensor or already channels last, skip + continue + + with gm.graph.inserting_before(conv): + new_node = gm.graph.call_function( + aten.clone.default, + (weight_node,), + {"memory_format": torch.channels_last}, + ) + conv.replace_input_with(weight_node, new_node) + + enforce_as_strided_input_layout(gm) + enforce_output_layout(gm) diff --git a/phivenv/Lib/site-packages/torch/_inductor/freezing_utils.py b/phivenv/Lib/site-packages/torch/_inductor/freezing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f54dee5777fef4906b62c17f3435e959f2435113 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/freezing_utils.py @@ -0,0 +1,55 @@ +import contextlib +import threading +from collections.abc import Generator +from typing import Any + +import torch + + +_TLS = threading.local() + + +def _freezing_active() -> bool: + return getattr(_TLS, "freezing_active", False) + + +@contextlib.contextmanager +def enter_freezing() -> Generator[Any, None, None]: + """ + Context manager to designate when freezing is active. + """ + prev = _freezing_active() + _TLS.freezing_active = True + try: + yield + finally: + _TLS.freezing_active = prev + + +def record_has_frozen_params(gm: torch.fx.GraphModule) -> None: + """ + Mark the gm as having frozen params. + """ + gm._has_frozen_params = True # type: ignore[assignment] + + +def has_frozen_params(gm: torch.fx.GraphModule) -> bool: + """ + Return True if the gm has frozen parameters. + """ + return getattr(gm, "_has_frozen_params", False) + + +def maybe_set_is_frozen_param(t: torch.Tensor) -> None: + """ + Mark the provided tensor as a frozen param if freezing is active. + """ + if _freezing_active(): + t._is_frozen_param = True # type: ignore[attr-defined] + + +def is_frozen_param(t: torch.Tensor) -> bool: + """ + Return True if the tensor is a frozen param. + """ + return getattr(t, "_is_frozen_param", False) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fuzzer.py b/phivenv/Lib/site-packages/torch/_inductor/fuzzer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa61c3d11ce98cca7c695f3f3584d613b70e3c34 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fuzzer.py @@ -0,0 +1,1000 @@ +import importlib +import itertools +import logging +import pickle +import random +import signal +import string +import sys +import traceback +from collections.abc import KeysView, Sequence +from enum import Enum +from functools import partial, wraps +from types import FrameType +from typing import ( + Any, + Callable, + get_args, + get_origin, + Literal, + Optional, + TypeVar, + Union, +) + +import torch +from torch._inductor.custom_graph_pass import CustomGraphPass +from torch._inductor.scheduler import BaseSchedulerNode +from torch.utils._config_module import _ConfigEntry, ConfigModule +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + + +def is_type(type_hint, comp_type) -> bool: # type: ignore[no-untyped-def] + """ + Determines if type_hint is comp_type. There are some type annotations that this doesn't work for. + I think it's because some Type annotations are Type Objects and some are Special Forms, but not sure. + There's definite room for improvement to make this more general for someone who deeply understands + Python types. + """ + return type_hint is comp_type or get_origin(type_hint) is comp_type + + +def is_optional_type(type_hint) -> bool: # type: ignore[no-untyped-def] + """ + Special case of is_type. + """ + origin = get_origin(type_hint) + + if origin is Union: + args = get_args(type_hint) + return type(None) in args + + return False + + +def is_callable_type(type_hint) -> bool: # type: ignore[no-untyped-def] + """ + Special Case of is_type. + """ + return type_hint.__name__ == "Callable" + + +class DummyPass(CustomGraphPass): + """ + A Dummy pass to be used by ConfigFuzzer + """ + + def __call__(self, graph: torch.fx.graph.Graph) -> None: + return None + + def uuid(self) -> Optional[Any]: + return None + + +T = TypeVar("T") + + +class TypeExemplars: + """ + This class returns examples of a Type, given its class name. + """ + + TYPE_EXEMPLARS: dict[str, Any] = { + CustomGraphPass.__name__: DummyPass(), + torch.fx.graph.Graph.__name__: torch.fx.graph.Graph(), + BaseSchedulerNode.__name__: BaseSchedulerNode(None), # type: ignore[arg-type] + } + + @staticmethod + def example(t: type[T]) -> Optional[T]: + """ + Return an example of a class. + """ + return TypeExemplars.TYPE_EXEMPLARS.get(t.__name__, None) + + @staticmethod + def contains(t: type[T]) -> bool: + return t.__name__ in TypeExemplars.TYPE_EXEMPLARS + + +def check_halide_import() -> bool: + """checks if we have halide available""" + try: + importlib.import_module("halide") + return True + except ModuleNotFoundError: + return False + + +if check_halide_import(): + CUDA_BACKEND = ["triton", "halide"] +else: + CUDA_BACKEND = ["triton"] + + +class Status(Enum): + """ + The Status return value enum for Config Fuzzer + """ + + # ConfigFuzzer skipped the test + SKIPPED = "skipped" + # ConfigFuzzer compiled and ran the test and function it passed. + PASSED = "passed" + # ConfigFuzzer failed to compile the test function + FAILED_COMPILE = "failed_compile" + # ConfigFuzzer compiled the test function and running it raised an exception + FAILED_RUN_COMPILE_EXCEPTION = "failed_run_compile_exception" + # ConfigFuzzer ran eager and it raised an exception + FAILED_RUN_EAGER_EXCEPTION = "failed_run_eager_exception" + # ConfigFuzzer compiled the test function, but the return value indicated that the compiled value didn't match the + # value from eager (or however else you set up the comparison in the test function) + FAILED_RUN_RETURN = "failed_run_return" + + def failing(self) -> bool: + """ + Convenience method to check whether these status represent failure. + """ + return ( + self == Status.FAILED_COMPILE + or self == Status.FAILED_RUN_EAGER_EXCEPTION + or self == Status.FAILED_RUN_COMPILE_EXCEPTION + or self == Status.FAILED_RUN_RETURN + ) + + +# Sometime the types of configs aren't expressive enough to be captured by python type system, so the options can be +# manually specified here: +# TODO this needs to be indexed to the module, like inductor or dynamo, for name collisions +TYPE_OVERRIDES: dict[str, list[Any]] = { + "cuda_backend": CUDA_BACKEND, + "post_grad_fusion_options": [ + { + "batch_linear_post_grad": { + "shape_broadcast_batch_linear": True, + "fuse_nodes_with_same_users": True, + }, + "batch_aten_mul": {"fuse_nodes_with_same_parent": False}, + "batch_aten_sigmoid": {"fuse_nodes_with_same_parent": True}, + "batch_aten_add": {"fuse_nodes_with_same_parent": True}, + "normalization_aten_pass": {}, + "unbind_stack_aten_pass": {}, + }, + { + "batch_aten_add": {}, + "batch_aten_mul": {}, + "batch_aten_sub": {}, + "batch_aten_div": {}, + "group_linear": {"require_fbgemm": True}, + }, + ], + "autoheuristic_collect": ["pad_mm", "mixed_mm"], + "autoheuristic_use": ["pad_mm", "mixed_mm"], + "traceable_tensor_subclasses": [OrderedSet()], + "nontraceable_tensor_subclasses": [OrderedSet()], +} +SamplingType = Callable[[str, type[Any], Any], Any] + + +class SamplingMethod(Enum): + """ + This class handles the process of assigning concrete values to type annotations. So a type annotation of + ```python + foo: Optional[int] = None + ``` + Will be assigned an int if the dispatch function gets TOGGLE, or a 50/50 split between an int and None if it gets + RANDOM. + """ + + TOGGLE = "TOGGLE" # toggle to the opposite value + RANDOM = "RANDOM" # randomly choose an option + + @staticmethod + def _generate_value_for_type( + random_sample: bool, field_name: str, type_hint: type[Any], default: Any + ) -> Any: + """ + Generates a value of a type based on the setting. + """ + # look for name in type overrides + if field_name in TYPE_OVERRIDES: + return random.choice(TYPE_OVERRIDES[field_name]) + + if type_hint == bool: + return random.choice([True, False]) if random_sample else not default + elif type_hint == int: + # NOTE initially tried to use negation of the value, but it doesn't work because most types are ints + # when they should be natural numbers + zero. Python types to cover these values aren't super convenient. + return random.randint(0, 1000) + elif type_hint == float: + return random.uniform(0, 1000) + elif type_hint == str: + characters = string.ascii_letters + string.digits + string.punctuation + return "".join( + random.choice(characters) for _ in range(random.randint(1, 20)) + ) + elif is_type(type_hint, list): + elem_type = getattr( + type_hint, + "__args__", + [type(default[0])] if default and len(default) else [type(None)], + )[0] + new_default = default[0] if default and len(default) > 0 else None + return [ + SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, new_default + ) + for _ in range(random.randint(1, 3)) + ] + elif is_type(type_hint, set): # noqa: set_linter + indexable = list(default) + elem_type = getattr( + type_hint, + "__args__", + [type(indexable[0])] if default and len(default) else [type(None)], + )[0] + new_default = indexable[0] if default and len(default) > 0 else None + return { # noqa: set_linter + SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, new_default + ) + for _ in range(random.randint(1, 3)) + } + elif is_type(type_hint, OrderedSet): + indexable = list(default) + elem_type = getattr( + type_hint, + "__args__", + [type(indexable[0])] if default and len(default) else [type(None)], + )[0] + new_default = indexable[0] if default and len(default) > 0 else None + return OrderedSet( + [ + SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, new_default + ) + for _ in range(random.randint(1, 3)) + ] + ) + elif is_type(type_hint, dict): + key_type, value_type = getattr( + type_hint, + "__args__", + map(type, next(iter(default.items()))) + if (default is not None and len(default)) + else (type(None), type(None)), + ) + if default is not None and len(default.items()) > 0: + default_key, default_val = next(iter(default.items())) + else: + default_key, default_val = None, None + return { + SamplingMethod._generate_value_for_type( + random_sample, field_name, key_type, default_key + ): SamplingMethod._generate_value_for_type( + random_sample, field_name, value_type, default_val + ) + for _ in range(random.randint(0, 3)) + } + elif is_type(type_hint, Union): + # do whatever is not the type of default + try: + assert len(type_hint.__args__) > 1 + except AttributeError as err: + raise ValueError("Union type with no args") from err + if random_sample: + new_type = random.choice(type_hint.__args__) + else: + new_type = random.choice( + [t for t in type_hint.__args__ if t != type(default)] + ) + try: + new_default = new_type() + except Exception: # noqa: E722 + # if default constructor doesn't work, try None + new_default = None + + return SamplingMethod._generate_value_for_type( + random_sample, field_name, new_type, new_default + ) + elif is_type(type_hint, tuple): + args = getattr( + type_hint, + "__args__", + tuple(map(type, default)), + ) + zipped = zip(args, default) + return tuple( + map( # noqa: C417 + lambda x: SamplingMethod._generate_value_for_type( + random_sample, field_name, x[0], x[1] + ), + zipped, + ) + ) + elif is_type(type_hint, Literal): + try: + if random_sample: + return random.choice(type_hint.__args__) + else: + choices = [t for t in type_hint.__args__ if t != default] + if choices: + return random.choice(choices) + else: + return default + except AttributeError as err: + raise ValueError("Literal type with no args") from err + elif is_optional_type(type_hint): + try: + elem_type = type_hint.__args__[0] + except AttributeError as err: + raise ValueError("Optional type with no args") from err + if random_sample: + return random.choice( + [ + None, + SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, default + ), + ] + ) + else: + if default is None: + return SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, None + ) + else: + return None + elif type_hint is type(None): + return None + elif is_callable_type(type_hint): + try: + return_type = list(type_hint.__args__)[-1] + except AttributeError as err: + raise ValueError("Callable type with no args") from err + + @wraps(lambda *args, **kwargs: None) + def dummy_function(*args, **kwargs): # type: ignore[no-untyped-def] + return SamplingMethod._generate_value_for_type( + random_sample, field_name, return_type, None + ) + + return dummy_function + elif type_hint == torch._ops.OpOverload: + return torch.ops.aten.add.default + elif TypeExemplars.contains(type_hint): + return TypeExemplars.example(type_hint) + elif type_hint == Any: + return 1 if not default == 1 else 2 + else: + raise ValueError(f"Unable to process type {type_hint}. PRs welcome :)") + + @staticmethod + def dispatch(sm: "SamplingMethod") -> SamplingType: + """ + Returns a function that will generate values from a type, based on the SamplingMethod passed in. + """ + if sm == SamplingMethod.RANDOM: + return partial(SamplingMethod._generate_value_for_type, True) + elif sm == SamplingMethod.TOGGLE: + return partial(SamplingMethod._generate_value_for_type, False) + else: + raise ValueError(f"malformed sampling method: {sm}") + + +class Default: + """ + Singleton default object that will cause the ConfigFuzzer to always use the default value set in the config. + """ + + +DEFAULT = Default() + +# The combination of config settings being set (based on their strings) +ComboType = tuple[str, ...] + + +class ResultType: + """ + The mapping of the combo strings to the result status after running the config fuzzer. + """ + + _vals: dict[ComboType, Status] + + def __repr__(self) -> str: + return f"ResultType[{self._vals}]" + + def __init__(self) -> None: + self._vals = {} + + def __len__(self) -> int: + return len(self._vals) + + def num_ran(self) -> int: + """ + Returns how many combos actually ran (weren't skipped). + """ + ret = len(self._vals) + for status in self._vals.values(): + if status == Status.SKIPPED: + ret -= 1 + return ret + + def set(self, combo: ComboType, status: Status) -> None: + combo = tuple(sorted(combo)) + self._vals[combo] = status + + def lookup(self, combo: ComboType) -> Optional[Status]: + combo = tuple(sorted(combo)) + return self._vals.get(combo, None) + + def keys(self) -> KeysView[ComboType]: + return self._vals.keys() + + +# Type that maps config strings to their default value +ConfigType = dict[str, Any] +# Callable that returns a bool +FactoryOutputType = Callable[[], bool] +# input function factory +FactoryType = Callable[[], FactoryOutputType] + +# Why are some configs disabled by default? Because if we don't the fuzzer produces uninteresting results. +# It will always hone-in on these failures, even with the most basic model, making it useless for +# debugging more complex models. +# +# More explicit explanations are below: +# Out of Scope: We can't fuzz, say, the cuda version because that comes from the environment and will +# produce a failure if not aligned with env. +# Known Failure: Disabled due to known failure. Hopefully re-enable. Known failures are listed in the +# docstring of this file. +# Required: Required for the fuzzer to operate (removing caching, etc.) +# FSDP: Flag meant for FSDP that fails in non FSDP envs. Re-enable these if you're testing FSDP. +# Typing: disabled because the type annotation of the config isn't constrained enough to produce +# meaningful fuzz values. These could be improved. +# Timing: These take too long to compile, feel free to enable. +MODULE_DEFAULTS: dict[str, ConfigType] = { + "torch._inductor.config": { + "force_disable_caches": True, # Required + "cpp.cxx": DEFAULT, # Out of Scope + "TYPE_CHECKING": DEFAULT, # Not a config + "max_autotune_pointwise": DEFAULT, # Timing + "max_autotune_gemm": DEFAULT, # Timing, re-enable when autotune speed improvements merged. + "max_autotune_gemm_backends": DEFAULT, # Timing + "max_autotune_conv_backends": DEFAULT, # Timing + "max_autotune_gemm_search_space": DEFAULT, # Timing + "max_autotune_subproc_result_timeout_seconds": DEFAULT, # Timing + "max_autotune_subproc_graceful_timeout_seconds": DEFAULT, # Timing + "max_autotune_subproc_terminate_timeout_seconds": DEFAULT, # Timing + "aot_inductor.presets": DEFAULT, # Typing + "cuda.arch": DEFAULT, # Out of Scope + "cuda.version": DEFAULT, # Out of Scope + "cuda.cutlass_dir": DEFAULT, # Out of Scope + "cuda.cuda_cxx": DEFAULT, # Out of Scope + "rocm.arch": DEFAULT, # Out of Scope + "rocm.ck_supported_arch": DEFAULT, # Out of Scope + "rocm.ck_dir": DEFAULT, # Out of Scope + "rocm.rocm_home": DEFAULT, # Out of Scope + "check_stack_no_cycles_TESTING_ONLY": DEFAULT, # Testing + "sleep_sec_TESTING_ONLY": DEFAULT, # Testing + "triton.inject_relu_bug_TESTING_ONLY": DEFAULT, # Testing + "reorder_for_compute_comm_overlap": DEFAULT, # FSDP + "enabled_metric_tables": DEFAULT, # Typing + "triton.debug_sync_graph": DEFAULT, # Known Failure + "triton.debug_sync_kernel": DEFAULT, # Known Failure + "profile_bandwidth_regex": DEFAULT, # Known Failure + "disable_cpp_codegen": DEFAULT, # Known Failure + "trace.save_real_tensors": DEFAULT, # Known Failure + "pre_grad_fusion_options": DEFAULT, # Typing + "external_matmul": DEFAULT, # Typing, need to add this to type overrides or type exemplars. + "test_configs.autotune_choice_name_regex": DEFAULT, # Typing + "test_configs.autotune_choice_desc_regex": DEFAULT, # Typing + "cpp.enable_floating_point_contract_flag": DEFAULT, # Typing + "post_grad_custom_pre_pass": DEFAULT, # Typing + "post_grad_custom_post_pass": DEFAULT, # Typing + "reorder_for_compute_comm_overlap_passes": DEFAULT, # Typing + "joint_custom_post_pass": DEFAULT, # Typing + "joint_custom_pre_pass": DEFAULT, # Typing + "pre_grad_custom_pass": DEFAULT, # Typing + }, + "torch._dynamo.config": { + "traceable_tensor_subclasses": DEFAULT, # Typing + "nontraceable_tensor_subclasses": DEFAULT, # Typing + "compiled_autograd_kwargs_override": DEFAULT, # Typing + "fail_on_recompile_limit_hit": DEFAULT, # fails in combo with suppress_errors + "suppress_errors": DEFAULT, + }, +} + + +class ConfigFuzzer: + """ + This tool makes it easy to search through config state-space with a minimal reproduction or test, either for + debugging or just bug hunting. + It has two entry points: + - bisect, which randomly flips configs and tries to find the minimal reproduction upon failure. + - fuzz_n_tuple, which tries every combination of n configs. This grows quickly as a function of n, so beware. + bisect is recommended, but fuzz_n_tuple can give you peace of mind that a new config will compose with + every other config. + + The main interface is a function factory that will return Callables to be torch.compiled. This function factory + should return a test function when it's called. Said test function returns a boolean, which determines whether + the ConfigFuzzer considers it a successful run or not. Throwing an exception from within the function will be + considered a failure as well. + + # Example usage: + + ```python + import torch._inductor.config as cfg + + + def create_simple_test_model_gpu() -> FactoryOutputType: + batch_size = 32 + seq_length = 50 + hidden_size = 768 + + def test_fn() -> bool: + inp = torch.randn(batch_size, seq_length, hidden_size, device="cuda") + weight = torch.randn(hidden_size, hidden_size, device="cuda") + matmul_output = inp @ weight + final_output = torch.nn.LayerNorm(hidden_size, device="cuda")(matmul_output) + return True + + return test_fn + + + fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2) + + # Test every pair of configs: + results = fuzzer.fuzz_n_tuple(n, max_combinations=10000000) + + visualize_results(n, results) + + # Test random configs with bisection: + ret = fuzzer.bisect(num_attempts=10) + + # reproduce a failing config + fuzzer.reproduce( + [{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}] + ) + ``` + + The list of known failures on inductor config are: + cpp_wrapper, triton_debug_sync_graph + cpp_wrapper, triton_debug_sync_kernel + cpp_wrapper, disable_cpp_codegen + combo_kernels, benchmark_combo_kernel, profile_bandwidth, profile_bandwidth_regex + trace.enabled, trace.save_real_tensors + """ + + sample: SamplingType + default: ConfigType + + def __init__( + self, + config_module: ConfigModule, + test_model_fn_factory: FactoryType, + seed: int, + default: Optional[ConfigType] = None, + sm: SamplingMethod = SamplingMethod.TOGGLE, + test_timeout: int = 3600, + ): + """ + Args: + config_module: The module containing the configs to fuzz + test_model_fn_factory: Function that returns a test model, which runs and returns True if successful, or + the outputs if they should be compared with eager + seed: Randomness seed. + default: Default values for the config. Inductor has preset based on know failures. + sm: How type value samples are generated, default TOGGLE. + test_timeout: max time a test can take. + """ + if sys.version_info < (3, 10): + log.error("Only python 3.10 and later supported") + return + self.seed = seed + self.test_timeout = test_timeout + self.detailed_results: dict[ComboType, dict[str, Any]] = {} + self.config_module = config_module + self.test_model_fn_factory = test_model_fn_factory + self.fields: dict[str, _ConfigEntry] = self.config_module._config + self.sample = SamplingMethod.dispatch(sm) + + if default is None: + if self.config_module.__name__ in MODULE_DEFAULTS: + self.default = MODULE_DEFAULTS[self.config_module.__name__] + else: + raise ValueError("No default passed to ConfigFuzzer.") + else: + self.default = default + + def __repr__(self) -> str: + return ( + f"ConfigFuzzer(config_module={self.config_module}, " + f"test_model_fn_factor={self.test_model_fn_factory}, seed={self.seed}, default={self.default})" + ) + + def _set_config(self, field_name: str, value: Any) -> None: + """Set a config value in the module.""" + setattr(self.config_module, field_name, value) + + def _reset_configs(self) -> None: + """Reset all configs to their default values.""" + for field_name, field_obj in self.fields.items(): + self._set_config(field_name, field_obj.default) + + def new_config(self) -> ConfigType: + """creates a new config from the default""" + ret = { + name: val if val != DEFAULT else self.fields[name].default + for name, val in self.default.items() + } + return ret + + def reproduce(self, configs: Sequence[ConfigType]) -> ResultType: + """entrypoint to reproduce any failure""" + results = ResultType() + for conf in configs: + self._reproduce_single_helper(conf, results) + return results + + def _reproduce_single_helper(self, conf: ConfigType, results: ResultType) -> None: + print(f"Starting repro of {conf}") + new_config = self.new_config() + new_config.update(conf) + self.test_config(results, new_config) + print(f"Status of {conf}:\n{results.lookup(tuple(conf.keys()))}") + + def reproduce_single(self, config: ConfigType) -> ResultType: + results = ResultType() + self._reproduce_single_helper(config, results) + return results + + def _fuzz_helper(self, results: ResultType, combo: ComboType) -> Status: + print(combo) + if st := results.lookup(combo): + # we already processed this config + return st + + config = self.new_config() + + skip = False + for field_name in combo: + if field_name in config: + # don't break here because we need to build the config dict + skip = True + if field_name.startswith("_"): + skip = True + field = self.fields[field_name] + value = self.sample(field_name, field.value_type, field.default) + config[field_name] = value + if skip: + results.set(combo, Status.SKIPPED) + return Status.SKIPPED + + return self.test_config(results, config) + + def fuzz_n_tuple(self, n: int, max_combinations: int = 1000) -> ResultType: + """ + Test every combination of n configs. + + returns a dict of this shape: {(config-1, config-2... config-n): status} + """ + results = ResultType() + print(f"Starting {n}-tuple testing with seed {self.seed}") + random.seed(self.seed) + + for combo in itertools.combinations(self.fields, n): + st = self._fuzz_helper(results, combo) + if st != Status.SKIPPED: + max_combinations -= 1 + if max_combinations <= 0: + print("Reached maximum combinations limit") + break + + return results + + def save_state(self, filename: str = "fuzzer_state.pkl") -> None: + """Save the current fuzzer state to a file""" + with open(filename, "wb") as f: + pickle.dump( + {"results": self.results, "detailed_results": self.detailed_results}, f + ) + + def load_state(self, filename: str = "fuzzer_state.pkl") -> None: + """Load fuzzer state from a file""" + with open(filename, "rb") as f: + state = pickle.load(f) + self.results = state["results"] + self.detailed_results = state.get("detailed_results", {}) + + def timeout_handler(self, signum: int, frame: Optional[FrameType]) -> None: + raise TimeoutError("Test execution timed out") + + def test_config(self, results: ResultType, config: ConfigType) -> Status: + """ + Tests a config by calling the function produced by the factory function. + """ + original_handler = signal.signal(signal.SIGALRM, self.timeout_handler) + signal.alarm(self.test_timeout) + print(f"Testing config {config}") + config_tuple = tuple(config.keys()) + if ret := results.lookup(config_tuple): + signal.signal(signal.SIGALRM, original_handler) + return ret + + def print_config() -> None: + for field, value in config.items(): + print(f"{field} = {value}") + + def get_error_info(exc: Exception) -> dict[str, Any]: + return { + "exception": str(exc), + "traceback": traceback.format_exc(), + "config": config.copy(), + } + + def handle_return( + message: str, + return_status: Status, + print_traceback: bool, + exc: Optional[Exception], + ) -> Status: + signal.signal(signal.SIGALRM, original_handler) + print(f"{message} with config combination:") + print_config() + if exc: + self.detailed_results[config_tuple] = get_error_info(exc) + if print_traceback: + traceback.print_exc() + results.set(config_tuple, return_status) + return return_status + + # reset config + torch._dynamo.reset() + self._reset_configs() + for name, value in config.items(): + self._set_config(name, value) + + # try running eager + test_model_fn = self.test_model_fn_factory() + try: + test_model_fn() + except Exception as exc: # noqa: E722 + return handle_return( + "Eager exception", Status.FAILED_RUN_EAGER_EXCEPTION, True, exc + ) + + # try compilation + try: + test_model_fn2 = self.test_model_fn_factory() + comp = torch.compile(test_model_fn2, backend="inductor") + except Exception as exc: # noqa: E722 + return handle_return( + "Exception compiling", Status.FAILED_COMPILE, True, exc + ) + + # try running compiled + try: + compile_result = comp() + except Exception as exc: # noqa: E722 + return handle_return( + "Exception running compiled", + Status.FAILED_RUN_COMPILE_EXCEPTION, + True, + exc, + ) + + # bool return value means don't compare with eager + if not compile_result: + return handle_return( + "Function returned False", Status.FAILED_RUN_RETURN, False, None + ) + else: + return handle_return("Function succeeded", Status.PASSED, False, None) + + def bisect(self, num_attempts: int = 100, p: float = 0.5) -> list[ConfigType]: + """ + Test configs and bisect to minimal failing configuration. + """ + print(f"Starting random testing with bisection, seed {self.seed}, and p {p}") + random.seed(self.seed) + self._reset_configs() + results = ResultType() + ret: list[ConfigType] = [] + + for attempt in range(num_attempts): + print(f"Random attempt {attempt + 1}/{num_attempts}") + + config = self.new_config() + + for field_name, config_entry in self.fields.items(): + if ( + field_name not in config + and not field_name.startswith("_") + and "TESTING_ONLY" not in field_name + and random.random() < p + ): + value = self.sample( + field_name, config_entry.value_type, config_entry.default + ) + config[field_name] = value + + status = self.test_config(results, config) + if status not in OrderedSet([Status.PASSED, Status.SKIPPED]): + if minimal_failing_config := self._bisect_failing_config( + results, config + ): + print(f"Minimum failing config: {minimal_failing_config}") + ret.append(minimal_failing_config) + + return ret + + def _bisect_failing_config( + self, results: ResultType, failing_config: ConfigType + ) -> Optional[ConfigType]: + return self._bisect_failing_config_helper(results, list(failing_config.items())) + + def _bisect_failing_config_helper( + self, results: ResultType, failing_config: list[tuple[str, Any]] + ) -> Optional[ConfigType]: + """ + Bisect a failing configuration to find minimal set of configs that cause failure. + + Splits it into halves, then fourths, then tries dropping configs one-by-one. + """ + print(f"bisecting config: {failing_config}") + + if not failing_config: + return None + + def test(x: list[tuple[str, Any]]) -> Status: + d = dict(x) + result = self.test_config(results, d) + return result + + if len(failing_config) <= 1: + return dict(failing_config) if test(failing_config).failing() else None + + random.shuffle(failing_config) + + mid = len(failing_config) // 2 + first_half = failing_config[:mid] + second_half = failing_config[mid:] + if test(first_half).failing(): + return self._bisect_failing_config_helper(results, first_half) + if test(second_half).failing(): + return self._bisect_failing_config_helper(results, second_half) + + if len(failing_config) >= 8: + low = len(failing_config) // 4 + high = mid + low + quart1 = failing_config[low:] + if test(quart1).failing(): + return self._bisect_failing_config_helper(results, quart1) + quart2 = failing_config[:low] + second_half + if test(quart2).failing(): + return self._bisect_failing_config_helper(results, quart2) + quart3 = first_half + failing_config[:high] + if test(quart3).failing(): + return self._bisect_failing_config_helper(results, quart3) + quart4 = failing_config[high:] + if test(quart4).failing(): + return self._bisect_failing_config_helper(results, quart4) + # try dropping one value at a time + for i in range(len(failing_config)): + new_list = [x for j, x in enumerate(failing_config) if j != i] + if test(new_list).failing(): + return self._bisect_failing_config_helper(results, new_list) + # we have the minimal set + return dict(failing_config) + + +def visualize_results( + n: int, results: ResultType, filename: str = "results.html" +) -> None: + """ + Creates an HTML document representing the results of running the fuzzer with fuzz_n_tuple, with n = 2. + """ + # TODO support more dimensions + assert n == 2 + assert len(results) > 0 + + input_set: OrderedSet[str] = OrderedSet({}) + for key in results.keys(): + input_set.add(key[0]) + input_set.add(key[1]) + input_list = sorted(input_set) + + # Start the HTML content + html_content = """ + + + + + + Fuzzer Visualization + + + +

Fuzzer Visualization

+ + + """ + + html_content += "" + for col_name in input_list: + col = "
".join(col_name) + html_content += f"" + html_content += "" + + # Add table rows + for row_name in input_list: + html_content += f"" + for col_name in input_list: + # Determine the status class for the cell + status_enum = results.lookup((row_name, col_name)) + status_class = "" + status_val = "" + if status_enum == Status.SKIPPED: + status_class = "skipped" + status_val = "-" + elif status_enum == Status.PASSED: + status_class = "passed" + status_val = "O" + elif status_enum == Status.FAILED_RUN_EAGER_EXCEPTION: + status_class = "failed" + status_val = "e" + elif status_enum == Status.FAILED_RUN_COMPILE_EXCEPTION: + status_class = "failed" + status_val = "E" + elif status_enum == Status.FAILED_RUN_RETURN: + status_class = "failed" + status_val = "R" + elif status_enum == Status.FAILED_COMPILE: + status_class = "failed" + status_val = "C" + else: + status_class = "skipped" + status_val = "-" + + html_content += f'' + html_content += "" + + html_content += """ + +
\\{col}
{row_name}{status_val}
+ + + """ + + with open(filename, "w") as file: + file.write(html_content) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_utils.py b/phivenv/Lib/site-packages/torch/_inductor/fx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56cd454d5c80a5e3f6550835b18555e0b42bc18e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_utils.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +import operator +from collections import defaultdict +from typing import Any, Callable, Optional + +import sympy + +import torch +import torch.fx +from torch._dispatch.python import enable_python_dispatcher +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + statically_known_true, + sym_eq, +) +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map +from torch.utils.flop_counter import flop_registry + +from .virtualized import V + + +# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. +# Works for length 2 patterns with 1 module and 1 function/method. +def matches_module_function_pattern( + pattern: tuple[type[torch.nn.modules.Module], Callable[..., Any]], + node: torch.fx.node.Node, + modules: dict[str, torch.nn.modules.Module], +) -> bool: + if len(node.args) == 0: + return False + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node, torch.fx.Node + ): + return False + # the first node is call_module + if node.args[0].op != "call_module": + return False + if not isinstance(node.args[0].target, str): + return False + if node.args[0].target not in modules: + return False + if type(modules[node.args[0].target]) is not pattern[0]: + return False + # the second node is call_function or call_method + if node.op != "call_function" and node.op != "call_method": + return False + if node.target != pattern[1]: + return False + # make sure node.args[0] output is only used by current node. + if len(node.args[0].users) > 1: + return False + return True + + +class FakeTensorUpdater: + """ + The main idea here is that it's difficult to maintain accurate fake + tensors (our primary form of metadata) for each node in our graph as we + transform it. + + The most reliable way to obtain this information is by rerunning + faketensor propagation. However, in general, faketensor propagation is + fairly expensive. So, instead we'd like to only rerun faketensor + propagation on nodes that have changed. + + In order to detect which nodes have changed, we first hash its node, + target, and argument lists (which are immutable in FX). + + Then, whenever we call incremental_update, we check which FX nodes have a + new hash, and recompute the faketensor metadata for that node. Then, we + continue to recursively compute the faketensors for all users until the + fake tensors stop changing. + """ + + def __init__(self, graph: torch.fx.Graph) -> None: + self.processed_hashes = OrderedSet[Any]() + self.graph = graph + + for node in self.graph.nodes: + self.processed_hashes.add(self.hash_node(node)) + + def hash_node(self, node: torch.fx.Node): + # todo(chilli): Not a great hash function + return (node, node.target, id(node.args), id(node.kwargs)) + + def incremental_update(self): + existing_storages: defaultdict[Optional[int], int] = defaultdict(int) + for node in self.graph.nodes: + existing_storages[get_node_storage(node)] += 1 + + def is_intlist_same(new, old): + return statically_known_true(sym_eq(new, old)) + + def is_fake_tensor_same(new, old): + if type(new) != type(old): + return False + if isinstance(new, (list, tuple)): + if len(new) != len(old): + return False + return all( + is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old) + ) + if new is None: + return old is None + if not isinstance(new, torch.Tensor): + assert isinstance(new, (torch.SymInt, torch.SymBool, torch.SymFloat)), ( + f"Unknown type {type(new)} in {self.graph}" + ) + return ( + new.node.shape_env._maybe_evaluate_static( + sympy.Eq(new.node.expr, old.node.expr) + ) + == sympy.true + ) + if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout: + return False + if new.layout == torch.strided and ( + not is_intlist_same(new.stride(), old.stride()) + or not statically_known_true( + new.storage_offset() == old.storage_offset() + ) + ): + return False + + if new.device != old.device: + return False + + if get_storage(new) == get_storage(old): + return True + + # This is the case where it returns a completely fresh storage that's used nowhere else. + if ( + existing_storages[get_storage(old)] == 1 + and get_storage(new) not in existing_storages + ): + return True + return False + + def should_process_node(node): + # node.target for nodes returning true from this function + # are called under fake mode and does not work for inductor + # lowerings. We check if the node.target is an aten operator + # or operator.getitem which is used when returning multiple + # tensors from an op. + return node.op == "call_function" and ( + isinstance(node.target, torch._ops.OpOverload) + or node.target == operator.getitem + ) + + to_process = OrderedSet[int]() + for node in self.graph.nodes: + if ( + self.hash_node(node) in self.processed_hashes + and id(node) not in to_process + ): + continue + + if not should_process_node(node): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + with V.fake_mode, enable_python_dispatcher(): + new_fake_tensor = node.target(*args, **kwargs) + if "val" in node.meta and is_fake_tensor_same( + new_fake_tensor, node.meta["val"] + ): + continue + + rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor) + + node.meta["val"] = new_fake_tensor + if (shape_env := V.fake_mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor) + ): + # Refresh the bindings to the new symbols + node.meta["unbacked_bindings"] = symbol_to_path + + existing_storages[get_node_storage(node)] += 1 + + to_process.update([id(user) for user in node.users]) + + self.processed_hashes.add(self.hash_node(node)) + + +def get_storage(t: torch.Tensor) -> int: + return t.untyped_storage()._cdata + + +def get_node_storage(node: torch.fx.Node) -> Optional[int]: + if "val" not in node.meta: + return None + if not isinstance(node.meta["val"], torch.Tensor): + return None + if not torch._C._has_storage(node.meta["val"]): + return None + return get_storage(node.meta["val"]) + + +def get_fake(x): + if isinstance(x, torch.fx.Node): + if "val" not in x.meta: + return x + return x.meta["val"] + return x + + +def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], dict[str, Any]]: + """ + First value returns a boolean if any of the input nodes don't have a faketensor. + """ + args, kwargs = tree_map(get_fake, (x.args, x.kwargs)) + if any( + isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs) + ): + return False, args, kwargs + return True, args, kwargs + + +def is_node_realized(node: torch.fx.Node) -> bool: + """Returns true if a node is always realized when lowered to inductor IR. + + NOTE: This may return some false negatives. e.g. it doesn't + handle buffers realized heuristically during lowering, or + buffers realized indirectly through view ops. + """ + from torch._inductor.lowering import fallbacks, needs_realized_inputs + + def is_buffer(node: torch.fx.Node) -> bool: + if node.op == "call_function" and node.target is operator.getitem: + # For nodes with multiple outputs, we get the fx graph: + # foo = torch.ops.aten.foo(...) + # getitem = foo[0] + # getitem_1 = foo[1] + # where we need to check if foo is a fallback kernel + return is_buffer(node.args[0]) # type: ignore[arg-type] + return node.op in ("placeholder", "output") or node.target in fallbacks + + if is_buffer(node): + return True + + def realizes_inputs(node: torch.fx.Node) -> bool: + return node.op == "output" or node.target in needs_realized_inputs + + if any(realizes_inputs(user) for user in node.users): + return True + + # Otherwise, assume node isn't realized + return False + + +def count_flops_fx(node: torch.fx.Node) -> Optional[int]: + if isinstance(node.target, str): + return None + with FakeTensorMode(allow_non_fake_inputs=True): + success, args, kwargs = get_fake_args_kwargs(node) + + if success: + with torch.utils.flop_counter.FlopCounterMode( + display=False + ) as flop_counter_mode: + node.target(*args, **kwargs) + + counted_flops = flop_counter_mode.get_total_flops() + return counted_flops + return None + + +def countable_fx(node: torch.fx.Node) -> bool: + """ + Whether or not we can count the flops of an FX node. + """ + assert isinstance(node, torch.fx.Node) + if not hasattr(node, "target"): + return False + target = node.target + if not hasattr(target, "overloadpacket"): + return target in flop_registry + packet = target.overloadpacket + return packet in flop_registry diff --git a/phivenv/Lib/site-packages/torch/_inductor/graph.py b/phivenv/Lib/site-packages/torch/_inductor/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..74ea677b8606ca88deb5f3570d26e12d15e63169 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/graph.py @@ -0,0 +1,2432 @@ +from __future__ import annotations + +import contextlib +import functools +import itertools +import logging +import operator +import os +import re +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union + +import sympy +from sympy import Expr + +import torch +import torch._logging +import torch.fx +from torch import device, Tensor +from torch._decomp import get_decompositions +from torch._dynamo.utils import defake, dynamo_timed +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.utils import get_layout_constraint_tag +from torch._logging import LazyString, trace_structured +from torch._prims_common import ( + compute_required_storage_length, + make_channels_last_strides_for, +) +from torch._subclasses.fake_tensor import FakeTensor +from torch._utils_internal import full_aoti_runtime_assert +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import ( + _get_placeholder_expr, + free_unbacked_symbols, + has_free_symbols, + resolve_unbacked_bindings, + RuntimeAssert, + ShapeEnv, + SympyBoolean, + SymTypes, +) +from torch.fx.node import Node +from torch.utils._mode_utils import no_dispatch +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo + +from . import config, ir, metrics +from .codegen.common import ( + BackendFeature, + DeviceOpOverrides, + FileBackedGraphModule, + get_backend_features, + get_device_op_overrides, + get_wrapper_codegen_for_device, + init_backend_registration, + WorkspaceArg, +) +from .exc import ( + CppWrapperCodegenError, + LoweringException, + MissingOperatorWithDecomp, + MissingOperatorWithoutDecomp, +) +from .ir import ( + Constant, + DonatedBuffer, + FixedLayout, + get_device_type, + GraphPartitionSignature, + InputBuffer, + Pointwise, + Reduction, + StorageBox, + TensorBox, + TorchBindObject, +) +from .lowering import ( + constrain_to_fake_tensors, + constrain_to_fx_strides, + FALLBACK_ALLOW_LIST, + fallback_handler, + fallback_node_due_to_unsupported_type, + lowerings, + make_fallback, + maybe_layout_constraints, + needs_realized_inputs, + require_contiguous, + tag_to_layout_constraint, + unsupported_output_tensor, +) +from .runtime import autotune_cache +from .runtime.autotune_cache import AutotuneCacheBundler +from .sizevars import SizeVarAllocator +from .utils import ( + convert_shape_to_inductor, + gather_origins, + get_cloned_parameter_buffer_name, + get_donated_idxs, + get_sympy_Expr_dtype, + GraphPartitionMap, + is_same_tensor, + maybe_get_suppress_shape_guards_ctx, + normalize_name, + should_assume_input_aligned, + SUPPORTED_MKLDNN_DEVICES, + ValueWithLineMap, +) +from .virtualized import NullHandler, V + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + from types import ModuleType + + from torch._higher_order_ops.effects import _EffectType + from torch.fx import GraphModule + from torch.fx.graph import Graph + + from .codegen.wrapper import PythonWrapperCodegen + from .scheduler import BaseSchedulerNode + + CompiledModule = Union[ModuleType, FileBackedGraphModule] + +from torch._inductor.codecache import output_code_log + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") + +aten = torch.ops.aten + +_post_grad_graph_counter = itertools.count() + +if config.is_fbcode(): + from torch._inductor.fb.utils import log_module_code +else: + + def log_module_code(*args: Any, **kwargs: Any) -> None: + pass + + +def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]: + assert isinstance( + constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ), ( + "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" + ) + if isinstance(constant_buffer, sympy.core.numbers.Integer): + return torch.int64 + + if isinstance(constant_buffer, sympy.Expr): + return get_sympy_Expr_dtype(constant_buffer) + + if constant_buffer.is_integer: + return torch.int64 + elif constant_buffer.is_float: + return torch.float32 + else: + return None + + +def is_magic_method(op: Any) -> bool: + magic_ops = OrderedSet(method_to_operator(m) for m in magic_methods) + return op in magic_ops + + +def getattr_recursive( + obj: GraphModule, target: str +) -> Union[Tensor, torch._C.ScriptObject, GraphModule]: + target_atoms = target.split(".") + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]: + ret: dict[Node, tuple[int, ...]] = {} + output_node = g.find_nodes(op="output")[0] + + if "user_visible_output_idxs" not in output_node.meta: + return ret + + if not isinstance(output_node.args[0], torch.fx.Node): + output_node_args = output_node.args[0] + else: + output_node_args = output_node.args + + for idx, node in enumerate(output_node_args): + if idx in output_node.meta["user_visible_output_idxs"]: + ret[node] = output_node.meta["original_output_strides"][idx] + return ret + + +def mark_nodes_dislike_padding( + g: Graph, user_visible_output_strides: dict[Node, tuple[int, ...]] +) -> None: + """ + Nodes like convolution/convolution_backward want its input to be dense. + If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction. + + The pass finds nodes that dislike padding. These are nodes that can be reached + from a convolution/convolution_backward in the backward direction without + going thru a reduction. + """ + if not config.comprehensive_padding: + return + ops_dislike_padding = OrderedSet( + [ + aten.convolution, + aten.convolution_backward, + aten._scaled_mm, + ] + ) + # what's a better way to collect the reduction ops? + ops_like_padding = OrderedSet( + [ + aten.var_mean, + aten.sum, + aten.mean, + aten.prod, + aten.any, + aten.amin, + aten.amax, + aten.min, + aten.max, + aten.argmin, + aten.argmax, + aten.scatter_reduce, + ] + ) + + def _get_overload_packet( + node: torch.fx.Node, + ) -> Optional[torch._ops.OpOverloadPacket]: + return ( + node.target._overloadpacket + if node.op == "call_function" + # hasattr on OpOverloadPacket is slow, do isinstance first + and isinstance(node.target, torch._ops.OpOverload) + and hasattr(node.target, "_overloadpacket") + else None + ) + + for cur in reversed(g.nodes): + if isinstance( + cur.target, + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + ): + cur.meta["dislike_padding"] = True + continue + + if ( + isinstance(cur.target, torch._ops.OpOverload) + and get_layout_constraint_tag(cur.target) + == torch._C.Tag.needs_exact_strides + ): + cur.meta["dislike_padding"] = True + continue + + op = _get_overload_packet(cur) + if not op: + continue + if op in ops_dislike_padding: + cur.meta["dislike_padding"] = True + + if cur.meta.get("dislike_padding", False): + # propagate + for prior in cur.all_input_nodes: + prior_op = _get_overload_packet(prior) + if not prior_op: + continue + if prior_op not in ops_like_padding: + prior.meta["dislike_padding"] = True + # We only want to mark output nodes. So, move it after the above prior nodes process. + if not config.pad_outputs and cur in user_visible_output_strides: + cur.meta["dislike_padding"] = True + + +class GraphLowering(torch.fx.Interpreter): + graph_outputs: list[ir.IRNode] + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: Optional[Sequence[object]] = None, + shape_env: Optional[ShapeEnv] = None, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[ + Callable[[list[ir.ExternKernelNode]], Any] + ] = None, + is_inference: bool = False, + is_backward: bool = False, + is_const_graph: bool = False, + const_output_index: Optional[dict[str, int]] = None, + const_wrapper_code: Optional[str] = None, + const_kernel_code: Optional[str] = None, + const_module: Optional[GraphLowering] = None, + name: Optional[str] = None, + inputs_to_check: Optional[Sequence[int]] = None, + ) -> None: + super().__init__(gm) + self.example_inputs = example_inputs + self.layout_opt = ( + layout_opt + if layout_opt is not None + else self.decide_layout_opt(gm, is_inference=is_inference) + ) + self.num_channels_last_conv = 0 + self.is_inference = is_inference + self.is_backward = is_backward + self.is_const_graph = is_const_graph + self.const_wrapper_code = const_wrapper_code + self.const_kernel_code = const_kernel_code + self.const_module = const_module + self.inputs_to_check = inputs_to_check + + self.extra_traceback = False # we do our own error wrapping + if shape_env is None: + shape_env = ShapeEnv() + self.reuse_shape_env = False + else: + self.reuse_shape_env = True + self._shape_env = shape_env + # We're going to mutate ras_by_symbol as we finish generating them + self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = ( + shape_env.deferred_runtime_asserts.copy() + ) + self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() + self.sizevars = SizeVarAllocator(shape_env) + self.graph_input_names: list[str] = [] + self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {} + self.graph_inputs_original: dict[str, InputBuffer] = {} + self.partition_maps: Optional[list[GraphPartitionMap]] = None + self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet() + self.device_types: OrderedSet[str] = ( + const_module.device_types if const_module else OrderedSet() + ) + self.device_idxs: OrderedSet[int] = ( + const_module.device_idxs if const_module else OrderedSet() + ) + self.device_type = "cpu" + + # Inplace padding may require Inductor to allocate slightly larger + # tensor for padding. + self.buffer_to_padded_size: dict[str, list[int]] = {} + + self.buffers: list[ir.Buffer] = [] + self.operations: list[ir.Operation] = [] + self.const_output_index: dict[str, int] = ( + const_output_index if const_output_index else {} + ) + self.folded_constants: OrderedSet[str] = ( + OrderedSet(const_output_index.keys()) + if const_output_index + else OrderedSet() + ) + self.constants: dict[str, torch.Tensor] = ( + const_module.constants if const_module else {} + ) + self.named_buffers: dict[str, torch.Tensor] = ( + const_module.named_buffers if const_module else {} + ) + self.named_parameters: dict[str, torch.Tensor] = ( + const_module.named_parameters if const_module else {} + ) + self.torchbind_constants: dict[ + str, Union[torch._C.ScriptObject, FakeScriptObject] + ] = {} + self.seen_subgraphs: dict[str, ir.Subgraph] = {} + self.constant_reprs: dict[str, str] = {} + self.removed_operations: OrderedSet[str] = OrderedSet() + self.removed_buffers: OrderedSet[str] = OrderedSet() + self.removed_inplace_buffers: OrderedSet[str] = OrderedSet() + self.mutated_buffers: OrderedSet[str] = OrderedSet() + self.never_reuse_buffers: OrderedSet[str] = OrderedSet() + self.inplaced_to_remove: OrderedSet[str] = OrderedSet() + self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] + self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment] + # See `ProxyExecutor Design Note` in ir.py for more details + self.extern_kernel_nodes: list[ir.ExternKernelNode] = [] + + from torch._inductor.extern_node_serializer import extern_node_json_serializer + + self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = ( + extern_node_serializer + if config.is_fbcode() and extern_node_serializer + else extern_node_json_serializer + ) + + self.current_node: torch.fx.Node = None # type: ignore[assignment] + self.lists: dict[str, list[str]] = {} + self.mutated_inputs: OrderedSet[str] = OrderedSet() + self.mutated_input_idxs: list[int] = [] + self.name_to_buffer: dict[str, ir.Buffer] = {} + self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list) + self.name_to_op: dict[str, ir.Operation] = {} + self.creation_time = time.time() + self.name = name # type: ignore[assignment] + self.cpp_wrapper = cpp_wrapper + + # record multi_kernel choice for cpp_wrapper so the second pass knows + # which sub-kernel is picked. Copy cpp_wrapper to another variable + # since cpp_wrapper flag is OrderedSet to false for the first pass of codegen. + self.record_multi_kernel_choice = cpp_wrapper + self.multi_kernel_to_choice: dict[str, str] = {} + + self.aot_mode = aot_mode + self.graph_id = graph_id + self.post_grad_graph_id = next(_post_grad_graph_counter) + self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment] + + # record intermediate results for input of UsedDefinedTritonKernels + # This will be used if autotuning is done in one pass. + self.autotuning_inputs: Optional[list[torch.Tensor]] = None + self.autotuning_mapping: Optional[dict[str, dict[str, int]]] = None + self.autotuning_grids: Optional[dict[str, Any]] = None + + # current_device is set only during codegen of a device-specific kernel + # a graph can have many devices + self.current_device: Optional[torch.device] = None + + self.nodes_prefer_channels_last = ( + self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet() + ) + self._warned_fallback = OrderedSet(["aten.convolution_backward"]) + self.user_visible_output_strides = get_user_visible_output_strides(gm.graph) + mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides) + self.cache_key: str = "" # This is the cache key for the compiled artifact + self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored + self.cache_linemap: list[ + tuple[int, str] + ] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run + # Used if lowering encounters cases where cudagraphs are not supported + self.disable_cudagraphs_reason: Optional[str] = None + + # only keeping one node per device for stack trace purposes + self.device_node_mapping: dict[torch.device, torch.fx.Node] = {} + self.orig_gm: torch.fx.GraphModule = gm.__copy__() + for k, v in self.orig_gm.named_buffers(): + self.named_buffers[k] = v + for k, v in self.orig_gm.named_parameters(): + self.named_parameters[k] = v + self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr] + "dynamo_flat_name_to_original_fqn", {} + ) + self.allocated_constant_name: dict[str, str] = ( + const_module.allocated_constant_name if const_module is not None else {} + ) + init_backend_registration() + self.get_backend_features = functools.lru_cache(None)(get_backend_features) + + self.effectful_ops: dict[_EffectType, ir.Buffer] = {} + # Track the buffers that we know is unaligned + # This can either be a graph input or the output of fallback + # kernels. + self.unaligned_buffers: OrderedSet[str] = OrderedSet() + self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet() + + self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet() + # more aggressive prologue fusion + self.invoke_quant_ops: OrderedSet[str] = OrderedSet() + + # Below field is related to printing debug intermediate tensor values info for debugging + self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet() + + # state used by for Kernel.workspace + self.workspace_id = itertools.count() + + # track the current placeholder index that we are processing + self.placeholder_idx = -1 + + self.bw_donated_idxs = get_donated_idxs() + + def freeze_runtime_asserts(self) -> None: + self._shape_env.freeze_runtime_asserts() + + def symbolic_sizes_strides( + self, ex: torch.Tensor + ) -> tuple[Sequence[Union[int, Expr]], Sequence[Union[int, Expr]]]: + """ + Support dynamic shapes and dynamic strides by assigning variables + to each dimension. We duck-shape tensors, so if two tensors + have the same size they get assigned the same symbolic variable. + """ + if self.reuse_shape_env: + return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor( + ex.stride() + ) + else: + from torch._dynamo.source import ConstantSource + + # TODO: this should not be needed once #93059 lands + # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816 + # TODO: make a dedicated UnknownSource for this? + # NB: This is using the legacy default behavior from + # create_symbolic_sizes_strides_storage_offset but we hope we can + # just delete this entirely + source = ConstantSource( + f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}" + ) + ( + size, + stride, + _, + ) = self._shape_env.create_symbolic_sizes_strides_storage_offset( + ex, + source, + ) + + r_size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size] + r_stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride] + return r_size, r_stride + + def static_sizes_strides( + self, ex: torch.Tensor + ) -> tuple[list[sympy.Expr], list[sympy.Expr]]: + """ + Primarily used to weights + """ + size = [sympy.Integer(i) for i in ex.size()] + stride = [sympy.Integer(i) for i in ex.stride()] + return size, stride + + def get_allocation_size( + self, + node: Union[ + ir.TensorBox, ir.StorageBox, ir.Buffer, WorkspaceArg, ir.TorchBindObject + ], + ) -> Sequence[Expr]: + if isinstance(node, ir.TensorBox): + node = node.data # type: ignore[assignment] + if isinstance(node, ir.StorageBox): + node = node.data # type: ignore[assignment] + if ( + isinstance(node, ir.ComputedBuffer) + and node.name in self.buffer_to_padded_size + ): + return self.buffer_to_padded_size[node.name] + else: + return node.get_size() + + def get_allocation_storage_size( + self, node: Union[ir.Buffer, WorkspaceArg, ir.TorchBindObject] + ) -> Expr: + layout = node.get_layout() + size = self.get_allocation_size(node) # consider inplace padding + stride = layout.stride + offset = layout.offset + return compute_required_storage_length(size, stride, offset) # type: ignore[arg-type] + + def has_feature( + self, + device: Union[torch._inductor.ir.IRNode, device, None], + feature: BackendFeature, + ) -> bool: + assert isinstance(feature, BackendFeature), feature + return feature in self.get_backend_features(get_device_type(device)) + + def get_current_device_or_throw(self) -> torch.device: + if device := self.current_device: + return device + else: + raise RuntimeError("No current device") + + @contextlib.contextmanager + def set_current_device(self, device: torch.device) -> Iterator[None]: + prior = self.current_device + self.current_device = device + try: + yield + finally: + self.current_device = prior + + def get_training_phase(self) -> str: + if self.is_inference: + return "inference" + if self.is_backward: + return "backward" + return "forward" + + @staticmethod + def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool: + """ + Decide if we should enable layout optimization for this graph based on + heuristics. + """ + if not config.layout_optimization: + return False + + if config.force_layout_optimization: + return True + + conv_nodes = [ + n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default + ] + nconv = len(conv_nodes) + + if nconv == 0: + return False + + # For cpu backend and mkldnn enabled, we always use channels_last for better performance. + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and all( + n.args[idx].meta["val"].device.type in SUPPORTED_MKLDNN_DEVICES + for n in conv_nodes + for idx in [0, 1] + ) + ): + return True + + # Following models are skipped due to this: + # jx_nest_base + # volo_d1_224 + if len(list(gm.graph.nodes)) >= 300 * nconv: + log.debug("Skipped layout opt because only a few conv") + return False + + if any( + has_free_symbols(n.args[idx].meta["val"]) + for n in conv_nodes + for idx in [0, 1] + ): + log.debug( + "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670" + ) + return False + + def is_grouped(n: Any) -> bool: + meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator] + assert isinstance(meta_val, torch.Tensor) + return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator] + + def is_in_out_channel(n: torch.fx.Node) -> bool: + return ( + n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator] + and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator] + ) + + def is_small_channel(n: torch.fx.Node) -> bool: + return ( + n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator] + and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator] + ) + + # only grouped convolutions benchmarked as slower in conv samples for inference only + if is_inference: + from torch.utils.flop_counter import FlopCounterMode + + flop_counts: dict[str, float] = defaultdict(float) + for node in conv_nodes: + success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( + node + ) + + if success: + with FlopCounterMode(display=False) as flop_counter_mode: + with V.fake_mode: + node.target(*args, **kwargs) + + counted_flops = flop_counter_mode.get_total_flops() + if is_grouped(node): + node_type = "grouped" + elif is_small_channel(node): + node_type = "small" + elif is_in_out_channel(node): + node_type = "in_out" + else: + node_type = "default" + + flop_counts[node_type] += counted_flops + else: + log.debug("Conv inputs meta not found") + + # average benchmarked channels last speedup / slowdown, < 1 is speedup. + # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ + # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb + GROUPED_MULTIPLIER = 1.358 + DEFAULT_MULTIPLIER = 0.823 + IN_OUT_MULTIPLIER = 0.725 + SMALL_MULTIPLIER = 0.783 + + total_flops = sum(flop_counts.values()) + # TODO - get different values per hardware + weighted_flops = ( + flop_counts["grouped"] * GROUPED_MULTIPLIER + + flop_counts["small"] * SMALL_MULTIPLIER + + flop_counts["in_out"] * IN_OUT_MULTIPLIER + + flop_counts["default"] * DEFAULT_MULTIPLIER + ) + do_layout_opt = weighted_flops <= total_flops + if not do_layout_opt: + log.debug( + "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d", + total_flops, + weighted_flops, + ) + return do_layout_opt + + # Channels last layout can dramatically hurt grouped conv perf. E.g. + # Conv with arguments like + # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 2} + # slows down 31x using channels last.. + + # But a lot of timm models use depthwise separable convolution which will + # result in grouped convolution with in-channel size == 1. + # For those grouped convolution, channels last still helps a lot. + # E.g. + # Conv with arguments + # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 58} + # get 1.86x speedup with channels last layout. + # + # The following heuristics skip using channels-last if the model contains + # grouped convolution with in-channels > 1. + if any(map(is_grouped, conv_nodes)): + log.debug( + "Skip layout opt because found grouped convolution with >1 in_channels!" + ) + return False + + # For some models that contain convolution with larger in-channel than out-channel, applying + # channels last hurts performance. + # Following models are skipped due to this: + # - pytorch_unet + # - phlippe_densenet (slightly worse) + # - Background_Matting (1.22x -> 0.821x) + # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x) + if any(map(is_in_out_channel, conv_nodes)): + log.debug( + "Skip layout opt because some convolutions have smaller out_channel" + ) + return False + + # Following models are skipped due to this: + # - functorch_maml_omniglot + if all(map(is_small_channel, conv_nodes)): + log.debug("Skip layout opt because all convolution channels are too small") + return False + + return True + + def qualify_name(self, name: str) -> str: + """Prepend the given name with the graph name if any.""" + if self.name is not None: + return f"{self.name}_{name}" + return name + + def make_subgraph( + self, + gm: torch.fx.GraphModule, + example_inputs: list[torch.Tensor], + subgraph_name: str, + ) -> SubgraphLowering: + """ + Make a subgraph of the current graph with all inherited parts, except + the graph module (`gm`) and `example_inputs`. The subgraphs are lowered + separately and lifted into a separate function in the parent output + wrapper code. The subgraph name is qualified by the parent graph's + name. Note that the lifting of subgraph is supported for python wrapper + only. For cpp wrapper, we inline the subgraphs in the parent wrapper. + """ + return SubgraphLowering( + parent=self, + gm=gm, + example_inputs=example_inputs, + shape_env=self._shape_env, + cpp_wrapper=self.cpp_wrapper, + aot_mode=self.aot_mode, + extern_node_serializer=self.extern_node_serializer, + is_inference=self.is_inference, + is_backward=self.is_backward, + name=self.qualify_name(subgraph_name), + ) + + def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]: + """ + The rule to decide if an node prefer channels last is simple. + 1. if it's input/output of a convolution + 2. if one of its user prefers channels last + + We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs; + Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers + channels last. + + Consider the scenario: conv -> batch-norm -> relu -> conv + Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies: + 1. the output of batch-norm should be channels last initially since its input is a conv's output. + Forcing the batch-norm's output to be contiguous results in the first copy + 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output. + We need convert it to channels last layout which results in the second copy. + With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies + can be saved. + """ + output_set = OrderedSet[Node]() + for n in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] + if n.target == torch.ops.aten.convolution.default: + output_set.add(n) + continue + + for user in n.users: + if user in output_set: + output_set.add(n) + break + + # need a second pass to add downstream nodes of those channel last nodes to the sets. + # This pass is especially needed to avoid mix-layout kernel inputs in backward pass. + # + # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned + # from the fwd graph. Without this second pass, we will force relu's output to be contiguous. + # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last + # tensors and passed to a kernel. + # + # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x. + # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x . + # This also helps the following models: + # - res2net101_26w_4s + # - res2net50_14w_8s + # - sebotnet33ts_256 + for n in self.module.graph.nodes: # type: ignore[union-attr] + if n in output_set: + output_set.update(n.users) + + return output_set + + def warn_fallback(self, name: str) -> None: + if name not in self._warned_fallback: + self._warned_fallback.add(name) + perf_hint_log.info("Using FallbackKernel: %s", name) + + def add_device_info(self, device: torch.device) -> None: + self.device_types.add(device.type) + if device.index is not None: + self.device_idxs.add(device.index) + if V.graph.current_node and device not in self.device_node_mapping: + self.device_node_mapping[device] = V.graph.current_node + + @property + def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode: + return V.fake_mode + + def try_get_buffer( + self, buffer_name: str + ) -> Optional[Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject]]: + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name] + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name] + if buffer_name in self.constants: + data = V.graph.constants[buffer_name] + return ir.ConstantBuffer( + name=buffer_name, + layout=ir.FixedLayout( + data.device, data.dtype, *V.graph.static_sizes_strides(data) + ), + ) + + return None + + def add_symbol_graph_input(self, symbol: sympy.Expr) -> None: + raise RuntimeError("Should not be called for the main graph") + + def get_buffer( + self, buffer_name: str + ) -> Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject]: + buf = self.try_get_buffer(buffer_name) + if buf is not None: + return buf + raise RuntimeError(f"Failed to find buffer matching name {buffer_name}") + + def get_dtype(self, buffer_name: str) -> torch.dtype: + if buffer_name in self.constants: + return self.constants[buffer_name].dtype + # For a mutation op we should return the dtype of the buffer being mutated + if ( + hasattr(self.scheduler, "mutation_real_name") + and buffer_name in self.scheduler.mutation_real_name + ): + mutated_buf = self.scheduler.mutation_real_name[buffer_name] + if mutated_buf in self.name_to_buffer: + return self.name_to_buffer[mutated_buf].get_dtype() + if mutated_buf in self.graph_inputs: + return self.graph_inputs[mutated_buf].get_dtype() + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name].get_dtype() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_dtype() + m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name) + if m: + return self.get_dtype(m.group(1)) + raise KeyError(f"could not find {buffer_name}") + + def get_numel(self, buffer_name: str) -> Union[int, Expr]: + if buffer_name in self.constants: + return self.constants[buffer_name].numel() + if buffer_name in self.name_to_buffer: + buf = self.name_to_buffer[buffer_name] + if not buf.has_tensor_output(): + return 1 + return buf.get_numel() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_numel() + raise KeyError(f"could not find {buffer_name}") + + def run(self, *args: Any) -> Any: # type: ignore[override] + with dynamo_timed("GraphLowering.run"): + return super().run(*args) + + def register_operation(self, op: ir.Operation) -> str: + assert op.operation_name is None, f"Operation registered twice: {op}" + assert isinstance(op, ir.Operation) + name = self.qualify_name(f"op{len(self.operations)}") + self.operations.append(op) + self.name_to_op[name] = op + op.operation_name = name + return name + + def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) + self.name_to_buffer[name] = buffer + device = buffer.get_device() + if ( + # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 + device is not None + and not ( + isinstance(buffer, ir.ComputedBuffer) + and buffer.is_zero_elements() + and device == torch.device("cpu") + ) + ): + self.add_device_info(device) + + if set_name: + buffer.name = name + return name + + def register_operation_list(self, operation_names: list[str]) -> str: + name = self.qualify_name("list_" + "_".join(operation_names)) + self.lists[name] = operation_names + return name + + def register_users_of( + self, node_output: Union[Iterable[ir.IRNode], ir.IRNode] + ) -> None: + def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None: + if isinstance(value, (list, tuple)): + for x in value: + register(x) + if isinstance(value, ir.TensorBox): + for read_name in value.get_read_names(): + self.name_to_users[read_name].append(value) + + register(node_output) + + def mark_buffer_mutated(self, name: str) -> None: + """ + When a buffer is mutated we need to make sure all the reads to + the old version are realized before the mutation happens. + """ + assert isinstance(name, str) + self.mutated_buffers.add(name) + + if name not in self.name_to_users: + return + + for user in self.name_to_users[name]: + user.realize() + + def get_original_value_of_constant(self, name: str) -> torch.Tensor: + """ + In AOTI, module buffers may have been mutated during the tracing and compilation. + Thus we need to read from previously stored original buffers, to make sure the + generated model.so uses correct initial values. + """ + assert name in self.allocated_constant_name and name in self.constants, ( + "Can not find the original value for " + name + ) + orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name]) + return ( + self.module.meta[orig_name] # type: ignore[index] + if orig_name in self.module.meta # type: ignore[operator] + else self.constants[name] + ) + + def allocate_non_dup_const_name( + self, name: Optional[str], data: Union[Tensor] + ) -> str: + if not config.aot_inductor.use_runtime_constant_folding: + for constant_name, value in self.constants.items(): + if is_same_tensor(data, value): + return constant_name + + if name is None: + name = f"constant{len(self.constants)}" + orig_name = name + if name[0].isdigit(): + name = f"constant_{name}" + name = self.qualify_name(name) + # We may generate a var name for each constant in the codegen. + # Let's only keep sane characters. + prefix = normalize_name(name) + name = prefix + cnt = 0 + while name in self.constants: + name = f"{prefix}_{cnt}" + cnt += 1 + self.constants[name] = data + self.constant_reprs[name] = ( + f"{data.device!r} {data.dtype!r} " + f"{tuple(data.size())!r} {tuple(data.stride())!r} " + f"{hash(data):x}" + ) + self.allocated_constant_name[name] = orig_name # type: ignore[assignment] + return name + + def add_tensor_constant( + self, data: Tensor, name: Optional[str] = None + ) -> TensorBox: + new_name = self.allocate_non_dup_const_name(name, data) + return TensorBox.create( + ir.ConstantBuffer( + name=new_name, + layout=FixedLayout( + data.device, data.dtype, *self.static_sizes_strides(data) + ), + ) + ) + + def constant_name(self, name: str, device_override: Optional[torch.device]) -> str: + """ + We AOT copy constants to the devices they are needed on. + If device_override doesn't match the constant's device, then + copy it and return a different name. + """ + if self.constants[name].device == device_override or device_override is None: + return name + with torch.utils._python_dispatch._disable_current_modes(): + # caller might have OrderedSet fake tensor mode which will create a fake tensor + # when calling .to, so unset modes here + return self.allocate_non_dup_const_name( + f"{name}_{device_override.type}{device_override.index or 0}", + self.constants[name].to(device_override), + ) + + def placeholder( + self, + target: str, # type: ignore[override] + args: tuple[object], # type: ignore[override] + kwargs: dict[str, object], + ) -> Union[Expr, TensorBox, None]: + self.placeholder_idx += 1 + example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] + target = self.qualify_name(target) + if isinstance(example, SymTypes): + # TODO fix partitioning issue and re-enable for backward + # https://github.com/pytorch/pytorch/issues/155468. + if not V.graph.is_backward: + expr = _get_placeholder_expr(example.node) + else: + expr = example.node.expr + self.graph_inputs[target] = expr + self.graph_input_names.append(target) + return expr + elif isinstance(example, (int, bool, float)): + expr = sympy.sympify(example) + self.graph_inputs[target] = expr + self.graph_input_names.append(target) + return expr + elif isinstance(example, FakeScriptObject): + obj = TorchBindObject(name=target, value=example) + self.graph_inputs[target] = obj + self.graph_input_names.append(target) + return obj + elif example is None: + self.graph_input_names.append(target) + return None + if isinstance(example, BackwardState): + # Ignored arg, must be unused + # Alternately we could filter this out in AotAutograd + self.graph_input_names.append(target) + return None + # See note: Note: [Generator arguments in AOTDispatcher] + elif isinstance(example, torch.Generator): + assert ( + len(V.graph.current_node.users) == 1 + and next(iter(V.graph.current_node.users)).target + is torch._prims.rng_prims.graphsafe_run_with_rng_state + ) + gen = ir.GeneratorState(name=target, device=example.device) + self.graph_inputs[target] = gen # type: ignore[assignment] + self.graph_input_names.append(target) + return gen + + assert isinstance(example, torch.Tensor), example + # todo(chilli): We can remove the last check once we turn buffers into + # static shape tensors. That's a hack to workaround Inductor believing + # the buffer should be static but us passing in a fake tensor with + # symbolic shapes. + if not example._has_symbolic_sizes_strides: + # the first N inputs are weights + sizes, strides = self.static_sizes_strides(example) + else: + sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment] + + if ( + self.is_backward + and self.bw_donated_idxs + and self.placeholder_idx in self.bw_donated_idxs + ): + tensor = TensorBox.create( + DonatedBuffer( + name=target, + layout=FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + else: + # TODO(jansel): handle input aliasing + tensor = TensorBox.create( + InputBuffer( + name=target, + layout=FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + + self.graph_inputs[target] = tensor + self.graph_input_names.append(target) + self.graph_inputs_original[target] = tensor.data.data + if self.current_node.users: # cudagraphs should work with an unused CPU input + self.add_device_info(example.device) + + # Note: [Input Alignment handling in Inductor] + # Alignment matters for generating efficient code. Some operations, + # e.g. vectorized loads, can only be performed on aligned inputs. + # + # But if we codegen assuming aligned inputs and then get unaligned + # inputs at runtime, then we are forced to clone - which is bad for + # both perf and memory usage. + # + # One option would be to guard on storage_offset%ALIGNMENT, and then + # codegen based on this. But storage_offset guards turned out to be + # expensive and cause recompiles; Instead, we're generating code + # based on the alignment of the example input without guarding. + with maybe_get_suppress_shape_guards_ctx(): + if not should_assume_input_aligned(example): + self.unaligned_buffers.add(target) + return tensor + + def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override] + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + # hasattr on OpOverloadPacket is slow, check isinstance first + if not isinstance(target, torch._ops.OpOverloadPacket) and hasattr( + target, "_inductor_lowering_function" + ): + # passthrough lowerings from .pattern_matcher + return target(*args, **kwargs) + + if target not in lowerings: + assert isinstance(target, torch._ops.OpOverload), ( + f"{target} is not an OpOverload" + ) + base_name = target.name().split(".")[0] + if base_name in FALLBACK_ALLOW_LIST: + make_fallback(target, warn=False, override_decomp=True) + elif config.implicit_fallbacks: + error = ( + MissingOperatorWithDecomp + if get_decompositions([target]) + else MissingOperatorWithoutDecomp + ) + log.info( + "Creating implicit fallback for:\n%s", + error.operator_str(target, args, kwargs), + ) + + tag = get_layout_constraint_tag(target, with_default=False) + if ( + tag is None + and torch._library.utils.is_builtin(target) + and self.is_backward + ): + # for implicit fallback ATen ops during backward, if there + # is no layout constraint tag, we conservatively require contiguous + # input since some eager kernels do not + # support non-contiguous inputs. Otherwise they may silently cause + # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 + # We only do this For ATen ops and for backward. + # + # TODO: should really switch to "needs_fixed_stride" constraint on these + # and identify them one by one. + decided_constraint = require_contiguous # type: ignore[assignment] + else: + tag = get_layout_constraint_tag(target, with_default=True) + decided_constraint = tag_to_layout_constraint(tag) + + make_fallback(target, layout_constraint=decided_constraint) + + elif get_decompositions([target]): + # There isn't a good way to dynamically patch this in + # since AOT Autograd already ran. The error message tells + # the user how to fix it. + raise MissingOperatorWithDecomp(target, args, kwargs) + else: + raise MissingOperatorWithoutDecomp(target, args, kwargs) + + try: + log.debug(" via %s", lowerings[target]) # type: ignore[index] + + n = self.current_node + layout_constraints = maybe_layout_constraints(target) + if layout_constraints: + old_args, old_kwargs = args, kwargs + if layout_constraints is constrain_to_fake_tensors: + # only constrain_to_fake_tensor if this exists. + # otherwise, no constraints at all: the implication is + # that this operator was inserted by a custom pass + # so we'll give them the freedom. + if "eager_input_vals" in n.meta: + fake_args, fake_kwargs = n.meta["eager_input_vals"] + + # (fake_args, fake_kwargs) might not align with (args, kwargs). + # we need to normalize them based on the schema + assert isinstance(target, torch._ops.OpOverload) + + def normalize(args: Any, kwargs: Any) -> tuple[Any, Any]: + result = torch.fx.operator_schemas.normalize_function( + target, args, kwargs + ) + assert result is not None + return result[0], result[1] + + fake_args, fake_kwargs = normalize(fake_args, fake_kwargs) + args, kwargs = normalize(args, kwargs) + old_args, old_kwargs = normalize(old_args, old_kwargs) + + args, kwargs = constrain_to_fake_tensors( + args, kwargs, fake_args, fake_kwargs + ) + else: + args, kwargs = layout_constraints(n, *args, **kwargs) + + out = lowerings[target](*args, **kwargs) # type: ignore[index] + + if layout_constraints: + # layout_constraints are allowed to make new copies of the inputs. + # if they do, and if the target is mutable, then we need to + # write the new values back into the original inputs. + self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + + return out + except Exception as e: + raise LoweringException(e, target, args, kwargs).with_traceback( + e.__traceback__ + ) from None + + @staticmethod + def can_inline_constant(t: torch.Tensor) -> bool: + """ + True if this is a small constant attr that will be inlined. + """ + return len(t.shape) == 1 and t.shape[0] <= 8 + + def get_attr( + self, + target: str, # type: ignore[override] + args: tuple[()], # type: ignore[override] + kwargs: dict[str, object], + ) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]: + # this is a constant + value = getattr_recursive(self.module, target) # type: ignore[arg-type] + + if isinstance(value, torch.fx.GraphModule): + # Reuse the existing subgraph if we have seen it before already. + if target in self.seen_subgraphs: + return self.seen_subgraphs[target] + + out = ir.Subgraph(name=target, graph_module=value) + self.seen_subgraphs[target] = out + return out + + if isinstance(value, torch._C.ScriptObject): + self.torchbind_constants[target] = value + self.constant_reprs[target] = "" + return TorchBindObject(name=target, value=value) + elif isinstance(value, FakeScriptObject): + self.torchbind_constants[target] = value + self.constant_reprs[target] = "" + return TorchBindObject(name=target, value=value) + + assert isinstance(value, torch.Tensor) + if ( + config.aot_inductor.use_runtime_constant_folding + or config.always_keep_tensor_constants + or unsupported_output_tensor(value) + ): + return self.add_tensor_constant(value, target) + + with no_dispatch(): + if value.shape == (): + return Constant( + value=value.item(), dtype=value.dtype, device=value.device + ) + if self.can_inline_constant(value): + log.debug("Inlining constant: %s ", str(target)) + # tensor lowering has constant inlining logic + from .lowering import tensor + + return tensor(value.tolist(), dtype=value.dtype, device=value.device) + + return self.add_tensor_constant(value, target) + + def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn: + raise AssertionError + + def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn: + raise AssertionError + + def output( + self, + target: str, # type: ignore[override] + args: tuple[object], # type: ignore[override] + kwargs: dict[str, object], + ) -> None: + result = super().output(target, args, kwargs) # type: ignore[arg-type] + if not isinstance(result, (tuple, list)): + # nested subgraphs can have singleton outputs + result = (result,) + assert isinstance(result, (tuple, list)), type(result) + assert all( + isinstance( + x, + ( + TensorBox, + ir.Constant, + type(None), + ir.ConstantBuffer, + sympy.Expr, + sympy.logic.boolalg.Boolean, + int, + ir.EffectfulKernel, + ir.ShapeAsConstantBuffer, + ), + ) + for x in result + ), result + + fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type] + if not isinstance(fx_node_args, (tuple, list)): + # nested subgraphs can have singleton outputs + fx_node_args = (fx_node_args,) + result = [ir.ExternKernel.realize_input(x) for x in result] + result_correct_strides = [] + + assert len(fx_node_args) == len(result) + for r, fx_node in zip(result, fx_node_args): + if not isinstance(r, (ir.TensorBox, ir.BaseView)): + result_correct_strides.append(r) + elif isinstance(r.get_output_spec(), ir.CommBufferLayout): + # Active references to persistent comm buffers are not allowed + # outside of graphs + result_correct_strides.append(ir.ExternKernel.copy_input(r)) + else: + # AOT Autograd tries to detect stride divergence of inductor from output metadata. + # Here, we try to avoid spurious divergence by matching insignificant strides such as + + # should have already been realized + assert torch._inductor.ir.is_storage_and_layout(r) + meta_strides = [ + s.node.expr if isinstance(s, torch.SymInt) else s + for s in fx_node.meta["val"].stride() + ] + result_correct_strides.append( + ir.try_match_insignificant_strides(r, meta_strides) + ) + + self.graph_outputs = result_correct_strides + value: ir.IRNode + for name, value in self.graph_inputs.items(): + if isinstance(value, TorchBindObject): + continue + assert isinstance( + value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState) + ), f"Unsupported inductor graph input type: {type(value)}" + if not isinstance(value, TensorBox): + continue + value.realize() + assert isinstance(value, TensorBox) + value = value.data + assert isinstance(value, ir.StorageBox) + value_storage_box = value + value = value.data + if not isinstance(value, InputBuffer) or value.get_name() != name: + # one of our inputs was mutated, need to turn that into a copy + ir.MutationLayoutSHOULDREMOVE.realize_into( + value, self.graph_inputs_original[name] + ) + # replace output with mutated input + try: + ind = self.graph_outputs.index(value_storage_box) + self.graph_outputs[ind] = self.graph_inputs_original[name] + except ValueError: + pass + + self.finalize() + log.debug( + "Force channels last inputs for %d conv for the current graph with id %d", + self.num_channels_last_conv, + self.graph_id if self.graph_id is not None else -1, + ) + + def finalize(self) -> None: + for buf in self.buffers: + buf.decide_layout() + + @contextmanager + def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def] + old = self.current_node + try: + self.current_node = node + yield + finally: + self.current_node = old + + @contextmanager + def set_current_wrapper_code(self) -> Iterator[None]: + old = self.wrapper_code + try: + yield + finally: + self.wrapper_code = old + + def propagate_mutation( + self, + fx_node: torch.fx.Node, + old_args: tuple[Any], + old_kwargs: dict[str, Any], + new_args: tuple[Any], + new_kwargs: dict[str, Any], + ) -> None: + """Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs. + + Assumes we may have cloned old_args/old_kwargs into new_args/new_kwargs + and then called fx_node(*new_args, **new_kwargs). + + If fx_node mutates any of new_args/new_kwargs, and they are different from + old_args/old_kwargs, then we need to update the original tensor. + """ + assert len(old_args) == len(new_args) + assert len(old_kwargs) == len(new_kwargs) + + if fx_node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation: + kwargs = fx_node.kwargs["kwargs"] + assert isinstance(kwargs, dict) + mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors( + old_kwargs["kernel_idx"], + old_kwargs["constant_args_idx"], + { + k: v.meta["val"] if isinstance(v, torch.fx.Node) else v + for k, v in kwargs.items() + }, + old_kwargs["tma_descriptor_metadata"], + ) + for name in mutated: + old_arg = old_kwargs["kwargs"][name] + new_arg = new_kwargs["kwargs"][name] + if old_arg is new_arg: + continue + + self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {}) + return + + assert isinstance(fx_node.target, torch._ops.OpOverload) + + def maybe_propagate( + schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode + ) -> None: + if old_arg is new_arg: + return + if schema_arg.alias_info is not None and schema_arg.alias_info.is_write: + # The lowering for copy_ is smart enough to "replace" old_arg with + # new_arg in all future uses so a copy_ kernel never gets emitted. + # old_arg, new_arg may be immutable_list + if isinstance(old_arg, ir.IRNode): + old_arg = (old_arg,) # type: ignore[assignment] + new_arg = (new_arg,) # type: ignore[assignment] + + for old_arg_item, new_arg_item in zip(old_arg, new_arg): # type: ignore[call-overload] + if old_arg_item is new_arg_item: + continue + self.call_function( + torch.ops.aten.copy_.default, (old_arg_item, new_arg_item), {} + ) + + schema = fx_node.target._schema + for idx, (old_arg, new_arg) in enumerate(zip(old_args, new_args)): + schema_arg = schema.arguments[idx] + maybe_propagate(schema_arg, old_arg, new_arg) + + schema_kwargs = {arg.name: arg for arg in schema.arguments} + + for key in old_kwargs.keys(): + old_arg = old_kwargs[key] + new_arg = new_kwargs[key] + schema_arg = schema_kwargs[key] + maybe_propagate(schema_arg, old_arg, new_arg) + + def run_node(self, n: torch.fx.Node) -> object: + def debug(msg: str) -> None: + log.debug("lowering %s %s", LazyString(n.format_node), msg) # type: ignore[arg-type] + + from torch._inductor.compiler_bisector import CompilerBisector + + buffer_watermark = len(self.buffers) + operation_watermark = len(self.operations) + + # origins: OrderedSet[Union[Node, ir.IRNode]] = OrderedSet([n]) + origins: OrderedSet[Any] = OrderedSet([n]) + is_call_function = n.op == "call_function" + if is_call_function: + args, kwargs = self.fetch_args_kwargs_from_env(n) + origins |= gather_origins(args, kwargs) + with ( + ir.IRNode.current_origins(origins), + self.set_current_node(n), + V.set_current_node(n), + ): + if ( + n.op == "call_function" + and n.target + not in (operator.getitem, torch._higher_order_ops.invoke_subgraph) + and ( + fallback_node_due_to_unsupported_type(n) + or CompilerBisector.disable_subsystem( + "inductor", "lowerings", lambda: repr(n) + ) + ) + ): + debug("fallback_handler") + result = fallback_handler(n.target, add_to_fallback_set=False)( + *args, # type: ignore[possibly-undefined] + **kwargs, # type: ignore[possibly-undefined] + ) + elif ( + n.op == "call_function" + and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation + and config.triton_kernel_default_layout_constraint != "flexible_layout" + ): + debug("user_defined_triton_kernel_layout_constraints") + if ( + config.triton_kernel_default_layout_constraint + == "needs_fixed_stride_order" + ): + old_args = args # type: ignore[possibly-undefined] + old_kwargs = kwargs # type: ignore[possibly-undefined] + + if eager_input_vals := n.meta.get("eager_input_vals"): + inp_args = eager_input_vals[0] + inp_kwargs = eager_input_vals[1] + args, kwargs = constrain_to_fake_tensors( + args, kwargs, inp_args, inp_kwargs + ) + else: + args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index] + result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type] + self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + else: + raise RuntimeError( + f"Unknown triton_kernel_default_layout_constraint: {config.triton_kernel_default_layout_constraint}" + ) + elif is_magic_method(n.target): + # TODO: this is sus, it probably should be handled in the + # lowerings themselves similarly to sym_size/sym-stride + # https://github.com/pytorch/pytorch/issues/127789 + debug("is_magic_method") + if isinstance( + n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) + ): + result = n.meta["val"].node.expr + else: + result = super().run_node(n) + else: + debug("") + result = super().run_node(n) + + # require the same stride order for dense outputs, + # 1. user-land view() will not throw because inductor + # output different strides than eager + # long term the solution is to make view() always succeed + # with infallible strides. + # 2: as_strided ops, we need make sure its input has same size/stride with + # eager model to align with eager behavior. + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + torch.ops.aten.resize.default, + torch.ops.aten.resize_as.default, + ] + is_output = any(user.op == "output" for user in n.users) + is_user_visible = n in self.user_visible_output_strides + is_input_for_as_strided = any( + user.target in as_strided_ops for user in n.users + ) + + if n.meta.get("inductor_realize_to_strides", False) and isinstance( + result, TensorBox + ): + result.realize() + strides = n.meta["val"].stride() + sym_strides = torch._inductor.utils.any_is_symbolic(*strides) + if result.maybe_get_stride() != strides and not sym_strides: + stride_order = ir.get_stride_order(strides) + result = ir.ExternKernel.require_stride_order(result, stride_order) + if ( + is_output + and isinstance(result, TensorBox) + and isinstance(result.data, ir.BaseView) + ): + # Realize so that outputs are correctly aliased + result.realize() + + if (is_output or is_input_for_as_strided) and isinstance( + n.meta["val"], torch.Tensor + ): + if is_user_visible: + strides = self.user_visible_output_strides.get(n) + else: + strides = n.meta["val"].stride() + + if strides is not None and len(strides) > 0: + allow_padding = ( + config.pad_outputs or not is_user_visible + ) and not is_input_for_as_strided + dense = torch._prims_common.is_non_overlapping_and_dense( + n.meta["val"] + ) + unbacked_symbols_in_strides = ( + len(free_unbacked_symbols(strides)) > 0 + ) + if ( + not unbacked_symbols_in_strides + and dense + and len(result.get_size()) == 4 + and n in self.nodes_prefer_channels_last + and not is_user_visible + and not is_input_for_as_strided + ): + strides = ir.FlexibleLayout.stride_ordered_for_memory_format( + result.get_size(), torch.channels_last + ) + if not unbacked_symbols_in_strides and len(strides): + # To avoid converting possible view ops to a copy kernel, we use the previous + # require_exact_strides to handle views. But ultimately it's better to require + # the right strides at the tensor definition. + if n.meta["val"]._is_view() or isinstance( + result.data, ir.BaseView + ): + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order(strides), + allow_padding=allow_padding, + ) + else: + strides = [ + s.node.expr if isinstance(s, torch.SymInt) else s + for s in strides + ] + result = ir.ExternKernel.require_exact_strides( + result, strides, allow_padding=allow_padding + ) + + # Realize if (1) any user need inputs realized, or (2) there is + # already too many reads and rematerializing can be bad. + num_users = len(OrderedSet(n.users)) + if num_users > 1 and isinstance(result, TensorBox): + for user in n.users: + if user.target in needs_realized_inputs: + result.realize_hint() + # This inclusion is somewhat controversial (from + # discussion between Horace, Natalia, and Elias). + # Currently, it's not very clear why this is helpful. + # The general idea here is that even though a node may + # have FlexibleLayout, we still often *treat* it as if + # it was contiguous. This appears to sometimes result in + # suboptimal behavior. + # + # When we do a better job selecting layout, we should + # revisit this. + need_fixed_layout = [ + torch.ops.aten.convolution_backward.default, + torch.ops.aten.mm.default, + torch.ops.aten._int_mm.default, + ] + need_fixed_channels_last_layout = [] + if not self.layout_opt: + need_fixed_layout.append(torch.ops.aten.convolution.default) + if torch._C._has_mkldnn: + need_fixed_layout += [ + torch.ops.mkldnn._linear_pointwise.default, + torch.ops.mkldnn._linear_pointwise.binary, + torch.ops.aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qlinear_pointwise.default, + torch.ops.onednn.qlinear_pointwise.tensor, + torch.ops.onednn.qlinear_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.binary_tensor, + ] + need_fixed_channels_last_layout += [ + torch.ops.mkldnn._convolution_pointwise.default, + torch.ops.mkldnn._convolution_pointwise.binary, + torch.ops.mkldnn._convolution_pointwise_.binary, + torch.ops.mkldnn._convolution_transpose_pointwise.default, + torch.ops.onednn.qconv_pointwise.default, + torch.ops.onednn.qconv2d_pointwise.binary, + ] + if torch._C.has_mkl: + need_fixed_layout += [torch.ops.mkl._mkl_linear.default] + if user.target in need_fixed_layout: + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order(n.meta["val"].stride()), + allow_padding=True, + ) + if ( + user.target in need_fixed_channels_last_layout + and n is user.args[0] + ): + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order( + make_channels_last_strides_for(n.meta["val"].shape) + ), + ) + if user.op == "output": + if isinstance(result.data.data, (Pointwise, Reduction)): + result.realize() + + # TODO(jansel): introduce a store vs inline choice + result.mark_reuse(len(n.users)) + + # Realize if the IRNode already has accumulated lots of reads + if isinstance(result, TensorBox) and result.has_exceeded_max_reads(): + # Prevent excessive accumulation in a computed buffer, when + # there are multiple branches each with small number of memory + # reads, but they converge to a user. + result.realize_hint() + + # Realize if a Pointwise has too much stuff to be inlined. + # As this may cause RecursionError during Inductor's evaluation. + if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): + curr = result.data.data + if isinstance(curr, Pointwise): + # Use inner fn as a rough proxy. Good enough. + if curr.has_large_inner_fn(threshold=100): + result.realize() + + # This is not complete, but it doesn't have to be: origin_node + # tracking is best effort. The logic here critically relies on direct + # TensorBox -> StorageBox denoting a non-view; we don't bother trying + # to get views to work. Feel free to add any extra cases as needed. + # + # Note: we can't YOLO tree_map over this result, because if there are + # buffers or a view involved, we might not be able to validly assign + # the origin_node here. + if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox): + if isinstance(result.data.data, ir.Loops): + result.data.data._post_init_setattr("origin_node", n) + elif isinstance(result.data.data, ir.Buffer): + result.data.data._post_init_setattr("origin_node", n) + if isinstance(result.data.data, ir.ComputedBuffer) and isinstance( + result.data.data.data, ir.Loops + ): + result.data.data.data._post_init_setattr("origin_node", n) + # Not really multi-output, can straightforwardly recurse in + elif ( + isinstance(result.data.data, ir.MultiOutput) + and not result.data.data.indices + ): + if isinstance(result.data.data.inputs[0], ir.Buffer): + result.data.data.inputs[0]._post_init_setattr("origin_node", n) + + self.register_users_of(result) + + new_unbacked_defs = OrderedSet[sympy.Symbol]() + for buf in self.buffers[buffer_watermark:]: + new_unbacked_defs |= buf.get_unbacked_symbol_defs() + for op in self.operations[operation_watermark:]: + new_unbacked_defs |= op.get_unbacked_symbol_defs() + + shape_env = V.graph.sizevars.shape_env + + # An input can an unbacked symint i.e.: when mark_unabcked is used. + # in that case add it to new_unbacked_defs. + if ( + n.op == "placeholder" + and isinstance(result, sympy.Symbol) + and shape_env.is_unbacked_symint(result) + ): + new_unbacked_defs.add(result) + + def format_new_defs() -> str: + r = [ + f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n" + for buf in self.buffers[buffer_watermark:] + ] + r.extend( + f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n" + for op in self.operations[operation_watermark:] + ) + return "***\n".join(r) + + # We do not skip unbacked symints that are input for backward see the note below. + if V.graph.is_backward and n.op == "placeholder": + return result + + # Note [Backwards runtime asserts] + # Backwards poses an interesting problem for deferred runtime + # asserts. In the easy case, we may solely close over data + # dependent sized tensors, and there are no binding sites for + # unbacked SymInts. In this case, we can just drop all the + # runtime asserts on the floor: no non-placeholder bindings, no + # problem. + # + # However, it is *possible* for a fresh runtime assert to show up + # between forwards and backwards. Right now, the freezing process + # that happens when we lower forwards means that we will freeze + # runtime asserts, and then the moment the backwards lowering + # process attempts to add a new deferred runtime assert, we will + # fail. Let's say you remove that assert. Now when we get here, + # we need to make sure we actually emit these asserts (because we + # can't emit them in forwards, we already compiled it). So we + # have to do something here. But we don't want to reemit ALL + # deferred runtime asserts, we only want to emit the NEW ones. + # Therefore needing some sort of stratification in the ShapeEnv. + # This is all doable, it just hasn't been done yet. + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {}) + ) + assert unbacked_bindings is not None + # When we do lowering, it is possible we reallocate unbacked SymInts. + # So we need to line up the unbacked SymInts when performing the test + # here + # + # In principle, we could permit lowering to introduce MORE unbacked + # SymInts: as long as all the old unbacked ones are accounted for, + # it's fine for inductor to introduce extra calls to item()/unbacked() + # whatever. This actually happens in practice when an unbacked SymInt + # gets memoized away; naively, when Inductor reprocesses a kernel, it + # doesn't know that the memo still applies, and ends up allocating a + # new symbol. However, this is generally a bad thing: we may still + # end up needing to test equalities on the symbols, and a fresh + # symbol is likely to hit lots of GuardOnDataDependent errors that + # we already know facts for. + renamed_unbacked_bindings = OrderedSet( + V.fake_mode.shape_env.unbacked_renamings.get(s, s) + for s in unbacked_bindings.keys() + ) + assert new_unbacked_defs >= renamed_unbacked_bindings, ( + f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" + f"fx node is: {n.format_node()}\n" + f"new operations are:\n\n{format_new_defs()}" + ) + self.create_deferred_runtime_asserts(n, new_unbacked_defs) + return result + + def create_deferred_runtime_asserts( + self, n: torch.fx.Node, new_unbacked_defs: OrderedSet[sympy.Symbol] + ) -> None: + # [NOTE] Codegen runtime asserts in Inductor + # + # We need to generate runtime asserts directly in Inductor instead + # of just reusing the asserts from input graphs because we reuse the + # same ShapeEnv as before. In particular, on subsequent graph passes, + # we would immediately turn all of these assertions into noops, + # because when we evaluated their expressions, we would see that + # because we had a deferred runtime assert in the ShapeEnv, we + # know "oh, of course this expression is True" already. + # One example is below: + # + # class Model(torch.nn.Module): + # def forward(self, a, b, c): + # nz = torch.nonzero(a) + # ones = a.new_ones([nz.size(0), b.size(0)]) + # torch._check(ones.size(0) >= 1) + # equals = torch.add(ones, c) + # return equals + # torch._dynamo.mark_dynamic(c, 0) + # When we reuse the ShapeEnv in Inductor lowering, the check that checks + # a and nonzero have the same shape would be evaluated to True after we resolve + # unbacked bindings using the ShapeEnv. + # See test_unbacked_equals_input_size_runtime_assertion in test_aot_inductor. + # + # + # In addition to the Inductor generated runtime asserts, we also + # need the runtime asserts from the input graph, because some derived + # runtime asserts on backed symints are not generated in Inductor. One example is + # this: `y = x.reshape(100, -1).clone()`. x.shape[0] needs to be a multiple of 100. + # See test_aoti_runtime_asserts_backed_symint in test_aot_inductor. + + def make_assert(expr: SympyBoolean, msg: str) -> None: + assert_op = ir.AssertScalar(expr, msg) + self.register_buffer(assert_op, set_name=True) + self.register_operation(assert_op) + + if ( + full_aoti_runtime_assert() + and n.target == torch.ops.aten._assert_scalar.default + and self.aot_mode + ): + node_args, _ = self.fetch_args_kwargs_from_env(n) + if node_args[0] != True: # noqa: E712 + make_assert(node_args[0], f"{node_args[0]} to be True") + else: + # bound_unbacked_symbols tracks the symbols that are created so far, + # we use it to make sure that runtime assertions are added after all + # symbols used in them are defined. + self.bound_unbacked_symbols |= new_unbacked_defs + + shape_env = V.graph.sizevars.shape_env + + # Emit code for runtime asserts that can be inserted at this point. + for i0 in new_unbacked_defs: + ras = self.ras_by_symbol.pop(i0, []) + # NB: size-like not needed, we won't retrace + vr = shape_env.var_to_range[i0] + if not shape_env._default_unspecified_value_range().issubset(vr): + + def is_convertible(s: Expr) -> bool: + if s in (int_oo, -int_oo): + return False + try: + int(s) + return True + except TypeError: + return False + + if is_convertible(vr.lower): + make_assert(i0 >= vr.lower, f"{i0} >= {vr.lower}") + if is_convertible(vr.upper): + make_assert(i0 <= vr.upper, f"{i0} <= {vr.upper}") + + for ra in ras: + fvs = free_unbacked_symbols(ra.expr) + missing = fvs - self.bound_unbacked_symbols + if missing: + i1 = min(missing, key=str) + self.ras_by_symbol.setdefault(i1, []).append(ra) + else: + make_assert(ra.expr, f"{ra.expr}") + + def validate_can_generate_cpp_wrapper(self) -> None: + if config.disable_cpp_codegen: + raise CppWrapperCodegenError("C++ codegen is disabled") + + if sys.platform not in ("linux", "darwin", "win32"): + raise CppWrapperCodegenError(f"Unsupported platform {sys.platform}") + + def init_wrapper_code( + self, + is_subgraph: bool = False, + subgraph_name: Optional[str] = None, + parent_wrapper_code: Optional[PythonWrapperCodegen] = None, + partition_signatures: Optional[GraphPartitionSignature] = None, + ) -> None: + device_types = self.device_types.copy() + device_types.discard("cpu") + device_types.discard("meta") + # TODO(Eikan): Only support mixing cpu and other device now. + assert len(device_types) <= 1, "Does not support mixing {}".format( + "+".join(device_types) + ) + only_cpu = len(device_types) == 0 + self.device_type = "cpu" if only_cpu else device_types.pop() + + if self.cpp_wrapper: + self.validate_can_generate_cpp_wrapper() + + self.device_ops = get_device_op_overrides(self.device_type) + wrapper_code_gen_cls = get_wrapper_codegen_for_device( + self.device_type, self.cpp_wrapper + ) + assert wrapper_code_gen_cls is not None, ( + f"Device {self.device_type} not supported" + ) + self.wrapper_code = wrapper_code_gen_cls.create( + is_subgraph, + subgraph_name, + parent_wrapper_code, + partition_signatures, + ) + + if self.const_module: + self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter + + def extract_autotune_inputs( + self, example_inputs: list[Union[int, float, torch.Tensor]] + ) -> None: + import copy + + cloned_gm = copy.deepcopy(self.orig_gm) + example_inputs = copy.deepcopy(example_inputs) + triton_nodes = [] + for node in cloned_gm.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation + ): + triton_nodes.append(node) + + # Store grid related nodes + grid_inputs: list[torch.fx.Node] = [] + visited_grids: dict[torch.fx.Node, int] = {} + # Store kwargs related nodes + triton_inputs: dict[str, Any] = {} + kwargs_inputs: list[torch.fx.Node] = [] + visited_kwargs: dict[Any, int] = {} + for node in triton_nodes: + # first check whether we have fx node in grid settings. + for grid in node.kwargs["grid"]: + for val in grid: + if val in visited_grids: + continue + + if isinstance(val, torch.fx.Node): + visited_grids[val] = len(grid_inputs) + grid_inputs.append(val) + + kwargs = node.kwargs["kwargs"] + # identify which args might be mutated, those should be cloned. + mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors( + node.kwargs["kernel_idx"], + node.kwargs["constant_args_idx"], + { + k: v.meta["val"] if isinstance(v, torch.fx.Node) else v + for k, v in kwargs.items() + }, + node.kwargs["tma_descriptor_metadata"], + ) + + new_kwargs: dict[str, int] = {} + with cloned_gm.graph.inserting_before(node): + for k, v in kwargs.items(): + if k in mutated: + new_node = cloned_gm.graph.call_function(torch.clone, args=(v,)) + new_kwargs[k] = len(kwargs_inputs) + kwargs_inputs.append(new_node) + continue + + if v in visited_kwargs: + new_kwargs[k] = visited_kwargs[v] + continue + visited_kwargs[v] = len(kwargs_inputs) + kwargs_inputs.append(v) + new_kwargs[k] = visited_kwargs[v] + triton_inputs[node.name] = new_kwargs + + new_outputs = kwargs_inputs + grid_inputs + for node in cloned_gm.graph.nodes: + if node.op == "output": + node.args = (tuple(new_outputs),) + break + + cloned_gm.recompile() + runner = torch.fx.Interpreter(cloned_gm) + returned_outputs = runner.run(example_inputs) + # Extract and store the grid for autotuning + if len(grid_inputs) > 0: + grid_outputs = returned_outputs[len(kwargs_inputs) :] + self.autotuning_grids = {} + for node in triton_nodes: + dynamic_grid = False + new_grids: list[tuple[Any]] = [] + for grid in node.kwargs["grid"]: + new_grid = [] + for val in grid: + if not isinstance(val, torch.fx.Node): + new_grid.append(val) + continue + dynamic_grid = True + new_grid.append(grid_outputs[visited_grids[val]]) + new_grids.append(tuple(new_grid)) + + if dynamic_grid: + self.autotuning_grids[node.name] = new_grids + # Store the kwargs input for autotuning + self.autotuning_inputs = returned_outputs[: len(kwargs_inputs)] + self.autotuning_mapping = triton_inputs + + def codegen_with_cpp_wrapper( + self, + ) -> tuple[ValueWithLineMap, ValueWithLineMap]: + """ + For GPU, Triton kernels are autotuned and stored as cubin files + """ + if any(device in self.device_types for device in ["cuda", "xpu"]): + + def extract_real_inputs() -> list[Union[int, float, torch.Tensor]]: + def materialize( + x: Union[torch.SymInt, torch.SymFloat, torch.Tensor], + ) -> Union[int, float, torch.Tensor]: + if x is None: + return None + elif isinstance(x, (torch.SymInt, torch.SymFloat)): + # Need concrete value to run dynamic shapes and tune the result + return x.node.hint + elif isinstance(x, FakeTensor): + return defake(x) + else: + assert isinstance(x, torch.Tensor), ( + "Unknown type when creating real inputs" + str(type(x)) + ) + return x + + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None and not isinstance( + V.real_inputs, NullHandler + ): + if tracing_context.output_strides: + tracing_context.output_strides.clear() + + params_flat = [ + param + for param in tracing_context.params_flat # type: ignore[union-attr] + if param is not None + ] + real_inputs = [ + materialize(x) + for x in itertools.chain(params_flat, V.real_inputs) + ] + else: + # In the backward pass, V.real_inputs is not OrderedSet. + # Generating random inputs based on self.example_inputs sometimes can be problematic, + # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process. + real_inputs = [ + materialize(x) # type:ignore[arg-type] + for x in ( + self.example_inputs # type:ignore[union-attr] + if isinstance(V.real_inputs, NullHandler) + else V.real_inputs + ) + ] + + if self.mutated_inputs: + from .compile_fx import clone_preserve_strides + + mutated_input_idxs = [ + idx + for idx, name in enumerate(self.graph_inputs) + if name in self.mutated_inputs + and isinstance(real_inputs[idx], torch.Tensor) + ] + for idx in mutated_input_idxs: + # clone mutated Tensor inputs to avoid mutating them in + # the first pass of the CPP wrapper-based compilation, as + # this will lead to a side effect on the example inputs: + # e.g. if torch.compile(f)(x) if called on input-mutating + # f, the inputs x will be mutated twice in the process: + # once here, and again when running the compiled model; + # this will also lead to a numerically incorrect output + mutated_inp = real_inputs[idx] + assert isinstance(mutated_inp, torch.Tensor) + real_inputs[idx] = clone_preserve_strides(mutated_inp) + del mutated_inp + return real_inputs + + if config.triton.autotune_at_compile_time: + # If autotune_at_compile_time is True, we can do the codegen in one-pass + # We will construct the autotuning values if user defined kernel exists. + if config.triton.autotune_with_sample_inputs: + user_defined_kernels = False + for op in self.operations: + if isinstance(op, ir.UserDefinedTritonKernel): + user_defined_kernels = True + break + if user_defined_kernels: + real_inputs = extract_real_inputs() + self.extract_autotune_inputs(real_inputs) + return self.codegen() + else: + # first pass + self.cpp_wrapper = False + compiled = self.compile_to_module().call + + real_inputs = extract_real_inputs() + with torch.utils._python_dispatch._disable_current_modes(): + compiled(real_inputs) + del real_inputs + + # second pass + self.cpp_wrapper = True + self.removed_buffers.clear() + self.removed_operations.clear() + self.inplaced_to_remove.clear() + V.graph.sizevars.precomputed_replacements.clear() + V.graph.sizevars.inv_precomputed_replacements.clear() + metrics.reset() + with config.patch({"triton.autotune_at_compile_time": False}): + return self.codegen() + else: + # cpu + return self.codegen() + + def _update_scheduler(self) -> None: + """ + (Re)initializes the scheduler member. When initializing the scheduler, no CUBIN + files should be generated (to avoid biasing any benchmarks and pessimizing + fusion decisions). + """ + from .scheduler import Scheduler + + with config.patch("triton.store_cubin", False): + self.scheduler = Scheduler(self.operations) + + def codegen(self) -> tuple[ValueWithLineMap, ValueWithLineMap]: + with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True): + self.init_wrapper_code() + + self._update_scheduler() + V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes) + + self.wrapper_code.push_codegened_graph(self) + self.scheduler.codegen() + + log.debug( + "Finished codegen for all nodes. The list of kernel names available: %s", + V.graph.all_codegen_kernel_names, + ) + + result = self.wrapper_code.generate(self.is_inference) + self.wrapper_code.pop_codegened_graph() + return result + + def codegen_subgraph(self, parent_graph: GraphLowering) -> None: + """ + This is a more compact version of the `codegen()` above + where we codegen this graph as a subgraph of some parent + graph. The parent graph is passed as an argument: the + intention is to inline codegening of the subgraph in + the parent graph's wrapper code (including the generated + kernels). The wrapper code is not finalized (via `.generate()` + call), as this will be done in the parent graph's `codegen()`. + """ + with dynamo_timed("GraphLowering.codegen_subgraph", log_pt2_compile_event=True): + self.wrapper_code = parent_graph.wrapper_code + self.device_ops = parent_graph.device_ops + self.cpp_wrapper = parent_graph.cpp_wrapper + + self._update_scheduler() + self.scheduler.codegen() + + def count_bytes( + self, + ) -> tuple[ + int, list[tuple[BaseSchedulerNode, int]], list[tuple[BaseSchedulerNode, float]] + ]: + total_bytes = 0 + node_counts = [] + node_runtimes = [] + for node in self.scheduler.nodes: + num_bytes = node.get_read_write_buffers_sizes() + total_bytes += num_bytes + node_counts.append((node, num_bytes // 4)) + node_runtimes.append((node, node.get_estimated_runtime())) + + return total_bytes, node_counts, node_runtimes + + # No-op to be patched for unit tests + save_output_code: Optional[Callable[[str], None]] = None + + def compile_to_module(self) -> CompiledModule: + with dynamo_timed( + "GraphLowering.compile_to_module", + phase_name="code_gen", + log_pt2_compile_event=True, + dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us", + ): + return self._compile_to_module() + + def _compile_to_module(self) -> CompiledModule: + # If we're here, we don't have to worry about the kernel code, which is only + # returned separately in AOTInductor mode. + wrapper_code, _ = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + + if isinstance(wrapper_code, ValueWithLineMap): + mod = self._compile_to_module_lines(wrapper_code) + elif isinstance(wrapper_code, FileBackedGraphModule): + mod = wrapper_code + else: + raise NotImplementedError( + f"Unrecognized wrapper code type: {type(wrapper_code)}" + ) + + # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029 + # TODO. Revisit this once the logging API is more mature + assert mod.__file__ is not None + + log_module_code(mod.__file__) + log.debug("Output code written to: %s", mod.__file__) + output_code_log.info("Output code written to: %s", mod.__file__) + if config.benchmark_kernel: + print(f"Compiled module path: {mod.__file__}", file=sys.stderr) + V.debug.output_code(mod.__file__) + V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") + + return mod + + def _compile_to_module_lines( + self, wrapper_code: ValueWithLineMap + ) -> CompiledModule: + from .codecache import PyCodeCache + + if config.triton.autotune_at_compile_time: + # sanitize docstrings in kernel defs (#155006) + kernel_autotune_defs = self.wrapper_code.kernel_autotune_defs.getvalue() + kernel_autotune_defs = kernel_autotune_defs.replace('"""', '\\"\\"\\"') + + tuning_code = ( + '"""\n' + + "Compile-time auto-tuning block: \n" + + kernel_autotune_defs + + self.wrapper_code.kernel_autotune_calls.getvalue() + + '"""\n' + ) + wrapper_code.value = tuning_code + wrapper_code.value + if GraphLowering.save_output_code is not None: + GraphLowering.save_output_code(wrapper_code.value) + output_code_log.debug("Output code: \n%s", wrapper_code.value) + + inductor_meta = autotune_cache.inductor_meta_from_config() + AutotuneCacheBundler.begin_compile(inductor_meta, code=wrapper_code.value) + + try: + linemap = [ + (line_no, node.stack_trace) # type: ignore[attr-defined] + for line_no, node in wrapper_code.line_map + ] + key, path = PyCodeCache.write(wrapper_code.value) + output_code_log.debug("Output code written to: %s", path) + except Exception: + trace_structured( + "inductor_output_code", + # Just omit the filename, I still want the code though! + payload_fn=lambda: wrapper_code.value, + ) + raise + else: + trace_structured( + "inductor_output_code", + lambda: {"filename": path}, + payload_fn=lambda: wrapper_code.value, + ) + with dynamo_timed("PyCodeCache.load_by_key_path", log_pt2_compile_event=True): + mod = PyCodeCache.load_by_key_path( + key, + path, + linemap=linemap, # type: ignore[arg-type] + attrs={**self.constants, **self.torchbind_constants}, + ) + self.cache_key = key + self.cache_path = path + self.cache_linemap = linemap # type: ignore[assignment] + + if config.benchmark_harness and config.profile_bandwidth_output: + # run the inputs code gen to get the bandwidth info + mod.benchmark_compiled_module(times=1, repeat=1) + + return mod + + def get_output_names(self) -> list[str]: + names = [] + shape_counter = itertools.count(0) + none_counter = itertools.count(0) + for node in self.graph_outputs: + if isinstance(node, ir.NoneAsConstantBuffer): + names.append(f"{self.name}_none{next(none_counter)}") + elif isinstance(node, ir.ShapeAsConstantBuffer): + names.append(f"{self.name}_shape{next(shape_counter)}") + else: + names.append(node.get_name()) + return names + + def is_unspec_arg(self, name: str) -> bool: + # dynamo wraps unspec variable as 0d CPU tensor, + # need to convert to scalar during codegen (triton only) + return ( + name in self.graph_inputs.keys() + and self.graph_inputs[name].get_numel() == 1 + and len(self.graph_inputs[name].get_size()) == 0 + and get_device_type(self.graph_inputs[name]) == "cpu" + ) or name in self.zero_dim_cpu_tensor_list + + +class SubgraphLowering(GraphLowering): + """ + Mostly a helper class for the subgraph lowering. The main goal is to call + init_wrapper_code with the subgraph related arguments. + """ + + def __init__(self, parent: GraphLowering, *args: Any, **kwargs: Any) -> None: + self.parent = parent + super().__init__(*args, **kwargs) + + def init_wrapper_code( + self, + is_subgraph: bool = False, + subgraph_name: Optional[str] = None, + parent_wrapper_code: Optional[PythonWrapperCodegen] = None, + partition_signatures: Optional[GraphPartitionSignature] = None, + ) -> None: + super().init_wrapper_code( + is_subgraph=True, + subgraph_name=self.name, + parent_wrapper_code=self.parent.wrapper_code, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/hooks.py b/phivenv/Lib/site-packages/torch/_inductor/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..c6cb8db9dc290f76499ef7a4b69fab2a48aaaeae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/hooks.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Callable, TYPE_CHECKING + + +if TYPE_CHECKING: + import torch + +# Executed in the order they're registered +INTERMEDIATE_HOOKS: list[Callable[[str, "torch.Tensor"], None]] = [] + + +@contextlib.contextmanager +def intermediate_hook(fn): + INTERMEDIATE_HOOKS.append(fn) + try: + yield + finally: + INTERMEDIATE_HOOKS.pop() + + +def run_intermediate_hooks(name, val): + global INTERMEDIATE_HOOKS + hooks = INTERMEDIATE_HOOKS + INTERMEDIATE_HOOKS = [] + try: + for hook in hooks: + hook(name, val) + finally: + INTERMEDIATE_HOOKS = hooks diff --git a/phivenv/Lib/site-packages/torch/_inductor/index_propagation.py b/phivenv/Lib/site-packages/torch/_inductor/index_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc26e257358f67f3a77ff468b953c501750c2af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/index_propagation.py @@ -0,0 +1,370 @@ +# mypy: allow-untyped-defs +"""This file implements the IndexPropagation ops handler, which wraps an +underlying handler to add a limited form of constant propagation, as well as +propagation of sympy expressions downstream of ops.index_expr calls. + +For example, say we have the IR: + + tmp0 = ops.index_expr(x, torch.int32) + tmp1 = ops.constant(2, torch.int32) + tmp2 = ops.mul(tmp0, tmp1) + tmp3 = ops.indirect_indexing(tmp2, x_size) + tmp4 = ops.load("buf0", tmp3) + +The underlying handler would just see: + + ops.load("buf0", x * 2) + +This is limited by the set of operators handled in the sympy expression +printers. So simple operations like minimum and maximum cannot be translated to +SymPy expressions yet, despite sympy.Min and sympy.Max existing. + +""" + +import itertools +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Literal, Optional, overload, Union +from typing_extensions import TypeAlias + +import sympy + +import torch +from torch._prims_common import dtype_to_type, is_integer_dtype +from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from .ops_handler import DefaultHandler +from .sizevars import statically_known_true +from .utils import generate_assert +from .virtualized import V + + +_ExprType = Union[sympy.Expr, float, int, bool] + + +def _is_constant(val: _ExprType): + if isinstance(val, sympy.Basic): + return val.is_number + return isinstance(val, (int, float, bool)) + + +def upper_bound(val: _ExprType): + return bound_sympy(val).upper if isinstance(val, sympy.Expr) else val + + +@dataclass +class TypedExpr: + """A SymPy expression with associated type""" + + expr: _ExprType + dtype: torch.dtype + + def is_constant(self): + return _is_constant(self.expr) + + def __post_init__(self): + if _is_constant(self.expr): + self.expr = dtype_to_type(self.dtype)(self.expr) + + +class SymPyOps: + """An ops handler where all IR values are SymPy expressions + + When a value cannot be represented as a SymPy expression, the method is + either not defined, or returns NotImplemented + + """ + + @staticmethod + def identity(value: Any) -> Any: + return value + + @staticmethod + def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr: + return TypedExpr(value, dtype) + + @staticmethod + def index_expr(value: Union[sympy.Expr, int], dtype: torch.dtype) -> TypedExpr: + return TypedExpr(value, dtype) + + @staticmethod + def to_dtype( + value: TypedExpr, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = False, + ) -> TypedExpr: + return TypedExpr(value.expr, dtype) + + @staticmethod + def abs(x: TypedExpr) -> TypedExpr: + return TypedExpr(abs(x.expr), x.dtype) # type: ignore[arg-type] + + @staticmethod + def square(x: TypedExpr) -> TypedExpr: + return TypedExpr(x.expr * x.expr, x.dtype) + + @staticmethod + def add(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr + y.expr, result_type) + + @staticmethod + def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr - y.expr, result_type) + + @staticmethod + def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr * y.expr, result_type) + + @staticmethod + def neg(x: TypedExpr) -> TypedExpr: + return TypedExpr(-x.expr, x.dtype) + + @staticmethod + def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + return TypedExpr(FloorDiv(x.expr, y.expr), result_type) + + @staticmethod + def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr) + return TypedExpr(result_expr, result_type) + + @staticmethod + def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + x_expr = sympy.sympify(x.expr) + y_expr = sympy.sympify(y.expr) + # In these cases, remainder in Python == remainder in C++, so this transformation + # is sound + if ( + x_expr.is_nonnegative is not None + and x_expr.is_nonnegative == y_expr.is_positive + ): + result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr) + return TypedExpr(result_expr, result_type) + return NotImplemented + + @staticmethod + def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Min(x.expr, y.expr), result_type) + + @staticmethod + def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Max(x.expr, y.expr), result_type) + + +@dataclass +class IndexPropVar: + value: Any # Either an IR value, or TypedExpr if is_symbolic is true + is_symbolic: bool = False + + @staticmethod + def new_symbolic(expr: TypedExpr) -> "IndexPropVar": + return IndexPropVar(expr, is_symbolic=True) + + def __post_init__(self): + assert not self.is_symbolic or isinstance(self.value, TypedExpr), ( + "Symbolic IndexPropVar must contain a TypedExpr" + ) + + +IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]] + + +class IndexPropagation(DefaultHandler): + """Ops wrapper that tries to propagate constant and index_expr values through the computation. + + This aims to maximize the compile time simplification possible, and convert + indirect indexing from arange into normal static indexing. + + """ + + def __init__( + self, + inner: Any, + iter_ranges: dict[sympy.Symbol, sympy.Expr], + indirect_var_ranges: dict[sympy.Symbol, sympy.Expr], + ) -> None: + self._inner = inner + self.shape_env = V.graph.sizevars.shape_env + + var_to_range = { + k: ValueRanges(0, upper_bound(v) - 1) for k, v in iter_ranges.items() + } + self.var_to_range = tuple( + itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items()) + ) + # NOTE: this is intentionally kept as a reference so the caller can + # update it in-place + self.indirect_var_ranges = indirect_var_ranges + + axioms = [] + for x, s in iter_ranges.items(): + axioms.append(0 <= x) + axioms.append(x < s) + self.axioms = tuple(axioms) + self.shape_env.get_axioms() + + def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any: + # Construct a new constant/index_expr from the SymPy expression + if _is_constant(expr): + val = dtype_to_type(dtype)(expr) + return self._inner.constant(val, dtype) + return self._inner.index_expr(expr, dtype) + + def unwrap(self, a: Union[Any, IndexPropVar]) -> Any: + if isinstance(a, (list, tuple)): + return tuple(self.unwrap(v) for v in a) + + if not isinstance(a, IndexPropVar): + return a + + # Prefer the sympy representation if possible + if a.is_symbolic: + return self.materialize_expr(a.value.expr, a.value.dtype) + + return a.value + + def wrap(self, a) -> IndexPropResult: + if isinstance(a, (list, tuple)): + return tuple(self.wrap(v) for v in a) + return IndexPropVar(a) + + @overload + def fallback( + self, + name: Literal["indirect_indexing"], + args: Sequence[Any], + kwargs: dict[str, Any], + ) -> IndexPropVar: ... + + @overload + def fallback( + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] + ) -> IndexPropResult: ... + + def fallback( + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] + ) -> IndexPropResult: + # Fallback to the wrapped handler + new_args = [self.unwrap(a) for a in args] + new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()} + return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) + + def propagate_sympy( + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] + ) -> IndexPropResult: + # Build a new SymPy expression from this ops call + def unwrap(a: Union[Any, IndexPropVar]) -> Any: + if not isinstance(a, IndexPropVar): + return a + return a.value + + new_args = [unwrap(a) for a in args] + new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} + new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs) + is_valid_expr = new_expr is not NotImplemented and ( + # Inductor doesn't expect floating point in sympy expressions, but + # allow floating point constants to be propagated + new_expr.is_constant() or new_expr.expr.is_integer + ) + if not is_valid_expr: + return self.fallback(name, args, kwargs) + return IndexPropVar.new_symbolic(new_expr) + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + if not hasattr(SymPyOps, name): + return self.fallback(name, args, kwargs) + + var_arguments = [ + a + for a in itertools.chain(args, kwargs.values()) + if isinstance(a, IndexPropVar) + ] + if not all(v.is_symbolic for v in var_arguments): + return self.fallback(name, args, kwargs) + + return self.propagate_sympy(name, args, kwargs) + + def statically_true(self, e): + """ + Given some iter_ranges, return a function that given an expression, returns whether + it is true or false using value ranges, guard knowledge and runtime_asserts. + + FIXME I think this may not be entirely right, as we may not be able to use all runtime_asserts + If this is an issue, just use guards in `self.axioms`. + + The proper way of handling this would be to have a global shape_env that adds + runtime_asserts as they happen in the code. Then, it should be used in SimplifyIndexing + to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also + for indirect_indexing + """ + var_to_range = ( + *self.var_to_range, + *( + (k, ValueRanges(0, upper_bound(v) - 1)) + for k, v in self.indirect_var_ranges.items() + ), + ) + return statically_known_true(self.shape_env, e, self.axioms, var_to_range) + + def indirect_indexing( + self, + index: Union[Any, IndexPropVar], + size: Any, + check: bool = True, + wrap_neg=True, + ) -> Any: + if isinstance(index, IndexPropVar) and index.is_symbolic: + # If we find something we can convert into a direct indexing we do so + # We still need to (perhaps) wrap the expression and add bound checks + # We want to do this "constant folding", as we don't allow to fuse + # kernels into indirect indexing + + expr = sympy.sympify(index.value.expr) + + # TODO Perhaps move this logic to the simplify indexing pass + def wrap_expr(expr): + # Positive, negative, mixed + if self.statically_true(0 <= expr): + return expr + elif self.statically_true(expr < 0): + return expr + size + else: + return Where(expr < 0, expr + size, expr) + + # Sometimes it's easier to prove 0 <= expr than the weaker -size <= expr + can_prove_lower = self.statically_true(0 <= expr) or self.statically_true( + -size <= expr + ) + can_prove_upper = self.statically_true(expr < size) + if wrap_neg: + expr = wrap_expr(expr) + if generate_assert(check): + self.fallback( + "check_bounds", + (expr, size), + dict(lower=not can_prove_lower, upper=not can_prove_upper), + ) + return expr + + indirect_var = self.fallback( + "indirect_indexing", (index, size, check, wrap_neg), {} + ).value + return indirect_var diff --git a/phivenv/Lib/site-packages/torch/_inductor/inductor_prims.py b/phivenv/Lib/site-packages/torch/_inductor/inductor_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..debd36a9b72c227e444e891df21a677183074bba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/inductor_prims.py @@ -0,0 +1,225 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +import operator +from typing import Optional, TYPE_CHECKING + +import torch +from torch import _prims, Tensor + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +log = logging.getLogger(__name__) + + +def make_prim( + schema: str, + impl_aten, + return_type=_prims.RETURN_TYPE.NEW, + doc: str = "", + tags: Optional[Sequence[torch.Tag]] = None, +): + if isinstance(return_type, tuple): + + def meta(*args, **kwargs): + return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs)) + + else: + + def meta(*args, **kwargs): + return _prims.TensorMeta(impl_aten(*args, **kwargs)) + + return _prims._make_prim( + schema=schema, + return_type=return_type, + meta=meta, + impl_aten=impl_aten, + doc=doc, + tags=tags, + ) + + +def eager_force_stride(input_tensor: Tensor, stride) -> Tensor: + if input_tensor.stride() == stride: + return input_tensor + new_tensor = input_tensor.clone().as_strided( + input_tensor.shape, + stride, + ) + new_tensor.copy_(input_tensor) + return new_tensor + + +def eager_prepare_softmax(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + amax = torch.amax(x, dim, keepdim=True) + return amax, torch.sum(torch.exp(x - amax), dim, keepdim=True) + + +# Custom prims used for handling randomness +seed = make_prim( + "inductor_seed(Device device) -> Tensor", + lambda device: torch.randint(2**63 - 1, [], device=device), + doc="create a fresh seed (one per call) for use with inductor_rand", + tags=(torch.Tag.nondeterministic_seeded,), +) +seeds = make_prim( + "inductor_seeds(int count, Device device) -> Tensor", + lambda count, device: torch.randint(2**63 - 1, [count], device=device), + doc="Horizontal fusion of many inductor_seed() calls", + tags=(torch.Tag.nondeterministic_seeded,), +) +lookup_seed = make_prim( + # if inductor_lookup_seed changes, update partitioners.py + "inductor_lookup_seed(Tensor seeds, int index) -> Tensor", + lambda seeds, index: seeds[index].clone(), + doc="Extract a single seed from the result of inductor_seeds()", +) +# inductor_random() doesn't accept a dtype. +# instead, its lowering always burns in float32, and conversions to a different type +# are explicit in the graph. We therefore need this impl (used during tracing) to hardcoded +# the dtype, so it always faithfully produces a float32 tensor during tracing, +# even if the default dtype is set to something else. +random = make_prim( + "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", + lambda size, seed, mode: getattr(torch, mode)( + size, device=seed.device, dtype=torch.float32 + ), + doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", +) +randint = make_prim( + "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor", + lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device), + doc="torch.randint() using backend-specific RNG that can be fused", +) +force_stride_order = make_prim( + "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor", + eager_force_stride, + doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise", +) +_unsafe_index_put_ = make_prim( + "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", + lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_( + self, indices, values, accumulate + ), + doc="Unsafe index_put_ (doesn't issue device asserts)", +) +fma = make_prim( + "fma(Tensor a, Tensor b, Tensor c) -> Tensor", + lambda a, b, c: (a * b) + c, + doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication", +) +prepare_softmax_online = make_prim( + "prepare_softmax_online(Tensor a, int dim) -> (Tensor, Tensor)", + eager_prepare_softmax, + return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), + doc="Prepare the softmax by computing the max and sum.", +) + + +def _flattened_index_to_nd(indices, width): + import sympy + + from torch.utils._sympy.functions import FloorDiv + + dim = len(width) + + if dim == 1: + return [indices] + elif dim >= 2: + m = functools.reduce(operator.mul, width[1:]) + if isinstance(indices, sympy.Expr) or isinstance(m, sympy.Expr): + ih = FloorDiv(indices, m) + else: + ih = indices // m + indices_new = indices - (ih * m) + return [ih, *_flattened_index_to_nd(indices_new, width[1:])] + else: + raise ValueError(f"Unknown dim: {dim}") + + +def _flatten_index(indices, width): + result = indices[0] + for d in range(1, len(indices)): + result = width[d] * result + indices[d] + return result + + +def _low_memory_max_pool_with_offsets_aten( + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, +): + dim = len(kernel_size) + if dim == 2: + vals, indices = torch.ops.aten.max_pool2d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + else: + vals, indices = torch.ops.aten.max_pool3d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + + idhw = _flattened_index_to_nd(indices, self.shape[-dim:]) + + dhw_inc = [] + + for d in range(dim): + bh_shape = [1] * self.ndim + bh_shape[-dim + d] = -1 + bh = torch.arange( + indices.shape[-dim + d], dtype=torch.int64, device=self.device + ).view(bh_shape) + hbase = bh * stride[d] - padding[d] + h_inc = (idhw[d] - hbase) // dilation[d] + dhw_inc.append(h_inc) + + offsets = _flatten_index(dhw_inc, kernel_size) + + return vals, offsets.to(torch.int8) + + +def _low_memory_max_pool_offsets_to_indices_aten( + offsets, + kernel_size, + input_size, + stride, + padding, + dilation, +): + dim = len(kernel_size) + offsets = offsets.to(torch.int64) + dhw_inc = _flattened_index_to_nd(offsets, kernel_size) + + idhw = [] + for d in range(dim): + bh_shape = [1] * offsets.ndim + bh_shape[-dim + d] = -1 + bh = torch.arange( + offsets.shape[-dim + d], dtype=torch.int64, device=offsets.device + ).view(bh_shape) + hbase = bh * stride[d] - padding[d] + idhw.append(hbase + dhw_inc[d] * dilation[d]) + + return _flatten_index(idhw, input_size) + + +_low_memory_max_pool_with_offsets = make_prim( + "_low_memory_max_pool_with_offsets(Tensor self, SymInt[] kernel_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 + _low_memory_max_pool_with_offsets_aten, + return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), + doc="Instead of returning indices, returns indices offsets.", +) + +_low_memory_max_pool_offsets_to_indices = make_prim( + "_low_memory_max_pool_offsets_to_indices(Tensor self, SymInt[] kernel_size, SymInt[] input_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation) -> Tensor", # noqa: B950 + _low_memory_max_pool_offsets_to_indices_aten, + doc="Convert small int offsets to regular indices.", +) diff --git a/phivenv/Lib/site-packages/torch/_inductor/ir.py b/phivenv/Lib/site-packages/torch/_inductor/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..c247122f6ded4169fd0687afbfebe15dd313d894 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/ir.py @@ -0,0 +1,8451 @@ +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import itertools +import logging +import operator +import textwrap +import traceback +import typing +from collections.abc import Generator, Iterable, Sequence +from contextlib import AbstractContextManager, nullcontext +from enum import Enum +from functools import partial +from typing import ( + Any, + Callable, + ClassVar, + Literal, + Optional, + overload, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import assert_never, Never, TypeAlias +from unittest.mock import patch + +import sympy +from sympy import Expr, Integer, Symbol + +import torch._export.serde.schema as export_schema +import torch._library.utils as library_utils +import torch._logging +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo.utils import identity +from torch._export.serde.serialize import GraphModuleSerializer +from torch._higher_order_ops.auto_functionalize import can_auto_functionalize +from torch._inductor import metrics +from torch._prims_common import ( + compute_required_storage_length, + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + StrideType, +) +from torch._subclasses.fake_tensor import get_schema_info +from torch.fx.experimental.symbolic_shapes import ( + _remove_effect_token_unbacked_bindings, + compute_unbacked_bindings, + free_symbols, + free_unbacked_symbols, + IterateExprs, + rebind_unbacked, + resolve_unbacked_bindings, + ShapeEnv, + statically_known_true, + SymTypes, +) +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import ( + BackendFeature, + CodegenSymbol, + get_scheduling_for_device, + index_prevent_reordering, +) +from .dependencies import ( + Dep, + extract_free_symbols, + extract_input_node_reduction_ranges, + extract_read_writes, + var_builder, +) +from .loop_body import LoopBody +from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode +from .runtime.benchmarking import benchmarker +from .runtime.hints import DeviceProperties, ReductionHint +from .utils import ( + argsort, + argsort_sym, + cache_on_self, + ceildiv, + convert_shape_to_inductor, + convert_shape_to_symint, + developer_warning, + do_bench_using_profiling, + dtype_from_size, + get_dtype_size, + get_kernel_metadata, + GPU_ALIGN_BYTES, + ir_dataclass, + is_dynamic, + is_gpu, + sympy_dot, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_product, + sympy_subs, + tensor_is_aligned, +) +from .virtualized import ops, OpsValue, V + + +if TYPE_CHECKING: + from torch._library.fake_class_registry import FakeScriptObject + from torch.fx.node import Node + + from .codegen.cuda.cuda_template import CUDATemplate + from .graph import GraphLowering + from .utils import IndentedBuffer + +else: + CUDATemplate: TypeAlias = object + + +try: + import triton + + triton_version = triton.__version__ + has_triton = True +except ImportError: + triton_version = None + has_triton = False + + +_T = TypeVar("_T") +_U = TypeVar("_U") +_V = TypeVar("_V") + +_IntLike: TypeAlias = Union[int, Expr] +_NumLike: TypeAlias = Union[int, float, Expr] + +log = logging.getLogger(__name__) +indent = functools.partial(textwrap.indent, prefix=" ") +aten = torch.ops.aten + +""" [Note: Inductor IR] + +Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each +lowering is registered to a particular aten operator, and expects inputs that +correspond to the aten schema. However, in place of torch Tensor inputs, lowerings +expect Inductor TensorBox inputs. + +TensorBox IR represents torch tensors. Tensors are sometimes single objects owning +storage, and sometimes views of another Tensor's storage. Mutating tensor operations +(such as add_()) affect the underlying storage and any associated views. Other operations +(such as .t_()) update metadata about the current view but don't modify the underlying storage. + +To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer. + +TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor +output from an operation. But just as torch.Tensors take different forms, TensorBox IR can +reference View IR or directly reference StorageBox IRs. + +Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops) +may take an existing TensorBox and point it to a new underlying View IR. + +Tensors that directly own storage are represented as a chain of: +TensorBox -> StorageBox -> Buffer +where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout. + +If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer +(leaving the old buffer unmodified and functionalizing the operation). + +Tensors backed by views add one more indirection to the IR. +TensorBox -> View -> StorageBox -> Buffer +In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. + +Computation is represented by Operation nodes, with each operation producing 1 +or more output Buffers. In the case of mutations, these will be new Buffers that have the +mutated buffer listed in its get_mutation_names(). + +It is also possible to have an InputBuffer for which there is no corresponding Operation, +e.g. it may be a graph input or compile time constant. + +""" + + +_NodeOrNodes: TypeAlias = Union[ + int, + "TensorBox", + dict[str, "TensorBox"], + "Symbol", + "IRNode", + Sequence[ + Optional[Union[int, dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]] + ], +] + + +def _is_static(x: object) -> bool: + return isinstance(x, (int, Integer)) + + +@dataclasses.dataclass(frozen=True) +class GraphPartitionSignature: + # symbol inputs that are necessary for codegen + symbol_inputs: OrderedSet[sympy.Symbol] + + # mapping from partition input name to IRNode or Expr. Need the name str since + # we cannot get name from Expr. + input_nodes: dict[str, Union[IRNode, sympy.Expr, TorchBindObject]] + output_nodes: list[IRNode] + + # mapping from partition input name to a boolean for whether deallocating it + # in the partition function + input_deallocation: dict[str, bool] + skip_cudagraph: bool + + # name of constants read/written by the graph partition + constant_names: list[str] + + +def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None: + def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None: + # Could expand this to check deeper properties + # (e.g. TensorBox points to View or StorageBox) + if nodes is None: + pass + elif isinstance(nodes, (list, tuple)): + for node in nodes: + _check_tensorbox(node) + elif isinstance(nodes, dict): + for node in nodes.values(): + _check_tensorbox(node) + else: + assert isinstance( + nodes, + ( + ExpandView, + DynamicScalar, + AssertScalar, + TensorBox, + sympy.logic.boolalg.Boolean, + Expr, + int, + EffectfulKernel, + ShapeAsConstantBuffer, + ), + ), ( + f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" + ) + + # Be picky about the accepted data structure (don't use pytree here) + _check_tensorbox(node_or_nodes) + + +def ops_wrapper(name: str) -> Callable[..., OpsValue]: + assert isinstance(name, str) + + def fn(*args: object, **kwargs: object) -> OpsValue: + return getattr(ops, name)(*args, **kwargs) + + return fn + + +def inverse_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]: + inv_order = dict(zip(order, range(len(order)))) + + def reindex(index: Sequence[_T]) -> Sequence[_T]: + assert len(index) == len(inv_order) + return [index[inv_order[i]] for i in range(len(index))] + + return reindex + + +def same_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]: + def reindex(index: Sequence[_T]) -> Sequence[_T]: + assert len(index) == len(order) + return [index[order[i]] for i in range(len(index))] + + return reindex + + +def fuse_reindexing( + reindex1: Callable[[Sequence[_U]], Sequence[_V]], + reindex2: Callable[[Sequence[_T]], Sequence[_U]], +) -> Callable[[Sequence[_T]], Sequence[_V]]: + def reindex(index: Sequence[_T]) -> Sequence[_V]: + return reindex1(reindex2(index)) + + return reindex + + +def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: + if unbacked_only: + return free_unbacked_symbols(x) + else: + return free_symbols(x) + + +NHWC_STRIDE_ORDER = [3, 0, 2, 1] +NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] + + +def get_fill_order( + seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None +) -> Sequence[int]: + """ + Convert strides to fill order (argsort) + """ + if shape_env is None: + sorted_idx: Sequence[int] = argsort(seq) + else: + # argsort_sym handles unbacked symints (with the help of the shape_env) + sorted_idx = argsort_sym(shape_env, seq) + return sorted_idx + + +def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[int]: + """ + Convert stride order to fill order + For channel last format, + + stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] + """ + lookup = {pos: idx for idx, pos in enumerate(order)} + fill_order = [lookup[i] for i in range(len(order))] + return fill_order + + +def get_stride_order( + seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None +) -> Sequence[int]: + """ + Convert strides to stride order + """ + sorted_idx: Sequence[int] = get_fill_order(seq, shape_env) + out = [0 for _ in range(len(seq))] + for i, elem in enumerate(sorted_idx): + out[elem] = i + return out + + +@overload +def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: ... + + +@overload +def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ... + + +def ir_node_to_tensor( + x: Optional[IRNode], guard_shape: bool = True +) -> Optional[torch.Tensor]: + if x is None: + return None + + shape_fn: Callable[[Union[int, Expr]], Union[int, Expr]] + if not guard_shape: + shape_fn = V.graph.sizevars.size_hint + else: + shape_fn = identity + size = [shape_fn(s) for s in x.get_size()] + stride: StrideType + if is_storage_and_layout(x): + stride = [shape_fn(s) for s in x.get_layout().stride] + else: + stride = FlexibleLayout.contiguous_strides(size) + dtype = x.get_dtype() + device = x.get_device() + size = convert_shape_to_symint(size) + stride = convert_shape_to_symint(stride) + with V.graph.sizevars.shape_env.suppress_guards(): + t = torch.empty_strided( + size=size, stride=stride, dtype=dtype, device=device + ).zero_() + return t + + +def may_convert_to_optional( + value: Optional[Sequence[_T]], +) -> Optional[Sequence[Optional[_T]]]: + if isinstance(value, list) and not value: + # [None] makes sure the cpp wrapper codegen will generate something like + # {std::nullopt} instead of {} + return [None] + return value + + +def get_device_type( + x: Union[IRNode, OutputSpec, torch.device, None, str], +) -> Optional[str]: + if isinstance(x, str) or x is None: + return x + elif isinstance(x, torch.device): + return x.type + elif isinstance(x, (IRNode, OutputSpec)): + return get_device_type(x.get_device()) + assert_never(f"get_device_type({x}: {type(x).__name__})") + + +def is_triton(x: Union[IRNode, torch.device, None, str]) -> bool: + device = get_device_type(x) + # Special case cpu and cuda as using the method below + # to determine if the scheduler is a triton scheduler subclass + # requires instantiating a scheduler for them + if device in ["cpu", "cuda"]: + if getattr(config, f"{device}_backend") == "triton": + return True + return False + if ( + device is None + or (device_scheduling := get_scheduling_for_device(device)) is None + ): + return False + from .codegen.triton import TritonScheduling + + assert isinstance(device_scheduling, type) + return issubclass(device_scheduling, TritonScheduling) + + +def is_cpu(x: Union[IRNode, torch.device, None, str]) -> bool: + return get_device_type(x) == "cpu" + + +def is_aligned_realized_tensor(x: Union[Buffer, TensorBox], alignment: int) -> bool: + if not isinstance(x, IRNode) or x.maybe_get_stride() is None: + return False + + aligned_strides = all( + (V.graph.sizevars.size_hint_or_throw(x.get_stride()[i]) % alignment) == 0 + for i in range(len(x.get_stride()) - 1) + ) + # if the last dim size is <= 1, stride doesn't matter + aligned_last_dim = ( + V.graph.sizevars.size_hint_or_throw(x.get_stride()[-1]) == 1 + or V.graph.sizevars.size_hint_or_throw(x.get_size()[-1]) <= 1 + ) + return aligned_last_dim and aligned_strides + + +def significant_strides_equal( + strides1: Sequence[_IntLike], + strides2: Sequence[_IntLike], + shape: Sequence[_IntLike], +) -> bool: + """ + Returns true if the strides are equal, ignoring dimensions of size 1 . + """ + assert len(shape) == len(strides1) and len(strides1) == len(strides2) + for dim, s1, s2 in zip(shape, strides1, strides2): + if V.graph.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type] + continue + + if not V.graph.sizevars.statically_known_equals( + s1, s2 + ) and not V.graph.sizevars.symbolic_hint(s1) == V.graph.sizevars.symbolic_hint( + s2 + ): + return False + + return True + + +def try_match_insignificant_strides( + tensor: Union[TensorBox, BaseView], + strides: Sequence[Union[int, torch.SymInt]], +) -> Union[TensorBox, BaseView]: + """ + Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant + dimensions - size 0 or 1 - will be updated. + + If there are real stride differences (NHWC vs NCHW), or the tensor is not realized, then the input will be returned + """ + if not is_storage_and_layout(tensor): + return tensor + + if all( + V.graph.sizevars.statically_known_equals(s1, s2) + for s1, s2 in zip(strides, tensor.get_stride()) + ): + return tensor # type: ignore[arg-type] + + if not significant_strides_equal(strides, tensor.get_stride(), tensor.get_size()): + return tensor + + storage, old_layout = as_storage_and_layout(tensor) + new_stride = [*old_layout.stride] + for i, s in enumerate(tensor.get_size()): + if V.graph.sizevars.statically_known_leq(s, 1): # type: ignore[arg-type] + new_stride[i] = strides[i] + + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + old_layout.size, + new_stride, + old_layout.offset, + ) + return TensorBox(ReinterpretView(data=storage, layout=new_layout)) + + +def gm_original_output_strides(gm: torch.fx.GraphModule) -> None: + output_node = gm.graph.find_nodes(op="output")[0] + output_node.meta["user_visible_output_idxs"] = [ + idx for idx, _ in enumerate(output_node.args) + ] + from torch._inductor.compile_fx import record_original_output_strides + + record_original_output_strides(gm) + + +def get_symbolic_inputs(inputs: list[Buffer]) -> list[Expr]: + sym_vars: OrderedSet[Expr] = OrderedSet() + for inp in inputs: + sym_vars |= get_free_symbols(inp.get_size(), unbacked_only=False) + sym_vars |= get_free_symbols(inp.get_stride(), unbacked_only=False) + + return list(sym_vars) + + +class IRNode: + _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() + + # NB: These are kinda weird, + origins: OrderedSet[Any] = dataclasses.field(init=False) + traceback: Optional[list[str]] = dataclasses.field(init=False) + origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False) + + @staticmethod + @contextlib.contextmanager + def current_origins(origins: OrderedSet[Node]) -> Generator[None, None, None]: + old = IRNode._current_origins + IRNode._current_origins = old | origins + try: + yield + finally: + IRNode._current_origins = old + + def _post_init_setattr(self, attr: str, value: Any) -> None: + # Intended for use in __post_init__ for enforcing an invariant on a dataclass + # If you must, can also be used for setting provenance info + # We would like to try and minimize these usages though + object.__setattr__(self, attr, value) + + def __post_init__(self) -> None: + self._post_init_setattr("origins", OrderedSet(self._current_origins)) + self._post_init_setattr( + "traceback", traceback.format_stack() if config.debug_ir_traceback else None + ) + self._post_init_setattr("origin_node", None) + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(dep.name for dep in self.get_reads()) + + def get_traceback(self) -> Optional[list[str]]: + return self.traceback + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return self.origin_node + + def get_defining_op(self) -> Optional[Operation]: + return None + + def common_repr(self, shorten: bool = True) -> Sequence[str]: + origins = f"origins={getattr(self, 'origins', '')}" + if shorten and len(origins) > 64: + # this can get *very* long + origins = f"{origins[:61]}..." + return [origins] + + def str_helper( + self, lines: Sequence[object], shorten: bool = True, multiline: bool = True + ) -> str: + lines = list(lines) + list(self.common_repr(shorten)) + lines = list(map(str, lines)) + if multiline: + new_lines = indent(",\n".join(lines)) + return f"{type(self).__name__}(\n{new_lines}\n)" + else: + return f"{type(self).__name__}({lines})" + + def get_dtype(self) -> torch.dtype: + return self.dtype + + def maybe_get_dtype(self) -> Optional[torch.dtype]: + try: + return self.get_dtype() + except NotImplementedError: + return None + + def get_layout(self) -> Layout: + raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!") + + def maybe_get_layout(self) -> Optional[Layout]: + try: + return self.get_layout() + except NotImplementedError: + return None + + def get_output_spec(self) -> OutputSpec: + return self.get_layout() + + def maybe_get_output_spec(self) -> Optional[OutputSpec]: + try: + return self.get_output_spec() + except NotImplementedError: + return None + + def has_tensor_output(self) -> bool: + """True for single tensor output (excludes MultiOutput)""" + return isinstance(self.maybe_get_output_spec(), Layout) + + def get_size(self) -> Sequence[Expr]: + raise NotImplementedError(f"get_size() is not implemented by {type(self)}!") + + def maybe_get_size(self) -> Optional[Sequence[_IntLike]]: + try: + return self.get_size() + except NotImplementedError: + return None + + @property + def shape(self) -> Union[_IntLike, sympy.Rel, Sequence[_IntLike]]: + return self.get_size() + + def get_numel(self) -> Expr: + return sympy_product(self.get_size()) + + def is_zero_elements(self) -> bool: + return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0)) + + def realize(self) -> Optional[str]: + """ + If the IRNode refers to data which has not been materialized (e.g., + it is a Pointwise/Reduction that could potentially have more + compute fused into it), realize the IRNode into physical memory, + ending the possibility of fusing into it, but allowing, e.g., multiple + users to access the data without having to recompute. + + Check StorageBox.realize for a particularly notable implementation. + + TODO(ezyang): I think, in principle, every IRNode should have an + implementation of this, and most of the time no-op is OK, but you + really do have to audit each IRNode for this, so for now, raise + an error if it's not implemented. Note that some code in graph.py + will catch this thrown error and suppress it with a warning. + """ + raise NotImplementedError(f"realize NYI on {type(self)}") + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + raise NotImplementedError(f"codegen_reference NYI on {type(self)}") + + def get_device(self) -> Optional[torch.device]: + return None + + def get_device_or_error(self) -> torch.device: + device = self.get_device() + assert device is not None + return device + + def has_exceeded_max_reads(self) -> bool: + return False + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + raise NotImplementedError(type(self).__name__) + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + raise NotImplementedError(type(self).__name__) + + def get_stride(self) -> Sequence[_IntLike]: + raise NotImplementedError(type(self).__name__) + + def maybe_get_stride(self) -> Optional[Sequence[_IntLike]]: + try: + return self.get_stride() + except NotImplementedError: + return None + + def get_name(self) -> str: + raise NotImplementedError(type(self).__name__) + + def maybe_get_name(self) -> Optional[str]: + try: + return self.get_name() + except NotImplementedError: + return None + + def is_input_buffer(self) -> bool: + try: + return self.get_name() in V.graph.graph_inputs + except NotImplementedError: + return False + + def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool: + return False + + def mark_reuse(self, users: int) -> None: + pass + + def realize_hint(self) -> None: + pass + + def unwrap_view(self) -> IRNode: + raise NotImplementedError(type(self).__name__) + + def freeze_layout(self) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_stride_order( + self, order: list[int], allow_padding: bool = False + ) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_fill_order(self, order: list[int]) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_exact_strides( + self, exact_strides: list[_IntLike], allow_padding: bool = False + ) -> None: + raise NotImplementedError(type(self).__name__) + + def get_read_writes(self) -> dependencies.ReadWrites: + raise NotImplementedError(type(self).__name__) + + def get_reads(self) -> OrderedSet[Dep]: + return self.get_read_writes().reads + + def num_reads(self) -> int: + return len(self.get_reads()) + + def get_storage_numel(self) -> _IntLike: + raise NotImplementedError(type(self).__name__) + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + raise NotImplementedError(type(self).__name__) + + def get_reduction_type(self) -> Optional[str]: + raise NotImplementedError(type(self).__name__) + + def get_reduction_size(self) -> Sequence[sympy.Expr]: + raise NotImplementedError(type(self).__name__) + + def is_extern(self) -> bool: + return False + + def is_no_op(self) -> bool: + return False + + def constant_to_device(self, device: torch.device) -> IRNode: + raise NotImplementedError(type(self).__name__) + + def get_mutation_names(self) -> Sequence[str]: + raise NotImplementedError(type(self).__name__) + + def get_operation_name(self) -> str: + raise NotImplementedError(type(self).__name__) + + def get_inputs_that_alias_output(self) -> Sequence[str]: + raise NotImplementedError(type(self).__name__) + + if TYPE_CHECKING: + + @property + def dtype(self) -> torch.dtype: ... + + +@ir_dataclass(frozen=False) +class Operation: + def __post_init__(self) -> None: + self.operation_name: Optional[str] = None + + def get_device(self) -> Optional[torch.device]: + raise NotImplementedError + + def get_origin_node(self) -> Optional[torch.fx.Node]: + assert hasattr(self, "origin_node") + return self.origin_node + + def get_origins(self) -> OrderedSet[Any]: + assert hasattr(self, "origins") + return self.origins + + def get_operation_name(self) -> str: + assert self.operation_name is not None + return self.operation_name + + def is_extern(self) -> bool: + return False + + def is_no_op(self) -> bool: + return False + + def get_read_writes(self) -> dependencies.ReadWrites: + raise NotImplementedError + + def is_user_of(self, name: str) -> bool: + return name in self.get_read_names() + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(dep.name for dep in self.get_reads()) + + def get_reads(self) -> OrderedSet[Dep]: + return self.get_read_writes().reads + + def get_outputs(self) -> list[Buffer]: + raise NotImplementedError + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + """ + When unbacked_only=True: + Returns the unbacked symbols which are required to be in scope in + order to successfully perform codegen for this buffer. For example, + a buffer that corresponds to an extern kernel call that takes i0 as + an argument would return {i0} here. This is used to generate necessary + dependencies that ensure we actually bind i0 in codegen before you + try to use it. + + Note that this is NOT transitive; in particular, if this buffer takes + in as input another buffer with dynamic shape (e.g., (i0,)), we will + not report it here, because you will already have a dependency + on that buffer, which will eventually have a dependency on i0 if + necessary. + + When unbacked_only=False: + Similar to `unbacked_only=True` but including all free symbols + instead of only free unbacked symbols. + """ + return OrderedSet() + + def get_workspace_size(self) -> int: + """ + Gets extra global memory size needed by this buffer. + Some algorithms (e.g. group gemm) may require extra global memory in the generated code. + """ + return 0 + + +@ir_dataclass +class Loops(IRNode): + device: torch.device + dtype: torch.dtype + inner_fn: Callable[..., Any] + ranges: Sequence[_IntLike] + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.ranges), + self.inner_fn_free_symbols(unbacked_only), + ) + + def _to_str(self, names: Sequence[str]) -> str: + return self.str_helper( + [ + f"'{self.device.type}'", + str(self.dtype), + self.inner_fn_str(), + ] + + [f"{name}={getattr(self, name)}" for name in names] + + [f"origin_node={self.origin_node!r}"] + ) + + def __post_init__(self) -> None: + super().__post_init__() + + def __str__(self) -> str: + return self._to_str(("ranges",)) + + __repr__ = __str__ + + def get_device(self) -> Optional[torch.device]: + return self.device + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return self.origin_node + + def get_size(self) -> Sequence[Expr]: + return self.ranges + + def get_pointwise_size(self) -> Sequence[Expr]: + return self.ranges + + @classmethod + def create(cls, *args: Any, **kwargs: Any) -> TensorBox: + origin_node = kwargs.pop("origin_node", None) + tb = kwargs.pop("traceback", None) + r = cls(*args, **kwargs) + # Need to explicitly set origin_node here to propagate it down. + # todo(chilli): I think it would be better for IRNode to directly set + # origin_node + r._post_init_setattr("origin_node", origin_node) + r._post_init_setattr("traceback", tb or r.traceback) + return TensorBox.create(r) + + @staticmethod + def _index(ranges: Sequence[_IntLike], prefix: SymT = SymT.INDEX) -> Sequence[Expr]: + return [ + sympy.S.Zero if s == 1 else sympy_index_symbol_with_prefix(prefix, n) + for n, s in enumerate(ranges) + ] + + @cache_on_self + def inner_fn_opcount(self) -> OpCountResult: + opcounter = OpCounterCSE(V.MockHandler()) + with ( + V.set_ops_handler(opcounter), + patch.object(FlexibleLayout, "allow_indexing", True), + ): + self.inner_fn(*self.inner_fn_args()) + return opcounter.getvalue() + + def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]: + return (self._index(self.ranges),) + + @cache_on_self + def inner_fn_str(self) -> str: + return V.KernelFormatterHandler.ir_to_string( + self.inner_fn, *self.inner_fn_args() + ) + + def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool: + if threshold is None: + threshold = 0 + threshold = max(threshold, config.realize_opcount_threshold) + return self.inner_fn_opcount().num_ops > threshold + + def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + index = self._index(self.ranges) + return extract_free_symbols(self.inner_fn, index, unbacked_only=unbacked_only) + + def get_reads(self) -> OrderedSet[Dep]: + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.get_reduction_type(): + return extract_read_writes( + self.make_loader(), + self.get_size(), + self.get_reduction_size(), + ).reads + else: + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(self.inner_fn_opcount().read_buffers) + + def num_reads(self) -> int: + return len(self.inner_fn_opcount().read_buffers) + + def get_reduction_size(self) -> Sequence[sympy.Expr]: + raise NotImplementedError( + f"get_reduction_size() is not implemented by {type(self)}!" + ) + + def get_reduction_type(self) -> Optional[str]: + raise NotImplementedError( + f"get_reduction_type() is not implemented by {type(self)}!" + ) + + def constant_to_device(self, device: torch.device) -> IRNode: + raise NotImplementedError( + f"constant_to_device() is not implemented by {type(self)}!" + ) + + +def nop_loader_fn(idx: Union[Expr, Sequence[Expr]], *, dtype: torch.dtype) -> OpsValue: + if dtype.is_floating_point: + return ops.constant(float("nan"), dtype) + else: + return ops.constant(0, dtype) + + +@ir_dataclass +class Pointwise(Loops): + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + # Make zero-element loops into a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.dtype) + + return self.inner_fn + + def get_reduction_size(self) -> Sequence[sympy.Expr]: + return [] + + def get_reduction_type(self) -> Optional[str]: + return None + + def store_output( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Never], + vars: Sequence[Expr], + ) -> None: + loader = self.make_loader() + return ops.store(output_name or "unnamed", indexer(vars), loader(vars)) + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise( + device=device, dtype=self.dtype, inner_fn=loader, ranges=self.ranges + ) + + +@ir_dataclass +class Scatter(Pointwise): + output_indexer: Callable[[Sequence[Expr]], Expr] + scatter_mode: StoreMode = None + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Scatter( + device=device, + dtype=self.dtype, + inner_fn=loader, + ranges=self.ranges, + output_indexer=self.output_indexer, + scatter_mode=self.scatter_mode, + ) + + def store_output( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Never], + vars: Sequence[Expr], + ) -> None: + loader = self.make_loader() + if output_name is None: + output_name = "unnamed" + return ops.store( + output_name, + indexer(self.output_indexer(vars)), + loader(vars), + mode=self.scatter_mode, + ) + + +REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = { + "any": ops_wrapper("logical_or"), + "max": ops_wrapper("maximum"), + "min": ops_wrapper("minimum"), + "prod": ops_wrapper("mul"), + "sum": ops_wrapper("add"), + "xor_sum": ops_wrapper("bitwise_xor"), +} + + +def get_reduction_combine_fn( + reduction_type: str, dtype: torch.dtype, arg_break_ties_left: bool = True +) -> Callable[..., object]: + if reduction_type in REDUCTION_COMBINE_FN: + return REDUCTION_COMBINE_FN[reduction_type] + + elif reduction_type in ("argmax", "argmin"): + + def argmax_combine_fn( + a: tuple[object, object], b: tuple[object, object] + ) -> tuple[OpsValue, OpsValue]: + a_value, a_index = a + b_value, b_index = b + + if reduction_type == "argmin": + mask = ops.lt(a_value, b_value) + else: + mask = ops.gt(a_value, b_value) + + equal = ops.eq(a_value, b_value) + if is_float_dtype(dtype): + a_isnan = ops.ne(a_value, a_value) + b_isnan = ops.ne(b_value, b_value) + mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan)) + equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan)) + + tie = ( + ops.lt(a_index, b_index) + if arg_break_ties_left + else ops.gt(a_index, b_index) + ) + mask = ops.logical_or(mask, ops.logical_and(equal, tie)) + return ( + ops.where(mask, a_value, b_value), + ops.where(mask, a_index, b_index), + ) + + return argmax_combine_fn + + elif reduction_type == "welford_combine": + + def welford_combine_fn( + a: tuple[OpsValue, OpsValue, OpsValue], + b: tuple[OpsValue, OpsValue, OpsValue], + ) -> tuple[OpsValue, OpsValue, OpsValue]: + a_mean, a_m2, a_weight = a + b_mean, b_m2, b_weight = b + + delta = b_mean - a_mean + new_weight = a_weight + b_weight + w2_over_w = b_weight / new_weight + return ( + a_mean + delta * w2_over_w, + a_m2 + b_m2 + delta * delta * a_weight * w2_over_w, + new_weight, + ) + + return welford_combine_fn + + else: + raise NotImplementedError(f"unknown reduction_type={reduction_type}") + + +@ir_dataclass +class Reduction(Loops): + reduction_ranges: Sequence[_IntLike] + reduction_type: ReductionType + # self.dtype represents the dst dtype + src_dtype: torch.dtype + reduction_hint: ReductionHint + + def __str__(self) -> str: + return self._to_str(("ranges", "reduction_ranges", "reduction_type")) + + __repr__ = __str__ + + def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges) + ) + + def get_reduction_size(self) -> Sequence[sympy.Expr]: + return self.reduction_ranges + + def get_reduction_type(self) -> Optional[str]: + return self.reduction_type + + def store_reduction( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Never], + vars: Sequence[Expr], + reduction_vars: Sequence[Symbol], + ) -> None: + value = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + return ops.store_reduction(output_name or "unnamed", indexer(vars), value) + + def index_length(self) -> int: + return len(self.ranges) + len(self.reduction_ranges) + + def inner_fn_args(self) -> Sequence[Sequence[Expr]]: + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, SymT.R0_INDEX) + return (index, rindex) + + def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, SymT.R0_INDEX) + return extract_free_symbols( + self.inner_fn, index, rindex, unbacked_only=unbacked_only + ) + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Reduction( + device=device, + dtype=self.dtype, + inner_fn=loader, + ranges=self.ranges, + reduction_ranges=self.reduction_ranges, + reduction_type=self.reduction_type, + src_dtype=self.src_dtype, + reduction_hint=ReductionHint.DEFAULT, + ) + + @staticmethod + def num_splits( + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., OpsValue], + ranges: Sequence[_IntLike], + reduction_ranges: Sequence[_IntLike], + reduction_type: Union[ReductionType, Literal["scan"]], + reduction_numel: Expr, + input_node: Optional[IRNode] = None, + ) -> tuple[ReductionHint, _IntLike]: + reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel) + numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges)) + + should_split = reduction_type == "scan" or ( + not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT) + and reduction_type + not in ( + "argmax", + "argmin", + ) + and config.split_reductions + ) + if not (_is_static(reduction_numel_hint) and _is_static(numel_hint)): + # We don't support unbacked symints + return ReductionHint.DEFAULT, 1 + + props = DeviceProperties.create(device) + num_sm = props.multi_processor_count + min_elements_per_thread = 32 + if should_split: + inner_reduction_splits: Callable[[int, int], int] = functools.partial( + V.choices.reduction_split_factor, device, inner_reduction=True + ) + outer_reduction_splits: Callable[[int, int], int] = functools.partial( + V.choices.reduction_split_factor, device, inner_reduction=False + ) + else: + + def inner_reduction_splits( + reduction_numel_hint: int, + numel_hint: int, + ) -> int: + return 1 + + outer_reduction_splits = inner_reduction_splits + + # easy cases + if numel_hint == 1: + split = inner_reduction_splits(reduction_numel_hint, numel_hint) + if split == 1: + # No need to split. + return ReductionHint.INNER, split + if input_node is not None and isinstance(input_node, TensorBox): + with patch.object(FlexibleLayout, "allow_indexing", True): + ( + new_ranges, + new_reduction_ranges, + ) = extract_input_node_reduction_ranges(input_node) + if new_ranges is not None and new_reduction_ranges is not None: + extracted_numel_hint = V.graph.sizevars.symbolic_hint( + sympy_product(new_ranges + new_reduction_ranges) + ) + if reduction_numel_hint == extracted_numel_hint: + log.debug( + "Use previous IRNode's range and reduction_ranges instead of split. " + "current ranges: %s, current reduction ranges: %s, current split: %d, " + "new ranges: %s, new reduction ranges: %s", + ranges, + reduction_ranges, + split, + new_ranges, + new_reduction_ranges, + ) + # If the input_node or its dependent nodes are also Reduction nodes, + # use reduction_sizes of this node or its dependent nodes directly. + return ReductionHint.INNER, -1 + return ReductionHint.INNER, split + if ( + reduction_numel_hint <= min_elements_per_thread + or numel_hint >= num_sm * 2 * 32 + ): + return ReductionHint.DEFAULT, 1 + + r = Reduction( + device=device, + dtype=dst_dtype, + inner_fn=inner_fn, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type if reduction_type != "scan" else "sum", + src_dtype=src_dtype, + reduction_hint=ReductionHint.DEFAULT, + ) + + def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]: + cb = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=r.get_device(), + dtype=r.get_dtype(), + size=r.get_size(), + ), + data=r, + ) + read_writes = cb.get_read_writes() + # try finding the full size producer + # TODO this will fail for something like ((1, N) * (N, 1)).sum() + # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare + assert read_writes.range_vars is not None + range_vars = [ + r + for r in read_writes.range_vars + if isinstance(r, Expr) and not isinstance(r, sympy.Number) + ] + indices = [] + changed = False + for md in sorted(read_writes.reads, key=lambda x: x.name): + if all(r in md.index.free_symbols for r in range_vars): + indices.append(md.index) + if md.name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[md.name] + original_stride = getattr(buf.layout, "stride", None) + buf.decide_layout() + if getattr(buf.layout, "stride", None) != original_stride: + changed = True + return indices, changed + + indices, changed = get_read_indices(r) + if changed: + indices, _ = get_read_indices(r) + + if len(indices) == 0: + # TODO determine splits when all inputs are broadcast + return ReductionHint.DEFAULT, 1 + + (_, reduction_vars), ranges1 = dependencies.index_vars_squeeze( + r.get_size(), r.get_reduction_size() + ) + num_outer = 0 + num_inner = 0 + for i in indices: + j = V.graph.sizevars.simplify_with_ranges(i, ranges1) + strides = V.graph.sizevars.stride_hints( + j, reduction_vars, list(ranges1.keys()) + ) + outer = all(s > 1 for s in strides) + if outer: + num_outer += 1 + else: + num_inner += 1 + if num_inner > num_outer: + return ReductionHint.INNER, inner_reduction_splits( + reduction_numel_hint, numel_hint + ) + else: + return ReductionHint.OUTER, outer_reduction_splits( + reduction_numel_hint, numel_hint + ) + + @staticmethod + def _unroll_reduction_fn( + inner_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], OpsValue], + reduction_ranges: Sequence[_IntLike], + reduction_type: str, + src_dtype: torch.dtype, + ) -> Callable[[Sequence[_IntLike]], OpsValue]: + """Convert inner_fn from a reduction to an pointwise""" + reduction_ranges = [ + V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges + ] + + combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) + + def fn(index: Sequence[_IntLike]) -> Any: + return functools.reduce( + combine_fn, + ( + value_fn(index, rindex) + for rindex in itertools.product( + *[range(x) for x in reduction_ranges] + ) + ), + ) + + value_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Any] + if reduction_type in ("argmin", "argmax"): + flatten_index = FixedLayout( + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + reduction_ranges, + FlexibleLayout.contiguous_strides(reduction_ranges), + ).make_indexer() + + def value_fn( + index: Sequence[_IntLike], rindex: Sequence[_IntLike] + ) -> tuple[OpsValue, OpsValue]: + rindex = [sympy.expand(i) for i in rindex] + return ( + inner_fn(index, rindex), + ops.index_expr(flatten_index(rindex), torch.int64), + ) + + return lambda index: fn(index)[1] + else: + value_fn = inner_fn + return fn + + @classmethod + def create( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: Sequence[Expr], + reduction_ranges: Sequence[Expr], + reduction_type: ReductionType, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + input_node: Optional[IRNode] = None, + ) -> TensorBox: + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + if reduction_numel == 0: + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val: object) -> Union[bool, float, int]: + if dst_dtype == torch.bool: + return bool(val) + elif dst_dtype.is_floating_point: + assert isinstance(val, typing.SupportsFloat) + return float(val) + else: + assert isinstance(val, typing.SupportsInt) + return int(val) + + rtypes_to_inits = { + "sum": py_cnst(0), + "xor_sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert reduction_type in rtypes_to_inits.keys(), ( + f"{reduction_type} not supported for zero-dimension tensors!" + ) + + def const_fn(index: int) -> OpsValue: + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + ) + + if reduction_numel == 1: + # this reduction is actually a pointwise op + if reduction_type in ("argmin", "argmax"): + + def fn(index: int) -> OpsValue: + return ops.constant(0, dst_dtype) + + else: + + def fn(index: int) -> OpsValue: + reduction_index = [sympy.S.Zero for _ in reduction_ranges] + return inner_fn(index, reduction_index) + + return Pointwise.create( + device=device, dtype=dst_dtype, inner_fn=fn, ranges=ranges + ) + + if ( + isinstance(reduction_numel, Integer) + and V.graph.sizevars.size_hint_or_throw(reduction_numel) + < config.unroll_reductions_threshold + and (sympy_product(ranges) != 1 or is_gpu(device.type)) + ): + # NB: This works around https://github.com/pytorch/pytorch/issues/140457 + # since turning reductions into pointwise ops can exacerbate this problem + return Pointwise.create( + device=device, + dtype=dst_dtype, + inner_fn=cls._unroll_reduction_fn( + inner_fn, reduction_ranges, reduction_type, src_dtype + ), + ranges=ranges, + ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = cls.num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node, + ) + + def _maybe_increase_split(split: int) -> int: + # don't apply min_num_split constraint for static shape case. + if _is_static(reduction_numel): + return split + if split > 1: + return max(split, config.min_num_split) + else: + return split + + split = _maybe_increase_split(split) + + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split == -1: + assert input_node is not None + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node + ) + assert new_ranges is not None + assert new_reduction_ranges is not None + return cls.create_multilayer_existing_ranges( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + elif split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + input_node, + ) + + return TensorBox.create( + Reduction( + device=device, + dtype=dst_dtype, + inner_fn=inner_fn, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, + ) + ) + + @staticmethod + def default_accumulator( + reduction_type: str, dtype: torch.dtype + ) -> Union[_NumLike, Sequence[_NumLike]]: + if reduction_type in ("max", "argmax"): + if is_float_dtype(dtype): + return float("-inf") + elif is_boolean_dtype(dtype): + return False + else: + return torch.iinfo(dtype).min + if reduction_type in ("min", "argmin"): + if is_float_dtype(dtype): + return float("inf") + elif is_boolean_dtype(dtype): + return True + else: + return torch.iinfo(dtype).max + + zero = False if is_boolean_dtype(dtype) else 0 + one = True if is_boolean_dtype(dtype) else 1 + return { + "sum": zero, + "prod": one, + "xor_sum": zero, + "any": zero, + "welford_reduce": (zero, zero, zero), + "welford_combine": (zero, zero, zero), + "online_softmax_reduce": (float("-inf"), zero), + }[reduction_type] + + @staticmethod + def default_value( + reduction_type: str, dtype: torch.dtype + ) -> Union[_NumLike, Sequence[_NumLike]]: + if reduction_type == "welford_reduce": + return 0 + return Reduction.default_accumulator(reduction_type, dtype) + + @staticmethod + def _multilayer_second_step_hint( + split: _IntLike, numel_hint: int, reduction_hint: ReductionHint + ) -> ReductionHint: + if split == -1: + return reduction_hint + if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER: + return ReductionHint.OUTER_TINY + if ( + split <= 1024 + and numel_hint <= 256 + and reduction_hint == ReductionHint.OUTER + ): + return ReductionHint.OUTER_TINY + + return reduction_hint + + @classmethod + def check_for_split_dense_dim_reindexing( + cls, reduction_numel: _IntLike, input_node: Optional[IRNode] + ) -> Optional[int]: + """ + If we are reducing over the full tensor, and it is non-dense in the last dimension, + reindex so we reduce over the dense dimension. initially just handle complete + reduction case + """ + if input_node is None: + return None + + if not V.graph.sizevars.statically_known_equals( + input_node.get_numel(), reduction_numel + ): + return None + + input_node.realize() + try: + # finalize layout + as_storage_and_layout(input_node) + except NotImplementedError: + return None + + strides = input_node.get_stride() + + for i, s in enumerate(strides[:-1]): + if V.graph.sizevars.statically_known_equals(s, 1): + return i + + return None + + @classmethod + def _multilayer_wrap_loader( + cls, + loader: Callable[..., OpsValue], + reduction_ranges: Sequence[_IntLike], + reduction_numel: _IntLike, + split: _IntLike, + block_size: _IntLike, + default: Union[_NumLike, Sequence[_NumLike]], + input_node: Optional[IRNode] = None, + ) -> Callable[..., object]: + dense_index = cls.check_for_split_dense_dim_reindexing( + reduction_numel, input_node + ) + reindex = View.dynamic_reshape_indexer( + reduction_ranges, [reduction_numel], dense_index + ) + need_mask = not V.graph.sizevars.statically_known_true( + sympy.Eq(reduction_numel % split, 0) + ) + + def wrapper_fn( + index: Sequence[Symbol], reduction_index: Sequence[Symbol] + ) -> OpsValue: + (reduction_index,) = reduction_index + *new_index, reduction_block = index + indices = block_size * reduction_block + reduction_index + + def body() -> OpsValue: + return loader(new_index, reindex([indices])) + + if need_mask: + index_dtype = dtype_from_size(reduction_numel) + mask = ops.lt( + ops.index_expr(indices, index_dtype), + ops.index_expr(reduction_numel, index_dtype), + ) + return ops.masked(mask, body, default) + else: + return body() + + return wrapper_fn + + @classmethod + def _multilayer_wrap_loader_existing_ranges( + cls, + loader: Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue], + original_ranges: Sequence[Expr], + original_reduction_ranges: Sequence[Expr], + new_ranges: Sequence[Integer], + new_reduction_ranges: Sequence[Integer], + ) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]: + assert all(r == 1 for r in original_ranges), ( + f"Only enabled for numel_hint == 1, found {original_ranges=}" + ) + reindex = View.dynamic_reshape_indexer( + original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) + ) + + def wrapper_fn( + merged_index: Sequence[sympy.Expr], + new_reduction_index: Sequence[sympy.Expr], + ) -> OpsValue: + original_idx = merged_index[: len(original_ranges)] + new_index = merged_index[len(original_ranges) :] + return loader( + original_idx, + reindex(tuple(new_index) + tuple(new_reduction_index)), + ) + + return wrapper_fn + + @classmethod + def create_multilayer_helper( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + wrapper_fn: Callable[..., Any], + original_ranges: Sequence[Expr], + original_reduction_ranges: Sequence[Expr], + new_ranges: list[Expr], + new_reduction_ranges: list[Integer], + reduction_type: ReductionType, + split: _IntLike, + reduction_hint: ReductionHint, + ) -> TensorBox: + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # triton will automatically compute reductions in fp32 if reducing over fp16/bf16 + # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction + # in fp32 and not reduce precision by breaking up the kernel into multiple layers + intermediate_dtype = ( + dst_dtype + if dst_dtype not in (torch.float16, torch.bfloat16) + else torch.float + ) + intermediate = Reduction.create( + device, + intermediate_dtype, + src_dtype, + wrapper_fn, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + intermediate.realize() + intermediate_loader = intermediate.make_loader() + + def intermediate_fn( + index: Sequence[_IntLike], reduction_index: Sequence[_IntLike] + ) -> OpsValue: + return intermediate_loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + + assert original_ranges == new_ranges[: len(original_ranges)] + return TensorBox.create( + Reduction( + device=device, + dtype=dst_dtype, + inner_fn=intermediate_fn, + ranges=original_ranges, + reduction_ranges=new_ranges[len(original_ranges) :], + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, + ) + ) + + @classmethod + def create_multilayer( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: Sequence[Expr], + reduction_ranges: Sequence[Expr], + reduction_type: ReductionType, + split: _IntLike, + reduction_hint: ReductionHint, + input_node: Optional[IRNode] = None, + ) -> TensorBox: + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # TODO(jansel): realize the reduction so we can do dynamic indexing + reduction_numel = sympy_product(reduction_ranges) + block_size = FloorDiv(reduction_numel + (split - 1), split) + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader( + inner_fn, + reduction_ranges, + reduction_numel, + split, + block_size, + default, + input_node, + ) + + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + ranges, + reduction_ranges, + [*ranges, split], + [block_size], + reduction_type, + split, + reduction_hint, + ) + + @classmethod + def create_multilayer_existing_ranges( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + original_ranges: Sequence[Expr], + original_reduction_ranges: Sequence[Expr], + new_ranges: list[Integer], + new_reduction_ranges: list[Integer], + reduction_type: ReductionType, + reduction_hint: ReductionHint, + ) -> TensorBox: + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + wrapper_fn = cls._multilayer_wrap_loader_existing_ranges( + inner_fn, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + ) + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + original_ranges, + original_reduction_ranges, + [*original_ranges, *new_ranges], + new_reduction_ranges, + reduction_type, + -1, + reduction_hint, + ) + + +INNER_FN_TY = Callable[[Sequence[Expr], Sequence[Expr]], OpsValue] + + +class MultiOutputReduction(Reduction): + output_index: int + + def __init__( + self, + device: torch.device, + dst_dtype: torch.dtype, + inner_fns: Union[INNER_FN_TY, Sequence[INNER_FN_TY]], + ranges: Sequence[Integer], + reduction_ranges: Sequence[Integer], + reduction_type: ReductionType, + src_dtype: torch.dtype, + reduction_hint: ReductionHint, + output_index: int, + ): + if callable(inner_fns): + inner_fns = (inner_fns,) + + loader: Callable[[Sequence[Expr], Sequence[Expr]], Any] + if len(inner_fns) == 1: + loader = inner_fns[0] + else: + + def loader( + idx: Sequence[Expr], reduction_idx: Sequence[Expr] + ) -> tuple[OpsValue, ...]: + return tuple(fn(idx, reduction_idx) for fn in inner_fns) + + super().__init__( + device=device, + dtype=dst_dtype, + inner_fn=loader, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, + ) + self.output_index = output_index + + def store_reduction( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Never], + vars: Sequence[Expr], + reduction_vars: Sequence[Symbol], + ) -> None: + values = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + assert isinstance(values, (tuple, list)), f"{type(values)}" + value = values[self.output_index] + return ops.store_reduction(output_name or "unnamed", indexer(vars), value) + + +class OnlineSoftmaxReduction(MultiOutputReduction): + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: Sequence[Expr], + reduction_ranges: Sequence[Expr], + num_output: int, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + input_node: Optional[IRNode] = None, + ) -> Sequence[TensorBox]: + """ + Create the reduction disregarding splitting. + """ + results = tuple( + TensorBox.create( + MultiOutputReduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + "online_softmax_reduce", # type: ignore[arg-type] + src_dtype, + reduction_hint, + output_idx, + ) + ) + for output_idx in range(num_output) + ) + for t in results: + t.realize() + return results + + +class WelfordReduction(MultiOutputReduction): + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: list[Integer], + reduction_ranges: list[Integer], + reduction_type: ReductionType, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + ) -> Sequence[TensorBox]: + assert reduction_type in ("welford_reduce", "welford_combine") + + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + def const(val: int) -> TensorBox: + def inner_fn(idx: Sequence[Expr]) -> OpsValue: + return ops.constant( + val, + dtype, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_numel == 0: + mean = const(0) + m2 = const(0) + weight = const(0) + return mean, m2, weight + + if reduction_numel == 1: + + def copy( + loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue], + ) -> TensorBox: + def inner_fn(idx: Sequence[Expr]) -> OpsValue: + reduction_index = [sympy.S.Zero for _ in reduction_ranges] + return loader(idx, reduction_index) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_type == "welford_reduce": + return copy(inner_fns[0]), const(0), const(1) + else: + return tuple(copy(fn) for fn in inner_fns) + + # TODO: Unrolled reduction + # if ( + # isinstance(reduction_numel, Integer) + # and V.graph.sizevars.size_hint(reduction_numel) + # < config.unroll_reductions_threshold + # and sympy_product(ranges) != 1 + # ): + # return Pointwise.create( + # device, + # dst_dtype, + # cls._unroll_reduction_fn( + # inner_fn, reduction_ranges, reduction_type, src_dtype, + # ), + # ranges, + # ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = Reduction.num_splits( + device, + dtype, + dtype, + inner_fns[0], + ranges, + reduction_ranges, + reduction_type=reduction_type, + reduction_numel=reduction_numel, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + results = [ + TensorBox.create( + WelfordReduction( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + dtype, + reduction_hint, + output_idx, + ) + ) + for output_idx in range(3) + ] + for t in results: + t.realize() + return results + + @staticmethod + def default_value( + reduction_type: str, dtype: torch.dtype + ) -> Union[_NumLike, Sequence[_NumLike]]: + return (0, 0, 0) + + @classmethod + def create_multilayer( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: list[Integer], + reduction_ranges: list[Integer], + reduction_type: ReductionType, + split: _IntLike, + reduction_hint: ReductionHint, + ) -> Sequence[TensorBox]: + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + reduction_numel = sympy_product(reduction_ranges) + need_mask = not V.graph.sizevars.statically_known_true( + sympy.Eq(reduction_numel % split, 0) + ) + + if need_mask and reduction_type != "welford_combine": + # If we need mask, then "welford_reduce" doesn't work because + # masked inputs shouldn't count towards the welford weight + + def constant( + idx: Sequence[Expr], reduction_idx: Sequence[Expr], value: int + ) -> OpsValue: + return ops.constant(value, dtype) + + return cls.create_multilayer( + device=device, + dtype=dtype, + inner_fns=( + inner_fns[0], + partial(constant, value=0), + partial(constant, value=1), + ), + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type="welford_combine", + split=split, + reduction_hint=reduction_hint, + ) + + block_size = FloorDiv(reduction_numel + (split - 1), split) + intermediates = WelfordReduction.create( + device, + dtype, + tuple( + cls._multilayer_wrap_loader( + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default=0, + ) + for loader in inner_fns + ), + [*ranges, split], + [block_size], + reduction_type, + reduction_hint, + ) + for i in intermediates: + i.realize() + + def intermediate_loader_fn( + index: Sequence[Expr], + reduction_index: Sequence[Expr], + loader: Callable[[Sequence[Expr]], OpsValue], + ) -> OpsValue: + return loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + return WelfordReduction.create( + device, + dtype, + tuple( + partial(intermediate_loader_fn, loader=i.make_loader()) + for i in intermediates + ), + ranges, + [split], + # welford_reduce turns one input into three outputs, which are combined with welford_combine + "welford_combine", + reduction_hint, + ) + + +@ir_dataclass +class Scan(Loops): + scan_ranges: list[Integer] + size: list[Integer] + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]] + reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]] + reduction_hint: ReductionHint + output_index: int + # output_index indexes the following tuples + dtypes: tuple[torch.dtype, ...] + inner_fns: tuple[Callable[..., Any], ...] + + # HACK we mimic reduction + + def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we + # need to explicitly represent the closure so we can pull out unbacked + # symbols here + return ( + super().get_free_symbol_uses(unbacked_only) + | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.scan_ranges) + ) + | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.size) + ) + ) + + def __post_init__(self) -> None: + assert len(self.ranges) + len(self.scan_ranges) == len(self.size) + super().__post_init__() + + def store_reduction( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[_IntLike]], Never], + vars: Sequence[Expr], + scan_vars: Sequence[Symbol], + ) -> None: + idx = self.reindex(vars, scan_vars) + values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) + result = ops.scan(self.dtypes, self.combine_fn, values) + return ops.store( + output_name or "unnamed", indexer(idx), result[self.output_index] + ) + + def get_reduction_type(self) -> Optional[str]: + # return self.scan_op + return "custom" + + def get_reduction_size(self) -> Sequence[sympy.Expr]: + return self.scan_ranges + + def get_size(self) -> Sequence[Expr]: + return self.size + + def get_pointwise_size(self) -> Sequence[Expr]: + return self.ranges + + def index_length(self) -> int: + return len(self.ranges) + len(self.scan_ranges) + + def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]: + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, SymT.R0_INDEX) + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, SymT.R0_INDEX) + idx = self.reindex(index, rindex) + return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only) + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtypes: tuple[torch.dtype, ...], + inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...], + size: list[Integer], + axis: int, + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + *, + # Whether we have the option to fallback to aten + can_fallback_to_aten: bool = True, + **kwargs: Any, + ) -> Sequence[Optional[TensorBox]]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + scan_ranges = [size[axis]] + + if not V.graph.has_feature(device, BackendFeature.SCAN): + return [None] * len(dtypes) + + if len(dtypes) > 1 and not V.graph.has_feature( + device, BackendFeature.TUPLE_REDUCTION + ): + return [None] * len(dtypes) + + sizevars = V.graph.sizevars + scan_numel = sizevars.simplify(sympy_product(scan_ranges)) + + assert len(dtypes) == len(inner_fns) + + # Scan with a single element is just a copy + if sizevars.statically_known_true(sympy.Le(scan_numel, 1)): + return [ + Pointwise.create( + device=device, + dtype=dtypes[output_index], + inner_fn=inner_fns[output_index], + ranges=size, + ) + for output_index in range(len(dtypes)) + ] + + reduction_hint, num_splits = cls.num_splits( + device=device, + dtype=dtypes[0], + inner_fn=inner_fns[0], + axis=axis, + pointwise_ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + scan_numel=scan_numel, + ) + scan_type = Scan + if num_splits > 1: + supports_split = ( + torch.version.hip is None or (has_triton and triton_version >= "3.3.0") + ) and (len(dtypes) == 1) + if not supports_split: + if can_fallback_to_aten: + # Fallback to ATen + return [None] * len(dtypes) + else: + num_splits = 1 + else: + scan_type = SplitScan + + def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> list[Expr]: + assert len(scan_index) == len(scan_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *scan_index, *index[axis:]] + + results = [ + TensorBox.create( + scan_type( + device=device, + dtype=dtypes[output_index], + dtypes=dtypes, + inner_fn=inner_fns[output_index], + inner_fns=inner_fns, + size=size, + ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + reindex=reindex, + reduction_hint=reduction_hint, + output_index=output_index, + **kwargs, + ) + ) + for output_index in range(len(dtypes)) + ] + + for result in results: + result.realize() + + return results + + @classmethod + def num_splits( + cls, + device: torch.device, + dtype: torch.dtype, + inner_fn: Callable[[Sequence[Expr]], OpsValue], + axis: int, + pointwise_ranges: list[Integer], + scan_ranges: list[Integer], + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], + scan_numel: Expr, + ) -> tuple[ReductionHint, _IntLike]: + # TODO: custom splitting heuristic for scan + def wrapper_fn(idx: Sequence[Expr], reduction_idx: Sequence[Expr]) -> OpsValue: + return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]]) + + return Reduction.num_splits( + device=device, + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=wrapper_fn, + ranges=pointwise_ranges, + reduction_ranges=scan_ranges, + reduction_type="scan", + reduction_numel=scan_numel, + ) + + +# This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA. +@ir_dataclass +class SplitScan(Scan): + pass + + +@ir_dataclass +class Sort(Loops): + # Sorts a tuple of key, value pairs + sort_ranges: list[Integer] + size: list[Integer] + reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]] + reduction_hint: ReductionHint + output_index: int + # output_index indexes the following tuples + dtypes: tuple[torch.dtype, ...] + inner_fns: tuple[Callable[..., Any], ...] + + stable: bool + descending: bool + + # HACK we mimic reduction + + def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + return ( + super().get_free_symbol_uses(unbacked_only) + | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.sort_ranges) + ) + | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.size) + ) + ) + + def __post_init__(self) -> None: + assert len(self.ranges) + len(self.sort_ranges) == len(self.size) + super().__post_init__() + + def store_reduction( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Expr], + vars: Sequence[Expr], + reduction_vars: Sequence[Expr], + ) -> None: + idx = self.reindex(vars, reduction_vars) + values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) + result = ops.sort(self.dtypes, values, self.stable, self.descending) + return ops.store( + output_name or "unnamed", indexer(idx), result[self.output_index] + ) + + def get_reduction_type(self) -> Optional[str]: + return "sort" + + def get_reduction_size(self) -> Sequence[Expr]: + return self.sort_ranges + + def get_size(self) -> Sequence[Expr]: + return self.size + + def get_pointwise_size(self) -> Sequence[Expr]: + return self.ranges + + def index_length(self) -> int: + return len(self.ranges) + len(self.sort_ranges) + + def inner_fn_args(self) -> Sequence[Sequence[Expr]]: + index = self._index(self.ranges) + rindex = self._index(self.sort_ranges, SymT.R0_INDEX) + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + index = self._index(self.ranges) + rindex = self._index(self.sort_ranges, SymT.R0_INDEX) + idx = self.reindex(index, rindex) + return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only) + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtypes: tuple[torch.dtype, ...], + inner_fns: tuple[Callable[[list[Expr]], Any], ...], + size: list[Integer], + axis: int, + stable: bool, + descending: bool, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + **kwargs: Any, + ) -> Sequence[Optional[TensorBox]]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + sort_ranges = [size[axis]] + + if not V.graph.has_feature(device, BackendFeature.SORT): + return [None] * len(dtypes) + + sizevars = V.graph.sizevars + sort_numel = sizevars.simplify(sympy_product(sort_ranges)) + + # Heuristic, smallest rblock where triton usually outperforms aten.sort + # It also isn't bandwidth bound so fusion is unlikely to help. + max_rblock = 512 + is_persistent_kernel = ( + config.triton.persistent_reductions + and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock)) + ) + if not is_persistent_kernel: + # We only support persistent triton kernels + return [None] * len(dtypes) + + assert len(dtypes) == len(inner_fns) + + # Sort with a single element is just a copy + if sizevars.statically_known_true(sympy.Le(sort_numel, 1)): + return [ + Pointwise.create( + device=device, + dtype=dtypes[output_index], + inner_fn=inner_fns[output_index], + ranges=size, + ) + for output_index in range(len(dtypes)) + ] + + def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> list[Expr]: + assert len(sort_index) == len(sort_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *sort_index, *index[axis:]] + + results = [ + TensorBox.create( + Sort( + device=device, + dtype=dtypes[output_index], + dtypes=dtypes, + inner_fn=inner_fns[output_index], + inner_fns=inner_fns, + size=size, + ranges=pointwise_ranges, + sort_ranges=sort_ranges, + reindex=reindex, + reduction_hint=reduction_hint, + output_index=output_index, + stable=stable, + descending=descending, + **kwargs, + ) + ) + for output_index in range(len(dtypes)) + ] + + for result in results: + result.realize() + + return results + + +def is_storage_and_layout(x: IRNode) -> bool: + try: + as_storage_and_layout(x, freeze=False) + return True + except NotImplementedError: + return False + + +def is_contiguous_storage_and_layout(x: IRNode) -> bool: + try: + _buffer, layout = as_storage_and_layout(x, freeze=False) + # pad the stride here so we will NOT claim an tensor as contiguous + # if a padding is gonna happen. + if layout.should_pad_strides(): + layout.pad_strides() + return layout.is_contiguous() + except NotImplementedError: + return False + + +def as_storage_and_layout( + x: IRNode, + freeze: bool = True, + want_contiguous: bool = False, + stride_order: Optional[Sequence[Union[int, Integer]]] = None, + allow_padding: bool = False, + exact_strides: Optional[Sequence[Union[int, Integer]]] = None, +) -> tuple[StorageBox, Layout]: + """ + Try to simplify x into a StorageBox and a Layout. + + allow_padding only affect how we apply stride_order. When allow_padding + is True, we have the freedom to add padding when applying the stride_order. + """ + if isinstance(x, TensorBox): + return as_storage_and_layout( + x.data, + freeze=freeze, + want_contiguous=want_contiguous, + stride_order=stride_order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + if isinstance(x, StorageBox): + _, layout = as_storage_and_layout( + x.data, + freeze=freeze, + want_contiguous=want_contiguous, + stride_order=stride_order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + return x, x.data.get_layout() + if isinstance(x, Buffer): + if freeze: + if want_contiguous: + x.freeze_layout() + assert x.get_layout().is_contiguous() + elif stride_order is not None: + x.freeze_layout_with_stride_order( + stride_order, allow_padding=allow_padding + ) + elif exact_strides is not None: + x.freeze_layout_with_exact_strides( + exact_strides, allow_padding=allow_padding + ) + else: + x.decide_layout() + return StorageBox(x), x.get_layout() + if isinstance(x, ReinterpretView): + # making the base of x contiguous or stride_ordered will not necessarily make + # the ReinterpretView either, so don't pass along those arguments + buffer, _ = as_storage_and_layout( + x.data, + freeze=freeze, + ) + return buffer, x.layout + raise NotImplementedError + + +def is_stride_order_storage_and_layout( + x: IRNode, stride_order: Sequence[Union[int, Integer]] +) -> bool: + try: + _buffer, layout = as_storage_and_layout(x, freeze=False) + return layout.is_stride_ordered(stride_order) + except NotImplementedError: + return False + + +def is_unaligned(node: IRNode) -> bool: + if isinstance(node, (TensorBox, StorageBox)): + return is_unaligned(node.data) + + if isinstance(node, ReinterpretView): + layout = node.layout + has_unaligned_layout = not statically_known_true( + layout.offset * get_dtype_size(layout.dtype) % GPU_ALIGN_BYTES == 0 + ) + return is_unaligned(node.data) or has_unaligned_layout + + if isinstance(node, Buffer): + return node.get_name() in V.graph.unaligned_buffers + + # assume to be aligned otherwise + return False + + +@ir_dataclass +class BaseView(IRNode): + data: IRNode + + def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + return self.data.get_free_symbol_uses(unbacked_only) + + def make_reindexer(self) -> Callable[[Sequence[Expr]], Sequence[Expr]]: + raise NotImplementedError(f"make_reindexer NYI on {self}") + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + inner = self.data.make_indexer() + reindex = self.make_reindexer() + + def indexer(idx: Sequence[Expr]) -> Expr: + return inner(reindex(idx)) + + return indexer + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + inner = self.data.make_loader() + reindex = self.make_reindexer() + + def loader(idx: Sequence[Expr]) -> OpsValue: + return inner(reindex(idx)) + + return loader + + @property + def dtype(self) -> torch.dtype: + return self.data.get_dtype() + + def get_layout(self) -> Layout: + return self.data.get_layout() + + def get_device(self) -> Optional[torch.device]: + return self.data.get_device() + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return None + + def get_name(self) -> str: + return self.data.get_name() + + def get_pointwise_size(self) -> Sequence[Expr]: + return self.get_size() + + def mark_reuse(self, users: int) -> None: + return self.data.mark_reuse(users) + + def has_exceeded_max_reads(self) -> bool: + return self.data.has_exceeded_max_reads() + + def realize(self) -> Optional[str]: + return self.data.realize() + + def realize_hint(self): # type: ignore[no-untyped-def] + return self.data.realize_hint() + + def get_storage_numel(self): # type: ignore[no-untyped-def] + return self.data.get_storage_numel() + + def is_extern(self) -> bool: + return self.data.is_extern() # type: ignore[attr-defined] + + def is_module_buffer(self) -> bool: + return self.data.is_module_buffer() # type: ignore[attr-defined] + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_reads(self) -> OrderedSet[Dep]: + with patch.object(FlexibleLayout, "allow_indexing", True): + return extract_read_writes( + self.make_loader(), + self.get_size(), # type: ignore[arg-type] + ).reads + + def unwrap_view(self): # type: ignore[no-untyped-def] + x: IRNode = self + while isinstance(x, BaseView): + x = x.data + return x + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise( + device=device, + dtype=self.get_dtype(), + inner_fn=loader, + ranges=self.get_size(), + ) + + +@ir_dataclass +class ExpandView(BaseView): + size: list[Expr] + + @staticmethod + def _normalize_size(x, new_size): # type: ignore[no-untyped-def] + """Replace `-1` with correct sizes""" + sizevars = V.graph.sizevars + new_size = list(map(sympy.expand, new_size)) + old_size = x.get_size() + old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) + assert len(new_size) == len(old_size) + for i in range(len(new_size)): + if new_size[i] == -1: + assert old_size[i] is not None + new_size[i] = old_size[i] + elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(old_size[i], 1), size_oblivious=True + ): + pass + else: + # Sanity check: Expect broadcast compatibility + # + # NB: new_size[i] == old_size[i] is expected to already be + # guarded because the meta formula was expected to have taught + # us this equality. + assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, ( + "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}" + ) + return new_size + + @classmethod + def create(cls, x, new_size): # type: ignore[no-untyped-def] + new_size = cls._normalize_size(x, new_size) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + skip = len(new_size) - len(old_layout.size) + assert skip >= 0 + new_stride = [sympy.S.Zero] * skip + for stride, size in zip(old_layout.stride, old_layout.size): + new_stride.append( + stride + if not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(size, 1), size_oblivious=True + ) + else sympy.S.Zero + ) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(new_size), + new_stride, + old_layout.offset, + ) + return ReinterpretView(data=storage, layout=new_layout) + + return ExpandView(data=x, size=new_size) + + def get_size(self) -> Sequence[Expr]: + return self.size + + def make_reindexer(self): # type: ignore[no-untyped-def] + target = self.get_size() + actual = self.data.get_size() + skip = len(target) - len(actual) + + def reindex(index): # type: ignore[no-untyped-def] + index = list(index[skip:]) + assert len(index) == len(actual) + for i in range(len(actual)): + if actual[i] == 1: + # zero out broadcast dimension + index[i] = sympy.S.Zero + return index + + return reindex + + +@ir_dataclass +class PermuteView(BaseView): + dims: list[Expr] + + @classmethod + def create(cls, x, dims): # type: ignore[no-untyped-def] + dims = cls._map_neg_dims(dims) + assert OrderedSet(dims) == OrderedSet(range(len(dims))) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + [old_layout.size[i] for i in dims], + [old_layout.stride[i] for i in dims], + old_layout.offset, + ) + return ReinterpretView(data=storage, layout=new_layout) + + return PermuteView(data=x, dims=dims) + + @classmethod + def _map_neg_dims(cls, dims): # type: ignore[no-untyped-def] + return [dim if dim >= 0 else len(dims) + dim for dim in dims] + + def get_size(self) -> Sequence[Expr]: + assert OrderedSet(self._map_neg_dims(self.dims)) == OrderedSet( + range(len(self.dims)) + ) + size = self.data.get_size() + return [size[i] for i in self.dims] + + def make_reindexer(self): # type: ignore[no-untyped-def] + inv = {j: i for i, j in enumerate(self.dims)} + inv = [inv[i] for i in range(len(self.dims))] + assert OrderedSet(inv) == OrderedSet(range(len(self.dims))) + + def reindex(index): # type: ignore[no-untyped-def] + return [index[i] for i in inv] + + return reindex + + +@ir_dataclass +class SqueezeView(BaseView): + @classmethod + def create(cls, x, *, dim=None): # type: ignore[no-untyped-def] + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_size = [] + new_stride = [] + if dim is not None: + assert isinstance(dim, int), "expected integer dim argument" + assert 0 <= dim and dim < len(old_layout.size) + + for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): + if dim is None: + if size != 1: + new_size.append(size) + new_stride.append(stride) + else: + if i != dim: + new_size.append(size) + new_stride.append(stride) + else: + assert size == 1, "expected squeezed size to be 1" + + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset, + ) + return ReinterpretView(data=storage, layout=new_layout) + + if dim is None: + # redirect to a generic view + return View.create(x, [s for s in x.get_size() if s != 1]) + else: + assert x.get_size()[dim] == 1 + return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) + + @staticmethod + def squeezer(size: Sequence[sympy.Expr]): # type: ignore[no-untyped-def] + new_size = [s for s in size if s != 1] + not_one = [i for i, s in enumerate(size) if s != 1] + length = len(size) + + def reindex(index: list[sympy.Expr]) -> tuple[sympy.Expr, ...]: + assert len(index) == len(not_one), f"{index} {not_one}" + new_index = [sympy.S.Zero] * length + for idx, s in zip(not_one, index): + new_index[idx] = s + return tuple(new_index) + + return new_size, reindex + + def __init__(self, data) -> None: # type: ignore[no-untyped-def] + raise AssertionError("use SqueezeView.create()") + + +@ir_dataclass +class GenericView(BaseView): + size: list[Expr] + reindex: Callable[..., Any] + + def make_reindexer(self): # type: ignore[no-untyped-def] + return self.reindex + + def reindex_str(self) -> str: + index_old = [ + sympy_index_symbol_with_prefix(SymT.INDEX, n) for n in range(len(self.size)) + ] + index_new = list(self.reindex(index_old)) + return f"lambda {', '.join(map(str, index_old))}: {index_new}" + + def __str__(self) -> str: + return self.str_helper( + [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"] + ) + + __repr__ = __str__ + + @classmethod + def create(cls, x, new_size, reindex): # type: ignore[no-untyped-def] + return cls(data=x, size=list(new_size), reindex=reindex) + + def get_size(self) -> Sequence[Expr]: + return self.size + + +@ir_dataclass +class View(GenericView): + @staticmethod + def handle_negative_index(idx, size): # type: ignore[no-untyped-def] + idx = sympy.expand(idx) + size = sympy.expand(size) + evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr + if evaluate_expr(sympy.Lt(idx, 0)): + idx = idx + size + return idx + + @classmethod + def create(cls, x, new_size): # type: ignore[no-untyped-def, override] + assert isinstance(new_size, (tuple, list)) + old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) + + # Skip pointless views + if V.graph.sizevars.statically_known_list_equals(old_size, new_size): + return x + + unbacked_symbols_in_sizes = False + if ( + len(free_unbacked_symbols(old_size)) > 0 + or len(free_unbacked_symbols(new_size)) > 0 + ): + unbacked_symbols_in_sizes = True + + if 0 in new_size: + + def fake_reindex(index): # type: ignore[no-untyped-def] + return tuple([0] * len(old_size)) + + return cls(data=x, size=list(new_size), reindex=fake_reindex) + # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout + elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: + if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)): + # realize x; otherwise, the dynamic_reshape_indexer below will fail + # due to the size_hint's inability to process unbacked SymInts + # TODO: unbacked should not diverge from backed in determining striding + # Need to require contiguous here instead of realize, see: + # https://github.com/pytorch/pytorch/issues/145561 + x = ExternKernel.require_contiguous(x) + + storage, old_layout = as_storage_and_layout(x, want_contiguous=True) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + FlexibleLayout.contiguous_strides(new_size), + old_layout.offset, + ) + return ReinterpretView(data=storage, layout=new_layout) + + reindex = cls.dynamic_reshape_indexer(old_size, new_size) + return cls(data=x, size=list(new_size), reindex=reindex) + + @staticmethod + def resolve_negative_size(old_size, new_size): # type: ignore[no-untyped-def] + new_size = [V.graph.sizevars.simplify(x) for x in new_size] + old_size = [V.graph.sizevars.simplify(x) for x in old_size] + + new_size = list(new_size) + for i in range(len(new_size)): + if new_size[i] == -1: + new_size[i] = sympy.S.One + new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) + break + + V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) + return old_size, new_size + + @classmethod + def dynamic_reshape_indexer( + cls, + old_size: Sequence[_IntLike], + new_size: Sequence[_IntLike], + dense_dim: Optional[int] = None, # type: ignore[no-untyped-def] + ) -> Callable[[Sequence[_T]], Sequence[_V]]: + try: + reindex = cls._dynamic_reshape_indexer(old_size, new_size, dense_dim) + except (AssertionError, IndexError): + # optimistic algorithm failed, lets do a fallback + flat = [sympy_product(old_size)] + reindex1 = cls._dynamic_reshape_indexer(old_size, flat) + reindex2 = cls._dynamic_reshape_indexer(flat, new_size) + reindex = fuse_reindexing(reindex1, reindex2) + return reindex + + @staticmethod + def _dynamic_reshape_indexer(old_size, new_size, dense_dim: Optional[int] = None): # type: ignore[no-untyped-def] + """ + Perform a reshape entirely by modifying indexing math + """ + size_hint = V.graph.sizevars.size_hint + # TODO: These symbols may not escape, if they don't assert so and + # treat them as temporary + vars = [ + sympy_index_symbol_with_prefix(SymT.VIEW, i) for i in range(len(new_size)) + ] + + stack_new = list(zip(vars, new_size)) + stack_old = list(old_size) + + # process the dense dim first + reordering_dense_dim = ( + dense_dim is not None + and dense_dim != len(stack_old) - 1 + and len(new_size) == 1 + ) + if reordering_dense_dim: + assert dense_dim is not None # mypy + old_dim = stack_old.pop(dense_dim) + stack_old.append(old_dim) + + view_expr = [] + while stack_new and stack_old: + size_old = stack_old.pop() + var, size_new = stack_new.pop() + if size_old == 1: + view_expr.append(sympy.S.Zero) + stack_new.append((var, size_new)) # re-add + elif size_new == 1: + stack_old.append(size_old) # re-add + elif size_hint(size_new) == size_hint(size_old): + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) < size_hint(size_old): + while size_hint(size_new) < size_hint(size_old): + var2, size_new2 = stack_new.pop() + var = var2 * size_new + var + size_new = size_new * size_new2 + view_expr.append(var) + V.graph.sizevars.guard_equals(size_new, size_old) + elif size_hint(size_new) > size_hint(size_old): + divisor = sympy.S.One + modulus = size_old + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + while size_hint(size_new) > size_hint(size_old): + modulus = stack_old.pop() + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + size_old = size_old * modulus + V.graph.sizevars.guard_equals(size_new, size_old) + else: + raise AssertionError + + while stack_old: + size_old = stack_old.pop() + V.graph.sizevars.guard_equals(size_old, 1) + view_expr.append(sympy.S.Zero) + + while stack_new: + var, size_new = stack_new.pop() + V.graph.sizevars.guard_equals(size_new, 1) + + if dense_dim is not None and len(new_size) == 1: + view_expr.reverse() + # Move the last expression (dense dim) to its original position + dense_expr = view_expr.pop() + view_expr.insert(dense_dim, dense_expr) + else: + view_expr.reverse() + + assert len(view_expr) == len(old_size) + + def reindex(index): # type: ignore[no-untyped-def] + assert len(index) == len(vars), (len(index), len(vars)) + replacements = dict(zip(vars, index)) + return tuple(sympy_subs(x, replacements) for x in view_expr) + + return reindex + + +@ir_dataclass +class ReinterpretView(BaseView): + """Pretend our storage has a different layout""" + + layout: Layout + + def __post_init__(self) -> None: + super().__post_init__() + if isinstance(self.data, BaseView): + object.__setattr__(self, "data", self.data.unwrap_view()) + + def __str__(self) -> str: + return self.str_helper( + [ + self.data, + self.layout, + ] + ) + + __repr__ = __str__ + + def get_name(self) -> str: + return self.data.get_name() + + def get_device(self) -> Optional[torch.device]: + return self.layout.device + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return None + + @property + def dtype(self): # type: ignore[no-untyped-def] + return self.layout.dtype + + def get_size(self) -> Sequence[Expr]: + return list(self.layout.size) + + def get_stride(self): # type: ignore[no-untyped-def] + return list(self.layout.stride) + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + def loader(index: Sequence[Expr]) -> OpsValue: + indexer = self.layout.make_indexer() + tmp_loader = ops.load(self.get_name(), indexer(index)) + if self.layout.dtype != self.data.dtype: + return ops.to_dtype_bitcast(tmp_loader, self.dtype, self.data.dtype) + else: + return tmp_loader + + return loader + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.layout.make_indexer() + + def get_layout(self) -> Layout: + return self.layout + + def freeze_layout(self): # type: ignore[no-untyped-def] + pass + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return ( + get_free_symbols(self.layout.size, unbacked_only) + | get_free_symbols(self.layout.stride, unbacked_only) + | get_free_symbols(self.layout.offset, unbacked_only) + ) + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + # reinterpret_tensor is similar to as_strided except: + # - offset is added to the existing offset (rather than replacing it) + # - view tracking is disabled similar to unsafe_view + return V.graph.wrapper_code.codegen_reinterpret_view( + self.data, + self.layout.size, + self.layout.stride, + self.layout.offset, + writer.writeline if writer is not None else V.graph.wrapper_code.writeline, + dtype=self.layout.dtype, + ) + + def num_reads(self) -> int: + return 1 + + +@ir_dataclass +class DtypeView(BaseView): + """Pretend our storage has a different type""" + + target_dtype: torch.dtype + + @classmethod + def create(cls, x, new_dtype): # type: ignore[no-untyped-def] + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + new_dtype, + old_layout.size, + old_layout.stride, + old_layout.offset, + ) + return ReinterpretView(data=storage, layout=new_layout) + return DtypeView(data=x, target_dtype=new_dtype) + + def __str__(self) -> str: + return self.str_helper([self.data, self.target_dtype]) + + __repr__ = __str__ + + @property + def dtype(self): # type: ignore[no-untyped-def] + return self.target_dtype + + def get_size(self) -> Sequence[Expr]: + return self.data.get_size() + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + inner = self.data.make_loader() + + def loader(idx): # type: ignore[no-untyped-def] + return ops.to_dtype_bitcast(inner(idx), self.target_dtype, self.data.dtype) + + return loader + + +class SliceView(View): + @classmethod + def normalize_start_end(cls, x, dim, start, end): # type: ignore[no-untyped-def] + """ + Normalize start and end such that both are in the range + [0, x.get_size()[dim]] and start <= end. + """ + sizevars = V.graph.sizevars + dim_size = x.get_size()[dim] + + if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): + min_func = sympy.Min + max_func = sympy.Max + else: + min_func = sizevars.evaluate_min + max_func = sizevars.evaluate_max + + def clamp(x, lower, upper): # type: ignore[no-untyped-def] + clamped_lower = ( + x if sizevars.statically_known_geq(x, lower) else max_func(x, lower) + ) + clamped_full = ( + clamped_lower + if sizevars.statically_known_leq(clamped_lower, upper) + else min_func(clamped_lower, upper) + ) + return clamped_full + + def clamp_wrap(val, lower, upper, default): # type: ignore[no-untyped-def] + if val is None: + return default + val = cls.handle_negative_index(val, dim_size) + return clamp(val, lower, upper) + + start = clamp_wrap(start, 0, dim_size, 0) + end = clamp_wrap(end, start, dim_size, dim_size) + return start, end + + @classmethod + def create(cls, x, dim, start, end, step=1, clamp=True): # type: ignore[no-untyped-def, override] + step = sympy.expand(step) + assert isinstance(step, sympy.Expr) or step > 0 + try: + if start == 0 and end >= 2**63 - 1 and step == 1: + return x + except TypeError: + pass + + new_size = list(x.get_size()) + + # NB: Ordinarily we default to clamping. + # We only don't clamp for split_with_sizes. For split_with_sizes, sizes should be already valid + # failing in this situation is ok, since invalid sizes could trigger silent errors. + if clamp: + start, end = cls.normalize_start_end(x, dim, start, end) + + new_size[dim] = FloorDiv(end - start + (step - 1), step) + + if is_storage_and_layout(x): + # Fast path + storage, old_layout = as_storage_and_layout(x) + new_stride = list(old_layout.stride) + new_stride[dim] = new_stride[dim] * step + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset + old_layout.stride[dim] * start, + ) + return ReinterpretView(data=storage, layout=new_layout) + + def reindex(index): # type: ignore[no-untyped-def] + assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" + index = list(index) + index[dim] = index[dim] * step + start + return index + + # redirect to a generic view + return SliceView(data=x, size=new_size, reindex=reindex) + + +@ir_dataclass +class BaseConstant(IRNode): + dtype: torch.dtype + device: torch.device + + def get_size(self) -> Sequence[Expr]: + return () + + def get_device(self) -> Optional[torch.device]: + return self.device + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return None + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + +@ir_dataclass +class Constant(BaseConstant): + value: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + def loader(index: Sequence[Expr]) -> OpsValue: + return ops.constant(self.value, self.dtype) + + return loader + + def realize(self) -> Optional[str]: + pass + + def constant_to_device(self, device: torch.device) -> IRNode: + return Constant(value=self.value, dtype=self.dtype, device=device) + + +@ir_dataclass +class IndexingConstant(BaseConstant): + index: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + def loader(index: Sequence[Expr]) -> OpsValue: + return ops.index_expr(self.index, self.dtype) + + return loader + + def constant_to_device(self, device: torch.device) -> IRNode: + return IndexingConstant(index=self.index, dtype=self.dtype, device=device) + + +def is_contiguous_strides_for_shape( + stride: Sequence[_IntLike], shape: Sequence[_IntLike] +) -> bool: + return all( + size == 1 or left == right + for left, right, size in zip( + stride, FlexibleLayout.contiguous_strides(shape), shape + ) + ) + + +def get_align_for_dtype(dtype: torch.dtype) -> int: + return config.padding_alignment_bytes // dtype.itemsize + + +class OutputSpec: + """Abstract base for Layout, MultiOutputLayout, NoneLayout. + Represents the memory layout of the output of an Operation.""" + + def get_device(self) -> Optional[torch.device]: + raise NotImplementedError(type(self).__name__) + + def storage_size(self) -> int: + raise NotImplementedError(type(self).__name__) + + +@ir_dataclass +class Layout(OutputSpec): + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: list[Expr], + stride: Optional[list[Expr]] = None, + offset: Expr = Integer(0), + ) -> None: + if stride is None: + stride = FlexibleLayout.contiguous_strides(size) + self.device = device + self.dtype = dtype + assert len(size) == len(stride), f"size={size}, stride={stride}" + assert all(isinstance(s, (Expr, int)) for s in size) + self.size: list[Expr] = size + self.stride: list[Expr] = stride + self.offset: Expr = offset + + def __str__(self) -> str: + offset = "" + if self.offset != 0: + offset = f", offset={self.offset}" + + device_index_str = "" if self.device.index is None else f":{self.device.index}" + return ( + f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, " + f"size={self.size}, stride={self.stride}{offset})" + ) + + __repr__ = __str__ + + def get_device(self) -> torch.device: + return self.device + + def get_example(self) -> torch.Tensor: + with V.fake_mode: + return torch.empty_strided( + convert_shape_to_symint(self.size), + convert_shape_to_symint(self.stride), + dtype=self.dtype, + device=self.device, + ) + + def is_contiguous(self) -> bool: + return is_contiguous_strides_for_shape(self.stride, self.size) + + @staticmethod + def is_channels_last_contiguous( + shape: Sequence[_IntLike], strides: Sequence[_IntLike] + ) -> bool: + ndim = len(shape) + if ndim not in [4, 5] or shape[1] == 1: + return False + for left, right, size in zip( + strides, make_channels_last_strides_for(shape), shape + ): + if size != 1 and left != right: + return False + return True + + def is_transposed(self) -> bool: + for left, right, size in zip( + self.stride, + reversed(FlexibleLayout.contiguous_strides(list(reversed(self.size)))), + self.size, + ): + if size != 1 and left != right: + return False + return True + + def is_stride_ordered(self, order) -> bool: # type: ignore[no-untyped-def] + assert len(self.stride) == len(order) + + # ignore dimensions of size 1, they dont affect layout + non_1_indices = [ + i + for i, dim in enumerate(self.size) + if V.graph.sizevars.size_hint(dim, fallback=2) != 1 + ] + + stride = [self.stride[i] for i in non_1_indices] + order = [order[i] for i in non_1_indices] + + def sorted_indices(arr): # type: ignore[no-untyped-def] + sorted_arr = sorted(arr) + return [sorted_arr.index(element) for element in arr] + + # since we may have removed dimensions, need to re-sort & re-index order + order = sorted_indices(order) + + # reorder the stride given order + stride_ordered = [-1] * len(order) + for i in range(len(order)): + stride_ordered[order[i]] = stride[i] + # check if it is in ascending order + for i in range(len(order) - 1): + expr = stride_ordered[i] > stride_ordered[i + 1] + if not isinstance(expr, bool): + expr = V.graph._shape_env.evaluate_expr( + stride_ordered[i] > stride_ordered[i + 1], size_oblivious=True + ) + if expr: + return False + return True + + def is_channels_last_stride_ordered(self): # type: ignore[no-untyped-def] + # create channels_last order(NCHW, NCDHW, the C is the first order). + order = [0] + list(reversed(range(1, len(self.stride) - 1))) + order = [len(order)] + order + return self.is_stride_ordered(order) + + @staticmethod + def _pad_strides(in_strides, size, dtype): # type: ignore[no-untyped-def] + """ + The padding does not change stride order but makes sure all strides larger + than the threshold are multiple of align. + """ + align = get_align_for_dtype(dtype) + if len(in_strides) == 0: + return in_strides + + if not config.pad_channels_last and Layout.is_channels_last_contiguous( + size, in_strides + ): + return in_strides + + current_fx_node = V.get_current_node() + if hasattr(current_fx_node, "meta") and current_fx_node.meta.get( + "dislike_padding", False + ): + return in_strides + + # get_stride_order does not work with dynamic shape. Also we can not + # statically decide if a padding is needed or how much padding we should + # do for dynamic shape. + # + # Skip padding the strides for dynamic shape for now. + if not all( + isinstance(s, (int, sympy.Integer)) + for s in itertools.chain(in_strides, size) + ): + return in_strides + + stride_order = get_stride_order(in_strides) + fill_order = stride_order2fill_order(stride_order) + + new_strides = [0 for _ in range(len(in_strides))] + # since we pad when the layout is flexible, we can decide the + # smallest stride to be 1. + new_strides[fill_order[0]] = 1 + + padded = False + for rank, idx in enumerate(fill_order[1:], start=1): + prev_idx = fill_order[rank - 1] + stride = new_strides[prev_idx] * size[prev_idx] + + if stride > config.padding_stride_threshold and stride % align != 0: + stride = ceildiv(stride, align) * align + padded = True + new_strides[idx] = stride + + if not padded: + # Consider a tensor with shape [256, 1, 5, 5] + # Avoid strides like [25, 5, 5, 1] being padded to equivalent strides + # [25, 25, 5, 1]. + return in_strides + + metrics.num_comprehensive_padding += 1 + return new_strides + + def pad_strides(self): # type: ignore[no-untyped-def] + assert isinstance(self, FlexibleLayout) + assert self.stride is not None + self.stride = self._pad_strides(self.stride, self.size, self.dtype) + + def should_pad_strides(self): # type: ignore[no-untyped-def] + return config.comprehensive_padding and isinstance(self, FlexibleLayout) + + def as_fixed(self): # type: ignore[no-untyped-def] + if isinstance(self, FixedLayout): + return self + + if self.should_pad_strides(): + self.pad_strides() + return FixedLayout( + self.device, + self.dtype, + self.size, + self.stride, + self.offset, + ) + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + assert FlexibleLayout.allow_indexing, ( + f"convert {type(self).__name__} to FixedLayout first" + ) + return self.as_fixed().make_indexer() + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + return ( + self.device == other.device + and self.dtype == other.dtype + and self.size == other.size + and self.stride == other.stride + and self.offset == other.offset + ) + + def storage_size(self) -> sympy.Expr: + return compute_required_storage_length(self.size, self.stride, self.offset) + + +class FixedLayout(Layout): + """A Tensor layout we cannot change""" + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + """A closure containing math to read a given element""" + + def indexer(index): # type: ignore[no-untyped-def] + assert len(index) == len(self.stride) + assert len(index) == len(self.size) + result = self.offset + for idx, stride, sz in zip(index, self.stride, self.size): + if sz != 1: + result = result + idx * stride + return result + + return indexer + + +class FlexibleLayout(Layout): + """A Tensor layout we are allowed to change""" + + allow_indexing = False + + # WARNING! This doesn't handle zero size tensors correctly + @staticmethod + def contiguous_strides(sizes): # type: ignore[no-untyped-def] + if len(sizes) == 0: + return [] + reversed_strides = [sympy.S.One] + for size in reversed(sizes[1:]): + reversed_strides.append(size * reversed_strides[-1]) + return list(reversed(reversed_strides)) + + @staticmethod + def fill_ordered(sizes, order): # type: ignore[no-untyped-def] + """ + Create a stride based on the order the dimensions should be filled in. + + In this format, channels last would be: + [1, 3, 2, 0] + """ + assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order) + next_stride = sympy.S.One + strides = [None] * len(order) + + for i in order: + strides[i] = next_stride + next_stride = next_stride * sizes[i] + return strides + + @staticmethod + def stride_ordered(sizes, order): # type: ignore[no-untyped-def] + """ + Create a stride based on the sorted order of a permuted range. + + In this format, channels last would be: + [3, 0, 2, 1] + """ + assert OrderedSet(range(len(sizes))) == OrderedSet(order) + fill_order = stride_order2fill_order(order) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + @staticmethod + def stride_ordered_for_memory_format(sizes, memory_format): # type: ignore[no-untyped-def] + """ + Create a stride based on a memory format. + + Memory format is translasted into a stride order, + so channels_last is the same as: + FlexibleLayout.stride_ordered(sizes, [3, 0, 2, 1]) + + This interface does not support memory_format `torch.preserve_format` + which should be used to deduce a format from another source + """ + if memory_format == torch.channels_last: + return FlexibleLayout.stride_ordered(sizes, NHWC_STRIDE_ORDER) + elif memory_format == torch.channels_last_3d: + return FlexibleLayout.stride_ordered(sizes, NHWDC_STRIDE_ORDER) + elif memory_format == torch.contiguous_format: + return FlexibleLayout.contiguous_strides(sizes) + else: + log.debug( + "stride_ordered_for_memory_format, unsuppored memory_format: %s", + memory_format, + ) + raise NotImplementedError + + @staticmethod + def same_ordered(sizes, stride): # type: ignore[no-untyped-def] + """ + Create a stride that has the same stride order as given stride + + For example, if given stride is [1000, 1, 100, 10], + the fill order should be [1, 3, 2, 0] + """ + assert len(sizes) == len(stride) + stride = [V.graph.sizevars.size_hint_or_throw(x) for x in stride] + fill_order = sorted(range(len(stride)), key=stride.__getitem__) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + def as_stride_order(self, order, allow_padding=False): # type: ignore[no-untyped-def] + new_stride = self.stride_ordered(self.size, order) + if self.should_pad_strides() and allow_padding: + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def as_exact_strides(self, exact_strides, allow_padding=False): # type: ignore[no-untyped-def] + new_stride = exact_strides + if self.should_pad_strides() and allow_padding: + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def as_fill_order(self, order): # type: ignore[no-untyped-def] + new_stride = self.fill_ordered(self.size, order) + if self.should_pad_strides(): + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def as_same_order(self, stride): # type: ignore[no-untyped-def] + new_stride = self.same_ordered(self.size, stride) + if self.should_pad_strides(): + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + ) + + def __init__(self, device, dtype, size, stride_order=None) -> None: # type: ignore[no-untyped-def] + if stride_order: + strides = FlexibleLayout.fill_ordered(size, stride_order) + else: + strides = FlexibleLayout.contiguous_strides(size) + super().__init__(device, dtype, size, strides) + + +class NonOwningLayout(Layout): + """Is a view into the storage of another tensor""" + + def __init__(self, view: Union[BaseView, TensorBox]) -> None: + layout = view.get_layout() + super().__init__( + layout.device, + layout.dtype, + layout.size, + layout.stride, + ) + self.view = view + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.as_fixed().make_indexer() + + def maybe_guard_aligned(self): # type: ignore[no-untyped-def] + offset = self.view.get_layout().offset + if offset == 0: + return True + from .utils import ALIGNMENT + + return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) + + +class CommBufferType(Enum): + SYMM_MEM = "symm_mem" + + +class CommBufferLayout(FixedLayout): + """ + A layout that signifies the buffer is a comm buffer. + In terms of striding, the layout is identical to `FixedLayout`. + + Buffers with this layout do not participate in in-place reuse - it can be + neither the source nor the target for in-place reuse. + + For detailed motivation and usage of this layout, see + NOTE [lowering-time collective optimization]. + """ + + comm_buffer_type: CommBufferType + group_name: str + + def __init__( + self, + layout: FlexibleLayout, + comm_buffer_type: CommBufferType, + group_name: str, + ): + if not isinstance(layout, FlexibleLayout): + raise AssertionError( + "A `CommBufferLayout` can only be initialized with " + f"a `FlexibleLayout` (got {layout})." + ) + + fixed = layout.as_fixed() + super().__init__( + device=fixed.device, + dtype=fixed.dtype, + size=fixed.size, + stride=fixed.stride, + offset=fixed.offset, + ) + self.comm_buffer_type = comm_buffer_type + self.group_name = group_name + + +@ir_dataclass +class NoneLayout(OutputSpec): + # This is janky, I figured out what fields to populate by just running + # the model I was interested in and adding properties/methods as needed. + # This doesn't inherit from Layout because Layout assumes you have stuff + # like sizes, but I don't really have anything here. + # + # If you have an ir.Node with NoneLayout, you probably need to setup + # dependencies manually in scheduler + + device: Optional[torch.device] + size: list[int] = dataclasses.field(default_factory=lambda: [0]) + stride: list[int] = dataclasses.field(default_factory=lambda: [0]) + + def storage_size(self) -> int: + return 0 + + def as_fixed(self): # type: ignore[no-untyped-def] + return self + + def get_device(self) -> Optional[torch.device]: + return self.device + + +class MutationLayoutSHOULDREMOVE(Layout): + def __init__(self, target: IRNode) -> None: + super().__init__( + target.get_device_or_error(), + target.get_dtype(), + target.get_size(), # type: ignore[arg-type] + None, + ) + self.target = target + name = self.get_buffer().get_name() + V.graph.mark_buffer_mutated(name) + + @property + def stride(self) -> list[Expr]: + return self.real_layout().stride + + @stride.setter # type: ignore[override] + def stride(self, value: Never) -> None: + pass # ignore setting of stride + + def storage_size(self) -> sympy.Expr: + return self.real_layout().storage_size() + + def get_buffer(self) -> Buffer: + def unwrap_views(target): # type: ignore[no-untyped-def] + if isinstance(target, MutationLayoutSHOULDREMOVE): + return unwrap_views(target.target) + if isinstance(target, BaseView): + return unwrap_views(target.unwrap_view()) + if isinstance(target, MutableBox): + return unwrap_views(target.data) + return target + + result = unwrap_views(self.target) + assert isinstance(result, Buffer), ( + "MutationLayoutSHOULDREMOVE must refer to a buffer" + ) + return result + + def real_layout(self): # type: ignore[no-untyped-def] + return self.get_buffer().layout + + @classmethod + def realize_into(cls, src, dst, unsafe_alias=False): # type: ignore[no-untyped-def] + dst.realize() + # NOTE: We must realize users of `dst` before we realize `src`, since + # realization order determines scheduling order. Otherwise, src's + # mutation would be scheduled before the existing users of dst! + V.graph.mark_buffer_mutated(dst.get_name()) + + if isinstance(src, TensorBox): + src = src.data + + # We copy the contents of src into dst. In most cases this should + # be fused into a single kernel by the scheduler. + # NOTE: We cannot change src's layout to mutate dst directly as this + # would alias src to dst, which is not correct as further mutations to + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. + src.realize_hint() + + if not unsafe_alias: + src = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ).data + + src.realize() + assert isinstance(src.data.layout, FlexibleLayout) + src.data.layout = MutationLayoutSHOULDREMOVE(dst) + return src.data + + def as_fixed(self): # type: ignore[no-untyped-def] + return self + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.target.make_indexer() + + +@ir_dataclass(frozen=False) +class Buffer(IRNode, CodegenSymbol): + # Name is sometimes None; e.g., ForceInPlace, where there isn't + # a meaningful name + name: Optional[str] + layout: OutputSpec + + # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly, + # MultiOutput does NOT define this! + + def __post_init__(self) -> None: + super().__post_init__() + self._post_init_setattr("origin_node", None) + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.get_layout().make_indexer() + + def get_name(self) -> str: + assert self.name, self + return self.name + + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + if isinstance(self.layout, Layout): + return self.layout.get_example() + raise NotImplementedError(type(self.layout).__name__) + + def get_device(self) -> Optional[torch.device]: + return self.get_output_spec().get_device() + + def get_defining_op(self) -> Optional[Operation]: + return None + + @property + def dtype(self) -> torch.dtype: + return self.get_layout().dtype + + def get_size(self) -> Sequence[Expr]: + return [*self.get_layout().size] + + def get_stride(self) -> list[Expr]: + return [*self.get_layout().stride] + + def get_offset(self) -> Expr: + return self.get_layout().offset + + def get_layout(self) -> Layout: + if isinstance(self.layout, Layout): + return self.layout + raise NotImplementedError(type(self.layout).__name__) + + def get_output_spec(self) -> OutputSpec: + return self.layout + + def get_storage_numel(self): # type: ignore[no-untyped-def] + return self.get_numel() + + def freeze_layout(self): # type: ignore[no-untyped-def] + if isinstance(self.layout, Layout) and not isinstance( + self.layout, NonOwningLayout + ): + self.layout = self.layout.as_fixed() + + def freeze_layout_with_stride_order(self, order, allow_padding=False) -> None: # type: ignore[no-untyped-def] + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding) + + def freeze_layout_with_fill_order(self, order) -> None: # type: ignore[no-untyped-def] + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_fill_order(order) + + def freeze_layout_with_same_order(self, stride) -> None: # type: ignore[no-untyped-def] + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_same_order(stride) + + def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def] + self, exact_strides, allow_padding=False + ) -> None: + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_exact_strides( + exact_strides, allow_padding=allow_padding + ) + + def is_zero_elements(self): # type: ignore[no-untyped-def] + return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0)) + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + # Loading from a zero-element buffer is a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.get_dtype()) + + def loader(index): # type: ignore[no-untyped-def] + indexer = self.make_indexer() + return ops.load(self.name or "unnamed", indexer(index)) + + return loader + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.get_name() + + def decide_layout(self): # type: ignore[no-untyped-def] + pass + + def get_inputs_that_alias_output(self) -> Sequence[str]: + if isinstance(self.layout, NonOwningLayout): + return [self.layout.view.get_name()] + return () + + def get_mutation_names(self) -> Sequence[str]: + if isinstance(self.layout, MutationLayoutSHOULDREMOVE): + return [self.layout.target.get_name()] + return () + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet([self.get_name()]) + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def realize(self) -> Optional[str]: + pass + + def should_allocate(self) -> bool: + # Returns False by default. + return False + + +@ir_dataclass(frozen=False) +class OperationBuffer(Buffer, Operation): + # An operation that produces a single output buffer + def get_outputs(self) -> list[Buffer]: + return [self] + + def get_defining_op(self) -> Operation: + return self + + # Skip implementation in Buffer + get_operation_name = Operation.get_operation_name + + def __post_init__(self) -> None: + Buffer.__post_init__(self) + Operation.__post_init__(self) + + +class InputBuffer(Buffer): + def num_reads(self) -> int: + return 1 + + +class DonatedBuffer(InputBuffer): + """ + Represents a donated buffer which is a saved tensor that is not alias to any + fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace + reuse the input tensor memory during backward since it might be used in another + function. However, donated buffer can be inplace reused during backward + to save memory. + """ + + +class ConstantBuffer(InputBuffer): + override_device: Optional[torch.device] = None + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + def loader(index: Sequence[Expr]) -> OpsValue: + indexer = self.get_layout().make_indexer() + return ops.load( + V.graph.constant_name(self.get_name(), self.override_device), + indexer(index), + ) + + return loader + + def constant_to_device(self, device: torch.device) -> IRNode: + return ConstantBuffer( + name=V.graph.constant_name(self.get_name(), device), layout=self.layout + ) + + +@ir_dataclass +class NoneAsConstantBuffer(IRNode): + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return V.graph.wrapper_code.none_str + + def get_output_spec(self) -> OutputSpec: + return NoneLayout(device=None) + + def has_tensor_output(self) -> bool: + return False + + +@ir_dataclass +class ShapeAsConstantBuffer(IRNode): + expr: Expr + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return get_free_symbols(self.expr, unbacked_only) + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return V.graph.wrapper_code.codegen_sizevar(self.expr) + + def has_tensor_output(self) -> bool: + return False + + +@ir_dataclass(frozen=False) +class ComputedBuffer(OperationBuffer): + data: Loops + + def get_computed_buffer_name(self) -> Optional[str]: + """ + Returns self.name if it exists, otherwise returns the name of the data node if that exists. + If neither exist, returns None. + """ + if self.name is not None: + return self.name + if hasattr(self.data, "name"): + return self.data.name + return None + + def num_reads(self) -> int: + return self.data.num_reads() + + def get_reads(self) -> OrderedSet[Dep]: + return self.data.get_reads() + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_read_writes(self) -> dependencies.ReadWrites: + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.data.get_reduction_type(): + return extract_read_writes( + self.get_store_function(), + self.data.get_pointwise_size(), # type: ignore[arg-type] + self.data.get_reduction_size(), # type: ignore[arg-type] + ) + else: + return extract_read_writes( + self.get_store_function(), + self.data.get_size(), # type: ignore[arg-type] + ) + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + # Ordinarily, we'd like to just peek at the arguments list, + # but ComputedBuffers have no argument list. + # + # Morally, this logic needs to be synchronized with the + # KernelArgs.size calls, which are responsible for making symbols make + # there way as kernel arguments (and it is precisely passing in one of + # those symbols that establishes a dependency). However, we haven't + # started codegen yet so we can't directly reuse that logic. + # + # For now, I'm just yoloing with the size of the buffer. Not sure if + # it is enough. + # + # One thing you might wonder is if this is enough for a ComputedBuffer + # denoting a reduction over i0. Empirically, it is enough, but for an + # unusual reason: we only need accurate dependencies for item() call, + # but it's impossible to end up with a reduction over i0 from an + # item() call without a regular non-reduction buffer first. + return ( + get_free_symbols(self.get_size(), unbacked_only) + | get_free_symbols(self.get_stride(), unbacked_only) + | get_free_symbols(self.get_offset(), unbacked_only) + | self.data.get_free_symbol_uses(unbacked_only) + ) + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + if ( + not self.get_reduction_type() + and self.name not in V.graph.mutated_buffers + and self.num_reads() == 0 + ): + # inline this op rather than generating ops.load() + return self.data.make_loader() + return super().make_loader() + + def get_store_function(self) -> Callable[..., None]: + indexer = self.get_layout().as_fixed().make_indexer() + if isinstance(self.data, (Reduction, Scan, Sort)): + return partial(self.data.store_reduction, self.name, indexer) + else: + assert isinstance(self.data, Pointwise) + return partial(self.data.store_output, self.name, indexer) + + def get_fill_order(self) -> Optional[list[int]]: + """ + If our layout is still flexible, try to determine the stride order based on stride orders of reads. + + TODO(jansel): A better algorithm here would look at downstream consumers of this + value and try to do global graph-level layout optimization. + This is also something just begging to be autotuned. + """ + if isinstance(self.layout, FlexibleLayout): + (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size() + ) + reads = self.get_read_writes().reads + # only consider reads to buffer of same size + # ignore StarDeps because they don't contribute stride information + assert all( + isinstance(r, (dependencies.StarDep, dependencies.MemoryDep)) + for r in reads + ) + reads = [ + sympy_subs(r.index, {v: sympy.S.Zero for v in reduction_vars if v != 0}) + for r in reads + if isinstance(r, dependencies.MemoryDep) + ] + + if reads: + if isinstance(self.data, (Scan, Sort)): + indices = self.data.reindex(index_vars, reduction_vars) + else: + indices = index_vars + stride_lengths = [ + V.graph.sizevars.stride_hints(expr, indices) for expr in reads + ] + from .scheduler import pick_loop_order + + return pick_loop_order(stride_lengths, self.get_size()) + + return None + + def decide_layout(self) -> None: + if isinstance(self.layout, FlexibleLayout): + order = self.get_fill_order() + if order: + self.freeze_layout_with_fill_order(order) + else: + self.freeze_layout() + + @cache_on_self + def get_default_sizes_body( + self, + ) -> tuple[ + tuple[list[sympy.Expr], list[sympy.Expr]], + LoopBody, + tuple[list[sympy.Expr], list[sympy.Expr]], + ]: + args, var_ranges = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" + ) + with patch.object(ConstantBuffer, "override_device", self.get_device()): + body = LoopBody( + self.get_store_function(), + (args if self.get_reduction_type() else args[:1]), + var_ranges, + *args, + ) + index_vars = [] + reduce_vars: list[Any] = [] + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + assert not reduce_vars + index_vars.append(v) + index_size.append(s) + else: + assert v in args[1] + reduce_vars.append(v) + reduce_size.append(s) + return (index_size, reduce_size), body, (index_vars, reduce_vars) + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> tuple[tuple[list[sympy.Expr], list[sympy.Expr]], LoopBody]: + """ + This is a main place where we do loop transformations in a + backend-agnostic way. + + Here we: + 1) Remove any 1 dimensions + 2) Fuse contiguous dimensions together + 3) Reorder dimensions based on stride orders + + Optional argument extra_indexing_constraints can be used to append additional + indexing expressions to existing ones derived from buffer's body. This can be useful + to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) + on CPU by preventing indexing simplifications and obtaining index/reduce ranges for + the scheduler node compatible with other nodes. + Optional argument recompute_sizes_body_func can be used to recompute sizes and body + on the default body. This can be useful to append additional loop transformations. + """ + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = self.get_default_sizes_body() + + if recompute_sizes_body_func: + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = recompute_sizes_body_func( + (index_size, reduce_size), body, (index_vars, reduce_vars) + ) + + index_formulas = [*body.indexing_exprs.values()] + if extra_indexing_constraints is not None: + assert ( + isinstance(extra_indexing_constraints, tuple) + and len(extra_indexing_constraints) == 2 + ) + extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints + assert isinstance(extra_indexing_ranges, dict) + assert isinstance(extra_indexing_expr, list) + assert all(isinstance(f, Expr) for f in extra_indexing_expr) + + expected_var_ranges = body.var_ranges + assert expected_var_ranges == extra_indexing_ranges, ( + expected_var_ranges, + extra_indexing_ranges, + ) + # remove already existing expressions + extra_indexing_expr = [ + e for e in extra_indexing_expr if e not in index_formulas + ] + index_formulas += extra_indexing_expr + + memory_addrs = [*body.get_write_exprs()] + if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER): + memory_addrs.extend(body.get_read_exprs()) + + def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): # type: ignore[no-untyped-def] + sizes, reindex0, reindex1 = self._apply_loop_reordering( + x_vars, support_vars, sizes, memory_addrs + ) + # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] + x_vars = reindex0(x_vars) + + if simplify_loops: + sizes, reindex2, _prune = V.graph.sizevars._simplify_loops( + x_vars, + sizes, + index_prevent_reordering(index_formulas, x_vars, sizes), + ) + reindex = fuse_reindexing(reindex1, reindex2) + else: + reindex = reindex1 + return sizes, reindex, reindex1 + + support_vars = index_vars + reduce_vars + should_merge_loops = ( + not is_gpu(get_device_type(self)) or not config.loop_ordering_after_fusion + ) + iter_ranges, iter_reindex, _ = simplify_and_reorder( + index_vars, + support_vars, + index_size, + should_merge_loops, + ) + + # Like iteration dimensions, we may also want to delay merging reduction dimensions. + # E.g., if we reduce a tensor [M, N, K] for its M and N dimensions followed by a pointwise + # kernel, merging M and N dimension too early makes it hard to decide what loop order + # we should pick for the piontwise kernel so that it is fusible with the reduction. + reduce_ranges, reduce_reindex, _ = simplify_and_reorder( + reduce_vars, support_vars, reduce_size, should_merge_loops + ) + + # retrace the loop body with simplification and reordering applied + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + iter_ranges, + reduce_ranges, + prefix="p", + ) + body = LoopBody( + body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + return (iter_ranges, reduce_ranges), body + + @staticmethod + def _apply_loop_reordering( # type: ignore[no-untyped-def] + index_vars, + support_vars, + sizes, + memory_addrs, + priority_idx=None, + ): + """ + Shuffle the order of loops around to hopefully improve performance. + """ + from .scheduler import pick_loop_order + + if priority_idx is None: + priority_idx = [] + + try: + strides = [ + V.graph.sizevars.stride_hints(expr, index_vars, support_vars) + for expr in memory_addrs + ] + assert len(strides) == len(memory_addrs) and len(strides[0]) == len( + index_vars + ) + order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) + except Exception: + if config.debug: + log.warning( + "Did not simplify complex index:\n%s\n%s", + dict(zip(index_vars, sizes)), + memory_addrs, + ) + order = list(range(len(sizes))) + sizes = [sizes[i] for i in order] + return sizes, same_reorder(order), inverse_reorder(order) + + def get_reduction_size(self) -> Sequence[sympy.Expr]: + return self.data.get_reduction_size() + + def get_reduction_type(self) -> Optional[str]: + return self.data.get_reduction_type() + + def is_no_op(self) -> bool: + return self.data.is_zero_elements() + + def should_allocate(self) -> bool: + return True + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + return self.data.constant_to_device(device) + + +class TemplateBuffer(OperationBuffer): + """ + Represents a Triton (in the future other type) of template operator + that we can fuse an epilogue onto. + """ + + def __init__( + self, + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Callable[..., Any], + ) -> None: + super().__init__(name=None, layout=layout) + self.inputs = InputsKernel.unwrap_storage(inputs) + self.make_kernel_render = make_kernel_render + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def get_read_writes(self) -> dependencies.ReadWrites: + return self.extract_read_writes(normalize=True) + + def extract_read_writes(self, normalize): # type: ignore[no-untyped-def] + name = self.get_name() + indexer = self.get_layout().make_indexer() + + def dummy(index, rindex): # type: ignore[no-untyped-def] + assert len(rindex) == 0 + return ops.store(name, indexer(index), "fake") + + deps = dependencies.extract_read_writes( + dummy, self.get_size(), (), normalize=normalize + ) + + for inp in self.inputs: + indexer = inp.layout.make_indexer() + + def dummy(index, rindex): # type: ignore[no-untyped-def] + assert len(rindex) == 0 + ops.load(inp.get_name(), indexer(index)) + + deps.reads |= dependencies.extract_read_writes( + dummy, inp.get_size(), (), normalize=True + ).reads + + return deps + + def get_reduction_size(self) -> Sequence[sympy.Expr]: + return sympy.S.One + + def get_reduction_type(self) -> Optional[str]: + return None + + def should_allocate(self) -> bool: + return True + + def simplify_and_reorder( # type: ignore[no-untyped-def] + self, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ): + return ( + ( + self.get_size(), + (), + ), + None, + ) + + +class TritonTemplateBuffer(TemplateBuffer): + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + make_kernel_render, + mutated_inputs: Optional[Iterable[IRNode]] = None, + allowed_prologue_inps: Optional[OrderedSet[str]] = None, + ) -> None: + """ + NOTE:[TritonTemplates with multiple outputs] + We want the ability for TritonTemplates to output multiple tensors. Triton + kernels have no notion of outputs and this is done by creating tensors that + are then mutated by the kernel. Currently our STORE_OUTPUT codegen doesn't + support creating multinode outputs for triton templates. + We work around this by creating an extra input buffer during the lowering + and we mark them as mutated inputs. + """ + super().__init__(layout, inputs, make_kernel_render) + self.mutated_inputs = mutated_inputs + self.outputs: list[Buffer] = [self] + if mutated_inputs is not None: + # Ensure that the mutated inputs are only allowed for certain nodes + allowed_set = ( + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + ) + current_node = V.graph.current_node.target + assert current_node in allowed_set, ( + f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" + ) + device = self.inputs[0].get_device() + self.outputs += [ + MutationOutput(NoneLayout(device=device), buf, self) + for buf in mutated_inputs + ] + + self.allowed_prologue_inps = ( + allowed_prologue_inps if allowed_prologue_inps else OrderedSet() + ) + + self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None + self.subgraph_outs: Optional[list[Optional[IRNode]]] = None + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + res = super().get_free_symbol_uses(unbacked_only) + subgraph_outs = self.subgraph_outs if self.subgraph_outs else [] + subgraph_inps = self.subgraph_inps if self.subgraph_inps else [] + + for inp in subgraph_inps: + if isinstance(inp, sympy.Expr): + res.update(get_free_symbols(inp, unbacked_only)) + elif isinstance(inp, IRNode): + res.update(inp.get_free_symbol_uses(unbacked_only)) + else: + assert inp is None + + for out in subgraph_outs: + if isinstance(out, IRNode): + res.update(out.get_free_symbol_uses(unbacked_only)) + else: + assert out is None + + return res + + def get_outputs(self) -> list[Buffer]: + return self.outputs + + def get_allowed_prologue_inps(self) -> OrderedSet[str]: + return self.allowed_prologue_inps + + def __str__(self) -> str: + out = f"TritonTemplateBuffer(layout={self.layout})" + return out + + +PrimitiveInfoType = Union[int, float, bool, str, list[Union[int, str, float, bool]]] + + +class ChoiceCaller: + """ + Represents a possible choice used in autotune_process.py. + During autotuning, self.benchmark() is first called to get benchmark result, + and if this choice is selected, self.output_node() is called to get the output_node. + + Children classes: TritonTemplateCaller, CUDATemplateCaller. + """ + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + description: str, + ) -> None: + super().__init__() + self.name = name + self.layout = layout + self.input_nodes = input_nodes + # An additional description used to describe the choice (useful for + # knowing what autotuning is choosing) + self.description = description + + def benchmark(self, *args, out) -> float: # type: ignore[no-untyped-def] + algo = self.to_callable() + if config.profile_bandwidth_with_do_bench_using_profiling: + return do_bench_using_profiling(lambda: algo(*args)) + return benchmarker.benchmark(algo, args, {"out": out}) + + def call_name(self) -> str: + raise NotImplementedError + + def to_callable(self): # type: ignore[no-untyped-def] + raise NotImplementedError + + def kernel_hash_key(self) -> str: + """ + Hash key for the underlying kernel. By default, we assume there are no + runtime params, so kernel hash key defaults to choice caller's hash key. + """ + return self.hash_key() + + def hash_key(self) -> str: + raise NotImplementedError + + def output_node(self) -> TensorBox: + raise NotImplementedError + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return {} + + def autoheuristic_id(self) -> str: + return "unsupported_choice" + + +class TritonTemplateCallerBase(ChoiceCaller): + def get_make_kernel_render(self) -> Any: + raise NotImplementedError + + +class MultiTemplateBuffer(TritonTemplateBuffer): + """ + Represents a Buffer with multiple backing implementation choices. + + Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential + epilogue we will benchmark each of the choices with the epilogue to determine an implementation. + Otherwise, the fastest base choice will be chosen. + """ + + def __init__( + self, + layout: Layout, + inputs: list[IRNode], + choice_timings_fn: Callable[[], dict[ChoiceCaller, float]], + unfiltered_choices: list[ChoiceCaller], + allowed_prologue_inps: OrderedSet[str], + ) -> None: + super().__init__( + layout=layout, + inputs=inputs, + make_kernel_render=None, + allowed_prologue_inps=allowed_prologue_inps, + ) + self._choice_timings_fn = choice_timings_fn + self._choice_timings: Optional[dict[ChoiceCaller, float]] = None + self.original_inputs = inputs + self._output_plannable = all( + isinstance(choice, TritonTemplateCallerBase) + or ( + isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller) + and choice.has_out_variant + ) + for choice in unfiltered_choices + ) + + @property + def output_plannable(self) -> bool: + """ + Are all possible choices TritonTemplates or Extern Kernels with out variants + """ + return self._output_plannable + + @property + def choice_timings(self) -> dict[ChoiceCaller, float]: + if self._choice_timings is None: + self._choice_timings = self._choice_timings_fn() + return self._choice_timings + + @contextlib.contextmanager + def swap_as_triton_caller(self, caller: TritonTemplateCallerBase): # type: ignore[no-untyped-def] + assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) + assert self.layout == caller.layout + + render = self.make_kernel_render + self.make_kernel_render = caller.get_make_kernel_render() + try: + yield + finally: + self.make_kernel_render = render + + def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase) -> None: + assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) + assert self.get_size() == caller.layout.size + assert self.get_stride() == caller.layout.stride + self.make_kernel_render = caller.get_make_kernel_render() + + def get_min_choice(self) -> tuple[ChoiceCaller, float]: + min_choice = min(self.choice_timings, key=self.choice_timings.get) # type: ignore[arg-type] + return (min_choice, self.choice_timings[min_choice]) + + +class CUDATemplateBuffer(TemplateBuffer): + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + make_kernel_render, + workspace_size: int, + template: CUDATemplate, + supports_epilogue_fusion: bool, + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + self.supports_epilogue_fusion = supports_epilogue_fusion + + def get_workspace_size(self): # type: ignore[no-untyped-def] + return self.workspace_size if self.workspace_size is not None else 0 + + def emulate_store_fn(self) -> None: + for output in self.get_outputs(): + ops.store(output.get_name(), None, None) + + +class CppTemplateBuffer(TemplateBuffer): + def __init__(self, layout, inputs, make_kernel_render, template, choice) -> None: # type: ignore[no-untyped-def] + super().__init__(layout, inputs, make_kernel_render) + self.template = template + self.choice = choice + self.outputs: Optional[list[Buffer]] = None + + def get_layout(self) -> Layout: + if isinstance(self.layout, MultiOutputLayout): + assert isinstance(self.outputs, Iterable) + first_output = self.outputs[0] + assert isinstance(first_output, Buffer) + layout = first_output.layout + assert isinstance(layout, Layout) + return layout + else: + return super().get_layout() + + +@ir_dataclass(frozen=False) +class InputsKernel(OperationBuffer): + inputs: list[Buffer] + + def get_read_writes(self) -> dependencies.ReadWrites: + reads = OrderedSet[dependencies.Dep]() + StarDep = dependencies.StarDep + for input in self.inputs: + if isinstance(input, list): + reads.update(StarDep(x.get_name()) for x in input) + elif isinstance(input, ShapeAsConstantBuffer): + # Skip creating dependency for symbolics as they're visible globally + continue + else: + reads.add(StarDep(input.get_name())) + + writes = OrderedSet[dependencies.Dep]( + StarDep(buf.get_name()) for buf in self.get_outputs() + ) + + return dependencies.ReadWrites( + reads=reads, + writes=writes, + index_exprs=OrderedSet(), + ) + + def get_reads(self) -> OrderedSet[Dep]: + return self.get_read_writes().reads + + @classmethod + def unwrap_storage_for_input(cls, x: IRNode) -> IRNode: + if isinstance(x, TensorBox): + x = x.data + if isinstance(x, StorageBox): + x = x.data + if isinstance(x, BaseView) and not isinstance(x, ReinterpretView): + x = ExternKernel.realize_input(x) + if isinstance(x, TensorBox): + # when converting to ReinterpretView fails in the + # realize_input call above, the result will be wrapped + # into TensorBox / StorageBox pair as a result of the + # cls.copy_input call; so we should unwrap recursively + return cls.unwrap_storage_for_input(x) + if isinstance(x, TorchBindObject): + return x + assert isinstance(x, (Buffer, ReinterpretView)), x + return x + + @staticmethod + def unwrap_storage(inputs): # type: ignore[no-untyped-def] + inputs_new = [] + for x in inputs: + if isinstance(x, list): + x = [InputsKernel.unwrap_storage_for_input(i) for i in x] + else: + x = InputsKernel.unwrap_storage_for_input(x) + inputs_new.append(x) + return inputs_new + + def is_extern(self) -> bool: + return True + + def num_reads(self) -> int: + return 1 + + +class NopKernel(InputsKernel): + def is_no_op(self) -> bool: + return True + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + +class ConcatKernel(NopKernel): + """ + There isn't actually a real kernel for concat, we just change the + storage for the upstream data. + """ + + @classmethod + def create(cls, inputs, dim): # type: ignore[no-untyped-def] + device = inputs[0].get_device() + dtype = inputs[0].get_dtype() + new_size = list(inputs[0].get_size()) + offsets_start = [0] + offsets_end = [new_size[dim]] + assert 0 <= dim < len(new_size) + for i in range(1, len(inputs)): + input_size = inputs[i].get_size() + offsets_start.append(new_size[dim]) + assert len(input_size) == len(new_size) + assert inputs[i].get_dtype() == dtype + assert inputs[i].get_device() == device + for j in range(len(new_size)): + if j == dim: + new_size[j] = new_size[j] + input_size[j] + else: + new_size[j] = V.graph.sizevars.guard_equals( + new_size[j], input_size[j] + ) + offsets_end.append(new_size[dim]) + + output_stride = FlexibleLayout.contiguous_strides(new_size) + if config.comprehensive_padding: + # Ensure the output stride matches the alignment requirements + output_stride = Layout._pad_strides( + output_stride, new_size, inputs[0].dtype + ) + + # If any of the inputs is in CL format, use CL format for the output + for i in range(len(inputs)): + x = inputs[i] + if is_storage_and_layout(x): + layout = x.get_layout() + if isinstance( + layout, FixedLayout + ) and Layout.is_channels_last_contiguous(layout.size, layout.stride): + # use CL stride for the output + output_stride = make_channels_last_strides_for(new_size) + break + any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs) + fx_node_args = V.graph.current_node.args[0] + assert isinstance(fx_node_args, list) + # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output + if any_input_is_storage_and_layout is False and any( + "val" in arg.meta + and ( + arg.meta["val"].is_contiguous(memory_format=torch.channels_last) + or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d) + ) + for arg in fx_node_args + ): + output_stride = make_channels_last_strides_for(new_size) + + concat_kernel = ConcatKernel( + name=None, + layout=FixedLayout( + device=device, + dtype=dtype, + size=new_size, + stride=output_stride, + ), + inputs=[], + ) + kernel = StorageBox(concat_kernel) + op_names = [] + for i in range(len(inputs)): + input_buffer = cls.realize_into( + inputs[i], + SliceView.create( + kernel, dim, offsets_start[i], offsets_end[i], clamp=False + ), + ) + concat_kernel.inputs.append(input_buffer) + + if isinstance(inputs[i].data, BaseView): + input_unwrapped = inputs[i].data.unwrap_view() + else: + input_unwrapped = inputs[i].data + + if ( + input_unwrapped.is_input_buffer() + and is_gpu(inputs[i].get_device().type) + and not is_dynamic(input_buffer) + ): + op_names.append(input_buffer.get_operation_name()) + + if len(op_names) > 1 and V.graph.has_feature(device, BackendFeature.FOREACH): + V.graph.register_operation_list(op_names) + + concat_kernel.name = V.graph.register_buffer(concat_kernel) + concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs) + V.graph.register_operation(concat_kernel) + + return kernel + + @classmethod + def can_realize_into_without_copy(cls, src, dst=None): # type: ignore[no-untyped-def] + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.can_realize_into_without_copy(src.data, dst) + + if isinstance(src.data, MultiTemplateBuffer): + if ( + not isinstance(src.data.layout, FixedLayout) + or not src.data.output_plannable + ): + return False + + # we call can_realize_into_without_copy in cat lowering before we've decided + # on output format, optimistically assume layout matches + if dst is None: + return True + + # otherwise, check equality of layouts + if not len(src.get_stride()) == len(dst.get_stride()): + return False + + return all( + V.graph.sizevars.statically_known_equals(s1, s2) + for s1, s2 in zip(src.get_stride(), dst.get_stride()) + ) + + return isinstance(src.data.layout, FlexibleLayout) and not isinstance( + src.data, ExternKernelAlloc + ) + + @classmethod + def realize_into(cls, src, dst): # type: ignore[no-untyped-def] + # Attempt to turn this into a ReinterpretView rather than assert. + # This has concessions around layout, as as_storage_and_layout + # can cause us to go from flexible to fixed layout. + if not isinstance(dst, ReinterpretView): + if is_storage_and_layout(dst): + storage, layout = as_storage_and_layout(dst) + dst = ReinterpretView(data=storage, layout=layout) + assert isinstance(dst, ReinterpretView), dst + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.realize_into(src.data, dst) + + if isinstance(src, StorageBox): + src.realize() + # ExternKernelAlloc has specific requirements for output layout, should create a copy + assert hasattr(src.data, "layout") + if cls.can_realize_into_without_copy(src, dst): + src.data.layout = NonOwningLayout(dst) + return src.data + # introduce a copy + pw = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ) + return cls.realize_into(pw, dst) + + def should_allocate(self) -> bool: + return True + + +@ir_dataclass(frozen=False) +class ExternKernel(InputsKernel): + constant_args: tuple[Any, ...] = () + kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + output_view: Optional[ReinterpretView] = None + python_kernel_name: Optional[str] = None + cpp_kernel_name: Optional[str] = None + # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel + # We shouldn't need to do this since the information can be retrieved from op_overload._schema. + ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( + default_factory=list + ) + op_overload: Optional[ + Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] + ] = None + arg_properties: Optional[list[dict[str, Any]]] = None + kwarg_properties: Optional[dict[str, dict[str, Any]]] = None + unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field( + default_factory=dict + ) + mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list) + + def __init__( # type: ignore[no-untyped-def] + self, + name, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ) -> None: + super().__init__( + name=name, + layout=layout, + inputs=inputs, + ) + self.constant_args = constant_args + self.kwargs = kwargs if kwargs else {} + self.output_view = output_view + self.op_overload = op_overload + self.set_cpp_kernel_name(cpp_kernel_name) + self.set_python_kernel_name(python_kernel_name) + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + self.collect_arg_kwarg_properties() + self.unbacked_bindings = {} + self.mutation_outputs = [] + self.fx_node = V.graph.current_node + + def get_outputs(self) -> list[Buffer]: + return [self, *self.mutation_outputs] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def collect_arg_kwarg_properties(self): # type: ignore[no-untyped-def] + # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional + # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen + self.arg_properties = ( + [ + { + "name": x.name, + "type": x.real_type, + "default_value": x.default_value, + } + for x in self.op_overload._schema.arguments + if not x.kwarg_only + ] + if isinstance(self.op_overload, torch._ops.OpOverload) + else [{} for i in range(len(self.inputs))] + ) + self.allarg_properties = ( + { + x.name: {"type": x.real_type, "default_value": x.default_value} + for x in self.op_overload._schema.arguments + } + if isinstance(self.op_overload, torch._ops.OpOverload) + else {} + ) + # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes + # ordered_kwargs_for_cpp_kernel is explicitly passed in. + if isinstance(self.op_overload, torch._ops.OpOverload): + if not self.ordered_kwargs_for_cpp_kernel: + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in self.op_overload._schema.arguments if x.kwarg_only + ] + self.schema_kwargs = [ + x for x in self.op_overload._schema.arguments if x.kwarg_only + ] + else: + self.schema_kwargs = [] + + def decide_layout(self): # type: ignore[no-untyped-def] + if isinstance(self.layout, FlexibleLayout): + self.apply_constraint() + self.freeze_layout() + + def codegen_comment(self, wrapper) -> None: # type: ignore[no-untyped-def] + origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper) + if origin_str: + wrapper.make_comment(origin_str) + + def codegen(self, wrapper): # type: ignore[no-untyped-def] + raise NotImplementedError + + def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: + self.cpp_kernel_name = cpp_kernel_name + if not V.graph.cpp_wrapper or not isinstance( + self.op_overload, torch._ops.OpOverload + ): + return + + kernel = self.op_overload + if self.cpp_kernel_name is None: + # Try to construct cpp_kernel_name from op_overload + if kernel.namespace == "aten": + # Calling with the default kernel name can lead to ambiguous behavior like the following example. + # repeat_interleave(const at::Tensor & repeats, std::optional output_size=std::nullopt) + # repeat_interleave(const at::Tensor & self, int64_t repeats, + # std::optional dim=std::nullopt, std::optional output_size=std::nullopt) + opname = ( + kernel.__name__.split(".")[0] + if kernel._overloadname == "default" + else kernel.__name__.replace(".", "_") + ) + self.cpp_kernel_name = f"at::_ops::{opname}::call" + else: + self.cpp_kernel_name = kernel._schema.name + + def set_python_kernel_name(self, python_kernel_name: Optional[str]) -> None: + self.python_kernel_name = python_kernel_name + if python_kernel_name is not None: + return + + kernel = self.op_overload + if kernel is None: + pass + elif isinstance(kernel, torch._ops.HigherOrderOperator): + self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}" + else: + self.python_kernel_name = ( + f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" + ) + + def get_kernel_name(self): # type: ignore[no-untyped-def] + device = d.type if (d := self.get_device()) else V.graph.device_type + return ( + V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name, device) # type: ignore[attr-defined] + if V.graph.cpp_wrapper + else self.python_kernel_name + ) + + @staticmethod + def copy_input(x): # type: ignore[no-untyped-def] + pw = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=x.get_size(), + origin_node=x.get_origin_node(), + traceback=x.get_traceback(), + ) + pw.realize() + return pw + + @classmethod + def process_kernel( # type: ignore[no-untyped-def] + cls, kernel, *args, **kwargs + ) -> tuple[ + Any, + list[Any], + list[Any], + Callable[[Any, Any], Any], + Optional[dict[sympy.Symbol, pytree.KeyPath]], + ]: + binded_args = {"args": args, "kwargs": kwargs} + + args_flat, args_spec = pytree.tree_flatten(binded_args) + + is_arg_tensor = [] + # tensor_args can be either tensor or torchbind objects + tensor_args = [] + non_tensor_args: list[Any] = [] + for arg in args_flat: + is_arg_tensor.append( + isinstance(arg, IRNode) and not isinstance(arg, GeneratorState) + ) + if is_arg_tensor[-1]: + tensor_args.append(arg) + else: + if isinstance(arg, sympy.Expr): + arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) + non_tensor_args.append(arg) + + def unflatten_args(new_tensor_args, new_non_tensor_args): # type: ignore[no-untyped-def] + result = [] + it_tensors = iter(new_tensor_args) + it_non_tensors = iter(new_non_tensor_args) + for is_tensor in is_arg_tensor: + if is_tensor: + result.append(next(it_tensors)) + else: + result.append(next(it_non_tensors)) + r = pytree.tree_unflatten(result, args_spec) + return r.get("args", []), r.get("kwargs", {}) + + tensor_args = [cls.realize_input(x) for x in tensor_args] + + # freeze layout otherwise our output stride calculation might + # become incorrect + for x in tensor_args: + if is_storage_and_layout(x): + as_storage_and_layout(x, freeze=True) + + # Rerun fake tensor propagation, because Inductor may have changed the + # strides of inputs and we need to determine accurately what the + # output stride will be. + example_args: list[ + Union[ + torch.Tensor, torch._C.ScriptObject, FakeScriptObject, torch.Generator + ] + ] = [] + + # We need to retain the constant values of fake tensors that we originally + # propagated the graph with, because for some operators running without a + # constant would trigger an error / DataDependentException + for x in tensor_args: + # if x is a view of a constant, we need to realize the view + # (we can't pass the constant into the kernel directly) + if not isinstance(x, BaseView) and x.get_name() in V.graph.constants: + example_args.append(V.graph.constants[x.get_name()]) + elif ( + not isinstance(x, BaseView) + and x.get_name() in V.graph.torchbind_constants + ): + example_args.append(V.graph.torchbind_constants[x.get_name()]) + elif isinstance(x, TorchBindObject): + example_args.append(x.get_value()) + elif isinstance(x, torch._inductor.ir.GeneratorState): + device_index = x.device.index + assert x.device.type == "cuda" and device_index is not None + example_args.append( + torch.cuda.default_generators[device_index].clone_state() + ) + else: + example_args.append(ir_node_to_tensor(x, guard_shape=True)) + + new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) + example_output = kernel(*new_args, **new_kwargs) + + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None + if shape_env := V.fake_mode.shape_env: + node_meta_val = V.current_node.meta.get("val") + ctx = nullcontext() + if V.current_node.target == torch._higher_order_ops.effects.with_effects: + # remove the first effect token in meta["val"] and meta["unbacked_bindings"] + node_meta_val = node_meta_val[1] + ctx = _remove_effect_token_unbacked_bindings(V.current_node) # type: ignore[assignment] + + with ctx: + rebind_unbacked(shape_env, V.current_node, example_output) + unbacked_bindings = compute_unbacked_bindings( + shape_env, example_output, node_meta_val + ) + + example_out_li = ( + [example_output] + if not isinstance(example_output, (list, tuple)) + else example_output + ) + for t in example_out_li: + if isinstance(t, torch.Tensor) and t.is_sparse: + msg = "sparsity not handled. Please file issue for sparse inference weights." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + return ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) + + @classmethod + def convert_to_reinterpret_view(cls, x): # type: ignore[no-untyped-def] + """ + In order to pass this to an extern kernel we need a + ReinterpretView not a View. This allows us to avoid some + unneeded copies. + """ + assert isinstance(x, BaseView) + if isinstance(x, ReinterpretView): + return x + + # NOTE: Don't use extract_read_writes here as it fails when + # make_loader() inlines the computation + x_unwrap_view = x.unwrap_view() + buf = V.graph.get_buffer(x_unwrap_view.get_name()) + assert buf is not None + x_unwrap_view_fx_node = buf.get_origin_node() + # Prefer channels last format according to how the format is set from eager. + if ( + x_unwrap_view_fx_node is not None + and "val" in x_unwrap_view_fx_node.meta + and isinstance(x_unwrap_view.layout, FlexibleLayout) + and ( + x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last + ) + or x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last_3d + ) + ) + ): + x_unwrap_view.freeze_layout_with_same_order( + make_channels_last_strides_for(x_unwrap_view.get_size()) + ) + else: + x_unwrap_view.freeze_layout() + + index_args, var_ranges = dependencies.index_vars_squeeze( + x.get_size(), + prefix="r", # type: ignore[arg-type] + ) + range_vars = index_args[0] + index = x.make_indexer()(range_vars) + + index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) + strides = V.graph.sizevars.stride_vars(index, range_vars) + offset = V.graph.sizevars.offset_var(index, range_vars) + expected = sympy_dot(range_vars, strides) + offset + + if index != expected: + log.debug( + "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", + strides, + offset, + index, + ) + raise NotImplementedError + + return ReinterpretView( + data=x.data, + layout=FixedLayout( + device=x.get_device_or_error(), + dtype=x.get_dtype(), + size=x.get_size(), # type: ignore[arg-type] + stride=strides, + offset=offset, + ), + ) + + @classmethod + def realize_input(cls, x): # type: ignore[no-untyped-def] + if x is None: + return NoneAsConstantBuffer() + if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): + return ShapeAsConstantBuffer(expr=x) + if isinstance(x, Constant): + return V.graph.add_tensor_constant( + torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) + ) + if isinstance(x, ConstantBuffer): + return x + if isinstance(x, TensorBox): + return cls.realize_input(x.data) + if isinstance(x, ReinterpretView): + return ReinterpretView( + data=cls.realize_input(x.data), layout=x.get_layout() + ) + if isinstance(x, BaseView): + x.realize() + if is_storage_and_layout(x.unwrap_view()): + try: + return cls.convert_to_reinterpret_view(x) + except NotImplementedError: + pass + if isinstance(x, StorageBox): + # TODO(jansel): impose layout preference on realized buffer + x.realize() + return x + if isinstance(x, (NonTensorObj, ShapeAsConstantBuffer)): + return x + return cls.copy_input(x) + + @classmethod + def require_stride1(cls, x): # type: ignore[no-untyped-def] + if is_storage_and_layout(x): + if len(x.get_stride()) == 0: + return x + for stride in x.get_stride(): + if stride == 1: + return x + return cls.copy_input(x) + + @classmethod + def require_strides( # type: ignore[no-untyped-def] + cls, + x, + order: Optional[Sequence[int]] = None, + exact_strides: Optional[Sequence[_IntLike]] = None, + allow_padding=False, + ): + assert order is not None or exact_strides is not None + # Layout generally doesn't matter, but some consuming external ops might have requirements + if x.get_numel() in (0, 1) and not exact_strides: + return x + + # require x to have the layout + if is_storage_and_layout(x): + if isinstance(x.get_layout(), FlexibleLayout): + if order: + # If the the FlexibleLayout already has the size and stride in the required order, + # freeze it to a FixedLayout by using its current size and stride. + # The behavior of using its current size and stride or the given order can be different + # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1: + # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last), + # the current size and stride already satisfies this order. + # However by freezing it to the required order, the layout will be changed to: + # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. + + # fix flexiblelayout to be FixedLayout with stride_order + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=( + get_stride_order( + V.graph.sizevars.size_hints(x.get_layout().stride) + ) + if is_stride_order_storage_and_layout(x, order) + else order + ), + allow_padding=allow_padding, + ) + return x + else: + # If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides. + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=None, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + return x + elif isinstance(x.get_layout(), (FixedLayout, NonOwningLayout)) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() + ) + ) + ): + return ( + try_match_insignificant_strides(x, exact_strides) + if exact_strides is not None + else x + ) + elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE): + if isinstance(x.get_layout().real_layout(), FlexibleLayout): + raise AssertionError( + "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout" + ) + elif isinstance(x.get_layout().real_layout(), FixedLayout) and ( + (order and x.get_layout().real_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, + x.get_layout().real_layout().stride, + x.get_size(), + ) + ) + ): + return x + + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() + ) + ) + ): + return x + if ( + isinstance(x, TensorBox) + and isinstance(x.data, BaseView) + and not isinstance(x.data, ReinterpretView) + and is_storage_and_layout(x.unwrap_view()) + and not isinstance(x.unwrap_view().data, ExternKernelAlloc) # type: ignore[attr-defined] + ): + try: + x.data = cls.convert_to_reinterpret_view(x.data) + if order: + return cls.require_stride_order( + x, order, allow_padding=allow_padding + ) + elif exact_strides: + return cls.require_exact_strides( + x, exact_strides, allow_padding=allow_padding + ) + except NotImplementedError: + pass + + # Preserve ExpandView representation that would be lost during copy_input + # Without representation of the expand in inductor IR, in codegen we end up + # launching a grid for the full size tensor and doing redundant computation + # across expanded dims. + # TODO: could also be good to have a codegen fix to recognize overlapping elements + + expanded_dims: Optional[list[int]] = None + orig_size = x.get_size() + if exact_strides is not None: + sizevars = V.graph.sizevars + expanded_dims = [ + i + for i in range(len(x.get_size())) + if sizevars.statically_known_equals(exact_strides[i], 0) + and sizevars.statically_known_geq(x.get_size()[i], 2) + ] + + for dim in expanded_dims: + x = torch._inductor.lowering.slice_(x, dim, 0, 1) + + # Although this is a clone, inductor is good about fusing clones into previous + # operations if they weren't realized and their layouts were flexible. + x = cls.copy_input(x) + + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + if order: + assert is_stride_order_storage_and_layout(x, order) + elif expanded_dims: + assert orig_size is not None and exact_strides is not None + x = torch._inductor.lowering.expand(x, orig_size) + # the expand will sometimes may change insignificant strides, so match them back + return try_match_insignificant_strides(x, exact_strides) + + return x + + @classmethod + def require_exact_strides(cls, x, exact_strides, allow_padding=False): # type: ignore[no-untyped-def] + return cls.require_strides( + x, exact_strides=exact_strides, allow_padding=allow_padding + ) + + @classmethod + def require_stride_order(cls, x, order, allow_padding=False): # type: ignore[no-untyped-def] + return cls.require_strides(x, order=order, allow_padding=allow_padding) + + @classmethod + def require_channels_last(cls, x): # type: ignore[no-untyped-def] + return cls.require_stride_order(x, NHWC_STRIDE_ORDER) + + @classmethod + def require_channels_last_3d(cls, x): # type: ignore[no-untyped-def] + return cls.require_stride_order(x, NHWDC_STRIDE_ORDER) + + @classmethod + def require_contiguous(cls, x): # type: ignore[no-untyped-def] + def is_mkldnn_tensor(x): # type: ignore[no-untyped-def] + def safe_get_name(x): # type: ignore[no-untyped-def] + try: + return x.get_name() + except (AttributeError, NotImplementedError): + return None + + return ( + safe_get_name(x) in V.graph.constants + and V.graph.constants[safe_get_name(x)].is_mkldnn + ) + + # TODO move this to the more proper places + if is_mkldnn_tensor(x): + return x + else: + return cls.require_exact_strides( + x, FlexibleLayout.contiguous_strides(x.get_size()) + ) + + @classmethod + def require_contiguous_strides(cls, x): # type: ignore[no-untyped-def] + # TODO: combine this with require_contiguous after + # https://github.com/pytorch/pytorch/pull/148235 lands. + return cls.require_exact_strides( + x, FlexibleLayout.contiguous_strides(x.get_size()) + ) + + def apply_constraint(self) -> None: + pass + + def fill_non_provided_args(self, args, kwargs): # type: ignore[no-untyped-def] + # Previously, we want to maintain forward-compatibility by skipping + # default args in the serialized artifacts in fbcode. However, + # some of our shim interfaces require default values being OrderedSet. + # Discussed with Sherlock offline and we decided to allow serializing + # default args into the C++ wrapper code for now. We will refine this + # part if we see real FC requirement. More details related to FC + # can be found at: + # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing + assert isinstance(args, (list, tuple)) + if isinstance(args, tuple): + args = list(args) + assert self.arg_properties, "ExternKernel.arg_properties should not be empty" + + n_args = len(args) + n_pos_args = len(self.arg_properties) + # For cpp wrapper, if some positional args are not provided, we need to check + # if they're in the kwargs or use their default value + if n_args < n_pos_args: + log.debug( + "%s has %d unprovided positional arguments. " + "Will check if they are in the keyword arguments or will use default values.", + self.op_overload, + n_pos_args - n_args, + ) + for i in range(n_args, n_pos_args): + arg_name = self.arg_properties[i]["name"] + args.append( + kwargs[arg_name] + if arg_name in kwargs + else self.arg_properties[i]["default_value"] + ) + return args + + def codegen_const_args(self, names: Optional[list[str]] = None): # type: ignore[no-untyped-def] + if V.graph.cpp_wrapper: + result = [] + # Aten ops follow the convention that tensor args are before non-tensor args, + # in which case the following 'len(self.inputs) + i' logic works. But this + # may not be true for other ops, and if that is the case, caller needs to + # pass in a list of const arg names for arg_properties lookup. + name_to_arg_properties = None + if names and self.arg_properties: + assert len(self.constant_args) == len(names), ( + "names passed to codegen_const_args does not match self.constant_args" + ) + name_to_arg_properties = { + arg.get("name"): arg for arg in self.arg_properties + } + + for i, x in enumerate(self.constant_args): + if name_to_arg_properties is not None: + prop = name_to_arg_properties.get(names[i]) # type: ignore[index] + type_ = prop.get("type") if prop else None + else: + idx = len(self.inputs) + i + type_ = ( + self.arg_properties[idx].get("type") + if self.arg_properties and idx < len(self.arg_properties) + else None + ) + result.append(V.graph.wrapper_code.val_to_arg_str(x, type_)) + return result + else: + return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) + + def codegen_args(self): # type: ignore[no-untyped-def] + if V.graph.cpp_wrapper and self.op_overload is not None: + # cpp wrapper needs special logic to fill in missing args with default values + inputs = self.fill_non_provided_args( + [*self.inputs, *self.constant_args], self.kwargs + ) + # fill_non_provided_args has handled constant args, so no need to codegen for that later + need_codegen_constant_args = False + else: + inputs = self.inputs + need_codegen_constant_args = True + + args = [] + for i, x in enumerate(inputs): + if V.graph.cpp_wrapper: + assert self.arg_properties and i < len(self.arg_properties), ( + "Invalid access to ExternKernel.arg_properties" + ) + type_ = self.arg_properties[i].get("type") + args.append(V.graph.wrapper_code.val_to_arg_str(x, type_)) + else: + args.append(V.graph.wrapper_code.val_to_arg_str(x)) + if need_codegen_constant_args: + args.extend(self.codegen_const_args()) + return args + + def get_kwargs_value(self, arg_name, **kwargs): # type: ignore[no-untyped-def] + """Given an argument name, queries for values in (in order): + 1. any provided kwargs for this function. + 2. the class self.kwargs member. + 3. any available default arguments in self.allarg_properties.""" + if arg_name in kwargs: + return kwargs.get(arg_name) + if arg_name in self.kwargs: + return self.kwargs.get(arg_name) + if self.allarg_properties and arg_name in self.allarg_properties: + return self.allarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] + raise AssertionError(f"{arg_name} not in self.allarg_properties") + + def codegen_kwargs(self, skip_out=False): # type: ignore[no-untyped-def] + if V.graph.cpp_wrapper: + if self.op_overload is not None and len(self.schema_kwargs) == 0: + # All the args should have been generated by fill_non_provided_args in codegen_args + return [] + + kwargs = [] + for arg_name in self.ordered_kwargs_for_cpp_kernel: + if skip_out and arg_name == "out": + # ExternKernelOut has its own logic for inserting the out parameter + continue + + v = self.get_kwargs_value(arg_name) + if isinstance(v, sympy.Expr): + kwargs.append(v) + else: + type_ = ( + self.allarg_properties.get(arg_name).get("type") # type: ignore[union-attr] + if self.allarg_properties and arg_name in self.allarg_properties + else None + ) + kwargs.append(V.graph.wrapper_code.val_to_arg_str(v, type_)) + else: + kwargs = [ + f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" + for k, v in self.kwargs.items() + ] + return kwargs + + def get_op_name(self) -> str: + if self.fx_node is not None: + target = self.fx_node.target + op_namespace = getattr(target, "__module__", "unknown_namespace") + op_namespace = op_namespace.replace("._ops.", ".ops.") + op_namespace = op_namespace.rsplit(".", 1)[0] + op_name = f"{op_namespace}.{target}" + else: + op_name = "unknown_op" + return op_name + + def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def] + if config.size_asserts and not V.graph.cpp_wrapper: + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(self.get_size()) == 0: + return + size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) + stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) + op_name = self.get_op_name() + wrapper.writeline( + f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})" + ) + + def codegen_alignment_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def] + if config.alignment_asserts and not V.graph.cpp_wrapper: + name = self.get_name() + aligned = name not in V.graph.unaligned_buffers + op_name = self.get_op_name() + if aligned: + wrapper.writeline( + f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})" + ) + else: + wrapper.writeline( + f"# buffer {name} (op: {op_name}) is assumed to be not aligned" + ) + + def get_group_stride(self): # type: ignore[no-untyped-def] + """ + get output sizes and strides, for template_codegen + """ + _size = self.get_size() + _stride = self.get_stride() + # iter_ranges = _size of output tensor, reduce_range = [] because no reduction + return [_size, []], _stride + + def canonicalize(self): # type: ignore[no-untyped-def] + """ + Manually get canonicalization of the output index + """ + # manually generate index formula for conv + sizevars = V.graph.sizevars + sizes = self.get_size() + strides = self.get_stride() + strides = [sizevars.size_hint(x) for x in strides] + # TODO: I can't tell if the symbols here are temporary + index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))] + # reorder index vars according to stride + index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + lookup = {pos: idx for idx, pos in enumerate(index_order)} + order = [lookup[i] for i in range(len(lookup))] + index_vars = [index_vars[i] for i in order] + indexer = self.make_indexer() + index = indexer(index_vars) + + new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, [index] + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + _, add_var = var_builder("c") + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + + index = sympy_subs(sympy.expand(index), replacement) + return index, tuple(new_sizes) + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + # NB: It's not necessary to check regular inputs as we automatically + # have dependencies on them + maybe_get_symbols = ( + maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols + ) + r = OrderedSet[sympy.Symbol]() + for arg in self.constant_args: + r |= maybe_get_symbols(arg) + for arg in self.kwargs.values(): + r |= maybe_get_symbols(arg) + return r + + def __str__(self) -> str: + kernel_name = getattr(self, "python_kernel_name", None) + lines = [ + f"python_kernel_name={kernel_name!r}", + ] + lines += [ + f"{field.name}={getattr(self, field.name)}" + for field in dataclasses.fields(self) + ] + lines.append(f"origin_node={self.origin_node!r}") + return self.str_helper(lines) + + __repr__ = __str__ + + +@ir_dataclass(frozen=False) +class ExternKernelOut(ExternKernel): + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.generate_extern_kernel_out(self) + + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ) -> None: + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def should_allocate(self) -> bool: + return True + + +class RandomSeeds(ExternKernelOut): + def __init__(self, count: int, device: torch.device) -> None: + limits = torch.iinfo(torch.int64) + super().__init__( + layout=FixedLayout( + device=device, + dtype=torch.int64, + size=[count], + ), + inputs=[], + constant_args=[limits.min, limits.max, [count]], + python_kernel_name="aten.randint.low_out", + # FIXME: Ideally we should only use at::_ops::randint_low_out::call here, + # but the signature is different from is at::randint_out. Again, + # we can simplify the code when only keeping an ABI-compatible version. + cpp_kernel_name="at::_ops::randint_low_out::call", + op_overload=aten.randint.low_out, + ) + + +class ExternKernelAlloc(ExternKernel): + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.generate_extern_kernel_alloc(self) + + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + constant_args=(), + kwargs=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ) -> None: + super().__init__( + None, + layout, + self.unwrap_storage(inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + # We need output buffers for generating kernel arguments in the + # abi-compatible mode, where we retrieve outputs by pass each individual + # output through the abi-compatible interface. + self.outputs: Sequence[Any] = [] + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def should_allocate(self) -> bool: + return False + + def apply_constraint(self): # type: ignore[no-untyped-def] + raise NotImplementedError + + +class MutationOutput(Buffer): + """ + An output buffer that represents the mutation of a pre-existing buffer + """ + + def __init__(self, layout, mutated_node, mutating_node: Operation) -> None: # type: ignore[no-untyped-def] + super().__init__(name=None, layout=layout) + mutated_node_name = mutated_node.get_name() + V.graph.mark_buffer_mutated(mutated_node_name) + self.mutation_names = [mutated_node_name] + self.mutating_node: Operation = mutating_node + self.name = V.graph.register_buffer(self) + + def get_defining_op(self) -> Operation: + return self.mutating_node + + def get_mutation_names(self) -> Sequence[str]: + return self.mutation_names + + def should_allocate(self) -> bool: + return False + + +class TMADescriptor(ExternKernel): + """ + An IR node representing a generic host-side TMA descriptor in the Triton API + Mostly useful for user-defined Triton kernels relying on host-side TMA; + but can, in principle, be used for Inductor's Triton templates, too. + + See TMADescriptorExperimental and TMADescriptorStable for the two implementations + (the old API and the new API) + """ + + # as TMA descriptors are immutable, + # we can dedup them by the input args + _CACHE: dict[Any, TMADescriptor] = {} + + @classmethod + def _create_impl( + cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]] + ) -> TMADescriptor: + assert len(tma_meta) == 2 + if tma_meta[0] == "experimental": + return TMADescriptorExperimental(tensor, *tma_meta[1]) + else: + assert tma_meta[0] == "stable" + return TMADescriptorStable(tensor, *tma_meta[1]) + + @classmethod + def create( + cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]] + ) -> TMADescriptor: + key = (id(tensor), tma_meta) + if key not in cls._CACHE: + cls._CACHE[key] = cls._create_impl(tensor, tma_meta) + return cls._CACHE[key] + + def __init__(self, tensor: IRNode, inputs, constant_args): # type: ignore[no-untyped-def] + super().__init__( + None, + # link back to the underlying tensor in terms of ownership + # to avoid getting the underlying tensor deleted *before* + # the TMADescriptor node can be deleted. + NonOwningLayout( + ReinterpretView( + data=tensor, + layout=tensor.get_layout(), + ) + ), + inputs, + tuple(constant_args), + None, + ) + + self.tensor = tensor + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.generate_tma_descriptor(self) + + def get_tensor(self) -> IRNode: + return self.tensor + + +class TMADescriptorExperimental(TMADescriptor): + """ + the new host-side TMA Descriptor API: + (the ones obtained via create_{1d,2d}_tma_descriptor calls). + + See also TMADescriptorStable for the new API. + """ + + def __init__( + self, + tensor: IRNode, + dims: list[Union[int, torch.SymInt]], + block_dims: list[Union[int, torch.SymInt]], + element_size: Optional[int] = None, + ) -> None: + assert len(dims) in (1, 2) + assert len(dims) == len(block_dims) + + if element_size is None: + element_size = tensor.get_dtype().itemsize + + self.dims = dims + self.block_dims = block_dims + self.element_size = element_size + self.rank = len(self.dims) + + inputs = [tensor] + constant_args = [ + *self.dims, + *self.block_dims, + self.element_size, + ] + + super().__init__( + tensor=tensor, + inputs=inputs, + constant_args=constant_args, + ) + + +class TMADescriptorStable(TMADescriptor): + """ + the new host-side TMA descriptor API + (the ones obtained via TensorDescriptor.from_tensor). + + See also TMADescriptorExperimental for the old API. + """ + + def __init__(self, tensor: IRNode, block_shape: list[Union[int, torch.SymInt]]): + self.block_shape = block_shape + + super().__init__( + tensor=tensor, + inputs=[tensor], + constant_args=block_shape, + ) + + +class SubgraphBuffer(ExternKernel): + def __init__( + self, + layout: Layout, + input_nodes: list[Buffer], + gm: torch.fx.GraphModule, + example_inputs: list[Any], + subgraph_name: str, + ): + super().__init__(None, layout, input_nodes) + self.gm = gm + self.example_inputs = example_inputs + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name) + + sym_inputs = get_symbolic_inputs(self.inputs) + + for sym_inp in sym_inputs: + self.subgraph.graph_inputs[sym_inp.name] = sym_inp + self.subgraph.graph_input_names.append(sym_inp.name) + + self.sym_inputs = [sym_var.name for sym_var in sym_inputs] + + import torch._inductor.config as inductor_config + + with V.set_graph_handler(self.subgraph): + # Don't bother autotuning on Triton here + with inductor_config.patch( # type: ignore[no-untyped-def] + max_autotune=False, + max_autotune_gemm=False, + max_autotune_gemm_backends="ATEN", + ): + self.subgraph.run(*self.example_inputs) + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + class CodegenGraph: + def __init__(self, graph: GraphLowering): + self.graph = graph + self.name = graph.name + + outer_inputs = [t.codegen_reference() for t in self.inputs] + wrapper.codegen_subgraph_with_flattened_outputs( + CodegenGraph(self.subgraph), + [*self.sym_inputs, *outer_inputs], + [self.name], + ) + + +class UserDefinedTritonKernel(ExternKernel): + def get_kernel_and_metadata(self): # type: ignore[no-untyped-def] + from triton.runtime.autotuner import Autotuner + + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + kernel = kernel_side_table.get_kernel(self.kernel_idx) + configs = [] + restore_value_args: list[str] = [] + reset_to_zero_args: list[str] = [] + if isinstance(kernel, Autotuner): + # https://github.com/triton-lang/triton/pull/5083 + # changes kernel.restore_idx to kernel.restore_value + if hasattr(kernel, "restore_idx"): + restore_value_args.extend( + kernel.fn.arg_names[i] for i in kernel.restore_idx + ) + else: + assert hasattr(kernel, "restore_value") + restore_value_args.extend(kernel.restore_value) + + if hasattr(kernel, "reset_idx"): + for i in kernel.reset_idx: + reset_to_zero_args.append(kernel.fn.arg_names[i]) + else: + assert hasattr(kernel, "reset_to_zero") + reset_to_zero_args.extend(kernel.reset_to_zero) + + configs = kernel.configs + kernel = kernel.fn + return kernel, configs, restore_value_args, reset_to_zero_args + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + from torch._inductor.utils import triton_version_uses_attrs_dict + + ( + kernel, + configs, + restore_value_args, + reset_to_zero_args, + ) = self.get_kernel_and_metadata() + + # Definition of kernel + ( + new_name, + triton_meta, + extra_launch_args, + ) = wrapper.define_user_defined_triton_kernel( + kernel, + configs, + self.kwargs, + restore_value_args, + reset_to_zero_args, + self.grid, + ) + named_args = { + k: self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel + } + constexpr_names = OrderedSet([kernel.arg_names[i] for i in kernel.constexprs]) + + args: list[Any] = [] + arg_types: list[Any] = [] + raw_keys_filtered: list[Any] = [] + raw_args_filtered: list[Any] = [] + for name, arg in itertools.chain( + named_args.items(), zip(itertools.repeat(""), extra_launch_args) + ): + raw_keys_filtered.append(name) + raw_args_filtered.append(arg) + if isinstance(arg, IRNode): + args.append(arg.codegen_reference()) + arg_types.append(arg.get_dtype()) + elif isinstance(arg, (int, float, bool, sympy.Expr)): + args.append(arg) + arg_types.append(type(arg)) + elif name in constexpr_names: + # insert a dummy value for constexpr args of unsupported type + # constexprs will end up getting baked into the kernel at compile time + args.append(-1) + arg_types.append(int) + elif arg is None: + """ + Filter out None args. + + see https://github.com/pytorch/pytorch/issues/115344 + + Two cases for a None arg: + 1. The arg is already tl.constexpr, so leave it in + 2. The arg is not tl.constexpr so we have to remove it + """ + if triton_version_uses_attrs_dict(): + args.append(-1) + arg_types.append(int) + else: + raw_keys_filtered.pop() + raw_args_filtered.pop() + else: + raise NotImplementedError(f"Unsupported arg type: {type(arg)}: {arg}") + + self.codegen_comment(wrapper) + wrapper.generate_kernel_call( + new_name, + args, + arg_types=arg_types, + raw_args=raw_args_filtered, + raw_keys=raw_keys_filtered, + triton_meta=triton_meta, + triton=True, + device=self.get_device(), + original_fxnode_name=self.fx_node.name, + ) + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + # add unbacked symbols used in the grid to the ones used + # in the kwargs (the latter is generated by ExternKernel) + return super().get_free_symbol_uses(unbacked_only) | get_free_symbols( + self.grid, unbacked_only + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( # type: ignore[no-untyped-def] + self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args + ) -> None: + inputs = [] + kwargs = {} + constant_args = [] + for k, v in kernel_args.items(): + if isinstance(v, TensorBox): + t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) + if k in tma_descriptor_metadata: + t = TMADescriptor.create(t, tma_descriptor_metadata[k]) + inputs.append(t) + kwargs[k] = t + else: + constant_args.append(v) + kwargs[k] = v + + assert len(inputs) != 0 + self.device = inputs[0].get_device() + + super().__init__( + None, + NoneLayout(device=self.device), + inputs, + tuple(constant_args), + kwargs, + ) + self.kernel_idx = kernel_idx + self.grid = grid + + kernel, configs, _, _ = self.get_kernel_and_metadata() + + # If we are autotuning, not all arguments will be passed + self.ordered_kwargs_for_cpp_kernel = [ + arg for arg in kernel.arg_names if arg in kernel_args + ] + + from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors + + autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {} + self.mutable_args = [ + kernel_args[key] + for key in identify_mutated_tensors( + kernel, {**kernel_args, **autotuned_kwargs}, tma_descriptor_metadata + ) + ] + + self.mutation_outputs = [ + MutationOutput(NoneLayout(device=self.device), buf, self) + for buf in self.mutable_args + ] + V.graph.register_operation(self) + + def get_outputs(self) -> list[Buffer]: + return list(self.mutation_outputs) + + def get_device(self) -> Optional[torch.device]: + return self.device + + +class InplaceBernoulliFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + (x,) = (t.codegen_reference() for t in self.inputs) + + if V.graph.cpp_wrapper: + # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, + # which needs to be explicitly generated for cpp wrapper + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}" + ) + else: + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" + ) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> Sequence[str]: + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__(self, op_overload, x, *constant_args) -> None: # type: ignore[no-untyped-def] + super().__init__( + None, + NoneLayout(device=x.get_device()), + self.unwrap_storage([x]), + constant_args, + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(x.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +# Used to deal with torch.complex types +class InplaceCopyFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + (dst, src, non_blocking) = self.codegen_args() + wrapper.codegen_device_copy(src, dst, non_blocking) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> Sequence[str]: + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + constant_args, + ) -> None: + super().__init__( + None, + layout, + inputs, + constant_args, + python_kernel_name="aten.copy_", + cpp_kernel_name="aoti_torch_copy_", + ) + V.graph.mark_buffer_mutated(inputs[0].get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create(cls, dst, src, non_blocking: bool = False): # type: ignore[no-untyped-def] + inputs = [cls.realize_input(t) for t in [dst, src]] + constant_args = (non_blocking,) + result = InplaceCopyFallback( + NoneLayout(device=dst.get_device()), + inputs, + constant_args, + ) + return result + + +class MutatingFirstArgExternKernel(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + argrefs = [ + *(t.codegen_reference() for t in self.inputs), + *map(repr, self.constant_args), + ] + wrapper.writeline( + f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}" + ) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> Sequence[str]: + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def has_side_effects(self) -> bool: + return True + + +class ResizeStorageBytes(MutatingFirstArgExternKernel): + def __init__(self, variable, new_size) -> None: # type: ignore[no-untyped-def] + assert isinstance(new_size, int), "TODO: dynamic shapes" + super().__init__( + None, + NoneLayout(device=variable.get_device()), + self.unwrap_storage([variable]), + constant_args=(new_size,), + ) + V.graph.mark_buffer_mutated(variable.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + self.python_kernel_name = "inductor_ops.resize_storage_bytes_" + self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" + V.graph.never_reuse_buffers.add(variable.data.get_name()) + + +class SetSourceTensorKernel(ExternKernelAlloc): + def __init__(self, self_tensor, storage_tensor) -> None: # type: ignore[no-untyped-def] + storage_tensor.freeze_layout() + super().__init__( + storage_tensor.get_layout(), + [self_tensor, storage_tensor], + python_kernel_name="torch.ops.aten.set_.source_Tensor", + op_overload=torch.ops.aten.set_.source_Tensor, + ) + V.graph.never_reuse_buffers.add(self_tensor.data.get_name()) + V.graph.never_reuse_buffers.add(storage_tensor.get_name()) + V.graph.never_reuse_buffers.add(self.get_name()) + device = storage_tensor.get_device() + self.mutation_outputs = [ + MutationOutput(NoneLayout(device=device), self_tensor, self), + MutationOutput(NoneLayout(device=device), storage_tensor, self), + ] + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return [self.inputs[0].get_name(), self.inputs[1].get_name()] + + +class ScatterFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly. + This class handles both aten.scatter_ and aten.scatter_reduce_. + It also handle the case `src` being a scalar properly. + """ + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + reduce = self.kwargs["reduce"] + if V.graph.cpp_wrapper: + # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum + get_operator_enum = {"add": "sum", "multiply": "prod"} + if reduce in get_operator_enum: + reduce = get_operator_enum[reduce] + + if self.src_is_tensor: + (x, index, src) = (t.codegen_reference() for t in self.inputs) + else: + (x, index) = (t.codegen_reference() for t in self.inputs) + src = self.constant_args[1] + wrapper.generate_scatter_fallback( + x, + [x, self.constant_args[0], index, src], + self.cpp_kernel_name, + self.python_kernel_name, + self.src_is_tensor, + reduce, + self.codegen_kwargs(), + ) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> Sequence[str]: + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( # type: ignore[no-untyped-def] + self, + op_overload, + x, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, + ) -> None: + self.src_is_tensor = isinstance(src, TensorBox) + + constant_args: tuple[Any, ...] + if self.src_is_tensor: + tensors = [self.realize_input(t) for t in [x, index, src]] + constant_args = (dim,) + else: + tensors = [self.realize_input(t) for t in [x, index]] + constant_args = (dim, src) + + super().__init__( + None, + NoneLayout(device=x.get_device()), + self.unwrap_storage(tensors), + constant_args, + {"reduce": reduce, "include_self": include_self}, + python_kernel_name=str(op_overload), + ordered_kwargs_for_cpp_kernel=["reduce", "include_self"], + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(x.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +class IndexPutFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation and indices properly + """ + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) + indices = [] + iter_valid_indices = iter(valid_indices) + for i, _ in enumerate(self.indices): + if self.indices[i] is not None: + indices.append(next(iter_valid_indices)) + else: + indices.append(V.graph.wrapper_code.none_str) + + wrapper.generate_index_put_fallback( + self.get_kernel_name(), x, indices, values, *self.codegen_const_args() + ) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> Sequence[str]: + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__(self, op_overload, x, indices, values, accumulate) -> None: # type: ignore[no-untyped-def] + self.indices = indices + valid_indices = [i for i in indices if i is not None] + tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] + cpp_kernel_name = "aoti_torch_index_put_out" + super().__init__( + None, + NoneLayout(device=x.get_device()), + self.unwrap_storage(tensors), + (accumulate,), + python_kernel_name="aten.index_put_", + cpp_kernel_name=cpp_kernel_name, + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(self.inputs[0].get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +class DeviceCopy(ExternKernelOut): + @classmethod + def create(cls, x, device, non_blocking): # type: ignore[no-untyped-def] + if ( + not x.is_extern() + and all(r in V.graph.constants for r in x.get_read_names()) + and not config.aot_inductor.use_runtime_constant_folding + ): + return x.constant_to_device(device) + + V.graph.add_device_info(device) + V.graph.add_device_info(x.get_device()) + + developer_warning("DeviceCopy in input program") + constant_args = (non_blocking,) + return DeviceCopy( + FlexibleLayout( + device=device, + dtype=x.get_dtype(), + size=x.get_size(), + ), + [cls.realize_input(x)], + constant_args, + ) + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + args = self.codegen_args() + assert len(args) == 2 + if self.output_view: + wrapper.codegen_device_copy( + args[0], self.output_view.codegen_reference(), args[1] + ) + else: + wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1]) + + +class DynamicScalar(ExternKernel): + """ + The result of a call to aten._local_scalar_dense. + """ + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + def should_allocate(self) -> bool: + return False + + def __init__(self, sym, keypath, data) -> None: # type: ignore[no-untyped-def] + data.realize() + super().__init__( + None, NoneLayout(device=torch.device("cpu")), self.unwrap_storage([data]) + ) + self.sym = sym + self.keypath = keypath + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet([self.sym]) + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.codegen_dynamic_scalar(self) + + +class AssertScalar(ExternKernel): + """ + The result of a call to aten._assert_scalar + """ + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + def should_allocate(self) -> bool: + return False + + def __init__(self, scalar, msg) -> None: # type: ignore[no-untyped-def] + super().__init__( + # Buffer(name, layotu) + None, + NoneLayout(device=torch.device("cpu")), + # InputsKernel(inputs) + [], + ) + self.scalar = scalar + self.msg = msg + + def has_side_effects(self) -> bool: + return True + + def get_free_symbol_uses(self, unbacked_only: bool = False): # type: ignore[no-untyped-def] + return get_free_symbols(self.scalar, unbacked_only) + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + if not config.scalar_asserts: + return + # NB: It is EXTREMELY important not to simplify the scalar under assertion here, + # because simplify is done with respect to runtime asserts. So if you have + # "u0 == 0" in the runtime asserts, if you subsequently try to + # simplify(u0 == 0), you will get True (because we've already runtime assert'ed + # that it's true). But we're code generating the actual runtime assert here!! + symbol = next(iter(self.get_free_symbol_uses(unbacked_only=False))) + if V.graph.cpp_wrapper: + symbol_str = f"std::to_string({symbol})" + sizevar = V.graph.wrapper_code.codegen_cpp_sizevar( + self.scalar, simplify=False + ) + # TODO: when we start compiling in C++20, annotate with [[unlikely]]. + wrapper.writeline( + f'if (!({sizevar})) {{ throw std::runtime_error("Expected {self.msg} but received " + {symbol_str}); }}' + ) + else: + sizevar = V.graph.wrapper_code.codegen_python_sizevar( + self.scalar, simplify=False + ) + wrapper.writeline(f"if not ({sizevar}):") + wrapper.writeline(f" raise RuntimeError({repr(self.msg)})") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + wrapper.writeline(f"{self.get_name()} = None") + + +@ir_dataclass(frozen=False) +class ExternKernelNode: + name: str + node: export_schema.Node + + +class FallbackKernel(ExternKernelAlloc): + """ + A class that represents a fallback kernel for handling operators that are not + directly support by inductor. It currently supports functional ops, view ops, + inplace aten ops, and mutating ops that are auto-functionalizable. + """ + + def __init__( # type: ignore[no-untyped-def] + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, + ) -> None: + super().__init__( + layout, + tuple(tensor_args), + tuple(nontensor_args), + op_overload=kernel, + ) + + self.use_runtime_dispatch = False + self.unbacked_bindings = unbacked_bindings + + assert isinstance( + kernel, + ( + torch._ops.OpOverload, + torch._ops.HigherOrderOperator, + ), + ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" + self.op_overload = kernel + self.unflatten_args = unflatten_args + self.kwargs = {} if kwargs is None else kwargs + V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type] + + # args that are aliased + self.alias_names: list[str] = [] + # args that are mutated AND returned from the op + self.mutation_names: list[str] = [] + + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + # We assume here that HOPs with FallbackKernel are functional. + # This may not always be true! HOPs must individually opt-in to + # FallbackKernel, so please check this if you opt-in. + return + + if "_c10d_functional" in self.op_overload.name(): + # _c10d_functional kernels are lowered into _CollectiveKernel which + # derives from FallbackKernel for the cpp codegen. The kernels + # don't pass the can_auto_functionalize check, but their mutation + # is handled properly by _CollectiveKernel. + return + + schema = self.op_overload._schema + + # NOTE: [FallbackKernel supported operators] + # We only support three types of operators: + # - functional ops + # - view ops + # - inplace aten ops + # - mutating ops that are auto-functionalizable. That is, + # the operator may mutate any number of inputs, but its outputs + # may not alias any of the inputs. + # + # The unsupported cases usually do not show up here (because + # AOTAutograd functionalized them away); the only way for an in-place + # op to show up here is if a lowering or pass introduced it. + if torch._library.utils.mutates_and_returns_first_arg(self.op_overload): + self.mutation_names.append(tensor_args[0].get_name()) + return + + if schema.is_mutable and not can_auto_functionalize(kernel): + raise NotImplementedError( + f"NYI: Can't generate FallbackKernel for {kernel}" + ) + + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + + def handle_aliasing_and_mutation(info, arg) -> None: # type: ignore[no-untyped-def] + # Assertions to make sure we didn't mismatch args + if isinstance(info.type, torch.ListType): + assert isinstance(arg, (list, tuple)) + if library_utils.is_tensor_like_type(info.type): + # PyTorch also accepts None and scalar types for args marked as "Tensor". + # We're not going to check all of them here. + assert not isinstance(arg, (tuple, list)) + + if arg is None: + return + if info.alias_info is None: + return + + def add_alias(t) -> None: # type: ignore[no-untyped-def] + self.alias_names.append(t.get_name()) + if info.alias_info.is_write: + self.mutation_outputs.append( + MutationOutput(NoneLayout(device=t.get_device()), t, self) + ) + + if library_utils.is_tensorlist_like_type(info.type): + if arg is not None: + for optional_tensor_arg in arg: + add_alias(optional_tensor_arg) + else: + assert library_utils.is_tensor_like_type(info.type) + add_alias(arg) + + for info, arg in torch._library.utils.zip_schema(schema, args, kwargs): + handle_aliasing_and_mutation(info, arg) + + def get_read_writes(self) -> dependencies.ReadWrites: + read_writes = super().get_read_writes() + + if self.op_overload is torch._prims.rng_prims.graphsafe_run_with_rng_state: + for arg in self.constant_args: + if isinstance(arg, GeneratorState): + read_writes = read_writes.with_read( + dependencies.StarDep(arg.get_name()) + ) + + return read_writes + + def codegen_unbacked_symbol_defs(self, wrapper) -> None: # type: ignore[no-untyped-def] + return wrapper.codegen_unbacked_symbol_defs_for_outputs( + self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None) + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + if unbacked_bindings := getattr(self, "unbacked_bindings", None): + resolved = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ) + assert resolved is not None + return resolved.keys() # type: ignore[return-value] + else: + return OrderedSet() + + def codegen_args(self): # type: ignore[no-untyped-def] + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self) -> str: + return self.ref + + tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] + args, kwargs = self.unflatten_args(tensor_args, self.constant_args) + if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): + args = self.fill_non_provided_args(args, kwargs) + args = [ + V.graph.wrapper_code.val_to_arg_str(x, param.real_type) + for param, x in zip(self.op_overload._schema.arguments, args) + ] + else: + args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] + + # let self.codegen_kwargs handle kwargs + self.kwargs.update(kwargs) + return args + + @staticmethod + def find_device(tensor_args, example_output): # type: ignore[no-untyped-def] + non_torch_bind_tensor_args = ( + [t for t in tensor_args if not isinstance(t, TorchBindObject)] + if tensor_args + else None + ) + if non_torch_bind_tensor_args: + devices = [arg.get_device() for arg in tensor_args if arg.get_device()] + return devices[0] + if isinstance(example_output, torch.Tensor): + return example_output.device + if isinstance(example_output, (list, tuple)): + device_set = OrderedSet( + FallbackKernel.find_device(None, x) for x in example_output + ) + # Remove None + devices = [device for device in device_set if device] + if len(devices) == 1: + return devices[0] + for device in devices: + if is_gpu(device.type): + return device + return devices[0] + return None + + def has_side_effects(self): # type: ignore[no-untyped-def] + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + return False + return get_schema_info(self.op_overload).is_mutable() + + def get_inputs_that_alias_output(self): # type: ignore[no-untyped-def] + return self.alias_names + + def get_mutation_names(self) -> Sequence[str]: + assert len(self.mutation_names) <= 1 + return self.mutation_names + + def export_extern_kernel_node(self): # type: ignore[no-untyped-def] + """ + ProxyExecutor Design Note + We export the ExternFallbackNodes (for custom ops) into a serialized file + and run it with a host side proxy executor to address the ABI problem + This is currently only implemented for fbcode. Eventually, we will also make this work for OSS. + Detailed design doc can be found at + https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing + """ + log.debug( + "Extern kernel node added for node %s with target %s.", + self.get_name(), + self.op_overload, + ) + + assert isinstance(self, FallbackKernel) + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + args = self.fill_non_provided_args(args, kwargs) + ordered_kwargs = [ + self.get_kwargs_value(key, **kwargs) + for key in self.ordered_kwargs_for_cpp_kernel + ] + target = self.op_overload + + if not V.graph.aot_mode: + # No need to serialize in the cpp wrapper JIT mode + return [*args, *ordered_kwargs] + + serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] + named_arguments = serializer.serialize_inputs(target, args, kwargs) + + # serialize_outputs + def handle_single_output(return_type, output): # type: ignore[no-untyped-def] + if isinstance(return_type, (torch.TensorType, torch.NoneType)): + # For single Tensor or None + out = output + if isinstance(output, (list, tuple)): + assert len(output) == 1 + out = output[0] + if isinstance(return_type, torch.TensorType): + return export_schema.Argument.create( + as_tensor=export_schema.TensorArgument(name=out.get_name()) + ) + else: # NoneType + assert out is None + return export_schema.Argument.create(as_none=True) + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ): + # For single TensorList + return export_schema.Argument.create( + as_tensors=[ + export_schema.TensorArgument(name=out.get_name()) + for out in output + ] + ) + elif isinstance(return_type, torch.OptionalType) and isinstance( + return_type.getElementType(), torch.TensorType + ): + # For OptionalTensor + if output is None: + return export_schema.Argument.create( + as_optional_tensor=export_schema.OptionalTensorArgument.create( + as_none=True + ) + ) + else: + return export_schema.Argument.create( + as_optional_tensor=export_schema.OptionalTensorArgument.create( + as_tensor=export_schema.TensorArgument( + name=output.get_name() + ) + ) + ) + elif isinstance(return_type, torch.IntType): + return export_schema.Argument.create(as_int=output) + else: + raise RuntimeError(f"Unsupported return type {type(return_type)}") + + if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind): + returns = target.schema(args[0], args[1]).returns # type: ignore[union-attr] + else: + returns = target._schema.returns # type: ignore[union-attr] + if len(returns) == 1: + # NOTE: [special handling of all_reduce_coalesced_'s return value] + # all_reduce_coalesced_ return a list of tensors via self.mutation_outputs + outputs = self.outputs if self.outputs else self.mutation_outputs + return_type = returns[0].real_type + output_arguments = [handle_single_output(return_type, outputs)] + else: + # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" + # Not generating output args for self.mutation_outputs + output_arguments = [ + handle_single_output(return_schema.real_type, output) + for return_schema, output in zip(returns, self.outputs) + ] + + node = ExternKernelNode( + name=self.get_name(), + node=export_schema.Node( + target=self.op_overload.name(), # type: ignore[union-attr] + inputs=named_arguments, + outputs=output_arguments, + metadata={}, + ), + ) + + V.graph.extern_kernel_nodes.append(node) + + return [*args, *ordered_kwargs] + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + kernel = self.op_overload + if kernel.namespace == "aten": # type: ignore[union-attr] + # Aten Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) + if V.graph.cpp_wrapper: + from torchgen.aoti.fallback_ops import inductor_fallback_ops + + if str(kernel) not in inductor_fallback_ops: + # C shim v2 is torchgen-ed, which should cover all aten ops. + # If you do hit a missed op, please update fallback_ops.py. + log.warning( + "%s is missing a c-shim implementation, using proxy executor as fallback", + kernel, + ) + self.use_runtime_dispatch = True + elif kernel.namespace == "_quantized": # type: ignore[union-attr] + # Internal Quantized Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) + elif V.graph.cpp_wrapper: + # For non-aten OpOverload, i.e. custom ops + # If the op is in custom_ops_to_c_shims, generate direct function call + self.use_runtime_dispatch = ( + kernel not in config.aot_inductor.custom_ops_to_c_shims + ) + + # Handle the special case where a complex number is input to a C-shim kernel for + # a scalar input. The torchgen'ed shim API will use type "double", which is + # incompatible with complex numbers, forcing a fallback to runtime dispatch. + if ( + V.graph.cpp_wrapper + and isinstance(kernel, torch._ops.OpOverload) + and not self.use_runtime_dispatch + ): + + def is_number(t: torch.JitType) -> bool: + if isinstance(t, torch.OptionalType): + return is_number(t.getElementType()) + return isinstance(t, torch.NumberType) + + # Using unflatten_args is a bit of a hack, but all the complex arguments we + # care about are in self.constant_args, and calling unflatten_args puts them + # in the correct order without triggering codegen. + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + # Append kwarg values to args. ordered_kwargs_for_cpp_kernel is guaranteed + # to be set, since this is an OpOverload kernel. + args_iter = itertools.chain( + args, + ( + self.get_kwargs_value(k, **kwargs) + for k in self.ordered_kwargs_for_cpp_kernel + ), + ) + self.use_runtime_dispatch = any( + isinstance(v, complex) and is_number(a.real_type) + for v, a in zip(args_iter, kernel._schema.arguments) + ) + + self.codegen_comment(wrapper) + if self.use_runtime_dispatch: + exported_args = self.export_extern_kernel_node() + wrapper.generate_fallback_kernel_with_runtime_lookup( + self.get_name(), + self.python_kernel_name, + lambda: [*self.codegen_args(), *self.codegen_kwargs()], + self.op_overload, + exported_args, + # NOTE: [special handling of all_reduce_coalesced_'s return value] + self.outputs if self.outputs else self.mutation_outputs, + ) + else: + wrapper.generate_fallback_kernel(self) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + self.codegen_alignment_asserts(wrapper) + + self.codegen_unbacked_symbol_defs(wrapper) + + @staticmethod + def tensor_to_layout(output: torch.Tensor): # type: ignore[no-untyped-def] + return FixedLayout( + output.device, + output.dtype, + convert_shape_to_inductor(output.size()), + convert_shape_to_inductor(output.stride()), + ) + + @classmethod + def create(cls, kernel, *args, **kwargs): # type: ignore[no-untyped-def] + fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) + context: AbstractContextManager[None] = ( + V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() # type: ignore[assignment] + ) + with context: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, *args, **kwargs) + + # We need this extra check for input alignment since the example + # inputs we created are always aligned. + has_unaligned_input = any(is_unaligned(arg) for arg in tensor_args) + + device = cls.find_device(tensor_args, example_output) + + if not device and isinstance( + kernel, torch._higher_order_ops.torchbind.CallTorchBind + ): + # use CPU device for torchbind methods that don't take in or output any tensor, e.g. size() + device = torch.device("cpu") + + if example_output is None: + packed = cls( + NoneLayout(device=device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + else: + assert device, "Not sure where to find device info" + packed = cls( + MultiOutputLayout(device=device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + def generate_output(output, indices): # type: ignore[no-untyped-def] + if isinstance(output, (list, tuple)): + return type(output)( + generate_output(output[i], indices + [(type(output), i)]) + for i in range(len(output)) + ) + elif isinstance(output, dict): + return { + key: generate_output(val, indices + [(type(output), key)]) + for key, val in output.items() + } + elif isinstance(output, torch.Tensor): + buf = MultiOutput( + cls.tensor_to_layout(output), + packed, + indices, + ) + if ( + config.assume_unaligned_fallback_output + or has_unaligned_input + or not tensor_is_aligned(output) + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] + return buf + elif isinstance(output, int): + return output + elif isinstance(output, torch.SymInt): + return output.node.expr + else: + assert output is None, ( + f"FallbackKernel output type {type(output)} is not supported" + ) + return None + + outputs = generate_output(example_output, []) + if isinstance(outputs, (list, tuple, dict)): + packed.outputs = outputs # type: ignore[assignment] + else: + packed.outputs = [outputs] + return outputs + + def apply_constraint(self): # type: ignore[no-untyped-def] + return super().apply_constraint() + + +@ir_dataclass(frozen=False) +class ComplexView(FallbackKernel): + """View a complex number as two dtyped numbers or vice versa""" + + def should_allocate(self) -> bool: + return False + + def get_inputs_that_alias_output(self) -> Sequence[str]: + # Signal to codegen that our output buffer isn't safe to reuse + return [self.inputs[0].get_name()] + + def __init__( # type: ignore[no-untyped-def] + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + *, + unbacked_bindings=None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + +@ir_dataclass +class MultiOutputLayout(OutputSpec): + device: torch.device + + def get_device(self) -> Optional[torch.device]: + return self.device + + +class MultiOutput(ExternKernel): + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.codegen_multi_output(self) + if not self.skip_size_stride_alignment_checks: + self.codegen_size_asserts(wrapper) + self.codegen_alignment_asserts(wrapper) + + def __init__( # type: ignore[no-untyped-def] + self, + layout: OutputSpec, + input, + indices: list[tuple[Any, ...]], + skip_size_stride_alignment_checks=False, + ) -> None: + super().__init__(None, layout, [input], ()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + self.indices = indices + self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return self.inputs[0].get_free_symbol_uses(unbacked_only) + + def should_allocate(self) -> bool: + if len(self.inputs) == 1 and ( + isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM + ): + return True + return False + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return [ + inp.get_name() + for inp in self.inputs + if isinstance(inp, FallbackKernel) + and len(inp.get_inputs_that_alias_output()) > 0 + ] + + +# We just use a normal dataclass for MutableBox/TensorBox/StorageBox since +# they're mainly lowering-time constructs that we expect to mutate and such. +@dataclasses.dataclass +class MutableBox(IRNode): + """ + TensorBox / StorageBox allow in-place mutation of Tensors + """ + + data: IRNode + + def has_exceeded_max_reads(self) -> bool: + return self.data.has_exceeded_max_reads() + + def get_device(self) -> Optional[torch.device]: + return self.data.get_device() + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + return self.data.make_loader() + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.data.make_indexer() + + def get_stride(self) -> Sequence[_IntLike]: + return self.data.get_stride() + + def get_name(self) -> str: + return self.data.get_name() + + def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool: + return self.data.has_large_inner_fn(threshold) + + def mark_reuse(self, users: int) -> None: + return self.data.mark_reuse(users) + + def realize_hint(self) -> None: + return self.data.realize_hint() + + def unwrap_view(self) -> IRNode: + return self.data.unwrap_view() + + def is_input_buffer(self) -> bool: + return self.data.is_input_buffer() + + def freeze_layout(self) -> None: + return self.data.freeze_layout() + + def freeze_layout_with_stride_order( + self, order: list[int], allow_padding: bool = False + ) -> None: + return self.data.freeze_layout_with_stride_order(order, allow_padding) + + def freeze_layout_with_fill_order(self, order: list[int]) -> None: + return self.data.freeze_layout_with_fill_order(order) + + def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None: + return self.data.freeze_layout_with_same_order(stride) + + def freeze_layout_with_exact_strides( + self, exact_strides: list[_IntLike], allow_padding: bool = False + ) -> None: + return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding) + + def get_read_writes(self) -> dependencies.ReadWrites: + return self.data.get_read_writes() + + def get_reads(self) -> OrderedSet[Dep]: + return self.data.get_reads() + + def num_reads(self) -> int: + return self.data.num_reads() + + def get_storage_numel(self) -> _IntLike: + return self.data.get_storage_numel() + + def get_reduction_type(self) -> Optional[str]: + return self.data.get_reduction_type() + + def get_reduction_size(self) -> Sequence[sympy.Expr]: + return self.data.get_reduction_size() + + def is_extern(self) -> bool: + return self.data.is_extern() + + def is_no_op(self) -> bool: + return self.data.is_no_op() + + def constant_to_device(self, device: torch.device) -> IRNode: + return self.data.constant_to_device(device) + + def get_mutation_names(self) -> Sequence[str]: + return self.data.get_mutation_names() + + def get_operation_name(self) -> str: + return self.data.get_operation_name() + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return self.data.get_inputs_that_alias_output() + + def realize(self) -> Optional[str]: + return self.data.realize() + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return self.data.get_free_symbol_uses(unbacked_only) + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_defining_op(self) -> Optional[Operation]: + return self.data.get_defining_op() + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.data.codegen_reference(writer) + + @property + def layout(self) -> OutputSpec: + # we intentionally call get_output_spec (rather than get_layout) since Buffer.layout is an OutputSpec + return self.data.get_output_spec() + + def get_layout(self) -> Layout: + return self.data.get_layout() + + def get_output_spec(self) -> OutputSpec: + return self.data.get_output_spec() + + def get_size(self) -> Sequence[Expr]: + return self.data.get_size() + + @property + def dtype(self): # type: ignore[no-untyped-def] + return self.data.dtype + + def __str__(self) -> str: + if isinstance(self.data, MutableBox): + line0 = f"{type(self).__name__}({type(self.data).__name__}(" + endl = "))" + inner = self.data.data + else: + line0 = f"{type(self).__name__}(" + inner = self.data + endl = ")" + + lines = [ + line0, + indent(str(inner)), + endl, + ] + return "\n".join(lines) + + __repr__ = __str__ + + +class TensorBox(MutableBox): + @staticmethod + def create(data): # type: ignore[no-untyped-def] + if isinstance(data, ShapeAsConstantBuffer): + return data + return TensorBox(StorageBox(data)) + + +class StorageBox(MutableBox): + def is_input_buffer(self) -> bool: + if isinstance(self.data, (InputBuffer, ReinterpretView)): + return self.data.get_name() in V.graph.graph_inputs + return False + + def is_module_buffer(self): # type: ignore[no-untyped-def] + return ( + isinstance(self.data, (ConstantBuffer)) + and self.data.get_name() in V.graph.constants + ) + + def realize(self) -> Optional[str]: + if isinstance( + self.data, + ( + ComputedBuffer, + InputsKernel, + InputBuffer, + ReinterpretView, + TemplateBuffer, + ), + ): + return self.data.get_name() + assert isinstance(self.data, (Pointwise, Reduction, Scan, Sort)), type( + self.data + ) + origin_node = self.data.get_origin_node() + traceback = self.data.get_traceback() + self.data = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=self.data.get_device(), + dtype=self.data.get_dtype(), + size=self.data.get_size(), + ), + data=self.data, + ) + self.data.name = V.graph.register_buffer(self.data) + V.graph.register_operation(self.data) + self.data.origins = self.origins + self.data.origin_node = origin_node + self.data.traceback = traceback + return self.data.name + + def realize_hint(self) -> None: + """ + Called on buffers we expect to be forced to realize later. + """ + if ( + isinstance(self.data, (Pointwise, Reduction)) + and self.data.inner_fn_opcount().nontrivial_read_count > 1 + ): + self.realize() + + def has_exceeded_max_reads(self) -> bool: + return isinstance(self.data, Pointwise) and ( + self.num_reads() > config.realize_acc_reads_threshold + or self.has_large_inner_fn() + ) + + def should_realize_on_reuse(self, users): # type: ignore[no-untyped-def] + """ + A heuristic to decide if we should realize a tensor + that is used multiple times. + """ + if users > 1 and isinstance(self.data, (Pointwise, Reduction)): + if is_cpu(self.data): + # Heuristic for realizing reused result of heavy ops on cpu + opcount = self.data.inner_fn_opcount() + heavy_ops = ["exp", "sigmoid"] # a list of heavy ops + if any(x in opcount.used_ops for x in heavy_ops): + return True + return ( + self.num_reads() > config.realize_reads_threshold + or self.has_large_inner_fn() + ) + return False + + def mark_reuse(self, users: int) -> None: + if self.should_realize_on_reuse(users): + self.realize() + + def num_reads(self): # type: ignore[no-untyped-def] + return self.data.num_reads() + + +@ir_dataclass(frozen=False) +class Subgraph(IRNode): + name: str + graph_module: torch.fx.GraphModule + graph: Optional[GraphLowering] = None + + +def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool: + buffers = [ + buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer + for buffer in buffers + ] + # assuming the same buffer is represented by the same IRNode object + return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers) + + +@ir_dataclass(frozen=False) +class InvokeSubgraph(ExternKernel): + """ + Ir node for the invoke_subgraph HOP. + """ + + subgraph: Optional[Subgraph] = None + operands: Optional[list[TensorBox]] = None + outputs: Optional[list[MultiOutput]] = None + + def __init__( + self, subgraph: Subgraph, operands: list[TensorBox], layout: MultiOutputLayout + ) -> None: + super().__init__( + name=None, + layout=layout, + inputs=operands, + ) + self.subgraph = subgraph + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create(cls, subgraph: Subgraph, *operands): # type: ignore[no-untyped-def] + from .lowering import constrain_to_fake_tensor + + # TODO(anijain2305) - Support sym expr as operands in future. + current_node = V.graph.current_node + + fake_operands = None + if eager_input_vals := current_node.meta.get("eager_input_vals"): + # eager_input_vals is (args_values, kwargs_values). We need args for invoke_subgraph + fake_operands = eager_input_vals[0][2:] + else: + # For the partitioned backward graph, we do not have + # eager_input_vals. Here, we rely on the recorded example values. + fx_operands = current_node.args[2:] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + + # Realize the inputs. Also intermediates can have different strides than + # the inputs of the subgraph. So, force the intermediates to have same + # strides as that of subgraph inputs. + operands = [cls.realize_input(x) for x in operands] + + new_operands = [] + for idx, operand in enumerate(operands): + if isinstance(operand, ShapeAsConstantBuffer): + new_operands.append(operand) + else: + new_operands.append( + constrain_to_fake_tensor(operand, fake_operands[idx]) + ) + + operands = new_operands + + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fake_operands, + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_operands) + + outputs = subgraph.graph.graph_outputs + + # Find the device - operands could be integers from shapes, so we can't + # use operands[0] + device = None + for operand in operands: + if not isinstance(operand, ShapeAsConstantBuffer): + device = operand.get_device() + break + assert device is not None + + invoke_subgraph = InvokeSubgraph( + subgraph=subgraph, + operands=operands, + layout=MultiOutputLayout(device=device), + ) + + def create_output(output: IRNode, ind: int): # type: ignore[no-untyped-def] + if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)): + return output + else: + return MultiOutput( + FixedLayout( + device=output.get_device(), # type: ignore[arg-type] + dtype=output.get_dtype(), + size=output.get_size(), # type: ignore[arg-type] + stride=output.get_stride(), # type: ignore[arg-type] + offset=output.get_layout().offset, + ), + invoke_subgraph, # type: ignore[has-type] + [(list, ind)], + skip_size_stride_alignment_checks=True, + ) + + outputs = [create_output(output, i) for i, output in enumerate(outputs)] + invoke_subgraph.outputs = outputs + return outputs + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.codegen_invoke_subgraph(self) + + +@ir_dataclass(frozen=False) +class Conditional(ExternKernel): + predicate: Optional[IRNode] = None + operands: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + true_subgraph: Optional[Subgraph] = None + false_subgraph: Optional[Subgraph] = None + outputs: Optional[list[MultiOutput]] = None + + def __init__( + self, + predicate: IRNode, + operands: list[Union[TensorBox, ShapeAsConstantBuffer]], + true_subgraph: Subgraph, + false_subgraph: Subgraph, + layout: MultiOutputLayout, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], + ) -> None: + self.predicate = predicate + self.operands = operands + self.true_subgraph = true_subgraph + self.false_subgraph = false_subgraph + + sym_args, tensor_args = _split_by_sym_type([predicate] + operands) + + super().__init__( + name=None, + layout=layout, + inputs=tensor_args, + constant_args=sym_args, + ) + if unbacked_bindings is not None: + self.unbacked_bindings = unbacked_bindings + + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create( # type: ignore[no-untyped-def] + cls, + predicate: TensorBox, + true_fn: Subgraph, + false_fn: Subgraph, + operands: list[Union[TensorBox, ShapeAsConstantBuffer]], + ): + predicate = cls.realize_input(predicate) + operands = [cls.realize_input(x) for x in operands] + fx_operands = V.graph.current_node.args[-1] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + + for subgraph in (true_fn, false_fn): + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fake_operands, + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_operands) + + true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + false_outputs = false_fn.graph.graph_outputs # type: ignore[union-attr] + + for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): + if _has_aliased_buffers(true_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.cond. " + f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}" + ) + + # make sure true and false outputs are structurally equivalent + assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs) + for i, (t_o, f_o) in enumerate(zip(true_outputs, false_outputs)): + assert t_o.get_device() == f_o.get_device(), (i, t_o, f_o) + assert t_o.get_dtype() == f_o.get_dtype(), (i, t_o, f_o) + assert t_o.get_layout().offset == f_o.get_layout().offset, (i, t_o, f_o) + + device = next( + o.get_device() + for o in [predicate] + operands + if not isinstance(o, ShapeAsConstantBuffer) + ) + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, + V.graph.current_node.meta.get("unbacked_bindings", None), + ) + assert device is not None, "cannot determine device" + conditional = Conditional( + predicate=predicate, + operands=operands, + true_subgraph=true_fn, + false_subgraph=false_fn, + layout=MultiOutputLayout(device=device), + unbacked_bindings=unbacked_bindings, + ) + + def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.expr]: + if isinstance(s, int): + return s + return s.node.expr + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=[_maybe_expr(sz) for sz in merged_output.size()], + stride=[_maybe_expr(sz) for sz in merged_output.stride()], + offset=output.get_layout().offset, + ), + conditional, + [(list, i)], + ) + # as the true and false outputs are equivalent, + # we can use either of them here as a "template" + for i, (output, merged_output) in enumerate( + zip(true_outputs, V.graph.current_node.meta["val"]) + ) + ] + + conditional.outputs = outputs # type: ignore[assignment] + return outputs + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.codegen_conditional(self) + wrapper.codegen_unbacked_symbol_defs_for_outputs( + self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {}) + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + if unbacked_bindings := getattr(self, "unbacked_bindings", None): + resolved = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ) + assert resolved is not None + return resolved.keys() # type: ignore[return-value] + else: + return OrderedSet() + + +def _split_by_sym_type( + args: list[Any], +) -> tuple[list[ShapeAsConstantBuffer], list[Any]]: + non_sym_args = [] + sym_args = [] + for arg in args: + if isinstance(arg, ShapeAsConstantBuffer): + sym_args.append(arg.expr) + else: + non_sym_args.append(arg) + + return sym_args, non_sym_args + + +@ir_dataclass(frozen=False) +class WhileLoop(ExternKernel): + carried_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + additional_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + cond_subgraph: Optional[Subgraph] = None + body_subgraph: Optional[Subgraph] = None + outputs: Optional[list[MultiOutput]] = None + + def __init__( + self, + carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + cond_subgraph: Subgraph, + body_subgraph: Subgraph, + layout: MultiOutputLayout, + ) -> None: + self.carried_inputs = carried_inputs + self.additional_inputs = additional_inputs + self.cond_subgraph = cond_subgraph + self.body_subgraph = body_subgraph + + sym_args, tensor_args = _split_by_sym_type(carried_inputs + additional_inputs) + super().__init__( + name=None, + layout=layout, + inputs=tensor_args, + constant_args=sym_args, + ) + + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create( # type: ignore[no-untyped-def] + cls, + cond_fn: Subgraph, + body_fn: Subgraph, + carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + ): + from torch._higher_order_ops.utils import check_input_alias_and_mutation + + def _require_exact_strides( + tensor_boxes: list[TensorBox | ShapeAsConstantBuffer], + fake_tensors: list[Union[int, torch.SymInt, torch.Tensor]], + ) -> list[TensorBox | ShapeAsConstantBuffer]: + assert len(tensor_boxes) == len(fake_tensors) + ret = [] + for tb, fk in zip(tensor_boxes, fake_tensors): + if isinstance(fk, torch.Tensor): + ret.append( + ExternKernel.require_exact_strides( + tb, fk.stride(), allow_padding=False + ) + ) + else: + ret.append(tb) + return ret + + fx_carried_inputs = V.graph.current_node.args[-2] + fx_additional_inputs = V.graph.current_node.args[-1] + fx_all_inputs = fx_carried_inputs + fx_additional_inputs # type: ignore[operator] + fake_all_inputs = [x.meta["val"] for x in fx_all_inputs] # type: ignore[union-attr] + fake_carried_inputs = [x.meta["val"] for x in fx_carried_inputs] # type: ignore[union-attr] + fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr] + + carried_inputs = [cls.realize_input(x) for x in carried_inputs] + carried_inputs = _require_exact_strides(carried_inputs, fake_carried_inputs) + additional_inputs = [cls.realize_input(x) for x in additional_inputs] + additional_inputs = _require_exact_strides( + additional_inputs, fake_additional_inputs + ) + all_inputs = carried_inputs + additional_inputs + + for subgraph in (cond_fn, body_fn): + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fx_all_inputs, # type: ignore[arg-type] + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_all_inputs) + # For body_fn, we require its output to have the exact same stride + # as inputs because the previous output is the input of next iteration. + # + # This cannot be automatically done in graph lowering because body_fn's graph outputs + # are not user-facing so the special handling for strides of user-facing output in graph + # lowering is not applicable. + if subgraph is body_fn: + assert len(subgraph.graph.graph_outputs) == len( + fake_carried_inputs + ) + subgraph.graph.graph_outputs = _require_exact_strides( # type: ignore[assignment] + subgraph.graph.graph_outputs, # type: ignore[arg-type] + fake_carried_inputs, + ) + + cond_outputs = cond_fn.graph.graph_outputs # type: ignore[union-attr] + body_outputs = body_fn.graph.graph_outputs # type: ignore[union-attr] + + if _has_aliased_buffers(body_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.while_loop. " + f"The outputs of the body_fn subgraph of torch.while_loop are aliased: {body_outputs}" + ) + + # make sure cond_fn returns a boolean scalar Tensor + assert len(cond_outputs) == 1, cond_outputs + p = cond_outputs[0] + if not isinstance(p, ShapeAsConstantBuffer): + assert p.get_dtype() == torch.bool, p + assert len(p.get_size()) == 0, p + + assert len(all_inputs) > 0, ( + "torch.while_loop is assumed to have at least one operand." + ) + + device = all_inputs[0].get_device() + + assert device is not None # to make linter happy + # make sure carried_inputs and body outputs are structurally equivalent + assert len(carried_inputs) == len(body_outputs), (carried_inputs, body_outputs) + for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)): + + def _guard_list_equals( + lhs_exprs: Sequence[Union[int, Any]], + rhs_exprs: Sequence[Union[int, Any]], + ) -> None: + for lhs, rhs in zip(lhs_exprs, rhs_exprs): + V.graph.sizevars.guard_equals(lhs, rhs) + + _guard_list_equals(op.get_size(), bo.get_size()) + _guard_list_equals(op.get_stride(), bo.get_stride()) + # assume all carried_inputs and outputs are on the same device + # as the MultiOutputLayout below requires single device + assert op.get_device() == bo.get_device(), (i, op, bo, device) + assert op.get_dtype() == bo.get_dtype(), (i, op, bo) + assert op.get_layout().offset == bo.get_layout().offset, (i, op, bo) + + while_loop = WhileLoop( + carried_inputs=carried_inputs, + additional_inputs=additional_inputs, + cond_subgraph=cond_fn, + body_subgraph=body_fn, + # asserted above that there is at least one operand + layout=MultiOutputLayout(device=device), + ) + + assert body_fn.graph is not None and isinstance( + body_fn.graph.module, torch.fx.GraphModule + ) # to make linter happy + + # Handling input mutations + mutated_idxs = check_input_alias_and_mutation( + body_fn.graph.module, fake_all_inputs + )[3] + mutated_idx_set = OrderedSet(mutated_idxs) + mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set] + real_outputs = { + idx: out + for idx, out in enumerate(body_outputs) + if idx not in mutated_idx_set + } + real_outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + while_loop, + [(list, idx)], + ) + for idx, output in real_outputs.items() + ] + while_loop.outputs = real_outputs + while_loop.mutation_outputs = [ + MutationOutput(inp.layout, inp, while_loop) # type: ignore[union-attr] + for inp in mutated_inputs + ] + + outputs_iter = iter(real_outputs) + mutated_inputs_iter = iter(mutated_inputs) + all_outputs = [ + next(mutated_inputs_iter) if idx in mutated_idx_set else next(outputs_iter) + for idx in range(len(body_outputs)) + ] + for inp, out in zip(carried_inputs, all_outputs): + if inp.get_name() in V.graph.graph_inputs: + # if a carried input of the while_loop is a graph input, + # it can be returned as is when the number of iterations + # is zero. due to this, we can't (generally) reuse the + # output buffers corresponding to the graph inputs, as + # the inputs may end up being mutated. + V.graph.never_reuse_buffers.add(out.get_name()) + return all_outputs + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.codegen_while_loop(self) + + +class EffectfulKernel(FallbackKernel): + def __init__( # type: ignore[no-untyped-def] + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + + from torch._higher_order_ops.effects import get_effect_key + + uncovered_args = [ + a.value if isinstance(a, TorchBindObject) else a for a in tensor_args + ] + effect_type = get_effect_key(kernel, (*nontensor_args, *uncovered_args), kwargs) + assert effect_type is not None + self.effect_type = effect_type + self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None) + V.graph.effectful_ops[effect_type] = self + + def get_read_writes(self) -> dependencies.ReadWrites: + read_writes = super().get_read_writes() + + if self.prev_effect_buffer is not None: + read_writes.reads.add( + dependencies.StarDep(self.prev_effect_buffer.get_name()) + ) + + return read_writes + + def has_side_effects(self) -> bool: + return True + + +class NonTensorObj(IRNode): + pass + + +@ir_dataclass +class TorchBindObject(NonTensorObj): + name: str + value: Union[FakeScriptObject, torch.ScriptObject] + + def get_name(self) -> str: + return self.name + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.name + + def get_value(self) -> Union[FakeScriptObject, torch.ScriptObject]: + return self.value + + def get_real_obj(self) -> torch.ScriptObject: + if isinstance(self.value, torch.ScriptObject): + return self.value + else: + return self.value.real_obj + + def get_buf_bytes(self) -> int: + # Returns the sum of all tensors in the flattened object + real_script_obj = self.get_real_obj() + flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined] + flat_elems = pytree.tree_flatten(flat_dict)[0] + flat_sizes = [ + x.element_size() * x.numel() + for x in flat_elems + if isinstance(x, torch.Tensor) + ] + return functools.reduce(operator.add, flat_sizes, 0) + + +@ir_dataclass +class GeneratorState(NonTensorObj): + name: str + device: torch.device + + def get_name(self) -> str: + return self.name + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.name + + +class _CollectiveKernel(FallbackKernel): + def should_allocate(self) -> bool: + return False + + def has_side_effects(self) -> bool: + return True + + # This is identical to FallbackKernel.set_cpp_kernel(), minus the + # part that checks against input aliasing and mutation. + def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: + assert type(self.op_overload) is torch._ops.OpOverload, ( + "Setting cpp kernel needs a valid op_overload" + ) + kernel = self.op_overload + self.cpp_kernel_name = kernel._schema.name + + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in kernel._schema.arguments if x.kwarg_only + ] + + # NOTE: [In-Place Collective Safety] + # Between the initiation and completion of an in-place collective, the + # input buffers are subject to both volatile reads and volatile writes. + # They must not be read, written to or reused by another kernel. To ensure + # the constraints, we model collective -> wait_tensor as as two-step + # mutation of the input buffers. + @classmethod + def create_inplace( # type: ignore[no-untyped-def] + cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs + ) -> None: + with V.graph.fake_mode: + ( + _example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" + for tensor_arg in tensor_args: + tensor_arg.realize() + + device = tensor_args[0].get_device() + packed = cls( + NoneLayout(device=device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + + inps = pytree.tree_leaves(inputs) + packed.mutation_outputs.extend( + [MutationOutput(NoneLayout(device=device), buf, packed) for buf in inps] + ) + + # For inplace collective ops, the input is guaranteed to be alias of the returned value of op. + packed.alias_names.extend([inp.get_name() for inp in inps]) + if "out" in kwargs: + packed.mutation_outputs.append( + MutationOutput(NoneLayout(device=device), kwargs["out"], packed) + ) + # For out-variant collective ops, the `out=` arg is guaranteed to be alias of the returned value of op. + packed.alias_names.append(kwargs["out"].get_name()) + + # NOTE: [Out-of-Place Collective Safety] + # Between the initiation and completion of an out-of-place collective: + # + # Input buffers: + # - Are subject to volatile reads + # - Can be read by another kernel + # - Must not be written to or reused by another kernel + # + # Output buffers: + # - Are subject to volatile writes + # - Must not be read, written to or reused by another kernel + # + # To ensure the safety of input buffers without sacrificing read + # availability, we add input buffers as read deps of wait_tensor kernels. + # + # To ensure the safety of output buffers, we model wait_tensor as a + # mutation to the output buffer. Note we also assumes the user program being + # correct and the output buffer is not consumed by kernels other than + # wait_tensor. + # + # TODO(yifu): add a pre-grad pass to validate the correctness of collective + # usage in the user program. + @classmethod + def create_out_of_place( # type: ignore[no-untyped-def] + cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs + ): + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + assert not unbacked_bindings, f"{kernel}, {unbacked_bindings}" + for tensor_arg in tensor_args: + tensor_arg.realize() + + if isinstance(example_output, list): + device = cls.find_device(tensor_args, example_output) + packed = cls( + MultiOutputLayout(device=device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.outputs = [ + MultiOutput( + cls.tensor_to_layout(tensor), + packed, + [(list, i)], + ) + for i, tensor in enumerate(example_output) + ] + for buf, tensor in zip(packed.outputs, example_output): + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + tensor + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] + return packed.outputs + else: + packed = cls( + cls.tensor_to_layout(example_output), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + example_output + ): + V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type] + packed.outputs = [packed] + return packed + + +class _WaitKernel(_CollectiveKernel): + def get_volatile_reads(self): # type: ignore[no-untyped-def] + inp = self.inputs[0] + if isinstance(inp, _CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] + elif isinstance(inp, MultiOutput): + # This can be two things: + # 1. Out-of-place multi-output coll + # 2. In-place coll with inputs coming from another MultiOutput + coll = inp.inputs[0] + # Case 1 + if isinstance(coll, _CollectiveKernel): + _, idx = inp.indices[0] + return [coll.inputs[idx]] + # Case 2 + return [] + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + @classmethod + def create_wait(cls, kernel, inp: TensorBox) -> None: # type: ignore[no-untyped-def] + with V.graph.fake_mode: + ( + _example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inp) + assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" + packed = cls( + NoneLayout(device=inp.get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.mutation_outputs.append( + MutationOutput(NoneLayout(device=inp.get_device()), inp, packed) + ) + + def get_read_writes(self) -> dependencies.ReadWrites: + read_writes = super().get_read_writes() + # See [Out-of-Place Collective Safety]. + volatile_reads = self.get_volatile_reads() + for vr in volatile_reads: + read_writes.reads.add(dependencies.StarDep(vr.get_name())) + return read_writes + + +# NB: recursive structure here reflects val_to_arg_str, avoid +# calling free_unbacked_symbols on "exotic" types that don't get pexpr +# treatment +def maybe_free_unbacked_symbols(s: object) -> OrderedSet[Symbol]: + if isinstance(s, (SymTypes, Expr)): + # This branch should be impossible in return position + return free_unbacked_symbols(s) + elif isinstance(s, (tuple, list)): + r = OrderedSet[sympy.Symbol]() + for t in s: + r |= maybe_free_unbacked_symbols(t) + return r + elif isinstance(s, torch.Tensor): + # This branch is impossible in constant-args position + return free_unbacked_symbols(s) + else: + return OrderedSet() + + +def maybe_free_symbols(s: object) -> OrderedSet[Symbol]: + if isinstance(s, (SymTypes, Expr)): + # This branch should be impossible in return position + return free_symbols(s) + elif isinstance(s, (tuple, list)): + r = OrderedSet[sympy.Symbol]() + for t in s: + r |= maybe_free_symbols(t) + return r + elif isinstance(s, torch.Tensor): + # This branch is impossible in constant-args position + return free_symbols(s) + else: + return OrderedSet() diff --git a/phivenv/Lib/site-packages/torch/_inductor/jagged_lowerings.py b/phivenv/Lib/site-packages/torch/_inductor/jagged_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0f1c04f42187f5ca77ce70055ad068e4da29a5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/jagged_lowerings.py @@ -0,0 +1,268 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import sympy + +import torch + +from .ir import Pointwise, TensorBox +from .lowering import fallback_handler, is_integer_type, register_lowering +from .virtualized import ops + + +# pyre-ignore[2,3] +def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len): + # jagged_len + 1 is used as the upper bound, + # because the last sequence length may be zero + begin_idx = ops.indirect_indexing( + offsets_loader([batch_idx]), + jagged_len + 1, + ) + end_idx = offsets_loader([batch_idx + 1]) + jagged_idx = begin_idx + seq_idx + return jagged_idx, end_idx + + +def get_inverse_offsets( + offsets: TensorBox, + jagged_len: Union[int, sympy.Expr], + realize: bool = True, +) -> TensorBox: + """ + Returns "inverse_offsets" - the inverse of the offsets array. + offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor). + inverse_offsets maps jagged index to batch index. + + e.g. for offsets [0, 3, 4, 9, 10] this will return + inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3] + + For the given offsets, the computed inverse_offsets are cached + on the first call and reused in the further calls. + """ + + if hasattr(offsets, "inverse_offsets"): + # inverse_offsets are already computed + # for these offsets: can reuse + return offsets.inverse_offsets + + # ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise + # kernels, i.e. we need to realize it before using. In other words, we need + # offsets to be in global memory so that we can binary search over the + # entire tensor + offsets.realize() + device: torch.device = offsets.get_device_or_error() + dtype: torch.dtype = offsets.get_dtype() + + # pyre-ignore[2,3] + def inner_fn(index): + idx = index[0] + bucket = ops.bucketize( + values=ops.index_expr(idx, dtype), + boundaries=( + offsets.get_name(), + offsets.get_size()[-1], + offsets.get_size()[0] * offsets.get_stride()[0], + offsets.get_stride()[-1], + ), + boundary_indices=0, + indexing_dtype=dtype, + right=True, + ) + # ops.bucketize above returns 1-based bucket indices, + # but we need 0-based, hence we subtract 1 from batch + return bucket - 1 + + inverse_offsets = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[jagged_len], + ) + + if realize: + # "freeze" the node so that it doesn't get inlined downstream. + inverse_offsets.realize() + + # cache inverse_offsets for further reuse + offsets.inverse_offsets = inverse_offsets # type: ignore[attr-defined] + + return inverse_offsets + + +def jagged_idx_to_dense_idx( + jagged_idx, # pyre-ignore[2] + inverse_offsets_loader, # pyre-ignore[2] + offsets_loader, # pyre-ignore[2] + batch_size: Union[int, sympy.Expr], + max_seq_len: Union[int, sympy.Expr], + offsets_dtype: torch.dtype, +) -> tuple[sympy.Expr, sympy.Expr]: + batch_idx = ops.indirect_indexing( + inverse_offsets_loader([jagged_idx]), + batch_size + 1, + ) + batch_start = offsets_loader([batch_idx]) + seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start + # check=False because there may be sequences longer than max_seq_len + seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False) + return batch_idx, seq_idx + + +def register_jagged_ops(): + # pyre-ignore[56] + @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default) + def _jagged_to_padded_dense_forward( + jagged_values: TensorBox, + jagged_offsets: list[TensorBox], + max_lengths: list[int], # list of ints/SymInts + padding_value: float = 0.0, + ) -> TensorBox: + device = jagged_values.get_device_or_error() + dtype = jagged_values.get_dtype() + + jagged_values_size = jagged_values.get_size() + + # only handle the common case of a single jagged dimension + if ( + len(jagged_offsets) != 1 + or device.type != "cuda" + or device != jagged_offsets[0].get_device() + or len(jagged_values_size) != 2 + or len(jagged_offsets[0].get_size()) != 1 + or len(max_lengths) != len(jagged_offsets) + or not is_integer_type(jagged_offsets[0]) + ): + return fallback_handler( + torch.ops.aten._jagged_to_padded_dense_forward.default, + add_to_fallback_set=False, + )( + jagged_values, + jagged_offsets, + max_lengths, + padding_value, + ) + + offsets: TensorBox = jagged_offsets[0] + offsets_len = offsets.get_size()[0] + offsets_dtype = offsets.get_dtype() + batch_size = offsets_len - 1 + max_seq_len = max_lengths[0] + embedding_len = jagged_values_size[1] + jagged_len = jagged_values_size[0] + + output_size = [batch_size, max_seq_len, embedding_len] + + values_loader = jagged_values.make_loader() + offsets_loader = offsets.make_loader() + + # pyre-ignore[2,3,53] + def inner_fn(index): + # dense tensor size: [B, N, D] + batch_idx, seq_idx, emb_idx = index + jagged_idx, end_idx = dense_idx_to_jagged_idx( + batch_idx=batch_idx, + seq_idx=seq_idx, + offsets_loader=offsets_loader, + jagged_len=jagged_len, + ) + return ops.masked( + ops.lt( + ops.index_expr(jagged_idx, offsets_dtype), + end_idx, + ), + lambda: values_loader([jagged_idx, emb_idx]), + padding_value, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=output_size, + ) + + def _dense_to_jagged_forward_impl( + fallback_op, # pyre-ignore[2] + dense: TensorBox, + jagged_offsets: list[TensorBox], + jagged_len: Optional[int] = None, + ) -> TensorBox: + device = dense.get_device_or_error() + dtype = dense.get_dtype() + + dense_size = dense.get_size() + + # only handle the common case of a single jagged dimension + if ( + len(jagged_offsets) != 1 + or device.type != "cuda" + or device != jagged_offsets[0].get_device() + or len(jagged_offsets[0].get_size()) != 1 + or len(dense_size) != 3 + or jagged_len is None + or not is_integer_type(jagged_offsets[0]) + ): + return fallback_handler(fallback_op, add_to_fallback_set=False)( + dense, + jagged_offsets, + jagged_len, + ) + + offsets: TensorBox = jagged_offsets[0] + offsets_dtype = offsets.get_dtype() + batch_size = dense_size[0] + max_seq_len = dense_size[1] + embedding_len = dense_size[-1] + + output_size = [jagged_len, embedding_len] + + dense_loader = dense.make_loader() + offsets_loader = offsets.make_loader() + + inverse_offsets = get_inverse_offsets( + offsets=offsets, + jagged_len=jagged_len, + ) + inverse_offsets_loader = inverse_offsets.make_loader() + + # pyre-ignore[2,3,53] + def inner_fn(index): + # jagged tensor size: [sum_B(N_B), D] + jagged_idx, emb_idx = index + batch_idx, seq_idx = jagged_idx_to_dense_idx( + jagged_idx=jagged_idx, + offsets_loader=offsets_loader, + inverse_offsets_loader=inverse_offsets_loader, + batch_size=batch_size, + max_seq_len=max_seq_len, + offsets_dtype=offsets_dtype, + ) + return ops.masked( + ops.lt( + ops.index_expr(seq_idx, offsets_dtype), + ops.index_expr(max_seq_len, offsets_dtype), + ), + lambda: dense_loader([batch_idx, seq_idx, emb_idx]), + 0.0, # jagged sequence longer than max_seq_len + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=output_size, + ) + + # pyre-ignore[56] + @register_lowering(torch.ops.aten._padded_dense_to_jagged_forward) + def _dense_to_jagged_forward( + dense: TensorBox, + jagged_offsets: list[TensorBox], + jagged_len: Optional[int] = None, + ) -> TensorBox: + return _dense_to_jagged_forward_impl( + fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default, + dense=dense, + jagged_offsets=jagged_offsets, + jagged_len=jagged_len, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/loop_body.py b/phivenv/Lib/site-packages/torch/_inductor/loop_body.py new file mode 100644 index 0000000000000000000000000000000000000000..58a23c1a57b8e706fd246db0daf454e5bc9d580b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/loop_body.py @@ -0,0 +1,702 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import functools +import itertools +import re +from enum import auto, Enum +from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, TypeVar + +import sympy + +import torch.fx +from torch._dynamo.utils import identity +from torch.fx.proxy import Scope, TracerBase +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import index_prevent_reordering +from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler +from .utils import ( + cache_on_self, + reduction_num_outputs, + sympy_index_symbol_with_prefix, + sympy_subs, +) +from .virtualized import ops, V + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +T = TypeVar("T") + + +class InterpreterShim(torch.fx.Interpreter): + @staticmethod + @functools.cache + def _dummy_gm(): + return torch.fx.symbolic_trace(identity) + + def __init__(self, graph, submodules): + # call super() with a placeholder to avoid constructing a + # GraphModule which is very expensive (it does codegen). + super().__init__(self._dummy_gm(), garbage_collect_values=False) + self.module = self # type: ignore[assignment] + self.graph = graph + self.submodules = submodules + self.extra_traceback = False + self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign] + self.current_node = None + + def run_node(self, n: torch.fx.Node) -> Any: + self.current_node = n + return super().run_node(n) + + def run(self, *args, **kwargs): + with V.set_interpreter_handler(self): + return super().run(*args, **kwargs) + + +# We don't need the nn.Module and constant handling in Tracer +class LightTracer(TracerBase): + def __init__(self): + super().__init__() + self.graph = torch.fx.Graph(tracer_cls=self.__class__) # type: ignore[arg-type] + self.scope = Scope("", None) + self.module_stack = {} # type: ignore[assignment] + self.node_name_to_scope = {} + + +class MemoryEntry(NamedTuple): + index_name: str # LoopBody.indexing_exprs[index_name] + buffer_name: Optional[str] + mode: Optional[str] # V.ops.store(..., mode=mode) + + +class MemoryUsageType(Enum): + # These are 1:1 with the opcode generating the usage + LOAD = auto() + LOAD_SEED = auto() + STORE = auto() + STORE_REDUCTION = auto() + INDEX_EXPR = auto() + CHECK_BOUNDS = auto() + BUCKETIZE = auto() + + +class LoopBody: + """ + Captures the body of a Loops subclass into an FX graph. Persists any + indexing simplifications and makes it easier to analyze loop bodies. + """ + + indexing_exprs: dict[str, sympy.Expr] + indexing_exprs_name: dict[sympy.Expr, str] + submodules: dict[str, Any] + subblocks: dict[str, LoopBodyBlock] + indirect_vars: list[sympy.Symbol] + indirect_var_ranges: dict[sympy.Symbol, sympy.Expr] + root_block: LoopBodyBlock + memory_usage: dict[MemoryUsageType, list[MemoryEntry]] + op_counts: collections.Counter[str] + + def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): + super().__init__() + + _flat_sizes = tuple(var_ranges.values()) + self.sizes = ( + _flat_sizes[: len(iter_vars)], + _flat_sizes[len(iter_vars) :], + ) + + self.iter_vars = iter_vars + self.reduce_vars = reduce_vars + self.var_ranges = var_ranges + + if isinstance(fn, LoopBody): + self._init_with_copy(fn, args) + else: + self._init_with_tracing(fn, args) + + self.indexing = None + + def _init_with_tracing(self, fn, args): + """Do an FX trace of an arbitrary callable to construct self""" + self.indexing_exprs = {} + self.indexing_exprs_name = {} + self.submodules = {"get_index": self.get_index} + self.subblocks = {} + self.indirect_vars = [] + self.indirect_var_ranges: dict[sympy.Symbol, sympy.Expr] = {} + self.memory_usage = {t: [] for t in MemoryUsageType} + self.op_counts = collections.Counter() + self.root_block = LoopBodyBlock(self, fn, args) # traces + del self.indexing_exprs_name # not used after _init_with_tracing + + def _init_with_copy(self, other: LoopBody, args): + """ + _init_with_tracing() is slow, so this is a fast path in the case + where we are just reordering/merging/splitting the args of an + existing LoopBody. + """ + indexing_exprs = other.indexing_from_args(args) + self.indexing_exprs = { + name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges) + for name, expr in indexing_exprs.items() + } + self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} + self.indirect_vars = other.indirect_vars + self.indirect_var_ranges = other.indirect_var_ranges + self.memory_usage = other.memory_usage + self.op_counts = other.op_counts + self.root_block = other.root_block.clone(self) + + submodules = {**other.submodules} + submodules.pop("get_index") + self.submodules = { + "get_index": self.get_index, + **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] + } + + def has_op(self, name: str): + return self.op_counts.get(name, 0) > 0 + + def merge_loops(self) -> LoopBody: + """ + Merge both iteration and reduction loops and return a new LoopBody. + """ + old_body = self + old_sizes = self.sizes + old_iter_vars, old_reduce_vars = old_body.vars + old_iter_sizes, old_reduce_sizes = old_sizes + + index_exprs = [*old_body.indexing_exprs.values()] + + iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops( + old_iter_vars, + old_iter_sizes, + index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes), + ) + + reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops( + old_reduce_vars, + old_reduce_sizes, + index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes), + ) + + # if iter_sizes == old_iter_sizes: + # # no dimensions get merged. + # return old_sizes, old_body + + # Note: if no dimension get merges, the symbol prefix will + # remain 'y'. But if we merge dimensions, we change prefix to + # 'z'. If this is an issue, we can always retrace the LoopBody + # to change symbol prefix to 'z'. + # + # There is indeed an issue due to symbol name conflicting. + # y0 maybe reused for the y dimension later. + ( + ( + iter_vars, + reduce_vars, + ), + var_ranges, + ) = dependencies.index_vars_no_squeeze(iter_sizes, reduce_sizes, prefix="t") + new_body = LoopBody( + old_body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + + # use the original symbol prefix + # Can try to optimize if this is a bottleneck for compilation time + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="p" + ) + new_body2 = LoopBody( + new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body2 + + def reorder_iter_loops(self, new_order) -> LoopBody: + """ + Reorder iteration loops and return a new LoopBody. + """ + from .ir import same_reorder + + old_body = self + old_sizes = self.sizes + assert len(old_sizes[0]) == len(new_order) + reorder_fn = same_reorder(new_order) + + iter_size, reduce_size = old_sizes + new_iter_size = reorder_fn(iter_size) + + new_sizes = (new_iter_size, reduce_size) + + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + *new_sizes, + prefix="t", # type: ignore[arg-type] + ) + + inverse_order = {b: a for a, b in enumerate(new_order)} + inverse_order = [inverse_order[i] for i in range(len(new_order))] + + def new_body(*indices: Sequence[sympy.Expr]) -> Any: + index = [*itertools.chain.from_iterable(indices)] + assert len(index) == len(iter_size) + len(reduce_size) + iter_idx = index[: len(iter_size)] + reduce_idx = index[len(iter_size) :] + iter_idx = [iter_idx[i] for i in inverse_order] + return old_body(iter_idx, reduce_idx) + + loop_body = LoopBody( + new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars + ) + + # use the original symbol prefix so we can do multiple round of reordering + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + *new_sizes, + prefix="p", # type: ignore[arg-type] + ) + new_body = LoopBody( + loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body + + @property + def vars(self): + assert self.iter_vars is not None + assert self.reduce_vars is not None + return self.iter_vars, self.reduce_vars + + @cache_on_self + def get_nodes(self): + all_graphs = itertools.chain( + (self.root_block.graph,), + (block.graph for block in self.subblocks.values()), + ) + return [node for graph in all_graphs for node in graph.nodes] + + @cache_on_self + def bounds(self): + # Doing a local import to avoid dumping all the code here + from .bounds import BoundVars + + return BoundVars(self) + + def get_read_expr(self, buffer_name): + # reversed to match old behavior + for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_write_expr(self, buffer_name): + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_read_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in self.memory_usage[MemoryUsageType.LOAD] + ] + + def get_all_read_expr(self, buffer_name): + # reversed to match old behavior + out = [] + for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): + if entry.buffer_name == buffer_name: + out.append(self.indexing_exprs[entry.index_name]) + return out + + def get_write_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ) + ] + + def get_all_write_expr(self, buffer_name): + out = [] + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ): + if entry.buffer_name == buffer_name: + out.append(self.indexing_exprs[entry.index_name]) + return out + + def debug_str(self): + lines = [f"var_ranges = {dict(self.var_ranges)}"] + lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) + lines.extend( + [ + block.debug_str(name) + for name, block in itertools.chain( + [("body", self.root_block)], self.subblocks.items() + ) + ] + ) + return "\n".join(lines) + + def is_memory_copy(self) -> bool: + """ + True of this contains only a single loads and store. + Note, this could involve a layout change. + """ + return ( + len(self.memory_usage[MemoryUsageType.LOAD]) == 1 + and len(self.memory_usage[MemoryUsageType.STORE]) == 1 + and len(self.submodules) == 1 # get_index + and self.root_block.contains_only_ops(("load", "store")) + ) + + __repr__ = debug_str + + def add_index_expr( + self, + expr: sympy.Expr, + mtype: MemoryUsageType, + buffer_name: Optional[str] = None, + mode: Optional[str] = None, + ): + name = self.indexing_exprs_name.get(expr) + if not name: + name = f"index{len(self.indexing_exprs)}" + self.indexing_exprs_name[expr] = name + self.indexing_exprs[name] = expr + self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode)) + return name + + def add_submodule(self, block, prefix): + """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" + if prefix[-1].isnumeric() and prefix not in self.submodules: + name = prefix + else: + name = f"{prefix}{len(self.submodules)}" + self.submodules[name] = block + return name + + def add_indirect(self, size): + var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) + assert var not in self.indirect_var_ranges + self.indirect_vars.append(var) + self.indirect_var_ranges[var] = size + return var + + def replace_indirect(self, old, new): + """Swap in a variable used in indirect indexing""" + if str(old) == str(new): + return + assert self.indexing is not None + self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} + + def get_index(self, name): + assert self.indexing is not None + return self.indexing[name] + + def indexing_from_args(self, indices): + index = [*itertools.chain.from_iterable(indices)] + assert len(index) == len(self.var_ranges), (index, self.var_ranges) + assert all(v not in self.var_ranges for v in index), ( + f"{self.var_ranges=}, {indices=}" + ) + replacements = dict(zip(self.var_ranges.keys(), index)) + return { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + + def __call__(self, *indices): + self.indexing = self.indexing_from_args(indices) + result = self.root_block() + self.indexing = None + return result + + def bind_set_indirect_shim(self, var, size, check, wrap_neg): + def set_indirect(new_var): + self.replace_indirect( + var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) + ) + + set_indirect.clone = functools.partial( # type: ignore[attr-defined] + LoopBody.bind_set_indirect_shim, + var=var, + size=size, + check=check, + wrap_neg=wrap_neg, + ) + return set_indirect + + def bind_scan_shim(self, combine_fn): + def shim(dtypes, values): + return V.ops.scan(dtypes, combine_fn, values) + + shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined] + return shim + + def bind_masked_shim(self, name): + def shim(mask, other): + return V.ops.masked(mask, self.subblocks[name], other) + + shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined] + return shim + + +class LoopBodyBlock: + """ + Captures the body of a Loops subclass into an FX graph. + In normal cases there will be a 1:1 mapping between LoopBody and + LoopBodyBlock, however in the case of ops.masked() the masked out + operations will manifest as an extra LoopBodyBlock. + """ + + def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]): + self.body = body + + tracer = LightTracer() + proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) + + from .index_propagation import IndexPropagation + + handler: Any = CountOps( + CaptureIndexing(proxy_ops, body, tracer), + body.op_counts, + ) + if config.constant_and_index_propagation: + handler = IndexPropagation( + handler, self.body.var_ranges, self.body.indirect_var_ranges + ) + + with V.set_ops_handler(handler): + # This indirection is just a cute way to get IndexPropagation to + # unwrap the return value. + ops.output(fn(*args)) + self.graph = tracer.graph + + def __call__(self): + graph = self.graph + submodules = self.body.submodules + + return InterpreterShim(graph, submodules).run(V.get_ops_handler()) + + def debug_str(self, name="block"): + code = torch.fx.GraphModule(self.body.submodules, self.graph).code + return re.sub( + # strip `; del var0` suffixes to make output prettier + r";[^\n]*", + "", + code.strip().replace("def forward(", f"def {name}("), + ) + + def contains_only_ops(self, allowed_ops) -> bool: + return all( + node.target in allowed_ops + for node in self.graph.find_nodes(op="call_method") + ) + + def clone(self, body: LoopBody): + """Shallow copy with a new parent LoopBody""" + copy = LoopBodyBlock.__new__(LoopBodyBlock) + copy.__dict__.update({**self.__dict__, "body": body}) + return copy + + +class CountOps(DefaultHandler): + def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]): + self._inner = inner + self._counts = counts + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + self._counts[name] += 1 + return getattr(self._inner, name)(*args, **kwargs) + + +class CaptureIndexing(WrapperHandler): + name = "CaptureIndexing" + + def __init__( + self, + inner: OpsHandler[Any], + body: LoopBody, + tracer: LightTracer, + ): + super().__init__(inner) + self.body = body + self.tracer = tracer + + def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any): + return self.tracer.create_proxy( + "call_module", + "get_index", + (self.body.add_index_expr(expr, mtype, **kwargs),), + {}, + ) + + def _simplify(self, expr: sympy.Expr) -> sympy.Expr: + return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges) + + def load(self, name: str, index: sympy.Expr): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name) + return self._inner.load(name, index) + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + self.body.add_index_expr( + sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name + ) + return self._inner.load_seed(name, index) + + def store(self, name, index, value, mode=None): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE, buffer_name=name, mode=mode + ) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE_REDUCTION, buffer_name=name + ) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + num_outputs = reduction_num_outputs(reduction_type) + if num_outputs > 1: + return tuple(result[i] for i in range(num_outputs)) + return result + + def index_expr(self, index, dtype): + index = self._simplify(index) + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = self._add_index(index, MemoryUsageType.INDEX_EXPR) + return self._inner.index_expr(index, dtype) + + def check_bounds(self, index, size, lower, upper): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS) + size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS) + return self._inner.check_bounds(index, size, lower, upper) + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> T: + """ + See [Note: Inductor bucketize op] + """ + boundaries = ( + boundaries[0], + self._add_index( + boundaries[1], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[2], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[3], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + ) + if sorter is not None: + sorter = ( + sorter[0], + self._add_index( + sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0] + ), + ) + + return self._inner.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + + def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) + return self.tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + def scan( + self, + dtype_proxy, + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], + value_proxy, + ): + shim = self.body.bind_scan_shim(combine_fn) + name = self.body.add_submodule(shim, "scan") + result = self.tracer.create_proxy( + "call_module", + name, + (dtype_proxy, value_proxy), + {}, + ) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(value_proxy))) + + def sort(self, dtypes, values, stable, descending): + result = self._inner.sort(dtypes, values, stable, descending) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(values))) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg) + self.tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + def output(self, *result): + self.tracer.create_proxy("output", "output", result, {}) diff --git a/phivenv/Lib/site-packages/torch/_inductor/lowering.py b/phivenv/Lib/site-packages/torch/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..cead87004b1257d41e64132531ef6d98a34cc72c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/lowering.py @@ -0,0 +1,7151 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import textwrap +import warnings +from collections import defaultdict +from collections.abc import Iterable, Sequence +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec +from unittest.mock import patch + +import sympy + +import torch +import torch.ao.quantization.fx._decomposed +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo.utils import counters +from torch._higher_order_ops.associative_scan import associative_scan_op +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation +from torch._prims_common import ( + canonicalize_dim, + canonicalize_dims, + check, + dtype_to_type, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + get_computation_dtype, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, + Number, +) +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing + +from .._dynamo.utils import import_submodule +from . import config, inductor_prims, ir, test_operators # NOQA: F401 +from .decomposition import decompositions, get_decompositions +from .ir import ( + DtypeView, + ExpandView, + IndexingConstant, + IRNode, + is_triton, + OnlineSoftmaxReduction, + ops_wrapper, + PermuteView, + Pointwise, + Reduction, + SqueezeView, + TensorBox, + validate_ir, + View, +) +from .utils import ( + ceildiv, + decode_device, + is_dynamic, + is_gpu, + is_pointwise_use, + is_view, + needs_fallback_due_to_atomic_add_limitations, + pad_listlike, + register_op_dtype_propagation_rules, + register_op_requires_libdevice_fp64, + sympy_product, + use_scatter_fallback, +) +from .virtualized import ops, V + + +if TYPE_CHECKING: + from .ops_handler import ReductionType + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +# TODO(jansel): we should implement decomps or lowerings for these +# https://github.com/pytorch/torchdynamo/issues/327 +FALLBACK_ALLOW_LIST = OrderedSet( + [ + "torchvision::roi_align", + "aten::index_add", + ] +) + +log = logging.getLogger(__name__) +lowerings: dict[Union[Callable[..., Any], str], Callable[..., Any]] = {} +# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints +_maybe_layout_constraints: dict[ + torch._ops.OpOverload, Optional[Callable[..., Any]] +] = {} +fallbacks = OrderedSet[torch._ops.OpOverload]() +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +needs_realized_inputs = OrderedSet[torch._ops.OpOverload]() +foreach_ops = OrderedSet[torch._ops.OpOverload]( + [torch._higher_order_ops._foreach_map] # type: ignore[list-item] +) +# TODO(rec): torch._higher_order_ops._foreach_map is not an OpOverload +# so why is it in foreach_ops? +inplace_foreach_ops = OrderedSet[torch._ops.OpOverload]() +inplaceable_foreach_ops: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {} +quantized_decomposed = torch.ops.quantized_decomposed + + +def cur_node_has_non_foreach_users(): + for node in V.graph.current_node.users: + for user in node.users: + if not (user.op == "call_function" and (user.target in foreach_ops)): + return True + + return False + + +# group by device, whether any of the inputs are dynamic +# note arg_pairs may or may not be a pair +# foreach_map for example just passes output buffers here +def group_foreach_args(arg_pairs: Iterable[Union[tuple[Any, Any], Any]]): + out = defaultdict(list) + unpack_args = False + for i, args in enumerate(arg_pairs): + if not isinstance(args, Iterable): + unpack_args = True + args = (args,) + use_foreach = ( + not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes + ) + device = None + for t in args: + if isinstance(t, TensorBox): + device = t.data.get_device() + break + assert device is not None, "foreach op should have at least one tensor arg" + if unpack_args: + (args,) = args + out[(device, use_foreach)].append((i, args)) + return out + + +def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]: + """Get layout constraints. Returns None if there are no layout constraints.""" + if not isinstance(fn, torch._ops.OpOverload): + # Only OpOverloads have layout constraints. + return None + if fn in _maybe_layout_constraints: + return _maybe_layout_constraints[fn] + return None + + +def tag_to_layout_constraint(tag): + if tag == torch._C.Tag.needs_exact_strides: + return constrain_to_fake_tensors + if tag == torch._C.Tag.needs_contiguous_strides: + return require_contiguous_strides + if tag == torch._C.Tag.needs_fixed_stride_order: + return constrain_to_fx_strides + if tag == torch._C.Tag.flexible_layout: + return None + raise AssertionError(f"Unknown layout constraint tag: {tag}") + + +def assert_nyi(cond, msg): + if not cond: + raise NotImplementedError(f"inductor does not support {msg}") + + +def add_needs_realized_inputs(fn): + if isinstance(fn, (list, set, tuple, OrderedSet)): # noqa: set_linter + return [add_needs_realized_inputs(x) for x in fn] + needs_realized_inputs.add(fn) + if isinstance(fn, torch._ops.OpOverloadPacket): + needs_realized_inputs.update( + getattr(fn, overload) for overload in fn.overloads() + ) + + +def add_layout_constraint(fn, constraint): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + _maybe_layout_constraints[getattr(fn, overload)] = constraint + else: + _maybe_layout_constraints[fn] = constraint + + +add_needs_realized_inputs( + [ + aten.as_strided, + aten.as_strided_copy, + aten.avg_pool2d, + aten.avg_pool2d_backward, + aten.bmm, + aten.convolution, + aten.convolution_backward, + aten.max_pool2d_with_indices, + aten.max_pool3d_with_indices, + aten.max_pool2d_with_indices_backward, + aten.mm, + aten.upsample_nearest2d, + aten._upsample_nearest_exact2d, + aten._int_mm, + ] +) + +# TODO(jansel): ezyang says we won't need this in the future, try removing it +# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28 +DTYPE_ID_LOOKUP = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.int16, + 3: torch.int32, + 4: torch.int64, + 5: torch.float16, + 6: torch.float32, + 7: torch.float64, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex32, + 11: torch.bool, + 15: torch.bfloat16, + # TODO(jansel): add quantized types? + # _(c10::qint8, QInt8) /* 12 */ + # _(c10::quint8, QUInt8) /* 13 */ + # _(c10::qint32, QInt32) /* 14 */ + # _(c10::quint4x2, QUInt4x2) /* 16 */ + # _(c10::quint2x4, QUInt2x4) /* 17 */ +} + + +def decode_dtype(dtype: int): + if not isinstance(dtype, int): + return dtype + assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP" + dtype = DTYPE_ID_LOOKUP[dtype] + return dtype + + +def is_integer_type(x): + if isinstance(x, TensorBox): + return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + elif isinstance(x, sympy.Expr): + return x.is_integer is True # type: ignore[attr-defined] + else: + return isinstance(x, int) + + +def is_boolean_type(x): + if isinstance(x, TensorBox): + return is_boolean_dtype(x.get_dtype()) + else: + return isinstance(x, bool) + + +def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): + def construct_input(inp): + if isinstance(inp, (Number, sympy.Basic)): + return inp + else: + dim = len(inp.get_size()) + # construct a tmp tensor to feed into torch.result_type + return torch.zeros([1] * dim, dtype=inp.get_dtype()) + + inps = [construct_input(arg) for arg in args] + _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind) + return dtype + + +def get_overloads(aten_fn): + if not isinstance(aten_fn, (list, tuple)): + aten_fn = [aten_fn] + else: + aten_fn = list(aten_fn) + + for fn in list(aten_fn): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + if other_fn not in lowerings: + aten_fn.append(other_fn) + + return aten_fn + + +def in_namespace(op, namespace): + if isinstance(op, torch._ops.OpOverloadPacket): + return namespace in op._qualified_op_name + elif isinstance(op, torch._ops.OpOverload): + return namespace in op.name() + return False + + +def transform_args( + args: list[Any], + kwargs: dict[str, Any], + broadcast: bool, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool: bool, +) -> tuple[list[Any], dict[str, Any]]: + args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)] + # check that there's something to transform + if not args_indices and not kwargs_indices: + return args, kwargs + + if type_promotion_kind or convert_input_to_bool: + if convert_input_to_bool: + dtype = torch.bool + else: + # FIXME this is a crude approximation for promoting args + promoting_args = [ + a + for a in args + if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype") + ] + # only consider tensor kwargs for promotion, for now + promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype")) + dtype = get_promoted_dtype( + *promoting_args, + type_promotion_kind=type_promotion_kind, # type: ignore[arg-type] + ) + + device = ( + args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]] + ).get_device() + + # sometimes args are an immutable list so we can't mutate them + def promote(arg): + if isinstance(arg, TensorBox): + return to_dtype(arg, dtype) + elif isinstance(arg, ir.Constant): + return ir.Constant(value=arg.value, dtype=dtype, device=device) + else: + return arg + + args = [promote(a) for a in args] + kwargs = {k: promote(v) for k, v in kwargs.items()} + + if broadcast: + broadcasted = broadcast_tensors( + *list( + itertools.chain( + (args[i] for i in args_indices), + (kwargs[k] for k in kwargs_indices), + ) + ) + ) + size = list(broadcasted[0].get_size()) + + for i, x in zip(args_indices, broadcasted[: len(args_indices)]): + args[i] = x + for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]): + kwargs[k] = x + + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], size) + for k in kwargs: + if isinstance(kwargs[k], ir.Constant): + kwargs[k] = ExpandView.create(kwargs[k], size) + + return args, kwargs + + +def _register_foreach_lowering(aten_fn, decomp_fn): + """ + Add a foreach lowering to lowerings dict. + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + assert len(args) <= 2 + out = decomp_fn(*args, **kwargs) + validate_ir(out) + return out + + aten_fns = get_overloads(aten_fn) + foreach_ops.update(aten_fns) + lowerings.update(dict.fromkeys(aten_fns, wrapped)) + return wrapped + + +def _register_lowering( + aten_fn, + decomp_fn, + broadcast, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool, + lowering_dict, +): + """ + Add a lowering to lowerings dict + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + args: list[Any] = list(args) + kwargs: dict[str, Any] = dict(kwargs) + unpacked = False + # TODO maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = list(args[0]) + + if not all( + (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn + ): + # explicitly assert for "out=" ops for better error messages + assert not any(x == "out" for x in kwargs.keys()), ( + "out= ops aren't yet supported" + ) + + args, kwargs = transform_args( + args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool + ) + + if unpacked: + args = [args] + + out = decomp_fn(*args, **kwargs) + validate_ir(out) + + return out + + aten_fn = get_overloads(aten_fn) + + lowering_dict.update(dict.fromkeys(aten_fn, wrapped)) + return wrapped + + +def register_lowering( + aten_fn, + broadcast=False, + type_promotion_kind: Optional[ + ELEMENTWISE_TYPE_PROMOTION_KIND + ] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + lowering_dict=lowerings, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + Shim to support decorator syntax. + """ + return functools.partial( + _register_lowering, + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + lowering_dict=lowering_dict, + ) + + +def broadcast_symbolic_shapes(a, b): + """ + Broadcasting logic based on symbolic shapes. + + We give the shapes 0 and 1 concrete values, while all other shapes + are symbolic sympy formulas. + """ + output = [] + for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One): + if V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(y, 1), size_oblivious=True + ): + output.append(x) + elif V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(x, 1), size_oblivious=True + ): + output.append(y) + else: + V.graph.sizevars.guard_equals(x, y) + if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): + output.append(y) # prefer shorter formula + else: + output.append(x) + return tuple(reversed(output)) + + +def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): + assert override_return_dtype is None or type_promotion_kind is None, ( + "only one of override_return_dtype or type_promotion_kind may be given" + ) + + if override_return_dtype is None and type_promotion_kind is None: + type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + + if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs): + return inputs + if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): + dtype = override_return_dtype or get_promoted_dtype( + *inputs, type_promotion_kind=type_promotion_kind + ) + + def const_func(x): + if isinstance(x, sympy.Basic): + return ir.IndexingConstant( + index=x, dtype=dtype, device=decode_device(None) + ) + else: + return ir.Constant(value=x, dtype=dtype, device=decode_device(None)) + + return [const_func(x) for x in inputs] + ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant))) + out = [] + for x in inputs: + if isinstance(x, (int, float)): + out.append( + ExpandView.create( + ir.Constant( + value=x, dtype=ex.get_dtype(), device=ex.get_device_or_error() + ), + list(ex.get_size()), + ) + ) + elif isinstance(x, sympy.Basic): + out.append( + ExpandView.create( + IndexingConstant( + index=x, dtype=ex.get_dtype(), device=ex.get_device_or_error() + ), + list(ex.get_size()), + ) + ) + else: + out.append(x) + + return out + + +def make_pointwise( + fn, + override_return_dtype=None, + override_device=None, + override_fn_when_input_bool=None, + allow_alpha=False, + triton_fallback=None, +): + def inner(*inputs: TensorBox, alpha=None): + if triton_fallback is not None and any( + isinstance(inp, IRNode) and is_triton(inp) for inp in inputs + ): + assert not allow_alpha # not implemented + return triton_fallback(*inputs) + + inputs = promote_constants(inputs, override_return_dtype) + if allow_alpha: + if alpha is not None and alpha != 1: + inputs = list(inputs) + inputs[-1] = mul(inputs[-1], alpha) + else: + assert alpha is None + loaders = [x.make_loader() for x in inputs] + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + + for other in inputs[1:]: + assert isinstance(other, ir.BaseConstant) or len(ranges) == len( + other.get_size() + ), f"ndim mismatch {fn} {ranges} {other.get_size()}" + + # in tracing, we will annotate pointwise nodes that correspond to the output of + # a pointwise node that would have been run in eager. intermediary pointwise nodes + # during decompositions are not annotated. + low_pr_fp = (torch.bfloat16, torch.float16) + emulate_precision_casts = ( + V.graph is not None + and getattr(V.graph, "current_node", None) is not None + and V.graph.current_node.meta is not None + and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False) + and dtype in low_pr_fp + ) + + def inner_fn(index): + assert len(index) == len(ranges), f"wrong ndim {index} {ranges}" + if dtype == torch.bool and override_fn_when_input_bool is not None: + return override_fn_when_input_bool(*[load(index) for load in loaders]) + else: + inputs_loaded = [] + for inp_index, load in enumerate(loaders): + out = load(index) + inp_dtype = inputs[inp_index].get_dtype() + if emulate_precision_casts and inp_dtype in low_pr_fp: + downcast = ops.to_dtype(out, inp_dtype, use_compute_types=False) + out = ops.to_dtype(downcast, inp_dtype) + inputs_loaded.append(out) + + out = fn(*inputs_loaded) + if emulate_precision_casts: + # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here, + # then upcasting again, to emulate casts that eager would do. + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + return ops.to_dtype(downcast, dtype) + return out + + if not override_device: + device = None + for i in inputs: + if is_gpu(i.get_device().type): + device = i.get_device() + break + if not device: + device = inputs[0].get_device() + + device = override_device or device + + return Pointwise.create( + device=device, # type: ignore[arg-type] + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + return inner + + +def make_foreach_pointwise(pw_fn, allow_alpha=False): + def inner(*inputs: list[list[TensorBox]], alpha=1): + realize_outputs = ( + len(V.graph.current_node.users) == 0 + or V.graph.current_node.target in inplace_foreach_ops + or cur_node_has_non_foreach_users() + ) + + a_list_input = None + for input in inputs: + if isinstance(input, (list, tuple)): + a_list_input = input + break + assert a_list_input is not None, ( + "at least one input must be a list to a foreach op" + ) + + # broadcast scalar inputs to match length of list inputs + broadcast_inputs = [] + for input in inputs: + if not isinstance(input, (list, tuple)): + broadcast_inputs.append([input] * len(a_list_input)) + else: + broadcast_inputs.append(input) + + groups = group_foreach_args(zip(*broadcast_inputs)) + + outputs = [None] * len(a_list_input) + for (device, use_foreach), group in groups.items(): + operation_list: list[str] = [] + for ( + output_ind, + args, + ) in group: + if allow_alpha: + output = pw_fn(*args, alpha=alpha) + else: + output = pw_fn(*args) + + outputs[output_ind] = output + + if ( + V.graph.has_feature(device, BackendFeature.FOREACH) + and use_foreach + and realize_outputs + ): + output.realize() + operation_list.append(output.get_operation_name()) + + if operation_list: + V.graph.register_operation_list(operation_list) + + assert all(x is not None for x in outputs) + return outputs + + return inner + + +def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): + src_dtype = x.get_dtype() + if src_dtype == dtype: + return clone(x) if copy else x + + def _to_dtype(x): + return ops.to_dtype(x, dtype, src_dtype=src_dtype) + + return make_pointwise(_to_dtype, override_return_dtype=dtype)(x) + + +@register_lowering(torch._higher_order_ops._foreach_map, type_promotion_kind=None) +def _foreach_map(subgraph, *args, **kwargs): + """ + This lowers an invocation of foreach_map + The way this works is that an arbitrary N-arg func is provided by the user, looped over by the + polyfill with the same semantics as a foreach op (a loop applying an n-ary function to n args) + and then traced into a subgraph by dynamo. + This code allows us to inline the subgraph into the main graph lowering using the PontwiseSubgraphLowering. + The graph outputs represent the vertically fused sequence of ops, and then register_operation_list + below registers the buffers as horizontally fuseable in the scheduler. + """ + from .subgraph_lowering import PointwiseSubgraphLowering + + inputs = args + + gm = subgraph.graph_module + pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*inputs) + + sub_outputs = pw_subgraph.graph_outputs + # group outputs by device and register as foreach + assert sub_outputs # mypy lol + groups = group_foreach_args(sub_outputs) + + outputs = [None] * len(sub_outputs) + for (device, use_foreach), group in groups.items(): + operation_list: list[str] = [] + for ( + output_ind, + output, + ) in group: + outputs[output_ind] = output + + if V.graph.has_feature(device, BackendFeature.FOREACH) and use_foreach: + output.realize() + operation_list.append(output.get_operation_name()) + + if operation_list: + V.graph.register_operation_list(operation_list) + + assert all(x is not None for x in outputs) + return outputs + + +@register_lowering(prims.convert_element_type, type_promotion_kind=None) +def _convert_element_type(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + if x.get_size(): + # Decompose since aa aten fallback is more friendly for c++ codegen. + # This decomposition doesn't work for empty tensor, which needs more investigation. + dst = empty_like(x, dtype=dtype) + ir.InplaceCopyFallback.create(dst, x) + return dst + else: + return fallback_handler( + prims.convert_element_type.default, add_to_fallback_set=False + )(x, dtype) + return to_dtype(x, dtype, copy=True) + + +def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False): + x_dtype = x.get_dtype() + if x_dtype == dtype: + return clone(x) if copy else x + + def _get_primitive_bitwidth(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits + else: + return torch.iinfo(dtype).bits + + src_bits = _get_primitive_bitwidth(x_dtype) + dst_bits = _get_primitive_bitwidth(dtype) + if src_bits != dst_bits: + # fallback to aten eager implementation for differing bitwidths + return fallback_handler(aten.view.dtype)(x, dtype) + else: + return TensorBox(DtypeView.create(x, dtype)) + + +@register_lowering(aten.view.dtype, type_promotion_kind=None) +def _view_dtype(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + return TensorBox.create( + ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype) + ) + return to_dtype_bitcast(x, dtype) + + +def to_device(x: TensorBox, device: torch.device, *, copy=False, non_blocking=False): + device = decode_device(device) + if x.get_device() == device: + return clone(x) if copy else x + return TensorBox.create(ir.DeviceCopy.create(x, device, non_blocking)) + + +@register_lowering(prims.device_put, type_promotion_kind=None) +def _device_put(x: TensorBox, device: torch.device, non_blocking=False): + return to_device(x, device, copy=True, non_blocking=non_blocking) + + +def register_pointwise( + aten_fn, + name=None, + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + override_return_dtype=None, + override_fn_when_input_bool=None, + allow_alpha=False, + triton_fallback=None, +): + """A pointwise function that maps ops.{name} to inputs""" + name = name or aten_fn.__name__ + fn = ops_wrapper(name) + + register_op_dtype_propagation_rules( + name, type_promotion_kind, override_return_dtype + ) + + if override_fn_when_input_bool is not None: + override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) + + fn = make_pointwise( + fn, + override_return_dtype=override_return_dtype, + override_fn_when_input_bool=override_fn_when_input_bool, + allow_alpha=allow_alpha, + triton_fallback=triton_fallback, + ) + fn = register_lowering( + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + convert_input_to_bool=convert_input_to_bool, + )(fn) + return fn + + +def register_frexp(): + """A pointwise function that maps ops.frexp to inputs""" + name = "frexp" + frexp = ops_wrapper("frexp") + + def frexp0(*args, **kwargs): + return frexp(*args, **kwargs)[0] # type: ignore[index] + + def frexp1(*args, **kwargs): + return frexp(*args, **kwargs)[1] # type: ignore[index] + + pw_fns = [ + make_pointwise(frexp0), + make_pointwise(frexp1, override_return_dtype=torch.int32), + ] + + def fn(*args, **kwargs): + return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs) + + fn = register_lowering( + aten.frexp, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + )(fn) + return fn + + +register_frexp() + + +def register_foreach_pointwise( + aten_fn, + pointwise_lowering_fn, + allow_alpha=False, +): + fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha) + fn = _register_foreach_lowering(aten_fn, fn) + return fn + + +@register_lowering(aten.where, broadcast=False, type_promotion_kind=None) +def where(cond, a, b): + def fn(*args): + return ops.where(*args) + + if isinstance(a, (float, int)): + a = constant_like(a)(b) + if isinstance(b, (float, int)): + b = constant_like(b)(a) + + args = [cond, a, b] + dtype = get_promoted_dtype( + args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + return make_pointwise(fn, override_return_dtype=dtype)( + args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) + ) + + +@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) +def broadcast_tensors(*inputs): + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + return broadcast_tensors(*inputs[0]) + target: list[sympy.Expr] = functools.reduce( + broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] + ) + outputs = [] + for x in inputs: + sizes = x.get_size() + if len(sizes) != len(target) or any( + ( + ( + V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + or ( + not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + ) + for a, b in zip(sizes, target) + ): + x = expand(x, target) + outputs.append(x) + return outputs + + +@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of]) +def nop(x): + return x # AOT autograd handles this for us + + +if hasattr(aten, "lift_fresh"): + register_lowering(aten.lift_fresh)(nop) + + +@register_lowering(aten.squeeze, type_promotion_kind=None) +def squeeze(x, dim=None): + assert isinstance(x, TensorBox) + if dim is None: + return TensorBox(SqueezeView.create(x.data)) + + dim = ( + V.graph.sizevars.evaluate_static_shape(dim) + if isinstance(dim, (int, sympy.Expr)) + else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) + ) + dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] + dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim) + + new_shape = [] + for d, s in enumerate(x.get_size()): + if not ( + d in dims + and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True) + ): + new_shape.append(s) + + # squeeze does nothing if the size isn't 1 + return view(x, new_shape) if new_shape != x.get_size() else x + + +@register_lowering(aten.squeeze_copy, type_promotion_kind=None) +def squeeze_copy(x, dim=None): + return clone(squeeze(x, dim)) + + +@register_lowering([aten.squeeze_]) +def squeeze_(x, dim=None): + val = squeeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +@register_lowering(aten.isinf) +def isinf(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isinf") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.isnan) +def isnan(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isnan") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.ceil) +def ceil(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("ceil") + return make_pointwise(fn)(x) + + +@register_lowering(aten.floor) +def floor(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("floor") + return make_pointwise(fn)(x) + + +@register_lowering(aten.round.default) +def round(x): + if is_integer_type(x): + return clone(x) + else: + fn = ops_wrapper("round") + return make_pointwise(fn)(x) + + +@register_lowering(aten.trunc) +def trunc(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("trunc") + return make_pointwise(fn)(x) + + +@register_lowering(aten.expand, type_promotion_kind=None) +def expand(x, sizes): + (x,) = promote_constants([x]) + if isinstance(x, ir.BaseConstant): + return ExpandView.create(x, tuple(sizes)) + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + if tuple(x.get_size()) == tuple(sizes): + return x + + if not free_unbacked_symbols(x.get_size()): + x_size_product = V.graph.sizevars.size_hint_or_throw( + sympy_product(x.get_size()) + ) + # TODO: It would be better to realize the input if any of its sizes + # are unbacked, because typically the size will be non-zero. However, + # this cannot be done directly as below as we'll choke on the size_hint + # here + if x_size_product > 0 and not free_unbacked_symbols(sizes): + # maybe realize input before broadcasting it + x.mark_reuse( + V.graph.sizevars.size_hint_or_throw(sympy_product(sizes)) + // x_size_product + ) + return TensorBox(ExpandView.create(x.data, tuple(sizes))) + + +@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None) +def broadcast_in_dim(a, shape, broadcast_dimensions): + s = list(shape) + for broadcast_dimension in broadcast_dimensions: + s[broadcast_dimension] = -1 + + v = a + for idx, x in enumerate(s): + if x != -1: + v = unsqueeze(v, idx) + + return expand(v, shape) + + +@register_lowering(aten.expand_as, type_promotion_kind=None) +def expand_as(x, y): + return expand(x, y.get_size()) + + +@register_lowering(aten.repeat) +def repeat(x, repeats): + old_size = list(x.get_size()) + if len(repeats) > len(old_size): + old_size = [sympy.S.One] * (len(repeats) - len(old_size)) + old_size + x = view(x, list(old_size)) + assert len(repeats) == len(x.get_size()) + + new_size = list(x.get_size()) + + zero_tensor = False + for i in range(len(repeats)): + if repeats[i] == 0: + zero_tensor = True + new_size[i] = new_size[i] * repeats[i] + + if zero_tensor: + return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) + if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): + return clone(expand(x, new_size)) + + x_loader: Callable[[Any], Any] + + def inner_fn(index): + assert len(index) == len(repeats) + index = list(index) + for i in range(len(repeats)): + if repeats[i] != 1: + if old_size[i] == 1: + index[i] = sympy.S.Zero + else: + index[i] = ModularIndexing(index[i], 1, old_size[i]) + return x_loader(index) + + if not free_unbacked_symbols(old_size) and not free_unbacked_symbols(new_size): + old_size_product = V.graph.sizevars.size_hint_or_throw(sympy_product(old_size)) + if old_size_product > 0: + # maybe realize the input but skip for unbacked symints since it'll + # choke on the size hint. + x.mark_reuse( + V.graph.sizevars.size_hint_or_throw(sympy_product(new_size)) + // old_size_product + ) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(new_size), + ) + + +@register_lowering(aten._unsafe_view, type_promotion_kind=None) +@register_lowering(aten.view, type_promotion_kind=None) +@register_lowering(aten.reshape, type_promotion_kind=None) +def view(x, sizes): + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + return TensorBox(View.create(x.data, sizes)) + + +@register_lowering(aten.permute, type_promotion_kind=None) +def permute(x, dims): + assert isinstance(x, TensorBox) + assert isinstance(dims, (list, tuple)) + return TensorBox(PermuteView.create(x.data, tuple(dims))) + + +@register_lowering(aten.slice, type_promotion_kind=None) +def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): + assert isinstance(x, TensorBox) + dim = _validate_dim(x, dim, 0) + return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)) + + +@register_lowering(aten.as_strided, type_promotion_kind=None) +def as_strided(x, size, stride, storage_offset=None): + if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView): + # as_strided ignores views + x = x.data.unwrap_view() + x.realize() + if not ir.is_storage_and_layout(x): + raise NotImplementedError(f"unrealized as_strided({x}, ...)") + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + old_layout.device, + old_layout.dtype, + [sympy.expand(s) for s in size], + [sympy.expand(s) for s in stride], + sympy.expand(storage_offset or 0), + ) + return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout)) + + +@register_lowering(aten.as_strided_, type_promotion_kind=None) +def as_strided_(x, size, stride, storage_offset=None): + assert isinstance(x, TensorBox) + x.data = as_strided(x, size, stride, storage_offset).data + return x + + +@register_lowering(aten.as_strided_copy, type_promotion_kind=None) +def as_strided_copy(x, size, stride, storage_offset=None): + result = as_strided(x, size, stride, storage_offset) + return clone(result) + + +def pointwise_cat(inputs, dim=0): + # (inclusive, exclusive) + inputs_ranges: list[tuple[sympy.Expr, sympy.Expr]] = [] + prev_end = 0 + for inp in inputs: + inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] + prev_end = inputs_ranges[-1][-1] # type: ignore[assignment] + + inputs_loaders = [inp.make_loader() for inp in inputs] + + def inner_fn(idx): + idx_dim = ops.index_expr(idx[dim], torch.int64) + + masks = [] + masked_loads = [] + for i in range(len(inputs)): + start = ( + ops.constant(0, torch.int64) + if i == 0 + else ops.index_expr(inputs_ranges[i][0], torch.int64) + ) + end = ops.index_expr(inputs_ranges[i][1], torch.int64) + + start_cond = ops.ge(idx_dim, start) + end_cond = ops.lt(idx_dim, end) + if i == 0: + mask = end_cond + elif i == len(inputs) - 1: + mask = start_cond + else: + mask = ops.and_(start_cond, end_cond) + + masks.append(mask) + idx_load = list(idx) + + # if we're concatting [4], [2] + # when we index the second tensor for 5 we want to index 5 - 4 + # Use Identity to prevent expansion of index * stride to keep expression + # in same int bitwidth as shape + idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0]) + + masked_loads.append( + ops.masked( + mask, + lambda: inputs_loaders[i](idx_load), + 0.0, # this value should be unused + ), + ) + + next_val = masked_loads[-1] + for i in range((len(inputs)) - 2, -1, -1): + next_val = ops.where( + masks[i], + masked_loads[i], + next_val, + ) + return next_val + + new_size = list(inputs[0].get_size()) + new_size[dim] = inputs_ranges[-1][-1] + + return Pointwise.create( + device=inputs[0].get_device(), + dtype=inputs[0].get_dtype(), + inner_fn=inner_fn, + ranges=new_size, + ) + + +@register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None) +def quantized_decomposed_quantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert input.get_dtype() == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + ) + assert axis < len(input.get_size()), ( + f"Expecting axis to be < {len(input.get_size())}" + ) + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.int32: + zero_point = ops.to_dtype(zero_point, torch.int32) + inv_scale = ops.reciprocal(scale) + val = ops.round(input * inv_scale) + zero_point + clamped = ops.maximum(qmin, ops.minimum(qmax, val)) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_channel, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> TensorBox: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + assert input.get_dtype() == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + ) + assert axis < len(input.get_size()), ( + f"Expecting axis to be < {len(input.get_size())}" + ) + + if out_dtype is None: + out_dtype = torch.float32 + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.float32: + zero_point = ops.to_dtype(zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + val = ops.to_dtype(val, out_dtype) + return val + + return Pointwise.create( + device=input.get_device(), + dtype=out_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert input.get_dtype() == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + ) + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> TensorBox: + assert input.get_dtype() == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + ) + + if out_dtype is None: + out_dtype = torch.float32 + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + val = ops.to_dtype(val, out_dtype) + return val + + return Pointwise.create( + device=input.get_device(), + dtype=out_dtype, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> TensorBox: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert input.get_dtype() == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + ) + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.round(input * ops.reciprocal(_scale)) + _zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> TensorBox: + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + assert input.get_dtype() == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + ) + + if out_dtype is None: + out_dtype = torch.float32 + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale + val = ops.to_dtype(val, out_dtype) + return val + + return Pointwise.create( + device=input.get_device(), + dtype=out_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering(aten.cat) +def cat(inputs, dim=0): + cpu_device = inputs[0].get_device().type == "cpu" + if cpu_device and all( + input.get_dtype() in [torch.int8, torch.uint8] for input in inputs + ): + # TODO Remove this fallback when we support vectorization + # code gen with uint8 data type directly. + for input in inputs: + input.realize() + if all(len(input.get_size()) == 4 for input in inputs): + inputs, _ = require_channels_last(aten.cat, *inputs) + return fallback_handler(aten.cat.default)(inputs, dim) + + if len(inputs) == 1: + return clone(inputs[0]) + + dim = _validate_dim(inputs[0], dim, 0) + dtype = get_promoted_dtype( + *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + inputs = [to_dtype(inp, dtype) for inp in inputs] + + def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode: + if isinstance(x, TensorBox): + if isinstance(x.data, ir.BaseView): + return x.data.unwrap_view() + else: + return x.data + + if isinstance(x, ir.StorageBox): + return x.data + + return x + + def is_reduction(t): + return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction) + + def can_fuse_reduction(t): + if isinstance(t, (TensorBox, ir.StorageBox)): + return can_fuse_reduction(unwrap_tensor(t)) + return ( + is_reduction(t) + or isinstance(t, ir.Pointwise) + and any( + can_fuse_reduction(V.graph.get_buffer(read)) + for read in t.get_read_names() + ) + ) + + # fusing reducutions into computed concat buffer can cause regressions. + fusable_reduction = any(can_fuse_reduction(t) for t in inputs) + + def should_lower_cat_input(x) -> bool: + # Unrealized inputs will not be storage and layouts, and we dont want to realize + # them in case we want to fuse + if ir.is_storage_and_layout(x): + storage, _ = ir.as_storage_and_layout(x, freeze=False) + return not ir.ConcatKernel.can_realize_into_without_copy(storage) + + if isinstance(x, (TensorBox, ir.StorageBox)): + return should_lower_cat_input(unwrap_tensor(x)) + + if isinstance(x, ir.Pointwise): + return True + + return False + + if config.force_pointwise_cat: + return pointwise_cat(inputs, dim) + + # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it. + # We will revisit this later after enabling vectorization on index_expr. + if cpu_device: + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + def op_count(x): + if isinstance(x, (TensorBox, ir.StorageBox)): + return op_count(unwrap_tensor(x)) + + # this will correspond to a direct memory read + if not isinstance(x, ir.Pointwise): + return 0 + + count = x.inner_fn_opcount().num_ops + for read in x.get_read_names(): + count += op_count(V.graph.get_buffer(read)) + + return count + + # as of inputs increase, possibility for register spilling also increases + # past a certain threshold of inputs we only fuse if the if the input kernels + # are simple + # not sure if we want to expose to users via config since logic may change in future + MAX_COMPLEX_POINTWISE_CAT = 8 + MAX_SIMPLE_OP_COUNT = 2 + + def additional_pointwise_ops(op: torch._ops.OpOverload): + return op in (aten.cat.default, aten.constant_pad_nd.default) + + if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or ( + (len(inputs) <= config.max_pointwise_cat_inputs) + and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs) + ): + pointwise_uses = all( + is_pointwise_use(use, additional_pointwise_ops) + for use in V.current_node.users + ) + # fuse in case we will be used in a pointwise node, and there are any inputs we + # we can prevent materialization of. + fuse_pointwise_use = ( + any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses + ) + + # horizontal fuse in case all inputs will require a copy kernel anyway. + # only horizontally fuse pointwise kernels + horizontal_fuse_cat = all( + should_lower_cat_input(inp) for inp in inputs + ) and not any(can_fuse_reduction(t) for t in inputs) + if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction): + return pointwise_cat(inputs, dim) + + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + +@register_lowering(aten.diagonal, type_promotion_kind=None) +def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + original_shape = input.get_size() + num_dims = len(original_shape) + dim1 = canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = canonicalize_dim(idx=dim2, rank=num_dims) + + check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0)) + if offset_negative: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1] + offset, original_shape[dim2] + ), + 0, # type: ignore[arg-type] + ) + else: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1], original_shape[dim2] - offset + ), + 0, # type: ignore[arg-type] + ) + + base_idx = (0, 0) + if offset_negative: + base_idx = (-offset, 0) + else: + base_idx = (0, offset) + + sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)] + sizes.append(diag_size) + + def reindexer(idx): + diag_idx = idx[-1] + original_idx = [0] * len(original_shape) + cur_dim = 0 + for d in range(num_dims): + if d == dim1: + original_idx[d] = diag_idx + base_idx[0] + elif d == dim2: + original_idx[d] = diag_idx + base_idx[1] + else: + original_idx[d] = idx[cur_dim] + cur_dim += 1 + + assert cur_dim == len(original_shape) - 2 + return original_idx + + return TensorBox(ir.GenericView.create(input, sizes, reindexer)) + + +@register_lowering(aten.diagonal_copy, type_promotion_kind=None) +def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + return clone(diagonal(input, offset, dim1, dim2)) + + +@register_lowering(aten.diagonal_scatter, type_promotion_kind=None) +def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): + output = clone(input) + target = diagonal(output, offset, dim1, dim2) + mutate_to(target, src) + return output + + +@register_lowering(aten.select, type_promotion_kind=None) +def select(x, dim, idx): + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) + + +@register_lowering(aten.split, type_promotion_kind=None) +def split(x, sizes, dim=0): + dim = _validate_dim(x, dim, 0) + sizes_ = sizes + + # If sizes is an integer (or a SymInt), we turn it into a list of sizes + # by computing what the actual size of each chunk should be. + if not isinstance(sizes, (list, tuple)): + x_size = x.get_size()[dim] + chunks = V.graph.sizevars.evaluate_static_shape( + FloorDiv(x_size + sizes - 1, sizes) + ) + sizes_ = [sizes] * chunks + # The last chunk might have a smaller size than the rest. + sizes_[-1] = x_size - (chunks - 1) * sizes + + # From this point, we assume that the sum of the sizes of all chunks + # equals the size of the base tensor. + result = [] + start = 0 + for size in sizes_: + end = start + size + # No need for clamping here, since we compute the exact + # start and end values. + result.append(slice_(x, dim, start, end, clamp=False)) + start = end + return result + + +@register_lowering(aten.split_with_sizes, type_promotion_kind=None) +def split_with_sizes(x, sizes, dim=0): + return split(x, sizes, dim) + + +@register_lowering(aten.unbind, type_promotion_kind=None) +def unbind(x, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) + result = [select(x, dim, i) for i in range(x_size)] + return result + + +@register_lowering(aten.unfold, type_promotion_kind=None) +def unfold(x, dimension, size, step): + sizes = x.get_size() + ndim = len(sizes) + dim = canonicalize_dim(ndim, dimension) + + if ndim == 0: + return slice_(unsqueeze(x, 0), end=size) + + dim_size = sizes[dim] + sizevars = V.graph.sizevars + sizevars.guard_leq(size, dim_size) + sizevars.guard_lt(0, step) # type: ignore[arg-type] + + new_dim_size = FloorDiv(dim_size - size, step) + 1 + if sizevars.size_hint_or_throw(dim_size) > 0: + x.mark_reuse( + sizevars.size_hint_or_throw(CeilDiv(new_dim_size * size, dim_size)) + ) + + out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size] + + def reindexer(idx): + dim_idx = idx[-1] + idx[dim] * step + return (*idx[:dim], dim_idx, *idx[dim + 1 : -1]) + + return TensorBox(ir.GenericView.create(x, out_size, reindexer)) + + +@register_lowering(aten.unsqueeze, type_promotion_kind=None) +def unsqueeze(x, dim): + dim = _validate_dim(x, dim, 1) + new_shape = list(x.get_size()) + new_shape.insert(dim, sympy.S.One) + return view(x, new_shape) + + +@register_lowering(aten.unsqueeze_, type_promotion_kind=None) +def unsqueeze_(x, dim): + val = unsqueeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +def _validate_dim(x, dim, offset=0): + dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim)) + ndim = len(x.get_size()) + if dim < 0: + dim += ndim + offset + assert 0 <= dim < ndim + offset + return dim + + +@register_lowering(aten.glu) +def glu(x, dim=-1): + dim = _validate_dim(x, dim, 0) + # TODO: don't guard on static shape here + new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 + a = slice_(x, dim, 0, new_len) + b = slice_(x, dim, new_len, new_len * 2) + return mul(a, sigmoid(b)) + + +def fallback_handler(kernel, add_to_fallback_set=True): + if add_to_fallback_set: + fallbacks.add(kernel) + + def handler(*args, **kwargs): + def wrap_tensors(x): + return TensorBox.create(x) if isinstance(x, ir.IRNode) else x + + return pytree.tree_map( + wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs) + ) + + # This lets us detect that a lowering is a fallback handler. + handler._is_fallback_handler = True # type: ignore[attr-defined] + + return handler + + +@functools.cache +def _warn_complex_not_supported(): + warnings.warn( + "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." + ) + + +# There are some types (CPU) which we accept as input but not as +# output. +def unsupported_input_tensor(t: torch.Tensor, node=None): + "Do not support reading or writing to this tensor" + if t.is_complex(): + # Complex views are supported with IR ComplexView + _warn_complex_not_supported() + return True + + if t.is_meta: + return True + + if t.dtype == torch.float8_e8m0fnu: + if not node: + return True + + # allow bitcast, views, memory movement, but not arithmetic + # TODO: delete once triton adds native support + return not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target + in ( + aten.view.dtype, + aten.cat.default, + aten.clone.default, + aten._scaled_mm.default, + ) + or (isinstance(node.target, torch._ops.OpOverload) and is_view(node.target)) + ) + + return False + + +def unsupported_output_tensor(t: torch.Tensor, node=None): + "Do not support writing tensor but can read from it" + supported_complex_views = ( + aten.view.dtype, + torch.ops.prims.convert_element_type.default, + ) + if node is not None and node.target in supported_complex_views and t.is_complex(): + return False + if unsupported_input_tensor(t, node): + return True + return t.is_cpu and config.disable_cpp_codegen + + +def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True): + # Custom fallback lowering + if node.target is aten.view_as_complex.default: + return False + + if node.op == "placeholder": + return False + + # We should be able to remove this special case once `disable_cpp_codegen` is killed. + if node.target is aten.lift_fresh_copy.default: + return False + + def check_skip_condition(inp_out_node, is_output): + if not isinstance(inp_out_node, torch.fx.Node): + return False + + if "val" not in inp_out_node.meta: + return False + + for meta in pytree.tree_leaves(inp_out_node.meta["val"]): + if not isinstance(meta, torch._subclasses.FakeTensor): + continue + + if is_output: + if unsupported_output_tensor(meta, node): + return True + else: + if unsupported_input_tensor(meta, node): + return True + + return False + + # only skip codegen if there is a cpu output, not input + for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs): + if check_skip_condition(arg, is_output=False): + return True + + return check_skip_condition(node, is_output=True) + + +def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False): + assert op not in decompositions or override_decomp, ( + f"both a fallback and a decomp for same op: {op}" + ) + if ( + warn + and bool(os.getenv("CI")) + and get_decompositions([op]) + # if fallback_random, we allow not decomposing random + and not ( + config.fallback_random + and op in torch._decomp.decompositions_for_rng.extra_random_decomps + ) + and not override_decomp + ): + # Note: 'warn' is holdover from when this was a warning, but for ops that previously + # set warn=False we do not want a CI error. + # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not + # likely to be triggered preferentially on one CI config over another. + if torch._dynamo.config.suppress_errors: + torch._dynamo.config.suppress_errors = False + log.warning( + "A make_fallback error occurred in suppress_errors config," + " and suppress_errors is being disabled to surface it." + ) + raise AssertionError( + f"make_fallback({op}): a decomposition exists, we should switch to it." + " To fix this error, either add a decomposition to core_aten_decompositions (preferred)" + " or inductor_decompositions, and delete the corresponding `make_fallback` line." + " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.", + ) + + def register_fallback(op_overload): + add_needs_realized_inputs(op_overload) + if layout_constraint is not None: + add_layout_constraint(op_overload, layout_constraint) + return register_lowering(op_overload, type_promotion_kind=None)( + fallback_handler(op_overload) + ) + + if isinstance(op, torch._ops.OpOverloadPacket): + for ol in op.overloads(): + op_overload = getattr(op, ol) + register_fallback(op_overload) + elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + register_fallback(op) + else: + raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}") + + +def philox_rand_offset(shape): + """ + TorchInductor offset calculation differs from PyTorch eager offset + calculation for random ops (tl.rand vs torch.rand). In future, we should + strive for same impl for tl.rand and torch.rand. + """ + numel = 1 + for s in shape: + numel = numel * s + return tensor(numel, dtype=torch.int64) + + +@register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None) +def philox_rand(size, seed, offset, stride, device, dtype): + # stride arg is optional and will be used in future for distributed random + # ops. Currently, its unused. + random_pos = ir.FixedLayout( + device, + dtype, + size, + ir.FlexibleLayout.contiguous_strides(size), + ).make_indexer() + seed_loader = seed.make_loader() + offset_loader = offset.make_loader() + + def inner_fn(index): + # Both seed and offset in the philox_rand op are tensors. + # torch seed and offsets are of type int64, but tl.rand accepts int32 + seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32) + offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32) + # Get the offset'd position + rand_index_expr = ops.add( + ops.index_expr(random_pos(index), torch.int32), offset_index_expr + ) + result = ops.rand( + seed_index_expr, + rand_index_expr, + ) + return ops.to_dtype(result, dtype) + + random_values_node = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + offset_node = philox_rand_offset(size) + return random_values_node, offset_node + + +@register_lowering(aten.native_dropout, type_promotion_kind=None) +def native_dropout(x, p, train): + if config.fallback_random: + return pytree.tree_map( + TensorBox.create, + ir.FallbackKernel.create(aten.native_dropout.default, x, p, train), + ) + else: + raise AssertionError("should be handled in replace_random.py") + + +@register_lowering(aten.bernoulli_, type_promotion_kind=None) +def bernoulli_(x, *args): + assert config.fallback_random or x.get_device() == torch.device("cpu"), ( + "this should be handled in decomps unless config.fallback_random or the device is CPU" + ) + x.realize() + op_overload = ( + aten.bernoulli_.float + if len(args) == 0 or isinstance(args[0], float) + else aten.bernoulli_.Tensor + ) + ir.InplaceBernoulliFallback(op_overload, x, *args) + return x + + +@register_lowering(aten.bernoulli.p, type_promotion_kind=None) +def bernoulli_p(x, *args): + assert config.fallback_random or x.get_device() == torch.device("cpu"), ( + "this should be handled in decomps unless config.fallback_random or the device is CPU" + ) + return bernoulli_(clone(x), *args) + + +# This shouldn't be called in general +@register_lowering(aten._foobar) +def _foobar(_): + raise AssertionError + + +@functools.lru_cache(1) +def _warn_triton_random(salt): + log.info("using triton random, expect difference from eager") + + +def warn_triton_random(): + # only warn once per graph + _warn_triton_random(V.graph.creation_time) + + +fallback_rand_default = fallback_handler(aten.rand.default) +fallback_rand_generator = fallback_handler(aten.rand.generator) +fallback_randn_default = fallback_handler(aten.randn.default) +fallback_randn_generator = fallback_handler(aten.randn.generator) +make_fallback(aten.randint) + + +@register_lowering(aten.rand) +def rand(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_rand_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_rand_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(aten.randn) +def randn(*args, **kwargs): + if kwargs.get("generator", None) is not None: + return fallback_randn_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_randn_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None) +def inductor_force_stride_order(input_tensor, stride): + stride_order = ir.get_stride_order(stride) + return ir.ExternKernel.require_stride_order(input_tensor, stride_order) + + +@register_lowering(inductor_prims.seed, type_promotion_kind=None) +def inductor_seed(device: torch.device): + raise AssertionError("should be handled in fuse_seed_creation_pass()") + + +@register_lowering(inductor_prims.seeds, type_promotion_kind=None) +def inductor_seeds(count, device): + warn_triton_random() + return TensorBox.create(ir.RandomSeeds(count, decode_device(device))) + + +@register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None) +def inductor_lookup_seed(seeds, index): + def inner_fn(_): + return ops.load_seed(seeds.get_name(), index) + + return Pointwise.create( + device=seeds.get_device(), + dtype=seeds.get_dtype(), + inner_fn=inner_fn, + ranges=[], + ) + + +@register_lowering(inductor_prims.random, type_promotion_kind=None) +def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0): + assert not config.fallback_random + assert mode in ("rand", "randn") + size = [*size] + dtype = torch.float32 + device = seed.get_device_or_error() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return getattr(ops, mode)( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ) + + result = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + result.realize() + return result + + +@register_lowering(inductor_prims.randint, type_promotion_kind=None) +def inductor_randint( + low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0 +): + assert not config.fallback_random + size = [*size] + dtype = torch.int64 + device = seed.get_device_or_error() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return ops.randint64( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ops.index_expr(low, torch.int64), + ops.index_expr(high, torch.int64), + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + + +def _boundaries_helper(tb: TensorBox) -> tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]: + return ( + tb.get_name(), + tb.get_size()[-1], + tb.get_size()[0] * tb.get_stride()[0], + tb.get_stride()[-1], + ) + + +def _sorter_helper(tb: TensorBox) -> tuple[str, sympy.Expr]: + return tb.get_name(), tb.get_stride()[-1] + + +@register_lowering(aten.searchsorted.Tensor, type_promotion_kind=None) +def searchsorted( + sorted_sequence: TensorBox, + self: TensorBox, + *, + out_int32: bool = False, + right: bool = False, + side: Optional[str] = None, + sorter: Optional[TensorBox] = None, +) -> TensorBox: + validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731 + tb, BackendFeature.BUCKETIZE + ) + if ( + not validate_bucketize(sorted_sequence) + or not validate_bucketize(self) + or (sorter is not None and not validate_bucketize(sorter)) + ): + return fallback_handler(aten.searchsorted.Tensor, add_to_fallback_set=False)( + sorted_sequence, + self, + out_int32=out_int32, + right=right, + side=side, + sorter=sorter, + ) + + # If side is present, override the value of right if needed. This assumes that + # validation of the two options being non-contradictory is already done by the + # searchsorted meta-function. + if side is not None and side == "right": + right = True + + index_dtype = torch.int32 if out_int32 else torch.int64 + values_loader = self.make_loader() + + # The entire sorted_sequence tensor needs to be used by ops.bucketize, so we need to + # realize it into global memory; or in other words, we can't guarantee that + # sorted_sequence.get_name() (used below) will exist unless we call + # sorted_sequence.realize(). + sorted_sequence.realize() + + if sorter is not None: + sorter.realize() + + if len(sorted_sequence.get_size()) == 1: + + def inner_fn(idx): + val = values_loader(idx) + return ops.bucketize( + val, + _boundaries_helper(sorted_sequence), + 0, + index_dtype, + right, + sorter=None if sorter is None else _sorter_helper(sorter), + sorter_indices=None if sorter is None else 0, + ) + + else: + + def inner_fn(idx): + val = values_loader(idx) + + # Get index to the beginning of the sorted sequence within a flattened + # version of the array. + def get_flattened_index(tb: TensorBox): + strides = tb.get_stride() + return ops.index_expr( + functools.reduce( + operator.add, (s * i for s, i in zip(strides[:-1], idx[:-1])) + ), + index_dtype, + ) + + return ops.bucketize( + val, + _boundaries_helper(sorted_sequence), + get_flattened_index(sorted_sequence), + index_dtype, + right, + sorter=None if sorter is None else _sorter_helper(sorter), + sorter_indices=None if sorter is None else get_flattened_index(sorter), + ) + + device = self.get_device() + result = Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=self.shape, + ) + # see [NOTE: inductor bucketize realize] + result.realize() + + return result + + +@register_lowering( + aten.bucketize, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH +) +def bucketize( + input: TensorBox, + boundaries: TensorBox, + *, + out_int32: bool = False, + right: bool = False, +): + assert len(boundaries.get_size()) == 1 + + if not ( + V.graph.has_feature(input, BackendFeature.BUCKETIZE) + and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE) + ): + return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)( + input, boundaries, out_int32=out_int32, right=right + ) + + # The entire boundaries tensor needs to be used by ops.bucketize, so we + # need to realize it into global memory; or in other words, we can't + # guarantee that boundaries.get_name() (used below) will exist unless + # we call boundaries.realize(). + boundaries.realize() + device = input.get_device() + input_loader = input.make_loader() + + index_dtype = torch.int32 if out_int32 else torch.int64 + + def inner_fn(index): + val = input_loader(index) + indices = ops.bucketize( + val, + _boundaries_helper(boundaries), + 0, + index_dtype, + right, + ) + + return indices + + result = Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + # [NOTE: inductor bucketize realize] + # bucketize_binary_search is relatively expensive, so we don't want to re-compute + # it unnecessarily. If we run bucketize() and then broadcast the result, we don't + # want this to be fused into a large number of duplicate bucketize() computations + # for each of the elements in the result. + # + # If no broadcasting occurs, fusions can still occur in scheduler.py + result.realize() + + return result + + +def require_dense(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs) + ) + return args, kwargs + + +def require_contiguous(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs) + ) + return args, kwargs + + +def require_contiguous_strides(_, *args, **kwargs): + # TODO: combine this with require_contiguous after + # https://github.com/pytorch/pytorch/pull/148235 lands. + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs) + ) + return args, kwargs + + +def require_channels_last(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) + ) + return args, kwargs + + +def constrain_to_fake_tensor(arg, fake_arg): + if isinstance(arg, ir.IRNode): + meta_stride_expr = [ + s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride() + ] + return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr) + if isinstance(arg, dict): + return { + key: constrain_to_fake_tensor(arg[key], fake_arg[key]) for key in arg.keys() + } + elif isinstance(arg, (tuple, list)): + return type(arg)( + constrain_to_fake_tensor(a, f_a) for (a, f_a) in zip(arg, fake_arg) + ) + return arg + + +def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs): + args = tuple( + constrain_to_fake_tensor(arg, fake_arg) + for arg, fake_arg in zip(args, fake_args) + ) + kwargs = {k: constrain_to_fake_tensor(v, fake_kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +def constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, ir.IRNode): + stride_order = ir.get_stride_order( + fx_arg.meta["val"].stride(), V.graph.sizevars.shape_env + ) + return ir.ExternKernel.require_stride_order(arg, stride_order) + if isinstance(arg, dict): + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg.keys()} + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension] + + def apply_constraint(idx, arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + meta_stride_expr = [ + s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_val.stride() + ] + + stride_order = ir.get_stride_order(meta_val.stride()) + + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + if ( + fx_node.target + == aten._scaled_dot_product_efficient_attention_backward.default + and idx in (0, 5) + ): + assert len(stride_order) == 4 + # The 0 and 5th arguments for aten._scaled_dot_product_efficient_attention_backward.default + # are for out and gradient_out. They have to be in + # (3, 1, 2, 0) stride order. Otherwise the kernel will crash. + # Check https://github.com/pytorch/pytorch/issues/138772 + stride_order = (3, 1, 2, 0) + + if not meta_val.is_cuda: + return ir.ExternKernel.require_stride_order(arg, stride_order) + + # This is the minimum alignment required by SDPA kernels for attention_bias. + # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask + ALIGNMENT = 8 + + # effn_attn_fwd does requires dense last dim, not just alignment + effn_attn_fwd_bias = ( + fx_node.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + and idx == 3 + ) + + assert isinstance(arg, TensorBox) + if len(arg.get_size()) not in (3, 4): + return arg + + if ir.is_aligned_realized_tensor(arg, ALIGNMENT): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + if ( + isinstance(arg, IRNode) + and arg.maybe_get_stride() is not None + and ir.is_aligned_realized_tensor(arg, ALIGNMENT) + ): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + if effn_attn_fwd_bias: + out_size = list(arg.get_size()) + + expanded_dims = [] + # We require a dense last dimension, but the other strides + # can be expanded, which results in a smaller tensor + maybe_stride = arg.maybe_get_stride() + for i in range(len(arg.get_size()) - 1): + if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or ( + maybe_stride is not None + and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0) + ): + expanded_dims.append(i) + + # Now, pad strides to alignment + out_strides = [-1] * len(out_size) + out_strides[-1] = 1 + stride = 1 + for i in range(len(out_size) - 2, -1, -1): + if out_strides[i + 1] != 0: + stride = stride * out_size[i + 1] + + # the expanded dims still need to be aligned, if they are, + # we can make them expanded by setting the stride equal to 0 + if i in expanded_dims: + if V.graph.sizevars.statically_known_equals( + out_strides[i + 1] % ALIGNMENT, 0 + ): + out_strides[i] = 0 + continue + + if not V.graph.sizevars.statically_known_equals(stride % ALIGNMENT, 0): + stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT + + out_strides[i] = stride + + return ir.ExternKernel.require_exact_strides(arg, out_strides) + + if ir.is_aligned_realized_tensor(arg, ALIGNMENT): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + if ( + isinstance(arg, IRNode) + and arg.maybe_get_stride() is not None + and ir.is_aligned_realized_tensor(arg, ALIGNMENT) + ): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + def is_aligned(x): + return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 + + if isinstance(arg.data, ir.BaseView): + if not is_aligned(arg): + if is_aligned(arg.unwrap_view()): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(idx, arg, fx_arg) + for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args)) + ) + kwargs = {k: apply_constraint(-1, v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# WIP +make_fallback(aten._adaptive_avg_pool3d) # @isuruf +make_fallback(aten.adaptive_max_pool3d) # @isuruf +make_fallback(aten._scaled_dot_product_attention_math_for_mps) # @malfet + + +# 1) Easy +make_fallback(aten.uniform, warn=False) +make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py) +make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks +make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl? +make_fallback(aten._fused_rms_norm, warn=False) # (MPS-only and faster than decomp) +if torch.xpu.is_available(): + make_fallback( + aten.embedding_dense_backward, warn=False + ) # (XPU-only and faster than decomp) + + +# 1.5) Easy or Impossible +make_fallback(aten._cdist_forward) # p=2 should be feasible +make_fallback(aten._cdist_backward) + +# 2) Medium +make_fallback(aten._trilinear) + + +# 3) Difficult +# Scans +# See the discussion at +# https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19 +make_fallback(aten.segment_reduce.default) +make_fallback(aten._segment_reduce_backward.default) + +# Histogram (need to implement Histogram IR) +make_fallback(aten.histc) +make_fallback(aten.histogram.bin_ct) +make_fallback(aten._histogramdd_bin_edges.default) +make_fallback(aten._histogramdd_from_bin_cts.default) + +# Need templated kernel +make_fallback(aten.addbmm) +make_fallback(aten._addmm_activation, warn=False) + +make_fallback(aten._grouped_mm, require_dense) + +# Need templated kernel. Probably impossible to write efficiently +make_fallback(aten.convolution_backward, constrain_to_fx_strides) +make_fallback(aten._cudnn_rnn, require_dense) +make_fallback(aten._cudnn_rnn_backward, require_contiguous) + +# Haven't checked but sound difficult / impossible +make_fallback(aten._embedding_bag, require_contiguous) +make_fallback(aten._embedding_bag_forward_only, require_contiguous) +make_fallback(aten._embedding_bag_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._fused_moving_avg_obs_fq_helper) +make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) + + +# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp +make_fallback(aten.max_pool3d_with_indices_backward) +make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) +make_fallback(aten._adaptive_avg_pool3d_backward) +make_fallback(aten.adaptive_max_pool2d_backward) +make_fallback(aten.adaptive_max_pool3d_backward) +make_fallback(aten.fractional_max_pool2d_backward) +make_fallback(aten.fractional_max_pool3d_backward) +make_fallback(aten.replication_pad1d_backward) +make_fallback(aten.replication_pad2d_backward) +make_fallback(aten.upsample_linear1d_backward) +make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) +make_fallback(aten.upsample_trilinear3d_backward) +make_fallback(aten.grid_sampler_2d_backward, require_dense) +make_fallback(aten._pdist_backward) + + +# 5) Impossible (missing triton/CPU features) + +# Sorting / Sorting-like +make_fallback(aten.sort) +make_fallback(aten.sort.stable) +make_fallback(aten.kthvalue) +make_fallback(aten.topk) +make_fallback(aten.mode) +make_fallback(aten.median) +make_fallback(aten.nanmedian) +make_fallback(aten.randperm) +# see: https://github.com/pytorch/pytorch/pull/121354 +make_fallback(aten.resize_) +make_fallback(aten.resize_as_) + +# Linalg +make_fallback(aten._linalg_det) +make_fallback(aten.linalg_householder_product) +make_fallback(aten.linalg_inv_ex) +make_fallback(aten.linalg_ldl_factor_ex) +make_fallback(aten.linalg_ldl_solve) +make_fallback(aten.linalg_lu) +make_fallback(aten.linalg_lu_factor_ex) +make_fallback(aten.linalg_lu_solve) +make_fallback(aten.linalg_matrix_exp) +make_fallback(aten.linalg_qr) +make_fallback(aten._linalg_slogdet) +make_fallback(aten._linalg_solve_ex) +make_fallback(aten.linalg_solve_triangular) +make_fallback(aten._linalg_svd) +make_fallback(aten.lu_unpack) +make_fallback(aten.ormqr) +make_fallback(aten._linalg_check_errors) +make_fallback(aten.linalg_pinv.atol_rtol_tensor) +make_fallback(aten._linalg_eigh) +make_fallback(aten.triangular_solve) +make_fallback(aten.linalg_cholesky_ex) +make_fallback(aten.cholesky_inverse) +make_fallback(aten.cholesky_solve) +make_fallback(aten.geqrf) +make_fallback(aten._fft_r2c) # needs complex as well + +# Data dependent (are these necessary?) +make_fallback(aten.nonzero.default) + +# Misc +make_fallback(aten.gcd.default, warn=False) +make_fallback(aten._thnn_fused_lstm_cell, require_dense) +make_fallback(torch._prims.rng_prims.run_and_save_rng_state) +make_fallback(torch._prims.rng_prims.run_with_rng_state) +make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state) + + +# Implemented / Half implemented +# Scans. Implemented for CUDA, missing CPU +make_fallback(aten.masked_scatter) +make_fallback(aten.masked_scatter_backward) + +# Complex number support +make_fallback(aten.view_as_complex, require_contiguous) +make_fallback(aten.angle) # needs complex + +# Needs efficentzerotensor +make_fallback(aten._efficientzerotensor) + +# Needs Sparse +make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) +make_fallback(aten.to_sparse) +make_fallback(aten._to_sparse) + +# Needs dimname support +make_fallback(aten.zeros.names) + +# 6) Pattern-matched +make_fallback( + aten._scaled_dot_product_efficient_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_fused_attention_overrideable.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_fused_attention_overrideable_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback(aten._flash_attention_forward.default, sdpa_constraint) +make_fallback(aten._flash_attention_backward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_forward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_backward.default, sdpa_constraint) + +# index_reduce requires fallback when use_scatter_fallback(...) returns True +make_fallback(aten.index_reduce) + + +# Register with type_promotion_kind None. +# For example, fp16.copy_(fp32) should **not** promote the first input's dtype. +@register_lowering(aten.copy, type_promotion_kind=None) +def copy(self, src, non_blocking=False): + x = src + if self.get_device() != src.get_device(): + x = to_device(x, self.get_device()) + if self.get_dtype() != src.get_dtype(): + x = to_dtype(x, self.get_dtype()) + + if self.get_size() != src.get_size(): + out = expand(x, self.get_size()) + return clone(out) + return clone(x) + + +@register_lowering(aten.clone) +def clone(x, *, memory_format=None): + # TODO(jansel): memory format + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=list(x.get_size()), + ) + + +def clone_preserve_reinterpret_view(x): + reinterpret_view_layouts = [] + if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView): + x = x.data # unwrap TensorBox + while isinstance(x, ir.ReinterpretView): + reinterpret_view_layouts.append(x.get_layout()) + x = x.data + x = TensorBox(x) + + x = clone(x) + + if reinterpret_view_layouts: + x = x.data # unwrap TensorBox + for layout in reinterpret_view_layouts[::-1]: + x = ir.ReinterpretView(data=x, layout=layout) + x = TensorBox(x) + + return x + + +if hasattr(aten, "lift_fresh_copy"): + register_lowering(aten.lift_fresh_copy)(clone) + + +@register_lowering(prims.iota) +def iota( + length, + *, + start, + step, + dtype, + device, + requires_grad, +): + def fn(index): + return ops.index_expr(step * index[0] + start, dtype=dtype) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=fn, + ranges=[length], + ) + + +@register_lowering(aten.select_scatter, type_promotion_kind=None) +def select_scatter(x, src, dim: int, index: int): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): + index = index + x.get_size()[dim] + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + src = expand(unsqueeze(src, dim), x.get_size()) + src_loader = src.make_loader() + + def inner_fn(idx): + return ops.where( + ops.eq( + ops.index_expr(idx[dim], torch.int32), + ops.index_expr(index, torch.int32), + ), + src_loader(idx), + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +@register_lowering(aten.slice_scatter, type_promotion_kind=None) +def slice_scatter(x, src, dim=0, start=None, end=None, step=1): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + + start, end = ir.SliceView.normalize_start_end(x, dim, start, end) + + src_size = list(x.get_size()) + src_size[dim] = FloorDiv(end - start + (step - 1), step) + src = expand(src, src_size) + src_loader = src.make_loader() + + def inner_fn(idx): + if start == 0 and end == dim_size and step == 1: + # selecting every element is the same as just src.clone() + return src_loader(idx) + + idx_dim = ops.index_expr(idx[dim], torch.int64) + src_idx = list(idx) + src_idx[dim] = FloorDiv(idx[dim] - start, step) + + mask = [] + if start != 0: + mask.append( + ops.ge( + idx_dim, + ops.index_expr(sympy.expand(start), torch.int64), + ) + ) + if end != dim_size: + mask.append( + ops.lt( + idx_dim, + ops.index_expr(sympy.expand(end), torch.int64), + ) + ) + if step != 1: + mask.append( + ops.eq( + ops.index_expr( + ModularIndexing(idx[dim] - start, 1, step), torch.int64 + ), + ops.constant(0, torch.int64), + ) + ) + assert mask + mask = functools.reduce(ops.and_, mask) + src_val = ops.masked( + mask, + lambda: src_loader(src_idx), + 0 if is_integer_type(x) else 0.0, + ) + return ops.where( + mask, + src_val, + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +def _unwrap(x): + if isinstance(x, (list, tuple)) and len(x) > 0: + return _unwrap(x[0]) + return x + + +@register_lowering([torch.tensor, aten.scalar_tensor]) +def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + if isinstance(_unwrap(data), int): + dtype = dtype or torch.int64 + else: + dtype = dtype or torch.get_default_dtype() + + ranges: list[sympy.Expr] = [] + + if isinstance(data, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(data, dtype) + + elif isinstance(data, (float, int)): + + def inner_fn(index): + return ops.constant(data, dtype) + + elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: + # inline small tensors + ranges.append(sympy.Integer(len(data))) + + def inner_fn(index): + def binary_search(start, end): + assert start < end + if end - start == 1: + return ops.constant(data[start], dtype) + mid = (end - start) // 2 + start + return ops.where( + ops.lt( + ops.index_expr(index[0], torch.int64), + ops.constant(mid, torch.int64), + ), + binary_search(start, mid), + binary_search(mid, end), + ) + + if len(data) == 0: + return ops.constant(0, dtype) + return binary_search(0, len(data)) + + else: + return V.graph.add_tensor_constant( + torch.tensor(data, dtype=dtype, device=device) + ) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if isinstance(data, TensorBox): + if dtype is not None: + data = to_dtype(data, dtype) + if device is not None: + data = to_device(data, device) + return data + return tensor(data, dtype=dtype, device=device) + + +@register_lowering(torch.LongTensor) +def long_tensor(data): + return tensor(data, dtype=torch.int64) + + +@register_lowering(aten._local_scalar_dense) +def _local_scalar_dense(data): + from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings + + # This is interesting! Most lowerings return tensors, so you can just + # return the buffer you allocated and it will get used (or not used, if + # it's dead.) But _local_scalar_dense (aka item) returns an int, + # not a Tensor, so you would have a type mismatch if you return a buffer; + # we are obligated to return a sympy expression instead. However, + # we need to actually codegen the .item() call somehow. We do this + # by registering a faux buffer for the DynamicScalar IR node, which is + # solely responsible for generating this .item(). The buffer is + # not used for anything (notice we discard it); at codegen time, + # the "buffer" just gets assigned None. + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] + ) + assert unbacked_bindings is not None + assert len(unbacked_bindings) == 1, unbacked_bindings + # NB: Have to be very careful here. V.graph.current_node.meta["val"] + # seemingly also contains a symbol which you want to do binding for, + # but it actually isn't. In particular, if we have later performed + # a deferred runtime assert saying that u0 == s0, you will actually + # see s0 from expr! This is bad because we need to actually generate + # the assert that says u0 == s0, so we need to know where to get u0 + # from (this call). In particular, we must use unbacked_bindings, which + # is guaranteed to have the original, unreplaced symbol in question. + # + # NB2: Another thing we have to be very careful about are symbol bindings + # that require nontrivial refinement, e.g., when you have a binding site + # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division + # in order to appropriately bind u0. This is communicated via the keypath + # in unbacked_bindings, and we need to hold onto it in order to generate + # code appropriately for this case. + binding_sym, keypath = next(iter(unbacked_bindings.items())) + buffer = ir.DynamicScalar(binding_sym, keypath, data) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + # NB: the replaced expr is OK to use directly downstream, we want + # simplifications in this case! + val = V.graph.current_node.meta["val"] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + return val.node.expr + else: + return sympy.sympify(val) + + +@register_lowering(aten._assert_scalar) +def _assert_scalar(data, msg): + # NB: These will be handled at codegen time + # Not sure if we are guaranteed to be able to serve out truth from the + # deferred_runtime_asserts, TODO: try this assert out + # See [NOTE] Codegen runtime asserts in Inductor + # assert bool(data.scalar), data + return None + + +@register_lowering(aten._assert_tensor_metadata) +def _assert_tensor_metadata( + a, size=None, stride=None, dtype=None, *, device=None, layout=None +): + return None + + +def _full(fill_value, device, dtype, size): + value = fill_value + if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): + value = value.value + + if isinstance(value, (int, float)): + + def inner_fn(index): + return ops.constant(value, dtype) + + elif isinstance(value, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(value, dtype) + + else: + assert len(value.get_size()) == 0 + value_loader = value.make_loader() + + def inner_fn(index): + return value_loader([]) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + +@register_lowering(aten.full_like, type_promotion_kind=None) +def full_like(x, fill_value, **kwargs): + return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) + + +def tensor_constructor(fill_value): + # torch.zeros, torch.ones, etc + def inner( + *size, + names=None, + dtype=None, + device=None, + layout=None, + pin_memory=False, + memory_format=None, + ): + assert_nyi(names is None, "named tensors") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + device = decode_device(device) + dtype = dtype or torch.get_default_dtype() + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + # See https://github.com/pytorch/pytorch/issues/118102 + # All sizes at lowering time should be sympy.Symbol, not SymInt! + for s in size: + assert not isinstance(s, torch.SymInt) + size = [sympy.expand(s) for s in size] + return _full(fill_value, device, dtype, size) + + return inner + + +@register_lowering([torch.empty, aten.empty]) +def empty( + *size, + names=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + assert_nyi(names is None, "named tensors") + device = decode_device(device) + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +def create_tensor_like(creation_fn): + """ + Shim to convert X_like(...) into X(...). For example zeros_like() into zeros(). + """ + + def _constant_like( + x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None + ): + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + if dtype is None: + dtype = x.get_dtype() + else: + dtype = decode_dtype(dtype) + device = device or x.get_device() + size = list(x.get_size()) + return creation_fn( + size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory + ) + + return _constant_like + + +def constant_like(fill_value): + return create_tensor_like(tensor_constructor(fill_value)) + + +empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty)) +ones_like = create_tensor_like(tensor_constructor(1)) +zeros_like = create_tensor_like(tensor_constructor(0)) + + +def new_constant(fill_value): + def _new_constant( + x, size, *, dtype=None, layout=None, device=None, pin_memory=None + ): + assert isinstance(size, (list, tuple)) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or x.get_dtype() + device = device or x.get_device() + size = [sympy.Integer(s) for s in size] + return _full(fill_value, decode_device(device), dtype, size) + + return _new_constant + + +@register_lowering(aten.new_empty) +def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, + None, + dtype=dtype, + layout=layout, + device=decode_device(device), + pin_memory=pin_memory, + ) + + +@register_lowering(aten.empty_strided) +def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + assert isinstance(size, (list, tuple)) + assert isinstance(stride, (list, tuple, type(None))) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + dtype = decode_dtype(dtype) or torch.get_default_dtype() + device = device or torch.tensor(0.0).device + device = decode_device(device) + pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) + pointwise.realize() + buffer = pointwise.data.data + # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode + buffer.data = dataclasses.replace(buffer.data, ranges=[0] * len(size)) + assert isinstance(buffer, ir.ComputedBuffer) + size = [sympy.expand(s) for s in size] + stride = ( + [sympy.expand(s) for s in stride] + if stride + else ir.FlexibleLayout.contiguous_strides(size) + ) + buffer.layout = ir.FixedLayout( + device=device, + dtype=dtype, + size=size, + stride=stride, + ) + return pointwise + + +@register_lowering(aten.new_empty_strided) +def new_empty_strided( + x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, + stride, + dtype=dtype, + layout=layout, + device=decode_device(device), + pin_memory=pin_memory, + ) + + +@register_lowering(prims.copy_strided.default) +def copy_strided(x, stride): + stride = [V.graph.sizevars.size_hint_or_throw(s) for s in stride] + stride_order = sorted(range(len(stride)), key=stride.__getitem__) + return ir.ExternKernel.require_stride_order(x, stride_order) + + +@register_lowering([torch.full, aten.full]) +def full(size, fill_value, **kwargs): + assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition" + return tensor_constructor(fill_value)(size, **kwargs) + + +@register_lowering(aten.gather, type_promotion_kind=None) +def gather(x, dim, index, sparse_grad=False): + # sparse_grad doesn't affect forward computation, + # and backward tracing is taken care of by AOT Autograd + assert isinstance(x, TensorBox) + if index.get_numel() == 0: + # Empty index case. Return an empty array with the same shape + return new_empty(x, index.get_size()) + + size = x.get_size() + offset = len(size) == 0 + dim = _validate_dim(x, dim, offset) + + if offset: + x = expand(x, [1]) + size = [1] + + x_loader = x.make_loader() + index_loader = index.make_loader() + + def fn(idx): + idx = list(idx) + gather_idx = ops.indirect_indexing(index_loader(idx), size[dim]) + if len(idx) == 0: + idx = [gather_idx] + else: + idx[dim] = gather_idx + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + ) + + +@register_lowering(aten.embedding, type_promotion_kind=None) +def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + if sparse: + return fallback_handler(aten.embedding.default)( + weight, indices, padding_idx, scale_grad_by_freq, sparse + ) + + assert not sparse + assert isinstance(weight, TensorBox) + assert isinstance(indices, TensorBox) + assert "int" in str(indices.get_dtype()) + + weight_loader = weight.make_loader() + indices_loader = indices.make_loader() + indices_ndim = len(indices.get_size()) + weight_size = weight.get_size() + new_size = [*indices.get_size(), *weight_size[1:]] + + def fn(idx): + assert len(idx) == len(new_size), f"{idx} != {new_size}" + var_index = indices_loader(idx[:indices_ndim]) + weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [ + *idx[indices_ndim:] + ] + return weight_loader(weight_idx) + + return Pointwise.create( + device=weight.get_device(), + dtype=weight.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + + +def check_and_broadcast_indices(indices, device): + assert all( + i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) + for i in indices + if i is not None + ), ( + f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" + ) + if any( + i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None + ): + raise NotImplementedError("Fallback for bool indices") + + valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)] + assert len(valid_idxs) > 0, "requires at least 1 non-None index" + new_indices = [None] * len(indices) + for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])): + # Eager allows indices to be CPU tensor when running on CUDA + # FIXME: Calling to_device(x, device) should work but + # test_advancedindex_mixed_cpu_devices still fails + if x.get_device() != device: + raise NotImplementedError("Fallback when indices is on a different device") + new_indices[i] = x + return new_indices, valid_idxs + + +def index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + x_loader, + check, + wrap_neg=True, +): + # Note that behavior of indexing differs when there are non consecutive + # tensors. In this case, the tensor index is pulled to the beginning. + # + # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7) + # x = torch.tensor[1,2] + # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will + # be pulled to the front. + non_consecutive_tensors = False + for previous, current in zip(tensor_indices, tensor_indices[1:]): + if current - previous != 1: + non_consecutive_tensors = True + + output_size = [x_size[i] for i, val in enumerate(indices) if val is None] + output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]] + + first_tensor_index = tensor_indices[0] + if non_consecutive_tensors: + output_size = tensor_size + output_size + else: + output_size = ( + output_size[:first_tensor_index] + + tensor_size + + output_size[first_tensor_index:] + ) + + def fn(idx): + assert len(idx) == len(output_size) + assert len(indices_loaders) == len(indexed_size) + + rank = len(tensor_size) + new_index = [] + first_tensor_index = tensor_indices[0] + start_offset = 0 if non_consecutive_tensors else first_tensor_index + next_idx = 0 + for i in range(tensor_indices[-1] + 1): + if i == start_offset: + next_idx += rank + if indices[i] is None: + assert next_idx < len(idx) + new_index.append(idx[next_idx]) + next_idx += 1 + else: + loader = indices_loaders[i] + assert loader is not None + size = indexed_size[i] + new_index.append( + ops.indirect_indexing( + loader(idx[start_offset : start_offset + rank]), + size, + check=check, + wrap_neg=wrap_neg, + ) + ) + new_index = [ + *new_index, + *idx[next_idx:], + ] + return new_index if x_loader is None else x_loader(new_index) + + return output_size, fn + + +def index_impl(x, indices, check): + output_size, inner_fn, _ = index_impl_helper(x, indices, check) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=output_size, + ) + + +def index_impl_helper(x, indices, check, wrap_neg=True): + assert isinstance(indices, (list, tuple)) + x_loader = x.make_loader() + indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) + assert len(tensor_indices) > 0, "Must have at least one valid idx" + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + # no guards on output size, all the guards are set in broadcast_tensors + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + + x_size = x.get_size() + + indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] + if check and 0 in indexed_size and 0 not in tensor_size: + raise IndexError("index is out of bounds for dimension with size 0") + + indexed_size = [x_size[i] for i in range(len(indices))] + output_size, index_inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + wrap_neg=wrap_neg, + ) + + def inner_fn(idx): + return x_loader(index_inner_fn(idx)) + + return output_size, inner_fn, index_inner_fn + + +@register_lowering(aten.index, type_promotion_kind=None) +def index(x, indices): + try: + return index_impl(x, indices, check=True) + except NotImplementedError: + # Fallback to ATen for boolean indexing + x.realize() + return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)( + x, indices + ) + + +@register_lowering(aten._unsafe_index, type_promotion_kind=None) +def _unsafe_index(x, indices): + return index_impl(x, indices, check=False) + + +# All the indexing decompositions are written in terms of index, index_put, and index_put_ +# We cannot have this lowering as a decomposition as it introduces +# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead +# code elimination and common subexpression elimination optimizations, which +# assume graphs to be side-effect free. More details at +# https://github.com/pytorch/torchdynamo/issues/1235 +# and +# https://github.com/pytorch/torchdynamo/issues/1863 +@register_lowering(aten.index_put, type_promotion_kind=None) +def index_put(x, indices, values, accumulate=False): + return index_put_impl_( + clone(x), indices, values, accumulate, check=True, may_realize=False + ) + + +@register_lowering(aten._unsafe_index_put) +def _unsafe_index_put(x, indices, values, accumulate=False): + return index_put_impl_( + clone(x), indices, values, accumulate, check=False, may_realize=False + ) + + +def index_put_as_masked_fill(self, indices, value, accumulate): + if value.get_device() != self.get_device(): + value = to_device(value, self.get_device()) + if accumulate: + value = add(self, value) + return mutate_to(self, where(indices[0], value, self)) + + +def index_put_fallback(self, indices, values, accumulate): + ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) + return self + + +@register_lowering(aten.index_put_, type_promotion_kind=None) +def index_put_(self, indices, values, accumulate=False): + return index_put_impl_( + self, indices, values, accumulate, check=True, may_realize=True + ) + + +@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None) +def _unsafe_index_put_(self, indices, values, accumulate=False): + return index_put_impl_( + self, indices, values, accumulate, check=False, may_realize=True + ) + + +def index_put_impl_(self, indices, values, accumulate, check, may_realize=False): + if may_realize: + + def try_get_name(x): + if isinstance(x, ir.TensorBox): + x = x.data + if isinstance(x, ir.BaseView): + x = x.unwrap_view() + if isinstance(x, ir.StorageBox): + x = x.data + return x.get_name() if isinstance(x, ir.Buffer) else None + + def indice_slice_from_randperm(indice): + # Refer to: https://github.com/pytorch/pytorch/pull/139366#discussion_r1825424660 + # For this specific pattern, indices is unique as coming from torch.randperm. + # However, as the content of the indices is unknown, we have to check this specific pattern. + if isinstance(indice, TensorBox) and isinstance(indice.data, ir.BaseView): + indice = indice.data.unwrap_view() + return ( + isinstance(indice, ir.StorageBox) + and isinstance(indice.data, ir.ExternKernel) + and getattr(indice.data, "fx_node", None) + and indice.data.fx_node.target == torch.ops.aten.randperm.default + ) + return False + + if try_get_name(self) in values.get_read_names() and not all( + indice_slice_from_randperm(indice) for indice in indices + ): + # Fix issue: https://github.com/pytorch/pytorch/issues/138908 + # When self and values have memory overlapping, indices may + # contain duplicate values, potentially causing incorrect results since + # the load of `values` might contain modified value from the store of `self`. + # To address this, store values in a temporary buffer in such cases. + values.realize() + + # Dispatch to masked fill for single boolean index with single value + if ( + values.get_numel() == 1 + and len(indices) == 1 + and indices[0].get_dtype() in (torch.bool, torch.uint8) + ): + mask = indices[0] + for _ in range(len(mask.get_size()), len(self.get_size())): + mask = unsqueeze(mask, -1) + return index_put_as_masked_fill(self, [mask], values, accumulate) + + # Fallback in torch deterministic mode + if torch.are_deterministic_algorithms_enabled(): + return index_put_fallback(self, indices, values, accumulate) + + # Fallback if there is a boolean index + for index in indices: + if index is not None and index.get_dtype() in (torch.bool, torch.uint8): + return index_put_fallback(self, indices, values, accumulate) + + x_size = self.get_size() + x_ndim = len(x_size) + + if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()): + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + self = index_put_fallback(self, indices, values, accumulate) + if x_ndim == 0: + self = view(self, []) + return self + + values = to_dtype(values, self.get_dtype()) + + try: + # Note that code will only get here when dtype is uint32 + indices, tensor_indices = check_and_broadcast_indices( + indices, self.get_device() + ) + except NotImplementedError: + return index_put_fallback(self, indices, values, accumulate) + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + + assert isinstance(self, TensorBox) + self.realize() + + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + ) + + values = expand(values, expected_vals_size) + # all guards are set above during broadcast_tensors and expand + + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add" if accumulate else None, + ) + buffer = ir.ComputedBuffer( + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if x_ndim == 0: + self = view(self, []) + return self + + +fallback__unsafe_masked_index = fallback_handler( + aten._unsafe_masked_index.default, add_to_fallback_set=False +) + +fallback__unsafe_masked_index_put_accumulate = fallback_handler( + aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False +) + + +@register_lowering(aten._unsafe_masked_index, type_promotion_kind=None) +def _unsafe_masked_index(self, mask, indices, fill): + ranges, _, _unsafe_index_fn = index_impl_helper( + self, indices, check=False, wrap_neg=False + ) + mask_loader = mask.make_loader() + self_loader = self.make_loader() + + def inner_fn(idx): + if mask.dtype != torch.bool: + mask_val = ops.to_dtype(mask_loader(idx), torch.bool) + else: + mask_val = mask_loader(idx) + return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill) + + return Pointwise.create( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + masked_value = where(mask, values, 0) + shape = x.get_size() + clamped_indices = [ + clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None + for i in range(len(indices)) + ] + # TODO: use a masked store for this. currently only triton + # supports masked stores and cpp backend does not. + return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True) + + +@make_pointwise +def clamp(a, min, max): + return ops.maximum(min, ops.minimum(max, a)) + + +@register_lowering(aten.as_strided_scatter, type_promotion_kind=None) +def as_strided_scatter(self, src, size, stride, storage_offset=None): + output = clone(self) + output_view = as_strided(output, size, stride, storage_offset) + copy_(output_view, src) + return output + + +@register_lowering(aten.scatter, type_promotion_kind=None) +def scatter(x, dim: int, index, src, **kwargs): + return scatter_(clone(x), dim, index, src, **kwargs) + + +def scatter_fallback( + op_overload: torch._ops.OpOverload, + self, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, +): + src_is_tensor = isinstance(src, TensorBox) + if use_scatter_fallback( + op_overload, + reduce, + self.get_dtype(), + cast(torch.dtype, src.get_dtype() if src_is_tensor else type(src)), + src.get_device().type if src_is_tensor else "not impl", + src_is_tensor, + ): + ir.ScatterFallback( + op_overload, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + return self + + return None + + +@register_lowering(aten.scatter_, type_promotion_kind=None) +def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None): + assert reduce in (None, "add", "multiply") + if reduce is None: + op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr] + fallback_result = scatter_fallback( + op_overload, self, dim, index, src, reduce=reduce + ) + if fallback_result is not None: + return fallback_result + + if reduce == "add": + reduce = "sum" + elif reduce == "multiply": + reduce = "prod" + return scatter_reduce_(self, dim, index, src, reduce) + + +@register_lowering(aten.scatter_add, type_promotion_kind=None) +def scatter_add(x, dim: int, index, src): + return scatter_add_(clone(x), dim, index, src) + + +@register_lowering(aten.scatter_add_, type_promotion_kind=None) +def scatter_add_(x, dim: int, index, src): + return scatter_reduce_(x, dim, index, src, "sum") + + +@register_lowering(aten.scatter_reduce, type_promotion_kind=None) +def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): + return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs) + + +@register_lowering(aten.scatter_reduce_, type_promotion_kind=None) +def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): + assert reduce in (None, "sum", "prod", "mean", "amax", "amin") + assert ( + len(aten.scatter_reduce_.overloads()) == 1 + and "two" in aten.scatter_reduce_.overloads() + ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_" + + if isinstance(src, Number): + src = full_like(self, src) + + fallback_result = scatter_fallback( + aten.scatter_reduce_.two, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + + if fallback_result: + return fallback_result + + assert isinstance(self, TensorBox) + assert "int" in str(index.get_dtype()) + + ndim = len(self.get_size()) + if ndim == 0: + self = view(self, [1]) + + if isinstance(src, TensorBox) and len(src.get_size()) == 0: + src = view(src, [1]) + + if isinstance(index, TensorBox) and len(index.get_size()) == 0: + index = view(index, [1]) + + if index.get_numel() == 0: + return self + + dim = _validate_dim(self, dim) + + self.realize() + index_loader = index.make_loader() + src_loader = src.make_loader() if isinstance(src, TensorBox) else None + + def output_indexer(idx): + # self is captured from the end of the function, so it may have 0 dim + shape = self.get_size() + ndim = len(shape) + indirect_idx = list(idx) + indirect_idx[dim] = ops.indirect_indexing( + index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False + ) + return indirect_idx + + def fn(idx): + if src_loader: + return src_loader(idx) + else: + # src is a scalar + return ops.constant(src, self.get_dtype()) + + def backend_reduce_str(reduce): + if reduce == "sum": + return "atomic_add" + else: + # TODO: Need to support more reduction type + assert reduce is None + return None + + if not include_self: + # zero out the corresponding elements first + zero_out = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=lambda index: ops.constant(0, self.get_dtype()), + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=None, + ) + buffer = ir.ComputedBuffer( + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=zero_out, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + scatter = ir.Scatter( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=backend_reduce_str(reduce), + ) + buffer = ir.ComputedBuffer( + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if ndim == 0: + self = view(self, []) + return self + + +def upsample_nearestnd( + x, + output_size, + scales_x: tuple[Optional[float], ...], + n: int = 2, + exact: bool = False, +): + x.realize_hint() # elements are reused + x_loader = x.make_loader() + i_sizes = x.get_size()[-n:] + batch = x.get_size()[:-n] + i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] + + assert len(scales_x) == n + o_sizes = output_size + + inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)] + for i, scale in enumerate(scales_x): + if scale is not None: + inv_scales[i] = 1.0 / scale + + def scale_fn(x, scale, size): + # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5) + # = floor(scale * (output_index + 0.5)) + # Nearest: input_index = floor(scale * output_index) + x = ops.index_expr(x, torch.float32) + if exact: + x = ops.add(x, ops.constant(0.5, torch.float32)) + x = ops.mul(x, ops.constant(scale, torch.float32)) + x = ops.to_dtype(x, torch.int32) + return ops.indirect_indexing(x, size, check=False) + + def fn(idx): + x = idx[-n:] + b = idx[:-n] + return x_loader( + [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]] + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[*batch, *o_sizes], + ) + + +@register_lowering(aten.upsample_nearest1d.default) +def upsample_nearest1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1) + + +@register_lowering(aten._upsample_nearest_exact1d.default) +def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True) + + +@register_lowering(aten.upsample_nearest2d.default) +def upsample_nearest2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2) + + +@register_lowering(aten._upsample_nearest_exact2d.default) +def _upsample_nearest_exact2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True) + + +@register_lowering(aten.upsample_nearest3d.default) +def upsample_nearest3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3) + + +@register_lowering(aten._upsample_nearest_exact3d.default) +def _upsample_nearest_exact3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd( + x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True + ) + + +def _create_constants(*args, dtype): + return tuple(ops.constant(a, dtype) for a in args) + + +@register_lowering(prims.rev.default) +def rev(x, dims): + # note - dims pre-canonicalized + x_loader = x.make_loader() + sizes = x.get_size() + + def loader(idx): + idx = list(idx) + assert len(idx) == len(sizes) + for dim in dims: + idx[dim] = (sizes[dim] - 1) - idx[dim] + + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=loader, + ranges=sizes, + ) + + +def inplace_constant_pad_nd( + x: TensorBox, padding: Sequence[int], fill_value: float +) -> Optional[TensorBox]: + """ + This optimization changes the semantics of padding from 'clone' + style to 'view' style. + + Thanks to functionalization, this change can still maintain numerical + correctness. + """ + + def _padding_can_be_fused(): + """ + Conservatively check if padding can be fused with downstream op. + 1. if the downstream op is a sum, then there is little benefit to + do inplace padding + 2. if the downstream op is a matmul, doing inplace padding can + save membw. + """ + current_node = V.graph.current_node + if current_node is None: + return True # be conservative + users = tuple(current_node.users) + if len(users) == 1 and users[0].target in ( + aten.mm.default, + aten.addmm.default, + ): + return False + + return True # be conservative + + if _padding_can_be_fused(): + return None + + # Only handle 2D case for now + if len(padding) != 4 or len(x.get_size()) != 2: + return None + + # No harm to realize since we already know that + # the op can not be fused into the single user. + # It need to be realized later anyways. + x.realize() + + # If x is a view (e.g. a SliceView), realizing it just realizing the + # underlying storage. x itself is still a view. + if ( + not isinstance(x, ir.TensorBox) + or not isinstance(x.data, ir.StorageBox) + or not ( + isinstance(x.data.data, ir.ComputedBuffer) + or ( + config.can_inplace_pad_graph_input + and isinstance(x.data.data, ir.InputBuffer) + ) + ) + or not x.data.data.name + ): + return None + x.freeze_layout() + + _, layout = ir.as_storage_and_layout(x) + strides = layout.stride + if strides[1] != 1: + return None + + if padding[0] != 0 or padding[2] != 0 or padding[3] != 0: + return None + + npad = padding[1] + if npad == 0: + return None + + stride0 = strides[0] + rowsize = layout.size[1] + + if stride0 < rowsize + npad: + return None + + bufname = x.data.data.name + padded_size = [layout.size[0], layout.size[1] + npad] + V.graph.buffer_to_padded_size[bufname] = padded_size + resized_x = as_strided( + x, + padded_size, + layout.stride, + layout.offset, + ) + + sliced_x = slice_(resized_x, dim=1, start=rowsize, end=rowsize + npad) + fill_(sliced_x, fill_value) + + counters["inductor"]["inplace_padding"] += 1 + return resized_x + + +@register_lowering(aten.constant_pad_nd, type_promotion_kind=None) +def constant_pad_nd(x, padding, fill_value=0): + assert (len(padding) % 2) == 0 + if all(p == 0 for p in padding): + return clone(x) + + if config.inplace_padding: + out = inplace_constant_pad_nd(x, padding, fill_value) + if out: + return out + # fall through if can not inplace the padding + + sizes = x.get_size() + + bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) + n = len(sizes) - len(bounds) + + # if padding is a complicated expression, hoist it + bounds_precomp: list[tuple[sympy.Symbol, Any]] = [] + for l, h in bounds: + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] + + output_size = list(sizes[:n]) + mask_sizes = [] + for (low, high), size in zip(bounds, sizes[n:]): + mask_sizes.append(size) + output_size.append(sympy.expand(size + low + high)) + assert len(output_size) == len(sizes) + fill_value = dtype_to_type(x.get_dtype())(fill_value) + + def mask(index): + mask = [] + for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): + if low != 0: + mask.append(range_mask_low(idx, 0)) + if high != 0: + mask.append(range_mask_high(idx, length)) + mask = functools.reduce(ops.and_, mask) + return ops.masked(mask, lambda: x_loader(index), fill_value) + + def offset_fn(index): + new_index = list(index[:n]) + for idx, (low, _high) in zip(index[n:], bounds_precomp): + new_index.append(idx - low) + assert len(new_index) == len(index) + return mask(new_index) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=offset_fn, + ranges=output_size, + ) + + +def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]): + return ops.ge( + ops.index_expr(i, torch.int64), + ops.index_expr(sympy.Integer(low), torch.int64), + ) + + +def range_mask_high(i: sympy.Expr, high: sympy.Expr): + return ops.lt( + ops.index_expr(i, torch.int64), + ops.index_expr(high, torch.int64), + ) + + +def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr): + return ops.and_( + range_mask_low(i, low), + range_mask_high(i, high), + ) + + +def constant_boundary_condition( + x, fill_value, padding=None, pad_fill_value=1.0, dim=None +): + h = x.get_size()[-dim:] + x_loader = x.make_loader() + padding_h = padding or [0] * dim + + def load(index): + prefix = index[:-dim] + ih = index[-dim:] + + mask = functools.reduce( + ops.and_, + [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)], + ) + return ( + ops.masked( + mask, + lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)( + [*prefix, *ih] + ), + fill_value, + ) + if padding + else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value) + ) + + return load + + +def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None): + if dilation is None: + dilation = [1] * len(padding) + + x_out = FloorDiv( + x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1), + stride[i], + ) + + if ceil_mode: + x_alt = FloorDiv( + x + + 2 * padding[i] + - dilation[i] * (kernel_size[i] - 1) + + 2 * (stride[i] - 1), + stride[i], + ) + if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: + # Sliding windows must start within the input or left padding + x_alt -= 1 # type: ignore[assignment] + V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] + if V.graph.sizevars.size_hint(x_out - x_alt) == 0: + # ceil mode is actually a no-op, lets guard on that + V.graph.sizevars.guard_equals(x_out, x_alt) + ceil_mode = False + else: + x_out = x_alt + return x_out, ceil_mode + + +def should_fallback_max_pool_with_indices(kernel_size, *, n_dim): + kernel_size = pad_listlike(kernel_size, n_dim) + window_size = functools.reduce(operator.mul, kernel_size) + return window_size > 25 + + +def max_pool_checks( + x, kernel_size, stride, padding, dilation, n_dim, *, assert_fallback=None +): + if padding == 0: + padding = [0] * n_dim + if dilation == 1: + dilation = [1] * n_dim + if not stride: + stride = kernel_size + + kernel_size = pad_listlike(kernel_size, n_dim) + stride = pad_listlike(stride, n_dim) + padding = pad_listlike(padding, n_dim) + dilation = pad_listlike(dilation, n_dim) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == n_dim + assert len(stride) == n_dim + assert len(padding) == n_dim + assert len(dilation) == n_dim + assert len(x.get_size()) in (n_dim + 1, n_dim + 2) + + use_fallback = should_fallback_max_pool_with_indices(kernel_size, n_dim=n_dim) + if assert_fallback is not None: + assert use_fallback == assert_fallback + + return kernel_size, stride, padding, dilation, use_fallback + + +def _max_pool_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + *, + n_dim, +): + x.realize_hint() + batch = x.shape[:-n_dim] + dhw = x.shape[-n_dim:] + + dhw_out, ceil_mode = zip( + *[ + pooling_size( + dhw[d], d, kernel_size, stride, padding, ceil_mode, dilation=dilation + ) + for d in range(n_dim) + ] + ) + + dtype = x.dtype + min_value = ( + False + if dtype is torch.bool + else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min) + ) + + new_size = list(batch) + list(dhw_out) + if any(padding) or any(ceil_mode) or any(d > 1 for d in dilation): + x_loader = constant_boundary_condition(x, min_value, dim=n_dim) + else: + x_loader = x.make_loader() + + def fn_inner(idx, reduction_idx): + prefix = idx[:-n_dim] + bh = idx[-n_dim:] + ih = [ + (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i] + for i in range(n_dim) + ] + return x_loader([*prefix, *ih]) + + result = Reduction.create( + reduction_type="max", + input_node=x, + device=x.get_device(), + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=fn_inner, + ranges=new_size, + reduction_ranges=kernel_size, + ) + offsets = Reduction.create( + reduction_type="argmax", + input_node=x, + device=x.get_device(), + dst_dtype=torch.int64, + src_dtype=dtype, + inner_fn=fn_inner, + ranges=new_size, + reduction_ranges=kernel_size, + ) + if isinstance(result.data.data, Reduction): # type: ignore[attr-defined] + # Only realize if reduction isn't unrolled + result.realize() + if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined] + # Only realize if reduction isn't unrolled + offsets.realize() + + return result, offsets + + +@register_lowering(prims._low_memory_max_pool_with_offsets, type_promotion_kind=None) +def _low_memory_max_pool_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode=False, +): + n_dim = len(kernel_size) + + # assert we are not on a fallback path, the inductor decomp should have guaranteed this + kernel_size, stride, padding, dilation, _ = max_pool_checks( + x, + kernel_size, + stride, + padding, + dilation, + n_dim, + assert_fallback=False, + ) + + with config.patch(unroll_reductions_threshold=25): + result, offsets = _max_pool_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + n_dim=n_dim, + ) + return result, to_dtype(offsets, torch.int8) + + +def _pool_offsets_to_indices( + offsets: TensorBox, + kernel_size: Sequence[Union[int, torch.SymInt]], + input_size: Sequence[Union[int, torch.SymInt]], + increments_to_index: Callable[ + [Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]], + torch._inductor.virtualized.OpsValue, + ], +) -> TensorBox: + n_dim = len(kernel_size) + offsets_loader = offsets.make_loader() + window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size)) + + def offsets_to_indices(idx): + offset = offsets_loader(idx) + offset_sympy = ops.indirect_indexing(offset, window_size) + reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size) + idhw = increments_to_index(idx, reduction_idx) + return ops.index_expr( + inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64 + ) + + indices = Pointwise.create( + device=offsets.get_device(), + dtype=torch.int64, + inner_fn=offsets_to_indices, + ranges=offsets.get_size(), + ) + return indices + + +@register_lowering( + prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None +) +def _low_memory_max_pool_offsets_to_indices( + offsets, kernel_size, input_size, stride, padding, dilation +): + # TODO: Generalize to other max pooling flavors + n_dim = len(kernel_size) + + def increments_to_index(idx, reduction_idx): + bh = idx[-n_dim:] + return [ + (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i] + for i in range(n_dim) + ] + + return _pool_offsets_to_indices( + offsets, kernel_size, input_size, increments_to_index + ) + + +def _max_pool_with_indices( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + n_dim, +): + kernel_size, stride, padding, dilation, _ = max_pool_checks( + x, kernel_size, stride, padding, dilation, n_dim=n_dim + ) + + out, offsets = _max_pool_with_offsets( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=n_dim + ) + + indices = _low_memory_max_pool_offsets_to_indices( + offsets, + kernel_size, + x.shape[-n_dim:], + stride, + padding, + dilation, + ) + + return out, indices + + +# Fallback when we do not decompose to the low-memory path. +@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None) +def max_pool2d_with_indices( + x, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, +): + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=2 + ) + + +# Fallback when we do not decompose to the low-memory path. +@register_lowering(aten.max_pool3d_with_indices, type_promotion_kind=None) +def max_pool3d_with_indices( + x, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, +): + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=3 + ) + + +fallback_max_pool2d_with_indices_backward = fallback_handler( + aten.max_pool2d_with_indices_backward.default, + add_to_fallback_set=False, +) + + +@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) +def max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + # we will read this many times, so make sure it is computed + grad_output.realize_hint() + gO_stride = grad_output.maybe_get_stride() + x_stride: Optional[Sequence[Any]] + if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] + data = x.data.data # type: ignore[attr-defined] + x_buffer = ir.ComputedBuffer( + name=None, + layout=ir.FlexibleLayout( + device=data.get_device(), + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ) + x_buffer.decide_layout() + x_stride = x_buffer.get_stride() + else: + x_stride = x.maybe_get_stride() + + is_channels_last = (x_stride is not None and x_stride[1] == 1) or ( + gO_stride is not None and gO_stride[1] == 1 + ) + if any(d != 1 for d in dilation): + # dilation NYI + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + *_batch, _height, width = x.get_size() + *_, pooled_height, pooled_width = grad_output.get_size() + + indices_loader = indices.make_loader() + grad_loader = grad_output.make_loader() + new_size = list(x.get_size()) + + h_window_size = max( + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + indices_size = indices.get_size() + + def fn(idx): + *prefix, h, w = idx + index_test = ops.index_expr(h * width + w, torch.int32) + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + grad_index = [ + *prefix, + ops.indirect_indexing( + ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), + indices_size[-2], + check=False, + ), + ops.indirect_indexing( + ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), + indices_size[-1], + check=False, + ), + ] + + index_actual = indices_loader(grad_index) + grad_part = grad_loader(grad_index) + check = ops.eq(index_actual, index_test) + + if gradient is None: + # don't need mask for 0, 0 + gradient = ops.where( + check, grad_part, ops.constant(0.0, torch.float32) + ) + else: + mask = ops.and_( + ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ), + check, + ) + gradient = ops.where(mask, ops.add(gradient, grad_part), gradient) + assert gradient is not None + return gradient + + out = Pointwise.create( + device=grad_output.get_device(), + dtype=grad_output.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + if is_channels_last: + return ir.ExternKernel.require_channels_last(out) + else: + return out + + +def pad_adaptive_loader(x, pad_val=0.0): + x_loader = x.make_loader() + + def load(prefix, increments, start_indices, end_indices): + ih, iw = increments + h_start_index, w_start_index = start_indices + h_end_index, w_end_index = end_indices + + mask = ops.and_( + ops.lt( + ops.index_expr(h_start_index + ih, torch.int64), + ops.index_expr(h_end_index, torch.int64), + ), + ops.lt( + ops.index_expr(w_start_index + iw, torch.int64), + ops.index_expr(w_end_index, torch.int64), + ), + ) + + return ops.masked( + mask, + lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]), + pad_val, + ) + + return load + + +def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out): + h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) + + w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + return h_start_index, h_end_index, w_start_index, w_end_index + + +def _adaptive_pooling_fn( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + result = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + if result is None: + result = val + else: + result = pooling_fn(val, result) + return result + + return fn + + +def _adaptive_pooling_fn_with_idx( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + index = ops.index_expr( + (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 + ) + + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + + if maxval is None: + maxval = val + else: + maxval = pooling_fn(val, maxval) + + return maxindex + + return fn + + +fallback_adaptive_avg_pool2d = fallback_handler( + aten._adaptive_avg_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten._adaptive_avg_pool2d) +def _adaptive_avg_pool2d(x, output_size): + if x.get_dtype() == torch.int64: + # not supported in eager + raise RuntimeError("'adaptive_avg_pool2d' not implemented for 'Long'") + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + # no-op if the same input and output + if h_in == h_out and w_in == w_out: + return clone(x) + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return avg_pool2d(x, kernel_size) + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.add, + ) + + ones_loader = pad_adaptive_loader(ones_like(x)) + + def fn(idx): + return ops.truediv( + fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader) + ) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO: should we force these to be realized? + return rv + + +fallback_adaptive_max_pool2d = fallback_handler( + aten.adaptive_max_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.adaptive_max_pool2d) +def adaptive_max_pool2d(x, output_size): + if x.get_dtype() == torch.int64: + # not supported in eager + raise RuntimeError("adaptive_max_pool2d not implemented for Long") + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) + + h_out, w_out = output_size + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty( + o_size, dtype=torch.int64, device=x.get_device() + ) + + if h_in % h_out == 0 and w_in % w_out == 0: + # This is handled by a decomposition + raise ValueError + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_max_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + inner_func_max_val = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + inner_func_max_idx = _adaptive_pooling_fn_with_idx( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + def inner_fn_max_val(idx): + return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf"))) + + def inner_fn_max_idx(idx): + return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf"))) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=inner_fn_max_val, + ranges=new_size, + ) + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=inner_fn_max_idx, + ranges=new_size, + ) + return rv, ri + + +def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims): + out_sz = out_sz[dim] + in_sz = in_sz[dim] + kernel_sz = kernel_sz[dim] + samples_loader = samples.make_loader() + + def load(prefix, i): + sample = samples_loader([*prefix, ndims - 1 - dim]) + i_expr = ops.index_expr(i, samples.get_dtype()) + diff = ops.index_expr(in_sz - kernel_sz, torch.int64) + out_sz_expr = ops.index_expr(out_sz - 1, torch.int64) + alpha = ops.truediv( + ops.to_dtype(diff, torch.float64), ops.to_dtype(out_sz_expr, torch.float64) + ) + alpha = ops.where(ops.eq(out_sz_expr, 0), 0, alpha) + seq_i = ops.trunc((i_expr + sample) * alpha) - ops.trunc(sample * alpha) + seq_i = ops.to_dtype(seq_i, torch.int64) + mask = ops.lt(i_expr, out_sz_expr) + return ops.indirect_indexing(ops.where(mask, seq_i, diff), sympy.sympify(in_sz)) + + return load + + +@register_lowering(aten.fractional_max_pool2d) +def fractional_max_pool2d(x, kernel_size, output_size, random_samples): + return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2) + + +@register_lowering(aten.fractional_max_pool3d) +def fractional_max_pool3d(x, kernel_size, output_size, random_samples): + return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3) + + +def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim): + x.realize_hint() + batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:] + + with config.patch(unroll_reductions_threshold=25): + dhw_index_fn = [ + _fractional_pooling_offsets( + samples=random_samples, + in_sz=inp_dhw, + out_sz=output_size, + kernel_sz=kernel_size, + ndims=n_dim, + dim=d, + ) + for d in range(n_dim) + ] + + x_loader = x.make_loader() + + def fn_inner(idx, reduction_idx): + prefix = idx[:-n_dim] + return x_loader([*prefix, *increments_to_index(idx, reduction_idx)]) + + def increments_to_index(idx, reduction_idx): + prefix = idx[:-n_dim] + bdhw = idx[-n_dim:] + return [ + dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d] + for d in range(n_dim) + ] + + new_size = list(batch) + list(output_size) + dtype = x.get_dtype() + result = Reduction.create( + reduction_type="max", + input_node=x, + device=x.get_device(), + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=fn_inner, + ranges=new_size, + reduction_ranges=kernel_size, + ) + offsets = Reduction.create( + reduction_type="argmax", + input_node=x, + device=x.get_device(), + dst_dtype=torch.int64, + src_dtype=dtype, + inner_fn=fn_inner, + ranges=new_size, + reduction_ranges=kernel_size, + ) + if isinstance(result.data.data, Reduction): # type: ignore[attr-defined] + # Only realize if reduction isn't unrolled + result.realize() + if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined] + # Only realize if reduction isn't unrolled + offsets.realize() + + indices = _pool_offsets_to_indices( + offsets, kernel_size, x.shape, increments_to_index + ) + return result, indices + + +@register_lowering(aten.upsample_nearest2d_backward.default) +def upsample_nearest2d_backward( + x, output_size=None, input_size=None, scales_h=None, scales_w=None +): + x.realize_hint() + + *_batch, inp_h, inp_w = x.get_size() + inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) + inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) + + *_batch, out_h, out_w = input_size + + if inp_h % out_h == 0 and inp_w % out_w == 0: + return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1) + + h_kernel_max = ceildiv(inp_h, out_h) + w_kernel_max = ceildiv(inp_w, out_w) + + def start_index(index, out_dim, inp_dim): + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) + + def end_index(index, out_dim, inp_dim): + return start_index((index + 1), out_dim, inp_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[inp_h, inp_w], + out_sizes=[out_h, out_w], + pooling_fn=ops.add, + ) + + def fn(idx): + return fn_sum(idx, pad_adaptive_loader(x)) + + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=list(input_size), + ) + + return rv + + +fallback_avg_pool2d = fallback_handler( + aten.avg_pool2d.default, add_to_fallback_set=False +) +fallback_avg_pool3d = fallback_handler( + aten.avg_pool3d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d, type_promotion_kind=None) +def avg_pool2d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=2, + ) + + +@register_lowering(aten.avg_pool3d, type_promotion_kind=None) +def avg_pool3d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=3, + ) + + +def _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim, +): + if not stride: + stride = kernel_size + if not padding: + padding = [0] * dim + kernel_size = pad_listlike(kernel_size, dim) + stride = pad_listlike(stride, dim) + padding = pad_listlike(padding, dim) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == dim + assert len(stride) == dim + assert len(padding) == dim + assert len(x.get_size()) in (dim + 1, dim + 2) + + x.realize_hint() + batch = x.get_size()[:-dim] + h = x.get_size()[-dim:] + + h_out, ceil_modes = zip( + *[ + pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode) + for i in range(dim) + ] + ) + + if any(padding) or any(ceil_modes): + x_loader = constant_boundary_condition(x, 0.0, dim=dim) + had_padding = True + else: + x_loader = x.make_loader() + had_padding = False + + new_size = list(batch) + list(h_out) + dtype = x.get_dtype() + + window_size = functools.reduce(operator.mul, kernel_size) + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + if dim == 2: + fallback = fallback_avg_pool2d + elif dim == 3: + fallback = fallback_avg_pool3d + else: + raise ValueError(f"Unknown dim: {dim}") + + return fallback( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def fn_sum(idx, loader): + prefix = idx[:-dim] + b = idx[-dim:] + total = None + for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]): + inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)] + val = loader([*prefix, *inp]) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + if not had_padding or divisor_override: + divisor = divisor_override if divisor_override else window_size + if dtype.is_floating_point: + scale = 1 / divisor + + def fn(idx): + return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype)) + + else: + + def fn(idx): + # C style integer division as done in native/cpu/AvgPoolKernel.cpp + return ops.truncdiv(fn_sum(idx, x_loader), ops.constant(divisor, dtype)) + + else: + + def fn(idx): + bh = idx[-dim:] + + divide_factors = [] + for i in range(dim): + hstart = bh[i] * stride[i] - padding[i] + hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i]) + if not count_include_pad: + hstart = sympy.Max(hstart, 0) + hend = sympy.Min(hend, h[i]) + factor = ops.index_expr(hend - hstart, torch.int32) + divide_factors.append(factor) + divide_factor = functools.reduce(ops.mul, divide_factors) + if dtype.is_floating_point: + return ops.truediv(fn_sum(idx, x_loader), divide_factor) + # C style integer division as done in native/cpu/AvgPoolKernel.cpp + return ops.truncdiv(fn_sum(idx, x_loader), divide_factor) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO(jansel): should we force these to be realized? + return rv + + +fallback_avg_pool2d_backward = fallback_handler( + aten.avg_pool2d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None) +def avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(x.get_size()) in (3, 4) + + grad_output.realize_hint() # we will read this many times, so make sure it is computed + + *_, height, width = x.get_size() + + _h_out, ceil_mode1 = pooling_size( + height, 0, kernel_size, stride, padding, ceil_mode + ) + _w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + + had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2 + + *_, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + h_window_size = max( + max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(ph, pw): + """ + This computes the scaling factor that we will divide an element + by when `count_include_pad=False` + """ + stride_h = ops.constant(stride[0], torch.int32) + stride_w = ops.constant(stride[1], torch.int32) + pad_h = ops.constant(padding[0], torch.int32) + pad_w = ops.constant(padding[1], torch.int32) + kernel_h = ops.constant(kernel_size[0], torch.int32) + kernel_w = ops.constant(kernel_size[1], torch.int32) + hstart = ops.sub(ops.mul(ph, stride_h), pad_h) + wstart = ops.sub(ops.mul(pw, stride_w), pad_w) + hend = ops.minimum( + ops.add(hstart, kernel_h), + ops.add(ops.index_expr(height, torch.int32), pad_h), + ) + wend = ops.minimum( + ops.add(wstart, kernel_w), + ops.add(ops.index_expr(width, torch.int32), pad_w), + ) + hstart = ops.maximum(hstart, ops.constant(0, torch.int32)) + wstart = ops.maximum(wstart, ops.constant(0, torch.int32)) + hend = ops.minimum(hend, ops.index_expr(height, torch.int32)) + wend = ops.minimum(wend, ops.index_expr(width, torch.int32)) + divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart)) + return divide_factor + + def fn(idx): + *prefix, h, w = idx + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] + else: + scale = compute_pool_size_without_padding(ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where(mask, part, ops.constant(0.0, torch.float32)) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +fallback_avg_pool3d_backward = fallback_handler( + aten.avg_pool3d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None) +def avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 3 + assert len(stride) == 3 + assert len(padding) == 3 + assert len(x.get_size()) in (4, 5) + + grad_output.realize_hint() + + *_batch, depth, height, width = x.get_size() + + _d_out, ceil_mode_d = pooling_size( + depth, 0, kernel_size, stride, padding, ceil_mode + ) + _h_out, ceil_mode_h = pooling_size( + height, 1, kernel_size, stride, padding, ceil_mode + ) + _w_out, ceil_mode_w = pooling_size( + width, 2, kernel_size, stride, padding, ceil_mode + ) + + grad_loader = grad_output.make_loader() + had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w + + *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + d_window_size, h_window_size, w_window_size = ( + max( + max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1) + for d in range(kernel_size[i] * 2) + ) + for i in range(3) + ) + + window_size = d_window_size * h_window_size * w_window_size + if window_size > 125: + # Kernel size too big. Results in hard-to-optimize Triton code. + return fallback_avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(pd, ph, pw): + stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride) + pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding) + kernel_d, kernel_h, kernel_w = ( + ops.constant(k, torch.int32) for k in kernel_size + ) + + dstart, hstart, wstart = ( + ops.sub(ops.mul(p, s), pad) + for p, s, pad in zip( + [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w] + ) + ) + dend, hend, wend = ( + ops.minimum( + ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad) + ) + for start, k, dim, pad in zip( + [dstart, hstart, wstart], + [kernel_d, kernel_h, kernel_w], + [depth, height, width], + [pad_d, pad_h, pad_w], + ) + ) + dstart, hstart, wstart = ( + ops.maximum(start, ops.constant(0, torch.int32)) + for start in [dstart, hstart, wstart] + ) + dend, hend, wend = ( + ops.minimum(end, ops.index_expr(dim, torch.int32)) + for end, dim in zip([dend, hend, wend], [depth, height, width]) + ) + divide_factor = ops.mul( + ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart) + ) + return divide_factor + + def fn(idx): + *prefix, d, h, w = idx + d, h, w = (v + pad for v, pad in zip([d, h, w], padding)) + + pdstart, phstart, pwstart = ( + ops.index_expr(FloorDiv(v - k + s, s), torch.int32) + for v, k, s in zip([d, h, w], kernel_size, stride) + ) + + pdend, phend, pwend = ( + ops.index_expr(FloorDiv(v, s) + 1, torch.int32) + for v, s in zip([d, h, w], stride) + ) + + pdstart, phstart, pwstart = ( + ops.maximum(pstart, ops.constant(0, torch.int32)) + for pstart in [pdstart, phstart, pwstart] + ) + pdend, phend, pwend = ( + ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32)) + for pend, pooled_dim in zip( + [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width] + ) + ) + + gradient = None + # Iterate over the 3D region to accumulate gradients + for pd_ in range(d_window_size): + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + pd, ph, pw = ( + ops.add(pstart, ops.constant(p_, torch.int32)) + for pstart, p_ in zip( + [pdstart, phstart, pwstart], [pd_, ph_, pw_] + ) + ) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] * kernel_size[2] + else: + scale = compute_pool_size_without_padding(pd, ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + pd, ops.sub(pdend, ops.constant(1, torch.int32)) + ), + pooled_depth, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where( + mask, part, ops.constant(0.0, torch.float32) + ) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +def _validate_reduction_axis(x, axis): + size = x.get_size() + if isinstance(axis, int): + axis = [axis] + elif not axis: + axis = range(len(size)) + if len(size) == 0: + assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}" + return [] + axis = list(axis) + for i in range(len(axis)): + if axis[i] < 0: + axis[i] += len(size) if len(size) else 1 + assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0) + assert len(OrderedSet(axis)) == len(axis), "reduction axis not unique" + return axis + + +def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = OrderedSet[int](_validate_reduction_axis(x, axis)) + + kept_sizes = [] + kept_idx = [] + reduced_sizes = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + reduced_sizes.append(size[i]) + else: + kept_idx.append(i) + kept_sizes.append(size[i]) + + def loader(index, reduction_index): + assert len(reduction_index) == len(reduced_idx) + if keepdims: + assert len(index) == len(size) + index = [index[i] for i in kept_idx] + assert len(index) == len(kept_idx) + new_index = [None] * (len(index) + len(reduction_index)) + for idx, var in itertools.chain( + zip(kept_idx, index), zip(reduced_idx, reduction_index) + ): + new_index[idx] = var + return inner_loader(new_index) + + if keepdims: + new_size = list(size) + for i in reduced_idx: + new_size[i] = sympy.S.One + else: + new_size = kept_sizes + + inner_loader = x.make_loader() + return dict( + device=x.get_device(), + dst_dtype=override_return_dtype or x.get_dtype(), + src_dtype=x.get_dtype(), + inner_fn=loader, + ranges=new_size, + reduction_ranges=reduced_sizes, + ) + + +def make_reduction(reduction_type: ReductionType, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + if isinstance( + result.data.data, # type: ignore[attr-defined] + Reduction, + ): # Only realize if reduction isn't unrolled + result.realize() + return result + + return inner + + +def _make_scan_inner(x, *, axis, dtype): + if dtype is not None: + x = to_dtype(x, dtype) + axis = _validate_dim(x, axis) + + return dict( + device=x.get_device(), + dtypes=(x.get_dtype(),), + inner_fns=(x.make_loader(),), + size=x.get_size(), + axis=axis, + ) + + +@register_lowering(aten.mean) +def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + +def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + +def use_two_step_variance(x, axis, keepdim): + # Instead of unrolling welford, just unroll the simpler two-step var + axis = _validate_reduction_axis(x, axis) + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + + ranges = kwargs["ranges"] + reduction_numel = sympy_product(kwargs["reduction_ranges"]) + return ( + isinstance(reduction_numel, sympy.Integer) + and int(reduction_numel) < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ) + + +def var_mean_welford_(x, axis, *, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + loader = kwargs.pop("inner_fn") + kwargs.pop("dst_dtype") + kwargs.pop("src_dtype") + + mean, m2, _ = ir.WelfordReduction.create( + inner_fns=(loader,), + reduction_type="welford_reduce", + dtype=x.get_dtype(), + **kwargs, + ) + m2.realize() + + dtype = x.get_dtype() + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + rnumel = sympy_product(size[i] for i in axis) + + def get_constant_or_index_expr(x, dtype): + if isinstance(x, sympy.Expr) and not x.is_number: + return ops.to_dtype(ops.index_expr(x, torch.int64), dtype) + return ops.constant(x, dtype) + + def scale_fn(data): + c = get_constant_or_index_expr(correction, dtype) + N = get_constant_or_index_expr(rnumel, dtype) + zero = ops.constant(0, dtype) + return data / ops.maximum(zero, N - c) + + var = make_pointwise(scale_fn)(m2) + + if return_mean: + mean.realize() + return var, mean + return (var,) + + +def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + if use_two_step_variance(x, axis=axis, keepdim=keepdim) + else var_mean_welford_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + +@register_lowering([aten.var, prims.var]) +def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + +@register_lowering(aten.var_mean) +def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + +def pow_recursive(x, y, dtype): + if y < 0: + return pow_recursive(ops.reciprocal(x), -y, dtype) + if y == 0: + return ops.constant(1, dtype) + if y == 1: + return x + + result = pow_recursive(x, y // 2, dtype) + result = ops.mul(result, result) + if (y % 2) == 1: + result = ops.mul(result, x) + return result + + +@make_pointwise +def pow_native(a, b): + return ops.pow(a, b) + + +fallback_pow_tensor_tensor = fallback_handler( + aten.pow.Tensor_Tensor, add_to_fallback_set=False +) +fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False) +fallback_pow_tensor_scalar = fallback_handler( + aten.pow.Tensor_Scalar, add_to_fallback_set=False +) + + +@register_lowering(aten.pow, broadcast=True) +def pow(a, b): + if isinstance(b, float) and b == int(b): + return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) + elif isinstance(b, int) and b == 1: + return clone(a) + + # Type promotion ensures all tensor arguments have the same type + dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) + is_integer_pow = is_integer_dtype(dtype) + + # Optimize away small fixed powers, or for integers avoid falling back to ATen + embed_exponent = isinstance(b, int) and ( + -32 < b < 32 or (is_integer_pow and b >= 0) + ) + if embed_exponent: + loader = a.make_loader() + + def fn(idx): + return pow_recursive(loader(idx), b, a.get_dtype()) + + return Pointwise.create( + device=a.get_device(), + dtype=a.get_dtype(), + inner_fn=fn, + ranges=a.get_size(), + ) + + if isinstance(a, Number): + if a == 1: + return full_like(b, 1) + if a == 2 and is_float_dtype(b.get_dtype()): + return exp2(b) + + if is_integer_pow: + # ops.pow doesn't work for integers + if isinstance(a, Number): + return fallback_pow_scalar(a, b) + elif isinstance(b, Number): + return fallback_pow_tensor_scalar(a, b) + else: + return fallback_pow_tensor_tensor(a, b) + + return pow_native(a, b) + + +def mutate_to(changed, val, unsafe_alias=False): + if isinstance(changed, TensorBox): + changed_data = changed.data + else: + changed_data = changed + if isinstance(val, TensorBox): + val = val.data + + if not isinstance(val, ir.StorageBox): + # introduce a copy to handle views + val = Pointwise.create( + device=changed.get_device(), + dtype=changed.get_dtype(), + inner_fn=val.make_loader(), + ranges=changed.get_size(), + ).data + assert isinstance(val, ir.StorageBox) + + if isinstance(changed_data, ir.StorageBox) and not ( + changed_data.is_input_buffer() + # In AOTI, module parameters and buffers are not lifted as graph inputs + or changed_data.is_module_buffer() + or isinstance(changed_data.data, ir.NopKernel) + ): + # Fast path, just swing the data pointer + val.realize() + changed_data.data = val.data + return changed + + ir.MutationLayoutSHOULDREMOVE.realize_into( + val, changed_data, unsafe_alias=unsafe_alias + ) + return changed + + +@register_lowering(aten.fill_) +def fill_(x, fill_value): + return mutate_to(x, full_like(x, fill_value)) + + +@register_lowering(aten.copy_, type_promotion_kind=None) +def copy_(dst, src, non_blocking=False): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@make_pointwise +def floordiv(a, b): + return ops.floordiv(a, b) + + +@make_pointwise +def truncdiv(a, b): + return ops.truncdiv(a, b) + + +@register_lowering(aten.div, broadcast=True) +def div_mode(a, b, rounding_mode=None): + both_integer = is_integer_type(a) and is_integer_type(b) + both_boolean = is_boolean_type(a) and is_boolean_type(b) + + # floordiv and truncdiv need special handling for integer tensors on Triton, + # see the discussion at https://github.com/triton-lang/triton/issues/605 + if rounding_mode == "floor": + assert not both_boolean, "floordiv operands can not be boolean at the same time" + return floordiv(a, b) if both_integer else floor(div(a, b)) + if rounding_mode == "trunc": + assert not both_boolean, "truncdiv operands can not be boolean at the same time" + return truncdiv(a, b) if both_integer else trunc(div(a, b)) + return div(a, b) + + +@register_lowering([aten.mul], broadcast=True) +def mul(a, b): + both_bool = is_boolean_type(a) and is_boolean_type(b) + if both_bool: + return logical_and(a, b) + else: + fn = ops_wrapper(aten.mul.__name__) + return make_pointwise(fn)(a, b) + + +def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]: + """Try convert an arbitrary IR node into an ir.Constant value""" + + # First try unwrapping the IRNode to see if it is already an ir.Constant + # Optional step, but avoids unnecessary inner_fn evaluation. + if isinstance(x, ir.MutableBox): + return get_constant_value(x.data) + if isinstance(x, ir.BaseView): + return get_constant_value(x.unwrap_view()) + if isinstance(x, ir.Constant): + return x + + # If the unwrapped node is not an ir.Constant, try evaluating inner_fn + # to see if the returned value is from an `ops.constant` call + if not isinstance(x, ir.Loops): + return None + + handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device()) + with ( + V.set_ops_handler(handler), + patch.object(ir.FlexibleLayout, "allow_indexing", True), + ): + out = x.inner_fn(*x.inner_fn_args()) + + assert isinstance(out, torch._inductor.virtualized.OpsValue) + if isinstance(out.value, ir.Constant): + return out.value + return None + + +# NOTE: prims.div maps to a / b in C, so performs truncation division on +# integer inputs and true division for floating and complex inputs. +@register_lowering([prims.div], broadcast=True) +def div_prim(a, b): + is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b]) + + if is_integral: + return truncdiv(a, b) + + if (divisor := get_constant_value(b)) is not None: + # Replace divide by constant with multiply by reciprocal + if divisor.value == 0: + reciprocal = math.copysign(float("inf"), divisor.value) + else: + reciprocal = 1.0 / divisor.value + return mul(a, reciprocal) + + def fn(*args): + return ops.truediv(*args) + + return make_pointwise(fn)(a, b) + + +@register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def div(a, b): + a, b = promote_constants( + (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return div_prim(a, b) + + +@register_lowering([aten.fmod, prims.fmod], broadcast=True) +def fmod(a, b): + is_integral = is_boolean_type(a) or is_integer_type(a) + + if is_integral: + + def fn(a, b): + return ops.mod(a, b) + + else: + + def fn(a, b): + return ops.fmod(a, b) + + return make_pointwise(fn)(a, b) + + +@register_lowering([aten.sum, prims.sum]) +def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +fallback_cumsum = fallback_handler(aten.cumsum.default) +fallback_cumprod = fallback_handler(aten.cumprod.default) +fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default) +fallback_cummax = fallback_handler(aten.cummax.default) +fallback_cummin = fallback_handler(aten.cummin.default) + + +@register_lowering(aten.cumsum) +def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.add(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumsum(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.cumprod) +def cumprod(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.mul(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumprod(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.logcumsumexp) +def logcumsumexp(x, dim): + def log_add_exp_helper(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + min_v = ops.minimum(a, b) + max_v = ops.maximum(a, b) + mask = (min_v != max_v) | (~ops.isinf(min_v)) + return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),) + + dtype = x.get_dtype() + if len(x.get_size()) == 0: + assert dim in [0, -1] + return clone(x) + + kwargs = _make_scan_inner(x, axis=dim, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper) + if result is None: + return fallback_logcumsumexp(x, dim=dim) + return result + + +@register_lowering(aten.cummax, type_promotion_kind=None) +def cummax(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmax", dtype=dtype, arg_break_ties_left=False + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = ( + x.make_loader(), + lambda idx: ops.index_expr(idx[axis], torch.int64), + ) + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] + if values is None: + return fallback_cummax(x, dim=axis) + return values, indices + + +@register_lowering(aten.cummin, type_promotion_kind=None) +def cummin(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmin", dtype=dtype, arg_break_ties_left=False + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = ( + x.make_loader(), + lambda idx: ops.index_expr(idx[axis], torch.int64), + ) + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] + if values is None: + return fallback_cummin(x, dim=axis) + return values, indices + + +@register_lowering(aten.prod) +def prod(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("prod", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +@register_lowering(aten.any) +def reduce_any(x, dim=None, keepdim=False): + x = to_dtype(x, torch.bool) + return make_reduction("any")(x, axis=dim, keepdims=keepdim) + + +@register_lowering(aten.max, type_promotion_kind=None) +def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + +@register_lowering(aten.min, type_promotion_kind=None) +def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + +register_lowering(prims.xor_sum)(make_reduction("xor_sum")) +reduce_amax = register_lowering(aten.amax)(make_reduction("max")) +reduce_amin = register_lowering(aten.amin)(make_reduction("min")) +reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) +) +reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) +) + +add = register_pointwise( + aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" +) + +sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False) + + +@register_lowering(aten.sort.stable, type_promotion_kind=None) +def sort_stable(x, *, stable=None, dim=-1, descending=False): + if stable is None: + stable = False + + shape = x.get_size() + device = x.get_device() + dim = canonicalize_dim(len(shape), dim) + if len(shape) == 0: + return clone(x), _full(0, device, torch.int64, shape) + + dim_size = shape[dim] if len(shape) else 1 + if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max): + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + indices = iota( + dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False + ) + view_shape = [1] * len(shape) + if len(shape): + view_shape[dim] = dim_size + indices = view(indices, view_shape) + indices = expand(indices, shape) + + values, indices = ir.Sort.create( + device=device, + dtypes=(x.dtype, indices.dtype), + inner_fns=(x.make_loader(), indices.make_loader()), + size=shape, + axis=dim, + stable=stable, + descending=descending, + ) + if values is None: + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + assert indices is not None + return values, to_dtype(indices, torch.int64) + + +@register_lowering(aten.sort.default, type_promotion_kind=None) +def sort(x, dim=-1, descending=False): + return sort_stable(x, stable=False, dim=dim, descending=descending) + + +def register_pointwise_numeric(op, name=None, triton_fallback=None): + return register_pointwise( + op, + name=name, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + triton_fallback=triton_fallback, + ) + + +def register_pointwise_numeric_ldf64(op: torch._ops.OpOverloadPacket): + register_op_requires_libdevice_fp64(op.__name__) + return register_pointwise( + op, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + + +rsqrt = register_pointwise_numeric(aten.rsqrt) +exp = register_pointwise_numeric_ldf64(aten.exp) +exp2 = register_pointwise_numeric(aten.exp2) +expm1 = register_pointwise_numeric(aten.expm1) +relu = register_pointwise(aten.relu) +sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) +sqrt = register_pointwise_numeric_ldf64(aten.sqrt) +square = register_pointwise(aten.square) +sub = register_pointwise(aten.sub, allow_alpha=True) +register_pointwise_numeric_ldf64(aten.cos) +register_pointwise_numeric_ldf64(aten.sin) +abs = register_pointwise(aten.abs) +bitwise_and = register_pointwise(aten.bitwise_and) +bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) +bitwise_not = register_pointwise( + aten.bitwise_not, override_fn_when_input_bool="logical_not" +) +bitwise_or = register_pointwise(aten.bitwise_or) +bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) +bitwise_xor = register_pointwise(aten.bitwise_xor) +register_pointwise_numeric(aten.lgamma) +erf = register_pointwise_numeric(aten.erf) +register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +)(erf) + +register_pointwise_numeric(aten.log1p) +register_pointwise_numeric(aten.tan) +register_pointwise_numeric(aten.tanh) +register_pointwise_numeric_ldf64(aten.log) +logical_and = register_pointwise( + aten.logical_and, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_not = register_pointwise( + aten.logical_not, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_or = register_pointwise( + aten.logical_or, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_xor = register_pointwise( + aten.logical_xor, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +maximum = register_pointwise(aten.maximum) +minimum = register_pointwise(aten.minimum) +register_lowering(aten.clamp_min)(maximum) +register_lowering(aten.clamp_max)(minimum) +neg = register_pointwise(aten.neg) +abs = register_pointwise(aten.abs) +reciprocal = register_pointwise_numeric(aten.reciprocal) +register_pointwise(aten.remainder) +sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") +register_pointwise(aten.ceil) +register_pointwise(aten.signbit, override_return_dtype=torch.bool) + +register_lowering(aten._neg_view)(neg) + +register_pointwise(aten.le, override_return_dtype=torch.bool) +register_pointwise(aten.lt, override_return_dtype=torch.bool) +register_pointwise(aten.ge, override_return_dtype=torch.bool) +gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) +register_pointwise(aten.eq, override_return_dtype=torch.bool) +register_pointwise(aten.ne, override_return_dtype=torch.bool) + +register_pointwise_numeric(aten.cosh) +register_pointwise_numeric(aten.sinh) +register_pointwise_numeric(aten.acos) +register_pointwise_numeric(aten.acosh) +register_pointwise_numeric(aten.asin) +register_pointwise_numeric(aten.asinh) +register_pointwise_numeric(aten.atan2) +register_pointwise_numeric(aten.atan) +register_pointwise_numeric(aten.atanh) +register_pointwise_numeric(aten.copysign) +register_pointwise_numeric(aten.erfc) +register_pointwise_numeric(aten.erfinv) +register_pointwise_numeric(aten.hypot) +register_pointwise_numeric(aten.log10) +register_pointwise_numeric(aten.log2) +register_pointwise_numeric(aten.nextafter) + +from .codegen.common import BackendFeature, pointwise_overrides_data + + +def _get_pointwise_overrides(ns, name): + data = pointwise_overrides_data[name] + op = getattr(ns, data.name, None) + if op is None: + return + + def make_triton_fallback(op): + if data.triton is None: + return fallback_handler(op) + + if isinstance(op, torch._ops.OpOverloadPacket): + for olname in op.overloads(): + ol = getattr(op, olname) + yield ol, data.type_promotion_kind, make_triton_fallback(ol) + else: + yield op, data.type_promotion_kind, make_triton_fallback(op) + + +for name in pointwise_overrides_data: + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + aten, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + prims, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + +foreach_add_list = register_foreach_pointwise( + aten._foreach_add.List, add, allow_alpha=True +) +foreach_add_scalar = register_foreach_pointwise( + aten._foreach_add.Scalar, add, allow_alpha=True +) +register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) +foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) +register_foreach_pointwise(aten._foreach_mul.Tensor, mul) +foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) +register_foreach_pointwise(aten._foreach_sub.List, sub) +register_foreach_pointwise(aten._foreach_sub.Scalar, sub) +register_foreach_pointwise(aten._foreach_neg.default, neg) +register_foreach_pointwise(aten._foreach_abs.default, abs) +register_foreach_pointwise(aten._foreach_pow.Scalar, pow) +register_foreach_pointwise(aten._foreach_pow.List, pow) +register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) +foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) +register_foreach_pointwise(aten._foreach_div.Tensor, div) +foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) +register_foreach_pointwise(aten._foreach_sqrt, sqrt) +register_foreach_pointwise(aten._foreach_rsqrt, rsqrt) +register_foreach_pointwise(aten._foreach_maximum.List, maximum) +register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum) +register_foreach_pointwise(aten._foreach_minimum.List, minimum) +register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum) +register_foreach_pointwise(aten._foreach_clamp_min.List, maximum) +register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum) +register_foreach_pointwise(aten._foreach_clamp_max.List, minimum) +register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) +register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) +register_foreach_pointwise(aten._foreach_sign, sign) +register_foreach_pointwise(aten._foreach_copy, copy) + + +# these are only encountered as outputs of the graph +# reinplacing epilogue copies improves compile time +# by removing extra buffers sent to the scheduler. +def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op): + inplaceable_foreach_ops[outplace_aten_op] = aten_op + inplace_foreach_ops.add(aten_op) + + def fn(*args, **kwargs): + results = outplace_op(*args, **kwargs) + mut_results = [] + for arg, result in zip(args[0], results): + mut_results.append(mutate_to(arg, result, unsafe_alias=True)) + + return mut_results + + _register_foreach_lowering(aten_op, fn) + + +register_foreach_inplace( + aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list +) +register_foreach_inplace( + aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar +) +register_foreach_inplace( + aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list +) +register_foreach_inplace( + aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar +) +register_foreach_inplace( + aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list +) +register_foreach_inplace( + aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar +) + + +def register_inplace(aten_op, outplace_op): + @register_lowering(aten_op, type_promotion_kind=None) + def fn(*args, **kwargs): + result = outplace_op(*args, **kwargs) + result = to_dtype(result, args[0].get_dtype()) + return mutate_to(args[0], result) + + return fn + + +register_inplace(aten.add_, add) +register_inplace(aten.bitwise_and_, bitwise_and) +register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) +register_inplace(aten.bitwise_not_, bitwise_not) +register_inplace(aten.bitwise_or_, bitwise_or) +register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) +register_inplace(aten.bitwise_xor_, bitwise_xor) +register_inplace(aten.mul_, mul) +register_inplace(aten.div_.Tensor, div) +register_inplace(aten.div_.Tensor_mode, div_mode) +register_inplace(aten.logical_and_, logical_and) +register_inplace(aten.logical_not_, logical_not) +register_inplace(aten.logical_or_, logical_or) +register_inplace(aten.logical_xor_, logical_xor) +register_inplace(aten.sub_, sub) +register_inplace(aten.relu_, relu) +register_inplace(aten.sigmoid_, sigmoid) + + +register_lowering(aten.__and__)(bitwise_and) +register_lowering(aten.__lshift__)(bitwise_left_shift) +register_lowering(aten.__or__)(bitwise_or) +register_lowering(aten.__rshift__)(bitwise_right_shift) +register_lowering(aten.__xor__)(bitwise_xor) + +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) + + +@register_lowering(aten.sym_constrain_range) +def sym_constrain_range(a, min=None, max=None): + return None + + +@register_lowering(aten.sym_size.int) +def sym_size(a, dim): + val = V.graph.current_node.meta["val"] + # Note [Can val be an int?] + # ~~~~~~~~~~~~~~~~~~~~~~~~~ + # In principle, someone could construct an FX graph where + # a call to size/stride has a val that is a plain int (not + # SymInt). However, we will maintain the invariant that + # this is not possible: if you are constructing an FX graph + # where there is a call to size/stride that returns an + # int, but you KNOW that int must always be a constant, + # then you do not need trace that call at all (and just + # constant propagate the integer as is.) + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_stride.int) +def sym_stride(a, dim): + val = V.graph.current_node.meta["val"] + # See Note [Can val be an int?] + assert isinstance(val, torch.SymInt) + return val.node.expr + + +@register_lowering(aten.sym_numel) +def sym_numel(a): + return a.get_numel() + + +for method, func in magic_methods.items(): + register_lowering(method_to_operator(method))(func) # type: ignore[arg-type] + + +@register_lowering(torch.sym_sum) +def sym_sum(args): + return sympy.Add(*args) + + +@register_lowering(aten._foobar) +def foobar(self, *args, **kwargs): + raise NotImplementedError("Helpful for debugging") + + +@register_lowering(torch.ops._inductor_test.realize) +def _realize(x): + x.realize() + return clone(x) + + +@register_lowering(torch.ops.inductor.resize_storage_bytes_) +def resize_storage_bytes_(variable, new_size): + variable.realize() + ir.ResizeStorageBytes(variable, new_size) + return variable + + +@register_lowering(torch.ops.aten.set_.source_Tensor) +def set__source_tensor(self, source_tensor): + self.realize() + source_tensor.realize() + return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor)) + + +if hasattr(torch.ops.fsdp, "copy_"): + + @register_lowering(torch.ops.fsdp.copy_.default) + def fsdp_copy_(dst, src): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@register_lowering(torch.ops.aten.resize) +def resize(x, size, *, memory_format=None): + assert isinstance(x, TensorBox) + assert isinstance(size, (list, tuple)) + + if memory_format is None: + memory_format = torch.contiguous_format + if memory_format == torch.preserve_format: + raise RuntimeError(f"unsupported memory format: {memory_format}") + + if memory_format == torch.channels_last: + assert len(size) == 4 + if memory_format == torch.channels_last_3d: + assert len(size) == 5 + + old_numel = x.get_numel() + dtype = x.get_dtype() + device = x.get_device_or_error() + + if isinstance(x.data, ir.BaseView): + x.data = x.data.unwrap_view() + + if ( + torch.are_deterministic_algorithms_enabled() + and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] + ): + if is_float_dtype(dtype): + uninitalized_val = float("nan") + elif is_integer_dtype(dtype): + uninitalized_val = torch.iinfo(dtype).max + else: + uninitalized_val = True + else: + # using zero as that is what empty does + uninitalized_val = 0.0 + + if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type] + return full(size, uninitalized_val, dtype=dtype, device=device) + + x_flat = as_strided( + x, + [ + old_numel, + ], + [ + 1, + ], + ) + flat_loader = x_flat.make_loader() + out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format) + out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer() + + def inner_fn(idx): + flat_index = out_indexer(idx) + flat_index_expr = ops.index_expr(flat_index, torch.int64) + limit = ops.index_expr(old_numel, torch.int64) + mask = ops.lt(flat_index_expr, limit) + return ops.masked(mask, lambda: flat_loader([flat_index]), uninitalized_val) + + out = Pointwise.create( + device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size) + ) + return out + + +from torch._higher_order_ops.auto_functionalize import auto_functionalized + + +make_fallback(auto_functionalized) + + +@register_lowering(triton_kernel_wrapper_mutation) +def triton_kernel_wrap_( + *, + kernel_idx, + constant_args_idx, + grid, + tma_descriptor_metadata, + kwargs, +): + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + ir.UserDefinedTritonKernel( + kernel_idx=kernel_idx, + grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, + kernel_args={**kwargs, **constant_args}, + ) + return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} + + +@register_lowering(torch.ops.higher_order.cond, type_promotion_kind=None) +def cond(pred, true_fn, false_fn, operands): + if any(isinstance(x, IRNode) and is_triton(x) for x in [pred, *operands]): + msg = "control flow operator: torch.cond." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.Conditional.create(pred, true_fn, false_fn, operands) + return list(map(TensorBox.create, result)) + + +@register_lowering(torch.ops.higher_order.while_loop, type_promotion_kind=None) +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): + if any( + isinstance(x, IRNode) and is_triton(x) + for x in carried_inputs + additional_inputs + ): + msg = "control flow operator: torch.while_loop." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + def _map_output(out: Any): + if isinstance(out, TensorBox): + return out + elif isinstance(out, ir.StorageBox): + return TensorBox(out) + elif isinstance(out, ir.MultiOutput): + return TensorBox.create(out) + else: + raise RuntimeError(f"NYI unsupported output type: {type(out)}") + + result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs) + return list(map(_map_output, result)) + + +@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None) +def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands): + result = ir.InvokeSubgraph.create(subgraph_fn, *operands) + return list(map(TensorBox.create, result)) + + +@register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None) +def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None): + output = None + quant_options = V.graph.current_node.meta.get("quant_options", None) + assert quant_options is not None + + for i, node in enumerate(subgraph_fn.graph_module.graph.nodes): + if node.op == "placeholder": + V.graph.env[node] = operands[i] + continue + # todo getattr + elif node.op == "output": + args, kwargs = V.graph.fetch_args_kwargs_from_env(node) + + for v in itertools.chain(args, kwargs.values()): + v.realize() + + if quant_options.codegen_low_precision: + V.graph.low_precision_codegen_ops.add(v.get_operation_name()) + + V.graph.invoke_quant_ops.add(v.get_operation_name()) + + output = torch.fx.Interpreter.output(V.graph, node, args, kwargs) + else: + V.graph.env[node] = V.graph.run_node(node) + + return output + + +@register_lowering(associative_scan_op, type_promotion_kind=None) +def associative_scan( + combine_fn: ir.Subgraph, xs, additional_inputs: tuple[torch.Tensor] +): + from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph + + if len(additional_inputs) > 0: + raise RuntimeError( + "Unable to generate code for associative_scan op, because there are lifted arguments" + ) + + subgraph_inputs = [ + InputDescriptor(dtype=x.get_dtype(), device=x.get_device()) + for x in itertools.chain(xs, xs) + ] + lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated] + + def wrapped_combine_fn(lhs, rhs): + return lowered_combine_fn( + *pytree.tree_leaves(lhs), + *pytree.tree_leaves(rhs), + ) + + kwargs = _make_scan_inner(xs[0], axis=0, dtype=None) + kwargs["dtypes"] = tuple(x.get_dtype() for x in xs) + kwargs["inner_fns"] = tuple(x.make_loader() for x in xs) + result = ir.Scan.create( + combine_fn=wrapped_combine_fn, + can_fallback_to_aten=False, + **kwargs, + ) + if result[0] is None: + raise RuntimeError("Unable to generate code for associative_scan op") + return result + + +@register_lowering(torch.ops.prims._sink_tokens.default) +def _sink_tokens(tokens): + return None + + +@register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None) +def with_effects(token, op, *args, **kwargs): + result = ir.EffectfulKernel.create(op, *args, **kwargs) + + from torch._higher_order_ops.effects import get_effect_key + + effect_type = get_effect_key(op, args, kwargs) + assert effect_type is not None + effectful_kernel = V.graph.effectful_ops[effect_type] + + if result is None: + return (effectful_kernel,) + + result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result) + # See [NOTE: with_effects return type] + # Only return `result` if it is a tuple, not list. + if not isinstance(result, tuple): + return (effectful_kernel, result) + else: + return (effectful_kernel, *result) + + +from .comm_lowering import register_comm_lowerings + + +register_comm_lowerings() + + +@register_lowering(inductor_prims.prepare_softmax_online, type_promotion_kind=None) +def prepare_softmax_online(x, dim): + """ + Lowering inductor_prims.prepare_softmax_online to compute max/sum in one pass if no split is needed. + """ + kwargs = _make_reduction_inner( + x, axis=dim, keepdims=True, dtype=None, override_return_dtype=None + ) + + reduction_ranges = kwargs["reduction_ranges"] + rnumel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + hint, num_split = ir.Reduction.num_splits( + **kwargs, + reduction_type="online_softmax_reduce", # type: ignore[arg-type] + reduction_numel=rnumel, + ) + + if ( + num_split == 1 + and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold + ): + max_tensor, sum_tensor = OnlineSoftmaxReduction.create( + input_node=x, num_output=2, reduction_hint=hint, **kwargs + ) + return max_tensor, sum_tensor + else: + # Note: [Split online_softmax_reduce] + # We don't split reduction for online_softmax_reduce for now. + # On one hand, supporting split reduction makes things complex since + # the split out reuctions requires 2 inputs rather than one. + # On the other hand, during training the online_softmax_reduce should + # usually don't requires a split due to large batch size + # (more specifically batch size times sequence length). + # We should support split reduction if we find legit use cases to + # motivate the work. + # + # TODO: does inference need split online_softmax_reduce? + + warnings.warn( + textwrap.dedent( + """ + Online softmax is disabled on the fly since Inductor decides to + split the reduction. Cut an issue to PyTorch if this is an + important use case and you want to speed it up with online + softmax. + """ + ) + ) + amax = reduce_amax(x, dim, keepdims=True) + exp = lowerings[aten.exp](sub(x, amax)) + xsum = sum_(exp, dim, keepdims=True) + return amax, xsum + + +# populate lowerings defined in kernel/* +from . import kernel + + +import_submodule(kernel) + +from . import quantized_lowerings + + +quantized_lowerings.register_quantized_ops() +quantized_lowerings.register_woq_mm_ops() + +from . import mkldnn_lowerings + + +mkldnn_lowerings.register_onednn_fusion_ops() + +from . import jagged_lowerings + + +jagged_lowerings.register_jagged_ops() + + +@contextlib.contextmanager +def force_fallback(op: torch._ops.OpOverload): + """ + A context manager to force fallback an op. Used in unit test + for FallbackKernel. + """ + assert isinstance(op, torch._ops.OpOverload), ( + "Only OpOverload to make the clean up easier" + ) + old_handler = lowerings.get(op) + try: + register_lowering(op)(fallback_handler(op)) + yield + finally: + if old_handler: + lowerings[op] = old_handler + else: + lowerings.pop(op) diff --git a/phivenv/Lib/site-packages/torch/_inductor/memory.py b/phivenv/Lib/site-packages/torch/_inductor/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..3a14e2840ae08871aaa6c83a35d0a18a9fcae299 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/memory.py @@ -0,0 +1,697 @@ +from __future__ import annotations + +import collections +import dataclasses +import heapq +import logging +from typing import Callable, TYPE_CHECKING, TypedDict, Union + +from torch._utils_internal import signpost_event +from torch.utils._ordered_set import OrderedSet + +from .ir import MultiOutputLayout, NoneLayout +from .utils import get_dtype_size, is_wait +from .virtualized import V + + +if TYPE_CHECKING: + from .dependencies import Dep + from .scheduler import BaseSchedulerNode, SchedulerBuffer + + +torch_log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class PeakMemoryResult: + order: list[BaseSchedulerNode] + peak_memory: int + method: str + + +@dataclasses.dataclass +class MemoryPlanningInfoForBuffer: + size_alloc: int = 0 + size_free: int = 0 + succ_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( + default_factory=OrderedSet + ) + + +@dataclasses.dataclass +class MemoryPlanningInfoForNode: + index: int = 0 + size: int = 0 + pred_buffers: OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]] = ( + dataclasses.field(default_factory=OrderedSet) + ) + pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( + default_factory=OrderedSet + ) + succ_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( + default_factory=OrderedSet + ) + + +@dataclasses.dataclass +class FreeableInputBuffer: + name: str + mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field( + default_factory=MemoryPlanningInfoForBuffer + ) + + def get_name(self) -> str: + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + +def get_freeable_input_buf( + nodes: list[BaseSchedulerNode], + graph_inputs: OrderedSet[str], +) -> dict[str, FreeableInputBuffer]: + """ + Create and keep track of all input buffers that can be freed during the program + + Returns: + A dictionary containing all freeble input buffers, keyed by their names. + """ + + # this function is copied from torch/_inductor/scheduler.py + # TODO: would be nice to remove the try/except block for both places + def _dep_size_hint(dep: Dep) -> int: + res = 0 + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + return res + + # get freeable input buffers' successor nodes and their sizes + # note that different deps can have the same name, so we use name as keys + dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + dep_name_to_size: dict[str, int] = dict() + for node in nodes: + for dep in node.read_writes.reads: + if dep.name in graph_inputs and not dep.name.startswith( + ("primals_", "arg", "fwd_rng_state", "bwd_rng_state") + ): + dep_name_to_succ_nodes[dep.name].add(node) + dep_name_to_size[dep.name] = _dep_size_hint(dep) + + # create FreeableInputBuffer objects and add them to the returned dictionary + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = dict() + for dep_name, succ_nodes in dep_name_to_succ_nodes.items(): + name_to_freeable_input_buf[dep_name] = FreeableInputBuffer( + dep_name, + MemoryPlanningInfoForBuffer( + size_free=dep_name_to_size[dep_name], succ_nodes=succ_nodes + ), + ) + return name_to_freeable_input_buf + + +def compute_size_for_scheduler_buffer( + name_to_buf: dict[str, SchedulerBuffer], +) -> dict[str, tuple[int, int]]: + """ + Compute the size of each scheduler buffer, including (1) memory allocated when + it is created and (2) memory deallocated when it is freed. + + We specially handle the case of MultiOutputLayout. + Consider the following case: + buf0 = some_ops_with_multi_outputs(...) + buf1 = buf0[0] # assume 10 bytes + buf2 = buf0[1] # assume 20 bytes + In such cases, + buf0: at creation, 30 bytes allocated, when deleted, 0 bytes freed + buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed + buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed + + Returns: + A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free). + """ + from .ir import MultiOutput + from .scheduler import OutputNode + + sched_buf_to_size: dict[str, tuple[int, int]] = dict() + + def _compute_and_update_buf_size( + sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False + ) -> int: + if isinstance(sched_buf.node.layout, NoneLayout): + _size = 0 + # for a wait tensor op, its schedulerBuffer NoneLayout layout. However, + # the schedulerBuffer is treated as a mutation of the collective output + # so it needs to inherit the size of the collectives + if ( + sched_buf.defining_op + and is_wait(sched_buf.defining_op.node) + and sched_buf.get_mutations() + ): + mutated_buf_name = sched_buf.get_mutations()[0] + _size = ( + sched_buf_to_size[mutated_buf_name][1] + if mutated_buf_name in sched_buf_to_size + else 0 + ) + sched_buf_to_size[sched_buf.get_name()] = (_size, _size) + return _size + elif isinstance(sched_buf.node.layout, MultiOutputLayout): + size_alloc = 0 + for user in sched_buf.users: + if isinstance(user.node, OutputNode): + continue + for buf in user.node.get_outputs(): + if isinstance(buf.node, MultiOutput): + size_alloc += _compute_and_update_buf_size(buf, True) + sched_buf_to_size[sched_buf.get_name()] = ( + 0 if user_of_MultiOutputLayout else size_alloc, + 0, + ) + return size_alloc + else: + buf_size = V.graph.sizevars.size_hint( + sched_buf.node.get_numel(), fallback=0 + ) * get_dtype_size(sched_buf.node.get_dtype()) + sched_buf_to_size[sched_buf.get_name()] = ( + 0 if user_of_MultiOutputLayout else buf_size, + buf_size, + ) + return buf_size + + for sched_buf in name_to_buf.values(): + # skip if sched_buf is already processed as an user of another SchedulerBuffer + # whose layout is of the type MultiOutputLayout + if sched_buf.get_name() not in sched_buf_to_size: + _compute_and_update_buf_size(sched_buf) + + return sched_buf_to_size + + +def assign_memory_planning_info_for_scheduler_buffers( + nodes: list[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], +) -> None: + """ + For each SchedulerBuffer, assign its size info and successor nodes. + A buffer's successor nodes determines when a buffer can be freed. + """ + # get buffer sizes + sched_buf_to_size = compute_size_for_scheduler_buffer(name_to_buf) + + # get buffer's successor nodes + # note that different deps can have the same name, so we use name as keys + dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + for node in nodes: + for dep in node.unmet_dependencies: + dep_name_to_succ_nodes[dep.name].add(node) + + # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer + # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) + for buf_name in name_to_buf.keys(): + name_to_buf[buf_name].mpi_buffer = MemoryPlanningInfoForBuffer( + size_alloc=sched_buf_to_size[buf_name][0], + size_free=sched_buf_to_size[buf_name][1], + succ_nodes=dep_name_to_succ_nodes[buf_name], + ) + + +def assign_memory_planning_info_for_scheduler_nodes( + nodes: list[BaseSchedulerNode], + name_to_fused_node: dict[str, BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], +) -> None: + """ + Assign to each scheduler node its predecessor and successor nodes. + """ + from .scheduler import SchedulerBuffer + + for index, node in enumerate(nodes): + size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) + pred_buffers = OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]]() + for dep in node.read_writes.reads: + if dep.name in name_to_buf and dep in node.unmet_dependencies: + pred_buffers.add(name_to_buf[dep.name]) + elif dep.name in name_to_freeable_input_buf: + pred_buffers.add(name_to_freeable_input_buf[dep.name]) + pred_nodes = OrderedSet( + name_to_fused_node[pred_buffer.defining_op_name()] + for pred_buffer in pred_buffers + if (isinstance(pred_buffer, SchedulerBuffer)) + ) + succ_nodes = OrderedSet( + succ_node + for buffer in node.get_outputs() + for succ_node in buffer.mpi_buffer.succ_nodes + ) + node.mpi_node = MemoryPlanningInfoForNode( + index=index, + size=size_alloc, + pred_buffers=pred_buffers, + pred_nodes=pred_nodes, + succ_nodes=succ_nodes, + ) + + +def estimate_peak_memory( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + graph_outputs: OrderedSet[str], +) -> tuple[int, list[int]]: + """ + Given a list of nodes in their execution order, estimate the peak memory, by + keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers. + + Returns: + int: peak memory + List[int]: memory usage at each node (or each step). + """ + + # map each scheduler buffer to its size, start step, and end step + @dataclasses.dataclass + class BufferInfo: + buffer: Union[SchedulerBuffer, FreeableInputBuffer] + size_alloc: int + size_free: int + start_step: int + end_step: int + + # get the execution step of each node, this will be used to determine + # the end_step of buffers + node_to_step: dict[BaseSchedulerNode, int] = { + node: step for step, node in enumerate(nodes) + } + + # get buffers' size and liveliness information + buf_info_list: list[BufferInfo] = [] + # 1. for freeable input buffers + for buf_name, input_buf in name_to_freeable_input_buf.items(): + end_step = ( + len(nodes) - 1 + if buf_name in graph_outputs + else max( + node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes + ) + ) + buf_info_list.append( + BufferInfo( + input_buf, + input_buf.mpi_buffer.size_free, + input_buf.mpi_buffer.size_free, + 0, + end_step, + ) + ) + + # 2. for scheduler buffers + for step, node in enumerate(nodes): + for sched_buf in node.get_outputs(): + # note: it is possible for a non-graph-output sched_buf to have no succ_nodes and + # to be only used by its defining op (e.g., due to fusion when all consumers of + # the buffer are fused with its defining op). In such cases, end_step is step. + end_step = ( + len(nodes) - 1 + if sched_buf.get_name() in graph_outputs + else max( + [ + node_to_step[succ_node] + for succ_node in sched_buf.mpi_buffer.succ_nodes + ], + default=step, + ) + ) + buf_info_list.append( + BufferInfo( + sched_buf, + sched_buf.mpi_buffer.size_alloc, + sched_buf.mpi_buffer.size_free, + step, + end_step, + ) + ) + + # incremental memory changes at each step + memory = [0 for _ in range(len(nodes) + 1)] + + # for each buffer, update memory when created and when freed + for buf_info in buf_info_list: + memory[buf_info.start_step] += buf_info.size_alloc + memory[buf_info.end_step + 1] -= buf_info.size_free + + # get peak memory by compute the cumulative memories + max_memory = 0 + cur_memory = 0 + memories_at_nodes = [] + for t in range(len(nodes) + 1): + cur_memory += memory[t] + memories_at_nodes.append(cur_memory) + max_memory = max(max_memory, cur_memory) + + return (max_memory, memories_at_nodes) + + +def topological_sort_lpmf( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + name_to_buf: dict[str, SchedulerBuffer], + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First". + + The idea is from this paper: + Buffer memory optimization for video codec application modeled in Simulink + https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF + + The algorithm maintain the max memory so far. + At every iteration, for each scheduleable node, it computes: + - how much memory needs to be allocated for the output buffers of this node; + - how much memory can be freed as a result of executing this node. + This gives us two values for each node: + (1) mem1: memory during the execution of the node; + (2) mem2: memory after executing the node, after some input buffers are freed. + The greedy approach select as follows: + (i) if there are nodes whose mem1 values are below the max memory so far, + then pick the node with the lowest mem2 value; + (ii) otherwise, pick the one with the lowest mem1 value. + """ + + class NodeInfo(TypedDict): + indegree: int + memory_to_free: int + + class BufferInfo(TypedDict): + outdegree: int + + node_info: dict[BaseSchedulerNode, NodeInfo] = dict() + buf_info: dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict() + + # compute nodes' number of unmet dependencies (for schedulability) + # initialize the list of nodes ready to be scheduled + nodes_to_schedule: OrderedSet[BaseSchedulerNode] = OrderedSet() + for node in nodes: + node_info[node] = { + "indegree": len(node.mpi_node.pred_nodes), + "memory_to_free": 0, + } + if node_info[node]["indegree"] == 0: + nodes_to_schedule.add(node) + + # compute buffers' number of unmet successors (used to decide when to free) + for buf in list(name_to_buf.values()) + list(name_to_freeable_input_buf.values()): + buf_info[buf] = { + "outdegree": len(buf.mpi_buffer.succ_nodes) + + (1 if buf.get_name() in graph_outputs else 0) + } + + # initialize memory estimations + live_memory = sum( + input_buf.mpi_buffer.size_free + for input_buf in name_to_freeable_input_buf.values() + ) + + # this is the total output memory, which is a lower bound for peak memory + # we do not include the memory of non freeable input buffers + output_memory = 0 + for buf_name in graph_outputs: + if buf_name in name_to_buf: + output_memory += name_to_buf[buf_name].mpi_buffer.size_free + elif buf_name in name_to_freeable_input_buf: + output_memory += name_to_freeable_input_buf[buf_name].mpi_buffer.size_free + max_memory = max(live_memory, output_memory) + + # compute the amount of memory that is allocated when a node is scheduled + # and the amount of memory that can be freed when a node is scheduled + for node in nodes: + # 1. if a buffer read by this node is last used by this node + for buf in node.mpi_node.pred_buffers: + if buf_info[buf]["outdegree"] == 1: + node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free + # 2. if a buffer written by this node is used internally and not used later + for buf in node.get_outputs(): + if buf_info[buf]["outdegree"] == 0: + node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free + + # schedule nodes one at a time + schedule: list[BaseSchedulerNode] = [] + num_iters: int = 0 + while num_iters < len(nodes) and nodes_to_schedule: + # select a node to schedule: + selected_node = min( + nodes_to_schedule, + key=lambda node: ( + max(live_memory + node.mpi_node.size, max_memory), + node.mpi_node.size - node_info[node]["memory_to_free"], + node.mpi_node.index, + ), + ) + nodes_to_schedule.remove(selected_node) + schedule.append(selected_node) + num_iters += 1 + + # update memory usage + live_memory += selected_node.mpi_node.size + max_memory = max(max_memory, live_memory) + live_memory -= node_info[selected_node]["memory_to_free"] + + # update successor nodes and nodes_to_schedule + for succ_node in selected_node.mpi_node.succ_nodes: + assert node_info[succ_node]["indegree"] > 0 + node_info[succ_node]["indegree"] -= 1 + if node_info[succ_node]["indegree"] == 0: + nodes_to_schedule.add(succ_node) + + # update predecessor nodes + for buf in selected_node.mpi_node.pred_buffers: + assert buf_info[buf]["outdegree"] > 0 + buf_info[buf]["outdegree"] -= 1 + if buf_info[buf]["outdegree"] == 1: + for succ_node in buf.mpi_buffer.succ_nodes: + node_info[succ_node]["memory_to_free"] += buf.mpi_buffer.size_free + + if num_iters > len(nodes): + raise RuntimeError("Failed to schedule, while loop ran too long for lpmf") + + return schedule + + +def topological_sort_bfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + A BFS topological sort that selects nodes whose dependencies are executed the + earliest. This follows a FIFO idea. Specifically, at every iteration, for each node + that is schedulable, we gather the order in which its predecessor nodes are executed, + and this sorted list of execution orders of predecessor nodes defines the priority. + We select the node whose predecessors nodes are executed the earliest. The FIFO + idea aims to reduce the liveness duration of buffers created. + """ + + class NodeInfo(TypedDict): + indegree: int + order: int + + node_info: dict[BaseSchedulerNode, NodeInfo] = dict() + + @dataclasses.dataclass + class NodeWithPriority: + priority: list[int] + node: BaseSchedulerNode + + def __lt__(self, other: NodeWithPriority) -> bool: + if self.priority == other.priority: + return self.node.mpi_node.index < other.node.mpi_node.index + return self.priority < other.priority + + def _node_priority(node: BaseSchedulerNode) -> list[int]: + # priority is the order in which predecessor nodes are executed + assert node_info[node]["indegree"] == 0 + exec_orders = sorted( + OrderedSet( + node_info[pred_node]["order"] for pred_node in node.mpi_node.pred_nodes + ) + ) + return exec_orders + + # compute nodes' number of unmet dependencies (for schedulability) + # initialize the list of nodes ready to be scheduled + nodes_to_schedule: list[NodeWithPriority] = [] + for node in nodes: + node_info[node] = {"indegree": len(node.mpi_node.pred_nodes), "order": -1} + if node_info[node]["indegree"] == 0: + heapq.heappush( + nodes_to_schedule, NodeWithPriority(_node_priority(node), node) + ) + + # schedule nodes one at a time + schedule: list[BaseSchedulerNode] = [] + num_iters: int = 0 + while num_iters < len(nodes) and nodes_to_schedule: + # select a node to schedule + selected_node = heapq.heappop(nodes_to_schedule).node + node_info[selected_node]["order"] = len(schedule) + schedule.append(selected_node) + num_iters += 1 + + # update successor nodes and nodes_to_schedule + for succ_node in selected_node.mpi_node.succ_nodes: + assert node_info[succ_node]["indegree"] > 0 + node_info[succ_node]["indegree"] -= 1 + if node_info[succ_node]["indegree"] == 0: + heapq.heappush( + nodes_to_schedule, + NodeWithPriority(_node_priority(succ_node), succ_node), + ) + + if num_iters > len(nodes): + raise RuntimeError("Failed to schedule, while loop ran too long for bfs") + + return schedule + + +def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + This is a DFS topological sort. The setup is similar to `topological_sort_schedule` + in scheduler.py. The difference is the order nodes are visited in the outer loop. + In `topological_sort_schedule`, nodes are visited in their original order. + In this function, nodes are visited based on their priority -- for each node, we + compute the total memory of all buffers it reads from or writes to, and we visit + the nodes in ascending order of this priority. + """ + seen: OrderedSet[BaseSchedulerNode] = OrderedSet() + name_to_node: dict[str, BaseSchedulerNode] = dict() + result: list[BaseSchedulerNode] = [] + size_with_reads: dict[BaseSchedulerNode, int] = dict() + + def visit(n: BaseSchedulerNode) -> None: + if n not in seen: + seen.add(n) + dep_nodes = [ + name_to_node[dep.name] + for dep in n.unmet_dependencies + if dep.name in name_to_node + ] + for node in sorted( + dep_nodes, key=lambda n: (size_with_reads[n], n.mpi_node.index) + ): + visit(node) + result.append(n) + + for node in nodes: + for name in node.get_buffer_names(): + name_to_node[name] = node + + for node in nodes: + size_with_reads[node] = node.mpi_node.size + sum( + pred_buf.mpi_buffer.size_free for pred_buf in node.mpi_node.pred_buffers + ) + for node in sorted(nodes, key=lambda n: (size_with_reads[n], n.mpi_node.index)): + visit(node) + + return result + + +def prepare_planning_info( + nodes: list[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + graph_inputs: OrderedSet[str], + graph_outputs: OrderedSet[str], +) -> tuple[int, dict[str, FreeableInputBuffer]]: + """ + Prepare planning info. As nodes are scheduled one at a time, these help + keep track of when a buffer can be freed, and when a node can be scheduled + + Returns: + int: peak memory estimation + dict[str, FreeableInputBuffer]: name to freeable input buffer + """ + name_to_freeable_input_buf = get_freeable_input_buf(nodes, graph_inputs) + assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf) + assign_memory_planning_info_for_scheduler_nodes( + nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf + ) + + # the default + estimated_peak_memory, _ = estimate_peak_memory( + nodes, name_to_freeable_input_buf, graph_outputs + ) + + return estimated_peak_memory, name_to_freeable_input_buf + + +def reorder_for_peak_memory( + nodes: list[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + graph_inputs: OrderedSet[str], + graph_outputs: OrderedSet[str], + methods: list[Callable[..., list[BaseSchedulerNode]]] = [ # noqa: B006 + topological_sort_lpmf, + topological_sort_bfs, + topological_sort_dfs, + ], +) -> list[BaseSchedulerNode]: + """ + Try a few heuristics based topological sort algorithms, and pick the one whose + resulting topological order has the lowest peak memory estimation. + """ + + torch_log.info("Reordering for peak memory -- %d nodes", len(nodes)) + + estimated_peak_memory, name_to_freeable_input_buf = prepare_planning_info( + nodes, + name_to_buf, + name_to_fused_node, + graph_inputs, + graph_outputs, + ) + + # keep track of the peak memory estimates of different methods + peak_memory_diff_methods: list[PeakMemoryResult] = [] + peak_memory_diff_methods.append( + PeakMemoryResult(nodes, estimated_peak_memory, "baseline") + ) + torch_log.info("Baseline peak memory: %d", estimated_peak_memory) + + # other methods + for method in methods: + try: + if method == topological_sort_lpmf: + order = method( + nodes, name_to_freeable_input_buf, name_to_buf, graph_outputs + ) + else: + order = method(nodes) + assert len(order) == len(nodes) + peak_memory, _ = estimate_peak_memory( + order, name_to_freeable_input_buf, graph_outputs + ) + peak_memory_diff_methods.append( + PeakMemoryResult(order, peak_memory, method.__name__) + ) + torch_log.info("%s peak memory: %d", method.__name__, peak_memory) + except Exception as e: + torch_log.error("Failed to reorder for %s: %s", method.__name__, e) + + signpost_event( + category="inductor", + name="memory", + parameters={ + "orm": {elem.method: elem.peak_memory for elem in peak_memory_diff_methods}, + }, + ) + + # get the optimal one + best_result = min(peak_memory_diff_methods, key=lambda x: x.peak_memory) + + return best_result.order diff --git a/phivenv/Lib/site-packages/torch/_inductor/metrics.py b/phivenv/Lib/site-packages/torch/_inductor/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..4de7d4042f47d66f39144b8c82df36c59cccb3c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/metrics.py @@ -0,0 +1,454 @@ +from __future__ import annotations + +import csv +import dataclasses +import inspect +import os +import re +from dataclasses import dataclass +from functools import lru_cache +from typing import Callable, cast, Optional, TYPE_CHECKING, Union + +from torch._inductor import config +from torch._inductor.utils import get_benchmark_name +from torch.utils._ordered_set import OrderedSet + + +# Prevent circular import +if TYPE_CHECKING: + from torch._inductor.scheduler import BaseSchedulerNode + +# counter for tracking how many kernels have been generated +generated_kernel_count = 0 +generated_cpp_vec_kernel_count = 0 +num_bytes_accessed = 0 +nodes_num_elem: list[ + tuple[ + BaseSchedulerNode, + int, + ] +] = [] +node_runtimes: list[tuple[BaseSchedulerNode, float]] = [] + +# counters for tracking fusions +ir_nodes_pre_fusion = 0 + +# counters for tracking to_dtype inserted +cpp_to_dtype_count = 0 + + +@dataclasses.dataclass +class CppOuterLoopFusedCount: + inner_kernel_number: int + local_buffer_number: int = 0 + + +# The length counts the number of outer loop fusions. +cpp_outer_loop_fused_inner_counts: list[CppOuterLoopFusedCount] = [] + +num_comprehensive_padding = 0 +num_matches_for_scatter_upon_const_tensor = 0 + +num_loop_reordering = 0 + +# counter for parallel reduction. +parallel_reduction_count = 0 + + +# reset all counters +def reset() -> None: + global generated_kernel_count + global generated_cpp_vec_kernel_count + global num_bytes_accessed, nodes_num_elem + global ir_nodes_pre_fusion + global cpp_to_dtype_count + global cpp_outer_loop_fused_inner_counts + global num_comprehensive_padding + global num_matches_for_scatter_upon_const_tensor + global num_loop_reordering + global parallel_reduction_count + + generated_kernel_count = 0 + generated_cpp_vec_kernel_count = 0 + num_bytes_accessed = 0 + nodes_num_elem.clear() + node_runtimes.clear() + ir_nodes_pre_fusion = 0 + cpp_to_dtype_count = 0 + cpp_outer_loop_fused_inner_counts.clear() + num_comprehensive_padding = 0 + num_matches_for_scatter_upon_const_tensor = 0 + num_loop_reordering = 0 + parallel_reduction_count = 0 + + +@dataclass +class CachedMetricsDeltas: + """ + The subset of metrics we want update across cache hits, e.g., the + FxGraphCache. + """ + + generated_kernel_count: int + generated_cpp_vec_kernel_count: int + ir_nodes_pre_fusion: int + cpp_to_dtype_count: int + num_bytes_accessed: int + num_matches_for_scatter_upon_const_tensor: int + + +def get_metric_fields() -> list[str]: + return [field.name for field in dataclasses.fields(CachedMetricsDeltas)] + + +class CachedMetricsHelper: + """ + A helper class to help calculate and apply counter deltas for those + metrics we want to save with cache entries (e.g., FxGraphCache) and + apply on a cache hit. + """ + + def __init__(self) -> None: + self.cached_metrics = {} + for metric in get_metric_fields(): + self.cached_metrics[metric] = globals()[metric] + + def get_deltas(self) -> CachedMetricsDeltas: + delta_metrics = {} + for metric in get_metric_fields(): + delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric] + + return CachedMetricsDeltas(**delta_metrics) + + @staticmethod + def apply_deltas(delta: CachedMetricsDeltas) -> None: + for metric in get_metric_fields(): + globals()[metric] += getattr(delta, metric) + + +REGISTERED_METRIC_TABLES: dict[str, MetricTable] = {} + + +@dataclass +class MetricTable: + table_name: str + column_names: list[str] + + num_rows_added: int = 0 + + def add_row( + self, row_fn: Callable[[], dict[str, Optional[Union[str, float]]]] + ) -> None: + if self.table_name not in enabled_metric_tables(): + return + + row_dict = row_fn() + assert len(self.column_names) == len(row_dict), ( + f"{len(self.column_names)} v.s. {len(row_dict)}" + ) + assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), ( + f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}" + ) + + bn = get_benchmark_name() + # assert bn is not None + row = [bn] + [row_dict[column_name] for column_name in self.column_names] + assert all(isinstance(i, str) for i in row) + self._write_row(cast(list[str], row)) + + def output_filename(self) -> str: + return f"metric_table_{self.table_name}.csv" + + def write_header(self) -> None: + filename = self.output_filename() + with open(filename, "w") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(["model_name"] + self.column_names) + + def _write_row(self, row: list[str]) -> None: + filename = self.output_filename() + if self.num_rows_added == 0 and not os.path.exists(filename): + self.write_header() + + self.num_rows_added += 1 + + for idx, orig_val in enumerate(row): + if isinstance(orig_val, float): + new_val = f"{orig_val:.6f}" + elif orig_val is None: + new_val = "" + else: + new_val = orig_val + row[idx] = new_val + + with open(filename, "a") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(row) + + @staticmethod + def register_table(name: str, column_names: list[str]) -> None: + table = MetricTable(name, column_names) + REGISTERED_METRIC_TABLES[name] = table + + +MetricTable.register_table( + "slow_fusion", + [ + "kernel1_path", + "kernel1_latency", + "kernel2_path", + "kernel2_latency", + "fused_kernel_path", + "fused_kernel_latency", + "slow_down_ratio", + ], +) + +# track the fusion statistics for each graph +MetricTable.register_table( + "graph_stats", + [ + "graph_id", + "num_nodes_before_fusion", + "num_nodes_after_fusion", + ], +) + +# track the perf difference between persistent reduction and non-persistent +# reductions +MetricTable.register_table( + "persistent_red_perf", + [ + "kernel0_path", + "kernel1_path", + "kernel2_path", + "kernel3_path", + "kernel0_latency", + "kernel1_latency", + "kernel2_latency", + "kernel3_latency", + "size_hints", + "reduction_hint", + ], +) + +# Log the fusion failures due to indexing mismatch +MetricTable.register_table( + "fusion_failure_due_to_indexing_mismatch", + [ + "pre_grad_graph_id", + "post_grad_graph_id", + "node1_name", + "node2_name", + "node1_debug_str", + "node2_debug_str", + "common_buffer_names", + "failure_reason", + ], +) + +# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint +MetricTable.register_table( + "kernel_metadata", + [ + "kernel_name", + "kernel_path", + "kernel_category", # pointwise/reduction/foreach etc. + "size_hints", + "reduction_hint", + "line_of_code", + "num_load", + "num_store", + "num_for_loop", + "num_atomic_add", + "num_args", + # xyz numel can be different to size_hints since size_hints are rounded + # up to the nearest power of 2. + # Inductor kernel will burn in the xyz numel in kernel code for static + # shape kernels. + # Logging them will be helpful to find unaligned shape for reduction + "xnumel", + "ynumel", + "rnumel", + "kernel_args_num_gb", + ], +) + + +def _parse_kernel_fn_code(kernel_module_code: str) -> str: + """ + The kernel_module_code is the python module that contains kernel function code. + kernel function is the proper triton kernel function annotated with + @triton.jit + """ + from .codecache import PyCodeCache + from .wrapper_benchmark import get_triton_kernel + + mod = PyCodeCache.load(kernel_module_code) + kernel = get_triton_kernel(mod) + # kernel is a CachingAutotune; kernel.fn is the JITFunction; + # kernel.fn.fn is the function being decorate by triton.jit + return inspect.getsource(kernel.fn.fn) + + +def _parse_kernel_line_of_code(proper_kernel_fn_code: str) -> int: + """ + Return the line of code for the kernel excluding the decorators. + """ + return len(proper_kernel_fn_code.splitlines()) + + +def _parse_size_hints(kernel_module_code: str, kernel_category: str) -> Optional[str]: + if kernel_category == "foreach": + # foreach kernel does not have size_hints + return None + m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code) + assert m, "size_hints missing!" + return m.group(1) + + +def _parse_reduction_hint( + kernel_category: str, kernel_module_code: str +) -> Optional[str]: + if kernel_category not in ("reduction", "persistent_reduction"): + return None + m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code) + assert m, "reduction_hint not found in kernel source code!" + return m.group(1) + + +def _count_pattern(proper_kernel_fn_code: str, pattern: str) -> int: + return proper_kernel_fn_code.count(pattern) + + +def _count_args(proper_kernel_fn_code: str) -> int: + def_line = proper_kernel_fn_code.splitlines()[0] + assert def_line.startswith("def ") + start_idx = def_line.index("(") + end_idx = def_line.index("):") + decl_csv = def_line[start_idx + 1 : end_idx] + comps = decl_csv.split(",") + return len(comps) + + +def _parse_proper_kernel_fn_code(kernel_fn_code: str) -> str: + """ + Skip decorators. + """ + start_pos = kernel_fn_code.index("def ") + return kernel_fn_code[start_pos:] + + +def _parse_numel(proper_kernel_fn_code: str, numel_arg_name: str) -> Optional[int]: + m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code) + if m: + return int(m.group(1)) + else: + return None + + +def _parse_kernel_args_num_gb( + kernel_fn_code: str, kernel_category: str +) -> Optional[float]: + """ + inductor meta looks like: + inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0}, + """ + m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code) + if m: + return float(m.group(1)) + else: + """ + There are a few cases that kernel_num_gdb field can be missing: + 1. the field will be missing if config.benchmark_kernel and + config.profile_bandwidth are false + 2. even if config.benchmark_kernel or config.profile_bandwidth is true. + foreach kernel does not have kernel_num_gb field in the metadata + """ + return None + + +def log_kernel_metadata( + kernel_name: str, kernel_path: str, kernel_module_code: str +) -> None: + """ + An utility to log kernel metadata. We may parse metadata from kernel source code here. + + It's fine to parse the generated kernel code here since the logging is + disabled by default. It would hurt compilation time. + """ + from .wrapper_benchmark import get_kernel_category_by_source_code + + kernel_category = get_kernel_category_by_source_code(kernel_module_code) + reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code) + size_hints = _parse_size_hints(kernel_module_code, kernel_category) + kernel_fn_code = _parse_kernel_fn_code(kernel_module_code) + + proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code) + + # the line of code excluding the decortors + kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code) + + get_metric_table("kernel_metadata").add_row( + lambda: { + "kernel_name": kernel_name, + "kernel_path": kernel_path, + "kernel_category": kernel_category, + "size_hints": size_hints, + "reduction_hint": reduction_hint, + "line_of_code": kernel_line_of_code, + "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"), + "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"), + "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "), + "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"), + "num_args": _count_args(proper_kernel_fn_code), + "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"), + "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"), + "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"), + "kernel_args_num_gb": _parse_kernel_args_num_gb( + kernel_fn_code, kernel_category + ), + } + ) + + +def purge_old_log_files() -> None: + """ + Purge the old log file at the beginning when the benchmark script runs. + Should do it in the parent process rather than the child processes running + each individual model. + """ + for name, table in REGISTERED_METRIC_TABLES.items(): + if name in enabled_metric_tables(): + filename = table.output_filename() + if os.path.exists(filename): + os.unlink(filename) + + table.write_header() + + +def enabled_metric_tables() -> OrderedSet[str]: + return enabled_metric_tables_impl(config.enabled_metric_tables) + + +@lru_cache +def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]: + enabled: OrderedSet[str] = OrderedSet() + for name in config_str.split(","): + name = name.strip() + if not name: + continue + assert name in REGISTERED_METRIC_TABLES, ( + f"Metric table name {name} is not registered" + ) + enabled.add(name) + return enabled + + +def is_metric_table_enabled(name: str) -> bool: + return name in enabled_metric_tables() + + +def get_metric_table(name: str) -> MetricTable: + assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined" + return REGISTERED_METRIC_TABLES[name] diff --git a/phivenv/Lib/site-packages/torch/_inductor/mkldnn_ir.py b/phivenv/Lib/site-packages/torch/_inductor/mkldnn_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..9d439ce4846bb241a4ba57736857436f6bc47050 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/mkldnn_ir.py @@ -0,0 +1,1316 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence +from typing import Any, Optional + +import sympy + +import torch +from torch._prims_common import make_channels_last_strides_for +from torch.utils._ordered_set import OrderedSet + +from .ir import ( + ExternKernelAlloc, + FixedLayout, + FlexibleLayout, + get_device_type, + ir_node_to_tensor, + is_contiguous_storage_and_layout, + Layout, + may_convert_to_optional, + MultiOutput, + MultiOutputLayout, + MutationOutput, + NoneLayout, + TensorBox, +) +from .utils import convert_shape_to_inductor, pad_listlike, SUPPORTED_MKLDNN_DEVICES +from .virtualized import V + + +def _prepare_convolution_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding: Sequence[int], + stride: Sequence[int], + dilation: Sequence[int], + groups: int, + transposed: bool = False, + output_padding: Optional[Sequence[int]] = None, + quantize_args: Optional[list["TensorBox"]] = None, + other: Optional["TensorBox"] = None, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for convolution post-op fusion's create function, including deciding the output + layout (channels first or channels last), realizing inputs and make them etc. The + function only supports the CPU/XPU device since conv post-op fusion kernel is only + supported on CPU/XPU right now. + """ + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size + def _conv_input_size( + output_size, weight_size, padding, output_padding, stride, dilation, groups + ): + assert len(output_size) == len(weight_size), "Expect input dim == weight dim" + dim = len(output_size) + assert dim > 2, "Expect input dim > 2" + + BATCH_DIM = 0 + WEIGHT_INPUT_CHANNELS_DIM = 1 + input_size = [] + input_size.append(output_size[BATCH_DIM]) + input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) + for d in range(2, dim): + kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 + input_size_d = ( + (output_size[d] - 1) * stride[d - 2] + - (padding[d - 2] * 2) + + kernel + + output_padding[d - 2] + ) + input_size.append(input_size_d) + return list(map(int, input_size)) + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_output_size + def _conv_output_size(input_size, weight_size, padding, stride, dilation=None): + has_dilation = dilation is not None + dim = len(input_size) + output_size = [] + output_size.append(input_size[0]) + output_size.append(weight_size[0]) + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + kernel = dilation_ * (weight_size[d] - 1) + 1 + output_size_d = (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[ + d - 2 + ] + 1 + output_size.append(output_size_d) + return output_size + + # The size of prepacked_weight is the prepacked weight size of deconv: + # Groups > 1: [g*o, i/g, ...] + # Groups == 1: [o, i, ...] + # Returns original weight size in [i, o, ...] + def _original_deconv_weight_size( + prepacked_weight, + groups, + ): + prepacked_weight_size = prepacked_weight.size() + dim = len(prepacked_weight_size) + assert dim > 2, "Expect weight dim > 2" + if groups > 1: + weight_size = [] + weight_size.append(prepacked_weight_size[1] * groups) + weight_size.append(prepacked_weight_size[0] / groups) + weight_size.extend(prepacked_weight_size[d] for d in range(2, dim)) + else: + weight_size = prepacked_weight.transpose(0, 1).size() + return weight_size + + x.realize() + weight.realize() + if bias is not None: + bias.realize() + with V.graph.fake_mode: + # TODO cleaned up the fake_tensor trace as Linear implementation + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + dims = len(x_fake.size()) - 2 + assert 0 < len(padding) <= dims + assert 0 < len(dilation) <= dims + assert 0 < len(stride) <= dims + padding = pad_listlike(padding, dims) + dilation = pad_listlike(dilation, dims) + stride = pad_listlike(stride, dims) + if output_padding is None: + output_padding = pad_listlike([0], dims) + else: + assert 0 < len(output_padding) <= dims + output_padding = pad_listlike(output_padding, dims) + assert isinstance(groups, (int, sympy.core.numbers.Integer)) + if transposed: + # When transposed, the size of the prepacked oneDNN weight is different + # from the PyTorch weight. We're not able to run aten conv with such + # size. We infer the output size from the input params here: + weight_size = _original_deconv_weight_size(weight_fake, groups) + input_size = x_fake.size() + output_size = _conv_input_size( + input_size, + weight_size, + padding, + output_padding, + stride, + dilation, + groups, + ) + else: + x_shape = list(x_fake.shape) + weight_shape = list(weight_fake.shape) + if len(x_shape) != len(weight_shape): + assert len(x_shape) == 3 and len(weight_shape) == 4 + weight_shape.pop(2) + output_size = _conv_output_size( + x_shape, + weight_shape, + padding, + stride, + dilation, + ) + + req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) + req_stride_order = [len(req_stride_order)] + req_stride_order + + x = cls.require_stride_order(x, req_stride_order) + + # We won't do weight prepack for Conv if dynamic_shapes or if is xpu. + # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. + # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), + # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order + # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel, + # this tensor is considered as channels first and the output will be in contiguous format. + # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. + dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) + if ( + dynamic_shapes or get_device_type(x) == "xpu" + ) and is_contiguous_storage_and_layout(x): + output_stride = FlexibleLayout.contiguous_strides(output_size) + # Currently we don't support channel last for the situation that stride of input's batch dim is 0, + # eg. input_size = (1, 1280, 64, 64), but input_stride=(0, 1, 81920, 1280). + # So we use NCHW hear instead. + # Different with cpu, cpu conv always use channels_last for convolution when weight is prepacked, + # but xpu does not do the prepack, so the problem exposed here is only for xpu. + # TODO support channels_last for such zero stride input. + elif get_device_type(x) == "xpu" and x.get_stride()[0] == 0: + output_stride = FlexibleLayout.contiguous_strides(output_size) + else: + output_stride = make_channels_last_strides_for(output_size) + + assert get_device_type(x) == get_device_type(weight) + assert get_device_type(x) in SUPPORTED_MKLDNN_DEVICES + inputs = [x] + + if quantize_args is not None: + x_scale, x_zero_point, w_scale, w_zero_point = quantize_args + x_scale.realize() + x_zero_point.realize() + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + [weight] + [w_scale, w_zero_point] + else: + inputs += [weight] + + if other is not None: + other = cls.require_stride_order(other, req_stride_order) + assert isinstance(other, TensorBox) + inputs += [other] + + kernel_layout = FixedLayout( + x.get_device_or_error(), + x.get_dtype(), + convert_shape_to_inductor(output_size), + convert_shape_to_inductor(output_stride), + ) + constant_args = [padding, stride, dilation, groups] + if transposed: + constant_args.insert(1, output_padding) + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order, other + + +def _prepare_linear_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + quantize_args: Optional[list["TensorBox"]] = None, + other: Optional["TensorBox"] = None, + binary_sum: bool = False, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for linear post-op fusion's create function. The function only supports the CPU device + since linear post-op fusion kernel is only supported on CPU right now. + """ + x.realize() + weight.realize() + if bias is not None: + bias.realize() + + *m, _ = x.get_size() + # The weight has been transposed during the qlinear weight prepack process. + # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ + # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 + _, oc = weight.get_size() + output_size = list(m) + [oc] + req_stride_order = list(reversed(range(len(x.get_size())))) + + x = cls.require_stride_order(x, req_stride_order) + assert get_device_type(x) == get_device_type(weight) + assert get_device_type(x) in SUPPORTED_MKLDNN_DEVICES + inputs = [x] + + if quantize_args is not None: + x_scale, x_zero_point, w_scale, w_zero_point = quantize_args + x_scale.realize() + x_zero_point.realize() + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + [weight] + [w_scale, w_zero_point] + else: + inputs += [weight] + + if other is not None: + if binary_sum: + other = cls.require_stride_order(other, req_stride_order) + inputs = inputs + [other] + + output_stride = FlexibleLayout.contiguous_strides(output_size) + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ) + constant_args: list[Any] = [] + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order, other + + +def _create_output_node(packed): + output_ir = MultiOutput( + packed.get_layout(), + packed, + [], + ) + packed.layout = MultiOutputLayout(device=packed.get_device()) + packed.outputs = [output_ir] + return output_ir + + +class ConvolutionUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + self.device_type = get_device_type(inputs[0]) + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise.default, + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: list[int], + stride_: list[int], + dilation_: list[int], + groups: int, + attr, + scalars: Optional[list[Any]], + algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + _, + _, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + packed = ConvolutionUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + +class ConvolutionBinary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + cpp_constant_args=(), + ) -> None: + self.device_type = get_device_type(inputs[0]) + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise.binary, + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise_binary", + ) + self.cpp_constant_args = cpp_constant_args + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: list[int], + stride_: list[int], + dilation_: list[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[list[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + _, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + +class ConvolutionBinaryInplace(ExternKernelAlloc): + def __init__( + self, + kernel_layout, + inputs, + constant_args=(), + ) -> None: + # Due to constrain of op.call, other (Tensor&) should be at input[0] + self.device_type = get_device_type(inputs[0]) + reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] + + super().__init__( + kernel_layout, + reordered_inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise_.binary, + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise_binary_", + ) + + self.mutation_outputs = [ + MutationOutput(NoneLayout(device=inputs[0].get_device()), inputs[0], self), + MutationOutput(NoneLayout(device=inputs[1].get_device()), inputs[1], self), + ] + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: list[int], + stride_: list[int], + dilation_: list[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[list[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + _, + req_stride_order, + _, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinaryInplace( + kernel_layout=NoneLayout(device=inputs[1].get_device()), # type: ignore[arg-type] + inputs=inputs, + constant_args=constant_args, + ) + # This op mutates in place which means that the result is not the + # target but rather the input that is being mutated + # init reorders the inputs, so inputs[1] becomes packed.inputs[0] + return packed.inputs[0] + + +class ConvolutionTransposeUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default, + cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_transpose_pointwise", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: list[int], + output_padding_: list[int], + stride_: list[int], + dilation_: list[int], + groups_: int, + attr, + scalars: Optional[list[Any]], + algorithm, + ): + transposed = True + ( + inputs, + constant_args, + kernel_layout, + _, + _, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups_, + transposed, + output_padding_, + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + packed = ConvolutionTransposeUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + +class QConvPointWisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 5 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qconv_pointwise.default, + cpp_kernel_name="aoti_torch_cpu__qconv_pointwise_tensor", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: "TensorBox", + x_zero_point: "TensorBox", + qw: "TensorBox", # qw + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, + output_scale: float, + output_zero_point: int, + output_dtype, + attr, + scalars, + algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + kernel_layout, + _, + _, + ) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + [x_scale, x_zero_point, w_scale, w_zero_point], + ) + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + attr, + may_convert_to_optional(scalars), + algorithm, + ] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. + kernel_layout.dtype = output_dtype + + return QConvPointWisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + Needs input/weight/output qparams + if bias is not None + - inputs = [x, x_scale, x_zp, w, w_scale, w_zp, accum, b] + - const_args = [stride, padding, dilation, groups, o_scale, o_zp, + output_dtype, accum_scale, accum_zp, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, x_scale, x_zp, w, w_scale, w_zp, accum] + - const_args [b, stride, padding, dilation, groups, o_scale, o_zp, + output_dtype, accum_scale, accum_zp, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 8 + self.idx_for_inplace_sum = 6 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qconv2d_pointwise.binary, + cpp_kernel_name=("aoti_torch_cpu__qconv2d_pointwise_binary_tensor"), + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self) -> Sequence[str]: + return [self.inputs[self.idx_for_inplace_sum].get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: "TensorBox", + x_zero_point: "TensorBox", + qw: "TensorBox", # packed_weight + w_scale, + w_zero_point, + qaccum: "TensorBox", + bias: "TensorBox", + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, + output_scale: "TensorBox", + output_zero_point: "TensorBox", + output_dtype, + accum_scale, + accum_zero_point, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + _kernel_layout, + req_stride_order, + qaccum, + ) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + [x_scale, x_zero_point, w_scale, w_zero_point], + qaccum, + ) + + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + accum_scale, + accum_zero_point, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + assert binary_attr == "sum", ( + "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + ) + + V.graph.mark_buffer_mutated(qaccum.get_name()) + packed = QConvPointWiseBinaryPT2E( + layout=NoneLayout(device=qaccum.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + # Return accum since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + +class MKLPackedLinear(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkl._mkl_linear.default, + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + + @classmethod + def create(cls, x, packed_w, orig_w, B, batch_size): + x = cls.require_stride1(cls.realize_input(x)) + orig_w = cls.require_stride1(cls.realize_input(orig_w)) + *m, _ = x.get_size() + oc, _ = orig_w.get_size() + output_size = list(m) + [oc] + output_stride = FlexibleLayout.contiguous_strides(output_size) + inputs = [x, packed_w, orig_w] + constant_args = [batch_size] + if B is not None: + inputs += [B] + else: + constant_args.insert(0, None) + + return MKLPackedLinear( + layout=FixedLayout( + x.get_device(), x.get_dtype(), output_size, output_stride + ), + inputs=inputs, + constant_args=constant_args, + ) + + +class LinearUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._linear_pointwise.default, + cpp_kernel_name="aoti_torch_cpu__linear_pointwise", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + + @classmethod + def create(cls, x, w, B, attr, scalars, algorithm): + x = cls.require_contiguous(cls.realize_input(x)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, _ic = x.get_size() + oc, _ic = w.get_size() + output_size = list(m) + [oc] + inputs = [x, w] + constant_args = [attr, scalars if scalars else [-1], algorithm] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, None) + + packed = LinearUnary( + layout=FixedLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=output_size, + ), + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + def apply_constraint(self): + pass + + +class LinearBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._linear_pointwise.binary, + cpp_kernel_name="aoti_torch_cpu__linear_pointwise_binary", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + + @classmethod + def create(cls, x, y, w, B, attr): + x = cls.require_contiguous(cls.realize_input(x)) + y = cls.require_contiguous(cls.realize_input(y)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, _ic = x.get_size() + oc, _ic = w.get_size() + output_size = list(m) + [oc] + inputs = [x, y, w] + constant_args = [attr] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, B) + + packed = LinearBinary( + layout=FixedLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=output_size, + ), + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + def apply_constraint(self): + pass + + +class QLinearPointwisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=(torch.ops.onednn.qlinear_pointwise.tensor), + cpp_kernel_name=("aoti_torch_cpu__qlinear_pointwise_tensor"), + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: "TensorBox", + x_zero_point: "TensorBox", + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + post_op_name, + post_op_args, + post_op_algorithm, + ): + (inputs, constant_args, kernel_layout, _, _) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + [x_scale, x_zero_point, w_scale, w_zero_point], + ) + + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + post_op_name, + may_convert_to_optional(post_op_args), + post_op_algorithm, + ] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + ) + + +class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + ) -> None: + """ + if bias is not None + - inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2, bias] + - const_args is: [o_scale, o_zp, + fp32_output, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2] + - const_args is: [bias, o_scale, o_zp, + fp32_output, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.idx_for_inplace_sum = 6 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=(torch.ops.onednn.qlinear_pointwise.binary_tensor), + cpp_kernel_name="aoti_torch_cpu__qlinear_pointwise_binary_tensor", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self) -> Sequence[str]: + binary_post_op = self.constant_args[-5] + if binary_post_op == "sum": + return [self.inputs[self.idx_for_inplace_sum].get_name()] + else: + return [] + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: "TensorBox", + x_zero_point: "TensorBox", + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + other: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + unary_post_op_args, + unary_post_op_algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + other, + ) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + [x_scale, x_zero_point, w_scale, w_zero_point], + other, + binary_post_op == "sum", + ) + + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + may_convert_to_optional(unary_post_op_args), + unary_post_op_algorithm, + ] + + if binary_post_op == "sum": + V.graph.mark_buffer_mutated(other.get_name()) + packed = QLinearPointwiseBinaryPT2E( + layout=NoneLayout(device=other.get_device()), + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + ) + # Return other since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwiseBinaryPT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + ) + + +class MkldnnRnnLayer(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.aten.mkldnn_rnn_layer.default, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + w0: "TensorBox", + w1: "TensorBox", + w2: "TensorBox", + w3: "TensorBox", + hx: "TensorBox", + cx: "TensorBox", + reverse: bool, + batch_sizes: list[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + x = cls.require_stride1(cls.realize_input(x)) + # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. + # Make sure x is contiguous in batch_first case. + x.freeze_layout() + w0 = cls.require_stride1(cls.realize_input(w0)) + w1 = cls.require_stride1(cls.realize_input(w1)) + w2 = cls.require_stride1(cls.realize_input(w2)) + w3 = cls.require_stride1(cls.realize_input(w3)) + hx = cls.require_stride1(cls.realize_input(hx)) + hx.freeze_layout() + cx = cls.require_stride1(cls.realize_input(cx)) + cx.freeze_layout() + + input_size = x.get_size() + assert len(input_size) == 3, "Expect lstm input to be 3D" + # batch_first is handled in the lstm OP. When entering + # rnn_layer here, we'll always have batch_first = False + seq_length, mini_batch, input_size = input_size + output_shape = [seq_length, mini_batch, hidden_size] + + hy_shape = hx.get_size() + cy_shape = cx.get_size() + + inputs = [x, w0, w1, w2, w3, hx, cx] + constant_args = [ + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ] + + packed = MkldnnRnnLayer( + MultiOutputLayout(device=x.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + def get_strides_of_lstm_output(output_shape, batch_first): + assert len(output_shape) == 3, "Expect output_shape to be 3D" + return FlexibleLayout.contiguous_strides(output_shape) + + # C shim call requires all the outputs to be passed in, and thus the last + # dummy return value is added. + output_sizes = [output_shape, hy_shape, cy_shape, [1]] + output_strides = [ + get_strides_of_lstm_output(output_shape, batch_first), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), + [1], + ] + output_ir = [ + MultiOutput( + FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ), + packed, + [(tuple, i)], + ) + for i, (output_size, output_stride) in enumerate( + zip(output_sizes, output_strides) + ) + ] + packed.outputs = output_ir + + return output_ir + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + return super().codegen(wrapper) + + +# Add this IR so that we can include shim_cpu.h for cpp_wrapper +class WeightInt4PackMatmul(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + inputs = [x, w, qGroupSize, qScalesAndZeros] + constant_args = () + """ + assert len(inputs) == 4 + assert len(constant_args) == 0 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=(torch.ops.quantized.int4mm_packed_weight_cpu.default), + cpp_kernel_name=("aoti_torch_cpu__weight_int4pack_mm_cpu_tensor"), + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + w: "TensorBox", + qGroupSize: "TensorBox", + qScalesAndZeros: "TensorBox", + ): + inputs = [x, w, qGroupSize, qScalesAndZeros] + *m, _ = x.get_size() + n, _ = w.get_size() + output_size = list(m) + [n] + output_stride = FlexibleLayout.contiguous_strides(output_size) + kernel_layout = FixedLayout( + x.get_device(), # type: ignore[arg-type] + x.get_dtype(), + output_size, + output_stride, + ) + return WeightInt4PackMatmul( + layout=kernel_layout, + inputs=inputs, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/mkldnn_lowerings.py b/phivenv/Lib/site-packages/torch/_inductor/mkldnn_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..a32d021485382a19f45ec5c02fe22623bca8281b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/mkldnn_lowerings.py @@ -0,0 +1,1347 @@ +# mypy: allow-untyped-defs +import functools +from typing import Optional + +import torch +import torch.utils._pytree as pytree +from torch._inductor.kernel.mm_common import mm_args + +from . import config, ir +from .codegen.cpp_gemm_template import CppGemmTemplate +from .codegen.cpp_grouped_gemm_template import CppGroupedGemmTemplate +from .codegen.cpp_utils import create_epilogue_with_attr +from .ir import TensorBox +from .lowering import ( + add, + add_needs_realized_inputs, + aten, + permute, + register_lowering, + to_dtype, + view, +) +from .select_algorithm import ( + autotune_select_algorithm, + ChoiceCaller, + ExternKernelChoice, +) +from .utils import use_aten_gemm_kernels, use_cpp_gemm_template +from .virtualized import ops, OpsValue, V + + +def create_int8_compensation( + W_tensor: torch.Tensor, + packed_weight: ir.TensorBox, + x_scale: ir.TensorBox, + x_zp: ir.TensorBox, + w_scale: ir.TensorBox, +) -> tuple[bool, ir.TensorBox, Optional[ir.TensorBox]]: + use_int8_fast_compensation_path = False + weight_compens = None + x_w_scale = None + if all( + isinstance(item, ir.TensorBox) + and item.get_name() in V.graph.constants + and hasattr(item.data, "data") + and isinstance(item.data.data, ir.ConstantBuffer) + for item in [x_scale, x_zp, w_scale] + ): + use_int8_fast_compensation_path = True + x_w_scale_tensor = ( + V.graph.constants[x_scale.get_name()] + * V.graph.constants[w_scale.get_name()] + ) + x_w_scale = V.graph.add_tensor_constant( + x_w_scale_tensor, + name=packed_weight.get_name() + "_x_w_compens", + ) + weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) + x_zp_tensor = V.graph.constants[x_zp.get_name()] + weight_compens_tensor = weight_compens_tensor * x_w_scale_tensor * x_zp_tensor + weight_compens = V.graph.add_tensor_constant( + weight_compens_tensor, + name=packed_weight.get_name() + "_BMatrixCompens", + ) + else: + weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) + weight_compens = V.graph.add_tensor_constant( + weight_compens_tensor, + name=packed_weight.get_name() + "_BMatrixCompens", + ) + return ( + use_int8_fast_compensation_path, + weight_compens, + x_w_scale, + ) + + +def codegen_int8_gemm_template_compensation( + use_int8_fast_compensation_path: bool, + input: OpsValue, + _weight_compo: OpsValue, + _x_scale: Optional[OpsValue], + _x_zp: Optional[OpsValue], + _w_scale: Optional[OpsValue], + _x_w_scale: Optional[OpsValue], +) -> OpsValue: + if use_int8_fast_compensation_path: + temp = ops.sub( + ops.mul( + input, + _x_w_scale, + ), + _weight_compo, + ) + else: + temp = ops.mul( + ops.mul( + input, + _x_scale, + ), + _w_scale, + ) + # NOTE: We will apply compensation even if the x_zp is 0 for int8 quantization. + # That's because when torch.compile is invoked for dynamic quantization, + # x might coincidentally have such values that x_zp might be zero despite + # asymmetric quantization. + # Besides, if x_zp is dummy for int8 x, or if x is statically quantized, + # we'd still perform that redundant compute to avoid making the code messy + # because we discovered that redundant computation of compensation did not + # lead to performance degradation with the input shapes tested. + temp = ops.sub( + temp, + ops.mul( + ops.mul( + ops.mul( + _x_scale, + _w_scale, + ), + _x_zp, + ), + _weight_compo, + ), + ) + return temp + + +def grouped_gemm_lowering( + x: TensorBox, + w: list[TensorBox], + b: list[TensorBox], + attr=None, + scalars=None, + algorithm=None, + layout=None, +): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + num_gemm = len(w) + + assert config.max_autotune or config.max_autotune_gemm + b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b] + + choices: list[ChoiceCaller] = [] + *_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout) + + kwargs = dict( + has_bias=[bias is not None for bias in b], + trans_w=True, + epilogue_creator=None, + act_mapping=dict.fromkeys(range(num_gemm), x), + ) + + input_nodes = [x, *w] + input_nodes.extend([bias for bias in b if bias is not None]) + + CppGroupedGemmTemplate.add_choices( + choices, + layout, + input_nodes, + **kwargs, # type: ignore[arg-type] + ) + + assert len(choices) != 0 + result = autotune_select_algorithm( + "grouped_gemm", + choices, + input_nodes, + layout, + ) + template_buf = result.data.data + return_bufs = [ + ir.MultiOutput(layout, template_buf, [(list, gemm_idx)]) + for gemm_idx in range(num_gemm) + ] + template_buf.layout = ir.MultiOutputLayout(device=input_nodes[0].get_device()) + template_buf.outputs = return_bufs + return_tensors = [ + ir.TensorBox.create(return_bufs[gemm_idx]) for gemm_idx in range(num_gemm) + ] + if len(x_size) > 2: + for gemm_idx in range(num_gemm): + return_tensors[gemm_idx] = view( + return_tensors[gemm_idx], + (*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]), + ) + return return_tensors + + +grouped_gemm_lowering._inductor_lowering_function = True # type: ignore[attr-defined] + + +def register_onednn_fusion_ops(): + if torch._C._has_mkldnn: + from . import mkldnn_ir + + aten_mkldnn_linear_unary = ExternKernelChoice( + torch.ops.mkldnn._linear_pointwise, + "mkldnn::_linear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.LinearUnary.create, + ) + aten_mkldnn_linear_binary = ExternKernelChoice( + torch.ops.mkldnn._linear_pointwise.binary, + "mkldnn::_linear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.LinearBinary.create, + ) + aten_mkldnn_qlinear_unary = ExternKernelChoice( + torch.ops.onednn.qlinear_pointwise, + "onednn::qlinear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create, + ) + aten_mkldnn_qlinear_binary = ExternKernelChoice( + torch.ops.onednn.qlinear_pointwise.binary, + "onednn::qlinear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create, + ) + cpu_needs_realized_inputs = [ + torch.ops.mkldnn._convolution_pointwise, + torch.ops.mkldnn._convolution_pointwise_, + torch.ops.mkldnn._convolution_transpose_pointwise, + torch.ops.mkldnn._linear_pointwise, + aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qconv_pointwise, + ] + + @register_lowering(torch.ops.mkldnn._convolution_pointwise) + def convolution_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionUnary.create( + x, + weight, + bias, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary) + def convolution_binary( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionBinary.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) + def convolution_binary_inplace( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionBinaryInplace.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise) + def linear_unary( + x: TensorBox, + w: TensorBox, + b: TensorBox, + attr, + scalars, + algorithm, + layout=None, + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + if b is not None: + b = ir.ExternKernel.realize_input(b) + choices: list[ChoiceCaller] = [] + if config.max_autotune or config.max_autotune_gemm: + transposed_w = permute(w, [1, 0]) + *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) + if use_cpp_gemm_template(layout, x, transposed_w): + + def epilogue_creator(buf): + return create_epilogue_with_attr( + buf, attr, scalars=scalars, algorithm=algorithm + ) + + kwargs = dict( + has_bias=b is not None, + trans_w=True, + epilogue_creator=None if attr == "none" else epilogue_creator, + ) + if b is not None: + kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment] + CppGemmTemplate.add_choices( + choices, + layout, + [x, w] if b is None else [x, w, b], + **kwargs, # type: ignore[arg-type] + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm) + if b is None: + kwargs["B"] = None + choices.append( + aten_mkldnn_linear_unary.bind( + [x, w] if b is None else [x, w, b], + layout, + **kwargs, + ) + ) + assert w.get_name() in V.graph.constants + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + } + result = autotune_select_algorithm( + "linear_unary", + choices, + [x, w] if b is None else [x, w, b], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) + def linear_binary( + x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + y_size = y.get_size() + if len(y_size) > 2: + y = view(y, [-1, y_size[-1]]) + if b is not None: + b = ir.ExternKernel.realize_input(b) + choices: list[ChoiceCaller] = [] + if config.max_autotune or config.max_autotune_gemm: + transposed_w = permute(w, [1, 0]) + *_, layout, x, transposed_w, y = mm_args( + x, transposed_w, y, layout=layout + ) + if use_cpp_gemm_template(layout, x, transposed_w): + + def epilogue_creator(buf): + return create_epilogue_with_attr(buf, attr, other=y) + + kwargs = dict( + has_bias=b is not None, + trans_w=True, + epilogue_creator=epilogue_creator, + ) + kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1] + CppGemmTemplate.add_choices( + choices, + layout, + [x, y, w] if b is None else [x, y, w, b], + **kwargs, # type: ignore[arg-type] + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict(attr=attr) + if b is None: + kwargs["B"] = None + choices.append( + aten_mkldnn_linear_binary.bind( + [x, y, w] if b is None else [x, y, w, b], + layout, + **kwargs, + ) + ) + assert w.get_name() in V.graph.constants + input_gen_fns = { + 2: lambda x: V.graph.constants[x.get_name()], + } + result = autotune_select_algorithm( + "linear_binary", + choices, + [x, y, w] if b is None else [x, y, w, b], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise) + def convolution_transpose_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionTransposeUnary.create( + x, + weight, + bias, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(aten.mkldnn_rnn_layer.default) + def mkldnn_rnn_layer( + x: TensorBox, + w0: TensorBox, + w1: TensorBox, + w2: TensorBox, + w3: TensorBox, + hx: TensorBox, + cx: TensorBox, + reverse: bool, + batch_sizes: list[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + return pytree.tree_map( + TensorBox.create, + mkldnn_ir.MkldnnRnnLayer.create( + x, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ), + ) + + @register_lowering(torch.ops.onednn.qconv_pointwise, type_promotion_kind=None) + def qconvolution_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + # To align with qlinear where x_scale and x_zp are converted to Tensor + assert type(x_scale) == float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + assert type(x_zp) == int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + + return TensorBox.create( + mkldnn_ir.QConvPointWisePT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering( + torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None + ) + @register_lowering( + torch.ops.onednn.qconv2d_pointwise.binary_tensor, type_promotion_kind=None + ) + def qconvolution_binary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + accum: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + accum_scale, + accum_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ): + # To align with qlinear where x_scale and x_zp are converted to Tensor + assert type(x_scale) == float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + assert type(x_zp) == int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + + if ( + binary_attr == "sum" + and output_dtype in [torch.float32, torch.bfloat16] + and accum.get_dtype() in [torch.float32, torch.bfloat16] + and accum.get_dtype() != output_dtype + ): + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype conversion here. + accum = to_dtype(accum, output_dtype) + return TensorBox.create( + mkldnn_ir.QConvPointWiseBinaryPT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + accum, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + accum_scale, + accum_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ) + ) + + @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None) + def qlinear_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + o_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + layout=None, + ): + assert packed_weight.get_dtype() is torch.int8, ( + "Only int8 weights are supported by oneDNN qlinear." + ) + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) == float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + else: + x_scale.realize() + if all(dim == 1 for dim in x_scale.get_size()): + # Corner-case discovered with LLaMA series. + # If all outer dims of x_scale are 1, make it a 0D tensor. + # Otherwise, epilogue creator will run into indexing issues. + x_scale = view(x_scale, []) + assert len(x_scale.get_size()) in [0, 1], "x_scale must be 0D or 1D" + + if x_zp is None: + # If x_zp is None, x is int8 quantized per-tensor and its scale is not reshaped, + # then the codegened code would segfault if we don't create a tensor for x_zp. + # It's safe to do so since x is a symmetrically quantized int8 tensor. + # Moreover, oneDNN qlinear API doesn't accept None value for zp + x_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="x_zp" + ) + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) == int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + else: + x_zp.realize() + + assert x_zp.get_numel() == 1, "x_zp is incompatible with oneDNN qlinear" + + # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer + # Refer to + # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 # noqa: B950 + if w_zp is None: + # If w_zp is None, then it's a dummy tensor created to denote the + # absence of a zero point, and thus w is int8 symmetrically quantized. + # Moreover, oneDNN qlinear API doesn't accept None value for zp + w_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="w_zp" + ) + w_scale.realize() + w_zp.realize() + if w_zp.get_dtype() != torch.int32 and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ): + # W_zp might be a ConstantBuffer with int64, convert it to int32 + w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) + w_zp = V.graph.add_tensor_constant( + torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() + ) + + bias_dtype = None if bias is None else bias.get_dtype() + choices: list[ChoiceCaller] = [] + + if config.max_autotune or config.max_autotune_gemm: + *_, layout, x, packed_weight = mm_args( + x, packed_weight, layout=layout, out_dtype=output_dtype + ) + + if ( + # GEMM template currently only supports symmetrically quantized weights + isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ) + and torch.equal( + torch.zeros_like(V.graph.constants[w_zp.get_name()]), + V.graph.constants[w_zp.get_name()], + ) + ) and use_cpp_gemm_template(layout, x, packed_weight): + W_tensor = V.graph.constants[packed_weight.get_name()].to_dense() + + ( + use_int8_fast_compensation_path, + weight_compens, + x_w_scale, + ) = create_int8_compensation( + W_tensor, + packed_weight, + x_scale, + x_zp, + w_scale, + ) + + def epilogue_creator(input_buffer): + # Epilogue to convert from s32 to f32 for u8s8f32 + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + torch.int8, + ] + input_loader = input_buffer.make_loader() + weight_compens_loader = weight_compens.make_loader() + x_w_scale_loader = None + if use_int8_fast_compensation_path: + assert x_w_scale is not None + x_w_scale_loader = x_w_scale.make_loader() + x_scale_loader = x_scale.make_loader() + w_scale_loader = w_scale.make_loader() + x_zp_loader = x_zp.make_loader() + nonlocal bias + bias_loader = None + if bias is not None: + bias_loader = bias.make_loader() + + def inner_fn(index): + nonlocal bias + input = input_loader(index) + # MicroKernel Output is with int32 + # cvt to FP32 before doing compensation + input = ops.to_dtype(input, torch.float32) + weight_compens_index = (index[-1],) + + _x_scale = None + _x_zp = None + _w_scale = None + if not use_int8_fast_compensation_path: + _x_scale = x_scale_loader(()) + _x_zp = x_zp_loader(()) + _w_scale = w_scale_loader(weight_compens_index) + _weight_compo = weight_compens_loader(weight_compens_index) + _x_w_scale = None + if use_int8_fast_compensation_path: + assert x_w_scale_loader is not None + _x_w_scale = x_w_scale_loader(weight_compens_index) + # Step 1: Compute s8s8->s32 or u8s8->s32 GEMM & then apply compensation + temp = codegen_int8_gemm_template_compensation( + use_int8_fast_compensation_path, + input, + _weight_compo, + _x_scale, + _x_zp, + _w_scale, + _x_w_scale, + ) + # Step 2: add Bias if applicable + if bias is not None: + _bias = bias_loader(weight_compens_index) + nonlocal bias_dtype + assert bias_dtype in [torch.float32, torch.bfloat16] + if bias_dtype == torch.bfloat16: + _bias = ops.to_dtype(_bias, torch.float32) + temp = ops.add(temp, _bias) + + return temp + + output_buf = ir.Pointwise( + device=input_buffer.get_device(), + dtype=torch.float32, # Hardcode to FP32 for u8s8f32 & s8s8f32 + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + # Step 3: Doing the unary post op fusion + if attr != "none": + output_buf = create_epilogue_with_attr( + output_buf, attr, scalars=scalars, algorithm=algorithm + ) + + # Step 4: Cast output to Target Dtype + if output_dtype == torch.bfloat16: + output_cast_loader = output_buf.make_loader() + + def inner_fn_cast_output_to_bf16(index): + input = output_cast_loader(index) + return ops.to_dtype(input, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device_or_error(), + dtype=output_dtype, + inner_fn=inner_fn_cast_output_to_bf16, + ranges=output_buf.get_size(), + ) + elif output_dtype in [torch.uint8, torch.int8]: + from .lowering import _create_constants + + requant_input_loader = output_buf.make_loader() + + def inner_fn_requant(index, scale, zero_point): + input = requant_input_loader(index) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + if output_dtype == torch.uint8: + qmin, qmax = _create_constants( + 0, 255, dtype=torch.float32 + ) + else: + qmin, qmax = _create_constants( + -128, 127, dtype=torch.float32 + ) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device_or_error(), + dtype=output_dtype, + inner_fn=functools.partial( + inner_fn_requant, + scale=float(o_scale), + zero_point=int(o_zero_point), + ), + ranges=output_buf.get_size(), + ) + + return output_buf + + assert x.get_dtype() in [torch.uint8, torch.int8] + CppGemmTemplate.add_choices( + choices, + layout, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], + has_bias=bias is not None, + epilogue_creator=epilogue_creator, + input_indices=[0, 3, 1, 2, 4, 5] + if bias is None + else [6, 0, 3, 1, 2, 4, 5], + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict( + output_scale=o_scale, + output_zero_point=o_zero_point, + output_dtype=output_dtype, + post_op_name=attr, + post_op_args=scalars, + post_op_algorithm=algorithm, + ) + if bias is None: + kwargs["bias"] = None + choices.append( + aten_mkldnn_qlinear_unary.bind( + (x, x_scale, x_zp, packed_weight, w_scale, w_zp) + if bias is None + else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias), + layout, + **kwargs, + ) + ) + assert packed_weight.get_name() in V.graph.constants + input_gen_fns = { + 3: lambda x: V.graph.constants[x.get_name()], # packed weight + 4: lambda x: V.graph.constants[x.get_name()], # weight scale + 5: lambda x: V.graph.constants[x.get_name()], # weight zp + 6: lambda x: V.graph.constants[x.get_name()], # bias + } + if isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_scale), + ir.ConstantBuffer, + ): + # x is statically quantized + input_gen_fns[1] = lambda x: V.graph.constants[x.get_name()] + if isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_zp), + ir.ConstantBuffer, + ): + input_gen_fns[2] = lambda x: V.graph.constants[x.get_name()] + + result = autotune_select_algorithm( + "qlinear_unary", + choices, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None + ) + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None + ) + def qlinear_binary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + x2: TensorBox, + bias: TensorBox, + o_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + layout=None, + ): + x_size = x.get_size() + x2_size = x2.get_size() + assert len(x_size) == len(x2_size) + if len(x_size) > 2 and binary_attr == "add": + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + x2 = view(x2, [-1, x2_size[-1]]) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) == float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + else: + x_scale.realize() + if all(dim == 1 for dim in x_scale.get_size()): + # Corner-case discovered with LLaMA series. + # If all outer dims of x_scale are 1, make it a 0D tensor. + # Otherwise, epilogue creator will run into indexing issues. + x_scale = view(x_scale, []) + assert len(x_scale.get_size()) in [0, 1], "x_scale must be 0D or 1D" + + if x_zp is None: + x_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="x_zp" + ) + + if w_zp is None: + w_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="w_zp" + ) + + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) == int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + else: + x_zp.realize() + + # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer + # Refer to + # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 # noqa: B950 + w_scale.realize() + w_zp.realize() + if w_zp.get_dtype() != torch.int32 and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ): + w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) + w_zp = V.graph.add_tensor_constant( + torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() + ) + if binary_attr == "sum": + if output_dtype in [ + torch.float32, + torch.bfloat16, + ] and x2.get_dtype() in [torch.float32, torch.bfloat16]: + if x2.get_dtype() != output_dtype: + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype conversion here. + x2 = to_dtype(x2, output_dtype) + else: + assert x2.get_dtype() == output_dtype, ( + "dtype of accum for qlinear post op sum should be the same as output" + ) + x2_dtype = x2.get_dtype() + bias_dtype = bias.get_dtype() if bias is not None else None + choices: list[ChoiceCaller] = [] + if ( + config.max_autotune or config.max_autotune_gemm + ) and binary_attr == "add": # Support inplace sum fusion + *_, layout, x, packed_weight, x2 = mm_args( + x, packed_weight, x2, layout=layout, out_dtype=output_dtype + ) + if ( + isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_zp), + ir.ConstantBuffer, + ) + and len(x_zp.get_layout().size) == 0 # Per tensor quant of act + and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ) + and torch.equal( + torch.zeros_like(V.graph.constants[w_zp.get_name()]), + V.graph.constants[w_zp.get_name()], + ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA + and use_cpp_gemm_template(layout, x, packed_weight) + ): + W_tensor = V.graph.constants[packed_weight.get_name()] + W_tensor = W_tensor.to_dense() + ( + use_int8_fast_compensation_path, + weight_compens, + x_w_scale, + ) = create_int8_compensation( + W_tensor, + packed_weight, + x_scale, + x_zp, + w_scale, + ) + + def epilogue_creator(input_buffer): + # Epilogue to convert from s32 to f32 for u8s8f32 + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + torch.int8, + ] + + input_loader = input_buffer.make_loader() + x2_loader = x2.make_loader() + weight_compens_loader = weight_compens.make_loader() + x_w_scale_loader = None + if use_int8_fast_compensation_path: + assert x_w_scale is not None + x_w_scale_loader = x_w_scale.make_loader() + x_scale_loader = x_scale.make_loader() + w_scale_loader = w_scale.make_loader() + x_zp_loader = x_zp.make_loader() + nonlocal bias + bias_loader = None + if bias is not None: + bias_loader = bias.make_loader() + + def inner_fn(index): + nonlocal bias + input = input_loader(index) + _x2 = x2_loader(index) + _x_scale = None + _x_zp = None + _w_scale = None + weight_compens_index = (index[-1],) + if not use_int8_fast_compensation_path: + _x_scale = x_scale_loader(()) + _x_zp = x_zp_loader(()) + _w_scale = w_scale_loader(weight_compens_index) + # MicroKernel Output is with int32: cvt to FP32 before doing compensation + input = ops.to_dtype(input, torch.float32) + _weight_compo = weight_compens_loader(weight_compens_index) + _x_w_scale = None + if use_int8_fast_compensation_path: + assert x_w_scale_loader is not None + _x_w_scale = x_w_scale_loader(weight_compens_index) + # Step 1: Doing compensation to cvt fp32 + temp = codegen_int8_gemm_template_compensation( + use_int8_fast_compensation_path, + input, + _weight_compo, + _x_scale, + _x_zp, + _w_scale, + _x_w_scale, + ) + # Step 2: add Bias if applicable + if bias is not None: + _bias = bias_loader(weight_compens_index) + nonlocal bias_dtype + assert bias_dtype in [torch.float32, torch.bfloat16] + if bias_dtype == torch.bfloat16: + _bias = ops.to_dtype(_bias, torch.float32) + temp = ops.add(temp, _bias) + + # Step 3: Binary add + nonlocal x2_dtype + assert x2_dtype in [torch.float32, torch.bfloat16] + if x2_dtype == torch.bfloat16: + _x2 = ops.to_dtype(_x2, torch.float32) + temp = ops.add(temp, _x2) + + return temp + + output_buf = ir.Pointwise( + device=input_buffer.get_device(), + dtype=torch.float32, # Hardcode to FP32 for u8s8f32 + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + # Step 4: Unary post op if has + if unary_attr != "none": + output_buf = create_epilogue_with_attr( + output_buf, + unary_attr, + scalars=unary_scalars, + algorithm=unary_algorithmm, + ) + + # Step 5: Cast output to Target Dtype + if output_dtype == torch.bfloat16: + output_cast_loader = output_buf.make_loader() + + def inner_fn_cast_output_to_bf16(index): + input = output_cast_loader(index) + return ops.to_dtype(input, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device_or_error(), + dtype=output_dtype, + inner_fn=inner_fn_cast_output_to_bf16, + ranges=output_buf.get_size(), + ) + elif output_dtype in [torch.uint8, torch.int8]: + from .lowering import _create_constants + + requant_input_loader = output_buf.make_loader() + + def inner_fn_requant(index, scale, zero_point): + input = requant_input_loader(index) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + if output_dtype == torch.uint8: + qmin, qmax = _create_constants( + 0, 255, dtype=torch.float32 + ) + else: + qmin, qmax = _create_constants( + -128, 127, dtype=torch.float32 + ) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, torch.uint8) + + output_buf = ir.Pointwise( + device=output_buf.get_device_or_error(), + dtype=torch.uint8, + inner_fn=functools.partial( + inner_fn_requant, + scale=float(o_scale), + zero_point=int(o_zero_point), + ), + ranges=output_buf.get_size(), + ) + + return output_buf + + CppGemmTemplate.add_choices( + choices, + layout, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], + has_bias=bias is not None, + epilogue_creator=epilogue_creator, + # Reorder bias and x2 + input_indices=[0, 3, 1, 2, 4, 5, 6] + if bias is None + else [7, 0, 3, 1, 2, 4, 5, 6], + ) + + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict( + output_scale=o_scale, + output_zero_point=o_zero_point, + output_dtype=output_dtype, + other_scale=x2_scale, + other_zp=x2_zp, + binary_post_op=binary_attr, + binary_alpha=alpha, + unary_post_op=unary_attr, + unary_post_op_args=unary_scalars, + unary_post_op_algorithm=unary_algorithmm, + ) + if bias is None: + kwargs["bias"] = None + choices.append( + aten_mkldnn_qlinear_binary.bind( + (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2) + if bias is None + else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias), + layout, + **kwargs, + ) + ) + assert packed_weight.get_name() in V.graph.constants + input_gen_fns = { + 3: lambda x: V.graph.constants[x.get_name()], + 4: lambda x: V.graph.constants[x.get_name()], + 5: lambda x: V.graph.constants[x.get_name()], + } + if bias is not None: + input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()] # For bias + result = autotune_select_algorithm( + "qlinear_binary", + choices, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2 and binary_attr == "add": + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + if torch._C.has_mkl: + aten_mkl_linear = ExternKernelChoice( + torch.ops.mkl._mkl_linear, + "mkl::_mkl_linear", + has_out_variant=False, + kernel_creator=mkldnn_ir.MKLPackedLinear.create, + ) + cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) + + @register_lowering(torch.ops.mkl._mkl_linear) + def mkl_packed_linear( + x: TensorBox, + packed_w: TensorBox, + orig_w: TensorBox, + b: Optional[TensorBox], + batch_size, + *, + layout=None, + ): + choices: list[ChoiceCaller] = [] + if config.max_autotune or config.max_autotune_gemm: + transposed_w = permute(orig_w, [1, 0]) + *_, layout, x, transposed_w = mm_args( + x, transposed_w, layout=layout + ) + if use_cpp_gemm_template(layout, x, transposed_w): + CppGemmTemplate.add_choices( + choices, + layout, + [x, packed_w, orig_w], + trans_w=True, + input_indices=[0, 2], + ) + + if len(choices) == 0 or use_aten_gemm_kernels(): + choices.append( + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ) + + assert packed_w.get_name() in V.graph.constants + assert orig_w.get_name() in V.graph.constants + # packed_w is a mkldnn tensor which we can't generate directly + # so we use the weights from the original tensor in autotune. + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + 2: lambda x: V.graph.constants[x.get_name()], + } + result: TensorBox = autotune_select_algorithm( + "packed_linear", + choices, + [x, packed_w, orig_w], + layout, + input_gen_fns=input_gen_fns, + ) + if b is not None: + result = add(result, b) + return result + + add_needs_realized_inputs(cpu_needs_realized_inputs) + else: + pass diff --git a/phivenv/Lib/site-packages/torch/_inductor/mock_cache.py b/phivenv/Lib/site-packages/torch/_inductor/mock_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..9955047b666a99dbccab219707ae46588a4f0b3a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/mock_cache.py @@ -0,0 +1,273 @@ +# mypy: ignore-errors + +from __future__ import annotations + +import contextlib +import dataclasses +import sys +import threading +from typing import Any, Callable, Optional, TYPE_CHECKING +from typing_extensions import override, Self +from unittest.mock import patch + +from torch._inductor import config +from torch._inductor.remote_cache import RemoteCacheBackend + + +if TYPE_CHECKING: + from types import TracebackType + + +@dataclasses.dataclass +class Stats: + num_put: int = 0 + num_get_hit: int = 0 + num_get_miss: int = 0 + + def __iadd__(self, other: Stats) -> Self: + self.num_put += other.num_put + self.num_get_hit += other.num_get_hit + self.num_get_miss += other.num_get_miss + return self + + def reset(self) -> None: + self.num_put = 0 + self.num_get_hit = 0 + self.num_get_miss = 0 + + def __str__(self) -> str: + return "".join( + ( + f"puts: {self.num_put}, ", + f"misses: {self.num_get_miss}, ", + f"hits: {self.num_get_hit}, ", + ) + ) + + def __eq__(self, other: object) -> bool: + # Dataclass's default __eq__ checks that the types are the same so can't + # be used with _GlobalItemStats. + return ( + isinstance(other, (Stats, _GlobalItemStats)) + and self.num_put == other.num_put + and self.num_get_hit == other.num_get_hit + and self.num_get_miss == other.num_get_miss + ) + + +class _GlobalItemStats(Stats): + cache: dict[str, object] + + def __init__(self) -> None: + super().__init__() + self.cache = {} + + def reset(self) -> None: + super().reset() + self.cache = {} + + +# The cache states are thread-local so if we're running multiple tests at once +# they won't cross contaminate. However - it needs to be "global" because we +# allow code to create new cache clients which refer to the same cache (because +# it's a remote cache). + + +class _GlobalStats(threading.local): + def __init__(self) -> None: + self.autotune_local = _GlobalItemStats() + self.autotune_remote = _GlobalItemStats() + self.bundled_autotune = _GlobalItemStats() + self.fx_graph = _GlobalItemStats() + self.triton = _GlobalItemStats() + self.aot_autograd = _GlobalItemStats() + self.dynamo_pgo = _GlobalItemStats() + + def reset(self) -> None: + self.autotune_local.reset() + self.autotune_remote.reset() + self.bundled_autotune.reset() + self.fx_graph.reset() + self.triton.reset() + self.aot_autograd.reset() + self.dynamo_pgo.reset() + + def get_stat(self, name: str) -> _GlobalItemStats: + return getattr(self, name) + + def report(self): + subs = ( + ("autotune_local", self.autotune_local), + ("autotune_remote", self.autotune_remote), + ("bundled_autotune", self.bundled_autotune), + ("fx_graph", self.fx_graph), + ("triton", self.triton), + ("aot_autograd", self.aot_autograd), + ("dynamo_pgo", self.dynamo_pgo), + ) + + print("Cache Stats:", file=sys.stderr) + for name, sub in subs: + print(f" {name}: {sub}", file=sys.stderr) + + print("Cache Entries:", file=sys.stderr) + for name, sub in subs: + if sub.cache: + print(f" {name}:", file=sys.stderr) + for k, v in sorted(sub.cache.items()): + v = repr(v) + if len(v) > 100: + v = v[:100] + "..." + print(f" {k!r}: {v}", file=sys.stderr) + + +global_stats = _GlobalStats() + + +class MockBackend(RemoteCacheBackend[Any]): + def __init__(self, name: str) -> None: + self._name = name + + @staticmethod + def with_name(name: str) -> Callable[[], MockBackend]: + def wrapper() -> MockBackend: + return MockBackend(name) + + return wrapper + + @override + def _get(self, key: str) -> Optional[Any]: + stat = global_stats.get_stat(self._name) + if key in stat.cache: + stat += Stats(num_get_hit=1) + return stat.cache.get(key) + else: + stat += Stats(num_get_miss=1) + return None + + @override + def _put(self, key: str, data: Any) -> None: + stat = global_stats.get_stat(self._name) + stat += Stats(num_put=1) + stat.cache[key] = data + + +# List of configs for each cache +_CACHE_CONFIG_EN = ( + "fx_graph_cache", + "fx_graph_remote_cache", + "autotune_local_cache", + "autotune_remote_cache", + "bundled_autotune_remote_cache", +) + + +class PatchCaches(contextlib.AbstractContextManager): + @classmethod + def setUp(cls): + # If this test is using PatchCaches then disable all the caches by + # default, letting the tests turn them on explicitly. This is because + # tests using PatchCaches will often want to check stats explicitly. + cls._savedCacheState = {} + for name in _CACHE_CONFIG_EN: + if hasattr(config, name): + cls._savedCacheState[name] = getattr(config, name) + setattr(config, name, False) + + @classmethod + def tearDown(cls): + # Restore cache defaults + for name in _CACHE_CONFIG_EN: + delattr(config, name) + if name in cls._savedCacheState: + setattr(config, name, cls._savedCacheState[name]) + + def __init__(self) -> None: + self._stack = contextlib.ExitStack() + + def __enter__(self) -> Self: + global_stats.reset() + self._stack.__enter__() + + ctx = patch( + "torch._inductor.runtime.autotune_cache.LocalAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune_local"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune_remote"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteBundledAutotuneCache.backend_override_cls", + MockBackend.with_name("bundled_autotune"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls", + MockBackend.with_name("fx_graph"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteAOTAutogradCache.backend_override_cls", + MockBackend.with_name("aot_autograd"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteDynamoPGOCache.backend_override_cls", + MockBackend.with_name("dynamo_pgo"), + ) + self._stack.enter_context(ctx) + + if config.is_fbcode(): + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune_remote"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteBundledAutotuneCache.backend_override_cls", + MockBackend.with_name("bundled_autotune"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls", + MockBackend.with_name("fx_graph"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "triton.fb.fb_memcache.FbMemcacheRemoteKernelCache.backend_override_cls", + MockBackend.with_name("triton"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteAOTAutogradCache.backend_override_cls", + MockBackend.with_name("aot_autograd"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteDynamoPGOCache.backend_override_cls", + MockBackend.with_name("dynamo_pgo"), + ) + self._stack.enter_context(ctx) + + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self._stack.__exit__(exc_type, exc_value, traceback) diff --git a/phivenv/Lib/site-packages/torch/_inductor/ops_handler.py b/phivenv/Lib/site-packages/torch/_inductor/ops_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..7c46a8e26132f4a6537bde234c55c2c3482b4b72 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/ops_handler.py @@ -0,0 +1,1147 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import inspect +import itertools +import re +import warnings +from io import StringIO +from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union +from unittest.mock import patch + +import sympy + +import torch +import torch.utils._pytree as pytree + +from ..utils._ordered_set import OrderedSet +from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str + + +T = TypeVar("T") +StoreMode = Optional[Literal["atomic_add"]] +ReductionType = Literal[ + "argmax", + "argmin", + "welford_reduce", + "welford_combine", + "any", + "max", + "min", + "prod", + "sum", + "xor_sum", +] + + +def _arg_str(a: object) -> str: + if isinstance(a, sympy.Expr): + return sympy_str(a) + return str(a) + + +# See OpDecompositions for superclass that desugars operations like reciprocal/square. +class OpsHandler(Generic[T]): + """ + Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, + as well as the contract for op handlers. The type T signifies the domain + of the abstract analysis AKA what all the functions return / take as arguments + anywhere compute occurs. + + While these operators are typically dtype polymorphic (e.g., you can use mul + on both integers and floats), they do NOT do promotion and usually return the + same dtype as the input. You are expected to have handled type promotion + during ATen decompositions. Most operators correspond exactly to pointwise + operations as defined by torch, so when in doubt about semantics, check the + corresponding torch documentation. These are all scalar operations (so they + are defined to operate on a single element at a time.) + + For convenience, many operators take a src_dtype which indicates what the dtype + of the input argument is. Although in principle this can be derived by an + analysis, providing this for ops where it is useful helps avoid having to repeatedly + recompute dtype in code generation. + + Note that this often describes a class of static methods, for stateless + ops handlers. + + Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides), + which means you will not get type errors for those methods. We have tests in + test/inductor/test_op_completeness.py which check that all operators are implemented after + all the metaprogramming has run. + """ + + def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: + """Produces a scalar constant of type dtype.""" + raise NotImplementedError + + def load_seed(self, name: str, offset: T) -> T: + """Computes inductor_prims.lookup_seed.""" + raise NotImplementedError + + def rand(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" + raise NotImplementedError + + def randn(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" + raise NotImplementedError + + def randint64(self, seed: T, offset: T, low: T, high: T) -> T: + """Computes inductor_prims.randint. offset has dtype int32.""" + raise NotImplementedError + + def masked(self, mask: T, body: Callable[[], T], other: T) -> T: + """ + Computes body, but only perform loads/stores if the boolean mask + evaluates to true. For example, you would use this if you needed to + perform an indirect load that may not be valid on some elements; + without masking, invalid accesses can cause IMAs. When mask is true, + the result is the result of body; otherwise it is other. Here, `other` + needs to be a constant. + + Contrast this with ops.where, which can multiplex between two values + that have been unconditionally computed. + """ + raise NotImplementedError + + def where(self, condition: T, input: T, other: T) -> T: + """ + Computes torch.where: when condition is true, return input; otherwise return other. + """ + raise NotImplementedError + + def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: + """ + Converts a sympy expression into a scalar of type dtype. expr is typically + an indexing expression, thus the name; however, it can also be used in + non-indexing situations. + """ + raise NotImplementedError + + def to_dtype( + self, + x: T, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> T: + """ + Convert x to dtype. src_dtype can be optionally set to specify what the original + dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). + """ + raise NotImplementedError + + def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with truncation semantics (similar to how the int + constructor works in Python). In Inductor codegen, this just decays + to trunc and then to_dtype, but this composite operation helps + roundtrips for Sympy evaluation. + + dtype is taken as an explicit parameter because the desired output + dtype is typically the index dtype, which may vary between int32 and + int64 depending on if we've shown that all the indexing operations can + be done in int32. + """ + raise NotImplementedError + + def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + raise NotImplementedError + + def floor_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + raise NotImplementedError + + def round_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with round-to-even semantics. See also trunc_to_int. + """ + raise NotImplementedError + + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: + """ + Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) + src_dtype must be the original type of x. + """ + raise NotImplementedError + + def identity(self, x: T) -> T: + """ + Returns x as is. This is used to trigger CSE. + """ + raise NotImplementedError + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operations are only available in a "kernel" context. Check + # torch._inductor.codegen.common.CSEProxy for their typical implementation + # in op handler (routing to their respective implementations in the kernel + # handler) + # + # Importantly, inside a kernel, indexing and mask variables are available + # in scope, which are typically used by sympy.Expr indexing. + + def indirect_indexing( + self, x: T, size: sympy.Expr, check: bool = True, wrap_neg=True + ) -> sympy.Expr: + """ + Convert an integral x into a sympy.Expr that can be subsequently used in + indexing computation. 'size' represents an upper bound on what valid + indexes can be; when 'check' is True, we check that the x is in bounds. + + NB: This is typically mandatory to implement for any analysis, because you + MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). + """ + raise NotImplementedError + + def load(self, name: str, index: sympy.Expr) -> T: + """ + Load from the memory location 'name', offset by some indexing expression 'index'. + """ + raise NotImplementedError + + def store( + self, + name: str, + index: sympy.Expr, + value: T, + mode: StoreMode = None, + ) -> None: + """ + Store 'value' to the memory location 'name' offset by 'expr'. If + specified, 'mode' can require the store to be an atomic addition. + """ + raise NotImplementedError + + # TODO: Better explain how the "collective" semantics of these ops; + # remember that the input value is a scalar, you can't reduce on it in the + # traditional sense! + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: T, + ) -> Union[T, tuple[T, ...]]: + """ + Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype', + using 'dtype' as the accumulation dtype for the reduction. The result + is an intermediate computation which should be stored to the final + location using 'ops.store_reduction'. + + Valid reduction types are . For Welford reduction types, this + function returns multiple outputs; consult reduction_num_outputs to + determine the amount in metaprogramming applications. + """ + raise NotImplementedError + + # TODO: in practice, this seems to actually return None, but not returning + # a T makes common __getattr__ idioms not type correctly. Figure out if + # this should be returning something. + def store_reduction(self, name: str, index: sympy.Expr, value: T) -> None: + """ + Store the fully accumulated result of 'reduction' to the memory + location 'name' offset by 'expr'. + """ + raise NotImplementedError + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[[tuple[T, ...], tuple[T, ...]], tuple[T, ...]], + values: tuple[T, ...], + ) -> tuple[T, ...]: + """ + Perform an associative scan on 'value'. + """ + # TODO: Improve the description with some pseudocode + raise NotImplementedError + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[T, ...], + stable: bool, + descending: bool, + ) -> tuple[T, ...]: + """ + Sort values along the reduction dimension. + """ + raise NotImplementedError + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> T: + # See [Note: Inductor bucketize op] + raise NotImplementedError + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # The following ops have semantics that correspond exactly to the torch + # operation with the same corresponding name. + + def abs(self, x0: T) -> T: + raise NotImplementedError + + def exp(self, x0: T) -> T: + raise NotImplementedError + + def exp2(self, x0: T) -> T: + raise NotImplementedError + + def expm1(self, x0: T) -> T: + raise NotImplementedError + + def sqrt(self, x0: T) -> T: + raise NotImplementedError + + def relu(self, x0: T) -> T: + raise NotImplementedError + + def minimum(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def maximum(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def cos(self, x0: T) -> T: + raise NotImplementedError + + def sin(self, x0: T) -> T: + raise NotImplementedError + + def lgamma(self, x0: T) -> T: + raise NotImplementedError + + def erf(self, x0: T) -> T: + raise NotImplementedError + + def cosh(self, x0: T) -> T: + raise NotImplementedError + + def sinh(self, x0: T) -> T: + raise NotImplementedError + + def acos(self, x0: T) -> T: + raise NotImplementedError + + def acosh(self, x0: T) -> T: + raise NotImplementedError + + def asin(self, x0: T) -> T: + raise NotImplementedError + + def asinh(self, x0: T) -> T: + raise NotImplementedError + + def atan2(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def atan(self, x0: T) -> T: + raise NotImplementedError + + def atanh(self, x0: T) -> T: + raise NotImplementedError + + def copysign(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def erfc(self, x0: T) -> T: + raise NotImplementedError + + def erfinv(self, x0: T) -> T: + raise NotImplementedError + + def frexp(self, x0: T): + raise NotImplementedError + + def hypot(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def log10(self, x0: T) -> T: + raise NotImplementedError + + def log2(self, x0: T) -> T: + raise NotImplementedError + + def nextafter(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def logical_and(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def logical_not(self, x0: T) -> T: + raise NotImplementedError + + def logical_or(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def logical_xor(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_and(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_not(self, x0: T) -> T: + raise NotImplementedError + + def bitwise_or(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_xor(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_left_shift(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_right_shift(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def rsqrt(self, x0: T) -> T: + raise NotImplementedError + + def log1p(self, x0: T) -> T: + raise NotImplementedError + + def tan(self, x0: T) -> T: + raise NotImplementedError + + def tanh(self, x0: T) -> T: + raise NotImplementedError + + def sigmoid(self, x0: T) -> T: + raise NotImplementedError + + def signbit(self, x0: T) -> T: + raise NotImplementedError + + def fmod(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def log(self, x0: T) -> T: + raise NotImplementedError + + def isinf(self, x0: T) -> T: + raise NotImplementedError + + def isnan(self, x0: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + # This rounds half to even to break ties + def round(self, x0: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + def floor(self, x0: T) -> T: + raise NotImplementedError + + def sign(self, x0: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + def trunc(self, x0: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + def ceil(self, x0: T) -> T: + raise NotImplementedError + + def neg(self, x0: T) -> T: + raise NotImplementedError + + def reciprocal(self, x0: T) -> T: + raise NotImplementedError + + def eq(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def ne(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def lt(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def gt(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def le(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def ge(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def add(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def sub(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def mul(self, x0: T, x1: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + def pow(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def and_(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def or_(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def xor(self, x0: T, x1: T) -> T: + raise NotImplementedError + + # These are metaprogrammed by MockHandler._init_cls + def lshift(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def rshift(self, x0: T, x1: T) -> T: + raise NotImplementedError + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These are "special" operators. These only exist if the target + # language actually supports the operator. Keep this in sync with + # pointwise_overrides_data. + + def airy_ai(self, x: T) -> T: + raise NotImplementedError + + def bessel_j0(self, x: T) -> T: + raise NotImplementedError + + def bessel_j1(self, x: T) -> T: + raise NotImplementedError + + def bessel_y0(self, x: T) -> T: + raise NotImplementedError + + def bessel_y1(self, x: T) -> T: + raise NotImplementedError + + def digamma(self, x: T) -> T: + raise NotImplementedError + + def erfcx(self, x: T) -> T: + raise NotImplementedError + + def fma(self, x: T, y: T, z: T) -> T: + raise NotImplementedError + + def igamma(self, x: T, y: T) -> T: + raise NotImplementedError + + def igammac(self, x: T, y: T) -> T: + raise NotImplementedError + + def gammainc(self, x: T, y: T) -> T: + raise NotImplementedError + + def gammaincc(self, x: T, y: T) -> T: + raise NotImplementedError + + def i0(self, x: T) -> T: + raise NotImplementedError + + def i0e(self, x: T) -> T: + raise NotImplementedError + + def i1(self, x: T) -> T: + raise NotImplementedError + + def i1e(self, x: T) -> T: + raise NotImplementedError + + def log_ndtr(self, x: T) -> T: + raise NotImplementedError + + def modified_bessel_i0(self, x: T) -> T: + raise NotImplementedError + + def modified_bessel_i1(self, x: T) -> T: + raise NotImplementedError + + def modified_bessel_k0(self, x: T) -> T: + raise NotImplementedError + + def modified_bessel_k1(self, x: T) -> T: + raise NotImplementedError + + def ndtr(self, x: T) -> T: + raise NotImplementedError + + def ndtri(self, x: T) -> T: + raise NotImplementedError + + def polygamma(self, x: T, y: T) -> T: + raise NotImplementedError + + def scaled_modified_bessel_k0(self, x: T) -> T: + raise NotImplementedError + + def scaled_modified_bessel_k1(self, x: T) -> T: + raise NotImplementedError + + def spherical_bessel_j0(self, x: T) -> T: + raise NotImplementedError + + def zeta(self, x: T, y: T) -> T: + raise NotImplementedError + + def chebyshev_polynomial_t(self, x: T, y: T) -> T: + raise NotImplementedError + + def chebyshev_polynomial_u(self, x: T, y: T) -> T: + raise NotImplementedError + + def chebyshev_polynomial_v(self, x: T, y: T) -> T: + raise NotImplementedError + + def chebyshev_polynomial_w(self, x: T, y: T) -> T: + raise NotImplementedError + + def legendre_polynomial_p(self, x: T, y: T) -> T: + raise NotImplementedError + + def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: + raise NotImplementedError + + def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: + raise NotImplementedError + + def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: + raise NotImplementedError + + def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: + raise NotImplementedError + + def hermite_polynomial_h(self, x: T, y: T) -> T: + raise NotImplementedError + + def hermite_polynomial_he(self, x: T, y: T) -> T: + raise NotImplementedError + + def laguerre_polynomial_l(self, x: T, y: T) -> T: + raise NotImplementedError + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operators are a bit special, because they are conventionally + # natively supported in both Python and C, but the semantics differ so + # care must be taken + + def truncdiv(self, x0: T, x1: T) -> T: + """C-style trunc division between integers only. Computes the true + division of two numbers and rounds the result to zero. + """ + raise NotImplementedError + + def floordiv(self, x0: T, x1: T) -> T: + """Python-style floor division between integers only. Computes the + true division of two numbers and floors the result. If you want + floor division for floats, do regular truediv and floor the result. + """ + raise NotImplementedError + + def truediv(self, x0: T, x1: T) -> T: + """True division between floats. Integer inputs are NOT valid. To + do Python-style (int, int) -> float division, use int_truediv""" + raise NotImplementedError + + def int_truediv(self, x0: T, x1: T) -> T: + """True division between integers. This is NOT the same as promoting + to float and doing integer division, there is a bespoke algorithm for + doing the division in higher precision than the above. + """ + raise NotImplementedError + + def mod(self, x0: T, x1: T) -> T: + """C-style modulus, take sign from LHS (x0).""" + raise NotImplementedError + + def remainder(self, x0: T, x1: T) -> T: + """Python-style modulus, take sign from RHS (x1).""" + raise NotImplementedError + + def square(self, x0: T) -> T: + raise NotImplementedError + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + raise NotImplementedError + + # halide-only + def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T: + raise NotImplementedError + + # triton-only + def inline_asm_elementwise( + self, + *inputs: T, + asm: str, + constraints: Optional[str] = None, + dtype: torch.dtype = torch.float32, + is_pure: bool = True, + pack: int = 1, + ) -> T: + raise NotImplementedError + + def output(self, *args: T) -> None: + """This is a fake op used in analysis but not codegen""" + raise NotImplementedError + + def placeholder(self, index: int) -> T: + """This is a fake op used in analysis but not codegen""" + raise NotImplementedError + + +_ignore_op_re = re.compile(r"_.*|paren").fullmatch + + +def list_ops(cls: type[Any]): + return OrderedSet([x for x in dir(cls) if not _ignore_op_re(x)]) + + +OP_NAMES = list_ops(OpsHandler) + + +class DefaultHandler(OpsHandler[Any]): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + """ + Default implementation for all ops. Override in a subclass to + provide generic op behavior. + + Args: + name: name of the op, see OpHandler.{name} + args: positional args passed to the op + kwargs: keyword args passed to the op + + Returns: + return value of the op + + """ + raise NotImplementedError + + def __getattr__(self, name: str) -> Any: + def fallback(*args: Any, **kwargs: Any) -> Any: + return self._default(name, args, kwargs) + + # would like to remove this function entirely, but it's used in MTIA backend + warnings.warn(f"undefined OpHandler.{name}, please add missing op schema") + return fallback + + @staticmethod + def _call_default(target: str): + def call_default(self, *args, **kwargs): + return self._default(target, args, kwargs) + + call_default.__name__ = target + return call_default + + @classmethod + def _init_cls(cls): + """ + Here we codegen many functions of the form: + + def add(self, a, b): + return self._default('add', (a, b), {}) + + and install them in cls. This is the same as _call_default above, + but is about 1.2x faster since CPython varargs parsing is slow. + """ + code = StringIO() + for target in OP_NAMES: + sig = inspect.signature(getattr(OpsHandler, target)) + if all( + p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is inspect.Parameter.empty + for p in sig.parameters.values() + ): + self_arg, *args = sig.parameters.keys() + assert self_arg == "self" + code.write( + f""" + def {target}(self, {", ".join(args)}): + return self._default({target!r}, ({", ".join(args)}, ), {{}}) + """.strip() + ) + code.write("\n\n") + else: + # slower fallback for ops with default or variadic arguments + setattr(cls, target, cls._call_default(target)) + + ctx: dict[str, Any] = {} + exec(code.getvalue(), ctx) + for target, impl in ctx.items(): + if target in OP_NAMES: + setattr(cls, target, impl) + + +DefaultHandler._init_cls() + + +class NoopHandler(DefaultHandler): + name = "NoopHandler" + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return None + + @staticmethod + def masked(mask, body, other) -> None: + return None + + @staticmethod + def frexp(x) -> tuple[None, None]: + return (None, None) + + @staticmethod + def scan(dtypes, combine_fn, values) -> tuple[None, ...]: + return (None,) * len(values) + + @staticmethod + def sort(dtypes, values, stable, descending) -> tuple[None, ...]: + return (None,) * len(values) + + @staticmethod + def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: + return sympy.S.Zero + + +class BasicMathOpsMixin: + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def floordiv(a, b): + return f"{a} // {b}" + + @staticmethod + def truediv(a, b): + return f"{a} / {b}" + + @staticmethod + def mod(a, b): + # careful, depending on target semantics varies + return f"{a} % {b}" + + @staticmethod + def pow(a, b): + return f"{a} ** {b}" + + @staticmethod + def lshift(a, b): + return f"{a} << {b}" + + @staticmethod + def rshift(a, b): + return f"{a} >> {b}" + + @staticmethod + def and_(a, b): + return f"{a} & {b}" + + @staticmethod + def or_(a, b): + return f"{a} | {b}" + + @staticmethod + def xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def eq(a, b): + return f"{a} == {b}" + + @staticmethod + def ne(a, b): + return f"{a} != {b}" + + @staticmethod + def lt(a, b): + return f"{a} < {b}" + + @staticmethod + def gt(a, b): + return f"{a} > {b}" + + @staticmethod + def le(a, b): + return f"{a} <= {b}" + + @staticmethod + def ge(a, b): + return f"{a} >= {b}" + + @staticmethod + def neg(a): + return f"-{a}" + + +class MockHandler(BasicMathOpsMixin, DefaultHandler): + name = "MockHandler" + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + fargs = [*map(_arg_str, args)] + for k, v in kwargs.items(): + fargs.append(f"{k}={_arg_str(v)}") + return f"ops.{name}({', '.join(fargs)})" + + @staticmethod + def masked(mask, body, other) -> str: + return f"ops.masked({mask}, {body()}, {other})" + + @staticmethod + def frexp(x): + return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]") + + @staticmethod + def scan(dtypes, combine_fn, values): + return tuple( + f"ops.scan({dtypes}, {combine_fn}, {values})[{i}]" + for i in range(len(values)) + ) + + @staticmethod + def sort(dtypes, values, stable, descending): + return tuple( + f"ops.sort({dtypes}, {values}, stable={stable}, descending={descending})[{i}]" + for i in range(len(values)) + ) + + @staticmethod + def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: + return sympy_index_symbol(str(index_var)) + + +class KernelFormatterHandler(DefaultHandler): + def __init__(self, parent_handler: OpsHandler[Any]): + self.parent_handler = parent_handler + self._output = IndentedBuffer(1) + self.var_counter = itertools.count() + + @staticmethod + def ir_to_string(ir_fn, index, rindex=None) -> str: + from .ir import FlexibleLayout + from .virtualized import V + + args = [index, rindex] if rindex is not None else [index] + names = ["index", "rindex"] if rindex is not None else ["index"] + formatter = KernelFormatterHandler(MockHandler()) + + with formatter._output.indent(-1): + formatter._output.writeline(f"def inner_fn({', '.join(names)}):") + for name, arg in zip(names, args): + if arg: + lhs = ", ".join( + [ + str("_" if isinstance(v, (int, sympy.Integer)) else v) + for v in arg + ] + ) + formatter._output.writeline(f"{lhs} = {name}") + + with ( + V.set_ops_handler(formatter), + patch.object(FlexibleLayout, "allow_indexing", True), + ): + result = ir_fn(*args) + return formatter.getvalue(result) + + def indirect_indexing(self, *args, **kwargs) -> sympy.Symbol: + return self.parent_handler.indirect_indexing(*args, **kwargs) + + def _write(self, line): + # replace line with a new variable name + varname = f"tmp{next(self.var_counter)}" + self._output.writeline(f"{varname} = {line}") + return varname + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return pytree.tree_map( + self._write, getattr(self.parent_handler, name)(*args, **kwargs) + ) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[str, tuple[str, ...]], + ) -> Union[str, tuple[str, ...]]: + line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) + num_values = reduction_num_outputs(reduction_type) + varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] + self._output.writeline(f"{','.join(varnames)} = {line}") + return tuple(varnames) if num_values > 1 else varnames[0] + + def getvalue(self, result): + self._output.writeline(f"return {result}") + return self._output.getvalue() + + +class WrapperHandler(DefaultHandler): + def __init__(self, inner: OpsHandler[Any]): + self._inner = inner + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return getattr(self._inner, name)(*args, **kwargs) + + +class AddParenHandler(WrapperHandler): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + val = getattr(self._inner, name)(*args, **kwargs) + if not val or isinstance(val, (sympy.Expr, tuple, list)): + return val + return f"({val})" + + +class OpCountResult(NamedTuple): + num_ops: int + used_ops: OrderedSet[str] + read_buffers: list[str] + nontrivial_read_count: int + + +class OpCounterCSE(DefaultHandler): + """Shim to count how many ops are used""" + + def __init__(self, inner: OpsHandler[Any]): + super().__init__() + self.parent_handler = inner + self.op_count = 0 + self.var_names: dict[str, str] = {} + self._used_ops: OrderedSet[str] = OrderedSet() + self._read_names: list[str] = [] + self._nontrivial_read_count = 0 + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + self._used_ops.add(name) + return pytree.tree_map( + self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) + ) + + def _update_count(self, val): + varname = self.var_names.get(val) + if not varname: + varname = f"tmp{self.op_count}" + self.op_count += 1 + self.var_names[val] = varname + return varname + + def indirect_indexing(self, *args, **kwargs): + self._used_ops.add("indirect_indexing") + return self.parent_handler.indirect_indexing(*args, **kwargs) + + def load(self, name: str, index: sympy.Expr) -> str: + val = self.parent_handler.load(name, index) + if val not in self.var_names: + self._used_ops.add("load") + self._read_names.append(name) + if not isinstance(index, (sympy.Integer, int)): + self._nontrivial_read_count += 1 + return self._update_count(val) + + def load_seed(self, name: str, offset: T): + val = self.parent_handler.load_seed(name, offset) + if val not in self.var_names: + self._used_ops.add("load_seed") + self._read_names.append(name) + return self._update_count(val) + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> T: + """ + See [Note: Inductor bucketize op] + """ + val = self.parent_handler.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + if val not in self.var_names: + self._used_ops.add("bucketize") + self._read_names.append(boundaries[0]) + if sorter is not None: + self._read_names.append(sorter[0]) + return self._update_count(val) + + def getvalue(self): + return OpCountResult( + self.op_count, self._used_ops, self._read_names, self._nontrivial_read_count + ) + + +class ExtractConstantsHandler(NoopHandler): + def __init__(self, device: Optional[torch.device]): + self.device = device + + def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant: + from torch._inductor import ir + + return ir.Constant( + value=value, dtype=dtype, device=self.device or torch.get_default_device() + ) + + +class SimpleCSEHandler(WrapperHandler): + """Wraps the underlying handler with a CSE pass + + NOTE: Compared to codegen level CSE this is simplified as it + doesn't support stores which require load cache invalidation. + """ + + def __init__(self, inner: Any): + super().__init__(inner) + self.cse_cache: dict[str, Union[Any, tuple[Any, ...]]] = {} + self.mock = MockHandler() + + def indirect_indexing(self, *args, **kwargs) -> sympy.Expr: + return super().indirect_indexing(*args, **kwargs) # type: ignore[misc] + + def store(self, *args, **kwargs) -> None: + raise NotImplementedError("store not implemented") + + def store_reduction(self, *args, **kwargs) -> None: + raise NotImplementedError("store not implemented") + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + key = getattr(self.mock, name)(*args, **kwargs) + val = self.cse_cache.get(key) + if val is not None: + return val + + val = getattr(self._inner, name)(*args, **kwargs) + self.cse_cache[key] = val + return val diff --git a/phivenv/Lib/site-packages/torch/_inductor/optimize_indexing.py b/phivenv/Lib/site-packages/torch/_inductor/optimize_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..6c159b9b73e290ddc61d47bed2c79b70ddc32fea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/optimize_indexing.py @@ -0,0 +1,126 @@ +import math +from typing import Any + +import sympy + +import torch +from torch.utils._sympy.value_ranges import ValueRanges + +from .loop_body import LoopBody +from .utils import dominated_nodes + + +def val_expressable_in_32_bits(val: Any) -> bool: + if getattr(val, "is_Boolean", False): + return True + + if isinstance(val, sympy.Expr): + assert val.is_number + if val.is_Integer or val.is_Boolean: + val = int(val) + else: + val = float(val) + + # bound within mantissa + if isinstance(val, float): + return val <= (2**24) and val >= -(2**24) + + if isinstance(val, int): + iinfo = torch.iinfo(torch.int32) + return val <= iinfo.max and val >= iinfo.min + + raise TypeError(f"Unexpected value {val}") + + +def range_expressable_in_32_bits(range: ValueRanges[sympy.Expr]) -> bool: + return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits( + range.upper + ) + + +def try_to_reduce_precision( + node: Any, + bounds: dict[Any, Any], + indirect_vars: list[Any], + indices: dict[Any, sympy.Expr], + replacement_vals: dict[Any, ValueRanges[sympy.Expr]], +) -> None: + # if a downstream use of a node explicitly converts to int32, or float16/float32/float64, + # then it's precision is set for that chain of uses, and we don't need to consider those + # dominated values + def skip_filter(node: Any) -> bool: + return node.target == "to_dtype" and node.args[2] in ( + torch.int32, + torch.float32, + torch.float64, + ) + + # TODO - there are dominated uses whose dtype does not depend on whether + # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to + # int32 without changing the output precision of the node. this case hasn't shown up + for dominated in dominated_nodes([node], skip_filter): + if dominated.target in ["store", "output"]: + continue + + if isinstance(dominated.target, str) and "set_indirect" in dominated.target: + idx = int(dominated.target[len("set_indirect") :]) + indirect_var = indirect_vars[idx] + + # We check that we can compute all the indices it's involved in with int32 + for index, expr in indices.items(): + if indirect_var in expr.free_symbols: + index_val = replacement_vals[index] + + if math.isinf(index_val.lower) or math.isinf(index_val.upper): + return + + # all indices are integers, so make sure that we + # use the bounds of integers instead of floats. + # TODO - not sure if we should be doing int/float casts while tracing, + # might interfere with sympy. + + index_val_int = ValueRanges[sympy.Expr]( + int(index_val.lower), int(index_val.upper) + ) + if not range_expressable_in_32_bits(index_val_int): + return + + if not range_expressable_in_32_bits(bounds[dominated]): + return + + args = list(node.args) + args[2] = torch.int32 + node.args = tuple(args) + + +def indexing_dtype_strength_reduction(loop_body: LoopBody) -> None: + """ + Performs Value Range Analysis on LoopBody's fx graph to reduce precision of + intermediaries from int64 to int32 + """ + bv = loop_body.bounds() + + int64_dtype_nodes = [ + node + for node in loop_body.get_nodes() + if ( + node.target == "to_dtype" + and node.args[2] == torch.int64 + and node not in bv.unbounded_vars + ) + ] + if not int64_dtype_nodes: + return + + bounds = bv.get_bounds() + + # TODO - if dominated node of one to_dtype is not expressible in int32, + # we should short circuit another to_dtype node if that node also dominates + for node in int64_dtype_nodes: + try_to_reduce_precision( + node, + bounds, + loop_body.indirect_vars, + loop_body.indexing_exprs, + bv.replacement_vals, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/output_code.py b/phivenv/Lib/site-packages/torch/_inductor/output_code.py new file mode 100644 index 0000000000000000000000000000000000000000..f35cb2c2205cbebd7fb1038f620738d7c0a727ca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/output_code.py @@ -0,0 +1,758 @@ +""" +This provides an abstract class which parametrizes over an "output code" concept +for Inductor. Intuitively, this represents the compiled callable which Inductor +produces which you can call to get optimized code. However, this callable +has some other capabilities: + +- It is serializable, so you can save/load this product from disk without + having to do compilation again. + +- (When using remote cache) it is addressable, so you can save just a key + which you can use to load this product from remote cache later. + +This class is abstract because we have several different implementations of +serialized format: + +- Python wrapper (the default) + +- AOTInductor (this produces ABI stable binaries which work across PyTorch + versions) + +""" + +from __future__ import annotations + +import dataclasses +import logging +import os +from functools import partial +from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeAlias + +import torch +from torch._dynamo.utils import counters, get_runtime_metrics_context +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + CudagraphCachedInfo, + CudagraphMetadata, + get_partition_cudagraph_metadata, + get_placeholder_info, + log_cudagraph_skip_and_bump_counter, +) +from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param +from torch._inductor.utils import ( + align_inputs_from_check_idxs, + BoxedBool, + GraphPartitionMap, + InputType, + output_node, + set_tracing_context_output_strides, +) +from torch.utils._ordered_set import OrderedSet + +from . import config +from .runtime.autotune_cache import AutotuneCacheBundler + + +if TYPE_CHECKING: + from collections import Counter + from collections.abc import Sequence + + from torch._inductor import metrics + from torch._inductor.graph import GraphLowering + from torch._library.fake_class_registry import FakeScriptObject + from torch.export.pt2_archive._package_weights import Weights + + from .compile_fx import _CompileFxKwargs + from .triton_bundler import TritonBundle + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class OutputCode: + # TODO: Remove underscores here + + # None if the output is not remote cacheable + _fx_graph_cache_key: Optional[str] = dataclasses.field(default=None, init=False) + _fx_graph_cache_debug_lines: Optional[list[str]] = dataclasses.field( + default=None, init=False + ) + + # How long it took to compile this OutputCode, end to end + _time_taken_ns: Optional[int] = dataclasses.field(default=None, init=False) + + def __call__(self, inputs: Sequence[Any]) -> Any: + raise NotImplementedError(type(self)) + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + raise NotImplementedError(type(self)) + + # TODO: Get rid of this + def set_triton_bundle(self, triton_bundle: Any) -> None: + raise NotImplementedError(type(self)) + + +_StrideExprStr: TypeAlias = str + + +# copy_ fails when trying to write to tensors with memory overlap, +# for expanded dimensions (a dimension which used to have size 1 -> ?) +# we can select one element from that dimension and write to it +# to achieve writing to all values of that dimension of the input tensor +def get_expanded_dims(t: torch.Tensor) -> list[int]: + if not isinstance(t, torch.Tensor): + return None + return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] + + +def index_expanded_dims(t: torch.Tensor, expanded_dims: list[int]) -> torch.Tensor: + for expanded_dim in expanded_dims: + t = torch.ops.aten.slice(t, expanded_dim, 0, 1) + return t + + +def complex_memory_overlap(t: torch.Tensor) -> bool: + if config.always_complex_memory_overlap_TESTING_ONLY: + return True + + # if torch._debug_has_internal_overlap thinks this tensor potentially has + # memory overlap internally, let's dig deeper to find out whether it's true. + # + # Call squeeze() so that dimension with size 1 does not cause false positive. + t = index_expanded_dims(t, get_expanded_dims(t)).squeeze() + if torch._debug_has_internal_overlap(t) != 0: + strides = t.stride() + sizes = t.shape + indices = list(range(len(strides))) + indices = [x for _, x in sorted(zip(strides, indices))] + for i in range(len(strides)): + prev_stride = 1 if i == 0 else strides[indices[i - 1]] + prev_size = 1 if i == 0 else sizes[indices[i - 1]] + if strides[indices[i]] < prev_stride * prev_size: + return True + return False + + +def maybe_handle_backward_generation( + compiled_graph: CompiledFxGraph, + boxed_forward_device_index: Optional[BoxedDeviceIndex], +) -> None: + assert compiled_graph.current_callable is not None + is_backward = compiled_graph.fx_kwargs["is_backward"] + + # See [Backward Generation Handling] + # if cudagraph'd the forward and set the device, we need to let the cudagraph manager + # know we are we running the backward even if we will not run it in cudagraphs + if is_backward and config.triton.cudagraph_trees: + assert boxed_forward_device_index is not None + assert boxed_forward_device_index.value is not None + compiled_graph_callable = compiled_graph.current_callable + + manager = torch._inductor.cudagraph_trees.get_manager( + boxed_forward_device_index.value, create_if_none_exists=False + ) + # should already exist from forward + assert manager is not None + + def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]: + manager.set_to_running_backward() # type: ignore[union-attr] + return compiled_graph_callable(new_inputs) + + compiled_graph.current_callable = compiled_artifact + + +def prepare_cudagraph_post_compile( + compiled_graph: CompiledFxGraph, + example_inputs: Sequence[InputType], + boxed_forward_device_index: Optional[BoxedDeviceIndex], +) -> None: + if not config.triton.cudagraph_trees: + # Force specialize all inputs so that CUDA graphs will work + for t in example_inputs: + if isinstance(t, torch.SymInt): + int(t) # guard + + is_inference = compiled_graph.fx_kwargs["is_inference"] + is_backward = compiled_graph.fx_kwargs["is_backward"] + if boxed_forward_device_index is not None and not is_inference and not is_backward: + boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs))) + + +def cudagraph_post_compile( + example_inputs: Sequence[InputType], + compiled_graph: CompiledFxGraph, + cudagraphs: BoxedBool, + constants: dict[str, torch.Tensor], + boxed_forward_device_index: Optional[BoxedDeviceIndex], +) -> None: + """ + Checks for any reasons not to run cudagraphs and then + runs it on compiled_graph. + Mutates the `compiled_graph.current_callable` and `cudagraphs` + """ + assert compiled_graph.current_callable is not None + assert compiled_graph.cudagraph_info is not None + cached_info = compiled_graph.cudagraph_info + cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons + is_inference = compiled_graph.fx_kwargs["is_inference"] + is_backward = compiled_graph.fx_kwargs["is_backward"] + + if not cudagraph_fail_reasons: + fx_kwargs = compiled_graph.fx_kwargs + static_input_idxs = fx_kwargs["static_input_idxs"] + + placeholders = cached_info.placeholders + stack_traces = cached_info.stack_traces + + prepare_cudagraph_post_compile( + compiled_graph, example_inputs, boxed_forward_device_index + ) + + from .compile_fx import cudagraphify + + current_callable = compiled_graph.current_callable + assert current_callable is not None + compiled_graph.current_callable = cudagraphify( + current_callable, + static_input_idxs=static_input_idxs or (), + device_index=next(iter(compiled_graph.device_idxs)), + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=tuple(constants.values()), + placeholders=placeholders, + mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs), + ) + + else: + BoxedBool.disable(cudagraphs) + maybe_handle_backward_generation(compiled_graph, boxed_forward_device_index) + + if "cuda" in compiled_graph.device_types: + # prefer better disable_cudagraphs_reason bc stack trace + # TODO: migrate all disable reasons to stack trace, refactor + if compiled_graph.disabled_cudagraphs_reason: + log_cudagraph_skip_and_bump_counter( + compiled_graph.disabled_cudagraphs_reason + ) + else: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {cudagraph_fail_reasons}" + ) + + +def cudagraph_partition_post_compile( + example_inputs: Sequence[InputType], + compiled_graph: CompiledFxGraph, + cudagraphs: BoxedBool, + constants: dict[str, torch.Tensor], + boxed_forward_device_index: Optional[BoxedDeviceIndex], +) -> None: + """ + Cudagraphify each partition functions, which first prepares the necessary + metadata and then applies the cudagraphify function to each partition. + + Assuming all partition functions are cudagraphified and share the same order + as `compiled_graph.partition_maps`. See [Note: Graph Partition Map for CUDAGraph]. + """ + assert compiled_graph.cudagraph_info is not None + cudagraph_fail_reasons = compiled_graph.cudagraph_info.cudagraph_fail_reasons + + if ( + cudagraph_fail_reasons + or compiled_graph.partition_maps is None + or len(compiled_graph.partition_maps) == 0 + ): + # cudagraphify is not called if there are no partitions + BoxedBool.disable(cudagraphs) + maybe_handle_backward_generation(compiled_graph, boxed_forward_device_index) + return + + from .compile_fx import cudagraphify + + assert compiled_graph.current_callable is not None + assert compiled_graph.recursively_apply_fns is not None + is_inference = compiled_graph.fx_kwargs["is_inference"] + is_backward = compiled_graph.fx_kwargs["is_backward"] + static_input_idxs = OrderedSet(compiled_graph.fx_kwargs["static_input_idxs"] or ()) + mutated_input_idxs = compiled_graph.mutated_input_idxs + device_index = next(iter(compiled_graph.device_idxs)) + + graph_metadata = CudagraphMetadata( + compiled_graph.cudagraph_info.placeholders, + static_input_idxs, + mutated_input_idxs, + compiled_graph.cudagraph_info.stack_traces, + constants, + ) + + prepare_cudagraph_post_compile( + compiled_graph, example_inputs, boxed_forward_device_index + ) + + # cudagraphify each partition function, assuming every graph partition function + # is cudagraphable. Non-cudagraphable ops (e.g., cpu ops) are inlined into + # `call` function and not included in partition functions. + cudagraphify_fns = [] + for partition_map in compiled_graph.partition_maps: + partition_metadata = get_partition_cudagraph_metadata( + partition_map, + graph_metadata, + ) + + cudagraphify_fn = partial( + cudagraphify, + static_input_idxs=tuple(partition_metadata.static_input_idxs), + device_index=device_index, + stack_traces=partition_metadata.stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=tuple(partition_metadata.constants.values()), + placeholders=partition_metadata.placeholders, + mutated_input_idxs=tuple(partition_metadata.mutated_input_idxs), + ) + cudagraphify_fns.append(cudagraphify_fn) + + compiled_graph.recursively_apply_fns(cudagraphify_fns) + + +def maybe_realign_inputs( + ran_cudagraphs: BoxedBool, + compiled_graph: CompiledFxGraph, + inputs_to_check: Sequence[int], + mutated_inputs_idxs: OrderedSet[int], +) -> None: + """ + Realigns input strides from inputs_to_check if + we didn't end up running cudagraphs. Mutates + `compiled_graph.current_callable` if cudagraphs + was run. Otherwise, does nothing. + """ + if not ran_cudagraphs: + assert compiled_graph.current_callable is not None + new_callable = align_inputs_from_check_idxs( + compiled_graph.current_callable, inputs_to_check, mutated_inputs_idxs + ) + if new_callable is not compiled_graph.current_callable: + compiled_graph.current_callable = new_callable + + +class CompiledFxGraphConstants: + """Wrapper class that unwraps constants from a compiled fx graph. This + version of the class only supports directly grabbing the saved constants off of + a CompiledFxGraph. + + With freezing, FxGraphCache doesn't store the constants of the input + GraphModule it gets from AOTAutograd. Instead, it saves just the **names** + of those constants, and grabs the constant values directly from the graph module + passed in at runtime. + + Thing is, we don't always *have* the graph module available at runtime, hence + the existence of this class and its CompiledFxGraphConstantsWithGm counterpart. + + To support freezing, FXGraphCache gets passed a CompiledFxGraphConstantsWithGm during + post compile. Otherwise, CompiledFxGraphConstants supports the basic case of loading + the value of constants directly off of the original saved object. + """ + + def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]: + assert g.constants is not None + return g.constants + + +class CompiledFxGraphConstantsWithGm(CompiledFxGraphConstants): + """ + This version of CompiledFxGraphConstants, instead of grabbing constants + directly saved on CompiledFxGraphs, will just grab their names. Then, it takes + a second GraphModule to grab the corresponding constant values out of. + + This is necessary for supporting freezing in FxGraphCache. + """ + + def __init__(self, gm: torch.fx.GraphModule) -> None: + self.gm = gm + + def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]: + frozen_params = { + name: getattr(self.gm, orig_name) + for name, orig_name in g.frozen_param_names.items() + } + constants = g.constants or {} + return {**constants, **frozen_params} + + +@dataclasses.dataclass +class CompiledFxGraph(OutputCode): + """ + Class holding a compiled FX graph. This is the object serialized on disk + to support FxGraph caching. + """ + + current_callable: Optional[Callable[..., Any]] + recursively_apply_fns: Optional[Callable[..., Any]] + compiled_fn_runner: Optional[Any] + cache_key: str + source_code: str = dataclasses.field(repr=False) # Do not display source_code + runnable_graph_str: str = dataclasses.field(repr=False) # Do not display graph + inductor_post_grad_graph_str: str = dataclasses.field( + repr=False + ) # Do not display graph + cache_linemap: Optional[list[tuple[int, str]]] + device_types: OrderedSet[str] + device_idxs: OrderedSet[int] + mutated_inputs: OrderedSet[str] + mutated_input_idxs: OrderedSet[int] + constants: Optional[dict[str, torch.Tensor]] + frozen_param_names: dict[str, str] + torchbind_constants: dict[str, torch._C.ScriptObject | FakeScriptObject] + output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]] + disabled_cudagraphs_reason: Optional[str] + metrics_deltas: metrics.CachedMetricsDeltas + counter_deltas: Counter[str] + # This is a string representation of an expression we serialize + # with the object so the guards can be evaluated in a different + # context in order to verify the validity of serving a cached + # fx graph. The expression must be generated by: + # ShapeEnv.produce_guards_expression() + guards_expr: Optional[str] + + cudagraph_info: Optional[CudagraphCachedInfo] + partition_maps: Optional[list[GraphPartitionMap]] + fx_kwargs: _CompileFxKwargs + inputs_to_check: Sequence[int] + + _boxed_call: Optional[bool] = None + _triton_bundle: Optional[TritonBundle] = None + + def __init__( + self, + current_callable: Optional[Callable[..., Any]], + graph: GraphLowering, + gm: torch.fx.GraphModule, + output_strides: list[Optional[tuple[_StrideExprStr, ...]]], + disabled_cudagraphs_reason: Optional[str], + metrics_deltas: metrics.CachedMetricsDeltas, + counter_deltas: Counter[str], + cudagraphs: BoxedBool, + example_inputs: Sequence[InputType], + static_input_idxs: Sequence[int], + fx_kwargs: _CompileFxKwargs, + inputs_to_check: Sequence[int], + runnable_graph_str: str, + inductor_post_grad_graph_str: str, + compiled_fn_runner: Optional[Any] = None, + ) -> None: + self.current_callable = current_callable + self.compiled_fn_runner = compiled_fn_runner + self.recursively_apply_fns = ( + compiled_fn_runner.recursively_apply_fns + if compiled_fn_runner is not None + else None + ) + self.cache_key = graph.cache_key + if graph.cache_path: + with open(graph.cache_path) as f: + self.source_code = f.read() + self.runnable_graph_str = runnable_graph_str + self.inductor_post_grad_graph_str = inductor_post_grad_graph_str + self.cache_linemap = graph.cache_linemap + # TODO - ordered set + self.device_types = OrderedSet(graph.device_types) + self.device_idxs = OrderedSet(graph.device_idxs) + self.mutated_inputs = OrderedSet(graph.mutated_inputs) + self.mutated_input_idxs = OrderedSet(graph.mutated_input_idxs) + + # We store the constant attributes in the cache entry and re-attach them + # to the module created in PyCodeCache.load_by_key_path. In the case that + # the graph has frozen parameters, we save the mapping from the attribute + # names in the GraphLowering to the original name of the attribute in the + # GraphModule. When we create the module from the cache entry, we then + # look up the constants from the current GraphModule. This scheme allows + # us to support caching with freezing. + if not has_frozen_params(gm): + self.constants = graph.constants + self.frozen_param_names = {} + else: + self.constants = {} + self.frozen_param_names = {} + for k, v in graph.constants.items(): + if is_frozen_param(v): + self.frozen_param_names[k] = graph.allocated_constant_name[k] + else: + self.constants[k] = v + + self.torchbind_constants = graph.torchbind_constants + self.output_strides = output_strides + self.disabled_cudagraphs_reason = disabled_cudagraphs_reason + self.metrics_deltas = metrics_deltas + self.counter_deltas = counter_deltas + self.guards_expr = None + self.cudagraph_info = None + self.partition_maps = graph.partition_maps + self.fx_kwargs = {} + self.inputs_to_check = () + + cudagraph_info = None + if cudagraphs: + # check cudagraph disabling reasons from inductor lowering + if self.disabled_cudagraphs_reason: + if "cuda" in self.device_types: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}" + ) + else: + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + else: + complex_memory_overlap_inputs = any( + complex_memory_overlap(t) + for t in example_inputs + if isinstance(t, torch.Tensor) + ) + + if not config.triton.cudagraph_support_input_mutation: + # Skip supports for cudagraph-managed tensors + from torch._inductor.cudagraph_utils import ( + check_for_mutation_ignore_cuda_graph_managed_tensor, + ) + + has_mutation_str = ( + check_for_mutation_ignore_cuda_graph_managed_tensor( + gm, + self.mutated_inputs, + self.mutated_input_idxs, + static_input_idxs, + ) + ) + has_mutation = has_mutation_str is not None + + if has_mutation: + self.disabled_cudagraphs_reason = has_mutation_str + else: + # Check mutation later to support cudagraph-managed tensors + has_mutation = None + + cudagraph_tests = [ + (not has_mutation, "mutated inputs"), + (not complex_memory_overlap_inputs, "complex memory overlap"), + ( + all( + isinstance(t, (torch.Tensor, torch.SymInt, torch.Generator)) + for t in example_inputs + ), + "non-Tensor inputs", + ), + ] + output = output_node(gm) + # output args are tuple of first argument + assert len(output.args) == 1 + stack_traces = [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in output.args[0] # type: ignore[union-attr] + ] + cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b] + placeholders = tuple(get_placeholder_info(gm.graph)) + cudagraph_info = CudagraphCachedInfo( + placeholders, stack_traces, cudagraph_fail_reasons + ) + + self.cudagraph_info = cudagraph_info + self.inputs_to_check = inputs_to_check + self.fx_kwargs = fx_kwargs + + # aot autograd needs to know to pass in inputs as a list + self._boxed_call = True + + def __del__(self) -> None: + if self.compiled_fn_runner is not None: + # For torch._inductor.config.graph_partition = True, + # self.compiled_fn_runner.partitions hold cudagraphified functions + # which prevents deallocation. When CompiledFxGraph is deleted, + # self.compiled_fn_runner will not be called in the future so we + # should also delete these partitions. + del self.compiled_fn_runner.partitions + + def __call__(self, inputs: Sequence[Any]) -> Any: + assert self.current_callable is not None + try: + return self.current_callable(inputs) + finally: + get_runtime_metrics_context().finish() + AutotuneCacheBundler.end_compile() + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + """ + Run a set of post processing steps after loading from the cache. These involve: + - Setting the tracing context output strides + - Running cudagraphs if enabled + - Realigning inputs + + This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph. + The results of this function are *not* saved in the cache itself. + """ + set_tracing_context_output_strides(example_inputs, self) + assert graph_kwargs["cudagraphs"] is not None + assert graph_kwargs["is_backward"] is not None + is_backward = graph_kwargs["is_backward"] + cudagraphs: BoxedBool = graph_kwargs["cudagraphs"] + if cudagraphs: + # It's possible that cudagraphs is enabled, but was disabled + # during a previous compilation we're loading from the cache. + # If so, we need to disable it on this new process too. + if self.disabled_cudagraphs_reason: + if "cuda" in self.device_types: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}" + ) + else: + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + else: + if is_backward: + assert "boxed_forward_device_index" in graph_kwargs + boxed_forward_device_index = graph_kwargs[ + "boxed_forward_device_index" + ] + else: + # On the forward we don't know whether or not + # boxed_foward_device_index is set yet + boxed_forward_device_index = graph_kwargs.get( + "boxed_forward_device_index", None + ) + + if config.graph_partition: + # with graph_partition=True, we skip some cudagraph checks if it's supported + # with partition. So we have to use cudagraph_partition_post_compile. + cudagraph_partition_post_compile( + example_inputs, + self, + cudagraphs, + constants.unwrap(self), + boxed_forward_device_index, + ) + else: + cudagraph_post_compile( + example_inputs, + self, + cudagraphs, + constants.unwrap(self), + boxed_forward_device_index, + ) + inputs_to_check = self.inputs_to_check + # cudagraphs could have been disabled from the earlier conditions + # so we still need to realign inputs if that happens + maybe_realign_inputs( + cudagraphs, + self, + inputs_to_check, + self.mutated_input_idxs, + ) + + def set_triton_bundle(self, triton_bundle: Any) -> None: + self._triton_bundle = triton_bundle + + def prepare_for_serialization(self) -> None: + # We can't really serialize callables that may be C++/Triton/etc., + # so we serialize their PyCodeCache disk cache location instead. + # TODO: This could be better if we're ever able to serialize compiled + # models to disk. + self.current_callable = None + self.recursively_apply_fns = None + self.compiled_fn_runner = None + + def write_to_disk(self) -> str: + from torch._dynamo.utils import counters + from torch._inductor.codecache import get_path, write_atomic + + # See _save_graph(); we don't store the callable in the cache entry so + # recreate it here from the PyCodeCache disk cache. + artifact_path = get_path(self.cache_key, "py")[2] + code = self.source_code + if not os.path.exists(artifact_path): + counters["inductor"]["fxgraph_lookup_write_file"] += 1 + write_atomic(artifact_path, code, make_dirs=True) + return artifact_path + + def after_deserialization(self, constants: CompiledFxGraphConstants) -> str: + from torch._dynamo.utils import dynamo_timed + from torch._inductor.codecache import PyCodeCache + + artifact_path = self.write_to_disk() + + try: + with dynamo_timed( + "PyCodeCache.load_by_key_path", + log_pt2_compile_event=True, + ): + code_cache = PyCodeCache.load_by_key_path( + self.cache_key, + artifact_path, + self.cache_linemap, + constants.unwrap(self), + ) + self.current_callable = code_cache.call + self.recursively_apply_fns = getattr( + code_cache, "recursively_apply_fns", None + ) + self.compiled_fn_runner = getattr(code_cache, "runner", None) + except OSError: + log.error("Failed to load artifact: %s", artifact_path) + raise + + return artifact_path + + +@dataclasses.dataclass +class CompiledAOTI(OutputCode): + """ + Class holding an AOTInductor compiled so. + """ + + filename: Union[str, list[Union[str, Weights]]] + + def __call__(self, inputs: Sequence[Any]) -> Any: + raise NotImplementedError("NYI") + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + pass + + def set_triton_bundle(self, triton_bundle: Any) -> None: + pass + + +@dataclasses.dataclass +class MockFXGraphCacheOutput(OutputCode): + gm: Any = None + + def __post_init__(self) -> None: + self._boxed_call = True + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + pass + + def __call__(self, inputs: Sequence[Any]) -> Any: + return self.gm(inputs) + + def set_triton_bundle(self, triton_bundle: Any) -> None: + pass diff --git a/phivenv/Lib/site-packages/torch/_inductor/pattern_matcher.py b/phivenv/Lib/site-packages/torch/_inductor/pattern_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..9447b98f3488d3849e8fdedb95f395a0b15c6a2b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/pattern_matcher.py @@ -0,0 +1,2268 @@ +""" +# Inductor Pattern Matcher + +The pattern matcher enables search/replace within an FX graph. + +The main entrypoint to the pattern matcher is register_replacement(). Given a +search function and a replacement function this will register a replacement with +a pass (such as torch._inductor.fx_passes.joint_graph.patterns). + +Internally the pattern matcher represents patterns as a graph (a DAG). Creating +new patterns manually as a graph is cumbersome and error-prone so the standard +way to create patterns (using register_replacement()) is to provide a search +function and a replacement function which is traced and converted into a graph. + +Because the search functions are built somewhat generic (they tend to ignore +tensor sizes, for example) register_replacement() allows you to specify an +`extra_check` function which performs additional checks to verify that the +matched pattern fully matches before returning it. + +## Precompiled Patterns + +New patterns are added using register_replacement(). Patterns added in this way +can have a compile-time overhead because they need to be traced before +use. Patterns can be precompiled and added using gen_register_replacement() +instead. To do this you call gen_register_replacement() instead of +register_replacement(). The arguments are the same except for an additional +unique name which is used as a lookup key. + +## Internals + +The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr +implements a `_match` method which returns either a `Match` object for a +successful match or a `FailedMatch` object for a failure to match. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import importlib +import inspect +import itertools +import logging +import operator +import os +import re +import textwrap +import typing +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Collection, Generator, Iterable, Mapping, Sequence +from pathlib import Path +from typing import Any, Callable, NoReturn, Optional, Protocol, TypeVar, Union +from typing_extensions import Self, TypeIs + +import torch +import torch._guards +import torch.fx +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import counters +from torch._prims_common import is_integer_dtype +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import statically_known_true +from torch.fx.graph_module import _get_attr +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.utils._ordered_set import OrderedSet + +from .._functorch import config as functorch_config +from .._functorch.aot_autograd import aot_function, make_boxed_func +from .._functorch.partitioners import default_partition +from .._subclasses import FakeTensor, FakeTensorMode +from ..fx import Transformer +from . import config +from .decomposition import select_decomp_table +from .lowering import fallback_node_due_to_unsupported_type + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +Constant = Any +NodeOrConstant = Union[Constant, torch.fx.Node] + + +class SearchFn(Protocol): + __name__: str + + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class ReplaceFn(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class TraceFn(Protocol): + def __call__( + self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any + ) -> torch.fx.GraphModule: ... + + +T = TypeVar("T") + +# What's a better name for this? +FnsType = Union[torch.fx.node.Target, str] + + +class Multiple: + def __init__(self) -> None: + # Ensure we're really a singleton. + assert "MULTIPLE" not in globals() or self is MULTIPLE + + +# Sentinel indicating multiple quantities can be matched +MULTIPLE = Multiple() + + +def _transfer_meta( + new_meta: dict[str, Any], old_node: torch.fx.Node, pass_name: str = "" +) -> None: + from torch.fx.traceback import NodeSource, NodeSourceAction + + # transfer metadata after pattern matching occurs. + # skip "val" and "tensor_meta" because this info is too specific; it's unlikely + # to remain accurate after pattern matching has occurred. + if config.trace.enabled: + # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. + new_from_node = new_meta.get("from_node", []).copy() + new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) + new_meta.update( + (k, v) + for k, v in old_node.meta.items() + if k in torch.fx.proxy._COPY_META_FIELDS + ) + new_meta["from_node"] = new_from_node + else: + new_meta.update( + (k, v) + for k, v in old_node.meta.items() + if k in torch.fx.proxy._COPY_META_FIELDS + ) + + +class Match: + """ + Represents a successfully matched pattern. + + The `Match` object is returned to represent a successfully matched + pattern. Included in the Match are the pattern that was matched, the graph + nodes matched, and any args that were used during the matching. + + The args and kwargs are specific to the type of pattern that was matched and + provide hints about what was matched. + """ + + pattern: PatternExpr + args: list[Any] + kwargs: dict[str, Any] + nodes: list[torch.fx.Node] + targets: dict[_TargetExpr, torch.fx.node.Target] + ctx: MatchContext + replacement_graph: Optional[torch.fx.GraphModule] + + def __init__( + self, + ctx: MatchContext, + pattern: PatternExpr, + args: Optional[Sequence[Any]] = None, + kwargs: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__() + self.pattern = pattern + # The input nodes that must be passed in to the result + self.args = list(args or []) + self.kwargs = kwargs or {} + # The nodes matched in this expression + self.nodes = [] + # Mapping CallFunction to the node.target + self.targets = {} + self.ctx = ctx + self.replacement_graph = None + + @property + def graph(self) -> torch.fx.Graph: + return self.ctx.graph + + def extend(self, other: Match) -> None: + if self.kwargs: + for key in OrderedSet(self.kwargs.keys()) & OrderedSet(other.kwargs.keys()): + if self.kwargs[key] != other.kwargs[key]: + raise FailedMatch("kwarg mismatch: {}", key) + self.args.extend(other.args) + self.nodes.extend(other.nodes) + self.kwargs.update(other.kwargs) + self.targets.update(other.targets) + + def bundle(self) -> Match: + # Wrap args in an extra list + self.args = [tuple(self.args)] if self.args else [] + return self + + def __repr__(self) -> str: + return f"Match(..., {self.args}, {self.kwargs})" + + def erase_nodes(self) -> None: + graph = self.graph + for n in reversed(self.nodes): + if not n._erased and not n.users: + graph.erase_node(n) + + def output_nodes(self) -> list[Optional[torch.fx.Node]]: + return [ + (self.ctx.pattern_to_node[p] if p is not None else None) + for p in self.ctx.outputs + ] + + def output_node(self) -> torch.fx.Node: + return next(p for p in self.output_nodes() if p) + + def replace_with_graph( + self, replacement_graph: torch.fx.Graph, args: Sequence[Any] + ) -> None: + ReplacementPatternEntry.replace_with_graph( + self, self.ctx.graph, replacement_graph, args + ) + + def replace_by_example( + self, + replacement_fn: ReplaceFn, + args: Sequence[Any], + trace_fn: Optional[TraceFn] = None, + run_functional_passes: bool = True, + ) -> None: + """Replace with a graph generated by tracing the replacement_fn. + + Args: + run_functional_passes (bool). If we should run passes that + assume functional IR (like DCE, remove_noop_ops), on the + replacement graph. + + """ + from torch._inductor.virtualized import NullHandler, V + + context = ( + V.fake_mode + if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None)) + else contextlib.nullcontext() + ) + + def should_propagate_eager_input_vals(nodes: list[torch.fx.Node]) -> bool: + if len(nodes) != 1: + return False + node = nodes[0] + if "eager_input_vals" not in node.meta: + return False + return node.target in OrderedSet( + [ + torch.ops.higher_order.triton_kernel_wrapper_functional, + torch.ops.higher_order.auto_functionalized, + torch.ops.higher_order.auto_functionalized_v2, + ] + ) + + with context: + if trace_fn is None: + trace_fn = functools.partial( + fwd_only, run_functional_passes=run_functional_passes + ) + + if should_propagate_eager_input_vals(self.nodes): + # Our strategy is: + # 1) trace out the graph with eager_input_vals (which have accurate eager-mode metadata) + # 2) trace out the graph with vals (which have the accurate Inductor metadata) + # 3) Propagate the eager_input_vals from the first graph to the second. + # 4) Use the second graph as the replacement graph. + + # Construct a map of node -> FakeTensor val in eager_input_vals + node_to_val = {} + + fake_args, fake_kwargs = self.nodes[0].meta["eager_input_vals"] + fake_kwargs = {**fake_kwargs} + match_args, match_kwargs = tuple(self.args), self.kwargs + + def record(node: torch.fx.Node, val: Any) -> None: + if isinstance(node, torch.fx.Node): + node_to_val[node] = val + + torch.utils._pytree.tree_map( + record, (match_args, match_kwargs), (fake_args, fake_kwargs) + ) + # map args to their FakeTensor val in eager_input_vals + example_vals = torch.fx.map_arg(args, lambda arg: node_to_val[arg]) + + # first graph + graph_with_eager_vals = trace_fn(replacement_fn, example_vals) + + # second graph + example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + replacement = trace_fn(graph_with_eager_vals, example_vals) + + # propagate metadata from first graph to second + # NB: This assertion might not be true in general, but it is true for + # the two use cases we have + # (triton_kernel_wrapper_functional, auto_functionalized) + assert len(graph_with_eager_vals.graph.nodes) == len( + replacement.graph.nodes + ) + for old_node, new_node in zip( + graph_with_eager_vals.graph.nodes, replacement.graph.nodes + ): + if "eager_input_vals" in old_node.meta: + new_node.meta["eager_input_vals"] = old_node.meta[ + "eager_input_vals" + ] + + else: + example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + replacement = trace_fn(replacement_fn, example_vals) + if len(self.nodes) == 1: + for n in replacement.graph.nodes: + _transfer_meta( + new_meta=n.meta, + old_node=self.nodes[0], + pass_name="replace_by_example", + ) + + ReplacementPatternEntry.replace_with_graph( + self, + self.ctx.graph, + replacement, + args, + ) + + +class FailedMatch(RuntimeError): + """ + Represents a unsuccessful match. + + The `FailedMatch` object is returned to represent a failure to match a + pattern. + """ + + format_string: str + + def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None: + self.format_string = format_string + # We want to construct error messages lazily instead of eagerly, as + # constructing them eagerly can significantly worsen compile times. + if len(format_string) > 200: + raise RuntimeError( + f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}" + ) + self.args = args + self.kwargs = kwargs + + def __str__(self) -> str: + return self.format_string.format(*self.args, **self.kwargs) + + def __bool__(self) -> bool: + return False + + +MatchResult = Union[Match, FailedMatch] + + +def is_match(m: MatchResult) -> TypeIs[Match]: + """ + TypeIs cannot act on `self`. Thus this function exists to let mypy + recognize FailedMatch.__bool__ as a TypeIs. + """ + return bool(m) + + +class MatchContext: + """ + Internal state needed while running PatternExpr._match(). + """ + + outputs: list[Optional[PatternExpr]] + pattern_to_node: dict[PatternExpr, Optional[torch.fx.Node]] + graph: torch.fx.Graph + exclusive_node_set: list[NodeOrConstant] + + def __init__( + self, + outputs: list[Optional[PatternExpr]], + pattern_to_node: Optional[dict[PatternExpr, torch.fx.Node]] = None, + *, + graph: torch.fx.Graph, + ) -> None: + self.outputs = outputs + self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node) + self.graph = graph + self.exclusive_node_set = [] + + def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult: + """wrapper to check reused nodes in patterns""" + if pattern in self.pattern_to_node: + if self.pattern_to_node[pattern] == node: + return Match(self, pattern) # already checked this node + else: + return FailedMatch("repeated pattern differs") + m = pattern._match(node, self) + assert pattern not in self.pattern_to_node + self.pattern_to_node[pattern] = node if m else None + return m + + def filter_multi_user_patterns(self) -> dict[PatternExpr, torch.fx.Node]: + return { + pattern: node + for pattern, node in self.pattern_to_node.items() + if pattern.has_multiple_users() and node is not None + } + + +class PatternExpr(ABC): + """ + Base class for types of patterns. + """ + + @abstractmethod + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: ... + + def match(self, node: torch.fx.Node) -> MatchResult: + try: + return MatchContext([self], graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def has_multiple_users(self) -> bool: + return False + + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def find_anchor_nodes( + self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + + def pattern_eq(self, other: Any) -> bool: + """ + Compare two `PatternExpr`s and return true if they are the + same. Note this is NOT matching a pattern - it is comparing the pattern + structures (for debugging). + """ + return isinstance(other, self.__class__) + + +class Arg(PatternExpr): + """ + Capture an arg which will become an input to the handler. Args are + passed in depth first order. + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, args=[node]) # matches anything + + +class Ignored(PatternExpr): + """ + Match an arg, but don't pass it to handler + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self) # matches anything + + def __repr__(self) -> str: + return "*" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + return "Ignored()" + + +class KeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + + def __repr__(self) -> str: + return f"KeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, kwargs={self.name: node}) # matches anything + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.name == other.name + + +class ExclusiveKeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + name: str + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + + def __repr__(self) -> str: + return f"ExclusiveKeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + if node in ctx.exclusive_node_set: + return FailedMatch("exclusive arg appears twice") + + ctx.exclusive_node_set.append(node) + return Match(ctx, self, kwargs={self.name: node}) # matches anything + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.name == other.name + + +class _TargetExpr(PatternExpr): + """ + Base class for filtering match by node.target + """ + + fns: list[FnsType] + fns_set: OrderedSet[FnsType] + + def __init__( + self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1 + ) -> None: + super().__init__() + fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) + for fn in fns: + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend(getattr(fn, overload) for overload in fn.overloads()) + + self.fns = fns + self.fns_set = OrderedSet(fns) + self.users = users + + @property + @abstractmethod + def op(self) -> str: ... + + def fns_repr(self) -> str: + first_repr = self.fns[0] + if not isinstance(first_repr, str): + first_repr = first_repr.__name__ + + if len(self.fns) > 1: + return f"[{first_repr}, ...]" + elif self.fns[0] is getattr(torch, first_repr, None): + return f"torch.{first_repr}" + elif self.fns[0] is getattr(operator, first_repr, None): + return f"operator.{first_repr}" + elif isinstance(self.fns[0], torch._ops.OpOverload): + return str(self.fns[0]) + else: + return first_repr + + def __repr__(self) -> str: + if self.users is MULTIPLE: + comma_users = ", MULTIPLE" + elif self.users != 1: + comma_users = f", {self.users})" + else: + comma_users = "" + return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})" + + def has_multiple_users(self) -> bool: + return isinstance(self.users, Multiple) or self.users > 1 + + def find_anchor_nodes( + self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + raise NotImplementedError + + def _match_fns(self, node: torch.fx.Node) -> bool: + return ( + isinstance(node, torch.fx.Node) + and node.op == self.op + and extract_target(node) in self.fns_set + ) + + def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool: + return ( + self in ctx.outputs + or self.users is MULTIPLE + or len(node.users) == self.users + ) + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.op == other.op + and self.fns == other.fns + and self.users == other.users + ) + + +_SimpleSpec = tuple[Any, ...] + + +class _TargetArgsExpr(_TargetExpr): + """ + Base class for filtering match by node.{target,args,kwargs} + """ + + def __init__( + self, + fns: Union[torch.fx.node.Target, str, Sequence[Any]], + *args: Any, + _users: Union[int, Multiple] = 1, + **kwargs: Any, + ) -> None: + super().__init__(fns, _users) + self.args = tuple(args) + self.kwargs = dict(kwargs) + if any( + isinstance(x, (dict, list, tuple)) + for x in itertools.chain(args, kwargs.values()) + ): + self.flatten = self.pytree_flatten + else: + self.flatten = self.simple_flatten + self.flat_args_kwargs = self.flatten(self.args, self.kwargs) + + @staticmethod + def simple_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + values = (*args, *kwargs.values()) + spec = (len(args), *kwargs.keys()) + return values, spec + + @staticmethod + def pytree_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + type_mapping: dict[type, type] = { + immutable_list: tuple, + list: tuple, + immutable_dict: dict, + } + + def convert_type(x: Any) -> Any: + cls = type(x) + convert_fn = type_mapping.get(cls) + if convert_fn is not None: + return pytree.tree_map( + convert_type, + convert_fn(x), + is_leaf=lambda x: type(x) in type_mapping, + ) + return x + + normalized_args_tree = pytree.tree_map( + convert_type, + (args, kwargs), + is_leaf=lambda x: type(x) in type_mapping, + ) + flat, spec = pytree.tree_flatten(normalized_args_tree) + return flat, spec + + def __repr__(self) -> str: + args = [ + self.fns_repr(), + *map(repr, self.args), + *[f"{k}={v}" for k, v in self.kwargs.items()], + ] + if self.users is MULTIPLE: + args.append("_users=MULTIPLE") + elif self.users != 1: + args.append(f"_users={self.users}") + return f"{self.__class__.__name__}({', '.join(args)})" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + args = [ + self.fns_repr(), + *(pp.pretty_print(x) for x in self.args), + *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()], + ] + if self.users is MULTIPLE: + args.append("_users=MULTIPLE") + elif self.users != 1: + args.append(f"_users={self.users}") + + joiner_str = ", " + return f"{self.__class__.__name__}({joiner_str.join(args)})" + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + if not self._match_fns(node) or len(node.args) != len(self.args): + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users {}", self) + + _args = node.args + _kwargs = node.kwargs + if len(_kwargs) < len(self.kwargs): + from torch.fx.operator_schemas import normalize_function + + assert callable(node.target) + normalized_args_and_kwargs = normalize_function( + node.target, node.args, node.kwargs + ) + + if normalized_args_and_kwargs is None: + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + else: + _args, _kwargs = normalized_args_and_kwargs + if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs): + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + else: + return FailedMatch( + "function_mismatch: node={}, pattern={}", node, self + ) + else: + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + + node_items, node_spec = self.flatten(_args, _kwargs) + self_items, self_spec = self.flat_args_kwargs + if node_spec != self_spec: + return FailedMatch("args_structure {} {}", node_spec, self_spec) + assert len(node_items) == len(self_items) + + m = Match(ctx, self) + for i, pattern, child_node in zip(itertools.count(), self_items, node_items): + if isinstance(pattern, PatternExpr): + child_match = ctx.match(pattern, child_node) + if not is_match(child_match): + return child_match + m.extend(child_match) + elif isinstance(child_node, torch.fx.Node) or child_node != pattern: + return FailedMatch( + "constant_args: {} {!r}!={pattern!r}", node, child_node + ) + m.nodes.append(node) + m.targets[self] = node.target + return m + + def find_anchor_nodes( + self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + """ + This is used when we are matching a pattern with multiple outputs. + There is a partial match (stored in ctx) and we want to walk + this pattern to find a connection to an already-matched node. + + Yields candidate nodes that `self._match` might like. + """ + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + return + + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.flat_args_kwargs[1] == other.flat_args_kwargs[1] + and all( + a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b + for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0]) + ) + ) + + +class CallFunction(_TargetArgsExpr): + """ + Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)` + """ + + op = "call_function" + + +class CallMethod(_TargetArgsExpr): + """ + Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)` + """ + + op = "call_method" + + +class CallModule(_TargetArgsExpr): + """ + Matches a call_module node in the FX graphs: `module(*args, **kwargs)` + """ + + op = "call_module" + + +class _TargetExprVarArgs(_TargetExpr): + """ + Matches a call_function node with any arguments which are passed into the pattern + """ + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + if not self._match_fns(node): + return FailedMatch("function_mismatch") + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users") + + m = Match(ctx, self) + m.nodes.append(node) + m.targets[self] = node.target + m.args.extend(node.args) + m.kwargs.update(node.kwargs) + return m + + +class CallFunctionVarArgs(_TargetExprVarArgs): + op = "call_function" + + +class CallMethodVarArgs(_TargetExprVarArgs): + op = "call_method" + + +class CallModuleVarArgs(_TargetExprVarArgs): + op = "call_module" + + +class ListOf(PatternExpr): + """ + Matches a repeated pattern + """ + + def __init__(self, pattern: PatternExpr, partial: bool = False) -> None: + super().__init__() + assert isinstance(pattern, PatternExpr) + self.pattern = pattern + self.partial = partial + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.pattern})" + + def _match(self, node: list[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override] + if not isinstance(node, (list, tuple)) or len(node) == 0: + return FailedMatch("non_list") + m = Match(ctx, self) + # Propagating patterns with multiple users will ensure we don't revisit + # the same nodes + pattern_to_node = ctx.filter_multi_user_patterns() + matched = False + for i, child_node in enumerate(node): + child_ctx = MatchContext( + ctx.outputs, pattern_to_node, graph=child_node.graph + ) + child_match = child_ctx.match(self.pattern, child_node) + pattern_to_node = child_ctx.filter_multi_user_patterns() + if not is_match(child_match): + if not self.partial: + return FailedMatch("list[{}]: {}", i, child_match) + continue + matched = True + m.extend(child_match.bundle()) + if not matched: + return FailedMatch("list: no_match") + return m.bundle() + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.pattern.pattern_eq(other.pattern) + and self.partial == other.partial + ) + + +class MultiOutputPattern(PatternExpr): + outputs: list[Optional[PatternExpr]] + + def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None: + super().__init__() + assert isinstance(outputs[0], _TargetExpr) + assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs + self.outputs = list(outputs) + self.op = outputs[0].op + + @property + def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]: + # This cast is checked above in __init__() + output = typing.cast(_TargetExpr, self.outputs[0]) + return output.fns + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.outputs})" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + args = [pp.pretty_print(x) for x in self.outputs] + joiner_str = f",\n{' '}" + str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}" + str_out = f"{str_out}\n])" + return str_out + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + output = typing.cast(_TargetExpr, self.outputs[0]) + m = ctx.match(output, node) + if not is_match(m): + return m + + for pattern in self.outputs[1:]: + if pattern is None: + continue + child_match = self._match_from_anchors(pattern, ctx) + if not is_match(child_match): + return child_match + m.extend(child_match) + + return m + + def _match_from_anchors( + self, pattern: PatternExpr, ctx: MatchContext + ) -> MatchResult: + prior = dict(ctx.pattern_to_node) + m: MatchResult = FailedMatch("no anchor found") + for node in pattern.find_anchor_nodes(ctx, OrderedSet()): + m = ctx.match(pattern, node) + if is_match(m): + return m + # revert any partial matches + ctx.pattern_to_node = dict(prior) + return m + + def match(self, node: torch.fx.Node) -> MatchResult: + try: + return MatchContext(self.outputs, graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and len(self.outputs) == len(other.outputs) + and all( + a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b + for a, b in zip(self.outputs, other.outputs) + ) + ) + + +class RepeatedExpr(PatternExpr): + """ + Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind` + """ + + def __init__(self, inner_pattern: _TargetExpr) -> None: + super().__init__() + self.inner_pattern = inner_pattern + self.op = inner_pattern.op + + @property + def fns(self) -> Sequence[FnsType]: + return self.inner_pattern.fns + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + m = ctx.match(self.inner_pattern, node) + if not is_match(m): + return m + ctx.pattern_to_node.pop( + self.inner_pattern, + ) + # Check all anchor nodes match the pattern + for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, OrderedSet()): + anchor_m = MatchContext([self], graph=node.graph).match( + self.inner_pattern, anchor_node + ) + if not is_match(anchor_m): + return anchor_m + m.extend(anchor_m) + return m + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.inner_pattern.pattern_eq( + other.inner_pattern + ) + + +class PatternPrettyPrinter: + """ + Serializes Patterns to executable python. + XXX: currently only used and tested for fuse attention patterns. May not cover + all patterns. + """ + + def __init__(self) -> None: + self.namespace = torch.fx.graph._Namespace() + self.memoized_objs_names: dict[PatternExpr, str] = {} + self.memoized_objs_pp: dict[PatternExpr, str] = {} + + @staticmethod + @functools.cache + def run(obj: PatternExpr, output_name: str = "output") -> str: + """ + Serializes obj to python code with obj written out to `output_name` + """ + + pp = PatternPrettyPrinter() + assert hasattr(obj, "pretty_print") + out_str = obj.pretty_print(pp=pp) + + output = [ + f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}" + for key in pp.memoized_objs_names + ] + + output.append(f"{output_name} = {out_str}") + + return "\n".join(output) + + def pretty_print(self, obj: Any) -> str: + if isinstance(obj, _TargetArgsExpr): + if memoized_name := self.memoized_objs_names.get(obj): + return memoized_name + else: + return self.memoize(obj) + if hasattr(obj, "pretty_print"): + return obj.pretty_print(self) + + return repr(obj) + + def memoize(self, obj: _TargetArgsExpr) -> str: + obj_str = obj.pretty_print(self) + obj_name = obj.fns_repr() + for prefix in ("aten.", "torch.", "prims."): + obj_name = obj_name.replace(prefix, "") + + tmp_name = self.namespace.create_name(obj_name, None) + self.memoized_objs_names[obj] = tmp_name + self.memoized_objs_pp[obj] = obj_str + return tmp_name + + +class _PassDictsType(Protocol): + def __getitem__( + self, k: tuple[str, torch.fx.node.Target] + ) -> list[PatternEntry]: ... + + +@dataclasses.dataclass +class PatternEntry: + pattern: PatternExpr + extra_check: Callable[[Match], bool] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + raise NotImplementedError + + def register( + self, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + target: Union[torch.fx.node.Target, None] = None, + prepend: bool = False, + ) -> None: + if target is None: + assert hasattr(self.pattern, "fns") + for fn in self.pattern.fns: + self.register(pass_dicts, fn, prepend=prepend) + elif isinstance(pass_dicts, (dict, PatternMatcherPass)): + assert hasattr(self.pattern, "op") + if prepend: + pass_dicts[(self.pattern.op, target)].insert(0, self) + else: + pass_dicts[(self.pattern.op, target)].append(self) + else: + pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts) + for x in pass_dicts: + self.register(x, target, prepend=prepend) + + +@dataclasses.dataclass +class LoweringPatternEntry(PatternEntry): + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) + with graph.inserting_before(node): + replacement = graph.call_function(handler, tuple(match.args), match.kwargs) + replacement.meta.update(node.meta) + node.replace_all_uses_with(replacement) + assert match.nodes[-1] is node + match.erase_nodes() + + +@dataclasses.dataclass +class GraphPatternEntry(PatternEntry): + """ + A pattern that runs a function on the FX graph + """ + + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + with graph.inserting_before(node): + self.handler(match, *match.args, **match.kwargs) + + +@dataclasses.dataclass +class ReplacementPatternEntry(PatternEntry): + normalize_args: Callable[..., list[Any]] + + @staticmethod + def replace_with_graph( + match: Match, + graph: torch.fx.Graph, + replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule], + args: Sequence[torch.fx.Node], + ) -> None: + class Replacer(torch.fx.Interpreter): + call_method = None # type: ignore[assignment] + call_module = None # type: ignore[assignment] + get_attr = None # type: ignore[assignment] + + def run_node(self, node: torch.fx.Node) -> Any: + if node.op in ("placeholder", "output"): + return super().run_node(node) + target = node.target + args, kwargs = self.fetch_args_kwargs_from_env(node) + if node.op == "call_function": + assert callable(target) + result = graph.call_function(target, args, kwargs) + _transfer_meta( + new_meta=result.meta, + old_node=node, + pass_name="Interpreter_Replacer", + ) + # This function copy-pastes the replacement graph into + # the graph. If the replacement graph had any eager_input_vals, + # or val/tensor_meta, we propagate those over. + if "eager_input_vals" in node.meta: + result.meta["eager_input_vals"] = node.meta["eager_input_vals"] + if "val" in node.meta and "val" not in result.meta: + result.meta["val"] = node.meta["val"] + if isinstance(node.meta["val"], torch.Tensor): + assert "tensor_meta" in node.meta + result.meta["tensor_meta"] = node.meta["tensor_meta"] + return result + if node.op == "get_attr": + # If the replacement graph contains a HOP, the subgraphs of the HOP are "get_attr" nodes. + # We need to fetch the subgraph of the HOP then register the subgraph to the replaced graph's root. + from torch._higher_order_ops.utils import ( + unique_graph_name_with_root, + ) + + sub_gm = super().get_attr(target, args, kwargs) + if not isinstance(sub_gm, torch.fx.GraphModule): + raise NotImplementedError( + f"NYI: replacement_graph.{target} is not a graph module. Got {sub_gm}." + ) + + assert graph.owning_module is not None + _, graph_name = unique_graph_name_with_root( + graph.owning_module, str(target) + ) + graph.owning_module.register_module(graph_name, sub_gm) + return graph.get_attr(graph_name) + + raise NotImplementedError(f"unhandled {node}") + + output_nodes = match.output_nodes() + + if len(output_nodes) == 1: + last_node = output_nodes[0] + else: + assert output_nodes[0] + nodes = list(output_nodes[0].graph.nodes) + indices = [ + (nodes.index(n), n) + for n in output_nodes + if isinstance(n, torch.fx.Node) + ] + last_node = min(indices, key=operator.itemgetter(0))[1] + + def percolate_tags( + node: torch.fx.Node, + tag_name: str, + tag_value: str, + input_stops: OrderedSet[torch.fx.Node], + ) -> None: + queue = [node] + visited = OrderedSet[torch.fx.Node]() + + while queue: + arg = queue.pop() + if ( + arg not in visited + and arg not in input_stops + and hasattr(arg, "meta") + ): + visited.add(arg) + arg.meta[tag_name] = tag_value + queue.extend(arg.all_input_nodes) + + with graph.inserting_before(last_node): + assert isinstance(replacement_graph, torch.fx.GraphModule) + replacement = Replacer(replacement_graph).run(*args) + if isinstance(replacement, torch.fx.Node): + replacement = [replacement] + + def maybe_getitem(node: torch.fx.Node) -> Any: + if node.op != "call_function": + return None + if node.target != operator.getitem: + return None + assert len(node.args) == 2 + return node.args[1] + + def replace( + old: Union[torch.fx.Node, None], + new: Union[torch.fx.Node, Sequence[torch.fx.Node], None], + ) -> None: + if old is None: + assert new is None + return + assert isinstance(old, torch.fx.Node) + if new is None: + old.replace_all_uses_with(None) # type: ignore[arg-type] + graph.erase_node(old) + return + if isinstance(new, torch.fx.Node): + if "val" not in new.meta: + new.meta.update(old.meta) + + # Preserve the recompute tags in the replacement graph. We + # look at the recompute tags of the original output node to + # propagate the tag from the output all the way to the input + # args (named as args in the replace_with_graph). + # Note that this is best effort. Since patterns are from + # many to many, there is no easy way to correctly map the + # recomputable tags. It is possible in some scenarios that we + # incorrectly tag some nodes as recomputables. + for tag_name in ["recompute", "ac_graph_id"]: + if tag_name in old.meta: + percolate_tags( + new, tag_name, old.meta[tag_name], OrderedSet(args) + ) + + old.replace_all_uses_with(new) + graph.erase_node(old) + return + + # `new` is not a node: it's a list of nodes. + # + # This happens when we want to replace a node that has a single + # packed return with multiple unpacked returns. We need to do + # some graph surgery here. + # + # Example: + # def original_graph(x): + # a = op(x) + # b = a[0] + # c = a[1] + # ... + # + # Assume that we want to replace op(x) with the graph + # def new_op(x): + # w = x + 1 + # z = x + 2 + # return (w, z) + # + # We need to replace `op` with the contents of `new_op`, + # and then rewrite a[0] to be w and a[1] to be z, as so: + # def new_graph(x): + # w = x + 1 + # z = x + 2 + # b = w + # c = z + # ... + old_uses = list(old.users.keys()) + for user in old_uses: + idx = maybe_getitem(user) + if idx is None: + raise AssertionError("can't handle") + replace(user, new[idx]) + graph.erase_node(old) + + if len(output_nodes) == len(replacement): + for old, new in zip(output_nodes, replacement): + replace(old, new) + else: + assert len(output_nodes) == 1 + replace(output_nodes[0], replacement) + + match.erase_nodes() + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + assert match.replacement_graph is not None + self.replace_with_graph( + match, + graph, + match.replacement_graph, + self.normalize_args(*match.args, **match.kwargs), + ) + + +def _return_true(match: Match) -> bool: + return True + + +def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None: + log.info( + "Replacement pattern %s failed to apply due to shape mismatch: %s", + search_fn.__name__, + e, + ) + + +def check_and_add_duplicate_pattern( + pattern: PatternExpr, + graph: Optional[torch.fx.Graph], + seen_patterns: dict[str, list[Optional[str]]], + skip_duplicates: bool = False, +) -> bool: + """ + Check if a pattern is a duplicate. Because we ignore certain types in searching, but not + in matching, use the graph to distinguish equivalent search patterns. + + Returns True if a duplicate is found and `skip_duplicates=True` is passed in. Errors if + `skip_duplicates` is False and a duplicate is found. + """ + + pattern_repr = PatternPrettyPrinter.run(pattern) + equiv_pattern_reprs = seen_patterns.get(pattern_repr) + if not equiv_pattern_reprs: + seen_patterns[pattern_repr].append(str(graph) if graph else None) + return False + + if graph is None: + if skip_duplicates: + return True + torch._check( + False, + lambda: f"Duplicate pattern: {pattern_repr} with no graph", + ) + + new_graph_str = str(graph) + for graph_str in equiv_pattern_reprs: + if not new_graph_str == graph_str: + continue + if skip_duplicates: + return True + torch._check( + False, + lambda: f"Duplicate pattern: {pattern_repr} with duplicated match graph {graph_str} ", + ) + equiv_pattern_reprs.append(new_graph_str) + return False + + +def register_replacement( + search_fn: SearchFn, + replace_fn: ReplaceFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + search_fn_pattern: Union[PatternExpr, None] = None, + skip_duplicates: bool = False, +) -> bool: + """ + Create a replacement rule based on example functions that get traced + to create patterns. This supports both training and inference when + run on a joint forward+backward graph. + + Args: + search_fn: traced to give original pattern + replace_fn: traced to give replacement graph + example_inputs: example inputs for initial trace + trace_fn: fwd_only or joint_fwd_bwd + pass_dict: dict of passes to register to + extra_check: additional check to run on match(using real shapes) + """ + argnames_static = [*inspect.signature(search_fn).parameters.keys()] + + def check_fn(match: Match) -> bool: + """ + Often shapes get burned into the pattern, so our initial match ran with + `ignore_types=(int, ...)`. + + Recheck the match with the correct shapes. + """ + argnames = list(argnames_static) + for name in argnames: + if name not in match.kwargs: + raise RuntimeError( + f"Not all inputs to pattern found in match.kwargs. Perhaps one " + f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}" + ) + + args = list( + torch.fx.map_arg( + [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] + ) + ) + + sym_args: list[torch.SymInt] = [] + with torch._dynamo.utils.detect_fake_mode(args): + for i, grad in enumerate(requires_grad): + if isinstance(args[i], torch.Tensor): + if grad and is_integer_dtype(args[i].dtype): + return False + + args[i] = torch.empty_strided( + args[i].size(), + args[i].stride(), + dtype=args[i].dtype, + device=args[i].device, + requires_grad=grad, + ) + for v in itertools.chain(args[i].shape, args[i].stride()): + if isinstance(v, torch.SymInt) and all( + statically_known_true(v != a) for a in sym_args + ): + sym_args.append(v) + + # If we were given a pre-traced pattern then use that instead of + # retracing. Note that this means the pattern has to be independent + # of its args. + specific_pattern = search_fn_pattern + + if not specific_pattern: + if sym_args: + # AOT Autograd and make fx will dedupe symbolic shape size + # accesses of sym ints that appear as inputs + # We don't want the sym_size uses to interfere with pattern matching + # so we provide them as inputs. + # Later, when we actually do the replacement, the symbolic shape + # sizes will get re-traced and added to the graph. + + def search_fn_new(*args_new: Any) -> Any: + return search_fn(*args_new[len(args_new) - len(args) :]) + + try: + specific_graph = trace_fn(search_fn_new, sym_args + args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + # correct argnames in the graph + sym_arg_names = [] + for i, placeholder in zip( + range(len(sym_args) + len(args)), + specific_graph.graph.nodes, + ): + if i < len(sym_args): + sym_arg_names.append(placeholder.target) + continue + + with specific_graph.graph.inserting_after(placeholder): + new_node = specific_graph.graph.placeholder( + argnames[i - len(sym_args)] + ) + new_node.target = new_node.name + placeholder.replace_all_uses_with(new_node) + specific_graph.graph.erase_node(placeholder) + + argnames = sym_arg_names + argnames + else: + try: + specific_graph = trace_fn(search_fn, args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + specific_pattern = fx_to_pattern( + specific_graph, + argnames=argnames, + exclusive_arg_names=exclusive_arg_names, + scalar_workaround=scalar_workaround, + ) + + node = match.output_nodes()[0] + assert node is not None + specific_pattern_match = specific_pattern.match(node) + + if is_match(specific_pattern_match) and extra_check(specific_pattern_match): + # trace the pattern using the shapes from the user program + match.replacement_graph = trace_fn(replace_fn, args) + if len(match.nodes) == 1: + for n in match.replacement_graph.graph.nodes: + _transfer_meta( + new_meta=n.meta, + old_node=match.nodes[0], + pass_name="replacement", + ) + return True + return False + + def normalize_args(**kwargs: Any) -> list[Any]: + args = [kwargs.pop(name) for name in argnames_static] + for i in range(1, len(kwargs) + 1): + if f"tangents_{i}" not in kwargs: + break + args.append(kwargs.pop(f"tangents_{i}")) + assert not kwargs, f"leftover kwargs: {kwargs!r}" + return args + + if trace_fn is joint_fwd_bwd: + # If inference mode is enabled during compilation, assume that we don't + # want to match on any training graph patterns + if torch.is_inference_mode_enabled(): + return False + + # TODO: Revisit the functionalize_rng_ops for lowmem dropout + with functorch_config.patch(functionalize_rng_ops=False): + requires_grad: list[bool] = [ + isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs + ] + if search_fn_pattern is None: + pattern, gm = gen_pattern_and_search_gm( + search_fn, + example_inputs, + trace_fn, + scalar_workaround, + exclusive_arg_names, + ) + else: + pattern = search_fn_pattern + gm = None + + for pattern_matcher_pass in ( + pass_dicts if isinstance(pass_dicts, Sequence) else [pass_dicts] + ): + if isinstance(pattern_matcher_pass, PatternMatcherPass): + if check_and_add_duplicate_pattern( + pattern, + gm.graph if gm else None, + pattern_matcher_pass.seen_patterns, + skip_duplicates=skip_duplicates, + ): + return False + + pattern = ReplacementPatternEntry( + pattern=pattern, + extra_check=check_fn, + normalize_args=normalize_args, + ) + pattern.register(pass_dicts) + return pattern.pattern # type: ignore[return-value] + + +_serialized_patterns: OrderedSet[str] = OrderedSet() + + +def _serialize_pattern( + unique_name: str, + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[dict[str, Union[float, int]], None], +) -> PatternExpr: + def get_file_template() -> str: + auto_generated_msg = textwrap.dedent( + """\ + # This is an auto-generated file. Please do not modify it by hand. + # To re-generate, run: + # cd ~/pytorch && python torchgen/fuse/gen_patterns.py + """ + ) + + file_template = textwrap.dedent( + """\ + # mypy: ignore-errors + + # noqa: F401, E501 + {msg} + import torch + import torch._inductor + import operator + + aten = torch.ops.aten + prims = torch.ops.prims + + """ + ).format(msg=auto_generated_msg) + + pattern_matcher_imports = [] + for name in dir(torch._inductor.pattern_matcher): + attr = getattr(torch._inductor.pattern_matcher, name) + try: + if isinstance(attr, type) and issubclass( + attr, (PatternExpr, _TargetExpr) + ): + pattern_matcher_imports.append(name) + except TypeError: + pass + + formatted_imports = ",\n ".join(pattern_matcher_imports) + formatted_imports = f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n" + return f"{file_template}{formatted_imports}" + + if not SERIALIZED_PATTERN_PATH.is_dir(): + raise RuntimeError( + f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}" + ) + + pattern_name = search_fn.__name__ + + from torch._functorch import config as functorch_config + + with functorch_config.patch(functionalize_rng_ops=False): + pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround) + + serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name) + if pattern_name not in _serialized_patterns: + write_mode = "w" + _serialized_patterns.add(pattern_name) + else: + write_mode = "a" + + file_template = get_file_template() + + with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f: + if write_mode == "w": + f.write(file_template) + else: + f.write("\n\n") + f.write(serialized_pattern) + f.write("\n") + + return pattern + + +SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns" + +# This is the set of serialized patterns that we've registered. Used by +# test_serialized_patterns_up_to_date() to ensure the patterns are up +# to date. +_known_precompiled_patterns: list[ + tuple[ + Any, + Iterable[Any], + Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], + Any, + PatternExpr, + ] +] = [] + + +def gen_register_replacement( + unique_name: str, + search_fn: SearchFn, + replace_fn: ReplaceFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + skip_duplicates: bool = False, +) -> None: + # Make sure the example_inputs is materialized. + example_inputs = tuple(example_inputs) + + if "PYTORCH_GEN_PATTERNS" in os.environ: + pat = _serialize_pattern( + unique_name, search_fn, example_inputs, trace_fn, scalar_workaround + ) + else: + pattern_name = search_fn.__name__ + m = importlib.import_module( + f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}" + ) + if not m or not hasattr(m, unique_name): + log.warning( + "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.", + unique_name, + ) + pat = getattr(m, unique_name) + + for arg in pytree.tree_iter(example_inputs): + if isinstance(arg, FakeTensor) and arg.constant is not None: + # This can be a problem - small fake tensors (e.g. `tensor(2)`) will + # hold onto their original constant value - and by stashing it here + # will cause a memory leak if the constant value is on GPU. + # Since this is just an optimization we can clear it out. + arg.constant = None + + _known_precompiled_patterns.append( + (search_fn, example_inputs, trace_fn, scalar_workaround, pat) + ) + register_replacement( + search_fn, + replace_fn, + example_inputs, + trace_fn, + pass_dicts, + extra_check, + scalar_workaround, + exclusive_arg_names, + search_fn_pattern=pat, + skip_duplicates=skip_duplicates, + ) + + +@functorch_config.patch(functionalize_rng_ops=False) # type: ignore[misc] +def gen_pattern_and_search_gm( + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> tuple[PatternExpr, torch.fx.GraphModule]: + argnames = [*inspect.signature(search_fn).parameters.keys()] + + if scalar_workaround is None: + scalar_workaround = {} + flat_inputs = [] + input_idx = 0 # Positional arguments index + + for argname in argnames: + if argname in scalar_workaround: + flat_inputs.append(scalar_workaround[argname]) + else: + flat_inputs.append(example_inputs[input_idx]) + input_idx += 1 + + search_gm = trace_fn(search_fn, flat_inputs) + return ( + fx_to_pattern( + search_gm, + ignore_types=(int, float, list, torch.device, torch.dtype), + argnames=argnames, + scalar_workaround=scalar_workaround, + exclusive_arg_names=exclusive_arg_names, + ), + search_gm, + ) + + +def gen_pattern( + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> PatternExpr: + return gen_pattern_and_search_gm( + search_fn, example_inputs, trace_fn, scalar_workaround, exclusive_arg_names + )[0] + + +def register_lowering_pattern( + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Register an aten to inductor IR replacement pattern. The decorated + function is saved and then called a lowering time allowing direct + pattern to inductor IR conversion. + """ + + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + assert callable(handler) + LoweringPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + handler._inductor_lowering_function = True # type: ignore[attr-defined] + return handler + + return decorator + + +def register_graph_pattern( + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Register a pattern that runs a function on the FX graph, allowing + custom transformation code. + """ + + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + assert callable(handler) + GraphPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + return handler + + return decorator + + +def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool: + # first node in the graph + return node is next(iter(graph.nodes)) + + +# match: copy_, relu_, _set_grad_enabled, manual_seed, _enter_autocast, etc +# doesn't match: __rshift__, etc +_mutation_op_re = re.compile(r"(? bool: + if op.namespace != "inductor": + return False + + # TODO - fix schema + # Dont add any more ! + return op in ( + torch.ops.inductor.accumulate_grad_.default, + torch.ops.inductor.resize_storage_bytes_.default, + ) + + +def is_mutation_op(node: torch.fx.Node) -> bool: + if isinstance( + node.target, torch._ops.OpOverload + ) and not fixme_incorrect_inductor_schema_op(node.target): + return node.target._schema.is_mutable + elif isinstance( + node.target, torch._higher_order_ops.auto_functionalize.AutoFunctionalized + ): + return False + if node.op == "call_function": + assert callable(node.target) + if _mutation_op_re.search(node.target.__name__): + return True + elif node.op == "call_method": + assert isinstance(node.target, str) + if _mutation_op_re.search(node.target): + return True + return node.kwargs.get("out") is not None + + +def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool: + assert "mutation_region_id" in a.meta + assert "mutation_region_id" in b.meta + return a.meta["mutation_region_id"] == b.meta["mutation_region_id"] + + +def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int: + n = node + while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n): + n = n.prev + mutation_region_id = n.meta.get("mutation_region_id", 0) + while n is not node: + n = n.next + if is_mutation_op(n): + mutation_region_id += 1 + n.meta["mutation_region_id"] = mutation_region_id + return mutation_region_id + + +def should_compute_mutation_region_ids(graph: torch.fx.Graph) -> bool: + return "mutation_region_id" not in next(iter(graph.nodes)).meta + + +def compute_mutation_region_ids(graph: torch.fx.Graph) -> None: + mutation_region_id = 0 + for nd in graph.nodes: + if is_mutation_op(nd): + mutation_region_id += 1 + nd.meta["mutation_region_id"] = mutation_region_id + + +class PatternMatcherPass: + def __init__( + self, + pass_name: Optional[str] = None, + ) -> None: + super().__init__() + self.patterns: defaultdict[ + tuple[str, torch.fx.node.Target], list[PatternEntry] + ] = defaultdict(list) + self.pass_name = pass_name + + # For a particular generated pattern repr, store all of the str representations + # of the graph used to generate them. Because we ignore certain patterns + # in searching, but not in matching, use the graph to distinguish if two equivalent + # searches are actually different. + self.seen_patterns: dict[str, list[Optional[str]]] = defaultdict(list) + + def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]: + return self.patterns[item] + + def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int: + if not self.patterns: + return 0 + if isinstance(gm, torch.fx.GraphModule): + graph = gm.graph + elif isinstance(gm, torch.fx.Graph): + graph = gm + gm = graph.owning_module + else: + raise RuntimeError( + f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}" + ) + if should_compute_mutation_region_ids(graph): + compute_mutation_region_ids(graph) + get_mutation_region_id_partial = functools.partial( + get_mutation_region_id, graph + ) + count = 0 + nodes = [] + has_call_module = False + for op, target in self.patterns: + if op == "call_module": + has_call_module = True + else: + nodes.append(graph.find_nodes(op=op, target=target, sort=False)) + if has_call_module: + nodes.append(graph.find_nodes(op="call_module", sort=False)) + pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" + assert isinstance(gm, torch.fx.GraphModule) + with GraphTransformObserver(gm, pass_name): + for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): + target = extract_target(node) + if node.op == "call_module": + if (node.op, target) not in self.patterns: + continue + + # conservatively not applying pattern for cpu input, + # since some of the patterns induce codegen and split nodes. + # Note: we will only skip cpu compute if disable_cpp_codegen=True + if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): + continue + + for entry in self.patterns[(node.op, target)]: + if node._erased: + break + m = entry.pattern.match(node) + # pattern match crosses mutation barrier - discard + if ( + is_match(m) + and len( + OrderedSet(map(get_mutation_region_id_partial, m.nodes)) + ) + != 1 + ): + continue + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning("%s%s %s %s", node, node.args, m, entry.pattern) + if is_match(m) and entry.extra_check(m): + count += 1 + entry.apply(m, graph, node) + counters["inductor"]["pattern_matcher_count"] += 1 + counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) + return count + + def clear(self) -> None: + self.patterns.clear() + + +def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn: + raise NotImplementedError + + +def fx_to_pattern( + gm: Union[torch.fx.GraphModule, torch.fx.Graph], + ignore_types: Sequence[type[Any]] = (), + argnames: Sequence[str] = (), + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> PatternExpr: + """ + Convert an FX graph into a PatternExpr. This is useful for simple + patterns that can only match single functions and fixed-length lists. + """ + # scalar_workaround is a hack to capture dropout_p + # see https://github.com/pytorch/pytorch/issues/97894 + scalar_workaround = scalar_workaround or {} + inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()} + assert len(inv_scalar_workaround) == len(scalar_workaround) + + def process_arg( + x: T, ignore_types_override: Optional[Sequence[type[Any]]] = None + ) -> Union[T, KeywordArg, Ignored]: + current_ignore_types = ( + ignore_types_override if ignore_types_override is not None else ignore_types + ) + if isinstance(x, (float, int)) and x in inv_scalar_workaround: + return KeywordArg(inv_scalar_workaround[x]) + if type(x) in current_ignore_types: + return Ignored() + if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x: + return Ignored() + return x + + argnum = itertools.count() + + class Converter(torch.fx.Interpreter): + call_method = _not_implemented + call_module = _not_implemented + get_attr = _not_implemented + + def placeholder( + self, + target: str, # type: ignore[override] + args: Sequence[Any], + kwargs: Mapping[str, Any], + ) -> Union[ExclusiveKeywordArg, KeywordArg]: + n = next(argnum) + if n < len(argnames): + name = argnames[n] + elif argnames: + assert target.startswith("tangent") + name = target + else: + target = re.sub(r"_\d+$", "", target) # de-mangle arg name + name = target + if name in exclusive_arg_names: + return ExclusiveKeywordArg(name) + else: + return KeywordArg(name) + + def call_function( + self, + target: str, # type: ignore[override] + args: Sequence[Any], + kwargs: Mapping[str, Any], + ) -> PatternExpr: + process_arg_fn = process_arg + # Indexing is critical for matching getitem nodes, so we can't ignore int args here + if target == operator.getitem: + + def process_arg_fn_impl( + x: T, + ignore_types_override: Optional[Sequence[type[Any]]] = tuple( + t for t in ignore_types if t is not int + ), + ) -> Union[T, KeywordArg, Ignored]: + return process_arg(x, ignore_types_override) + + process_arg_fn = process_arg_fn_impl + + args, kwargs = pytree.tree_map(process_arg_fn, (args, kwargs)) + if list in ignore_types: + # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...] + args = [process_arg_fn(a) for a in args] + kwargs = {k: process_arg_fn(a) for k, a in kwargs.items()} + return CallFunction(target, *args, **kwargs) + + def run_node(self, n: torch.fx.Node) -> Any: + rv = super().run_node(n) + if n.op == "output" and isinstance(rv, tuple): + args = n.args[0] + assert isinstance(args, Collection) + assert len(rv) == len(args) + for r, arg in zip(rv, args): + r.users = len(arg.users) + else: + rv.users = len(n.users) + return rv + + assert isinstance(gm, torch.fx.GraphModule) + pattern = Converter(gm).run() + if not isinstance(pattern, PatternExpr): + return MultiOutputPattern(pytree.tree_leaves(pattern)) + return pattern + + +@torch.no_grad() +def fwd_only( + fn: Callable[..., Any], + args: Sequence[Any], + *, + run_functional_passes: bool = True, + get_decomp_fn: Optional[Callable[..., Any]] = None, +) -> torch.fx.GraphModule: + """Build a normalized inference graph, for use with fx_to_pattern""" + # TODO - look into using aot autograd, asserting no mutating ops here + with enable_python_dispatcher(): + decompositions = ( + get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() + ) + gm = make_fx(fn, decompositions, tracing_mode="real")(*args) + + from .fx_passes.post_grad import remove_noop_ops + + if run_functional_passes: + remove_noop_ops(gm.graph) + gm.graph.eliminate_dead_code() + + gm.recompile() + return gm + + +@torch.enable_grad() +def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule: + """Build a normalized training graph, for use with fx_to_pattern""" + gm: Optional[torch.fx.GraphModule] = None + + def record_joint_graph( + joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + nonlocal gm + assert not gm + gm = clone_graph(joint_graph) + return default_partition(joint_graph, inputs, **kwargs) + + with torch._guards.tracing(None): + aot_function( + fn, + lambda g, i: make_boxed_func(g), + partition_fn=record_joint_graph, + decompositions=select_decomp_table(), + keep_inference_input_mutations=True, + enable_log=False, + )(*args) + assert gm + + from .fx_passes.post_grad import remove_noop_ops + + remove_noop_ops(gm.graph) + + from .fx_passes.joint_graph import pointless_view + + matcher_pass = PatternMatcherPass() + + pattern = CallFunction( + torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size") + ) + GraphPatternEntry( + pattern=pattern, handler=pointless_view, extra_check=_return_true + ).register(matcher_pass.patterns) + matcher_pass.apply(gm.graph) + + # remove in/out specs + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +def _args(n: torch.fx.Node) -> list[torch.fx.node.Argument]: + args: list[torch.fx.node.Argument] = [] + torch.fx.map_arg((n.args, n.kwargs), args.append) + return args + + +def stable_topological_sort(graph: torch.fx.Graph) -> None: + # Nodes are in exactly one of these three collections: + + # - Nodes in `pending` are waiting to be processed (in reverse order): + pending = list(reversed(graph.nodes)) + + # - Nodes in `ready` have been processed and are already in the correct + # order. + ready = OrderedSet[torch.fx.Node]() + + # - `waiting` is a mapping from a dependency to nodes which depend on that + # dependency. + waiting = defaultdict(list) + + # The cursor indicates the last processed node so we can add new nodes + # after it. + cursor = None + while pending: + node = pending.pop() + waiting_for = [x for x in _args(node) if x not in ready] + if waiting_for: + # We have unprocessed input nodes. Might as well wait for the last + # arg so an already sorted list will only recheck this node once. + waiting[waiting_for[-1]].append(node) + else: + ready.add(node) + if cursor and cursor.next is not node: + cursor.append(node) + cursor = node + # Mark the nodes that have been waiting for this node to finish as + # ready to check again. + pending.extend(reversed(waiting.pop(node, ()))) + + assert not waiting and len(ready) == len(graph.nodes) + + +def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: + """Wrapper around lazy init functions in fx_passes/""" + + @functools.cache + @functools.wraps(fn) + def lazy_init() -> Any: + counters_ref = counters["inductor"].copy() + + with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): + result = fn() + + # clear view matches encountered during tracing + counters["inductor"] = counters_ref + + return result + + return lazy_init + + +def config_flag(name: str) -> Callable[[Match], Any]: + """Function for extra_check to put pass behind a flag""" + + def flag_check(match: Match) -> Any: + return getattr(config, name) + + return flag_check + + +def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: + class CopyGraph(Transformer): + def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node: + new_node = super().run_node(old_node) + if isinstance(new_node, torch.fx.Proxy): + new_node.node.meta.update(old_node.meta) + new_node.node.name = self.new_graph._graph_namespace.create_name( + old_node.name, None + ) + return new_node + + return CopyGraph(input_graph).transform() + + +# TODO: remove in follow up diff, used internally +_seen_patterns: OrderedSet[str] = OrderedSet() + + +def get_arg_value( + node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None +) -> Any: + if len(node.args) > arg_number: + return node.args[arg_number] + elif kwarg_name is None: + return None + else: + return node.kwargs.get(kwarg_name) + + +def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> list[torch.fx.Node]: + fns = [fn] + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + + return [node for node in nodes if node.target in fns] + + +def extract_target(node: torch.fx.Node) -> torch.fx.node.Target: + """For call_function and call_method, we directly use the target function; + For call_module, the target is string, and we treat the module class + as a function. + """ + if node.op == "call_module": + assert isinstance(node.target, str) + return _get_attr(node.graph.owning_module, node.target).__class__ + return node.target diff --git a/phivenv/Lib/site-packages/torch/_inductor/quantized_lowerings.py b/phivenv/Lib/site-packages/torch/_inductor/quantized_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..9a894d3af0cc6f7444020baf5e1fc97939e80c79 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/quantized_lowerings.py @@ -0,0 +1,168 @@ +import logging +from typing import Any + +import torch +from torch._inductor.kernel.mm_common import mm_args + +from . import config, lowering +from .codegen.cpp_gemm_template import CppGemmTemplate, CppWoqInt4GemmTemplate +from .codegen.cpp_utils import create_epilogue_with_attr +from .lowering import expand, register_lowering +from .mkldnn_ir import WeightInt4PackMatmul +from .select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, +) +from .utils import use_aten_gemm_kernels, use_cpp_gemm_template +from .virtualized import V + + +log = logging.getLogger(__name__) + +aten__weight_int8pack_mm = ExternKernelChoice( + torch._weight_int8pack_mm, "at::_weight_int8pack_mm", has_out_variant=False +) + +aten__weight_int4pack_mm_cpu = ExternKernelChoice( + torch.ops.quantized.int4mm_packed_weight_cpu, + "at::native::_weight_int4pack_mm_cpu_tensor", + has_out_variant=False, + kernel_creator=WeightInt4PackMatmul.create, +) + +quantized = torch.ops.quantized +_quantized = torch.ops._quantized +aten = torch.ops.aten + + +def register_quantized_ops() -> None: + lowering.add_needs_realized_inputs( + [ + quantized.max_pool2d, + _quantized.wrapped_fbgemm_pack_gemm_matrix_fp16, + _quantized.wrapped_fbgemm_linear_fp16_weight, + ] + ) + lowering.make_fallback(quantized.max_pool2d) + lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) + lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight) + + +def register_woq_mm_ops() -> None: + @register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None) # type: ignore[misc] + def int8pack_mm( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + *, + layout: Any = None, + ) -> Any: + _, _, _, layout, mat1, mat2 = mm_args( + input, weight, layout=layout, mat2_transposed=True + ) + assert ( + mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float] + and mat2.get_dtype() == torch.int8 + ) + aten_layout = layout + + # options to tune from + choices = ( + [aten__weight_int8pack_mm.bind((mat1, mat2, scale), aten_layout)] + if use_aten_gemm_kernels() + else [] + ) + + # scale is applied as an epilogue, and the scale tensor is expanded (with a view op) + # for broadcasting, as it's 1D. + def _mul_epilogue(buf: torch.Tensor) -> Any: + return create_epilogue_with_attr( + buf, "mul", other=realize_inputs(expand(scale, layout.size)) + ) + + if use_cpp_gemm_template(aten_layout, mat1, mat2, mat2_transposed=True): + CppGemmTemplate.add_choices( + choices, + aten_layout, + [mat1, mat2, scale], + trans_w=True, + epilogue_creator=_mul_epilogue, # type: ignore[arg-type] + ) + + return autotune_select_algorithm( + "_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout + ) + + @register_lowering(aten._weight_int4pack_mm_for_cpu, type_promotion_kind=None) # type: ignore[misc] + def int4pack_mm_cpu( + input: torch.Tensor, + weight: torch.Tensor, + qGroupSize: int, + qScaleAndZeros: torch.Tensor, + *, + layout: Any = None, + ) -> Any: + _, _, _, layout, mat1, mat2 = mm_args( + input, weight, layout=layout, use_4x2_dim=True, mat2_transposed=True + ) + assert ( + mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float] + and mat2.get_dtype() == torch.uint8 + ) + group_size = V.graph.add_tensor_constant( + torch.tensor(qGroupSize, dtype=torch.int64), name=None + ) + aten_layout = layout + + # options to tune from + choices = ( + [ + aten__weight_int4pack_mm_cpu.bind( + (mat1, mat2, group_size, qScaleAndZeros), aten_layout + ) + ] + if use_aten_gemm_kernels() + else [] + ) + if ( + (config.max_autotune or config.max_autotune_gemm) + and use_cpp_gemm_template( + aten_layout, + mat1, + mat2, + mat2_transposed=True, + is_woq_int4=True, + q_group_size=qGroupSize, + ) + and mat2.get_layout().is_contiguous() + ): + CppWoqInt4GemmTemplate[qGroupSize].add_choices( + choices, + aten_layout, + [mat1, mat2, group_size, qScaleAndZeros], + ) + + # define functions to generate example inputs for weight and group size + # otherwise, autotuner generates example inputs of all zeros for them + def get_example_weight(x: torch._inductor.ir.IRNode) -> torch.Tensor: + assert x.get_layout().is_contiguous() + shape = x.get_size() + device = x.get_device() + return torch.randint(0, 255, shape, dtype=torch.uint8, device=device) + + input_gen_fns = { + 1: get_example_weight, # packed weight + 2: lambda x: V.graph.constants[x.get_name()], # group size + } + + return autotune_select_algorithm( + "_weight_int4pack_mm_for_cpu", + choices, + [mat1, mat2, group_size, qScaleAndZeros], + aten_layout, + input_gen_fns=input_gen_fns, + ) + + lowering.make_fallback(aten._dyn_quant_matmul_4bit) + lowering.make_fallback(aten._dyn_quant_pack_4bit_weight) diff --git a/phivenv/Lib/site-packages/torch/_inductor/remote_cache.py b/phivenv/Lib/site-packages/torch/_inductor/remote_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6b26b51e664222bde4f354db3071ba49e56b9da4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/remote_cache.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +import atexit +import collections +import dataclasses +import functools +import json +import logging +import os +import sys +import typing +from abc import abstractmethod +from typing import Any, Callable, Generic, Optional, TypeVar, Union +from typing_extensions import override, TypeAlias + +from torch._dynamo.utils import dynamo_timed +from torch._inductor import config +from torch.monitor import _WaitCounter + + +try: + import redis +except ImportError: + redis = None # type: ignore[assignment] + + +log = logging.getLogger(__name__) + + +if config.is_fbcode(): + from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found] + Sample as Sample_, + ) + + Sample: TypeAlias = Sample_ +else: + Sample: TypeAlias = type[object] # type: ignore[misc,no-redef] + + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +remote_fx_cache_get_timed = functools.partial( + dynamo_timed, + "FbRemoteFxGraphCache.get", + phase_name="remote_fx_graph_cache_get", + log_pt2_compile_event=False, + dynamo_compile_column_us="remote_fx_graph_cache_get_time_us", + log_waitcounter=True, +) +remote_fx_cache_put_timed = functools.partial( + dynamo_timed, + "FbRemoteFxGraphCache.put", + phase_name="remote_fx_graph_cache_put", + log_pt2_compile_event=False, + dynamo_compile_column_us="remote_fx_graph_cache_put_time_us", + log_waitcounter=True, +) + + +class RemoteCacheBackend(Generic[_T]): + """ + A backend implementation for accessing a remote/distributed cache. Only + works with bytes in/out. For structured data use a RemoteCache. + """ + + def __init__(self) -> None: + self._name = f"backend:{type(self).__name__}" + + @abstractmethod + def _get(self, key: str) -> Optional[_T]: + pass + + @abstractmethod + def _put(self, key: str, data: _T) -> None: + pass + + def get(self, key: str) -> Optional[_T]: + try: + value = self._get(key) + cache_stats.get(self._name, value) + except Exception: + cache_stats.exception(self._name) + raise + return value + + def put(self, key: str, data: _T) -> None: + try: + self._put(key, data) + cache_stats.put(self._name) + except Exception: + cache_stats.exception(self._name) + raise + + +# Serde that encodes from _T to _U and decodes from _U to _T. +class RemoteCacheSerde(Generic[_T, _U]): + @abstractmethod + def encode(self, data: _T) -> _U: + pass + + @abstractmethod + def decode(self, data: _U) -> _T: + pass + + +JsonDataTy = Optional[ + Union[int, float, str, bool, dict[str, "JsonDataTy"], list["JsonDataTy"]] +] + + +class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]): + def encode(self, data: JsonDataTy) -> bytes: + return bytes(json.dumps(data), "ascii") + + def decode(self, data: bytes) -> JsonDataTy: + return json.loads(data) + + +class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]): + def encode(self, data: _T) -> _T: + return data + + def decode(self, data: _T) -> _T: + return data + + +# This class is the top of a RemoteCache. A RemoteCache is fundamentally made of +# three parts: +# +# 1. The controller (this class). +# 2. A serializer/deserializer (instance of RemoteCacheSerde). +# 3. A backend (instance of RemoteCacheBackend). +# +# To write (`put`), the RemoteCache takes data, uses the RemoteCacheSerde to +# convert it for the backend and passes it to the backend. +# +# Conversely when reading (`get`), the RemoteCache takes data from the backend, +# uses the RemoteCacheSerde to convert it and returns it. +# +# The RemoteCacheBackend is generic on _U - which is the type of data the +# backend can directly cache (usually `bytes`). +# +# The RemoteCacheSerde is responsible for converting between _T (the type of +# data the RemoteCache accepts in `put` and returns in `get`) and _U. +# +# When instantiating a RemoteCache you should override, not directly create a +# RemoteCache. The reason is that when logging cache use (`TORCH_LOGS=cache`) we +# use the concrete type of the RemoteCache as the reported cache. See +# RemoteFxGraphCache below as an example. +class RemoteCache(Generic[_T]): + backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None + + def __init__( + self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U] + ) -> None: + # Support for testing to mock out the backend on a class-by-class basis. + if (override_cls := self.__class__.backend_override_cls) is not None: + self.backend = override_cls() + else: + self.backend = backend + self.serde = serde + + # See if the cache contains `key`. Returns `None` if the value is not + # present in the cache. + def get(self, key: str) -> Optional[_T]: + with _WaitCounter("pytorch.remote_cache.get").guard(): + sample = self._create_sample() + try: + result = self._get(key, sample) + cache_stats.get(type(self).__name__, result) + except Exception: + cache_stats.exception(type(self).__name__) + raise + self._log_sample(sample) + return result + + # Add `value` to the cache with the key `key`. Note that `None` is not a + # valid value even if _T supports it (because you can't tell the difference + # between `None` and a missing cache entry). + def put(self, key: str, value: _T) -> None: + with _WaitCounter("pytorch.remote_cache.put").guard(): + assert value is not None + sample = self._create_sample() + try: + self._put(key, value, sample) + cache_stats.put(type(self).__name__) + except Exception: + cache_stats.exception(type(self).__name__) + raise + self._log_sample(sample) + + # Used to convert data from the cache into structured data. + def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override] + return self.serde.decode(data) # type: ignore[arg-type] + + # Used to convert structured data into data for the cache. + def _encode(self, value: _T, sample: Optional[Sample]) -> object: # returns _U + return self.serde.encode(value) + + # Get structured data from the cache. + # Separate from `get` so that it can be overridden. + def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]: + if data := self._backend_get(key): + return self._decode(data, sample) + return None + + # Get unstructured data from the cache. + # Separate from `get` so that it can be overridden. + # Returns _U - but we aren't actually generic on _U + def _backend_get(self, key: str) -> object: + return self.backend.get(key) + + # Put structured data into the cache. + # Separate from `put` so that it can be overridden. + def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None: + data = self._encode(value, sample) + self._backend_put(key, data) + + # Put unstructured data into the cache. + # Separate from `put` so that it can be overridden. + # Takes data: _U - but we aren't actually generic on _U + def _backend_put(self, key: str, data: object) -> None: + self.backend.put(key, data) + + # Create a logging Sample - used with internal loggers to monitor cache + # effectiveness. + def _create_sample(self) -> Optional[Sample]: + return None + + # Write the logging Sample to the logger. + def _log_sample(self, sample: Optional[Sample]) -> None: + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): + """ + A Redis implementation of a remote/distributed cache. + """ + + _redis: Optional[redis.Redis] = None + + def __init__(self, cache_id: str) -> None: + super().__init__() + if not redis: + raise RuntimeError("redis not available but required for remote cache") + + if "TORCHINDUCTOR_REDIS_URL" in os.environ: + self._redis = redis.Redis.from_url(os.environ["TORCHINDUCTOR_REDIS_URL"]) + else: + self._redis = redis.Redis( + host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"), + port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)), + ) + + @override + def _get(self, key: str) -> Optional[bytes]: + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return None + + try: + value = self._redis.get(key) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + return None + + # In theory redis.get() can return an Awaitable as well... + assert value is None or isinstance(value, bytes) + return value + + @override + def _put(self, key: str, data: bytes) -> None: + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return + + try: + self._redis.set(key, data) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + + +class RedisRemoteCache(RemoteCache[JsonDataTy]): + def __init__(self, cache_id: str) -> None: + # Special test handling: If we're just going to override the backend + # anyway don't require redis + if self.__class__.backend_override_cls: + # This is totally bogus but it works for now... + backend = typing.cast(RemoteCacheBackend[bytes], None) + else: + backend = RedisRemoteCacheBackend(cache_id) + serde = RemoteCacheJsonSerde() + super().__init__(backend, serde) + version = 1 # consistency between various types of keys + self._key_fmt = f"pt2:{cache_id}::{{key}}:c{version}" + + def _get_key(self, key: str) -> str: + return self._key_fmt.format(key=key) + + @override + def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: + key = self._get_key(key) + return super()._get(key, sample) + + @override + def _put(self, key: str, value: JsonDataTy, sample: Optional[Sample]) -> None: + key = self._get_key(key) + super()._put(key, value, sample) + + +class RemoteAutotuneCache(RedisRemoteCache): + pass + + +class RemoteBundledAutotuneCache(RedisRemoteCache): + pass + + +class RemoteFxGraphCache(RedisRemoteCache): + pass + + +class RemoteAOTAutogradCache(RedisRemoteCache): + pass + + +class RemoteDynamoPGOCache(RedisRemoteCache): + pass + + +def create_cache( + key: str, + is_fbcode: bool, + fb_cache_cls: str, + oss_cache_cls: str, +) -> Optional[RemoteCache[JsonDataTy]]: + try: + if is_fbcode: + import torch._inductor.fb.remote_cache + + cache_cls = getattr(torch._inductor.fb.remote_cache, fb_cache_cls) + return cache_cls(key) + else: + this_module = sys.modules[__name__] + + cache_cls = getattr(this_module, oss_cache_cls) + return cache_cls(key) + + except Exception: + log.warning("Unable to create a remote cache", exc_info=True) + return None + + +# Some simple stat capture +@dataclasses.dataclass +class _CacheStat: + miss: int = 0 + hit: int = 0 + put: int = 0 + exception: int = 0 + + def __str__(self) -> str: + return f"{{hit: {self.hit}, miss: {self.miss}, put: {self.put}, exception: {self.exception}}}" + + +class _CacheStats: + _stats: dict[str, _CacheStat] + + def __init__(self) -> None: + self._stats = collections.defaultdict(_CacheStat) + + def miss(self, name: str, count: int = 1) -> None: + self._stats[name].miss += count + + def hit(self, name: str, count: int = 1) -> None: + self._stats[name].hit += count + + def get(self, name: str, value: Optional[object]) -> None: + if value is None: + self.miss(name) + else: + self.hit(name) + + def put(self, name: str, count: int = 1) -> None: + self._stats[name].put += count + + def exception(self, name: str, count: int = 1) -> None: + self._stats[name].exception += count + + +cache_stats = _CacheStats() + + +@atexit.register +def dump_cache_stats() -> None: + if not log.isEnabledFor(logging.INFO): + return + + import io + + out = io.StringIO() + + if not cache_stats._stats: + print(" None", file=out) + else: + print(file=out) + for k, v in sorted(cache_stats._stats.items()): + print(f" {k}: {v}", file=out) + + log.info("Cache Metrics:%s", out.getvalue()) diff --git a/phivenv/Lib/site-packages/torch/_inductor/scheduler.py b/phivenv/Lib/site-packages/torch/_inductor/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..eca1330577aa2e35103f56cd0311d0547df00a03 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/scheduler.py @@ -0,0 +1,5026 @@ +from __future__ import annotations + +import collections +import dataclasses +import functools +import inspect +import itertools +import logging +import math +import operator +import os +import pprint +import textwrap +import traceback +import typing +from collections import Counter, defaultdict +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union + + +if TYPE_CHECKING: + from collections.abc import Sequence + from types import ModuleType + +import sympy + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.codecache import LambdaFuture, PyCodeCache +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.fx.experimental.symbolic_shapes import free_symbols +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from torch.utils._triton import has_triton + +from . import comms, config, dependencies, ir, metrics +from .analyze_preserves_zero_mask import can_codegen_without_upcasts +from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel +from .comm_analysis import estimate_nccl_collective_runtime +from .dependencies import Dep, MemoryDep, StarDep, WeakDep +from .exc import GPUTooOldForTriton, TritonMissing +from .fx_utils import count_flops_fx, countable_fx +from .ir import ( + get_device_type, + GraphPartitionSignature, + MultiOutput, + MultiOutputLayout, + NoneLayout, +) +from .loop_body import LoopBody +from .memory import MemoryPlanningInfoForBuffer, MemoryPlanningInfoForNode +from .runtime.runtime_utils import green_text, red_text +from .sizevars import SimplifyIndexing +from .utils import ( + cache_on_self, + cmp, + device_need_guard, + get_device_tflops, + get_dtype_size, + get_gpu_dram_gbps, + GraphPartitionMap, + IndentedBuffer, + is_collective, + is_cudagraph_unsafe_op, + is_gpu, + is_multi_outputs_template, + is_output_of_multi_outputs_template, + is_wait, + sympy_product, +) +from .virtualized import V + + +log = logging.getLogger(__name__) +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") + +PartitionType = list["BaseSchedulerNode"] + + +@dataclasses.dataclass +class SchedulerBuffer: + scheduler: Scheduler + node: ir.Buffer + defining_op: Optional[BaseSchedulerNode] + users: list[NodeUser] = dataclasses.field(default_factory=list) + mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field( + default_factory=MemoryPlanningInfoForBuffer + ) + + def defining_op_name(self) -> str: + op = self.defining_op + assert op is not None + return op.get_name() + + def __hash__(self) -> int: + return hash(self.node.name) + + def debug_str(self) -> str: + result = IndentedBuffer() + name = self.get_name() + result.writeline(f"{name}: {type(self.node).__name__}") + result.writeline(f"{name}.layout = {self.node.layout}") + if self.get_aliases(): + result.writeline(f"{name}.aliases = {pformat(self.get_aliases())}") + if self.get_mutations(): + result.writeline(f"{name}.mutations = {pformat(self.get_mutations())}") + + if len(self.users) <= 1: + result.writeline(f"{name}.users = {self.users}") + else: + result.writeline(f"{name}.users = [") + with result.indent(1): + for user in self.users: + result.writeline(f"{user},") + result.writeline("]") + return result.getrawvalue() + + def get_name(self) -> str: + return self.node.get_name() + + def allocate(self) -> None: + assert self.node is not None + if not self.node.should_allocate(): + return + + if ( + self.node.get_inputs_that_alias_output() + or self.node.get_mutation_names() + or isinstance(self.node.get_output_spec(), ir.CommBufferLayout) + ): + V.graph.wrapper_code.codegen_allocation(self.node) + return + + # hacky check for if V.kernel is a real kernel or NullHandler + if ( + hasattr(V.kernel, "args") + and self.get_name() in V.kernel.inplace_update_buffers + ): + input_buffer: Union[ir.DonatedBuffer, ir.Buffer] + input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()] + if input_buffer_name in self.scheduler.name_to_donated_buffer: + input_buffer = self.scheduler.name_to_donated_buffer[ + input_buffer_name + ].node + else: + input_buffer = self.scheduler.name_to_buf[input_buffer_name].node + V.graph.wrapper_code.codegen_inplace_reuse( + input_buffer, + self.node, + ) + else: + V.graph.wrapper_code.codegen_allocation(self.node) + + def can_free(self) -> bool: + # There's no real allocated buffer, no need to free it + assert self.node is not None + if isinstance(self.node.layout, ir.NoneLayout) or is_multi_outputs_template( + self.node + ): + return False + for use in self.users: + if isinstance(use.node, OutputNode): + return False + return True + + def set_users(self, users: list[NodeUser]) -> None: + # deduplicate + result: dict[int, NodeUser] = {} + for use in users: + if id(use.node) in result: + result[id(use.node)] = use.merge(result[id(use.node)]) + else: + result[id(use.node)] = use + self.users = list(result.values()) + + def get_aliases(self) -> Sequence[str]: + assert self.node is not None + return self.node.get_inputs_that_alias_output() + + def get_mutations(self) -> Sequence[str]: + assert self.node is not None + return self.node.get_mutation_names() + + def get_device(self) -> Optional[torch.device]: + return self.node.get_output_spec().get_device() + + +@dataclasses.dataclass +class SchedulerDonatedBuffer(SchedulerBuffer): + defining_op: Optional[BaseSchedulerNode] = None + + +class BaseSchedulerNode: + group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]] + read_writes: dependencies.ReadWrites + unmet_dependencies: OrderedSet[Dep] + # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. + # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node + # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3. + # For non-"grouped" nodes (i.e. regular SchedulerNode), + # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`. + min_order: int + max_order: int + mpi_node: MemoryPlanningInfoForNode + + def __init__(self, scheduler: Scheduler) -> None: + self.scheduler: Scheduler = scheduler + self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = ( + lambda *args, **kwargs: [] + ) + + def _init_from_node(self, node: ir.Operation) -> None: + self.node: Optional[ir.Operation] = node + self.ancestors: OrderedSet[str] = OrderedSet() + self.last_usage = OrderedSet[ + str + ]() # buffers that won't be used after this kernel + self.written = False + self.outputs: list[SchedulerBuffer] = [ + SchedulerBuffer( + scheduler=self.scheduler, + node=output, + defining_op=self, + ) + for output in node.get_outputs() + ] + self.outputs_by_name: dict[str, SchedulerBuffer] = { + buf.get_name(): buf for buf in self.outputs + } + + def __repr__(self) -> str: + return f"{type(self).__name__}(name={self.get_name()!r})" + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + buf = IndentedBuffer() + buf.splice( + f"""\ +{name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__}) +{name}.writes = {pformat(self.read_writes.writes)} +{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} +{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} +{name}.outputs = [ + """ + ) + with buf.indent(): + for out in self.get_outputs(): + buf.splice(out.debug_str()) + buf.writeline("]") + + try: + buf.splice(self.debug_str_extra()) + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return buf.getrawvalue().rstrip() + + def debug_str_extra(self) -> str: + return "" + + def _debug_str_for_device(self) -> list[str]: + return self.debug_device_str(self) + + def debug_str_short(self) -> str: + maybe_data = getattr(self.node, "data", None) + data_str = "" + if isinstance(maybe_data, torch._inductor.ir.Pointwise): + data_str = ", " + maybe_data.str_helper( + [maybe_data.get_size()], shorten=False, multiline=False + ) + elif isinstance(maybe_data, torch._inductor.ir.Reduction): + data_str = ", " + maybe_data.str_helper( + [maybe_data.get_reduction_size(), maybe_data.get_reduction_type()], + shorten=False, + multiline=False, + ) + return f"{self}{data_str}" + + def log_details(self) -> None: + log.info( + "%s: unmet_dependencies = %s, writes = %s", + self, + self.unmet_dependencies, + self.read_writes.writes, + ) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + return + + def update_mutated_names(self, renames: dict[str, str]) -> None: + self.set_read_writes(self.read_writes.rename(renames)) + + def add_fake_dep(self, dep: Dep) -> None: + self.set_read_writes(self.read_writes.with_read(dep)) + + def has_aliasing_or_mutation(self) -> bool: + return any( + buf.get_aliases() or buf.get_mutations() for buf in self.get_outputs() + ) + + def set_read_writes(self, rw: dependencies.ReadWrites) -> None: + self.read_writes = rw + self.unmet_dependencies = self.read_writes.reads + self.prune_deps() + + def set_last_usage( + self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str] + ) -> None: + used_buffers = self.used_or_aliased_buffer_names() + used_buffers = OrderedSet(mutation_real_name.get(k, k) for k in used_buffers) + self.last_usage = used_buffers - future_used_buffers + + def mark_run(self) -> None: + for buf in self.outputs: + buf.allocate() + + def used_buffer_names(self) -> OrderedSet[str]: + return OrderedSet( + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + ) + + def used_or_aliased_buffer_names(self) -> OrderedSet[str]: + used_names: OrderedSet[str] = OrderedSet() + + deps = [ + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + ] + while len(deps) > 0: + dep = deps.pop() + used_names.add(dep) + if V.graph.name_to_buffer.get(dep): + deps.extend( + alias + for alias in V.graph.name_to_buffer[ + dep + ].get_inputs_that_alias_output() + if alias not in used_names + ) + return used_names + + def prune_deps(self) -> None: + self.unmet_dependencies = OrderedSet( + dep + for dep in self.unmet_dependencies + if dep.name not in self.scheduler.available_buffer_names + ) + + def prune_weak_deps(self) -> None: + # Prune weak dependencies on operations that have been removed + def should_prune(dep: Dep) -> bool: + if not isinstance(dep, WeakDep): + return False + op_name = self.scheduler.name_to_buf[dep.name].defining_op_name() + return op_name in V.graph.removed_operations + + to_remove = OrderedSet( + dep for dep in self.read_writes.reads if should_prune(dep) + ) + self.set_read_writes(self.read_writes.remove_reads(to_remove)) + + def prune_redundant_deps( + self, name_to_fused_node: dict[str, BaseSchedulerNode] + ) -> None: + _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) + + def get_name(self) -> str: + assert self.node is not None + return self.node.get_operation_name() + + def get_first_name(self) -> str: + return self.get_name() + + @cache_on_self + def get_operation_names(self) -> OrderedSet[str]: + return OrderedSet(node.get_name() for node in self.get_nodes()) + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet(out.get_name() for out in self.outputs) + + @cache_on_self + def can_codegen_in_low_precision(self) -> bool: + return all( + isinstance(n, SchedulerNode) + and can_codegen_without_upcasts(n, disallow_fp32_ops=True) + for n in self.get_nodes() + ) + + @cache_on_self + def can_codegen_without_upcasts(self) -> bool: + return all( + isinstance(n, SchedulerNode) and can_codegen_without_upcasts(n) + for n in self.get_nodes() + ) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return [self] + + def get_outputs(self) -> Sequence[SchedulerBuffer]: + return self.outputs + + def get_output(self, buf_name: str) -> SchedulerBuffer: + return self.outputs_by_name[buf_name] + + def get_device(self) -> Optional[torch.device]: + assert self.node is not None + return self.node.get_device() + + def is_cpu(self) -> bool: + device = self.get_device() + return device is not None and device.type == "cpu" + + def is_gpu(self) -> bool: + device = self.get_device() + return device is not None and is_gpu(device.type) + + def is_reduction(self) -> bool: + return False + + def is_split_scan(self) -> bool: + return False + + def is_template(self) -> bool: + return False + + def is_extern(self) -> bool: + return False + + def is_foreach(self) -> bool: + return False + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + return False + + def has_side_effects(self) -> bool: + return False + + def decide_inplace_update(self) -> None: + """ + Decide if there should be inplace updates for the node + and record the decision in the active kernel. + """ + from .codegen.wrapper import can_match_buffer_size + + if not ( + isinstance(self, SchedulerNode) + and config.inplace_buffers + and V.graph.has_feature(self.get_device(), BackendFeature.INPLACE_BUFFERS) + and ( + not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel) + or getattr(V.kernel, "mutations", None) is not None + ) + # hacky check for if V.kernel is a real kernel or NullHandler + and hasattr(V.kernel, "args") + ): + return + + # NOTE remove V.graph.removed_operations once deps issue is fixed + inconsequential_nodes = ( + self.ancestors + | V.graph.removed_operations + | self.scheduler.completed_operations + ) + + def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool: + # Inside of NodeUser, we track that the read and write are equivalent + # before deciding if the use can be inplace. + # But if that use is fused into a larger kernel, we need to check equivalence + # of other accesses in fused scheduler node as well. + fused_node = buf_to_be_inplaced.scheduler.get_fused_node(self) + buf_name = buf_to_be_inplaced.get_name() + # Dedup read/writes with equivalent indices + # TODO - would be nice if we could just cache accesses on ReadWrites, + # and enforce variant that this class & members are functional.. + deps: OrderedSet[Dep] = OrderedSet() + for user in buf_to_be_inplaced.users: + user_node = user.node + if not isinstance(user_node, BaseSchedulerNode): + continue + + if ( + user_node.get_first_name() + not in buf_to_be_inplaced.scheduler.name_to_fused_node + or buf_to_be_inplaced.scheduler.get_fused_node(user_node) + is not fused_node + ): + continue + + deps |= ( + o + for o in user_node.read_writes.reads_and_writes() + if o.name == buf_name + ) + if len(deps) > 1: + return False + + return True + + for buf in self.get_outputs(): + buf_node = buf.node + assert buf_node is not None + if ( + not buf_node.should_allocate() + or buf_node.get_inputs_that_alias_output() + or buf_node.get_mutation_names() + or buf.get_name() in V.graph.removed_buffers + ): + continue + + for read in self.read_writes.reads: + input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]] + if read.name in self.scheduler.name_to_donated_buffer: + input_buf = self.scheduler.name_to_donated_buffer[read.name] + else: + input_buf = self.scheduler.name_to_buf.get(read.name) + + if ( + input_buf + and V.graph.wrapper_code.can_reuse(input_buf, self) + and not isinstance(input_buf.defining_op, NopKernelSchedulerNode) + ): + assert input_buf.users is not None + remaining_uses = [ + x + for x in input_buf.users + if x.node.get_name() not in inconsequential_nodes + ] + if ( + len(remaining_uses) == 1 + and remaining_uses[0].can_inplace + and remaining_uses[0].node is self + and input_buf.node is not None + and not isinstance( + input_buf.node.get_output_spec(), + ( + ir.NoneLayout, + ir.MultiOutputLayout, + ir.MutationLayoutSHOULDREMOVE, + ), + ) + and not ( + input_buf.defining_op + and isinstance( + input_buf.defining_op.node, + (ir.FallbackKernel, ir.MultiOutput), + ) + and len(input_buf.node.get_inputs_that_alias_output()) > 0 + ) + and can_match_buffer_size(input_buf.node, buf.node) + and single_index_in_fused_node(input_buf) + ): + # if there isn't a triton kernel, then we don't need to call triton-specific things. + # but TODO this might be a convenient place to signal to the Collective kernels to inplace + # (and, can we make "kernel" less generic of a name?) + V.kernel.args.make_inplace(input_buf.get_name(), buf.get_name()) + # mutations not tracked in cpp kernels + if isinstance( + V.kernel, torch._inductor.codegen.simd.SIMDKernel + ): + V.kernel.mutations.add(input_buf.get_name()) + V.kernel.mutations.add(buf.get_name()) + + V.kernel.inplace_update_buffers[buf.get_name()] = ( + input_buf.get_name() + ) + break + + def codegen_originating_info( + self, buffer: IndentedBuffer, only_once: bool = True + ) -> None: + if not config.comment_origin: + return + + if only_once and self.written: + return + assert self.node is not None + origins = self.node.get_origins() + out_lines = [] + + for o in origins: + if o.op == "output": + # These are boring and samey + continue + + out_lines.append("") + # TODO(voz): Should the pragma be constant somewhere? + out_lines.append("#pragma CMT ORIGIN:") + op_info_str = f"#pragma CMT {o.op} {o.target}" + if "seq_nr" in o.meta: + op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}" + out_lines.append(op_info_str) + if "stack_trace" in o.meta: + stack_trace = f"{o.meta['stack_trace']}" + stack_trace_last_line = stack_trace.split("|")[-1] + out_lines.append( + "#pragma CMT " + + stack_trace_last_line.replace("{", "{{") + .replace("}", "}}") + .replace("\n", "\\") + ) + out_lines.append("#pragma CMT END ORIGIN") + out_lines.append("") + + if len(out_lines) == 0: + return + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + buffer.writelines(out_lines) + self.written = True + + @cache_on_self + def get_read_write_buffers_sizes(self) -> int: + return self.get_read_write_buffers_sizes_impl( + include_reads=True, include_writes=True + ) + + @cache_on_self + def get_read_buffer_sizes(self) -> int: + return self.get_read_write_buffers_sizes_impl( + include_reads=True, include_writes=False + ) + + @cache_on_self + def get_write_buffer_sizes(self) -> int: + return self.get_read_write_buffers_sizes_impl( + include_reads=False, include_writes=True + ) + + def get_read_write_buffers_sizes_impl( + self, include_reads: bool, include_writes: bool + ) -> int: + return sum( + self.get_read_write_buffer_accesses( + include_reads=include_reads, include_writes=include_writes + ).values(), + start=0, + ) + + def get_read_write_buffer_accesses( + self, include_reads: bool, include_writes: bool + ) -> dict[str, int]: + """ + Counting the number of bytes accessed for a kernel is + surprisingly tricky. In particular, there is a differentiation + between 'theoretical' memory accesses and practical memory + accesses. For example, a layernorm kernel may actually access an + input 3 times, but in theory, it only needs to access its input + once (and may be optimized to do so through say, persistent + reductions) + + Another example is that even though a buffer is passed in, we may + not access the entire buffer. This may occur if we are accessing + a slice of the buffer. Another tricky case is for indirect + indexing, where the amount of bytes accessed depends on the + values of the input. + + What this function aims to compute is the memory accesses for + worst-case inputs, best-case optimization. What this means is + that for each buffer we compute the amount of potential accesses in two ways and take the minimum. + + 1. Numel in ranges multiplied by number of deps the buffer has + 2. The buffer size + + Returns memory accesses per buffer. + """ + if isinstance(self, NopKernelSchedulerNode): + return {} + if isinstance(self, ExternKernelSchedulerNode) and isinstance( + self.node, MultiOutput + ): + # todo: Calculate this - it's kinda annoying. + return {} + if ( + isinstance(self, ExternKernelSchedulerNode) + and isinstance(self.node, ir.FallbackKernel) + and self.node.op_overload + is torch._prims.rng_prims.graphsafe_run_with_rng_state + ): + return {} + + def try_size_hint(s: sympy.Expr) -> int: + return V.graph.sizevars.size_hint(s, fallback=0) + + if isinstance(self, SchedulerNode): + node_numel = try_size_hint( + sympy_product(self.get_ranges()[0]) + * sympy_product(self.get_ranges()[1]), + ) + else: + node_numel = int(1e9) + buf_accesses = collections.defaultdict(list) + + if include_reads: + for dep in self.read_writes.reads: + buf_accesses[dep.name].append(dep) + + if include_writes: + for dep in self.read_writes.writes: + buf_accesses[dep.name].append(dep) + + reads = ( + OrderedSet(dep.name for dep in self.read_writes.reads) + if include_reads + else OrderedSet() + ) + writes = ( + OrderedSet(dep.name for dep in self.read_writes.writes) + if include_writes + else OrderedSet() + ) + + def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool: + users = self.scheduler.name_to_buf[buf].users + buf_uses = OrderedSet(user.node for user in users) + return len(buf_uses - OrderedSet(snodes)) > 0 + + if isinstance(self, FusedSchedulerNode): + removed_buffers = OrderedSet( + dep for dep in writes if not is_materialized(dep, self.snodes) + ) + writes = writes - removed_buffers + reads = reads - removed_buffers + + buf_byte_accesses: dict[str, int] = {} + + for buf_name in reads | writes: + buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name]) + buf: Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject] + if buf_name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[buf_name] + elif buf_name in V.graph.graph_inputs: + buf = V.graph.graph_inputs[buf_name] + else: + continue + + def get_buf_bytes( + buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]], + ) -> int: + if not buf: + return 0 + + if isinstance(buf, ir.TorchBindObject): + return buf.get_buf_bytes() + elif isinstance(buf.layout, MultiOutputLayout): + # Kind of a lazy way to get the MultiOutput nodes corresponding to + # a MultiOutputLayout + users = self.scheduler.name_to_buf[buf.get_name()].users + tot = 0 + for user in users: + assert isinstance(user.node, BaseSchedulerNode) + if isinstance(user.node.node, MultiOutput): + for sched_buf in user.node.get_outputs(): + tot += get_buf_bytes(sched_buf.node) + else: + # Buf is a MultiOutputLayout but not all of its + # users are MultiOutputs... + # TODO: Figure out what's going on + return 0 + return tot + elif isinstance(buf.layout, ir.NoneLayout): + return sum( + get_buf_bytes(V.graph.get_buffer(mut_name)) + for mut_name in buf.get_mutation_names() + ) + else: + buf_elems = try_size_hint(sympy_product(buf.get_size())) + return get_dtype_size(buf.get_dtype()) * min( + buf_accessed_elems, buf_elems + ) + + buf_bytes = get_buf_bytes(buf) + if buf_name not in buf_byte_accesses: + buf_byte_accesses[buf_name] = buf_bytes + else: + buf_byte_accesses[buf_name] += buf_bytes + + return buf_byte_accesses + + @cache_on_self + def estimate_flops(self) -> int | None: + if self.node is None: + return None + fx_node = self.node.get_origin_node() + if fx_node is None: + return None + if not countable_fx(fx_node): + return None + + flops = count_flops_fx(fx_node) + + resolved_flops = V.graph.sizevars.size_hints((flops,), fallback=0)[0] + counters["inductor"]["flop_count"] += resolved_flops + return resolved_flops + + @cache_on_self + def get_estimated_runtime(self) -> float: + """ + Returns estimated op runtime in nanoseconds (ns) + """ + buf = self.get_nodes()[0].get_outputs()[0] + layout = buf.node.get_output_spec() + if not is_gpu(get_device_type(layout)): + # default to no reordering based on runtime + return 0 + + # Collective kernels + if is_collective(self.node): + assert isinstance(self.node, ir.IRNode) + try: + return estimate_nccl_collective_runtime(self.node) + except ValueError as e: + # We don't know how to estimate runtime for this collective, + # falling back to 0 + log.info(e) + return 0 + except TypeError as e: + # this happens when the collective is not of type ir._CollectiveKernel + log.info(e) + return 0 + + elif is_wait(self.node): + # ir.Wait is only used for collective ops. + # The time needed for the collective op is already estimated and considered + # when we are processing the collective op IR node, so ir.Wait takes 0 time + # since it doesn't take extra time to get the result after the collective is completed. + return 0 + + dtype = buf.node.maybe_get_dtype() + try: + gpu_memory_bandwidth = get_gpu_dram_gbps() + gpu_flops = get_device_tflops(dtype) * 10**12 + # If cudaGetDeviceProperties returns 0 for gpu_memory_bandwidth or gpu_flops + # there is a chance to continue execution successfully. Otherwise, it would fail with + # ZeroDivisionError below. + if gpu_memory_bandwidth <= 0: + raise AssertionError( + f"gpu_memory_bandwidth cannot be <= 0, but got {gpu_memory_bandwidth}" + ) + if gpu_flops <= 0: + raise AssertionError(f"gpu_flops cannot be <= 0, but got {gpu_flops}") + except Exception: + return 0 + + flops_est = self.estimate_flops() + + if flops_est == 0 or flops_est is None: + # no flops estimate, so fall back to memory estimate + return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth + + # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship + factor = 1.0 + counted_bytes = self.get_read_write_buffers_sizes() + counted_bytes = 0 if counted_bytes is None else counted_bytes + compute_time = (factor * flops_est / gpu_flops) * 1e9 + transfer_time = counted_bytes / gpu_memory_bandwidth + + # Return estimated runtime in nanoseconds + return max(compute_time, transfer_time) + + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return None + + def get_template_node_or_throw(self) -> ir.TemplateBuffer: + template = self.get_template_node() + assert template is not None + return template + + @staticmethod + def get_prologue_template_epilogue( + nodes: list[BaseSchedulerNode], + ) -> tuple[list[BaseSchedulerNode], BaseSchedulerNode, list[BaseSchedulerNode]]: + """ + For the list of nodes, get the prologue, template, and epilogue + """ + template_index = next(i for i, n in enumerate(nodes) if n.is_template()) + + prologue = nodes[:template_index] + template_node = nodes[template_index] + epilogue = nodes[template_index + 1 :] + return prologue, template_node, epilogue + + +class WhyNoFuse: + # TODO when we drop support for Python < 3.10, we can use + # @dataclass(slots=True) instead of manually specifying __slots__. + __slots__ = ["name1", "name2", "reason", "args"] + reason: str + args: tuple[Any, ...] + + def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None: + self.name1 = node1.get_name() + self.name2 = node2.get_name() + + def __call__(self, reason: str, *args: Any) -> None: + self.reason = reason + self.args = args + fusion_log.debug(self) + + def __str__(self) -> str: + return f"cannot fuse {self.name1} with {self.name2}: " + ( + self.reason % self.args + ) + + +def pformat(obj: Any) -> str: + if isinstance(obj, (OrderedSet, set)): # noqa: set_linter + # pformat has trouble with sets of sympy exprs + obj = sorted(obj, key=str) + result = pprint.pformat(obj, indent=4) + if "\n" in result: + return f"\n{textwrap.indent(result, ' ' * 4)}" + return result + + +class OutputNode: + def __init__(self, dep: StarDep) -> None: + self.unmet_dependencies = OrderedSet([dep]) + + def is_reduction(self) -> bool: + return False + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return () + + def get_name(self) -> str: + return "OUTPUT" + + __repr__ = get_name + + +def _prune_redundant_deps( + node: BaseSchedulerNode, + name_to_fused_node: dict[str, BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], +) -> None: + """ + Prunes weakdeps intended for mutation ordering + on an upstream fused node if after fusion there is another dependency + on the fused upstream node, making the weakdep redundant + + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will + be incrementally removed, enabling other fusions, ensuring they are fused in order. + """ + name_to_dep_count: Counter[str] = collections.Counter() + + for dep in node.unmet_dependencies: + if not isinstance(dep, WeakDep): + op_name = name_to_buf[dep.name].defining_op_name() + name_to_dep_count[name_to_fused_node[op_name].get_name()] += 1 + + def should_prune(dep: Dep) -> bool: + if isinstance(dep, WeakDep): + op_name = name_to_buf[dep.name].defining_op_name() + is_redundant = name_to_dep_count[name_to_fused_node[op_name].get_name()] > 0 + # These can occur because fused nodes always gather deps from their snodes + # If B has a weakdep on A + # B gets fused with C, then any time BC is fused, the weakdep will reappear + is_self_dep = name_to_fused_node[op_name] == node + return is_redundant or is_self_dep + else: + return False + + deps_to_prune = OrderedSet( + dep for dep in node.unmet_dependencies if should_prune(dep) + ) + + if deps_to_prune: + node.unmet_dependencies = node.unmet_dependencies - deps_to_prune + node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) + + +class ExternKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + + def debug_str_extra(self) -> str: + return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" + + def is_extern(self) -> bool: + return True + + def has_side_effects(self) -> bool: + assert self.node is not None + return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() + + +class NopKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + + +class SchedulerNode(BaseSchedulerNode): + _sizes: tuple[Sequence[sympy.Expr], ...] + _body: LoopBody + + def __init__( + self, + scheduler: Scheduler, + node: Union[ir.ComputedBuffer, ir.TemplateBuffer], + ) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self._compute_attrs() + + def _compute_attrs( + self, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> None: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) + self._sizes, self._body = self.node.simplify_and_reorder( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) + + device = self.node.get_device_or_error() + group_fn = self.scheduler.get_backend(device).group_fn + self.group = (device, group_fn(self._sizes)) + + # Don't normalize since normalization will merge loops which + # makes it hard to decide new loop orders. + should_normalize = not config.loop_ordering_after_fusion or not is_gpu( + device.type + ) + + if isinstance(self.node, ir.TemplateBuffer): + self.set_read_writes( + self.node.extract_read_writes(normalize=should_normalize) + ) + else: + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=should_normalize + ) + ) + + def recompute_size_and_body( + self, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> None: + self._compute_attrs( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) + + def refresh_dependencies( + self, normalize: bool, need_clear_tiling_cache: bool + ) -> None: + # Fake dependencies are added manually. They can not be analyzed from + # extract_read_writes. Find them out and apply manually. + fake_deps: OrderedSet[Dep] = OrderedSet( + dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep)) + ) + + # don't normalize since the loop order may need to be further changed + # later + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=normalize + ).with_read(fake_deps) + ) + + self.pointwise_read_writes.clear_cache(self) + + if need_clear_tiling_cache: + from .codegen.simd import SIMDScheduling + + # TODO(shunting) if this cause compilation time increase when + # enabling LOAF by default, try just clearing the specific cache + # entry by using a customized cache implementation rather than + # lru_cache. + SIMDScheduling.candidate_tilings.cache_clear() + + def apply_new_loop_order(self, new_order: Sequence[int]) -> None: + self._body = self._body.reorder_iter_loops( + new_order, + ) + self._sizes = self._body.sizes + + self.refresh_dependencies(normalize=False, need_clear_tiling_cache=True) + + def merge_loops(self) -> None: + self._body = self._body.merge_loops() + self._sizes = self._body.sizes + + # merge_loops is called after loop reordering. + # We still need retain fake dependencies since codegen the + # estimated amount of memory access rely on them. + # + # Merge loops does not affect the tiling decision. So we + # don't need clear the tiling cache. + self.refresh_dependencies(normalize=True, need_clear_tiling_cache=False) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + new_order = None + self_sizes = self._sizes[0] + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if new_order: + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for %s with order %s", self.get_name(), new_order + ) + self.apply_new_loop_order(new_order) + else: + loop_ordering_log.debug( + "Don't reordering %s because we can not decide the suitable loop order", + self.get_name(), + ) + + def debug_str_extra(self) -> str: + name = self.get_name() + lines = [ + f"{name}.group.device = {self.group[0]}", + f"{name}.group.iteration = {self.group[1]}", + f"{name}.sizes = {self._sizes}", + ] + for dep in self.read_writes.reads_and_writes(): + if not isinstance(dep, WeakDep): + buf_name = dep.name + buf = V.graph.get_buffer(buf_name) + if not isinstance(buf, ir.TorchBindObject): + lines.append(f"{buf_name}_layout = {pformat(buf.layout)}") + if isinstance(self._body, LoopBody): + lines.append(f"class {name}_loop_body:") + lines.append(textwrap.indent(self._body.debug_str(), " ")) + + assert self.node is not None + lines.extend(self._debug_str_for_device()) + + return "\n".join(lines) + + def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]: + return self._sizes + + def is_reduction(self) -> bool: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), ( + f"{type(self.node)=}" + ) + return bool(self.node.get_reduction_type()) + + def is_split_scan(self) -> bool: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), ( + f"{type(self.node)=}" + ) + return isinstance(self.node, ir.ComputedBuffer) and isinstance( + self.node.data, ir.SplitScan + ) + + def is_template(self) -> bool: + return isinstance(self.node, ir.TemplateBuffer) + + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return self.node if isinstance(self.node, ir.TemplateBuffer) else None + + def run(self, *index_vars: Sequence[sympy.Expr]) -> None: + self.decide_inplace_update() + self.mark_run() + self.codegen(index_vars) + + def ranges_from_index_vars( + self, index_vars: Sequence[Sequence[sympy.Expr]] + ) -> dict[sympy.Expr, sympy.Expr]: + sizes = self._sizes + assert sum(map(len, sizes)) == sum(map(len, index_vars)) + var_ranges = dict( + zip( + itertools.chain.from_iterable(index_vars), + itertools.chain.from_iterable(sizes), + ) + ) + return var_ranges + + def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: + var_ranges = self.ranges_from_index_vars(index_vars) + try: + with ( + V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)), + V.kernel.set_current_node(self), + ): + self._body(*index_vars) + except Exception: + log.fatal("Error in codegen for %s", self.node) + raise + + def pointwise_or_reduction_read_writes( + self, pointwise: bool = True + ) -> dependencies.ReadWrites: + """ + Get the memory dependencies in either the pointwise or the reduction axes. + """ + keep_sizes, ignore_sizes = self._sizes if pointwise else reversed(self._sizes) + return dependencies.extract_read_writes( + self._body, keep_sizes, hidden_args=[[sympy.S.Zero] * len(ignore_sizes)] + ) + + @cache_on_self + def pointwise_read_writes(self) -> dependencies.ReadWrites: + """ + Get the memory dependencies in the non-reduction axes. + """ + return self.pointwise_or_reduction_read_writes(pointwise=True) + + @cache_on_self + def reduction_read_writes(self) -> dependencies.ReadWrites: + """ + Get the memory dependencies in the reduction axes. + """ + return self.pointwise_or_reduction_read_writes(pointwise=False) + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + if self.is_template(): + return False + if any(out.get_aliases() for out in self.get_outputs()): + return False + if len(self.read_writes.writes) == 1 and isinstance( + read_dep, dependencies.MemoryDep + ): + write_dep = next(iter(self.read_writes.writes)) + assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}" + return read_dep.index == write_dep.index and read_dep.size == write_dep.size + return False + + @cache_on_self + def _get_atomic_add_buffers(self) -> OrderedSet[str]: + buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet() + if isinstance(self._body, LoopBody): + for node in self._body.get_nodes(): + if ( + node.op == "call_method" + and node.target == "store" + and ( + ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add") + or (len(node.args) == 5 and node.args[4] == "atomic_add") + ) + ): + buffers_store_as_atomic_add.add( + node.kwargs["name"] + if "name" in node.kwargs + else (node.args[1] if len(node.args) >= 2 else "") + ) + return buffers_store_as_atomic_add + + +def refresh_group_node_dependencies( + group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode], +) -> None: + snodes = group_snode.snodes + group_snode.set_read_writes( + dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) + ) + + group_snode.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes]) + if dep.name not in group_snode.get_buffer_names() + ) + - group_snode.read_writes.writes + ) + + +def init_group_node( + group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode], + scheduler: Scheduler, + snodes: list[BaseSchedulerNode], +) -> None: + assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode)) + group_snode.snodes = snodes + group_snode.scheduler = scheduler + group_snode.node = None + group_snode.ancestors = OrderedSet.union( + *[x.ancestors for x in snodes if x.ancestors is not None] + ) + + refresh_group_node_dependencies(group_snode) + + group_snode.min_order = min(x.min_order for x in group_snode.snodes) + group_snode.max_order = max(x.max_order for x in group_snode.snodes) + group_snode.outputs_by_name = { + buf.get_name(): buf for buf in group_snode.get_outputs() + } + + +class FusedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be fused together. The way it does this is by maintaining + its unmet dependencies as the union of its constituent nodes. + """ + + snodes: list[BaseSchedulerNode] + + @classmethod + def fuse( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> FusedSchedulerNode: + assert node1.scheduler is node2.scheduler + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + if node1.is_template() and isinstance(node2, ExternKernelSchedulerNode): + # Fuse multi outputs template and its outputs + # * Node1 has memorydep of MultiOutput in reads + # * Node2 has StarDep of MultiOutput in writes + # Rewrite the Node2' StarDep to MemoryDep, because calculate score_fusion_memory + # of the template node and its epilogue requires the same type of dependencies + assert isinstance(node2.node, MultiOutput) + assert len(node2.read_writes.writes) == 1 + assert isinstance(next(iter(node2.read_writes.writes)), StarDep) + name = next(iter(node2.read_writes.writes)).name + template_nodes = [node for node in node1.get_nodes() if node.is_template()] + assert len(template_nodes) == 1 + template_node = template_nodes[0] + assert len(template_node.read_writes.writes) == 1 + write = next(iter(template_node.read_writes.writes)) + assert isinstance(write, MemoryDep) + node2.read_writes.writes = OrderedSet( + [ + MemoryDep( + name, write.index, write.var_names, write.size, write.mode + ), + ] + ) + else: + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes())) + return cls(node1.scheduler, nodes) + + @cache_on_self + def estimate_flops(self) -> int | None: + # don't increment counters in fused methods so we don't double count + fps = list( + filter( + None, + ( + node.estimate_flops() + for node in self.get_nodes() + if node.is_template() or node.is_extern() + ), + ) + ) + if len(fps) == 0: + return None + ret = sum(fps) + return ret + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + if self.is_template(): + # We can not really reorder loops for a triton template + return + self_sizes = None + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + if self_sizes is not None and tuple(self_sizes) != tuple(snode._sizes[0]): + loop_ordering_log.debug( + "Can not reorder fused node due to different sizes" + ) + return + self_sizes = snode._sizes[0] + new_order = None + + assert self_sizes is not None + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if not new_order: + loop_ordering_log.debug( + "Dont reordering fused node %s because we can not decide the suitable loop order", + self.get_name(), + ) + return + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for fused node %s with order %s", self.get_name(), new_order + ) + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + snode.apply_new_loop_order(new_order) + + refresh_group_node_dependencies(self) + + def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None: + super().__init__(scheduler) + init_group_node(self, scheduler, snodes) + self.users: list[NodeUser] = [] + self.group = max(snodes, key=lambda x: int(x.is_reduction())).group + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) + + def get_outputs(self) -> list[SchedulerBuffer]: + result: list[SchedulerBuffer] = [] + for node in self.snodes: + result.extend(node.get_outputs()) + return result + + def debug_str_extra(self) -> str: + lines = [ + f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" + for i, node in enumerate(self.snodes) + ] + node = self.snodes[0].node + if node is not None: + lines.extend(self._debug_str_for_device()) + + return textwrap.indent("\n".join(lines).rstrip(), " ") + + def debug_str_short(self) -> str: + snodes_str = [node.debug_str_short() for node in self.snodes] + return f"{self}, snodes: {snodes_str}" + + def set_last_usage( + self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str] + ) -> None: + # Set self.last_usage using the global information + # This will be used for inter-kernel optimisations + super().set_last_usage(future_used_buffers, mutation_real_name) + # Set self.last_usage on the snodes + # This will be used for optimisations within the kernel + future_used_buffers: OrderedSet[str] = OrderedSet() + for node in reversed(self.snodes): + node.set_last_usage(future_used_buffers, mutation_real_name) + future_used_buffers.update(node.last_usage) + + @cache_on_self + def used_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes]) + + @cache_on_self + def used_or_aliased_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union( + *[x.used_or_aliased_buffer_names() for x in self.snodes] + ) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return self.snodes + + def __repr__(self) -> str: + return f"{type(self).__name__}(nodes={self.get_name()})" + + @cache_on_self + def is_reduction(self) -> bool: + return any(x.is_reduction() for x in self.snodes) + + @cache_on_self + def is_split_scan(self) -> bool: + return any(x.is_split_scan() for x in self.snodes) + + @cache_on_self + def is_template(self) -> bool: + return any(x.is_template() for x in self.snodes) + + @cache_on_self + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + for node in self.snodes: + if node.is_template(): + return node.get_template_node() + return None + + def get_device(self) -> torch.device: + return self.group[0] + + @cache_on_self + def has_aliasing_or_mutation(self) -> bool: + return any(x.has_aliasing_or_mutation() for x in self.snodes) + + # None of these need to be implemented, as a FusedSchedulerNode is just an + # abstraction for scheduling purposes + def update_mutated_names(self, renames: dict[str, str]) -> None: + raise NotImplementedError + + def add_fake_dep(self, name: Dep) -> None: + raise NotImplementedError + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + raise NotImplementedError + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + node_typestr = ",".join(type(n).__name__ for n in self.snodes) + buf = IndentedBuffer() + buf.splice( + f"""\ +{name}: {type(self).__name__}({node_typestr}) +{name}.writes = {pformat(self.read_writes.writes)} +{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} +{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} +{name}.outputs = [ + """ + ) + with buf.indent(): + for out in self.get_outputs(): + buf.splice(out.debug_str()) + buf.writeline("]") + + try: + buf.splice(self.debug_str_extra()) + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return buf.getrawvalue().rstrip() + + +class ForeachKernelSchedulerNode(FusedSchedulerNode): + """ + This is a schedular node that consists of a set of scheduler nodes that + has no data dependencies among them and can be executed in parallel. + """ + + def get_consumer_subnode_for( + self, producer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: + for buf in producer.get_outputs(): + if buf.get_name() in self.read_to_node: + return self.read_to_node[buf.get_name()] + + return None + + def get_producer_subnode_for( + self, consumer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: + producers = OrderedSet[BaseSchedulerNode]() + for rd in consumer.read_writes.reads: + if rd.name not in self.scheduler.name_to_buf: + continue + + node_name = self.scheduler.name_to_buf[rd.name].defining_op_name() + if node_name in self.name_to_node: + producers.add(self.name_to_node[node_name]) + + # Don't permit fusion if there are multiple subnodes + # that this consumer reads from + if len(producers) == 1: + return next(iter(producers)) + else: + return None + + @classmethod + def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: + why = WhyNoFuse(producer, consumer) + if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + foreach_match = len(producer.snodes) == len(consumer.snodes) + if not foreach_match: + why("foreach do not have same length") + return foreach_match and all( + producer.scheduler.can_fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ) + elif consumer.is_foreach(): + if producer.is_reduction(): + why( + "candidate producer is a reduction, foreach ops cannot be fused with reductions currently" + ) + return False + + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + consumer_subnode = consumer.get_consumer_subnode_for(producer) + if consumer_subnode is not None: + return consumer.scheduler.can_fuse(producer, consumer_subnode) + + why("candidate producer is not dep of any foreach consumer") + return False + + elif producer.is_foreach(): + if consumer.is_reduction(): + why( + "candidate consumer is a reduction, foreach ops cannot be fused with reductions currently" + ) + return False + + producer = typing.cast(ForeachKernelSchedulerNode, producer) + producer_subnode = producer.get_producer_subnode_for(consumer) + if producer_subnode is not None: + return producer.scheduler.can_fuse(producer_subnode, consumer) + + why("candidate consumer has no dep in any foreach producer") + return False + + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node" + ) + + @classmethod + def fuse( + cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode + ) -> ForeachKernelSchedulerNode: + assert producer.is_foreach() or consumer.is_foreach() + if producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + use_custom_partition_algo = producer.use_custom_partition_algo + enable_autotune = producer.enable_autotune + else: + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + use_custom_partition_algo = consumer.use_custom_partition_algo + enable_autotune = consumer.enable_autotune + prev_node_1 = None + prev_node_2 = None + fused_nodes: list[BaseSchedulerNode] + if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + fused_nodes = [ + FusedSchedulerNode.fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ] + elif producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + producer_subnode = producer.get_producer_subnode_for(consumer) + fused_nodes = [] + prev_node_1 = producer + prev_node_2 = None + for node in producer.snodes: + if node is producer_subnode: + new_node = FusedSchedulerNode.fuse(node, consumer) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + + elif consumer.is_foreach(): + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + consumer_subnode = consumer.get_consumer_subnode_for(producer) + fused_nodes = [] + prev_node_1 = consumer + prev_node_2 = None + + for node in consumer.snodes: + if node is consumer_subnode: + new_node = FusedSchedulerNode.fuse(producer, node) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + else: + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node" + ) + + return cls( + producer.scheduler, + fused_nodes, + use_custom_partition_algo=use_custom_partition_algo, + prev_node_1=prev_node_1, + prev_node_2=prev_node_2, + enable_autotune=enable_autotune, + ) + + def __init__( + self, + scheduler: Scheduler, + snodes: list[BaseSchedulerNode], + use_custom_partition_algo: bool, + prev_node_1: Optional[BaseSchedulerNode] = None, + prev_node_2: Optional[BaseSchedulerNode] = None, + enable_autotune: bool = False, + ) -> None: + self.read_to_node = {} + self.name_to_node = {} + + if prev_node_1 is None or prev_node_2 is None: + super().__init__(scheduler, snodes) + + for node in snodes: + for read in node.read_writes.reads: + self.read_to_node[read.name] = node + + for name in node.get_operation_names(): + self.name_to_node[name] = node + else: + self.scheduler = scheduler + self.snodes = snodes + self.node = None + self.users: list[NodeUser] = [] + + self.set_read_writes( + dependencies.ReadWrites.merge_list( + [prev_node_1.read_writes, prev_node_2.read_writes] + ) + ) + + self.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union( + prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies + ) + if dep.name not in self.get_buffer_names() + ) + - self.read_writes.writes + ) + + self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) + self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) + + if prev_node_1.is_foreach(): + assert isinstance(prev_node_1, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_1, prev_node_2 + else: + assert isinstance(prev_node_2, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_2, prev_node_1 + + self.ancestors = foreach_node.ancestors + self.ancestors.update(other_node.ancestors) + + self.name_to_node = foreach_node.name_to_node + for name in other_node.get_operation_names(): + self.name_to_node[name] = other_node + + self.outputs_by_name: dict[str, SchedulerBuffer] = { + k: v for snode in self.snodes for k, v in snode.outputs_by_name.items() + } + + self.use_custom_partition_algo = use_custom_partition_algo + device = snodes[0].get_device() + assert device + self.group = (device, ((sympy.Expr("combo_kernel"),),)) + self.origins = OrderedSet[torch.fx.Node]() + self.enable_autotune = enable_autotune + + @classmethod + def combinable_nodes( + cls, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: + extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)] + if extern: + log.debug( + "ComboKernels: %d external nodes are filtered %s", + len(extern), + [node.node.get_origins() for node in extern if node.node is not None], + ) + filtered_nodes = [ + x + for x in nodes + if not isinstance(x, (NopKernelSchedulerNode, ExternKernelSchedulerNode)) + ] + foreach_nodes = [ + x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode) + ] + if foreach_nodes: + log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes)) + filtered_nodes = [ + x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode) + ] + template_nodes = [x for x in filtered_nodes if x.is_template()] + if template_nodes: + log.debug( + "ComboKernels: %d template nodes are filtered: %s", + len(template_nodes), + template_nodes, + ) + filtered_nodes = [x for x in filtered_nodes if x not in template_nodes] + return filtered_nodes + + @staticmethod + def _default_group_nodes_for_combo_kernels( + scheduler: Scheduler, + ) -> list[list[BaseSchedulerNode]]: + """ + Returns a list of lists of nodes that are to be grouped together. + """ + sorted_nodes = scheduler._topological_sort_nodes() + grouped_nodes = [] + max_num_nodes = 8 + for nodes in sorted_nodes: + grouped_nodes.extend( + [ + nodes[i : i + max_num_nodes] + for i in range(0, len(nodes), max_num_nodes) + ] + ) + + return grouped_nodes + + group_algorithm_for_combo_kernels: Callable[ + [Scheduler], list[list[BaseSchedulerNode]] + ] = _default_group_nodes_for_combo_kernels + + @staticmethod + def set_group_algorithm_for_combo_kernels( + custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]], + ) -> None: + ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = ( + custom_group_algorithm + ) + + @staticmethod + def group_nodes_for_combo_kernels( + scheduler: Scheduler, + ) -> list[list[BaseSchedulerNode]]: + return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler) + + def mark_run(self) -> None: + raise NotImplementedError + + def codegen(self) -> None: + raise NotImplementedError + + def is_foreach(self) -> bool: + return True + + def get_subkernel_nodes(self) -> list[BaseSchedulerNode]: + """Returns a list of nodes which comprise the combo kernel. + These nodes may be vertically fused.""" + return list(self.snodes) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + """Returns all nodes contained in this kernel, unpacking fused nodes + into their constituent scheduler nodes.""" + return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) + + def get_first_name(self) -> str: + return self.snodes[0].get_first_name() + + def prune_redundant_deps( + self, name_to_fused_node: dict[str, BaseSchedulerNode] + ) -> None: + _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) + + for node in self.snodes: + node.prune_redundant_deps(name_to_fused_node) + + +class GroupedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be *grouped* together (it does not allow another node to be scheduled + in between its constituent nodes, nor does it allow another node to fuse into any of its constituent nodes). + The way it does this is by maintaining its unmet dependencies as the union of its constituent nodes. + Fusion will still happen among the nodes within each GroupedSchedulerNode. + At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node. + """ + + snodes: list[BaseSchedulerNode] + + @classmethod + def create(cls, snodes: list[BaseSchedulerNode]) -> GroupedSchedulerNode: + scheduler = snodes[0].scheduler + assert all(node.scheduler is scheduler for node in snodes) + grouped_snode = cls(scheduler, snodes) + for snode in snodes: + scheduler.name_to_fused_node[snode.get_name()] = grouped_snode + scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode + return grouped_snode + + def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None: + super().__init__(scheduler) + init_group_node(self, scheduler, snodes) + + def unpack(self) -> list[BaseSchedulerNode]: + """ + Do fusion among nodes within this GroupedSchedulerNode, + and then unpack this GroupedSchedulerNode into regular nodes. + """ + for snode in self.snodes: + self.scheduler.name_to_fused_node[snode.get_name()] = snode + del self.scheduler.name_to_fused_node[self.get_name()] + return self.scheduler.fuse_nodes(self.snodes) + + def add_fake_dep(self, fake_dep: Dep) -> None: + self.set_read_writes(self.read_writes.with_read(fake_dep)) + self.unmet_dependencies.add(fake_dep) + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) + + def get_outputs(self) -> list[SchedulerBuffer]: + result: list[SchedulerBuffer] = [] + for node in self.snodes: + result.extend(node.get_outputs()) + return result + + @cache_on_self + def estimate_flops(self) -> int | None: + # don't increment counters in fused methods so we don't double count + fps = list( + filter( + None, + ( + node.estimate_flops() + for node in self.get_nodes() + if node.is_template() or node.is_extern() + ), + ) + ) + if len(fps) == 0: + return None + ret = sum(fps) + return ret + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return self.snodes + + @classmethod + def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: + # GroupedSchedulerNode cannot be fused with another node + return False + + +def pick_loop_order( + stride_lengths: list[list[int]], + sizes: Sequence[sympy.Expr], + priority_idx: tuple[int, ...] = (), +) -> list[int]: + """ + A heuristic to decide loop iteration orders. This has not been well + tuned and may be something we should autotune. + """ + + @functools.cmp_to_key + def index_cmp(a: int, b: int) -> int: + if sizes[a] == 1 or sizes[b] == 1: + # 1-sizes don't matter, just move them to the end + return cmp(sizes[a] == 1, sizes[b] == 1) + + # Take abs, otherwise flipped dimensions are treated as smaller + # strides than contiguous dims + stride_len_a = [abs(sl[a]) for sl in stride_lengths] + stride_len_b = [abs(sl[b]) for sl in stride_lengths] + + # equivalent to + # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() + a_first = sum( + sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + b_first = sum( + sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + if a_first > b_first: + return -1 + if b_first > a_first: + return 1 + + # otherwise contiguous + return cmp(b, a) + + order = list(reversed(range(len(stride_lengths[0])))) + if len(priority_idx) > 0: + # if we have priority node, only use that node's order + stride_lengths = [stride_lengths[pi] for pi in priority_idx] + if config.pick_loop_orders: + order.sort(key=index_cmp) + return order + + +@dataclasses.dataclass +class NodeUser: + node: Union[BaseSchedulerNode, OutputNode] + can_inplace: bool = False + + # A weak user must be scheduled after a given node, but doesn't actually + # use the result + is_weak: bool = False + + def __hash__(self) -> int: + return hash((self.node.get_name(), self.can_inplace, self.is_weak)) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, NodeUser) + and self.get_name() == other.get_name() + and self.can_inplace == other.can_inplace + and self.is_weak == other.is_weak + ) + + def get_name(self) -> str: + return self.node.get_name() + + def merge(self, other: NodeUser) -> NodeUser: + assert self.node is other.node + return NodeUser( + self.node, + self.can_inplace and other.can_inplace, + self.is_weak and other.is_weak, + ) + + +_post_grad_graph_counter = itertools.count() + + +class Scheduler: + """ + A Scheduler is a graph of BaseSchedulerNodes. It is responsible for + optimizations such as fusion, reorder, and graph partition. + """ + + __dep_size_hint_cache: dict[Dep, int] + + def __init__(self, nodes: list[ir.Operation]) -> None: + with dynamo_timed("Scheduler.__init__"): + self._init(nodes) + + def _init(self, nodes: list[ir.Operation]) -> None: + super().__init__() + self.__dep_size_hint_cache = {} + V.graph.scheduler = self + self.backends: dict[torch.device, BaseScheduling] = {} + self.post_grad_graph_id = next(_post_grad_graph_counter) + self._graph_partition_counter = itertools.count() + + self.completed_operations: OrderedSet[str] = OrderedSet() + self.available_buffer_names = OrderedSet( + [ + *V.graph.graph_inputs.keys(), + *V.graph.constants.keys(), + *V.graph.torchbind_constants.keys(), + ] + ) + + self.nodes = [self.create_scheduler_node(n) for n in nodes] + self.update_zero_dim_cpu_tensor() + # some new constants could have been created above + self.available_buffer_names.update(V.graph.constants.keys()) + for node in self.nodes: + node.prune_deps() + + self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = ( + self.get_donated_buffers() + ) + self.name_to_node: dict[str, BaseSchedulerNode] = { + n.get_name(): n for n in self.nodes + } + self.name_to_buf: dict[str, SchedulerBuffer] = { + buf.get_name(): buf for node in self.nodes for buf in node.get_outputs() + } + self.name_to_fused_node: dict[str, BaseSchedulerNode] = self.name_to_node.copy() + + # mutation_real_name: Maps back to the original name for codegen + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_real_name = {"buf0" : "buf1"} + # all subsequent uses of buf0 become buf1's usage in dependency graph + self.mutation_real_name: dict[str, str] = {} + + # We handle mutation by renaming modified versions of the same + # buffer in the dependency graph to prevent cycles. + # mutation_renames: tracks the current name for a given buffer + # (changed once per mutation) + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_renames = {"buf1" : "buf0"} + # in codegen we only use buf0, never buf1 + self.mutation_renames: dict[str, str] = {} + + # Must run first to correctly set dependencies, before all other passes that rely on + # reading from .read_writes.reads or .unmet_dependencies + self.nodes = comms.decide_global_ordering_of_comms( + self.nodes, + self.name_to_buf, + self.name_to_fused_node, + ) + + self.compute_dependencies() + self.nodes = self.topological_sort_schedule(self.nodes) + self.dead_node_elimination() + self.name_to_fused_node = {n.get_name(): n for n in self.nodes} + self.compute_ancestors() + + metrics.ir_nodes_pre_fusion += len(self.nodes) + from torch._inductor.debug import log_ir_post_fusion, log_ir_pre_fusion + + log_ir_pre_fusion(self.nodes) + self.num_orig_nodes = len(self.nodes) + self.create_foreach_nodes() + self.nodes = self.topological_sort_schedule(self.nodes) + self.logged_slow_fusion = OrderedSet[tuple[str, str]]() + if config._pre_fusion_custom_pass is not None: + self.nodes = config._pre_fusion_custom_pass(self.nodes) + self.nodes = self.fuse_nodes(self.nodes) + if config._post_fusion_custom_pass is not None: + self.nodes = config._post_fusion_custom_pass(self.nodes) + self.merge_loops() + self.finalize_multi_template_buffers() + if config.combo_kernels: + self.create_combo_kernel_nodes(num_ck_nodes=None) + + # Peak memory pass and overlap pass must run last, otherwise + # other reordering passes could undo their effects. + if config.reorder_for_peak_memory: + from .memory import reorder_for_peak_memory + + self.nodes = reorder_for_peak_memory( + self.nodes, + self.name_to_buf, + self.name_to_fused_node, + OrderedSet(V.graph.graph_inputs.keys()), + OrderedSet(V.graph.get_output_names()), + ) + if config.reorder_for_compute_comm_overlap: + self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) + self.process_grouped_nodes() + + if torch._inductor.config.graph_partition: + self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) + self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) + + self.compute_last_usage() + log_ir_post_fusion(self.nodes) + V.debug.graph_diagram(self.nodes) + self.debug_draw_graph() + + # used during codegen: + self.buffer_names_to_free: OrderedSet[str] = OrderedSet() + + # fx graph node to the position it appears in the graph + # for debug attribution + self.origin_to_index: dict[torch.fx.Node, int] = {} + + get_metric_table("graph_stats").add_row( + lambda: { + "graph_id": self.post_grad_graph_id, + "num_nodes_before_fusion": self.num_orig_nodes, + "num_nodes_after_fusion": len(self.nodes), + } + ) + + def get_donated_buffers(self) -> dict[str, SchedulerDonatedBuffer]: + name_to_donated_buf = {} + for name in V.graph.graph_inputs_original: + if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer): + name_to_donated_buf[name] = SchedulerDonatedBuffer( + self, + V.graph.graph_inputs_original[name], + defining_op=None, + ) + return name_to_donated_buf + + @property + def current_device(self) -> Optional[torch.device]: + return V.graph.current_device + + @current_device.setter + def current_device(self, device: Optional[torch.device]) -> None: + V.graph.current_device = device + + def debug_draw_graph(self) -> None: + """Generate an image of the graph for debugging""" + if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": + from .debug import draw_buffers + + draw_buffers(self.nodes, print_graph=True) + + def debug_print_nodes(self, label: str) -> None: + if log.isEnabledFor(logging.INFO): + log.info("%s:", label) + for node in self.nodes: + node.log_details() + + def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode: + assert node.get_origins() is not None, ( + "All nodes passed to scheduling must have an origin" + ) + if node.is_no_op(): + return NopKernelSchedulerNode(self, node) + elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): + return SchedulerNode(self, node) + elif isinstance(node, ir.ExternKernel): + return ExternKernelSchedulerNode(self, node) + else: + raise NotImplementedError(node) + + def create_foreach_nodes(self) -> None: + removed_node_names: OrderedSet[str] = OrderedSet() + fe_nodes = [] + kept_node_names = self.name_to_fused_node.keys() + + for names in V.graph.lists.values(): + names = [ + name + for name in names + if name in kept_node_names + and not isinstance(self.name_to_node[name], NopKernelSchedulerNode) + ] + if not names: + # All nodes eliminated + continue + + removed_node_names.update(names) + snodes = [self.name_to_node[name] for name in names] + + enable_autotune = config.combo_kernels_autotune > 1 + fe_node = ForeachKernelSchedulerNode( + self, + snodes, + use_custom_partition_algo=False, + enable_autotune=enable_autotune, + ) + + fe_nodes.append(fe_node) + + for name in names: + self.name_to_fused_node[name] = fe_node + + self.nodes = [ + node for node in self.nodes if node.get_name() not in removed_node_names + ] + list(fe_nodes) + + def compute_dependencies(self) -> None: + """ + Create dependency edges between nodes, handling aliasing and + mutation properly. + """ + + T = TypeVar("T") + + class DedupList(Generic[T]): + """ + This data structure behaves like a list except it makes sure the + elements remain unique. + Normally one could use a OrderedSet/dict for this purpose however + the list in question gets elements appended as it is being + iterated over which means that we need to keep the list + semantics. + """ + + def __init__( + self, + items: Optional[list[T]] = None, + membership: Optional[OrderedSet[T]] = None, + ) -> None: + self.items = items or [] + self.membership = membership or OrderedSet() + + def append(self, node_user: T) -> None: + if node_user in self.membership: + return + self.items.append(node_user) + self.membership.add(node_user) + + def __add__(self, other: DedupList[T]) -> DedupList[T]: + new_membership = OrderedSet.union(self.membership, other.membership) + new_items = self.items + [ + x for x in other.items if x not in self.membership + ] + return DedupList(new_items, new_membership) + + name_to_users: defaultdict[str, DedupList[NodeUser]] = collections.defaultdict( + DedupList + ) + + # handle aliasing by using python aliasing in name_to_users + # if foo aliases bar then we will make name_to_users["foo"] point + # to the same python list as name_to_users["bar"] + for node in self.nodes: + for buf1 in node.get_outputs(): + buf1_name = buf1.get_name() + for buf2_name in buf1.get_aliases(): + if buf1_name in name_to_users and buf2_name in name_to_users: + # merge the two + list1 = name_to_users[buf1_name] + list2 = name_to_users[buf2_name] + combined = list1 + list2 + for key in name_to_users.keys(): + if ( + name_to_users[key] is list1 + or name_to_users[key] is list2 + ): + name_to_users[key] = combined + elif buf1_name in name_to_users: + name_to_users[buf2_name] = name_to_users[buf1_name] + else: + name_to_users[buf1_name] = name_to_users[buf2_name] + + def rename(n: str) -> str: + if n in self.mutation_renames: + return rename(self.mutation_renames[n]) + return n + + def add_user( + used_by_name: str, + user_node: Union[BaseSchedulerNode, OutputNode], + can_inplace: bool = False, + is_weak: bool = False, + ) -> None: + name_to_users[rename(used_by_name)].append( + NodeUser(user_node, can_inplace, is_weak) + ) + + unbacked_symbol_to_origin_node: dict[sympy.Symbol, Optional[str]] = {} + + # NB: None means that the dependency is on an input. Don't actually + # generate a dependency because if we do, Inductor will start trying + # to free the unbacked int but that's pointless + for name, val in V.graph.graph_inputs.items(): + if isinstance(val, sympy.Expr): + for fs in val.free_symbols: + unbacked_symbol_to_origin_node[fs] = None + elif isinstance(val, ir.TensorBox): + # We also need to add symbols from input size as well because + # AOTI doesn't lift the unbacked symints to inputs + sym_size = [s for s in val.get_size() if isinstance(s, sympy.Expr)] + for s in sym_size: + for fs in s.free_symbols: + unbacked_symbol_to_origin_node[fs] = None + + for node in self.nodes: + log.debug("scheduling %s", node.node) + + # unbacked symbols don't follow ordinary buffer dependencies, so + # we track their def/uses separately + assert node.node is not None + unbacked_symbol_defs = sorted( + node.node.get_unbacked_symbol_defs(), key=lambda x: x.name + ) + for s in unbacked_symbol_defs: + assert isinstance(s, sympy.Symbol) + # Pick the first definer as canonical. There may be multiple + # because if a MultiOutputLayout buffer propagates an unbacked + # symint to multiple outputs, they will all claim to def it. + if s not in unbacked_symbol_to_origin_node: + unbacked_symbol_to_origin_node[s] = node.get_name() + + unbacked_symbol_uses = sorted( + node.node.get_free_symbol_uses(unbacked_only=True), key=lambda x: x.name + ) + # if a kernel takes unbacked symints, register dependencies + for s in unbacked_symbol_uses: + assert s in unbacked_symbol_to_origin_node, ( + f"{s} not in {unbacked_symbol_to_origin_node}" + ) + if (r := unbacked_symbol_to_origin_node[s]) is not None: + for buf in self.name_to_node[r].get_outputs(): + node.add_fake_dep(StarDep(buf.get_name())) + + if ( + len(node.read_writes.writes) == 1 + and (dep := next(iter(node.read_writes.writes))) + and isinstance(dep, MemoryDep) + ): + node_mode = dep.mode + else: + node_mode = None + + # Handle output mutations + for buf in node.get_outputs(): + # a node will mutate either 0 or 1 buffers + assert len(buf.get_mutations()) <= 1 + for alt_name in buf.get_mutations(): + alt_name = rename(alt_name) + # this node must run after the prior writer + add_user(alt_name, node) + node.add_fake_dep(StarDep(alt_name, mode=node_mode)) + for user in name_to_users[alt_name].items: + if user.get_name() == node.get_name(): + continue + + assert isinstance(user.node, BaseSchedulerNode) + for other_name in user.node.get_buffer_names(): + # this node must run after all prior readers + other_name = rename(other_name) + node.add_fake_dep( + WeakDep(other_name, mutating_buf=buf.get_name()) + ) + add_user(other_name, node, is_weak=True) + + # add normal non-mutation dependencies + for read in node.read_writes.reads: + if not isinstance(read, WeakDep): + add_user(read.name, node, node.can_inplace(read)) + + node.update_mutated_names(self.mutation_renames) + + # update our renaming scheme for the next iteration + for buf in node.get_outputs(): + for alt_name in buf.get_mutations(): + self.mutation_renames[rename(alt_name)] = buf.get_name() + self.mutation_renames[alt_name] = buf.get_name() + self.mutation_real_name[buf.get_name()] = ( + self.mutation_real_name.get(alt_name, alt_name) + ) + + # make sure outputs aren't dead-code-eliminated + for buf_name in V.graph.get_output_names(): + log.debug("scheduling output %s", buf_name) + add_user(buf_name, OutputNode(StarDep(buf_name))) + + # make sure unbacked symints aren't dead-code-eliminated + for out in V.graph.graph_outputs: + for s in out.get_free_symbol_uses(unbacked_only=True): + assert s in unbacked_symbol_to_origin_node, ( + f"{s} not in {unbacked_symbol_to_origin_node.keys()}" + ) + if r := unbacked_symbol_to_origin_node[s]: + for buf_name in self.name_to_node[r].get_buffer_names(): + log.debug( + "scheduling output %s for unbacked symint %s", buf_name, s + ) + add_user(buf_name, OutputNode(StarDep(buf_name))) + + # make sure input mutation isn't dead-code-eliminated + for name in self.mutation_renames: + if name in V.graph.graph_inputs: + add_user(name, OutputNode(StarDep(name))) + V.graph.mutated_inputs.add(name) + elif name in V.graph.constants: + # In AOTI, module parameters and buffers are not lifted as graph inputs + add_user(name, OutputNode(StarDep(name))) + + inp_names = { + name: index for index, name in enumerate(V.graph.graph_inputs.keys()) + } + V.graph.mutated_input_idxs = [ + inp_names[name] for name in V.graph.mutated_inputs + ] + + # copy users information onto the nodes + for node in self.nodes: + for buf in node.get_outputs(): + buf.set_users(name_to_users[buf.get_name()].items) + + for name in self.name_to_donated_buffer: + self.name_to_donated_buffer[name].set_users(name_to_users[name].items) + + def dead_node_elimination(self) -> None: + """ + Remove any nodes without users + """ + # self.nodes is in topological order, so by iterating in reverse order + # we have visited (and potentially removed) all users before visiting a + # given node. + updated_nodes = [] + for node in reversed(self.nodes): + + def can_eliminate_user(user: NodeUser) -> bool: + return user.is_weak or user.get_name() in V.graph.removed_operations + + active_buffers = False + for buf in node.get_outputs(): + can_eliminate = all(can_eliminate_user(u) for u in buf.users) + if can_eliminate: + log.debug("removed dead buffer: %s", buf.get_name()) + V.graph.removed_buffers.add(buf.get_name()) + else: + active_buffers = True + + can_eliminate = not node.has_side_effects() and not active_buffers + + if not can_eliminate: + updated_nodes.append(node) + else: + # dead code + log.debug("removed dead operation: %s", node.get_name()) + V.graph.removed_operations.add(node.get_name()) + for read in node.read_writes.reads: + if read.name in self.name_to_buf: + users = self.name_to_buf[read.name].users + self.name_to_buf[read.name].users = [ + u for u in users if u.node.get_name() != node.get_name() + ] + self.nodes = list(reversed(updated_nodes)) + + # Prune any WeakDeps no longer needed + for node in self.nodes: + node.prune_weak_deps() + + def topological_sort_schedule( + self, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: + """ + Ensure nodes is in topologically sorted order + """ + seen = OrderedSet[BaseSchedulerNode]() + name_to_node: dict[str, BaseSchedulerNode] = dict() + result: list[BaseSchedulerNode] = [] + + def visit(n: BaseSchedulerNode) -> None: + if n not in seen: + seen.add(n) + for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): + # We only care about doing toposort within `nodes` + if dep.name not in name_to_node: + continue + visit(name_to_node[dep.name]) + result.append(n) + + for node in nodes: + for name in node.get_buffer_names(): + name_to_node[name] = node + for node in nodes: + visit(node) + return result + + def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]: + unmet_deps: OrderedSet[str] = OrderedSet() + if isinstance( + snode, + ( + SchedulerNode, + ExternKernelSchedulerNode, + NopKernelSchedulerNode, + FusedSchedulerNode, + ), + ): + for dep in snode.unmet_dependencies: + unmet_deps.add(dep.name) + else: + raise RuntimeError( + f"get_unmet_dep_nodes is not implemented for {type(snode)}." + ) + unmet_dep_ops = (self.name_to_buf[dep].defining_op_name() for dep in unmet_deps) + return list(OrderedSet(self.name_to_fused_node[n] for n in unmet_dep_ops)) + + def _topological_sort_nodes(self) -> list[list[BaseSchedulerNode]]: + """ + Sort nodes by their topological order, return a list of node lists. + """ + order = [] + nodes = dict.fromkeys(self.nodes, 0) + children: dict[Any, Any] = {} + for node in self.nodes: + deps = self._get_unmet_dep_nodes(node) + nodes[node] = len(deps) + for dep in deps: + c = children.get(dep, []) + c.append(node) + children[dep] = c + + zero_deg_nodes = [n for n, v in nodes.items() if v == 0] + while zero_deg_nodes: + order.append(zero_deg_nodes) + for n in zero_deg_nodes: + for user in children.get(n, []): + nodes[user] -= 1 + nodes.pop(n) + zero_deg_nodes = [n for n, v in nodes.items() if v == 0] + assert not nodes, "Topological sort failed!" + return order + + def compute_ancestors(self) -> None: + """ + Populate each node.ancestors + """ + # note self.nodes is topologically sorted + name_to_ancestors: dict[str, OrderedSet[str]] = {} + for node in self.nodes: + ancestors: OrderedSet[str] = OrderedSet() + for dep in node.unmet_dependencies: + dep_node_name = self.name_to_buf[dep.name].defining_op_name() + ancestors.add(dep_node_name) + ancestors |= name_to_ancestors[dep_node_name] + name_to_ancestors[node.get_name()] = ancestors + node.ancestors = ancestors + + for order, node in enumerate(self.nodes): + node.min_order = order + node.max_order = order + + def merge_loops(self) -> None: + for node in self.nodes: + if not config.loop_ordering_after_fusion: + continue + + # Even for CPU, if we are using the halide backend, we still need + # the merge loops steps below + if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or ( + not node.is_gpu() and config.cpu_backend != "halide" + ): + continue + for snode in node.get_nodes(): + # merge loops for the scheduler node + if not isinstance(snode, SchedulerNode) or snode.is_template(): + continue + + snode.merge_loops() + + # Note that for CPU backend, merging loops will change + # snode.group. It's fine for Triton backend. + # But if we simplify update snode.group like this: + # group_fn = self.get_backend(snode.node.get_device()).group_fn + # snode.group = (snode.node.get_device(), group_fn(snode._sizes)) + # There is still an issue due to different snode in a + # FusedSchedulerNode having different merged loops. + # Skip CPU backend for now. + + def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Combine eligible nodes into FusedSchedulerNodes. + """ + with dynamo_timed( + "Scheduler.fused_nodes", log_pt2_compile_event=True, log_waitcounter=True + ): + for i in range(10): + old_len = len(nodes) + fusion_log.debug( + "===== attempting fusion (%d/10): %d nodes =====", + i + 1, + old_len, + ) + nodes = self.fuse_nodes_once(nodes) + new_len = len(nodes) + fusion_log.debug( + "completed fusion round (%d/10): fused %d nodes into %d nodes\n", + i + 1, + old_len, + new_len, + ) + if new_len == old_len or new_len == 1: + fusion_log.debug( + "===== fusion complete (%d iterations) =====", i + 1 + ) + break + return nodes + + def process_grouped_nodes(self) -> None: + """ + Unpack GroupedSchedulerNode into regular nodes. + """ + new_nodes: list[BaseSchedulerNode] = [] + for node in self.nodes: + new_nodes.extend( + node.unpack() if isinstance(node, GroupedSchedulerNode) else [node] + ) + self.nodes = new_nodes + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + assert len(nodes) > 0 + device = nodes[0].get_device() + self.current_device = device + backend = self.get_backend(device) + with dynamo_timed( + "benchmark_fused_nodes", + log_pt2_compile_event=True, + dynamo_compile_column_us="compile_time_autotune_time_us", + ): + return backend.benchmark_fused_nodes(nodes) + + def generate_kernel_code_from_nodes( + self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool + ) -> str: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + assert len(nodes) > 0 + device = nodes[0].get_device() + self.current_device = device + backend = self.get_backend(device) + with dynamo_timed("benchmark_fused_nodes"): + return backend.generate_kernel_code_from_nodes(nodes, benchmark_kernel) + + def benchmark_codegened_module( + self, module: ModuleType, device: torch.device + ) -> tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + self.current_device = device + backend = self.get_backend(device) + with dynamo_timed("benchmark_fused_nodes"): + return backend.benchmark_codegened_module(module) + + def finalize_multi_template_buffers(self) -> None: + """ + Finalize a backing choice for MultiTemplateBuffers which did not already have a + choice finalized through fusion. In the case of an extern choice, this will result + in replacing the SchedulerNode. + + If a MultiTemplateBuffer did not have any fusion opportunities, finalizing a choice + will force completion of compilation and benchmarking. + """ + + def replace_operation_buffer( + orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer + ) -> None: + replaced_buf_name = new_node.get_name() + orig_buf_name = orig_node.get_name() + assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str) + + replaced_op_name = new_node.get_operation_name() + orig_op_name = orig_node.get_operation_name() + assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str) + + del V.graph.name_to_buffer[replaced_buf_name] + new_node.name = orig_buf_name + + del V.graph.name_to_op[replaced_op_name] + new_node.operation_name = orig_op_name + + orig = V.graph.buffers.index(orig_node) + V.graph.buffers.remove(new_node) + V.graph.buffers[orig] = new_node + V.graph.name_to_buffer[orig_buf_name] = new_node + + orig = V.graph.operations.index(orig_node) + V.graph.operations.remove(new_node) + V.graph.operations[orig] = new_node + V.graph.name_to_op[orig_op_name] = new_node + + for i, node in enumerate(self.nodes): + if isinstance(node, SchedulerNode) and isinstance( + node.node, ir.MultiTemplateBuffer + ): + multi_node = node.node + if not config.test_configs.force_extern_kernel_in_multi_template: + min_node_unfused, _ = multi_node.get_min_choice() + else: + min_node_unfused = next( + ( + timing + for timing in multi_node.choice_timings + if isinstance( + timing, + torch._inductor.select_algorithm.ExternKernelCaller, + ) + ), + ) + + if isinstance( + min_node_unfused, + torch._inductor.ir.TritonTemplateCallerBase, + ): + node.node.finalize_as_triton_caller(min_node_unfused) + continue + + out_tensorbox = min_node_unfused.output_node() + out_storage = out_tensorbox.data + assert isinstance(out_storage, ir.StorageBox) + out_buffer = out_storage.data + assert isinstance(out_buffer, ir.OperationBuffer) + + out_buffer.layout = multi_node.layout + replace_operation_buffer(multi_node, out_buffer) + new_scheduler_node = self.create_scheduler_node(out_buffer) + + self.nodes[i] = new_scheduler_node + self.name_to_node[node.get_name()] = new_scheduler_node + self.name_to_fused_node[node.get_name()] = new_scheduler_node + + # We need to reflect the mutation renames that were recorded in the original node + mutation_renames = {} + for dep in itertools.chain( + node.read_writes.reads, node.unmet_dependencies + ): + if real_name := self.mutation_real_name.get(dep.name, None): + mutation_renames[real_name] = dep.name + + def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]: + return OrderedSet(dep.rename(mutation_renames) for dep in deps) + + new_scheduler_node.unmet_dependencies = rename_deps( + new_scheduler_node.unmet_dependencies + ) + new_scheduler_node.read_writes.reads = rename_deps( + new_scheduler_node.read_writes.reads + ) + + for new_out, old_out in zip( + new_scheduler_node.get_outputs(), node.get_outputs() + ): + self.name_to_buf[old_out.get_name()] = new_out + new_out.users = old_out.users + + new_scheduler_node.min_order = node.min_order + new_scheduler_node.max_order = node.max_order + new_scheduler_node.last_usage = node.last_usage + + def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool: + return any( + hasattr(n.node, "data") + and n.node is not None + and hasattr(n.node.data, "scatter_mode") + and n.node.data.scatter_mode == "atomic_add" + for n in node_list + ) + + def speedup_by_fusion( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> Union[bool, Callable[[], bool]]: + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + + is_multi_template = any( + n.is_template() + and isinstance(n.get_template_node(), ir.MultiTemplateBuffer) + for n in (node1, node2) + ) + if not config.benchmark_fusion and not is_multi_template: + return True + + if ( + node1.is_template() + and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) + or node1.is_foreach() + or node2.is_foreach() + ): + # TODO support benchmarking epilogue fusion + return True + + node_list_1 = node1.get_nodes() + device = node_list_1[0].get_device() + assert device + + # don't support benchmark fusion for CPU right now. + if device.type == "cpu": + return True + + node_list_2 = node2.get_nodes() + node_list_fused = list(itertools.chain(node_list_1, node_list_2)) + + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + # Skip benchmarking them by allowing fusion. + if self._any_atomic_add(node_list_fused): + return True + + from triton.compiler.errors import CompilationError + + why = WhyNoFuse(node1, node2) + + device = node_list_fused[0].get_device() + assert device is not None + + def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: + if fusion_log.isEnabledFor(logging.DEBUG): + if ms_fused < ms1 + ms2: + fusion_log.debug( + "can fuse (benchmark): fusing %s with %s cause %sx speedup", + node1.get_buffer_names(), + node2.get_buffer_names(), + green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown", + node1.get_buffer_names(), + node2.get_buffer_names(), + red_text(f"{ms_fused / (ms1 + ms2):.3f}"), + ) + + async_compile = torch._inductor.async_compile.AsyncCompile() + + def compile_kernel( + nodes: Sequence[BaseSchedulerNode], + ) -> tuple[Optional[LambdaFuture], ModuleType]: + src_code = self.generate_kernel_code_from_nodes( + nodes, benchmark_kernel=True + ) + mod = PyCodeCache.load(src_code) + if not async_compile.use_process_pool(): + fut = None + else: + fut = async_compile.triton(kernel_name="triton_", source_code=src_code) + assert isinstance(fut, LambdaFuture) + + return (fut, mod) + + if is_multi_template and any( + n.get_template_node() is not None for n in (node1, node2) + ): + epilogue_fusion = node1.get_template_node() is not None + multi_node = ( + node1.get_template_node() + if epilogue_fusion + else node2.get_template_node() + ) + assert isinstance(multi_node, ir.MultiTemplateBuffer) + + # Eagerly compile and benchmark non-template nodes + choice_timings = multi_node.choice_timings + _, ms1 = multi_node.get_min_choice() + ms2, path2 = ( + self.benchmark_fused_nodes(node_list_2) + if epilogue_fusion + else self.benchmark_fused_nodes(node_list_1) + ) + + # Start compiling choices in parallel + future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] + triton_choices = 0 + for choice, unfused_time in sorted( + choice_timings.items(), key=operator.itemgetter(1) + ): + if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase): + continue + + # For prologue fusion we check if the underlying template of the choice + # supports all allowed prologue inputs. If not, we skip this choice in + # the fusion benchmark. + # TODO: Remove this check after all Triton templates support prologue fusion. + # Currently, persistent+TMA Triton template does not due to the TMA-based loads. + if ( + not epilogue_fusion + and hasattr(choice, "allowed_prologue_inps") + and choice.allowed_prologue_inps != multi_node.allowed_prologue_inps + ): + continue + + if unfused_time >= ms1 + ms2: + break + + triton_choices += 1 + if triton_choices > config.max_epilogue_benchmarked_choices: + break + + with multi_node.swap_as_triton_caller(choice): + future_choices.append((choice, *compile_kernel(node_list_fused))) + + if len(future_choices) == 0: + return False + + def benchmark_when_ready() -> bool: + min_ms_fused = float("inf") + ms_fused_choice = None + + new_timings = {} + # Benchmark each choice after compilation completes + for choice, future, mod_fused in future_choices: + try: + if future is not None: + future.result() + + # Ideally we would more narrowly catch Exceptions here but + # triton will unpredictably error with valid prologue fusions + except Exception as e: + if fusion_log.isEnabledFor(logging.DEBUG): + fusion_log.debug( + "Exception in compiling %s: %s", + "prologue" if not epilogue_fusion else "epilogue", + str(e), + ) + continue + with multi_node.swap_as_triton_caller(choice): + ms_fused, path = self.benchmark_codegened_module( + mod_fused, device + ) + new_timings[choice] = ms_fused + if ms_fused < min_ms_fused: + min_ms_fused = ms_fused + ms_fused_choice = choice + + log_fusion(min_ms_fused, ms1, ms2) + + if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None: + multi_node.finalize_as_triton_caller(ms_fused_choice) + multi_node._choice_timings = new_timings + return True + else: + return False + + return benchmark_when_ready + + else: + # Start parallel compilation for all three kernels + future_and_mod_l1 = compile_kernel(node_list_1) + future_and_mod_l2 = compile_kernel(node_list_2) + future_and_mod_l1_fused = compile_kernel(node_list_fused) + + def benchmark_when_ready() -> bool: + from torch._inductor.runtime.triton_heuristics import ( + NoTritonConfigsError, + ) + + try: + # Wait for all compilations to complete + for fut in ( + future_and_mod_l1[0], + future_and_mod_l2[0], + future_and_mod_l1_fused[0], + ): + if fut is not None: + fut.result() + + ms1, path1 = self.benchmark_codegened_module( + future_and_mod_l1[1], device + ) + if math.isinf(ms1): + why("register spilling of the first kernel") + return False + + ms2, path2 = self.benchmark_codegened_module( + future_and_mod_l2[1], device + ) + if math.isinf(ms2): + why("register spilling of the second kernel") + return False + + ms_fused, path_fused = self.benchmark_codegened_module( + future_and_mod_l1_fused[1], device + ) + if math.isinf(ms_fused): + why("register spilling of the fused kernel") + return False + + log_fusion(ms_fused, ms1, ms2) + + if ( + is_metric_table_enabled("slow_fusion") + and ms_fused >= ms1 + ms2 + and (path1, path2) not in self.logged_slow_fusion + ): + self.logged_slow_fusion.add((path1, path2)) + get_metric_table("slow_fusion").add_row( + lambda: { + "kernel1_path": path1, + "kernel1_latency": ms1, + "kernel2_path": path2, + "kernel2_latency": ms2, + "fused_kernel_path": path_fused, + "fused_kernel_latency": ms_fused, + "slow_down_ratio": ms_fused / (ms1 + ms2), + } + ) + + return ms_fused < ms1 + ms2 + + except NoTritonConfigsError: + return False + + except CompilationError as e: + if "Loop-carried variable" in str(e): + return True + raise + + return benchmark_when_ready + + def get_fused_node(self, node: BaseSchedulerNode) -> BaseSchedulerNode: + "Look up the node in Scheduler name_to_fused_node" + return self.name_to_fused_node[node.get_first_name()] + + def fuse_nodes_once( + self, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: + """ + Combine eligible nodes into FusedSchedulerNodes. + + This relies on two key functions to control the logic: + - self.can_fuse(): checks if a fusion is legal + - self.score_fusion(): assigns priority to a given fusion + """ + fused_nodes = OrderedSet(nodes) + if fusion_log.isEnabledFor(logging.DEBUG): + fusion_log.debug("fuse_nodes_once, candidates:") + for node in fused_nodes: + fusion_log.debug(" %s", node.debug_str_short()) + + # These are potential fusions which we are async compiling, + # and which we will benchmark profitability of. + pending_fusions: dict[ + BaseSchedulerNode, + tuple[Callable[[], bool], BaseSchedulerNode, BaseSchedulerNode], + ] = {} + + def fuse_two_nodes( + node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> BaseSchedulerNode: + fusion_log.debug("fusing %s with %s", node1.get_name(), node2.get_name()) + + device = node1.get_device() + assert node2.get_device() == device + node3 = self.get_backend(device).fuse(node1, node2) + fused_nodes.remove(node1) + fused_nodes.remove(node2) + fused_nodes.add(node3) + self.name_to_fused_node.update( + {n.get_name(): node3 for n in node3.get_nodes()} + ) + return node3 + + def resolve_pending_fusions( + node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> None: + while ( + self.get_fused_node(node1) in pending_fusions + or self.get_fused_node(node2) in pending_fusions + ): + pending_fusion = pending_fusions.get( + self.get_fused_node(node1), + pending_fusions.get(self.get_fused_node(node2), None), + ) + assert pending_fusion is not None + + is_speedup, node_key1, node_key2 = pending_fusion + pending_fusions.pop(node_key1, None) + pending_fusions.pop(node_key2, None) + + assert self.get_fused_node(node_key1) is node_key1 + assert self.get_fused_node(node_key2) is node_key2 + + if not is_speedup() or self.will_fusion_create_cycle(node1, node2): + continue + + fuse_two_nodes(node_key1, node_key2) + + for node1, node2 in self.get_possible_fusions(nodes): + # if either node is in a pending fusion, resolve it. + # since we iterate on potential fusions based on profitability + # the first potential fusion should take precedence. + resolve_pending_fusions(node1, node2) + node1 = self.get_fused_node(node1) + node2 = self.get_fused_node(node2) + + if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( + node1, node2 + ): + speedup = self.speedup_by_fusion(node1, node2) + if callable(speedup): + pending_fusions[node1] = (speedup, node1, node2) + pending_fusions[node2] = (speedup, node1, node2) + continue + + if not speedup: + continue + + fuse_two_nodes(node1, node2) + + seen_pair_speedup_fn: OrderedSet[Callable[[], bool]] = OrderedSet() + for is_speedup_fn, node_key1, node_key2 in pending_fusions.values(): + if is_speedup_fn in seen_pair_speedup_fn: + continue + + seen_pair_speedup_fn.add(is_speedup_fn) + + assert self.get_fused_node(node_key1) is node_key1 + assert self.get_fused_node(node_key2) is node_key2 + + if is_speedup_fn() and not self.will_fusion_create_cycle( + node_key1, node_key2 + ): + fuse_two_nodes(node_key1, node_key2) + + nodes = sorted(fused_nodes, key=lambda x: x.min_order) + nodes = self.topological_sort_schedule(nodes) + self.prune_redundant_deps(nodes) + return nodes + + def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None: + """ + Groups parallel nodes + """ + fused_nodes = OrderedSet(self.nodes) + count = 0 + num_nodes_orig = len(self.nodes) + log.debug("ComboKernels: Generating with num_ck_nodes = %s...", num_ck_nodes) + for num, node_list in enumerate( + ForeachKernelSchedulerNode.group_nodes_for_combo_kernels(self) + ): + node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list) + if len(node_list) < 2: + continue + if num_ck_nodes is not None and count > num_ck_nodes: + break + if not self.speedup_by_combo_kernel(node_list): + log.debug("ComboKernels: Not speeding up %d-th group", num) + continue + count += 1 + enable_autotune = config.combo_kernels_autotune > 0 + group_snode = ForeachKernelSchedulerNode( + node_list[0].scheduler, + node_list, + use_custom_partition_algo=True, + enable_autotune=enable_autotune, + ) + log.info( + "ComboKernels: Combining %d nodes for %d-th group", + len(node_list), + num, + ) + for node in node_list: + fused_nodes.remove(node) + fused_nodes.add(group_snode) + self.name_to_fused_node.update( + {n.get_name(): group_snode for n in group_snode.get_nodes()} + ) + self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) + self.nodes = self.topological_sort_schedule(self.nodes) + log.info( + "Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodes", + count, + num_nodes_orig, + len(self.nodes), + ) + self.prune_redundant_deps(self.nodes) + + def prune_redundant_deps(self, nodes: list[BaseSchedulerNode]) -> None: + for node in nodes: + node.prune_redundant_deps(self.name_to_fused_node) + + def get_possible_fusions( + self, nodes: list[BaseSchedulerNode] + ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]: + """ + Helper to find all legal fusion opportunities, sorted by self.score_fusion() + """ + possible_fusions = [] + seen = OrderedSet[tuple[BaseSchedulerNode, BaseSchedulerNode]]() + + def check_all_pairs(nodes: list[BaseSchedulerNode]) -> None: + for node1_index, node1 in enumerate(nodes): + for node2 in nodes[ + node1_index + 1 : node1_index + + 1 + + config.max_fusion_buffer_group_pairwise_attempts + ]: + key = (node1, node2) + if key in seen: + continue + seen.add(key) + + if self.can_fuse(node1, node2): + possible_fusions.append(key) + elif (node2.is_template() or node2.is_foreach()) and self.can_fuse( + node2, node1 + ): + # foreach fusions and epilogue fusions are order dependent + possible_fusions.append((node2, node1)) + + buffer_names_grouping = collections.defaultdict(list) + for node in nodes: + if self.unfusable_node(node): + continue + for buf in node.used_buffer_names(): + buffer_names_grouping[buf].append(node) + for node_grouping in buffer_names_grouping.values(): + check_all_pairs(node_grouping) + + if config.aggressive_fusion: + group_grouping = collections.defaultdict(list) + for node in nodes: + group = getattr(node, "group", None) + if group: + group_grouping[group].append(node) + for node_grouping in group_grouping.values(): + check_all_pairs(node_grouping) + + possible_fusions = self.get_possible_fusions_with_highest_priority( + possible_fusions + ) + possible_fusions.sort(key=self.score_fusion_key, reverse=True) + fusion_log.debug("found %d possible fusions", len(possible_fusions)) + return possible_fusions + + def will_fusion_create_cycle( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Finds whether there's a path from node1 to node2 (or vice-versa) + caused indirectly by other fusions. + """ + # since we are just returning boolean here, use slightly faster, unordered set + visited = OrderedSet[FusedSchedulerNode]() + + def found_path(node: BaseSchedulerNode) -> bool: + # only fused nodes can introduce new ancestors. + if isinstance(node, FusedSchedulerNode) and node not in visited: + visited.add(node) + if node.get_operation_names().issubset(combined_ancestors): + # All fusion outputs are in ancestors of node1 and node2, thus + # cannot introduce new path: + # + # 1. if output is neither descendent of node1 or node2, the + # output cannot introduce a path + # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be + # on path(node1->node2), hence it cannot be ancestor of node2 + # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be + # ancestor of node1 + return False + else: + # continue DFS of new ancestors introduced by the fusion + return bool(combined_names & node.ancestors) or any( + found_path(self.name_to_fused_node[n]) + for n in node.ancestors - combined_ancestors + ) + return False + + # as above - use slightly faster, unordered set + combined_names = ( + node1.get_operation_names()._dict.keys() + | node2.get_operation_names()._dict.keys() + ) + combined_ancestors = ( + node1.ancestors._dict.keys() | node2.ancestors._dict.keys() + ) - combined_names + cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) + if cycle: + WhyNoFuse(node1, node2)("will create cycle") + return cycle + + def can_fusion_increase_peak_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Return true if fusing the two nodes can potentially increasing peak memory. + + The implementation is more like a heuristic since we don't really know if we are at peak + or not when trying to fuse these two nodes. The order of nodes may change later which makes the + peak memory estimation hard. + + Here is how we decide the LOWER BOUND of extra memory allocation if we fuse these 2 nodes: + 1. find all buffers read by each node with a single user. These buffers are supposed to + be reused if we don't fuses these 2 nodes + 2. find the intersection of these buffers for the two node and sum the total buffer size. + If we don't fuse these two nodes, we can at lease avoid this much memory allocation. + Note that the extra memory allocation is not necessarily causing peak memory increase. + This is just a heuristic. + + We return true only if the saving for fusion can not trade off the extra memory allocation. + """ + + from .codegen.wrapper import buffer_reuse_key + + def _find_single_user_inputs( + node: BaseSchedulerNode, + ) -> list[ir.Buffer]: + output = [] + for rd in node.read_writes.reads: + buf = self.name_to_buf.get(rd.name) + if buf and len(buf.users) == 1 and buf.node.has_tensor_output(): + output.append(buf.node) + return output + + # Check inputs that can be potentially reused + lhs_dep_nodes = _find_single_user_inputs(node1) + rhs_dep_nodes = _find_single_user_inputs(node2) + + lhs_reuse_keys = OrderedSet(buffer_reuse_key(buf) for buf in lhs_dep_nodes) + rhs_reuse_keys = OrderedSet(buffer_reuse_key(buf) for buf in rhs_dep_nodes) + + common_reuse_keys = lhs_reuse_keys.intersection(rhs_reuse_keys) + + memory_overhead = 0 + for key in common_reuse_keys: + try: + memory_overhead += int(key[2]) + except ValueError: + # not an integer. Fallback is to fuse + return False + + bw_saving = self.score_fusion_memory(node1, node2) + + # The factor 32 here is quite arbitrary. + if V.graph.sizevars.statically_known_gt(memory_overhead, 32 * bw_saving): + return True + return False + + def are_long_distant_nodes( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + This function prevents fusion for nodes that can increase memory + footprint. This problem is more common in horizontal fusion, where nodes + that are far apart in the original order get fused, lengthening the live + intervals of tensors. This is very evident in models with activation + checkpointing, where the recomputed nodes from different checkpointed + regions get fused and significantly increase the memory footprint. + + The current attempt is a quick, possibly hacky, heuristic to prevent the + fusion of nodes that are far away in the original order. + + A better but difficult to implement heurisitic would be to use live + intervals of the buffers, find region of peak pressure in the original + program and prevent fusion that crosses that peak region. We might need + special care or good approximation in this implementation, as fusion of + node changes live intervals, and re-computing live intervals and peak + memory after each fusion can introduce large compilation overhead. + """ + proximity_score = max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return proximity_score > 64 + + def decide_fusion_fail_reason( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + common_buf_names: Union[tuple[str], OrderedSet[str]], + ) -> str: + """ + Try to decide reasons why fusion fail due to no shared memory even though + there are common buffers. + """ + reasons = {} + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + for buf_name in common_buf_names: + buf = V.graph.get_buffer(buf_name) + lhs_dep = node1_name2dep[buf_name] + rhs_dep = node2_name2dep[buf_name] + + if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): + reasons[buf_name] = ( + f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}" + ) + continue + + if lhs_dep.get_numel() != rhs_dep.get_numel(): + reasons[buf_name] = ( + f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}" + ) + continue + + # same numel but different MemoryDep.size. Should be broadcasting + if sympy_product(lhs_dep.size) != sympy_product(rhs_dep.size): + reasons[buf_name] = "broadcast" + continue + + lhs_off = lhs_dep.get_offset() + rhs_off = rhs_dep.get_offset() + if lhs_off != rhs_off: + # One example is in transformer, we use a concatenated linear layer + # to project Q/K/V and then split the result. The 3 splits will + # point to the same buffer with different offsets. + reasons[buf_name] = f"different offset: {lhs_off} v.s. {rhs_off}" + continue + + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + reasons[buf_name] = f"Mismatch loop orders: {lhs_dep} v.s. {rhs_dep}" + continue + + # Add more rules here + layout_str = "" + if not isinstance(buf, ir.TorchBindObject): + layout_str = f"Layout: {buf.layout}" + reasons[buf_name] = ( + f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}" + ) + + return str(reasons) + + def shared_data_after_reordering_loop( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Right now just greedily reorder the loop of node1 to be compatible with node2, + but ideally we should have some heuristics to reorder the loop for node2 + to be compatible with node1 if that's more efficient. + """ + + # TODO Don't do loop reordering for CPU for now. + # Should debug more why it does not work for CPU codegen + if not config.loop_ordering_after_fusion or any( + n.is_cpu() for n in [node1, node2] + ): + return 0 + + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + # Fast path: no common buffers. + common_buffer_names = node1_buffer_names & node2_buffer_names + if not common_buffer_names: + return 0 + + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + # Find the commons buffers that has different loop orders + candidates = [] + for buffer_name in common_buffer_names: + lhs_dep = node1_name2dep[buffer_name] + rhs_dep = node2_name2dep[buffer_name] + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + candidates.append( + ( + V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0), + lhs_dep, + rhs_dep, + ) + ) + + if len(candidates) == 0: + return 0 + + # Pick the largest buffer to guide the loop reordering + _numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0)) + + if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): + return 0 + + if lhs_dep.num_vars != rhs_dep.num_vars: + # this can happen due to we don't merge loops. + # We can not do loop reordering in this case right now + # Simply returning true if the two Deps are the same after + # normalization (merging loops) + if lhs_dep.normalize() == rhs_dep.normalize(): + return self.dep_size_hint(lhs_dep) + return 0 + + # Only reorder loops for pointwise for now + if not node1.is_reduction(): + node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep) + elif not node2.is_reduction(): + node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep) + else: + loop_ordering_log.debug( + "Don't reorder loops since both nodes are reductions: %s v.s. %s", + node1.get_name(), + node2.get_name(), + ) + + return self.score_fusion_memory(node1, node2) + + def unfusable_node(self, node: BaseSchedulerNode) -> bool: + """ + Is this node unfusable under any conditions. + """ + return ( + isinstance(node, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node.is_template() + and not is_output_of_multi_outputs_template(node.node) + ) + + def check_prologue_fusion_heuristics_fusable( + self, + prologue_node: BaseSchedulerNode, + template_node: BaseSchedulerNode, + why: WhyNoFuse, + ) -> bool: + """ + Heuristics to avoid benchmarking predictably slow prologue fusions + """ + # user opt into more aggressive prologue fusion, dont use heuristics + if prologue_node.get_operation_names() <= V.graph.invoke_quant_ops: + return True + + read_bytes = prologue_node.get_read_buffer_sizes() + write_bytes = prologue_node.get_write_buffer_sizes() + + # Initially, only do fusions which will result in fewer memory accesses inside of the template to avoid + # potential bad cache behavior and shared memory use. + # we also want to avoid benchmarking reliably unprofitable fusions like downcasts from fp32 -> fp16 inside kernel. + # allowing gathers by allowing increasing write_bytes by small factor + # TODO - make configurable per input, for instance, bias can fuse fp32 -> fp16 profitably + + BYTES_THRESHOLD_MULTIPLIER = 1.1 + if read_bytes > (write_bytes * BYTES_THRESHOLD_MULTIPLIER): + why("prologue fusion will not increase amount of bytes read in kernel") + return False + + # we want to avoid attempting to fuse predictably unprofitable prologues + # such as increasing the unaligned reads or writes. + # TODO - would be nice to generalize this, however, we would need more explicit + # knowledge of memory access patterns in the TritonTemplate in order to know + # the stride order to check alignment. + origins = tuple( + e.target + for n in prologue_node.get_nodes() + if n.node is not None + for e in n.node.get_origins() + if e.op == "call_function" + ) + if origins == (torch.ops.aten.constant_pad_nd.default,): + why( + "prologue fusion will not increase attempt to fuse in padding bc it increases unaligned reads" + ) + return False + + def low_prec_fp(dtype: torch.dtype) -> bool: + return dtype.itemsize <= 2 and dtype.is_floating_point + + if ( + low_prec_fp(template_node.get_template_node_or_throw().dtype) + and not prologue_node.can_codegen_in_low_precision() + ): + why( + "prologue fusion that must be upcast to fp32 not profitable for low precision templates" + ) + return False + + return True + + def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: + """ + Determine if it is possible to combine node1 and node2 into a + single fused node. + """ + + if node1 is node2: + return False + + why = WhyNoFuse(node1, node2) + + if node1.is_template() and self.get_backend( + node1.get_device() + ).can_fuse_multi_outputs_template(node1, node2): + return True + + if isinstance(node1, GroupedSchedulerNode) or isinstance( + node2, GroupedSchedulerNode + ): + why("grouped node must not be fused with other nodes") + return False + if ( + isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node1.is_template() + ): + why("node1 is extern or nop") + return False + if ( + isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node2.is_template() + ): + why("node2 is extern or nop") + return False + + if node2.get_operation_names() & node1.ancestors: + why("node1 must go before node2") + return False + + if node2.is_template(): + if not config.prologue_fusion: + why("prologue fusion turned off") + return False + + if node1.is_reduction() or node1.is_template(): + why("prologue fusion only supported for pointwise nodes") + return False + + template = node2.get_template_node_or_throw() + if not isinstance(template, ir.TritonTemplateBuffer): + why("prologue fusion only supported for TritonTemplates") + return False + + allowed_prologue_inps = template.get_allowed_prologue_inps() + + unsupported_prologue_args = ( + OrderedSet(inp.get_name() for inp in template.inputs) + - allowed_prologue_inps + ) + + if node1.get_buffer_names() & unsupported_prologue_args: + why("prologue fusion not implemented for kernel for these inputs") + return False + + if node1.has_aliasing_or_mutation() or node1.has_aliasing_or_mutation(): + why("template prologue can only fuse functional pointwise nodes") + return False + + prologue_nodes = node1.get_nodes() + for node in prologue_nodes[:-1]: + node_outs = node.get_outputs() + for out in node_outs: + if not all(user.node in prologue_nodes for user in out.users): + why("template prologue can only fuse nodes with a single use") + return False + + template_snodes = ( + [node2] + if not isinstance(node2, FusedSchedulerNode) + else [n for n in node2.snodes if n.is_template()] + ) + assert len(template_snodes) == 1 + template_snode = template_snodes[0] + + if not ( + len(prologue_nodes[-1].outputs) == 1 + and len(prologue_nodes[-1].outputs[0].users) == 1 + and prologue_nodes[-1].outputs[0].users[0].node is template_snode + ): + why( + "template prologue can only fuse nodes with a single use into template" + ) + return False + + if not self.check_prologue_fusion_heuristics_fusable(node1, node2, why): + return False + + if node1.is_template() and ( + node2.has_aliasing_or_mutation() + or node2.is_reduction() + or not config.epilogue_fusion + ): + why("template epilogue not satisfied") + return False + + if (node1.get_buffer_names() & V.graph.no_fuse_buffer_names) or ( + node2.get_buffer_names() & V.graph.no_fuse_buffer_names + ): + why("fusion for buffer explicit disabled") + return False + + device = node1.get_device() + device2 = node2.get_device() + if device != device2: + why("device mismatch (%s vs %s)", device, device2) + return False + del device2 + + shared_data_score = self.score_fusion_memory(node1, node2) + if ( + shared_data_score < config.score_fusion_memory_threshold + and config.loop_ordering_after_fusion + ): + shared_data_score = self.shared_data_after_reordering_loop(node1, node2) + + if loop_ordering_log.isEnabledFor(logging.DEBUG): + loop_ordering_log.debug( + "%s and %s has %s shared data", + node1.get_name(), + node2.get_name(), + shared_data_score, + ) + + if not V.choices.can_fuse(self, node1, node2, shared_data_score): + return False + + if node1.get_operation_names() & node2.ancestors: + # node2 depends on node1 outputs + return ( + self.can_fuse_vertical(node1, node2) + and V.choices.can_fuse_vertical(self, node1, node2, shared_data_score) + and self.get_backend(device).can_fuse_vertical(node1, node2) + ) + else: # nodes don't depend on each other, but may have common reads + return V.choices.can_fuse_horizontal( + self, node1, node2, shared_data_score + ) and self.get_backend(device).can_fuse_horizontal(node1, node2) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check if it is legal to fuse a consumer (node2) into a producer (node1). + + We can fuse them if all the reads of node2 either match + corresponding writes in node1, or are written by nodes that can + be scheduled before the fusion of node1 and node2. + """ + node1_buf_names = node1.get_buffer_names() + why = WhyNoFuse(node1, node2) + remaining_deps_by_name: dict[str, list[Dep]] = defaultdict(list) + + for dep in node2.unmet_dependencies: + name = self.mutation_renames.get(dep.name, dep.name) + if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): + continue + remaining_deps_by_name[name].append(dep) + + for cd in node1.read_writes.writes: + if not isinstance(cd, MemoryDep): + continue + remaining = remaining_deps_by_name.get( + self.mutation_renames.get(cd.name, cd.name) + ) + if remaining: + for rd in remaining: + if self.fusable_read_and_write(rd, cd): + remaining.remove(rd) + + remaining_deps = OrderedSet( + dep.name + for dep in itertools.chain.from_iterable(remaining_deps_by_name.values()) + ) + + if remaining_deps & node1_buf_names: + # MemoryDeps didn't match and read different locations of the same buffer. + # Examples here include: + # - MemoryDep("foo", x) != MemoryDep("foo", x + 1) + # - MemoryDep("foo", x) != StarDep("foo") + why("memory deps did not match") + return False + + node1_op_names = node1.get_operation_names() + for name in remaining_deps: + op_name = self.name_to_buf[name].defining_op_name() + if node1_op_names & self.name_to_fused_node[op_name].ancestors: + why("intermediate nodes between node1 & node2") + return False + + return True + + def fusable_weak_dep( + self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if weak_dep.name not in node1.get_buffer_names(): + return False + + # A weak dep can be fused if and only if the fused operation acts inplace + # on the buffer being mutated. i.e. the same index is being read then mutated + mutating_writes = [ + write + for write in node2.read_writes.writes + if write.name == weak_dep.mutating_buf + ] + if len(mutating_writes) != 1: + return False + write = mutating_writes[0] + assert isinstance(write, MemoryDep) + + if free_symbol_is_type(write.index, SymT.TMP): + return False + + real_name = self.mutation_real_name[weak_dep.mutating_buf] + relevant_reads = [ + read for read in node1.read_writes.reads if read.name == real_name + ] + return all( + isinstance(read, MemoryDep) + and not free_symbol_is_type(read.index, SymT.TMP) + and read.index == write.index + and read.size == write.size + for read in relevant_reads + ) + + # StarDep doesn't match MemoryDep, different indices don't match + # However, broadcasting sometimes strips dimensions, and if that's the case + # we still can match unmet dep + # if there's indirect indexing, don't match it + def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: + if isinstance(read, MemoryDep): + read_name = self.mutation_renames.get(read.name, read.name) + + if ( + read_name != write.name + or free_symbol_is_type(read.index, SymT.TMP) + or free_symbol_is_type(write.index, SymT.TMP) + ): + return False + + if config.loop_ordering_after_fusion and read.num_vars != write.num_vars: + # Need merge loops if we do loop ordering after fusion since + # we have not merged the loops yet when creating the scheduler + # nodes. + read = read.normalize() + write = write.normalize() + + return ( + read.index == write.index + and len(read.size) >= len(write.size) + and read.size[: len(write.size)] == write.size + ) + elif isinstance(read, StarDep): + read_name = self.mutation_renames.get(read.name, read.name) + write_name = self.mutation_renames.get(write.name, write.name) + if ( + read.mode == write.mode + and write.mode is not None + and read_name == write_name + ): + return True + return False + + def dep_size_hint(self, dep: Dep) -> int: + res = 0 + if dep not in self.__dep_size_hint_cache: + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.__dep_size_hint_cache[dep] = res + else: + res = self.__dep_size_hint_cache[dep] + return res + + def score_fusion_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + The first term in our fusion score that estimates number of saved + memory operations. + """ + node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes) + node2_dep_len = len(node1.read_writes.reads) + len(node2.read_writes.writes) + + # optimization: iter over smaller set + if min(node1_dep_len, node2_dep_len) * 4 < max(node1_dep_len, node2_dep_len): + if node1_dep_len > node2_dep_len: + tmp = node1 + node1 = node2 + node2 = tmp + + deps = [ + dep + for dep in node1.read_writes.reads | node1.read_writes.writes + if dep in node2.read_writes.reads or dep in node2.read_writes.writes + ] + + return sum(self.dep_size_hint(dep) for dep in deps) + + common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( + node2.read_writes.reads | node2.read_writes.writes + ) + return sum(self.dep_size_hint(dep) for dep in common_memory_deps) + + def get_possible_fusions_with_highest_priority( + self, possible_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]] + ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]: + # Group the possible fusions based on their priority from the backend. + # Only return the group of possible fusions with highest priority. + if len(possible_fusions) == 0: + return possible_fusions + possible_fusions_group_by_priority: dict[ + int, list[tuple[BaseSchedulerNode, BaseSchedulerNode]] + ] = {} + + for node1, node2 in possible_fusions: + assert node1.get_device() == node2.get_device() + device = node1.get_device() + fusion_pair_priority = int( + self.get_backend(device).get_fusion_pair_priority(node1, node2) + ) + if fusion_pair_priority not in possible_fusions_group_by_priority: + possible_fusions_group_by_priority[fusion_pair_priority] = [ + (node1, node2), + ] + else: + possible_fusions_group_by_priority[fusion_pair_priority].append( + (node1, node2) + ) + # return the possible fusions with highest priority + possible_fusions_with_highest_priority = min( + possible_fusions_group_by_priority.items(), key=operator.itemgetter(0) + )[1] + assert len(possible_fusions_with_highest_priority) > 0 + return possible_fusions_with_highest_priority + + def score_fusion_key( + self, nodes: tuple[BaseSchedulerNode, BaseSchedulerNode] + ) -> Any: + """ + Shim for list.sort(key=...) + """ + return V.choices.score_fusion(self, *nodes) + + def compute_last_usage(self) -> None: + """ + Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) + """ + + future_used_buffers = OrderedSet(V.graph.get_output_names()) + + for node in reversed(self.nodes): + node.set_last_usage(future_used_buffers, self.mutation_real_name) + future_used_buffers.update(node.last_usage) + + def free_buffers(self) -> None: + """Free any buffers that are no longer needed""" + for name in sorted( + self.buffer_names_to_free + - V.graph.removed_buffers + - V.graph.wrapper_code.freed # type: ignore[has-type] + ): + if name in self.name_to_buf: + buf = self.name_to_buf[name] + if buf.can_free(): + V.graph.wrapper_code.codegen_free(buf.node) + elif name in V.graph.graph_inputs: + inp = V.graph.graph_inputs[name] + if isinstance(inp, ir.TorchBindObject): + V.graph.wrapper_code.codegen_free(inp) + elif isinstance(inp, ir.GeneratorState): + continue + else: + storage = inp.data + assert ( + isinstance(storage, ir.StorageBox) and storage.is_input_buffer() + ) + V.graph.wrapper_code.codegen_free(storage.data) + + self.buffer_names_to_free.clear() + + def flush(self) -> None: + for backend in self.backends.values(): + backend.flush() + self.free_buffers() + + def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None: + assert isinstance(scheduler_node, ExternKernelSchedulerNode) + # 'decide_inplace_update' stores the inplace update decisions in + # the current kernel from where 'allocate' retrieve those decisions. + # We have to make sure there is a non-NULL kernel handler to store + # those inplace update decisions. + counters["inductor"]["extern_calls"] += 1 + with V.set_kernel_handler(Kernel(increase_kernel_count=False)): + scheduler_node.decide_inplace_update() + scheduler_node.mark_run() + node = scheduler_node.node + assert isinstance(node, ir.ExternKernel), f"{type(node)=}" + node.codegen(V.graph.wrapper_code) + self.free_buffers() + + def create_backend(self, device: torch.device) -> BaseScheduling: + assert not is_gpu(device.type) or device.index is not None, ( + f"{device} should have been normalized in lowering" + ) + V.graph.add_device_info(device) + + device_scheduling = get_scheduling_for_device(device.type) + if device_scheduling is None: + raise RuntimeError(f"Unsupported device type: {device.type}") + + if not has_triton(): + if ( + device.type == "cuda" + and (device_props := torch.cuda.get_device_properties(device)).major < 7 + ): + raise GPUTooOldForTriton(device_props, inspect.currentframe()) + elif is_gpu(device.type) and not device.type == "mps": + raise TritonMissing(inspect.currentframe()) + + return device_scheduling(self) + + def get_backend(self, device: Optional[torch.device]) -> BaseScheduling: + assert device is not None + if device not in self.backends: + self.backends[device] = self.create_backend(device) + return self.backends[device] + + def enter_context(self, node: BaseSchedulerNode) -> None: + def get_order(n: torch.fx.Node) -> int: + if n not in self.origin_to_index: + self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) + return self.origin_to_index[n] + + # Use a dict to have ordering + origins = { + (get_order(e), e): None + for n in node.get_nodes() + if n.node is not None + for e in n.node.get_origins() + } + origins = list(origins.keys()) + if origins: + _, last = max(origins, key=operator.itemgetter(0)) + V.graph.wrapper_code.enter_context(last) + + def can_buffer_be_removed_through_fusion( + self, name: str, fused_node_names: OrderedSet[str] + ) -> bool: + try: + users = self.name_to_buf[name].users + except KeyError: + return False + return ( + all(user.is_weak or user.get_name() in fused_node_names for user in users) + and name not in self.mutation_renames + and name not in self.mutation_real_name + ) + + def should_partition(self, node: BaseSchedulerNode) -> bool: + """Return True if we should partition the inductor graph on this node""" + if isinstance(node, FusedSchedulerNode): + return any(self.should_partition(snode) for snode in node.snodes) + + if not node.is_gpu(): + return True + + if node.node is None: + return True + + if isinstance(node.node, ir.DeviceCopy): + return True + + if isinstance(node.node, ir.Conditional): + return True + + if getattr(node.node, "unbacked_bindings", None): + return True + + if is_cudagraph_unsafe_op(node.node): + return True + + return False + + def get_name_to_nodes( + self, + ) -> dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]]: + """ + Return a mapping from name strings to the corresponding graph inputs or + base scheduler node outputs. + """ + name_to_node: dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]] = {} + name_to_node.update(V.graph.graph_inputs) + + for node in self.nodes: + for name, scheduler_buffer in node.outputs_by_name.items(): + name_to_node[name] = scheduler_buffer.node + + return name_to_node + + def compute_graph_partition_maps( + self, + signatures: list[GraphPartitionSignature], + ) -> None: + """ + computes a mapping from partition input/output indices to graph input/output + indices for each partition. + """ + name_to_graph_input_index = { + name: idx for idx, name in enumerate(V.graph.graph_inputs) + } + name_to_graph_output_index = { + name: idx for idx, name in enumerate(V.graph.get_output_names()) + } + + V.graph.partition_maps = [] + for partition_id, signature in enumerate(signatures): + if signature.skip_cudagraph: + # Note: [Graph Partition Map for CUDAGraph] + # number of partition map should be the same as the number of generated + # partition functions. This assumption will be used when cudagraphify + # each partition function. + continue + + input_mapping = [] + for name in signature.input_nodes: + input_mapping.append(name_to_graph_input_index.get(name)) + + output_mapping = [] + for node in signature.output_nodes: + output_mapping.append(name_to_graph_output_index.get(node.get_name())) + + V.graph.partition_maps.append( + GraphPartitionMap( + partition_id, + input_mapping, + output_mapping, + signature.constant_names, + ) + ) + + def get_graph_partition_symbol_inputs( + self, + partition: PartitionType, + input_nodes: dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]], + ) -> OrderedSet[sympy.Symbol]: + """ + Returns all symbol inputs which are required to be in scope to successfully + perform codegen for this graph partition, including: + - free symbols used in partition nodes + - free symbols in partition input/node shapes, strides, and offsets. This is needed + for recording cudagraphs for tensors with dynamic shapes. + """ + + def get_layout_symints(node: ir.IRNode) -> OrderedSet[sympy.Symbol]: + free_symbol_uses: OrderedSet[sympy.Symbol] = OrderedSet() + layout = node.maybe_get_layout() + if isinstance(layout, ir.Layout): + free_symbol_uses.update( + free_symbols(layout.size) + | free_symbols(layout.stride) + | free_symbols(layout.offset) + ) + if isinstance(layout, ir.MutationLayoutSHOULDREMOVE): + # symint may be used as index in layout.target + free_symbol_uses.update(get_layout_symints(layout.target)) + else: + assert layout is None, ( + f"Expect layout to be None but found layout={layout}" + ) + return free_symbol_uses + + def get_scheduler_node_symbol_uses( + node: BaseSchedulerNode, + ) -> OrderedSet[sympy.Symbol]: + """ + Gets symbols used in node. + """ + if isinstance(node, FusedSchedulerNode): + return OrderedSet().union( + *(get_scheduler_node_symbol_uses(snode) for snode in node.snodes) + ) + assert node.node is not None + free_symbol_uses = node.node.get_free_symbol_uses() + free_symbol_uses.update( + *(get_layout_symints(ir_node) for ir_node in node.node.get_outputs()) + ) + return free_symbol_uses + + def get_input_node_symbols( + node: Union[ir.IRNode, sympy.Expr, ir.TorchBindObject], + ) -> OrderedSet[sympy.Symbol]: + """ + Gets symbols used in input node shapes, strides, and offsets. + """ + if isinstance(node, ir.TorchBindObject): + # TorchBindObject does not involve dynamic shapes yet + return OrderedSet() + elif isinstance(node, ir.IRNode): + return get_layout_symints(node) + else: + # node cannot be sympy.Expr since node comes from read_writes and + # read_writes does not contain sympy.Expr + raise NotImplementedError(f"Unsupported input node type: {type(node)}") + + def filter_symbols( + symbols: OrderedSet[sympy.Symbol], + ) -> OrderedSet[sympy.Symbol]: + """ + Filters a set of symbols that are required for codegen. Skip symbols + that are always internal to kernels, such as SymT.TMP, SymT.INDEX, + and SymT.R0_INDEX. + """ + return OrderedSet( + s + for s in symbols + if symbol_is_type( + s, + ( + SymT.SIZE, + SymT.FLOAT, + SymT.UNBACKED_INT, + SymT.UNBACKED_FLOAT, + ), + ) + ) + + candidate_symbols: OrderedSet[sympy.Symbol] = OrderedSet().union( + *(get_scheduler_node_symbol_uses(node) for node in partition) + ) + candidate_symbols.union( + *(get_input_node_symbols(node) for _, node in input_nodes.items()) + ) + + candidate_symbols = filter_symbols(candidate_symbols) + + res: OrderedSet[sympy.Symbol] = OrderedSet() + for s in candidate_symbols: + symplified_s = V.graph.sizevars.simplify(s) + # use free_symbols only when s is simplified to an Integer or expr + res.update(symplified_s.free_symbols) + + return OrderedSet(sorted(res, key=operator.attrgetter("name"))) + + def get_graph_partition_signature( + self, partitions: list[PartitionType], skip_cudagraphs: list[bool] + ) -> list[GraphPartitionSignature]: + """ + Gets signature for each graph partition, including input nodes, output nodes, and + whether deallocating an input within graph partition. + """ + signatures = [] + + unmet_output_names = OrderedSet(V.graph.get_output_names()) + name_to_node = self.get_name_to_nodes() + + def is_none_layout(buf_name: str) -> bool: + """ + Checks if buf_name is NoneLayout. Buffers with NoneLayout is not allocated + so graph partition should not take it as inputs or outputs. + """ + buf = self.name_to_buf.get(buf_name, None) + + if buf is None: + return False + + if isinstance(buf.node.layout, NoneLayout): + if isinstance(buf.node, ir.MutationOutput) and ( + real_name := self.mutation_real_name.get(buf_name, None) + ): + return is_none_layout(real_name) + + return True + + return False + + for partition, skip_cudagraph in zip( + reversed(partitions), reversed(skip_cudagraphs) + ): + output_names: OrderedSet[str] = OrderedSet() + + for node in partition: + output_names.update(node.outputs_by_name.keys()) + + returned_output_names = output_names.intersection(unmet_output_names) + + # all reads/writes are partition inputs except those generated + # within the partition and tensor constants + read_writes = dependencies.ReadWrites.merge_list( + [node.read_writes for node in partition] + ) + + # WeakDep is fake dependency on unused buffer. It should not appear + # in partition_input_names for inputs that are actually read or written. + partition_input_names = ( + OrderedSet( + [ + x.name + for x in read_writes.reads | read_writes.writes + if not is_none_layout(x.name) + ] + ) + - output_names + ) + + partition_input_names = OrderedSet( + self.mutation_real_name.get(name, name) + for name in partition_input_names + ) + + buffer_names_to_free: OrderedSet[str] = OrderedSet() + for node in partition: + buffer_names_to_free.update(node.last_usage) + + input_nodes = { + name: name_to_node[name] + for name in partition_input_names + if name in name_to_node + } + input_deallocation = { + name: True if name in buffer_names_to_free else False + for name in partition_input_names + if name in name_to_node + } + + # if an input tensor is not freed in the partition function, it should + # also be returned as an output. This brings benefits to cudagraph + # since the returned output tensor is a cudagraph managed tensor with + # a static tensor address. + extra_output_names = [ + name + for name in partition_input_names + if name in name_to_node and name not in buffer_names_to_free + ] + + returned_output_names.update(extra_output_names) + + returned_output_names = OrderedSet( + self.mutation_real_name.get(name, name) + for name in returned_output_names + ) + + output_nodes = [ + name_to_node[name] + for name in returned_output_names + if not is_none_layout(name) + ] + + constant_names = [ + name for name in partition_input_names if name in V.graph.constants + ] + + symbol_inputs = self.get_graph_partition_symbol_inputs( + partition, input_nodes + ) + + partition_signature = GraphPartitionSignature( + symbol_inputs, + input_nodes, + output_nodes, + input_deallocation, + skip_cudagraph, + constant_names, + ) + + signatures.append(partition_signature) + + unmet_output_names = partition_input_names.union( + unmet_output_names - returned_output_names + ) + + return signatures[::-1] + + def clean_removed_buffer_from_partition_signatures( + self, signature: GraphPartitionSignature + ) -> GraphPartitionSignature: + """ + Updates the partition signature by removing buffers specified in + V.graph.removed_buffers. See [Note: Removed Graph Partition Arguments] + """ + input_nodes = { + name: buffer + for name, buffer in signature.input_nodes.items() + if name not in V.graph.removed_buffers + } + input_deallocation = { + name: val + for name, val in signature.input_deallocation.items() + if name not in V.graph.removed_buffers + } + output_nodes = [ + node + for node in signature.output_nodes + if node.maybe_get_name() not in V.graph.removed_buffers + ] + constant_names = [ + name + for name in signature.constant_names + if name not in V.graph.removed_buffers + ] + return GraphPartitionSignature( + signature.symbol_inputs, + input_nodes, + output_nodes, + input_deallocation, + signature.skip_cudagraph, + constant_names, + ) + + def reorder_for_minimizing_partition( + self, + nodes: list[BaseSchedulerNode], + ) -> list[BaseSchedulerNode]: + """ + Reorder nodes to minimize the number of partitions via a bfs + topological sort. This is the optimal reordering such that the + number of partitions cannot be reduced further. This may be + sub-optimal for other metrics such as peak memory. This does not + change relative orders of two cudagraphable nodes, nor the + relative order of two non_cudagraphable nodes. + """ + import heapq + + node_to_indegree: dict[BaseSchedulerNode, int] = dict() + cudagraphable_nodes: list[tuple[int, BaseSchedulerNode]] = [] + non_cudagraphable_nodes: list[tuple[int, BaseSchedulerNode]] = [] + node_to_index = {node: idx for idx, node in enumerate(nodes)} + + def insert_pending_nodes(node: BaseSchedulerNode) -> None: + node_with_index = (node_to_index[node], node) + if self.should_partition(node): + heapq.heappush(non_cudagraphable_nodes, node_with_index) + else: + heapq.heappush(cudagraphable_nodes, node_with_index) + + def update_indegree(node: BaseSchedulerNode) -> None: + for succ_node in node.mpi_node.succ_nodes: + assert node_to_indegree[succ_node] > 0 + node_to_indegree[succ_node] -= 1 + if node_to_indegree[succ_node] == 0: + insert_pending_nodes(succ_node) + + for node in nodes: + node_to_indegree[node] = len(node.mpi_node.pred_nodes) + if node_to_indegree[node] == 0: + insert_pending_nodes(node) + + schedule: list[BaseSchedulerNode] = [] + num_iters: int = 0 + while num_iters < len(nodes) and ( + non_cudagraphable_nodes or cudagraphable_nodes + ): + while non_cudagraphable_nodes: + _, node = heapq.heappop(non_cudagraphable_nodes) + schedule.append(node) + update_indegree(node) + + while cudagraphable_nodes: + _, node = heapq.heappop(cudagraphable_nodes) + schedule.append(node) + update_indegree(node) + + num_iters += 1 + + if num_iters > len(nodes): + raise RuntimeError( + """ + Failed to schedule, while loop ran too long when + reordering for minimizing the num of partitions + """ + ) + + return schedule + + def maybe_reorder_for_minimizing_partition( + self, + nodes: list[BaseSchedulerNode], + ) -> list[BaseSchedulerNode]: + """ + Reorder nodes to minimize the number of partitions if this only slightly + increase peak memory. + """ + from .memory import estimate_peak_memory, prepare_planning_info + + graph_outputs = OrderedSet(V.graph.get_output_names()) + + default_peak_memory, name_to_freeable_input_buf = prepare_planning_info( + nodes, + self.name_to_buf, + self.name_to_fused_node, + OrderedSet(V.graph.graph_inputs.keys()), + graph_outputs, + ) + + reordered_nodes = self.reorder_for_minimizing_partition(nodes) + reorder_peak_memory, _ = estimate_peak_memory( + reordered_nodes, name_to_freeable_input_buf, graph_outputs + ) + + # 1.1 here means 10% extra peak memory budget which is quite arbitrary + if reorder_peak_memory < default_peak_memory * 1.1: + return reordered_nodes + + return nodes + + def reorder_for_partition_with_simple_dependency( + self, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: + """ + Reorder a node if it should be partitioned and has simple dependency: + 1. move a partitioned node to the front if it has no dependency + 2. move a partitioned node to the back if it is only used by OutputNode + 3. otherwise do not reorder + """ + + front: list[BaseSchedulerNode] = [] + middle: list[BaseSchedulerNode] = [] + back: list[BaseSchedulerNode] = [] + + def only_output_user(node: BaseSchedulerNode) -> bool: + for buf in node.get_outputs(): + for use in buf.users: + if not isinstance(use.node, OutputNode): + return False + return True + + for node in nodes: + should_partition = self.should_partition(node) + if should_partition and len(node.unmet_dependencies) == 0: + front.append(node) + elif should_partition and only_output_user(node): + back.append(node) + else: + middle.append(node) + + return front + middle + back + + def graph_partition( + self, + ) -> tuple[list[PartitionType], list[GraphPartitionSignature]]: + """ + Given a list of BaseSchedulerNodes, split into a list of + graph partitions and compute partition input/output signatures. + """ + partitions: list[PartitionType] = [] + skip_cudagraph = True + cur_partition: PartitionType = [] + skip_cudagraphs = [] + for node in self.nodes: + should_partition = self.should_partition(node) + if cur_partition and skip_cudagraph != should_partition: + partitions.append(cur_partition) + skip_cudagraphs.append(skip_cudagraph) + cur_partition = [] + + skip_cudagraph = should_partition + cur_partition.append(node) + + if cur_partition: + partitions.append(cur_partition) + skip_cudagraphs.append(skip_cudagraph) + + signatures = self.get_graph_partition_signature( + partitions=partitions, skip_cudagraphs=skip_cudagraphs + ) + self.compute_graph_partition_maps(signatures) + + return partitions, signatures + + def codegen(self) -> None: + with dynamo_timed("Scheduler.codegen"): + return ( + self._codegen_partitions() + if torch._inductor.config.graph_partition + else self._codegen(self.nodes) + ) + + def _codegen_partition_wrapper( + self, + partition: PartitionType, + signature: GraphPartitionSignature, + ) -> None: + """Codegen a partition given its inputs/outputs""" + from .codegen.wrapper import SubgraphPythonWrapperCodegen + + parent_wrapper_code = V.graph.wrapper_code + graph_partition_id = next(self._graph_partition_counter) + + with V.graph.set_current_wrapper_code(): + V.graph.init_wrapper_code( + is_subgraph=True, + subgraph_name=f"partition_{graph_partition_id}", + parent_wrapper_code=parent_wrapper_code, + partition_signatures=signature, + ) + self._codegen(partition) + + # Note: [Removed Graph Partition Arguments] + # Graph partition relies on node.read_writes to analyze the partition + # inputs and outputs. However, during codegen, we may decide some buffers + # are internal to a kernel (e.g., triton kernel) such that these buffers + # are never actually defined. This information is collected during codegen + # and recorded in V.graph.removed_buffers. So we cleanup signature and write + # prefix (i.e., generating call function and return outputs) after we have + # codegen the partition. + assert isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen) + signature = self.clean_removed_buffer_from_partition_signatures(signature) + V.graph.wrapper_code.partition_signatures = signature + V.graph.wrapper_code.write_prefix() + + partition_code, _ = V.graph.wrapper_code.generate(V.graph.is_inference) + + V.graph.wrapper_code.define_subgraph_launcher_fn(partition_code.value) + + V.graph.wrapper_code.codegen_partition_call(graph_partition_id, signature) + V.graph.wrapper_code.allocated.update( # type: ignore[has-type] + [node.get_name() for node in signature.output_nodes] + ) + + def _codegen_partitions(self) -> None: + """ + Split nodes into partitions and codegen each partition into separate functions. + This allows further applying different optimizations (e.g., cudagraph) to + each function. + """ + partitions, signatures = self.graph_partition() + + for partition, signature in zip(partitions, signatures): + assert len(partition) >= 1, ( + f"Each partition must have at least one node but found {len(partition)}" + ) + + if signature.skip_cudagraph: + self._codegen(partition) + else: + self._codegen_partition_wrapper(partition, signature) + + num_partitions = next(self._graph_partition_counter) + V.graph.wrapper_code.set_all_partition_names(num_partitions) + + # See [Note: Graph Partition Map for CUDAGraph] + if num_partitions > 0: + assert V.graph.partition_maps is not None + assert num_partitions == len(V.graph.partition_maps), ( + f"Expect {num_partitions} partition maps but got {len(V.graph.partition_maps)}" + ) + + def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: + if config.check_stack_no_cycles_TESTING_ONLY: + import torch._dynamo.convert_frame + + stack = traceback.extract_stack() + seen: OrderedSet[tuple[str, int | None]] = OrderedSet() + for frame in reversed(stack): + # This is where maybe_cprofile is + if ( + frame.name == "_compile_inner" + and frame.filename == torch._dynamo.convert_frame.__file__ + ): + break + key = (frame.filename, frame.lineno) + assert key not in seen, ( + f"Duplicate stack frame {frame.filename}:{frame.lineno}; " + "did you add a decorator to one of the functions in this stack " + "trace? If so, try using a context manager instead." + ) + seen.add(key) + + self.current_device = None + for node in nodes: + if log.isEnabledFor(logging.DEBUG): + try: + log.debug( + "Generating code for node %s with estimated runtime %f", + node.get_name(), + node.get_estimated_runtime(), + ) + except Exception: + log.debug( + "Generating code for node %s with estimated runtime 0.0", + node.get_name(), + ) + + self.enter_context(node) + + if device := node.get_device(): + if ( + device != self.current_device + or node.is_extern() + or node.is_template() + ): + self.flush() + if device != self.current_device: + if self.current_device and device_need_guard( + self.current_device.type + ): + V.graph.wrapper_code.codegen_device_guard_exit() + self.current_device = device + if device_need_guard(device.type): + assert device.index is not None, "device should have an index" + V.graph.wrapper_code.codegen_device_guard_enter(device.index) + + self.buffer_names_to_free.update(node.last_usage) + + if node.is_template(): + prologue, template_node, epilogue = node.get_prologue_template_epilogue( + list(node.get_nodes()) + ) + self.get_backend(device).codegen_template( + template_node, epilogue, prologue + ) + elif node.is_extern(): + node = typing.cast(ExternKernelSchedulerNode, node) + self.codegen_extern_call(node) + elif node.is_foreach(): + node = typing.cast(ForeachKernelSchedulerNode, node) + backend_ = self.get_backend(device) + from .codegen.cuda_combined_scheduling import CUDACombinedScheduling + from .codegen.simd import SIMDScheduling + + if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)): + backend = backend_ + else: + raise AssertionError(f"{type(self)=}") + backend.codegen_combo_kernel(node) + elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): + self.get_backend(device).codegen_node(node) + else: + assert isinstance(node, NopKernelSchedulerNode) + node.mark_run() + + if config.triton.debug_sync_kernel: + self.get_backend(device).codegen_sync() + + self.available_buffer_names.update(node.get_buffer_names()) + self.completed_operations.update(node.get_operation_names()) + + if not isinstance(node, NopKernelSchedulerNode): + device = node.get_device() + if ( + device is not None + and device.type != "meta" + and self.get_backend(device).ready_to_flush() + ): + self.flush() + + if self.current_device and device_need_guard(self.current_device.type): + # exit the outermost CUDA device guard. this is + # important for nested indentation codegen-ing. + V.graph.wrapper_code.codegen_device_guard_exit() + + self.flush() + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> tuple[float, float, list[Optional[str]]]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + device = node_list[0].get_device() + V.graph.scheduler = self + self.current_device = device + assert device is not None + backend = self.get_backend(device) + return backend.benchmark_combo_kernel(node_list) + + def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool: + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + if not config.benchmark_combo_kernel: + return True + + subkernel_nodes = nodes + device = subkernel_nodes[0].get_device() + + # don't support benchmark fusion for CPU right now. + if device is None or device.type == "cpu": + return True + + from triton.compiler.errors import CompilationError + + ms1, path1_list = 0.0, [] + for i, snode in enumerate(subkernel_nodes): + node_list = snode.get_nodes() + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + if self._any_atomic_add(node_list): + fusion_log.debug( + "ComboKernel: benchmarking may not accurate due to atomic_add" + ) + + try: + ms, path = self.benchmark_fused_nodes(node_list) + if math.isinf(ms): + fusion_log.debug( + "ComboKernel benchmark: register spilling of %d-th subkernel", + i, + ) + return False + except CompilationError as e: + # workaround triton issue: https://github.com/triton-lang/triton/issues/2151 + if "Loop-carried variable" in str(e): + fusion_log.debug( + "ComboKernel benchmark: return True because of loop-carried variable" + ) + return True # allow fusion + else: + raise + ms1 += ms + path1_list.append(path) + + try: + ms2, ms2_clone, _path2_list = self.benchmark_combo_kernel(subkernel_nodes) + except CompilationError as e: + # workaround triton issue: https://github.com/triton-lang/triton/issues/2151 + if "Loop-carried variable" in str(e): + fusion_log.debug( + "ComboKernel benchmark: return True because of loop-carried variable" + ) + return True # allow fusion + else: + raise + + # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking. + small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3 + if fusion_log.isEnabledFor(logging.DEBUG): + if ms1 > ms2 or small_kernel: + fusion_log.debug( + "can fuse (benchmark): fusing causes %sx speedup", + green_text(f"{ms1 / ms2:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing causes %sx slowdown", + red_text(f"{ms1 / ms2:.3f}"), + ) + # ms1 returned by benchmark_fused_nodes discounted clone time + return ms2 - ms2_clone < ms1 or small_kernel + + def get_buffer_layout(self, buf_name: str) -> ir.Layout: + buf = self.name_to_buf[buf_name] + assert buf.node is not None + return buf.node.get_layout() + + def update_zero_dim_cpu_tensor(self) -> None: + for node in self.nodes: + if node.is_gpu(): + for read in node.read_writes.reads: + buffer = V.graph.name_to_buffer.get(read.name) + if ( + buffer + and get_device_type(buffer) == "cpu" + and not isinstance( + buffer.layout, (NoneLayout, MultiOutputLayout) + ) + and buffer.get_size() == [] + ): + V.graph.zero_dim_cpu_tensor_list.add(read.name) + + +class BaseScheduling: + def __init__(self, scheduler: Optional[Scheduler]): + super().__init__() + self.scheduler = scheduler + + def free_buffers_in_scheduler(self) -> None: + if self.scheduler: + self.scheduler.free_buffers() + + def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: + """Return a set of .codegen.common.BackendFeature()""" + return OrderedSet() + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check whether node1 and node2 can be vertically fused or not. + """ + raise NotImplementedError + + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check whether node1 and node2 can be horizontally fused or not. + """ + raise NotImplementedError + + def can_fuse_multi_outputs_template( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + A Multi-Output Template (referenced in #144012) is a template node + with MultiOutputLayout, and its output buffers are instances of MultiOutput. + In this context, we verify whether node1 represents the Multi-Output Template + and node2 corresponds to one of its outputs. If so, we further check if + backend supports this fusion. + """ + return False + + def fuse( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> FusedSchedulerNode: + """ + Fuse two nodes + """ + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def group_fn( + self, sizes: Sequence[Sequence[sympy.Expr]] + ) -> tuple[tuple[sympy.Expr, ...], ...]: + """ + Process the iteration sizes in case a transformation needs to be applied. + """ + raise NotImplementedError + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ) -> Optional[str]: + """ + Given a template node, generate a kernel. + + This function is only available for triton now. If the third-party backend behaves as a sub-class + of TritonScheduling, it can override it or reuse it. + """ + raise NotImplementedError + + def generate_kernel_code_from_nodes( + self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool + ) -> str: + """ + Generate a kernel given a list of pre-fused nodes. + """ + raise NotImplementedError + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: + """ + Generate a kernel given a list of pre-fused nodes. + """ + raise NotImplementedError + + def codegen_sync(self) -> None: + """ + Generate synchronization code for the kernel. This method depends on the hardware characteristics. + """ + raise NotImplementedError + + def ready_to_flush(self) -> bool: + """ + Check whether the backend is requesting the scheduler to flush the generated kernel. + If not supported, please return False. + """ + return False + + def flush(self) -> None: + """ + Flush the generated kernel and python wrapper code to the source code file. + """ + raise NotImplementedError + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + raise NotImplementedError + + def benchmark_codegened_module(self, module: ModuleType) -> tuple[float, str]: + """ + Benchmark a compiled module and return the execution time + in milliseconds on randomly generated inputs. + """ + raise NotImplementedError + + def get_fusion_pair_priority( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Return an unsigned integer which represents the priority of this fusion pair. + The smaller is with higher priority. + """ + return 0 + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> tuple[float, float, list[Optional[str]]]: + """ + Benchmark the list of nodes to combine and return the execution time + and memory copy time in milliseconds on randomly generated inputs. + """ + raise NotImplementedError diff --git a/phivenv/Lib/site-packages/torch/_inductor/script.ld b/phivenv/Lib/site-packages/torch/_inductor/script.ld new file mode 100644 index 0000000000000000000000000000000000000000..af9ea0e2509ea0363f78309323c5945c78e8d659 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/script.ld @@ -0,0 +1,8 @@ +SECTIONS { + /* By default, in LLD 16, .lrodata is placed immediately after .rodata. + * However, .lrodata can be very large in our compiled models, which leads to + * relocation out-of-range errors for relative relocations. So we place it + * after other the sections that are referenced from .text using relative + * relocations. This is the default behavior in GNU ld. */ + .lrodata : { *(.lrodata) } + } INSERT AFTER .bss; diff --git a/phivenv/Lib/site-packages/torch/_inductor/select_algorithm.py b/phivenv/Lib/site-packages/torch/_inductor/select_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..13c178dc2825d02b16040d27e600ee9ce0955b25 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/select_algorithm.py @@ -0,0 +1,3181 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import inspect +import itertools +import json +import logging +import math +import operator +import os +import re +import sys +import textwrap +import time +from collections.abc import Sequence +from concurrent.futures import as_completed, ThreadPoolExecutor +from io import StringIO +from types import ModuleType +from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union +from typing_extensions import Self +from unittest.mock import patch + +import sympy + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.testing import rand_strided +from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state +from torch._inductor.utils import clear_on_fresh_cache +from torch.utils._filelock import FileLock +from torch.utils._ordered_set import OrderedSet + +from ..utils._sympy.functions import CeilDiv +from . import config, ir +from .autotune_process import ( + TensorMeta, + TritonBenchmarkRequest, + TritonCPUBenchmarkRequest, + TritonGPUBenchmarkRequest, +) +from .codecache import code_hash, PersistentCache, PyCodeCache +from .codegen.common import ( + CSEVariable, + IndentedBuffer, + KernelTemplate, + OpOverrides, + WorkspaceArg, + WorkspaceZeroMode, +) +from .codegen.simd_kernel_features import SIMDKernelFeatures +from .codegen.subgraph import SubgraphChoiceCaller +from .codegen.triton import ( + gen_common_triton_imports, + texpr, + TritonKernel, + TritonScheduling, +) +from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta +from .codegen.wrapper import pexpr +from .exc import CUDACompileError +from .ir import ChoiceCaller, PrimitiveInfoType +from .ops_handler import StoreMode +from .runtime.benchmarking import benchmarker +from .runtime.hints import DeviceProperties +from .runtime.triton_compat import HAS_WARP_SPEC +from .runtime.triton_heuristics import FixedGrid +from .utils import ( + ceildiv, + do_bench_using_profiling, + FakeIndentedBuffer, + get_dtype_size, + is_gpu, + Placeholder, + restore_stdout_stderr, + sympy_dot, + sympy_index_symbol, + sympy_product, + triton_type, + triton_type_to_torch, + unique, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + +# correctness checks struggle with fp16/tf32 +VERIFY: dict[str, Any] = {} +PRINT_AUTOTUNE = True +DEBUG = False + + +if TYPE_CHECKING: + import concurrent + + from torch._inductor.codegen.simd import IterationRangesRoot + + +class KernelNamespace: + pass + + +# these objects are imported from the generated wrapper code +extern_kernels = KernelNamespace() + + +@dataclasses.dataclass +class BenchmarkTensors: + """Represents a set of inputs and outputs for autotuning with a template""" + + input_tensors: list[torch.Tensor] + output_tensor: Optional[torch.Tensor] + + def unpack(self): + return self.input_tensors, self.output_tensor + + +@dataclasses.dataclass +class AutotuneArgs: + """During autotuning, we need to pass the same inputs to all choices. + Note: + Since we typically have a mix of external choices and triton choices, we create + two lists of inputs for the same underlying buffers: + - External inputs (for aten kernels): Include offset for sliced tensors + - Triton inputs: Use base pointer for sliced tensors, without offset + """ + + triton: BenchmarkTensors + extern: BenchmarkTensors + expected: Optional[torch.Tensor] = None + + def get_benchmark_tensors(self, extern=False) -> BenchmarkTensors: + """Returns the inputs and output tensors for a given choice.""" + bench_tensors = self.extern if extern else self.triton + return bench_tensors + + @classmethod + def from_choice_args( + cls, + example_inputs: list[torch.Tensor], + example_inputs_extern: list[torch.Tensor], + out: torch.Tensor, + out_extern: torch.Tensor, + expected: Optional[torch.Tensor] = None, + ) -> Self: + """Factory method to create AutotuneInputs from separate inputs/outputs""" + return cls( + triton=BenchmarkTensors(example_inputs, out), + extern=BenchmarkTensors(example_inputs_extern, out_extern), + expected=expected, + ) + + def verify(self, **kwargs): + """Verify the correctness of the benchmarking results""" + + torch.testing.assert_close(self.extern.output_tensor, self.expected, **kwargs) + + +class PartialRender: + """ + Some parts of a template need to be generated at the end, but + inserted into the template at the start. This allows doing a bunch + of replacements after the initial render. + """ + + def __init__(self, code, replacement_hooks) -> None: + super().__init__() + self.code = code + self.replacement_hooks = replacement_hooks + + def finalize_hook(self, hook_key: str, strict=True) -> None: + if hook_key not in self.replacement_hooks: + if strict: + raise RuntimeError( + f"{hook_key} not registered in self.replacement_hooks" + ) + else: + return + assert self.replacement_hooks[hook_key] is not None, ( + "hook_key can only be called once" + ) + self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]()) + self.replacement_hooks[hook_key] = None + + def finalize_all(self) -> str: + for key, fn in self.replacement_hooks.items(): + self.code = self.code.replace(key, fn()) + return self.code + + +# This is used to store info needed for lowering each subgraph in triton +# templates + + +@dataclasses.dataclass() +class SubgraphInfo: + body: IndentedBuffer + template_mask: Optional[str] = None + template_out: Optional[str] = None + compute: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer) + indexing_code: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer) + loads: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer) + stores: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer) + ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined] + + # only copied over if not None + range_trees: Optional[list["IterationRangesRoot"]] = None + numels = None # type: ignore[var-annotated] + + def __post_init__(self): + self.only_copy_if_non_none_fields = ("range_trees", "numels") + + def to_dict(self): + return { + field.name: getattr(self, field.name) for field in dataclasses.fields(self) + } + + +class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined] + """Handles placeholder substitutions during subgraph processing.""" + + def __init__( + self, + kernel, + subgraph_number: int, + fixed_inputs: dict[str, Any], + mask: Optional[str], + ): + super().__init__(V.ops) + self.name = f"PlaceholderSubstitution_{subgraph_number}" + self.kernel = kernel + self.fixed_inputs = fixed_inputs + self.mask = mask + + def load(self, name: str, index: sympy.Expr): + """Handle loading from tensor or fixed input.""" + if name not in self.fixed_inputs: + index_str = self._process_indexing(index) + var = self._add_kernel_input(name) + var_dtype = V.graph.get_buffer(name).dtype + line = f"tl.load({var} + {index_str})" + + if ( + var_dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + ): + line += ".to(tl.float32)" + var_dtype = torch.float32 + + out = self.kernel.cse.generate(self.kernel.compute, line, dtype=var_dtype) + return out + + return self.kernel.cse.generate( + self.kernel.compute, f"({self.fixed_inputs[name]})", dtype=torch.float32 + ) + + def indirect_indexing(self, index_var: str, size, check, wrap_neg=True): + """Convert index variable to symbolic form.""" + return sympy_index_symbol(str(index_var)) + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> str: + """Currently only supports stores for atomic adds coming from scatter nodes + This is used by flex_attention's backwards grad for captured buffers, see + zeros_and_scatter lowering + """ + assert self.mask is not None, ( + "Mask is required for inner stores in modifications" + ) + assert mode == "atomic_add", "Only atomic_add is supported for inner stores" + + buf_name = self._add_kernel_input(name) + index_str = self._process_indexing(index) + index_str = f"tl.broadcast_to({index_str}, {value}.shape)" + store = f"tl.atomic_add({buf_name} + {index_str}, {value}, {self.mask}, sem='relaxed')" + return store + + def _add_kernel_input(self, name: str): + """Add name as input to kernel and return input ref.""" + return self.kernel.args.input(name) + + def _process_indexing(self, index): + """Process and rename indexing, adding symbols as kernel inputs.""" + return self.kernel.kexpr(self.kernel.rename_indexing(index)) + + +# Function name, followed by args and kwargs. +RecordedEventsType = list[tuple[str, list[Any], dict[str, Any]]] + + +class TritonTemplateKernel(TritonKernel): + def __init__( + self, + kernel_name, + input_nodes, + output_node, + defines, + num_stages, + num_warps, + grid_fn, + meta, + call_sizes, + num_consumer_groups=0, + num_buffers_warp_spec=0, + use_jit=False, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + subgraphs: Optional[list[ir.ComputedBuffer]] = None, + workspace_arg: Optional[WorkspaceArg] = None, + prologue_loads_all_inputs=False, + ) -> None: + numel = sympy_product(output_node.get_size()) + super().__init__( + { + "x": numel, + "r0_": sympy.S.One, + }, + features=SIMDKernelFeatures([], numel), + ) + self.input_nodes = input_nodes + self.output_node = output_node + self.named_input_nodes = {} # type: ignore[var-annotated] + self.defines = defines + self.kernel_name = kernel_name + self.use_jit = use_jit + self.num_stages = num_stages + self.num_warps = num_warps + self.num_consumer_groups = num_consumer_groups + self.num_buffers_warp_spec = num_buffers_warp_spec + self.grid_fn = grid_fn + self.meta = meta + self.call_sizes = call_sizes + # for templates with fixed epilogues + self.prefix_args = prefix_args + self.suffix_args = suffix_args + self.epilogue_fn = epilogue_fn + self.render_hooks = {} # type: ignore[var-annotated] + self.triton_meta: Optional[dict[str, object]] = None + # For Templated Attention this can be a list of ir.Subgraph + self.subgraphs: Optional[list[ir.ComputedBuffer]] = subgraphs + + # Some templates use extra global memory as a workspace + self.workspace_arg = workspace_arg + if workspace_arg is not None: + self.args.workspace_args.append(workspace_arg) + + # The following attributes (body, template_mask, output_val) are all + # used for triton kernel codegen. + # They are swapped onto the TritonTemplateKernel object by + # `set_subgraph_body` + self.subgraph_bodies: dict[str, SubgraphInfo] = {} + + # input buffers which we are allowed to prologue fuse into + self.prologue_supported_inputs: OrderedSet[str] = OrderedSet() + + # input buffers which we are fusing into + self.prologue_fused_inputs: OrderedSet[str] = OrderedSet() + # input buffers which we are fusing into, which preserve a zero mask + self.prologue_fused_inputs_preserve_zero: OrderedSet[str] = OrderedSet() + + # The following attributes are all used for triton kernel codegen. + # They are swapped onto the TritonTemplateKernel object by + # `set_subgraph_body` + # NB: the names here must match the fields in SubgraphInfo + self.body: IndentedBuffer = FakeIndentedBuffer() + self.compute: IndentedBuffer = FakeIndentedBuffer() + self.indexing_code: IndentedBuffer = FakeIndentedBuffer() + self.loads: IndentedBuffer = FakeIndentedBuffer() + self.stores: IndentedBuffer = FakeIndentedBuffer() + self.template_mask: Optional[str] = None + self.template_out: Optional[str] = None + self.ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined] + + # When caching is enabled, the generated code is not dependent on the input nodes names, or + # symbolic sizes names. + # However, some of the variables returned by generate_and_load that are computed during the + # triton template expansions (code generation) are dependent on those. + # In order to cache the code generation and avoid redoing it for similar inputs that varies only by + # input names or symbol names, we do a record and replay method. + # During template expansions we record all function calls that change input_dependent_preserved_state + # and replay them on a cache hit to regenerate them. + self.cached_replay_events: Optional[RecordedEventsType] = None + + # Update each time an input is marked frozen, used to replay the freezing of inputs on a cache hit. + self.frozen_layouts_cnt = 0 + + # When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel + # by adding all inputs. + self.prologue_loads_all_inputs = prologue_loads_all_inputs + + def input_dependent_preserved_state(self) -> str: + # Not adding self.args.output_buffers on purpose. But we do not need to reproduce it on a cache hit. + # (never accessed). + return repr( + [ + self.args.input_buffers, + self.args.sizevars, + self.args.workspace_args, + self.prologue_supported_inputs, + self.frozen_layouts_cnt, + ] + ) + + def record_input_dependent_tracked_event(self) -> Callable[..., Any]: + def decorator(fn) -> Callable[..., Any]: + def wrapper(*args, **kwargs) -> Any: + pre_state = self.input_dependent_preserved_state() + result = fn(*args, **kwargs) + post_state = self.input_dependent_preserved_state() + if pre_state != post_state: + assert self.cached_replay_events is not None + self.cached_replay_events.append((fn.__name__, [*args], {**kwargs})) + return result + + return wrapper + + return decorator + + def replay_cached_events(self, events: RecordedEventsType) -> None: + for f, args, kwargs in events: + getattr(self, f)(*args, **kwargs) + + @contextlib.contextmanager + def set_subgraph_body(self, body_name: str): + assert all( + hasattr(self, field.name) for field in dataclasses.fields(SubgraphInfo) + ) + old_state = { + key.name: getattr(self, key.name) + for key in dataclasses.fields(SubgraphInfo) + } + + assert body_name in self.subgraph_bodies, body_name + + subgraph = self.subgraph_bodies[body_name] + for key, value in subgraph.to_dict().items(): + if value is None and key in subgraph.only_copy_if_non_none_fields: + continue + setattr(self, key, value) + + context = ( + contextlib.nullcontext + if not self.ops_handler + else lambda: V.set_ops_handler(self.ops_handler(V.get_ops_handler())) + ) + with context(): # type: ignore[operator] + yield + self.subgraph_bodies[body_name] = SubgraphInfo( + **{ + key.name: getattr(self, key.name) + for key in dataclasses.fields(SubgraphInfo) + } + ) + for key, value in old_state.items(): + setattr(self, key, value) + + @contextlib.contextmanager + def create_subgraph_body(self, body_name: str): + assert body_name not in self.subgraph_bodies + self.subgraph_bodies[body_name] = SubgraphInfo( + IndentedBuffer(), + None, + None, + ) + with self.set_subgraph_body(body_name): + yield + + def need_numel_args(self): + return False + + def estimate_kernel_num_bytes(self): + """ + Estimate the total number of bytes this kernel takes. + For in/out nodes, sizes are counted twice: once for reading and + once for writing. + """ + ninplace_args = len(unique(self.args.inplace_buffers.values())) + num_bytes = [] + for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): + size = V.graph.sizevars.size_hints(inp.get_size()) + numel = functools.reduce(operator.mul, size, 1) + dtype_size = get_dtype_size(inp.get_dtype()) + num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(num_bytes) + + def jit_lines(self): + if self.use_jit: + return "@triton.jit" + + argdefs, _, signature, _ = self.args.python_argdefs() + triton_meta: dict[str, Any] = { + "signature": signature_to_meta( + signature, + size_dtype=self.index_dtype, + argdefs=argdefs, + is_template=True, + ), + "device": DeviceProperties.create(self.output_node.get_device()), + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + for arg_num in equal_1_arg_indices(signature): # type: ignore[index] + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] + matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", None) + waves_per_eu = self.meta.get("waves_per_eu", None) + kpack = self.meta.get("kpack", None) + if matrix_instr_nonkdim: + triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim + if waves_per_eu: + triton_meta["waves_per_eu"] = waves_per_eu + if kpack: + triton_meta["kpack"] = kpack + + self.triton_meta = triton_meta + + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + **TritonKernel.inductor_meta_common(), + **FixedGrid.setup_grid_as_args(), + } + if config.profile_bandwidth or config.benchmark_kernel: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + + template_args = f""" + num_stages={self.num_stages}, + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + """ + + if HAS_WARP_SPEC: + template_args += f""" + num_consumer_groups={self.num_consumer_groups}, + num_buffers_warp_spec={self.num_buffers_warp_spec}, + """ + + return f""" + @triton_heuristics.template( + {template_args} + ) + @triton.jit + """ + + def gen_argdefs(self): + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + return f"{', '.join(x.full_name() for x in arg_defs)}" + + self.render_hooks[""] = hook + return "" + + def gen_defines(self): + return self.defines + + def def_kernel(self, *argnames): + """ + Hook called from template code to generate function def and + needed args. + """ + assert all(isinstance(x, str) for x in argnames) + renames = IndentedBuffer(initial_indent=1) + + named_args = self.input_nodes[ + self.prefix_args : len(self.input_nodes) - self.suffix_args + ] + + assert len(argnames) == len(named_args), ( + len(argnames), + len(named_args), + self.prefix_args, + len(self.input_nodes), + ) + + for input_node in self.input_nodes[: self.prefix_args]: + # get args in correct order + self.args.input(input_node.get_name()) + + for name, input_node in zip(argnames, named_args): + arg_name = f"arg_{name}" + self.named_input_nodes[name] = input_node + if input_node.get_name() in V.graph.removed_buffers: + continue + if input_node.get_name() in self.prologue_fused_inputs: + continue + + self.args.input_buffers[input_node.get_name()] = arg_name + + # The args may be duplicated, so renaming must be after args are de-duplicated. + for name in argnames: + input_node = self.named_input_nodes[name] + if self.prologue_loads_all_inputs: + self.prologue_supported_inputs.add(input_node.get_name()) + if input_node.get_name() in V.graph.removed_buffers: + continue + if input_node.get_name() in self.prologue_fused_inputs: + continue + + arg_name = self.args.input_buffers[input_node.get_name()] + if input_node.get_layout().offset == 0: + renames.writeline(f"{name} = {arg_name}") + else: + offset = texpr(self.rename_indexing(input_node.get_layout().offset)) + renames.writeline(f"{name} = {arg_name} + {offset}") + + for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]: + # get args in correct order + if input_node.get_name() in V.graph.removed_buffers: + continue + if input_node.get_name() in self.prologue_fused_inputs: + continue + + self.args.input(input_node.get_name()) + + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + code = IndentedBuffer() + code.splice(gen_common_triton_imports()) + code.splice(self.jit_lines()) + code.writeline( + f"def {self.kernel_name}({', '.join(x.full_name() for x in arg_defs)}):" + ) + with code.indent(): + code.splice(self.defines) + code.splice(renames.getvalue()) + return code.getvalue() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def size(self, name: str, index: int): + """ + Hook called from template code to get the size of an arg. + Will add needed args to pass it in if it is dynamic. + """ + assert isinstance(index, int) + if name is None: + val = self.output_node.get_size()[index] + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_size()[index] + return texpr(self.rename_indexing(val)) + + def stride(self, name, index=None): + """ + Hook called from template code to get the stride of an arg. + Will add needed args to pass it in if it is dynamic. + """ + if name is None: + val = self.output_node.get_stride() + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_stride() + + if isinstance(index, int): + return texpr(self.rename_indexing(val[index])) + return ", ".join([texpr(self.rename_indexing(i)) for i in val]) + + def _get_subgraph(self, subgraph_number: int): + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert subgraph_number < len(self.subgraphs), ( + f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + ) + assert self.body.getvalue() == "", ( + "Body should be clear before adding a modification" + ) + return self.subgraphs[subgraph_number] + + def _handle_scatter_graph(self, scatter_graph): + """Handle processing for a single scatter graph. + + Args: + scatter_graph: The scatter graph to process + """ + assert isinstance(scatter_graph, ir.ComputedBuffer), ( + f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}" + ) + + def contiguous_strides(x): + # We always create a fresh contiguous grad for scattering into + return sum( + x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride()) + ) + + return scatter_graph.data.store_output( # type: ignore[attr-defined] + scatter_graph.name, contiguous_strides, [] + ) + + def modification( + self, + subgraph_number: int, + output_name: Optional[str], + mask: Optional[str] = None, + **fixed_inputs, + ) -> str: + """This creates a modification function for a subgraph. + To use this inside a template, the first argument should specify which subgraph to codegen for + + Args: + subgraph_number (int): The index of the subgraph in self.subgraphs + output_name (Optional[str]): The name of the output variable to store the result in + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. + """ + num = 0 + out = None + scatters = [] + while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: + num += 1 + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): + subgraph = self._get_subgraph(subgraph_number) + modification_handler = ModificationWrapper( + self, subgraph_number, fixed_inputs, mask + ) + with V.set_ops_handler(modification_handler): + assert isinstance(subgraph, (ir.ComputedBuffer, list)), ( + f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" + ) + # Handle scatter stores + if isinstance(subgraph, list): + for scatter_graph in subgraph: + scatters.append(self._handle_scatter_graph(scatter_graph)) + elif isinstance(subgraph.data, ir.InputBuffer): + out = subgraph.data.make_loader()(()) + else: + out = subgraph.data.inner_fn(()) + + self.codegen_body() + if output_name is not None: + assert isinstance(output_name, str) + assert out is not None + self.body.writeline(f"{output_name} = {out.value}") + else: + assert out is None + for scatter in scatters: + self.body.writeline(str(scatter)) + + body_val = self.body.getvalue() + self.cse.invalidate(OrderedSet()) + return body_val + + def load_input( + self, + input_name: str, + output_name: str, + indices: Union[list[Any], tuple[Any]], + mask: Optional[str] = None, + other: Optional[Union[float, int]] = 0.0, + indent_width: int = 4, + ): + """Loads an input and applies any necessary preprocessing or masking. + + Args: + input_name (str): The name of the input to load. + indices (Union[List, Tuple]): The index for each dimension of the input. + val (str): The name of the variable to store the loaded value. + mask (Optional[str]): An optional mask to use for the load operation. + other (Optional[Union[float, int]]): The value to use for masked elements. Default is 0.0. + indent_width (int): The number of spaces to use for indentation. + """ + + input_node = self.named_input_nodes[input_name] + if not self.prologue_loads_all_inputs: + self.prologue_supported_inputs.add(input_node.get_name()) + + tilings = (sympy_product(input_node.get_size()), sympy.Integer(1)) + groups = { + "x": tilings[0], + "r0_": tilings[1], + } + + range_trees = self.construct_range_trees( + pid_cache=None, + inside_reduction=False, + is_reduction=False, + numels=groups, + no_x_dim=False, + ) + load_code = None + + with self.create_subgraph_body(f""): + assert isinstance(indices, (list, tuple)) + assert isinstance(output_name, str) + assert isinstance(mask, (str, type(None))) + self.range_trees = range_trees + self.numels = {k: V.graph.sizevars.simplify(v) for k, v in groups.items()} + indices = list(map(OpOverrides.paren, indices)) + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + + lengths = [V.graph.sizevars.simplify(s) for s in input_node.get_size()] + assert len(indices) == len(lengths) + + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + assert len(indices) == len(lengths) + + # glue to make generated code use same indexing from template + + # TODO (from reviewers as well) + # in codegen_template, + # prologue_node.codegen(kernel.split_and_set_ranges(prologue_node.get_ranges())) + # the ranges need to reflect the group of the prologue input or it will error + # not sure if there is any difference between original range_tree_entry in + # and new one from correct lengths/groups... both actually seem to work + for name, range_tree_entry in zip( + indices, self.range_trees[0].construct_entries(lengths) + ): + range_tree_entry.set_name(name) + contiguous_index = sympy_dot( + ir.FlexibleLayout.contiguous_strides(lengths), index_symbols + ) + contiguous_index = self.rename_indexing(contiguous_index) + self.body.writeline("xindex = " + texpr(contiguous_index)) + + xindex_range_root = self.range_trees[0].lookup( + sympy.Integer(1), sympy_product(lengths) + ) + xindex_range_root.set_name("xindex") + + # Note - ["None" override_mask] + # MM Templates work by taking out of bounds index values and wrapping them around to 0 + # so that no mask is required on the load: offs_a_m = `rm % M` + # We should to override the mask to be "None" instead of inheriting the mask that would + # have been loaded otherwise. + # We are using "None" for clarity in output code, but + # we could alternatively emit `xmask = tl.full([xindex.shape], True, tl.int1)` + self.template_mask = mask if mask is not None else "None" + self.template_out = "xindex" + self.template_indices = indices + self.named_input_nodes[input_name].data.freeze_layout() + self.cse.invalidate(OrderedSet()) + + template_mask = self.template_mask + + class StoreOutputSubstitution(V.WrapperHandler): # type: ignore[name-defined] + name = "StoreOutputSubstitution" + + def store( + self, + name: str, + index: sympy.Expr, + value: "CSEVariable", + mode: "StoreMode" = None, + ): + V.kernel.store_buffer_names.add(name) + V.kernel.cse.store_cache[name] = value + if name in V.kernel.prologue_fused_inputs: + # We load masked out values with 0, then apply a prologue. + # The masked out values may not necessariliy be 0 any more + # so we need to reapply the mask. + value_dtype = value.dtype + value_str = str(value) + if template_mask != "None" and ( + name not in V.kernel.prologue_fused_inputs_preserve_zero + or other != 0 + ): + value_str = ( + f"tl.where({template_mask}, {value_str}, {other})" + ) + + if value_dtype != V.graph.get_buffer(name).dtype: + value_str = f"{value_str}.to({triton_type(V.graph.get_buffer(name).dtype)})" + + # TODO: we should have intermediary var shapes + V.kernel.compute.writeline( + f"{output_name} = {value_str}.broadcast_to(xindex.shape)" + ) + + self.ops_handler = StoreOutputSubstitution + + input_node = self.named_input_nodes[input_name] + output_index = input_node.make_indexer()(index_symbols) + + # in def_kernel above we define the inputs with the storage offset adjusted + # creating the load in input_node.make_indexer() will also adjust by storage offset + # so subtract here to not double increment + if not V.graph.sizevars.statically_known_equals( + input_node.layout.offset, 0 + ): + output_index = output_index - self.rename_indexing( + input_node.get_layout().offset + ) + + output_index = self.rename_indexing(output_index) + + if output_index == contiguous_index: + output_index_str = "xindex" + else: + out_indexing = self.indexing( + output_index, + copy_shape=self.template_out, + override_mask=self.template_mask, + ) + from .codegen.triton import IndexingOptions + + assert isinstance(out_indexing, IndexingOptions) + output_index_str = ( + f"({out_indexing.index_str}).broadcast_to(xindex.shape)" + ) + + # Generate load code + load_code = f"{output_name} = tl.load({input_name} + ({output_index_str})" + + if mask: + load_code += f", mask={mask}, other={other})" + else: + load_code += ")" + + hook_key = f"" + + def hook(): + with self.set_subgraph_body(hook_key): + self.cse.invalidate(OrderedSet()) + self.codegen_body() + self.cse.invalidate(OrderedSet()) + if input_node.get_name() not in self.prologue_fused_inputs: + assert load_code is not None + self.body.writeline(load_code) + + return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() + + assert hook_key not in self.render_hooks + self.render_hooks[hook_key] = hook + return hook_key + + def store_output( + self, + indices: Union[list[Any], tuple[Any]], + val: str, + mask: Optional[str] = None, + indent_width: int = 4, + ): + """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. + + Args: + indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of + these indices and output strides must match `val`. + val (str): The value to store. + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. + indent_width (int): The number of spaces to use for indentation. This is used when the call to + store_output is indented in the kernel definition. + """ + with self.create_subgraph_body(""): + assert isinstance(indices, (list, tuple)) + assert isinstance(val, str) + assert isinstance(mask, (str, type(None))) + assert self.template_mask is None + indices = list(map(OpOverrides.paren, indices)) + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + lengths = [ + V.graph.sizevars.simplify(s) for s in self.output_node.get_size() + ] + assert len(indices) == len(lengths) + + # glue to make generated code use same indexing from template + for name, range_tree_entry in zip( + indices, self.range_trees[0].construct_entries(lengths) + ): + range_tree_entry.set_name(name) + contiguous_index = sympy_dot( + ir.FlexibleLayout.contiguous_strides(lengths), index_symbols + ) + contiguous_index = self.rename_indexing(contiguous_index) + self.body.writeline("xindex = " + texpr(contiguous_index)) + self.range_trees[0].lookup(sympy.S.One, sympy_product(lengths)).set_name( + "xindex" + ) + self.template_mask = mask + self.template_out = val + self.template_indices = indices + output_index = self.output_node.get_layout().make_indexer()(index_symbols) + output_index = self.rename_indexing(output_index) + if output_index == contiguous_index: + output_index = sympy.Symbol("xindex", integer=True) + + acc_dtype = ( + triton_type_to_torch(self.meta["ACC_TYPE"]) + if "ACC_TYPE" in self.meta + else torch.float32 + ) + epilogue_args = [V.kernel.cse.namedvar(val, dtype=acc_dtype)] + for input_node in itertools.chain( + self.input_nodes[: self.prefix_args], + self.input_nodes[len(self.input_nodes) - self.suffix_args :], + ): + input_node.freeze_layout() + epilogue_args.append(input_node.make_loader()(index_symbols)) + # We update frozen_layouts_cnt in order to replay this function on a cache hit. + self.frozen_layouts_cnt += 1 + + V.ops.store( + self.output_node.get_name(), + output_index, + self.epilogue_fn(*epilogue_args), + ) + self.codegen_body() + + def hook(): + # more stuff might have been added since the codegen_body above + self.codegen_body() + self.cse.invalidate(OrderedSet()) + + return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def render(self, template, kwargs, record_input_dependent_tracked_event=False): + if record_input_dependent_tracked_event: + self.cached_replay_events = [] + + template_env = { + fn.__name__: self.record_input_dependent_tracked_event()(fn) + if record_input_dependent_tracked_event + else fn + for fn in [ + self.def_kernel, + self.size, + self.stride, + self.store_output, + self.load_input, + self.make_load, + self.modification, + self.gen_argdefs, + self.gen_defines, + ] + } + return PartialRender( + template.render(**template_env, **kwargs), + self.render_hooks, + ) + + def make_load(self, name, indices, mask): + """ + Optional helper called from template code to generate the code + needed to load from an tensor. + """ + assert isinstance(indices, (list, tuple)) + assert isinstance(name, str) + assert isinstance(mask, str) + stride = self.named_input_nodes[name].get_stride() + indices = list(map(OpOverrides.paren, indices)) + assert len(indices) == len(stride) + index = " + ".join( + f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) + ) + return f"tl.load({name} + ({index}), {mask}, other=0.0)" + + def indexing( + self, + index: sympy.Expr, + *, + dense_indexing=False, + copy_shape=None, + override_mask=None, + block_ptr=False, + ): + """ + Override the default indexing to use our custom mask and force + dense indexing. + """ + return super().indexing( + index, + dense_indexing=False, + # We pass template_out as the shape to broadcast the indexing to as + # the mask might be broadcast to the output shape + copy_shape=self.template_out, + override_mask=self.template_mask, + block_ptr=block_ptr, + ) + + def codegen_range_tree(self): + pass # ignore default codegen + + def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + + grid_args = () + if isinstance(self.grid_fn, SymbolicGridFn): + grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta) + elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes): + grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta) + else: + assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn" + wrapper.add_import_once(f"import {self.grid_fn.__module__}") + meta = wrapper.add_meta_once(self.meta) + fn_name = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}" + call_args.append( + f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})" + ) + arg_types.append(None) + assert len(grid_args) in (0, 3), "grid_fn should return 3 values" + call_args.extend(grid_args) + arg_types.extend(map(type, grid_args)) + + if self.workspace_arg is not None: + wrapper.generate_workspace_allocation(self.workspace_arg) + wrapper.generate_kernel_call( + name, + call_args, + arg_types=arg_types, + triton_meta=self.triton_meta, + triton=True, + ) + if self.workspace_arg is not None: + wrapper.generate_workspace_deallocation(self.workspace_arg) + + def kernel_benchmark_extra_args(self) -> list[str]: + return [ + str(x) + for x in self.grid_fn( + *V.graph.sizevars.size_hints(self.call_sizes), self.meta + ) + ] + + +@functools.cache +def _jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class GenerateAndLoadResult(NamedTuple): + """ + Return type of TritonTemplate.generate_and_load. + """ + + mod: ModuleType + extra: str + input_call_args: tuple[str, ...] + prologue_supported_inputs: OrderedSet[str] + kernel_args_sizevars_keys: tuple[sympy.Expr] + kernel_options: dict[str, Any] + + +class GeneratedCodeCacheEntry(NamedTuple): + code: str + extra: str + events: list[Any] + + +class GeneratedCodeCache: + """ + Cache for generated code. The cache key is a string representation of the input nodes, + number of stages, number of warps, and call sizes. The cache value is a tuple of the + generated code, extra code, and events. + """ + + def __init__(self, *args, **kwargs): + self._cache: dict[str, GeneratedCodeCacheEntry] = {} + + def cache_clear(self) -> None: + self._cache.clear() + + def __repr__(self): + return repr(self._cache) + + def make_key( + self, + input_nodes: tuple[ir.IRNode], + num_stages: int, + num_warps: int, + call_sizes: list[sympy.core.symbol.Symbol], + prefix_args: int, + suffix_args: int, + epilogue_fn: Optional[Callable[..., Any]], + epilogue_fn_hash: Optional[str], + subgraphs: Optional[list[ir.Buffer]], # has to be none to cache + workspace_arg: Optional[WorkspaceArg], # has to be none to cache + layout: ir.Layout, + num_consumer_groups: int, + num_buffers_warp_spec: int, + kwargs: dict[str, Any], + ) -> Optional[str]: + def layout_key(layout: ir.Layout) -> str: + assert not isinstance(layout, ir.FlexibleLayout) + return repr( + [ + layout.size, + layout.stride, + layout.dtype, + layout.device, + layout.offset, + ] + ) + + def has_flexible_layout() -> bool: + if isinstance(layout, ir.FlexibleLayout): + return True + + for input in input_nodes: + if isinstance(input.get_layout(), ir.FlexibleLayout): + return True + return False + + if epilogue_fn is identity: + assert epilogue_fn_hash is None + epilogue_fn_hash = "identity" + + # we do not cache under those conditions right now. + if ( + has_flexible_layout() + or subgraphs is not None + or workspace_arg is not None + or epilogue_fn_hash is None + ): + return None + + return repr( + { + "input_nodes": [ + layout_key(input.get_layout()) for input in input_nodes + ], + "num_stages": num_stages, + "num_warps": num_warps, + "prefix_args": prefix_args, + "suffix_args": suffix_args, + "call_sizes": call_sizes, + "layout": layout_key(layout), + "num_consumer_groups": num_consumer_groups, + "num_buffers_warp_spec": num_buffers_warp_spec, + "epilogue_fn_hash": epilogue_fn_hash, + "kwargs": kwargs, + } + ) + + def get_entry(self, cache_key: Optional[str]) -> Optional[GeneratedCodeCacheEntry]: + if cache_key is None: + return None + + entry = self._cache.get(cache_key, None) + if entry is None: + torch._dynamo.utils.counters["inductor"]["generated_module_cache_miss"] += 1 + else: + torch._dynamo.utils.counters["inductor"]["generated_module_cache_hit"] += 1 + return entry + + def put_entry( + self, + cache_key: Optional[str], + code: str, + extra: str, + events: list[Any], + ) -> None: + if cache_key is None: + return + entry = GeneratedCodeCacheEntry(code, extra, events) + self._cache.update({cache_key: entry}) + + +class TritonTemplate(KernelTemplate): + """ + A Triton template is a template that can be used to generate a Triton kernel. + """ + + # Allow subclasses to override the kernel type + kernel_type: type[Any] = TritonTemplateKernel + index_counter = itertools.count() + all_templates: dict[str, "TritonTemplate"] = {} + + def __init__( + self, + name: str, + grid: Any, + source: str, + debug=False, + cache_codegen_enabled_for_template=False, + prologue_loads_all_inputs=False, + ) -> None: + super().__init__(name) + self.grid = grid + self.template = self._template_from_string(source) + assert name not in self.all_templates, "duplicate template name" + TritonTemplate.all_templates[name] = self + self.debug = debug + self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template + self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache() + clear_on_fresh_cache(self._generated_code_cache) + # When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel + # by adding all inputs. + self.prologue_loads_all_inputs = prologue_loads_all_inputs + + # When this flag is on, we ensure that the cached results and the generated result if cache + # was not used are the same. + test_cache = False + + def maybe_append_choice( + self, choices: list[Any], **kwargs: Any + ) -> Optional[NotImplementedError]: + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + Returns None if success, otherwise returns the error. + + choices: A list of ChoiceCallers. + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + + try: + choices.append(self.generate(generate_with_caching=True, **kwargs)) + return None + except NotImplementedError as e: + log.info( + "Cannot Append Choice: %s. KernelTemplate type is %s", + e, + type(self), + stack_info=log.getEffectiveLevel() < logging.INFO, + ) + return e + + # NOTE: MAKE SURE THAT ANY ARGUMENT ADDED TO THIS FUNCTION IS PROPERLY HANDLED IN _generated_code_cache.make_key. + def generate_and_load( + self, + input_nodes: tuple[ir.IRNode], + num_stages: int, + num_warps: int, + call_sizes: list[sympy.core.symbol.Symbol], + prefix_args: int, + suffix_args: int, + epilogue_fn: Optional[Callable[..., Any]], + epilogue_fn_hash: Optional[str], + subgraphs: Optional[list[ir.Buffer]], + workspace_arg: Optional[WorkspaceArg], + num_consumer_groups: int, + num_buffers_warp_spec: int, + layout: ir.Layout, + kwargs: dict[str, Any], + generate_with_caching, + ) -> Optional[GenerateAndLoadResult]: + """Generate the python code and load it into the current process""" + caching_enabled = ( + generate_with_caching + and torch._inductor.config.enable_caching_generated_triton_templates + ) + + cache_key = None + if caching_enabled: + cache_key = self._generated_code_cache.make_key( + input_nodes, + num_stages, + num_warps, + call_sizes, + prefix_args, + suffix_args, + epilogue_fn, + epilogue_fn_hash, + subgraphs, + workspace_arg, + layout, + num_consumer_groups, + num_buffers_warp_spec, + kwargs, + ) + + assert self.template, "requires jinja2" + defines = StringIO() + + for name, val in kwargs.items(): + defines.write(f"{name} : tl.constexpr = {val}\n") + defines = defines.getvalue() + + fake_out = ir.Buffer(name="buf_out", layout=layout) + kernel_name = f"triton_{self.name}" + + numel = sympy_product(layout.size) + buffers = itertools.chain(input_nodes, (fake_out,)) + if not TritonScheduling.can_use_32bit_indexing(numel, buffers): + raise NotImplementedError( + "64-bit indexing is not yet implemented for triton templates" + ) + + kernel_options = { + "input_nodes": input_nodes, + "defines": defines, + "num_stages": num_stages, + "num_warps": num_warps, + "grid_fn": self.grid, + "meta": kwargs, + "call_sizes": call_sizes, + "prefix_args": prefix_args, + "suffix_args": suffix_args, + "epilogue_fn": epilogue_fn, + "subgraphs": subgraphs, + "prologue_loads_all_inputs": self.prologue_loads_all_inputs, + } + + if HAS_WARP_SPEC: + kernel_options.update( + { + "num_consumer_groups": num_consumer_groups, + "num_buffers_warp_spec": num_buffers_warp_spec, + } + ) + + def make_kernel(): + return self.kernel_type( + kernel_name=kernel_name, + output_node=fake_out, + workspace_arg=workspace_arg, + use_jit=False, + **kernel_options, + ) + + def generate_code(kernel) -> Optional[tuple[str, str]]: + def make_extra() -> str: + extra_parts = [ + f"{kwarg}={repr(kwargs[kwarg])}" for kwarg in sorted(kwargs.keys()) + ] + + extra_parts.extend( + [ + f"num_stages={num_stages}", + f"num_warps={num_warps}", + ] + ) + if HAS_WARP_SPEC: + extra_parts.extend( + [ + f"num_consumer_groups={num_consumer_groups}", + f"num_buffers_warp_spec={num_buffers_warp_spec}", + ] + ) + extra = "-".join(extra_parts) + "-" + return extra + + try: + template = kernel.render(self.template, kwargs, caching_enabled) + with kernel.set_subgraph_body(""): + code = template.finalize_all() + except ZeroDivisionError: + # TODO(nmacchioni): fix sympy division by zero + return None + if self.debug: + print("Generated Code:\n", code) + + extra = make_extra() + return code, extra + + def maybe_test_cache(code: str, extra: str, kernel): + if self.test_cache or self.debug: + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)), + V.graph.set_current_device(layout.device), + make_kernel() as kernel_test, + ): + result2 = generate_code(kernel_test) + assert result2 is not None + code_test, extra_test = result2 + assert ( + code == code_test + and extra == extra_test + and kernel.args.input_buffers == kernel_test.args.input_buffers + and kernel.prologue_supported_inputs + == kernel_test.prologue_supported_inputs + and kernel.args.sizevars == kernel_test.args.sizevars + ), "Generated code cache results in wrong output" + + # Generate code, extra. + code: Optional[str] = None + extra: Optional[str] = None + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)), + V.graph.set_current_device(layout.device), + make_kernel() as kernel, + ): + cache_entry = self._generated_code_cache.get_entry(cache_key) + cache_hit = False + + if cache_entry is not None: + code, extra, events = cache_entry + kernel.replay_cached_events(events) + cache_hit = True + + else: + result = generate_code(kernel) + if result is None: # happens at ZeroDivisionError: + return None + code, extra = result + self._generated_code_cache.put_entry( + cache_key, code, extra, kernel.cached_replay_events + ) + + assert code is not None and extra is not None + + mod = PyCodeCache.load(code, extra) + + input_call_args = tuple(kernel.args.input_buffers.keys()) + prologue_supported_inputs = kernel.prologue_supported_inputs.copy() + kernel_args_sizevars_keys = tuple(kernel.args.sizevars.keys()) + + if cache_hit: + maybe_test_cache(code, extra, kernel) + + return GenerateAndLoadResult( + mod, + extra, + input_call_args, + prologue_supported_inputs, + kernel_args_sizevars_keys, + kernel_options, + ) + + def generate( # type: ignore[override] + self, + input_nodes: tuple[ir.IRNode], + layout: ir.Layout, + num_stages: int, + num_warps: int, + num_consumer_groups: int = 0, + num_buffers_warp_spec: int = 0, + prefix_args: int = 0, + suffix_args: int = 0, + epilogue_fn: Optional[Callable[..., Any]] = identity, + epilogue_fn_hash: Optional[str] = None, + subgraphs: Optional[list[ir.Buffer]] = None, + mutated_inputs: Optional[list[ir.IRNode]] = None, + call_sizes: Optional[list[sympy.core.symbol.Symbol]] = None, + workspace_arg: Optional[WorkspaceArg] = None, + generate_with_caching=False, + **kwargs, + ): + """This function generates a TritonTemplateCaller + + Args: + input_nodes: List of input nodes + layout: Output layout + num_stages: Number of stages for triton launch + num_warps: Number of warps for triton launch + prefix_args: Number of input nodes to be passed as arguments + suffix_args: Number of input nodes to be passed as arguments + epilogue_fn: Optional epilogue function to be called on the output + subgraphs: Optional subgraphs to be passed as arguments, these will be inlined + into the triton template string + mutated_inputs: Optional list of input nodes that are mutated by the kernel, this is helpful + if you need to return multiple outputs. You can pass them as inputs and mark them as + being mutated by the kernel. + """ + # HACK: Triton currently breaks if TF32 floats are requested, but the CUDA + # capability doesn't support them. This is a bug in Triton, but for now we'll + # patch around it here. See https://github.com/triton-lang/triton/issues/3011 + # for one example issue with this problem. + if torch.cuda.is_available() and not torch.cuda.is_tf32_supported(): + kwargs["ALLOW_TF32"] = "False" + + if call_sizes is None: + call_sizes = layout.size + + result = self.generate_and_load( + input_nodes, + num_stages, + num_warps, + call_sizes, + prefix_args, + suffix_args, + epilogue_fn, + epilogue_fn_hash, + subgraphs, + workspace_arg, + num_consumer_groups, + num_buffers_warp_spec, + layout, + kwargs, + generate_with_caching and self._cache_codegen_enabled_for_template, + ) + + # May happen as result of dev by 0. + if result is None: + return None + + # We expect the input_buffer order to be [*input_nodes, *captured_buffers] + expected_input_args = tuple(unique(x.get_name() for x in input_nodes)) + assert ( + result.input_call_args[: len(expected_input_args)] == expected_input_args + ), ( + result.input_call_args, + expected_input_args, + ) + + full_input_nodes = tuple( + [V.graph.get_buffer(k) for k in result.input_call_args] + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, result.kernel_args_sizevars_keys), + fallback=config.unbacked_symint_fallback, + ) + + kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}" + + workspace_args = [] + if workspace_arg is not None: + # Create workspace tensor + workspace_size = workspace_arg.count + workspace_tensor = torch.empty_strided( + (workspace_size,), + (1,), + dtype=torch.uint8, + device=layout.device.type, + ) + + # Handle zero initialization if needed + if workspace_arg.zero_mode != WorkspaceZeroMode.UNINITIALIZED: + workspace_tensor.zero_() + + workspace_args.append(workspace_tensor) + + options = result.kernel_options + + def make_kernel_render(out_node): + assert result is not None + kernel = self.kernel_type( + kernel_name=str(Placeholder.KERNEL_NAME), + output_node=out_node, + workspace_arg=workspace_arg, + use_jit=False, + **options, + ) + render = functools.partial( + kernel.render, + self.template, + kwargs, + ) + return kernel, render + + # create the BenchmarkRequest + assert result.mod.__file__ is not None + grid = self.grid( + *V.graph.sizevars.size_hints( + call_sizes, + fallback=config.unbacked_symint_fallback, + ), + kwargs, + ) + bmreq_cls: type[TritonBenchmarkRequest] + if layout.device.type == "cpu": + bmreq_cls = TritonCPUBenchmarkRequest + else: + bmreq_cls = TritonGPUBenchmarkRequest + bmreq = bmreq_cls( + module_path=result.mod.__file__, + module_cache_key=result.mod.key, + kernel_name=f"triton_{self.name}", + extra_args=[*extra_args, *workspace_args, *grid], + num_stages=num_stages, + num_warps=num_warps, + num_consumer_groups=num_consumer_groups, + num_buffers_warp_spec=num_buffers_warp_spec, + matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), + waves_per_eu=kwargs.get("waves_per_eu", 0), + kpack=kwargs.get("kpack", 2), + input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type] + output_tensor_meta=TensorMeta.from_irnodes(layout), + ) + + return TritonTemplateCaller( + kernel_hash_name, + full_input_nodes, + layout, + make_kernel_render, + result.extra.strip("-").replace("-", ", "), + bmreq, + log_info={ + "tile_shape": str( + ( + kwargs.get("BLOCK_M", -1), + kwargs.get("BLOCK_K", -1), + kwargs.get("BLOCK_N", -1), + ) + ), + "num_stages": num_stages, + "num_warps": num_warps, + "GROUP_M": kwargs.get("GROUP_M", -1), + "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), + "acc_type": str(kwargs.get("ACC_TYPE", None)), + "matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0), + "waves_per_eu": kwargs.get("waves_per_eu", 0), + "kpack": kwargs.get("kpack", 2), + }, + mutated_inputs=mutated_inputs, + workspace_arg=workspace_arg, + allowed_prologue_inps=result.prologue_supported_inputs, + ) + + +class ExternKernelChoice: + def __init__( + self, + kernel, + cpp_kernel=None, + *, + name=None, + has_out_variant=True, + op_overload=None, + use_fallback_kernel=False, + kernel_creator=None, + ) -> None: + super().__init__() + name = name or kernel.__name__ + assert callable(kernel) + assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" + self.name = name + self.cpp_kernel_name = cpp_kernel + self.has_out_variant = has_out_variant + setattr(extern_kernels, name, kernel) + self.op_overload = op_overload + self.use_fallback_kernel = use_fallback_kernel + self.kernel_creator = kernel_creator + + def to_callable(self): + return getattr(extern_kernels, self.name) + + def call_name(self): + return f"extern_kernels.{self.name}" + + @functools.cache # noqa: B019 + def hash_key(self): + fn = self.to_callable() + parts = [ + self.name, + getattr(fn, "__name__", ""), + getattr(fn, "__module__", ""), + ] + try: + parts.append(inspect.getsource(fn)) + except Exception: + pass + return code_hash("-".join(parts)) + + def bind( + self, + input_nodes, + layout, + ordered_kwargs_for_cpp_kernel=(), + **kwargs, + ): + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + return ExternKernelCaller( + self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant + ) + + +class TritonTemplateCaller(ir.TritonTemplateCallerBase): + def __init__( + self, + name, + input_nodes, + layout, + make_kernel_render, + description, + bmreq, + log_info: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ] = None, + mutated_inputs=None, + workspace_arg: Optional[WorkspaceArg] = None, + allowed_prologue_inps: Optional[OrderedSet[str]] = None, + ) -> None: + super().__init__(name, input_nodes, layout, description) + self.make_kernel_render = make_kernel_render + self.bmreq: TritonBenchmarkRequest = bmreq + if log_info is None: + log_info = {} + self.log_info: dict[str, Any] = log_info + self.log_info.update( + { + "backend": "Triton", + "num_stages": self.bmreq.num_stages, + "num_warps": self.bmreq.num_warps, + } + ) + self.mutated_inputs = mutated_inputs + self.workspace_arg = workspace_arg + self.allowed_prologue_inps = ( + allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet() + ) + + def benchmark(self, *args, out): + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def precompile(self): + assert self.bmreq is not None + self.bmreq.precompile() + + def __str__(self) -> str: + return f"TritonTemplateCaller({self.bmreq.module_path}, {self.description})" + + def call_name(self): + return f"template_kernels.{self.name}" + + def hash_key(self): + return "-".join( + [ + self.name.rsplit("_", 1)[0], + self.bmreq.module_cache_key, + ] + ) + + def output_node(self): + return ir.TensorBox.create( + ir.TritonTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + mutated_inputs=self.mutated_inputs, + allowed_prologue_inps=self.allowed_prologue_inps, + ) + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return self.log_info + + def get_make_kernel_render(self): + return self.make_kernel_render + + def autoheuristic_id(self): + type_name = "triton" + info = self.info_dict() + # TODO(AlnisM): Does tile_shape always exist? + tile = info["tile_shape"] + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + num_stages = info["num_stages"] + num_warps = info["num_warps"] + return f"type={type_name}_BLOCK-M={BLOCK_M}_BLOCK-K={BLOCK_K}_BLOCK-N={BLOCK_N}_numstages={num_stages}_numwarps={num_warps}" + + +class ExternKernelCaller(ChoiceCaller): + def __init__( + self, + choice: ExternKernelChoice, + input_nodes, + layout, + kwargs=None, + *, + has_out_variant=True, + ) -> None: + super().__init__(choice.name, input_nodes, layout, description="") + self.choice = choice + self.kwargs = kwargs or {} + self.has_out_variant = has_out_variant + + def __str__(self) -> str: + return f"ExternKernelCaller({self.choice.call_name()})" + + def benchmark(self, *args, out): + if out.numel() == 0: + # no need to run the kerrnel of do benchmarking + return 0.0 + if self.has_out_variant: + return super().benchmark(*args, out=out) + else: + algo = self.to_callable() + out_new = algo(*args) + torch._C._dynamo.guards.assert_size_stride( + out_new, tuple(out.size()), tuple(out.stride()) + ) + out.copy_(out_new) # for correctness checking + if config.profile_bandwidth_with_do_bench_using_profiling: + return do_bench_using_profiling(lambda: algo(*args)) + return benchmarker.benchmark(algo, args, {}) + + def to_callable(self): + fn = self.choice.to_callable() + if self.kwargs: + return functools.partial(fn, **self.kwargs) + return fn + + def hash_key(self): + return "-".join( + [ + self.choice.name, + *[ + f"{kwarg}={repr(self.kwargs[kwarg])}" + for kwarg in sorted(self.kwargs.keys()) + ], + self.choice.hash_key(), + ] + ) + + def output_node(self): + if self.choice.use_fallback_kernel: + assert self.choice.op_overload is not None, ( + "Please provide an op_overload to use ir.FallbackKernel" + ) + inner = ir.FallbackKernel.create( + self.choice.op_overload, *self.input_nodes, **self.kwargs + ) + elif self.choice.kernel_creator is not None: + inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) + else: + cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc + inner = cls( + layout=self.layout, + inputs=self.input_nodes, + python_kernel_name=self.choice.call_name(), + cpp_kernel_name=self.choice.cpp_kernel_name, + ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel, + op_overload=self.choice.op_overload, + kwargs=self.kwargs, + ) + + return ir.TensorBox.create(inner) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "extern", + "kernel_call_name": self.choice.call_name(), + } + + def autoheuristic_id(self): + return f"extern_{self.choice.name}" + + +@functools.cache +def get_mm_log_filename() -> Optional[str]: + mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) + if not mm_file_name: + return None + + if "json" not in mm_file_name: + mm_file_name = f"{mm_file_name}.json" + + return mm_file_name + + +def append_to_log(filename, data): + lock_file = filename.replace(".json", ".lock") + lock = FileLock(lock_file) + with lock: + try: + with open(filename) as f: + log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + log_data = [] + + log_data.append(data) + + with open(filename, "w") as f: + json.dump(log_data, f, indent=4) + + +class DataProcessorChoiceCallerWrapper: + def __init__(self, wrapped, preprocessor, postprocessor) -> None: + self._wrapped = wrapped + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def benchmark(self, *args, out) -> float: + new_args, new_out = self._preprocessor(args, out) + result = self._wrapped.benchmark(*new_args, out=new_out) + new_out = self._postprocessor(new_out) + if out is not new_out: + out.copy_(new_out) + return result + + def output_node(self) -> ir.TensorBox: + result = self._wrapped.output_node() + return self._postprocessor(result) + + def __repr__(self) -> str: + return f"DataProcessorChoiceCallerWrapper({self._wrapped})" + + +class DataProcessorTemplateWrapper: + """ + A wrapper class for a kernel template. + + This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to + preprocess and postprocess data before and after using the wrapped template. A typical + usage is to reorder or filter the input nodes in order to match the expected input of other + kernel choices like a ATen kernel. A more complicated usage is to prepack the weights. + See the example from :mod:`cpp_gemm_template` for more details. + """ + + def __init__( + self, + wrapped_template_cls, + preprocessor, + postprocessor, + **kwargs, + ) -> None: + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + assert "input_nodes" in kwargs + assert "layout" in kwargs + kwargs["input_nodes"], kwargs["layout"] = preprocessor( + kwargs["input_nodes"], kwargs["layout"] + ) + self._wrapped = wrapped_template_cls(**kwargs) + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def maybe_append_choice(self, choices, **kwargs): + return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) + + def generate(self, **kwargs): + choice_caller = self._wrapped.generate(**kwargs) + return DataProcessorChoiceCallerWrapper( + choice_caller, self._preprocessor, self._postprocessor + ) + + def __repr__(self) -> str: + return f"DataProcessorTemplateWrapper({self._wrapped})" + + +class ErrorFromChoice(RuntimeError): + def __init__(self, msg, choice: ChoiceCaller, inputs_str) -> None: + msg += f"\nFrom choice {choice}\n{inputs_str}" + super().__init__(msg) + self.choice = choice + + +class NoValidChoicesError(RuntimeError): + pass + + +@functools.cache +def get_num_workers() -> int: + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + + cpu_count = ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count() + ) + assert cpu_count + + # Divide the number of CPUs by the number of GPUs for distributed workloads + if ( + config.is_fbcode() + and torch.cuda.is_available() + and torch.cuda.device_count() > 0 + ): + cpu_count = cpu_count // torch.cuda.device_count() + + return cpu_count + + +def create_inputs_key(input_nodes) -> str: + return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes]) + + +def create_precompile_key( + name: str, inputs_key: str, choices: list[ChoiceCaller] +) -> str: + return ":".join( + [ + name, + inputs_key, + torch.get_float32_matmul_precision(), + ] + + [choice.kernel_hash_key() for choice in choices] + ) + + +# Args to FeedbackFunctions +# timings: mapping from choices to the benchmark time +# name: name of the op +# input_nodes: list of input ir.py Nodes +# choices: list of choices +# profiled time: Callable that returns a dict mapping from choices to the profiled time +FeedbackFunction = Callable[ + [ + dict[ChoiceCaller, float], + str, + list[Any], + list[ChoiceCaller], + Callable[[], dict[ChoiceCaller, float]], + ], + None, +] + + +class AlgorithmSelectorCache(PersistentCache): + """ + A persistent cache for algorithm selection results used in autotuning of GEMMs + and convolutions. + + This classes includes precompilation and benchmarking of the kernels. + + The cache is keyed by input characteristics (sizes, strides, dtypes, etc.) but + doesn't depend on the output layout. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # the autotuning will get occur in the scheduler, so there is + # no guarantee that the first lowering for a given key will also be the + # first to benchmark it. share a single precompilation function for all lowerings + # of a particular key + self.precompile_cache: dict[str, Callable[[], None]] = {} + # list of callbacks that are called after benchmarking + self.feedback_saver_fns: list[FeedbackFunction] = [] + # cache for prescreening results to ensure deterministic candidate selection + self.prescreening_cache: dict[str, OrderedSet[str]] = {} + + clear_on_fresh_cache(self) + + def cache_clear(self) -> None: + self.precompile_cache.clear() + self.prescreening_cache.clear() + + def __call__( + self, + name, + choices: list[ChoiceCaller], + input_nodes, + layout, + # optional dict mapping arg indices to the functions + # generating a torch.Tensor for that input from the + # corresponding ir.Buffer. if passed for a given + # arg, the function will be called instead of + # generating a random torch.Tensor for benchmarking. + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, + precompilation_timeout_seconds: int = 60 * 60, + return_multi_template=False, + ): + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + # Templates selected with input_gen_fns require specific input data to avoid IMA + # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection + # TODO(jgong5): support multi-template on CPU + if input_gen_fns is not None or layout.device.type == "cpu": + return_multi_template = False + + # TODO - assert that we have not mutating kernels here + + if config.test_configs.autotune_choice_name_regex is not None: + choices = [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_name_regex, + c.name, + ) + ] + if config.test_configs.autotune_choice_desc_regex is not None: + choices = [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_desc_regex, + c.description, + ) + ] + + if mm_file_name := get_mm_log_filename(): + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + append_to_log(mm_file_name, {"invoke": str((M, K, N))}) + + if len(choices) == 0: + backend_config = ( + "max_autotune_gemm_backends" + if name != "convolution" + else "max_autotune_conv_backends" + ) + raise NoValidChoicesError( + f"No choices to select, please consider adding ATEN into {backend_config} " + "config (defined in torch/_inductor/config.py) to allow at least one choice. " + ) + log.debug("Max autotune selects from %s choices.", str(len(choices))) + + if len(choices) == 1: + if not isinstance(choices[0], CUDATemplateCaller): + # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. + return choices[0].output_node() + + @functools.cache + def make_benchmark_fn(): + return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) + + inputs_key = create_inputs_key(input_nodes) + + def autotune(choices): + log.debug("Starting autotuning") + + with dynamo_timed( + f"{name}_template_autotuning", + log_pt2_compile_event=True, + dynamo_compile_column_us="compile_time_autotune_time_us", + metadata={ + "autotune_strides": ", ".join( + [str(n.get_stride()) for n in input_nodes] + ), + "autotune_dtypes": ", ".join( + [str(n.get_dtype()) for n in input_nodes] + ), + "autotune_shape": ", ".join( + ["x".join(map(str, n.get_size())) for n in input_nodes] + ), + "autotune_offset": ", ".join( + [str(n.get_layout().offset) for n in input_nodes] + ), + }, + ): + return make_benchmark_fn()(choices) + + if config.autotune_in_subproc: + # Initialize the suprocess pool so it will warmup early. + torch._inductor.autotune_process.get_tuning_process_pool() + + def do_autotuning(choices, precompile_fn): + precompile_start_ts = time.time() + with dynamo_timed( + f"{name}_template_precompiling", + log_pt2_compile_event=True, + dynamo_compile_column_us="compile_time_autotune_time_us", + ): + precompile_fn() + precompile_elapse = time.time() - precompile_start_ts + log.debug("Precompilation elapsed time: %.02fs", precompile_elapse) + + candidates = self.prescreen_choices( + choices, name, inputs_key, self.prescreening_cache + ) + prescreening_elapse: Optional[float] = None + if candidates: + prescreening_start_ts = time.time() + timings = self.lookup( + candidates, + name, + inputs_key, + autotune, + ) + choices = self.prune_choices_postscreen( + choices, timings, name, inputs_key, self.prescreening_cache + ) + prescreening_elapse = time.time() - prescreening_start_ts + log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse) + + autotune_start_ts = time.time() + timings = self.lookup( + choices, + name, + inputs_key, + autotune, + ) + + autotune_elapse = time.time() - autotune_start_ts + log.debug("Autotuning elapsed time: %.02fs", autotune_elapse) + + if timings and all( + not math.isfinite(timing) for timing in timings.values() + ): + raise NoValidChoicesError + + if make_benchmark_fn.cache_info().currsize: + counters["inductor"]["select_algorithm_autotune"] += 1 + + if ( + make_benchmark_fn.cache_info().currsize + or log.getEffectiveLevel() == logging.DEBUG + or config.trace.log_autotuning_results + ): + self.log_results( + name, + input_nodes, + timings, + autotune_elapse, + precompile_elapse, + prescreening_elapse, + ) + + def profiler_bench_function(): + # we're not running through the normal caching autotuner method here because we want to avoid returning + # the cached value. + # Avoid benchmarking in a separate process because it's not easy to signal to the TuningProcess that we + # should use the profiler. + with config.patch( + profile_bandwidth_with_do_bench_using_profiling=True, + autotune_in_subproc=False, + ): + return self.make_benchmark_fn( + choices, input_nodes, layout, input_gen_fns + )(choices) + + for feedback_fn in self.feedback_saver_fns: + # re-benchmarking the same choices with profiler is a bit expensive, so pass it in as a thunk. + feedback_fn( + timings, + name, + input_nodes, + choices, + profiler_bench_function, + ) + + return timings + + precompile_fn = self.make_precompile_fn( + choices, + name, + inputs_key, + precompilation_timeout_seconds=precompilation_timeout_seconds, + ) + + if return_multi_template and (config.max_autotune or config.max_autotune_gemm): + + def get_timings(): + timings = do_autotuning(choices, precompile_fn) + min_extern_choice = float("inf") + for choice, timing in timings.items(): + if isinstance(choice, ExternKernelCaller): + min_extern_choice = min(min_extern_choice, timing) + + timings = { + choice: time + for choice, time in timings.items() + if ( + time <= min_extern_choice + or not isinstance(choice, ExternKernelCaller) + ) + } + + return timings + + # We take the union of allowed prologue inputs from all choices, + # and, within benchmark fusion, don't allow prologue fusion for + # choices which dont support the whole union. + allowed_prologue_inps: OrderedSet[str] = OrderedSet() + for c in choices: + if isinstance(c, TritonTemplateCaller): + allowed_prologue_inps |= c.allowed_prologue_inps + + return torch._inductor.ir.TensorBox.create( + torch._inductor.ir.MultiTemplateBuffer( + layout, + input_nodes, + get_timings, + choices, + allowed_prologue_inps, + ) + ) + + timings = do_autotuning(choices, precompile_fn) + + # if timings is empty, we really have no choice but to return a semi-random + # choice. returning the first `ExternKernelCaller` is probably the safest bet + # in this case, since it will generally be the ATen kernel. if there are no + # `ExternKernelCaller`s to return, then returning the 0th kernel is our next + # best option (ideally we'd fail whenever there is no ATen kernel to fallback + # to, but that's not trivial to figure out) + if timings == {}: + for choice in choices: + if isinstance(choice, ExternKernelCaller): + node = choice.output_node() + log.debug( + "Autotuning returned empty timings, falling back to first `ExternKernelCaller`: %s", + node, + ) + return node + node = choices[0].output_node() + log.debug( + "Autotuning returned empty timings, falling back to first choice: %s", + node, + ) + return node + + # if we got any timings at all, pick the best of those + choice = min(timings, key=timings.__getitem__) + node = choice.output_node() + log.debug("Autotuning selected choice: %s", node) + return node + + def make_precompile_fn( + self, + choices, + name: str, + inputs_key: str, + precompilation_timeout_seconds: Optional[int] = 60 * 60, + ) -> Callable[[], None]: + """ + Returns a function that precompiles the given choices. + """ + log.debug("Starting precompilation") + + def no_op(*args, **kwargs): + return + + if ( + precompilation_timeout_seconds is None + or precompilation_timeout_seconds <= 0 + ): + log.debug("Precompilation timeout is None or <= 0, returning no_op") + return no_op + + num_workers = min(get_num_workers(), len(choices)) + + if num_workers <= 0: + return no_op + + # https://github.com/python/cpython/issues/106905 + if ( + sys.version_info.major == 3 + and sys.version_info.minor == 11 + and sys.version_info.micro <= 8 + ): + return no_op + + # check local and global cache before precompiling + timings = self.lookup( + choices, + name, + inputs_key, + benchmark=None, + ) + + if timings and len(timings) == len(choices): + # compilation in precompile stage is much cheaper than that in + # autotuning stage + log.debug("Found all %d timings in cache, returning no_op", len(timings)) + return no_op + + if config.search_autotune_cache and not ( + config.max_autotune or config.max_autotune_gemm + ): + return no_op + + precompile_key = create_precompile_key(name, inputs_key, choices) + if precompile_func := self.precompile_cache.get(precompile_key): + log.debug("Precompile function found in cache, returning it") + return precompile_func + + log.info( + "Multithreaded precompilation for %d choices using %d worker threads", + len(choices), + num_workers, + ) + + # In rare circumstances, because python threads inherit global state, + # thread pool executor can race and leave stdout/stderr in a state + # different than the original values. we explicitly restore the state + # here to avoid this issue. + + def precompile_with_captured_stdout(choice) -> tuple[None, int]: + log.debug("Precompiling choice with captured stdout: %s", choice) + start_ns = time.time_ns() + with restore_stdout_stderr(): + choice.precompile() + elapsed_ns = time.time_ns() - start_ns + # Return tuple as triton async compile (_worker_compile_triton) + # returns tuple[CachingAutotuner, int] + return None, elapsed_ns // 1000 + + def on_complete(future): + if not future.exception(): + _, precompile_elapsed_us = future.result() + elapsed_seconds = precompile_elapsed_us / 1e6 + elapsed_times[future] = elapsed_seconds + log.debug( + "Precompilation complete for future: %s, elapsed time: %.02fs", + future, + elapsed_seconds, + ) + + executor = ThreadPoolExecutor(max_workers=num_workers) + async_compile = torch._inductor.async_compile.AsyncCompile() + + futures: dict[concurrent.futures.Future[Any], ChoiceCaller] = {} + elapsed_times: dict[concurrent.futures.Future[Any], float] = {} + + # Some choices only differ in runtime arguments, so we + # skip a choice if it has the same hash as a previously seen choice + seen_choices: OrderedSet[str] = OrderedSet() + for c in choices: + # Skip choices which we have already issued a precompile + if c.kernel_hash_key() in seen_choices: + log.debug("Skipping already seen choice: %s", c) + continue + else: + seen_choices.add(c.kernel_hash_key()) + + if hasattr(c, "precompile"): + triton_cuda_choice = isinstance(c, TritonTemplateCaller) and isinstance( + c.bmreq, TritonGPUBenchmarkRequest + ) + if triton_cuda_choice and async_compile.use_process_pool(): + with open(c.bmreq.module_path) as file: + source_code = file.read() + future = async_compile.triton( + kernel_name=c.bmreq.kernel_name, source_code=source_code + ).future + log.debug("Submitted triton async compile for choice: %s", c) + else: + future = executor.submit(precompile_with_captured_stdout, c) + log.debug("Submitted precompile for choice: %s", c) + + future.add_done_callback(on_complete) + futures[future] = c + + @functools.cache + @restore_stdout_stderr() + def wait_on_futures(): + log.debug("Waiting on futures") + counters["inductor"]["select_algorithm_precompile"] += 1 + for future in as_completed( + futures, + timeout=precompilation_timeout_seconds, + ): + if e := future.exception(): + from torch._inductor.codegen.cuda.cuda_kernel import ( + CUDATemplateCaller, + ) + + if isinstance(e, CUDACompileError) and isinstance( + futures[future], CUDATemplateCaller + ): + log.debug( + "Exception %s for benchmark choice %s", + e, + futures[future], + exc_info=True, + ) + else: + log.error( + "Exception %s for benchmark choice %s", e, futures[future] + ) + else: + counters["inductor"]["select_algorithm_num_precompiles"] += 1 + log.info( + "Precompiling benchmark choice %s took %.02fs", + futures.get(future), + elapsed_times.get(future), + ) + + executor.shutdown(wait=True) + + self.precompile_cache[precompile_key] = wait_on_futures + + return wait_on_futures + + @classmethod + def get_inputs( + cls, + choices: Sequence[ChoiceCaller], + input_nodes: list[ir.IRNode], + layout: ir.Layout, + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], + ) -> AutotuneArgs: + """ + Factory method to create AutotuneArgs from a list of ChoiceCallers. + """ + if input_gen_fns is None: + input_gen_fns = {} + + # de-duplicate args + unique_example_inputs = { + x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) + for i, x in enumerate(input_nodes) + } + example_inputs = list(unique_example_inputs.values()) + example_inputs_extern = [ + ( + unique_example_inputs[input_node.get_name()] + if unique_example_inputs[input_node.get_name()].is_mkldnn + else torch.as_strided( + unique_example_inputs[input_node.get_name()], + V.graph.sizevars.size_hints( + input_node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + input_node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + ) + for input_node in input_nodes + ] + out = cls.benchmark_example_value(layout) + out_extern = torch.as_strided( + out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset) + ) + expected = None + if VERIFY: + choices[0].benchmark(*example_inputs_extern, out=out_extern) + expected = out_extern.clone() + + return AutotuneArgs.from_choice_args( + example_inputs, + example_inputs_extern, + out, + out_extern, + expected, + ) + + @classmethod + def benchmark_choice( + cls, choice: ChoiceCaller, autotune_args: AutotuneArgs + ) -> float: + is_extern = isinstance(choice, (ExternKernelCaller, SubgraphChoiceCaller)) + benchmark_tensors = autotune_args.get_benchmark_tensors(is_extern) + inpts, output = benchmark_tensors.unpack() + output.zero_() + result = choice.benchmark(*inpts, out=output) + device_type = next( + (tensor.device.type for tensor in inpts if is_gpu(tensor.device.type)), + "cuda", + ) + device_interface = get_interface_for_device(device_type) + if device_interface.is_available(): + device_interface.synchronize() # shake out any CUDA errors + + if VERIFY and autotune_args.expected is not None: + autotune_args.verify(**VERIFY) + return result + + @classmethod + def benchmark_choices( + cls, + choices: Sequence[ChoiceCaller], + autotune_args: AutotuneArgs, + ) -> dict[ChoiceCaller, float]: + timings = {} + for choice in choices: + try: + timing = cls.benchmark_choice(choice, autotune_args) + except CUDACompileError as e: + from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller + + if not isinstance(choice, CUDATemplateCaller): + log.error( + "CUDA compilation error during autotuning: \n%s. \nIgnoring this choice.", + e, + ) + timing = float("inf") + except NotImplementedError as e: + log.warning("Not yet implemented: %s", e) + timing = float("inf") + except RuntimeError as e: + from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller + + msg = str(e) + if "invalid argument" in msg: + msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n" + else: + if "illegal memory access" in msg: + msg += "\n\nEither error in template or triton bug.\n" + + if isinstance(choice, CUDATemplateCaller): + log.debug( + "Runtime error during autotuning: \n%s. \nIgnoring this choice.", + msg, + exc_info=True, + ) + else: + log.error( + "Runtime error during autotuning: \n%s. \nIgnoring this choice.", + msg, + ) + timing = float("inf") + except AssertionError as e: + raise AssertionError( # noqa: B904 + f"Incorrect result from choice {choice}\n\n{e}" + ) + except Exception as e: + try: + from triton.runtime.autotuner import OutOfResources + + if isinstance(e, OutOfResources): + log.warning(e) + timing = float("inf") + else: + raise e + except ImportError: + raise e from None + + timings[choice] = timing + + return timings + + @classmethod + def benchmark_in_current_process( + cls, + choices: Sequence[ChoiceCaller], + input_nodes: list[ir.IRNode], + layout: ir.Layout, + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], + ) -> dict[ChoiceCaller, float]: + inputs = cls.get_inputs(choices, input_nodes, layout, input_gen_fns) + return cls.benchmark_choices(choices, inputs) + + @classmethod + def benchmark_in_sub_process( + cls, + choices: Sequence[ChoiceCaller], + input_nodes: list[ir.IRNode], + layout: ir.Layout, + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], + ): + from . import autotune_process + + # only benchmark triton kernel in sub process for now. + # ATen/Extern kernel are still benchmarked in the current process. + extern = [c for c in choices if isinstance(c, ExternKernelCaller)] + triton = [c for c in choices if not isinstance(c, ExternKernelCaller)] + + timings = cls.benchmark_in_current_process( + extern, input_nodes, layout, input_gen_fns + ) + timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type] + return timings + + @classmethod + def make_benchmark_fn( + cls, + choices: Sequence[ChoiceCaller], + input_nodes: list[ir.IRNode], + layout: ir.Layout, + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], + ): + if DEBUG: + print(f"{len(choices)} tuning requests:") + + if config.autotune_in_subproc: + return functools.partial( + cls.benchmark_in_sub_process, + input_nodes=input_nodes, + layout=layout, + input_gen_fns=input_gen_fns, + ) + else: + return functools.partial( + cls.benchmark_in_current_process, + input_nodes=input_nodes, + layout=layout, + input_gen_fns=input_gen_fns, + ) + + @staticmethod + def prescreen_choices( + choices: list[ChoiceCaller], + name: str, + inputs_key: str, + prescreen_cache: dict[str, OrderedSet[str]], + ) -> list[ChoiceCaller]: + """ + Figure out what choices need to be prescreened before autotuning with runtime + params. + + Prescreening is a process of reducing the number of autotuning for choices with + runtime params via a two stage autotuning process. First, we fix a set of runtime + params (here we use swizzle=2) and run autotuning to get a set of candidates. + Then, we run autotuning again with the candidates and the full set of runtime + params. + + Since have the concept of runtime params, we need to differentiate between + choice's hash_key and choice's kernel_hash_key. The former includes information + like runtime params, while the latter does not. prescreen_cache, if exists, stores + the set of hash_key that should win the prescreening. + + Right now, only CUTLASS choices have runtime params. + """ + # Create a cache key for prescreening results + prescreen_key = f"{name}:{inputs_key}" + + # Check if we have cached prescreening results (prescreen_winners) + if prescreen_key in prescreen_cache: + prescreen_winners = [ + choice + for choice in choices + if choice.hash_key() in prescreen_cache[prescreen_key] + ] + return prescreen_winners + + # prescreen cutlass + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + candidates = [] + if ( + config.cuda.cutlass_prescreening + and len(config.cuda.cutlass_max_profiling_swizzle_options) > 1 + ): + candidates.extend( + [ + c + for c in choices + if isinstance(c, CUDATemplateCaller) + # hardcoded to only look at swizzle=2 + if c.info_dict().get("swizzle") == "2" + ] + ) + + # skip prescreening if the number of candidates is too small + if len(candidates) < 10: + return [] + + return candidates # type: ignore[return-value] + + @staticmethod + def prune_choices_postscreen( + choices: list[ChoiceCaller], + candidate_timings: dict[ChoiceCaller, float], + name: str, + inputs_key: str, + prescreen_cache: dict[str, OrderedSet[str]], + ) -> list[ChoiceCaller]: + """ + Prune the choices after prescreening. + """ + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + prescreen_key = f"{name}:{inputs_key}" + + # Check if we have cached postscreen results + if prescreen_key in prescreen_cache: + # candidate_timings are from choices that have won prescreening already + winner_kernel_hashes = [ + candidate.kernel_hash_key() for candidate in candidate_timings + ] + + pruned_choices = [ + choice + for choice in choices + if not isinstance(choice, CUDATemplateCaller) + or choice.kernel_hash_key() in winner_kernel_hashes + ] + return pruned_choices + + log.debug("Before pruning using prescreening timings, %d choices", len(choices)) + sorted_candidates = sorted( + candidate_timings.keys(), key=lambda choice: candidate_timings[choice] + ) + + # Print prescreening timings + if ( + candidate_timings + and PRINT_AUTOTUNE + and config.autotune_num_choices_displayed != 0 + ): + n = config.autotune_num_choices_displayed + top_k = sorted_candidates[:n] + best = top_k[0] + best_time = candidate_timings[best] + + lines = ["PRESCREENING CANDIDATE TIMINGS"] + for choice in top_k: + result = candidate_timings[choice] + if result: + lines.append( + f" {choice.name} {result:.4f} ms {best_time / result:.1%} {choice.description}" + ) + else: + lines.append( + f" {choice.name} {result:.4f} ms " + ) + + log.info("\n".join(lines)) + num_to_keep = max(int(math.sqrt(len(choices)) / 4), 8) + + # prune choices based on prescreening timings + candidates_to_prune = OrderedSet( + candidate.kernel_hash_key() for candidate in sorted_candidates[num_to_keep:] + ) + winner_hashes: OrderedSet[str] = OrderedSet() + for candidate in sorted_candidates[:num_to_keep]: + if candidate_timings[candidate] == float("inf"): + candidates_to_prune.add(candidate.kernel_hash_key()) + else: + winner_hashes.add(candidate.hash_key()) + if isinstance(candidate, CUDATemplateCaller): + candidate.bmreq.ensure_dll_loaded() + + pruned_choices = [ + choice + for choice in choices + if choice.kernel_hash_key() not in candidates_to_prune # type: ignore[attr-defined] + ] + + # Cache the hash_key of winners of prescreening + prescreen_cache[prescreen_key] = winner_hashes + + log.debug( + "After pruning using prescreening timings, %d choices", len(pruned_choices) + ) + return pruned_choices + + @staticmethod + def log_results( + name: str, + input_nodes: list[ir.IRNode], + timings: dict[ChoiceCaller, float], + elapse: float, + precompile_elapse: float, + prescreening_elapse: Optional[float] = None, + ): + V.debug.log_autotuning_results( + name, input_nodes, timings, elapse, precompile_elapse + ) + if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE: + return + sizes = ", ".join( + [ + "x".join( + map( + str, + V.graph.sizevars.size_hints( + n.get_size(), + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ), + ) + ) + for n in input_nodes + ] + ) + + strides = ", ".join([str(n.get_stride()) for n in input_nodes]) + dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes]) + if config.autotune_num_choices_displayed == 0: + return + # when autotune_num_choices_displayed is None, [:None] means all + n = config.autotune_num_choices_displayed + top_k = sorted(timings, key=timings.__getitem__)[:n] + + best = top_k[0] + + def get_choice_info(choice): + if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): + return {"type": "cublas", "time": timings[choice]} + + assert isinstance( + choice, torch._inductor.select_algorithm.TritonTemplateCaller + ) + + info = choice.info_dict() + tile = info["tile_shape"] + + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + + return { + "type": "triton", + "time": timings[choice], + "BLOCK_M": BLOCK_M, + "BLOCK_K": BLOCK_K, + "BLOCK_N": BLOCK_N, + "num_stages": info["num_stages"], + "num_warps": info["num_warps"], + } + + mm_filename = get_mm_log_filename() + if mm_filename and "mm" in name: + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + + out_dict = { + str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()] + } + + append_to_log(mm_filename, out_dict) + + best_time = timings[best] + sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") + sys.stderr.write(f"strides: {strides}\n") + sys.stderr.write(f"dtypes: {dtypes}\n") + + for choice in top_k: + result = timings[choice] + if result: + kernel_description = choice.description + sys.stderr.write( + f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_description}\n" + ) + else: + sys.stderr.write( + f" {choice.name} {result:.4f} ms \n" + ) + + autotune_type_str = ( + "SubProcess" if config.autotune_in_subproc else "SingleProcess" + ) + prescreening_msg = ( + f" and {prescreening_elapse:.4f} seconds prescreening" + if prescreening_elapse is not None + else "" + ) + sys.stderr.write( + f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}" + f" seconds precompiling for {len(timings)} choices" + + prescreening_msg + + "\n" + ) + + @staticmethod + def benchmark_example_value(node): + """ + Convert an ir.Buffer into a concrete torch.Tensor we can use for + benchmarking. + """ + if isinstance(node, ir.Layout): + node = ir.Buffer(name="fake", layout=node) + # triton templates want the base tensor. + if isinstance(node, ir.BaseView): + node = node.unwrap_view() + + # Inplace padding may reinterpret a tensor to a larger tensor if the + # stride is large enough. The V.graph.get_allocation_size takes this into account. + # So we need call as_strided in the end to 'view' the tensor with the correct + # sizes/strides + return AlgorithmSelectorCache.generate_example_value( + V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + node.get_device(), + node.get_dtype(), + node.layout.offset, + V.graph.sizevars.size_hints( + V.graph.get_allocation_size(node), + fallback=config.unbacked_symint_fallback, + ), + ) + + @staticmethod + def generate_example_value( + size, stride, device, dtype, extra_size, allocation_size=None + ): + # preserve rng states to avoid the rand_strided call below changes + # the rng states for the real model code. + with preserve_rng_state(): + if allocation_size is None or allocation_size == size: + return rand_strided( + size, + stride, + device=device, + dtype=dtype, + extra_size=extra_size, + ) + else: + return rand_strided( + allocation_size, + stride, + device=device, + dtype=dtype, + extra_size=extra_size, + ).as_strided(size, stride) + + @staticmethod + def key_of(node): + """ + Extract the pieces of an ir.Buffer that we should invalidate cached + autotuning results on. + """ + sizevars = V.graph.sizevars + return ( + node.get_device().type, + str(node.get_dtype()), + *sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + *sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + + def add_feedback_saver(self, fn: FeedbackFunction): + self.feedback_saver_fns.append(fn) + + +_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None + + +def autotune_select_algorithm(*args, **kwargs): + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + + if "return_multi_template" not in kwargs: + kwargs["return_multi_template"] = ( + torch._inductor.config.benchmark_epilogue_fusion + ) + + if "precompilation_timeout_seconds" not in kwargs: + kwargs["precompilation_timeout_seconds"] = config.precompilation_timeout_seconds + + return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) + + +def add_feedback_saver( + fn: FeedbackFunction, +): + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + _ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn) + + +def realize_inputs(*args): + if len(args) == 1: + return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0])) + return [realize_inputs(x) for x in args] + + +class SymbolicGridFn: + """ + Wrapper around a grid function that allows either int or sympy inputs. + + @SymbolicGridFn + def grid(x, meta, *, cdiv): + return cdiv(x, meta["BLOCK_X"]) + """ + + def __init__(self, fn: Callable[..., tuple[Any, Any, Any]]): + self.fn = fn + self.kwargs_int = {} + self.kwargs_sym = {} + params = inspect.signature(fn).parameters + for name, fn_sym, fn_int in [ + ("cdiv", CeilDiv, ceildiv), + ("min", sympy.Min, min), + ("max", sympy.Max, max), + ]: + if name in params: + self.kwargs_int[name] = fn_int + self.kwargs_sym[name] = fn_sym + + def __call__(self, *args, **kwargs) -> tuple[int, int, int]: + return self.fn(*args, **kwargs, **self.kwargs_int) + + def sympy_call(self, *args, **kwargs): + return self.fn(*args, **kwargs, **self.kwargs_sym) + + +# ensure lowering is imported so that `extern_kernels.*` is populated +from . import lowering # noqa: F401 diff --git a/phivenv/Lib/site-packages/torch/_inductor/sizevars.py b/phivenv/Lib/site-packages/torch/_inductor/sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..32cfdbd56ebcb65deea9a3983762f6044650dc4d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/sizevars.py @@ -0,0 +1,976 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections.abc import Iterable, Sequence +from typing import Any, Callable, cast, Optional, Union + +import sympy +from sympy import Expr + +from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols, ShapeEnv +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import symbol_is_type, SymT +from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges + +from .runtime.runtime_utils import is_power_of_2 +from .utils import ( + has_free_symbols, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_subs, + VarRanges, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + + +def statically_known_true( + shape_env: ShapeEnv, + expr: Union[sympy.Basic, bool], + axioms: Optional[tuple[sympy.Expr]] = None, + var_to_range: Optional[tuple[tuple[sympy.Symbol, ValueRanges[Any]]]] = None, +) -> bool: + if expr in (True, False): + return bool(expr) + + try: + simplified = shape_env._maybe_evaluate_static( + expr, + axioms=axioms, + var_to_range=var_to_range, + ) + if simplified is not None: + return bool(simplified) + except Exception: + log.debug("Could not simplify %s", expr, exc_info=True) + + return False + + +# This class is a little awkward, because ShapeEnv is doing most of the heavy +# lifting and in some cases we should be directly passing through to ShapeEnv, +# but there is some extra inductor logic that needs to be handled here +class SizeVarAllocator: + def __init__(self, shape_env=None) -> None: + super().__init__() + if shape_env is None: + shape_env = ShapeEnv() + self.shape_env = shape_env + self.var_to_val = self.shape_env.var_to_val + self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements + self.unbacked_replacements: Optional[dict[Expr, Expr]] = None + # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. + # The basic idea is if we have some complicated sympy expression + # f(s0), we may choose to precompute it on the host and then replace + # all occurrences of that sympy expression with ps0, so that when we + # codegen we simply reference ps0 directly without repeating + # f(s0). Unlike regular size variables, ps variables cannot be + # guarded upon; so if we are asked to guard on a Sympy expression + # which potentially could have already had a precomputed replacement + # on it, we are obligated to invert the precomputed replacements + # (inv_precomputed_replacements). + self.precomputed_replacements: dict[Expr, sympy.Symbol] = {} + self.inv_precomputed_replacements: dict[sympy.Symbol, Expr] = {} + self.stride_vars = self.make_stride_vars_cache() + self.simplify_with_ranges = self.make_simplify_with_ranges_cache() + self._simplify_loops = self.make_simplify_loops_cache() + + def simplify(self, expr: Expr): + return sympy.expand(expr).xreplace(self.replacements) + + def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]: + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: dict[tuple[Any, ...], Expr] = {} + replacement_count = len(self.replacements) + + def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr: + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (expr, *var_ranges.items()) + result = cache.get(key, None) + if result is None: + result = self._simplify_with_ranges(expr, var_ranges) + cache[key] = result + if result != expr: + cache[(result, *var_ranges.items())] = result + return result + + return simplify_with_ranges + + def make_simplify_loops_cache(self): + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: dict[tuple[Any, ...], Any] = {} + replacement_count = len(self.replacements) + + def simplify_loops(index_vars, sizes, index_formulas): + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (*index_vars, *sizes, *index_formulas) + result = cache.get(key, None) + if result is None: + result = self._simplify_loops_impl(index_vars, sizes, index_formulas) + cache[key] = result + return result + + return simplify_loops + + def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr: + """ + Simplify indexing expression with knowledge of the ranges of + iteration variables. + """ + + expr = join_dimensions(self.simplify(expr)) + original_expr = expr + + var_to_range = dict(self.shape_env.var_to_range) + var_to_range.update( + { + k: ValueRanges( + 0, max(0, v - 1) if not has_free_symbols([v]) else IntInfinity() + ) + for k, v in var_ranges.items() + } + ) + for var in expr.free_symbols: + if var not in var_to_range: + var_to_range[var] = ValueRanges(0, IntInfinity()) + + var_to_range_tuple = cast( + tuple[tuple[sympy.Symbol, ValueRanges[sympy.Expr]]], + tuple(var_to_range.items()), + ) + + axioms = [] + for var, upper_bound in var_ranges.items(): + axioms.append(0 <= var) + axioms.append(var < upper_bound) + axioms = tuple(axioms) + self.shape_env.get_axioms() + + def statically_known(expr): + evaluated = self.shape_env._maybe_evaluate_static( + expr, + axioms=axioms, + var_to_range=var_to_range_tuple, + ) + return bool(evaluated) + + def remove_zero_terms(base, divisor): + """Symbols smaller than the divisor are zero""" + if not statically_known(base >= 0): + return base + + for v in base.free_symbols: + if v in var_ranges: + # var smaller than divisor can be removed + # if the rest is guaranteed to be multiple of divisor + rest = sympy.Wild("_rest", exclude=[v]) + m = base.match(v + rest) + if m and v not in m[rest].free_symbols: + gcd = sympy.gcd(m[rest], divisor) + if gcd == divisor: + if statically_known(v < divisor): + base = m[rest] + return base + + def visit_indexing_div(base, divisor): + return FloorDiv(remove_zero_terms(base, divisor), divisor) + + def visit_modular_indexing(base, divisor, modulus): + base = remove_zero_terms(base, divisor) + + can_remove_mod = statically_known(base >= 0) and statically_known( + base < modulus * divisor + ) + + if can_remove_mod: + return FloorDiv(base, divisor) + return ModularIndexing(base, divisor, modulus) + + if expr.has(ModularIndexing): + expr = expr.replace( + ModularIndexing( + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + sympy.Wild("modulus", integer=True), + ), + visit_modular_indexing, + ) + + if expr.has(FloorDiv): + expr = expr.replace( + FloorDiv( + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + ), + visit_indexing_div, + ) + + if expr != original_expr: + return self._simplify_with_ranges(expr, var_ranges) + return expr + + def _simplify_loops_impl( + self, index_vars: list[sympy.Symbol], sizes, index_formulas + ): + """ + Try to remove as many axis from loop iterations as possible, by: + 1) removing size==1 dimensions + 2) fuse contiguous dimensions into a single loop + If channel_last = True, we will prevent the last dim fused with other dims + """ + sizes = list(map(self.simplify, sizes)) + + strides = [ + # index_formulas may contain boolean expressions (e.g. s0 < 10), + # for which "strides" don't make sense so we ignore them here. + # NOTE: These expressions may still block merging dims in the sound + # substitution test performed in can_merge_dims. + ( + self.stride_vars(x, index_vars) + if isinstance(x, sympy.Expr) + else [0] * len(index_vars) + ) + for x in index_formulas + ] + assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) + + for i in range(len(sizes)): + if sizes[i] == 1: + # remove dim + sizes[i] = None + + def can_merge_dims(a, b): + for k in range(len(strides)): + if self.simplify(strides[k][a] * sizes[a]) == self.simplify( + strides[k][b] + ): + # approximate test passed, try sound version + va = index_vars[a] + vb = index_vars[b] + m1 = sympy_index_symbol("_merge_tester1") + m2 = sympy_index_symbol("_merge_tester2") + # NOTE: can't sub vb=0 here in case va * vb appears in the expression, + # in which case both expr1 and expr2 would be zero! + expr1 = sympy_subs(index_formulas[k], {va: m1 * sizes[a], vb: m2}) + expr2 = sympy_subs(index_formulas[k], {va: 0, vb: (m1 + m2)}) + if self.simplify(expr1) == self.simplify(expr2): + continue + return False + return True + + changed = True + while changed: + changed = False + for i, j in itertools.product( + reversed(range(len(sizes))), reversed(range(len(sizes))) + ): + if i == j or sizes[i] is None or sizes[j] is None: + continue + if can_merge_dims(i, j): + changed = True + sizes[i] = sizes[i] * sizes[j] + sizes[j] = None + + def reindex(index): + it = list(reversed(index)) + new_index = [] + for size in sizes: + if size is None: + new_index.append(sympy.S.Zero) + else: + new_index.append(it.pop()) + assert not it + return new_index + + def prune(index): + assert len(index) == len(sizes) + return [i for i, s in zip(index, sizes) if s is not None] + + return [x for x in sizes if x is not None], reindex, prune + + # Note - [On Statically Known] + # The statically_known_* family of functions below NEVER guard, they could return True if the + # asked questions can be answered without guarding otherwise they return False. + # Those are similar to statically_known_true in symbolic_shapes but operate on sympy + # expressions instead of symnodes. + def statically_known_true(self, expr: Union[sympy.Basic, bool]) -> bool: + """ + Returns true if an expression is always true (symbolically or via guards), + false otherwise. Never add guards, or throw data dependent errors. + """ + return statically_known_true(self.shape_env, expr) + + def statically_known_equals( + self, left: Union[Expr, int], right: Union[Expr, int] + ) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right are equal. + """ + return self.statically_known_true(sympy.Eq(left, right)) # type: ignore[arg-type] + + def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right lists are equal. + """ + return len(left) == len(right) and all( + self.statically_known_equals(l, r) for l, r in zip(left, right) + ) + + def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than or equal to right. + """ + expr = left <= right + return self.statically_known_true(expr) + + def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right. + """ + expr = left >= right + return self.statically_known_true(expr) + + def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than right. + """ + expr = left < right + return self.statically_known_true(expr) + + def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than right. + """ + expr = left > right + return self.statically_known_true(expr) + + def statically_known_multiple_of( + self, numerator: Expr, denominator: Union[Expr, int] + ) -> bool: + """ + Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. + """ + # The reason we skip unbacked here is that we want to avoid the cost of trying to eval this symbolically. + if has_free_unbacked_symbols(numerator) or has_free_unbacked_symbols( + denominator + ): + return False + expr = sympy.Eq(numerator % denominator, 0) + return self.statically_known_true(expr) # type: ignore[arg-type] + + def statically_known_power_of_2(self, expr: Expr) -> bool: + """ + Returns a bool indicating if x is known to be a power of 2. + """ + return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr)) + + # The guard functions require you to ALREADY KNOW that a particular + # condition holds. If you don't know (you want to guard on an expression + # being a particular value, and then get access to that value), use + # the evaluate functions. + + def guard_equals(self, left: Expr, right: Expr) -> Expr: + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + + expr = sympy.Eq(left, right) + static_expr = self.shape_env._maybe_evaluate_static(expr) + + if static_expr is not None: + assert bool(static_expr) + return left + + assert self.shape_env.guard_or_defer_runtime_assert(expr, "guard_equals") + return left + + def guard_leq(self, left: Expr, right: Expr) -> None: + return self.guard_lt(left, right + 1) + + def guard_lt(self, left: Expr, right: Expr) -> None: + expr = sympy.Lt(left, right) + static_expr = self.shape_env._maybe_evaluate_static(expr) + + if static_expr is not None: + assert bool(static_expr) + return + + assert self.shape_env.guard_or_defer_runtime_assert(expr, "guard_lt") + + def guarded_order(self, seq): + """ + Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing. + """ + seq = [*map(self.remove_precomputed_replacements, seq)] + seq = [ + (self.size_hint_or_throw(var), orig_idx, var) + for orig_idx, var in enumerate(seq) + ] + seq.sort() + order = [-1] * len(seq) + last_var = None + for new_index, (_, orig_index, var) in enumerate(seq): + order[orig_index] = new_index + if last_var is not None: + self.guard_leq(last_var, var) + last_var = var + return order + + # Similar to the functions guard_or_false/guard_or_true in symbolic_shapes but operates on sympy + # expressions instead of symnodes. see Note [guard_or_]. + + def guard_or_false(self, left): + return self.evaluate_expr(left, fallback_value=False) + + def guard_or_true(self, left): + return self.evaluate_expr(left, fallback_value=True) + + # The evaluate functions evaluate some symbolic sympy expression + # (NB: not necessarily an Expr) and return what the concrete result + # is, guarding on the expression being that result + + # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b) + # as this will ensure that you actually have a sympy'ified expression, + # and will prevent you from incorrectly writing evaluate_expr(a == b) + # which does the wrong thing if a or b is a sympy expression + def evaluate_expr( + self, + left: Union[Expr, sympy.logic.boolalg.Boolean], + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + ) -> bool: + assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) + return self.shape_env.evaluate_expr( + sympy.sympify(left), + size_oblivious=size_oblivious, + fallback_value=fallback_value, + ) + + def evaluate_min(self, left: Expr, right: Expr) -> Expr: + """return the smaller of left and right, and guard on that choice""" + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + try: + lv = self.size_hint_or_throw(left) + rv = self.size_hint_or_throw(right) + except TypeError: # unbacked symints + if left == right or self.statically_known_leq(left, right): + return left + if self.statically_known_leq(right, left): + return right + gcd = sympy.gcd(left, right) + if left == gcd: # handle `min(10*u0, u0)` etc + return left + if right == gcd: + return right + raise TypeError( + f"evaluate_min({left}, {right}) with unbacked symints" + ) from None + if lv <= rv: + self.guard_leq(left, right) + return left + else: + self.guard_leq(right, left) + return right + + def evaluate_max(self, left: Expr, right: Expr) -> Expr: + """return the larger of left and right, and guard on that choice""" + # Always choose the opposite of eval min for consistency + # This means min(a, b) and max(a, b) produce the same guards + min_val = self.evaluate_min(left, right) + return right if min_val is left else left + + def evaluate_static_shape(self, left: Union[Expr, int]) -> int: + if isinstance(left, int): + return left + right = self.size_hint_or_throw(left) + self.guard_equals(left, sympy.Integer(right)) + return int(right) + + def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]: + return [self.evaluate_static_shape(x) for x in left] + + def remove_precomputed_replacements(self, expr: Expr) -> Expr: + if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined] + return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] + return expr + + def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]: + if isinstance(expr, int): + return expr + # Substitute all hints into expr, but leave unbacked symints alone + expr = self.simplify(expr) + if not isinstance(expr, Expr): + assert isinstance(expr, int) + return expr + free_symbols = expr.free_symbols + if not free_symbols: + try: + return int(expr) # type: ignore[return-value] + except TypeError: + return expr # inf/nan/I + expr = self.remove_precomputed_replacements(expr) + return sympy_subs(expr, self.var_to_val) + + def size_hint( + self, expr: Union[Expr, int], *, fallback: Optional[int] = None + ) -> int: + out = self.symbolic_hint(expr) + if not isinstance(out, (int, sympy.Integer)) and fallback is not None: + # Use the provided heuristic fallback hint + unbacked_sym_vrs = { + s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols + } + if all(vr is not None for vr in unbacked_sym_vrs.values()): + hint_vr = bound_sympy(out, unbacked_sym_vrs) # type: ignore[arg-type] + if isinstance(hint_vr.lower, (int, sympy.Integer)): + fallback = max(fallback, int(hint_vr.lower)) + if isinstance(hint_vr.upper, (int, sympy.Integer)): + fallback = min(fallback, int(hint_vr.upper)) + return fallback + + try: + return int(out) + except Exception: + log.debug("failed on: %s", out) + raise + + def size_hint_or_throw(self, expr: Union[Expr, int]) -> int: + out = self.symbolic_hint(expr) + try: + return int(out) + except Exception: + log.debug("failed on: %s", out, exc_info=True) + raise + + def size_hints( + self, + exprs: Iterable[Union[Expr, int]], + *, + fallback: Optional[int] = None, + ) -> tuple[int, ...]: + return tuple(self.size_hint(x, fallback=fallback) for x in exprs) + + def _lru_cache(self, fn, maxsize=None): + """ + Wrapper around functools.lru_cache that clears when replacements + has been invalidated. + """ + fn_cache = functools.lru_cache(maxsize)(fn) + prior_len = len(self.replacements) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + nonlocal prior_len + if prior_len != len(self.replacements): + prior_len = len(self.replacements) + fn_cache.cache_clear() + return fn_cache(*args, **kwargs) + + return wrapper + + def make_stride_vars_cache(self): + cache = self._lru_cache(self._stride_vars) + + def stride_vars( + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Optional[Sequence[sympy.Symbol]] = None, + ) -> list[Expr]: + if not support_vars: + support_vars = vars + return cache(index, tuple(vars), tuple(support_vars)) + + return stride_vars + + def _stride_vars( + self, + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Sequence[sympy.Symbol], + ) -> list[Expr]: + """Convert an indexing expression back into strides + + NOTE: This is only valid if the index is a standard strided offset + calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a + stride of -10 because the index wraps around after the first element + + """ + strides = [] + index = self.simplify(index) + # remove any offset + index = index - sympy_subs( + index, {v: sympy.S.Zero for v in support_vars if v != 0} + ) + for i in range(len(vars)): + # drop all the other dims + index_dim = sympy_subs( + index, + { + support_vars[j]: sympy.S.Zero + for j in range(len(support_vars)) + if vars[i] != support_vars[j] and support_vars[j] != 0 + }, + ) + v = vars[i] + if v == 0: + strides.append(sympy.S.Zero) + else: + # TODO(jansel): should we use sympy.diff here? + strides.append( + sympy_subs(index_dim, {v: sympy.S.One}) + - sympy_subs(index_dim, {v: sympy.S.Zero}) + ) + return strides + + def _get_unbacked_replacements(self) -> dict[Expr, Expr]: + """ + This helps with covering unbacked symint cases where you may have two + expressions: s0 + u0 and u1. And s0 + u0 is known to be equal to u1 + via deferred_runtime_asserts. + + For example in atomically_apply_size_hint, it must return the same size + hint for both s0 + u0 and u1, but it first needs to know they are equal. + Then it can substitute s0 + u0 for u1. + """ + if self.unbacked_replacements is not None: + return self.unbacked_replacements + + self.unbacked_replacements = {} + for assertions in self.shape_env.deferred_runtime_asserts.values(): + for assertion in assertions: + if not isinstance(assertion.expr, sympy.Equality): + continue + + lhs, rhs = assertion.expr.lhs, assertion.expr.rhs + l2r = lhs.compare(rhs) == 1 # see sympy.Basic.compare + src = lhs if l2r else rhs + dst = rhs if l2r else lhs + + existing_replacement = self.unbacked_replacements.get(src, None) + if existing_replacement and isinstance( + existing_replacement, sympy.Symbol + ): + # Prefer to keep replacements with symbols. + continue + self.unbacked_replacements[src] = dst + return self.unbacked_replacements + + @functools.lru_cache # noqa: B019 + def _sub_unbacked_exprs(self, expr: Expr) -> Expr: + # it's fine to cache this fn since self is a singleton + replacements = self._get_unbacked_replacements() + while True: + new_expr = expr.subs(replacements) + if new_expr == expr: + return new_expr + expr = sympy.factor(new_expr) + + def atomically_apply_size_hint( + self, expr: Union[Expr, int], *, fallback: Optional[int] = None + ) -> Union[Expr, int]: + if isinstance(expr, (int, sympy.Integer)): + return int(expr) + + if has_free_unbacked_symbols(expr): + # Make sure to substitute with the factored version + # e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0 + expr = self._sub_unbacked_exprs(sympy.factor(expr)) + + # For multiple expressions that depend on an unbacked symint, + # we want to compute them consistently for a size hint we have chosen. + # So, recursively compute expressions via size hints of contained symbols. + # For example: u1 * u2 - 10 ==> fallback * fallback - 10 + assert isinstance(expr, Expr), type(expr) + free_symbols = expr.free_symbols + size_dict = { + symbol: V.graph.sizevars.size_hint(symbol, fallback=fallback) + for symbol in free_symbols + } + return expr.subs(size_dict) + + def offset_var(self, index: Expr, vars: list[sympy.Symbol]) -> Expr: + """Extract offset part of an indexing expression""" + index = self.simplify(index) + return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0}) + + def stride_hints( + self, + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Optional[Sequence[sympy.Symbol]] = None, + ) -> list[int]: + for v in index.free_symbols: + if symbol_is_type(v, SymT.INDIRECT): # type: ignore[attr-defined] + index = sympy_subs(index, {v: 0}) # type: ignore[dict-item] + result = [] + for s in self.stride_vars(index, vars, support_vars): + try: + result.append(self.size_hint_or_throw(s)) + except TypeError: + result.append(0) + return result + + def stride_order(self, index: Expr, vars: list[sympy.Symbol]) -> list[int]: + strides = tuple(map(abs, self.stride_hints(index, vars))) + order = list(range(len(strides))) + order.sort(key=lambda x: (strides[x] == 0, strides[x])) + return order + + def lookup_precomputed_size(self, expr: Expr) -> Expr: + if ( + isinstance(expr, (int, sympy.Symbol, sympy.Number)) + or expr.is_number + or expr.is_symbol + ): + return expr + expr = self.remove_precomputed_replacements(expr) + if expr not in self.precomputed_replacements: + sym = sympy_index_symbol_with_prefix( + SymT.PRECOMPUTED_SIZE, len(self.precomputed_replacements) + ) + self.precomputed_replacements[expr] = sym + self.inv_precomputed_replacements[sym] = expr + return self.precomputed_replacements[expr] + + def free_symbols(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet(self.var_to_val.keys()) - OrderedSet(self.replacements.keys()) + + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + """ + A pair of special ModularIndexing can be combined. + + E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b) + We can simplify this to ModuleIndexing(x, 1, b), if + 1. x is non negative integer + 2. a and b are positive integers + 3. a is a multiple of b. + """ + + def _check_args(x, div, mod, is_first): + if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer): + return False + if div != 1: + return False + if mod <= 0: + return False + + if is_first: + # first ModularIndexing should contains a nested ModularIndex + if not isinstance(x, ModularIndexing): + return False + else: + # second ModularIndexing should contains a non-negative + # symbol + if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( + x, 0 + ): + return False + return True + + if isinstance(index, ModularIndexing): + x, div, mod = index.args + + if not _check_args(x, div, mod, True): + return index + + x2, div2, mod2 = x.args + + if not _check_args(x2, div2, mod2, False): + return index + + if mod2 % mod != 0: + return index + + return ModularIndexing(x2, 1, mod) + + return index + + def expand_floor_div( + self, index: sympy.Expr + ) -> Union[bool, tuple[sympy.Expr, sympy.Expr]]: + """ + Expand the FloorDiv to the entire expression so that the expression may + be simplified. + + E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables + x1, x2, index expression 'x1 * 2b + x2' can be easily combined. + But index expression 'x1 * b + x2 // 2' can not. + By expanding the FloorDiv to the entire expression, we get + '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops + for the numerator! + + Return false if this optimization can be applied; + Return the new expression and the denominator otherwise. + The original expression will be equivalent to 'new_expression // denominator' + """ + if not isinstance(index, sympy.Add): + return False + terms = index.args + + if len(terms) < 2: + return False + floor_div_index = -1 + varlist = [] + factorlist = [] + for idx, term in enumerate(terms): + if isinstance(term, sympy.Mul): + # For dynamic shape, term like '2*s1*x1' has 3 child nodes. + # - A integer for 2 + # - A symbol for s1 + # - A symbol for x1 + # Skip for now. + if len(term.args) != 2: + return False + factor, var = term.args + varlist.append(var) + factorlist.append(factor) + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + # It's easier to reason about the correceness of the transformation + # for non-negative integers. + if not self.statically_known_geq(var, 0): + return False + elif isinstance(term, FloorDiv): + var, factor = term.args + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + if not self.statically_known_geq(var, 0): + return False + if floor_div_index >= 0: + # can not handle multi FloorDiv yet + return False + + floor_div_index = idx + varlist.append(var) + # this factor is denominator + factorlist.append(factor) + else: + return False + + if floor_div_index < 0: + return False + + # Construct the new expression and remember the denominator + denominator = factorlist[floor_div_index] + new_index = sympy.S.Zero + + for var, factor, idx in zip(varlist, factorlist, itertools.count()): + if idx == floor_div_index: + new_index += var + else: + new_index += (factor * denominator) * var + + return new_index, denominator + + +def join_dimensions(expr: Expr) -> Expr: + if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): + return expr # fast exit path + return _join_dimensions_cached(expr) + + +@functools.lru_cache(256) +def _join_dimensions_cached(expr: Expr) -> Expr: + """ + ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) + becomes + ModularIndexing(i0, 1, 128) + ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32) + becomes i0 + + + This type of pattern can come from view operations + """ + assert isinstance(expr, sympy.Add) + + scale = sympy.Wild("scale", exclude=[0], integer=True) + base = sympy.Wild("base", integer=True) + divisor = sympy.Wild("divisor", integer=True) + mod1 = sympy.Wild("modulus", integer=True) + mod2 = sympy.Wild("modulus2", integer=True) + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] + * m1[mod1] + * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2) + ) + if m2 and term1 != term2: + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] + * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2]) + ) + return expr + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1]) + ) + if m2 is not None: # in case of success we get an empty dict here + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] * FloorDiv(m1[base], m1[divisor]) + ) + return expr + return expr + + +class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] + """ + A wrapper around .virtualize.ops that uses var range information to + simplify ModularIndexing/FloorDiv. + """ + + def __init__(self, inner, var_ranges: VarRanges) -> None: + super().__init__(inner) + self.name = "SimplifyIndexing" + self._simplify: Callable[[Expr], Expr] = ( + lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) + ) + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(name, self._simplify(index)) + + def store(self, name, index, value, mode=None): + return self._inner.store(name, self._simplify(index), value, mode=mode) + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(name, self._simplify(index), value) + + def index_expr(self, index, dtype): + return self._inner.index_expr(self._simplify(index), dtype) + + def check_bounds(self, index, size, lower, upper): + return self._inner.check_bounds(self._simplify(index), size, lower, upper) diff --git a/phivenv/Lib/site-packages/torch/_inductor/standalone_compile.py b/phivenv/Lib/site-packages/torch/_inductor/standalone_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f811b04ba3b5eaccffe438a3347cb88580c32a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/standalone_compile.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import copy +import logging +import os +import pickle +import shutil +from contextlib import AbstractContextManager, nullcontext +from typing import Any, Callable, Literal, Optional, TYPE_CHECKING + +import torch.fx +from torch._dynamo.utils import dynamo_timed +from torch._inductor.cudagraph_utils import BoxedDeviceIndex +from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir +from torch._inductor.utils import BoxedBool, InputType +from torch._subclasses import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ShapeEnv + +from . import config + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from torch.compiler._cache import CacheInfo + from torch.fx import GraphModule + + +log = logging.getLogger(__name__) + + +class CompiledArtifact: + """ + CompiledArtifact class represents the precompiled inductor artifact that + can be invoked in order to avoid repeated compilation. + + CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs) + to create a fresh CompiledArtifact from a GraphModule and example inputs. + + Later this CompiledArtifact can be saved to disk, either as a binary or unpacked + into the provided folder via the CompiledArtifact.save function. + + CompiledArtifact.load provides a way to create a CompiledArtifact from the + binary or unpacked data. + + Finally, the CompiledArtifact can be invoked via the __call__ method + to execute the precompiled artifact. + """ + + _compiled_fn: Callable[..., Any] + _artifacts: Optional[tuple[bytes, CacheInfo]] + + def __init__( + self, + compiled_fn: Callable[..., Any], + artifacts: Optional[tuple[bytes, CacheInfo]], + ): + self._compiled_fn = compiled_fn + self._artifacts = artifacts + + def __call__(self, *args: Any) -> Any: + return self._compiled_fn(*args) + + def save( + self, *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> None: + with dynamo_timed("CompiledArtifact.save"): + if self._artifacts is None: + raise RuntimeError( + "CompiledArtifact.save failed to save since there's no artifact to save" + ) + artifact_bytes, cache_info = self._artifacts + assert len(cache_info.aot_autograd_artifacts) == 1, cache_info + key = cache_info.aot_autograd_artifacts[0] + + if format == "binary": + # can't assert that it is a file since it might not exist yet + assert not os.path.isdir(path) + + from torch.utils._appending_byte_serializer import BytesWriter + + from .codecache import torch_key + + writer = BytesWriter() + writer.write_bytes(torch_key()) + writer.write_str(key) + writer.write_bytes(artifact_bytes) + with open(path, "wb") as file: + file.write(writer.to_bytes()) + else: + assert format == "unpacked" + if os.path.exists(path): + assert os.path.isdir(path) + shutil.rmtree(path, ignore_errors=True) + + from .codecache import FxGraphCache + + with temporary_cache_dir(path): + # This function unpacks the cache artifacts to disk + loaded_cache_info = torch.compiler.load_cache_artifacts( + artifact_bytes + ) + assert loaded_cache_info is not None + # Now write all the output_code artifacts to disk so that + # they can be inspected and modified + for key in loaded_cache_info.inductor_artifacts: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + assert os.path.exists(subdir) + for path in sorted(os.listdir(subdir)): + with open(os.path.join(subdir, path), "rb") as f: + graph = pickle.load(f) + output_file = graph.write_to_disk() + log.info("Output code written to: %s", output_file) + + @staticmethod + def load( + *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> CompiledArtifact: + with dynamo_timed("CompiledArtifact.load"): + if format == "binary": + # can't assert that it is a file since it might not exist yet + assert not os.path.isdir(path) + with open(path, "rb") as file: + artifacts = file.read() + from torch.utils._appending_byte_serializer import BytesReader + + from .codecache import torch_key + + reader = BytesReader(artifacts) + assert reader.read_bytes() == torch_key() + key = reader.read_str() + artifact_bytes = reader.read_bytes() + assert reader.is_finished() + + torch.compiler.load_cache_artifacts(artifact_bytes) + + cache_dir_ctx: AbstractContextManager[None] = nullcontext() + else: + assert format == "unpacked" + assert os.path.isdir(path) + autograd_cache_dir = os.path.join(path, "aotautograd") + assert os.path.isdir(autograd_cache_dir) + files = list(os.listdir(autograd_cache_dir)) + assert len(files) == 1 + key = files[0] + cache_dir_ctx = temporary_cache_dir(path) + + with ( + cache_dir_ctx, + config.patch(unsafe_skip_cache_dynamic_shape_guards=True), + ): + with torch._functorch.config.patch(strict_autograd_cache=True): + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache, + ) + + entry = AOTAutogradCache._lookup( + key, + local=True, + remote=False, + args=[], + cache_info={}, + aot_config=None, + ) + + assert entry is not None + + from .compile_fx import _CompileFxKwargs + + fx_config = _CompileFxKwargs( + cudagraphs=BoxedBool(False), + boxed_forward_device_index=BoxedDeviceIndex(0), + ) + + context = torch._guards.TracingContext( + FakeTensorMode(shape_env=ShapeEnv()) + ) + with torch._guards.tracing(context): + compiled_fn = entry.wrap_post_compile( + [], entry.sanitized_aot_config, fx_config + ) + return CompiledArtifact(lambda *args: compiled_fn(list(args)), None) + + +def standalone_compile( + gm: GraphModule, + example_inputs: Sequence[InputType], + *, + dynamic_shapes: Any, + options: Any, +) -> CompiledArtifact: + from torch.compiler._cache import CacheArtifactManager + + from .compile_fx import compile_fx + + ignore_shape_env = False + if dynamic_shapes == "from_example_inputs": + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + # tells compile_fx to ignore the shape_envs on the ambient context + # and the graph_module. + ignore_shape_env = True + elif dynamic_shapes == "from_tracing_context": + # Reuse fake_mode from the TracingContext. + # NB: The TracingContext only exists if we're currently in a torch.compile backend. + context = torch._guards.TracingContext.get() + fake_mode = context.fake_mode + elif dynamic_shapes == "from_graph": + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + # Strategy: find a FakeTensor in the graph output, grab its FakeTensorMode. + # The graph passed to standalone_compile must be an Inductor-approved graph, + # which means that there is at least one Tensor output and the output node + # contains a flat list of Tensors. + last_node = next(iter(reversed(gm.graph.nodes))) + assert last_node.op == "output" + assert len(last_node.args) == 1 + for node in last_node.args[0]: + if "example_value" in node.meta: + maybe_tensor = node.meta["example_value"] + if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor): + fake_mode = maybe_tensor.fake_mode + else: + raise ValueError( + f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}." + ) + + context = torch._guards.TracingContext(fake_mode) + with ( + torch._guards.tracing(context), + CacheArtifactManager.with_fresh_cache(), + config.patch("triton.autotune_at_compile_time", True), + ): + # compile_fx can mutate gm + gm = copy.deepcopy(gm) + compiled_fn = compile_fx( + gm, example_inputs, ignore_shape_env=ignore_shape_env, **options + ) + assert callable(compiled_fn) + + artifacts = torch.compiler.save_cache_artifacts() + if artifacts is None: + log.warning( + "standalone_compile artifact generation failed, cannot save. " + "Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem" + ) + + return CompiledArtifact(compiled_fn, artifacts) diff --git a/phivenv/Lib/site-packages/torch/_inductor/subgraph_lowering.py b/phivenv/Lib/site-packages/torch/_inductor/subgraph_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..fe830b233132f5363cc40d064bc6e6055a09581e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/subgraph_lowering.py @@ -0,0 +1,209 @@ +"""Utilities for lowering subgraphs used by higher order operators""" + +import functools +import operator +from collections.abc import Generator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +from torch.utils._ordered_set import OrderedSet + +from . import ir +from .exc import SubgraphLoweringException +from .ops_handler import SimpleCSEHandler +from .virtualized import ops, V, WrapperHandler + + +T = TypeVar("T") +_P = ParamSpec("_P") + +OpOverload = torch._ops.OpOverload +LoweringDict = dict[Union[OpOverload, str], Callable[..., Any]] +TargetType = Union[Callable[..., Any], str] + + +class PointwiseSubgraphLowering(torch.fx.Interpreter): + """ + Lowers a pointwise subgraph to a single set of buffers with a separate + lowering object. Errors if buffers are created unexpectedly + """ + + graph_outputs: Optional[list[ir.IRNode]] + root_graph: torch._inductor.graph.GraphLowering + _current_op: Optional[TargetType] + # For backwards of buffer_grads with scatters we allow mutations + allowed_mutations: Optional[OrderedSet[OpOverload]] + additional_lowerings: Optional[LoweringDict] + buffers: list[ir.Buffer] + mutated_buffers: OrderedSet[str] + + def __init__( + self, + gm: torch.fx.GraphModule, + root_graph_lowering: torch._inductor.graph.GraphLowering, + allowed_mutations: Optional[OrderedSet[OpOverload]] = None, + additional_lowerings: Optional[LoweringDict] = None, + ) -> None: + super().__init__(gm) + self.graph_outputs = None + self.root_graph = root_graph_lowering + self.allowed_mutations = allowed_mutations + self.additional_lowerings = additional_lowerings + self._current_op = None + + # Used to track buffers created during lowering + self.mutated_buffers = OrderedSet() + self.buffers = [] + + @contextmanager + def _op_context(self, op: TargetType) -> Generator[None, None, None]: + """Set which op is being processed in call function to know if we can mutate buffers""" + previous = self._current_op + self._current_op = op + try: + yield + finally: + self._current_op = previous + + def _approved_mutator(self) -> bool: + return ( + self.allowed_mutations is not None + and self._current_op in self.allowed_mutations + ) + + def mark_buffer_mutated(self, name: str) -> None: + if self._approved_mutator(): + self.mutated_buffers.add(name) + else: + raise SubgraphLoweringException( + f"Buffer mutation detected during lowering of {self._current_op}. " + "Buffer mutations are only allowed in approved mutation ops. " + "This is an error in the lowering of the subgraph, please file a bug report." + ) + + def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: + if self._approved_mutator(): + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) + return name + else: + raise SubgraphLoweringException( + "Buffers cannot be created while lowering a pointwise subgraph. " + "This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), " + "but it could also be a bug. Please file a bug report if you think this should be supportable." + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self.root_graph, name) + + def call_function( + self, + target: TargetType, + args: Any, + kwargs: dict[str, Any], + ) -> Any: + from .lowering import lowerings + + with self._op_context(target): + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + # These takes precedence over the main lowerings + if self.additional_lowerings is not None: + if target in self.additional_lowerings: + assert isinstance(target, OpOverload) + return self.additional_lowerings[target](*args, **kwargs) + + if target not in lowerings: + raise SubgraphLoweringException( + f"{target} not supported in subgraph, (missing lowering)" + ) + + return lowerings[target](*args, **kwargs) + + def output(self, target: str, args: tuple[Any], kwargs: dict[str, Any]) -> None: # type: ignore[override] + assert len(args) == 1 + self.graph_outputs = args[0] + + +@dataclass +class InputDescriptor: + dtype: torch.dtype + device: torch.device + + +class TracingOpsHandler(WrapperHandler): + def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None: + parent = tracer.create_proxy("placeholder", "ops", (), {}) + super().__init__(parent) + self.tracer = tracer + + self.placeholders = [ + self.tracer.create_proxy("placeholder", f"input{i}", (), {}) + for i in range(num_inputs) + ] + + def placeholder(self, idx: int) -> torch.fx.Proxy: + return self.placeholders[idx] + + def output(self, *args: tuple[object]) -> None: + self.tracer.create_node( + "output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {} + ) + + +def lower_pointwise_subgraph( + subgraph: ir.Subgraph, inputs: list[InputDescriptor] +) -> Callable[_P, Any]: + # Lower subgraph to ir.Pointwise nodes + def fake_inner_fn( + loop_idx: int, input_idx: int + ) -> Union[ir.Expr, ir.TensorBox, None]: + return ops.placeholder(input_idx) + + graph_inputs = [ + ir.Pointwise.create( + device=desc.device, + dtype=desc.dtype, + inner_fn=functools.partial(fake_inner_fn, input_idx=i), + ranges=[], + ) + for i, desc in enumerate(inputs) + ] + gm = subgraph.graph_module + pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*graph_inputs) + + # Combine multiple pointwise computations into a single graph module + # Do this by tracing through each individually and doing CSE + tracer = torch.fx.Tracer() + tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + trace_ops = SimpleCSEHandler(TracingOpsHandler(tracer, len(inputs))) + assert pw_subgraph.graph_outputs is not None + + with V.set_ops_handler(trace_ops): + output_irs = [] + + for out_var in pw_subgraph.graph_outputs: + assert isinstance(out_var, ir.TensorBox), type(out_var) + assert out_var.get_size() == [] + assert isinstance(out_var.data, ir.StorageBox) + assert isinstance(out_var.data.data, ir.Pointwise) + + idx = () + ir_out = out_var.data.data.inner_fn(idx) + + output_irs.append(ir_out) + + ops.output(*output_irs) + + lowered_gm = torch.fx.GraphModule({}, tracer.graph) + + def inner_fn(*args: _P.args, **kwargs: _P.kwargs) -> Any: + return lowered_gm(V.get_ops_handler(), *args, **kwargs) + + return inner_fn diff --git a/phivenv/Lib/site-packages/torch/_inductor/template_heuristics.py b/phivenv/Lib/site-packages/torch/_inductor/template_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..68b94a5214694771922cb2a04675eab3dd7e5f59 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/template_heuristics.py @@ -0,0 +1,1180 @@ +from __future__ import annotations + +import dataclasses +import itertools +import math +from functools import partial +from threading import Lock +from typing import Any, Callable, TYPE_CHECKING + +import torch +from torch.utils._ordered_set import OrderedSet + +from . import config +from .utils import get_backend_num_stages +from .virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Generator + + from triton import Config as TritonConfig + + +# Gemm Configs +@dataclasses.dataclass +class BaseConfig: + """ + Base Gemm configuration used for most backends (CPU, CUDA) + """ + + block_m: int + block_n: int + block_k: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class GemmConfig(BaseConfig): + """ + Gemm configuration used for most backends (CPU, CUDA) + """ + + group_m: int = 8 + + +ConvConfig = BaseConfig + + +# FlexAttention Configs +@dataclasses.dataclass +class FlexConfig: + """ + Base Config class for flex attention + - FlexAttn forward, backward and flex decode will use this + + NOTE: + For flex_attn bwd block_m and block_n are reused for block_m1, block_m2, block_n1, block_n2 + + """ + + block_m: int + block_n: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class FlexDecodeConfig: + """ + Config class for flex decoding + """ + + block_n: int + num_stages: int + num_warps: int + + +# ROCm classes +@dataclasses.dataclass +class ROCmGemmConfig(GemmConfig): + """ + ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmConvConfig(ConvConfig): + """ + ROCm subclass for Conv, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexConfig(FlexConfig): + """ + ROCm subclass for FlexAttn, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexDecodeConfig(FlexDecodeConfig): + """ + ROCm subclass for FlexDecode, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +class BaseHeuristicSingleton(type): + """ + Thread-safe implementation of single to be used in the config heuristic subclasses + to ensure heavy __init__ calls are not repeatedly run + """ + + _instances: dict[type[Any], Any] = {} + _lock: Lock = Lock() + + def __call__( + cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any + ) -> BaseConfigHeuristic: + with cls._lock: + if cls not in cls._instances: + instance = super().__call__() + cls._instances[cls] = instance + return cls._instances[cls] + + +class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton): + """ + Base class for mm_configs, device specific triton kernels config inherit from here + """ + + def __init__(self) -> None: + # List of dictionaries to store the kernel configs. Configs that evaluate to true + # will be utilised on the target platform. The configs are as follows: + # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + self.mm_configs: list[BaseConfig] = [ + GemmConfig(32, 32, 16, 1, 2), + GemmConfig(32, 32, 128, 2, 4), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(64, 32, 128, 5, 4), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(64, 64, 128, 5, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(64, 128, 64, 3, 4), + GemmConfig(64, 128, 128, 4, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(128, 128, 32, 3, 4), + GemmConfig(128, 128, 64, 3, 4), + GemmConfig(128, 128, 64, 5, 8), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5] + for num_warps in [2, 4, 8] + for group_m in [8] + ] + + # these are only used in tuned_mm when AutoHeuristic is enabled + # the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned + # when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 + # which saves compilation time (since less configs are autotuned) and potentially increase performance + # because the learned heuristic might predict a config that is not part mm_configs + self.extra_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 32, 16, 3, 2), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(64, 64, 128, 3, 4), + GemmConfig(128, 64, 32, 2, 2), + GemmConfig(128, 64, 64, 3, 8), + GemmConfig(128, 64, 128, 4, 8), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.int8_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(128, 256, 128, 3, 8), + GemmConfig(256, 128, 128, 3, 8), + ] + + self.mixed_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 128, 256, 3, 4), + GemmConfig(16, 128, 256, 5, 8), + ] + + self.persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 64, 3, 8), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 8), + GemmConfig(256, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.scaled_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 32, 3, 8), + GemmConfig(256, 128, 32, 3, 8), + GemmConfig(256, 64, 32, 4, 4), + GemmConfig(64, 256, 32, 4, 4), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 64, 32, 4, 4), + GemmConfig(64, 128, 32, 4, 4), + GemmConfig(128, 32, 32, 4, 4), + GemmConfig(64, 32, 32, 5, 2), + GemmConfig(256, 128, 128, 3, 8), + GemmConfig(256, 64, 128, 4, 4), + GemmConfig(64, 256, 128, 4, 4), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 64, 64, 4, 4), + GemmConfig(64, 128, 64, 4, 4), + GemmConfig(128, 32, 64, 4, 4), + GemmConfig(64, 32, 64, 5, 2), + GemmConfig(16, 32, 32, 2, 2), + GemmConfig(16, 64, 32, 2, 2), + GemmConfig(16, 128, 32, 2, 4), + GemmConfig(16, 256, 32, 2, 4), + GemmConfig(16, 32, 64, 2, 2), + GemmConfig(16, 64, 64, 2, 2), + GemmConfig(16, 128, 64, 2, 4), + GemmConfig(16, 256, 64, 2, 4), + GemmConfig(32, 32, 32, 2, 2), + GemmConfig(32, 64, 32, 2, 2), + GemmConfig(32, 128, 32, 2, 4), + GemmConfig(32, 256, 32, 2, 4), + GemmConfig(32, 32, 64, 2, 2), + GemmConfig(32, 64, 64, 2, 2), + GemmConfig(32, 128, 64, 2, 4), + GemmConfig(32, 256, 64, 2, 4), + GemmConfig(16, 32, 32, 3, 2), + GemmConfig(16, 64, 32, 3, 2), + GemmConfig(16, 128, 32, 3, 4), + GemmConfig(16, 256, 32, 3, 4), + GemmConfig(16, 32, 64, 3, 2), + GemmConfig(16, 64, 64, 3, 2), + GemmConfig(16, 128, 64, 3, 4), + GemmConfig(16, 256, 64, 3, 4), + GemmConfig(32, 32, 32, 3, 2), + GemmConfig(32, 64, 32, 3, 2), + GemmConfig(32, 128, 32, 3, 4), + GemmConfig(32, 256, 32, 3, 4), + GemmConfig(32, 32, 64, 3, 2), + GemmConfig(32, 64, 64, 3, 2), + GemmConfig(32, 128, 64, 3, 4), + GemmConfig(32, 256, 64, 3, 4), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 64, 32, 4, 2), + GemmConfig(16, 128, 32, 4, 4), + GemmConfig(16, 256, 32, 4, 4), + GemmConfig(16, 32, 64, 4, 2), + GemmConfig(16, 64, 64, 4, 2), + GemmConfig(16, 128, 64, 4, 4), + GemmConfig(16, 256, 64, 4, 4), + GemmConfig(32, 32, 32, 4, 2), + GemmConfig(32, 64, 32, 4, 2), + GemmConfig(32, 128, 32, 4, 4), + GemmConfig(32, 256, 32, 4, 4), + GemmConfig(32, 32, 64, 4, 2), + GemmConfig(32, 64, 64, 4, 2), + GemmConfig(32, 128, 64, 4, 4), + GemmConfig(32, 256, 64, 4, 4), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(16, 64, 32, 5, 2), + GemmConfig(16, 128, 32, 5, 4), + GemmConfig(16, 256, 32, 5, 4), + GemmConfig(16, 32, 64, 5, 2), + GemmConfig(16, 64, 64, 5, 2), + GemmConfig(16, 128, 64, 5, 4), + GemmConfig(16, 256, 64, 5, 4), + GemmConfig(32, 32, 32, 5, 2), + GemmConfig(32, 64, 32, 5, 2), + GemmConfig(32, 128, 32, 5, 4), + GemmConfig(32, 256, 32, 5, 4), + GemmConfig(32, 32, 64, 5, 2), + GemmConfig(32, 64, 64, 5, 2), + GemmConfig(32, 128, 64, 5, 4), + GemmConfig(32, 256, 64, 5, 4), + GemmConfig(16, 32, 32, 6, 2), + GemmConfig(16, 64, 32, 6, 2), + GemmConfig(16, 128, 32, 6, 4), + GemmConfig(16, 256, 32, 6, 4), + GemmConfig(16, 32, 64, 6, 2), + GemmConfig(16, 64, 64, 6, 2), + GemmConfig(16, 128, 64, 6, 4), + GemmConfig(16, 256, 64, 6, 4), + GemmConfig(32, 32, 32, 6, 2), + GemmConfig(32, 64, 32, 6, 2), + GemmConfig(32, 128, 32, 6, 4), + GemmConfig(32, 256, 32, 6, 4), + GemmConfig(32, 32, 64, 6, 2), + GemmConfig(32, 64, 64, 6, 2), + GemmConfig(32, 128, 64, 6, 4), + GemmConfig(32, 256, 64, 6, 4), + ] + + self.scaled_persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 4, 8), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 5, 4), + GemmConfig(128, 128, 128, 5, 8), + GemmConfig(128, 128, 128, 6, 8), + GemmConfig(128, 128, 64, 4, 8), + ] + + # TODO: Unify with other gemm patterns, mm_plus_mm currently follows + # slightly different pattern than rest + self.mm_plus_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 32, 3, 8), + GemmConfig(64, 64, 32, 4, 16), + GemmConfig(64, 32, 32, 4, 8), + GemmConfig(32, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 1, 8), + GemmConfig(64, 64, 64, 1, 8), + GemmConfig(32, 32, 128, 1, 8), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(32, 32, 16, 1, 2), + ] + + self.conv_configs: list[BaseConfig] = [ + ConvConfig(64, 256, 16, 2, 4), + ConvConfig(256, 64, 16, 2, 4), + ConvConfig(1024, 16, 16, 1, 8), + ConvConfig(128, 128, 32, 2, 8), + ConvConfig(64, 64, 32, 2, 4), + ConvConfig(64, 256, 32, 2, 8), + ConvConfig(256, 64, 32, 2, 8), + ] + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(128, 64, 3, 4), + FlexConfig(128, 128, 3, 4), + FlexConfig(128, 128, 2, 8), + FlexConfig(64, 128, 3, 4), + FlexConfig(64, 64, 3, 4), + ] + + self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(BLOCK1, BLOCK2, s, w) + for BLOCK1 in [32, 64] + for BLOCK2 in [32, 64, 128] + for s in [1, 3, 4, 5] # num_stages + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(64, 3, 2), + FlexDecodeConfig(32, 3, 2), + FlexDecodeConfig(128, 3, 2), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK1, BLOCK2, num_stages, num_warps) + for BLOCK1 in [16, 32, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + if BLOCK2 % BLOCK1 == 0 + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(block_n, num_stages, num_warps) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + num_warps, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": num_warps, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) + + def _scale_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + scale: float, + has_int8_tensor: bool, + exclude: Callable[[int, int, int], bool], + ) -> list[BaseConfig]: + """ + Scales and filters matrix multiplication configs based on input size. + """ + from .runtime.runtime_utils import next_power_of_2 + + min_block_size = 16 + min_block_size_k = 32 if has_int8_tensor else 16 + + m = max( + next_power_of_2( + V.graph.sizevars.size_hint( + m, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + ), + min_block_size, + ) + n = max( + next_power_of_2( + V.graph.sizevars.size_hint( + n, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + ), + min_block_size, + ) + k = max( + next_power_of_2( + V.graph.sizevars.size_hint( + k, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + ), + min_block_size_k, + ) + + scaled_configs = [] + for c in configs: + scaled_config = dataclasses.replace( + c, + block_m=max(min(int(c.block_m * scale), m), min_block_size), + block_n=max(min(int(c.block_n * scale), n), min_block_size), + block_k=max(min(int(c.block_k * scale), k), min_block_size_k), + ) + + if not exclude( + scaled_config.block_m, scaled_config.block_n, scaled_config.block_k + ): + scaled_configs.append(scaled_config) + + return scaled_configs + + def _prune_exhaustive_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + import torch + + pruned_configs = [] + for gemm_config in configs: + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + sm_available = props.shared_memory_per_block_optin # type: ignore[attr-defined] + NUM_REG = 255 + + acc_regs = math.ceil( + gemm_config.block_m * gemm_config.block_n / (gemm_config.num_warps * 32) + ) + + shared_mem_accum = dtype_size * ( + gemm_config.block_m * gemm_config.block_k + + gemm_config.block_n * gemm_config.block_k + ) + + # Will use more shared memory than available + if shared_mem_accum * gemm_config.num_stages > sm_available: + continue + # Lower bound for register spillage, if exceeds the kernel will certainly spill + elif acc_regs > NUM_REG: + continue + + pruned_configs.append(gemm_config) + + return pruned_configs + + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: int = 1, + exclude: Callable[[int, int, int], bool] = lambda m, n, k: False, + dtype_size: int = 0, + ) -> Generator[TritonConfig, None, None]: + scaled_configs = self._scale_mm_configs( + m, n, k, configs, scale, has_int8_tensor, exclude + ) + + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + assert dtype_size > 0, "dtype_size must be provided for exhaustive search" + scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size) + return self._finalize_mm_configs(scaled_configs) + + def triton_config( + self, num_stages: int, num_warps: int, **kwargs: Any + ) -> TritonConfig: + from triton import Config as TritonConfig # type: ignore[attr-defined] + + return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps) + + def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.mm_configs) + + def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs) + + def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs) + + def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs) + + def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + mm_configs = ( + self.mm_configs + self.mixed_mm_configs + if config.max_autotune_gemm_search_space == "EXHAUSTIVE" + else self.mm_configs + ) + return partial(self.preprocess_mm_configs, configs=mm_configs) + + def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + persistent_mm_configs = ( + self.exhaustive_configs + if config.max_autotune_gemm_search_space == "EXHAUSTIVE" + else self.persistent_mm_configs + ) + + # num_warps=2 not safe for TMA + persistent_mm_configs = [ + config for config in persistent_mm_configs if config.num_warps != 2 + ] + return partial(self.preprocess_mm_configs, configs=persistent_mm_configs) + + def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs) + + def get_scaled_persistent_mm_configs( + self, + ) -> partial[Generator[TritonConfig, None, None]]: + return partial( + self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs + ) + + def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs) + + def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.conv_configs) + + # Flex attn helpers + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + default_config = FlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = FlexDecodeConfig(block_n=64, num_stages=1, num_warps=2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class CPUConfigHeuristic(BaseConfigHeuristic): + pass + + +class CUDAConfigHeuristic(BaseConfigHeuristic): + """ + Child class for CUDA device specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + + self.h100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 128, 3, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + + self.a100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(128, 32, 3, 4), + (torch.float32, 256): FlexConfig(64, 16, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 64, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(32, 64, 3, 4), + (torch.float16, 64): FlexConfig(128, 64, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 256): FlexConfig(32, 64, 3, 4), + } + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + if capability >= (9, 0): + default_config = self.h100_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (8, 0): + default_config = self.a100_default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = FlexConfig(16, 16, 1, 4) + elif head_dim <= 256 and capability >= (9, 0): # H100 + if head_dim == 64: + default_config = FlexConfig(64, 64, 3, 4) + elif head_dim == 128: + default_config = FlexConfig(64, 128, 3, 8) + else: + default_config = FlexConfig(64, 64, 2, 4) + elif capability >= (8, 0): # A100 + if head_dim == 64: + default_config = FlexConfig(32, 128, 3, 4) + elif head_dim == 128: + # SM86/89 have smaller shared memory sizes + num_stages = 3 if capability[1] == 0 else 2 + default_config = FlexConfig(64, 64, num_stages, 4) + else: + default_config = FlexConfig(64, 64, 2, 4) + else: # modest hardware or extremely large head_dim + default_config = FlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + capability = torch.cuda.get_device_capability() + + default_config = FlexDecodeConfig(64, 1, 2) + + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + if capability >= (9, 0): # sm_90+ + if head_dim > 128 and dtype == torch.float32: + default_config = FlexDecodeConfig(64, 1, 2) + else: + default_config = FlexDecodeConfig(64, 3, 2) + else: + default_config = FlexDecodeConfig(64, 1, 2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class ROCmConfigHeuristic(BaseConfigHeuristic): + """ + Child class for ROCm specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + + self.default_num_stages = get_backend_num_stages() + + self.mm_configs: list[BaseConfig] = [ + ROCmGemmConfig( + 16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig( + 32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + ROCmGemmConfig( + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_stages, + num_warps, + group_m, + matrix_instr_nonkdim, + waves_per_eu, + kpack, + ) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, self.default_num_stages] + for num_warps in [4, 8] + for group_m in [4, 8, 16] + for matrix_instr_nonkdim in [0, 16] + for waves_per_eu in [0, 2] + for kpack in [2] + ] + + self.default_flex_config = { + (torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4), + (torch.bfloat16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 256): ROCmFlexConfig(32, 64, 1, 8), + (torch.float16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 256): ROCmFlexConfig(32, 64, 1, 4), + } + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w) + for BLOCK1 in [16, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for w in [4, 8] + ] + + self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w, mfma) + for BLOCK1 in [16, 32, 64] + for BLOCK2 in [32, 64, 128] + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + for mfma in [0, 16] + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(32, 1, 4), + ROCmFlexDecodeConfig(64, 1, 4), + ROCmFlexDecodeConfig(128, 1, 4), + ROCmFlexDecodeConfig(32, 1, 8), + ROCmFlexDecodeConfig(64, 1, 8), + ROCmFlexDecodeConfig(128, 1, 8), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, num_stages, num_warps, mfma, wpeu) + for BLOCK1 in [16, 32, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + if BLOCK2 % BLOCK1 == 0 + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(block_n, num_stages, num_warps, mfma, wpeu, kpack=2) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + def _filter_configs( + self, configs: list[BaseConfig], new_num_stages: int + ) -> list[BaseConfig]: + # TODO: _filter_configs can be removed once backend specific configs are added + # for all methods + for c in configs: + c.num_stages = self.default_num_stages + return configs + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Defaults for AMD triton backend kern args if not set + matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16) + waves_per_eu = getattr(conf, "waves_per_eu", 0) + kpack = getattr(conf, "kpack", 2) + + if matrix_instr_nonkdim != 0 and ( + conf.block_m % matrix_instr_nonkdim != 0 + or conf.block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + conf.num_warps, + waves_per_eu, + matrix_instr_nonkdim, + kpack, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if waves_per_eu != 0: + waves_per_eu = int(8 // conf.num_warps) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": conf.num_warps, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) + + def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.extra_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.int8_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + mm_configs = ( + self.mm_configs + self.mixed_mm_configs + if config.max_autotune_gemm_search_space == "EXHAUSTIVE" + else self.mm_configs + ) + filtered_configs = self._filter_configs(mm_configs, self.default_num_stages) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.persistent_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.scaled_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_scaled_persistent_mm_configs( + self, + ) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.scaled_persistent_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1) + return partial(self._finalize_mm_configs, configs=filtered_configs) + + def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.conv_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(128, 64, 1, 8) + default_config = self.default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = ROCmFlexConfig(32, 16, 1, 4) + else: + default_config = ROCmFlexConfig(64, 32, 1, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = ROCmFlexConfig(16, 16, 1, 4) + elif head_dim <= 256: + if head_dim == 64: + default_config = ROCmFlexConfig(64, 64, 1, 4) + elif head_dim == 128: + default_config = ROCmFlexConfig(64, 128, 1, 8) + else: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = ROCmFlexDecodeConfig(64, 1, 4) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class XPUConfigHeuristic(BaseConfigHeuristic): + """ + Placeholder child class for XPU specific overrides. + """ diff --git a/phivenv/Lib/site-packages/torch/_inductor/test_case.py b/phivenv/Lib/site-packages/torch/_inductor/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..6742523ad3fb4a51b2317586f6eee7145b442308 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/test_case.py @@ -0,0 +1,48 @@ +import contextlib +import os +from typing import Union + +from torch._dynamo.test_case import ( + run_tests as dynamo_run_tests, + TestCase as DynamoTestCase, +) +from torch._functorch import config as functorch_config +from torch._inductor import config +from torch._inductor.utils import fresh_cache + + +def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None: + dynamo_run_tests(needs) + + +class TestCase(DynamoTestCase): + """ + A base TestCase for inductor tests. Enables FX graph caching and isolates + the cache directory for each test. + """ + + def setUp(self) -> None: + super().setUp() + self._inductor_test_stack = contextlib.ExitStack() + self._inductor_test_stack.enter_context( + functorch_config.patch( + { + "enable_autograd_cache": True, + } + ) + ) + + if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: + self._inductor_test_stack.enter_context( + config.patch({"fx_graph_cache": True}) + ) + + if ( + os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1" + and os.environ.get("TORCH_COMPILE_DEBUG") != "1" + ): + self._inductor_test_stack.enter_context(fresh_cache()) + + def tearDown(self) -> None: + super().tearDown() + self._inductor_test_stack.close() diff --git a/phivenv/Lib/site-packages/torch/_inductor/test_operators.py b/phivenv/Lib/site-packages/torch/_inductor/test_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..02ae3868bfc81cca777a622eb318b11b902f50e2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/test_operators.py @@ -0,0 +1,29 @@ +from typing import Any + +import torch.library +from torch import Tensor +from torch.autograd import Function + + +if not torch._running_with_deploy(): + _test_lib_def = torch.library.Library("_inductor_test", "DEF") + _test_lib_def.define( + "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag + ) + + _test_lib_impl = torch.library.Library("_inductor_test", "IMPL") + for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + + class Realize(Function): + @staticmethod + def forward(ctx: object, x: Tensor) -> Tensor: + return torch.ops._inductor_test.realize(x) + + @staticmethod + # types need to stay consistent with _SingleLevelFunction + def backward(ctx: Any, *grad_output: Any) -> Any: + return grad_output[0] + + def realize(x: Tensor) -> Tensor: + return Realize.apply(x) diff --git a/phivenv/Lib/site-packages/torch/_inductor/tiling_utils.py b/phivenv/Lib/site-packages/torch/_inductor/tiling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..40c141908221f876f7b5520f613f3f5a6e376084 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/tiling_utils.py @@ -0,0 +1,764 @@ +import dataclasses +import functools +import itertools +import sys +from collections import Counter, defaultdict +from collections.abc import Iterable, Iterator +from typing import Callable, Literal, Optional, overload, TYPE_CHECKING, TypeVar, Union + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.dependencies import index_vars_no_squeeze +from torch._inductor.utils import sympy_product, sympy_subs +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import Identity +from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .virtualized import V + + +T = TypeVar("T") +U = TypeVar("U") + + +Split = tuple[sympy.Expr, ...] +VarsAndRanges = tuple[list[sympy.Symbol], list[sympy.Expr]] + + +loop_tiling_log = torch._logging.getArtifactLogger(__name__, "loop_tiling") +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + + +if TYPE_CHECKING: + from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode + + +def solve_for_zero(expr: sympy.Expr) -> Optional[sympy.Expr]: + """ + Given an expr with a single free symbol, solve for a constant relation that would make + this expression 0. + """ + if expr.is_constant(): + return None + elif isinstance(expr, FloorDiv): + return None + + assert len(expr.free_symbols) == 1 + free_symbol = next(iter(expr.free_symbols)) + if isinstance(expr, ModularIndexing): + out = try_solve(sympy.Eq(expr.args[0], expr.args[2]), free_symbol) + else: + out = try_solve(sympy.Eq(expr, 0), free_symbol) + if not out or not out[1].is_constant(): + return None + return out[1] + + +def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]: + """ + Giving an expr with a single free symbol, try to find a tiling that would + make the expression coalesced with respect to that symbol. + + Tiling an expression `x` by `y` means that the expression will now be indexed + by both the original (x) and by (x * y). So we are looking for a + multiplicative factor that will make ((x + 1) * y) - (x * y) == 1. + + To simplify things for sympy, we'll try just x * y == 1, check x(1) and x(0). + """ + + if len(expr.free_symbols) == 0: + return None + + free_symbol = next(iter(expr.free_symbols)) + + def _solve_simple_expr(expr: sympy.Expr) -> Optional[sympy.Expr]: + assert not expr.has(ModularIndexing) and not expr.has(FloorDiv) + if len(expr.free_symbols) != 1: + return None + + out = try_solve(sympy.Eq(expr, 1), free_symbol) + if not out or not out[1].is_constant(): + return None + return out[1] + + # Sympy solving is very limited with ModularIndexing and FloorDiv, + # but good otherwise. + if not expr.has(ModularIndexing) and not expr.has(FloorDiv): + return _solve_simple_expr(expr) + + required_values = [] + eq_1_expressions = [] + + # very piecemeal solution if ModularIndexing or FloorDiv involved. + # Look for terms we'll try to make 0, and then other terms we'll try to make 1. + # Expand as needed. + for arg in sympy.Add.make_args(expr): + # Try to make mul terms 0 + if isinstance(arg, sympy.Mul): + seen = False + # TODO - only need one of these to be solvable to zero + # + for mul_arg in arg.args: + out = solve_for_zero(mul_arg) + if out is None: + continue + + assert out.is_constant() + seen = True + required_values.append(out) + + if not seen: + return None + else: + eq_1_expressions.append(arg) + + if not eq_1_expressions: + return None + + eq_1_expr = sum(eq_1_expressions) + + def indexing_div_rep( + x: sympy.Expr, + y: sympy.Expr, + z: Optional[sympy.Expr] = None, + ) -> sympy.Expr: + return x / y + + # For the purposes of tiling/coalesced access, approximate ModularIndexing and FloorDiv + # then check later + eq_1_expr_simplified = eq_1_expr.replace(ModularIndexing, indexing_div_rep).replace( + FloorDiv, indexing_div_rep + ) + + out = _solve_simple_expr(eq_1_expr_simplified) + # since we approximated FloorDiv/ModularIndexing, double check here + if not out or not (sympy_subs(eq_1_expr, {free_symbol: out})) == 1: + return None + + required_values.append(out) + + if len(OrderedSet(required_values)) == 1: + return required_values[0] + + return None + + +def find_coalesced_var( + index: sympy.Expr, var_ranges: dict[sympy.Expr, int] +) -> Optional[sympy.Expr]: + """ + Try to find the symbol which coalesces this index + """ + top_level_terms = sympy.Add.make_args(index) + for v in var_ranges: + if v in top_level_terms: + return v + + # Approximate analysis by evaluating at 1 and 0 + variables: dict[sympy.Symbol, int] = {} + for v in index.free_symbols: + if v in var_ranges: + variables[v] = 0 + else: + variables[v] = get_hint(v) + + zero_index = sympy_subs(index, variables) + for v in var_ranges.keys(): + variables[v] = 1 + try: + new_val = sympy_subs(index, variables) + except ZeroDivisionError: + loop_tiling_log.info("zero division error %s %s", index, variables) + continue + if new_val - zero_index == 1: + variables[v] = 2 + # in some more complex expressions, 0->1 will be coalesced, + # but not 1->2 + if (sympy_subs(index, variables) - new_val) == 1: + return v + variables[v] = 0 + + return None + + +@dataclasses.dataclass(frozen=True) +class FusedNormalizedReadsWrites: + """ + Normalized reads and writes for nodes in the same FusedSchedulerNode. + """ + + index_vars: OrderedSet[sympy.Symbol] + reduce_vars: OrderedSet[sympy.Symbol] + reads: dict[sympy.Expr, OrderedSet[str]] + writes: dict[sympy.Expr, OrderedSet[str]] + var_ranges: dict[sympy.Symbol, int] + + +@overload +def get_pw_red_splits( + n: "SchedulerNode", + pointwise_numel: sympy.Expr, + red_numel: sympy.Expr, + none_if_not_divisible: Literal[True], +) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]: ... + + +@overload +def get_pw_red_splits( + n: "SchedulerNode", + pointwise_numel: sympy.Expr, + red_numel: sympy.Expr, + none_if_not_divisible: Literal[False] = False, +) -> tuple[VarsAndRanges, VarsAndRanges]: ... + + +def get_pw_red_splits( + n: "SchedulerNode", + pointwise_numel: sympy.Expr, + red_numel: sympy.Expr, + none_if_not_divisible: bool = False, +) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]: + if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel: + return ( + (n._body.iter_vars, n._body.sizes[0]), + (n._body.reduce_vars, n._body.sizes[1]), + ) # type: ignore[return-value] + + assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator] + i = len(n._body.sizes[0]) - 1 + prod = 1 + while i >= 0: + prod *= n._body.sizes[0][i] + if prod == red_numel: + break + i -= 1 + + if i >= 0: + pw_splits = n._body.sizes[0][0:i] + iter_vars = n._body.iter_vars[0:i] + + red_splits = n._body.sizes[0][i:] + red_vars = n._body.iter_vars[i:] + return (iter_vars, pw_splits), (red_vars, red_splits) # type: ignore[return-value] + + if none_if_not_divisible: + return None + else: + return ( + (n._body.iter_vars, n._body.sizes[0]), + (n._body.reduce_vars, n._body.sizes[1]), + ) # type: ignore[return-value] + + +class NodeSplitGetter: + """ + Finds a Pointwise, Reduction Split that compatible with all nodes in a SchedulerNode. + """ + + def __init__( + self, + node: Union["FusedSchedulerNode", "SchedulerNode"], + ): + self.node = node + self.pointwise_numel: sympy.Expr = node.group[1][0] + self.red_numel: sympy.Expr = node.group[1][1] + + self.pw_split_options: dict[int, OrderedSet[Split]] = defaultdict(OrderedSet) + + self.reduction_split: Split = () + self.all_node_sizes: OrderedSet[tuple[Split, Split]] = OrderedSet() + + fused_group = node.group[1] + for n in reversed(node.get_nodes()): + if not isinstance(n, torch._inductor.scheduler.SchedulerNode): + continue + + # if we can't split the pw ranges into a (pw, red) split, + # dont add as a split option, but do make sure we check that this size + # is splittable + maybe_splits = get_pw_red_splits( + n, self.pointwise_numel, self.red_numel, none_if_not_divisible=True + ) + if maybe_splits is None: + self.all_node_sizes.add(n._body.sizes) + continue + + (_, n_pw_splits), (_, n_red_splits) = maybe_splits + + # fill in reduction size + n_pw_splits, n_red_splits = ( + torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths( + fused_group, (n_pw_splits, n_red_splits), self.red_numel + ) + ) + + self.pw_split_options[len(n_pw_splits)].add(tuple(n_pw_splits)) + + # initially, we are just going to do a single reduction split since + # reduction tiling is off by default. even if we miss a reduction split, + # we can recover it in the split var analysis. + # TODO: an earlier version for this code tried to iteratively try the maximum number + # of split vars, by iterating over both pointwise and reduction. but not worth + # the complexity yet. + + if n_red_splits != (): + self.reduction_split = (sympy_product(n_red_splits),) + + n_size = (tuple(n_pw_splits), tuple(n_red_splits)) + self.all_node_sizes.add(n_size) + + self.seen_pw_splits: OrderedSet[Split] = OrderedSet() + + def get_node_splits(self) -> tuple[Split, Split]: + """ + Get a compatible pointwise, reduction split of the node + """ + + if len(self.all_node_sizes) == 1: + return next(iter(self.all_node_sizes)) + + max_pw_split = max(self.pw_split_options.keys()) + for pw_split_len in range(max_pw_split, 0, -1): + for pw_split in self.pw_split_options[pw_split_len]: + if out := self.try_split(pw_split, self.reduction_split): + return out + + # combine dims for next round + for pw_split in self.pw_split_options[pw_split_len]: + for i in range(len(pw_split) - 1): + new_split = tuple( + pw_split[0:i] + + (sympy_product(pw_split[i : i + 2]),) + + pw_split[i + 2 :] + ) + self.pw_split_options[len(new_split)].add(new_split) + + # if for whatever reason we couldn't split above, return default split + return ((self.pointwise_numel,), (self.red_numel,)) + + def try_split(self, pw: Split, red: Split) -> Optional[tuple[Split, Split]]: + """ + See if this split is compatible, and potentially returning a longer split + than the input. + """ + + from torch._inductor.codegen.simd import CantSplit, SIMDKernel + + if pw in self.seen_pw_splits: + return None + self.seen_pw_splits.add(pw) + + for n_pw, n_red in self.all_node_sizes: + try: + groups = pw + red + lengths = (n_pw, n_red) + splits, getters = SIMDKernel._split_iteration_ranges(groups, lengths) + except CantSplit: + return None + + assert len(getters) == 2 + pw_group_splits = splits[: len(pw)] + # if we had to divide a variable into two to do this split, + # then lets try the larger, induced split. + # e.g. splitting (12, 2) into (2, 12) will split the first var into: + # (2, 6) and produce an overall split of (2, 6, 2) + flattened_pw_splits = tuple(itertools.chain.from_iterable(pw_group_splits)) + if flattened_pw_splits != pw: + if out := self.try_split(flattened_pw_splits, red): + return out + + return pw, red + + +if sys.version_info >= (3, 10): + # On Python 3.10+ we can use zip(strict=True) + zip_equal = functools.partial(zip, strict=True) +else: + # Fallback for older versions + def zip_equal(it1: Iterable[T], it2: Iterable[U]) -> Iterator[tuple[T, U]]: + """ + Zip two iterables, raising ValueError if their lengths differ. + """ + if len(it1) != len(it2): + raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}") + return zip(it1, it2) + + +def apply_var_mapping( + iter_vars: list[sympy.Symbol], + red_vars: list[sympy.Symbol], + norm_pw_vars: list[sympy.Symbol], + norm_red_vars: list[sympy.Symbol], + new_ranges: list[list[sympy.Expr]], + return_getters_groups: list[list[Callable[[list[sympy.Expr]], sympy.Expr]]], +) -> dict[sympy.Symbol, sympy.Expr]: + """Maps original variables to expressions using normalized variables.""" + + # the output of split_iteration_range is a new_ranges, return_getters_groups + # new_ranges is a flattened list of ranges corresponding to the new pw and red vars + # for example, taking in pw vars of range (6, 6) to normalized range [36], + # new_ranges would be [[6, 6]] + # There is a return_getter callable for each input iter_var and red_vars. + # if you flatten out all of the ranges, and create a variable for each index, + # then applying the flattening vars to the callables in return_getters_groups + # gives you the mapping from input vars -> flattened vars. + # From there, we can compute the output, normalized variables. + # For instance [6, 6] corresponding to flat vars v0, v1 will be + # v0 + 6 * v1 + + # Create flattened iteration variables + num_vars = sum(len(s) for s in new_ranges) + flat_vars = sympy.symbols(f"v_0:{num_vars}") + count = 0 + + if len(iter_vars) == 0 and len(red_vars) == 0: + return {} + + assert len(new_ranges) == len(norm_pw_vars + norm_red_vars) + apply_groups = [] + for group in return_getters_groups: + apply_groups.append([g(flat_vars) for g in group]) + + iter_vars_to_flat_vars = {} + for i, (group, var_group) in enumerate( + zip_equal(apply_groups, ((iter_vars, red_vars))) + ): + # if the node has sizes (p0, 1) and the fused node is (p0, r0) + # the reduction var gets filled in for split_iteration_range + if len(group) != len(var_group): + assert i == 1 + assert len(var_group) == 0 + continue + + iter_vars_to_flat_vars.update({v: g for g, v in zip(group, var_group)}) + + count = 0 + flat_vars_to_new_vars = {} + for new_range, new_var in zip_equal(new_ranges, norm_pw_vars + norm_red_vars): + range_vars = [] + for i in range(len(new_range)): + range_vars.append(flat_vars[count]) + count += 1 + + prod = 1 + for i in range(len(new_range) - 1, -1, -1): + flat_vars_to_new_vars[range_vars[i]] = new_var * prod + prod = new_range[i] * prod + + return { + k: sympy_subs(v, flat_vars_to_new_vars) + for k, v in iter_vars_to_flat_vars.items() + } + + +def extract_normalized_read_writes( + node: Union["FusedSchedulerNode", "SchedulerNode"], +) -> Optional[FusedNormalizedReadsWrites]: + """Extracts index variables, reduce variables, read/write expressions, and variable ranges from a fused node.""" + reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + + all_output_names = node.get_buffer_names() + op_names = node.get_operation_names() + outputs: OrderedSet[str] = OrderedSet() + removed_buffers: OrderedSet[str] = OrderedSet() + for buf_name in all_output_names: + if V.graph.scheduler.can_buffer_be_removed_through_fusion(buf_name, op_names): + removed_buffers.add(buf_name) + else: + outputs.add(buf_name) + + inputs = OrderedSet( + dep.name for dep in node.read_writes.reads if dep.name not in removed_buffers + ) + + pointwise_numel: sympy.Expr = node.group[1][0] + red_numel: sympy.Expr = node.group[1][1] + + # TODO - a few dynamic shapes issues to resolve + if any( + (isinstance(var, sympy.Expr) and not var.is_constant()) + for var in (pointwise_numel, red_numel) + ): + return None + + pw_splits, red_splits = NodeSplitGetter(node).get_node_splits() + + # lets use different prefix (`n`) to distinguish + (norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze( + pw_splits, red_splits, prefix="n" + ) + node = node + + for n in list(node.get_nodes()): + if not isinstance(n, torch._inductor.scheduler.SchedulerNode): + continue + + body = n._body + + # TODO - not handled well. indirect loads will not be coalesced, + # need to account for that in analysis. + if body.indirect_vars: + return None + + n_reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + n_writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + + # TODO - will the names for all the inputs/outputs accurately + # reflect mutation, or do I need to remap with mutation_real_name + for inp in inputs: + for expr in body.get_all_read_expr(inp): + n_reads[expr].add(inp) + + for out in outputs: + for expr in body.get_all_write_expr(out): + n_writes[expr].add(out) + + if not n_reads and not n_writes: + continue + + (iter_vars, n_pw_splits), (red_vars, n_red_splits) = get_pw_red_splits( + n, pointwise_numel, red_numel + ) + + groups = pw_splits + red_splits + lengths = (n_pw_splits, (n_red_splits)) + lengths = ( + torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths( + groups, lengths, red_numel + ) + ) + new_ranges, return_getters_groups = ( + torch._inductor.codegen.simd.SIMDKernel._split_iteration_ranges( + groups, lengths + ) + ) + var_map = apply_var_mapping( + iter_vars, + red_vars, + norm_pw_vars, + norm_red_vars, + new_ranges, + return_getters_groups, + ) + + # We create Identity sympy.Functions to prevent expansion to int64, + # unwrap for tiling analysis. + def remove_identity(expr: sympy.Expr) -> sympy.Expr: + return expr.replace(Identity, lambda x: x) + + n_reads_new = { + sympy_subs(remove_identity(read), var_map): v for read, v in n_reads.items() + } + n_writes_new = { + sympy_subs(remove_identity(write), var_map): v + for write, v in n_writes.items() + } + + for expr, buf_names in n_reads_new.items(): + reads[expr] |= buf_names + + for expr, buf_names in n_writes_new.items(): + writes[expr] |= buf_names + + reads = { + V.graph.sizevars.simplify_with_ranges(r, ranges): v for r, v in reads.items() + } + writes = { + V.graph.sizevars.simplify_with_ranges(w, ranges): v for w, v in writes.items() + } + + fused_out = FusedNormalizedReadsWrites( + norm_pw_vars, # type: ignore[arg-type] + norm_red_vars, # type: ignore[arg-type] + reads, + writes, + ranges, + ) + loop_tiling_log.info("Normalized Fused reads: %s", fused_out) + return fused_out + + +def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int: + """ + Score addr according to its approximate size + """ + + # TODO - deduplicate with candidate_tilings + var_sizes = [] + for v in addr.free_symbols: + v_size = var_ranges.get(v, None) + # TODO - reason about indirect vars + if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None: + var_sizes.append(v_size) + from .virtualized import V + + return V.graph.sizevars.atomically_apply_size_hint( + sympy_product(var_sizes), fallback=config.unbacked_symint_fallback + ) + + +def get_hint(v: Union[sympy.Expr, int]) -> int: + if isinstance(v, int): + return v + else: + return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback) + + +@dataclasses.dataclass(frozen=True) +class VarTiling: + """ + Tiling of a var by `tiling_factor` that yields additional coalesced mem accesses by `benefit_score` + """ + + var: sympy.Symbol + tiling_factor: int + score: int + + +@dataclasses.dataclass(frozen=True) +class CoalesceVarAnalysis: + # Var -> Memory Score - not strictly the amount of memory + # because we multiply writes x2 + # TODO: separate into dataclass that olds mem, dtype, is_write + coalesced_by_var: dict[sympy.Expr, int] + + norm_read_writes: FusedNormalizedReadsWrites + + suggested_split: Optional[VarTiling] = None + + +def analyze_memory_coalescing( + fused_node: Union["FusedSchedulerNode", "SchedulerNode"], +) -> Optional[CoalesceVarAnalysis]: + """ + Find variables that coalesce the reads and writes and score the total size. + + If uncoalesced memory expressions are found, look for additionally tiling of variables + which will coalesce memory accesses. + + For instance - for the following expression: + + (32*p0) // 2048 + + Tiling p0 by 64 will make this expression coalesced. + """ + + norm_read_writes = extract_normalized_read_writes(fused_node) + + if norm_read_writes is None: + return None + + reads = norm_read_writes.reads + writes = norm_read_writes.writes + var_ranges = norm_read_writes.var_ranges + + coalesced_by_var: dict[sympy.Symbol, int] = Counter() + uncoalesced_addrs: dict[sympy.Expr, int] = Counter() + + for is_read, (memory_expr, buf_names) in itertools.chain( + ((True, item) for item in reads.items()), + ((False, item) for item in writes.items()), + ): + # skip memory deps with indirect vars - todo: better handling + indirect_expr = bool( + memory_expr.free_symbols - norm_read_writes.var_ranges.keys() + ) + + if indirect_expr: + continue + + size = get_score(memory_expr, var_ranges) + if size == 0: + continue + + maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges) + + byte_multipler = 0 + for buf_name in buf_names: + if buf := V.graph.try_get_buffer(buf_name): + byte_multipler += buf.dtype.itemsize + + # coalesced writes more important + byte_multipler *= 1 if is_read else 2 + + if maybe_coalesced_var: + coalesced_by_var[maybe_coalesced_var] += size * byte_multipler + else: + uncoalesced_addrs[memory_expr] += size * byte_multipler + + if not uncoalesced_addrs: + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes + ) + + # map from var -> tiling -> total_score + tiling_scores: dict[sympy.Expr, dict[int, int]] = defaultdict(Counter) + + for uncoalesced_expr, addr_score in uncoalesced_addrs.items(): + expr_subs = dict.fromkeys(uncoalesced_expr.free_symbols, 0) + for v in uncoalesced_expr.free_symbols: + # skip non iter/reduce var variables + if v not in var_ranges: + continue + # skip small addrs + if addr_score == 0: + continue + del expr_subs[v] + single_var_expr = sympy_subs(uncoalesced_expr, expr_subs) + expr_subs[v] = 0 + tiling_factor = solve_for_tiling(single_var_expr) + if ( + tiling_factor is None + or not tiling_factor.is_constant() + or not tiling_factor.is_integer + ): + continue + + tiling_factor = int(tiling_factor) + if not V.graph.sizevars.statically_known_lt(tiling_factor, var_ranges[v]): + continue + + # TODO - if a var is in the middle, such as [n0, n1, n2] + # n1 can can be split beyond range + + MIN_TILING_BLOCK = 8 + if not all( + V.graph.sizevars.statically_known_lt(MIN_TILING_BLOCK, block) + for block in (tiling_factor, var_ranges[v] // tiling_factor) + ): + continue + + tiling_scores[v][tiling_factor] += addr_score + + if len(tiling_scores) == 0: + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes + ) + + best_tiling: Optional[tuple[sympy.Expr, int]] = None + best_tiling_score = 0 + + for var, tiling_counter in tiling_scores.items(): + for tile, tile_score in tiling_counter.items(): + if tile_score > best_tiling_score: + best_tiling = (var, tile) + best_tiling_score = tile_score + + if best_tiling is None: + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes + ) + + # TODO - for strictly pointwise fusions, + # we can consider just swizzling the var if the var we are going to tile + # does not coalesce a significant portion of global reads + # TODO - could also prefer index var splits to reduction, better tested + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, + norm_read_writes=norm_read_writes, + suggested_split=VarTiling(best_tiling[0], best_tiling[1], best_tiling_score), + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/triton_bundler.py b/phivenv/Lib/site-packages/torch/_inductor/triton_bundler.py new file mode 100644 index 0000000000000000000000000000000000000000..9679537f99d984b15e9b8bf95ec673561f456f14 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/triton_bundler.py @@ -0,0 +1,409 @@ +import copy +import dataclasses +import logging +import os +import shutil +import uuid +from pathlib import Path +from typing import Optional + +from torch._dynamo.utils import counters, dynamo_timed, set_feature_use +from torch._utils_internal import justknobs_check +from torch.utils._filelock import FileLock + +from .runtime.runtime_utils import triton_cache_dir +from .utils import _IS_WINDOWS, GPU_KERNEL_BIN_EXTS + + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class TritonBundleEntry: + """ + When we have compiled a triton kernel, we take note of that kernel by + its triton generated hash, its device, and where this kernel is located. + This is the minimum information we can use to later retrieve this kernel + from file system. + """ + + kernel_hash: str + device: int + directory: str + + +@dataclasses.dataclass(frozen=True) +class TritonKernelArtifact: + """ + Artifact for an individual kernel converted to bytes. + Bytes could be a cubin, json, ttir, or ttgir. + """ + + filename: str + payload: bytes = dataclasses.field(repr=False) # Do not display binary + + +@dataclasses.dataclass(frozen=True) +class StaticallyLaunchedAutotuner: + """ + Represents a statically compiled CachingAutotuner object that we can + save directly in the cache. A CachingAutotuner is made up of a list of + StaticTritonCompileResults, each of which uses the cubin from a TritonKernelArtifact. + + Statically saved here have their cubin files saved by a corresponding TritonBundleEntry. + """ + + cache_key: str + kernel_name: str + kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821 + + +@dataclasses.dataclass(frozen=True) +class TritonKernelArtifacts: + """ + Collection of artifacts for a particular kernel. + """ + + kernel_hash: str + device: int + artifacts: list[TritonKernelArtifact] + + +@dataclasses.dataclass(frozen=True) +class TritonBundlerMetadata: + """ + Metadata used for instrumentation + """ + + cached_kernel_names: list[str] + statically_launched_kernel_names: list[str] + + +@dataclasses.dataclass(frozen=True) +class TritonBundle: + """ + Serializable bundle to save into FXGraphCache + """ + + kernel_artifacts: list[TritonKernelArtifacts] + static_autotuners: list[StaticallyLaunchedAutotuner] + + +class TritonBundler: + """ + Lightweight Triton Kernel bundler that notes each time we compile a triton + kernel. When collect is called, converts all the previously noted kernels and + their artifacts into a structured bytes blob, and later when write is called + it writes this structured blob back to file system. + + Intended Life cycle: + - TritonBundler.begin_compile is called when we start compiling in Inductor + - TritonBundler.put is called each time a Triton Kernel is compiled + - TritonBundler.collect is called when a cache entry is being generated + - TritonBundler.end_compile is called to indicate bundling is completed, + collect will execute this function as well. + - TritonBundler.read_and_emit is called when a cache entry is read + """ + + _entries: Optional[list[TritonBundleEntry]] = None + _static_autotuners: Optional[list[StaticallyLaunchedAutotuner]] = None + + # __grp__kernel_name.json contains metadata with source code paths + # we use this as sentinel value for search and replace + _REPLACE_BYTES: bytes = b"[REPLACE]" + + @staticmethod + def is_enabled() -> bool: + from torch._inductor import config + + if config.force_disable_caches: + return False + + if (b := config.bundle_triton_into_fx_graph_cache) is not None: + return b + + if not config.is_fbcode(): + return False + + return justknobs_check( + "pytorch/remote_cache:bundle_triton_into_fx_graph_cache_v2" + ) + + @classmethod + def begin_compile(cls) -> None: + """ + Initializes the TritonBundler. + The current TritonBundler bundle is finalized by TritonBundler.collect. + """ + if not TritonBundler.is_enabled(): + return + log.debug("TritonBundler.begin_compile is called") + assert cls._entries is None + cls._entries = [] + cls._static_autotuners = [] + + @classmethod + def end_compile(cls) -> None: + """ + Finalizes the TritonBundler. If collect is not yet called, it + discards the current bundle. + """ + log.debug("TritonBundler.end_compile is called") + cls._entries = None + cls._static_autotuners = None + + @classmethod + def put(cls, kernel_hash: str, device: int) -> None: + """ + Lazily observes that we have seen a Triton kernel compilation. Remembers + it for when collect is later called. + """ + if (entries := cls._entries) is not None: + entries.append( + TritonBundleEntry(kernel_hash, device, triton_cache_dir(device)) + ) + + @classmethod + def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821 + from torch._inductor import config + + assert config.use_static_cuda_launcher + if (entries := cls._static_autotuners) is not None: + # Clear a bunch of unpicklable values and make a copy to save + # for FXGraphCache + old_values = kernel.prepare_for_pickle() + new_kernel = copy.deepcopy(kernel) + new_kernel.prepare_for_caching() + new_kernel._reload_kernel = None + + entries.append( + StaticallyLaunchedAutotuner( + key, + new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"), + new_kernel, + ) + ) + # Put the values back since we need it to use now + ( + kernel.fn.fn, + kernel.fn.__globals__, + kernel.fn.used_global_vals, + kernel.fn.repr, + kernel.launchers, + ) = old_values + + @classmethod + def collect_static_autotuners( + cls, + ) -> tuple[list[StaticallyLaunchedAutotuner], list[str]]: + if not cls._static_autotuners: + return [], [] + else: + log.info( + "Saving %d statically launchable CachingAutotuners", + len(cls._static_autotuners), + ) + static_autotuner_names = [i.kernel_name for i in cls._static_autotuners] + counters["inductor"]["triton_bundler_save_static_autotuner"] += 1 + return cls._static_autotuners, static_autotuner_names + + @classmethod + def load_autotuners( + cls, static_autotuners: Optional[list[StaticallyLaunchedAutotuner]] + ) -> list[str]: + """ + Load statically launchable CachingAutotuners into async_compile.CompiledTritonKernels + cache. + """ + if not static_autotuners: + return [] + + from torch._inductor.async_compile import CompiledTritonKernels + from torch._inductor.codecache import StaticAutotunerFuture + + log.info("Loading %d statically launchable autotuners", len(static_autotuners)) + kernel_names = [] + with dynamo_timed("TritonBundler.load_cached_static_autotuners"): + for result in static_autotuners: + try: + # Make sure the cubin path exists and is valid + for compile_result in result.kernel.compile_results: + compile_result.reload_cubin_path() + except RuntimeError as e: + log.warning( + "Failed to reload cubin file statically launchable autotuner %s: %s", + result.kernel_name, + e, + ) + continue + # We make a future instead of returning the kernel here so that + # kernels that are not statically launchable (i.e. cache miss) + # can launch a worker without waiting on the blocking step of + # StaticAutotunerFuture.result(). + CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture( + result.kernel + ) + counters["inductor"]["triton_bundler_load_static_autotuner"] += 1 + kernel_names.append(result.kernel_name) + return kernel_names + + @classmethod + def collect( + cls, + ) -> tuple[TritonBundle, Optional[TritonBundlerMetadata]]: + """ + This is the main function called when a cache write happens. This function + converts all the previously remembered kernels into bundled format so that + it can be written into a cache entry. + This function also finalizes the current bundle. + """ + from torch._inductor import config + + if not TritonBundler.is_enabled(): + cls.end_compile() + set_feature_use("triton_bundling", False) + return TritonBundle([], []), None + set_feature_use("triton_bundling", True) + + with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True): + entries = cls._entries + if entries is not None: + result: list[TritonKernelArtifacts] = [] + kernel_names: list[str] = [] + for entry in entries: + artifacts: list[TritonKernelArtifact] = [] + path = os.path.join(entry.directory, entry.kernel_hash) + if not os.path.exists(path): + continue + for filename in os.listdir(path): + filepath = os.path.join(path, filename) + try: + assert os.path.isfile(filepath) + with open(filepath, "rb") as file: + payload = file.read() + if filepath.endswith(".json"): + # Make sure there's no sentinel value + if TritonBundler._REPLACE_BYTES in payload: + log.warning( + "Bundle contains illegal %s, payload: %s", + TritonBundler._REPLACE_BYTES, + payload, + ) + raise AssertionError( + "Bundle contains illegal bytes" + ) + # Remove the path from payload + payload = payload.replace( + str.encode(path), TritonBundler._REPLACE_BYTES + ) + artifacts.append( + TritonKernelArtifact(filename, payload) + ) + counters["inductor"]["triton_bundler_save_kernel"] += 1 + except Exception: + log.debug("failed to collect triton kernel", exc_info=True) + extension = os.path.splitext(filename)[1] + if extension in GPU_KERNEL_BIN_EXTS.values(): + # Each kernel has bunch of files like .cubin(for cuda), .spv(for xpu), .json, .ttir + # Just append one of them without the extension + kernel_names.append(Path(filename).stem) + if artifacts: + result.append( + TritonKernelArtifacts( + entry.kernel_hash, + entry.device, + artifacts, + ) + ) + if config.use_static_cuda_launcher: + static_autotuners, static_kernel_names = ( + cls.collect_static_autotuners() + ) + else: + static_autotuners = [] + static_kernel_names = [] + cls.end_compile() + return TritonBundle(result, static_autotuners), TritonBundlerMetadata( + kernel_names, static_kernel_names + ) + return TritonBundle([], []), None + + @staticmethod + def read_and_emit(bundle: TritonBundle) -> Optional[TritonBundlerMetadata]: + """ + This is the main function called when a cache read happens. This function + converts the bundled format back into individual files and writes them + to the filesystem. + + NOTE: When we are writing to the filesystem, we assume exclusive access + to the target directory. + This means that if the target folder already exists and is non-empty, + we bail out. + Exclusive access means that no other process should be writing to + or reading from the target directory. + """ + from torch._inductor import config + + if not TritonBundler.is_enabled(): + return None + + with dynamo_timed( + key="TritonBundler.read_and_emit", log_pt2_compile_event=True + ): + kernel_names: list[str] = [] + + for artifacts in bundle.kernel_artifacts: + basedir = triton_cache_dir(artifacts.device) + directory = os.path.join(basedir, artifacts.kernel_hash) + + if os.path.exists(directory) and len(os.listdir(directory)) != 0: + # If directory already exists, we bail out and leave + # local disk to take care of caching + log.debug( + "Bailing out TritonBundler.read_and_emit, %s is non empty", + directory, + ) + continue + + Path(basedir).mkdir(parents=True, exist_ok=True) + + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + tmp_dir = os.path.join(basedir, f"tmp.{rnd_id}") + os.makedirs(tmp_dir) + + for artifact in artifacts.artifacts: + filepath = os.path.join(tmp_dir, artifact.filename) + with open(filepath, "wb") as file: + payload = artifact.payload + if artifact.filename.endswith(".json"): + payload = payload.replace( + TritonBundler._REPLACE_BYTES, str.encode(directory) + ) + file.write(payload) + counters["inductor"]["triton_bundler_read_and_emit_kernel"] += 1 + extension = os.path.splitext(artifact.filename)[1] + if extension in GPU_KERNEL_BIN_EXTS.values(): + # Each kernel has bunch of files like .cubin(for cuda), spv(for xpu), .json, .ttir + # Just append one of them without the extension + kernel_names.append(Path(artifact.filename).stem) + + if _IS_WINDOWS: + with FileLock(directory + ".lock"): + if os.path.exists(directory): + shutil.rmtree(directory) + os.replace(tmp_dir, directory) + else: + # Atomic on POSIX systems + try: + os.replace(tmp_dir, directory) + except OSError: + log.warning("Directory %s is not empty - skipping!", tmp_dir) + + if config.use_static_cuda_launcher: + static_kernel_names = TritonBundler.load_autotuners( + bundle.static_autotuners + ) + else: + static_kernel_names = [] + return TritonBundlerMetadata(kernel_names, static_kernel_names) diff --git a/phivenv/Lib/site-packages/torch/_inductor/utils.py b/phivenv/Lib/site-packages/torch/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2687258cca4120d412c46e50cf89f71154c4bca6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/utils.py @@ -0,0 +1,3204 @@ +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import enum +import functools +import importlib +import inspect +import io +import itertools +import logging +import math +import operator +import os +import platform +import re +import shutil +import statistics +import sys +import tempfile +import textwrap +import time +import unittest +from collections.abc import Collection, Iterator, Mapping, MutableMapping, MutableSet +from datetime import datetime +from io import StringIO +from typing import ( + Any, + Callable, + cast, + Generic, + Literal, + NamedTuple, + Optional, + Protocol, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import ( + Concatenate, + dataclass_transform, + ParamSpec, + Self, + TypeAlias, + TypeGuard, +) +from unittest import mock + +import sympy + +import torch +from torch._inductor.runtime.hints import DeviceProperties +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map_only + + +OPTIMUS_EXCLUDE_POST_GRAD = [ + "activation_quantization_aten_pass", +] + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence, ValuesView + + from torch import SymBool, SymFloat, SymInt + from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND + from torch.fx import GraphModule + from torch.fx.experimental.symbolic_shapes import ShapeEnv + from torch.fx.node import Node + + from .codegen.common import WorkspaceArg + from .codegen.wrapper import PythonWrapperCodegen + from .graph import GraphLowering + from .ir import ( + Buffer, + ExternKernel, + ExternKernelOut, + IRNode, + Layout, + Operation, + ReinterpretView, + ) + from .output_code import CompiledFxGraph + from .scheduler import BaseSchedulerNode, SchedulerBuffer + + +GPU_TYPES = ["cuda", "mps", "xpu"] +T = TypeVar("T") + + +# defines here before import torch._dynamo is for avoiding circular import +# when get_gpu_type is imported from dynamo +@functools.cache +def get_gpu_type() -> str: + avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] + assert len(avail_gpus) <= 1 + gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop() + return gpu_type + + +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import detect_fake_mode +from torch.autograd import DeviceType +from torch.autograd.profiler_util import EventList +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import ShapeProp +from torch.utils._sympy.functions import ( + CeilDiv, + CleanDiv, + FloorDiv, + Identity, + ModularIndexing, +) +from torch.utils._sympy.symbol import make_symbol, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from . import config +from .runtime.runtime_utils import ceildiv as runtime_ceildiv + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + +_T = TypeVar("_T") +VarRanges = dict[sympy.Expr, sympy.Expr] +InputType = Optional[Union[torch.Tensor, int, torch.SymInt]] + +GPU_KERNEL_BIN_EXTS = {"cuda": ".cubin", "xpu": ".spv"} + +GPU_ALIGN_BYTES = 16 +ALIGNMENT = 16 + +TMA_ALIGNMENT = 16 +TMA_DESCRIPTOR_SIZE = 128 + +ALIGN_BYTES = 64 +assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" + + +def _align(nbytes: int) -> int: + """Round up to the nearest multiple of ALIGN_BYTES""" + return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES + + +def _is_aligned(v: sympy.Expr) -> bool: + """v can be statically proven to be a multiple of ALIGN_BYTES""" + if isinstance(v, (sympy.Add, sympy.Max)): + return all(map(_is_aligned, v.args)) + return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES + + +class align(sympy.Function): + """Symbolically round up to the nearest multiple of ALIGN_BYTES""" + + nargs = (1,) + is_integer = True + + @classmethod + def eval(cls, value: sympy.Expr) -> Optional[sympy.Expr]: + if isinstance(value, (int, sympy.Integer)): + return _align(int(value)) + if _is_aligned(value): + return value + + +@dataclasses.dataclass(frozen=True) +class GraphPartitionMap: + """ + Mapping from the partition info (e.g., input/output) to the graph info + """ + + # a unique id of graph partition + id: int + + # map partition input/output indices to graph input/output indices. None indicates + # a partition input/output is not a graph input/output. + input_index_mapping: list[Optional[int]] + output_index_mapping: list[Optional[int]] + + # name of constants read/written by the graph partition + constant_names: list[str] + + +def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float: + """ + Returns benchmark results by examining torch profiler events. + This could be more accurate as it doesn't count CPU side overhead. + However, this also requires manually excluding irrelevant event, e.g. + vectorized_elementwise_kernel which is used to fill L2 cache, + various CUDA events, etc, so could also be fragile. + """ + + fn() + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.float16, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + + # Warm-up + for _ in range(n_warmup): + fn() + + start_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + torch.cuda.synchronize() + for i in range(n_repeat): + cache.zero_() + start_event[i].record() + with torch.cuda.nvtx.range("RunCudaModule"): + fn() + end_event[i].record() + torch.cuda.synchronize() + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)] + ) + + res = torch.mean(times).item() + log.debug("raw events") + log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1)) + filtered_events = EventList( + [ + event + for event in p.events() + if event.device_type == DeviceType.CUDA and "fused_abs_max_0" in event.name + ] + ) + if filtered_events: + res -= ( + statistics.mean(event.device_time_total for event in filtered_events) + / 1000.0 + ) + + log.debug("profiling results: %s ms", res) + return res + + +def do_bench_using_profiling( + fn: Callable[[], Any], warmup: int = 25, rep: int = 100 +) -> float: + """ + Returns benchmark results by examining torch profiler events. + This could be more accurate as it doesn't count CPU side overhead. + However, this also requires manually excluding irrelevant event, e.g. + vectorized_elementwise_kernel which is used to fill L2 cache, + various CUDA events, etc, so could also be fragile. + """ + + fn() + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + + # Warm-up + for _ in range(n_warmup): + fn() + + torch.cuda.synchronize() + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + # Benchmark + for i in range(n_repeat): + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + fn() + # Record clocks + torch.cuda.synchronize() + + log.debug("raw events") + log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1)) + + filtered_events = EventList( + [ + event + for event in p.events() + if event.device_type == DeviceType.CUDA and event.name != "Context Sync" + ] + ) + if len(filtered_events) % n_repeat != 0: + raise RuntimeError( + "Failed to divide all profiling events into #repeat groups. " + "#CUDA events: %d, #repeats: %s", + len(filtered_events), + n_repeat, + ) + num_event_per_group = len(filtered_events) / n_repeat + actual_events = EventList( + [ + event + for i, event in enumerate(filtered_events) + if i % num_event_per_group != 0 + ] + ) + actual_events._build_tree() + actual_events = actual_events.key_averages() + + log.debug("profiling time breakdown") + log.debug(actual_events.table(row_limit=-1)) + + res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat + log.debug("profiling results: %s ms", res) + return res + + +@functools.cache +def has_torchvision_roi_align() -> bool: + try: + from torchvision.ops import roi_align # noqa: F401 + + torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta") + return roi_align is not None and hasattr( + getattr(torch.ops, "torchvision", None), "roi_align" + ) + except ImportError: + return False + except RuntimeError as e: + assert "torchvision::nms does not exist" in str(e) + return False + + +def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: + if device is None: + return torch.tensor(0.0).device # default device + if isinstance(device, str): + device = torch.device(device) + if device.type not in ("cpu", "meta") and device.index is None: + device_interface = get_interface_for_device(device.type) + return torch.device(device.type, index=device_interface.Worker.current_device()) + return device + + +def sympy_product(it: Iterable[sympy.Expr]) -> sympy.Expr: + return functools.reduce(operator.mul, it, sympy.S.One) + + +def sympy_dot(seq1: Sequence[sympy.Expr], seq2: Sequence[sympy.Expr]) -> sympy.Expr: + assert len(seq1) == len(seq2) + return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) + + +def unique(it: Iterable[_T]) -> ValuesView[_T]: + return {id(x): x for x in it}.values() + + +def ceildiv( + number: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] +) -> Union[int, sympy.Expr]: + if isinstance(number, sympy.Expr) or isinstance(denom, sympy.Expr): + return CeilDiv(sympy.sympify(number), sympy.sympify(denom)) + # TODO: There is a bug in a call to this function, to repro: + # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy + # --amp --only YituTechConvBert --dynamic-shapes + assert isinstance(number, int) and isinstance(denom, int), ( + f"{number}: {type(number)}, {denom}: {type(denom)}" + ) + return runtime_ceildiv(number, denom) + + +def _type_of(key: Optional[torch.dtype]) -> str: + # Use the function here to get rid of dependencies on the Triton during the codegen. + # Refer to Triton implementation here: + # https://github.com/triton-lang/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238 + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + dtype_str = str(key).split(".")[-1] + tys = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8e4b15x4": "fp8e4b15x4", + "float8_e4m3fn": "fp8e4nv", + "float8_e5m2": "fp8e5", + # TODO: remove when support is added in triton + # https://github.com/triton-lang/triton/issues/6054 + "float8_e8m0fnu": "u8", + "float4_e2m1fn_x2": "u8", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + } + # reinterpret can create triton type + tys.update({v: v for v in list(tys.values())}) + return key if isinstance(key, str) else f"*{tys[dtype_str]}" + + +def convert_shape_to_inductor( + lst: Iterable[Union[int, torch.SymInt]], +) -> list[sympy.Expr]: + """ + Gets the shape and stride of a tensor. For non-symbolic tensors, this is + trivial. But for symbolic tensors, we need to map from SymIntNode into + sympy.Expr. + """ + return [sympy.sympify(i) for i in lst] + + +def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]: + """ + Like convert_shape_to_symint, but operates on a single expression. + """ + from .virtualized import V + + return ( + i + if isinstance(i, int) + else ( + int(i) + if isinstance(i, sympy.Integer) + else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) + ) + ) + + +def convert_shape_to_symint( + lst: Iterable[Union[int, sympy.Expr]], +) -> list[Union[int, torch.SymInt]]: + """ + Takes a list of shapes from Inductor and converts them into symints (or just + ints if all shapes are static). + """ + return [convert_to_symint(i) for i in lst] + + +def is_view(op: torch._ops.OpOverload) -> bool: + """ + Does this op overload have aliasing + """ + return any(a.alias_info is not None for a in op._schema.arguments) + + +def is_pointwise_use( + use: Node, + is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False, +) -> bool: + """ + Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn` + + Uses in views ops will follow the views uses + """ + + if not use.op == "call_function": + return False + if not ( + isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem + ): + return False + + target = cast(torch._ops.OpOverload, use.target) + if target is operator.getitem or is_view(target): + return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users) + + return torch.Tag.pointwise in target.tags or is_pointwise_fn(target) + + +def gen_gm_and_inputs( + target: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[GraphModule, list[torch.Tensor]]: + g = torch.fx.Graph() + graph_args: list[torch.Tensor] = [] + + def add_tensor_arg(arg: torch.Tensor) -> Node: + graph_args.append(arg) + return g.placeholder(f"arg{len(graph_args)}") + + node = g.call_function( + target, *tree_map_only(torch.Tensor, add_tensor_arg, (args, kwargs)) + ) + if ( + len(target._schema.returns) == 1 + and str(target._schema.returns[0].type) == "Tensor" + ): + node = (node,) # type: ignore[assignment] + g.output(node) + + gm = torch.fx.GraphModule({}, g) + return gm, graph_args + + +def synchronize(device: str = "cuda") -> None: + if device == "cpu": + return + device_interface = get_interface_for_device(device) + if device_interface.is_available(): + device_interface.synchronize() + + +def timed( + model: Callable[..., Any], + example_inputs: Sequence[Any], + times: int = 1, + device: str = "cuda", +) -> float: + synchronize(device) + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(times): + result = model(*example_inputs) + synchronize(device) + t1 = time.perf_counter() + # GC the result after timing + assert result is not None # type: ignore[possibly-undefined] + return t1 - t0 + + +def print_performance( + model: Callable[..., Any], + example_inputs: Sequence[Any] = (), + times: int = 10, + repeat: int = 10, + baseline: float = 1.0, + device: str = "cuda", +) -> float: + timings = torch.tensor( + [timed(model, example_inputs, times, device) for _ in range(repeat)] + ) + took = torch.median(timings) / times + print(f"{took / baseline:.6f}") + return took.item() + + +def precompute_method(obj: Any, method: str) -> None: + """Replace obj.method() with a new method that returns a precomputed constant.""" + result = getattr(obj, method)() + setattr(obj, method, lambda: result) + + +def precompute_methods(obj: Any, methods: list[str]) -> None: + """Replace methods with new methods that returns a precomputed constants.""" + for method in methods: + precompute_method(obj, method) + + +def cmp(a: int, b: int) -> int: + return int(a > b) - int(a < b) + + +def pad_listlike(x: Union[int, Sequence[int]], size: int) -> Sequence[int]: + if isinstance(x, int): + return [x] * size + if len(x) == 1: + return type(x)([x[0]]) * size # type: ignore[call-arg, operator, return-value] + return x + + +# Used to ensure that iterating over a set is deterministic +def tuple_sorted(x: tuple[_T, ...]) -> list[_T]: + if len(x) == 0: + return [] + + def sort_func(elem: _T) -> str: + if isinstance(elem, str): + return elem + + from .scheduler import BaseSchedulerNode + + assert isinstance(elem, BaseSchedulerNode) + return elem.get_name() + + return sorted(x, key=sort_func) + + +P = ParamSpec("P") +RV = TypeVar("RV", covariant=True) + + +class CachedMethod(Protocol, Generic[P, RV]): + @staticmethod + def clear_cache(cache: Any) -> None: ... + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ... + + +# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature +def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: + name = fn.__name__ + key = f"__{name}_cache" + + # wrapper is likely on the hot path, compile a specialized version of it + ctx = {"fn": fn} + exec( + f"""\ + def {name}_cache_on_self(self): + try: + return self.{key} + except AttributeError: + pass + rv = fn(self) + object.__setattr__(self, "{key}", rv) + return rv + """.lstrip(), + ctx, + ) + wrapper = functools.wraps(fn)(ctx[f"{name}_cache_on_self"]) + + def clear_cache(self: Any) -> None: + if hasattr(self, key): + delattr(self, key) + + wrapper.clear_cache = clear_cache # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + +def aggregate_origins( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], +) -> OrderedSet[Node]: + from . import ir + + if isinstance(node_schedule, list): + return functools.reduce( + operator.or_, + [ + node.node.origins + for node in node_schedule + if hasattr(node, "node") and node.node + ], + OrderedSet(), + ) + elif isinstance(node_schedule, ir.ExternKernel): + return node_schedule.origins + else: + return OrderedSet() + + +def get_fused_kernel_name( + node_schedule: Sequence[BaseSchedulerNode], + descriptive_names: Literal[True, "torch", "original_aten", "inductor_node"], +) -> str: + all_origins = aggregate_origins(node_schedule) + if descriptive_names == "original_aten": + # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions) + sources = [ + origin.meta["original_aten"]._overloadpacket.__name__ + for origin in all_origins + if origin.op == "call_function" + and "original_aten" in origin.meta + and origin.meta["original_aten"] is not None + ] + sources = sorted(OrderedSet(sources)) + elif descriptive_names == "torch": + # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) + sources = [] + for origin in all_origins: + if origin.op == "call_function" and "source_fn_stack" in origin.meta: + source_fn = origin.meta["source_fn_stack"][-1] + if isinstance(source_fn[1], str): + sources.append(source_fn[1]) + else: + sources.append(source_fn[1].__name__) + sources = sorted(OrderedSet(sources)) + elif descriptive_names == "inductor_node": + sources = [ + origin.name for origin in all_origins if origin.op == "call_function" + ] + else: + raise NotImplementedError + sources = sources + return "_".join(["fused"] + sources) + + +def get_kernel_metadata( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + wrapper: PythonWrapperCodegen, +) -> tuple[str, str]: + all_origins = aggregate_origins(node_schedule) + inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] + + from_node_dict = collections.defaultdict(list) + original_aten_dict = collections.defaultdict(list) + + # Attempt to sort `inductor_nodes` topologically. Note that the case + # where `inductor_nodes` contains nodes from multiple graph instances + # is not supported. An example of this is conditional statements. + single_graph = None + if len(inductor_nodes): + unique_graphs = OrderedSet(n.graph for n in inductor_nodes) + if len(unique_graphs) == 1: + single_graph = inductor_nodes[0].graph + # create a map of idx -> node and cache it + if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"): + node_to_idx_map = {n: idx for idx, n in enumerate(single_graph.nodes)} + single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map # type: ignore[attr-defined] + inductor_nodes.sort( + key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] # type: ignore[attr-defined] + ) + + for node in inductor_nodes: + if "original_aten" in node.meta and node.meta["original_aten"] is not None: + key = str(node.meta["original_aten"]._overloadpacket) + original_aten_dict[key].append(node.name) + if "from_node" in node.meta: + key = node.meta["from_node"][0].name + from_node_dict[key].append(node.name) + sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted" + metadata = ( + f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], " + f"Original ATen: [{', '.join(original_aten_dict.keys())}]" + ) + + # trace back to original node here + detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"] + for original_node, nodes in sorted(from_node_dict.items()): + detailed_metadata.append( + f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}" + ) + + # print the aot_autograd graph fragment + if single_graph is not None: + detailed_metadata.append(f"{wrapper.comment} Graph fragment:") + for n in inductor_nodes: + # TODO(future): maybe refactor torch/fx/graph.py to make it easy to + # generate python code for graph fragments + detailed_metadata.append(f"{wrapper.comment} {n.format_node()}") + + return metadata, "\n".join(detailed_metadata) + + +def dominated_nodes( + initial_queue: Iterable[torch.fx.Node], + skip_filter: Optional[Callable[[Any], bool]] = None, +) -> OrderedSet[torch.fx.Node]: + """Returns the set of nodes whose values depend on those within initial_queue""" + initial_queue = list(initial_queue) + dominated_set = OrderedSet(initial_queue) + + while initial_queue: + node = initial_queue.pop() + for user in node.users: + if skip_filter and skip_filter(user): + continue + if user not in dominated_set: + dominated_set.add(user) + initial_queue.append(user) + + return dominated_set + + +def gather_origins( + args: Sequence[IRNode], kwargs: dict[str, IRNode] +) -> OrderedSet[IRNode]: + import itertools + + from . import ir + + def is_unrealized_node(n: IRNode) -> bool: + if isinstance(n, ir.TensorBox): + return is_unrealized_node(n.data) + if isinstance(n, ir.StorageBox): + return is_unrealized_node(n.data) + return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) + + kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] + arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] + return OrderedSet(itertools.chain(*arg_origins, *kwarg_origins)) + + +def sympy_str(expr: sympy.Expr) -> str: + """ + Normal sympy str is very slow, this is a lot faster. The result are + somewhat worse, as it doesn't do as much simplification. So don't + use this for final codegen. + """ + + def is_neg_lead(expr: sympy.Expr) -> bool: + return ( + isinstance(expr, sympy.Mul) and len(expr.args) == 2 and expr.args[0] == -1 + ) + + def sympy_str_add(expr: sympy.Expr) -> str: + if isinstance(expr, sympy.Add): + # Special case 'a - b'. Note that 'a - b - c' will still appear as + # 'a + -1 * b + -1 * c'. + if len(expr.args) == 2 and is_neg_lead(expr.args[1]): + return f"{sympy_str_mul(expr.args[0])} - {sympy_str_mul(expr.args[1].args[1])}" + else: + return " + ".join(map(sympy_str_mul, expr.args)) + else: + return sympy_str_mul(expr) + + def sympy_str_mul(expr: sympy.Expr) -> str: + if isinstance(expr, sympy.Mul): + if is_neg_lead(expr): + # Special case '-a'. Note that 'a * -b' will still appear as + # '-1 * a * b'. + return f"-{sympy_str_atom(expr.args[1])}" + else: + return " * ".join(map(sympy_str_atom, expr.args)) + else: + return sympy_str_atom(expr) + + def sympy_str_atom(expr: sympy.Expr) -> str: + if isinstance(expr, sympy.Symbol): + return expr.name + elif isinstance(expr, (sympy.Add, sympy.Mul)): + return f"({sympy_str_add(expr)})" + elif isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)): + return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" + else: + return str(expr) + + return sympy_str_add(expr) + + +def get_bounds_index_expr(index: sympy.Expr) -> ValueRanges[Any]: + from .virtualized import V + + # If this expression does not come from an FX node, we compute its bounds + if ( + config.compute_all_bounds + and (fx_node := getattr(V.interpreter, "current_node", None)) + and fx_node.target != "index_expr" + ): + return bound_sympy(index) + else: + return ValueRanges.unknown() + + +def prefix_is_reduction(prefix: str) -> bool: + return prefix[0] == "r" + + +def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert prefix != SymT.SIZE + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return make_symbol(prefix, idx, integer=True, nonnegative=True) + + +def generate_assert(check: bool) -> bool: + return (check or config.debug_index_asserts) and config.assert_indirect_indexing + + +def sympy_index_symbol(name: str) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) + + +def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.Expr: + """ + When the passed replacement symbol v is a string, it is converted to a symbol with name v that + have the same replaced expression integer and nonnegative properties. + """ + + def to_symbol( + replaced: sympy.Expr, replacement: Union[sympy.Expr, str] + ) -> sympy.Symbol: + assert isinstance(replaced, sympy.Expr) + if isinstance(replacement, str): + return sympy.Symbol( + replacement, + integer=replaced.is_integer, # type: ignore[attr-defined] + nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] + ) + else: + return replacement + + # xreplace is faster than subs, but is way more picky + return sympy.sympify(expr).xreplace( + {k: to_symbol(k, v) for k, v in replacements.items()} + ) + + +def is_symbolic(a: Any) -> TypeGuard[Union[torch.SymInt, torch.Tensor]]: + return isinstance(a, torch.SymInt) or ( + isinstance(a, torch.Tensor) + and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride())) + ) + + +def any_is_symbolic(*args: Any) -> bool: + return any(is_symbolic(a) for a in args) + + +def get_first_incompatible_cudagraph_node( + gm: torch.fx.GraphModule, +) -> Optional[torch.fx.Node]: + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + forbidden_set = OrderedSet( + [ + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "fbgemm.dense_to_jagged.default", + "fbgemm.jagged_to_padded_dense.default", + "run_and_save_rng_state", + "run_with_rng_state", + "aten._local_scalar_dense", + # Technically, it's not necessary to ban this, because an + # assert_scalar with constant arguments can be validly run + # with CUDA graphs, but the operator is also pointless with + # constant arguments, so might as well ban + "aten._assert_scalar", + ] + ) + if torch.are_deterministic_algorithms_enabled(): + forbidden_set.update( + ( + "aten._unsafe_index_put.default", + "aten._unsafe_masked_index_put_accumulate.default", + "aten.index_put.default", + "aten.index_put_.default", + "aten.scatter.src", + "aten.scatter.reduce", + "aten.scatter.value_reduce", + "aten.scatter_add_", + "aten.scatter_add.default", + "aten.scatter_reduce.two", + "aten.scatter_reduce_.two", + "aten.scatter_reduce.two_out", + ) + ) + + for node in gm.graph.nodes: + if str(node.target) in forbidden_set: + return node + + if ( + not torch._inductor.config.graph_partition + and isinstance(node.target, torch._ops.OpOverload) + and torch._C.Tag.cudagraph_unsafe in node.target.tags + ): + # skip cudagraph if a cudagraph_unsafe op is detected. + # graph_partition helps by splitting on this cudagraph_unsafe + # op and cudagraphifying the subgraphs. + return node + + if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val): + return node + + return None + + +def output_node(gm: torch.fx.GraphModule) -> Node: + """Get the output node from an FX graph""" + last_node = next(iter(reversed(gm.graph.nodes))) + assert last_node.op == "output" + return last_node + + +def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]: + placeholder_nodes = gm.graph.find_nodes(op="placeholder") + input_devices: OrderedSet[torch.device] = OrderedSet( + node.meta["val"].device + for node in placeholder_nodes + if isinstance(node.meta.get("val"), torch.Tensor) + ) + + out_arg = output_node(gm).args[0] # type: ignore[union-attr] + out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,) + out_devices: OrderedSet[torch.device] = OrderedSet( + arg.meta["val"].device + for arg in out_args + if isinstance(arg, torch.fx.Node) + and isinstance(arg.meta.get("val"), torch.Tensor) + ) + return input_devices | out_devices + + +import gc + + +def unload_xpu_triton_pyds() -> None: + # unload __triton_launcher.pyd + for module_name in list(sys.modules.keys()): + if not module_name.startswith("torch._inductor.runtime.compile_tasks."): + continue + m = sys.modules[module_name] + for attr_name in m.__dict__.keys(): + if attr_name.startswith("triton_"): + kernel = getattr(m, attr_name) + if isinstance( + kernel, torch._inductor.runtime.triton_heuristics.CachingAutotuner + ): + for result in kernel.compile_results: + if isinstance( + result, + torch._inductor.runtime.triton_heuristics.TritonCompileResult, + ): + result.kernel.run.mod.__del__() + del sys.modules[module_name] + + # unload spirv_utils.pyd + if "triton.runtime.driver" in sys.modules: + mod = sys.modules["triton.runtime.driver"] + del type(mod.driver.active.utils).instance + del mod.driver.active.utils + + gc.collect() + + +_registered_caches: list[Any] = [] + + +def clear_on_fresh_cache(obj: Any) -> Any: + """ + Use this decorator to register any caches that should be cache_clear'd + with fresh_cache(). + """ + if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear): + raise AttributeError(f"{obj} does not have a cache_clear method") + + _registered_caches.append(obj) + return obj + + +def clear_caches() -> None: + """ + Clear all registered caches. + """ + for obj in _registered_caches: + obj.cache_clear() + + +@contextlib.contextmanager +def fresh_cache( + cache_entries: Optional[dict[str, Any]] = None, + dir: Optional[str] = None, + delete: bool = True, +) -> Iterator[None]: + """ + Contextmanager that provides a clean tmp cachedir for pt2 caches. + + Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes + generated with this cache instance. + """ + clear_caches() + + inductor_cache_dir = tempfile.mkdtemp(dir=dir) + try: + with mock.patch.dict( + os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} + ): + log.debug("Using inductor cache dir %s", inductor_cache_dir) + triton_cache_dir = os.path.join(inductor_cache_dir, "triton") + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): + yield + if isinstance(cache_entries, dict): + assert len(cache_entries) == 0, "expected empty cache_entries dict" + if os.path.exists(triton_cache_dir): + files = os.listdir(triton_cache_dir) + cache_entries.update( + { + f: os.path.getsize(os.path.join(triton_cache_dir, f)) + for f in files + if ".lock" not in f + } + ) + if delete: + if is_windows() and torch.xpu.is_available(): + unload_xpu_triton_pyds() + + shutil.rmtree( + inductor_cache_dir, + # Let's not fail if we can't clean up the temp dir. Also note that for + # Windows, we can't delete the loaded modules because the module binaries + # are open. + onerror=lambda func, path, exc_info: log.warning( + "Failed to remove temporary cache dir at %s", + inductor_cache_dir, + exc_info=exc_info, + ), + ) + except Exception: + log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) + raise + finally: + clear_caches() + + +# Deprecated functions -- only keeping them for BC reasons +clear_on_fresh_inductor_cache = clear_on_fresh_cache +clear_inductor_caches = clear_caches +fresh_inductor_cache = fresh_cache + + +def argsort(seq: Sequence[Any]) -> list[int]: + # preserve original order for equal strides + getter = seq.__getitem__ + a_r = range(len(seq)) + return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 + + +def argsort_sym( + shape_env: ShapeEnv, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]] +) -> list[int]: + def cmp(a: tuple[int, sympy.Expr], b: tuple[int, sympy.Expr]) -> int: + a_idx, a_val = a + b_idx, b_val = b + + def evaluate(expr: Union[bool, torch.SymInt, sympy.Expr]) -> bool: + if isinstance(expr, bool): + return expr + return shape_env.evaluate_expr(expr, size_oblivious=True) + + if evaluate(a_val < b_val): + return -1 + if evaluate(a_val > b_val): + return 1 + # If strides are the same, prefer the original order. + # (this matches argsort's algorithm). + # For strides = [2048, 2048, 16, 1], this is + # [3, 2, 1, 0]. + if a_idx < b_idx: + return 1 + if a_idx > b_idx: + return -1 + return 0 + + # Strategy: convert all symints to sympy.Expr, then use a custom comparator + exprs = [ + (idx, s.node.expr if isinstance(s, torch.SymInt) else s) + for idx, s in enumerate(seq) + ] + exprs = sorted(exprs, key=functools.cmp_to_key(cmp)) + result = [idx for idx, _ in exprs] + return result + + +@functools.lru_cache(8) +def get_dtype_size(dtype: torch.dtype) -> int: + # TODO: Investigate why uint64 tensor creation causes overflow error: + # Workaround for RuntimeError in memory size calculation, but underlying cause unclear + if dtype == torch.uint64: + return 8 + return torch.empty((), dtype=dtype).element_size() + + +class LineContext(NamedTuple): + context: Any + + +@dataclasses.dataclass +class ValueWithLineMap: + value: str + line_map: list[tuple[int, LineContext]] + + +class IndentedBuffer: + tabwidth = 4 + + def __init__(self, initial_indent: int = 0) -> None: + self._lines: list[Union[DeferredLineBase, LineContext, str]] = [] + self._indent = initial_indent + + @contextlib.contextmanager + def set_tabwidth(self, tabwidth: int) -> Iterator[None]: + prev = self.tabwidth + try: + self.tabwidth = tabwidth + yield + finally: + self.tabwidth = prev + + def getvaluewithlinemap(self) -> ValueWithLineMap: + buf = StringIO() + p = 1 + linemap: list[tuple[int, LineContext]] = [] + for li in self._lines: + if isinstance(li, DeferredLineBase): + line = li() + if line is None: + continue + elif isinstance(li, LineContext): + linemap.append((p, li.context)) + continue + else: + line = li + assert isinstance(line, str) + buf.write(line) + buf.write("\n") + p += 1 + line.count("\n") + return ValueWithLineMap(buf.getvalue(), linemap) + + def getvalue(self) -> str: + return self.getvaluewithlinemap().value + + def getrawvalue(self) -> str: + buf = StringIO() + for li in self._lines: + if isinstance(li, DeferredLineBase): + line = li() + if line is None: + continue + elif isinstance(li, LineContext): + continue + else: + line = li + assert isinstance(line, str) + # backslash implies line continuation + if line.endswith("\\"): + buf.write(line[:-1]) + else: + buf.write(line) + buf.write("\n") + return buf.getvalue() + + def clear(self) -> None: + self._lines.clear() + + def __bool__(self) -> bool: + return bool(self._lines) + + def prefix(self) -> str: + return " " * (self._indent * self.tabwidth) + + def newline(self) -> None: + self.writeline("\n") + + def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: + if isinstance(line, LineContext): + self._lines.append(line) + elif isinstance(line, DeferredLineBase): + self._lines.append(line.with_prefix(self.prefix())) + elif line.strip(): + self._lines.append(f"{self.prefix()}{line}") + else: + self._lines.append("") + + def writelines( + self, lines: Sequence[Union[LineContext, DeferredLineBase, str]] + ) -> None: + for line in lines: + self.writeline(line) + + def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]: + @contextlib.contextmanager + def ctx() -> Iterator[None]: + self._indent += offset + try: + yield + finally: + self._indent -= offset + + return ctx() + + def do_indent(self, offset: int = 1) -> None: + self._indent += offset + + def do_unindent(self, offset: int = 1) -> None: + self._indent -= offset + + def splice( + self, other_code: Union[IndentedBuffer, str], strip: bool = False + ) -> None: + if isinstance(other_code, IndentedBuffer): + dedent = float("inf") + for line in other_code._lines: + if not isinstance(line, LineContext) and line: + dedent = min(dedent, len(line) - len(line.lstrip())) + if math.isinf(dedent): + dedent = 0 + for line in other_code._lines: + if isinstance(line, LineContext): + self._lines.append(line) + else: + IndentedBuffer.writeline(self, line[int(dedent) :]) + else: + other_code = textwrap.dedent(other_code) + if strip: + other_code = other_code.lstrip() + if not other_code: + return + other_code = other_code.rstrip() + for s in other_code.split("\n"): + self.writeline(s) + + def map(self, func: Callable[[Any], Any]) -> IndentedBuffer: + res = IndentedBuffer(initial_indent=self._indent) + res._lines = [func(line) for line in self._lines] + return res + + def __repr__(self) -> str: + return f"{type(self)}({self.getvalue()})" + + def __add__(self, other: Self) -> IndentedBuffer: + assert self._indent == other._indent + res = IndentedBuffer(initial_indent=self._indent) + # TODO(rec): or should this be self.__class__(initial_indent=self._indent)? + res.writelines(self._lines) + res.writelines(other._lines) + return res + + +class FakeIndentedBuffer(IndentedBuffer): + def __init__(self) -> None: + super().__init__() + + def __getattribute__(self, name: str) -> Any: + if name == "__class__": # Allow access to the class attribute + return object.__getattribute__(self, name) + raise RuntimeError( + f"Tried to call self.{name} on FakeIndentedBuffer. This buffer" + "is currently used on TritonTemplateKernel to prevent actual" + "writes to the body without explicitly specifying the body with" + "`TritonTemplateKernel.set_subgraph_body(name)`" + ) + + +@contextlib.contextmanager +def restore_stdout_stderr() -> Iterator[None]: + initial_stdout, initial_stderr = sys.stdout, sys.stderr + try: + yield + finally: + sys.stdout, sys.stderr = initial_stdout, initial_stderr + + +class DeferredLineBase: + """A line that can be 'unwritten' at a later time""" + + def __init__(self, line: str): + if not line.strip(): + line = "" + self.line = line + + def __call__(self) -> Union[str, None]: + """Returns either self.line or None to indicate the line has been 'unwritten'""" + raise NotImplementedError + + def _new_line(self, line: str) -> Self: + """Returns a new deferred line with the same condition""" + raise NotImplementedError + + def with_prefix(self, prefix: str) -> Self: + return self._new_line(f"{prefix}{self.line}") + + def lstrip(self) -> Self: + return self._new_line(self.line.lstrip()) + + def __getitem__(self, index: Union[int, slice]) -> Self: + return self._new_line(self.line[index]) + + def __bool__(self) -> bool: + return bool(self.line) + + def __len__(self) -> int: + return len(self.line) + + +class DelayReplaceLine(DeferredLineBase): + """At end of codegen call `line.replace(key, value_fn())`""" + + def __init__(self, key: str, value_fn: Callable[[], str], line: str): + super().__init__(line) + self.key = key + self.value_fn = value_fn + + def __call__(self) -> str: + return self.line.replace(self.key, self.value_fn()) + + def _new_line(self, line: str) -> DelayReplaceLine: + return DelayReplaceLine(self.key, self.value_fn, line) + + +@functools.cache +def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool: + if isinstance(index_or_device, torch.device): + device = index_or_device + else: + device = torch.device(get_gpu_type(), index_or_device) + + prop = DeviceProperties.create(device) + + # SM logic is not relevant to ROCm gpus + # Arbitrarily skipping the older models + if torch.version.hip: + assert prop.major is not None + if prop.major < 9 or prop.major == 10: + log.warning("GPU arch does not support max_autotune_gemm mode usage") + return False + return True + + min_sms = 16 if device.type == "xpu" else 68 # 3080 + avail_sms = prop.multi_processor_count + if avail_sms < min_sms: + log.warning( + "Not enough SMs to use max_autotune_gemm mode", + extra={"min_sms": min_sms, "avail_sms": avail_sms}, + ) + return False + return True + + +@functools.lru_cache +def get_max_num_sms() -> int: + return torch.cuda.get_device_properties("cuda").multi_processor_count + + +def get_num_sms() -> int: + """Handle experimental carveout if set otherwise return hardware SM count""" + # TODO we need to properly guard on this global + carveout = torch._C._get_sm_carveout_experimental() + return get_max_num_sms() - (carveout if carveout is not None else 0) + + +def get_tma_workspace_arg( + num_tma_descriptors: int, + device: torch.device, + num_programs: Optional[int] = None, +) -> WorkspaceArg: + """Builds and returns a WorkspaceArg for the device side TMA workspace buffer.""" + from .codegen.common import WorkspaceArg, WorkspaceZeroMode + + if num_programs is None: + num_programs = get_num_sms() + zero_mode = WorkspaceZeroMode.from_bool(False) + size = num_programs * num_tma_descriptors * TMA_DESCRIPTOR_SIZE + return WorkspaceArg( + count=size, + zero_mode=zero_mode, + device=device, + outer_name=WorkspaceArg.unique_name(), + ) + + +def _use_template_for_gpu( + layout: Layout, allowed_layout_dtypes: list[torch.dtype] +) -> bool: + if layout.dtype not in allowed_layout_dtypes: + log.debug( + "Not using template since dtype %s is not in allowed layout dtypes %s", + layout.dtype, + allowed_layout_dtypes, + ) + return ( + is_gpu(layout.device.type) + and layout.dtype in allowed_layout_dtypes + and is_big_gpu(layout.device) + ) + + +def _use_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_gemm_backends.upper().split(",") + ] + + +def _use_conv_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_conv_backends.upper().split(",") + ] + + +def use_triton_template( + layout: Layout, *, enable_int32: bool = False, enable_float8: bool = False +) -> bool: + from .codegen.common import BackendFeature, has_backend_feature + + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] + if enable_int32: + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] + if enable_float8: + layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2]) + return ( + ( + ( + is_gpu(layout.device.type) + and _use_template_for_gpu(layout, layout_dtypes) + ) + or (layout.device.type == "cpu" and layout.dtype in layout_dtypes) + ) + and (config.max_autotune or config.max_autotune_gemm) + and _use_autotune_backend("TRITON") + and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) + ) + + +def use_triton_tma_template(*matrices: IRNode) -> bool: + from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device + + from .virtualized import V + + def _is_tma_compatible(x: IRNode) -> bool: + if len(x.get_size()) != 2: + return False + + dtype = x.get_dtype() + if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): + return False + + layout = x.get_layout() + transposed = layout.is_transposed() + if not (layout.is_contiguous() or transposed): + return False + + inner_dim = layout.size[1] + if transposed: + inner_dim = layout.size[0] + + if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt( + inner_dim, 32 + ): + return False + + inner_bytes = inner_dim * dtype.itemsize + return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) + + if has_triton_stable_tma_api() and config.cpp_wrapper: + # TODO(dberard) remove this when we get AOTI support for new TMA APIs (#155047) + return False + + return ( + config.triton.enable_persistent_tma_matmul + and has_triton_tma_device() + and all(_is_tma_compatible(m) for m in matrices) + ) + + +def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: + from .virtualized import V + + gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) + if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: + return False + from .codegen.cuda.cutlass_utils import try_import_cutlass + + # Do not use cutlass template on ROCm + if torch.version.hip: + return False + + # output dtype + # FP32 not supported: https://github.com/pytorch/pytorch/issues/145952 + layout_dtypes = [torch.float16, torch.bfloat16, torch.int32] + res = ( + _use_template_for_gpu(layout, layout_dtypes) + and (config.max_autotune or config.max_autotune_gemm) + and _use_autotune_backend("CUTLASS") + ) + + if res: + if not try_import_cutlass(): + log.warning( + "Failed to import CUTLASS lib. Please check whether " + "_inductor.config.cuda.cutlass_dir is set correctly. " + "Skipping CUTLASS backend for now." + ) + return False + return res + + +def _use_cutlass_for_op(op_name: str) -> bool: + """Check if CUTLASS should be used for the given operation.""" + enabled_ops = config.cuda.cutlass_enabled_ops.upper() + if enabled_ops == "ALL": + return True + return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] + + +decompose_k_threshold = 32 + +# To limit compile time +k_splits_limit = 5 + +# Hand-tuned +default_k_splits = [16, 32, 64, 128, 256] + +_IntLike: TypeAlias = Union[int, sympy.Expr] + + +def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: + from torch._inductor.virtualized import V + + return ( + V.graph.sizevars.statically_known_true( + sympy.And( + sympy.Ge(k, decompose_k_threshold * m), + sympy.Ge(k, decompose_k_threshold * n), + ) + ) + and not V.graph.aot_mode # TODO: Support AOTI for decomposeK + and not V.graph.cpp_wrapper + and not config.disable_decompose_k + ) + + +@functools.cache +def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: + # If k is a sympy expression, we can't do any splitting + if isinstance(k, sympy.Expr) and not k.is_number: + return default_k_splits + + if (isinstance(m, sympy.Expr) and not m.is_number) or ( + isinstance(n, sympy.Expr) and not n.is_number + ): + max_k_split = 256 + else: + max_k_split = min(k // m, k // n) + + min_k_split = 2 + # Get all divisors of k, k has to be divisible by kPart + divisors = sympy.divisors(k) + + divisors = [ + divisor + for divisor in divisors + if divisor <= max_k_split and divisor >= min_k_split + ] + + pow_of_2_divisors, mul_of_32_divisors, rest_of_splits = [], [], [] + + for d in divisors: + kPart = k // d + + # Smaller than 128 might not even fit in a single tile, BLOCK_K can be 128 + if kPart < 128: + continue + + # Power of 2 divisors are best performing, conform to hardware + if (kPart & kPart - 1) == 0 and kPart >= 128: + pow_of_2_divisors.append(d) + # Else check if creates a multiple of 32 + elif kPart % 32 == 0: + mul_of_32_divisors.append(d) + # otherwise, take the smallest values + else: + rest_of_splits.append(d) + + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits + # If the # of power of 2 divisors are greater than k_splits_limit, return all + # This should be ok for compile time, all perfect squares between 128 and min(k / m, k / n) + # should never be a massive amount + if len(pow_of_2_divisors) >= k_splits_limit: + return pow_of_2_divisors + else: + best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits + # Otherwise, conform results to k_splits_limit + return best_splits[:k_splits_limit] + + +@functools.cache +def _rocm_native_device_arch_name(device: str) -> str: + return torch.cuda.get_device_properties(device).gcnArchName + + +@functools.cache +def try_import_ck_lib() -> tuple[ + Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any] +]: + try: + import ck4inductor # type: ignore[import] + from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library, + gen_ops_preselected, + ) + from ck4inductor.universal_gemm.op import ( # type: ignore[import] + CKGemmOperation, + ) + + package_dirname = os.path.dirname(ck4inductor.__file__) + except ImportError: + + def gen_ops_library() -> list[Any]: + return [] + + def gen_ops_preselected() -> list[Any]: + return [] + + class CKGemmOperation: # type: ignore[no-redef] + pass + + package_dirname = None + return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation + + +def use_ck_template(layout: Layout) -> bool: + # config knobs check 1 + if not (config.max_autotune or config.max_autotune_gemm): + return False + # platform check + if not torch.version.hip: + return False + # tensors must be on GPU + if not layout.device.type == "cuda": + return False + # hardware check + # if config arch list is not specified, get the native arch from the device properties + native_arch = _rocm_native_device_arch_name(layout.device) + requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or { + native_arch.split(":")[0]: native_arch + } + requested_supported_archs = [ + requested_archs[k] + for k in requested_archs.keys() & config.rocm.ck_supported_arch + ] + if not requested_supported_archs: + return False + # supported input dtypes + if layout.dtype not in [torch.float16, torch.bfloat16, torch.float32]: + return False + + ck_package_dirname, _, _, _ = try_import_ck_lib() + + if not ck_package_dirname: + log.warning("Please pip install Composable Kernel package") + return False + + if config.is_fbcode(): + config.rocm.ck_dir = ck_package_dirname + + if not config.rocm.ck_dir: + log.warning("Please set TORCHINDUCTOR_CK_DIR env variable") + return False + + if ck_package_dirname != config.rocm.ck_dir: + log.warning("Invalid path to CK library") + return False + + return True + + +def use_ck_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool: + from .virtualized import V + + return ( + _use_autotune_backend("CK") + and use_ck_template(layout) + and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0 + ) + + +def use_ck_tile_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool: + from .virtualized import V + + return ( + _use_autotune_backend("CKTILE") + and use_ck_template(layout) + and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0 + ) + + +def use_ck_conv_template(layout: Layout) -> bool: + return _use_conv_autotune_backend("CK") and use_ck_template(layout) + + +def _use_template_for_cpu(layout: Layout) -> bool: + return ( + config.max_autotune or config.max_autotune_gemm + ) and layout.device.type == "cpu" + + +def use_cpp_bmm_template( + layout: Layout, mat1: Union[ReinterpretView, Buffer], mat2: IRNode +) -> bool: + from .ir import Layout + + assert isinstance(mat1.layout, Layout) + + return ( + use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) + and mat1.layout.is_contiguous() + ) + + +def use_cpp_gemm_template( + layout: Layout, + mat1: IRNode, + mat2: IRNode, + mat2_transposed: bool = False, + require_constant_mat2: bool = True, + is_woq_int4: bool = False, + q_group_size: Optional[int] = None, +) -> bool: + from . import ir + from .codegen.cpp_micro_gemm import create_micro_gemm + from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype + from .kernel.mm_common import mm_args + + if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): + return False + + if not config.cpp.weight_prepack: + return False + + int8_gemm = mat1.get_dtype() in [torch.uint8, torch.int8] + layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8] + m, n, k, layout, mat1, mat2 = mm_args( + mat1, + mat2, + out_dtype=layout.dtype if int8_gemm else None, + mat2_transposed=mat2_transposed, + use_4x2_dim=is_woq_int4, + ) + + # TODO(jgong5): support dynamic shapes for n or k + if has_free_symbols((n, k)): + return False + + if isinstance(mat2, ir.BaseView): + mat2 = mat2.unwrap_view() + + output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype()) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=mat1.get_dtype(), + input2_dtype=mat2.get_dtype(), + output_dtype=output_dtype, + num_threads=parallel_num_threads(), + use_ref=not is_woq_int4, + q_group_size=q_group_size, + ) + + def is_last_dim_stride1(x: IRNode) -> bool: + x.freeze_layout() + return x.get_stride()[-1] == 1 + + return ( + layout.dtype in layout_dtypes + and micro_gemm is not None + and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input + and isinstance(mat2, ir.StorageBox) + and (mat2.is_module_buffer() or not require_constant_mat2) + ) + + +def use_aten_gemm_kernels() -> bool: + return not ( + config.max_autotune or config.max_autotune_gemm + ) or _use_autotune_backend("ATEN") + + +class DebugDirManager: + counter = itertools.count(0) + prev_debug_name: str + + def __init__(self) -> None: + self.id = next(DebugDirManager.counter) + + def __enter__(self) -> None: + self.prev_debug_name = torch._dynamo.config.debug_dir_root + self.new_name = f"{self.prev_debug_name}_tmp_{self.id}" + torch._dynamo.config.debug_dir_root = self.new_name + + def __exit__(self, *args: Any) -> None: + shutil.rmtree(self.new_name) + torch._dynamo.config.debug_dir_root = self.prev_debug_name + + +def run_and_get_code( + fn: Callable[P, _T], + *args: P.args, + **kwargs: P.kwargs, +) -> tuple[_T, list[str]]: + from .graph import GraphLowering + + source_codes: list[str] = [] + + def save_output_code(code: str) -> None: + source_codes.append(code) + + with mock.patch.object(GraphLowering, "save_output_code", save_output_code): + torch._dynamo.reset() + result = fn(*args, **kwargs) + return result, source_codes + + +def run_and_get_kernels( + fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs +) -> tuple[_T, list[str]]: + result, source_codes = run_and_get_code(fn, *args, **kwargs) + kernels = [] + for code in source_codes: + kernels.extend(re.findall(r"'''.*?'''", code, re.DOTALL)) + return result, kernels + + +def run_fw_bw_and_get_code(fn: Callable[..., Any]) -> tuple[Any, list[str]]: + def run_with_backward() -> Any: + result = fn() + result.sum().backward() + return result + + return run_and_get_code(run_with_backward) + + +def get_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> list[str]: + """Get the inductor-generated code, but skip any actual compilation or running.""" + from .graph import GraphLowering + + source_codes: list[str] = [] + + def save_output_code(code: str) -> None: + source_codes.append(code) + + def patched_compile_to_module(self: GraphLowering) -> Any: + class DummyModule: + """This is empty to replace the generated triton module""" + + def __init__(self) -> None: + pass + + def call(self, *args: Any, **kwargs: Any) -> None: + # Don't do anything when called + pass + + wrapper_code, kernel_code = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + # Skip all the actual compiling. + nonlocal save_output_code + save_output_code(wrapper_code.value) + if kernel_code: + save_output_code(kernel_code.value) + + return DummyModule() + + with ( + mock.patch.object( + GraphLowering, "compile_to_module", patched_compile_to_module + ), + mock.patch.object(GraphLowering, "save_output_code", save_output_code), + ): + torch._dynamo.reset() + # Note the return here is None + _ = fn(*args, **kwargs) + + return source_codes + + +def get_triton_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> str: + source_codes = get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert 1 <= len(source_codes) <= 2, ( + f"expected one or two code outputs got {len(source_codes)}" + ) + return source_codes[0] + + +def run_and_get_triton_code( + fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs +) -> str: + _, source_codes = run_and_get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert 1 <= len(source_codes) <= 2, ( + f"expected one or two code outputs got {len(source_codes)}" + ) + return source_codes[0] + + +def run_and_get_graph_lowering( + fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs +) -> tuple[Any, list[GraphLowering]]: + from torch._inductor.graph import GraphLowering + from torch._inductor.output_code import CompiledFxGraph + + real_init = CompiledFxGraph.__init__ + graph_lowerings = [] + + def fake_init(*args: Any, **kwargs: Any) -> None: + real_init(*args, **kwargs) + graph = args[2] + assert isinstance(graph, GraphLowering) + graph_lowerings.append(graph) + + with mock.patch.object(CompiledFxGraph, "__init__", fake_init): + result = fn(*args, **kwargs) + + return result, graph_lowerings + + +@contextlib.contextmanager +def override_lowering( + aten_op: Callable[..., Any], override_fn: Callable[..., Any] +) -> Iterator[None]: + """ + Override the lowering of aten_op with override_fn. + The first argument of override_fn is the original lowering fn. + """ + from torch._inductor import lowering + + orig_fn = lowering.lowerings[aten_op] + try: + lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn) + yield + finally: + lowering.lowerings[aten_op] = orig_fn + + +def add_scheduler_init_hook( + pre_fn: Callable[..., Any], post_fn: Optional[Callable[..., Any]] = None +) -> Any: + """ + Add hook functions to be called at the beginning and end of Scheduler.__init__. + Used for unit tests. + """ + from torch._inductor.scheduler import Scheduler + + orig_fn = Scheduler.__init__ + + def wrapper(scheduler: Any, nodes: Any) -> Any: + pre_fn(scheduler, nodes) + out = orig_fn(scheduler, nodes) + if post_fn: + post_fn(scheduler, nodes) + return out + + return unittest.mock.patch.object(Scheduler, "__init__", wrapper) + + +def developer_warning(msg: str) -> None: + """ + Warnings that will be actionable for PyTorch developers, but not + end users. Allows us to easily disable them in stable releases but + keep them on for nightly builds. + """ + if config.developer_warnings: + log.warning(msg) + else: + log.info(msg) + + +def get_benchmark_name() -> Optional[str]: + """ + An experimental API used only when config.benchmark_kernel is true. + + The benchmark name is only available at codegen time. So we can not + directly call it in benchmark_all_kernels which is run after codegen. + + The function assumes the argument after --only is the benchmark name. + It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc + scripts, this function may return None. + + There are 2 flavors of --only argument we need handle: + 1. --only model_name + 2. --only=model_name + """ + try: + idx = sys.argv.index("--only") + if ( + idx + 1 < len(sys.argv) + and len(sys.argv[idx + 1]) > 0 + and sys.argv[idx + 1][0] != "-" + ): + return sys.argv[idx + 1] + except ValueError: + pass + + for arg in sys.argv: + if arg.startswith("--only="): + return arg[len("--only=") :] + + return None + + +def is_ones(items: Sequence[Any]) -> bool: + return all(x == 1 for x in items) + + +def is_zeros(items: Sequence[Any]) -> bool: + return all(x == 0 for x in items) + + +def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool: + return all( + item.device == torch.device("cpu") + for item in inputs + if isinstance(item, torch.Tensor) + ) + + +def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: + assert isinstance(val, sympy.Expr), ( + "only support sympy.Expr as input to get_sympy_Expr_dtype" + ) + if val.is_integer: # type: ignore[attr-defined] + return torch.int64 + else: + return torch.float64 + + +@contextlib.contextmanager +def maybe_profile(should_profile: bool, *args: Any, **kwargs: Any) -> Iterator[Any]: + if should_profile: + with torch.profiler.profile(*args, **kwargs) as p: + yield p + else: + yield + + +def parallel_num_threads() -> int: + threads = config.cpp.threads + if threads < 1: + threads = torch.get_num_threads() + return threads + + +@functools.cache +def get_backend_num_stages() -> int: + from .runtime.triton_helpers import get_backend_options + + options = get_backend_options() + return options.get("num_stages", 2 if torch.version.hip else 3) + + +@functools.cache +def get_device_tflops(dtype: torch.dtype) -> int: + from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops + + assert dtype in (torch.float16, torch.bfloat16, torch.float32) + + if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): + # Triton API change in https://github.com/triton-lang/triton/pull/2293 + from torch._utils_internal import max_clock_rate + + sm_clock = max_clock_rate() + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype, sm_clock) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32, sm_clock) + else: + return get_max_simd_tflops(torch.float32, sm_clock) + else: + if dtype in (torch.float16, torch.bfloat16): + return get_max_tensorcore_tflops(dtype) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32) + else: + return get_max_simd_tflops(torch.float32) + + +@functools.cache +def get_gpu_dram_gbps() -> int: + from triton.testing import get_dram_gbps + + return get_dram_gbps() + + +def get_gpu_shared_memory() -> int: + from triton.runtime import driver + + return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0) + + +def is_welford_reduction(reduction_type: str) -> bool: + return reduction_type.startswith("welford") + + +def reduction_num_outputs(reduction_type: str) -> int: + if is_welford_reduction(reduction_type): + return 3 + elif reduction_type == "online_softmax_reduce": + return 2 + else: + return 1 + + +def is_linux() -> bool: + return platform.system() == "Linux" + + +def is_windows() -> bool: + return sys.platform == "win32" + + +def has_free_symbols(itr: Iterable[Any]) -> bool: + return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr) + + +def is_dynamic(*args: Any) -> bool: + from . import ir + + for t in args: + if isinstance( + t, (ir.TensorBox, ir.StorageBox, ir.BaseView, ir.ComputedBuffer, ir.Buffer) + ): + if has_free_symbols(t.maybe_get_size() or ()) or has_free_symbols( + t.maybe_get_stride() or () + ): + return True + elif not isinstance(t, ir.IRNode): + continue + else: + raise TypeError(f"unexpected type for is_dynamic {type(t)}") + + return False + + +# Placeholder strings used in triton codegen. +class Placeholder(enum.Enum): + # The placeholder for the actual name of a triton kernel. + # e.g. for "def triton_" it would be "triton_" + KERNEL_NAME = "KERNEL_NAME" + + # The descriptive name of the triton kernel; when unique_kernel_names = False, this + # placeholder will be replaced with a string with more information. + DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME" + + +def pass_execution_and_save( + func: Callable[..., Any], gm: GraphModule, inp: Sequence[Any], msg: str +) -> None: + from .pattern_matcher import stable_topological_sort + + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + ) as f: + before_io = io.StringIO() + after_io = io.StringIO() + ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp) + print(f"Before:\n{gm.graph}", file=f) + print(gm.graph, file=before_io) + start_time = datetime.now() + with GraphTransformObserver(gm, msg): + func(gm.graph) + time_elapsed = datetime.now() - start_time + # recompile graph + stable_topological_sort(gm.graph) + gm.graph.lint() + gm.recompile() + + print(f"After:\n{gm.graph}", file=f) + print(gm.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + log.info( + "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s", + msg, + f.name, + t, + time_elapsed, + ) + + +def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) -> bool: + """ + Check if input buffer is a multi-outputs template buffer + """ + from . import ir + + return isinstance(input_buf, ir.CppTemplateBuffer) and isinstance( + input_buf.layout, ir.MultiOutputLayout + ) + + +def is_output_of_multi_outputs_template( + input_buf: Optional[Union[Buffer, Operation]], +) -> bool: + """ + Check if input buffer is a output of multi-outputs template buffer + """ + from . import ir + + return ( + isinstance(input_buf, ir.MultiOutput) + and len(input_buf.inputs) == 1 + and is_multi_outputs_template(input_buf.inputs[0]) + ) + + +def is_collective( + node: Optional[Union[Node, Operation]], + op: Optional[torch._ops.OperatorBase] = None, +) -> bool: + if node is None: + return False + + from . import ir + + return ( + type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op) + ) or ( + # TODO: this is a temporary solution to ensure that we can identify torchrec's + # communication ops. But in order to allow better communication and computation + # overlap, torchrec's communication ops should be not used. + type(node) == ir.FallbackKernel + and ( + # NOTE: the `hasattr()` check is to bypass errors such as the following: + # AttributeError: '_OpNamespace' 'torchrec' object has no attribute 'all_to_all_single' + ( + hasattr(torch.ops.torchrec, "all_to_all_single") + and node.op_overload == torch.ops.torchrec.all_to_all_single.default + ) + or ( + hasattr(torch.ops.torchrec, "all_gather_into_tensor") + and node.op_overload + == torch.ops.torchrec.all_gather_into_tensor.default + ) + or ( + hasattr(torch.ops.torchrec, "reduce_scatter_tensor") + and node.op_overload == torch.ops.torchrec.reduce_scatter_tensor.default + ) + ) + ) + + +def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool: + from . import ir + + return type(node) == ir._WaitKernel + + +def contains_collective(snode: BaseSchedulerNode) -> bool: + from torch._inductor.scheduler import GroupedSchedulerNode + + if isinstance(snode, GroupedSchedulerNode): + return any(contains_collective(x) for x in snode.snodes) + + return is_collective(snode.node) + + +def contains_wait(snode: BaseSchedulerNode) -> bool: + from torch._inductor.scheduler import GroupedSchedulerNode + + if isinstance(snode, GroupedSchedulerNode): + return any(contains_wait(x) for x in snode.snodes) + else: + return is_wait(snode.node) + + +def is_fallback_op( + node: Optional[Operation], + op: Union[torch._ops.OpOverload, Collection[torch._ops.OpOverload]], +) -> bool: + from . import ir + + if isinstance(op, torch._ops.OpOverload): + op = [op] + return isinstance(node, ir.FallbackKernel) and node.op_overload in op + + +def buf_name_to_fused_snode( + buf_name: str, name_to_buf: dict[str, Any], name_to_fused_node: dict[str, Any] +) -> Any: + return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()] + + +def find_recursive_deps_of_node( + snode: BaseSchedulerNode, + collected_node_set: MutableSet[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + criteria_cb: Callable[[Any], bool] = lambda snode: False, +) -> None: + if criteria_cb(snode): + return + collected_node_set.add(snode) + for dep in snode.unmet_dependencies: + defining_op_for_dep = buf_name_to_fused_snode( + dep.name, name_to_buf, name_to_fused_node + ) + if defining_op_for_dep in collected_node_set: + continue + find_recursive_deps_of_node( + defining_op_for_dep, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def find_recursive_users_of_node( + snode: BaseSchedulerNode, + collected_node_set: MutableSet[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + criteria_cb: Callable[[Any], bool] = lambda snode: False, +) -> None: + if criteria_cb(snode): + return + collected_node_set.add(snode) + for o in snode.get_outputs(): + for user in o.users: + assert user.node is not None + if user.node.get_name() == "OUTPUT": + continue + if user.node.get_name() not in name_to_fused_node: + continue + user_op = name_to_fused_node[user.node.get_name()] + if user_op in collected_node_set: + continue + find_recursive_users_of_node( + user_op, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int) -> int: + "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)" + num_rng_seed_offset_inputs = ( + 2 if torch._functorch.config.functionalize_rng_ops else 0 + ) + # AOT won't lift any parameters if we're inlining NN Modules + # however desugaring subclasses will still add arguments + # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502 + return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs + + +def count_tangents(fx_g: torch.fx.GraphModule) -> int: + """ + Infers which inputs are static for a backwards graph + """ + + def is_saved_tensor(x: Node) -> bool: + return ( + "tangents" not in x.name + and "bwd_seed" not in x.name + and "bwd_base_offset" not in x.name + and "bwd_rng_state" not in x.name + ) + + arg_count = 0 + static_arg_idxs = [] + for n in fx_g.graph.nodes: + if n.op == "placeholder": + if is_saved_tensor(n): + static_arg_idxs.append(arg_count) + arg_count += 1 + + assert static_arg_idxs == list(range(len(static_arg_idxs))) + return len(static_arg_idxs) + + +@dataclasses.dataclass +class BoxedBool: + value: bool + + def __bool__(self) -> bool: + return self.value + + @staticmethod + def disable(obj: Any) -> Union[BoxedBool, bool]: + if isinstance(obj, BoxedBool): + obj.value = False + return obj + return False + + +@contextlib.contextmanager +def collect_defined_kernels(kernel_list: list[str]) -> Iterator[None]: + from .codegen.wrapper import PythonWrapperCodegen + + orig_define_kernel = PythonWrapperCodegen.define_kernel + + def define_kernel( + self: PythonWrapperCodegen, + kernel_name: str, + kernel_code: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ) -> Any: + kernel_list.append(kernel_code) + return orig_define_kernel( + self, kernel_name, kernel_code, metadata, gpu, cpp_definition + ) + + with mock.patch.object(PythonWrapperCodegen, "define_kernel", define_kernel): + yield + + +def get_cloned_parameter_buffer_name(name: str) -> str: + return name + "__original__" + + +def is_gpu(device: Optional[str]) -> bool: + return device in GPU_TYPES + + +def device_need_guard(device: str) -> bool: + return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now + + +def needs_fallback_due_to_atomic_add_limitations(dtype: torch.dtype) -> bool: + # tl.atomic add has bfloat16 support in fbcode + # but not in OSS https://github.com/pytorch/pytorch/issues/97016 + # we will fallback until the code is upstreamed to OSS + if ( + config.is_fbcode() + and dtype == torch.bfloat16 + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and config.bfloat16_atomic_adds_enabled + ): + return False + else: + return dtype in OrderedSet([torch.int64, torch.bool, torch.bfloat16]) + + +def use_scatter_fallback( + op_overload: torch._ops.OpOverload, + reduction_type: Optional[str], + self_dtype: torch.dtype, + src_dtype: torch.dtype, + src_device_type: str, + src_is_tensor: bool, +) -> bool: + if ( + op_overload.overloadpacket + in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce) + and reduction_type is None + ): + return False + + reduce_ty = ( + "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum" + ) + + return ( + reduction_type not in (None, reduce_ty) + or ( + src_is_tensor + and is_gpu(src_device_type) + and needs_fallback_due_to_atomic_add_limitations(src_dtype) + ) + or ( + op_overload.overloadpacket == torch.ops.aten.scatter_reduce_ + and reduction_type == "sum" + and src_is_tensor + and src_device_type == "cpu" + and config.cpp.fallback_scatter_reduce_sum + and (config.cpp.dynamic_threads or parallel_num_threads() != 1) + ) + or (reduction_type == reduce_ty and self_dtype in (torch.bool, torch.int64)) + or torch.are_deterministic_algorithms_enabled() + ) + + +def dump_node_schedule(node_schedule: Sequence[BaseSchedulerNode]) -> None: + """ + An API that can be used in pdb to dump a node_schedule. + Right mainly dump the read/write dependencies but can add more as needed. + """ + from torch._inductor.codegen.simd import DisableReduction, EnableReduction + from torch._inductor.scheduler import SchedulerNode + + print(f"Node schedule with {len(node_schedule)} nodes") + for idx, node in enumerate(node_schedule): + print(f" {idx:3}:") + if node is EnableReduction: + print("enable reduction") + elif node is DisableReduction: + print("disable reduction") + elif isinstance(node, SchedulerNode): + is_red = node.is_reduction() + print(f"{'red' if is_red else 'pw'} scheduler node") + if is_red: + assert node.node is not None + print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined] + print("ReadDep:") + for dep in node.read_writes.reads: + print(dep) + print("WriteDep:") + for dep in node.read_writes.writes: + print(dep) + else: + raise RuntimeError(f"Unrecognized node type: {type(node)}") + + +def tensor_is_aligned(tensor: torch.Tensor) -> bool: + # See Note: [Input Alignment handling in Inductor] + # Right now, we don't try to guard on the alignment of the storage offset. + # When this comment was written, non-symbolic storage_offsets are not guarded on + # but symbolic storage_offsets are. For consistency, we suppress guard creation + # upon performing this check: that ensures that we don't add recompiles when we + # add this logic. + from torch.fx.experimental.symbolic_shapes import statically_known_true + + return statically_known_true( + (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0 + ) + + +def should_assume_input_aligned(example_input: torch.Tensor) -> bool: + # See Note: [Input Alignment handling in Inductor] + + # right now, we only care about alignment for cuda tensors. + if not is_gpu(example_input.device.type): + return False + return config.assume_aligned_inputs or tensor_is_aligned(example_input) + + +def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[None]: + # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards() + # If it's not available, return a nullcontext. + + # If we're dealing with cudagraphs, we might not have a tracing_context + tracing_context = torch._guards.TracingContext.try_get() + if not tracing_context: + return contextlib.nullcontext() + + # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode + shape_env = tracing_context.fake_mode.shape_env + if not shape_env: + return contextlib.nullcontext() + + return shape_env.suppress_guards() + + +def run_and_get_cpp_code( + fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs +) -> tuple[_T, str]: + # We use the patch context manager instead of using it as a decorator. + # In this way, we can ensure that the attribute is patched and unpatched correctly + # even if this run_and_get_cpp_code function is called multiple times. + with unittest.mock.patch.object(config, "debug", True): + torch._dynamo.reset() + import io + import logging + + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + from torch._inductor.codecache import output_code_log + + output_code_log.addHandler(ch) + prev_level = output_code_log.level + output_code_log.setLevel(logging.DEBUG) + result = fn(*args, **kwargs) + s = log_capture_string.getvalue() + output_code_log.setLevel(prev_level) + output_code_log.removeHandler(ch) + return result, s + + +def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]: + fake_mode = detect_fake_mode(inputs) + + # TODO(voz): It would be nice to enable this assert, but there are lots of tests that + # pass in real inputs for now. + # if len(inputs) > 0: + # assert fake_mode is not None, breakpoint() + + if fake_mode is not None: + return fake_mode.shape_env + + # When there are no tensor inputs, get shape_env from the first SymInt. + for input in inputs: + if isinstance(input, torch.SymInt): + return input.node.shape_env + + # TODO(voz): Should we always have one anyway? + return None + + +def align_inputs_from_check_idxs( + model: Callable[[list[InputType]], _T], + inputs_to_check: Sequence[int], + mutated_input_idxs: OrderedSet[int], +) -> Callable[[list[InputType]], _T]: + if len(inputs_to_check) == 0: + return model + + def run(new_inputs: list[InputType]) -> Any: + old_tensors, new_tensors = copy_misaligned_inputs( + new_inputs, inputs_to_check, mutated_input_idxs + ) + out = model(new_inputs) + + # If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the + # original tensor. + if len(old_tensors): + torch._foreach_copy_(old_tensors, new_tensors) + + return out + + return run + + +def clone_preserve_strides(x: torch.Tensor) -> torch.Tensor: + if 0 in x.size(): + # Short-circuits if the shape has no elements + needed_size = 0 + else: + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 + ) + buffer = torch.as_strided(x, (needed_size,), (1,)).clone() + return torch.as_strided(buffer, x.size(), x.stride()) + + +def copy_misaligned_inputs( + new_inputs: list[InputType], + check_inputs_idxs: Sequence[int], + return_pair_idxs: Optional[OrderedSet[int]] = None, +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """ + Clones misaligned tensors which we inferred were aligned. Returns a tuple of [old_tensors], [new_tensors] for every + cloned tensor which is in `return_pair_idxs`. + """ + + old_tensors: list[torch.Tensor] = [] + new_tensors: list[torch.Tensor] = [] + + # hoist above loop because this is on the hot path + ret_pair_defined = return_pair_idxs is not None + for i in check_inputs_idxs: + _inp = new_inputs[i] + assert isinstance(_inp, torch.Tensor), ( + f"Expected tensors only, but got: {type(_inp)}" + ) + if _inp.data_ptr() % ALIGNMENT: + new_inputs[i] = clone_preserve_strides(_inp) + + if ret_pair_defined and i in return_pair_idxs: # type: ignore[operator] + old_tensors.append(_inp) + new_tensors.append(new_inputs[i]) # type: ignore[arg-type] + + return old_tensors, new_tensors + + +def remove_unaligned_input_idxs( + inputs: Sequence[InputType], + static_input_idxs: Sequence[int], +) -> Sequence[int]: + """ + We require all inputs to be aligned, so introduce a copy for any + that aren't. + """ + aligned_static_input_idxs = [] + for idx in static_input_idxs: + input = inputs[idx] + if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0: + aligned_static_input_idxs.append(idx) + if len(aligned_static_input_idxs) != len(static_input_idxs): + return aligned_static_input_idxs + return static_input_idxs + + +def expr_fits_within_32bit(e: sympy.Expr) -> bool: + from .virtualized import V + + int_max = torch.iinfo(torch.int32).max + size_hint = V.graph.sizevars.size_hint + has_hint = V.graph.sizevars.shape_env.has_hint + + # Allow for unhinted e as long as we can still statically prove + # (e.g., via ValueRanges) that it is still in bounds + if V.graph.sizevars.statically_known_true(e <= int_max): + return True + # Otherwise, the hint MUST exist and be in range + return has_hint(e) and size_hint(e) <= int_max + + +def set_tracing_context_output_strides( + example_inputs: Sequence[Any], compiled_graph: CompiledFxGraph +) -> None: + # Return the output strides to the caller via TracingContext + context = torch._guards.TracingContext.try_get() + if context is not None and context.output_strides is not None: + assert len(context.output_strides) == 0 + shape_env = shape_env_from_inputs(example_inputs) + assert compiled_graph.output_strides is not None + for exprs in compiled_graph.output_strides: + if exprs is None: + context.output_strides.append(None) + else: + fakify_first_call = False + if ctx := torch._guards.TracingContext.try_get(): + fakify_first_call = ctx.fakify_first_call + + def map_expr(e: Any) -> Union[float, int, SymInt, SymFloat, SymBool]: + if shape_env is None: + return int(e) + if fakify_first_call: + return shape_env.deserialize_symexpr(e) + return shape_env.evaluate_symexpr(e) + + context.output_strides.append( + tuple(map_expr(e) for e in exprs) # type: ignore[misc] + ) + + +def should_use_remote_fx_graph_cache() -> bool: + if config.fx_graph_remote_cache is not None: + return config.fx_graph_remote_cache + if not config.is_fbcode(): + return False + + if torch._utils_internal.is_fb_unit_test(): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:fx_graph_memcache_version" + ) + + +def normalize_name(name: str) -> str: + return re.sub(r"[^a-zA-Z0-9_]", "_", name) + + +# correct cases where Triton types names don't match PyTorch +_triton_type_mapping = { + "tl.bool": "tl.int1", + "tl.float8_e4m3fn": "tl.float8e4nv", + "tl.float8_e5m2": "tl.float8e5", + "tl.float8_e4m3fnuz": "tl.float8e4b8", + "tl.float8_e5m2fnuz": "tl.float8e5b16", + # TODO: remove when support is added in triton + # https://github.com/triton-lang/triton/issues/6054 + "tl.float8_e8m0fnu": "tl.uint8", + "tl.float4_e2m1fn_x2": "tl.uint8", +} +_torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()} + + +_triton_type_re = re.compile(r"^.*[.]") + + +def triton_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type""" + triton_type_name = _triton_type_re.sub("tl.", str(dtype)) + return _triton_type_mapping.get(triton_type_name, triton_type_name) + + +def triton_type_to_torch(dtype: str) -> torch.dtype: + adjusted_type = _torch_triton_mapping.get(dtype, dtype) + type_name = adjusted_type.replace("tl.", "") + out_dtype = getattr(torch, type_name) + assert isinstance(out_dtype, torch.dtype) + return out_dtype + + +def is_same_tensor(data: torch.Tensor, value: torch.Tensor) -> bool: + return ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and data.untyped_storage().data_ptr() == value.untyped_storage().data_ptr() + and data.storage_offset() == value.storage_offset() + ) + + +def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool: + return ( + data.is_mkldnn + and data.size() == value.size() + and data.dtype == value.dtype + and data.device == value.device + and torch.ops.mkldnn.data_ptr(data) == torch.ops.mkldnn.data_ptr(value) + ) + + +@functools.cache +def boolean_ops() -> tuple[str, ...]: + return ( + "isinf", + "isnan", + "logical_not", + "logical_and", + "signbit", + "and_", + "le", + "lt", + "ge", + "gt", + "eq", + "ne", + "or_", # TODO should remove this op + "xor", + ) + + +@dataclasses.dataclass +class OpDtypeRule: + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND + override_return_dtype: Optional[torch.dtype] + + +op_dtype_propagation_rules: dict[str, OpDtypeRule] = {} + + +def register_op_dtype_propagation_rules( + name: str, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, + override_return_dtype: Optional[torch.dtype], +) -> None: + op_dtype_propagation_rules[name] = OpDtypeRule( + type_promotion_kind, override_return_dtype + ) + + +op_requires_libdevice_fp64: OrderedSet[str] = OrderedSet() + + +def register_op_requires_libdevice_fp64(name: str) -> None: + op_requires_libdevice_fp64.add(name) + + +def get_current_backend() -> str: + from torch._inductor.virtualized import V + + device_str = V.graph.get_current_device_or_throw().type + if device_str == "cpu": + return config.cpu_backend + elif device_str == "mps": + return "mps" + else: + return config.cuda_backend + + +def upcast_compute_type(dtype: torch.dtype) -> torch.dtype: + """Maybe upcast [b]float16 to float32""" + if ( + dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + and get_current_backend() == "triton" + ): + return torch.float32 + return dtype + + +KeyType = TypeVar("KeyType") +ValType = TypeVar("ValType") + + +class ScopedDict(MutableMapping[KeyType, ValType]): + """ + A dictionary-like object that allows for scoped updates. It maintains + an original dictionary and a set of new items that can override + the original items within the scope. The original dictionary is + unmodified. + """ + + def __init__(self, original_dict: Mapping[KeyType, ValType]): + self.original_dict = original_dict + self.new_items: dict[KeyType, ValType] = {} + + def __getitem__(self, key: KeyType) -> ValType: + if key in self.new_items: + return self.new_items[key] + return self.original_dict[key] + + def __setitem__(self, key: KeyType, value: ValType) -> None: + self.new_items[key] = value + + def __contains__(self, key: object) -> bool: + return key in self.new_items or key in self.original_dict + + def get(self, key: KeyType, default: Optional[ValType] = None) -> Optional[ValType]: # type: ignore[override] + if key in self.new_items: + return self.new_items[key] + return self.original_dict.get(key, default) + + def __len__(self) -> int: + n = len(self.original_dict) + for k in self.new_items: + if k not in self.original_dict: + n += 1 + return n + + def __iter__(self) -> Iterator[KeyType]: + yield from self.original_dict + for k in self.new_items: + if k not in self.original_dict: + yield k + + def __bool__(self) -> bool: + return bool(self.original_dict or self.new_items) + + def __delitem__(self, key: KeyType) -> None: + raise NotImplementedError + + +@dataclass_transform(frozen_default=True) +def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any: + def wrap(cls: _T) -> _T: + if sys.version_info >= (3, 10): + return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload] + else: + # Polyfill for python=3.9. kw_only simply introduces an extra check + # that only kwargs are used (and is not available on 3.9) + return dataclasses.dataclass(cls, frozen=frozen) + + if cls is None: + return wrap + return wrap(cls) + + +def get_donated_idxs() -> Optional[list[int]]: + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None and tracing_context.fw_metadata: + return tracing_context.fw_metadata.bw_donated_idxs + return None + + +def set_kernel_post_grad_provenance_tracing( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernelOut], + kernel_name: str, + is_extern: bool = False, +) -> None: + from .codegen.simd_kernel_features import DisableReduction, EnableReduction + from .ir import ExternKernelOut + from .virtualized import V + + if is_extern: + assert isinstance(node_schedule, ExternKernelOut) + curr_node_info = ( + V.debug._inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + ) + curr_node_info.extend( + origin.name + for origin in node_schedule.origins + if origin.name not in curr_node_info + ) + else: + assert isinstance(node_schedule, list) + for snode in node_schedule: + if snode not in (EnableReduction, DisableReduction): + if snode.node is not None: + curr_node_info = V.debug._inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + curr_node_info.extend( + origin.name + for origin in snode.node.origins + if origin.name not in curr_node_info + ) + + +class TritonAttrsDescriptorVersion(enum.Enum): + V0_NO_TRITON = 0 + V1_COMPILER = 1 # triton.compiler.compiler.AttrsDescriptor + V2_BACKENDS = 2 # triton.backends.compiler.AttrsDescriptor + V3_BACKENDS_TUPLE = ( + 3 # triton.backends.compiler.AttrsDescriptor, but with tuple support + ) + V4_DICT = 4 # a raw dict + + +@functools.cache +def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion: + if importlib.util.find_spec("triton") is None: + return TritonAttrsDescriptorVersion.V0_NO_TRITON + + import triton.backends.compiler + import triton.compiler.compiler + + if hasattr(triton.backends.compiler, "AttrsDescriptor"): + # Triton 3.2.0 + # AttrsDescriptor was moved from triton.compiler.compiler to triton.backends.compiler. + # AttrsDescriptor and its serialization format were also changed. + + # TODO: implement V3_BACKENDS_TUPLE + # On Dec 9, 2024, tuple support (triton #5220) was implemented and breaks handling. + # We don't have a way to detect this (and haven't implemented this version) + return TritonAttrsDescriptorVersion.V2_BACKENDS + elif hasattr(triton.compiler.compiler, "AttrsDescriptor"): + # Triton 3.0.0 + return TritonAttrsDescriptorVersion.V1_COMPILER + else: + # After Jan 1, 2025 + # AttrsDescriptor was removed and replaced with a raw dict. + return TritonAttrsDescriptorVersion.V4_DICT + + +def triton_version_uses_attrs_dict() -> bool: + return get_triton_attrs_descriptor_version() == TritonAttrsDescriptorVersion.V4_DICT + + +def is_cudagraph_unsafe_op(node: Operation) -> bool: + """ + Returns True if the node is an op that is not cudagraphable. + Usually only custom ops have this tag. + """ + from . import ir + + if not isinstance(node, ir.FallbackKernel): + return False + + if ( + isinstance(node.op_overload, torch._ops.OpOverload) + and torch._C.Tag.cudagraph_unsafe in node.op_overload.tags + ): + return True + + return False + + +def get_ld_library_path() -> str: + path = os.environ.get("LD_LIBRARY_PATH", "") + if config.is_fbcode(): + from libfb.py.parutil import get_runtime_path + + runtime_path = get_runtime_path() + if runtime_path: + lib_path = os.path.join(runtime_path, "runtime", "lib") + path = os.pathsep.join([lib_path, path]) if path else lib_path + + return path + + +def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool: + from torch._inductor.codegen.wrapper import SubgraphPythonWrapperCodegen + + return ( + isinstance(wrapper, SubgraphPythonWrapperCodegen) + and wrapper.partition_signatures is not None + ) + + +def dtype_from_size(size: int) -> torch.dtype: + from .virtualized import V + + if V.graph.sizevars.statically_known_lt( + size, 2**31 + ) and V.graph.sizevars.statically_known_geq(size, -(2**31)): + return torch.int32 + else: + return torch.int64 + + +SUPPORTED_MKLDNN_DEVICES = ("cpu", "xpu") + + +def is_mkldnn_bf16_supported(device_type: str) -> bool: + """ + Returns True if the device supports MKL-DNN BF16. + """ + if device_type == "cpu": + return torch.ops.mkldnn._is_mkldnn_bf16_supported() + elif "xpu" in device_type: + # match "xpu", "xpu:0", "xpu:1", etc. + return True + return False + + +def is_mkldnn_fp16_supported(device_type: str) -> bool: + """ + Returns True if the device supports MKL-DNN FP16. + """ + if device_type == "cpu": + return torch.ops.mkldnn._is_mkldnn_fp16_supported() + elif "xpu" in device_type: + # match "xpu", "xpu:0", "xpu:1", etc. + return True + return False diff --git a/phivenv/Lib/site-packages/torch/_inductor/virtualized.py b/phivenv/Lib/site-packages/torch/_inductor/virtualized.py new file mode 100644 index 0000000000000000000000000000000000000000..783f2b4916c4d79d06883d87849a37dbc034ad18 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/virtualized.py @@ -0,0 +1,411 @@ +# mypy: allow-untyped-defs +""" +This file provides a number of "global" variables/handlers that are actually +thread local and dynamically scoped, with Inductor patching them to various +implementations depending on the situation. + +These handlers are interacted with in a fairly stylized way. Typically, +we will import V from this module:: + + from .virtualized import V + +Various handlers are accessible as attributes on this module; for example, +you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with +a number. + +There are a few distinct usage patterns for virtualized global variables: + +1. Implicit argument passing. Examples: ``V.current_node``, ``V.aot_compilation``. + Use ``V.set_current_node`` to change what the current node is while we're + executing some region of code, so code inside that region can query ``V.current_node`` + to find out what it is. This is often more convenient than manually threading + the current node as an argument through all call stacks. + +2. Per-compilation global state. Examples: ``V.fake_mode``, ``V.graph``. For a + given ``compile_fx`` invocation, these typically don't change, but they are + associated with some internal state so they cannot just be global functions. + We install these objects at the beginning of compilation and then you can + conveniently access them without having to pass them around. + +3. Alternate define-by-run interpretations. Examples: ``V.ops``, ``V.kernel``. + A commonly used IR in Inductor is define-by-run: instead of maintaining + explicit syntax data structures, we instead represent loop bodies as + callable functions, which internally invoke operations defined on + ``V.ops``. To perform semantic analysis, print or code generate these + operations, we dynamically patch ``V.ops`` with an alternate handler with + the intended semantics and then run the callable function. For example, to + extract out a traditional (FX) graph representation of the define-by-run + IR, simply install a handler that records each ``ops`` call to a graph. + + TODO: Define a parent class / protocol that defines all of the operations + V.ops is expected to support. + +It is typically an error to access a virtualized global without having installed +an appropriate handler (you will get a NullHandler), although in some cases we +provide a default implementation. + +One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is +ubiquitous enough to have its own top level variable, so you will typically see +``ops.constant(...)`` rather than ``V.ops.constant(...)``. In fact, these are not +equivalent; the former interface supports arithmetic overloads like ``x + y`` +instead of forcing ``ops.add(x, y)``, so it should be preferred. + +Some operators are seemingly unused, but they are implicitly used by ops_wrapper. +In particular, we typically have an operator for every basic pointwise PyTorch operation +supported. +""" + +from __future__ import annotations + +from contextlib import AbstractContextManager, contextmanager +from threading import local +from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union + +from torch.utils._ordered_set import OrderedSet + +from .ops_handler import ( # noqa: F401 + DefaultHandler, + KernelFormatterHandler, + MockHandler, + OpsHandler, + ReductionType, + StoreMode, + WrapperHandler, +) + + +if TYPE_CHECKING: + import torch + from torch._inductor.choices import InductorChoices + from torch._inductor.codegen.cpp_utils import LocalBufferContext + from torch._inductor.debug import DebugContext + from torch._inductor.graph import GraphLowering + from torch._inductor.loop_body import InterpreterShim + from torch._subclasses import FakeTensorMode + +threadlocal = local() + +T = TypeVar("T") + + +class NullHandler: + """ + Sentinel indicating that a global variable is unset ala None. Typically, + attempting to access the global variable before it's set is an error, but with + NullHandler it won't fail until you try to access an attribute on it. + """ + + +# If a virtualized value is set to _PoisonedVirtual then any attempt to get the +# value will result an an exception being raised. This is useful if we want to +# trap uninitialized reads of virtualized globals - for example when compiling +# in a subprocess we don't want the child reading globals that weren't copied +# from the parent. +_PoisonedVirtual = object() + + +class Virtualized(Generic[T]): + """ + Implements a global variable that redirects via thread local variable + (NB: construct this class to create the global variable; this is not + a singleton class!) + + This allows us to swap in different op implementations in codegen. + + NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is + the default value of the variable), we sometimes use these variables to + store other things, like booleans. + """ + + def __init__(self, vname: str, default: Union[Callable[[], T], type[NullHandler]]): + self._vname = vname + self._key: str = f"__torchinductor_{vname}" + self._default = default + + def _set_handler(self, value: T) -> AbstractContextManager[None]: + prior = self._get_handler(False) + setattr(threadlocal, self._key, value) + + @contextmanager + def ctx(): + try: + yield + finally: + self._set_handler(prior) + + return ctx() + + def _get_handler(self, check_poisoned: bool = True) -> T: + try: + value = getattr(threadlocal, self._key) + if check_poisoned and value is _PoisonedVirtual: + raise RuntimeError( + f"Attempt to use poisoned virtualized value '{self._vname}'." + ) + return value + except AttributeError: + # TODO: To be honest, I feel we probably should just error in this + # case, instead of making a null handler that will probably error + # when you getattr on it + return self._default() # type: ignore[return-value] + + def __getattr__(self, name: str) -> Any: + return getattr(self._get_handler(), name) + + +class NullKernelHandler(NullHandler): + """ + We need access `V.kernel.removed_buffers` in DeferredLine class when there + is no kernel in the context. This happens when codegening the wrapper. + Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't + need call 'getattr' with default value which is error prone to typo in + attribute name. + """ + + def __init__(self): + super().__init__() + self.removed_buffers = OrderedSet[Any]() + self.inplaced_to_remove = OrderedSet[Any]() + self.index_dtype = "tl.int64" + + def get_index_dtype_as_torch_dtype(self): + import torch + + if self.index_dtype == "tl.int64": + return torch.int64 + elif self.index_dtype == "tl.int32": + return torch.int32 + else: + raise ValueError(f"Unknown dtype: {self.index_dtype}") + + +_ops: Virtualized[OpsHandler[Any]] = Virtualized( + "ops", cast(type[OpsHandler[Any]], MockHandler) +) +_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) +_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler) +_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) +_kernel: Virtualized[NullKernelHandler] = Virtualized( + "kernel", NullKernelHandler +) # TODO: improve type +_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler) +_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler) +_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler) +_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler) +_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized( + "local_buffer_context", NullHandler +) + + +def _choices_default(): + """ + Lazy init the global choices handler + + We virtualize InductorChoices to allow changing inductor heuristics from out of tree. + """ + from torch._inductor.choices import InductorChoices + + rv = InductorChoices() + setattr(threadlocal, _choices._key, rv) + return rv + + +_choices: Virtualized[InductorChoices] = Virtualized("choices", _choices_default) + + +class OpsValue: + """The return type of most ops calls. + + This exists so we can overload magic methods, and write mathematical + expressions much more fluently. So instead of + + ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1) + + we can write + + (_Ap2 * x - _Ap3) * x * x + _1 + + """ + + value: Any + + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __repr__(self): + return f"OpsValue({self.value!r})" + + def __add__(self, other): + return ops.add(self, other) + + def __mul__(self, other): + return ops.mul(self, other) + + def __sub__(self, other): + return ops.sub(self, other) + + def __neg__(self): + return ops.neg(self) + + def __truediv__(self, other): + return ops.truediv(self, other) + + def __floordiv__(self, other): + return ops.floordiv(self, other) + + def __mod__(self, other): + return ops.mod(self, other) + + def __pow__(self, other): + return ops.pow(self, other) + + def __lt__(self, other): + return ops.lt(self, other) + + def __le__(self, other): + return ops.le(self, other) + + def __eq__(self, other): + return ops.eq(self, other) + + def __ne__(self, other): + return ops.ne(self, other) + + def __gt__(self, other): + return ops.gt(self, other) + + def __ge__(self, other): + return ops.ge(self, other) + + def __and__(self, other): + return ops.bitwise_and(self, other) + + def __or__(self, other): + return ops.bitwise_or(self, other) + + def __xor__(self, other): + return ops.bitwise_xor(self, other) + + def __invert__(self): + return ops.bitwise_not(self) + + def __rshfit__(self, n): + return ops.bitwise_right_shift(self, n) + + def __lshift__(self, n): + return ops.bitwise_left_shift(self, n) + + +class OpsWrapper(DefaultHandler): + """This wraps any returned IR values into an `OpsValue` instance, so that we + can overload the magic methods for writing mathematical expressions fluently. + """ + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + new_args = [OpsWrapper._unwrap(a) for a in args] + new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} + return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) + + @staticmethod + def _unwrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsWrapper._unwrap(v) for v in x) + if isinstance(x, OpsValue): + return x.value + return x + + @staticmethod + def _wrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsValue(v) for v in x) + return OpsValue(x) + + @staticmethod + def indirect_indexing(index, size, check=True, wrap_neg=True): + # Returns a sympy value, not IR value + index = OpsWrapper._unwrap(index) + return _ops.indirect_indexing(index, size, check, wrap_neg) + + +ops: OpsHandler[Any] = OpsWrapper() + + +class _V: + MockHandler = MockHandler + KernelFormatterHandler = KernelFormatterHandler + WrapperHandler = WrapperHandler + + set_ops_handler: Callable[[OpsHandler[Any]], AbstractContextManager[None]] = ( + _ops._set_handler + ) + get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler + set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler + set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler + get_real_inputs: Callable[[], Any] = _real_inputs._get_handler + set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler + get_fake_mode: Callable[[], Any] = _fake_mode._get_handler + set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler + set_debug_handler: Callable[[Any], Any] = _debug._set_handler + set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler + set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler + get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler + set_current_node: Callable[[Any], Any] = _current_node._set_handler + get_current_node: Callable[[], Any] = _current_node._get_handler + set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler + get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler + set_choices_handler: Callable[[Any], Any] = _choices._set_handler + + @property + def ops(self) -> OpsHandler[Any]: + """The operator handler specific to the current codegen task""" + return _ops._get_handler() + + @property + def graph(self) -> GraphLowering: + """The graph currently being generated""" + return _graph._get_handler() + + @property + def real_inputs(self): + """non-fake example inputs""" + return _real_inputs._get_handler() + + @property + def fake_mode(self): + """The graph currently being generated""" + return _fake_mode._get_handler() + + @property + def kernel(self): + """The kernel currently being generated""" + return _kernel._get_handler() + + @property + def debug(self): + return _debug._get_handler() + + @property + def interpreter(self): + return _interpreter._get_handler() + + @property + def aot_compilation(self): + return _aot_compilation._get_handler() is True + + @property + def current_node(self): + return _current_node._get_handler() + + @property + def local_buffer_context(self): + return _local_buffer_context._get_handler() + + @property + def choices(self) -> InductorChoices: + return _choices._get_handler() + + +V = _V() diff --git a/phivenv/Lib/site-packages/torch/_inductor/wrapper_benchmark.py b/phivenv/Lib/site-packages/torch/_inductor/wrapper_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..61c3feccbada8d05eee0841131b18627b11a93f3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/wrapper_benchmark.py @@ -0,0 +1,495 @@ +import argparse +import dataclasses +import datetime +import tempfile +from collections import defaultdict +from types import ModuleType +from typing import Any, Optional, Protocol + +import torch +from torch.autograd import DeviceType +from torch.utils._ordered_set import OrderedSet + +from .runtime.benchmarking import benchmarker +from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes + + +class BenchmarkCallableType(Protocol): + def __call__(self, times: int, repeat: int) -> float: ... + + +_kernel_category_choices = [ + "foreach", + "persistent_reduction", + "pointwise", + "reduction", + "split_scan", + "template", +] + + +def get_kernel_category_by_source_code(src_code: str) -> str: + """ + Similar to get_kernel_category but use the source code. Call this API + if we have not compile the src_code to module yet. + """ + choices = [ + ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code + ] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_kernel_category(kernel_mod: ModuleType) -> str: + """ + Given the module defining a triton kernel, return the category of the kernel. + Category can be one of: + - pointwise + - reduction + - persistent_reduction + + Currently we simply decide the category depending on what decorator is imported + by the kernel. + """ + choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_triton_kernel(mod: ModuleType): # type: ignore[no-untyped-def] + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + + cand_list = [ + v + for k, v in mod.__dict__.items() + if k.startswith("triton_") and isinstance(v, CachingAutotuner) + ] + assert len(cand_list) == 1 + return cand_list[0] + + +def benchmark_all_kernels( + benchmark_name: str, benchmark_all_configs: Optional[dict[Any, Any]] +) -> None: + """ + An experimental API used only when config.benchmark_kernel is true. + + Run the kernel benchmarks for all the kernels cached in PyCodeCache. + Used in the compiled modules. + + Put this method here rather than codegen it for convenience since its implementation + does not change based on different graph modules being compiled. + """ + from torch._inductor.codecache import PyCodeCache + + nfound = 0 + for kernel_mod in PyCodeCache.modules: + kernel_key = kernel_mod.key + if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): + continue + + triton_kernel = get_triton_kernel(kernel_mod) + kernel_category = get_kernel_category(kernel_mod) + args = kernel_mod.get_args() + num_in_out_ptrs = len( + [ + arg_name + for arg_name in triton_kernel.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + + def get_info_str( + ms: float, + n_regs: Optional[Any], + n_spills: Optional[Any], + shared: Optional[Any], + prefix: str = "", + ) -> str: + if not any(x is None for x in [n_regs, n_spills, shared]): + kernel_detail_str = ( + f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem" + ) + else: + kernel_detail_str = "" + + gb_per_s = num_gb / (ms / 1e3) + return create_bandwidth_info_str( + ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str + ) + + kernel_desc = ( + f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}" + ) + if benchmark_all_configs: + assert hasattr(kernel_mod, "benchmark_all_configs") + bench_result = kernel_mod.benchmark_all_configs(args) + print(kernel_desc) + for launcher, ms in bench_result.items(): + print( + f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" + ) + else: + ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40) + assert len(triton_kernel.launchers) == 1, ( + "Autotuner should have selected the best config" + ) + launcher = triton_kernel.launchers[0] + print( + get_info_str( + ms, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + prefix=f"{kernel_desc} ", + ) + ) + + nfound += 1 + if nfound == 0: + print( + "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True" + ) + + +@dataclasses.dataclass +class ProfileEvent: + category: str + key: str + self_device_time_ms: float + # the benchmark is run multiple times and we average the count across all the + # runs. It should be an integer but define a float just in case. + count: float + + +def parse_profile_event_list( + benchmark_name: str, + event_list: torch.autograd.profiler_util.EventList, + wall_time_ms: float, + nruns: int, + device_name: str, +) -> None: + def get_self_device_time( + ev: torch.autograd.profiler_util.EventList, + ) -> float: + """ + ev.self_device_time_total is in microsecond. Convert to millisecond. + """ + return ev.self_device_time_total / 1000 / nruns # type: ignore[attr-defined] + + all_events: dict[str, list[ProfileEvent]] = defaultdict(list) + + def add_event( + ev: torch.autograd.profiler_util.EventList, + category: str, + ) -> None: + profile_ev = ProfileEvent( + category=category, + key=ev.key, # type: ignore[attr-defined] + self_device_time_ms=get_self_device_time(ev), + count=ev.count / nruns, # type: ignore[operator] # average across all runs + ) + all_events[category].append(profile_ev) + + for ev in event_list: + assert not ev.is_legacy, "Don't support the legacy profiler" + if ev.device_type == DeviceType.CPU: + # ignore the event on CPU side + continue + + category = "unknown" + if ev.key.startswith("triton_"): + if ev.key.startswith("triton_poi"): + category = "triton_pointwise" + elif ev.key.startswith("triton_red"): + category = "triton_reduction" + elif ev.key.startswith("triton_per"): + category = "triton_persistent_reduction" + else: + category = "triton_unknown" + + add_event(ev, category) + + def report_category(category: str, profile_events: list[ProfileEvent]) -> float: + if not device_name: + return 0.0 + + from tabulate import tabulate + + profile_events.sort(key=lambda ev: ev.self_device_time_ms, reverse=True) + + rows = [] + total_time = 0.0 + print(f"\n == {category} category kernels == ") + for ev in profile_events: + total_time += ev.self_device_time_ms + percent = f"{ev.self_device_time_ms / wall_time_ms * 100:.2f}%" + rows.append([ev.key[:120], ev.self_device_time_ms, ev.count, percent]) + rows.append( + ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"] + ) + print( + tabulate( + rows, + headers=[ + "Kernel", + f"Self {device_name.upper()} TIME (ms)", + "Count", + "Percent", + ], + ) + ) + return total_time + + def report() -> None: + category_list = [ + "triton_pointwise", + "triton_reduction", + "triton_persistent_reduction", + "triton_unknown", + "unknown", + ] + assert OrderedSet(all_events.keys()).issubset(OrderedSet(category_list)), ( + f"{list(all_events.keys())}" + ) + + per_category_wall_time = {} + total_device_ms = 0.0 + for category in category_list: + if category in all_events: + _time = report_category(category, all_events[category]) + per_category_wall_time[category] = _time + total_device_ms += _time + + device_busy_percent = f"{total_device_ms / wall_time_ms * 100:.2f}%" + if device_name: + print( + f"\nPercent of time when {device_name.upper()} is busy: {device_busy_percent}" + ) + else: + print("No device detected") + + print(f"Total wall time {wall_time_ms:.3f} ms") + + # output such a line so we can gather such line from all compiled modules from all + # benchmarks and tabulate it! + # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent, + # unknown_category_percent, device_busy_percent, wall_time_ms + tabulate_line = f"Output for tabulate: {benchmark_name}" + for category in category_list: + percent = ( + f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%" + ) + tabulate_line += f", {percent}" + tabulate_line += f", {device_busy_percent}, {wall_time_ms:.3f}ms" + + print(tabulate_line) + + report() + + +def perf_profile( + wall_time_ms: float, + times: int, + repeat: int, + benchmark_name: str, + benchmark_compiled_module_fn: BenchmarkCallableType, +) -> None: + with torch.profiler.profile(record_shapes=True) as p: + benchmark_compiled_module_fn(times=times, repeat=repeat) + + path = f"{tempfile.gettempdir()}/compiled_module_profile.json" + p.export_chrome_trace(path) + print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") + print(f"Chrome trace for the profile is written to {path}") + event_list = p.key_averages(group_by_input_shape=True) + print(event_list.table(sort_by="self_device_time_total", row_limit=10)) + parse_profile_event_list( + benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device + ) + + +def ncu_analyzer( + benchmark_name: str, + benchmark_compiled_module_fn: BenchmarkCallableType, + args: argparse.Namespace, +) -> None: + import inspect + import os + import subprocess + + kernel_regex = args.ncu_kernel_regex + metrics = args.ncu_metrics + + module_file = inspect.getfile(benchmark_compiled_module_fn) + module_dir = os.path.dirname(module_file) + module_name = os.path.splitext(os.path.basename(module_file))[0] + + ncu_dir = tempfile.gettempdir() + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + ncu_output = os.path.join(ncu_dir, f"ncu_output_{timestamp}.ncu-rep") + python_cmd = ( + f"""import sys; sys.path.insert(0, '{module_dir}'); """ + f"""from {module_name} import benchmark_compiled_module; """ + """benchmark_compiled_module(times=1, repeat=1)""" + ) + + ncu_cmd = [ + "ncu", + "--target-processes", + "all", + "--replay-mode", + "kernel", + "--kernel-name-base", + "function", + "--print-units", + "base", + "--import-source", + "yes", + "--force-overwrite", + "--export", + ncu_output, + ] + + if kernel_regex: + ncu_cmd.extend(["--kernel-name", f"regex:{kernel_regex}"]) + + if metrics: + ncu_cmd.extend(["--metrics", metrics]) + else: + ncu_cmd.extend(["--set", "full"]) + + ncu_cmd.extend( + [ + "python", + "-c", + python_cmd, + ] + ) + + try: + subprocess.run(ncu_cmd, check=True) + print(f"\nNCU profiling results for benchmark {benchmark_name}:") + print(f"NCU report has been written to {ncu_output}") + + except subprocess.CalledProcessError as e: + print(f"NCU profiling failed with error: {e}") + return + + +def collect_memory_snapshot( + benchmark_compiled_module_fn: BenchmarkCallableType, +) -> None: + assert torch.cuda.is_available() + + torch.cuda.memory._record_memory_history(max_entries=100000) + benchmark_compiled_module_fn(times=10, repeat=1) # run 10 times + snapshot_path = f"{tempfile.gettempdir()}/memory_snapshot.pickle" + torch.cuda.memory._dump_snapshot(snapshot_path) + torch.cuda.memory._record_memory_history(enabled=None) + print(f"The collect memory snapshot has been written to {snapshot_path}") + + +# With AOTAutograd cache, we directly call the compiled module. So prevent +# Dynamo from reentering +@torch.compiler.disable # type: ignore[misc] +def compiled_module_main( + benchmark_name: str, benchmark_compiled_module_fn: BenchmarkCallableType +) -> None: + """ + This is the function called in __main__ block of a compiled module. + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark-kernels", + "-k", + action="store_true", + help="Whether to benchmark each individual kernels", + ) + parser.add_argument( + "--benchmark-all-configs", + "-c", + action="store_true", + help="Whether to benchmark each individual config for a kernel", + ) + parser.add_argument( + "--profile", + "-p", + action="store_true", + help="Whether to profile the compiled module", + ) + parser.add_argument( + "--cuda-memory-snapshot", + action="store_true", + help=""" + Whether to collect CUDA memory snapshot. Refer to + "https://pytorch.org/blog/understanding-gpu-memory-1/ + for details about how to visualize the collected snapshot + """, + ) + parser.add_argument( + "--ncu", + action="store_true", + help="Whether to run ncu analysis", + ) + parser.add_argument( + "--ncu-kernel-regex", + type=str, + default=None, + help=( + "Filter kernels profiled by NCU using a regex (e.g., '^triton_.*'). " + "Maps to '--kernel-name regex:'. " + "If None, NCU will profile all kernels." + ), + ) + parser.add_argument( + "--ncu-metrics", + type=str, + default=None, + help=( + "Comma-separated list of NCU metrics to collect (e.g., 'dram__bytes.sum.per_second'). " + "If None, NCU will use '--set full'." + ), + ) + args = parser.parse_args() + + if args.benchmark_kernels: + benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) + else: + times = 10 + repeat = 10 + + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000 + + if torch.cuda.is_available(): + peak_mem = torch.cuda.max_memory_allocated() + print(f"Peak GPU memory usage {peak_mem / 1e6:.3f} MB") + + if torch.cuda.is_available() and args.cuda_memory_snapshot: + collect_memory_snapshot(benchmark_compiled_module_fn) + + if args.profile: + perf_profile( + wall_time_ms, + times, + repeat, + benchmark_name, + benchmark_compiled_module_fn, + ) + if args.ncu: + ncu_analyzer( + benchmark_name, + benchmark_compiled_module_fn, + args=args, + ) diff --git a/phivenv/Lib/site-packages/torch/lib/fmt.lib b/phivenv/Lib/site-packages/torch/lib/fmt.lib new file mode 100644 index 0000000000000000000000000000000000000000..667793a1ede0a051a860ee3db6bb4b1b888dc395 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/fmt.lib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f03ed71886f4b7a54c18a50b94092704f686bfabf1e93c6299c2576609dde32c +size 3382620 diff --git a/phivenv/Lib/site-packages/torch/lib/libiomp5md.dll b/phivenv/Lib/site-packages/torch/lib/libiomp5md.dll new file mode 100644 index 0000000000000000000000000000000000000000..5526ba6b9695b3d9d7d000d197a9efb120902d9b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/libiomp5md.dll @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9d66ed25f1a0ea725fa3a41b22cfd5d182c19dbe4771d9c90ca02ad7466f6a1 +size 1613680 diff --git a/phivenv/Lib/site-packages/torch/lib/libittnotify.lib b/phivenv/Lib/site-packages/torch/lib/libittnotify.lib new file mode 100644 index 0000000000000000000000000000000000000000..14c4b305d0e9df4f07b29316fa4d54b4ff9bd1fd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/lib/libittnotify.lib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7e9e4855fd8fafa1d8c869364713ad38a93dbb09611470a2d6b04e41dbe5f6f +size 591548 diff --git a/phivenv/Lib/site-packages/torch/testing/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb5201aa68c20c71e51b6fe41af18e1608b13dc7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/__pycache__/_comparison.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/__pycache__/_comparison.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ceab207b67f8c5aed6fcffe3f8336b9e54a85ae Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/__pycache__/_comparison.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/__pycache__/_creation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/__pycache__/_creation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f3b45ea7da0dec7261fe1f4ac4d0aa4467107be Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/__pycache__/_creation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2e1a697b86d537714e9541b3a6a47c22296a50c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64445653c4cc022321855a3715e8c4dc83ae4809 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/autocast_test_lists.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/autocast_test_lists.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f965ca53104d7bfcb35b61ba72003c5868168758 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/autocast_test_lists.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/autograd_function_db.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/autograd_function_db.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a47808c9249940a38dd4b69458fd6c838769307 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/autograd_function_db.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/check_kernel_launches.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/check_kernel_launches.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6563dbb6b25763c5c0ecca02e579ec89cc1950fb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/check_kernel_launches.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_cuda.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_cuda.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b23af90d80cd10c4dd8b97b5d0cbbb90d24a048d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_cuda.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_device_type.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_device_type.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..094d7328cd92d39af8f00c553aa363ac9be85a64 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_device_type.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_dist_composable.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_dist_composable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aa5ffe422ee6404d8feff02f6e7ece2bbc536fa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_dist_composable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_distributed.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_distributed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb802e286f385cc6d44936b04bed410eca6391a6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_distributed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_dtype.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_dtype.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..762adbd54cbed06400118022dc7e19a407220fbd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_dtype.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_fsdp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_fsdp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39e189deba673ac2eeb30400821af9901bb10987 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_fsdp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_jit.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_jit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..118f63f0fff9510647c3e2e58b45c85df8cfb850 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_jit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_mkldnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_mkldnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5613cda9e47af5d280833c35f9a28183e1eabed4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_mkldnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_modules.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..375146c032f8406f85a6496e27764a38475ec879 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_modules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_mps.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_mps.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb7977657b2a175860ab6750836206b9e51835b3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_mps.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_optimizers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_optimizers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4538258c7c8cd2ce850510f2ae2d6f68fb70c418 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_optimizers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_pruning.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_pruning.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce6c9eb0805b60833db60d693c971dcd849f5f88 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_pruning.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_quantized.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_quantized.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4da244dae7fb55fe9f1041017783f169a650c96 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_quantized.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_subclass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_subclass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ab06ac1217ff8a23159b29debfa8aff675ae793 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/common_subclass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/composite_compliance.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/composite_compliance.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14a6bd3812bfe6f2935c968d5f9a28242527d9a3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/composite_compliance.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/custom_op_db.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/custom_op_db.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dc306499158907138b8630301de3740ce47b7dc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/custom_op_db.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/custom_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/custom_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a66a80c20f2de194d5f016454eccfebaeee74d4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/custom_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/dist_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/dist_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94d36c01136be93c1848bee2ce0aabfbd3c1ab64 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/dist_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ddf5cd51a0cb464862388223273f009dc2bfb8d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/fake_config_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/fake_config_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac4d7170f3ae37328685a00b0c0ef22c04882bea Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/fake_config_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/fake_config_module2.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/fake_config_module2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ce5900e8cecb3399bd43b9c14b140101572771d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/fake_config_module2.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/fake_config_module3.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/fake_config_module3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91542822e40eb7be3d3eec8b9257d90a9155895d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/fake_config_module3.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/hop_db.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/hop_db.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0af15f3098668705eec16d6da5583e76b2271108 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/hop_db.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/hypothesis_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/hypothesis_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb21a956d53840cc304eb1e15737df11c92c6b30 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/hypothesis_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/inductor_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/inductor_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ad7faa03c8594deb46dad133e201b046ddeed2d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/inductor_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e67d92b6f2e188cb286c35fc6f47cd577d16bde9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/jit_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/jit_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1043df53013089dd86f526955f00f5e55b6eae24 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/jit_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0dd03310644000d71af1b9da3faadd1cc2c4cf41 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/logging_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/logging_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/logging_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11c9427eba6041177f96d0a77e4d30342c9b4df4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/logging_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83cbd2d7e031e6611a9ed5c37593c786fa1e8596 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/static_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/static_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e9a080ded36d76cb46af61294d5b5f2d8d63ec3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/static_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/subclasses.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/subclasses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e6ef8ae87425b84d5c373d74d0f5c200bddf64f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/subclasses.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/torchbind_impls.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/torchbind_impls.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75acaee001ed395d511fe191e572798be4f1eab3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/torchbind_impls.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/triton_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/triton_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92e78a1bdbe123ab97cd264022a7dbacd685febf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/triton_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/two_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/two_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6f3b60ded14f653369f5199f3f6c92e735b0b0d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/__pycache__/two_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__init__.py b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ff0448a99775ca910765bd4958e7e1239e3f0ce Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__pycache__/future_div.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__pycache__/future_div.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..680bad908a54268766c55b41de59edfea2c48b82 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__pycache__/future_div.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-39.pyc b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d32df94a7790055d8889f837ed668d21a94dd224 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/testing/_internal/test_module/no_future_div.py b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/no_future_div.py new file mode 100644 index 0000000000000000000000000000000000000000..5c48be0d09e2af1a169b83ec301acfee54574bf9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/testing/_internal/test_module/no_future_div.py @@ -0,0 +1,11 @@ +# mypy: ignore-errors + +import torch # noqa: F401 + + +def div_int_nofuture(): + return 1 / 2 + + +def div_float_nofuture(): + return 3.14 / 0.125 diff --git a/phivenv/Lib/site-packages/torch/utils/__init__.py b/phivenv/Lib/site-packages/torch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ddd5fd76d1425f7b0bbacb2b31e350afd6cd616 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/__init__.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs + +import copyreg +import os.path as _osp +import weakref + +import torch +from torch.utils import ( + backcompat as backcompat, + collect_env as collect_env, + data as data, + deterministic as deterministic, + hooks as hooks, +) +from torch.utils.backend_registration import ( + generate_methods_for_privateuse1_backend, + rename_privateuse1_backend, +) +from torch.utils.cpp_backtrace import get_cpp_backtrace +from torch.utils.throughput_benchmark import ThroughputBenchmark + + +def set_module(obj, mod): + """ + Set the module attribute on a python object for a given object for nicer printing + """ + if not isinstance(mod, str): + raise TypeError("The mod argument should be a string") + obj.__module__ = mod + + +if torch._running_with_deploy(): + # not valid inside torch_deploy interpreter, no paths exists for frozen modules + cmake_prefix_path = None +else: + cmake_prefix_path = _osp.join( + _osp.dirname(_osp.dirname(__file__)), "share", "cmake" + ) + + +def swap_tensors(t1, t2): + """ + This function swaps the content of the two Tensor objects. + At a high level, this will make t1 have the content of t2 while preserving + its identity. + + This will not work if t1 and t2 have different slots. + """ + # Ensure there are no weakrefs + if weakref.getweakrefs(t1): + raise RuntimeError("Cannot swap t1 because it has weakref associated with it") + if weakref.getweakrefs(t2): + raise RuntimeError("Cannot swap t2 because it has weakref associated with it") + t1_slots = set(copyreg._slotnames(t1.__class__)) # type: ignore[attr-defined] + t2_slots = set(copyreg._slotnames(t2.__class__)) # type: ignore[attr-defined] + if t1_slots != t2_slots: + raise RuntimeError("Cannot swap t1 and t2 if they have different slots") + + def swap_attr(name): + tmp = getattr(t1, name) + setattr(t1, name, (getattr(t2, name))) + setattr(t2, name, tmp) + + def error_pre_hook(grad_outputs): + raise RuntimeError( + "Trying to execute AccumulateGrad node that was poisoned by swap_tensors " + "this can happen when you try to run backward on a tensor that was swapped. " + "For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` " + "you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) " + "between running forward and backward. To resolve this, please only change the " + "device/dtype before running forward (or after both forward and backward)." + ) + + def check_use_count(t, name="t1"): + use_count = t._use_count() + error_str = ( + f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} " + f"make sure you are not holding references to the tensor in other places." + ) + if use_count > 1: + if use_count == 2 and t.is_leaf: + accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node + # Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge + if t._use_count() == 2: + accum_grad_node.register_prehook(error_pre_hook) + else: + raise RuntimeError(error_str) + else: + raise RuntimeError(error_str) + + check_use_count(t1, "t1") + check_use_count(t2, "t2") + + # Swap the types + # Note that this will fail if there are mismatched slots + swap_attr("__class__") + + # Swap the dynamic attributes + swap_attr("__dict__") + + # Swap the slots + for slot in t1_slots: + if hasattr(t1, slot) and hasattr(t2, slot): + swap_attr(slot) + elif hasattr(t1, slot): + setattr(t2, slot, (getattr(t1, slot))) + delattr(t1, slot) + elif hasattr(t2, slot): + setattr(t1, slot, (getattr(t2, slot))) + delattr(t2, slot) + + # Swap the at::Tensor they point to + torch._C._swap_tensor_impl(t1, t2) diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bff468f720885306f587c093cf34cc21976dc032 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1672e15492735ac328981cafa9c3e7be245a806 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_appending_byte_serializer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_backport_slots.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_backport_slots.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f2453a3f66e8c258a0518f6cee688279a9bc240 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_backport_slots.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_config_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_config_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c023123e0fb113b5b8b44445bbff4bbd6d932cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_config_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_content_store.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_content_store.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4567678ef444f13e5ecf325e80a6397c7c352d7b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_content_store.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_contextlib.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_contextlib.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80a7ac6dde67e29d2f1bcd7c779ecd0bea252bad Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_contextlib.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_cpp_embed_headers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_cpp_embed_headers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfeb5953a9a90b02400790c569ceae4d497f334c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_cpp_embed_headers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_cpp_extension_versioner.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_cpp_extension_versioner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db18d464ef598a9fee2e2698e27d8746153ddd59 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_cpp_extension_versioner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51a99c59bd0641c297825770d0c0c80dd29d9c8c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_cxx_pytree.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_device.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_device.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f446a88f74da5872ea29d540dbe4256cb3319b0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_device.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41d6832d5fcefa71affd453e6629a7708b97dbe3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_dtype_abbrs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_exposed_in.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_exposed_in.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a22c72cac2e0dd323e67eb97b0ca7a579e7bd282 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_exposed_in.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_filelock.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_filelock.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf3d35b0a7887e733b97429fca9cfca469197298 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_filelock.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_foreach_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_foreach_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41e8ca450c40206653325b76862cbd587694ebe3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_foreach_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_freeze.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_freeze.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01f51ed12abcc21033807baeb5ec786ea0eeb06e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_freeze.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_functools.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_functools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eddbf6dd978b86c39d4c1bc1579622272d16ae2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_functools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_get_clean_triton.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_get_clean_triton.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a6357883865865a6d81ab5d5008513eaa76314a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_get_clean_triton.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_helion.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_helion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21403918e1d97d54e79faadcda534d904a5b9fb1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_helion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_import_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_import_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..494b3173263d4bbfff8a9f349acd2f3bb5fb3e18 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_import_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_mode_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_mode_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1954d8de657e1535d05c165468b28962b362dc39 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_mode_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_ordered_set.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_ordered_set.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e0c9b24e4fd6bd1c45b852ba4dde62f3bdbb462 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_ordered_set.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_python_dispatch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_python_dispatch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..167338a9e1d8fea2d4034f2852b4240f05b0995a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_python_dispatch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_pytree.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_pytree.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c189c7dd9955e3cd0e493cba9d977ba255ebfecb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_pytree.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_stats.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_stats.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e040887411af79b993192c26ff9a4965d8b8827 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_stats.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_thunk.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_thunk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90651d1177ddeac0222db218cf95e19530a3b874 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_thunk.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_traceback.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_traceback.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04e4817318f29f5cf685e2f4b1971f2409cf1a1c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_traceback.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_triton.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_triton.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6102ba79387c3f9784febf22ff92ce19b2ab866d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_triton.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_typing_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_typing_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0e6efb2fcd6a22f506faf3b6438fb3e2c114d37 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_typing_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/_zip.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/_zip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5bfaec22058d0ee5620f5a2a4d1425f8a0a30e2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/_zip.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/backend_registration.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/backend_registration.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bfc12bdbecdbad573a7b2903bcd60c7c574f80a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/backend_registration.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/bundled_inputs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/bundled_inputs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c269105a67aa6c3f446b2e7934a6693ef497597 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/bundled_inputs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/checkpoint.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/checkpoint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b5d3224249ceecd1d6c6fccf1d238b91158f26 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/checkpoint.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/collect_env.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/collect_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3234d1fe78ec20057f3ff8057ef6d9edad5bd28c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/collect_env.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/cpp_backtrace.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/cpp_backtrace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ffa9266309d2049e4df177739d6b89582907adb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/cpp_backtrace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/cpp_extension.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/cpp_extension.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ebb251a8119285611249016a2fcb6beefcabe82 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/cpp_extension.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/deterministic.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/deterministic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a498f58b9cedf90a91a2b01ee187deb56e0e7168 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/deterministic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/dlpack.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/dlpack.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c747e08f8f68aca2bfa3743a0d6fd2609da80196 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/dlpack.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/file_baton.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/file_baton.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f83f327561686a3d67cbf9ea70578e2e3d856966 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/file_baton.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/flop_counter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/flop_counter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3875fcd9f08d61c3deeb7f82480bf25c89be9b0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/flop_counter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/hooks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a58e1169c25b3f025afc56a97e78d91f093a857 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/hooks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/mkldnn.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/mkldnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6f5dac055fdda779ae23585eb5c43f2c9687abc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/mkldnn.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/mobile_optimizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/mobile_optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d4330c1f334b85b4c428cc160cff56d860999c2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/mobile_optimizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/model_zoo.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/model_zoo.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..322d5c6cfe0b41d5d98c432d4581116692b71134 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/model_zoo.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/module_tracker.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/module_tracker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18d5602290d6b3c8c29204d511fe28d1f9198925 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/module_tracker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/show_pickle.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/show_pickle.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c3b1e34564e66e971d57ff4bfadd68f2fbe24a6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/show_pickle.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6ce4c06919b068e4f89947bf1a94b26333ebc89 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/throughput_benchmark.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/__pycache__/weak.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/__pycache__/weak.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1007118e14e2636efac439dd36567640f9a76196 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/__pycache__/weak.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_appending_byte_serializer.py b/phivenv/Lib/site-packages/torch/utils/_appending_byte_serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..861d2ee70e9e07c623d85496848de22f55762153 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_appending_byte_serializer.py @@ -0,0 +1,128 @@ +import base64 +import zlib +from collections.abc import Iterable +from typing import Callable, Generic, TypeVar + + +T = TypeVar("T") + +_ENCODING_VERSION: int = 1 + +__all__ = ["AppendingByteSerializer"] + + +####################################### +# Helper classes +####################################### + +CHECKSUM_DIGEST_SIZE = 4 + + +class BytesWriter: + def __init__(self) -> None: + # Reserve CHECKSUM_DIGEST_SIZE bytes for checksum + self._data = bytearray(CHECKSUM_DIGEST_SIZE) + + def write_uint64(self, i: int) -> None: + self._data.extend(i.to_bytes(8, byteorder="big", signed=False)) + + def write_str(self, s: str) -> None: + payload = base64.b64encode(s.encode("utf-8")) + self.write_bytes(payload) + + def write_bytes(self, b: bytes) -> None: + self.write_uint64(len(b)) + self._data.extend(b) + + def to_bytes(self) -> bytes: + digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes( + 4, byteorder="big", signed=False + ) + assert len(digest) == CHECKSUM_DIGEST_SIZE + self._data[0:CHECKSUM_DIGEST_SIZE] = digest + return bytes(self._data) + + +class BytesReader: + def __init__(self, data: bytes) -> None: + # Check for data corruption + assert len(data) >= CHECKSUM_DIGEST_SIZE + digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes( + 4, byteorder="big", signed=False + ) + assert len(digest) == CHECKSUM_DIGEST_SIZE + if data[0:CHECKSUM_DIGEST_SIZE] != digest: + raise RuntimeError( + "Bytes object is corrupted, checksum does not match. " + f"Expected: {data[0:CHECKSUM_DIGEST_SIZE]!r}, Got: {digest!r}" + ) + + self._data = data + self._i = CHECKSUM_DIGEST_SIZE + + def is_finished(self) -> bool: + return len(self._data) == self._i + + def read_uint64(self) -> int: + result = int.from_bytes( + self._data[self._i : self._i + 8], byteorder="big", signed=False + ) + self._i += 8 + return result + + def read_str(self) -> str: + return base64.b64decode(self.read_bytes()).decode("utf-8") + + def read_bytes(self) -> bytes: + size = self.read_uint64() + result = self._data[self._i : self._i + size] + self._i += size + return result + + +####################################### +# AppendingByteSerializer +####################################### + + +class AppendingByteSerializer(Generic[T]): + """ + Provides efficient serialization and deserialization of list of bytes + Note that this does not provide any guarantees around byte order + """ + + _serialize_fn: Callable[[BytesWriter, T], None] + _writer: BytesWriter + + def __init__( + self, + *, + serialize_fn: Callable[[BytesWriter, T], None], + ) -> None: + self._serialize_fn = serialize_fn + self.clear() + + def clear(self) -> None: + self._writer = BytesWriter() + # First 8-bytes are for version + self._writer.write_uint64(_ENCODING_VERSION) + + def append(self, data: T) -> None: + self._serialize_fn(self._writer, data) + + def extend(self, elems: Iterable[T]) -> None: + for elem in elems: + self.append(elem) + + def to_bytes(self) -> bytes: + return self._writer.to_bytes() + + @staticmethod + def to_list(data: bytes, *, deserialize_fn: Callable[[BytesReader], T]) -> list[T]: + reader = BytesReader(data) + assert reader.read_uint64() == _ENCODING_VERSION + + result: list[T] = [] + while not reader.is_finished(): + result.append(deserialize_fn(reader)) + return result diff --git a/phivenv/Lib/site-packages/torch/utils/_backport_slots.py b/phivenv/Lib/site-packages/torch/utils/_backport_slots.py new file mode 100644 index 0000000000000000000000000000000000000000..c84ac634aa8d35694527b85c7f35dd4dfa2b105d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_backport_slots.py @@ -0,0 +1,116 @@ +# This code is backported from python 3.10 dataclasses. Once 3.10 becomes the +# minimum supported we should use dataclass(slots=True) instead. + +from __future__ import annotations + +import dataclasses +import itertools +from typing import TYPE_CHECKING, TypeVar + + +if TYPE_CHECKING: + from collections.abc import Generator + + from _typeshed import DataclassInstance + + +__all__ = ["dataclass_slots"] + +_T = TypeVar("_T", bound="DataclassInstance") + + +def dataclass_slots(cls: type[_T]) -> type[DataclassInstance]: + assert dataclasses.is_dataclass(cls), "Can only be used on dataclasses." + + def _get_slots(cls: type[DataclassInstance]) -> Generator[str, None, None]: + slots = cls.__dict__.get("__slots__") + # `__dictoffset__` and `__weakrefoffset__` can tell us whether + # the base type has dict/weakref slots, in a way that works correctly + # for both Python classes and C extension types. Extension types + # don't use `__slots__` for slot creation + if slots is None: + slots = [] + if getattr(cls, "__weakrefoffset__", -1) != 0: + slots.append("__weakref__") + if getattr(cls, "__dictrefoffset__", -1) != 0: + slots.append("__dict__") + yield from slots + elif isinstance(slots, str): + yield slots + # Slots may be any iterable, but we cannot handle an iterator + # because it will already be (partially) consumed. + elif not hasattr(cls, "__next__"): + yield from slots + else: + raise TypeError(f"Slots of '{cls.__name__}' cannot be determined") + + def _add_slots( + cls: type[DataclassInstance], is_frozen: bool, weakref_slot: bool + ) -> type[DataclassInstance]: + # Need to create a new class, since we can't set __slots__ + # after a class has been created. + + # Make sure __slots__ isn't already set. + if "__slots__" in cls.__dict__: + raise TypeError(f"{cls.__name__} already specifies __slots__") + + # Create a new dict for our new class. + cls_dict = dict(cls.__dict__) + field_names = tuple(f.name for f in dataclasses.fields(cls)) + # Make sure slots don't overlap with those in base classes. + inherited_slots = set( + itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1])) + ) + # The slots for our class. Remove slots from our base classes. Add + # '__weakref__' if weakref_slot was given, unless it is already present. + cls_dict["__slots__"] = tuple( + itertools.filterfalse( + inherited_slots.__contains__, + itertools.chain( + # gh-93521: '__weakref__' also needs to be filtered out if + # already present in inherited_slots + field_names, + ("__weakref__",) if weakref_slot else (), + ), + ), + ) + + for field_name in field_names: + # Remove our attributes, if present. They'll still be + # available in _MARKER. + cls_dict.pop(field_name, None) + + # Remove __dict__ itself. + cls_dict.pop("__dict__", None) + + # Clear existing `__weakref__` descriptor, it belongs to a previous type: + cls_dict.pop("__weakref__", None) # gh-102069 + + # And finally create the class. + qualname = getattr(cls, "__qualname__", None) + cls = type(cls.__name__, cls.__bases__, cls_dict) + if qualname is not None: + cls.__qualname__ = qualname + + def _dataclass_getstate(self: _T) -> object: + fields = dataclasses.fields(self) + return [getattr(self, f.name) for f in fields] + + def _dataclass_setstate(self: _T, state: list[object]) -> None: + fields = dataclasses.fields(self) + for field, value in zip(fields, state): + # use setattr because dataclass may be frozen + object.__setattr__(self, field.name, value) + + if is_frozen: + # Need this for pickling frozen classes with slots. + if "__getstate__" not in cls_dict: + cls.__getstate__ = _dataclass_getstate # type: ignore[method-assign, assignment] + if "__setstate__" not in cls_dict: + cls.__setstate__ = _dataclass_setstate # type: ignore[attr-defined] + + return cls + + params = getattr(cls, dataclasses._PARAMS) # type: ignore[attr-defined] + weakref_slot = getattr(params, "weakref_slot", False) + return _add_slots(cls, params.frozen, weakref_slot) diff --git a/phivenv/Lib/site-packages/torch/utils/_config_module.py b/phivenv/Lib/site-packages/torch/utils/_config_module.py new file mode 100644 index 0000000000000000000000000000000000000000..67fd2126843c1acbd36735d1ad60a7f4fa362455 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_config_module.py @@ -0,0 +1,803 @@ +import contextlib +import copy +import hashlib +import importlib +import inspect +import io +import os +import pickle +import sys +import tokenize +import unittest +from dataclasses import dataclass +from types import FunctionType, ModuleType +from typing import ( + Any, + Callable, + Generic, + NoReturn, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import deprecated +from unittest import mock + +from torch._utils_internal import justknobs_check + + +# Types saved/loaded in configs +CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict) + + +# Duplicated, because mypy needs these types statically +T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict]) + + +_UNSET_SENTINEL = object() + + +@dataclass +class _Config(Generic[T]): + """Represents a config with richer behaviour than just a default value. + :: + i.e. + foo = Config(justknob="//foo:bar", default=False) + install_config_module(...) + + This configs must be installed with install_config_module to be used + + Precedence Order: + alias: If set, the directly use the value of the alias. + env_name_force: If set, this environment variable has precedence over + everything after this. + If multiple env variables are given, the precendence order is from + left to right. + user_override: If a user sets a value (i.e. foo.bar=True), that + has precedence over everything after this. + env_name_default: If set, this environment variable will override everything + after this. + If multiple env variables are given, the precendence order is from + left to right. + justknob: If this pytorch installation supports justknobs, that will + override defaults, but will not override the user_override precendence. + default: This value is the lowest precendance, and will be used if nothing is + set. + + Environment Variables: + These are interpreted to be either "0" or "1" to represent true and false. + + Arguments: + justknob: the name of the feature / JK. In OSS this is unused. + default: is the value to default this knob to in OSS. + alias: The alias config to read instead. + env_name_force: The environment variable, or list of, to read that is a FORCE + environment variable. I.e. it overrides everything except for alias. + env_name_default: The environment variable, or list of, to read that changes the + default behaviour. I.e. user overrides take preference. + """ + + default: Union[T, object] + justknob: Optional[str] = None + env_name_default: Optional[list[str]] = None + env_name_force: Optional[list[str]] = None + alias: Optional[str] = None + + def __init__( + self, + default: Union[T, object] = _UNSET_SENTINEL, + justknob: Optional[str] = None, + env_name_default: Optional[Union[str, list[str]]] = None, + env_name_force: Optional[Union[str, list[str]]] = None, + value_type: Optional[type] = None, + alias: Optional[str] = None, + ): + # python 3.9 does not support kw_only on the dataclass :(. + self.default = default + self.justknob = justknob + self.env_name_default = _Config.string_or_list_of_string_to_list( + env_name_default + ) + self.env_name_force = _Config.string_or_list_of_string_to_list(env_name_force) + self.value_type = value_type + self.alias = alias + if self.alias is not None: + assert ( + default is _UNSET_SENTINEL + and justknob is None + and env_name_default is None + and env_name_force is None + ), "if alias is set, none of {default, justknob and env var} can be set" + + @staticmethod + def string_or_list_of_string_to_list( + val: Optional[Union[str, list[str]]] + ) -> Optional[list[str]]: + if val is None: + return None + if isinstance(val, str): + return [val] + assert isinstance(val, list) + return val + + +# In runtime, we unbox the Config[T] to a T, but typechecker cannot see this, +# so in order to allow for this dynamic behavior to work correctly with +# typechecking we are going to lie to the typechecker that Config[T] returns +# a T. +if TYPE_CHECKING: + + def Config( + default: Union[T, object] = _UNSET_SENTINEL, + justknob: Optional[str] = None, + env_name_default: Optional[Union[str, list[str]]] = None, + env_name_force: Optional[Union[str, list[str]]] = None, + value_type: Optional[type] = None, + alias: Optional[str] = None, + ) -> T: + ... + +else: + + def Config( + default: Union[T, object] = _UNSET_SENTINEL, + justknob: Optional[str] = None, + env_name_default: Optional[Union[str, list[str]]] = None, + env_name_force: Optional[Union[str, list[str]]] = None, + value_type: Optional[type] = None, + alias: Optional[str] = None, + ) -> _Config[T]: + return _Config( + default, justknob, env_name_default, env_name_force, value_type, alias + ) + + +def _read_env_variable(name: str) -> Optional[Union[bool, str]]: + value = os.environ.get(name) + if value == "1": + return True + if value == "0": + return False + return value + + +def install_config_module(module: ModuleType) -> None: + """ + Converts a module-level config into a `ConfigModule()`. + + See _config_typing.pyi for instructions on how to get the converted module to typecheck. + """ + + class ConfigModuleInstance(ConfigModule): + # __annotations__ is written to by Sphinx autodoc + _bypass_keys = set({"_is_dirty", "_hash_digest", "__annotations__"}) + + def visit( + source: Union[ModuleType, type], + dest: Union[ModuleType, SubConfigProxy], + prefix: str, + ) -> None: + """Walk the module structure and move everything to module._config""" + if sys.version_info[:2] < (3, 10): + type_hints = getattr(source, "__annotations__", {}) + else: + type_hints = inspect.get_annotations(source) + for key, value in list(source.__dict__.items()): + if ( + key.startswith("__") + or isinstance(value, (ModuleType, FunctionType)) + or (hasattr(value, "__module__") and value.__module__ == "typing") + # Handle from torch.utils._config_module import Config + or (isinstance(value, type) and issubclass(value, _Config)) + ): + continue + + name = f"{prefix}{key}" + annotated_type = type_hints.get(key, None) + if isinstance(value, CONFIG_TYPES): + config[name] = _ConfigEntry( + _Config(default=value, value_type=annotated_type) + ) + if dest is module: + delattr(module, key) + elif isinstance(value, _Config): + if annotated_type is not None and value.value_type is None: + value.value_type = annotated_type + + config[name] = _ConfigEntry(value) + + if dest is module: + delattr(module, key) + elif isinstance(value, type): + assert value.__module__ == module.__name__ + # a subconfig with `class Blah:` syntax + proxy = SubConfigProxy(module, f"{name}.") + visit(value, proxy, f"{name}.") + if dest is module: + setattr(dest, key, proxy) + else: + dest.__dict__[key] = proxy + else: + raise AssertionError(f"Unhandled config {key}={value} ({type(value)})") + + config: dict[str, _ConfigEntry] = {} + + compile_ignored_keys = get_assignments_with_compile_ignored_comments(module) + + visit(module, module, "") + module._config = config # type: ignore[attr-defined] + module._compile_ignored_keys = compile_ignored_keys # type: ignore[attr-defined] + module.__class__ = ConfigModuleInstance + module._is_dirty = True # type: ignore[attr-defined] + module._hash_digest = None # type: ignore[attr-defined] + + +COMPILE_IGNORED_MARKER = "@compile_ignored" + + +# Gets all the keys (i.e. assignments) with a @compile_ignored comment +def get_assignments_with_compile_ignored_comments(module: ModuleType) -> set[str]: + source_code = inspect.getsource(module) + assignments = set() + + # Tokenize the source code to retrieve comments + tokens = tokenize.tokenize(io.BytesIO(source_code.encode("utf-8")).readline) + current_comment = "", -1 + prev_name = "" + + for token in tokens: + if token.type == tokenize.COMMENT: + prev_name = "" + maybe_current = token.string.strip() + if COMPILE_IGNORED_MARKER in maybe_current: + assert current_comment == ( + "", + -1, + ), f"unconsumed {COMPILE_IGNORED_MARKER}" + current_comment = maybe_current, token.start[0] + elif token.type == tokenize.NAME: + # Only accept the first name token, to handle if you have + # something like foo: Bar = ... + if not prev_name: + prev_name = token.string + elif token.type == tokenize.OP and token.string == "=": + # Check if the current assignment follows a comment + # with COMPILE_IGNORED_MARKER + if ( + COMPILE_IGNORED_MARKER in current_comment[0] + and current_comment[1] == token.start[0] - 1 + ): + assignments.add(prev_name) + current_comment = "", -1 # reset + prev_name = "" + assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}" + return assignments + + +@dataclass +class _ConfigEntry: + # The default value specified in the configuration + default: Any + # The type of the configuration value + value_type: type + # The value specified by the user when they overrode the configuration + # _UNSET_SENTINEL indicates the value is not set. + user_override: Any = _UNSET_SENTINEL + # The justknob to check for this config + justknob: Optional[str] = None + # environment variables are read at install time + env_value_force: Any = _UNSET_SENTINEL + env_value_default: Any = _UNSET_SENTINEL + # Used to work arounds bad assumptions in unittest.mock.patch + # The code to blame is + # https://github.com/python/cpython/blob/94a7a4e22fb8f567090514785c69e65298acca42/Lib/unittest/mock.py#L1637 + # Essentially, mock.patch requires, that if __dict__ isn't accessible + # (which it isn't), that after delattr is called on the object, the + # object must throw when hasattr is called. Otherwise, it doesn't call + # setattr again. + # Technically we'll have an intermediate state of hiding the config while + # mock.patch is unpatching itself, but it calls setattr after the delete + # call so the final state is correct. It's just very unintuitive. + # upstream bug - python/cpython#126886 + hide: bool = False + alias: Optional[str] = None + + def __init__(self, config: _Config): + self.default = config.default + self.value_type = ( + config.value_type if config.value_type is not None else type(self.default) + ) + self.justknob = config.justknob + self.alias = config.alias + if config.env_name_default is not None: + for val in config.env_name_default: + if (env_value := _read_env_variable(val)) is not None: + self.env_value_default = env_value + break + if config.env_name_force is not None: + for val in config.env_name_force: + if (env_value := _read_env_variable(val)) is not None: + self.env_value_force = env_value + break + + # Ensure justknobs and envvars are allowlisted types + if self.justknob is not None and self.default is not None: + assert isinstance( + self.default, bool + ), f"justknobs only support booleans, {self.default} is not a boolean" + if self.value_type is not None and ( + config.env_name_default is not None or config.env_name_force is not None + ): + assert self.value_type in ( + bool, + str, + Optional[bool], + Optional[str], + ), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" + + +class ConfigModule(ModuleType): + # NOTE: This should be kept in sync with _config_typing.pyi. + + # The actual configuration settings. E.g., torch._dynamo.config.debug + # would live as "debug" in the key, and torch._inductor.config.triton.cudagraphs + # maps as "triton.cudagraphs". See discussion on the class for meaning of various sub items + _config: dict[str, _ConfigEntry] + _bypass_keys: set[str] + _compile_ignored_keys: set[str] + _is_dirty: bool + _hash_digest: Optional[bytes] + + def __init__(self) -> None: + raise NotImplementedError( + f"use {__name__}.install_config_module(sys.modules[__name__])" + ) + + def __setattr__(self, name: str, value: object) -> None: + if name in self._bypass_keys: + super().__setattr__(name, value) + elif name not in self._config: + raise AttributeError(f"{self.__name__}.{name} does not exist") + elif self._config[name].alias is not None: + self._set_alias_val(self._config[name], value) + else: + self._config[name].user_override = value + self._is_dirty = True + self._config[name].hide = False + + def __getattr__(self, name: str) -> Any: + try: + config = self._config[name] + + if config.hide: + raise AttributeError(f"{self.__name__}.{name} does not exist") + + alias_val = self._get_alias_val(config) + if alias_val is not _UNSET_SENTINEL: + return alias_val + + if config.env_value_force is not _UNSET_SENTINEL: + return config.env_value_force + + if config.user_override is not _UNSET_SENTINEL: + return config.user_override + + if config.env_value_default is not _UNSET_SENTINEL: + return config.env_value_default + + if config.justknob is not None: + # JK only supports bools and ints + return justknobs_check(name=config.justknob, default=config.default) + + # Note that reference types can still be modified, so we + # copy them to user_overrides in case the user overrides + # them + if isinstance(config.default, (list, set, dict)): + config.user_override = copy.deepcopy(config.default) + return config.user_override + return config.default + + except KeyError as e: + # make hasattr() work properly + raise AttributeError(f"{self.__name__}.{name} does not exist") from e + + def __delattr__(self, name: str) -> None: + self._is_dirty = True + # must support delete because unittest.mock.patch deletes + # then recreate things + self._config[name].user_override = _UNSET_SENTINEL + self._config[name].hide = True + + def _get_alias_module_and_name( + self, entry: _ConfigEntry + ) -> Optional[tuple[ModuleType, str]]: + alias = entry.alias + if alias is None: + return None + module_name, constant_name = alias.rsplit(".", 1) + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise AttributeError("config alias {alias} does not exist") from e + return module, constant_name + + def _get_alias_val(self, entry: _ConfigEntry) -> Any: + data = self._get_alias_module_and_name(entry) + if data is None: + return _UNSET_SENTINEL + module, constant_name = data + constant_value = getattr(module, constant_name) + return constant_value + + def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None: + data = self._get_alias_module_and_name(entry) + assert data is not None + module, constant_name = data + setattr(module, constant_name, val) + + def _is_default(self, name: str) -> bool: + """ + Returns true if the config is at its default value. + configs overridden by the env are not considered default. + """ + config_val = self._config[name] + # The config is not overridden by the user, and the env_value_default + # is different from the default value (meaning user has set the env to + # change the default value). + not_set_env_default = ( + config_val.env_value_default is _UNSET_SENTINEL + or config_val.env_value_default == config_val.default + ) + not_set_env_force = ( + config_val.env_value_force is _UNSET_SENTINEL + or config_val.env_value_force == config_val.default + ) + + unset = config_val.user_override is _UNSET_SENTINEL + # Handle reference types specially to avoid spammy warnings + if isinstance(config_val.default, (list, set, dict)): + unset = unset or config_val.user_override == config_val.default + return unset and not_set_env_default and not_set_env_force + + def _get_dict( + self, + ignored_keys: Optional[list[str]] = None, + ignored_prefixes: Optional[list[str]] = None, + skip_default: bool = False, + ) -> dict[str, Any]: + """Export a dictionary of current configuration keys and values. + + This function is design to provide a single point which handles + accessing config options and exporting them into a dictionary. + This is used by a number of different user facing export methods + which all have slightly different semantics re: how and what to + skip. + If a config is aliased, it skips this config. + + Arguments: + ignored_keys are keys that should not be exported. + ignored_prefixes are prefixes that if a key matches should + not be exported + skip_default does two things. One if a key has not been modified + it skips it. + """ + config: dict[str, Any] = {} + for key in self._config: + if ignored_keys and key in ignored_keys: + continue + if ignored_prefixes: + if any(key.startswith(prefix) for prefix in ignored_prefixes): + continue + if skip_default and self._is_default(key): + continue + if self._config[key].alias is not None: + continue + config[key] = copy.deepcopy(getattr(self, key)) + + return config + + def get_type(self, config_name: str) -> type: + return self._config[config_name].value_type + + def save_config(self) -> bytes: + """Convert config to a pickled blob""" + ignored_keys = getattr(self, "_save_config_ignore", []) + return pickle.dumps( + self._get_dict(ignored_keys=ignored_keys), + protocol=2, + ) + + def save_config_portable( + self, *, ignore_private_configs: bool = True + ) -> dict[str, Any]: + """Convert config to portable format""" + prefixes = [] + if ignore_private_configs: + prefixes.append("_") + prefixes.extend(getattr(self, "_cache_config_ignore_prefix", [])) + return self._get_dict(ignored_prefixes=prefixes) + + def codegen_config(self) -> str: + """Convert config to Python statements that replicate current config. + This does NOT include config settings that are at default values. + """ + + # additional imports required + imports = set() + + def get_module_name(func: Callable, add_dot: bool) -> str: + module_name = func.__module__ + if module_name == "builtins": + module_name = "" + if add_dot and module_name != "": + module_name += "." + return module_name + + def add_import(func: Callable) -> None: + module_name = get_module_name(func, False) + if module_name: + imports.add(module_name) + + def list_of_callables_to_string(v: Union[list, set]) -> list[str]: + return [f"{get_module_name(item, True)}{item.__name__}" for item in v] + + def importable_callable(v: Any) -> bool: + # functools.partial has no attributes below but is a callable + return callable(v) and hasattr(v, "__module__") and hasattr(v, "__name__") + + def get_config_line(mod, k, v) -> str: # type: ignore[no-untyped-def] + """ + Return a string version of the config line. + Handle v when v is a callable, or a list/dict of callables. Add import statements for callables if necessary. + We assume that the value of a single config won't be a mix of callables and non-callables. + + Example output: + import logging + import _warnings + torch._dynamo.config.reorderable_logging_functions = { _warnings.warn, logging.warn, print } + """ + if importable_callable(v): + add_import(v) + return f"{mod}.{k} = {get_module_name(v, True)}{v.__name__}" + elif isinstance(v, (list, set)) and all( + importable_callable(item) for item in v + ): + for item in v: + add_import(item) + v_list = list_of_callables_to_string(v) + if isinstance(v, list): + return f"{mod}.{k} = {v_list}" + else: + return f"{mod}.{k} = {{ {', '.join(v_list)} }}" + else: + return f"{mod}.{k} = {v!r}" + + lines = [] + mod = self.__name__ + for k, v in self._get_dict( + ignored_keys=getattr(self, "_save_config_ignore", []), skip_default=True + ).items(): + lines.append(get_config_line(mod, k, v)) + for import_name in imports: + lines.insert(0, f"import {import_name}") + return "\n".join(lines) + + def get_hash(self) -> bytes: + """Hashes the configs that are not compile_ignored""" + if self._is_dirty or self._hash_digest is None: + dict_to_hash = self._get_dict(ignored_keys=list(self._compile_ignored_keys)) + string_to_hash = repr(sorted(dict_to_hash.items())) + self._hash_digest = hashlib.md5( + string_to_hash.encode("utf-8"), usedforsecurity=False + ).digest() + self._is_dirty = False + return self._hash_digest + + @deprecated( + "`config.to_dict()` has been deprecated. It no longer changes the underlying config." + " use `config.get_config_copy()` instead if you just want a copy of the config, or " + "config.load_config if you need mutable access", + category=FutureWarning, + ) + def to_dict(self) -> dict[str, Any]: + return self.get_config_copy() + + @deprecated( + "`config.shallow_copy_dict()` has been deprecated. It no longer changes the underlying config." + " use `config.get_config_copy()` instead if you just want a copy of the config, or " + "config.load_config if you need mutable access", + category=FutureWarning, + ) + def shallow_copy_dict(self) -> dict[str, Any]: + return self.get_config_copy() + + def load_config(self, maybe_pickled_config: Union[bytes, dict[str, Any]]) -> None: + """Restore from a prior call to save_config() or shallow_copy_dict()""" + if not isinstance(maybe_pickled_config, dict): + config = pickle.loads(maybe_pickled_config) + else: + config = maybe_pickled_config + for k, v in config.items(): + if k in self._config: + setattr(self, k, v) + else: + from torch._dynamo.utils import warn_once + + warn_once(f"key {k} with value {v} is not understood by this config") + + def get_config_copy(self) -> dict[str, Any]: + return self._get_dict() + + def patch( + self, + arg1: Optional[Union[str, dict[str, Any]]] = None, + arg2: Any = None, + **kwargs: dict[str, Any], + ) -> "ContextDecorator": + """ + Decorator and/or context manager to make temporary changes to a config. + + As a decorator: + + @config.patch("name", val) + @config.patch(name1=val1, name2=val2) + @config.patch({"name1": val1, "name2", val2}) + def foo(...): + ... + + As a context manager: + + with config.patch("name", val): + ... + """ + changes: dict[str, Any] + if arg1 is not None: + if arg2 is not None: + assert isinstance(arg1, str) + # patch("key", True) syntax + changes = {arg1: arg2} + else: + assert isinstance(arg1, dict) + # patch({"key": True}) syntax + changes = arg1 + assert not kwargs + else: + # patch(key=True) syntax + changes = kwargs + assert arg2 is None + assert isinstance(changes, dict), f"expected `dict` got {type(changes)}" + prior: dict[str, Any] = {} + config = self + + class ConfigPatch(ContextDecorator): + def __init__(self) -> None: + self.changes = changes + + def __enter__(self) -> None: + assert not prior + for key in self.changes.keys(): + # KeyError on invalid entry + prior[key] = config.__getattr__(key) + for k, v in self.changes.items(): + config.__setattr__(k, v) + + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def] + for k, v in prior.items(): + config.__setattr__(k, v) + prior.clear() + + return ConfigPatch() + + def _make_closure_patcher(self, **changes: dict[str, Any]) -> Any: + """ + A lower-overhead version of patch() for things on the critical path. + + Usage: + + # do this off the critical path + change_fn = config.make_closure_patcher(foo=True) + + ... + + revert = change_fn() + try: + ... + finally: + revert() + + """ + config = self._config + + def change() -> Callable[[], None]: + prior = {k: config[k].user_override for k in changes} + for k, v in changes.items(): + self._config[k].user_override = v + + def revert() -> None: + for k, v in prior.items(): + self._config[k].user_override = v + + return revert + + return change + + +class ContextDecorator(contextlib.ContextDecorator): + """ + Same as contextlib.ContextDecorator, but with support for + `unittest.TestCase` + """ + + def __enter__(self) -> None: + raise NotImplementedError("NYI") + + def __exit__(self, exc_type, exc_val, exc_tb) -> NoReturn: # type: ignore[no-untyped-def] + raise NotImplementedError("NYI") + + def __call__(self, func: Callable[[Any], Any]) -> Any: + if isinstance(func, type) and issubclass(func, unittest.TestCase): + + class _TestCase(func): # type: ignore[valid-type, misc] + @classmethod + def setUpClass(cls) -> None: + self.__enter__() + try: + super().setUpClass() + except Exception: + self.__exit__(None, None, None) + raise + + @classmethod + def tearDownClass(cls) -> None: + try: + super().tearDownClass() + finally: + self.__exit__(None, None, None) + + _TestCase.__name__ = func.__name__ + _TestCase.__qualname__ = func.__qualname__ + _TestCase.__module__ = func.__module__ + + return _TestCase + + return super().__call__(func) + + +class SubConfigProxy: + """ + Shim to redirect to main config. + `config.triton.cudagraphs` maps to _config["triton.cudagraphs"] + """ + + def __init__(self, config: object, prefix: str): + # `super().__setattr__` to bypass custom `__setattr__` + super().__setattr__("_config", config) + super().__setattr__("_prefix", prefix) + + def __setattr__(self, name: str, value: object) -> None: + return self._config.__setattr__(self._prefix + name, value) + + def __getattr__(self, name: str) -> Any: + return self._config.__getattr__(self._prefix + name) + + def __delattr__(self, name: str) -> None: + return self._config.__delattr__(self._prefix + name) + + +def patch_object(obj: object, name: str, value: object) -> object: + """ + Workaround `mock.patch.object` issue with ConfigModule + """ + if isinstance(obj, ConfigModule): + return obj.patch(name, value) + return mock.patch.object(obj, name, value) + + +def get_tristate_env(name: str, default: Any = None) -> Optional[bool]: + value = os.environ.get(name) + if value == "1": + return True + if value == "0": + return False + return default diff --git a/phivenv/Lib/site-packages/torch/utils/_config_typing.pyi b/phivenv/Lib/site-packages/torch/utils/_config_typing.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c008b111df1fe6037f458d24ccf52b4f5f68e40d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_config_typing.pyi @@ -0,0 +1,34 @@ +# mypy: allow-untyped-defs +from typing import Any, TYPE_CHECKING + +""" +This was semi-automatically generated by running + + stubgen torch.utils._config_module.py + +And then manually extracting the methods of ConfigModule and converting them into top-level functions. + +This file should be imported into any file that uses install_config_module like so: + + if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + + from torch.utils._config_module import install_config_module + + # adds patch, save_config, etc + install_config_module(sys.modules[__name__]) + +Note that the import should happen before the call to install_config_module(), otherwise runtime errors may occur. +""" + +assert TYPE_CHECKING, "Do not use at runtime" + +def save_config() -> bytes: ... +def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ... +def codegen_config() -> str: ... +def get_hash() -> bytes: ... +def to_dict() -> dict[str, Any]: ... +def shallow_copy_dict() -> dict[str, Any]: ... +def load_config(config: bytes | dict[str, Any]) -> None: ... +def get_config_copy() -> dict[str, Any]: ... +def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ... diff --git a/phivenv/Lib/site-packages/torch/utils/_content_store.py b/phivenv/Lib/site-packages/torch/utils/_content_store.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f9e14c5bcd7d8145ab7bf232f8673890f0bb55 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_content_store.py @@ -0,0 +1,239 @@ +# mypy: allow-untyped-defs +# This module provides a FAST (on GPU) content addressable store for storages +# (and tensors on top of them) with VERY WEAK portability guarantees (e.g., +# don't expect CPU/CUDA to address to the same hash, don't expect it to be +# portable across devices) that is NOT cryptographically secure. In return, +# we are able to hash 40G of tensor data on GPU in less than a second, +# compared to running SHA-1 in CPU which would a minute or so. The primary +# use case is for efficiently snapshotting intermediate tensor data for +# offline debugging, but it's been put in this module in case you think of +# another use case for it. The hash function could be replaced with a +# straight reimplementation of SHA-1, which would give us much stronger +# portability guarantees. +# +# WARNING: THERE IS NO BC/FC GUARANTEE FOR THIS FORMAT! If you need to format +# shift the result, consider packing it into a single torch.save object +# with traditional view sharing. +# +# Because of the weak portability guarantees, you can only write to the +# content store from a single process; we don't provide any capability +# of "reopening" a content store to add more things to it. But we don't +# assume that you can keep all of the tensors you want to add to the store +# in memory at once, because you probably can't! Nor do we assume that +# you know a priori whether or not two storages can be deduplicated or not. +# +# Note: only storages are content-addressed; tensors are name addressed +# +# Note: our padding strategy means that [1, 0] and [1] int16 tensors would +# map to the same (padded) storage. We think this will be immaterial for most +# users. + +import ctypes +import functools +import hashlib +import os.path +import struct +from collections import defaultdict +from typing import Optional + +import torch +import torch._prims as prims +import torch._utils +import torch.nn.functional as F +from torch.multiprocessing.reductions import StorageWeakRef + + +def lazy_compile(**compile_kwargs): + """Lazily wrap a function with torch.compile on the first call + + This avoids eagerly importing dynamo. + """ + + def decorate_fn(fn): + @functools.wraps(fn) + def compile_hook(*args, **kwargs): + compiled_fn = torch.compile(fn, **compile_kwargs) + globals()[fn.__name__] = functools.wraps(fn)(compiled_fn) + return compiled_fn(*args, **kwargs) + + return compile_hook + + return decorate_fn + + +# Use of torch.compile is mandatory for (1) good memory usage +# and (2) xor_sum implementation. This is our first instance of +# using PT2 to implement a kernel in PyTorch; if we get AOT capabilities +# it would be good to apply it here. +@lazy_compile(dynamic=True) +def hash_storage_kernel(x): + # The randint calls are carefully written to hit things we + # have lowerings for in inductor. Lack of unsigned 32-bit integer + # is a pain. + a = torch.randint( + -(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32 + ).abs() + a = ((a % (2**31 - 1)) + 1).long() + b = ( + torch.randint(-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32) + .abs() + .long() + ) + # This is a standard shift-multiply universal hash family + # plus xor sum hash, using Philox to generate random numbers. + # Our Philox RNG is not deterministic across devices so + # don't use this for stable hashing. + # + # This assumes fixed length so you're also obligated to bucket + # by the length of tensor as well + return prims.xor_sum((a * x + b).int(), [0]) + + +# Returns a hex digest of the data in the storage. Guaranteed to be +# SHA-1 if stable_hash=True, otherwise it will consistent for a single +# process run but not necessarily across processes. +def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) -> str: + import torch._dynamo + from torch._dynamo.utils import is_compile_supported + + device_type = storage.device.type + if stable_hash or not is_compile_supported(device_type): + cpu_storage = storage.cpu() + # TODO: make storage support buffer protocol so this isn't + # necessary + buf = (ctypes.c_byte * cpu_storage.nbytes()).from_address( + cpu_storage.data_ptr() + ) + sha1 = hashlib.sha1(usedforsecurity=False) + sha1.update(buf) + return sha1.hexdigest() + + # TODO: factor this into a random utility + if device_type == "cpu": + generator = torch._C.default_generator + elif device_type == "cuda": + generator = torch.cuda.default_generators[storage.device.index] + elif device_type == "mps": + generator = torch.mps._get_default_mps_generator() + elif device_type == "xpu": + generator = torch.xpu.default_generators[storage.device.index] + else: + raise AssertionError(f"unhandled device type {device_type}") + state = generator.get_state() + try: + generator.manual_seed(0) + x = torch.empty(0, dtype=torch.uint8, device=storage.device).set_(storage) # type: ignore[call-overload] + # The dtype-casting view cannot be compiled, and so the + # padding/reshaping also needs to be done externally even + # though it could be profitably fused + pad = -x.numel() % 4 + if pad > 0: + x = F.pad(x, (0, pad), "constant", 0) + x = x.view(torch.int32) + # We run the 32-bit hash five times with differing parameters to + # reduce chance of collision + ITER = 5 + cs = [hash_storage_kernel(x).item() for _ in range(ITER)] + return struct.pack(">" + "i" * ITER, *cs).hex() + finally: + generator.set_state(state) + + +class ContentStoreWriter: + # Structure: + # storages/ + # 00/ + # 0000..00 + # tensors/ + # name + def __init__(self, loc: str, stable_hash: bool = False) -> None: + self.loc: str = loc + self.seen_storage_hashes: set[str] = set() + self.stable_hash = stable_hash + + # TODO: offer some sort of non-blocking API to speed things up + def write_storage(self, storage: torch.UntypedStorage) -> str: + h = hash_storage(storage, stable_hash=self.stable_hash) + if h in self.seen_storage_hashes: + return h + # TODO: consider not using torch.save for this; we don't actually + # need any metadata for the storage + subfolder = os.path.join(self.loc, "storages") + os.makedirs(subfolder, exist_ok=True) + target = os.path.join(subfolder, h) + if os.path.exists(target): + return h + torch.save(storage, target) + self.seen_storage_hashes.add(h) + return h + + def compute_tensor_metadata(self, t: torch.Tensor, h=None): + if h is None: + h = hash_storage(t.untyped_storage(), stable_hash=self.stable_hash) + return ( + t.dtype, + h, + t.storage_offset(), + tuple(t.shape), + t.stride(), + torch._utils.get_tensor_metadata(t), + ) + + def write_tensor(self, name: str, t: torch.Tensor) -> None: + storage = t.untyped_storage() + h = self.write_storage(storage) + # TODO: Support more advanced snapshotting of requires_grad/grad/etc + d, f = os.path.split(name) + payload = self.compute_tensor_metadata(t, h=h) + subfolder = os.path.join(self.loc, "tensors", d) + os.makedirs(subfolder, exist_ok=True) + torch.save(payload, os.path.join(subfolder, f)) + + +class ContentStoreReader: + def __init__(self, loc: str, *, cache=True) -> None: + self.loc = loc + self.storage_cache: Optional[ + dict[Optional[torch.device], dict[str, StorageWeakRef]] + ] = None + if cache: + self.storage_cache = defaultdict(dict) + + def read_storage(self, h: str, *, device=None) -> torch.UntypedStorage: + if device is not None: + device = torch.device(device) + ws = ( + self.storage_cache[device].get(h) + if self.storage_cache is not None + else None + ) + s: Optional[torch.UntypedStorage] + if ws is not None: + s = torch.UntypedStorage._new_with_weak_ptr(ws.cdata) + if s is not None: + return s + s = torch.load( + os.path.join(self.loc, "storages", h), + weights_only=True, + map_location=device, + )._untyped_storage + assert s is not None + if self.storage_cache is not None: + self.storage_cache[device][h] = StorageWeakRef(s) + return s + + def read_tensor_metadata(self, name: str): + fn = os.path.join(self.loc, "tensors", name) + if not os.path.exists(fn): + raise FileNotFoundError(fn) + return torch.load(fn, weights_only=True) + + def read_tensor(self, name: str, *, device=None) -> torch.Tensor: + dtype, h, storage_offset, size, stride, metadata = self.read_tensor_metadata( + name + ) + storage = self.read_storage(h, device=device) + t = torch.tensor([], dtype=dtype, device=storage.device) + t.set_(storage, storage_offset, size, stride) + torch._utils.set_tensor_metadata(t, metadata) + return t diff --git a/phivenv/Lib/site-packages/torch/utils/_contextlib.py b/phivenv/Lib/site-packages/torch/utils/_contextlib.py new file mode 100644 index 0000000000000000000000000000000000000000..23e162910f8b1f0aef71c482f52179228798318e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_contextlib.py @@ -0,0 +1,161 @@ +# mypy: allow-untyped-defs +# Extra utilities for working with context managers that should have been +# in the standard library but are not + +import functools +import inspect +import sys +import warnings +from typing import Any, Callable, cast, TypeVar + + +# Used for annotating the decorator usage of _DecoratorContextManager (e.g., +# 'no_grad' and 'enable_grad'). +# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators +FuncType = Callable[..., Any] +F = TypeVar("F", bound=FuncType) + + +def _wrap_generator(ctx_factory, func): + """ + Wrap each generator invocation with the context manager factory. + + The input should be a function that returns a context manager, + not a context manager itself, to handle one-shot context managers. + """ + + @functools.wraps(func) + def generator_context(*args, **kwargs): + gen = func(*args, **kwargs) + + # Generators are suspended and unsuspended at `yield`, hence we + # make sure the grad mode is properly set every time the execution + # flow returns into the wrapped generator and restored when it + # returns through our `yield` to our caller (see PR #49017). + try: + # Issuing `None` to a generator fires it up + with ctx_factory(): + response = gen.send(None) + + while True: + try: + # Forward the response to our caller and get its next request + request = yield response + + except GeneratorExit: + # Inform the still active generator about its imminent closure + with ctx_factory(): + gen.close() + raise + + except BaseException: + # Propagate the exception thrown at us by the caller + with ctx_factory(): + response = gen.throw(*sys.exc_info()) + + else: + # Pass the last request to the generator and get its response + with ctx_factory(): + response = gen.send(request) + + # We let the exceptions raised above by the generator's `.throw` or + # `.send` methods bubble up to our caller, except for StopIteration + except StopIteration as e: + # The generator informed us that it is done: take whatever its + # returned value (if any) was and indicate that we're done too + # by returning it (see docs for python's return-statement). + return e.value + + return generator_context + + +def context_decorator(ctx, func): + """ + Like contextlib.ContextDecorator. + + But with the following differences: + 1. Is done by wrapping, rather than inheritance, so it works with context + managers that are implemented from C and thus cannot easily inherit from + Python classes + 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743) + 3. Errors out if you try to wrap a class, because it is ambiguous whether + or not you intended to wrap only the constructor + + The input argument can either be a context manager (in which case it must + be a multi-shot context manager that can be directly invoked multiple times) + or a callable that produces a context manager. + """ + assert not (callable(ctx) and hasattr(ctx, "__enter__")), ( + f"Passed in {ctx} is both callable and also a valid context manager " + "(has __enter__), making it ambiguous which interface to use. If you " + "intended to pass a context manager factory, rewrite your call as " + "context_decorator(lambda: ctx()); if you intended to pass a context " + "manager directly, rewrite your call as context_decorator(lambda: ctx)" + ) + + if not callable(ctx): + + def ctx_factory(): + return ctx + + else: + ctx_factory = ctx + + if inspect.isclass(func): + raise RuntimeError( + "Cannot decorate classes; it is ambiguous whether or not only the " + "constructor or all methods should have the context manager applied; " + "additionally, decorating a class at definition-site will prevent " + "use of the identifier as a conventional type. " + "To specify which methods to decorate, decorate each of them " + "individually." + ) + + if inspect.isgeneratorfunction(func): + return _wrap_generator(ctx_factory, func) + + @functools.wraps(func) + def decorate_context(*args, **kwargs): + with ctx_factory(): + return func(*args, **kwargs) + + return decorate_context + + +class _DecoratorContextManager: + """Allow a context manager to be used as a decorator.""" + + def __call__(self, orig_func: F) -> F: + if inspect.isclass(orig_func): + warnings.warn( + "Decorating classes is deprecated and will be disabled in " + "future versions. You should only decorate functions or methods. " + "To preserve the current behavior of class decoration, you can " + "directly decorate the `__init__` method and nothing else.", + FutureWarning, + stacklevel=2, + ) + func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs)) + else: + func = orig_func + + return cast(F, context_decorator(self.clone, func)) + + def __enter__(self) -> None: + raise NotImplementedError + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + raise NotImplementedError + + def clone(self): + # override this method if your children class takes __init__ parameters + return self.__class__() + + +class _NoParamDecoratorContextManager(_DecoratorContextManager): + """Allow a context manager to be used as a decorator without parentheses.""" + + def __new__(cls, orig_func=None): + if orig_func is None: + return super().__new__(cls) + return cls()(orig_func) diff --git a/phivenv/Lib/site-packages/torch/utils/_cpp_embed_headers.py b/phivenv/Lib/site-packages/torch/utils/_cpp_embed_headers.py new file mode 100644 index 0000000000000000000000000000000000000000..d1448c39399f6730727e9dfc2affc28af4a87357 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_cpp_embed_headers.py @@ -0,0 +1,58 @@ +from collections.abc import Sequence +from pathlib import Path +from re import match as _match +from typing import Optional, Union + + +def read_file(fname: Union[Path, str]) -> list[str]: + with open(fname, encoding="utf-8") as f: + return f.readlines() + + +def _embed_headers( + content: list[str], include_dirs: list[Path], processed_files: set[str] +) -> str: + for line_idx, cur_line in enumerate(content): + # Eliminate warning: `#pragma once in main file` + if cur_line.startswith("#pragma once"): + content[line_idx] = "" + continue + m = _match('^\\s*#include\\s*[<"]([^>"]+)[>"]', cur_line) + if m is None: + continue + for include_dir in include_dirs: + path = include_dir / m[1] + if not path.exists(): + continue + if str(path) in processed_files: + content[line_idx] = "" + continue + processed_files.add(str(path)) + content[line_idx] = _embed_headers( + read_file(path), include_dirs, processed_files + ) + break + return "".join(content) + + +def embed_headers( + fname: str, include_dirs: Optional[Union[Sequence[str], Sequence[Path], str]] = None +) -> str: + if include_dirs is None: + base_dir = Path(__file__).parent.parent.parent + include_dirs = [base_dir, base_dir / "aten" / "src"] + elif isinstance(include_dirs, str): + include_dirs = [Path(include_dirs)] + else: + include_dirs = [Path(x) for x in include_dirs] + + return _embed_headers(read_file(fname), include_dirs, {fname}) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage:\n {sys.argv[0]} filename") + sys.exit(1) + print(embed_headers(sys.argv[1])) diff --git a/phivenv/Lib/site-packages/torch/utils/_cpp_extension_versioner.py b/phivenv/Lib/site-packages/torch/utils/_cpp_extension_versioner.py new file mode 100644 index 0000000000000000000000000000000000000000..13663f7195d4468ca0f07770bb8453397dcb9597 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_cpp_extension_versioner.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +import collections + + +Entry = collections.namedtuple("Entry", "version, hash") + + +def update_hash(seed, value): + # Good old boost::hash_combine + # https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html + return seed ^ (hash(value) + 0x9E3779B9 + (seed << 6) + (seed >> 2)) + + +def hash_source_files(hash_value, source_files): + for filename in source_files: + with open(filename, "rb") as file: + hash_value = update_hash(hash_value, file.read()) + return hash_value + + +def hash_build_arguments(hash_value, build_arguments): + for group in build_arguments: + if group: + for argument in group: + hash_value = update_hash(hash_value, argument) + return hash_value + + +class ExtensionVersioner: + def __init__(self): + self.entries = {} + + def get_version(self, name): + entry = self.entries.get(name) + return None if entry is None else entry.version + + def bump_version_if_changed( + self, + name, + source_files, + build_arguments, + build_directory, + with_cuda, + with_sycl, + is_python_module, + is_standalone, + ): + hash_value = 0 + hash_value = hash_source_files(hash_value, source_files) + hash_value = hash_build_arguments(hash_value, build_arguments) + hash_value = update_hash(hash_value, build_directory) + hash_value = update_hash(hash_value, with_cuda) + hash_value = update_hash(hash_value, with_sycl) + hash_value = update_hash(hash_value, is_python_module) + hash_value = update_hash(hash_value, is_standalone) + + entry = self.entries.get(name) + if entry is None: + self.entries[name] = entry = Entry(0, hash_value) + elif hash_value != entry.hash: + self.entries[name] = entry = Entry(entry.version + 1, hash_value) + + return entry.version diff --git a/phivenv/Lib/site-packages/torch/utils/_cxx_pytree.py b/phivenv/Lib/site-packages/torch/utils/_cxx_pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5e3c37cea84e9f349e1237065c8cf5cc148fa7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_cxx_pytree.py @@ -0,0 +1,1111 @@ +""" +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_leaves` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. +""" + +import functools +import sys +import types +from collections.abc import Iterable +from typing import Any, Callable, Optional, overload, TypeVar, Union +from typing_extensions import deprecated, TypeIs + +import torch.utils._pytree as python_pytree +from torch.torch_version import TorchVersion as _TorchVersion +from torch.utils._pytree import ( + is_namedtuple as is_namedtuple, + is_namedtuple_class as is_namedtuple_class, + is_namedtuple_instance as is_namedtuple_instance, + is_structseq as is_structseq, + is_structseq_class as is_structseq_class, + is_structseq_instance as is_structseq_instance, + KeyEntry as KeyEntry, +) + + +# Do not try to import `optree` package if the static version check already fails. +if not python_pytree._cxx_pytree_dynamo_traceable: + raise ImportError( + f"{__name__} depends on `optree>={python_pytree._optree_minimum_version}`, " + "which is an optional dependency of PyTorch. " + "To use it, please upgrade your optree package via " + "`python3 -m pip install --upgrade optree`" + ) + + +import optree +from optree import PyTreeSpec as TreeSpec # direct import for type annotations + + +__all__ = [ + "PyTree", + "Context", + "FlattenFunc", + "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", + "TreeSpec", + "LeafSpec", + "keystr", + "key_get", + "register_pytree_node", + "tree_is_leaf", + "tree_flatten", + "tree_flatten_with_path", + "tree_unflatten", + "tree_iter", + "tree_leaves", + "tree_leaves_with_path", + "tree_structure", + "tree_map", + "tree_map_with_path", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_all", + "tree_any", + "tree_all_only", + "tree_any_only", + "treespec_dumps", + "treespec_loads", + "treespec_pprint", + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", +] + + +# In-tree installation may have VCS-based versioning. Update the previous static version. +python_pytree._optree_version = _TorchVersion(optree.__version__) # type: ignore[attr-defined] + +__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch") +__TORCH_DICT_SESSION.__enter__() # enable globally and permanently + + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") + + +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] +OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] +KeyPath = tuple[KeyEntry, ...] +FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]] + + +def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc: + @functools.wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Any: + return func(*reversed(args), **kwargs) + + return wrapped + + +def register_pytree_node( + cls: type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node. + + Args: + cls (type): A Python type to treat as an internal pytree node. + flatten_fn (callable): A function to be used during flattening, taking an instance of + ``cls`` and returning a pair, with (1) an iterable for the children to be flattened + recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be + passed to the ``unflatten_fn``. + unflatten_fn (callable): A function taking two arguments: the auxiliary data that was + returned by ``flatten_fn`` and stored in the treespec, and the unflattened children. + The function should return an instance of ``cls``. + serialized_type_name (str, optional): A keyword argument used to specify the fully + qualified name used when serializing the tree spec. + to_dumpable_context (callable, optional): An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable representation. This is + used for json serialization, which is being used in :mod:`torch.export` right now. + from_dumpable_context (callable, optional): An optional keyword argument to custom specify + how to convert the custom json dumpable representation of the context back to the + original context. This is used for json deserialization, which is being used in + :mod:`torch.export` right now. + + Example:: + + >>> # xdoctest: +SKIP + >>> # Registry a Python type with lambda functions + >>> register_pytree_node( + ... set, + ... lambda s: (sorted(s), None, None), + ... lambda children, _: set(children), + ... ) + """ + if flatten_with_keys_fn is not None: + raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + python_pytree._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +@deprecated( + "`torch.utils._cxx_pytree._register_pytree_node` is deprecated. " + "Please use `torch.utils._cxx_pytree.register_pytree_node` instead.", + category=FutureWarning, +) +def _register_pytree_node( + cls: type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """Register a container-like type as pytree node for the C++ pytree only. + + The ``namespace`` argument is used to avoid collisions that occur when different libraries + register the same Python type with different behaviors. It is recommended to add a unique prefix + to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify + the same class in different namespaces for different use cases. + + .. warning:: + For safety reasons, a ``namespace`` must be specified while registering a custom type. It is + used to isolate the behavior of flattening and unflattening a pytree node type. This is to + prevent accidental collisions between different libraries that may register the same type. + + Args: + cls (type): A Python type to treat as an internal pytree node. + flatten_fn (callable): A function to be used during flattening, taking an instance of + ``cls`` and returning a pair, with (1) an iterable for the children to be flattened + recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be + passed to the ``unflatten_fn``. + unflatten_fn (callable): A function taking two arguments: the auxiliary data that was + returned by ``flatten_fn`` and stored in the treespec, and the unflattened children. + The function should return an instance of ``cls``. + serialized_type_name (str, optional): A keyword argument used to specify the fully + qualified name used when serializing the tree spec. + to_dumpable_context (callable, optional): An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable representation. This is + used for json serialization, which is being used in :mod:`torch.export` right now. + from_dumpable_context (callable, optional): An optional keyword argument to custom specify + how to convert the custom json dumpable representation of the context back to the + original context. This is used for json deserialization, which is being used in + :mod:`torch.export` right now. + """ + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _private_register_pytree_node( + cls: type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """This is an internal function that is used to register a pytree node type + for the C++ pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support + # PyStructSequence types + if not optree.is_structseq_class(cls): + optree.register_pytree_node( + cls, + flatten_fn, + _reverse_args(unflatten_fn), + namespace="torch", + ) + + +def _is_pytreespec_instance(obj: Any, /) -> TypeIs[TreeSpec]: + return isinstance(obj, TreeSpec) + + +def tree_is_leaf( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + """Check if a pytree is a leaf. + + >>> tree_is_leaf(1) + True + >>> tree_is_leaf(None) + True + >>> tree_is_leaf([1, 2, 3]) + False + >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) + True + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + False + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + False + + Args: + tree (pytree): A pytree to check if it is a leaf node. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A boolean indicating if the pytree is a leaf node. + """ + return optree.tree_is_leaf( + tree, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + + +def tree_flatten( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> tuple[list[Any], TreeSpec]: + """Flatten a pytree. + + See also :func:`tree_unflatten`. + + The flattening order (i.e., the order of elements in the output list) is deterministic, + corresponding to a left-to-right depth-first tree traversal. + + >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5} + >>> tree_flatten(tree) + ([2, 3, 4, 1, None, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')) + >>> tree_flatten(1) + ([1], PyTreeSpec(*, NoneIsLeaf, namespace='torch')) + >>> tree_flatten(None) + ([None], PyTreeSpec(*, NoneIsLeaf, namespace='torch')) + >>> from collections import OrderedDict + >>> tree = OrderedDict([("b", (2, [3, 4])), ("a", 1), ("c", None), ("d", 5)]) + >>> tree_flatten(tree) + ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf, namespace='torch')) + + Args: + tree (pytree): A pytree to flatten. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the + second element is a treespec representing the structure of the pytree. + """ + return optree.tree_flatten( # type: ignore[return-value] + tree, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + + +def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: + """Reconstruct a pytree from the treespec and the leaves. + + The inverse of :func:`tree_flatten`. + + >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5} + >>> leaves, treespec = tree_flatten(tree) + >>> tree == tree_unflatten(leaves, treespec) + True + + Args: + leaves (iterable): The list of leaves to use for reconstruction. The list must match the + number of leaves of the treespec. + treespec (TreeSpec): The treespec to reconstruct. + + Returns: + The reconstructed pytree, containing the ``leaves`` placed in the structure described by + ``treespec``. + """ + if not _is_pytreespec_instance(treespec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " + f"PyTreeSpec but got item of type {type(treespec)}." + ) + return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type] + + +def tree_iter( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Any]: + """Get an iterator over the leaves of a pytree. + + See also :func:`tree_flatten`. + + >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5} + >>> list(tree_iter(tree)) + [2, 3, 4, 1, None, 5] + >>> list(tree_iter(1)) + [1] + >>> list(tree_iter(None)) + [None] + + Args: + tree (pytree): A pytree to flatten. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + An iterator over the leaf values. + """ + return optree.tree_iter( + tree, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + + +def tree_leaves( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> list[Any]: + """Get the leaves of a pytree. + + See also :func:`tree_flatten`. + + >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5} + >>> tree_leaves(tree) + [2, 3, 4, 1, None, 5] + >>> tree_leaves(1) + [1] + >>> tree_leaves(None) + [None] + + Args: + tree (pytree): A pytree to flatten. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A list of leaf values. + """ + return optree.tree_leaves( + tree, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + + +def tree_structure( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + """Get the treespec for a pytree. + + See also :func:`tree_flatten`. + + >>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5} + >>> tree_structure(tree) + PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch') + >>> tree_structure(1) + PyTreeSpec(*, NoneIsLeaf, namespace='torch') + >>> tree_structure(None) + PyTreeSpec(*, NoneIsLeaf, namespace='torch') + + Args: + tree (pytree): A pytree to flatten. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A treespec object representing the structure of the pytree. + """ + return optree.tree_structure( # type: ignore[return-value] + tree, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + + +def tree_map( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Map a multi-input function over pytree args to produce a new pytree. + + See also :func:`tree_map_`. + + >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)}) + {'x': 8, 'y': (43, 65)} + >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None}) + {'x': False, 'y': (False, False), 'z': True} + + If multiple inputs are given, the structure of the tree is taken from the first input; + subsequent inputs need only have ``tree`` as a prefix: + + >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` + is the tuple of values at corresponding nodes in ``rests``. + """ + return optree.tree_map( + func, + tree, + *rests, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + + +def tree_map_( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. + + See also :func:`tree_map`. + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf + in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + """ + return optree.tree_map_( + func, + tree, + *rests, + is_leaf=is_leaf, + none_is_leaf=True, + namespace="torch", + ) + + +Type2 = tuple[type[T], type[S]] +Type3 = tuple[type[T], type[S], type[U]] +if sys.version_info >= (3, 10): + TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType] +else: + TypeAny = Union[type[Any], tuple[type[Any], ...]] + +Fn2 = Callable[[Union[T, S]], R] +Fn3 = Callable[[Union[T, S, U]], R] +Fn = Callable[[T], R] +FnAny = Callable[[Any], R] + +MapOnlyFn = Callable[[T], Callable[[Any], Any]] + + +# These specializations help with type inference on the lambda passed to this +# function +@overload +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: + ... + + +# This specialization is needed for the implementations below that call +@overload +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: + ... + + +def map_only( + type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], / +) -> MapOnlyFn[FnAny[Any]]: + """ + Suppose you are writing a tree_map over tensors, leaving everything + else unchanged. Ordinarily you would have to write: + + def go(t): + if isinstance(t, Tensor): + return ... + else: + return t + + With this function, you only need to write: + + @map_only(Tensor) + def go(t): + return ... + + You can also directly use 'tree_map_only' + """ + if isinstance(type_or_types_or_pred, (type, tuple)) or ( + sys.version_info >= (3, 10) + and isinstance(type_or_types_or_pred, types.UnionType) + ): + + def pred(x: Any) -> bool: + return isinstance(x, type_or_types_or_pred) # type: ignore[arg-type] + + elif callable(type_or_types_or_pred): + pred = type_or_types_or_pred # type: ignore[assignment] + else: + raise TypeError("Argument must be a type, a tuple of types, or a callable.") + + def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: + @functools.wraps(func) + def wrapped(x: T) -> Any: + if pred(x): + return func(x) + return x + + return wrapped + + return wrapper + + +@overload +def tree_map_only( + type_or_types_or_pred: type[T], + /, + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + type_or_types_or_pred: Type2[T, S], + /, + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + type_or_types_or_pred: Type3[T, S, U], + /, + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + type_or_types_or_pred: TypeAny, + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + type_or_types_or_pred: Callable[[Any], bool], + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only( + type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +@overload +def tree_map_only_( + type_or_types_or_pred: type[T], + /, + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + type_or_types_or_pred: Type2[T, S], + /, + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + type_or_types_or_pred: Type3[T, S, U], + /, + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + type_or_types_or_pred: TypeAny, + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + type_or_types_or_pred: Callable[[Any], bool], + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only_( + type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +def tree_all( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(map(pred, flat_args)) + + +def tree_any( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(map(pred, flat_args)) + + +@overload +def tree_all_only( + type_or_types: type[T], + /, + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + type_or_types: Type2[T, S], + /, + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + type_or_types: Type3[T, S, U], + /, + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_all_only( + type_or_types: TypeAny, + /, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(pred(x) for x in flat_args if isinstance(x, type_or_types)) + + +@overload +def tree_any_only( + type_or_types: type[T], + /, + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + type_or_types: Type2[T, S], + /, + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + type_or_types: Type3[T, S, U], + /, + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_any_only( + type_or_types: TypeAny, + /, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(pred(x) for x in flat_args if isinstance(x, type_or_types)) + + +def broadcast_prefix( + prefix_tree: PyTree, + full_tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> list[Any]: + """Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``. + + If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be + constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**. + + This function returns a list of leaves with the same size as ``full_tree``. The leaves are + replicated from ``prefix_tree``. The number of replicas is determined by the corresponding + subtree in ``full_tree``. + + >>> broadcast_prefix(1, [1, 2, 3]) + [1, 1, 1] + >>> broadcast_prefix([1, 2, 3], [1, 2, 3]) + [1, 2, 3] + >>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4]) + Traceback (most recent call last): + ... + ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4]. + >>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) + [1, 2, 3, 3] + >>> broadcast_prefix([1, 2, 3], [1, 2, {"a": 3, "b": 4, "c": (None, 5)}]) + [1, 2, 3, 3, 3, 3] + + Args: + prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``. + full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``. + """ + result: list[Any] = [] + + def add_leaves(x: Any, subtree: PyTree) -> None: + subtreespec = tree_structure(subtree, is_leaf=is_leaf) + result.extend([x] * subtreespec.num_leaves) + + tree_map_( + add_leaves, + prefix_tree, + full_tree, + is_leaf=is_leaf, + ) + return result + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten( + tree: PyTree, + treespec: TreeSpec, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Optional[list[Any]]: + assert _is_pytreespec_instance(treespec) + full_tree = tree_unflatten([0] * treespec.num_leaves, treespec) + try: + return broadcast_prefix(tree, full_tree, is_leaf=is_leaf) + except ValueError: + return None + + +def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: + """Serialize a treespec to a JSON string.""" + if not _is_pytreespec_instance(treespec): + raise TypeError( + f"treespec_dumps(treespec): Expected `treespec` to be instance of " + f"PyTreeSpec but got item of type {type(treespec)}." + ) + + dummy_tree = tree_unflatten([0] * treespec.num_leaves, treespec) + orig_treespec = python_pytree.tree_structure(dummy_tree) + return python_pytree.treespec_dumps(orig_treespec, protocol=protocol) + + +@functools.lru_cache +def treespec_loads(serialized: str) -> TreeSpec: + """Deserialize a treespec from a JSON string.""" + orig_treespec = python_pytree.treespec_loads(serialized) + dummy_tree = python_pytree.tree_unflatten( + [0] * orig_treespec.num_leaves, + orig_treespec, + ) + treespec = tree_structure(dummy_tree) + return treespec + + +class _DummyLeaf: + def __repr__(self) -> str: + return "*" + + +def treespec_pprint(treespec: TreeSpec) -> str: + dummy_tree = tree_unflatten( + [_DummyLeaf() for _ in range(treespec.num_leaves)], + treespec, + ) + return repr(dummy_tree) + + +class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc] + def __instancecheck__(self, instance: object) -> bool: + return _is_pytreespec_instance(instance) and instance.is_leaf() + + +class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): + def __new__(cls) -> "LeafSpec": + return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value] + + +def tree_flatten_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]: + """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. + + Args: + tree: a pytree to flatten. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A tuple where the first element is a list of (key path, leaf) pairs, and the + second element is a :class:`TreeSpec` representing the structure of the flattened + tree. + """ + raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") + + +def tree_leaves_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> list[tuple[KeyPath, Any]]: + """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. + + Args: + tree: a pytree. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A list of (key path, leaf) pairs. + """ + raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") + + +def tree_map_with_path( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but the provided callable takes an additional key path argument. + + Args: + func: A function that takes ``2 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. The first positional argument + to ``func`` is the key path of the leaf in question. The second + positional argument is the value of the leaf. + tree: A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests: A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the + corresponding leaf in ``tree``, ``x`` is the value at that leaf, and + ``xs`` is the tuple of values at corresponding nodes in ``rests``. + """ + raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") + + +def keystr(kp: KeyPath) -> str: + """Given a key path, return a pretty-printed representation.""" + raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") + + +def key_get(obj: Any, kp: KeyPath) -> Any: + """Given an object and a key path, return the value at the key path.""" + raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") + + +with python_pytree._NODE_REGISTRY_LOCK: + python_pytree._cxx_pytree_imported = True + args, kwargs = (), {} # type: ignore[var-annotated] + for args, kwargs in python_pytree._cxx_pytree_pending_imports: + _private_register_pytree_node(*args, **kwargs) + python_pytree._cxx_pytree_pending_imports.clear() + del args, kwargs diff --git a/phivenv/Lib/site-packages/torch/utils/_device.py b/phivenv/Lib/site-packages/torch/utils/_device.py new file mode 100644 index 0000000000000000000000000000000000000000..88d9f2a4ee8d7ed5693a77f4b7bdfe7cdd552fd2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_device.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +import functools +from typing import Optional + +import torch +from torch._C import _len_torch_function_stack +from torch.overrides import _pop_mode, _push_mode, TorchFunctionMode +from torch.utils._contextlib import context_decorator + + +CURRENT_DEVICE: Optional[torch.device] = None + + +@functools.lru_cache(1) +def _device_constructors(): + return { + # standard ones + torch.empty, + torch.empty_permuted, + torch.empty_strided, + torch.empty_quantized, + torch.ones, + torch.arange, + torch.bartlett_window, + torch.blackman_window, + torch.eye, + torch.fft.fftfreq, + torch.fft.rfftfreq, + torch.full, + torch.hamming_window, + torch.hann_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.nested.nested_tensor, + # This function doesn't actually take a device argument + # torch.normal, + torch.rand, + torch.randn, + torch.randint, + torch.randperm, + torch.range, + torch.sparse_coo_tensor, + torch.sparse_compressed_tensor, + torch.sparse_csr_tensor, + torch.sparse_csc_tensor, + torch.sparse_bsr_tensor, + torch.sparse_bsc_tensor, + torch.tril_indices, + torch.triu_indices, + torch.zeros, + torch.asarray, + # weird ones + torch.tensor, + torch.as_tensor, + torch.scalar_tensor, + } + + +# NB: This is directly called from C++ in torch/csrc/Device.cpp +class DeviceContext(TorchFunctionMode): + def __init__(self, device): + self.device = torch.device(device) + + def __enter__(self): + global CURRENT_DEVICE + self.old_device = CURRENT_DEVICE + CURRENT_DEVICE = self.device + # We need to put the device at the bottom of the stack + # If we set default device within a function mode context + # exiting that context mode will pop the device function mode off + # of the stack incorrectly + cur_stack = [_pop_mode() for _ in range(_len_torch_function_stack())] + + _push_mode(self) + + for mode in reversed(cur_stack): + _push_mode(mode) + + def __exit__(self, exc_type, exc_val, exc_tb): + global CURRENT_DEVICE + CURRENT_DEVICE = self.old_device + cur_stack = [] + # Invariant: there should only be one DeviceContext on the stack at any time + # (At the bottom), pop all mdoes until we hit the bottom, assert it's a DeviceContext + # or else someone else has popped it! + for _ in range(_len_torch_function_stack() - 1): + mode = _pop_mode() + assert not isinstance(mode, DeviceContext) + cur_stack.append(mode) + + if _len_torch_function_stack() > 0: + mode = _pop_mode() + assert isinstance(mode, DeviceContext) + + for mode in reversed(cur_stack): + _push_mode(mode) + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if func in _device_constructors() and kwargs.get("device") is None: + kwargs["device"] = self.device + return func(*args, **kwargs) + + +# NB: This is directly called from C++ in torch/csrc/Device.cpp +def device_decorator(device, func): + return context_decorator(lambda: device, func) + + +def set_device(device): + """ + Set the default device inside of the wrapped function by decorating it with this function. + + If you would like to use this as a context manager, use device as a + context manager directly, e.g., ``with torch.device(device)``. + """ + return lambda func: device_decorator(torch.device(device), func) diff --git a/phivenv/Lib/site-packages/torch/utils/_dtype_abbrs.py b/phivenv/Lib/site-packages/torch/utils/_dtype_abbrs.py new file mode 100644 index 0000000000000000000000000000000000000000..563581ec4ac44797f615e3d61d7b9c7dfb3f18bb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_dtype_abbrs.py @@ -0,0 +1,30 @@ +import torch + + +# Used for testing and logging +dtype_abbrs = { + torch.bfloat16: "bf16", + torch.float64: "f64", + torch.float32: "f32", + torch.float16: "f16", + torch.float8_e4m3fn: "f8e4m3fn", + torch.float8_e5m2: "f8e5m2", + torch.float8_e4m3fnuz: "f8e4m3fnuz", + torch.float8_e5m2fnuz: "f8e5m2fnuz", + torch.float8_e8m0fnu: "f8e8m0fnu", + torch.float4_e2m1fn_x2: "f4e2m1fnx2", + torch.complex32: "c32", + torch.complex64: "c64", + torch.complex128: "c128", + torch.int8: "i8", + torch.int16: "i16", + torch.int32: "i32", + torch.int64: "i64", + torch.bool: "b8", + torch.uint8: "u8", + torch.uint16: "u16", + torch.uint32: "u32", + torch.uint64: "u64", + torch.bits16: "b16", + torch.bits1x8: "b1x8", +} diff --git a/phivenv/Lib/site-packages/torch/utils/_exposed_in.py b/phivenv/Lib/site-packages/torch/utils/_exposed_in.py new file mode 100644 index 0000000000000000000000000000000000000000..9dec51ead7b520201634bb54e8355b56c2901154 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_exposed_in.py @@ -0,0 +1,20 @@ +from typing import Callable, TypeVar + + +F = TypeVar("F") + + +# Allows one to expose an API in a private submodule publicly as per the definition +# in PyTorch's public api policy. +# +# It is a temporary solution while we figure out if it should be the long-term solution +# or if we should amend PyTorch's public api policy. The concern is that this approach +# may not be very robust because it's not clear what __module__ is used for. +# However, both numpy and jax overwrite the __module__ attribute of their APIs +# without problem, so it seems fine. +def exposed_in(module: str) -> Callable[[F], F]: + def wrapper(fn: F) -> F: + fn.__module__ = module + return fn + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/utils/_filelock.py b/phivenv/Lib/site-packages/torch/utils/_filelock.py new file mode 100644 index 0000000000000000000000000000000000000000..8de9fdefb0e1a368621d240b83ca05006c095470 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_filelock.py @@ -0,0 +1,42 @@ +from types import TracebackType +from typing import Optional +from typing_extensions import Self + +from filelock import FileLock as base_FileLock + +from torch.monitor import _WaitCounter + + +class FileLock(base_FileLock): + """ + This behaves like a normal file lock. + + However, it adds waitcounters for acquiring and releasing the filelock + as well as for the critical region within it. + + pytorch.filelock.enter - While we're acquiring the filelock. + pytorch.filelock.region - While we're holding the filelock and doing work. + pytorch.filelock.exit - While we're releasing the filelock. + """ + + def __enter__(self) -> Self: + self.region_counter = _WaitCounter("pytorch.filelock.region").guard() + with _WaitCounter("pytorch.filelock.enter").guard(): + result = super().__enter__() + self.region_counter.__enter__() + return result + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.region_counter.__exit__() + with _WaitCounter("pytorch.filelock.exit").guard(): + # Returns nothing per + # https://github.com/tox-dev/filelock/blob/57f488ff8fdc2193572efe102408fb63cfefe4e4/src/filelock/_api.py#L379 + super().__exit__(exc_type, exc_value, traceback) + # Returns nothing per + # https://github.com/pytorch/pytorch/blob/0f6bfc58a2cfb7a5c052bea618ab62becaf5c912/torch/csrc/monitor/python_init.cpp#L315 + return None diff --git a/phivenv/Lib/site-packages/torch/utils/_foreach_utils.py b/phivenv/Lib/site-packages/torch/utils/_foreach_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..06da0d3273e359f7b2f7ad16ed1a12729054be87 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_foreach_utils.py @@ -0,0 +1,60 @@ +from typing import Optional +from typing_extensions import TypeAlias + +import torch +from torch import Tensor +from torch.autograd.grad_mode import no_grad + + +def _get_foreach_kernels_supported_devices() -> list[str]: + r"""Return the device type list that supports foreach kernels.""" + return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()] + + +def _get_fused_kernels_supported_devices() -> list[str]: + r"""Return the device type list that supports fused kernels in optimizer.""" + return [ + "mps", + "cuda", + "xpu", + "hpu", + "cpu", + torch._C._get_privateuse1_backend_name(), + ] + + +TensorListList: TypeAlias = list[list[Optional[Tensor]]] +Indices: TypeAlias = list[int] +_foreach_supported_types = [torch.Tensor] + + +# This util function splits tensors into groups by device and dtype, which is useful before sending +# tensors off to a foreach implementation, which requires tensors to be on one device and dtype. +# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified: +# - tensorlists CAN be None +# - all tensors in the first specified list cannot be None +# - given an index i, all specified tensorlist[i]s match in dtype and device +# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry. +# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out. +# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the +# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation +# may be necessary. Check out torch/optim/sgd.py for an example. +@no_grad() +def _group_tensors_by_device_and_dtype( + tensorlistlist: TensorListList, + with_indices: bool = False, +) -> dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]]: + return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) + + +def _device_has_foreach_support(device: torch.device) -> bool: + return ( + device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) + and not torch.jit.is_scripting() + ) + + +def _has_foreach_support(tensors: list[Tensor], device: torch.device) -> bool: + return _device_has_foreach_support(device) and all( + t is None or type(t) in _foreach_supported_types for t in tensors + ) diff --git a/phivenv/Lib/site-packages/torch/utils/_freeze.py b/phivenv/Lib/site-packages/torch/utils/_freeze.py new file mode 100644 index 0000000000000000000000000000000000000000..f9811c8b9f47031f29cbc9f973b9fb58e6acfebe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_freeze.py @@ -0,0 +1,292 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +""" +Freeze Python packages. + + + + +Freezing makes it possible to ship arbitrary Python modules as part of a C++ +library. The Python source of the module is compiled to bytecode and written +to `.c` files, to be imported by Python's built-in FrozenImporter. + +In a normal Python installation, FrozenImporter is only used to bootstrap the +initialization of the import machinery. Python's importers are defined in +Python (see `_bootstrap.py` and `_bootstrap_external.py`) but need to be +retrieved before any importers are available. Freezing the module bytecode +resolves this circular dependency. + +This script will freeze the Python standard library. It produces two things: +- Bytecode files: A set of `.c` that define C variables containing Python bytecode. +- Main file: A `main.c` file listing all of these modules in the right form to be + consumed by FrozenImporter. + +The library that wishes to these modules make them available to the local +Python instance by extending `PyImport_FrozenModules` appropriately (see +https://docs.python.org/3/c-api/import.html#c.PyImport_FrozenModules). +""" + +import argparse +import functools +import itertools +import marshal +import os +import types +from dataclasses import dataclass +from pathlib import Path + + +PATH_MARKER = "" +MAIN_INCLUDES = """#include + +""" + +MAIN_PREFIX_TEMPLATE = """ +// Compiled standard library modules. These should be appended to the existing +// `PyImport_FrozenModules` that ships with CPython. +struct _frozen {}[] = {{ +""" + +FAKE_PREFIX = MAIN_PREFIX_TEMPLATE.format("_PyImport_FrozenModules") + +MAIN_SUFFIX = """\ + {0, 0, 0} /* sentinel */ +}; +""" + +# Exclude some standard library modules to: +# 1. Slim down the final frozen lib. +# 2. Remove functionality we don't want to support. +DENY_LIST = [ + # Interface to unix databases + "dbm", + # ncurses bindings (terminal interfaces) + "curses", + # Tcl/Tk GUI + "tkinter", + "tkinter", + # Tests for the standard library + "test", + "tests", + "idle_test", + "__phello__.foo.py", + # importlib frozen modules. These are already baked into CPython. + "_bootstrap.py", + "_bootstrap_external.py", +] + +NUM_BYTECODE_FILES = 5 + + +def indent_msg(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + args[0].indent += 1 + ret = fn(*args, **kwargs) + args[0].indent -= 1 + return ret + + return wrapper + + +@dataclass +class FrozenModule: + # The fully qualified module name, e.g. 'foo.bar.baz' + module_name: str + # The name of the C variable that holds the bytecode, e.g. 'M_foo__bar__baz' + c_name: str + # The size of the C variable. Negative if this module is a package. + size: int + # The frozen bytecode + bytecode: bytes + + +class Freezer: + def __init__(self, verbose: bool): + self.frozen_modules: list[FrozenModule] = [] + self.indent: int = 0 + self.verbose: bool = verbose + + def msg(self, path: Path, code: str): + if not self.verbose: + return + # P: package dir + # F: python file + # S: skipped (not a package dir) + # X: skipped (deny-listed) + # N: skipped (not a python file) + print(" " * self.indent, end="") + print(f"{code} {path}") + + def write_bytecode(self, install_root): + """ + Write the `.c` files containing the frozen bytecode. + + Shared frozen modules evenly across the files. + """ + bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)] + bytecode_files = [ + open(os.path.join(install_root, name), "w") for name in bytecode_file_names + ] + it = itertools.cycle(bytecode_files) + for m in self.frozen_modules: + self.write_frozen(m, next(it)) + + for f in bytecode_files: + f.close() + + def write_main(self, install_root, oss, symbol_name): + """Write the `main.c` file containing a table enumerating all the frozen modules.""" + with open(os.path.join(install_root, "main.c"), "w") as outfp: + outfp.write(MAIN_INCLUDES) + for m in self.frozen_modules: + outfp.write(f"extern unsigned char {m.c_name}[];\n") + + outfp.write(MAIN_PREFIX_TEMPLATE.format(symbol_name)) + for m in self.frozen_modules: + outfp.write(f'\t{{"{m.module_name}", {m.c_name}, {m.size}}},\n') + outfp.write(MAIN_SUFFIX) + if oss: + outfp.write(FAKE_PREFIX) + outfp.write(MAIN_SUFFIX) + + def write_frozen(self, m: FrozenModule, outfp): + """Write a single frozen module's bytecode out to a C variable.""" + outfp.write(f"unsigned char {m.c_name}[] = {{") + for i in range(0, len(m.bytecode), 16): + outfp.write("\n\t") + for c in bytes(m.bytecode[i : i + 16]): + outfp.write(f"{c:d},") + outfp.write("\n};\n") + + def compile_path(self, path: Path, top_package_path: Path): + """Entry point for compiling a Path object.""" + if path.is_dir(): + self.compile_package(path, top_package_path) + else: + self.compile_file(path, top_package_path) + + @indent_msg + def compile_package(self, path: Path, top_package_path: Path): + """Compile all the files within a Python package dir.""" + assert path.is_dir() + if path.name in DENY_LIST: + self.msg(path, "X") + return + + # Python packages are directories that have __init__.py in them. + is_package_dir = any(child.name == "__init__.py" for child in path.iterdir()) + if not is_package_dir: + self.msg(path, "S") + return + + self.msg(path, "P") + # Recursively compile all children in this dir + for child in path.iterdir(): + self.compile_path(child, top_package_path) + + def get_module_qualname(self, file_path: Path, top_package_path: Path) -> list[str]: + # `path` looks like 'Lib/foo/bar/baz.py' + + # chop off 'Lib/' to get something that represents a Python module hierarchy. + # e.g. 'foo/bar/baz.py', which maps to 'foo.bar.baz' + normalized_path = file_path.relative_to(top_package_path.parent) + + if normalized_path.name == "__init__.py": + # Special handling for `__init__.py`. In this case, this file + # specifies that the containing directory should be treated as a package. + # For 'foo/bar/baz/__init__.py': + # - The module name is 'baz' + module_basename = normalized_path.parent.name + # - The parent is foo.bar (need to shave off the 'baz') + module_parent = normalized_path.parent.parent.parts + else: + module_basename = normalized_path.stem + module_parent = normalized_path.parent.parts + return list(module_parent) + [module_basename] + + def compile_string(self, file_content: str) -> types.CodeType: + # instead of passing in the real build time path to 'compile', we + # pass in a marker instead. This prevents the build time path being + # leaked to runtime. That path may not be available at runtime. + # Setting the path to a mark make sure it's a hard error rather + # than a flaky error when inspect module tries to retrieve python source + # code during torchscripting. + path_marker = PATH_MARKER + return compile(file_content, path_marker, "exec") + + @indent_msg + def compile_file(self, path: Path, top_package_path: Path): + """ + Compile a Python source file to frozen bytecode. + + Append the result to `self.frozen_modules`. + """ + assert path.is_file() + if path.suffix != ".py": + self.msg(path, "N") + return + + if path.name in DENY_LIST: + self.msg(path, "X") + return + + self.msg(path, "F") + module_qualname = self.get_module_qualname(path, top_package_path) + module_mangled_name = "__".join(module_qualname) + c_name = "M_" + module_mangled_name + + with open(path) as src_file: + co = self.compile_string(src_file.read()) + + bytecode = marshal.dumps(co) + size = len(bytecode) + if path.name == "__init__.py": + # Python packages are signified by negative size. + size = -size + self.frozen_modules.append( + FrozenModule(".".join(module_qualname), c_name, size, bytecode) + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Compile py source") + parser.add_argument("paths", nargs="*", help="Paths to freeze.") + parser.add_argument("--verbose", action="store_true", help="Print debug logs") + parser.add_argument( + "--install-dir", "--install_dir", help="Root directory for all output files" + ) + parser.add_argument( + "--oss", + action="store_true", + help="If it's OSS build, add a fake _PyImport_FrozenModules", + ) + parser.add_argument( + "--symbol-name", + "--symbol_name", + help="The name of the frozen module array symbol to generate", + default="_PyImport_FrozenModules_torch", + ) + + args = parser.parse_args() + + f = Freezer(args.verbose) + + for p in args.paths: + path = Path(p) + if path.is_dir() and not Path.exists(path / "__init__.py"): + # this 'top level path p' is a standard directory containing modules, + # not a module itself + # each 'mod' could be a dir containing __init__.py or .py file + # NB: sorted to make sure this is deterministic + for mod in sorted(path.glob("*")): + f.compile_path(mod, mod) + else: + f.compile_path(path, path) + + f.write_bytecode(args.install_dir) + f.write_main(args.install_dir, args.oss, args.symbol_name) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/phivenv/Lib/site-packages/torch/utils/_functools.py b/phivenv/Lib/site-packages/torch/utils/_functools.py new file mode 100644 index 0000000000000000000000000000000000000000..6632dfb3cc36999d7af48091ce625f1562a7fb0c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_functools.py @@ -0,0 +1,44 @@ +import functools +from typing import Callable, TypeVar +from typing_extensions import Concatenate, ParamSpec + + +_P = ParamSpec("_P") +_T = TypeVar("_T") +_C = TypeVar("_C") + +# Sentinel used to indicate that cache lookup failed. +_cache_sentinel = object() + + +def cache_method( + f: Callable[Concatenate[_C, _P], _T] +) -> Callable[Concatenate[_C, _P], _T]: + """ + Like `@functools.cache` but for methods. + + `@functools.cache` (and similarly `@functools.lru_cache`) shouldn't be used + on methods because it caches `self`, keeping it alive + forever. `@cache_method` ignores `self` so won't keep `self` alive (assuming + no cycles with `self` in the parameters). + + Footgun warning: This decorator completely ignores self's properties so only + use it when you know that self is frozen or won't change in a meaningful + way (such as the wrapped function being pure). + """ + cache_name = "_cache_method_" + f.__name__ + + @functools.wraps(f) + def wrap(self: _C, *args: _P.args, **kwargs: _P.kwargs) -> _T: + assert not kwargs + if not (cache := getattr(self, cache_name, None)): + cache = {} + setattr(self, cache_name, cache) + cached_value = cache.get(args, _cache_sentinel) + if cached_value is not _cache_sentinel: + return cached_value + value = f(self, *args, **kwargs) + cache[args] = value + return value + + return wrap diff --git a/phivenv/Lib/site-packages/torch/utils/_get_clean_triton.py b/phivenv/Lib/site-packages/torch/utils/_get_clean_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..1b83d4d80ca21e546b017d58b27be2e0892372d2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_get_clean_triton.py @@ -0,0 +1,199 @@ +# mypy: allow-untyped-defs +import argparse +import os +import re +import subprocess +from pathlib import Path + + +def remove_triton_function_declaration(source_code: str) -> str: + remove_head = re.sub(r"(\n.+\s\'\'\'\n)", "\n", source_code) + remove_tail = re.sub(r"(\'\'\'\,.+)", "\n", remove_head) + return remove_tail + + +def remove_async_compile(source_code: str) -> str: + remove_top_level = str.replace(source_code, "async_compile = AsyncCompile()", "") + remove_compile = str.replace(remove_top_level, "async_compile.wait(globals())", "") + remove_del = str.replace(remove_compile, "del async_compile", "") + return remove_del + + +def rename_kernels(source_code: str) -> str: + pattern = r"(\w+)\s*=\s*async_compile\.triton\('triton_',\s" + triton_kernel_decl = "def triton_" + matches = [ + (match.end(), match.group(1)) + for match in re.finditer(pattern, source_code, re.DOTALL) + ] + + # Starting from the last match to avoid issues with shifting indices after replacements + for end_index, captured_string in reversed(matches): + # Find the index of the next "B" after the current match + index_of_B = source_code.find(triton_kernel_decl, end_index) + if index_of_B != -1: + # Replace the triton_kernel_decl with the captured string + source_code = ( + source_code[:index_of_B] + + f"def {captured_string}" + + source_code[index_of_B + len(triton_kernel_decl) :] + ) + else: + # If triton_kernel_decl is not found after the current match, continue to the next + continue + + return source_code + + +def merge_params(original_params: list[str], new_params: list[str]) -> list[str]: + for idx in range(len(new_params)): + if new_params[idx] == "T": + new_params[idx] = original_params[idx] + return new_params + + +def add_launch_params( + original: str, kernel_to_params: dict[str, tuple[str, str]] +) -> str: + # Regex to match the function call in the original string + pattern = r"(\w+)\.run\((.*)\)" + + def replace(match) -> str: + # Extract parts from the regex match + func_name = match.group(1) + params = match.group(2) + new_params, grid = kernel_to_params[func_name] + new_params = merge_params(params.split(", "), new_params.split(", ")) + + # Format the new function call + new_string = f"{func_name}[{grid}]({', '.join(new_params)})" + return new_string + + transformed = re.sub(pattern, replace, original) + + remove_inductor_wrappers = re.sub( + r"@triton_heuristics[^@]*@triton.jit", + r"@triton.jit", + transformed, + flags=re.DOTALL, + ) + + return remove_inductor_wrappers + + +def process_file( + input_filename: str, output_filename: str, auto_generate_params: bool = True +) -> str: + with open(input_filename) as file: + source_code = file.read() + + transformed_code = source_code + if "def triton_(" in source_code: + raise RuntimeError( + "Need to run original Pytorch code generating kernels with TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1" + ) + # transformed_code = rename_kernels(transformed_code) + transformed_code = remove_triton_function_declaration(transformed_code) + transformed_code = remove_async_compile(transformed_code) + + launch_params_filename = f"{input_filename}.launch_params" + + # Auto-generate launch_params if they don't exist and auto_generate_params is True + if not os.path.exists(launch_params_filename) and auto_generate_params: + print(f"Launch params file {launch_params_filename} not found. Generating...") + try: + # Set environment variable and run the input file + env = os.environ.copy() + env["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = "1" + + result = subprocess.run( + ["python", input_filename], + env=env, + capture_output=True, + text=True, + cwd=os.path.dirname(input_filename) or ".", + ) + + if result.returncode != 0: + print(f"Error running {input_filename}:") + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + raise RuntimeError( + f"Failed to generate launch params. Command failed with return code {result.returncode}" + ) + + print(f"Successfully generated {launch_params_filename}") + + except Exception as e: + raise RuntimeError( + f"Failed to generate launch params by running {input_filename}: {str(e)}" + ) from e + + if not os.path.exists(launch_params_filename): + raise RuntimeError( + f"Missing {launch_params_filename}. Run `TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1 python {input_filename}` first." + ) + + with open(launch_params_filename) as f: + launch_params_meta = f.readlines() + + split_params = [i.split("|") for i in launch_params_meta] + kernel_args_grid = {a.strip(): (b.strip(), c.strip()) for a, b, c in split_params} + transformed_code = add_launch_params(transformed_code, kernel_args_grid) + + with open(output_filename, "w") as file: + file.write(transformed_code) + print(f"Successfully generated {output_filename}") + return transformed_code + + +def get_clean_triton( + input_path: Path, + output_path: Path = Path("triton_only_repro.py"), + auto_generate_params: bool = True, +): + """Run experiments and output results to file + + Args: + input_path (Optional[Path]): Path to inductor generated output codede + output_path (Optional[Path]): Path to write out the new python file + auto_generate_params (bool): Whether to automatically generate launch_params if missing + """ + return process_file(str(input_path), str(output_path), auto_generate_params) + + +if __name__ == "__main__": + """Sample usage: + # Running sweep + python _get_clean_triton.py output_code.py + + # To disable auto-generation of launch params: + python _get_clean_triton.py output_code.py --no-auto-generate + """ + parser = argparse.ArgumentParser( + description="Clean Inductor generated code to remove Inductor dependencies" + ) + + # Add the arguments + parser.add_argument( + "input_path", type=Path, help="Path to inductor generated output code" + ) + parser.add_argument( + "--output_path", + type=Path, + default=Path("triton_only_repro.py"), + help="Path to write out the clean triton output", + ) + parser.add_argument( + "--no-auto-generate", + action="store_true", + help="Disable automatic generation of launch_params file", + ) + + # Parse the arguments + args = parser.parse_args() + + # Call the function with parsed arguments + result = get_clean_triton( + args.input_path, args.output_path, not args.no_auto_generate + ) diff --git a/phivenv/Lib/site-packages/torch/utils/_helion.py b/phivenv/Lib/site-packages/torch/utils/_helion.py new file mode 100644 index 0000000000000000000000000000000000000000..c303146e815e4dc882ad166b25a596cca6f0e516 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_helion.py @@ -0,0 +1,17 @@ +import functools + +from torch.utils._triton import has_triton + + +@functools.cache +def has_helion_package() -> bool: + try: + import helion # type: ignore[import-untyped, import-not-found] # noqa: F401 + except ImportError: + return False + return True + + +@functools.cache +def has_helion() -> bool: + return has_helion_package() and has_triton() diff --git a/phivenv/Lib/site-packages/torch/utils/_import_utils.py b/phivenv/Lib/site-packages/torch/utils/_import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..783752d8af838dc3aee715c3a08d9db6e518b260 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_import_utils.py @@ -0,0 +1,44 @@ +import functools +import importlib.util +from types import ModuleType +from typing import Optional + +import torch + + +def _check_module_exists(name: str) -> bool: + r"""Returns if a top-level module with :attr:`name` exists *without** + importing it. This is generally safer than try-catch block around a + `import X`. It avoids third party libraries breaking assumptions of some of + our tests, e.g., setting multiprocessing start method when imported + (see librosa/#747, torchvision/#544). + """ + try: + spec = importlib.util.find_spec(name) + return spec is not None + except ImportError: + return False + + +@functools.lru_cache +def dill_available() -> bool: + return ( + _check_module_exists("dill") + # dill fails to import under torchdeploy + and not torch._running_with_deploy() + ) + + +@functools.lru_cache +def import_dill() -> Optional[ModuleType]: + if not dill_available(): + return None + + import dill + + # XXX: By default, dill writes the Pickler dispatch table to inject its + # own logic there. This globally affects the behavior of the standard library + # pickler for any user who transitively depends on this module! + # Undo this extension to avoid altering the behavior of the pickler globally. + dill.extend(use_dill=False) + return dill diff --git a/phivenv/Lib/site-packages/torch/utils/_mode_utils.py b/phivenv/Lib/site-packages/torch/utils/_mode_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..864f0d6d10efc789b3c37258b97428db42035b4d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_mode_utils.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs +from typing import TypeVar + +import torch + + +T = TypeVar("T") + + +# returns if all are the same mode +def all_same_mode(modes): + return all(tuple(mode == modes[0] for mode in modes)) + + +no_dispatch = torch._C._DisableTorchDispatch diff --git a/phivenv/Lib/site-packages/torch/utils/_ordered_set.py b/phivenv/Lib/site-packages/torch/utils/_ordered_set.py new file mode 100644 index 0000000000000000000000000000000000000000..815193e7d5cf6228eeb2268242e644252d40d58a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_ordered_set.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from collections.abc import ( + Iterable, + Iterator, + MutableSet, + Reversible, + Set as AbstractSet, +) +from typing import Any, cast, Optional, TypeVar + + +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) + +__all__ = ["OrderedSet"] + + +class OrderedSet(MutableSet[T], Reversible[T]): + """ + Insertion ordered set, similar to OrderedDict. + """ + + __slots__ = ("_dict",) + + def __init__(self, iterable: Optional[Iterable[T]] = None): + self._dict = dict.fromkeys(iterable, None) if iterable is not None else {} + + @staticmethod + def _from_dict(dict_inp: dict[T, None]) -> OrderedSet[T]: + s: OrderedSet[T] = OrderedSet() + s._dict = dict_inp + return s + + # + # Required overridden abstract methods + # + def __contains__(self, elem: object) -> bool: + return elem in self._dict + + def __iter__(self) -> Iterator[T]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + def __reversed__(self) -> Iterator[T]: + return reversed(self._dict) + + def add(self, elem: T) -> None: + self._dict[elem] = None + + def discard(self, elem: T) -> None: + self._dict.pop(elem, None) + + def clear(self) -> None: + # overridden because MutableSet impl is slow + self._dict.clear() + + # Unimplemented set() methods in _collections_abc.MutableSet + + @classmethod + def _wrap_iter_in_set(cls, other: Any) -> Any: + """ + Wrap non-Set Iterables in OrderedSets + + Some of the magic methods are more strict on input types than + the public apis, so we need to wrap inputs in sets. + """ + + if not isinstance(other, AbstractSet) and isinstance(other, Iterable): + return cls(other) + else: + return other + + def pop(self) -> T: + if not self: + raise KeyError("pop from an empty set") + return self._dict.popitem()[0] + + def copy(self) -> OrderedSet[T]: + return OrderedSet._from_dict(self._dict.copy()) + + def difference(self, *others: Iterable[T]) -> OrderedSet[T]: + res = self.copy() + res.difference_update(*others) + return res + + def difference_update(self, *others: Iterable[T]) -> None: + for other in others: + self -= other # type: ignore[arg-type] + + def update(self, *others: Iterable[T]) -> None: + for other in others: + self |= other + + def intersection(self, *others: Iterable[T]) -> OrderedSet[T]: + res = self.copy() + for other in others: + if other is not self: + res &= other # type: ignore[arg-type] + return res + + def intersection_update(self, *others: Iterable[T]) -> None: + for other in others: + self &= other # type: ignore[arg-type] + + def issubset(self, other: Iterable[T]) -> bool: + return self <= self._wrap_iter_in_set(other) + + def issuperset(self, other: Iterable[T]) -> bool: + return self >= self._wrap_iter_in_set(other) + + def symmetric_difference(self, other: Iterable[T]) -> OrderedSet[T]: + return self ^ other # type: ignore[operator] + + def symmetric_difference_update(self, other: Iterable[T]) -> None: + self ^= other # type: ignore[arg-type] + + def union(self, *others: Iterable[T]) -> OrderedSet[T]: + res = self.copy() + for other in others: + if other is self: + continue + res |= other + return res + + # Specify here for correct type inference, otherwise would + # return AbstractSet[T] + def __sub__(self, other: AbstractSet[T_co]) -> OrderedSet[T]: + # following cpython set impl optimization + if isinstance(other, OrderedSet) and (len(self) * 4) > len(other): + out = self.copy() + out -= other + return out + return cast(OrderedSet[T], super().__sub__(other)) + + def __ior__(self, other: Iterable[T]) -> OrderedSet[T]: # type: ignore[misc, override] # noqa: PYI034 + if isinstance(other, OrderedSet): + self._dict.update(other._dict) + return self + return super().__ior__(other) # type: ignore[arg-type] + + def __eq__(self, other: object) -> bool: + if isinstance(other, OrderedSet): + return self._dict == other._dict + return super().__eq__(other) + + def __ne__(self, other: object) -> bool: + if isinstance(other, OrderedSet): + return self._dict != other._dict + return super().__ne__(other) + + def __or__(self, other: AbstractSet[T_co]) -> OrderedSet[T]: + return cast(OrderedSet[T], super().__or__(other)) + + def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]: + # MutableSet impl will iterate over other, iter over smaller of two sets + if isinstance(other, OrderedSet) and len(self) < len(other): + return other & self + return cast(OrderedSet[T], super().__and__(other)) + + def __xor__(self, other: AbstractSet[T_co]) -> OrderedSet[T]: + return cast(OrderedSet[T], super().__xor__(other)) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({list(self)})" + + def __getstate__(self) -> list[T]: + return list(self._dict.keys()) + + def __setstate__(self, state: list[T]) -> None: + self._dict = dict.fromkeys(state, None) + + def __reduce__(self) -> tuple[type[OrderedSet[T]], tuple[list[T]]]: + return (OrderedSet, (list(self),)) diff --git a/phivenv/Lib/site-packages/torch/utils/_python_dispatch.py b/phivenv/Lib/site-packages/torch/utils/_python_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..7cba97b21d3516aaa65863cf662d08f1b8cd305a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_python_dispatch.py @@ -0,0 +1,721 @@ +# mypy: allow-untyped-defs +import contextlib +import warnings +from collections import deque +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Optional, overload, Protocol, Union +from typing_extensions import TypeIs + +import torch +import torchgen +import torchgen.model +from torch._C import ( + _get_dispatch_stack_at, + _len_torch_dispatch_stack, + _pop_torch_dispatch_stack, + _push_on_torch_dispatch_stack, + DispatchKey, +) + + +# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: +# - We need a better user-facing api for _DisableTorchDispatch that +# is able to selectively disable __torch_dispatch__ of a particular class. +# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor) +# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694) + +_is_in_torch_dispatch_mode = False +_is_in_non_infra_torch_dispatch_mode = False + + +def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool: + return ( + _is_in_torch_dispatch_mode + if include_infra_modes + else _is_in_non_infra_torch_dispatch_mode + ) + + +class TorchDispatchMode: + """ + A ``TorchDispatchMode`` allows you to override the meaning of all + ``__torch_dispatch__`` overrideable functions within a dynamic scope, + without having to actually create a tensor subclass or manually + monkey-patch functions in the PyTorch API. Some common situations + where you should use a mode: + + * You want to override the meaning of factory functions, or other + functions that do not otherwise take a tensor as an argument + (these cannot be overridden with tensor subclasses). + + * You want to override the behavior of all functions without needing + to wrap your inputs in tensor subclasses; e.g., if you are just + interested in logging intermediate computations. + + * You want to control the order of execution of various tensor + subclasses explicitly, rather than implicitly via the return of + ``NotImplemented``. + + Independent subclasses of :class:`TorchDispatchMode` are compositional: + modes can be pushed onto a stack using ``with MyMode():``. + When you call functions in the PyTorch API inside your + ``__torch_dispatch__`` implementation, by default, they will forward on to + the next mode on the mode stack. If you want recursively call back into + your current ``__torch_dispatch__`` implementation, either explicitly + invoke ``self.__torch_dispatch__(...)``, or use the context manager + ``__torch_dispatch__(self)`` to make PyTorch + API self-referential (beware of infinite loops, in this case!) + """ + + def __init__(self, _dispatch_key=None): + if _dispatch_key is not None: + assert isinstance(_dispatch_key, torch._C.DispatchKey) + self.__dict__["_dispatch_key"] = _dispatch_key + + self.old_dispatch_mode_flags: deque[bool] = deque() + self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() + + def _lazy_init_old_dispatch_mode_flags(self): + if not hasattr(self, "old_dispatch_mode_flags"): + self.old_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef] + + if not hasattr(self, "old_non_infra_dispatch_mode_flags"): + self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef] + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + raise NotImplementedError + + def __enter__(self): + global _is_in_torch_dispatch_mode + global _is_in_non_infra_torch_dispatch_mode + # Previously, there wasn't any state in this class' constructor + # super calls were added to existing modes, but for any new modes + # this will replicate the previous behavior of not strictly needing + # to call super().__init__() + self._lazy_init_old_dispatch_mode_flags() + self.old_dispatch_mode_flags.append(_is_in_torch_dispatch_mode) + _is_in_torch_dispatch_mode = True + self.old_non_infra_dispatch_mode_flags.append( + _is_in_non_infra_torch_dispatch_mode + ) + _is_in_non_infra_torch_dispatch_mode = ( + _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode() + ) + _push_mode(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + mb_dk_or_mode_key = self.__dict__.get("_dispatch_key", None) + if mb_dk_or_mode_key is None: + # Today, mode keys are not used at all in the per-dispatch-key-mode logic (for pre-dispatch) + # We should probably revisit this. + mb_dk_or_mode_key = self.__dict__.get("_mode_key", None) + global _is_in_torch_dispatch_mode + _is_in_torch_dispatch_mode = self.old_dispatch_mode_flags.pop() + global _is_in_non_infra_torch_dispatch_mode + _is_in_non_infra_torch_dispatch_mode = ( + self.old_non_infra_dispatch_mode_flags.pop() + ) + _pop_mode(mb_dk_or_mode_key) + + @classmethod + def push(cls, *args, **kwargs): + warnings.warn( + "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`" + ) + instance = cls(*args, **kwargs) + return instance + + @classmethod + def is_infra_mode(cls): + return False + + +def _get_current_dispatch_mode(): + stack_len = _len_torch_dispatch_stack() + # Return a user mode on the stack if there are any + if stack_len > 0: + return _get_dispatch_stack_at(stack_len - 1) + return None + + +def _detect_infra_mode(key): + assert key in [ + torch._C._TorchDispatchModeKey.FUNCTIONAL, + torch._C._TorchDispatchModeKey.PROXY, + ] + from torch._ops import _get_dispatch_mode_pre_dispatch + + pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key) + post_dispatch_mode = torch._C._get_dispatch_mode(key) + + assert (pre_dispatch_mode is None) or (post_dispatch_mode is None) + + if pre_dispatch_mode is None: + return post_dispatch_mode + + return pre_dispatch_mode + + +def _unset_infra_mode(key): + from torch._ops import _get_dispatch_mode_pre_dispatch, unset_mode_pre_dispatch + + pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key) + post_dispatch_mode = torch._C._get_dispatch_mode(key) + if pre_dispatch_mode and post_dispatch_mode: + raise AssertionError( + "Can't have active infra mode on both pre and post dispatch mode stack" + ) + + if pre_dispatch_mode: + mode = unset_mode_pre_dispatch(key) + return mode + if post_dispatch_mode: + return torch._C._unset_dispatch_mode(key) + + +def _disable_infra_mode(key): + assert key in ( + torch._C._TorchDispatchModeKey.FUNCTIONAL, + torch._C._TorchDispatchModeKey.PROXY, + ) + mode_unset = _unset_infra_mode(key) + try: + yield mode_unset + finally: + if mode_unset is not None: + _push_mode(mode_unset) + + +def _get_current_dispatch_mode_stack(): + stack_len = _len_torch_dispatch_stack() + return [_get_dispatch_stack_at(i) for i in range(stack_len)] + + +def _push_mode(mode: TorchDispatchMode): + k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None + assert k is None or k == torch._C.DispatchKey.PreDispatch + if k is None: + _push_on_torch_dispatch_stack(mode) + return + + from torch._ops import _set_mode_pre_dispatch, get_cached_ops + + # See Note [Not Caching Per-Dispatch-Key Mode Handlers] + # Clear the cache of every op that has been used so far, for this particular key. + ks = torch._C._functionality_to_backend_keys(k) + for op in get_cached_ops(): + for key in ks: + op._uncache_dispatch(key) + _set_mode_pre_dispatch(mode) + + +def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None): + if k == torch._C.DispatchKey.PreDispatch: # type: ignore[attr-defined] + from torch._ops import _pop_mode_from_pre_dispatch + + return _pop_mode_from_pre_dispatch() + + if k is None or isinstance(k, torch._C._TorchDispatchModeKey): + return _pop_torch_dispatch_stack(k) + + +@contextlib.contextmanager +def _pop_mode_temporarily(k: Optional[DispatchKey] = None): + old = _pop_mode(k) + try: + yield old + finally: + _push_mode(old) + + +@contextlib.contextmanager +def _disable_current_modes(): + from torch._ops import ( + _len_torch_dispatch_stack_pre_dispatch, + _pop_mode_from_pre_dispatch, + ) + from torch._subclasses.functional_tensor import FunctionalTensorMode + from torch._subclasses.schema_check_mode import SchemaCheckMode + from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode + + mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch() + old_pre_dispatch_modes = [ + _pop_mode_from_pre_dispatch() for _ in range(mode_len_pre_dispatch) + ] + + has_proxy_mode_in_pre_dispatch = False + has_functional_mode_in_pre_dispatch = False + has_schema_check_mode_in_pre_dispatch = False + + for i in old_pre_dispatch_modes: + if isinstance(i, ProxyTorchDispatchMode): + has_proxy_mode_in_pre_dispatch = True + if isinstance(i, FunctionalTensorMode): + has_functional_mode_in_pre_dispatch = True + if isinstance(i, SchemaCheckMode): + has_schema_check_mode_in_pre_dispatch = True + + mode_len = _len_torch_dispatch_stack() + old_modes = [_pop_mode() for _ in range(mode_len)] + + for old in old_modes: + if ( + isinstance(old, FunctionalTensorMode) + and has_functional_mode_in_pre_dispatch + ): + raise AssertionError( + "Can't have FunctionalMode available both in PreDispatch and Python Key" + ) + if isinstance(old, ProxyTorchDispatchMode) and has_proxy_mode_in_pre_dispatch: + raise AssertionError( + "Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key" + ) + if isinstance(old, SchemaCheckMode) and has_schema_check_mode_in_pre_dispatch: + raise AssertionError( + "Can't have SchemaCheckMode available both in PreDispatch and Python Key" + ) + + # Manually disable proxy and fake modes, if any are active + try: + yield old_pre_dispatch_modes + old_modes + finally: + for mode in reversed(old_modes): + _push_mode(mode) + for mode in reversed(old_pre_dispatch_modes): + _push_mode(mode) + + +class BaseTorchDispatchMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return func(*args, **kwargs) + + +# Subtypes which have __tensor_flatten__ and __tensor_unflatten__. +class TensorWithFlatten(Protocol): + def __tensor_flatten__(self) -> tuple[Sequence[str], object]: + ... + + @staticmethod + def __tensor_unflatten__( + inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int + ) -> torch.Tensor: + ... + + # It would be really nice to be able to say that the return of + # is_traceable_wrapper_subclass() is Intersection[torch.Tensor, + # TensorWithFlatten] - but that doesn't exist. + + shape: torch._C.Size + + @overload + def stride(self, dim: None = None) -> tuple[int, ...]: + ... + + @overload + def stride(self, dim: int) -> int: + ... + + @overload + def size(self, dim: None = None) -> tuple[int, ...]: + ... + + @overload + def size(self, dim: int) -> int: + ... + + def storage_offset(self) -> int: + ... + + def dim(self) -> int: + ... + + @overload + def to( + self, + dtype: torch.types._dtype, + non_blocking: bool = False, + copy: bool = False, + *, + memory_format: Optional[torch.memory_format] = None, + ) -> torch.Tensor: + ... + + @overload + def to( + self, + device: Optional["torch._prims_common.DeviceLikeType"] = None, + dtype: Optional[torch.types._dtype] = None, + non_blocking: bool = False, + copy: bool = False, + *, + memory_format: Optional[torch.memory_format] = None, + ) -> torch.Tensor: + ... + + @overload + def to( + self, + other: torch.Tensor, + non_blocking: bool = False, + copy: bool = False, + *, + memory_format: Optional[torch.memory_format] = None, + ) -> torch.Tensor: + ... + + +def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: + """ + Returns whether or not a tensor subclass that implements __torch_dispatch__ + is 'traceable' with torch.compile. + In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2, + It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__. + It is also expected to obey some restrictions around traceability and aliasing: + * The subclass's __torch_dispatch__() implementation should desugar into pytorch + dispatcher operations that can be traced into a graph. + * The subclass should use return_and_correct_aliasing(). This is needed today to make + sure that torch.compile does the right thing in a few cases around input mutation + and output aliasing. + + Expected magic method signatures: + attrs, ctx = t.__tensor_flatten__() + attrs: list of attribute name strings for inner tensors + ctx: dict containing any other subclass-specific metadata needed for unflattening + + t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride) + inner_tensors: dict mapping attribute name -> tensor for each inner tensor + ctx: dict with subclass metadata in the form that __tensor_flatten__() produces + outer_size: expected (possibly symbolic) size that the returned subclass + instance should have. Note that this arg is useful for certain subclasses + that require the shape info to be constructed. In most cases, this arg can be + safely ignored. + outer_stride: expected (possibly symbolic) stride that the returned subclass + instance should have. Note that this arg is useful for certain subclasses + that require the stride info to be constructed. In most cases, this arg can be + safely ignored. + """ + is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor + return ( + is_subclass + and hasattr(t, "__tensor_flatten__") + and hasattr(t, "__tensor_unflatten__") + ) + + +def is_traceable_wrapper_subclass_type(t: type) -> TypeIs[type[TensorWithFlatten]]: + """Same as above, but takes a type argument instead of an instance.""" + return ( + issubclass(t, torch.Tensor) + and t != torch.Tensor + and hasattr(t, "__tensor_flatten__") + and hasattr(t, "__tensor_unflatten__") + ) + + +def transform_subclass(t, callback, outer_size=None, outer_stride=None): + """ + Given a traceable, wrapper tensor subclass ``t`` that implements + ``__torch_dispatch__`` and holds some inner tensors, + and a callback of type ``Callable[[str, torch.Tensor], torch.Tensor]``, + `transform_subclass` will construct a fresh instance of the wrapper tensor subclass. + It will do so by grabbing each inner tensor attribute from the wrapper, + passing them into ``callback`` to get a transformed tensor, + and putting each transformed tensor into the fresh tensor subclass instance. + + Note: this function will not handle ensuring that the fresh subclass + gets the same (autograd, and aliasing) metadata as the original tensor. + This is generally handled in other subsystems like AOTAutograd. + """ + outer_size = outer_size if outer_size is not None else t.size() + outer_stride = outer_stride if outer_stride is not None else t.stride() + + attrs, ctx = t.__tensor_flatten__() + transformed_tensors_dict = {} + for attr in attrs: + transformed_tensors_dict[attr] = callback(attr, getattr(t, attr)) + sub = type(t).__tensor_unflatten__( + transformed_tensors_dict, ctx, outer_size, outer_stride + ) + + # NB: Purposefully guard here to simplify the inner / outer symbols. + # Using sym_eq() for symbolic comparison can result in an expression that's too + # difficult to guard on, so we use == here. + assert sub.shape == outer_size, ( + f"Expected return value from {type(t)}__tensor_unflatten__() to have " + f"shape equal to {outer_size}, but got: {sub.shape}" + ) + assert sub.stride() == outer_stride, ( + f"Expected return value from {type(t)}__tensor_unflatten__() to have " + f"stride equal to {outer_stride}, but got: {sub.stride()}" + ) + + return sub + + +def _correct_storage_aliasing(func, schema_info, args, outs): + """ + Given: an OpOverload, a SchemaInfo (cached information from torchgen about schema), + and the inputs/outputs to the OpOverload, + this function checks to see if func is a view operator + (by checking if any of the outputs in the op's schema + are immutable aliases of inputs). + If so, this function manually aliases the storage of the output tensor + with its corresponding input tensor alias. + It does this by unsafely overwriting the storage field of the output tensor + to be the same storage as the input. + """ + assert isinstance(func, torch._ops.OpOverload) + assert isinstance(args, tuple) + assert isinstance(outs, (list, tuple)) + + def alias_non_inplace_storage(arg, ret): + # This is hopefully a reasonable assert: + # subclasses that rely on this API for output aliasing + # should always return wrapper tensor subclasses for us to manually alias. + # in theory if a subclass that needs this API wants to sometimes return + # plain tensors, we could remove the assert and just not perform the aliasing, + # but it seems safer to learn more about this case first. + if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret): + ret_list = ret if isinstance(ret, list) else [ret] + for r in ret_list: + assert type(arg) == type( + r + ), f"""Called {str(func)} with input of type {type(arg)} +and output of type {type(ret)}. But expected types to match.""" + # Need to call a non-dispatcher helper, because we explicitly do **not** + # want our subclass to intercept the set_() call. + # instead, our subclass should directly have its storage swapped out. + # we **explicitly** don't want to reset the sizes on ret, if the storage implies a size change. + # Why? + # The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct. + # We just want to "fix up" the storage aliasing, without modifying or output's metadata. + # Example: out = inp.expand(inp.shape[0], inp.shape[0]) + # This requires swapping the storage of out to be the same as inp, + # but we do *not* want it to change the sizes/strides that were compute for out. + + if isinstance(ret, list): + for r in ret: + torch._functionalize_unsafe_set(r, arg) + else: + assert isinstance(ret, torch.Tensor), f"type: {type(ret)}" + torch._functionalize_unsafe_set(ret, arg) + + def is_read_only_alias_match(arg, ret): + shared_aliases = arg.alias_set & ret.alias_set + return len(shared_aliases) > 0 and not arg.is_write + + num_args = len(func._schema.arguments) + num_returns = len(func._schema.returns) + for arg_idx in range(num_args): + for return_idx in range(num_returns): + if is_read_only_alias_match( + schema_info.args[arg_idx], schema_info.outs[return_idx] + ): + alias_non_inplace_storage(args[arg_idx], outs[return_idx]) + + +# This abstracts over the fact that in return_and_correct_aliasing, +# we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy), +# and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested). +@dataclass +class AliasInfo: + alias_set: set[str] + is_write: bool + name: Optional[str] + + +@dataclass +class SchemaInfo: + args: list[AliasInfo] + outs: list[AliasInfo] + + +# Can't import torch._ops.OpOverload due to circular reference +parsed_schema_map: dict[Any, SchemaInfo] = {} + + +# Given an OpOverload, returns schema information on it. +# This is cached for efficiency, since it can involve running torchgen +def get_alias_info(func) -> SchemaInfo: + if func in parsed_schema_map: + return parsed_schema_map[func] + # For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations + # properly for some ops that output tensorlists) + if func.namespace == "aten": + torchgen_schema_str = str(func._schema) + assert torchgen_schema_str.startswith("aten::") + # remove the aten:: namespace, which is added by the torchscript parser, + # and torchgen doesn't know how to handle + torchgen_schema_str = torchgen_schema_str[6:] + import re + + # the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1], + # which torchgen chokes on. + torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str) + torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str) + # for aten::rot90 / aten:fft_* + torchgen_schema_str = re.sub( + r"=\[(-?[0-9]+), (-?[0-9]+)\]", r"=[\1,\2]", torchgen_schema_str + ) + torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str) + arg_schemas = [ + AliasInfo( + alias_set=( + set() if a.annotation is None else set(a.annotation.alias_set) + ), + is_write=a.annotation is not None and a.annotation.is_write, + name=a.name, + ) + for a in torchgen_schema.arguments.flat_all + ] + out_schemas = [ + AliasInfo( + alias_set=( + set() if a.annotation is None else set(a.annotation.alias_set) + ), + is_write=a.annotation is not None and a.annotation.is_write, + name=a.name, + ) + for a in torchgen_schema.returns + ] + else: + # For non-aten ops, torchgen is untested so we rely on torchscript schema parsing + arg_schemas = [ + AliasInfo( + alias_set=( + set() if a.alias_info is None else set(a.alias_info.before_set) + ), + is_write=a.alias_info is not None and a.alias_info.is_write, + name=a.name, + ) + for a in func._schema.arguments + ] + out_schemas = [ + AliasInfo( + alias_set=( + set() if a.alias_info is None else set(a.alias_info.before_set) + ), + is_write=a.alias_info is not None and a.alias_info.is_write, + name=a.name, + ) + for a in func._schema.returns + ] + schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas) + parsed_schema_map[func] = schema_info + return schema_info + + +def return_and_correct_aliasing(func, args, kwargs, out): + """ + This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses + that would like to work with torch.compile. It ensures that the subclass + properly implements the aliasing behavior of every op, + which is needed for correctness in AOTAutograd. + This function will handle: + + * When we see a view op, we will alias the storages of any + input and output tensor subclasses + + * When we see an inplace or out= op, we will directly + return the corresponding input tensor, instead of returning + a (potentially) fresh output tensor. + """ + + # Caching here because torchgen parsing is definitely not fast, and this function is called + # once for every op in the graph during functionalization. + schema_info = get_alias_info(func) + + def get_write_alias(x): + if len(x.alias_set) == 0: + return None + alias_set = list(x.alias_set) + # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing + assert len(alias_set) == 1 + if x.is_write: + return alias_set[0] + return None + + def get_arg_from_alias(output_alias, schema_info, args, kwargs): + new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs + ) + + arg_indices = [ + i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set + ] + # For any dispatcher op with an output alias, we expect it to map to exactly one alias in the schema's input arguments. + assert len(arg_indices) == 1 + idx = arg_indices[0] + arg_info = schema_info.args[idx] + if arg_info.name is not None and arg_info.name in new_kwargs: + return new_kwargs[arg_info.name] + return new_args[idx] + + # Fix up the storages of any outs so that they point to the same storage as the input, + # if func is a view op. + _correct_storage_aliasing( + func, schema_info, args, (out,) if not isinstance(out, tuple) else out + ) + + # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's + # metadata is set correctly. + if torch.Tag.inplace_view in func.tags: + # no_dispatch() to make sure that we secretly change the metadata on the wrapper, + # but don't end up dispatching the op anywhere else. + mutated_args = [ + x + for i, x in enumerate(args) + if get_write_alias(schema_info.args[i]) is not None + ] + # Assumption: we have a very small number of inplace_view ops that follow a strict schema: + # there is only a single argument that gets its metadata mutated. + assert len(mutated_args) == 1 + # This check exists because we generally *do* want to update the metadata of any wrapper subclasses, + # but FunctionalTensor is special: it overrides all size/stride calls to plumb to the inner tensor. + # so we don't actually need to update the metadata (and attempting to do so causes errors) + from torch._subclasses.functional_tensor import FunctionalTensor + + if not isinstance(mutated_args[0], FunctionalTensor): + with torch.utils._mode_utils.no_dispatch(): + # See Note: [Fake Tensor Dispatch Keys] + # we're borrowing the way it modifies dispatch key TLS. + meta_in_tls = torch._C._meta_in_tls_dispatch_include() + torch._C._set_meta_in_tls_dispatch_include(True) + try: + func(*args, **kwargs) + finally: + torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) + + # Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()). + + # simple case: none of our outputs have mutable aliases, so we can return the output as-is + if not any(get_write_alias(r) is not None for r in schema_info.outs): + return out + + # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)" + if not all(get_write_alias(r) is not None for r in schema_info.outs): + raise RuntimeError("Unsupported schema: " + str(func._schema)) + + if len(func._schema.returns) == 1: + return get_arg_from_alias( + get_write_alias(schema_info.outs[0]), schema_info, args, kwargs + ) + + # In the multi-return case, all aten ops return a tuple / list, so cast accordingly. + outs_to_return = type(out)( + [ + ( + get_arg_from_alias( + get_write_alias(schema_info.outs[i]), schema_info, args, kwargs + ) + if get_write_alias(r) is not None + else o + ) + for ((i, r), o) in zip(enumerate(schema_info.outs), out) + ] + ) + return outs_to_return diff --git a/phivenv/Lib/site-packages/torch/utils/_pytree.py b/phivenv/Lib/site-packages/torch/utils/_pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..78dd7b4b73a7a57d74e72d1448b166c45ca9a1fc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_pytree.py @@ -0,0 +1,2089 @@ +""" +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_leaves` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. + +This pytree implementation is not very performant due to Python overhead +To improve the performance we can move parts of the implementation to C++. +""" + +import dataclasses +import functools +import importlib +import importlib.metadata +import json +import sys +import threading +import types +import warnings +from collections import defaultdict, deque, namedtuple, OrderedDict +from collections.abc import Hashable, Iterable, Mapping, Sequence +from enum import Enum +from typing import ( + Any, + Callable, + cast, + ClassVar, + Final, + Generic, + NoReturn, + Optional, + overload, + Protocol, + TypeVar, + Union, +) +from typing_extensions import deprecated, NamedTuple, Self + +from torch.torch_version import TorchVersion as _TorchVersion + + +__all__ = [ + "PyTree", + "Context", + "FlattenFunc", + "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", + "TreeSpec", + "LeafSpec", + "keystr", + "key_get", + "register_pytree_node", + "tree_is_leaf", + "tree_flatten", + "tree_flatten_with_path", + "tree_unflatten", + "tree_iter", + "tree_leaves", + "tree_leaves_with_path", + "tree_structure", + "tree_map", + "tree_map_with_path", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_all", + "tree_any", + "tree_all_only", + "tree_any_only", + "treespec_dumps", + "treespec_loads", + "treespec_pprint", + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", +] + + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") + + +DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 +NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND" + + +class KeyEntry(Protocol): + def __hash__(self) -> int: + ... + + def __eq__(self, other: object) -> bool: + ... + + def __str__(self) -> str: + ... + + def get(self, parent: Any) -> Any: + ... + + +class EnumEncoder(json.JSONEncoder): + def default(self, obj: object) -> Union[str, dict[str, Any]]: + if isinstance(obj, Enum): + return { + "__enum__": True, + "fqn": f"{obj.__class__.__module__}:{obj.__class__.__qualname__}", + "name": obj.name, + } + return cast(str, super().default(obj)) + + +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] +ToStrFunc = Callable[["TreeSpec", list[str]], str] +MaybeFromStrFunc = Callable[[str], Optional[tuple[Any, Context, str]]] +KeyPath = tuple[KeyEntry, ...] +FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]] + + +# A NodeDef holds two callables: +# - flatten_fn should take the collection and return a flat list of values. +# It can also return some context that is used in reconstructing the +# collection. +# - unflatten_fn should take a flat list of values and some context +# (returned by flatten_fn). It returns the collection by reconstructing +# it from the list and the context. +# - flatten_with_keys_fn, which is a callable that takes a +# pytree and returns a list of (keypath, value) pairs and a context. +class NodeDef(NamedTuple): + type: type[Any] + flatten_fn: FlattenFunc + unflatten_fn: UnflattenFunc + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] + + +_NODE_REGISTRY_LOCK = threading.RLock() +SUPPORTED_NODES: dict[type[Any], NodeDef] = {} + + +# _SerializeNodeDef holds the following: +# - typ: the type of the node (e.g., "Dict", "List", etc) +# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict" +# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the +# context, and the version number +# - from_dumpable_context takes in a string representation of the context, and the +# version, and returns the deserialized context +class _SerializeNodeDef(NamedTuple): + typ: type[Any] + serialized_type_name: str + to_dumpable_context: Optional[ToDumpableContextFn] + from_dumpable_context: Optional[FromDumpableContextFn] + + +SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {} +SERIALIZED_TYPE_TO_PYTHON_TYPE: dict[str, type[Any]] = {} + +# NB: we try really hard to not import _cxx_pytree (which depends on optree) +# as much as possible. This is for isolation: a user who is not using C++ pytree +# shouldn't pay for it, and it helps makes things like cpython upgrades easier. +_optree_minimum_version = _TorchVersion("0.13.0") +try: + _optree_version = importlib.metadata.version("optree") +except importlib.metadata.PackageNotFoundError: + # No optree package found + _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False + _optree_version = _TorchVersion("0.0.0a0") +else: + _optree_version = _TorchVersion(_optree_version) + if _optree_version < _optree_minimum_version: + # optree package less than our required minimum version. + # Pretend the optree package doesn't exist. + # NB: We will raise ImportError if the user directly tries to + # `import torch.utils._cxx_pytree` (look in that file for the check). + _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = False + else: + _cxx_pytree_dynamo_traceable = _cxx_pytree_exists = True + +_cxx_pytree_imported = False +_cxx_pytree_pending_imports: list[Any] = [] + + +def register_pytree_node( + cls: type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node. + + Note: + :func:`register_dataclass` is a simpler way of registering a container-like + type as a pytree node. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + raise ValueError(f"{cls} is already registered as pytree node.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + if not _cxx_pytree_exists: + return + + if _cxx_pytree_imported: + from . import _cxx_pytree as cxx + + cxx._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + else: + args = (cls, flatten_fn, unflatten_fn) + kwargs = { + "serialized_type_name": serialized_type_name, + "to_dumpable_context": to_dumpable_context, + "from_dumpable_context": from_dumpable_context, + } + _cxx_pytree_pending_imports.append((args, kwargs)) + + +def register_dataclass( + cls: type[Any], + *, + field_names: Optional[list[str]] = None, + drop_field_names: Optional[list[str]] = None, + serialized_type_name: Optional[str] = None, +) -> None: + """ + Registers a type that has the semantics of a ``dataclasses.dataclass`` type + as a pytree node. + + This is a simpler API than :func:`register_pytree_node` for registering + a dataclass or a custom class with the semantics of a dataclass. + + Args: + cls: The python type to register. The class must have the semantics of a + dataclass; in particular, it must be constructed by passing the fields + in. + field_names (Optional[List[str]]): A list of field names that correspond + to the **non-constant data** in this class. This list must contain + all the fields that are used to initialize the class. This argument + is optional if ``cls`` is a dataclass, in which case the fields will + be taken from ``dataclasses.fields()``. + drop_field_names (Optional[List[str]]): A list of field names that + should not be included in the pytree. + serialized_type_name: A keyword argument used to specify the fully + qualified name used when serializing the tree spec. This is only + needed for serializing the treespec in torch.export. + + Example: + + >>> from torch import Tensor + >>> from dataclasses import dataclass + >>> import torch.utils._pytree as pytree + >>> + >>> @dataclass + >>> class Point: + >>> x: Tensor + >>> y: Tensor + >>> + >>> pytree.register_dataclass(Point) + >>> + >>> point = Point(torch.tensor(0), torch.tensor(1)) + >>> point = pytree.tree_map(lambda x: x + 1, point) + >>> assert torch.allclose(point.x, torch.tensor(1)) + >>> assert torch.allclose(point.y, torch.tensor(2)) + + """ + drop_field_names = drop_field_names or [] + + if not dataclasses.is_dataclass(cls): + if field_names is None: + raise ValueError( + "field_names must be specified with a list of all fields used to " + f"initialize {cls}, as it is not a dataclass." + ) + elif field_names is None: + field_names = [f.name for f in dataclasses.fields(cls) if f.init] + else: + dataclass_init_fields = {f.name for f in dataclasses.fields(cls) if f.init} + dataclass_init_fields.difference_update(drop_field_names) + + if dataclass_init_fields != set(field_names): + error_msg = "field_names does not include all dataclass fields.\n" + + if missing := dataclass_init_fields - set(field_names): + error_msg += ( + f"Missing fields in `field_names`: {missing}. If you want " + "to include these fields in the pytree, please add them " + "to `field_names`, otherwise please add them to " + "`drop_field_names`.\n" + ) + + if unexpected := set(field_names) - dataclass_init_fields: + error_msg += ( + f"Unexpected fields in `field_names`: {unexpected}. " + "Please remove these fields, or add them to `drop_field_names`.\n" + ) + + raise ValueError(error_msg) + + def _flatten_fn(obj: Any) -> tuple[list[Any], Context]: + flattened = [] + flat_names = [] + none_names = [] + for name in field_names: + val = getattr(obj, name) + if val is not None: + flattened.append(val) + flat_names.append(name) + else: + none_names.append(name) + return flattened, [flat_names, none_names] + + def _unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, none_names = context + return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) + + def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: + flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc] + return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + + _private_register_pytree_node( + cls, + _flatten_fn, + _unflatten_fn, + serialized_type_name=serialized_type_name, + flatten_with_keys_fn=_flatten_fn_with_keys, + ) + + +CONSTANT_NODES: set[type] = set() + + +def register_constant(cls: type[Any]) -> None: + """Registers a type as a pytree node with no leaves. + + In a :func:`torch.compile` region, if instances of these types get passed to + :func:`torch._dynamo.nonstrict_trace`-ed function, they treated as a + constant (sometimes referred to as "static"): + + 1. if the instance object existed before the :func:`torch.compile` region, + we _assume_ no mutation will happen to it inside the :func:`torch.compile` + region, require that it has non-default `__eq__` and `__hash__` methods, and + we guard on the instance based on its `__eq__` method, i.e., if a new + instance fails to match any instances from the previous compilations, + :func:`torch.compile` will recompile the function using the new instance. + + 2. else if the instance object is created inside the :func:`torch.compile` + region, we currently don't support using it in a + :func:`torch._dynamo.nonstrict_trace`-ed function. + + In general, if your class holds Tensors or dynamic int/float/bool (values that + may change from run-to-run of a function being compiled), then you probably + do not want to register it as a constant. + + Otherwise if you want to pass instance of a class to a + :func:`torch._dynamo.nonstrict_trace`-ed function, but you either can't use + :func:`register_pytree_node` on the class, or the class is "constant" enough + that you don't want to bother using :func:`register_pytree_node`, you should + consider using this function. + + Args: + cls: the type to register as a constant. This type must be hashable. + + Example: + + >>> from dataclasses import dataclass + >>> import torch.utils._pytree as pytree + >>> + >>> @dataclass(frozen=True) + >>> class Config: + >>> norm: str + >>> + >>> pytree.register_constant(Config) + >>> + >>> config = Config("l2") + >>> values, spec = pytree.tree_flatten(config) + >>> assert len(values) == 0 + + """ + if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap] + raise TypeError( + "register_constant(cls) expects `cls` to have a non-default `__eq__` implementation." + ) + + # Class with a custom `__eq__` without `__hash__` won't inherit the default + # `__hash__` from object; see https://stackoverflow.com/a/1608907. + if cls.__hash__ is None: # type: ignore[comparison-overlap] + raise TypeError( + "register_constant(cls) expects `cls` to have a non-default `__hash__` implementation." + ) + + def _flatten(x): # type: ignore[no-untyped-def] + return [], ConstantNode(x) + + def _unflatten(_, context): # type: ignore[no-untyped-def] + return context.value + + def _flatten_with_keys(x): # type: ignore[no-untyped-def] + return [], ConstantNode(x) + + with _NODE_REGISTRY_LOCK: + _private_register_pytree_node( + cls, + _flatten, + _unflatten, + flatten_with_keys_fn=_flatten_with_keys, + ) + CONSTANT_NODES.add(cls) + + +def is_constant_class(cls: type[Any]) -> bool: + return isinstance(cls, type) and cls in CONSTANT_NODES + + +@dataclasses.dataclass(frozen=True) +class ConstantNode: + value: Any + + +def _is_constant_holder(spec: "TreeSpec") -> bool: + """Checks if the spec is from a pytree registered with register_constant""" + return isinstance(spec.context, ConstantNode) + + +def _retrieve_constant(spec: "TreeSpec") -> Any: + """Given a spec from a pytree registered with register_constant, retrieves the constant""" + assert _is_constant_holder(spec) + return tree_unflatten([], spec) + + +def _register_namedtuple( + cls: type[Any], + *, + serialized_type_name: str, +) -> None: + """ + Registers a namedtuple as a valid pytree node. By default namedtuples are + valid pytree nodes, but they are not serializable. This API provides the + argument `serialized_type_name` which allows these namedtuples to be + serialized. + + Args: + cls: the dataclass type to register + serialized_type_name: The serialized name for the dataclass. This is + required if you want to serialize the pytree TreeSpec containing this + namedtuple. + """ + _private_register_pytree_node( + cls, + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name=serialized_type_name, + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, + flatten_with_keys_fn=_namedtuple_flatten_with_keys, + ) + + +@deprecated( + "`torch.utils._pytree._register_pytree_node` is deprecated. " + "Please use `torch.utils._pytree.register_pytree_node` instead.", + category=FutureWarning, +) +def _register_pytree_node( + cls: type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + to_str_fn: Optional[ToStrFunc] = None, # deprecated + maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node for the Python pytree only. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + """ + if to_str_fn is not None or maybe_from_str_fn is not None: + warnings.warn( + "`to_str_fn` and `maybe_from_str_fn` is deprecated. " + "Please use `to_dumpable_context` and `from_dumpable_context` instead.", + FutureWarning, + stacklevel=2, + ) + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + +def _deregister_pytree_node( + cls: type[Any], +) -> None: + """This is an internal function that is used to deregister a pytree node type + for the Python pytree only. This should be only used inside PyTorch. + """ + with _NODE_REGISTRY_LOCK: + del SUPPORTED_NODES[cls] + node_def = SUPPORTED_SERIALIZED_TYPES[cls] + del SERIALIZED_TYPE_TO_PYTHON_TYPE[node_def.serialized_type_name] + del SUPPORTED_SERIALIZED_TYPES[cls] + CONSTANT_NODES.discard(cls) + + +def _private_register_pytree_node( + cls: type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """This is an internal function that is used to register a pytree node type + for the Python pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + # TODO: change this warning to an error after OSS/internal stabilize + warnings.warn( + f"{cls} is already registered as pytree node. " + "Overwriting the previous registration.", + ) + + node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn) + SUPPORTED_NODES[cls] = node_def + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + if serialized_type_name is None: + serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND + + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + + +@dataclasses.dataclass(frozen=True) +class SequenceKey(Generic[T]): + idx: int + + def __str__(self) -> str: + return f"[{self.idx!r}]" + + def get(self, sequence: Sequence[T]) -> T: + return sequence[self.idx] + + +K = TypeVar("K", bound=Hashable) + + +@dataclasses.dataclass(frozen=True) +class MappingKey(Generic[K, T]): + key: K + + def __str__(self) -> str: + return f"[{self.key!r}]" + + def get(self, mapping: Mapping[K, T]) -> T: + return mapping[self.key] + + +@dataclasses.dataclass(frozen=True) +class GetAttrKey: + name: str + + def __str__(self) -> str: + return f".{self.name}" + + def get(self, obj: Any) -> Any: + return getattr(obj, self.name) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_namedtuple(obj: Union[object, type]) -> bool: + """Return whether the object is an instance of namedtuple or a subclass of namedtuple.""" + cls = obj if isinstance(obj, type) else type(obj) + return is_namedtuple_class(cls) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_namedtuple_class(cls: type) -> bool: + """Return whether the class is a subclass of namedtuple.""" + return ( + isinstance(cls, type) + and issubclass(cls, tuple) + and isinstance(getattr(cls, "_fields", None), tuple) + and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined] + and callable(getattr(cls, "_make", None)) + and callable(getattr(cls, "_asdict", None)) + ) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_namedtuple_instance(obj: object) -> bool: + """Return whether the object is an instance of namedtuple.""" + return is_namedtuple_class(type(obj)) + + +_T_co = TypeVar("_T_co", covariant=True) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +class structseq(tuple[_T_co, ...]): + """A generic type stub for CPython's ``PyStructSequence`` type.""" + + __slots__: ClassVar[tuple[()]] = () + + n_fields: Final[int] # type: ignore[misc] + n_sequence_fields: Final[int] # type: ignore[misc] + n_unnamed_fields: Final[int] # type: ignore[misc] + + def __init_subclass__(cls) -> NoReturn: + """Prohibit subclassing.""" + raise TypeError("type 'structseq' is not an acceptable base type") + + def __new__( + cls: type[Self], + sequence: Iterable[_T_co], + dict: dict[str, Any] = ..., + ) -> Self: + raise NotImplementedError + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_structseq(obj: Union[object, type]) -> bool: + """Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.""" + cls = obj if isinstance(obj, type) else type(obj) + return is_structseq_class(cls) + + +# Set if the type allows subclassing (see CPython's Include/object.h) +Py_TPFLAGS_BASETYPE: int = 1 << 10 + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_structseq_class(cls: type) -> bool: + """Return whether the class is a class of PyStructSequence.""" + return ( + isinstance(cls, type) + # Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)` + and cls.__bases__ == (tuple,) + # Check PyStructSequence members + and isinstance(getattr(cls, "n_fields", None), int) + and isinstance(getattr(cls, "n_sequence_fields", None), int) + and isinstance(getattr(cls, "n_unnamed_fields", None), int) + # Check the type does not allow subclassing + and not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # only works for CPython + ) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_structseq_instance(obj: object) -> bool: + """Return whether the object is an instance of PyStructSequence.""" + return is_structseq_class(type(obj)) + + +def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]: + return list(d), None + + +def _tuple_flatten_with_keys( + d: tuple[T, ...] +) -> tuple[list[tuple[KeyEntry, T]], Context]: + values, context = _tuple_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _tuple_unflatten(values: Iterable[T], context: Context) -> tuple[T, ...]: + return tuple(values) + + +def _list_flatten(d: list[T]) -> tuple[list[T], Context]: + return d, None + + +def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]: + values, context = _list_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _list_unflatten(values: Iterable[T], context: Context) -> list[T]: + return list(values) + + +def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]: + return list(d.values()), list(d.keys()) + + +def _dict_flatten_with_keys( + d: dict[Any, T] +) -> tuple[list[tuple[KeyEntry, T]], Context]: + values, context = _dict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _dict_unflatten(values: Iterable[T], context: Context) -> dict[Any, T]: + return dict(zip(context, values)) + + +def _namedtuple_flatten(d: NamedTuple) -> tuple[list[Any], Context]: + return list(d), type(d) + + +def _namedtuple_flatten_with_keys( + d: NamedTuple, +) -> tuple[list[tuple[KeyEntry, Any]], Context]: + values, context = _namedtuple_flatten(d) + return ( + [(GetAttrKey(field), v) for field, v in zip(context._fields, values)], + context, + ) + + +def _namedtuple_unflatten(values: Iterable[T], context: Context) -> NamedTuple: + return cast(NamedTuple, context(*values)) + + +def _namedtuple_serialize(context: Context) -> DumpableContext: + if context not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Can't serialize TreeSpec of namedtuple class {context} because we " + "didn't register a serializated_type_name. Please register using " + "`_register_namedtuple`." + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context] + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"Can't serialize TreeSpec of namedtuple class {context} because we " + "couldn't find a serializated_type_name. Please register using " + "`_register_namedtuple`." + ) + return serialized_type_name + + +def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: + if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} " + "because we couldn't find a serializated name." + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context] + return typ + + +def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]: + return list(d.values()), list(d.keys()) + + +def _ordereddict_flatten_with_keys( + d: OrderedDict[Any, T] +) -> tuple[list[tuple[KeyEntry, T]], Context]: + values, context = _ordereddict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _ordereddict_unflatten( + values: Iterable[T], + context: Context, +) -> OrderedDict[Any, T]: + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_odict_flatten = _ordereddict_flatten +_odict_unflatten = _ordereddict_unflatten + + +def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]: + values, dict_context = _dict_flatten(d) + return values, [d.default_factory, dict_context] + + +def _defaultdict_flatten_with_keys( + d: defaultdict[Any, T] +) -> tuple[list[tuple[KeyEntry, T]], Context]: + values, context = _defaultdict_flatten(d) + _, dict_context = context + return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context + + +def _defaultdict_unflatten( + values: Iterable[T], + context: Context, +) -> defaultdict[Any, T]: + default_factory, dict_context = context + return defaultdict(default_factory, _dict_unflatten(values, dict_context)) + + +def _defaultdict_serialize(context: Context) -> DumpableContext: + default_factory, dict_context = context + json_defaultdict = { + "default_factory_module": default_factory.__module__, + "default_factory_name": default_factory.__qualname__, + "dict_context": dict_context, + } + return json_defaultdict + + +def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context: + assert isinstance(dumpable_context, dict) + assert set(dumpable_context) == { + "default_factory_module", + "default_factory_name", + "dict_context", + } + + default_factory_module = dumpable_context["default_factory_module"] + default_factory_name = dumpable_context["default_factory_name"] + assert isinstance(default_factory_module, str) + assert isinstance(default_factory_name, str) + module = importlib.import_module(default_factory_module) + default_factory = getattr(module, default_factory_name) + + dict_context = dumpable_context["dict_context"] + return [default_factory, dict_context] + + +def _deque_flatten(d: deque[T]) -> tuple[list[T], Context]: + return list(d), d.maxlen + + +def _deque_flatten_with_keys( + d: deque[T], +) -> tuple[list[tuple[KeyEntry, T]], Context]: + values, context = _deque_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]: + return deque(values, maxlen=context) + + +_private_register_pytree_node( + tuple, + _tuple_flatten, + _tuple_unflatten, + serialized_type_name="builtins.tuple", + flatten_with_keys_fn=_tuple_flatten_with_keys, +) +_private_register_pytree_node( + list, + _list_flatten, + _list_unflatten, + serialized_type_name="builtins.list", + flatten_with_keys_fn=_list_flatten_with_keys, +) +_private_register_pytree_node( + dict, + _dict_flatten, + _dict_unflatten, + serialized_type_name="builtins.dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) +_private_register_pytree_node( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name="collections.namedtuple", + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, + flatten_with_keys_fn=_namedtuple_flatten_with_keys, +) +_private_register_pytree_node( + OrderedDict, + _ordereddict_flatten, + _ordereddict_unflatten, + serialized_type_name="collections.OrderedDict", + flatten_with_keys_fn=_ordereddict_flatten_with_keys, +) +_private_register_pytree_node( + defaultdict, + _defaultdict_flatten, + _defaultdict_unflatten, + serialized_type_name="collections.defaultdict", + to_dumpable_context=_defaultdict_serialize, + from_dumpable_context=_defaultdict_deserialize, + flatten_with_keys_fn=_defaultdict_flatten_with_keys, +) +_private_register_pytree_node( + deque, + _deque_flatten, + _deque_unflatten, + serialized_type_name="collections.deque", + flatten_with_keys_fn=_deque_flatten_with_keys, +) + + +STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict}) +BUILTIN_TYPES: frozenset[type] = frozenset( + { + tuple, + list, + dict, + namedtuple, # type: ignore[arg-type] + OrderedDict, + defaultdict, + deque, + }, +) + + +@deprecated( + "torch.utils._pytree._is_namedtuple_instance is private and will be removed in a future release. " + "Please use torch.utils._pytree.is_namedtuple_instance instead.", + category=FutureWarning, +) +def _is_namedtuple_instance(tree: Any) -> bool: + return is_namedtuple_instance(tree) + + +def _get_node_type(tree: Any) -> Any: + node_type = type(tree) + # All namedtuple types are implicitly registered as pytree nodes. + # XXX: Other parts of the codebase expect namedtuple types always return + # `namedtuple` instead of the actual namedtuple type. Even if the type + # is explicitly registered. + if is_namedtuple_class(node_type): + return namedtuple + return node_type + + +# A leaf is defined as anything that is not a Node. +def tree_is_leaf( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + """Check if a pytree is a leaf. + + >>> tree_is_leaf(1) + True + >>> tree_is_leaf(None) + True + >>> tree_is_leaf([1, 2, 3]) + False + >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) + True + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + False + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + False + """ + if is_leaf is not None and is_leaf(tree): + return True + return _get_node_type(tree) not in SUPPORTED_NODES + + +@deprecated( + "torch.utils._pytree._is_leaf is private and will be removed in a future release. " + "Please use torch.utils._pytree.tree_is_leaf instead.", + category=FutureWarning, +) +def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: + return tree_is_leaf(tree, is_leaf=is_leaf) + + +# A TreeSpec represents the structure of a pytree. It holds: +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children_specs: specs for each child of the root Node +# num_leaves: the number of leaves +@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False) +class TreeSpec: + type: Any + context: Context + children_specs: list["TreeSpec"] + + num_nodes: int = dataclasses.field(init=False) + num_leaves: int = dataclasses.field(init=False) + num_children: int = dataclasses.field(init=False) + + def __post_init__(self) -> None: + num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1) + num_leaves = sum(spec.num_leaves for spec in self.children_specs) + num_children = len(self.children_specs) + object.__setattr__(self, "num_nodes", num_nodes) + object.__setattr__(self, "num_leaves", num_leaves) + object.__setattr__(self, "num_children", num_children) + + def __repr__(self, indent: int = 0) -> str: + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + children_specs_str: str = "" + if self.num_children > 0: + indent += 2 + children_specs_str += self.children_specs[0].__repr__(indent) + children_specs_str += "," if self.num_children > 1 else "" + children_specs_str += ",".join( + [ + "\n" + " " * indent + child.__repr__(indent) + for child in self.children_specs[1:] + ] + ) + repr_suffix: str = f"{children_specs_str}])" + return repr_prefix + repr_suffix + + def __eq__(self, other: PyTree) -> bool: + if self is other: + return True + elif other.__class__ is self.__class__: + if str(self.type) != str(other.type): + return False + if self.context != other.context: + return False + elif self.children_specs != other.children_specs: + return False + return True + return NotImplemented + + def is_leaf(self) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def flatten_up_to(self, tree: PyTree) -> list[PyTree]: + def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: + if treespec.is_leaf(): + subtrees.append(tree) + return + + node_type = _get_node_type(tree) + if treespec.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != treespec.type: + raise ValueError( + f"Type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + children, context = flatten_fn(tree) + if len(children) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(children)}.", + ) + if context != treespec.context: + raise ValueError( + f"Node context mismatch for custom node type {treespec.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + treespec.type in STANDARD_DICT_TYPES + and node_type in STANDARD_DICT_TYPES + ) + if not both_standard_dict and node_type != treespec.type: + raise ValueError( + f"Node type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + if len(tree) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(tree)}.", + ) + + if both_standard_dict: + # dictionary types are compatible with each other + dict_context = ( + treespec.context + if treespec.type is not defaultdict + # ignore mismatch of `default_factory` for defaultdict + else treespec.context[1] + ) + expected_keys = dict_context + got_key_set = set(tree) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + children = [tree[key] for key in expected_keys] + else: + # node_type is treespec.type + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + children, context = flatten_fn(tree) + if ( + node_type is not deque # ignore mismatch of `maxlen` for deque + ) and context != treespec.context: + raise ValueError( + f"Node context mismatch for node type {treespec.type!r}; " + f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch + ) + + for subtree, subspec in zip(children, treespec.children_specs): + helper(subspec, subtree, subtrees) + + subtrees: list[PyTree] = [] + helper(self, tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any]) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn + + # Recursively unflatten the children + start = 0 + end = 0 + child_pytrees = [] + for child_spec in self.children_specs: + end += child_spec.num_leaves + child_pytrees.append(child_spec.unflatten(leaves[start:end])) + start = end + + return unflatten_fn(child_pytrees, self.context) + + def __hash__(self) -> int: + node_type = self.type + if node_type is defaultdict: + default_factory, dict_context = self.context + hashable_context = (default_factory, tuple(dict_context)) + elif node_type in (dict, OrderedDict): + hashable_context = tuple(self.context) + elif node_type is None or node_type in BUILTIN_TYPES: + hashable_context = self.context + elif isinstance(self.context, ConstantNode): + hashable_context = self.context.value + else: + # The context for user-defined node types might not be hashable. + # Ignore it for hashing. + # This does not break the correctness that equal objects imply the + # same hash. This might increase the hash collision rate, but we + # don't care about that. + hashable_context = None + return hash((node_type, hashable_context, tuple(self.children_specs))) + + +# NOTE: subclassing a dataclass is subtle. In order to enable reasoning about +# this class with `dataclasses.fields`, etc., while having a simplified +# constructor that takes no argument, we wrap with `dataclass(init=True, ...)` +# again, with fields that have `init=False`. +@dataclasses.dataclass(init=True, frozen=True, eq=False, repr=False) +class LeafSpec(TreeSpec): + type: Any = dataclasses.field(default=None, init=False) + context: Context = dataclasses.field(default=None, init=False) + children_specs: list["TreeSpec"] = dataclasses.field( + default_factory=list, init=False + ) + + def __post_init__(self) -> None: + # Override `__post_init__` for `num_leaves` derivation. + object.__setattr__(self, "num_nodes", 1) + object.__setattr__(self, "num_leaves", 1) + object.__setattr__(self, "num_children", 0) + + def __repr__(self, indent: int = 0) -> str: + return "*" + + +# All leaves are equivalent, so represent with a single object to save on +# object construction time +_LEAF_SPEC = LeafSpec() + + +def tree_flatten( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> tuple[list[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used + to reconstruct the pytree. + """ + + def helper(node: PyTree, leaves: list[Any]) -> TreeSpec: + if tree_is_leaf(node, is_leaf=is_leaf): + leaves.append(node) + return _LEAF_SPEC + + node_type = _get_node_type(node) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + children, context = flatten_fn(node) + + # Recursively flatten the children + subspecs = [helper(child, leaves) for child in children] + return TreeSpec(node_type, context, subspecs) + + leaves: list[Any] = [] + treespec = helper(tree, leaves) + return leaves, treespec + + +def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: + """Given a list of values and a TreeSpec, builds a pytree. + This is the inverse operation of `tree_flatten`. + """ + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be " + f"instance of TreeSpec but got item of type {type(treespec)}.", + ) + return treespec.unflatten(leaves) + + +def tree_iter( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Any]: + """Get an iterator over the leaves of a pytree.""" + if tree_is_leaf(tree, is_leaf=is_leaf): + yield tree + else: + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, _ = flatten_fn(tree) + + # Recursively flatten the children + for child in child_pytrees: + yield from tree_iter(child, is_leaf=is_leaf) + + +def tree_leaves( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> list[Any]: + """Get a list of leaves of a pytree.""" + return list(tree_iter(tree, is_leaf=is_leaf)) + + +def tree_structure( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + """Get the TreeSpec for a pytree.""" + return tree_flatten(tree, is_leaf=is_leaf)[1] + + +def tree_map( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Map a multi-input function over pytree args to produce a new pytree. + + See also :func:`tree_map_`. + + >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + {'x': 8, 'y': (43, 65)} + >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + {'x': False, 'y': (False, False), 'z': True} + + If multiple inputs are given, the structure of the tree is taken from the first input; + subsequent inputs need only have ``tree`` as a prefix: + + >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` + is the tuple of values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, *flat_args)) + + +def tree_map_( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. + + See also :func:`tree_map`. + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf + in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable + return tree + + +Type2 = tuple[type[T], type[S]] +Type3 = tuple[type[T], type[S], type[U]] +if sys.version_info >= (3, 10): + TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType] +else: + TypeAny = Union[type[Any], tuple[type[Any], ...]] + +Fn2 = Callable[[Union[T, S]], R] +Fn3 = Callable[[Union[T, S, U]], R] +Fn = Callable[[T], R] +FnAny = Callable[[Any], R] + +MapOnlyFn = Callable[[T], Callable[[Any], Any]] + + +# These specializations help with type inference on the lambda passed to this +# function +@overload +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: + ... + + +# This specialization is needed for the implementations below that call +@overload +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: + ... + + +def map_only( + type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], / +) -> MapOnlyFn[FnAny[Any]]: + """ + Suppose you are writing a tree_map over tensors, leaving everything + else unchanged. Ordinarily you would have to write: + + def go(t): + if isinstance(t, Tensor): + return ... + else: + return t + + With this function, you only need to write: + + @map_only(Tensor) + def go(t): + return ... + + You can also directly use 'tree_map_only' + """ + if isinstance(type_or_types_or_pred, (type, tuple)) or ( + sys.version_info >= (3, 10) + and isinstance(type_or_types_or_pred, types.UnionType) + ): + + def pred(x: Any) -> bool: + return isinstance(x, type_or_types_or_pred) # type: ignore[arg-type] + + elif callable(type_or_types_or_pred): + pred = type_or_types_or_pred # type: ignore[assignment] + else: + raise TypeError("Argument must be a type, a tuple of types, or a callable.") + + def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: + @functools.wraps(func) + def wrapped(x: T) -> Any: + if pred(x): + return func(x) + return x + + return wrapped + + return wrapper + + +@overload +def tree_map_only( + type_or_types_or_pred: type[T], + /, + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + type_or_types_or_pred: Type2[T, S], + /, + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + type_or_types_or_pred: Type3[T, S, U], + /, + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + type_or_types_or_pred: TypeAny, + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + type_or_types_or_pred: Callable[[Any], bool], + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only( + type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +@overload +def tree_map_only_( + type_or_types_or_pred: type[T], + /, + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + type_or_types_or_pred: Type2[T, S], + /, + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + type_or_types_or_pred: Type3[T, S, U], + /, + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + type_or_types_or_pred: TypeAny, + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + type_or_types_or_pred: Callable[[Any], bool], + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only_( + type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + /, + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +def tree_all( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(map(pred, flat_args)) + + +def tree_any( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(map(pred, flat_args)) + + +@overload +def tree_all_only( + type_or_types: type[T], + /, + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + type_or_types: Type2[T, S], + /, + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + type_or_types: Type3[T, S, U], + /, + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_all_only( + type_or_types: TypeAny, + /, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(pred(x) for x in flat_args if isinstance(x, type_or_types)) + + +@overload +def tree_any_only( + type_or_types: type[T], + /, + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + type_or_types: Type2[T, S], + /, + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + type_or_types: Type3[T, S, U], + /, + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_any_only( + type_or_types: TypeAny, + /, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(pred(x) for x in flat_args if isinstance(x, type_or_types)) + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten( + tree: PyTree, + treespec: TreeSpec, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Optional[list[Any]]: + assert isinstance(treespec, TreeSpec) + + if tree_is_leaf(tree, is_leaf=is_leaf): + return [tree] * treespec.num_leaves + if treespec.is_leaf(): + return None + node_type = _get_node_type(tree) + if node_type != treespec.type: + return None + + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, ctx = flatten_fn(tree) + + # Check if the Node is different from the spec + if len(child_pytrees) != treespec.num_children or ctx != treespec.context: + return None + + # Recursively flatten the children + result: list[Any] = [] + for child, child_spec in zip(child_pytrees, treespec.children_specs): + flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) + if flat is not None: + result += flat + else: + return None + + return result + + +@dataclasses.dataclass +class _TreeSpecSchema: + """ + _TreeSpecSchema is the schema used to serialize the TreeSpec + It contains the following fields: + - type: A string name of the type. null for the case of a LeafSpec. + - context: Any format which is json dumpable + - children_spec: A list of children serialized specs. + """ + + type: Optional[str] + context: DumpableContext + children_spec: list["_TreeSpecSchema"] + + +class _ProtocolFn(NamedTuple): + treespec_to_json: Callable[[TreeSpec], DumpableContext] + json_to_treespec: Callable[[DumpableContext], TreeSpec] + + +_SUPPORTED_PROTOCOLS: dict[int, _ProtocolFn] = {} + + +def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: + if treespec.is_leaf(): + return _TreeSpecSchema(None, None, []) + + if treespec.type not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Serializing {treespec.type} in pytree is not registered.", + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] + + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"No registered serialization name for {treespec.type} found. " + "Please update your _register_pytree_node call with a `serialized_type_name` kwarg." + ) + + if serialize_node_def.to_dumpable_context is None: + try: + serialized_context = json.dumps(treespec.context, cls=EnumEncoder) + except TypeError as e: + raise TypeError( + "Unable to serialize context. " + "Please make the context json dump-able, or register a " + "custom serializer using _register_pytree_node." + ) from e + else: + serialized_context = serialize_node_def.to_dumpable_context(treespec.context) + + child_schemas = [_treespec_to_json(child) for child in treespec.children_specs] + + return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) + + +def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]: + if "__enum__" in obj: + modname, _, classname = obj["fqn"].partition(":") + mod = importlib.import_module(modname) + enum_cls = mod + for attr in classname.split("."): + enum_cls = getattr(enum_cls, attr) + enum_cls = cast(type[Enum], enum_cls) + return enum_cls[obj["name"]] + return obj + + +def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: + if ( + json_schema["type"] is None + and json_schema["context"] is None + and len(json_schema["children_spec"]) == 0 + ): + return _LEAF_SPEC + + if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f'Deserializing {json_schema["type"]} in pytree is not registered.', + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] + + if serialize_node_def.from_dumpable_context is None: + try: + context = json.loads(json_schema["context"], object_hook=enum_object_hook) + except TypeError as ex: + raise TypeError( + "Unable to deserialize context. " + "Please make the context json load-able, or register a " + "custom serializer using _register_pytree_node.", + ) from ex + else: + context = serialize_node_def.from_dumpable_context(json_schema["context"]) + + children_specs = [ + _json_to_treespec(child_string) for child_string in json_schema["children_spec"] + ] + + return TreeSpec(typ, context, children_specs) + + +_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) + + +def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " + f"TreeSpec but got item of type {type(treespec)}.", + ) + + if protocol is None: + protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL + + if protocol in _SUPPORTED_PROTOCOLS: + json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) + else: + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + str_spec = json.dumps((protocol, dataclasses.asdict(json_spec)), cls=EnumEncoder) + return str_spec + + +@functools.lru_cache +def treespec_loads(serialized: str) -> TreeSpec: + protocol, json_schema = json.loads(serialized) + + if protocol in _SUPPORTED_PROTOCOLS: + return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + +class _DummyLeaf: + def __repr__(self) -> str: + return "*" + + +def treespec_pprint(treespec: TreeSpec) -> str: + dummy_tree = tree_unflatten( + [_DummyLeaf() for _ in range(treespec.num_leaves)], + treespec, + ) + return repr(dummy_tree) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +@deprecated( + "`pytree_to_str` is deprecated. Please use `treespec_dumps` instead.", + category=FutureWarning, +) +def pytree_to_str(treespec: TreeSpec) -> str: + return treespec_dumps(treespec) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +@deprecated( + "`str_to_pytree` is deprecated. Please use `treespec_loads` instead.", + category=FutureWarning, +) +def str_to_pytree(json: str) -> TreeSpec: + return treespec_loads(json) + + +def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> list[Any]: + """Get a flat list of arguments to this function + + A slightly faster version of tree_leaves((args, kwargs)) + """ + leaves: list[Any] = [] + for a in args: + leaves.extend(tree_iter(a)) + for a in kwargs.values(): + leaves.extend(tree_iter(a)) + return leaves + + +def tree_flatten_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]: + """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. + + Args: + tree: a pytree to flatten. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A tuple where the first element is a list of (key path, leaf) pairs, and the + second element is a :class:`TreeSpec` representing the structure of the flattened + tree. + """ + _, treespec = tree_flatten(tree, is_leaf) + return list(_generate_key_paths((), tree, is_leaf)), treespec + + +def tree_leaves_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> list[tuple[KeyPath, Any]]: + """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. + + Args: + tree: a pytree. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A list of (key path, leaf) pairs. + """ + return list(_generate_key_paths((), tree, is_leaf)) + + +def _generate_key_paths( + key_path: KeyPath, + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[tuple[KeyPath, Any]]: + if is_leaf and is_leaf(tree): + yield key_path, tree + return + + node_type = _get_node_type(tree) + handler = SUPPORTED_NODES.get(node_type) + if not handler: + # This is a leaf + yield key_path, tree + return + + flatten_with_keys = handler.flatten_with_keys_fn + if flatten_with_keys: + key_children, _ = flatten_with_keys(tree) + for k, c in key_children: + yield from _generate_key_paths((*key_path, k), c, is_leaf) + else: + # We registered this pytree but didn't add a flatten_with_keys_fn, complain. + raise ValueError( + f"Did not find a flatten_with_keys_fn for type: {node_type}. " + "Please pass a flatten_with_keys_fn argument to register_pytree_node." + ) + + +def tree_map_with_path( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but the provided callable takes an additional key path argument. + + Args: + func: A function that takes ``2 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. The first positional argument + to ``func`` is the key path of the leaf in question. The second + positional argument is the value of the leaf. + tree: A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests: A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the + corresponding leaf in ``tree``, ``x`` is the value at that leaf, and + ``xs`` is the tuple of values at corresponding nodes in ``rests``. + """ + keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf) + keypath_leaves = list(zip(*keypath_leaves)) + all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves)) + + +def keystr(kp: KeyPath) -> str: + """Given a key path, return a pretty-printed representation.""" + return "".join([str(k) for k in kp]) + + +def key_get(obj: Any, kp: KeyPath) -> Any: + """Given an object and a key path, return the value at the key path.""" + for k in kp: + obj = k.get(obj) + return obj diff --git a/phivenv/Lib/site-packages/torch/utils/_stats.py b/phivenv/Lib/site-packages/torch/utils/_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..efa017bf8bd10b7e1ef5d2e7a32e1932e48bc278 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_stats.py @@ -0,0 +1,30 @@ +# NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE. +# IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils +# AND SCRUB AWAY TORCH NOTIONS THERE. +import collections +import functools +from collections import OrderedDict +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + + +simple_call_counter: OrderedDict[str, int] = collections.OrderedDict() + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def count_label(label: str) -> None: + prev = simple_call_counter.setdefault(label, 0) + simple_call_counter[label] = prev + 1 + + +def count(fn: Callable[_P, _R]) -> Callable[_P, _R]: + @functools.wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + if fn.__qualname__ not in simple_call_counter: + simple_call_counter[fn.__qualname__] = 0 + simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1 + return fn(*args, **kwargs) + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/utils/_strobelight/__init__.py b/phivenv/Lib/site-packages/torch/utils/_strobelight/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/_strobelight/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_strobelight/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e612ad9303720f2dbc23e07e991b29714f721de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_strobelight/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ded53f36ebc711bf45e6db834c3da54b0be03c3a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_strobelight/cli_function_profiler.py b/phivenv/Lib/site-packages/torch/utils/_strobelight/cli_function_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..033ad81b8c1adff7221da4ad1a2c52e13a872af5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_strobelight/cli_function_profiler.py @@ -0,0 +1,312 @@ +# mypy: disallow-untyped-defs + +import functools +import logging +import os +import re +import subprocess +import time +from collections.abc import Sequence +from threading import Lock +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + + +logger = logging.getLogger("strobelight_function_profiler") + +console_handler = logging.StreamHandler() +formatter = logging.Formatter( + "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s" +) +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) +logger.setLevel(logging.INFO) +logger.propagate = False + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +class StrobelightCLIProfilerError(Exception): + """ + Raised when an error happens during strobelight profiling + """ + + +def _pid_namespace_link(pid: Optional[int] = None) -> str: + """Returns the link to the process's namespace, example: pid:[4026531836]""" + PID_NAMESPACE_PATH = "/proc/{}/ns/pid" + pid = pid or os.getpid() + return os.readlink(PID_NAMESPACE_PATH.format(pid)) + + +def _pid_namespace(pid: Optional[int] = None) -> int: + """Returns the process's namespace id""" + pid = pid or os.getpid() + link = _pid_namespace_link(pid) + return int(link[link.find("[") + 1 : -1]) + + +def _command_to_string(command: Sequence[str]) -> str: + return " ".join(command) + + +class StrobelightCLIFunctionProfiler: + """ + Note: this is a meta only tool. + + StrobelightCLIFunctionProfiler can be used to profile a python function and + generate a strobelight link with the results. It works on meta servers but + does not requries an fbcode target. + When stop_at_error is false(default), error during profiling does not prevent + the work function from running. + + Check function_profiler_example.py for an example. + """ + + # This lock is used to make sure only one thread is running the profiler at any point. + _lock = Lock() + + def __init__( + self, + *, + stop_at_error: bool = False, + max_profile_duration_sec: int = 60 * 10, + sample_each: float = 1e7, # sample each sample_each cycles. + run_user_name: str = "pytorch-strobelight-ondemand", + timeout_wait_for_running_sec: int = 60, + timeout_wait_for_finished_sec: int = 60, + recorded_env_variables: Optional[list[str]] = None, + sample_tags: Optional[list[str]] = None, + stack_max_len: int = 127, + async_stack_max_len: int = 127, + ): + self.stop_at_error = stop_at_error + self.max_profile_duration_sec = max_profile_duration_sec + self.sample_each = sample_each + self.run_user_name = run_user_name + self.timeout_wait_for_running_sec = timeout_wait_for_running_sec + self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec + # Results of the most recent run. + # Tracks the strobelight run id of the most recent run + self.current_run_id: Optional[int] = None + self.sample_tags = sample_tags + + def _run_async(self) -> None: + processId = os.getpid() + namespace = _pid_namespace(processId) + command = [ + "strobeclient", + "run", + "--profiler", + "pyperf", + "--event", + "cycles", + "--async", + "--sample-interval", + f"{int(self.sample_each)}", + "--duration-ms", + f"{int(self.max_profile_duration_sec * 1000)}", + "--pid", + f"{namespace}:{processId}", + ] + + if self.sample_tags: + command.append("--sample-tags") + command.append(",".join(self.sample_tags)) + + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in run_async:{output}" + ) + + if match := re.search(r"INFO Run Id: (-?\d+)", output): + self.current_run_id = int(match.group(1)) + return + + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, unexpected result {output}" + ) + + def _wait_for_running(self, counter: int = 0) -> None: + if counter > 20: + raise StrobelightCLIProfilerError( + "wait_for_running called more than 20 times" + ) + + command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in wait_for_running:{output}" + ) + + if match := re.search("Profile run status: (.*)", output): + current_status = match.group(1) + if current_status == "RUNNING": + return + elif current_status == "PREPARING": + time.sleep(10) + self._wait_for_running(counter + 1) + return + else: + raise StrobelightCLIProfilerError(f"unexpected {current_status} phase") + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _stop_run(self) -> None: + command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, return code is not 0 :{output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Success!"): + return + else: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, got {current_status} result" + ) + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _get_results(self) -> None: + command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, return code is not 0 : {output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Profile run status: PROCESSING"): + time.sleep(10) + self._get_results() + return + elif not current_status.__contains__("Profile run finished with SUCCESS"): + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, unexpected response {output}" + ) + + for item in re.findall( + r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))", + output, + ): + logger.info(item[0]) + + def _stop_strobelight_no_throw( + self, + collect_results: bool, + ) -> None: + try: + # call stop run + self._stop_run() + logger.info("strobelight profiling stopped") + + logger.debug("collection stopped") + + if not collect_results: + return + + self._get_results() + except Exception: + logger.warning("error during stop_strobelight", exc_info=True) + + # Return true if strobelight started and is running. Never throw. + def _start_strobelight(self) -> bool: + strobelight_started = False + try: + self._run_async() + strobelight_started = True + logger.info("strobelight run id is: %s", self.current_run_id) + self._wait_for_running() + logger.info("strobelight profiling running") + return True + + except Exception: + logger.warning("error during start_strobelight:", exc_info=True) + if strobelight_started: + self._stop_strobelight_no_throw(collect_results=False) + return False + + def profile( + self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs + ) -> Optional[_R]: + self.current_run_id = None + + if locked := StrobelightCLIFunctionProfiler._lock.acquire(False): + if not locked: + if self.stop_at_error: + raise StrobelightCLIProfilerError("concurrent runs not supported") + + logger.warning("concurrent runs not supported") + return work_function(*args, **kwargs) + + started = self._start_strobelight() + if not started: + if self.stop_at_error: + StrobelightCLIFunctionProfiler._lock.release() + raise StrobelightCLIProfilerError( + "failed to start strobelight profiling" + ) + result = work_function(*args, **kwargs) + StrobelightCLIFunctionProfiler._lock.release() + return result + + try: + logger.debug("collection started") + result = work_function(*args, **kwargs) + self._stop_strobelight_no_throw(collect_results=True) + StrobelightCLIFunctionProfiler._lock.release() + return result + except Exception as error: + logger.warning("work function throw exception", exc_info=True) + self._stop_strobelight_no_throw(collect_results=False) + StrobelightCLIFunctionProfiler._lock.release() + raise error + return None + + +# A function decorator that wraps profile, if no profiler is provided one with +# default args is created. A function can be annotated as: +# @strobelight() +# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..)) +# @strobelight(stop_at_error=True,...) +def strobelight( + profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any +) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]: + if not profiler: + profiler = StrobelightCLIFunctionProfiler(**kwargs) + + def strobelight_inner( + work_function: Callable[_P, _R] + ) -> Callable[_P, Optional[_R]]: + @functools.wraps(work_function) + def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: + return profiler.profile(work_function, *args, **kwargs) + + return wrapper_function + + return strobelight_inner diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__init__.py b/phivenv/Lib/site-packages/torch/utils/_sympy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..226bc9ba022b759b6bcf93cce2b80c39ef2ed3e2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a8d4b6b258f4b29e6185d778ad6f450667e1bac Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..957dae61283937c0d5db401fbc5f644111b38c98 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/interp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb7fadacdaddd69726a2fb69affdbf46879517a6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/numbers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb065c4308356cba0fc9352fbd47dea845d02184 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/printers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/reference.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/reference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..479685616efab6bf65925cefe9d07c52c1c7d650 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/reference.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa267b0cdc499e000ebd9c5a87ef64ccd366a2a4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/singleton_int.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..059aab6fe504fa9ece93d16c010daf788bfd37e1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/solve.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..959483c80b72befb1a5e0b17a511952fe6faf230 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/symbol.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ce404c725eb68545bbfe62c9cd0b3ceb3a53dd0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/_sympy/__pycache__/value_ranges.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/functions.py b/phivenv/Lib/site-packages/torch/utils/_sympy/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..08c1e98ff256acbc3c4b4ca09cef6fdd44b3618c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_sympy/functions.py @@ -0,0 +1,1407 @@ +# mypy: allow-untyped-defs +import functools +import math +import operator +import sys +from typing import Callable, Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union +from typing_extensions import TypeVarTuple, Unpack + +import sympy +from sympy import S +from sympy.core import sympify +from sympy.core.expr import Expr +from sympy.core.function import Application +from sympy.core.logic import _torf, fuzzy_and, fuzzy_or +from sympy.core.numbers import equal_valued +from sympy.core.operations import LatticeOp, ShortCircuit +from sympy.core.sorting import ordered +from sympy.core.traversal import walk +from sympy.printing.precedence import PRECEDENCE +from sympy.utilities.iterables import sift + +from .numbers import int_oo + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +_T = TypeVar("_T", bound=SupportsFloat) +_Ts = TypeVarTuple("_Ts") + +# Portions of this file are adapted from the Sympy codebase, which was +# licensed as follows: +# +# Copyright (c) 2006-2023 SymPy Development Team +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# a. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# b. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# c. Neither the name of SymPy nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. + +__all__ = [ + "FloorDiv", + "ModularIndexing", + "Where", + "PythonMod", + "Mod", + "CleanDiv", + "CeilToInt", + "FloorToInt", + "CeilDiv", + "IntTrueDiv", + "FloatTrueDiv", + "LShift", + "RShift", + "IsNonOverlappingAndDenseIndicator", + "TruncToFloat", + "TruncToInt", + "RoundToInt", + "RoundDecimal", + "ToFloat", + "FloatPow", + "PowByNatural", + "Identity", +] + + +def _is_symbols_binary_summation(expr: sympy.Expr) -> bool: + # No need to check that two args are not the same, since expr is pr-optimized but we do it anyway. + return ( + expr.is_Add + and len(expr._args) == 2 + and expr._args[0].is_symbol + and expr._args[1].is_symbol + and expr._args[0] is not expr._args[1] + ) + + +def _keep_float( + f: Callable[[Unpack[_Ts]], _T] +) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]: + @functools.wraps(f) + def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: + r: Union[_T, sympy.Float] = f(*args) + if any(isinstance(a, sympy.Float) for a in args) and not isinstance( + r, sympy.Float + ): + r = sympy.Float(float(r)) + return r + + return inner + + +def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]: + if None in (x, y): + return None + return x == y + + +def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic: + """ + Fast path for sympy.gcd, using a simple factoring strategy. + + We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0, + where n is the greatest common integer factor and e is the largest + syntactic common factor (i.e., common sub-expression) in p and q. + Then the gcd returned is n*e, cancelling which we would be left with + p1 + p2 and q0. + + Note that further factoring of p1 + p2 and q0 might be possible with + sympy.factor (which uses domain-specific theories). E.g., we are unable + to find that x*y + x + y + 1 is divisible by x + 1. More generally, + when q is of the form q1 + q2 (instead of being already factored) it + might be necessary to fall back on sympy.gcd. + """ + + def integer_coefficient(x: sympy.Basic) -> int: + integer_coefficients: list[int] = [ + abs(int(arg)) + for arg in sympy.Mul.make_args(x) + if isinstance(arg, (int, sympy.Integer)) + ] + return math.prod(integer_coefficients) + + def integer_factor(expr: sympy.Basic) -> int: + integer_factors: Iterable[int] = map( + integer_coefficient, sympy.Add.make_args(expr) + ) + return functools.reduce(math.gcd, integer_factors) + + gcd: int = math.gcd(integer_factor(p), integer_factor(q)) + p, q = p / gcd, q / gcd # type: ignore[operator, assignment] # remove in py3.12 + + base_splits: list[tuple[sympy.Basic, ...]] = list( + map(sympy.Mul.make_args, sympy.Add.make_args(p)) + ) + divisor_split: tuple[sympy.Basic, ...] = sympy.Mul.make_args(q) + for x in divisor_split: + if all(x in base_split for base_split in base_splits): + gcd = gcd * x # type: ignore[operator] # remove in py3.12 + return gcd # type: ignore[return-value] # remove in py3.12 + + +# It would be nice to have assertions on whether or not inputs is_integer +# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy +# sometimes inconsistently reports floats an integers. +# +# What we can assume from sympy is that if something is an int, it +# definitely is is_integer, but if it is a float it may or may not +# be is_integer. So we are unable to do strong asserts that things +# are NOT integers. + + +# TODO: In Triton, // rounds to zero, but in Python, it is floor division. +# When we can prove both arguments are non-negative, we should just have a +# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, +# and then PythonFloorDiv and CIntDiv which have the appropriate rounding +# semantics. +# +# Right now, FloorDiv de facto changes behavior if arguments are negative or +# not, this can potentially cause correctness issues. +class FloorDiv(sympy.Function): + """ + We maintain this so that: + 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. + 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) + + NB: This is Python-style floor division, round to -Inf + """ + + nargs: tuple[int, ...] = (2,) + precedence: int = 35 # lower precedence than add + is_integer: bool = True + + @property + def base(self) -> sympy.Basic: + return self.args[0] + + @property + def divisor(self) -> sympy.Basic: + return self.args[1] + + def _sympystr(self, printer: sympy.printing.StrPrinter) -> str: + base = printer.parenthesize(self.base, PRECEDENCE["Atom"] - 0.5) + divisor = printer.parenthesize(self.divisor, PRECEDENCE["Atom"] - 0.5) + return f"({base}//{divisor})" + + # Automatic evaluation. + # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval + @classmethod + def eval( + cls, base: sympy.Integer, divisor: sympy.Integer + ) -> Union[sympy.Basic, None]: + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # Assert triggered by inequality solver + # assert base.is_integer, base + # assert divisor.is_integer, divisor + + # We don't provide the same error message as in Python because SymPy + # makes it difficult to check the types. + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in ( + int_oo, + -int_oo, + sympy.oo, + -sympy.oo, + ): + return sympy.nan + if base is sympy.nan or divisor is sympy.nan: + return sympy.nan + + if base.is_zero: + return sympy.S.Zero + if base.is_integer and equal_valued(divisor, 1): + return base + if base.is_integer and equal_valued(divisor, -1): + return sympy.Mul(base, -1) + if ( + isinstance(base, sympy.Number) + and isinstance(divisor, sympy.Number) + and ( + base in (int_oo, -int_oo, sympy.oo, -sympy.oo) + or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) + ) + ): + r = float(base) / float(divisor) + if r == math.inf: + return int_oo + elif r == -math.inf: + return -int_oo + elif math.isnan(r): + return sympy.nan + else: + return sympy.Integer(math.floor(r)) + if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): + return sympy.Integer(int(base) // int(divisor)) + if isinstance(base, FloorDiv): + return FloorDiv(base.args[0], base.args[1] * divisor) + + # Expands (x + y) // b into x // b + y // b. + # This only works if floor is an identity, i.e. x / b is an integer. + if isinstance(divisor, sympy.Integer): + quotients = 0 + terms = [] + for term in sympy.Add.make_args(base): + quotient = term / divisor + + if quotient.is_integer: + terms.append(term) + quotients += quotient + + if len(terms) != 0: + # Passing evaluate = False since expression will be optimized during the subtraction post its construction. + return ( + FloorDiv(base - sympy.Add(*terms, evaluate=False), divisor) + + quotients + ) + + try: + gcd = simple_floordiv_gcd(base, divisor) + if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add): + gcd = sympy.gcd(base, divisor) + if not equal_valued(gcd, 1): + return FloorDiv( + sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) + ) + except sympy.PolynomialError: + pass # https://github.com/pytorch/pytorch/issues/108276 + + return None + + def _ccode(self, printer): + base = printer.parenthesize(self.base, PRECEDENCE["Atom"] - 0.5) + divisor = printer.parenthesize(self.divisor, PRECEDENCE["Atom"] - 0.5) + return f"floor({base}/{divisor})" + + +class ModularIndexing(sympy.Function): + """ + ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus + """ + + nargs: tuple[int, ...] = (3,) + is_integer: bool = True + precedence: int = 35 # lower precedence than add + + @classmethod + def eval( + cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer + ) -> Optional[sympy.Basic]: + if base == 0 or modulus == 1: + return sympy.S.Zero + + if ( + isinstance(base, sympy.Integer) + and isinstance(divisor, sympy.Integer) + and isinstance(modulus, sympy.Integer) + ): + return (base // divisor) % modulus + + try: + if divisor != 1: + gcd = sympy.gcd(base, divisor) + if gcd != 1: + return ModularIndexing( + sympy.simplify(base / gcd), + sympy.simplify(divisor / gcd), + modulus, + ) + except sympy.PolynomialError: + pass # https://github.com/pytorch/pytorch/issues/108276 + + if isinstance(base, sympy.Add): + new_terms: list[sympy.Integer] = [] + all_positive: bool = True + for term in base.args: + if sympy.gcd(term, modulus * divisor) != modulus * divisor: + if (isinstance(term, sympy.Integer) and term < 0) or ( + isinstance(term, sympy.Mul) + and isinstance(term.args[0], sympy.Integer) + and term.args[0] < 0 + ): + # workaround for https://github.com/triton-lang/triton/issues/619, + # if there are negative terms, // produces wrong result + # TODO if https://github.com/triton-lang/triton/issues/619 is fixed + # this optimization would become valid + all_positive = False + break + else: + new_terms.append(term) + + if len(new_terms) != len(base.args) and all_positive: + return ModularIndexing(sum(new_terms), divisor, modulus) + + if isinstance(base, FloorDiv): + return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) + + return None + + def _eval_is_nonnegative(self) -> Optional[bool]: + p, q = self.args[:2] + return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] + + +class Where(sympy.Function): + """ + Good ol' ternary operator + """ + + nargs: tuple[int, ...] = (3,) + precedence: int = 35 # lower precedence than add + + def _eval_is_integer(self) -> Optional[bool]: + return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] + + def _eval_is_nonnegative(self) -> Optional[bool]: + return ( + True + if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] + else None + ) + + def _eval_is_positive(self) -> Optional[bool]: + return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] + + @classmethod + def eval( + cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic + ) -> Optional[sympy.Basic]: + if c == sympy.true: + return p + elif c == sympy.false: + return q + return None + + +# Python-style modulus: take sign from RHS +class PythonMod(sympy.Function): + nargs: tuple[int, ...] = (2,) + + precedence: int = 35 # lower precedence than add + is_integer: bool = True + + @classmethod + def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]: + # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint + # Triggered by sympy.solvers.inequalities.reduce_inequalities + # assert p.is_integer, p + # assert q.is_integer, q + + if q.is_zero: + raise ZeroDivisionError("Modulo by zero") + + # Three cases: + # 1. p == 0 + # 2. p is either q or -q + # 3. p is integer and q == 1 + if p is S.Zero or p in (q, -q) or q == 1: + return S.Zero + + # Evaluate if they are both literals. + if q.is_Number and p.is_Number: + return p % q + + # If q == 2, it's a matter of whether p is odd or even. + if q.is_Number and q == 2: + if p.is_even: + return S.Zero + if p.is_odd: + return S.One + + # If p is a multiple of q. + r = p / q + if r.is_integer: + return S.Zero + + # If p < q and its ratio is positive, then: + # - floor(p / q) = 0 + # - p % q = p - floor(p / q) * q = p + less = p < q + if less.is_Boolean and bool(less) and r.is_positive: + return p + + if sympy.Mod(p, q) == 0: + return S.Zero + + return None + + # NB: args[1] for PythonMod + def _eval_is_nonnegative(self) -> Optional[bool]: + return True if self.args[1].is_positive else None # type: ignore[attr-defined] + + def _eval_is_nonpositive(self) -> Optional[bool]: + return True if self.args[1].is_negative else None # type: ignore[attr-defined] + + +# Generic modulus: only defined on non-negative arguments +class Mod(sympy.Function): + nargs = (2,) + precedence: int = 35 # lower precedence than add + + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, p, q): + # This was adapted from: sympy/core/mod.py + + # Triggered by + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # assert p.is_integer, p + # assert q.is_integer, q + + if q.is_zero: + raise ZeroDivisionError("Modulo by zero") + + # Three cases: + # 1. p == 0 + # 2. p is either q or -q + # 3. p is integer and q == 1 + if p is S.Zero or p in (q, -q) or q == 1: + return S.Zero + + # Evaluate if they are both literals. + if q.is_Number and p.is_Number: + assert p >= 0, p + assert q >= 1, q + return p % q + + # If q == 2, it's a matter of whether p is odd or even. + if q.is_Number and q == 2: + if p.is_even: + return S.Zero + if p.is_odd: + return S.One + + # If p is a multiple of q. + r = p / q + if r.is_integer: + return S.Zero + + # If p < q and its ratio is positive, then: + # - floor(p / q) = 0 + # - p % q = p - floor(p / q) * q = p + less = p < q + if less.is_Boolean and bool(less) and r.is_positive: + return p + + +class CleanDiv(FloorDiv): + """ + Div where we can assume no rounding. + This is to enable future optimizations. + """ + + +# Don't use sympy ceiling/floor as they will attempt simplifications involving +# frac +class CeilToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number in (sympy.oo, int_oo): + return int_oo + if number in (-sympy.oo, -int_oo): + return -int_oo + if isinstance(number, sympy.Number): + return sympy.Integer(math.ceil(float(number))) + + def _ccode(self, printer): + number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5) + return f"ceil({number})" + + +class FloorToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + if number in (sympy.oo, int_oo): + return int_oo + if number in (-sympy.oo, int_oo): + return -int_oo + if isinstance(number, sympy.Integer): + return number + if isinstance(number, sympy.Number): + return sympy.Integer(math.floor(float(number))) + + +class CeilDiv(sympy.Function): + """ + Div used in indexing that rounds up. + """ + + is_integer = True + + def __new__(cls, base, divisor): + base = sympy.sympify(base) + divisor = sympy.sympify(divisor) + if sympy.gcd(base, divisor) == divisor: + return CleanDiv(base, divisor) + else: + return FloorDiv(base + (divisor - 1), divisor) + + +class LShift(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, base, shift): + if shift < 0: + raise ValueError("negative shift count") + return base * 2**shift + + +class RShift(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, base, shift): + if shift < 0: + raise ValueError("negative shift count") + return FloorDiv(base, 2**shift) + + +class MinMaxBase(Expr, LatticeOp): # type: ignore[misc] + def __new__(cls, *original_args, **assumptions): + from sympy.core.parameters import global_parameters + + evaluate = assumptions.pop("evaluate", global_parameters.evaluate) + args = (sympify(arg) for arg in original_args) + + # See the comment in _satisfy_unique_summations_symbols. + unique_summations_symbols = ( + None + if not evaluate + else cls._satisfy_unique_summations_symbols(original_args) + ) + + if evaluate: + try: + # first standard filter, for cls.zero and cls.identity + # also reshape Max(a, Max(b, c)) to Max(a, b, c) + args = frozenset(cls._new_args_filter(args)) # type: ignore[assignment] + except ShortCircuit: + return cls.zero # type: ignore[attr-defined] + + # No need to run _collapse_arguments and _find_localzeros, see the comment + # in _satisfy_unique_summations_symbols. + if unique_summations_symbols is None: + # remove redundant args that are easily identified + args = cls._collapse_arguments(args, **assumptions) + + # find local zeros + args = cls._find_localzeros(args, **assumptions) + + args = frozenset(args) + + if not args: + return cls.identity # type: ignore[attr-defined] + + if len(args) == 1: + return list(args).pop() + + # base creation + obj = Expr.__new__(cls, *ordered(args), **assumptions) + obj._argset = args + + obj.unique_summations_symbols = unique_summations_symbols + return obj + + @classmethod + def _satisfy_unique_summations_symbols( + cls, args + ) -> Optional[set[sympy.core.symbol.Symbol]]: + """ + One common case in some models is building expressions of the form + max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...). + For such expressions, we call the Max constructor X times (once for each nested + max) and the expression gets flattened. + + An expensive cost in constructing those expressions is running _collapse_arguments + and _find_localzeros. However, those two optimizations are unnecessary when the args + to max are all of the form a+b, c+d, ..etc where each term uses a unique set of symbols. + + This function is used to detect such properties of the expressions we are building + and if so inform that we do not need to run those optimizations. To detect those, + we store a property in the expression that tells that this expression is a min/max + operation over terms that use unique symbols "unique_summations_symbols". This property + also memoize the set of symbols used in all the terms to make it faster to detect this + property inductively. + + When we apply max to add a new term, all we need to do is check if the new term uses + unique symbols (with respect to existing terms and itself). + Example: + t = Max(a+b, c+d) ==> satisfies the property + Max(t, h+j) ==> h,j not in [a,b,c,d] => satisfy the property. + + The function returns None if the new expression does not satisfy the unique_summations_symbols + property. Otherwise, it returns a new set of unique symbols. + """ + if len(args) != 2: + return None + + (lhs, rhs) = ( + (args[1], args[0]) + if isinstance(args[1], MinMaxBase) + else (args[0], args[1]) + ) + + if not _is_symbols_binary_summation(rhs): + return None + + # base case max(a+b, c+d) ==> satisfies the property if a+b and c+d use unique symbols. + if _is_symbols_binary_summation(lhs): + return cls._unique_symbols(args) + + # inductive case max(t, h+j) ==> satisfies the property if h, j not in t.unique_summations_symbols + if isinstance(lhs, MinMaxBase): + lhs_unique_summations_symbols = getattr( + lhs, "unique_summations_symbols", None + ) + if lhs_unique_summations_symbols is not None: + return cls._unique_symbols([rhs], lhs_unique_summations_symbols) + + return None + + @classmethod + def _unique_symbols( + cls, args, initial_set: Optional[set[sympy.core.symbol.Symbol]] = None + ) -> Optional[set[sympy.core.symbol.Symbol]]: + """ + Return seen_symbols if all atoms in all args are all unique symbols, + else returns None. initial_set can be used to represent initial value for seen_symbols + """ + seen_symbols = set() if initial_set is None else initial_set + for arg in args: + for element in arg.atoms(): + if not isinstance(element, sympy.core.symbol.Symbol): + return None + elif element in seen_symbols: + return None + else: + seen_symbols.add(element) + return seen_symbols + + @classmethod + def _collapse_arguments(cls, args, **assumptions): + """Remove redundant args. + + Examples + ======== + + >>> from sympy import Min, Max + >>> from sympy.abc import a, b, c, d, e + + Any arg in parent that appears in any + parent-like function in any of the flat args + of parent can be removed from that sub-arg: + + >>> Min(a, Max(b, Min(a, c, d))) + Min(a, Max(b, Min(c, d))) + + If the arg of parent appears in an opposite-than parent + function in any of the flat args of parent that function + can be replaced with the arg: + + >>> Min(a, Max(b, Min(c, d, Max(a, e)))) + Min(a, Max(b, Min(a, c, d))) + """ + if not args: + return args + args = list(ordered(args)) + if cls is Min: + other = Max + else: + other = Min # type: ignore[assignment] + + # find global comparable max of Max and min of Min if a new + # value is being introduced in these args at position 0 of + # the ordered args + if args[0].is_number: + sifted = mins, maxs = [], [] # type: ignore[var-annotated] + for i in args: + for v in walk(i, Min, Max): + if v.args[0].is_comparable: + sifted[isinstance(v, Max)].append(v) + small = Min.identity + for i in mins: + v = i.args[0] + if v.is_number and (v < small) == True: # noqa: E712 + small = v + big = Max.identity + for i in maxs: + v = i.args[0] + if v.is_number and (v > big) == True: # noqa: E712 + big = v + # at the point when this function is called from __new__, + # there may be more than one numeric arg present since + # local zeros have not been handled yet, so look through + # more than the first arg + if cls is Min: + for arg in args: + if not arg.is_number: + break + if (arg < small) == True: # noqa: E712 + small = arg + elif cls == Max: + for arg in args: + if not arg.is_number: + break + if (arg > big) == True: # noqa: E712 + big = arg + T = None + if cls is Min: + if small != Min.identity: + other = Max + T = small + elif big != Max.identity: + other = Min # type: ignore[assignment] + T = big + if T is not None: + # remove numerical redundancy + for i in range(len(args)): + a = args[i] + if isinstance(a, other): + a0 = a.args[0] + if ( # noqa: E712 + (a0 > T) if other == Max else (a0 < T) # noqa: E712 + ) == True: # noqa: E712 + args[i] = cls.identity # type: ignore[attr-defined] + + # remove redundant symbolic args + def do(ai, a): + if not isinstance(ai, (Min, Max)): + return ai + cond = a in ai.args + if not cond: + return ai.func(*[do(i, a) for i in ai.args], evaluate=False) + if isinstance(ai, cls): + return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False) + return a + + for i, a in enumerate(args): + args[i + 1 :] = [do(ai, a) for ai in args[i + 1 :]] + + # factor out common elements as for + # Min(Max(x, y), Max(x, z)) -> Max(x, Min(y, z)) + # and vice versa when swapping Min/Max -- do this only for the + # easy case where all functions contain something in common; + # trying to find some optimal subset of args to modify takes + # too long + + def factor_minmax(args): + is_other = lambda arg: isinstance(arg, other) # noqa: E731 + other_args, remaining_args = sift(args, is_other, binary=True) + if not other_args: + return args + + # Min(Max(x, y, z), Max(x, y, u, v)) -> {x,y}, ({z}, {u,v}) + arg_sets = [set(arg.args) for arg in other_args] + common = set.intersection(*arg_sets) + if not common: + return args + + new_other_args = list(common) + arg_sets_diff = [arg_set - common for arg_set in arg_sets] + + # If any set is empty after removing common then all can be + # discarded e.g. Min(Max(a, b, c), Max(a, b)) -> Max(a, b) + if all(arg_sets_diff): + other_args_diff = [other(*s, evaluate=False) for s in arg_sets_diff] + new_other_args.append(cls(*other_args_diff, evaluate=False)) + + other_args_factored = other(*new_other_args, evaluate=False) + return remaining_args + [other_args_factored] + + if len(args) > 1: + args = factor_minmax(args) + + return args + + @classmethod + def _new_args_filter(cls, arg_sequence): + """ + Generator filtering args. + + first standard filter, for cls.zero and cls.identity. + Also reshape ``Max(a, Max(b, c))`` to ``Max(a, b, c)``, + and check arguments for comparability + """ + for arg in arg_sequence: + # pre-filter, checking comparability of arguments + if ( + not isinstance(arg, Expr) + or arg.is_extended_real is False + or (arg.is_number and not arg.is_comparable) + ): + raise ValueError(f"The argument '{arg}' is not comparable.") + + if arg == cls.zero: # type: ignore[attr-defined] + raise ShortCircuit(arg) + elif arg == cls.identity: # type: ignore[attr-defined] + continue + elif arg.func == cls: + yield from arg.args + else: + yield arg + + @classmethod + def _find_localzeros(cls, values, **options): + """ + Sequentially allocate values to localzeros. + + When a value is identified as being more extreme than another member it + replaces that member; if this is never true, then the value is simply + appended to the localzeros. + + Unlike the sympy implementation, we only look for zero and one, we don't + do generic is connected test pairwise which is slow + """ + + # First, collapse all numeric arguments + other_values = set() + num_value = None + for arg in values: + if arg.is_Number: + if num_value is None: + num_value = arg + else: + if cls is Max: + num_value = max(num_value, arg) + elif cls is Min: + num_value = min(num_value, arg) + else: + raise AssertionError(f"impossible {cls}") + else: + other_values.add(arg) + + # Special cases when there is only one symbolic value + if num_value is None: + return other_values + + if len(other_values) == 0: + return {num_value} + + if len(other_values) == 1: + other_value = next(iter(other_values)) + if num_value in (0.0, 0) and other_value.is_nonnegative: + return other_values if cls is Max else {num_value} + if num_value == 1 and other_value.is_positive: + return other_values if cls is Max else {num_value} + + other_values.add(num_value) + return other_values + + _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 + _eval_is_antihermitian = lambda s: _torf( # noqa: E731 + i.is_antihermitian for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_commutative = lambda s: _torf( # noqa: E731 + i.is_commutative for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731 + _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731 + _eval_is_even = lambda s: _torf(i.is_even for i in s.args) # noqa: E731 + _eval_is_finite = lambda s: _torf(i.is_finite for i in s.args) # noqa: E731 + _eval_is_hermitian = lambda s: _torf(i.is_hermitian for i in s.args) # noqa: E731 + _eval_is_imaginary = lambda s: _torf(i.is_imaginary for i in s.args) # noqa: E731 + _eval_is_infinite = lambda s: _torf(i.is_infinite for i in s.args) # noqa: E731 + _eval_is_integer = lambda s: _torf(i.is_integer for i in s.args) # noqa: E731 + _eval_is_irrational = lambda s: _torf(i.is_irrational for i in s.args) # noqa: E731 + _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731 + _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731 + _eval_is_nonnegative = lambda s: _torf( # noqa: E731 + i.is_nonnegative for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_nonpositive = lambda s: _torf( # noqa: E731 + i.is_nonpositive for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731 + _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731 + _eval_is_polar = lambda s: _torf(i.is_polar for i in s.args) # noqa: E731 + _eval_is_positive = lambda s: _torf(i.is_positive for i in s.args) # noqa: E731 + _eval_is_prime = lambda s: _torf(i.is_prime for i in s.args) # noqa: E731 + _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731 + _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731 + _eval_is_extended_real = lambda s: _torf( # noqa: E731 + i.is_extended_real for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_transcendental = lambda s: _torf( # noqa: E731 + i.is_transcendental for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731 + + +class Max(MinMaxBase, Application): # type: ignore[misc] + r""" + Return, if possible, the maximum value of the list. + """ + + zero = S.Infinity + identity = S.NegativeInfinity + + def _eval_is_positive(self): # type:ignore[override] + return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined] + + def _eval_is_nonnegative(self): # type:ignore[override] + return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] + + def _eval_is_negative(self): # type:ignore[override] + return fuzzy_and(a.is_negative for a in self.args) + + +class Min(MinMaxBase, Application): # type: ignore[misc] + """ + Return, if possible, the minimum value of the list. + """ + + zero = S.NegativeInfinity + identity = S.Infinity + + def _eval_is_positive(self): # type:ignore[override] + return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined] + + def _eval_is_nonnegative(self): # type:ignore[override] + return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] + + def _eval_is_negative(self): # type:ignore[override] + return fuzzy_or(a.is_negative for a in self.args) + + +def safe_pow(base, exp): + sign = 1 + if base < 0: + base = -base + sign = 1 if exp % 2 == 0 else -1 + return sign * _safe_pow(base, exp) + + +# Prevent people from overflowing pow +def _safe_pow(base, exponent): + if exponent < 0: + raise ValueError("Exponent must be non-negative.") + + if exponent == 0: + return 1 + + half_exp = safe_pow(base, exponent // 2) + if half_exp is int_oo: + return int_oo + + # TODO: microoptimization is to avoid overflowing into arbitrary precision + # and detect overflow prior to doing operations + + result = half_exp * half_exp + if result > sys.maxsize: + return int_oo + + if exponent % 2 == 1: + result *= base + if result > sys.maxsize: + return int_oo + + return result + + +class PowByNatural(sympy.Function): + is_integer = True + + precedence: int = 50 # precedence of mul + + @classmethod + def eval(cls, base, exp): + if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer): + r = safe_pow(base, exp) + if r in (-int_oo, int_oo): + return r + return sympy.Integer(r) + if isinstance(exp, sympy.Integer): + # Rely on regular sympy Pow for this (note that iterated + # multiplication turns into a Pow anyway, you can't escape!!) + return sympy.Pow(base, exp) + if exp in (int_oo, sympy.oo): + if base.is_nonnegative: + return int_oo + elif base.is_negative: + return sympy.zoo # this is apparently what (-2)**sympy.oo does + # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp + # is a natural number if we do + + +# base is assumed to be nonnegative, thereby prevent complex numbers from +# occuring +class FloatPow(sympy.Function): + is_real = True + + precedence: int = 60 # precedence of pow + + @classmethod + def eval(cls, base, exp): + # NB: These test sympy.Number, not sympy.Float, because: + # - Sometimes we may have sympy.oo or int_oo, and that's not a Float + # (but coerces to math.Inf) + # - Sometimes Float(0.0) will unpredictably decay to Integer(0), + # but we should still accept it in floatey contexts + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Float(float(base) ** float(exp)) + # NB: do not do any nontrivial reasoning + + +# Overloaded to be compatible with regular Python. +# https://github.com/pytorch/pytorch/issues/90900 +# +# In particular, sympy division is willing to simplify x/x == 1 +# where 1 is an integer, but this must be a float if x was float. +class FloatTrueDiv(sympy.Function): + is_real = True + + precedence: int = 35 # lower precedence than add + + @classmethod + def eval(cls, base, divisor): + # assert base.is_integer is not True, base + # assert divisor.is_integer is not True, divisor + + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(float(base) / float(divisor)) + + +# Overloaded to be compatible with regular Python. We distinguish this from +# FloatTrueDiv, because the code generation has to be different for this case: +# Python has a fancy algorithm for integer true division that isn't just +# "promote both arguments to float and use float division", so you need to +# codegen it differently. While technically you can work it out from the +# types of the input, this is often inconvenient to do in Inductor codegen, +# so just have a different operator +# NB: Right now, Inductor codegen doesn't implement this correctly lol +class IntTrueDiv(sympy.Function): + is_real = True + + precedence: int = 35 # lower precedence than add + + @classmethod + def eval(cls, base, divisor): + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + + if ( + isinstance(base, sympy.Number) + and isinstance(divisor, sympy.Number) + and ( + base in (int_oo, -int_oo, sympy.oo, -sympy.oo) + or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo) + ) + ): + # Don't have to worry about precision here, you're getting zero or + # inf from the division + return sympy.Float(float(base) / float(divisor)) + if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): + return sympy.Float(int(base) / int(divisor)) + + def _ccode(self, printer): + base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) + divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5) + return f"((int){base}/(int){divisor})" + + +# TODO: As an indicator, this != 0 implies == 1 (and vice versa). +# Because we do not have the ability to guard on the stride permutation +# at the moment, it is hard to make further inferences when this is true, +# as although we know the tensor is contiguous in *some* layout, we don't +# know which one (however, you could, for example, make the inference that +# reshaping this to a 1D tensor can be guard-free.) +class IsNonOverlappingAndDenseIndicator(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, *args): + assert len(args) % 2 == 0 + dim = len(args) // 2 + sizes = args[0:dim] + strides = args[dim:] + + # sym_node imported in torch.__init__. Local import to avoid an import cycle + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + + if all(isinstance(a, sympy.Integer) for a in args): + return eval_is_non_overlapping_and_dense( + [int(a) for a in sizes], [int(a) for a in strides] + ) + + if dim == 1: + # Manually implement the rank one short circuit + if strides[0].is_Number and strides[0] == 1: + return 1 + + if sizes[0].is_Number and sizes[0] < 2: + return 1 + + # return 0 case covered by case above + + # TODO: Inability to access size-obliviousness sucks: if we have a + # size oblivious test on a size-like unbacked SymInt, we could + # confidently return zero when we have a size-like u0 stride + # and a size-like u1 size. Maybe a fancy ValueRanges analysis for + # this function could help figure this out. + + if all(isinstance(a, sympy.Integer) for a in strides): + assert dim != 0 + # When all strides are integral, we can sort, and the size for the + # largest stride doesn't matter and can be arbitrarily symbolic + s_sizes, s_strides = zip( + *sorted(zip(sizes, strides), key=operator.itemgetter(1)) + ) + # Put something arbitrary in the max size spot, it'll be ignored + if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]): + s_sizes = s_sizes[:-1] + (42,) + # We can reuse the regular eval, because it is invariant to + # permutation of dimensions + return eval_is_non_overlapping_and_dense( + [int(a) for a in s_sizes], [int(a) for a in s_strides] + ) + + return None + + +# NB: this is inconsistent with math.trunc in Python +class TruncToFloat(sympy.Function): + is_real = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if isinstance(number, sympy.Number): + # NB: It is safe to use truncation to integer, which is what + # math.trunc does, as Python integers are arbitrary precision and + # so we are guaranteed not to lose precision when we do this + return sympy.Float(math.trunc(float(number))) + + +class TruncToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number in (sympy.oo, int_oo): + return int_oo + if number in (-sympy.oo, -int_oo): + return -int_oo + if isinstance(number, sympy.Number): + return sympy.Integer(math.trunc(float(number))) + + +# This is float -> int +class RoundToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + + if number is sympy.oo: + return int_oo + if number is -sympy.oo: + return -int_oo + if isinstance(number, sympy.Number): + return sympy.Integer(round(float(number), 0)) + + +# To get float -> int, Python style round semantics. +# +# x = PyFloat_AsDouble(self); +# if (o_ndigits == Py_None) { +# /* single-argument round or with None ndigits: +# * round to nearest integer */ +# rounded = round(x); +# if (fabs(x-rounded) == 0.5) +# /* halfway case: round to even */ +# rounded = 2.0*round(x/2.0); +# return PyLong_FromDouble(rounded); +# } + + +# NB: Like Round, this only ever returns floats. ndigits cannot be None +class RoundDecimal(sympy.Function): + is_real = True + + @classmethod + def eval(cls, number, ndigits): + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): + return sympy.Float(round(float(number), int(ndigits))) + + +class ToFloat(sympy.Function): + is_real = True + + @classmethod + def eval(cls, number): + if number in [sympy.oo, -sympy.oo]: + return number + + if isinstance(number, sympy.Integer): + return sympy.Float(int(number)) + if number is int_oo: + return sympy.oo + if number is -int_oo: + return -sympy.oo + + +class Identity(sympy.Function): + """ + Prevents expansion and other optimizations + """ + + precedence = 10 + + def __repr__(self): # type: ignore[override] + return f"Identity({self.args[0]})" + + def _eval_is_real(self): + return self.args[0].is_real + + def _eval_is_integer(self): + return self.args[0].is_integer # type: ignore[attr-defined] + + def _eval_expand_identity(self, **hints): + # Removes the identity op. + return self.args[0] + + def __int__(self) -> int: + return int(self.args[0]) + + def __float__(self) -> float: + return float(self.args[0]) + + +def make_opaque_unary_fn(name): + class OpaqueUnaryFn(sympy.Function): + """ + Unlike the builtin sympy functions on real numbers like sympy.sqrt, + these equivalents do not do any nontrivial reasoning besides + constant propagation. This helps avoid performing transformations + that are valid for real numbers but are invalid for floating point; + in particular, while we are willing to make optimizations that change + numerics for Tensor compute, we are NOT willing to make optimziations + that change numerics for size compute. + """ + + _torch_handler_name = name + _torch_unpickler = make_opaque_unary_fn + + @classmethod + def eval(cls, a): + if isinstance(a, (sympy.Integer, sympy.Float)): + # Python converts to float64 before computing, c.f. + # >>> math.sin(2**53+1) + # -0.848925964814655 + # >>> math.sin(float(2**53+1)) + # -0.848925964814655 + try: + return sympy.Float(getattr(math, name)(float(a))) + # Just use sympy semantics for infinity/overflow, you might get some + # weird objects but ask silly questions, get silly answers + except OverflowError: + return getattr(sympy, name)(a) + elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]: + if a is int_oo: + a = sympy.oo + if a is -int_oo: + a = -sympy.oo + if name == "log2": + return sympy.log(a, 2) + return getattr(sympy, name)(a) + return None + + nm = "OpaqueUnaryFn_" + name + OpaqueUnaryFn.__name__ = nm + OpaqueUnaryFn.__qualname__ = nm + + return OpaqueUnaryFn + + +# Keep in sync with math_op_names in torch/fx/experimental/sym_node.py +OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt") +OpaqueUnaryFn_cos = make_opaque_unary_fn("cos") +OpaqueUnaryFn_cosh = make_opaque_unary_fn("cosh") +OpaqueUnaryFn_sin = make_opaque_unary_fn("sin") +OpaqueUnaryFn_sinh = make_opaque_unary_fn("sinh") +OpaqueUnaryFn_tan = make_opaque_unary_fn("tan") +OpaqueUnaryFn_tanh = make_opaque_unary_fn("tanh") +OpaqueUnaryFn_asin = make_opaque_unary_fn("asin") +OpaqueUnaryFn_acos = make_opaque_unary_fn("acos") +OpaqueUnaryFn_atan = make_opaque_unary_fn("atan") +OpaqueUnaryFn_exp = make_opaque_unary_fn("exp") +OpaqueUnaryFn_log = make_opaque_unary_fn("log") +OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh") +OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2") + + +def make_opaque_bitwise_fn(name, real_op_name): + if name == "bitwise_and": + prec = PRECEDENCE["BitwiseAnd"] + elif name == "bitwise_or": + prec = PRECEDENCE["BitwiseOr"] + else: + raise AssertionError(f"unrecognized {name}") + + class BitwiseFn(sympy.Function): + _torch_handler_name = name + precedence: int = prec + _torch_unpickler = functools.partial( + make_opaque_bitwise_fn, real_op_name=real_op_name + ) + + @classmethod + def eval(cls, a, b): + if a.is_Boolean and b.is_Boolean: + return getattr(operator, real_op_name)(a, b) + if a.is_Boolean: + a = sympy.Integer(1 if a else 0) + if b.is_Boolean: + b = sympy.Integer(1 if b else 0) + if isinstance(a, (sympy.Integer, int)) and isinstance( + b, (sympy.Integer, int) + ): + return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b))) + return None + + BitwiseFn.__name__ = "BitwiseFn_" + name + return BitwiseFn + + +BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_") +BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_") diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/interp.py b/phivenv/Lib/site-packages/torch/utils/_sympy/interp.py new file mode 100644 index 0000000000000000000000000000000000000000..06e01dc5b9352a9161984d4a17815eb2c4f3533c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_sympy/interp.py @@ -0,0 +1,225 @@ +# mypy: allow-untyped-defs +""" +This is a simple interpreter for Sympy expressions that dispatches to +classes following the torch._inductor.virtualized calling convention. +For directness, the interpreter takes the handler directly rather than +consulting the TLS. It does not use most of the methods on the full +handler; only those with corresponding Sympy expressions. To see an example +of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis. +""" + +import functools +import logging +from typing import Any, Union + +import sympy +from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom + +import torch + +from .functions import ( + BitwiseFn_bitwise_and, + BitwiseFn_bitwise_or, + CeilToInt, + CleanDiv, + FloatPow, + FloatTrueDiv, + FloorDiv, + FloorToInt, + Identity, + IntTrueDiv, + IsNonOverlappingAndDenseIndicator, + Max, + Min, + Mod, + ModularIndexing, + OpaqueUnaryFn_log2, + PowByNatural, + PythonMod, + RoundDecimal, + RoundToInt, + ToFloat, + TruncToFloat, + TruncToInt, + Where, +) + + +log = logging.getLogger(__name__) + + +# TODO: Dedupe this with SYMPY_INTERP + + +@functools.cache +def handlers(): + # TODO add CeilDiv (it doesn't appear in the index_expr) + + # TODO default to some decompositions if the interpreter doesn't have them + # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a) + + HANDLERS = { + sympy.Or: "or_", + sympy.And: "and_", + sympy.Eq: "eq", + sympy.Ne: "ne", + sympy.Lt: "lt", + sympy.Gt: "gt", + sympy.Le: "le", + sympy.Ge: "ge", + sympy.Not: "not_", + IntTrueDiv: "int_truediv", + FloatTrueDiv: "truediv", + FloorDiv: "floordiv", + CleanDiv: "floordiv", # TODO: hmm? + TruncToFloat: "trunc", + Where: "where", + sympy.Add: "add", + sympy.Mul: "mul", + FloatPow: "pow", + PowByNatural: "pow_by_natural", + # sympy simplifies x * x into Pow(x, 2), so we need to handle this. + # Do NOT use builtin Pow for floats + # TODO: There is a hazard here, if we have float * float it will + # also get turned into Pow(float, 2) but we don't want this because + # pow_by_natural is assumed to only be integers. Probably the fix is + # to add a FloatMul to impede this optimization + sympy.Pow: "pow_by_natural", + Mod: "mod", + PythonMod: "mod", # TODO: this is wrong + # TODO: Inductor can generate these, but it's ill-specified which + # semantics were intended here. Needs to be cleaned up along with + # FloorDiv in a bigger cleanup + sympy.Mod: "mod", + sympy.Abs: "abs", + sympy.log: "log", + sympy.exp: "exp", + sympy.Min: "minimum", + sympy.Max: "maximum", + Min: "minimum", + Max: "maximum", + ModularIndexing: "modular_indexing", + sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", + sympy.Piecewise: "piecewise", + Identity: "identity", + IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", + RoundDecimal: "round_decimal", + # TODO: do the rest of the opaque unary functions... + OpaqueUnaryFn_log2: "log2", + BitwiseFn_bitwise_and: "bitwise_and", + BitwiseFn_bitwise_or: "bitwise_or", + } + # TODO: This is kind of pointless, we shouldn't be generating sympy.sin + # for these functions, they should be Opaque instead + for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: + HANDLERS[getattr(sympy, name)] = name + + return HANDLERS + + +ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"} + + +def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64): + # Special cases + if isinstance(expr, sympy.Pow) and isinstance( + expr.args[1], sympy.core.numbers.Half + ): + return analysis.sqrt(args[0]) + if isinstance(expr, ToFloat): + return analysis.to_dtype(args[0], torch.float64) + + # These handlers are special because they take an extra dtype argument + # specifying what they should convert to, and we need to appropriately set + # this up when we convert from Sympy. A reasonable default when you + # are translating is to conservatively do int64, and then narrow these + # arguments later when you discover you can narrow the index range. But + # if you already know that 32-bit indexing is OK, you can directly do the + # sympy translation with index_dtype=torch.int32 + INDEX_DTYPE_HANDLERS = { + TruncToInt: "trunc_to_int", + sympy.floor: "floor_to_int", + sympy.ceiling: "ceil_to_int", + FloorToInt: "floor_to_int", + CeilToInt: "ceil_to_int", + RoundToInt: "round_to_int", + } + if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: + return getattr(analysis, handler_name)(*args, index_dtype) + + # Fastpath for n-ary integral addition + if expr.func is sympy.Add and expr.is_integer and hasattr(analysis, "sym_sum"): + r = analysis.sym_sum(args) + log.debug("sym_sum(%s) -> %s", args, r) + return r + + if hasattr(expr.func, "_torch_handler_name"): + handler_name = expr.func._torch_handler_name + else: + handler_name = handlers()[expr.func] + handler = getattr(analysis, handler_name) + try: + if handler_name in ASSOCIATIVE_OPS: + assert len(args) > 1 + acc = handler(args[0], args[1]) + for i in range(2, len(args)): + acc = handler(acc, args[i]) + log.debug("%s(%s) -> %s", handler_name, args, acc) + return acc + else: + r = handler(*args) + log.debug("%s(%s) -> %s", handler_name, args, r) + return r + except NotImplementedError: + raise + except Exception: + log.warning("failed while executing %s(%s)", handler_name, args) + raise + + +_nil = object() + + +def sympy_interp( + analysis, + env: dict[sympy.Symbol, Any], + expr: Union[sympy.Expr, SympyBoolean], + *, + index_dtype=torch.int64, + missing_handler=None, +): + # Handle base cases + dtype = None + if isinstance(expr, BooleanAtom): + dtype = torch.bool + elif isinstance(expr, sympy.Integer): + dtype = torch.int64 + elif isinstance(expr, sympy.Number): + dtype = torch.double + + if dtype is not None: + return analysis.constant(expr, dtype) + elif isinstance(expr, sympy.Symbol): + if (r := env.get(expr, _nil)) is not _nil: + return r + elif missing_handler: + return missing_handler(expr) + else: + raise KeyError(expr) + + # Recursive case + return _run_sympy_handler( + analysis, + [ + sympy_interp( + analysis, + env, + arg, + index_dtype=index_dtype, + missing_handler=missing_handler, + ) + for arg in expr.args + ], # type: ignore[arg-type] + expr, + index_dtype=index_dtype, + ) # type: ignore[arg-type] diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/numbers.py b/phivenv/Lib/site-packages/torch/utils/_sympy/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..5d79848f5c3e098ccb69441fec411dffcaffde27 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_sympy/numbers.py @@ -0,0 +1,397 @@ +# mypy: allow-untyped-defs +import mpmath.libmp as mlib # type: ignore[import-untyped] +import sympy +from sympy import Expr +from sympy.core.decorators import _sympifyit +from sympy.core.expr import AtomicExpr +from sympy.core.numbers import Number +from sympy.core.parameters import global_parameters +from sympy.core.singleton import S, Singleton + + +class IntInfinity(Number, metaclass=Singleton): + r"""Positive integer infinite quantity. + + Integer infinity is a value in an extended integers which + is greater than all other integers. We distinguish it from + sympy's existing notion of infinity in that it reports that + it is_integer. + + Infinity is a singleton, and can be accessed by ``S.IntInfinity``, + or can be imported as ``int_oo``. + """ + + # NB: We can't actually mark this as infinite, as integer and infinite are + # inconsistent assumptions in sympy. We also report that we are complex, + # different from sympy.oo + + is_integer = True + is_commutative = True + is_number = True + is_extended_real = True + is_comparable = True + is_extended_positive = True + is_prime = False + + # Ensure we get dispatched to before plain numbers + _op_priority = 100.0 + + __slots__ = () + + def __new__(cls): + return AtomicExpr.__new__(cls) + + def _sympystr(self, printer): + return "int_oo" + + def _eval_subs(self, old, new): + if self == old: + return new + + # We could do these, not sure about it + """ + def _eval_evalf(self, prec=None): + return Float('inf') + + def evalf(self, prec=None, **options): + return self._eval_evalf(prec) + """ + + @_sympifyit("other", NotImplemented) + def __add__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other in (S.Infinity, S.NegativeInfinity): + return other + if other in (S.NegativeIntInfinity, S.NaN): + return S.NaN + return self + return Number.__add__(self, other) + + __radd__ = __add__ + + @_sympifyit("other", NotImplemented) + def __sub__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.Infinity: + return S.NegativeInfinity + if other is S.NegativeInfinity: + return S.Infinity + if other in (S.IntInfinity, S.NaN): + return S.NaN + return self + return Number.__sub__(self, other) + + @_sympifyit("other", NotImplemented) + def __rsub__(self, other): + return (-self).__add__(other) + + @_sympifyit("other", NotImplemented) + def __mul__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other.is_zero or other is S.NaN: + return S.NaN + if other.is_extended_positive: + return self + return S.NegativeIntInfinity + return Number.__mul__(self, other) + + __rmul__ = __mul__ + + @_sympifyit("other", NotImplemented) + def __truediv__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other in ( + S.Infinity, + S.IntInfinity, + S.NegativeInfinity, + S.NegativeIntInfinity, + S.NaN, + ): + return S.NaN + if other.is_extended_nonnegative: + return S.Infinity # truediv produces float + return S.NegativeInfinity # truediv produces float + return Number.__truediv__(self, other) + + def __abs__(self): + return S.IntInfinity + + def __neg__(self): + return S.NegativeIntInfinity + + def _eval_power(self, expt): + if expt.is_extended_positive: + return S.IntInfinity + if expt.is_extended_negative: + return S.Zero + if expt is S.NaN: + return S.NaN + if expt is S.ComplexInfinity: + return S.NaN + if expt.is_extended_real is False and expt.is_number: + from sympy.functions.elementary.complexes import re + + expt_real = re(expt) + if expt_real.is_positive: + return S.ComplexInfinity + if expt_real.is_negative: + return S.Zero + if expt_real.is_zero: + return S.NaN + + return self ** expt.evalf() + + def _as_mpf_val(self, prec): + return mlib.finf + + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + return other is S.IntInfinity + + def __ne__(self, other): + return other is not S.IntInfinity + + def __gt__(self, other): + if other is S.Infinity: + return sympy.false # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.true + + def __ge__(self, other): + if other is S.Infinity: + return sympy.false # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.true + + def __lt__(self, other): + if other is S.Infinity: + return sympy.true # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.false + + def __le__(self, other): + if other is S.Infinity: + return sympy.true # sympy.oo > int_oo + elif other is S.IntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.false + + @_sympifyit("other", NotImplemented) + def __mod__(self, other): + if not isinstance(other, Expr): + return NotImplemented + return S.NaN + + __rmod__ = __mod__ + + def floor(self): + return self + + def ceiling(self): + return self + + +int_oo = S.IntInfinity + + +class NegativeIntInfinity(Number, metaclass=Singleton): + """Negative integer infinite quantity. + + NegativeInfinity is a singleton, and can be accessed + by ``S.NegativeInfinity``. + + See Also + ======== + + IntInfinity + """ + + # Ensure we get dispatched to before plain numbers + _op_priority = 100.0 + + is_integer = True + is_extended_real = True + is_commutative = True + is_comparable = True + is_extended_negative = True + is_number = True + is_prime = False + + __slots__ = () + + def __new__(cls): + return AtomicExpr.__new__(cls) + + def _eval_subs(self, old, new): + if self == old: + return new + + def _sympystr(self, printer): + return "-int_oo" + + """ + def _eval_evalf(self, prec=None): + return Float('-inf') + + def evalf(self, prec=None, **options): + return self._eval_evalf(prec) + """ + + @_sympifyit("other", NotImplemented) + def __add__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.Infinity: + return S.Infinity + if other in (S.IntInfinity, S.NaN): + return S.NaN + return self + return Number.__add__(self, other) + + __radd__ = __add__ + + @_sympifyit("other", NotImplemented) + def __sub__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other is S.NegativeInfinity: + return S.Infinity + if other in (S.NegativeIntInfinity, S.NaN): + return S.NaN + return self + return Number.__sub__(self, other) + + @_sympifyit("other", NotImplemented) + def __rsub__(self, other): + return (-self).__add__(other) + + @_sympifyit("other", NotImplemented) + def __mul__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other.is_zero or other is S.NaN: + return S.NaN + if other.is_extended_positive: + return self + return S.IntInfinity + return Number.__mul__(self, other) + + __rmul__ = __mul__ + + @_sympifyit("other", NotImplemented) + def __truediv__(self, other): + if isinstance(other, Number) and global_parameters.evaluate: + if other in ( + S.Infinity, + S.IntInfinity, + S.NegativeInfinity, + S.NegativeIntInfinity, + S.NaN, + ): + return S.NaN + if other.is_extended_nonnegative: + return self + return S.Infinity # truediv returns float + return Number.__truediv__(self, other) + + def __abs__(self): + return S.IntInfinity + + def __neg__(self): + return S.IntInfinity + + def _eval_power(self, expt): + if expt.is_number: + if expt in ( + S.NaN, + S.Infinity, + S.NegativeInfinity, + S.IntInfinity, + S.NegativeIntInfinity, + ): + return S.NaN + + if isinstance(expt, sympy.Integer) and expt.is_extended_positive: + if expt.is_odd: + return S.NegativeIntInfinity + else: + return S.IntInfinity + + inf_part = S.IntInfinity**expt + s_part = S.NegativeOne**expt + if inf_part == 0 and s_part.is_finite: + return inf_part + if ( + inf_part is S.ComplexInfinity + and s_part.is_finite + and not s_part.is_zero + ): + return S.ComplexInfinity + return s_part * inf_part + + def _as_mpf_val(self, prec): + return mlib.fninf + + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + return other is S.NegativeIntInfinity + + def __ne__(self, other): + return other is not S.NegativeIntInfinity + + def __gt__(self, other): + if other is S.NegativeInfinity: + return sympy.true # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.false + + def __ge__(self, other): + if other is S.NegativeInfinity: + return sympy.true # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.false + + def __lt__(self, other): + if other is S.NegativeInfinity: + return sympy.false # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.false # consistency with sympy.oo + else: + return sympy.true + + def __le__(self, other): + if other is S.NegativeInfinity: + return sympy.false # -sympy.oo < -int_oo + elif other is S.NegativeIntInfinity: + return sympy.true # consistency with sympy.oo + else: + return sympy.true + + @_sympifyit("other", NotImplemented) + def __mod__(self, other): + if not isinstance(other, Expr): + return NotImplemented + return S.NaN + + __rmod__ = __mod__ + + def floor(self): + return self + + def ceiling(self): + return self + + def as_powers_dict(self): + return {S.NegativeOne: 1, S.IntInfinity: 1} diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/printers.py b/phivenv/Lib/site-packages/torch/utils/_sympy/printers.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfb57dacb0bde531034160e00ef578864844542 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_sympy/printers.py @@ -0,0 +1,506 @@ +import sys +from typing import Optional + +import sympy +from sympy.printing.precedence import PRECEDENCE, precedence +from sympy.printing.str import StrPrinter + + +INDEX_TYPE = "int64_t" +INDEX_TYPE_MAX = (1 << 63) - 1 +INDEX_TYPE_MIN = -1 << 63 + + +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python +class ExprPrinter(StrPrinter): + # override this so that _print_FloorDiv is used + printmethod = "_torch_sympystr" + + def _print_Mul(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, "*", precedence(expr)) + + def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str: + return self.stringify(expr.args, " + ", precedence(expr)) + + def _print_Relational(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr)) + + def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " & ", PRECEDENCE["BitwiseAnd"]) + + def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"]) + + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent + def _print_Mod(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + + def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str: + s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) + return f"({s})" + + def _print_CleanDiv(self, expr: sympy.Expr) -> str: + return self._print_FloorDiv(expr) + + def _print_Identity(self, expr: sympy.Expr) -> str: + return self._print(expr.args[0]) + + def _print_Float(self, expr: sympy.Expr) -> str: + if expr._prec == 53: + # IEEE-754 double precision have 53 bits. SymPy prints them with + # 15 digits, but we need 17 for round-trip correctness + return str(sympy.Float(expr, dps=17)) + else: + # We don't use other precisions in pytorch + return str(expr) + + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + def _print_Pow(self, expr: sympy.Expr) -> str: + base, exp = expr.args + assert exp == int(exp), exp + exp = int(exp) + assert exp >= 0 + if exp > 0: + return self.stringify([base] * exp, "*", PRECEDENCE["Mul"]) + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + + +class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + # NB: We use sym_float here because the printer is used for cache + # serialization, and cache guards get evaluated with SymInt to + # propagate guards to the parent ShapeEnv. However, this comes at a + # runtime cost for guards involving float. If this is unacceptable + # overhead, what you want to do is have two separate printers for + # SymInt, one for when the inputs are guaranteed to be int, and + # another for when they could be SymInt. + # + # NB: sym_min/sym_max also have this problem, but I chose not to fix + # those. + # + # See https://github.com/pytorch/pytorch/issues/142507 for more + # context. + return f"torch.sym_float({self._print(expr.args[0])})" + + def _print_And(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " and ", precedence(expr)) + + def _print_Or(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " or ", precedence(expr)) + + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: + x, div, mod = ( + self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args + ) + if div != "1": + x = f"({x} // {div})" + return f"({x} % {mod})" + + def _print_Infinity(self, expr: sympy.Expr) -> str: + return "math.inf" + + def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args) + return f"{x} // {div}" + + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) + + def _helper_sqrt(self, expr: sympy.Expr) -> str: + return f"math.sqrt({self._print(expr)})" + + def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: + return self._helper_sqrt(expr.args[0]) + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"]) + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"]) + + def _print_floor(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + # This also could have been int(), they'll do the same thing for float + return f"math.trunc({self._print(expr.args[0])})" + + def _print_ceiling(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_CeilToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"abs({self._print(expr.args[0])})" + + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion + def _print_Max(self, expr: sympy.Expr) -> str: + assert len(expr.args) >= 2 + return f"max({', '.join(map(self._print, expr.args))})" + + def _print_Min(self, expr: sympy.Expr) -> str: + assert len(expr.args) >= 2 + return f"min({', '.join(map(self._print, expr.args))})" + + def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.log2({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"round({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + assert isinstance(ndigits, sympy.Integer) + return f"round({self._print(number)}, {ndigits})" + + +class CppPrinter(ExprPrinter): + def _print_Integer(self, expr: sympy.Expr) -> str: + suffix = "LL" if sys.platform in ["darwin", "win32"] else "L" + i = int(expr) + if i > INDEX_TYPE_MAX or i < INDEX_TYPE_MIN: + raise OverflowError(f"{i} too big to convert to {INDEX_TYPE}") + elif i == INDEX_TYPE_MIN: + assert i == (-1) << 63 + # Writing -9223372036854775808L makes the value overflow + # as it is parsed as -(9223372036854775808L) by the C/C++ compiler + return f"(-1{suffix} << 63)" + return f"{i}{suffix}" + + def _print_Where(self, expr: sympy.Expr) -> str: + c, p, q = ( + self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args + ) + return f"{c} ? {p} : {q}" + + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: + x, div, mod = expr.args + x = self.doprint(x) + if div != 1: + div = self.doprint(div) + if expr.is_integer: + x = f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" + else: + x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + mod = self.doprint(mod) + return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))" + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + x, div = expr.args + x = self.doprint(x) + div = self.doprint(div) + if expr.is_integer: + return f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" + return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + + def _print_floor(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::trunc({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" + + def _print_TruncToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::trunc({self._print(expr.args[0])})" + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"static_cast({self._print(expr.args[0])})" + + def _print_PythonMod(self, expr: sympy.Expr) -> str: + x, div = expr.args + x = self.doprint(x) + div = self.doprint(div) + return f"c10::div_mod({x}, {div})" + + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**53 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT + # use std::pow, that operates on floats + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + # Implement the special-case of 2**x for now + base, exp = expr.args + if base == 2: + return f"(1 << ({self._print(exp)}))" + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + base, exp = expr.args + return f"std::pow({self._print(base)}, {self._print(exp)})" + + def _print_Pow(self, expr: sympy.Expr) -> str: + # Uses float constants to perform FP div + base, exp = expr.args + + if exp == 0.5 or exp == -0.5: + base = self._print(base) + return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" + if exp.is_integer: + exp = int(exp) + if exp > 0: + r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"]) + elif exp < -1: + r = ( + "1.0/(" + + self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"]) + + ")" + ) + elif exp == -1: + r = "1.0/" + self._print(base) + else: # exp == 0 + r = "1.0" + + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + else: + # TODO: float vs double + return f"std::pow({base}, {float(exp)})" + + def _print_Rational(self, expr: sympy.Expr) -> str: + # Uses float constants to perform FP div + if expr.q == 1: + r = f"{expr.p}" + else: + r = f"{expr.p}.0/{expr.q}.0" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_ceiling(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_CeilToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Min(self, expr: sympy.Expr) -> str: + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::min<{INDEX_TYPE}>({il})" + + def _print_Max(self, expr: sympy.Expr) -> str: + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::max<{INDEX_TYPE}>({il})" + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: + return f"std::sqrt({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + return f"std::log2({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + # TODO: dispatch to llrint depending on index type + return f"std::lrint({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"static_cast(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})" + + def _print_BooleanTrue(self, expr: sympy.Expr) -> str: + return "true" + + def _print_BooleanFalse(self, expr: sympy.Expr) -> str: + return "false" + + def _print_Infinity(self, expr: sympy.Expr) -> str: + return "std::numeric_limits::infinity()" + + def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: + return f"-{self._print_Infinity(expr)}" diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/reference.py b/phivenv/Lib/site-packages/torch/utils/_sympy/reference.py new file mode 100644 index 0000000000000000000000000000000000000000..023008d9869b1e083c46a3d3c8feda4187edbbcd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_sympy/reference.py @@ -0,0 +1,581 @@ +# mypy: allow-untyped-defs +import math +import operator +from typing import Union + +import sympy + +import torch +from torch.utils._sympy.functions import ( + _keep_float, + BitwiseFn_bitwise_and, + BitwiseFn_bitwise_or, + FloatPow, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, + Max, + Min, + Mod, + OpaqueUnaryFn_exp, + OpaqueUnaryFn_log, + OpaqueUnaryFn_log2, + OpaqueUnaryFn_sqrt, + PowByNatural, + RoundDecimal, + RoundToInt, + ToFloat, + TruncToInt, +) + + +# The sympy interpretation of operators. It will also sometimes work with +# plain int/float, but if you do certain operations you will get out a +# sympy.Basic in the end. If you want the Python/FX traceable interpretation, +# check PythonReferenceAnalysis. +# NB: For magic methods this needs to use normal magic methods +# so that test_magic_methods works +class ReferenceAnalysis: + @staticmethod + def constant(c, dtype): + return sympy.sympify(c) + + @staticmethod + def or_(a, b): + return a | b + + @staticmethod + def and_(a, b): + return a & b + + @staticmethod + def eq(a, b): + if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr): + return sympy.Eq(a, b) + return a == b + + @classmethod + def ne(cls, a, b): + return cls.not_(cls.eq(a, b)) + + @staticmethod + def lt(a, b): + return a < b + + @staticmethod + def gt(a, b): + return a > b + + @staticmethod + def le(a, b): + return a <= b + + @staticmethod + def ge(a, b): + return a >= b + + @staticmethod + def not_(a): + assert not isinstance(a, bool) + return ~a + + @staticmethod + def reciprocal(x): + return FloatTrueDiv(1.0, x) + + @staticmethod + def square(x): + return PowByNatural(x, 2) + + @staticmethod + def trunc_to_int(x, dtype): + return TruncToInt(x) + + @staticmethod + def ceil_to_int(x, dtype): + return sympy.ceiling(x) + + @staticmethod + def floor_to_int(x, dtype): + return sympy.floor(x) + + @staticmethod + def floor(x): + return _keep_float(sympy.floor)(x) + + @staticmethod + def ceil(x): + return _keep_float(sympy.ceiling)(x) + + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return ToFloat(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + @staticmethod + def mod(x, y): + return Mod(x, y) + + @staticmethod + def abs(x): + return abs(x) + + @staticmethod + def neg(x): + return -x + + @staticmethod + def truediv(a, b): + return FloatTrueDiv(a, b) + + @staticmethod + def int_truediv(a, b): + return IntTrueDiv(a, b) + + @staticmethod + def floordiv(a, b): + return FloorDiv(a, b) + + @staticmethod + def truncdiv(a, b): + raise NotImplementedError("TODO: truncdiv") + + @staticmethod + def add(a, b): + return _keep_float(operator.add)(a, b) + + @classmethod + def sym_sum(cls, args): + return sympy.Add(*args) + + @staticmethod + def mul(a, b): + return _keep_float(operator.mul)(a, b) + + @staticmethod + def sub(a, b): + return _keep_float(operator.sub)(a, b) + + @staticmethod + def exp(x): + return OpaqueUnaryFn_exp(x) + + @staticmethod + def log(x): + return OpaqueUnaryFn_log(x) + + @staticmethod + def log2(x): + return OpaqueUnaryFn_log2(x) + + @staticmethod + def sqrt(x): + return OpaqueUnaryFn_sqrt(x) + + @staticmethod + def pow(a, b): + return _keep_float(FloatPow)(a, b) + + @staticmethod + def pow_by_natural(a, b): + return PowByNatural(a, b) + + @staticmethod + def minimum(a, b): + return Min(a, b) + + @staticmethod + def maximum(a, b): + return Max(a, b) + + @staticmethod + def round_to_int(a, dtype): + return RoundToInt(a) + + @staticmethod + def round_decimal(a, b): + return RoundDecimal(a, b) + + @staticmethod + def bitwise_and(a, b): + return BitwiseFn_bitwise_and(a, b) + + @staticmethod + def bitwise_or(a, b): + return BitwiseFn_bitwise_or(a, b) + + +# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain +# Python types and is FX traceable. Inheritance here is purely for code +# sharing (TODO: considering splitting out a BaseReferenceAnalysis). +class PythonReferenceAnalysis(ReferenceAnalysis): + @staticmethod + def constant(c, dtype): + if dtype is torch.int64: + return int(c) + elif dtype is torch.double: + return float(c) + elif dtype is torch.bool: + return bool(c) + else: + raise AssertionError(f"unrecognized dtype {dtype}") + + @staticmethod + def not_(a): + return torch.sym_not(a) + + @classmethod + def sym_sum(cls, args): + if len(args) == 0: + return 0 + if len(args) == 1: + return args[0] + acc = cls.add(args[0], args[1]) + for i in range(2, len(args)): + acc = cls.add(acc, args[i]) + return acc + + @staticmethod + def floordiv(a, b): + return a // b + + @staticmethod + def mod(x, y): + return x % y + + @staticmethod + def truncdiv(a, b): + return a / b + + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return torch.sym_float(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + @staticmethod + def exp(x): + raise AssertionError("exp is not valid shape sympy expr") + + @staticmethod + def log(x): + raise AssertionError("log is not valid shape sympy expr") + + @staticmethod + def log2(x): + return torch._sym_log2(x) # type: ignore[attr-defined] + + @staticmethod + def sqrt(x): + return torch._sym_sqrt(x) # type: ignore[attr-defined] + + @staticmethod + def minimum(a, b): + return torch.sym_min(a, b) + + @staticmethod + def maximum(a, b): + return torch.sym_max(a, b) + + @staticmethod + def floor_to_int(x, dtype): + return math.floor(x) + + @staticmethod + def ceil_to_int(x, dtype): + return math.ceil(x) + + @staticmethod + def floor(x): + return float(math.floor(x)) + + @staticmethod + def ceil(x): + return float(math.ceil(x)) + + @staticmethod + def truediv(a, b): + return a / b + + @staticmethod + def pow(a, b): + return a**b + + @staticmethod + def pow_by_natural(a, b): + # Pray that safe_pow is not needed here lol. In particular, this + # never participates in VR low/high ranges, so overflow should be + # unlikely + return a**b + + @staticmethod + def round_to_int(a, dtype): + return round(a) + + @staticmethod + def round_decimal(a, b): + return round(a, ndigits=b) + + @staticmethod + def bitwise_and(a, b): + return a & b + + @staticmethod + def bitwise_or(a, b): + return a | b + + +# Like PythonReferenceAnalysis, but some export-unfriendly choices of +# operators to make things faster +class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis): + @staticmethod + def sym_sum(args): + return torch.sym_sum(args) + + +def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.ops.prims.convert_element_type.default(x, dtype) + + +# Suppose we have some int/float arguments. This diagram commutes: +# +# int/float -- PythonReferenceAnalysis.op --> int/float +# | | +# | | +# torch.tensor(..., dtype=torch.int64/torch.float64) +# | | +# V V +# Tensor -- TensorReferenceAnalysis.op --> Tensor +# +# NB: int before and after must be representable in int64 (we will +# insert guards accordingly.) +# +# This is guaranteed to be FX traceable with OpOverloads only. +class TensorReferenceAnalysis: + # NB: This is actually dead, because with Proxy tracing the factory + # function isn't traced correctly. Here for completeness. + @staticmethod + def constant(c, dtype): + d: Union[int, float, bool] + if dtype is torch.int64: + d = int(c) + elif dtype is torch.double: + d = float(c) + elif dtype is torch.bool: + d = bool(c) + else: + raise AssertionError(f"unrecognized dtype {dtype}") + return torch.ops.aten.scalar_tensor.default(d, dtype=dtype) + + @staticmethod + def or_(a, b): + return torch.ops.aten.logical_or.default(a, b) + + @staticmethod + def and_(a, b): + return torch.ops.aten.logical_and.default(a, b) + + @staticmethod + def bitwise_and(a, b): + return torch.ops.aten.bitwise_and(a, b) + + @staticmethod + def bitwise_or(a, b): + return torch.ops.aten.bitwise_or(a, b) + + @staticmethod + def eq(a, b): + return torch.ops.aten.eq.Tensor(a, b) + + @classmethod + def ne(cls, a, b): + return torch.ops.aten.ne.Tensor(a, b) + + @staticmethod + def lt(a, b): + return torch.ops.aten.lt.Tensor(a, b) + + @staticmethod + def gt(a, b): + return torch.ops.aten.gt.Tensor(a, b) + + @staticmethod + def le(a, b): + return torch.ops.aten.le.Tensor(a, b) + + @staticmethod + def ge(a, b): + return torch.ops.aten.ge.Tensor(a, b) + + @staticmethod + def not_(a): + return torch.ops.aten.logical_not.default(a) + + @staticmethod + def reciprocal(x): + return torch.ops.aten.reciprocal.default(x) + + @staticmethod + def square(x): + # TODO: maybe composite implicit autograd doesn't work here? + return torch.ops.aten.square.default(x) + + @staticmethod + def trunc_to_int(x, dtype): + return _to_dtype(torch.ops.aten.trunc.default(x), dtype) + + @staticmethod + def ceil_to_int(x, dtype): + return _to_dtype(torch.ops.aten.ceil.default(x), dtype) + + @staticmethod + def floor_to_int(x, dtype): + return _to_dtype(torch.ops.aten.floor.default(x), dtype) + + @staticmethod + def floor(x): + return torch.ops.aten.floor.default(x) + + @staticmethod + def ceil(x): + return torch.ops.aten.ceil.default(x) + + @staticmethod + def to_dtype(x, dtype): + return _to_dtype(x, dtype) + + @staticmethod + def mod(x, y): + # TODO: https://github.com/pytorch/pytorch/pull/133654 + raise NotImplementedError( + "no C-style modulus operation available from frontend atm" + ) + + @staticmethod + def abs(x): + return torch.ops.aten.abs.default(x) + + @staticmethod + def neg(x): + return torch.ops.aten.neg.default(x) + + @staticmethod + def truediv(a, b): + return torch.ops.aten.true_divide.Tensor(a, b) + + @staticmethod + def int_truediv(a, b): + raise NotImplementedError( + "Python int truediv difficult to implement in PyTorch atm" + ) + + # TODO: This is wrong, CPython has a custom implementation of true + # division that results in higher precision when the floats are + # sufficiently large. Short term fix: add a guard here + return torch.ops.aten.true_divide.default( + _to_dtype(a, torch.float64), _to_dtype(b, torch.float64) + ) + + @staticmethod + def floordiv(a, b): + return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor") + + @staticmethod + def truncdiv(a, b): + raise NotImplementedError( + "no C-style truncdiv operation available from frontend atm" + ) + + @staticmethod + def add(a, b): + return torch.ops.aten.add.Tensor(a, b) + + @staticmethod + def mul(a, b): + return torch.ops.aten.mul.Tensor(a, b) + + @staticmethod + def sub(a, b): + return torch.ops.aten.sub.Tensor(a, b) + + @staticmethod + def exp(x): + return torch.ops.aten.exp.default(x) + + @staticmethod + def log(x): + return torch.ops.aten.log.default(x) + + @staticmethod + def log2(x): + return torch.ops.aten.log2.default(x) + + @staticmethod + def sqrt(x): + return torch.ops.aten.sqrt.default(x) + + @staticmethod + def sin(x): + return torch.ops.aten.sin.default(x) + + @staticmethod + def cos(x): + return torch.ops.aten.cos.default(x) + + @staticmethod + def tanh(x): + return torch.ops.aten.tanh.default(x) + + @staticmethod + def sinh(x): + return torch.ops.aten.sinh.default(x) + + @staticmethod + def cosh(x): + return torch.ops.aten.cosh.default(x) + + @staticmethod + def tan(x): + return torch.ops.aten.tan.default(x) + + @staticmethod + def acos(x): + return torch.ops.aten.acos.default(x) + + @staticmethod + def atan(x): + return torch.ops.aten.atan.default(x) + + @staticmethod + def asin(x): + return torch.ops.aten.asin.default(x) + + @staticmethod + def pow(a, b): + return torch.ops.aten.pow.Tensor_Tensor(a, b) + + @staticmethod + def pow_by_natural(a, b): + # NB: pow handles int x int fine + return torch.ops.aten.pow.Tensor_Tensor(a, b) + + @staticmethod + def minimum(a, b): + return torch.ops.aten.minimum.default(a, b) + + @staticmethod + def maximum(a, b): + return torch.ops.aten.maximum.default(a, b) + + @staticmethod + def round_to_int(a, dtype): + return torch.ops.aten.round.default(a) + + @staticmethod + def round_decimal(a, b): + raise NotImplementedError( + "round decimal doesn't support Tensor second argument atm" + ) + + # return torch.ops.aten.round.decimals(a, b) diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/singleton_int.py b/phivenv/Lib/site-packages/torch/utils/_sympy/singleton_int.py new file mode 100644 index 0000000000000000000000000000000000000000..5b75520e9bb55edcccd82eb87ddf069d6655e293 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_sympy/singleton_int.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +import sympy +from sympy.multipledispatch import dispatch + + +__all__ = ["SingletonInt"] + + +class SingletonInt(sympy.AtomicExpr): + # This is probably not super important unless we are in multiple dispatch + # situations with other more exotic Expr types. + _op_priority = 99999 + + def __new__(cls, *args, coeff=None, **kwargs): + instance = super().__new__(cls, *args, **kwargs) + return instance + + # The semantics of this class should match that of NestedIntSymNodeImpl in + # c10/core/NestedIntSymNodeImpl.h + def __init__(self, val, *, coeff=1): + self._val = val + self._coeff = coeff + super().__init__() + + # See NOTE [ Inequalities with nested int ] + def _eval_Eq(self, other): + if ( + isinstance(other, SingletonInt) + and other._val == self._val + and self._coeff == other._coeff + ): + return sympy.true + else: + return sympy.false + + # This is necessary so that calling expr.free_symbols on exprs that contain + # this Singleton does not error + @property + def free_symbols(self): + return set() + + def __mul__(self, other): + if isinstance(other, SingletonInt): + raise ValueError( + "SingletonInt cannot be multiplied by another SingletonInt" + ) + return SingletonInt(self._val, coeff=self._coeff * other) + + def __rmul__(self, other): + if isinstance(other, SingletonInt): + raise ValueError( + "SingletonInt cannot be multiplied by another SingletonInt" + ) + return SingletonInt(self._val, coeff=self._coeff * other) + + # Make sure we promptly raise an error instead of falling back to building + # an expression tree. There are probably more ops, how can we be exhaustive? + def __add__(self, other): + raise NotImplementedError("NYI") + + def __sub__(self, other): + raise NotImplementedError("NYI") + + def __truediv__(self, other): + raise NotImplementedError("NYI") + + def __floordiv__(self, other): + raise NotImplementedError("NYI") + + def __mod__(self, other): + raise NotImplementedError("NYI") + + +# See NOTE [ Inequalities with nested int ] +@dispatch(sympy.Integer, SingletonInt) +def _eval_is_ge(a, b): + if a < 2: + return sympy.false + raise ValueError("Symbolic SingletonInt: Relation is indeterminate") + + +@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef] +def _eval_is_ge(a, b): # noqa: F811 + if b <= 2: + return sympy.true + raise ValueError("Symbolic SingletonInt: Relation is indeterminate") + + +@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef] +def _eval_is_ge(a, b): # noqa: F811 + if a._val == b._val: + if a._coeff >= b._coeff: + return sympy.true + else: + return sympy.false + raise ValueError("Symbolic SingletonInt: Relation is indeterminate") diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/solve.py b/phivenv/Lib/site-packages/torch/utils/_sympy/solve.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc352ceb5e8cec8f2148ecc171a8973fd42ed74 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_sympy/solve.py @@ -0,0 +1,178 @@ +import logging +from typing import Optional + +import sympy + +from torch.utils._sympy.functions import FloorDiv + + +log = logging.getLogger(__name__) + +_MIRROR_REL_OP: dict[type[sympy.Basic], type[sympy.Rel]] = { + sympy.Eq: sympy.Eq, + sympy.Ne: sympy.Ne, + sympy.Ge: sympy.Le, + sympy.Gt: sympy.Lt, + sympy.Le: sympy.Ge, + sympy.Lt: sympy.Gt, +} + +INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le) + + +def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]: + return _MIRROR_REL_OP.get(type, None) + + +# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side. +# +# Returns a tuple of: +# 1. The simplified expression +# 2. The expression on the right-hand side +# +# Returns 'None' if it can't reach a state where the only thing in the left +# hand side is 'thing'. +# +# 'trials': number of times 'try_solve' will try to isolate 'thing' to the +# left-hand side. +# +# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into +# inequalities. +def try_solve( + expr: sympy.Basic, + thing: sympy.Basic, + trials: int = 5, + floordiv_inequality: bool = True, +) -> Optional[tuple[sympy.Rel, sympy.Expr]]: + mirror = mirror_rel_op(type(expr)) + + # Ignore unsupported expressions: + # - Those that are not relational operations + # - Those that don't have a mirror (just avoiding unexpected classes) + if not isinstance(expr, sympy.Rel) or mirror is None: + log.debug("expression with unsupported type: %s", type(expr)) + return None + + lhs_has_thing = expr.lhs.has(thing) + rhs_has_thing = expr.rhs.has(thing) + + # Give up when 'thing' appears on both sides of the relational expression. + # That is because, as is, we assume the thing we are trying to isolate is + # only on the right-hand side. + if lhs_has_thing and rhs_has_thing: + log.debug("thing (%s) found in both sides of expression: %s", thing, expr) + return None + + # Try considering both LHS and RHS by mirroring the original expression: + # a < b ==> b > a + expressions = [] + + # Add each version of 'expr' if 'thing' is in its left-hand side. + if lhs_has_thing: + expressions.append(expr) + if rhs_has_thing: + expressions.append(mirror(expr.rhs, expr.lhs)) + + for e in expressions: + if e is None: + continue + + assert isinstance(e, sympy.Rel) + + for _ in range(trials): + trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality) + # Stop if there was no change in this trial. + if trial == e: + break + e = trial # type: ignore[assignment] + + # Return if we were able to isolate 'thing' on the left-hand side. + if isinstance(e, sympy.Rel) and e.lhs == thing: + log.debug("solved: %s ---> %s", expr, e) + return e, e.rhs + + return None + + +def _try_isolate_lhs( + e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool +) -> sympy.Basic: + op = type(e) + + if isinstance(e, sympy.Rel): + # Move any constants in the left-hand side to the right-hand side. + lhs_not_thing = ( + sum(a for a in e.lhs.args if not a.has(thing)) + if isinstance(e.lhs, sympy.Add) + else 0 + ) + e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing) # type: ignore[attr-defined] + + # Divide both sides by the factors that don't contain thing. + if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul): + lhs, rhs = e.args + other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)]) + + # If we can't tell whether 'other' is negative or positive, we do nothing. + # That is because we don't know whether we have mirror the operation or not. + # We also divide only when we know 'rhs' is not zero. + if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None) and not ( + not isinstance(e, INEQUALITY_TYPES) and rhs.is_zero + ): + # Divide both sides by 'other'. + lhs = lhs / other + rhs = rhs / other + + # If 'e' is an inequality and 'other' is negative, we have to + # mirror the expression. + if isinstance(e, INEQUALITY_TYPES) and other.is_negative: + op = mirror_rel_op(op) # type: ignore[assignment] + + assert op is not None + e = op(lhs, rhs) + + ################################################################################ + # left-hand side is FloorDiv + ################################################################################ + # + # Given the expression: a // b op c + # where 'op' is a relational operation, these rules only work if: + # - b > 0 + # - c is an integer + if ( + floordiv_inequality + and isinstance(e, sympy.Rel) + and isinstance(e.lhs, FloorDiv) + and e.lhs.divisor.is_positive + and e.rhs.is_integer + ): + # a // b == expr + # => a >= (b * expr) and a < (b * (expr + 1)) + if isinstance(e, sympy.Eq): + numerator, denominator = e.lhs.args + return sympy.And( + sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type] + sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type] + ) + # a // b != expr + # => a < (b * expr) or a >= (b * (expr + 1)) + if isinstance(e, sympy.Ne): + numerator, denominator = e.lhs.args + return sympy.Or( + sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type] + sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type] + ) + # The transformations below only work if b is positive. + # Note: we only have this information for constants. + # a // b > expr => a >= b * (expr + 1) + # a // b >= expr => a >= b * expr + if isinstance(e, (sympy.Gt, sympy.Ge)): + quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type] + return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type] + # a // b < expr => a < b * expr + # a // b <= expr => a < b * (expr + 1) + if isinstance(e, (sympy.Lt, sympy.Le)): + quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type] + return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type] + + return e diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/symbol.py b/phivenv/Lib/site-packages/torch/utils/_sympy/symbol.py new file mode 100644 index 0000000000000000000000000000000000000000..4c92a68787395d96fa687317cfd5da2cb54737b9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_sympy/symbol.py @@ -0,0 +1,101 @@ +# mypy: allow-untyped-defs +""" +This file contains canonical definitions for our symbol naming conventions, +across torch.fx.experimental.symbolic_shapes and torch._inductor. The +intention is: + +1. To make it easily greppable where all the sites we use a prefix are +2. Make it possible to easily tell if we can introduce a new prefix without + introducing a conflict + +You can occasionally test if prefixes have been hardcoded by renaming prefixes +in this file and seeing what breaks. +""" + +from collections.abc import Iterable +from enum import auto, Enum +from typing import Union + +import sympy + + +class SymT(Enum): + SIZE = auto() + FLOAT = auto() + UNBACKED_INT = auto() + UNBACKED_FLOAT = auto() + # Inductor: The intermediates in inner_fn tmp0, one generated per ops call. + # If one of these shows up in an indexing expression, that means an + # indirect load is happening. + TMP = auto() + # Inductor: Placeholder variable that is later replaced with TMP + INDIRECT = auto() + # Inductor: Some size expressions are replaced with a precomputed size ps0 + # which is computed host side, and then directly reused in the kernel, so + # we don't repeatedly recompute it on device. + PRECOMPUTED_SIZE = auto() + # Inductor: An indexing variable i0 in loops IR which ranges over non-reduced + # dim in the loop + INDEX = auto() + # Inductor: A reduction indexing (r0, r1) variables in loops IR which ranges over + # reduced dim(s) in the loop + R0_INDEX = auto() + R1_INDEX = auto() + # Inductor: In templated kernels torch._inductor.kernel, we have a hook to + # store the final output and append epilogue fusions. To do this, we must + # know what the indexes the outputs range over. NB: These will also + # advertise as INDEX, this is... probably OK? + TEMPLATE_INDEX = auto() + # Inductor: iteration domain for blockIdx.x/blockIdx.y + XBLOCK = auto() + YBLOCK = auto() + ZBLOCK = auto() + # Inductor: this is used solely for dynamic_reshape_indexer + VIEW = auto() + # Alternate (non-modular) indexing used in halide kernels + HALIDE = auto() + + +# Invariant: there must not be a prefix which is a prefix of another string, +# as this introduces ambiguity +prefix_str = { + SymT.SIZE: "s", # integer + SymT.UNBACKED_INT: "u", # integer + # Prefix z here is chosen to avoid false aliasing in symbol_is_type test + # DO NOT add a "z" type. You also need to avoid conflicts on these + # prefixes but this is somewhat easier to manage + SymT.FLOAT: "zf", + SymT.UNBACKED_FLOAT: "zuf", + SymT.TMP: "tmp", + SymT.PRECOMPUTED_SIZE: "ps", + SymT.INDEX: "i", + SymT.R0_INDEX: "r0_", + SymT.R1_INDEX: "r1_", + SymT.TEMPLATE_INDEX: "idx", + SymT.XBLOCK: "x", + SymT.YBLOCK: "y", + SymT.ZBLOCK: "z", + SymT.INDIRECT: "indirect", # false aliasing? + SymT.VIEW: "view", + SymT.HALIDE: "h", +} + + +def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol: + # TODO: maybe put the assumptions here directly + return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs) + + +# This type is a little wider than it should be, because free_symbols says +# that it contains Basic, rather than Symbol +def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> bool: + assert isinstance(sym, sympy.Symbol) + name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK + if isinstance(prefix, SymT): + return name_str.startswith(prefix_str[prefix]) + else: + return name_str.startswith(tuple(prefix_str[p] for p in prefix)) + + +def free_symbol_is_type(e: sympy.Expr, prefix: Union[SymT, Iterable[SymT]]) -> bool: + return any(symbol_is_type(v, prefix) for v in e.free_symbols) diff --git a/phivenv/Lib/site-packages/torch/utils/_sympy/value_ranges.py b/phivenv/Lib/site-packages/torch/utils/_sympy/value_ranges.py new file mode 100644 index 0000000000000000000000000000000000000000..f19184aee41aa77b20e755ac63963c962fc81e45 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_sympy/value_ranges.py @@ -0,0 +1,1052 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import functools +import itertools +import logging +import math +import operator +from typing import ( + Callable, + Generic, + Optional, + overload, + SupportsFloat, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import TypeGuard + +import sympy +from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom + +import torch +from torch._logging import LazyString +from torch._prims_common import dtype_to_type + +from .functions import ( + _keep_float, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, + OpaqueUnaryFn_exp, + OpaqueUnaryFn_log, + OpaqueUnaryFn_log2, + OpaqueUnaryFn_sqrt, + PowByNatural, + RoundDecimal, + RoundToInt, + safe_pow, + ToFloat, + TruncToFloat, + TruncToInt, +) +from .interp import sympy_interp +from .numbers import int_oo, IntInfinity, NegativeIntInfinity + + +log = logging.getLogger(__name__) + +__all__ = ["ValueRanges", "bound_sympy"] + +_T = TypeVar("_T", sympy.Expr, SympyBoolean) + + +class ValueRangeError(RuntimeError): + pass + + +# Like sympify, but supports less stuff, and also ensures that direct +# sympy expressions don't have free variables +def simple_sympify(e): + if isinstance(e, bool): + return sympy.true if e else sympy.false + elif isinstance(e, int): + return sympy.Integer(e) + elif isinstance(e, float): + # infinity is special; we use it to bracket integers as well + if math.isinf(e): + return sympy.oo if e > 0 else -sympy.oo + return sympy.Float(e) + elif isinstance(e, sympy.Expr): + assert e.is_number, e + # NaNs can occur when doing things like 0 * sympy.oo, but it is better + # if the operator notices this and takes care of it, because sometimes + # the NaN is inappropriate (for example, for ints, the [-oo, oo] range + # should go to zero when multiplied with [0, 0]) + assert e != sympy.nan + return e + elif isinstance(e, BooleanAtom): + return e + else: + raise AssertionError(f"not simple sympy type {type(e)}: {e}") + + +# Sympy atomics only. Unlike <=, it also works on Sympy bools. +def sympy_generic_le(lower, upper): + if isinstance(lower, sympy.Expr): + assert isinstance(upper, sympy.Expr) + # instead of lower <= upper, we do upper >= lower since upper is mostly int_oo + # and we have better code paths there. + return upper >= lower + else: + # only negative condition is True > False + assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), ( + lower, + upper, + ) + return not (lower and not upper) + + +def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]: + return vr.is_bool + + +def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]: + return not vr.is_bool + + +ExprIn = Union[int, float, sympy.Expr] +BoolIn = Union[bool, SympyBoolean] +AllIn = Union[ExprIn, BoolIn] +ExprFn = Callable[[sympy.Expr], sympy.Expr] +ExprFn2 = Callable[[sympy.Expr, sympy.Expr], sympy.Expr] +BoolFn = Callable[[SympyBoolean], SympyBoolean] +BoolFn2 = Callable[[SympyBoolean, SympyBoolean], SympyBoolean] +AllFn = Union[ExprFn, BoolFn] +AllFn2 = Union[ExprFn2, BoolFn2] + + +@dataclasses.dataclass(frozen=True) +class ValueRanges(Generic[_T]): + if TYPE_CHECKING: + # ruff doesn't understand circular references but mypy does + ExprVR = ValueRanges[sympy.Expr] # noqa: F821 + BoolVR = ValueRanges[SympyBoolean] # noqa: F821 + AllVR = Union[ExprVR, BoolVR] + + # Although the type signature here suggests you can pass any + # sympy expression, in practice the analysis here only works + # with constant sympy expressions + lower: _T + upper: _T + is_bool: bool + is_int: bool + is_float: bool + + def __repr__(self) -> str: + return f"VR[{self.lower}, {self.upper}]" + + @overload + def __init__( + self: ValueRanges[sympy.Expr], + lower: ExprIn, + upper: ExprIn, + ) -> None: + ... + + @overload + def __init__( # type: ignore[misc] + self: ValueRanges[SympyBoolean], + lower: BoolIn, + upper: BoolIn, + ) -> None: + ... + + def __init__(self, lower: AllIn, upper: AllIn) -> None: + lower = simple_sympify(lower) + upper = simple_sympify(upper) + # TODO: when the bounds have free variables, this may be + # nontrivial to actually verify + try: + if not sympy_generic_le(lower, upper): + raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]") + except TypeError as e: + raise TypeError(f"Could not compare {lower} <= {upper}") from e + + is_bool_lower = isinstance(lower, SympyBoolean) + is_bool_upper = isinstance(upper, SympyBoolean) + assert is_bool_lower == is_bool_upper, (lower, upper) + + # Warning: is_int/is_float is best effort. We do pretty well in + # Dynamo, but in Inductor these attributes are often wrong because we + # are not very rigorous in dtype analysis. This is also why we need + # the flexible analysis for is_int: sometimes a sympy.oo pops in for + # an integer bound. I would /like/ for us not to do this, but it's + # too hard to push the invariant through right now. + if isinstance(lower, sympy.Integer) and upper == sympy.oo: + upper = int_oo + if isinstance(upper, sympy.Integer) and lower == -sympy.oo: + lower = -int_oo + # NB: [-int_oo, -int_oo] and [int_oo, int_oo] are allowed + integer_types = (sympy.Integer, NegativeIntInfinity, IntInfinity) + is_int_lower = isinstance(lower, integer_types) + is_int_upper = isinstance(upper, integer_types) + + # Because this is a frozen class + object.__setattr__(self, "lower", lower) + object.__setattr__(self, "upper", upper) + # Unlike bool/int in Python, we don't report bools are ints + # + # NB: is_bool_lower == is_bool_upper, so we only need to check one + object.__setattr__(self, "is_bool", is_bool_lower) + object.__setattr__( + self, + "is_int", + not self.is_bool and is_int_lower and is_int_upper, + ) + """ + # This assert is just impossible right now, too many sympy bugs + if self.is_int: + # NB: sympy will sometimes randomly lose the float-ness of zero, + # so we also need to account for that in the assertion here. + # See also https://github.com/sympy/sympy/issues/26620 + assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( + lower, + upper, + ) + assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) + """ + # NB: [-oo, oo] always advertises as float! + object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) + assert self.is_bool or self.is_int or self.is_float, (lower, upper) + + def boolify(self) -> ValueRanges[SympyBoolean]: + if vr_is_bool(self): + return self + elif self == ValueRanges.unknown(): + return ValueRanges.unknown_bool() + else: + raise AssertionError(f"not bool like {self}") + + def __contains__(self, x: AllIn) -> bool: + return ValueRanges.wrap(x).issubset(self) + + def issubset(self, other): + if other is self.unknown_int(): + return True + return sympy_generic_le(other.lower, self.lower) and sympy_generic_le( + self.upper, other.upper + ) + + def tighten(self, other) -> ValueRanges: + """Given two ValueRanges, returns their intersection""" + return self & other + + # Intersection + @overload + def __and__( + self: ValueRanges[sympy.Expr], + other: ValueRanges[sympy.Expr], + ) -> ValueRanges[sympy.Expr]: + ... + + @overload + def __and__( # type: ignore[misc] + self: ValueRanges[SympyBoolean], + other: ValueRanges[SympyBoolean], + ) -> ValueRanges[SympyBoolean]: + ... + + def __and__(self: AllVR, other: AllVR) -> AllVR: + if other in (ValueRanges.unknown(), ValueRanges.unknown_int()): + return self + if self in (ValueRanges.unknown(), ValueRanges.unknown_int()): + return other + assert self.is_bool == other.is_bool, (self, other) + assert self.is_int == other.is_int, (self, other) + assert self.is_float == other.is_float, (self, other) + if self.is_bool: + return ValueRanges( + sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) + ) + else: + return ValueRanges( + sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper) + ) + + # Union + @overload + def __or__( + self: ValueRanges[sympy.Expr], + other: ValueRanges[sympy.Expr], + ) -> ValueRanges[sympy.Expr]: + ... + + @overload + def __or__( # type: ignore[misc] + self: ValueRanges[SympyBoolean], + other: ValueRanges[SympyBoolean], + ) -> ValueRanges[SympyBoolean]: + ... + + def __or__(self: AllVR, other: AllVR) -> AllVR: + if ValueRanges.unknown() in (self, other): + return ValueRanges.unknown() + assert self.is_bool == other.is_bool, (self, other) + assert self.is_int == other.is_int, (self, other) + assert self.is_float == other.is_float, (self, other) + if self.is_bool: + return ValueRanges( + sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper) + ) + else: + return ValueRanges( + sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper) + ) + + def is_singleton(self) -> bool: + return self.lower == self.upper + + @staticmethod + @functools.cache + def unknown() -> ValueRanges[sympy.Expr]: + return ValueRanges(-sympy.oo, sympy.oo) + + @staticmethod + @functools.cache + def unknown_int() -> ValueRanges[sympy.Expr]: + return ValueRanges(-int_oo, int_oo) + + @staticmethod + @functools.cache + def unknown_bool() -> ValueRanges[SympyBoolean]: + return ValueRanges(sympy.false, sympy.true) + + @overload + @staticmethod + # work around the fact that bool and int overlap + def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap] + ... + + @overload + @staticmethod + def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR: # type: ignore[misc] + ... + + @staticmethod + def wrap(arg: Union[AllIn, AllVR]) -> AllVR: + if isinstance(arg, ValueRanges): + return arg + if isinstance(arg, float) and math.isnan(arg): + return ValueRanges.unknown() + # arg is either ExprIn or BoolIn, but we don't know it here + return ValueRanges(arg, arg) # type: ignore[arg-type] + + @staticmethod + def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: + """Increasing: x <= y => f(x) <= f(y).""" + x = ValueRanges.wrap(x) + return ValueRanges(fn(x.lower), fn(x.upper)) + + @overload + @staticmethod + def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: + ... + + @overload + @staticmethod + def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR: # type: ignore[misc] + ... + + @staticmethod + def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR: + """Decreasing: x <= y => f(x) >= f(y).""" + x = ValueRanges.wrap(x) + # consistently either Expr or Bool, but we don't know it here + return ValueRanges(fn(x.upper), fn(x.lower)) # type: ignore[arg-type] + + @staticmethod + def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: + """It's increasing or decreasing.""" + x = ValueRanges.wrap(x) + l = fn(x.lower) + u = fn(x.upper) + return ValueRanges(min(l, u), max(l, u)) + + @staticmethod + def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: + """Fn is convex and has a minimum at 0.""" + x = ValueRanges.wrap(x) + if 0 in x: + upper = max(fn(x.lower), fn(x.upper)) + upper = simple_sympify(upper) + if isinstance(upper, sympy.Float) or upper == sympy.oo: + return ValueRanges(0.0, upper) + return ValueRanges(0, upper) + return ValueRanges.monotone_map(x, fn) + + @overload + @staticmethod + def coordinatewise_increasing_map( + x: Union[ExprIn, ExprVR], + y: Union[ExprIn, ExprVR], + fn: ExprFn2, + ) -> ExprVR: + ... + + @overload + @staticmethod + def coordinatewise_increasing_map( # type: ignore[misc] + x: Union[BoolIn, BoolVR], + y: Union[BoolIn, BoolVR], + fn: BoolFn2, + ) -> BoolVR: + ... + + @staticmethod + def coordinatewise_increasing_map( + x: Union[AllIn, AllVR], + y: Union[AllIn, AllVR], + fn: AllFn2, + ) -> AllVR: + """ + It's increasing on each coordinate. + + Mathematically: + For every 1 <= i <= n and x_i <= y_i we have that + f(x1, .., xn) <= f(x1, , yi, ..., xn) + """ + x, y = ValueRanges.wrap(x), ValueRanges.wrap(y) + return ValueRanges( + fn(x.lower, y.lower), # type: ignore[arg-type] + fn(x.upper, y.upper), # type: ignore[arg-type] + ) + + @classmethod + def coordinatewise_monotone_map(cls, x, y, fn): + """It's increasing or decreasing on each coordinate.""" + x, y = cls.wrap(x), cls.wrap(y) + products = [ + fn(a, b) + for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper]) + ] + return ValueRanges(min(products), max(products)) + + +class SymPyValueRangeAnalysis: + """ + It gives bounds on a SymPy operator given bounds on its arguments + See the function `bound_sympy` for a function that applies this logic to a full SymPy expression + """ + + @staticmethod + def constant(value, dtype): + if isinstance(value, ValueRanges): + assert value.is_singleton() + value = value.lower + # NB: value is NOT a sympy expression, it's a constant! + is_python = isinstance(value, (int, float, bool)) + assert is_python or isinstance( + value, (BooleanAtom, sympy.Integer, sympy.Number) + ) + + # using nan makes subsequent computation throw, and for the purposes of optimization + # returning -math.inf - math.inf is equivalent to giving up + if isinstance(value, SupportsFloat) and math.isnan(value): + if dtype == torch.bool: + return ValueRanges.unknown_bool() + elif dtype.is_floating_point: + return ValueRanges.unknown() + else: + return ValueRanges.unknown_int() + + if is_python: + type_ = dtype_to_type(dtype) + value = type_(value) + else: + # We do a type check on a best-effort basis + # We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision + if dtype == torch.bool: + assert isinstance(value, BooleanAtom) + elif dtype.is_floating_point: + assert not value.is_finite or value.is_real + else: + # dtype is intXX + assert value.is_integer + + r = ValueRanges.wrap(value) + return r + + @staticmethod + def to_dtype(a, dtype, src_dtype=None): + if dtype == torch.float64: + return ValueRanges.increasing_map(a, ToFloat) + elif dtype == torch.bool: + return ValueRanges.unknown_bool() + elif not dtype.is_floating_point: + return ValueRanges.unknown_int() + return ValueRanges.unknown() + + @staticmethod + def trunc_to_int(a, dtype): + return ValueRanges.increasing_map(a, TruncToInt) + + @staticmethod + def not_(a): + a = ValueRanges.wrap(a) + a = a.boolify() + assert a.is_bool + return ValueRanges.decreasing_map(a, sympy.Not) + + @staticmethod + def or_(a, b): + return ValueRanges.coordinatewise_increasing_map(a, b, sympy.Or) + + @staticmethod + def and_(a, b): + return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And) + + @staticmethod + def _bool_to_int(x): + if x.is_singleton(): + return ValueRanges.wrap(sympy.Integer(1 if x.lower else 0)) + else: + return ValueRanges(sympy.Integer(0), sympy.Integer(1)) + + @classmethod + def bitwise_and(cls, a, b): + a, b = ValueRanges.wrap(a), ValueRanges.wrap(b) + if a.is_bool and b.is_bool: + return cls.and_(a, b) + if a.is_bool: + a = cls._bool_to_int(a) + if b.is_bool: + b = cls._bool_to_int(b) + lower = min(a.lower, b.lower) + if lower < 0 and lower != -sympy.oo and lower != -int_oo: + # If both lower bounds are negative, then bits start like + # 1...10..., so the smallest possible value is 1...101...1. + # Thus, we need to find the next smallest power of 2 (inclusive). + try: + lower = -(1 << int(-lower - 1).bit_length()) + except Exception: + lower = -int_oo + else: + lower = 0 + return ValueRanges(lower, max(a.upper, b.upper)) + + @classmethod + def bitwise_or(cls, a, b): + a, b = ValueRanges.wrap(a), ValueRanges.wrap(b) + if a.is_bool and b.is_bool: + return cls.or_(a, b) + if a.is_bool: + a = cls._bool_to_int(a) + if b.is_bool: + b = cls._bool_to_int(b) + upper = max(a.upper, b.upper) + if upper == 0: + upper = 0 + elif upper > 0 and upper != sympy.oo and upper != int_oo: + # If both upper bounds are positive, then the largest + # possible value is 01...1, so we need to find + # next largest power of 2 (exclusive), minus 1 + try: + upper = (1 << int(upper).bit_length()) - 1 + except Exception: + upper = int_oo + elif upper < 0: + upper = -1 + return ValueRanges(min(a.lower, b.lower), upper) + + @staticmethod + def eq(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_singleton() and b.is_singleton() and a.lower == b.lower: + return ValueRanges.wrap(sympy.true) + elif a.lower > b.upper or b.lower > a.upper: # ranges disjoint + return ValueRanges.wrap(sympy.false) + return ValueRanges(sympy.false, sympy.true) + + @classmethod + def ne(cls, a, b): + return cls.not_(cls.eq(a, b)) + + @classmethod + def identity(cls, a): + return ValueRanges.wrap(a) + + @classmethod + def lt(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + assert a.is_bool == b.is_bool + if a.is_bool: + return cls.and_(cls.not_(a), b) + else: + if a.upper < b.lower: + return ValueRanges.wrap(sympy.true) + elif a.lower >= b.upper: + return ValueRanges.wrap(sympy.false) + return ValueRanges(sympy.false, sympy.true) + + @classmethod + def gt(cls, a, b): + return cls.lt(b, a) + + @classmethod + def le(cls, a, b): + return cls.not_(cls.gt(a, b)) + + @classmethod + def ge(cls, a, b): + return cls.not_(cls.lt(a, b)) + + @staticmethod + def add(a, b): + return ValueRanges.coordinatewise_increasing_map( + a, b, _keep_float(operator.add) + ) + + @classmethod + def mul(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + + assert a.is_bool == b.is_bool + if a.is_bool: + return cls.and_(a, b) + + def safe_mul(a, b): + # Make unknown() * wrap(0.0) == wrap(0.0) + if a == 0.0 or a == 0: + return a + elif b == 0.0 or b == 0: + return b + else: + return a * b + + return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) + + @staticmethod + def int_truediv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)): + return ValueRanges.unknown() + else: + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(IntTrueDiv) + ) + + @staticmethod + def truediv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b or ( + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + ): + return ValueRanges.unknown() + else: + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(FloatTrueDiv) + ) + + @staticmethod + def floordiv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b: + return ValueRanges.unknown_int() + products = [] + for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]): + r = FloorDiv(x, y) + if r is sympy.nan: + products.append((sympy.sign(x) * sympy.sign(y)) * int_oo) + else: + products.append(r) + + return ValueRanges(min(products), max(products)) + + @classmethod + def mod(cls, x, y): + x = ValueRanges.wrap(x) + y = ValueRanges.wrap(y) + # nb. We implement C semantics + + def c_mod(a, b): + ret = abs(a) % abs(b) + if a < 0: + ret *= -1 + return ret + + def c_div(a, b): + x = a / b + return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x + + if 0 in y: + return ValueRanges.unknown_int() + elif y.is_singleton(): + y_val = abs(y.lower) + # If it wraps, we need to take the whole interval + + # The function is locally linear if they are in the same class + if c_div(x.lower, y_val) == c_div(x.upper, y_val): + return ValueRanges.increasing_map(x, lambda u: c_mod(u, y_val)) + if x.upper < 0: + # Negative case + return ValueRanges(-y_val + 1, 0) + elif x.lower > 0: + # Positive case + return ValueRanges(0, y_val - 1) + else: + # Mixed case + lower = max(-y_val + 1, x.lower) + upper = min(y_val - 1, x.upper) + return ValueRanges(lower, upper) + else: + # Too difficult, we bail out + upper = cls.abs(y).upper - 1 + return ValueRanges(-upper, upper) + + @classmethod + def modular_indexing(cls, a, b, c): + return cls.mod(cls.floordiv(a, b), c) + + @classmethod + def is_non_overlapping_and_dense_indicator(cls, *args): + return ValueRanges.unknown_int() + + @classmethod + def pow_by_natural(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_singleton() and b.is_singleton(): + return ValueRanges.wrap(safe_pow(a.lower, b.lower)) + # NB: Exclude zero, because zero is special + elif a.lower >= 1: + # We should know that b >= 0 but we may have forgotten this fact due + # to replacements, so don't assert it, but DO clamp it to prevent + # degenerate problems + return ValueRanges.coordinatewise_increasing_map( + a, b & ValueRanges(0, int_oo), PowByNatural + ) + elif b.is_singleton(): + if b.lower % 2 == 0: + # x^n where n is even + return ValueRanges.convex_min_zero_map( + a, lambda x: safe_pow(x, b.lower) + ) + else: + # x^n where n is odd + return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) + else: + # a is potentially negative, and we don't know if the exponent is + # even or odd. So just conservatively set the upper and lower + # bound based on what the maximum absolute value could be, in both + # directions + max_base = max(a.upper, -a.lower) + return ValueRanges( + -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) + ) + + @classmethod + def pow(cls, a, b): + return ValueRanges.unknown() + + # We could implement all this, but for floating point pow, is there + # really a point? + """ + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + + # Not implemented yet. It's a bit tricky + # If you want to implement it, compute the partial derivatives of a ** b + # and check the ranges where the function is increasing / decreasing + # Another non-tight way of doing this is defaulting to doing noting that for a > 0, a ** b == exp(b * log(a)) + # If this second option is implemented, by carefult about the types and possible infinities here and there. + if not b.is_singleton(): + return ValueRanges.unknown() + + b = b.lower + if a.is_singleton(): + a = a.lower + r = a**b + if not r.is_finite: + return ValueRanges.unknown() + return ValueRanges.wrap(r) + + if b == 0: + if not a.lower.is_finite: + return ValueRanges.unknown() + return ValueRanges.wrap(1.0) + + if b < 0: + a = cls.reciprocal(a) + b = -b + + if a == ValueRanges.unknown(): + return ValueRanges.unknown() + + # If the base is positive, then we're good, otherwise nothing's defined + if a.lower >= 0: + return ValueRanges.increasing_map(a, lambda x: x**b) + else: + return ValueRanges.unknown() + """ + + @staticmethod + def reciprocal(x): + """Needed as it's used in pow, but it won't appear on a SymPy expression""" + x = ValueRanges.wrap(x) + if 0 in x: + return ValueRanges.unknown() + else: + return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) # type: ignore[operator] + + @staticmethod + def abs(x): + return ValueRanges.convex_min_zero_map(x, abs) + + @staticmethod + def exp(x): + return ValueRanges.increasing_map(x, OpaqueUnaryFn_exp) + + @staticmethod + def log(x): + x = ValueRanges.wrap(x) + if x.lower <= 0: + return ValueRanges.unknown() + return ValueRanges.increasing_map(x, OpaqueUnaryFn_log) + + @staticmethod + def log2(x): + x = ValueRanges.wrap(x) + if x.lower <= 0: + return ValueRanges.unknown() + return ValueRanges.increasing_map(x, OpaqueUnaryFn_log2) + + @classmethod + def minimum(cls, a, b): + return cls.min_or_max(a, b, sympy.Min) + + @classmethod + def maximum(cls, a, b): + return cls.min_or_max(a, b, sympy.Max) + + @staticmethod + def min_or_max(a, b, fn): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + return ValueRanges.coordinatewise_increasing_map(a, b, fn) + + @classmethod + def floor_to_int(cls, x, dtype): + return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) + + @classmethod + def ceil_to_int(cls, x, dtype): + return ValueRanges.increasing_map( + x, sympy.functions.elementary.integers.ceiling + ) + + # I think these implementations are sound. The hazard here is that sympy + # will carry out the floor/ceil at too high precision and then something + # bad will happen when we convert it to float. + # + # For truncation, the implementation is clearly sound, because the desired + # target float is always exactly representable, since you're just chopping + # off bits the mantissa. But what about ceil/floor? + # + # The important constraint here is that we're not defining floor on + # arbitrary real numbers, only representable float numbers. So we can + # take advantage of the fact that before we reach the first + # unrepresentable integer in floating point space, we have the range of + # numbers corresponding to exponent zero: all integers, with no fractional + # amounts. floor/ceil is an identity operation in this case. In the + # range below here, representable floating point numbers are spaced + # exactly 1/2 apart, and notably, both the floor/ceil are defined floating + # point numbers. There is no "gap" as you step up to the next exponent. + + @classmethod + def floor(cls, x): + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.floor) + ) + + @classmethod + def ceil(cls, x): + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.ceiling) + ) + + @classmethod + def round_decimal(cls, number, ndigits): + if not ndigits.is_singleton(): + return ValueRanges.unknown() + + ndigits = ndigits.lower + # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind + # the second parameter. + fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 + + return ValueRanges.increasing_map(number, fn) + + @classmethod + def round_to_int(cls, number, dtype): + return ValueRanges.increasing_map(number, RoundToInt) + + # It's used in some models on symints + @staticmethod + def sqrt(x): + x = ValueRanges.wrap(x) + if x.lower < 0: + return ValueRanges.unknown() + return ValueRanges.increasing_map(x, OpaqueUnaryFn_sqrt) + + @staticmethod + def where(a, b, c): + b = ValueRanges.wrap(b) + c = ValueRanges.wrap(c) + a = a.boolify() + # We sometimes write unknown without specifying the type correctly + # In particular, we do that when initialising the bounds for loads in bounds.py + assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c) + if b.is_bool: + return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper)) + else: + return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper)) + + # expr_cond_pair is used to represent a single (expr, condition) pair in piecewise. + # We just return the value range of the expression and its corresponding condition as a tuple + # and defer the analysis to piecewise + @staticmethod + def expr_cond_pair(a, b): + b = b.boolify() + return (a, b) + + # piecewise function can be used to convert a SymBool to SymInt: + # int_expr = Piecewise((1, bool_expr), (0, True)), it evalutes to 1 when sym_bool is True and 0 otherwise. + # + # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair. + # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True. + @staticmethod + def piecewise(*ranges): + init_range = None + for expr_range, cond_range in ranges: + if sympy.true in cond_range: + if init_range is None: + init_range = expr_range + else: + init_range = init_range | expr_range + return init_range + + @staticmethod + def cos(x): + # TODO: We should tighten value ranges + # If input range span is pi + 2*pi*k, then output range is (-1, 1) + # otherwise the minimum of the value of the function on the extremes + return ValueRanges(-1.0, 1.0) + + @staticmethod + def cosh(x): + return ValueRanges(0.0, sympy.oo) + """ + x = ValueRanges.wrap(x) + if x.lower > 0: + return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) + elif x.upper < 0: + return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) + return ValueRanges(0.0, sympy.oo) + """ + + @staticmethod + def sin(x): + # TODO: We should tighten value ranges + # See details on cos + return ValueRanges(-1.0, 1.0) + + @staticmethod + def sinh(x): + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + return ValueRanges(-sympy.oo, sympy.oo) + + @staticmethod + def tan(x): + return ValueRanges(-sympy.oo, sympy.oo) + + @staticmethod + def tanh(x): + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + return ValueRanges(-sympy.oo, sympy.oo) + + @staticmethod + def asin(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ + x = ValueRanges.wrap(x) + if -1 <= x.lower and x.upper <= 1: + return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) + return ValueRanges.unknown() + """ + + @staticmethod + def acos(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ + x = ValueRanges.wrap(x) + if -1 <= x.lower and x.upper <= 1: + return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) + return ValueRanges.unknown() + """ + + @staticmethod + def atan(x): + return ValueRanges(-sympy.oo, sympy.oo) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) + + @staticmethod + def trunc(x): + return ValueRanges.increasing_map(x, TruncToFloat) + + +def bound_sympy( + expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None +) -> ValueRanges: + log.debug( + "bound_sympy(%s)%s", + expr, + LazyString( + lambda: ( + "\n" + + "\n".join( + f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols + ) + if ranges + else "" + ) + ), + ) + if isinstance(expr, sympy.Number): + return ValueRanges.wrap(expr) + + ranges = ranges or {} + + # If there's a tracing context, augment available constrained ranges. + context = torch._guards.TracingContext.try_get() + if context and context.fake_mode.shape_env: + if ranges: + ranges = {**context.fake_mode.shape_env.var_to_range, **ranges} + else: + ranges = context.fake_mode.shape_env.var_to_range + + def missing_handler(s): + if s.is_integer: # type: ignore[attr-defined] + if s.is_positive: # type: ignore[attr-defined] + vr = ValueRanges(1, int_oo) + elif s.is_nonnegative: # type: ignore[attr-defined] + vr = ValueRanges(0, int_oo) + else: + vr = ValueRanges.unknown_int() + else: + # Don't bother trying very hard here + vr = ValueRanges.unknown() + return vr + + return sympy_interp( + SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler + ) diff --git a/phivenv/Lib/site-packages/torch/utils/_thunk.py b/phivenv/Lib/site-packages/torch/utils/_thunk.py new file mode 100644 index 0000000000000000000000000000000000000000..368e263ca38b95c08d5b430931e047d6bebe0470 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_thunk.py @@ -0,0 +1,28 @@ +from typing import Callable, Generic, Optional, TypeVar + + +R = TypeVar("R") + + +class Thunk(Generic[R]): + """ + A simple lazy evaluation implementation that lets you delay + execution of a function. It properly handles releasing the + function once it is forced. + """ + + f: Optional[Callable[[], R]] + r: Optional[R] + + __slots__ = ["f", "r"] + + def __init__(self, f: Callable[[], R]): + self.f = f + self.r = None + + def force(self) -> R: + if self.f is None: + return self.r # type: ignore[return-value] + self.r = self.f() + self.f = None + return self.r diff --git a/phivenv/Lib/site-packages/torch/utils/_traceback.py b/phivenv/Lib/site-packages/torch/utils/_traceback.py new file mode 100644 index 0000000000000000000000000000000000000000..02fa9b6003a424baf057f9bf31da1cae57ec2373 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_traceback.py @@ -0,0 +1,260 @@ +# mypy: allow-untyped-defs +import contextlib +import inspect +import os.path +import tempfile +import traceback +from types import TracebackType +from typing import Optional + + +# This file contains utilities for ensuring dynamically compile()'d +# code fragments display their line numbers in backtraces. +# +# The constraints: +# +# - We don't have control over the user exception printer (in particular, +# we cannot assume the linecache trick will work, c.f. +# https://stackoverflow.com/q/50515651/23845 ) +# +# - We don't want to create temporary files every time we compile() +# some code; file creation should happen lazily only at exception +# time. Arguably, you *should* be willing to write out your +# generated Python code to file system, but in some situations +# (esp. library code) it would violate user expectation to write +# to the file system, so we try to avoid it. In particular, we'd +# like to keep the files around, so users can open up the files +# mentioned in the trace; if the file is invisible, we want to +# avoid clogging up the filesystem. +# +# If this is not a constraint for you, there is a substantially simpler +# way to implement the functionality in this PR: instead of using +# eval/exec directly, just always write a Python file to filesystem +# and compile that. +# +# - You have control over a context where the compiled code will get +# executed, so that we can interpose while the stack is unwinding +# (otherwise, we have no way to interpose on the exception printing +# process.) +# +# There are two things you have to do to make use of the utilities here: +# +# - When you compile your source code, you must save its string source +# in its f_globals under the magic name "__compile_source__" +# +# - Before running the compiled code, enter the +# report_compile_source_on_error() context manager. + + +@contextlib.contextmanager +def report_compile_source_on_error(): + try: + yield + except Exception as exc: + tb = exc.__traceback__ + + # Walk the traceback, looking for frames that have + # source attached + stack = [] + while tb is not None: + filename = tb.tb_frame.f_code.co_filename + source = tb.tb_frame.f_globals.get("__compile_source__") + + if filename == "" and source is not None: + # What black magic are we doing here? Intuitively, what + # we would like to do is overwrite the co_filename on any + # frames that were generated from exec/eval so that they + # point to a temporary file that has the actual line + # information, so Python's default error printer can print + # useful line information on it. + # + # Writing out the temporary file is easy. But overwriting + # co_filename is not! You can't modify the code object + # associated with a frame. You can, however, reconstruct + # a traceback with entirely new frames from scratch, so that's + # what we do. But there's another problem, which is how to + # make the frame? + # + # The black magic is we make a frankenstein frame and code + # object which resembles the original frame/code enough so + # that it will print properly under traceback and the default + # error printer, but IT IS NOT THE ORIGINAL FRAME (you + # couldn't, e.g., execute its code with different variables + # and expect it to work.) + + # Don't delete the temporary file so the user can inspect it + # TODO: This creates a temporary file for every frame, but we + # technically only need one per distinct __compile_source__ + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".py" + ) as f: + f.write(source) + # Create a frame. Python doesn't let you construct + # FrameType directly, so just make one with compile + frame = tb.tb_frame + code = compile("__inspect_currentframe()", f.name, "eval") + code = code.replace(co_name=frame.f_code.co_name) + # Python 3.11 only + if hasattr(frame.f_code, "co_linetable"): + # We can't copy ALL of the metadata over, because you + # can cause Python to segfault this way. What exactly + # do we need? We need enough information for + # traceback to be able to print the exception + # correctly. Code reading Lib/traceback.py reveals + # that traceback calls code.co_positions() in order to + # get the augmented line/col numbers. Objects/codeobject.c, + # specifically _PyCode_InitAddressRange, reveals that + # this iterator is initialized from co_linetable and + # co_firstfileno. So copy these we must! + code = code.replace( # type: ignore[call-arg] + co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined] + co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined] + ) + fake_frame = eval( + code, + frame.f_globals, + {**frame.f_locals, "__inspect_currentframe": inspect.currentframe}, + ) + fake_tb = TracebackType(None, fake_frame, tb.tb_lasti, tb.tb_lineno) + stack.append(fake_tb) + else: + stack.append(tb) + + tb = tb.tb_next + + # Reconstruct the linked list + tb_next = None + for tb in reversed(stack): + tb.tb_next = tb_next + tb_next = tb + + raise exc.with_traceback(tb_next) # noqa: B904 + + +def shorten_filename(fn, *, base=None): + """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user.""" + if base is None: + base = os.path.dirname(os.path.dirname(__file__)) + # Truncate torch/foo.py to foo.py + try: + prefix = os.path.commonpath([fn, base]) + except ValueError: + return fn + else: + return fn[len(prefix) + 1 :] + + +def format_frame(frame, *, base=None, line=False): + """ + Format a FrameSummary in a short way, without printing full absolute path or code. + + The idea is the result fits on a single line. + """ + extra_line = "" + if line: + extra_line = f"{frame.line} # " + return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}" + + +def format_traceback_short(tb): + """Format a TracebackType in a short way, printing only the inner-most frame.""" + return format_frame(traceback.extract_tb(tb)[-1]) + + +class CapturedTraceback: + __slots__ = ["tb", "skip"] + + def __init__(self, tb, skip=0): + self.tb = tb + self.skip = skip + + def cleanup(self): + self.tb = None + + def summary(self): + import torch._C._profiler + + if self.tb is None: + # TODO: Maybe indicate that the traceback was elided? + return traceback.StackSummary() + + return _extract_symbolized_tb( + torch._C._profiler.symbolize_tracebacks([self.tb])[0], self.skip + ) + + def __getstate__(self): + return ( + None, + { + "tb": None, # TB is not pickleable + "skip": self.skip, + }, + ) + + @staticmethod + def extract(*, script=False, cpp=False, skip=0): + """ + Like traceback.extract_stack(), but faster (approximately 20x faster); it + is fast enough that you can unconditionally log stacks this way as part of + normal execution. It returns a torch._C._profiler.CapturedTraceback + object that must be formatted specially with format_captured_tb. + + By default, this only reports Python backtraces (like extract_stack). You + can set the script/cpp kwargs to also turn on TorchScript/C++ trace + reporting. + """ + import torch._C._profiler + + if script or cpp: + assert skip == 0, "skip with script/cpp NYI" + + return CapturedTraceback( + torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp), + # Elide extract() frame if we don't have script/cpp frames. If + # we do have those frames, it doesn't work so force zero. + 0 if script or cpp else skip + 1, + ) + + def format(self): + """ + Formats a single torch._C._profiler.CapturedTraceback into a list of + strings equivalent to the output of traceback.format_list. Note that if + pass it CapturedTraceback with C++ traces, it is better not to use this + function and use the batch formatting API format_captured_tbs to amortize + the cost of symbolization + """ + return traceback.format_list(self.summary()) + + @staticmethod + def format_all(tbs): + """ + Bulk version of CapturedTraceback.format. Returns a list of list of strings. + """ + import torch._C._profiler + + # Directly populate tracebacks that already have cached summaries + rs: list[Optional[list[str]]] = [] + delayed_idxs = [] + for i, tb in enumerate(tbs): + if tb.tb is None: + rs.append([]) + else: + rs.append(None) + delayed_idxs.append(i) + + torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs]) + for i in delayed_idxs: + rs[i] = traceback.format_list(tbs[i].summary()) + + return rs + + +def _extract_symbolized_tb(tb, skip): + """ + Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of + pre-processed stack trace entries. + """ + stack = traceback.StackSummary() + for f in reversed(tb[skip:]): + stack.append(traceback.FrameSummary(f["filename"], f["line"], f["name"])) + return stack diff --git a/phivenv/Lib/site-packages/torch/utils/_triton.py b/phivenv/Lib/site-packages/torch/utils/_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..504e956e558edbccef57f2fc48e204fdedcdcc3d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_triton.py @@ -0,0 +1,169 @@ +import functools +import hashlib +from typing import Any + + +@functools.cache +def has_triton_package() -> bool: + try: + from triton.compiler.compiler import triton_key + + return triton_key is not None + except ImportError: + return False + except RuntimeError: + return False + + +@functools.cache +def _device_supports_tma() -> bool: + import torch + + return ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ) + + +@functools.cache +def has_triton_experimental_host_tma() -> bool: + if has_triton_package(): + if _device_supports_tma(): + try: + from triton.tools.experimental_descriptor import ( # noqa: F401 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + + return True + except ImportError: + pass + + return False + + +@functools.cache +def has_triton_tensor_descriptor_host_tma() -> bool: + if has_triton_package(): + if _device_supports_tma(): + try: + from triton.tools.tensor_descriptor import ( # noqa: F401 + TensorDescriptor, + ) + + return True + except ImportError: + pass + + return False + + +@functools.cache +def has_triton_tma() -> bool: + return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma() + + +@functools.cache +def has_triton_tma_device() -> bool: + if has_triton_package(): + import torch + + if ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ): + # old API + try: + from triton.language.extra.cuda import ( # noqa: F401 + experimental_device_tensormap_create1d, + experimental_device_tensormap_create2d, + ) + + return True + except ImportError: + pass + + # new API + try: + from triton.language import make_tensor_descriptor # noqa: F401 + + return True + except ImportError: + pass + + return False + + +@functools.lru_cache(None) +def has_triton_stable_tma_api() -> bool: + if has_triton_package(): + import torch + + if ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ): + try: + from triton.language import make_tensor_descriptor # noqa: F401 + + return True + except ImportError: + pass + return False + + +@functools.cache +def has_triton() -> bool: + if not has_triton_package(): + return False + + from torch._dynamo.device_interface import get_interface_for_device + + def cuda_extra_check(device_interface: Any) -> bool: + return device_interface.Worker.get_device_properties().major >= 7 + + def cpu_extra_check(device_interface: Any) -> bool: + import triton.backends + + return "cpu" in triton.backends.backends + + def _return_true(device_interface: Any) -> bool: + return True + + triton_supported_devices = { + "cuda": cuda_extra_check, + "xpu": _return_true, + "cpu": cpu_extra_check, + } + + def is_device_compatible_with_triton() -> bool: + for device, extra_check in triton_supported_devices.items(): + device_interface = get_interface_for_device(device) + if device_interface.is_available() and extra_check(device_interface): + return True + return False + + return is_device_compatible_with_triton() + + +@functools.cache +def triton_backend() -> Any: + from triton.compiler.compiler import make_backend + from triton.runtime.driver import driver + + target = driver.active.get_current_target() + return make_backend(target) + + +@functools.cache +def triton_hash_with_backend() -> str: + from triton.compiler.compiler import triton_key + + backend = triton_backend() + key = f"{triton_key()}-{backend.hash()}" + + # Hash is upper case so that it can't contain any Python keywords. + return hashlib.sha256(key.encode("utf-8")).hexdigest().upper() diff --git a/phivenv/Lib/site-packages/torch/utils/_typing_utils.py b/phivenv/Lib/site-packages/torch/utils/_typing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8677432e476eb7b23e2e46ad7d6d2f8288803d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_typing_utils.py @@ -0,0 +1,14 @@ +"""Miscellaneous utilities to aid with typing.""" + +from typing import Optional, TypeVar + + +# Helper to turn Optional[T] into T when we know None either isn't +# possible or should trigger an exception. +T = TypeVar("T") + + +def not_none(obj: Optional[T]) -> T: + if obj is None: + raise TypeError("Invariant encountered: value was None when it should not be") + return obj diff --git a/phivenv/Lib/site-packages/torch/utils/_zip.py b/phivenv/Lib/site-packages/torch/utils/_zip.py new file mode 100644 index 0000000000000000000000000000000000000000..c29431bd500c0c7ebbd21a9ee961bc7693ee3e73 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/_zip.py @@ -0,0 +1,87 @@ +# mypy: allow-untyped-defs +import argparse +import glob +import os +from pathlib import Path +from zipfile import ZipFile + + +# Exclude some standard library modules to: +# 1. Slim down the final zipped file size +# 2. Remove functionality we don't want to support. +DENY_LIST = [ + # Interface to unix databases + "dbm", + # ncurses bindings (terminal interfaces) + "curses", + # Tcl/Tk GUI + "tkinter", + "tkinter", + # Tests for the standard library + "test", + "tests", + "idle_test", + "__phello__.foo.py", + # importlib frozen modules. These are already baked into CPython. + "_bootstrap.py", + "_bootstrap_external.py", +] + +strip_file_dir = "" + + +def remove_prefix(text, prefix): + if text.startswith(prefix): + return text[len(prefix) :] + return text + + +def write_to_zip(file_path, strip_file_path, zf, prepend_str=""): + stripped_file_path = prepend_str + remove_prefix(file_path, strip_file_dir + "/") + path = Path(stripped_file_path) + if path.name in DENY_LIST: + return + zf.write(file_path, stripped_file_path) + + +def main() -> None: + global strip_file_dir + parser = argparse.ArgumentParser(description="Zip py source") + parser.add_argument("paths", nargs="*", help="Paths to zip.") + parser.add_argument( + "--install-dir", "--install_dir", help="Root directory for all output files" + ) + parser.add_argument( + "--strip-dir", + "--strip_dir", + help="The absolute directory we want to remove from zip", + ) + parser.add_argument( + "--prepend-str", + "--prepend_str", + help="A string to prepend onto all paths of a file in the zip", + default="", + ) + parser.add_argument("--zip-name", "--zip_name", help="Output zip name") + + args = parser.parse_args() + + zip_file_name = args.install_dir + "/" + args.zip_name + strip_file_dir = args.strip_dir + prepend_str = args.prepend_str + zf = ZipFile(zip_file_name, mode="w") + + for p in sorted(args.paths): + if os.path.isdir(p): + files = glob.glob(p + "/**/*.py", recursive=True) + for file_path in sorted(files): + # strip the absolute path + write_to_zip( + file_path, strip_file_dir + "/", zf, prepend_str=prepend_str + ) + else: + write_to_zip(p, strip_file_dir + "/", zf, prepend_str=prepend_str) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/phivenv/Lib/site-packages/torch/utils/backcompat/__init__.py b/phivenv/Lib/site-packages/torch/utils/backcompat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b0cd00c238f21527db509e50a058ab8f7d597d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/backcompat/__init__.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +from torch._C import ( + _get_backcompat_broadcast_warn, + _get_backcompat_keepdim_warn, + _set_backcompat_broadcast_warn, + _set_backcompat_keepdim_warn, +) + + +class Warning: + def __init__(self, setter, getter): + self.setter = setter + self.getter = getter + + def set_enabled(self, value): + self.setter(value) + + def get_enabled(self): + return self.getter() + + enabled = property(get_enabled, set_enabled) + + +broadcast_warning = Warning( + _set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn +) +keepdim_warning = Warning(_set_backcompat_keepdim_warn, _get_backcompat_keepdim_warn) diff --git a/phivenv/Lib/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17191d79f28811f2c4613056389390069f414b38 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/backcompat/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/backend_registration.py b/phivenv/Lib/site-packages/torch/utils/backend_registration.py new file mode 100644 index 0000000000000000000000000000000000000000..429adba3eca1dc4a5fd75bdc449bafc468b8d251 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/backend_registration.py @@ -0,0 +1,440 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend +from torch.overrides import handle_torch_function, has_torch_function_unary + + +__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"] + +# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get +# renamed-backend name for `privateuse1`, but the func will cause an +# error with torch.jit.script, so we use the global variable named +# `_privateuse1_backend_name`. +_privateuse1_backend_name = "privateuseone" + + +def rename_privateuse1_backend(backend_name: str) -> None: + r""" + Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs. + + The steps are: + + (1) (In C++) implement kernels for various torch operations, and register them + to the PrivateUse1 dispatch key. + (2) (In python) call torch.utils.rename_privateuse1_backend("foo") + + You can now use "foo" as an ordinary device string in python. + + Note: this API can only be called once per process. Attempting to change + the external backend after it's already been set will result in an error. + + Note(AMP): If you want to support AMP on your device, you can register a custom backend module. + The backend must register a custom backend module with ``torch._register_device_module("foo", BackendModule)``. + BackendModule needs to have the following API's: + + (1) ``get_amp_supported_dtype() -> List[torch.dtype]`` + get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype. + + Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's: + + (1) ``_is_in_bad_fork() -> bool`` + Return ``True`` if now it is in bad_fork, else return ``False``. + + (2) ``manual_seed_all(seed int) -> None`` + Sets the seed for generating random numbers for your devices. + + (3) ``device_count() -> int`` + Returns the number of "foo"s available. + + (4) ``get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor`` + Returns a list of ByteTensor representing the random number states of all devices. + + (5) ``set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None`` + Sets the random number generator state of the specified "foo" device. + + And there are some common funcs: + + (1) ``is_available() -> bool`` + Returns a bool indicating if "foo" is currently available. + + (2) ``current_device() -> int`` + Returns the index of a currently selected device. + + For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend + For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example + + Example:: + + >>> # xdoctest: +SKIP("failing") + >>> torch.utils.rename_privateuse1_backend("foo") + # This will work, assuming that you've implemented the right C++ kernels + # to implement torch.ones. + >>> a = torch.ones(2, device="foo") + + """ + _rename_privateuse1_backend(backend_name) + global _privateuse1_backend_name + _privateuse1_backend_name = backend_name + + +def _check_register_once(module, attr): + if hasattr(module, attr): + raise RuntimeError( + f"The custom device module of {module} has already been registered with {attr}" + ) + + +def _normalization_device( + custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None +) -> int: + def _get_current_device_index(): + _get_device_index = "current_device" + if hasattr(torch, custom_backend_name) and hasattr( + getattr(torch, custom_backend_name), _get_device_index + ): + return getattr(getattr(torch, custom_backend_name), _get_device_index)() + else: + # The default device index is 0. + return 0 + + if device is None: + return _get_current_device_index() + # if isinstance(device, str), this means that the parameter passed in is in the string format "foo:0" + # convert str object to torch.device object, and then process it uniformly + elif isinstance(device, str): + device = torch.device(device) + + # variable devcie can only be torch.device type or int type + if isinstance(device, torch.device): + if device.type != custom_backend_name: + raise RuntimeError(f"Invalid device, must be {custom_backend_name} device") + elif device.index is None: + device_idx = _get_current_device_index() + else: + device_idx = device.index + # if isinstance(device, int), we can take the index number directly + else: + device_idx = device + return device_idx + + +def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None: + @property # type: ignore[misc] + def wrap_tensor_backend(self: torch.Tensor) -> bool: + if has_torch_function_unary(self): + # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 + return handle_torch_function(wrap_tensor_backend.__get__, (self,), self) # type: ignore[attr-defined] + return self.device.type == custom_backend_name + + _check_register_once(torch.Tensor, f"is_{custom_backend_name}") + wrap_tensor_backend.fget.__name__ = f"is_{custom_backend_name}" # type: ignore[attr-defined] + setattr(torch.Tensor, f"is_{custom_backend_name}", wrap_tensor_backend) + + def wrap_tensor_to( + self: torch.Tensor, + device: Optional[Union[int, torch.device]] = None, + non_blocking=False, + **kwargs, + ) -> torch.Tensor: + r"""Perform Tensor device conversion. Call the to operator implementation. + + .. note:: + If the ``self`` Tensor already + has the correct :class:`torch.device`, then ``self`` is returned. + Otherwise, the returned tensor is a copy of ``self`` with the desired :class:`torch.device`. + + Args: + device (int, optional): if specified, all parameters will be copied to that device + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + **kwargs (dict): For compatibility, may contain the key ``memory_format`` argument. + """ + if has_torch_function_unary(self): + return handle_torch_function( + wrap_tensor_to, + (self,), + self, + device=device, + non_blocking=False, + **kwargs, + ) + device_idx = _normalization_device(custom_backend_name, device) + return self.to( + device=torch.device(f"{custom_backend_name}:{device_idx}"), + non_blocking=non_blocking, + **kwargs, + ) + + _check_register_once(torch.Tensor, custom_backend_name) + wrap_tensor_to.__name__ = custom_backend_name + setattr(torch.Tensor, custom_backend_name, wrap_tensor_to) + + +def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -> None: + # Generate Module attributes and methods depends on Tensor methods, + # so we need to check whether Tensor methods is already registered. + if not hasattr(torch.Tensor, custom_backend_name): + raise RuntimeError( + f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module." + f"Because torch.Tensor doesn't has the method {custom_backend_name}()." + f"For this error, you can try setting for_tensor=True." + ) + + def wrap_module_to( + self: torch.nn.modules.module.T, + device: Optional[Union[int, torch.device]] = None, + ) -> torch.nn.modules.module.T: + r"""Move all model parameters and buffers to the custom device. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on device while being optimized. + + .. note:: + This method modifies the module in-place. + + Args: + device (int, optional): if specified, all parameters will be copied to that device + """ + return self._apply(lambda t: getattr(t, custom_backend_name)(device)) + + _check_register_once(torch.nn.Module, custom_backend_name) + setattr(torch.nn.Module, custom_backend_name, wrap_module_to) + + +def _generate_packed_sequence_methods_for_privateuse1_backend( + custom_backend_name: str, +) -> None: + # Generate PackedSequence Module attributes and methods depends on Tensor methods, + # so we need to check whether Tensor methods is already registered. + if not hasattr(torch.Tensor, f"is_{custom_backend_name}") or not hasattr( + torch.Tensor, custom_backend_name + ): + raise RuntimeError( + f"Can not automatically generate is_{custom_backend_name}() or " + f"{custom_backend_name}() method for torch.nn.utils.rnn.PackedSequence." + f"Because torch.Tensor doesn't has the method is_{custom_backend_name}()" + f"or {custom_backend_name}()." + f"For this error, you can try setting for_tensor=True." + ) + + @property # type: ignore[misc] + def wrap_tensor_backend(self: torch.nn.utils.rnn.PackedSequence) -> bool: + return self.data.device.type == custom_backend_name + + _check_register_once(torch.nn.utils.rnn.PackedSequence, f"is_{custom_backend_name}") + setattr( + torch.nn.utils.rnn.PackedSequence, + f"is_{custom_backend_name}", + wrap_tensor_backend, + ) + + def wrap_module_to( + self: torch.nn.utils.rnn.PackedSequence, *args, **kwargs + ) -> torch.nn.utils.rnn.PackedSequence: + r"""Move all model parameters and buffers to the custom device. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on device while being optimized. + + .. note:: + This method modifies the module in-place. + + Args: + device (int, optional): if specified, all parameters will be copied to that device + """ + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( + *args, **kwargs + ) + if ex.device.type == custom_backend_name: + return self.to(*args, **kwargs) + kwargs.update({"device": custom_backend_name}) + return self.to(*args, **kwargs) + + _check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name) + setattr(torch.nn.utils.rnn.PackedSequence, custom_backend_name, wrap_module_to) + + +def _generate_storage_methods_for_privateuse1_backend( + custom_backend_name: str, unsupported_dtype: Optional[list[torch.dtype]] = None +) -> None: + # Attribute is registered in the _StorageBase class + # and UntypedStorage obtains through inheritance. + @property # type: ignore[misc] + def wrap_storage_backend(self: torch.storage._StorageBase) -> bool: + r"""Return the internal :class:`torch.UntypedStorage`.""" + return self.device.type == custom_backend_name + + _check_register_once(torch.storage._StorageBase, f"is_{custom_backend_name}") + setattr( + torch.storage._StorageBase, f"is_{custom_backend_name}", wrap_storage_backend + ) + + def wrap_storage_to(self, device=None, non_blocking=False): + r"""Return a copy of this object in custom device memory. + + If this object is already in device memory and on the correct device, then + no copy is performed and the original object is returned. + + Args: + device (int): The destination device id. Defaults to the current device. + non_blocking (bool): If ``True`` and the source is in pinned memory, + the copy will be asynchronous with respect to the host. Otherwise, + the argument has no effect. + """ + # There should be a judgment related to storage device and a judgment related to storage type, + # but it depends on the extended function, so this part is temporarily omitted in the automatic generation. + device_idx = _normalization_device(custom_backend_name, device) + + if getattr(self, f"is_{custom_backend_name}"): + # storage has already on expected device. + if self.get_device() == device_idx: + return self + # For sparse storage, custom need to extend the implementation by themselves. + if self.is_sparse: + raise RuntimeError( + f"Can not support a sparse storage move to {custom_backend_name} backend" + ) + # create untyped_storage and copy data + untyped_storage = torch.UntypedStorage( + self.size(), device=torch.device(f"{custom_backend_name}:{device_idx}") + ) + untyped_storage.copy_(self, non_blocking) + return untyped_storage + + _check_register_once(torch.storage._StorageBase, custom_backend_name) + setattr(torch.storage._StorageBase, custom_backend_name, wrap_storage_to) + + # Register the corresponding attribute for the TypedStorage class. + # When the TypedStorage class is removed, the registration is also removed. + + @property # type: ignore[misc] + def wrap_typed_storage_backend(self: torch.storage.TypedStorage) -> bool: + torch.storage._warn_typed_storage_removal() + return self._untyped_storage.device.type == custom_backend_name + + _check_register_once(torch.TypedStorage, f"is_{custom_backend_name}") + setattr( + torch.storage.TypedStorage, + f"is_{custom_backend_name}", + wrap_typed_storage_backend, + ) + + def wrap_typed_storage_to( + self: torch.storage.TypedStorage, device=None, non_blocking=False, **kwargs + ) -> torch.storage.TypedStorage: + torch.storage._warn_typed_storage_removal() + if unsupported_dtype and self.dtype in unsupported_dtype: + raise RuntimeError( + f"Cannot create {custom_backend_name} storage " + f"as {self.dtype} dtype is not supported by this backend" + ) + custom_backend_storage: torch.UntypedStorage = getattr( + self._untyped_storage, custom_backend_name + )(device, non_blocking, **kwargs) + return self._new_wrapped_storage(custom_backend_storage) + + _check_register_once(torch.TypedStorage, custom_backend_name) + setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to) + + +def generate_methods_for_privateuse1_backend( + for_tensor: bool = True, + for_module: bool = True, + for_packed_sequence: bool = True, + for_storage: bool = False, + unsupported_dtype: Optional[list[torch.dtype]] = None, +) -> None: + r""" + Automatically generate attributes and methods for the custom backend after rename privateuse1 backend. + + In the default scenario, storage-related methods will not be generated automatically. + + When you implement kernels for various torch operations, and register them to the PrivateUse1 dispatch key. + And call the function torch.rename_privateuse1_backend("foo") to rename your backend name. + At this point, you can easily register specific methods and attributes by calling this function. + Just like torch.Tensor.foo(), torch.Tensor.is_foo, torch.Storage.foo(), torch.Storage.is_foo. + + Note: We recommend you use generic functions (check devices are equal or to(device=)). + We provide these methods for convenience only and they will be "monkey patched" onto the objects + and so will not be properly typed. For Storage methods generate, if you need to support sparse data storage, + you need to extend the implementation yourself. + + Args: + for_tensor (bool): whether register related methods for torch.Tensor class. + for_module (bool): whether register related methods for torch.nn.Module class. + for_storage (bool): whether register related methods for torch.Storage class. + unsupported_dtype (List[torch.dtype]): takes effect only when the storage method needs to be generated, + indicating that the storage does not support the torch.dtype type. + + Example:: + + >>> # xdoctest: +SKIP("failing") + >>> torch.utils.rename_privateuse1_backend("foo") + >>> torch.utils.generate_methods_for_privateuse1_backend() + # Then automatically generate backend-related attributes and methods. + >>> a = torch.tensor(2).foo() + >>> a.is_foo + >>> hasattr(torch.nn.Module, 'foo') + """ + custom_backend_name = _get_privateuse1_backend_name() + + if for_tensor: + _generate_tensor_methods_for_privateuse1_backend(custom_backend_name) + + if for_module: + _generate_module_methods_for_privateuse1_backend(custom_backend_name) + + if for_storage: + _generate_storage_methods_for_privateuse1_backend( + custom_backend_name, unsupported_dtype + ) + + if for_packed_sequence: + _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name) + + +def _get_custom_mod_func(func_name: str): + r""" + Return the func named `func_name` defined in custom device module. If not defined, + return `None`. And the func is registered with `torch.utils.rename_privateuse1_backend('foo')` + and `torch._register_device_module('foo', BackendModule)`. + If the custom device module or the func is not defined, it will give warning or error message. + Args: + func_name (str): return the callable func named func_name defined in custom device module. + Example:: + class DummyfooModule: + @staticmethod + def is_available(): + return True + @staticmethod + def func_name(*args, **kwargs): + .... + torch.utils.rename_privateuse1_backend("foo") + torch._register_device_module("foo", DummyfooModule) + foo_is_available_func = torch.utils.backend_registration._get_custom_mod_func("is_available") + if foo_is_available_func: + foo_is_available = foo_is_available_func() + func_ = torch.utils.backend_registration._get_custom_mod_func("func_name") + if func_: + result = func_(*args, **kwargs) + Attention: This function is not meant to be used directly by users, which is why + it is marked as private. It is a convenience function for backend implementers to + more easily call the hooks into their backend extensions. + """ + assert isinstance( + func_name, str + ), f"func_name must be `str`, but got `{type(func_name)}`." + backend_name = _get_privateuse1_backend_name() + custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type] + function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type] + if custom_device_mod is None or function is None: + message = f"Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend " + message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And " + message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n" + raise RuntimeError(message) + return function diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/__init__.py b/phivenv/Lib/site-packages/torch/utils/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6c4b8102eebc086c399057bf81c3de57d9c632 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/__init__.py @@ -0,0 +1,6 @@ +from torch.utils.benchmark.utils.common import * # noqa: F403 +from torch.utils.benchmark.utils.timer import * # noqa: F403 +from torch.utils.benchmark.utils.compare import * # noqa: F403 +from torch.utils.benchmark.utils.fuzzer import * # noqa: F403 +from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import * # noqa: F403 +from torch.utils.benchmark.utils.sparse_fuzzer import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..471f68aae80129cd0978b3b0811d0e905885b28a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__init__.py b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f901e5a178bfb3fafb643bfadf0891233f6dec24 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/compare.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/compare.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec3d4ce0937d393dc15be9042ca557192c0319ef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/compare.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c3b3c04e2e4e80f085e72445198856daaa965b4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47d96a003ebe85387f38ba9126e9853c4de2fd55 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca9d9e45944cb681e113e3b95eada13c1ee87b33 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c13888e104fd60f179042ac59a8a2bcf8e90377a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/compare.py b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/compare.py new file mode 100644 index 0000000000000000000000000000000000000000..3596ffbbac68f055cc0c34b4aa1ec5a0544e4f47 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/compare.py @@ -0,0 +1,99 @@ +# mypy: allow-untyped-defs +"""Example of Timer and Compare APIs: + +$ python -m examples.compare +""" + +import pickle +import sys +import time + +import torch + +import torch.utils.benchmark as benchmark_utils + + +class FauxTorch: + """Emulate different versions of pytorch. + + In normal circumstances this would be done with multiple processes + writing serialized measurements, but this simplifies that model to + make the example clearer. + """ + def __init__(self, real_torch, extra_ns_per_element): + self._real_torch = real_torch + self._extra_ns_per_element = extra_ns_per_element + + def extra_overhead(self, result): + # time.sleep has a ~65 us overhead, so only fake a + # per-element overhead if numel is large enough. + numel = int(result.numel()) + if numel > 5000: + time.sleep(numel * self._extra_ns_per_element * 1e-9) + return result + + def add(self, *args, **kwargs): + return self.extra_overhead(self._real_torch.add(*args, **kwargs)) + + def mul(self, *args, **kwargs): + return self.extra_overhead(self._real_torch.mul(*args, **kwargs)) + + def cat(self, *args, **kwargs): + return self.extra_overhead(self._real_torch.cat(*args, **kwargs)) + + def matmul(self, *args, **kwargs): + return self.extra_overhead(self._real_torch.matmul(*args, **kwargs)) + + +def main(): + tasks = [ + ("add", "add", "torch.add(x, y)"), + ("add", "add (extra +0)", "torch.add(x, y + zero)"), + ] + + serialized_results = [] + repeats = 2 + timers = [ + benchmark_utils.Timer( + stmt=stmt, + globals={ + "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns), + "x": torch.ones((size, 4)), + "y": torch.ones((1, 4)), + "zero": torch.zeros(()), + }, + label=label, + sub_label=sub_label, + description=f"size: {size}", + env=branch, + num_threads=num_threads, + ) + for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)] + for label, sub_label, stmt in tasks + for size in [1, 10, 100, 1000, 10000, 50000] + for num_threads in [1, 4] + ] + + for i, timer in enumerate(timers * repeats): + serialized_results.append(pickle.dumps( + timer.blocked_autorange(min_run_time=0.05) + )) + print(f"\r{i + 1} / {len(timers) * repeats}", end="") + sys.stdout.flush() + print() + + comparison = benchmark_utils.Compare([ + pickle.loads(i) for i in serialized_results + ]) + + print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n") + comparison.print() + + print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n") + comparison.trim_significant_figures() + comparison.colorize() + comparison.print() + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/fuzzer.py b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/fuzzer.py new file mode 100644 index 0000000000000000000000000000000000000000..1e388b0f922409233d1f652f2aa04ff8b9e58842 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/fuzzer.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +"""Example of the Timer and Fuzzer APIs: + +$ python -m examples.fuzzer +""" + +import sys + +import torch.utils.benchmark as benchmark_utils + + +def main(): + add_fuzzer = benchmark_utils.Fuzzer( + parameters=[ + [ + benchmark_utils.FuzzedParameter( + name=f"k{i}", + minval=16, + maxval=16 * 1024, + distribution="loguniform", + ) for i in range(3) + ], + benchmark_utils.FuzzedParameter( + name="d", + distribution={2: 0.6, 3: 0.4}, + ), + ], + tensors=[ + [ + benchmark_utils.FuzzedTensor( + name=name, + size=("k0", "k1", "k2"), + dim_parameter="d", + probability_contiguous=0.75, + min_elements=64 * 1024, + max_elements=128 * 1024, + ) for name in ("x", "y") + ], + ], + seed=0, + ) + + n = 250 + measurements = [] + for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)): + x, x_order = tensors["x"], str(tensor_properties["x"]["order"]) + y, y_order = tensors["y"], str(tensor_properties["y"]["order"]) + shape = ", ".join(tuple(f'{i:>4}' for i in x.shape)) + + description = "".join([ + f"{x.numel():>7} | {shape:<16} | ", + f"{'contiguous' if x.is_contiguous() else x_order:<12} | ", + f"{'contiguous' if y.is_contiguous() else y_order:<12} | ", + ]) + + timer = benchmark_utils.Timer( + stmt="x + y", + globals=tensors, + description=description, + ) + + measurements.append(timer.blocked_autorange(min_run_time=0.1)) + measurements[-1].metadata = {"numel": x.numel()} + print(f"\r{i + 1} / {n}", end="") + sys.stdout.flush() + print() + + # More string munging to make pretty output. + print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}") + + def time_fn(m): + return m.median / m.metadata["numel"] + measurements.sort(key=time_fn) + + template = f"{{:>6}}{' ' * 19}Size Shape{' ' * 13}X order Y order\n{'-' * 80}" + print(template.format("Best:")) + for m in measurements[:15]: + print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}") + + print("\n" + template.format("Worst:")) + for m in measurements[-15:]: + print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}") + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/op_benchmark.py b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/op_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..8730aaf6dd8405a5a875f3b7e895866d2bb955dc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/op_benchmark.py @@ -0,0 +1,105 @@ +# mypy: allow-untyped-defs +"""Example use of Timer and op fuzzers to measure kernel performance. + +$ python -m examples.op_benchmark +""" + +import numpy as np +import torch + +from torch.utils.benchmark import Timer +from torch.utils.benchmark.op_fuzzers.binary import BinaryOpFuzzer +from torch.utils.benchmark.op_fuzzers.unary import UnaryOpFuzzer +import operator + + +_MEASURE_TIME = 1.0 + + +def assert_dicts_equal(dict_0, dict_1): + """Builtin dict comparison will not compare numpy arrays. + e.g. + x = {"a": np.ones((2, 1))} + x == x # Raises ValueError + """ + assert set(dict_0.keys()) == set(dict_0.keys()) + assert all(np.all(v == dict_1[k]) for k, v in dict_0.items() if k != "dtype") + + +def run(n, stmt, fuzzer_cls): + float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n) + int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n) + raw_results = [] + for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter)): + float_tensors, float_tensor_params, float_params = float_values + int_tensors, int_tensor_params, int_params = int_values + + # This benchmark assumes that the two fuzzers generate identically + # sized and strided Tensors, since the same seed is used. + assert_dicts_equal(float_params, int_params) + assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"]) + + float_measurement, int_measurement = ( + Timer( + stmt, + globals=tensors, + ).blocked_autorange(min_run_time=_MEASURE_TIME) + for tensors in (float_tensors, int_tensors) + ) + + descriptions = [] + for name in float_tensors: + shape_str = "(" + ", ".join([ + f"2 ** {int(np.log2(i))}" + if 2 ** int(np.log2(i)) == i and i > 1 + else str(i) + for i in float_tensors[name].shape + ]) + ")" + order = float_tensor_params[name]["order"] + order_str = ("" if all(order == np.arange(len(order))) else str(tuple(order))) + steps = float_tensor_params[name]["steps"] + steps_str = str(steps) if sum(steps) > len(steps) else "" + descriptions.append((name, shape_str, order_str, steps_str)) + raw_results.append((float_measurement, int_measurement, descriptions)) + + print(f"\r{i + 1} / {n}", end="") + print() + + parsed_results, name_len, shape_len, order_len, steps_len = [], 0, 0, 0, 0 + for float_measurement, int_measurement, descriptions in raw_results: + t_float = float_measurement.median * 1e6 + t_int = int_measurement.median * 1e6 + rel_diff = abs(t_float - t_int) / (t_float + t_int) * 2 + parsed_results.append((t_float, t_int, rel_diff, descriptions)) + for name, shape, order, steps in descriptions: + name_len = max(name_len, len(name)) + shape_len = max(shape_len, len(shape)) + order_len = max(order_len, len(order)) + steps_len = max(steps_len, len(steps)) + + parsed_results.sort(key=operator.itemgetter(2)) + + print(f"stmt: {stmt}") + print(f" diff faster{'':>17}{' ' * name_len} ", end="") + print(f"{'shape'.ljust(shape_len)}{'':>16}{'order'.ljust(order_len)}", end="") + print(f" steps\n{'-' * 100}") + for results, spacer in [(parsed_results[:10], "..."), (parsed_results[-10:], "")]: + for t_float, t_int, rel_diff, descriptions in results: + time_str = [f"{rel_diff * 100:>4.1f}% {'int' if t_int < t_float else 'float':<20}"] + time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]]) + for t_str, (name, shape, order, steps) in zip(time_str, descriptions): + name = f"{name}:".ljust(name_len + 1) + shape = shape.ljust(shape_len + 10) + order = order.ljust(order_len) + print(f"{t_str} {name} {shape}| {order} | {steps}") + print(spacer) + + +def main(): + run(n=100, stmt="torch.median(x, dim=0)", fuzzer_cls=UnaryOpFuzzer) + run(n=100, stmt="torch.square(x)", fuzzer_cls=UnaryOpFuzzer) + run(n=100, stmt="x + y", fuzzer_cls=BinaryOpFuzzer) + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/simple_timeit.py b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/simple_timeit.py new file mode 100644 index 0000000000000000000000000000000000000000..47d7ab4e495e7b2d21ad312ca86aa35a406a9aca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/simple_timeit.py @@ -0,0 +1,25 @@ +"""Trivial use of Timer API: + +$ python -m examples.simple_timeit +""" + +import torch + +import torch.utils.benchmark as benchmark_utils + + +def main() -> None: + timer = benchmark_utils.Timer( + stmt="x + y", + globals={"x": torch.ones((4, 8)), "y": torch.ones((1, 8))}, + label="Broadcasting add (4x8)", + ) + + for i in range(3): + print(f"Run: {i}\n{'-' * 40}") + print(f"timeit:\n{timer.timeit(10000)}\n") + print(f"autorange:\n{timer.blocked_autorange()}\n\n") + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ee623c2dcc88fb257e5dbd591a344ec28695557f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs +"""Microbenchmarks for the torch.fft module""" +from argparse import ArgumentParser +from collections import namedtuple +from collections.abc import Iterable + +import torch +import torch.fft +from torch.utils import benchmark +from torch.utils.benchmark.op_fuzzers.spectral import SpectralOpFuzzer + + +def _dim_options(ndim): + if ndim == 1: + return [None] + elif ndim == 2: + return [0, 1, None] + elif ndim == 3: + return [0, 1, 2, (0, 1), (0, 2), None] + raise ValueError(f"Expected ndim in range 1-3, got {ndim}") + + +def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, device: str, samples: int, + probability_regular: float): + cuda = device == 'cuda' + spectral_fuzzer = SpectralOpFuzzer(seed=seed, dtype=dtype, cuda=cuda, + probability_regular=probability_regular) + results = [] + for tensors, tensor_params, params in spectral_fuzzer.take(samples): + shape = [params['k0'], params['k1'], params['k2']][:params['ndim']] + str_shape = ' x '.join([f"{s:<4}" for s in shape]) + sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}" + for dim in _dim_options(params['ndim']): + for nthreads in (1, 4, 16) if not cuda else (1,): + measurement = benchmark.Timer( + stmt='func(x, dim=dim)', + globals={'func': function, 'x': tensors['x'], 'dim': dim}, + label=f"{name}_{device}", + sub_label=sub_label, + description=f"dim={dim}", + num_threads=nthreads, + ).blocked_autorange(min_run_time=1) + measurement.metadata = { + 'name': name, + 'device': device, + 'dim': dim, + 'shape': shape, + } + measurement.metadata.update(tensor_params['x']) + results.append(measurement) + return results + + +Benchmark = namedtuple('Benchmark', ['name', 'function', 'dtype']) +BENCHMARKS = [ + Benchmark('fft_real', torch.fft.fftn, torch.float32), + Benchmark('fft_complex', torch.fft.fftn, torch.complex64), + Benchmark('ifft', torch.fft.ifftn, torch.complex64), + Benchmark('rfft', torch.fft.rfftn, torch.float32), + Benchmark('irfft', torch.fft.irfftn, torch.complex64), +] +BENCHMARK_MAP = {b.name: b for b in BENCHMARKS} +BENCHMARK_NAMES = [b.name for b in BENCHMARKS] +DEVICE_NAMES = ['cpu', 'cuda'] + +def _output_csv(file, results): + file.write('benchmark,device,num_threads,numel,shape,contiguous,dim,mean (us),median (us),iqr (us)\n') + for measurement in results: + metadata = measurement.metadata + device, dim, shape, name, numel, contiguous = ( + metadata['device'], metadata['dim'], metadata['shape'], + metadata['name'], metadata['numel'], metadata['is_contiguous']) + + if isinstance(dim, Iterable): + dim_str = '-'.join(str(d) for d in dim) + else: + dim_str = str(dim) + shape_str = 'x'.join(str(s) for s in shape) + + print(name, device, measurement.task_spec.num_threads, numel, shape_str, contiguous, dim_str, # type: ignore[possibly-undefined] + measurement.mean * 1e6, measurement.median * 1e6, measurement.iqr * 1e6, + sep=',', file=file) + + +if __name__ == '__main__': + parser = ArgumentParser(description=__doc__) + parser.add_argument('--device', type=str, choices=DEVICE_NAMES, nargs='+', default=DEVICE_NAMES) + parser.add_argument('--bench', type=str, choices=BENCHMARK_NAMES, nargs='+', default=BENCHMARK_NAMES) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--samples', type=int, default=10) + parser.add_argument('--probability-regular', '--probability_regular', type=float, default=1.0) + parser.add_argument('-o', '--output', type=str) + args = parser.parse_args() + + num_benchmarks = len(args.device) * len(args.bench) + i = 0 + results = [] + for device in args.device: + for bench in (BENCHMARK_MAP[b] for b in args.bench): + results += run_benchmark( + name=bench.name, function=bench.function, dtype=bench.dtype, + seed=args.seed, device=device, samples=args.samples, + probability_regular=args.probability_regular) + i += 1 + print(f'Completed {bench.name} benchmark on {device} ({i} of {num_benchmarks})') + + if args.output is not None: + with open(args.output, 'w') as f: + _output_csv(f, results) + + compare = benchmark.Compare(results) + compare.trim_significant_figures() + compare.colorize() + compare.print() diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__init__.py b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14ba04d7625b60266ce74eea081febee6bb890e7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82ca7964564a2acdedacd04f15758bd3fd5d5c62 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98f0279a96969234fb5dab4ee05ed4fbe60b1e0e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c3be47d7c68c8215540edc7c9d88c1b0a7f6673 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c123526a0e5473f7a9269be705d64d3510fb95d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcfb1c763df6be9ab23b6d3a85e9cdfeb02b00ef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/binary.py b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/binary.py new file mode 100644 index 0000000000000000000000000000000000000000..223967c5f6a3556f29b2b3e5049070c8183f3f36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/binary.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +import numpy as np +import torch + +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor + + +_MIN_DIM_SIZE = 16 +_MAX_DIM_SIZE = 16 * 1024 ** 2 +_POW_TWO_SIZES = tuple(2 ** i for i in range( + int(np.log2(_MIN_DIM_SIZE)), + int(np.log2(_MAX_DIM_SIZE)) + 1, +)) + + +class BinaryOpFuzzer(Fuzzer): + def __init__(self, seed, dtype=torch.float32, cuda=False): + super().__init__( + parameters=[ + # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("dim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + + # Shapes for `x` and `y`. + # It is important to test all shapes, however + # powers of two are especially important and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the powers of two + # (both distributions are loguniform) and then randomly + # selecting between the two. + # Moreover, `y` will occasionally have singleton + # dimensions in order to test broadcasting. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=_MIN_DIM_SIZE, + maxval=_MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_pow2_{i}", + distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_any_{i}"): 0.8, + ParameterAlias(f"k_pow2_{i}"): 0.2, + }, + strict=True, + ) for i in range(3) + ], + + [ + FuzzedParameter( + name=f"y_k{i}", + distribution={ + ParameterAlias(f"k{i}"): 0.8, + 1: 0.2, + }, + strict=True, + ) for i in range(3) + ], + + # Steps for `x` and `y`. (Benchmarks strided memory access.) + [ + FuzzedParameter( + name=f"{name}_step_{i}", + distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04}, + ) + for i in range(3) + for name in ("x", "y") + ], + + # Repeatable entropy for downstream applications. + FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"), + ], + tensors=[ + FuzzedTensor( + name="x", + size=("k0", "k1", "k2"), + steps=("x_step_0", "x_step_1", "x_step_2"), + probability_contiguous=0.75, + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + max_allocation_bytes=2 * 1024**3, # 2 GB + dim_parameter="dim", + dtype=dtype, + cuda=cuda, + ), + FuzzedTensor( + name="y", + size=("y_k0", "y_k1", "y_k2"), + steps=("x_step_0", "x_step_1", "x_step_2"), + probability_contiguous=0.75, + max_allocation_bytes=2 * 1024**3, # 2 GB + dim_parameter="dim", + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2d0fc2cb6e80b100ddf53e2a42097b3f564a66 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/sparse_binary.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +import numpy as np +import torch + +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedSparseTensor + + +_MIN_DIM_SIZE = 16 +_MAX_DIM_SIZE = 16 * 1024 ** 2 +_POW_TWO_SIZES = tuple(2 ** i for i in range( + int(np.log2(_MIN_DIM_SIZE)), + int(np.log2(_MAX_DIM_SIZE)) + 1, +)) + + +class BinaryOpSparseFuzzer(Fuzzer): + def __init__(self, seed, dtype=torch.float32, cuda=False): + super().__init__( + parameters=[ + # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("dim_parameter", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + FuzzedParameter( + name="sparse_dim", + distribution={1: 0.4, 2: 0.4, 3: 0.2}, + strict=True + ), + # Shapes for `x` and `y`. + # It is important to test all shapes, however + # powers of two are especially important and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the powers of two + # (both distributions are loguniform) and then randomly + # selecting between the two. + # Moreover, `y` will occasionally have singleton + # dimensions in order to test broadcasting. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=_MIN_DIM_SIZE, + maxval=_MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_pow2_{i}", + distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_any_{i}"): 0.8, + ParameterAlias(f"k_pow2_{i}"): 0.2, + }, + strict=True, + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"y_k{i}", + distribution={ + ParameterAlias(f"k{i}"): 1.0}, + strict=True, + ) for i in range(3) + ], + FuzzedParameter( + name="density", + distribution={0.1: 0.4, 0.05: 0.3, 0.01: 0.3}, + ), + FuzzedParameter( + name="coalesced", + distribution={True: 0.5, False: 0.5}, + ), + # Repeatable entropy for downstream applications. + FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"), + ], + tensors=[ + FuzzedSparseTensor( + name="x", + size=("k0", "k1", "k2"), + dim_parameter="dim_parameter", + sparse_dim="sparse_dim", + density="density", + coalesced="coalesced", + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + dtype=dtype, + cuda=cuda, + ), + FuzzedSparseTensor( + name="y", + size=("y_k0", "y_k1", "y_k2"), + dim_parameter="dim_parameter", + sparse_dim="sparse_dim", + density="density", + coalesced="coalesced", + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py new file mode 100644 index 0000000000000000000000000000000000000000..8b75415c65e33749439bc722cf6c022e9683337a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/sparse_unary.py @@ -0,0 +1,83 @@ +# mypy: allow-untyped-defs + +import numpy as np +import torch +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedSparseTensor + + +_MIN_DIM_SIZE = 16 +_MAX_DIM_SIZE = 16 * 1024 ** 2 +_POW_TWO_SIZES = tuple(2 ** i for i in range( + int(np.log2(_MIN_DIM_SIZE)), + int(np.log2(_MAX_DIM_SIZE)) + 1, +)) + +class UnaryOpSparseFuzzer(Fuzzer): + def __init__(self, seed, dtype=torch.float32, cuda=False): + super().__init__( + parameters=[ + # Sparse dim parameter of x. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("dim_parameter", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + FuzzedParameter( + name="sparse_dim", + distribution={1: 0.4, 2: 0.4, 3: 0.2}, + strict=True + ), + # Shapes for `x`. + # It is important to test all shapes, however + # powers of two are especially important and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the powers of two + # (both distributions are loguniform) and then randomly + # selecting between the two. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=_MIN_DIM_SIZE, + maxval=_MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_pow2_{i}", + distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_any_{i}"): 0.8, + ParameterAlias(f"k_pow2_{i}"): 0.2, + }, + strict=True, + ) for i in range(3) + ], + FuzzedParameter( + name="density", + distribution={0.1: 0.4, 0.05: 0.3, 0.01: 0.3}, + ), + FuzzedParameter( + name="coalesced", + distribution={True: 0.5, False: 0.5}, + ), + FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"), + ], + tensors=[ + FuzzedSparseTensor( + name="x", + size=("k0", "k1", "k2"), + dim_parameter="dim_parameter", + sparse_dim="sparse_dim", + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + density="density", + coalesced="coalesced", + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py new file mode 100644 index 0000000000000000000000000000000000000000..ebed54c6e0d98a9eb6612d4b6a9181b61a017b79 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/spectral.py @@ -0,0 +1,94 @@ +# mypy: allow-untyped-defs +import math + +import torch +from torch.utils import benchmark +from torch.utils.benchmark import FuzzedParameter, FuzzedTensor, ParameterAlias + + +__all__ = ['SpectralOpFuzzer'] + +MIN_DIM_SIZE = 16 +MAX_DIM_SIZE = 16 * 1024 + +def power_range(upper_bound, base): + return (base ** i for i in range(int(math.log(upper_bound, base)) + 1)) + +# List of regular numbers from MIN_DIM_SIZE to MAX_DIM_SIZE +# These numbers factorize into multiples of prime factors 2, 3, and 5 only +# and are usually the fastest in FFT implementations. +REGULAR_SIZES = [] +for i in power_range(MAX_DIM_SIZE, 2): + for j in power_range(MAX_DIM_SIZE // i, 3): + ij = i * j + for k in power_range(MAX_DIM_SIZE // ij, 5): + ijk = ij * k + if ijk > MIN_DIM_SIZE: + REGULAR_SIZES.append(ijk) +REGULAR_SIZES.sort() + +class SpectralOpFuzzer(benchmark.Fuzzer): + def __init__(self, *, seed: int, dtype=torch.float64, + cuda: bool = False, probability_regular: float = 1.0): + super().__init__( + parameters=[ + # Dimensionality of x. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("ndim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + + # Shapes for `x`. + # It is important to test all shapes, however + # regular sizes are especially important to the FFT and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the regular numbers + # (both distributions are loguniform) and then randomly + # selecting between the two. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=MIN_DIM_SIZE, + maxval=MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_regular_{i}", + distribution={size: 1. / len(REGULAR_SIZES) for size in REGULAR_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_regular_{i}"): probability_regular, + ParameterAlias(f"k_any_{i}"): 1 - probability_regular, + }, + strict=True, + ) for i in range(3) + ], + + # Steps for `x`. (Benchmarks strided memory access.) + [ + FuzzedParameter( + name=f"step_{i}", + distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04}, + ) for i in range(3) + ], + ], + tensors=[ + FuzzedTensor( + name="x", + size=("k0", "k1", "k2"), + steps=("step_0", "step_1", "step_2"), + probability_contiguous=0.75, + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + max_allocation_bytes=2 * 1024**3, # 2 GB + dim_parameter="ndim", + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/unary.py b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/unary.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc80ca1df1409bb493084b8d757eacdce94fa3f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/op_fuzzers/unary.py @@ -0,0 +1,82 @@ +# mypy: allow-untyped-defs +import numpy as np +import torch + +from torch.utils.benchmark import Fuzzer, FuzzedParameter, ParameterAlias, FuzzedTensor + + +_MIN_DIM_SIZE = 16 +_MAX_DIM_SIZE = 16 * 1024 ** 2 +_POW_TWO_SIZES = tuple(2 ** i for i in range( + int(np.log2(_MIN_DIM_SIZE)), + int(np.log2(_MAX_DIM_SIZE)) + 1, +)) + + +class UnaryOpFuzzer(Fuzzer): + def __init__(self, seed, dtype=torch.float32, cuda=False): + super().__init__( + parameters=[ + # Dimensionality of x. (e.g. 1D, 2D, or 3D.) + FuzzedParameter("dim", distribution={1: 0.3, 2: 0.4, 3: 0.3}, strict=True), + + # Shapes for `x`. + # It is important to test all shapes, however + # powers of two are especially important and therefore + # warrant special attention. This is done by generating + # both a value drawn from all integers between the min and + # max allowed values, and another from only the powers of two + # (both distributions are loguniform) and then randomly + # selecting between the two. + [ + FuzzedParameter( + name=f"k_any_{i}", + minval=_MIN_DIM_SIZE, + maxval=_MAX_DIM_SIZE, + distribution="loguniform", + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k_pow2_{i}", + distribution={size: 1. / len(_POW_TWO_SIZES) for size in _POW_TWO_SIZES} + ) for i in range(3) + ], + [ + FuzzedParameter( + name=f"k{i}", + distribution={ + ParameterAlias(f"k_any_{i}"): 0.8, + ParameterAlias(f"k_pow2_{i}"): 0.2, + }, + strict=True, + ) for i in range(3) + ], + + # Steps for `x`. (Benchmarks strided memory access.) + [ + FuzzedParameter( + name=f"x_step_{i}", + distribution={1: 0.8, 2: 0.06, 4: 0.06, 8: 0.04, 16: 0.04}, + ) for i in range(3) + ], + + # Repeatable entropy for downstream applications. + FuzzedParameter(name="random_value", minval=0, maxval=2 ** 32 - 1, distribution="uniform"), + ], + tensors=[ + FuzzedTensor( + name="x", + size=("k0", "k1", "k2"), + steps=("x_step_0", "x_step_1", "x_step_2"), + probability_contiguous=0.75, + min_elements=4 * 1024, + max_elements=32 * 1024 ** 2, + max_allocation_bytes=2 * 1024**3, # 2 GB + dim_parameter="dim", + dtype=dtype, + cuda=cuda, + ), + ], + seed=seed, + ) diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__init__.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46ee0edcec88187b07a3993201ef0942746a3963 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/_stubs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/_stubs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b58bc73dab91f3d3141fe4b1397eb07d37bda3c2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/_stubs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bac5289948300f1259d9dced9189745097dd7f1e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/compare.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/compare.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..276b5bce3d112c16de6af222b62bd8547d24627c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/compare.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/compile.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/compile.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ebee12d941836349abe80ec5b27ebf4050e5808 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/compile.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3d67bebdadf3f4e9244c1b04508844d7cae385e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ba234a4a25401670c0d10ebead2b57177c43f2f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9116a65664dd68f73a167541a82707e027a1ef5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/timer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/timer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe8f4351643fd1e7ae7702131cb30dfe4d7153d6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/__pycache__/timer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/_stubs.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/_stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..7f8be0c09df1c81b28283e1eb974eb8c32bbbb6b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/_stubs.py @@ -0,0 +1,41 @@ +from typing import Any, Callable +from typing_extensions import Protocol, runtime_checkable + + +class TimerClass(Protocol): + """This is the portion of the `timeit.Timer` API used by benchmark utils.""" + def __init__( + self, + stmt: str, + setup: str, + timer: Callable[[], float], + globals: dict[str, Any], + **kwargs: Any, + ) -> None: + ... + + def timeit(self, number: int) -> float: + ... + + +@runtime_checkable +class TimeitModuleType(Protocol): + """Modules generated from `timeit_template.cpp`.""" + def timeit(self, number: int) -> float: + ... + + +class CallgrindModuleType(Protocol): + """Replicates the valgrind endpoints in `torch._C`. + + These bindings are used to collect Callgrind profiles on earlier versions + of PyTorch and will eventually be removed. + """ + __file__: str + __name__: str + + def _valgrind_supported_platform(self) -> bool: + ... + + def _valgrind_toggle(self) -> None: + ... diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/common.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..34b3a628d40c67d99209181e90d469caa4f3c749 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/common.py @@ -0,0 +1,356 @@ +"""Base shared classes and utilities.""" + +import collections +import contextlib +import dataclasses +import os +import shutil +import tempfile +import textwrap +import time +from typing import cast, Any, Optional +from collections.abc import Iterable, Iterator +import uuid + +import torch + + +__all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"] + + +_MAX_SIGNIFICANT_FIGURES = 4 +_MIN_CONFIDENCE_INTERVAL = 25e-9 # 25 ns + +# Measurement will include a warning if the distribution is suspect. All +# runs are expected to have some variation; these parameters set the +# thresholds. +_IQR_WARN_THRESHOLD = 0.1 +_IQR_GROSS_WARN_THRESHOLD = 0.25 + + +@dataclasses.dataclass(init=True, repr=False, eq=True, frozen=True) +class TaskSpec: + """Container for information used to define a Timer. (except globals)""" + stmt: str + setup: str + global_setup: str = "" + label: Optional[str] = None + sub_label: Optional[str] = None + description: Optional[str] = None + env: Optional[str] = None + num_threads: int = 1 + + @property + def title(self) -> str: + """Best effort attempt at a string label for the measurement.""" + if self.label is not None: + return self.label + (f": {self.sub_label}" if self.sub_label else "") + elif "\n" not in self.stmt: + return self.stmt + (f": {self.sub_label}" if self.sub_label else "") + return ( + f"stmt:{f' ({self.sub_label})' if self.sub_label else ''}\n" + f"{textwrap.indent(self.stmt, ' ')}" + ) + + def setup_str(self) -> str: + return ( + "" if (self.setup == "pass" or not self.setup) + else f"setup:\n{textwrap.indent(self.setup, ' ')}" if "\n" in self.setup + else f"setup: {self.setup}" + ) + + def summarize(self) -> str: + """Build TaskSpec portion of repr string for other containers.""" + sections = [ + self.title, + self.description or "", + self.setup_str(), + ] + return "\n".join([f"{i}\n" if "\n" in i else i for i in sections if i]) + +_TASKSPEC_FIELDS = tuple(i.name for i in dataclasses.fields(TaskSpec)) + + +@dataclasses.dataclass(init=True, repr=False) +class Measurement: + """The result of a Timer measurement. + + This class stores one or more measurements of a given statement. It is + serializable and provides several convenience methods + (including a detailed __repr__) for downstream consumers. + """ + number_per_run: int + raw_times: list[float] + task_spec: TaskSpec + metadata: Optional[dict[Any, Any]] = None # Reserved for user payloads. + + def __post_init__(self) -> None: + self._sorted_times: tuple[float, ...] = () + self._warnings: tuple[str, ...] = () + self._median: float = -1.0 + self._mean: float = -1.0 + self._p25: float = -1.0 + self._p75: float = -1.0 + + def __getattr__(self, name: str) -> Any: + # Forward TaskSpec fields for convenience. + if name in _TASKSPEC_FIELDS: + return getattr(self.task_spec, name) + return super().__getattribute__(name) + + # ========================================================================= + # == Convenience methods for statistics =================================== + # ========================================================================= + # + # These methods use raw time divided by number_per_run; this is an + # extrapolation and hides the fact that different number_per_run will + # result in different amortization of overheads, however if Timer has + # selected an appropriate number_per_run then this is a non-issue, and + # forcing users to handle that division would result in a poor experience. + @property + def times(self) -> list[float]: + return [t / self.number_per_run for t in self.raw_times] + + @property + def median(self) -> float: + self._lazy_init() + return self._median + + @property + def mean(self) -> float: + self._lazy_init() + return self._mean + + @property + def iqr(self) -> float: + self._lazy_init() + return self._p75 - self._p25 + + @property + def significant_figures(self) -> int: + """Approximate significant figure estimate. + + This property is intended to give a convenient way to estimate the + precision of a measurement. It only uses the interquartile region to + estimate statistics to try to mitigate skew from the tails, and + uses a static z value of 1.645 since it is not expected to be used + for small values of `n`, so z can approximate `t`. + + The significant figure estimation used in conjunction with the + `trim_sigfig` method to provide a more human interpretable data + summary. __repr__ does not use this method; it simply displays raw + values. Significant figure estimation is intended for `Compare`. + """ + self._lazy_init() + n_total = len(self._sorted_times) + lower_bound = int(n_total // 4) + upper_bound = int(torch.tensor(3 * n_total / 4).ceil()) + interquartile_points: tuple[float, ...] = self._sorted_times[lower_bound:upper_bound] + std = torch.tensor(interquartile_points).std(unbiased=False).item() + sqrt_n = torch.tensor(len(interquartile_points)).sqrt().item() + + # Rough estimates. These are by no means statistically rigorous. + confidence_interval = max(1.645 * std / sqrt_n, _MIN_CONFIDENCE_INTERVAL) + relative_ci = torch.tensor(self._median / confidence_interval).log10().item() + num_significant_figures = int(torch.tensor(relative_ci).floor()) + return min(max(num_significant_figures, 1), _MAX_SIGNIFICANT_FIGURES) + + @property + def has_warnings(self) -> bool: + self._lazy_init() + return bool(self._warnings) + + def _lazy_init(self) -> None: + if self.raw_times and not self._sorted_times: + self._sorted_times = tuple(sorted(self.times)) + _sorted_times = torch.tensor(self._sorted_times, dtype=torch.float64) + self._median = _sorted_times.quantile(.5).item() + self._mean = _sorted_times.mean().item() + self._p25 = _sorted_times.quantile(.25).item() + self._p75 = _sorted_times.quantile(.75).item() + + def add_warning(msg: str) -> None: + rel_iqr = self.iqr / self.median * 100 + self._warnings += ( + f" WARNING: Interquartile range is {rel_iqr:.1f}% " + f"of the median measurement.\n {msg}", + ) + + if not self.meets_confidence(_IQR_GROSS_WARN_THRESHOLD): + add_warning("This suggests significant environmental influence.") + elif not self.meets_confidence(_IQR_WARN_THRESHOLD): + add_warning("This could indicate system fluctuation.") + + + def meets_confidence(self, threshold: float = _IQR_WARN_THRESHOLD) -> bool: + return self.iqr / self.median < threshold + + @property + def title(self) -> str: + return self.task_spec.title + + @property + def env(self) -> str: + return ( + "Unspecified env" if self.taskspec.env is None + else cast(str, self.taskspec.env) + ) + + @property + def as_row_name(self) -> str: + return self.sub_label or self.stmt or "[Unknown]" + + def __repr__(self) -> str: + """ + Example repr: + + Broadcasting add (4x8) + Median: 5.73 us + IQR: 2.25 us (4.01 to 6.26) + 372 measurements, 100 runs per measurement, 1 thread + WARNING: Interquartile range is 39.4% of the median measurement. + This suggests significant environmental influence. + """ + self._lazy_init() + skip_line, newline = "MEASUREMENT_REPR_SKIP_LINE", "\n" + n = len(self._sorted_times) + time_unit, time_scale = select_unit(self._median) + iqr_filter = '' if n >= 4 else skip_line + + repr_str = f""" +{super().__repr__()} +{self.task_spec.summarize()} + {'Median: ' if n > 1 else ''}{self._median / time_scale:.2f} {time_unit} + {iqr_filter}IQR: {self.iqr / time_scale:.2f} {time_unit} ({self._p25 / time_scale:.2f} to {self._p75 / time_scale:.2f}) + {n} measurement{'s' if n > 1 else ''}, {self.number_per_run} runs {'per measurement,' if n > 1 else ','} {self.num_threads} thread{'s' if self.num_threads > 1 else ''} +{newline.join(self._warnings)}""".strip() # noqa: B950 + + return "\n".join(l for l in repr_str.splitlines(keepends=False) if skip_line not in l) + + @staticmethod + def merge(measurements: Iterable["Measurement"]) -> list["Measurement"]: + """Convenience method for merging replicates. + + Merge will extrapolate times to `number_per_run=1` and will not + transfer any metadata. (Since it might differ between replicates) + """ + grouped_measurements: collections.defaultdict[TaskSpec, list[Measurement]] = collections.defaultdict(list) + for m in measurements: + grouped_measurements[m.task_spec].append(m) + + def merge_group(task_spec: TaskSpec, group: list["Measurement"]) -> "Measurement": + times: list[float] = [] + for m in group: + # Different measurements could have different `number_per_run`, + # so we call `.times` which normalizes the results. + times.extend(m.times) + + return Measurement( + number_per_run=1, + raw_times=times, + task_spec=task_spec, + metadata=None, + ) + + return [merge_group(t, g) for t, g in grouped_measurements.items()] + + +def select_unit(t: float) -> tuple[str, float]: + """Determine how to scale times for O(1) magnitude. + + This utility is used to format numbers for human consumption. + """ + time_unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(torch.tensor(t).log10().item() // 3), "s") + time_scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[time_unit] + return time_unit, time_scale + + +def unit_to_english(u: str) -> str: + return { + "ns": "nanosecond", + "us": "microsecond", + "ms": "millisecond", + "s": "second", + }[u] + + +def trim_sigfig(x: float, n: int) -> float: + """Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)""" + assert n == int(n) + magnitude = int(torch.tensor(x).abs().log10().ceil().item()) + scale = 10 ** (magnitude - n) + return float(torch.tensor(x / scale).round() * scale) + + +def ordered_unique(elements: Iterable[Any]) -> list[Any]: + return list(collections.OrderedDict(dict.fromkeys(elements)).keys()) + + +@contextlib.contextmanager +def set_torch_threads(n: int) -> Iterator[None]: + prior_num_threads = torch.get_num_threads() + try: + torch.set_num_threads(n) + yield + finally: + torch.set_num_threads(prior_num_threads) + + +def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> str: + """Create a temporary directory. The caller is responsible for cleanup. + + This function is conceptually similar to `tempfile.mkdtemp`, but with + the key additional feature that it will use shared memory if the + `BENCHMARK_USE_DEV_SHM` environment variable is set. This is an + implementation detail, but an important one for cases where many Callgrind + measurements are collected at once. (Such as when collecting + microbenchmarks.) + + This is an internal utility, and is exported solely so that microbenchmarks + can reuse the util. + """ + use_dev_shm: bool = (os.getenv("BENCHMARK_USE_DEV_SHM") or "").lower() in ("1", "true") + if use_dev_shm: + root = "/dev/shm/pytorch_benchmark_utils" + assert os.name == "posix", f"tmpfs (/dev/shm) is POSIX only, current platform is {os.name}" + assert os.path.exists("/dev/shm"), "This system does not appear to support tmpfs (/dev/shm)." + os.makedirs(root, exist_ok=True) + + # Because we're working in shared memory, it is more important than + # usual to clean up ALL intermediate files. However we don't want every + # worker to walk over all outstanding directories, so instead we only + # check when we are sure that it won't lead to contention. + if gc_dev_shm: + for i in os.listdir(root): + owner_file = os.path.join(root, i, "owner.pid") + if not os.path.exists(owner_file): + continue + + with open(owner_file) as f: + owner_pid = int(f.read()) + + if owner_pid == os.getpid(): + continue + + try: + # https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-python + os.kill(owner_pid, 0) + + except OSError: + print(f"Detected that {os.path.join(root, i)} was orphaned in shared memory. Cleaning up.") + shutil.rmtree(os.path.join(root, i)) + + else: + root = tempfile.gettempdir() + + # We include the time so names sort by creation time, and add a UUID + # to ensure we don't collide. + name = f"{prefix or tempfile.gettempprefix()}__{int(time.time())}__{uuid.uuid4()}" + path = os.path.join(root, name) + os.makedirs(path, exist_ok=False) + + if use_dev_shm: + with open(os.path.join(path, "owner.pid"), "w") as f: + f.write(str(os.getpid())) + + return path diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/compare.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/compare.py new file mode 100644 index 0000000000000000000000000000000000000000..a0fd5af8b2cbf506fc9ceacb03ff490486bd909c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/compare.py @@ -0,0 +1,345 @@ +# mypy: allow-untyped-defs +"""Display class to aggregate and print the results of many measurements.""" +import collections +import enum +import itertools as it +from typing import Optional + +from torch.utils.benchmark.utils import common +from torch import tensor as _tensor +import operator + +__all__ = ["Colorize", "Compare"] + +BEST = "\033[92m" +GOOD = "\033[34m" +BAD = "\033[2m\033[91m" +VERY_BAD = "\033[31m" +BOLD = "\033[1m" +TERMINATE = "\033[0m" + + +class Colorize(enum.Enum): + NONE = "none" + COLUMNWISE = "columnwise" + ROWWISE = "rowwise" + + +# Classes to separate internal bookkeeping from what is rendered. +class _Column: + def __init__( + self, + grouped_results: list[tuple[Optional[common.Measurement], ...]], + time_scale: float, + time_unit: str, + trim_significant_figures: bool, + highlight_warnings: bool, + ): + self._grouped_results = grouped_results + self._flat_results = [*it.chain.from_iterable(grouped_results)] + self._time_scale = time_scale + self._time_unit = time_unit + self._trim_significant_figures = trim_significant_figures + self._highlight_warnings = ( + highlight_warnings + and any(r.has_warnings for r in self._flat_results if r) + ) + leading_digits = [ + int(_tensor(r.median / self._time_scale).log10().ceil()) if r else None + for r in self._flat_results + ] + unit_digits = max(d for d in leading_digits if d is not None) + decimal_digits = min( + max(m.significant_figures - digits, 0) + for digits, m in zip(leading_digits, self._flat_results) + if (m is not None) and (digits is not None) + ) if self._trim_significant_figures else 1 + length = unit_digits + decimal_digits + (1 if decimal_digits else 0) + self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}" + + def get_results_for(self, group): + return self._grouped_results[group] + + def num_to_str(self, value: Optional[float], estimated_sigfigs: int, spread: Optional[float]): + if value is None: + return " " * len(self.num_to_str(1, estimated_sigfigs, None)) + + if self._trim_significant_figures: + value = common.trim_sigfig(value, estimated_sigfigs) + + return self._template.format( + value, + f" (! {spread * 100:.0f}%)" if self._highlight_warnings and spread is not None else "") + + +def optional_min(seq): + l = list(seq) + return None if len(l) == 0 else min(l) + + +class _Row: + def __init__(self, results, row_group, render_env, env_str_len, + row_name_str_len, time_scale, colorize, num_threads=None): + super().__init__() + self._results = results + self._row_group = row_group + self._render_env = render_env + self._env_str_len = env_str_len + self._row_name_str_len = row_name_str_len + self._time_scale = time_scale + self._colorize = colorize + self._columns: tuple[_Column, ...] = () + self._num_threads = num_threads + + def register_columns(self, columns: tuple[_Column, ...]): + self._columns = columns + + def as_column_strings(self): + concrete_results = [r for r in self._results if r is not None] + env = f"({concrete_results[0].env})" if self._render_env else "" + env = env.ljust(self._env_str_len + 4) + output = [" " + env + concrete_results[0].as_row_name] + for m, col in zip(self._results, self._columns or ()): + if m is None: + output.append(col.num_to_str(None, 1, None)) + else: + output.append(col.num_to_str( + m.median / self._time_scale, + m.significant_figures, + m.iqr / m.median if m.has_warnings else None + )) + return output + + @staticmethod + def color_segment(segment, value, best_value): + if value <= best_value * 1.01 or value <= best_value + 100e-9: + return BEST + BOLD + segment + TERMINATE * 2 + if value <= best_value * 1.1: + return GOOD + BOLD + segment + TERMINATE * 2 + if value >= best_value * 5: + return VERY_BAD + BOLD + segment + TERMINATE * 2 + if value >= best_value * 2: + return BAD + segment + TERMINATE * 2 + + return segment + + def row_separator(self, overall_width): + return ( + [f"{self._num_threads} threads: ".ljust(overall_width, "-")] + if self._num_threads is not None else [] + ) + + def finalize_column_strings(self, column_strings, col_widths): + best_values = [-1 for _ in column_strings] + if self._colorize == Colorize.ROWWISE: + row_min = min(r.median for r in self._results if r is not None) + best_values = [row_min for _ in column_strings] + elif self._colorize == Colorize.COLUMNWISE: + best_values = [ + optional_min(r.median for r in column.get_results_for(self._row_group) if r is not None) + for column in (self._columns or ()) + ] + + row_contents = [column_strings[0].ljust(col_widths[0])] + for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values): + col_str = col_str.center(width) + if self._colorize != Colorize.NONE and result is not None and best_value is not None: + col_str = self.color_segment(col_str, result.median, best_value) + row_contents.append(col_str) + return row_contents + + +class Table: + def __init__( + self, + results: list[common.Measurement], + colorize: Colorize, + trim_significant_figures: bool, + highlight_warnings: bool + ): + assert len({r.label for r in results}) == 1 + + self.results = results + self._colorize = colorize + self._trim_significant_figures = trim_significant_figures + self._highlight_warnings = highlight_warnings + self.label = results[0].label + self.time_unit, self.time_scale = common.select_unit( + min(r.median for r in results) + ) + + self.row_keys = common.ordered_unique([self.row_fn(i) for i in results]) + self.row_keys.sort(key=operator.itemgetter(slice(2))) # preserve stmt order + self.column_keys = common.ordered_unique([self.col_fn(i) for i in results]) + self.rows, self.columns = self.populate_rows_and_columns() + + @staticmethod + def row_fn(m: common.Measurement) -> tuple[int, Optional[str], str]: + return m.num_threads, m.env, m.as_row_name + + @staticmethod + def col_fn(m: common.Measurement) -> Optional[str]: + return m.description + + def populate_rows_and_columns(self) -> tuple[tuple[_Row, ...], tuple[_Column, ...]]: + rows: list[_Row] = [] + columns: list[_Column] = [] + ordered_results: list[list[Optional[common.Measurement]]] = [ + [None for _ in self.column_keys] + for _ in self.row_keys + ] + row_position = {key: i for i, key in enumerate(self.row_keys)} + col_position = {key: i for i, key in enumerate(self.column_keys)} + for r in self.results: + i = row_position[self.row_fn(r)] + j = col_position[self.col_fn(r)] + ordered_results[i][j] = r + + unique_envs = {r.env for r in self.results} + render_env = len(unique_envs) > 1 + env_str_len = max(len(i) for i in unique_envs) if render_env else 0 + + row_name_str_len = max(len(r.as_row_name) for r in self.results) + + prior_num_threads = -1 + prior_env = "" + row_group = -1 + rows_by_group: list[list[list[Optional[common.Measurement]]]] = [] + for (num_threads, env, _), row in zip(self.row_keys, ordered_results): + thread_transition = (num_threads != prior_num_threads) + if thread_transition: + prior_num_threads = num_threads + prior_env = "" + row_group += 1 + rows_by_group.append([]) + rows.append( + _Row( + results=row, + row_group=row_group, + render_env=(render_env and env != prior_env), + env_str_len=env_str_len, + row_name_str_len=row_name_str_len, + time_scale=self.time_scale, + colorize=self._colorize, + num_threads=num_threads if thread_transition else None, + ) + ) + rows_by_group[-1].append(row) + prior_env = env + + for i in range(len(self.column_keys)): + grouped_results = [tuple(row[i] for row in g) for g in rows_by_group] + column = _Column( + grouped_results=grouped_results, + time_scale=self.time_scale, + time_unit=self.time_unit, + trim_significant_figures=self._trim_significant_figures, + highlight_warnings=self._highlight_warnings,) + columns.append(column) + + rows_tuple, columns_tuple = tuple(rows), tuple(columns) + for ri in rows_tuple: + ri.register_columns(columns_tuple) + return rows_tuple, columns_tuple + + def render(self) -> str: + string_rows = [[""] + self.column_keys] + string_rows.extend(r.as_column_strings() for r in self.rows) + num_cols = max(len(i) for i in string_rows) + for sr in string_rows: + sr.extend(["" for _ in range(num_cols - len(sr))]) + + col_widths = [max(len(j) for j in i) for i in zip(*string_rows)] + finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths))] + overall_width = len(finalized_columns[0]) + for string_row, row in zip(string_rows[1:], self.rows): + finalized_columns.extend(row.row_separator(overall_width)) + finalized_columns.append(" | ".join(row.finalize_column_strings(string_row, col_widths))) + + newline = "\n" + has_warnings = self._highlight_warnings and any(ri.has_warnings for ri in self.results) + return f""" +[{(' ' + (self.label or '') + ' ').center(overall_width - 2, '-')}] +{newline.join(finalized_columns)} + +Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}). +{'(! XX%) Measurement has high variance, where XX is the IQR / median * 100.' + newline if has_warnings else ""}"""[1:] + + +class Compare: + """Helper class for displaying the results of many measurements in a + formatted table. + + The table format is based on the information fields provided in + :class:`torch.utils.benchmark.Timer` (`description`, `label`, `sub_label`, + `num_threads`, etc). + + The table can be directly printed using :meth:`print` or casted as a `str`. + + For a full tutorial on how to use this class, see: + https://pytorch.org/tutorials/recipes/recipes/benchmark.html + + Args: + results: List of Measurment to display. + """ + def __init__(self, results: list[common.Measurement]): + self._results: list[common.Measurement] = [] + self.extend_results(results) + self._trim_significant_figures = False + self._colorize = Colorize.NONE + self._highlight_warnings = False + + def __str__(self): + return "\n".join(self._render()) + + def extend_results(self, results): + """Append results to already stored ones. + + All added results must be instances of ``Measurement``. + """ + for r in results: + if not isinstance(r, common.Measurement): + raise ValueError( + "Expected an instance of `Measurement`, " f"got {type(r)} instead." + ) + self._results.extend(results) + + def trim_significant_figures(self): + """Enables trimming of significant figures when building the formatted table.""" + self._trim_significant_figures = True + + def colorize(self, rowwise=False): + """Colorize formatted table. + + Colorize columnwise by default. + """ + self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE + + def highlight_warnings(self): + """Enables warning highlighting when building formatted table.""" + self._highlight_warnings = True + + def print(self): + """Print formatted table""" + print(str(self)) + + def _render(self): + results = common.Measurement.merge(self._results) + grouped_results = self._group_by_label(results) + output = [self._layout(group) for group in grouped_results.values()] + return output + + def _group_by_label(self, results: list[common.Measurement]): + grouped_results: collections.defaultdict[str, list[common.Measurement]] = collections.defaultdict(list) + for r in results: + grouped_results[r.label].append(r) + return grouped_results + + def _layout(self, results: list[common.Measurement]): + table = Table( + results, + self._colorize, + self._trim_significant_figures, + self._highlight_warnings + ) + return table.render() diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/compile.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/compile.py new file mode 100644 index 0000000000000000000000000000000000000000..8320c007bd907435e67fe1e1c7868e7eb9a114b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/compile.py @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable, cast, Optional, Union + +import torch +import torch._dynamo +from torch._dynamo.testing import CompileCounterWithBackend +from torch.utils.benchmark import Timer + + +__all__ = ["bench_all", "benchmark_compile"] + + +_warned_tensor_cores = False +_default_float_32_precision = torch.get_float32_matmul_precision() + +try: + from tabulate import tabulate + + HAS_TABULATE = True +except ModuleNotFoundError: + HAS_TABULATE = False + tabulate = None # type: ignore[assignment] + print("tabulate is not installed, please pip install tabulate to use this utility") + +if HAS_TABULATE: + def _enable_tensor_cores(): + global _warned_tensor_cores + + if torch.cuda.is_available(): + if torch.backends.cuda.matmul.allow_tf32 is False and torch.cuda.get_device_capability() >= (8, 0): + torch.set_float32_matmul_precision("high") + if not _warned_tensor_cores: + print("Your GPU supports tensor cores") + print("we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`") + _warned_tensor_cores = True + + def _disable_tensor_cores(): + torch.set_float32_matmul_precision(_default_float_32_precision) + + def bench_loop( + model: Union[torch.nn.Module, Callable], + sample_input: Union[torch.Tensor, Any], + num_iters: int = 5, + optimizer: Optional[torch.optim.Optimizer] = None, + loss_fn: Optional[Callable] = None, + ): + # Define the statement and setup for the benchmark + if optimizer and loss_fn: + # Training mode + stmt = """ + output = model(sample_input) + loss = loss_fn(output) if loss_fn else output.sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + """ + else: + # Inference mode + stmt = "model(sample_input)" + + # Create the Timer object + timer = Timer( + stmt=stmt, + globals={"model": model, "sample_input": sample_input, "optimizer": optimizer, "loss_fn": loss_fn}, + ) + + + result = timer.timeit(number=num_iters) + + # Get the average time per iteration in milliseconds + avg_time = result.mean * 1000 + return round(avg_time, 2) + + def benchmark_compile( + model: Union[torch.nn.Module, Callable], + sample_input: Union[torch.Tensor, Any], + num_iters: int = 5, + backend: Optional[str] = None, + mode: Optional[str] = "default", + optimizer: Optional[torch.optim.Optimizer] = None, + loss_fn : Union[torch.nn.Module, Callable, None] = None, + ): + """ + Use this utility to benchmark torch.compile + """ + if backend: + try: + torch._dynamo.reset() + compile_counter_with_backend = CompileCounterWithBackend(backend) + opt_model = torch.compile(model, backend=compile_counter_with_backend, mode=mode) + + # Compilation only happens after the first inference + compilation_time = bench_loop(opt_model, sample_input, 1, optimizer, loss_fn) + + running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn) + + if compile_counter_with_backend.frame_count == 0: + raise RuntimeError("No compilation occurred during benchmarking.") + + if compile_counter_with_backend.frame_count > 1: + raise RuntimeError("Recompilation occurred during benchmarking.") + + except Exception as e: + print(e) + print(f"Failed to compile {backend} with mode {mode}") + return None, None + else: + opt_model = model + compilation_time = None + running_time = bench_loop(opt_model, sample_input, num_iters, optimizer, loss_fn) + + compilation_time = round(compilation_time, 2) if compilation_time else None + running_time = round(running_time, 2) if running_time else None + + + return compilation_time, running_time + + + def bench_all( + model : Union[torch.nn.Module, Callable], + sample_input: Union[torch.Tensor, Any], + num_iters : int = 5, + optimizer: Optional[torch.optim.Optimizer] = None, + loss_fn : Union[torch.nn.Module, Callable, None] = None, + ): + """ + This is a simple utility that can be used to benchmark torch.compile + In particular it ensures that your GPU is setup to use tensor cores if it supports its + It also tries out all the main backends and prints a table of results so you can easily compare them all + Many of the backendds have their own optional dependencies so please pip install them seperately + + You will get one table for inference and another for training + If you'd like to leverage this utility for training make sure to pass in a torch.optim.Optimizer + + The important warnings are + Your GPU supports tensor cores + we will enable it automatically by setting `torch.set_float32_matmul_precision('high')` + + If a compilation fails for any reason including the dependency not being included + then we will print Failed to compile {backend} with mode {mode} + """ + field_names = ["Train/Inference", "Backend", "Mode", "Compilation Time", "Average Running Time"] + table = [] + + + eager_time = None + torch._dynamo.reset() + _, eager_time = benchmark_compile(model, sample_input, num_iters, None, None, optimizer) + table.append( + [("Training" if optimizer else "Inference"), "Eager", "-", "-", f"{eager_time} ms"] + ) + + for backend in torch._dynamo.list_backends(): + + if backend == "inductor": + mode_options = cast(list[Optional[str]], list(torch._inductor.list_mode_options().keys())) + [None] + for mode in mode_options: + if mode == "default": + continue + torch._dynamo.reset() + try: + if torch.cuda.is_available(): + _enable_tensor_cores() + compilation_time, running_time = benchmark_compile( + model, sample_input, num_iters, backend, mode, optimizer, loss_fn) + finally: + if torch.cuda.is_available(): + _disable_tensor_cores() + table.append([ + ("Training" if optimizer else "Inference"), + backend if backend else "-", + mode if mode is not None else "-", + f"{compilation_time} ms " if compilation_time else "-", + f"{running_time} ms " if running_time else "-", + ]) + + else: + torch._dynamo.reset() + compilation_time, running_time = benchmark_compile( + model, sample_input, num_iters, backend, None, optimizer, loss_fn) + + if running_time is not None: + table.append([ + ("Training" if optimizer else "Inference"), + backend, "-", + f"{compilation_time} ms " or "-", + f"{running_time} ms ", + ]) + + + return tabulate(table, headers=field_names, tablefmt="github") diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/cpp_jit.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/cpp_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..00085de84c008da93eb5d387cb602f2f8144e538 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/cpp_jit.py @@ -0,0 +1,172 @@ +"""JIT C++ strings into executables.""" +import atexit +import os +import re +import shutil +import textwrap +import threading +from typing import Any, Optional + +import torch +from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType +from torch.utils.benchmark.utils.common import _make_temp_dir +from torch.utils import cpp_extension + + +LOCK = threading.Lock() +SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0] + +# We calculate uuid once at import time so that separate processes will have +# separate build roots, but threads will share the same build root. +# `cpp_extension` uses build root as part of the cache key, so per-invocation +# uuid's (e.g. different build root per _compile_template call) would lead to +# a 0% cache hit rate and spurious recompilation. Consider the following: +# ``` +# setup = "auto x = torch::ones({1024, 1024});" +# stmt = "torch::mm(x, x);" +# for num_threads in [1, 2, 4, 8]: +# print(Timer(stmt, setup, num_threads=num_threads, language="c++").blocked_autorange()) +# ```` +# `setup` and `stmt` do not change, so we can reuse the executable from the +# first pass through the loop. +_BUILD_ROOT: Optional[str] = None + +def _get_build_root() -> str: + global _BUILD_ROOT + if _BUILD_ROOT is None: + _BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build") + atexit.register(shutil.rmtree, _BUILD_ROOT) + return _BUILD_ROOT + + +# BACK_TESTING_NOTE: +# There are two workflows where this code could be used. One is the obvious +# case where someone simply builds or installs PyTorch and uses Timer. +# The other is that the entire `torch/utils/benchmark` folder from a CURRENT +# PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch +# source code. This is what we refer to here as "back testing". The rationale +# is that we might want to use current tooling to study some aspect of an +# earlier version of PyTorch. (e.g. a regression.) +# +# The problem is that Timer relies on several aspects of core PyTorch, namely +# some binding functions for Valgrind symbols in `torch._C` and the +# `torch.__config__._cxx_flags()` method. If we were to naively copy code +# around this wouldn't work as the symbols of interest aren't present in +# earlier versions of PyTorch. In order to work around this, we must add back +# testing shims. These shims will never activate during normal use, but will +# allow Timer to function outside of the "correct" version of PyTorch by +# emulating functionality that was added later. +# +# These shims are temporary, and as Timer becomes more integrated with +# PyTorch the cost and complexity of such shims will increase. Once back +# testing is no longer required (which is to say we have done enough historic +# analysis and the shims no longer justify their maintenance and code +# complexity costs) back testing paths will be removed. + +CXX_FLAGS: Optional[list[str]] +if hasattr(torch.__config__, "_cxx_flags"): + try: + CXX_FLAGS = torch.__config__._cxx_flags().strip().split() + if CXX_FLAGS is not None and "-g" not in CXX_FLAGS: + CXX_FLAGS.append("-g") + # remove "-W" flags to allow build benchmarks + # with a relaxed constraint of compiler versions + if CXX_FLAGS is not None: + CXX_FLAGS = list(filter(lambda x: not x.startswith("-W"), CXX_FLAGS)) + + except RuntimeError: + # We are in FBCode. + CXX_FLAGS = None +else: + # FIXME: Remove when back testing is no longer required. + CXX_FLAGS = ["-O2", "-fPIC", "-g"] + +EXTRA_INCLUDE_PATHS: list[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")] +CONDA_PREFIX = os.getenv("CONDA_PREFIX") +if CONDA_PREFIX is not None: + # Load will automatically search /usr/include, but not conda include. + EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include")) + + +COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None +def get_compat_bindings() -> CallgrindModuleType: + with LOCK: + global COMPAT_CALLGRIND_BINDINGS + if COMPAT_CALLGRIND_BINDINGS is None: + COMPAT_CALLGRIND_BINDINGS = cpp_extension.load( + name="callgrind_bindings", + sources=[os.path.join( + SOURCE_ROOT, + "valgrind_wrapper", + "compat_bindings.cpp" + )], + extra_cflags=CXX_FLAGS, + extra_include_paths=EXTRA_INCLUDE_PATHS, + ) + return COMPAT_CALLGRIND_BINDINGS + + +def _compile_template( + *, + stmt: str, + setup: str, + global_setup: str, + src: str, + is_standalone: bool +) -> Any: + for before, after, indentation in ( + ("// GLOBAL_SETUP_TEMPLATE_LOCATION", global_setup, 0), + ("// SETUP_TEMPLATE_LOCATION", setup, 4), + ("// STMT_TEMPLATE_LOCATION", stmt, 8) + ): + # C++ doesn't care about indentation so this code isn't load + # bearing the way it is with Python, but this makes the source + # look nicer if a human has to look at it. + src = re.sub( + before, + textwrap.indent(after, " " * indentation)[indentation:], + src + ) + + # We want to isolate different Timers. However `cpp_extension` will + # cache builds which will significantly reduce the cost of repeated + # invocations. + with LOCK: + name = f"timer_cpp_{abs(hash(src))}" + build_dir = os.path.join(_get_build_root(), name) + os.makedirs(build_dir, exist_ok=True) + + src_path = os.path.join(build_dir, "timer_src.cpp") + with open(src_path, "w") as f: + f.write(src) + + # `cpp_extension` has its own locking scheme, so we don't need our lock. + return cpp_extension.load( + name=name, + sources=[src_path], + build_directory=build_dir, + extra_cflags=CXX_FLAGS, + extra_include_paths=EXTRA_INCLUDE_PATHS, + is_python_module=not is_standalone, + is_standalone=is_standalone, + ) + + +def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType: + template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp") + with open(template_path) as f: + src: str = f.read() + + module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False) + assert isinstance(module, TimeitModuleType) + return module + + +def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str: + template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp") + with open(template_path) as f: + src: str = f.read() + + target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True) + assert isinstance(target, str) + return target diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/fuzzer.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/fuzzer.py new file mode 100644 index 0000000000000000000000000000000000000000..208c83d22c1aaf62b7af46dc03ced9425e073c6b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/fuzzer.py @@ -0,0 +1,462 @@ +# mypy: allow-untyped-defs +import functools +import itertools as it +from typing import Any, Callable, Optional, Union + +import torch + + +__all__ = [ + "Fuzzer", + "FuzzedParameter", "ParameterAlias", + "FuzzedTensor", +] + + +_DISTRIBUTIONS = ( + "loguniform", + "uniform", +) + + +class FuzzedParameter: + """Specification for a parameter to be generated during fuzzing.""" + def __init__( + self, + name: str, + minval: Optional[Union[int, float]] = None, + maxval: Optional[Union[int, float]] = None, + distribution: Optional[Union[str, dict[Any, float]]] = None, + strict: bool = False, + ): + """ + Args: + name: + A string name with which to identify the parameter. + FuzzedTensors can reference this string in their + specifications. + minval: + The lower bound for the generated value. See the description + of `distribution` for type behavior. + maxval: + The upper bound for the generated value. Type behavior is + identical to `minval`. + distribution: + Specifies the distribution from which this parameter should + be drawn. There are three possibilities: + - "loguniform" + Samples between `minval` and `maxval` (inclusive) such + that the probabilities are uniform in log space. As a + concrete example, if minval=1 and maxval=100, a sample + is as likely to fall in [1, 10) as it is [10, 100]. + - "uniform" + Samples are chosen with uniform probability between + `minval` and `maxval` (inclusive). If either `minval` + or `maxval` is a float then the distribution is the + continuous uniform distribution; otherwise samples + are constrained to the integers. + - dict: + If a dict is passed, the keys are taken to be choices + for the variables and the values are interpreted as + probabilities. (And must sum to one.) + If a dict is passed, `minval` and `maxval` must not be set. + Otherwise, they must be set. + strict: + If a parameter is strict, it will not be included in the + iterative resampling process which Fuzzer uses to find a + valid parameter configuration. This allows an author to + prevent skew from resampling for a given parameter (for + instance, a low size limit could inadvertently bias towards + Tensors with fewer dimensions) at the cost of more iterations + when generating parameters. + """ + self._name = name + self._minval = minval + self._maxval = maxval + self._distribution = self._check_distribution(distribution) + self.strict = strict + + @property + def name(self): + return self._name + + def sample(self, state): + if self._distribution == "loguniform": + return self._loguniform(state) + + if self._distribution == "uniform": + return self._uniform(state) + + if isinstance(self._distribution, dict): + return self._custom_distribution(state) + + def _check_distribution(self, distribution): + if not isinstance(distribution, dict): + assert distribution in _DISTRIBUTIONS + else: + assert not any(i < 0 for i in distribution.values()), "Probabilities cannot be negative" + assert abs(sum(distribution.values()) - 1) <= 1e-5, "Distribution is not normalized" + assert self._minval is None + assert self._maxval is None + + return distribution + + def _loguniform(self, state): + import numpy as np + output = int(2 ** state.uniform( + low=np.log2(self._minval) if self._minval is not None else None, + high=np.log2(self._maxval) if self._maxval is not None else None, + )) + if self._minval is not None and output < self._minval: + return self._minval + if self._maxval is not None and output > self._maxval: + return self._maxval + return output + + def _uniform(self, state): + if isinstance(self._minval, int) and isinstance(self._maxval, int): + return int(state.randint(low=self._minval, high=self._maxval + 1)) + return state.uniform(low=self._minval, high=self._maxval) + + def _custom_distribution(self, state): + import numpy as np + # If we directly pass the keys to `choice`, numpy will convert + # them to numpy dtypes. + index = state.choice( + np.arange(len(self._distribution)), + p=tuple(self._distribution.values())) + return list(self._distribution.keys())[index] + + +class ParameterAlias: + """Indicates that a parameter should alias the value of another parameter. + + When used in conjunction with a custom distribution, this allows fuzzed + tensors to represent a broader range of behaviors. For example, the + following sometimes produces Tensors which broadcast: + + Fuzzer( + parameters=[ + FuzzedParameter("x_len", 4, 1024, distribution="uniform"), + + # `y` will either be size one, or match the size of `x`. + FuzzedParameter("y_len", distribution={ + 0.5: 1, + 0.5: ParameterAlias("x_len") + }), + ], + tensors=[ + FuzzedTensor("x", size=("x_len",)), + FuzzedTensor("y", size=("y_len",)), + ], + ) + + Chains of alias' are allowed, but may not contain cycles. + """ + def __init__(self, alias_to): + self.alias_to = alias_to + + def __repr__(self): + return f"ParameterAlias[alias_to: {self.alias_to}]" + + +def dtype_size(dtype): + if dtype == torch.bool: + return 1 + if dtype.is_floating_point or dtype.is_complex: + return int(torch.finfo(dtype).bits / 8) + return int(torch.iinfo(dtype).bits / 8) + + +def prod(values, base=1): + """np.prod can overflow, so for sizes the product should be done in Python. + + Even though np.prod type promotes to int64, it can still overflow in which + case the negative value will pass the size check and OOM when attempting to + actually allocate the Tensor. + """ + return functools.reduce(lambda x, y: int(x) * int(y), values, base) + + +class FuzzedTensor: + def __init__( + self, + name: str, + size: tuple[Union[str, int], ...], + steps: Optional[tuple[Union[str, int], ...]] = None, + probability_contiguous: float = 0.5, + min_elements: Optional[int] = None, + max_elements: Optional[int] = None, + max_allocation_bytes: Optional[int] = None, + dim_parameter: Optional[str] = None, + roll_parameter: Optional[str] = None, + dtype=torch.float32, + cuda=False, + tensor_constructor: Optional[Callable] = None + ): + """ + Args: + name: + A string identifier for the generated Tensor. + size: + A tuple of integers or strings specifying the size of the generated + Tensor. String values will replaced with a concrete int during the + generation process, while ints are simply passed as literals. + steps: + An optional tuple with the same length as `size`. This indicates + that a larger Tensor should be allocated, and then sliced to + produce the generated Tensor. For instance, if size is (4, 8) + and steps is (1, 4), then a tensor `t` of size (4, 32) will be + created and then `t[:, ::4]` will be used. (Allowing one to test + Tensors with strided memory.) + probability_contiguous: + A number between zero and one representing the chance that the + generated Tensor has a contiguous memory layout. This is achieved by + randomly permuting the shape of a Tensor, calling `.contiguous()`, + and then permuting back. This is applied before `steps`, which can + also cause a Tensor to be non-contiguous. + min_elements: + The minimum number of parameters that this Tensor must have for a + set of parameters to be valid. (Otherwise they are resampled.) + max_elements: + Like `min_elements`, but setting an upper bound. + max_allocation_bytes: + Like `max_elements`, but for the size of Tensor that must be + allocated prior to slicing for `steps` (if applicable). For + example, a FloatTensor with size (1024, 1024) and steps (4, 4) + would have 1M elements, but would require a 64 MB allocation. + dim_parameter: + The length of `size` and `steps` will be truncated to this value. + This allows Tensors of varying dimensions to be generated by the + Fuzzer. + dtype: + The PyTorch dtype of the generated Tensor. + cuda: + Whether to place the Tensor on a GPU. + tensor_constructor: + Callable which will be used instead of the default Tensor + construction method. This allows the author to enforce properties + of the Tensor (e.g. it can only have certain values). The dtype and + concrete shape of the Tensor to be created will be passed, and + concrete values of all parameters will be passed as kwargs. Note + that transformations to the result (permuting, slicing) will be + performed by the Fuzzer; the tensor_constructor is only responsible + for creating an appropriately sized Tensor. + """ + self._name = name + self._size = size + self._steps = steps + self._probability_contiguous = probability_contiguous + self._min_elements = min_elements + self._max_elements = max_elements + self._max_allocation_bytes = max_allocation_bytes + self._dim_parameter = dim_parameter + self._dtype = dtype + self._cuda = cuda + self._tensor_constructor = tensor_constructor + + @property + def name(self): + return self._name + + @staticmethod + def default_tensor_constructor(size, dtype, **kwargs): + if dtype.is_floating_point or dtype.is_complex: + return torch.rand(size=size, dtype=dtype, device="cpu") + else: + return torch.randint(1, 127, size=size, dtype=dtype, device="cpu") + + def _make_tensor(self, params, state): + import numpy as np + size, steps, allocation_size = self._get_size_and_steps(params) + constructor = ( + self._tensor_constructor or + self.default_tensor_constructor + ) + + raw_tensor = constructor(size=allocation_size, dtype=self._dtype, **params) + if self._cuda: + raw_tensor = raw_tensor.cuda() + + # Randomly permute the Tensor and call `.contiguous()` to force re-ordering + # of the memory, and then permute it back to the original shape. + dim = len(size) + order = np.arange(dim) + if state.rand() > self._probability_contiguous: + while dim > 1 and np.all(order == np.arange(dim)): + order = state.permutation(raw_tensor.dim()) + + raw_tensor = raw_tensor.permute(tuple(order)).contiguous() + raw_tensor = raw_tensor.permute(tuple(np.argsort(order))) + + slices = [slice(0, size * step, step) for size, step in zip(size, steps)] + tensor = raw_tensor[tuple(slices)] + + properties = { + "numel": int(tensor.numel()), + "order": order, + "steps": steps, + "is_contiguous": tensor.is_contiguous(), + "dtype": str(self._dtype), + } + + return tensor, properties + + def _get_size_and_steps(self, params): + dim = ( + params[self._dim_parameter] + if self._dim_parameter is not None + else len(self._size) + ) + + def resolve(values, dim): + """Resolve values into concrete integers.""" + values = tuple(params.get(i, i) for i in values) + if len(values) > dim: + values = values[:dim] + if len(values) < dim: + values = values + tuple(1 for _ in range(dim - len(values))) + return values + + size = resolve(self._size, dim) + steps = resolve(self._steps or (), dim) + allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps)) + return size, steps, allocation_size + + def satisfies_constraints(self, params): + size, _, allocation_size = self._get_size_and_steps(params) + # Product is computed in Python to avoid integer overflow. + num_elements = prod(size) + assert num_elements >= 0 + + allocation_bytes = prod(allocation_size, base=dtype_size(self._dtype)) + + def nullable_greater(left, right): + if left is None or right is None: + return False + return left > right + + return not any(( + nullable_greater(num_elements, self._max_elements), + nullable_greater(self._min_elements, num_elements), + nullable_greater(allocation_bytes, self._max_allocation_bytes), + )) + + +class Fuzzer: + def __init__( + self, + parameters: list[Union[FuzzedParameter, list[FuzzedParameter]]], + tensors: list[Union[FuzzedTensor, list[FuzzedTensor]]], + constraints: Optional[list[Callable]] = None, + seed: Optional[int] = None + ): + """ + Args: + parameters: + List of FuzzedParameters which provide specifications + for generated parameters. Iterable elements will be + unpacked, though arbitrary nested structures will not. + tensors: + List of FuzzedTensors which define the Tensors which + will be created each step based on the parameters for + that step. Iterable elements will be unpacked, though + arbitrary nested structures will not. + constraints: + List of callables. They will be called with params + as kwargs, and if any of them return False the current + set of parameters will be rejected. + seed: + Seed for the RandomState used by the Fuzzer. This will + also be used to set the PyTorch random seed so that random + ops will create reproducible Tensors. + """ + import numpy as np + if seed is None: + seed = int(np.random.RandomState().randint(0, 2 ** 32 - 1, dtype=np.int64)) + self._seed = seed + self._parameters = Fuzzer._unpack(parameters, FuzzedParameter) + self._tensors = Fuzzer._unpack(tensors, FuzzedTensor) + self._constraints = constraints or () + + p_names = {p.name for p in self._parameters} + t_names = {t.name for t in self._tensors} + name_overlap = p_names.intersection(t_names) + if name_overlap: + raise ValueError(f"Duplicate names in parameters and tensors: {name_overlap}") + + self._rejections = 0 + self._total_generated = 0 + + @staticmethod + def _unpack(values, cls): + return tuple(it.chain.from_iterable( + [[i] if isinstance(i, cls) else i for i in values] + )) + + def take(self, n): + import numpy as np + state = np.random.RandomState(self._seed) + torch.manual_seed(state.randint(low=0, high=2 ** 63, dtype=np.int64)) + for _ in range(n): + params = self._generate(state) + tensors = {} + tensor_properties = {} + for t in self._tensors: + tensor, properties = t._make_tensor(params, state) + tensors[t.name] = tensor + tensor_properties[t.name] = properties + yield tensors, tensor_properties, params + + @property + def rejection_rate(self): + if not self._total_generated: + return 0. + return self._rejections / self._total_generated + + def _generate(self, state): + strict_params: dict[str, Union[float, int, ParameterAlias]] = {} + for _ in range(1000): + candidate_params: dict[str, Union[float, int, ParameterAlias]] = {} + for p in self._parameters: + if p.strict: + if p.name in strict_params: + candidate_params[p.name] = strict_params[p.name] + else: + candidate_params[p.name] = p.sample(state) + strict_params[p.name] = candidate_params[p.name] + else: + candidate_params[p.name] = p.sample(state) + + candidate_params = self._resolve_aliases(candidate_params) + + self._total_generated += 1 + if not all(f(candidate_params) for f in self._constraints): + self._rejections += 1 + continue + + if not all(t.satisfies_constraints(candidate_params) for t in self._tensors): + self._rejections += 1 + continue + + return candidate_params + raise ValueError("Failed to generate a set of valid parameters.") + + @staticmethod + def _resolve_aliases(params): + params = dict(params) + alias_count = sum(isinstance(v, ParameterAlias) for v in params.values()) + + keys = list(params.keys()) + while alias_count: + for k in keys: + v = params[k] + if isinstance(v, ParameterAlias): + params[k] = params[v.alias_to] + alias_count_new = sum(isinstance(v, ParameterAlias) for v in params.values()) + if alias_count == alias_count_new: + raise ValueError(f"ParameterAlias cycle detected\n{params}") + + alias_count = alias_count_new + + return params diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py new file mode 100644 index 0000000000000000000000000000000000000000..0d4575726a44e2c46549689f61f0513c4438507e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/sparse_fuzzer.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union +from numbers import Number +import torch +from torch.utils.benchmark import FuzzedTensor +import math + +class FuzzedSparseTensor(FuzzedTensor): + def __init__( + self, + name: str, + size: tuple[Union[str, int], ...], + min_elements: Optional[int] = None, + max_elements: Optional[int] = None, + dim_parameter: Optional[str] = None, + sparse_dim: Optional[str] = None, + nnz: Optional[str] = None, + density: Optional[str] = None, + coalesced: Optional[str] = None, + dtype=torch.float32, + cuda=False + ): + """ + Args: + name: + A string identifier for the generated Tensor. + size: + A tuple of integers or strings specifying the size of the generated + Tensor. String values will replaced with a concrete int during the + generation process, while ints are simply passed as literals. + min_elements: + The minimum number of parameters that this Tensor must have for a + set of parameters to be valid. (Otherwise they are resampled.) + max_elements: + Like `min_elements`, but setting an upper bound. + dim_parameter: + The length of `size` will be truncated to this value. + This allows Tensors of varying dimensions to be generated by the + Fuzzer. + sparse_dim: + The number of sparse dimensions in a sparse tensor. + density: + This value allows tensors of varying sparsities to be generated by the Fuzzer. + coalesced: + The sparse tensor format permits uncoalesced sparse tensors, + where there may be duplicate coordinates in the indices. + dtype: + The PyTorch dtype of the generated Tensor. + cuda: + Whether to place the Tensor on a GPU. + """ + super().__init__(name=name, size=size, min_elements=min_elements, + max_elements=max_elements, dim_parameter=dim_parameter, dtype=dtype, cuda=cuda) + self._density = density + self._coalesced = coalesced + self._sparse_dim = sparse_dim + + @staticmethod + def sparse_tensor_constructor(size, dtype, sparse_dim, nnz, is_coalesced): + """sparse_tensor_constructor creates a sparse tensor with coo format. + + Note that when `is_coalesced` is False, the number of elements is doubled but the number of indices + represents the same amount of number of non zeros `nnz`, i.e, this is virtually the same tensor + with the same sparsity pattern. Moreover, most of the sparse operation will use coalesce() method + and what we want here is to get a sparse tensor with the same `nnz` even if this is coalesced or not. + + In the other hand when `is_coalesced` is True the number of elements is reduced in the coalescing process + by an unclear amount however the probability to generate duplicates indices are low for most of the cases. + This decision was taken on purpose to maintain the construction cost as low as possible. + """ + if isinstance(size, Number): + size = [size] * sparse_dim + assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments' + v_size = [nnz] + list(size[sparse_dim:]) + if dtype.is_floating_point: + v = torch.rand(size=v_size, dtype=dtype, device="cpu") + else: + v = torch.randint(1, 127, size=v_size, dtype=dtype, device="cpu") + + i = torch.rand(sparse_dim, nnz, device="cpu") + i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i)) + i = i.to(torch.long) + + if not is_coalesced: + v = torch.cat([v, torch.randn_like(v)], 0) + i = torch.cat([i, i], 1) + + x = torch.sparse_coo_tensor(i, v, torch.Size(size)) + if is_coalesced: + x = x.coalesce() + return x + + def _make_tensor(self, params, state): + size, _, _ = self._get_size_and_steps(params) + density = params['density'] + nnz = math.ceil(sum(size) * density) + assert nnz <= sum(size) + + is_coalesced = params['coalesced'] + sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size) + sparse_dim = min(sparse_dim, len(size)) + tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced) + + if self._cuda: + tensor = tensor.cuda() + sparse_dim = tensor.sparse_dim() + dense_dim = tensor.dense_dim() + is_hybrid = len(size[sparse_dim:]) > 0 + + properties = { + "numel": int(tensor.numel()), + "shape": tensor.size(), + "is_coalesced": tensor.is_coalesced(), + "density": density, + "sparsity": 1.0 - density, + "sparse_dim": sparse_dim, + "dense_dim": dense_dim, + "is_hybrid": is_hybrid, + "dtype": str(self._dtype), + } + return tensor, properties diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/timeit_template.cpp b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/timeit_template.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afb9f570b6f6c31944f87e3c5ca7ca69bff3e70c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/timeit_template.cpp @@ -0,0 +1,43 @@ +/* C++ template for Timer.timeit + +This template will be consumed by `cpp_jit.py`, and will replace: + `GLOBAL_SETUP_TEMPLATE_LOCATION`, + `SETUP_TEMPLATE_LOCATION` + and + `STMT_TEMPLATE_LOCATION` +sections with user provided statements. +*/ +#include + +#include +#include +#include +#include + +// Global setup. (e.g. #includes) +// GLOBAL_SETUP_TEMPLATE_LOCATION + +double timeit(int n) { + pybind11::gil_scoped_release no_gil; + + // Setup + // SETUP_TEMPLATE_LOCATION + + { + // Warmup + // STMT_TEMPLATE_LOCATION + } + + // Main loop + auto start_time = std::chrono::high_resolution_clock::now(); + for (const auto loop_idx : c10::irange(n)) { + (void)loop_idx; + // STMT_TEMPLATE_LOCATION + } + auto end_time = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end_time - start_time).count(); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("timeit", &timeit); +} diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/timer.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d8cf0859f7bac1d149c7661ca8ed895b5ff88b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/timer.py @@ -0,0 +1,541 @@ +"""Timer class based on the timeit.Timer class, but torch aware.""" +import enum +import timeit +import textwrap +from typing import overload, Any, Callable, NoReturn, Optional, Union + +import torch +from torch.utils.benchmark.utils import common, cpp_jit +from torch.utils.benchmark.utils._stubs import TimerClass, TimeitModuleType +from torch.utils.benchmark.utils.valgrind_wrapper import timer_interface as valgrind_timer_interface + + +__all__ = ["Timer", "timer", "Language"] + + +if torch.backends.cuda.is_built() and torch.cuda.is_available(): # type: ignore[no-untyped-call] + def timer() -> float: + torch.cuda.synchronize() + return timeit.default_timer() +elif torch.xpu.is_available(): + def timer() -> float: + torch.xpu.synchronize() + return timeit.default_timer() +elif torch._C._get_privateuse1_backend_name() != "privateuseone": + privateuse1_device_handler = getattr(torch, torch._C._get_privateuse1_backend_name(), None) \ + if torch._C._get_privateuse1_backend_name() != "cpu" else None + + def timer() -> float: + if privateuse1_device_handler: + privateuse1_device_handler.synchronize() + return timeit.default_timer() +else: + timer = timeit.default_timer + + +class Language(enum.Enum): + PYTHON = 0 + CPP = 1 + + +class CPPTimer: + def __init__( + self, + stmt: str, + setup: str, + global_setup: str, + timer: Callable[[], float], + globals: dict[str, Any], + ) -> None: + if timer is not timeit.default_timer: + raise NotImplementedError( + "PyTorch was built with CUDA and a GPU is present; however " + "Timer does not yet support GPU measurements. If your " + "code is CPU only, pass `timer=timeit.default_timer` to the " + "Timer's constructor to indicate this. (Note that this will " + "produce incorrect results if the GPU is in fact used, as " + "Timer will not synchronize CUDA.)" + ) + + if globals: + raise ValueError("C++ timing does not support globals.") + + self._stmt: str = textwrap.dedent(stmt) + self._setup: str = textwrap.dedent(setup) + self._global_setup: str = textwrap.dedent(global_setup) + self._timeit_module: Optional[TimeitModuleType] = None + + def timeit(self, number: int) -> float: + if self._timeit_module is None: + self._timeit_module = cpp_jit.compile_timeit_template( + stmt=self._stmt, + setup=self._setup, + global_setup=self._global_setup, + ) + + return self._timeit_module.timeit(number) + + +class Timer: + """Helper class for measuring execution time of PyTorch statements. + + For a full tutorial on how to use this class, see: + https://pytorch.org/tutorials/recipes/recipes/benchmark.html + + The PyTorch Timer is based on `timeit.Timer` (and in fact uses + `timeit.Timer` internally), but with several key differences: + + 1) Runtime aware: + Timer will perform warmups (important as some elements of PyTorch are + lazily initialized), set threadpool size so that comparisons are + apples-to-apples, and synchronize asynchronous CUDA functions when + necessary. + + 2) Focus on replicates: + When measuring code, and particularly complex kernels / models, + run-to-run variation is a significant confounding factor. It is + expected that all measurements should include replicates to quantify + noise and allow median computation, which is more robust than mean. + To that effect, this class deviates from the `timeit` API by + conceptually merging `timeit.Timer.repeat` and `timeit.Timer.autorange`. + (Exact algorithms are discussed in method docstrings.) The `timeit` + method is replicated for cases where an adaptive strategy is not + desired. + + 3) Optional metadata: + When defining a Timer, one can optionally specify `label`, `sub_label`, + `description`, and `env`. (Defined later) These fields are included in + the representation of result object and by the `Compare` class to group + and display results for comparison. + + 4) Instruction counts + In addition to wall times, Timer can run a statement under Callgrind + and report instructions executed. + + Directly analogous to `timeit.Timer` constructor arguments: + + `stmt`, `setup`, `timer`, `globals` + + PyTorch Timer specific constructor arguments: + + `label`, `sub_label`, `description`, `env`, `num_threads` + + Args: + stmt: Code snippet to be run in a loop and timed. + + setup: Optional setup code. Used to define variables used in `stmt` + + global_setup: (C++ only) + Code which is placed at the top level of the file for things like + `#include` statements. + + timer: + Callable which returns the current time. If PyTorch was built + without CUDA or there is no GPU present, this defaults to + `timeit.default_timer`; otherwise it will synchronize CUDA before + measuring the time. + + globals: + A dict which defines the global variables when `stmt` is being + executed. This is the other method for providing variables which + `stmt` needs. + + label: + String which summarizes `stmt`. For instance, if `stmt` is + "torch.nn.functional.relu(torch.add(x, 1, out=out))" + one might set label to "ReLU(x + 1)" to improve readability. + + sub_label: + Provide supplemental information to disambiguate measurements + with identical stmt or label. For instance, in our example + above sub_label might be "float" or "int", so that it is easy + to differentiate: + "ReLU(x + 1): (float)" + + "ReLU(x + 1): (int)" + when printing Measurements or summarizing using `Compare`. + + description: + String to distinguish measurements with identical label and + sub_label. The principal use of `description` is to signal to + `Compare` the columns of data. For instance one might set it + based on the input size to create a table of the form: :: + + | n=1 | n=4 | ... + ------------- ... + ReLU(x + 1): (float) | ... | ... | ... + ReLU(x + 1): (int) | ... | ... | ... + + + using `Compare`. It is also included when printing a Measurement. + + env: + This tag indicates that otherwise identical tasks were run in + different environments, and are therefore not equivalent, for + instance when A/B testing a change to a kernel. `Compare` will + treat Measurements with different `env` specification as distinct + when merging replicate runs. + + num_threads: + The size of the PyTorch threadpool when executing `stmt`. Single + threaded performance is important as both a key inference workload + and a good indicator of intrinsic algorithmic efficiency, so the + default is set to one. This is in contrast to the default PyTorch + threadpool size which tries to utilize all cores. + """ + + _timer_cls: type[TimerClass] = timeit.Timer + + def __init__( + self, + stmt: str = "pass", + setup: str = "pass", + global_setup: str = "", + timer: Callable[[], float] = timer, + globals: Optional[dict[str, Any]] = None, + label: Optional[str] = None, + sub_label: Optional[str] = None, + description: Optional[str] = None, + env: Optional[str] = None, + num_threads: int = 1, + language: Union[Language, str] = Language.PYTHON, + ): + if not isinstance(stmt, str): + raise ValueError("Currently only a `str` stmt is supported.") + + # We copy `globals` to prevent mutations from leaking. + # (For instance, `eval` adds the `__builtins__` key) + self._globals = dict(globals or {}) + + timer_kwargs = {} + if language in (Language.PYTHON, "py", "python"): + # Include `torch` if not specified as a convenience feature. + self._globals.setdefault("torch", torch) + self._language: Language = Language.PYTHON + if global_setup: + raise ValueError( + f"global_setup is C++ only, got `{global_setup}`. Most " + "likely this code can simply be moved to `setup`." + ) + + elif language in (Language.CPP, "cpp", "c++"): + assert self._timer_cls is timeit.Timer, "_timer_cls has already been swapped." + self._timer_cls = CPPTimer + setup = ("" if setup == "pass" else setup) + self._language = Language.CPP + timer_kwargs["global_setup"] = global_setup + + else: + raise ValueError(f"Invalid language `{language}`.") + + # Convenience adjustment so that multi-line code snippets defined in + # functions do not IndentationError (Python) or look odd (C++). The + # leading newline removal is for the initial newline that appears when + # defining block strings. For instance: + # textwrap.dedent(""" + # print("This is a stmt") + # """) + # produces '\nprint("This is a stmt")\n'. + # + # Stripping this down to 'print("This is a stmt")' doesn't change + # what gets executed, but it makes __repr__'s nicer. + stmt = textwrap.dedent(stmt) + stmt = (stmt[1:] if stmt and stmt[0] == "\n" else stmt).rstrip() + setup = textwrap.dedent(setup) + setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip() + + self._timer = self._timer_cls( + stmt=stmt, + setup=setup, + timer=timer, + globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals), + **timer_kwargs, + ) + self._task_spec = common.TaskSpec( + stmt=stmt, + setup=setup, + global_setup=global_setup, + label=label, + sub_label=sub_label, + description=description, + env=env, + num_threads=num_threads, + ) + + def _timeit(self, number: int) -> float: + # Even calling a timer in C++ takes ~50 ns, so no real operation should + # take less than 1 ns. (And this prevents divide by zero errors.) + return max(self._timer.timeit(number), 1e-9) + + def timeit(self, number: int = 1000000) -> common.Measurement: + """Mirrors the semantics of timeit.Timer.timeit(). + + Execute the main statement (`stmt`) `number` times. + https://docs.python.org/3/library/timeit.html#timeit.Timer.timeit + """ + with common.set_torch_threads(self._task_spec.num_threads): + # Warmup + self._timeit(number=max(int(number // 100), 2)) + + return common.Measurement( + number_per_run=number, + raw_times=[self._timeit(number=number)], + task_spec=self._task_spec + ) + + def repeat(self, repeat: int = -1, number: int = -1) -> None: + raise NotImplementedError("See `Timer.blocked_autorange.`") + + def autorange(self, callback: Optional[Callable[[int, float], NoReturn]] = None) -> None: + raise NotImplementedError("See `Timer.blocked_autorange.`") + + def _threaded_measurement_loop( + self, + number: int, + time_hook: Callable[[], float], + stop_hook: Callable[[list[float]], bool], + min_run_time: float, + max_run_time: Optional[float] = None, + callback: Optional[Callable[[int, float], NoReturn]] = None + ) -> list[float]: + total_time = 0.0 + can_stop = False + times: list[float] = [] + with common.set_torch_threads(self._task_spec.num_threads): + while (total_time < min_run_time) or (not can_stop): + time_spent = time_hook() + times.append(time_spent) + total_time += time_spent + if callback: + callback(number, time_spent) + can_stop = stop_hook(times) + if max_run_time and total_time > max_run_time: + break + return times + + def _estimate_block_size(self, min_run_time: float) -> int: + with common.set_torch_threads(self._task_spec.num_threads): + # Estimate the block size needed for measurement to be negligible + # compared to the inner loop. This also serves as a warmup. + overhead = torch.tensor([self._timeit(0) for _ in range(5)]).median().item() + number = 1 + while True: + time_taken = self._timeit(number) + relative_overhead = overhead / time_taken + if relative_overhead <= 1e-4 and time_taken >= min_run_time / 1000: + break + if time_taken > min_run_time: + break + # Avoid overflow in C++ pybind11 interface + if number * 10 > 2147483647: + break + number *= 10 + return number + + def blocked_autorange( + self, + callback: Optional[Callable[[int, float], NoReturn]] = None, + min_run_time: float = 0.2, + ) -> common.Measurement: + """Measure many replicates while keeping timer overhead to a minimum. + + At a high level, blocked_autorange executes the following pseudo-code:: + + `setup` + + total_time = 0 + while total_time < min_run_time + start = timer() + for _ in range(block_size): + `stmt` + total_time += (timer() - start) + + Note the variable `block_size` in the inner loop. The choice of block + size is important to measurement quality, and must balance two + competing objectives: + + 1) A small block size results in more replicates and generally + better statistics. + + 2) A large block size better amortizes the cost of `timer` + invocation, and results in a less biased measurement. This is + important because CUDA synchronization time is non-trivial + (order single to low double digit microseconds) and would + otherwise bias the measurement. + + blocked_autorange sets block_size by running a warmup period, + increasing block size until timer overhead is less than 0.1% of + the overall computation. This value is then used for the main + measurement loop. + + Returns: + A `Measurement` object that contains measured runtimes and + repetition counts, and can be used to compute statistics. + (mean, median, etc.) + """ + number = self._estimate_block_size(min_run_time) + + def time_hook() -> float: + return self._timeit(number) + + def stop_hook(times: list[float]) -> bool: + return True + + times = self._threaded_measurement_loop( + number, time_hook, stop_hook, + min_run_time=min_run_time, + callback=callback) + + return common.Measurement( + number_per_run=number, + raw_times=times, + task_spec=self._task_spec + ) + + def adaptive_autorange( + self, + threshold: float = 0.1, + *, + min_run_time: float = 0.01, + max_run_time: float = 10.0, + callback: Optional[Callable[[int, float], NoReturn]] = None, + ) -> common.Measurement: + """Similar to `blocked_autorange` but also checks for variablility in measurements + and repeats until iqr/median is smaller than `threshold` or `max_run_time` is reached. + + + At a high level, adaptive_autorange executes the following pseudo-code:: + + `setup` + + times = [] + while times.sum < max_run_time + start = timer() + for _ in range(block_size): + `stmt` + times.append(timer() - start) + + enough_data = len(times)>3 and times.sum > min_run_time + small_iqr=times.iqr/times.mean float: + return self._timeit(number) + + def stop_hook(times: list[float]) -> bool: + if len(times) > 3: + return common.Measurement( + number_per_run=number, + raw_times=times, + task_spec=self._task_spec + ).meets_confidence(threshold=threshold) + return False + times = self._threaded_measurement_loop( + number, time_hook, stop_hook, min_run_time, max_run_time, callback=callback) + + return common.Measurement( + number_per_run=number, + raw_times=times, + task_spec=self._task_spec + ) + + @overload + def collect_callgrind( + self, + number: int, + *, + repeats: None, + collect_baseline: bool, + retain_out_file: bool, + ) -> valgrind_timer_interface.CallgrindStats: + ... + + @overload + def collect_callgrind( + self, + number: int, + *, + repeats: int, + collect_baseline: bool, + retain_out_file: bool, + ) -> tuple[valgrind_timer_interface.CallgrindStats, ...]: + ... + + def collect_callgrind( + self, + number: int = 100, + *, + repeats: Optional[int] = None, + collect_baseline: bool = True, + retain_out_file: bool = False, + ) -> Any: + """Collect instruction counts using Callgrind. + + Unlike wall times, instruction counts are deterministic + (modulo non-determinism in the program itself and small amounts of + jitter from the Python interpreter.) This makes them ideal for detailed + performance analysis. This method runs `stmt` in a separate process + so that Valgrind can instrument the program. Performance is severely + degraded due to the instrumentation, however this is ameliorated by + the fact that a small number of iterations is generally sufficient to + obtain good measurements. + + In order to to use this method `valgrind`, `callgrind_control`, and + `callgrind_annotate` must be installed. + + Because there is a process boundary between the caller (this process) + and the `stmt` execution, `globals` cannot contain arbitrary in-memory + data structures. (Unlike timing methods) Instead, globals are + restricted to builtins, `nn.Modules`'s, and TorchScripted functions/modules + to reduce the surprise factor from serialization and subsequent + deserialization. The `GlobalsBridge` class provides more detail on this + subject. Take particular care with nn.Modules: they rely on pickle and + you may need to add an import to `setup` for them to transfer properly. + + By default, a profile for an empty statement will be collected and + cached to indicate how many instructions are from the Python loop which + drives `stmt`. + + Returns: + A `CallgrindStats` object which provides instruction counts and + some basic facilities for analyzing and manipulating results. + """ + if not isinstance(self._task_spec.stmt, str): + raise ValueError("`collect_callgrind` currently only supports string `stmt`") + + if repeats is not None and repeats < 1: + raise ValueError("If specified, `repeats` must be >= 1") + + # Check that the statement is valid. It doesn't guarantee success, but it's much + # simpler and quicker to raise an exception for a faulty `stmt` or `setup` in + # the parent process rather than the valgrind subprocess. + self._timeit(1) + is_python = (self._language == Language.PYTHON) + assert is_python or not self._globals + result = valgrind_timer_interface.wrapper_singleton().collect_callgrind( + task_spec=self._task_spec, + globals=self._globals, + number=number, + repeats=repeats or 1, + collect_baseline=collect_baseline and is_python, + is_python=is_python, + retain_out_file=retain_out_file, + ) + + return (result[0] if repeats is None else result) diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86df155e1dcd8d781690c59f9835f6106fd77296 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b92da8c24db1732355271065baca8083efcd8c1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h new file mode 100644 index 0000000000000000000000000000000000000000..2e39be7f73f05ce7be4464021852126e08c313e2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h @@ -0,0 +1,129 @@ + +/* + ---------------------------------------------------------------- + + Notice that the following BSD-style license applies to this one + file (callgrind.h) only. The rest of Valgrind is licensed under the + terms of the GNU General Public License, version 2, unless + otherwise indicated. See the COPYING file in the source + distribution for details. + + ---------------------------------------------------------------- + + This file is part of callgrind, a valgrind tool for cache simulation + and call tree tracing. + + Copyright (C) 2003-2017 Josef Weidendorfer. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. The origin of this software must not be misrepresented; you must + not claim that you wrote the original software. If you use this + software in a product, an acknowledgment in the product + documentation would be appreciated but is not required. + + 3. Altered source versions must be plainly marked as such, and must + not be misrepresented as being the original software. + + 4. The name of the author may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS + OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE + GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------------------------------------------------------- + + Notice that the above BSD-style license applies to this one file + (callgrind.h) only. The entire rest of Valgrind is licensed under + the terms of the GNU General Public License, version 2. See the + COPYING file in the source distribution for details. + + ---------------------------------------------------------------- +*/ + +#ifndef __CALLGRIND_H +#define __CALLGRIND_H + +#include "valgrind.h" + +/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! + This enum comprises an ABI exported by Valgrind to programs + which use client requests. DO NOT CHANGE THE ORDER OF THESE + ENTRIES, NOR DELETE ANY -- add new ones at the end. + + The identification ('C','T') for Callgrind has historical + reasons: it was called "Calltree" before. Besides, ('C','G') would + clash with cachegrind. + */ + +typedef + enum { + VG_USERREQ__DUMP_STATS = VG_USERREQ_TOOL_BASE('C','T'), + VG_USERREQ__ZERO_STATS, + VG_USERREQ__TOGGLE_COLLECT, + VG_USERREQ__DUMP_STATS_AT, + VG_USERREQ__START_INSTRUMENTATION, + VG_USERREQ__STOP_INSTRUMENTATION + } Vg_CallgrindClientRequest; + +/* Dump current state of cost centers, and zero them afterwards */ +#define CALLGRIND_DUMP_STATS \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS, \ + 0, 0, 0, 0, 0) + +/* Dump current state of cost centers, and zero them afterwards. + The argument is appended to a string stating the reason which triggered + the dump. This string is written as a description field into the + profile data dump. */ +#define CALLGRIND_DUMP_STATS_AT(pos_str) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DUMP_STATS_AT, \ + pos_str, 0, 0, 0, 0) + +/* Zero cost centers */ +#define CALLGRIND_ZERO_STATS \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__ZERO_STATS, \ + 0, 0, 0, 0, 0) + +/* Toggles collection state. + The collection state specifies whether the happening of events + should be noted or if they are to be ignored. Events are noted + by increment of counters in a cost center */ +#define CALLGRIND_TOGGLE_COLLECT \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__TOGGLE_COLLECT, \ + 0, 0, 0, 0, 0) + +/* Start full callgrind instrumentation if not already switched on. + When cache simulation is done, it will flush the simulated cache; + this will lead to an artificial cache warmup phase afterwards with + cache misses which would not have happened in reality. */ +#define CALLGRIND_START_INSTRUMENTATION \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__START_INSTRUMENTATION, \ + 0, 0, 0, 0, 0) + +/* Stop full callgrind instrumentation if not already switched off. + This flushes Valgrinds translation cache, and does no additional + instrumentation afterwards, which effectivly will run at the same + speed as the "none" tool (ie. at minimal slowdown). + Use this to bypass Callgrind aggregation for uninteresting code parts. + To start Callgrind in this mode to ignore the setup phase, use + the option "--instr-atstart=no". */ +#define CALLGRIND_STOP_INSTRUMENTATION \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STOP_INSTRUMENTATION, \ + 0, 0, 0, 0, 0) + +#endif /* __CALLGRIND_H */ diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..532eec13708e9294ad58e63f59eafeedf21bf707 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp @@ -0,0 +1,35 @@ +/* Used to collect profiles of old versions of PyTorch. */ +#include +#include + +bool _valgrind_supported_platform() { +#if defined(NVALGRIND) + return false; +#else + return true; +#endif +} + +void _valgrind_toggle() { +#if defined(NVALGRIND) + TORCH_CHECK(false, "Valgrind is not supported."); +#else + CALLGRIND_TOGGLE_COLLECT; +#endif +} + +void _valgrind_toggle_and_dump_stats() { +#if defined(NVALGRIND) + TORCH_CHECK(false, "Valgrind is not supported."); +#else + // NB: See note in Module.cpp + CALLGRIND_TOGGLE_COLLECT; + CALLGRIND_DUMP_STATS; +#endif +} + +PYBIND11_MODULE(callgrind_bindings, m) { + m.def("_valgrind_supported_platform", &_valgrind_supported_platform); + m.def("_valgrind_toggle", &_valgrind_toggle); + m.def("_valgrind_toggle_and_dump_stats", &_valgrind_dump_stats); +} diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c95d0da311b950864d1f67149fcbac8361c8a0a4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp @@ -0,0 +1,68 @@ +/* C++ template for Timer.collect_callgrind + +This template will be consumed by `cpp_jit.py`, and will replace: + `GLOBAL_SETUP_TEMPLATE_LOCATION`, + `SETUP_TEMPLATE_LOCATION` + and + `STMT_TEMPLATE_LOCATION` +sections with user provided statements. +*/ + +#include +#include +#include + +#include + +// Global setup. (e.g. #includes) +// GLOBAL_SETUP_TEMPLATE_LOCATION + +#if defined(NVALGRIND) +static_assert(false); +#endif + +int main(int argc, char* argv[]) { + // This file should only be called inside of `Timer`, so we can adopt a + // very simple and rigid argument parsing scheme. + TORCH_CHECK(argc == 9); + TORCH_CHECK(std::string(argv[1]) == "--number"); + auto number = std::stoi(argv[2]); + + TORCH_CHECK( + std::string(argv[3]) == "--number-warmup" || + std::string(argv[3]) == "--number_warmup"); + auto number_warmup = std::stoi(argv[4]); + + TORCH_CHECK(std::string(argv[5]) == "--repeats"); + auto repeats = std::stoi(argv[6]); + + TORCH_CHECK( + std::string(argv[7]) == "--number-threads" || + std::string(argv[7]) == "--number_threads"); + auto number_threads = std::stoi(argv[8]); + torch::set_num_threads(number_threads); + + // Setup + // SETUP_TEMPLATE_LOCATION + + // Warmup + for (const auto i : c10::irange(number_warmup)) { + (void)i; + // STMT_TEMPLATE_LOCATION + } + + // Main loop + for (const auto repeat : c10::irange(repeats)) { + (void)repeat; + CALLGRIND_TOGGLE_COLLECT; + + for (const auto i : c10::irange(number)) { + (void)i; + // STMT_TEMPLATE_LOCATION + } + + // NB: See note in Module.cpp + CALLGRIND_TOGGLE_COLLECT; + CALLGRIND_DUMP_STATS; + } +} diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..c30009d7e5184d71716b0e1968ea7e413fd42511 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -0,0 +1,910 @@ +"""Intermediate layer between `Timer` and `valgrind`.""" +import collections +import enum +import dataclasses +import itertools as it +import os +import pickle +import re +import shutil +import subprocess +import sys +import textwrap +from typing import ( + cast, Any, Callable, NamedTuple, + Optional, Union, TYPE_CHECKING) +from collections.abc import Iterator + +import torch +from torch.utils.benchmark.utils import common, cpp_jit +from torch.utils.benchmark.utils._stubs import CallgrindModuleType +import operator + + +__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"] + + +if TYPE_CHECKING: + CompletedProcessType = subprocess.CompletedProcess[str] +else: + CompletedProcessType = subprocess.CompletedProcess + + +class FunctionCount(NamedTuple): + # TODO(#105471): Rename the count field + count: int # type: ignore[assignment] + function: str + + +@dataclasses.dataclass(repr=False, eq=False, frozen=True) +class FunctionCounts: + """Container for manipulating Callgrind results. + + It supports: + 1) Addition and subtraction to combine or diff results. + 2) Tuple-like indexing. + 3) A `denoise` function which strips CPython calls which are known to + be non-deterministic and quite noisy. + 4) Two higher order methods (`filter` and `transform`) for custom + manipulation. + """ + _data: tuple[FunctionCount, ...] + inclusive: bool + truncate_rows: bool = True + + # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines + # the print settings. This is simply to allow hermetic unit tests. + _linewidth: Optional[int] = None + + def __iter__(self) -> Iterator[FunctionCount]: + yield from self._data + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]: + data: Union[FunctionCount, tuple[FunctionCount, ...]] = self._data[item] + return ( + FunctionCounts(cast(tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False) + if isinstance(data, tuple) else data + ) + + def __repr__(self) -> str: + count_len = 0 + for c, _ in self: + # Account for sign in string length. + count_len = max(count_len, len(str(c)) + int(c < 0)) + + lines = [] + linewidth = self._linewidth or torch._tensor_str.PRINT_OPTS.linewidth + fn_str_len = max(linewidth - count_len - 4, 40) + for c, fn in self: + if len(fn) > fn_str_len: + left_len = int((fn_str_len - 5) // 2) + fn = fn[:left_len] + " ... " + fn[-(fn_str_len - left_len - 5):] + lines.append(f" {c:>{count_len}} {fn}") + + if self.truncate_rows and len(lines) > 18: + lines = lines[:9] + ["...".rjust(count_len + 2)] + lines[-9:] + + if not self.inclusive: + lines.extend(["", f"Total: {self.sum()}"]) + + return "\n".join([super().__repr__()] + lines) + + def __add__( + self, + other: "FunctionCounts", + ) -> "FunctionCounts": + return self._merge(other, lambda c: c) + + def __sub__( + self, + other: "FunctionCounts", + ) -> "FunctionCounts": + return self._merge(other, operator.neg) + + def __mul__(self, other: Union[int, float]) -> "FunctionCounts": + return self._from_dict({ + fn: int(c * other) for c, fn in self._data + }, self.inclusive) + + def transform(self, map_fn: Callable[[str], str]) -> "FunctionCounts": + """Apply `map_fn` to all of the function names. + + This can be used to regularize function names (e.g. stripping irrelevant + parts of the file path), coalesce entries by mapping multiple functions + to the same name (in which case the counts are added together), etc. + """ + counts: collections.defaultdict[str, int] = collections.defaultdict(int) + for c, fn in self._data: + counts[map_fn(fn)] += c + + return self._from_dict(counts, self.inclusive) + + def filter(self, filter_fn: Callable[[str], bool]) -> "FunctionCounts": + """Keep only the elements where `filter_fn` applied to function name returns True.""" + return FunctionCounts(tuple(i for i in self if filter_fn(i.function)), self.inclusive) + + def sum(self) -> int: + return sum(c for c, _ in self) + + def denoise(self) -> "FunctionCounts": + """Remove known noisy instructions. + + Several instructions in the CPython interpreter are rather noisy. These + instructions involve unicode to dictionary lookups which Python uses to + map variable names. FunctionCounts is generally a content agnostic + container, however this is sufficiently important for obtaining + reliable results to warrant an exception.""" + return self.filter(lambda fn: "dictobject.c:lookdict_unicode" not in fn) + + def _merge( + self, + second: "FunctionCounts", + merge_fn: Callable[[int], int] + ) -> "FunctionCounts": + assert self.inclusive == second.inclusive, "Cannot merge inclusive and exclusive counts." + counts: collections.defaultdict[str, int] = collections.defaultdict(int) + for c, fn in self: + counts[fn] += c + + for c, fn in second: + counts[fn] += merge_fn(c) + + return self._from_dict(counts, self.inclusive) + + @staticmethod + def _from_dict(counts: dict[str, int], inclusive: bool) -> "FunctionCounts": + flat_counts = (FunctionCount(c, fn) for fn, c in counts.items() if c) + return FunctionCounts(tuple(sorted(flat_counts, reverse=True)), inclusive) + + +@dataclasses.dataclass(repr=False, eq=False, frozen=True) +class CallgrindStats: + """Top level container for Callgrind results collected by Timer. + + Manipulation is generally done using the FunctionCounts class, which is + obtained by calling `CallgrindStats.stats(...)`. Several convenience + methods are provided as well; the most significant is + `CallgrindStats.as_standardized()`. + """ + task_spec: common.TaskSpec + number_per_run: int + built_with_debug_symbols: bool + baseline_inclusive_stats: FunctionCounts + baseline_exclusive_stats: FunctionCounts + stmt_inclusive_stats: FunctionCounts + stmt_exclusive_stats: FunctionCounts + stmt_callgrind_out: Optional[str] + + def __repr__(self) -> str: + base_stats = self.baseline_exclusive_stats + output = f""" +{super().__repr__()} +{self.task_spec.summarize()} + {'':>25}All{'':>10}Noisy symbols removed + Instructions: {self.counts(denoise=False):>12}{'':>15}{self.counts(denoise=True):>12} + Baseline: {base_stats.sum():>12}{'':>15}{base_stats.denoise().sum():>12} +{self.number_per_run} runs per measurement, {self.task_spec.num_threads} thread{'s' if self.task_spec.num_threads > 1 else ''} +""".strip() + if not self.built_with_debug_symbols: + output += textwrap.dedent(""" + Warning: PyTorch was not built with debug symbols. + Source information may be limited. Rebuild with + REL_WITH_DEB_INFO=1 for more detailed results.""") + return output + + def stats(self, inclusive: bool = False) -> FunctionCounts: + """Returns detailed function counts. + + Conceptually, the FunctionCounts returned can be thought of as a tuple + of (count, path_and_function_name) tuples. + + `inclusive` matches the semantics of callgrind. If True, the counts + include instructions executed by children. `inclusive=True` is useful + for identifying hot spots in code; `inclusive=False` is useful for + reducing noise when diffing counts from two different runs. (See + CallgrindStats.delta(...) for more details) + """ + return self.stmt_inclusive_stats if inclusive else self.stmt_exclusive_stats + + def counts(self, *, denoise: bool = False) -> int: + """Returns the total number of instructions executed. + + See `FunctionCounts.denoise()` for an explanation of the `denoise` arg. + """ + stats = self.stmt_exclusive_stats + return (stats.denoise() if denoise else stats).sum() + + # FIXME: Once 3.7 is the minimum version, type annotate `other` per PEP 563 + def delta( + self, + other: "CallgrindStats", + inclusive: bool = False, + ) -> FunctionCounts: + """Diff two sets of counts. + + One common reason to collect instruction counts is to determine the + the effect that a particular change will have on the number of instructions + needed to perform some unit of work. If a change increases that number, the + next logical question is "why". This generally involves looking at what part + if the code increased in instruction count. This function automates that + process so that one can easily diff counts on both an inclusive and + exclusive basis. + """ + return self.stats(inclusive=inclusive) - other.stats(inclusive=inclusive) + + def as_standardized(self) -> "CallgrindStats": + """Strip library names and some prefixes from function strings. + + When comparing two different sets of instruction counts, on stumbling + block can be path prefixes. Callgrind includes the full filepath + when reporting a function (as it should). However, this can cause + issues when diffing profiles. If a key component such as Python + or PyTorch was built in separate locations in the two profiles, which + can result in something resembling:: + + 23234231 /tmp/first_build_dir/thing.c:foo(...) + 9823794 /tmp/first_build_dir/thing.c:bar(...) + ... + 53453 .../aten/src/Aten/...:function_that_actually_changed(...) + ... + -9823794 /tmp/second_build_dir/thing.c:bar(...) + -23234231 /tmp/second_build_dir/thing.c:foo(...) + + Stripping prefixes can ameliorate this issue by regularizing the + strings and causing better cancellation of equivalent call sites + when diffing. + """ + def strip(stats: FunctionCounts) -> FunctionCounts: + transforms = ( + # PyTorch may have been built in different locations. + (r"^.+build/\.\./", "build/../"), + (r"^.+/" + re.escape("build/aten/"), "build/aten/"), + + # "Python" and "Objects" come from CPython. + (r"^.+/" + re.escape("Python/"), "Python/"), + (r"^.+/" + re.escape("Objects/"), "Objects/"), + + # Strip library name. e.g. `libtorch.so` + (r"\s\[.+\]$", ""), + ) + + for before, after in transforms: + stats = stats.transform(lambda fn: re.sub(before, after, fn)) + + return stats + + return CallgrindStats( + task_spec=self.task_spec, + number_per_run=self.number_per_run, + built_with_debug_symbols=self.built_with_debug_symbols, + baseline_inclusive_stats=strip(self.baseline_inclusive_stats), + baseline_exclusive_stats=strip(self.baseline_exclusive_stats), + stmt_inclusive_stats=strip(self.stmt_inclusive_stats), + stmt_exclusive_stats=strip(self.stmt_exclusive_stats), + + # `as_standardized` will change symbol names, so the contents will + # no longer map directly to `callgrind.out` + stmt_callgrind_out=None, + ) + + +class Serialization(enum.Enum): + PICKLE = 0 + TORCH = 1 + TORCH_JIT = 2 + + +_GLOBALS_ALLOWED_TYPES: dict[Serialization, tuple[Any, ...]] = { + Serialization.PICKLE: (str, bytes, bool, int, float, complex), + Serialization.TORCH_JIT: (torch.jit.ScriptFunction, torch.jit.ScriptModule), + Serialization.TORCH: (torch.nn.Module,), +} + + +class CopyIfCallgrind: + """Signal that a global may be replaced with a deserialized copy. + + See `GlobalsBridge` for why this matters. + """ + def __init__(self, value: Any, *, setup: Optional[str] = None): + for method, supported_types in _GLOBALS_ALLOWED_TYPES.items(): + if any(isinstance(value, t) for t in supported_types): + self._value: Any = value + self._setup: Optional[str] = setup + self._serialization: Serialization = method + break + else: + supported_str = "\n".join([ + getattr(t, "__name__", repr(t)) + for t in it.chain(_GLOBALS_ALLOWED_TYPES.values())]) + + raise ValueError( + f"Unsupported type: {type(value)}\n" + f"`collect_callgrind` restricts globals to the following types:\n" + f"{textwrap.indent(supported_str, ' ')}" + ) + + @property + def value(self) -> Any: + return self._value + + @property + def setup(self) -> Optional[str]: + return self._setup + + @property + def serialization(self) -> Serialization: + return self._serialization + + @staticmethod + def unwrap_all(globals: dict[str, Any]) -> dict[str, Any]: + return { + k: (v.value if isinstance(v, CopyIfCallgrind) else v) + for k, v in globals.items() + } + + +class GlobalsBridge: + """Handle the transfer of (certain) globals when collecting Callgrind statistics. + + Key takeaway: Any globals passed must be wrapped in `CopyIfCallgrind` to + work with `Timer.collect_callgrind`. + + Consider the following code snippet: + ``` + import pickle + import timeit + + class Counter: + value = 0 + + def __call__(self): + self.value += 1 + + counter = Counter() + timeit.Timer("counter()", globals={"counter": counter}).timeit(10) + print(counter.value) # 10 + + timeit.Timer( + "counter()", + globals={"counter": pickle.loads(pickle.dumps(counter))} + ).timeit(20) + print(counter.value) # Still 10 + ``` + + In the first case, `stmt` is executed using the objects in `globals`; + however, the addition of serialization and deserialization changes the + semantics and may meaningfully change behavior. + + This is a practical consideration when collecting Callgrind statistics. + Unlike `exec` based execution (which `timeit` uses under the hood) which + can share in-memory data structures with the caller, Callgrind collection + requires an entirely new process in order to run under Valgrind. This means + that any data structures used for statement execution will have to be + serialized and deserialized in the subprocess. + + In order to avoid surprising semantics from (user invisible) process + boundaries, what can be passed through `globals` is severely restricted + for `Timer.collect_callgrind`. It is expected that most setup should be + achievable (albeit perhaps less ergonomically) by passing a `setup` + string. + + There are, however, exceptions. One such class are TorchScripted functions. + Because they require a concrete file with source code it is not possible + to define them using a `setup` string. Another group are torch.nn.Modules, + whose construction can be complex and prohibitively cumbersome to coerce + into a `setup` string. Finally, most builtin types are sufficiently well + behaved and sufficiently common to warrant allowing as well. (e.g. + `globals={"n": 1}` is very convenient.) + + Fortunately, all have well defined serialization semantics. This class + is responsible for enabling the Valgrind subprocess to use elements in + `globals` so long as they are an allowed type. + + Caveats: + The user is required to acknowledge this serialization by wrapping + elements in `globals` with `CopyIfCallgrind`. + + While ScriptFunction and ScriptModule are expected to save and load + quite robustly, it is up to the user to ensure that an nn.Module can + un-pickle successfully. + + `torch.Tensor` and `np.ndarray` are deliberately excluded. The + serialization/deserialization process perturbs the representation of a + tensor in ways that could result in incorrect measurements. For example, + if a tensor lives in pinned CPU memory, this fact would not be preserved + by a dump, and that will in turn change the performance of certain CUDA + operations. + """ + + def __init__(self, globals: dict[str, Any], data_dir: str) -> None: + self._globals: dict[str, CopyIfCallgrind] = {} + self._data_dir = data_dir + if not os.path.exists(data_dir): + os.mkdir(data_dir) + + if globals.get("torch", torch) is not torch: + raise ValueError("`collect_callgrind` does not support mocking out `torch`.") + + for name, value in globals.items(): + if name in ("torch", "__builtins__"): + # Torch will be imported by the collection script, and + # __builtins__ is added by Timer. + continue + + if not isinstance(value, CopyIfCallgrind): + raise ValueError( + "`collect_callgrind` requires that globals be wrapped in " + "`CopyIfCallgrind` so that serialization is explicit." + ) + + self._globals[name] = value + + def construct(self) -> str: + load_lines = [] + for name, wrapped_value in self._globals.items(): + if wrapped_value.setup is not None: + load_lines.append(textwrap.dedent(wrapped_value.setup)) + + if wrapped_value.serialization == Serialization.PICKLE: + path = os.path.join(self._data_dir, f"{name}.pkl") + load_lines.append( + f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)") + with open(path, "wb") as f: + pickle.dump(wrapped_value.value, f) + + elif wrapped_value.serialization == Serialization.TORCH: + path = os.path.join(self._data_dir, f"{name}.pt") + # TODO: Figure out if we can use torch.serialization.add_safe_globals here + # Using weights_only=False after the change in + # https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573 + load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)") + torch.save(wrapped_value.value, path) + + elif wrapped_value.serialization == Serialization.TORCH_JIT: + path = os.path.join(self._data_dir, f"{name}.pt") + load_lines.append(f"{name} = torch.jit.load({repr(path)})") + with open(path, "wb") as f: + torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call] + + else: + raise NotImplementedError( + f"Unknown serialization method: {wrapped_value.serialization}") + + return "\n".join(load_lines) + + +class _ValgrindWrapper: + def __init__(self) -> None: + self._bindings_module: Optional[CallgrindModuleType] = None + valgrind_symbols = ( + "_valgrind_supported_platform", + "_valgrind_toggle", + "_valgrind_toggle_and_dump_stats", + ) + if all(hasattr(torch._C, symbol) for symbol in valgrind_symbols): + self._supported_platform: bool = torch._C._valgrind_supported_platform() + + else: + print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.") + self._bindings_module = cpp_jit.get_compat_bindings() + assert all(hasattr(self._bindings_module, symbol) for symbol in valgrind_symbols) + self._supported_platform = self._bindings_module._valgrind_supported_platform() + + self._commands_available: dict[str, bool] = {} + if self._supported_platform: + # Only bother checking on supported platforms. + for cmd in ("valgrind", "callgrind_control", "callgrind_annotate"): + self._commands_available[cmd] = not subprocess.run( + ["which", cmd], + capture_output=True, + check=False, + ).returncode + + self._build_type: Optional[str] = None + build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show()) # type: ignore[no-untyped-call] + if build_search is not None: + self._build_type = build_search.groups()[0].split(",")[0] + + def _validate(self) -> None: + if not self._supported_platform: + raise OSError("Valgrind is not supported on this platform.") + + missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available] + if missing_cmds: + raise OSError("Missing: " + ", ".join(missing_cmds)) + + def collect_callgrind( + self, + task_spec: common.TaskSpec, + globals: dict[str, Any], + *, + number: int, + repeats: int, + collect_baseline: bool, + is_python: bool, + retain_out_file: bool, + ) -> tuple[CallgrindStats, ...]: + """Collect stats, and attach a reference run which can be used to filter interpreter overhead.""" + self._validate() + assert is_python or not collect_baseline + + *task_stats, baseline_stats = self._invoke( + task_spec=task_spec, + globals=globals, + number=number, + repeats=repeats, + collect_baseline=collect_baseline, + is_python=is_python, + retain_out_file=retain_out_file, + ) + assert len(task_stats) == repeats + + return tuple( + CallgrindStats( + task_spec=task_spec, + number_per_run=number, + built_with_debug_symbols=self._build_type == "RelWithDebInfo", + baseline_inclusive_stats=baseline_stats[0], + baseline_exclusive_stats=baseline_stats[1], + stmt_inclusive_stats=stmt_inclusive_stats, + stmt_exclusive_stats=stmt_exclusive_stats, + stmt_callgrind_out=out_contents, + ) + for stmt_inclusive_stats, stmt_exclusive_stats, out_contents in task_stats + ) + + def _invoke( + self, + *, + task_spec: common.TaskSpec, + globals: dict[str, Any], + number: int, + repeats: int, + collect_baseline: bool, + is_python: bool, + retain_out_file: bool, + ) -> tuple[tuple[FunctionCounts, FunctionCounts, Optional[str]], ...]: + """Core invocation method for Callgrind collection. + + Valgrind operates by effectively replacing the CPU with an emulated + version which allows it to instrument any code at the cost of severe + performance degradation. This has the practical effect that in order + to collect Callgrind statistics, a new process has to be created + running under `valgrind`. The steps for this process are: + + 1) Create a scratch directory. + 2) Codegen a run script. (_ValgrindWrapper._construct_script) + Inside the run script: + * Validate that Python and torch match the parent process + * Validate that it is indeed running under valgrind + * Execute `setup` and warm up `stmt` + * Begin collecting stats + * Run the `stmt` loop + * Stop collecting stats + 3) Parse the run results. + 4) Cleanup the scratch directory. + """ + working_dir = common._make_temp_dir(prefix="callgrind") + data_dir = os.path.join(working_dir, "data") + script_file = os.path.join(working_dir, "timer_callgrind.py") + callgrind_out = os.path.join(working_dir, "callgrind.out") + error_log = os.path.join(working_dir, "error.txt") + stat_log = os.path.join(working_dir, "callgrind_stat.txt") + stdout_stderr_log = os.path.join(working_dir, "stdout_stderr.log") + + def run(args: list[str], **kwargs: Any) -> tuple[CompletedProcessType, str]: + # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/ + f_stdout_stderr = open(stdout_stderr_log, "wb") + try: + invocation = subprocess.run( + args, + stdout=f_stdout_stderr, + stderr=subprocess.STDOUT, + **kwargs, + ) + with open(stdout_stderr_log) as f: + return invocation, f.read() + finally: + f_stdout_stderr.close() + + try: + if is_python: + if self._bindings_module is not None: + shutil.copy( + self._bindings_module.__file__, + os.path.join(working_dir, os.path.split(self._bindings_module.__file__)[1]) + ) + + script_file = os.path.join(working_dir, "timer_callgrind.py") + with open(script_file, "w") as f: + f.write(self._construct_script( + task_spec, + globals=GlobalsBridge(globals, data_dir), + number=number, + repeats=repeats, + collect_baseline=collect_baseline, + error_log=error_log, + stat_log=stat_log, + bindings=self._bindings_module)) + + run_loop_cmd = ["python", script_file] + else: + assert not collect_baseline + run_loop_exec = cpp_jit.compile_callgrind_template( + stmt=task_spec.stmt, + setup=task_spec.setup, + global_setup=task_spec.global_setup, + ) + run_loop_cmd = [ + run_loop_exec, + "--number", str(number), + "--number-warmup", str(min(number, 10)), + "--repeats", str(repeats), + "--number-threads", str(task_spec.num_threads), + ] + + valgrind_invocation, valgrind_invocation_output = run([ + "valgrind", + "--tool=callgrind", + f"--callgrind-out-file={callgrind_out}", + "--dump-line=yes", + "--dump-instr=yes", + "--instr-atstart=yes", + "--collect-atstart=no", + ] + run_loop_cmd) + + if valgrind_invocation.returncode: + error_report = "" + if os.path.exists(error_log): + with open(error_log) as f: + error_report = f.read() + if not error_report: + error_report = "Unknown error.\n" + valgrind_invocation_output + + raise OSError(f"Failed to collect callgrind profile:\n{error_report}") + + def parse_output(fpath: str, inclusive: bool) -> FunctionCounts: + _annotate_invocation, annotate_invocation_output = run([ + "callgrind_annotate", + f"--inclusive={'yes' if inclusive else 'no'}", + "--threshold=100", + "--show-percs=no", + fpath + ], check=True) + + total_pattern = re.compile(r"^([0-9,]+)\s+PROGRAM TOTALS") + begin_pattern = re.compile(r"Ir\s+file:function") + function_pattern = re.compile(r"^\s*([0-9,]+)\s+(.+:.+)$") + + class ScanState(enum.Enum): + SCANNING_FOR_TOTAL = 0 + SCANNING_FOR_START = 1 + PARSING = 2 + + scan_state = ScanState.SCANNING_FOR_TOTAL + fn_counts = [] + for l in annotate_invocation_output.splitlines(keepends=False): + if scan_state == ScanState.SCANNING_FOR_TOTAL: + total_match = total_pattern.match(l) + if total_match: + program_totals = int(total_match.groups()[0].replace(",", "")) + scan_state = ScanState.SCANNING_FOR_START + + elif scan_state == ScanState.SCANNING_FOR_START: + if begin_pattern.match(l): + scan_state = ScanState.PARSING + + else: + assert scan_state == ScanState.PARSING + fn_match = function_pattern.match(l) + if fn_match: + ir_str, file_function = fn_match.groups() + ir = int(ir_str.replace(",", "")) + if ir == program_totals: # type: ignore[possibly-undefined] + # Callgrind includes some top level red herring symbols when + # a program dumps multiple profiles. + continue + fn_counts.append(FunctionCount(ir, file_function)) + + elif re.match(r"-+", l): + # Ignore heading separator lines. + continue + + else: + break + + assert scan_state == ScanState.PARSING, f"Failed to parse {fpath}" + return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive) + + def read_results(i: int) -> tuple[FunctionCounts, FunctionCounts, Optional[str]]: + if i == repeats and not collect_baseline: + # Null baseline. + return ( + FunctionCounts((), inclusive=True), + FunctionCounts((), inclusive=False), + None, + ) + + fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files. + callgrind_out_contents: Optional[str] = None + if retain_out_file: + with open(fpath) as f: + callgrind_out_contents = f.read() + + return ( + parse_output(fpath, inclusive=True), + parse_output(fpath, inclusive=False), + callgrind_out_contents + ) + + return tuple(read_results(i) for i in range(repeats + 1)) + finally: + shutil.rmtree(working_dir) + + @staticmethod + def _construct_script( + task_spec: common.TaskSpec, + globals: GlobalsBridge, + *, + number: int, + repeats: int, + collect_baseline: bool, + error_log: str, + stat_log: str, + bindings: Optional[CallgrindModuleType], + ) -> str: + def block_stmt(stmt: str, indent: int = 0) -> str: + """Partially unroll benchmark loop. + + The naive template looks something like: + "for _ in range({number}): {stmt}" + + However a loop in Python is surprisingly expensive, and significantly + increases the number of background Python instructions. So instead we + partially unroll the loops, with a block size of 100 chosen to keep + the instruction overhead from `range` low while also not ballooning + the size of the generated file. + """ + block_size = 100 + loop_count = number // block_size + if loop_count == 1: + # There is no point in having `for _ in range(1): ...` rather + # than just `...`, and this lets us save shave a few background + # instructions. + loop_count = 0 + remainder = number - block_size * loop_count + blocked_stmt = "" + + if loop_count: + unrolled_stmts = textwrap.indent("\n".join([stmt] * block_size), " " * 4) + blocked_stmt += f"for _ in range({loop_count}):\n{unrolled_stmts}\n" + + if remainder: + blocked_stmt += "\n".join([stmt] * remainder) + + return textwrap.indent(blocked_stmt, " " * indent) + + pass_baseline = ( + "callgrind_bindings._valgrind_toggle()\n" + f"{block_stmt('pass')}\n" + "callgrind_bindings._valgrind_toggle_and_dump_stats()" + ) + + return textwrap.dedent(r""" + import gc + import os + import pickle + import subprocess + import sys + import time + + # Mitigate https://github.com/pytorch/pytorch/issues/37377 + # which can sometimes cause the subprocess call to fail. + import numpy as np + + import torch + torch.set_num_threads({num_threads}) + + {bindings_import} + + PID = os.getpid() + + def log_failure(msg): + with open({error_log_repr}, "wt") as f: + f.write(msg) + sys.exit(1) + + def check_result(completed_process): + if completed_process.returncode: + log_failure(f"Command failed: {{' '.join(completed_process.args)}}") + return completed_process + + # ============================================================================= + # == Check that subprocess matches parent ===================================== + # ============================================================================= + if os.path.realpath(sys.executable) != "{parent_interpreter}": + log_failure( + "Interpreter mismatch:\n" + f" {{os.path.realpath(sys.executable)}}\n vs.\n {parent_interpreter}" + ) + + if torch.__file__ != "{torch_file}": + log_failure( + "PyTorch does not match expected file:\n" + f" {{torch.__file__}}\n vs.\n {torch_file}" + ) + + # ============================================================================= + # == User specified setup ===================================================== + # ============================================================================= + # Load serialized globals + {load_globals} + + # User setup str + {setup} + + for _ in range({warmup_number}): + {indented_stmt} + + # ============================================================================= + # == Callgrind management ===================================================== + # ============================================================================= + with open("{stat_log}", "wb") as stat_file: + # If many instances of callgrind are running at once, the output of + # `callgrind_control` may exceed 16kb which would cause `subprocess.PIPE` + # to deadlock. So instead we use a file. + callgrind_stat = check_result(subprocess.run( + ["callgrind_control", "--stat"], + stdout=stat_file, + stderr=subprocess.STDOUT, + )) + + with open("{stat_log}", "rt") as stat_file: + stat_lines = stat_file.read().splitlines() + + if f"PID {{PID}}: python {{__file__}}" not in stat_lines: + log_failure("Process does not appear to be running callgrind.") + + gc.collect() + time.sleep(0.01) + + # ============================================================================= + # == User code block ========================================================== + # ============================================================================= + for _ in range({repeats}): + callgrind_bindings._valgrind_toggle() + {blocked_stmt} + callgrind_bindings._valgrind_toggle_and_dump_stats() + gc.collect() + + {baseline} + """).strip().format( + indented_stmt=textwrap.indent(task_spec.stmt, " " * 4), + blocked_stmt=block_stmt(task_spec.stmt, indent=4), + baseline=(pass_baseline if collect_baseline else ""), + number=number, + repeats=repeats, + load_globals=globals.construct(), + setup=task_spec.setup, + warmup_number=min(number, 10), + num_threads=task_spec.num_threads, + error_log_repr=repr(error_log), + stat_log=stat_log, + parent_interpreter=os.path.realpath(sys.executable), + torch_file=torch.__file__, + bindings_import=( + "import torch._C as callgrind_bindings" if bindings is None + else f"import {bindings.__name__} as callgrind_bindings"), + ) + + +CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None +def wrapper_singleton() -> _ValgrindWrapper: + global CALLGRIND_SINGLETON + if CALLGRIND_SINGLETON is None: + CALLGRIND_SINGLETON = _ValgrindWrapper() + return CALLGRIND_SINGLETON diff --git a/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h new file mode 100644 index 0000000000000000000000000000000000000000..cf227c56a91dbf0ceb65ff8bb97b8e2e244f2a63 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h @@ -0,0 +1,7157 @@ +/* -*- c -*- + ---------------------------------------------------------------- + + Notice that the following BSD-style license applies to this one + file (valgrind.h) only. The rest of Valgrind is licensed under the + terms of the GNU General Public License, version 2, unless + otherwise indicated. See the COPYING file in the source + distribution for details. + + ---------------------------------------------------------------- + + This file is part of Valgrind, a dynamic binary instrumentation + framework. + + Copyright (C) 2000-2017 Julian Seward. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. The origin of this software must not be misrepresented; you must + not claim that you wrote the original software. If you use this + software in a product, an acknowledgment in the product + documentation would be appreciated but is not required. + + 3. Altered source versions must be plainly marked as such, and must + not be misrepresented as being the original software. + + 4. The name of the author may not be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS + OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE + GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------------------------------------------------------- + + Notice that the above BSD-style license applies to this one file + (valgrind.h) only. The entire rest of Valgrind is licensed under + the terms of the GNU General Public License, version 2. See the + COPYING file in the source distribution for details. + + ---------------------------------------------------------------- +*/ + + +/* This file is for inclusion into client (your!) code. + + You can use these macros to manipulate and query Valgrind's + execution inside your own programs. + + The resulting executables will still run without Valgrind, just a + little bit more slowly than they otherwise would, but otherwise + unchanged. When not running on valgrind, each client request + consumes very few (eg. 7) instructions, so the resulting performance + loss is negligible unless you plan to execute client requests + millions of times per second. Nevertheless, if that is still a + problem, you can compile with the NVALGRIND symbol defined (gcc + -DNVALGRIND) so that client requests are not even compiled in. */ + +#ifndef __VALGRIND_H +#define __VALGRIND_H + + +/* ------------------------------------------------------------------ */ +/* VERSION NUMBER OF VALGRIND */ +/* ------------------------------------------------------------------ */ + +/* Specify Valgrind's version number, so that user code can + conditionally compile based on our version number. Note that these + were introduced at version 3.6 and so do not exist in version 3.5 + or earlier. The recommended way to use them to check for "version + X.Y or later" is (eg) + +#if defined(__VALGRIND_MAJOR__) && defined(__VALGRIND_MINOR__) \ + && (__VALGRIND_MAJOR__ > 3 \ + || (__VALGRIND_MAJOR__ == 3 && __VALGRIND_MINOR__ >= 6)) +*/ +#define __VALGRIND_MAJOR__ 3 +#define __VALGRIND_MINOR__ 17 + + +#include + +/* Nb: this file might be included in a file compiled with -ansi. So + we can't use C++ style "//" comments nor the "asm" keyword (instead + use "__asm__"). */ + +/* Derive some tags indicating what the target platform is. Note + that in this file we're using the compiler's CPP symbols for + identifying architectures, which are different to the ones we use + within the rest of Valgrind. Note, __powerpc__ is active for both + 32 and 64-bit PPC, whereas __powerpc64__ is only active for the + latter (on Linux, that is). + + Misc note: how to find out what's predefined in gcc by default: + gcc -Wp,-dM somefile.c +*/ +#undef PLAT_x86_darwin +#undef PLAT_amd64_darwin +#undef PLAT_x86_win32 +#undef PLAT_amd64_win64 +#undef PLAT_x86_linux +#undef PLAT_amd64_linux +#undef PLAT_ppc32_linux +#undef PLAT_ppc64be_linux +#undef PLAT_ppc64le_linux +#undef PLAT_arm_linux +#undef PLAT_arm64_linux +#undef PLAT_s390x_linux +#undef PLAT_mips32_linux +#undef PLAT_mips64_linux +#undef PLAT_nanomips_linux +#undef PLAT_x86_solaris +#undef PLAT_amd64_solaris + + +#if defined(__APPLE__) && defined(__i386__) +# define PLAT_x86_darwin 1 +#elif defined(__APPLE__) && defined(__x86_64__) +# define PLAT_amd64_darwin 1 +#elif (defined(__MINGW32__) && defined(__i386__)) \ + || defined(__CYGWIN32__) \ + || (defined(_WIN32) && defined(_M_IX86)) +# define PLAT_x86_win32 1 +#elif (defined(__MINGW32__) && defined(__x86_64__)) \ + || (defined(_WIN32) && defined(_M_X64)) +/* __MINGW32__ and _WIN32 are defined in 64 bit mode as well. */ +# define PLAT_amd64_win64 1 +#elif defined(__linux__) && defined(__i386__) +# define PLAT_x86_linux 1 +#elif defined(__linux__) && defined(__x86_64__) && !defined(__ILP32__) +# define PLAT_amd64_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && !defined(__powerpc64__) +# define PLAT_ppc32_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF != 2 +/* Big Endian uses ELF version 1 */ +# define PLAT_ppc64be_linux 1 +#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) && _CALL_ELF == 2 +/* Little Endian uses ELF version 2 */ +# define PLAT_ppc64le_linux 1 +#elif defined(__linux__) && defined(__arm__) && !defined(__aarch64__) +# define PLAT_arm_linux 1 +#elif defined(__linux__) && defined(__aarch64__) && !defined(__arm__) +# define PLAT_arm64_linux 1 +#elif defined(__linux__) && defined(__s390__) && defined(__s390x__) +# define PLAT_s390x_linux 1 +#elif defined(__linux__) && defined(__mips__) && (__mips==64) +# define PLAT_mips64_linux 1 +#elif defined(__linux__) && defined(__mips__) && (__mips==32) +# define PLAT_mips32_linux 1 +#elif defined(__linux__) && defined(__nanomips__) +# define PLAT_nanomips_linux 1 +#elif defined(__sun) && defined(__i386__) +# define PLAT_x86_solaris 1 +#elif defined(__sun) && defined(__x86_64__) +# define PLAT_amd64_solaris 1 +#else +/* If we're not compiling for our target platform, don't generate + any inline asms. */ +# if !defined(NVALGRIND) +# define NVALGRIND 1 +# endif +#endif + + +/* ------------------------------------------------------------------ */ +/* ARCHITECTURE SPECIFICS for SPECIAL INSTRUCTIONS. There is nothing */ +/* in here of use to end-users -- skip to the next section. */ +/* ------------------------------------------------------------------ */ + +/* + * VALGRIND_DO_CLIENT_REQUEST(): a statement that invokes a Valgrind client + * request. Accepts both pointers and integers as arguments. + * + * VALGRIND_DO_CLIENT_REQUEST_STMT(): a statement that invokes a Valgrind + * client request that does not return a value. + + * VALGRIND_DO_CLIENT_REQUEST_EXPR(): a C expression that invokes a Valgrind + * client request and whose value equals the client request result. Accepts + * both pointers and integers as arguments. Note that such calls are not + * necessarily pure functions -- they may have side effects. + */ + +#define VALGRIND_DO_CLIENT_REQUEST(_zzq_rlval, _zzq_default, \ + _zzq_request, _zzq_arg1, _zzq_arg2, \ + _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + do { (_zzq_rlval) = VALGRIND_DO_CLIENT_REQUEST_EXPR((_zzq_default), \ + (_zzq_request), (_zzq_arg1), (_zzq_arg2), \ + (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0) + +#define VALGRIND_DO_CLIENT_REQUEST_STMT(_zzq_request, _zzq_arg1, \ + _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + do { (void) VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + (_zzq_request), (_zzq_arg1), (_zzq_arg2), \ + (_zzq_arg3), (_zzq_arg4), (_zzq_arg5)); } while (0) + +#if defined(NVALGRIND) + +/* Define NVALGRIND to completely remove the Valgrind magic sequence + from the compiled code (analogous to NDEBUG's effects on + assert()) */ +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + (_zzq_default) + +#else /* ! NVALGRIND */ + +/* The following defines the magic code sequences which the JITter + spots and handles magically. Don't look too closely at them as + they will rot your brain. + + The assembly code sequences for all architectures is in this one + file. This is because this file must be stand-alone, and we don't + want to have multiple files. + + For VALGRIND_DO_CLIENT_REQUEST, we must ensure that the default + value gets put in the return slot, so that everything works when + this is executed not under Valgrind. Args are passed in a memory + block, and so there's no intrinsic limit to the number that could + be passed, but it's currently five. + + The macro args are: + _zzq_rlval result lvalue + _zzq_default default value (result returned when running on real CPU) + _zzq_request request code + _zzq_arg1..5 request params + + The other two macros are used to support function wrapping, and are + a lot simpler. VALGRIND_GET_NR_CONTEXT returns the value of the + guest's NRADDR pseudo-register and whatever other information is + needed to safely run the call original from the wrapper: on + ppc64-linux, the R2 value at the divert point is also needed. This + information is abstracted into a user-visible type, OrigFn. + + VALGRIND_CALL_NOREDIR_* behaves the same as the following on the + guest, but guarantees that the branch instruction will not be + redirected: x86: call *%eax, amd64: call *%rax, ppc32/ppc64: + branch-and-link-to-r11. VALGRIND_CALL_NOREDIR is just text, not a + complete inline asm, since it needs to be combined with more magic + inline asm stuff to be useful. +*/ + +/* ----------------- x86-{linux,darwin,solaris} ---------------- */ + +#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) \ + || (defined(PLAT_x86_win32) && defined(__GNUC__)) \ + || defined(PLAT_x86_solaris) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "roll $3, %%edi ; roll $13, %%edi\n\t" \ + "roll $29, %%edi ; roll $19, %%edi\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EDX = client_request ( %EAX ) */ \ + "xchgl %%ebx,%%ebx" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EAX = guest_NRADDR */ \ + "xchgl %%ecx,%%ecx" \ + : "=a" (__addr) \ + : \ + : "cc", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_EAX \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%EAX */ \ + "xchgl %%edx,%%edx\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "xchgl %%edi,%%edi\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_x86_linux || PLAT_x86_darwin || (PLAT_x86_win32 && __GNUC__) + || PLAT_x86_solaris */ + +/* ------------------------- x86-Win32 ------------------------- */ + +#if defined(PLAT_x86_win32) && !defined(__GNUC__) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#if defined(_MSC_VER) + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + __asm rol edi, 3 __asm rol edi, 13 \ + __asm rol edi, 29 __asm rol edi, 19 + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + valgrind_do_client_request_expr((uintptr_t)(_zzq_default), \ + (uintptr_t)(_zzq_request), (uintptr_t)(_zzq_arg1), \ + (uintptr_t)(_zzq_arg2), (uintptr_t)(_zzq_arg3), \ + (uintptr_t)(_zzq_arg4), (uintptr_t)(_zzq_arg5)) + +static __inline uintptr_t +valgrind_do_client_request_expr(uintptr_t _zzq_default, uintptr_t _zzq_request, + uintptr_t _zzq_arg1, uintptr_t _zzq_arg2, + uintptr_t _zzq_arg3, uintptr_t _zzq_arg4, + uintptr_t _zzq_arg5) +{ + volatile uintptr_t _zzq_args[6]; + volatile unsigned int _zzq_result; + _zzq_args[0] = (uintptr_t)(_zzq_request); + _zzq_args[1] = (uintptr_t)(_zzq_arg1); + _zzq_args[2] = (uintptr_t)(_zzq_arg2); + _zzq_args[3] = (uintptr_t)(_zzq_arg3); + _zzq_args[4] = (uintptr_t)(_zzq_arg4); + _zzq_args[5] = (uintptr_t)(_zzq_arg5); + __asm { __asm lea eax, _zzq_args __asm mov edx, _zzq_default + __SPECIAL_INSTRUCTION_PREAMBLE + /* %EDX = client_request ( %EAX ) */ + __asm xchg ebx,ebx + __asm mov _zzq_result, edx + } + return _zzq_result; +} + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm { __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %EAX = guest_NRADDR */ \ + __asm xchg ecx,ecx \ + __asm mov __addr, eax \ + } \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_EAX ERROR + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm { __SPECIAL_INSTRUCTION_PREAMBLE \ + __asm xchg edi,edi \ + } \ + } while (0) + +#else +#error Unsupported compiler. +#endif + +#endif /* PLAT_x86_win32 */ + +/* ----------------- amd64-{linux,darwin,solaris} --------------- */ + +#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) \ + || defined(PLAT_amd64_solaris) \ + || (defined(PLAT_amd64_win64) && defined(__GNUC__)) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rolq $3, %%rdi ; rolq $13, %%rdi\n\t" \ + "rolq $61, %%rdi ; rolq $51, %%rdi\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %RDX = client_request ( %RAX ) */ \ + "xchgq %%rbx,%%rbx" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %RAX = guest_NRADDR */ \ + "xchgq %%rcx,%%rcx" \ + : "=a" (__addr) \ + : \ + : "cc", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_RAX \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%RAX */ \ + "xchgq %%rdx,%%rdx\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "xchgq %%rdi,%%rdi\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */ + +/* ------------------------- amd64-Win64 ------------------------- */ + +#if defined(PLAT_amd64_win64) && !defined(__GNUC__) + +#error Unsupported compiler. + +#endif /* PLAT_amd64_win64 */ + +/* ------------------------ ppc32-linux ------------------------ */ + +#if defined(PLAT_ppc32_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rlwinm 0,0,3,0,31 ; rlwinm 0,0,13,0,31\n\t" \ + "rlwinm 0,0,29,0,31 ; rlwinm 0,0,19,0,31\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned int _zzq_args[6]; \ + unsigned int _zzq_result; \ + unsigned int* _zzq_ptr; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R11 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc32_linux */ + +/* ------------------------ ppc64-linux ------------------------ */ + +#if defined(PLAT_ppc64be_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + unsigned long int r2; /* what tocptr do we need? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ + "rotldi 0,0,61 ; rotldi 0,0,51\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned long int _zzq_args[6]; \ + unsigned long int _zzq_result; \ + unsigned long int* _zzq_ptr; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR_GPR2 */ \ + "or 4,4,4\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->r2 = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R11 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc64be_linux */ + +#if defined(PLAT_ppc64le_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + unsigned long int r2; /* what tocptr do we need? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ + "rotldi 0,0,61 ; rotldi 0,0,51\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({ unsigned long int _zzq_args[6]; \ + unsigned long int _zzq_result; \ + unsigned long int* _zzq_ptr; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + _zzq_ptr = _zzq_args; \ + __asm__ volatile("mr 3,%1\n\t" /*default*/ \ + "mr 4,%2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = client_request ( %R4 ) */ \ + "or 1,1,1\n\t" \ + "mr %0,3" /*result*/ \ + : "=b" (_zzq_result) \ + : "b" (_zzq_default), "b" (_zzq_ptr) \ + : "cc", "memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR */ \ + "or 2,2,2\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %R3 = guest_NRADDR_GPR2 */ \ + "or 4,4,4\n\t" \ + "mr %0,3" \ + : "=b" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->r2 = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R12 */ \ + "or 3,3,3\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or 5,5,5\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_ppc64le_linux */ + +/* ------------------------- arm-linux ------------------------- */ + +#if defined(PLAT_arm_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "mov r12, r12, ror #3 ; mov r12, r12, ror #13 \n\t" \ + "mov r12, r12, ror #29 ; mov r12, r12, ror #19 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("mov r3, %1\n\t" /*default*/ \ + "mov r4, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* R3 = client_request ( R4 ) */ \ + "orr r10, r10, r10\n\t" \ + "mov %0, r3" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "cc","memory", "r3", "r4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* R3 = guest_NRADDR */ \ + "orr r11, r11, r11\n\t" \ + "mov %0, r3" \ + : "=r" (__addr) \ + : \ + : "cc", "memory", "r3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir *%R4 */ \ + "orr r12, r12, r12\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "orr r9, r9, r9\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_arm_linux */ + +/* ------------------------ arm64-linux ------------------------- */ + +#if defined(PLAT_arm64_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "ror x12, x12, #3 ; ror x12, x12, #13 \n\t" \ + "ror x12, x12, #51 ; ror x12, x12, #61 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + \ + __extension__ \ + ({volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile("mov x3, %1\n\t" /*default*/ \ + "mov x4, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* X3 = client_request ( X4 ) */ \ + "orr x10, x10, x10\n\t" \ + "mov %0, x3" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" ((unsigned long int)(_zzq_default)), \ + "r" (&_zzq_args[0]) \ + : "cc","memory", "x3", "x4"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* X3 = guest_NRADDR */ \ + "orr x11, x11, x11\n\t" \ + "mov %0, x3" \ + : "=r" (__addr) \ + : \ + : "cc", "memory", "x3" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* branch-and-link-to-noredir X8 */ \ + "orr x12, x12, x12\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "orr x9, x9, x9\n\t" \ + : : : "cc", "memory" \ + ); \ + } while (0) + +#endif /* PLAT_arm64_linux */ + +/* ------------------------ s390x-linux ------------------------ */ + +#if defined(PLAT_s390x_linux) + +typedef + struct { + unsigned long int nraddr; /* where's the code? */ + } + OrigFn; + +/* __SPECIAL_INSTRUCTION_PREAMBLE will be used to identify Valgrind specific + * code. This detection is implemented in platform specific toIR.c + * (e.g. VEX/priv/guest_s390_decoder.c). + */ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "lr 15,15\n\t" \ + "lr 1,1\n\t" \ + "lr 2,2\n\t" \ + "lr 3,3\n\t" + +#define __CLIENT_REQUEST_CODE "lr 2,2\n\t" +#define __GET_NR_CONTEXT_CODE "lr 3,3\n\t" +#define __CALL_NO_REDIR_CODE "lr 4,4\n\t" +#define __VEX_INJECT_IR_CODE "lr 5,5\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile(/* r2 = args */ \ + "lgr 2,%1\n\t" \ + /* r3 = default */ \ + "lgr 3,%2\n\t" \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + __CLIENT_REQUEST_CODE \ + /* results = r3 */ \ + "lgr %0, 3\n\t" \ + : "=d" (_zzq_result) \ + : "a" (&_zzq_args[0]), "0" (_zzq_default) \ + : "cc", "2", "3", "memory" \ + ); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + __GET_NR_CONTEXT_CODE \ + "lgr %0, 3\n\t" \ + : "=a" (__addr) \ + : \ + : "cc", "3", "memory" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_R1 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + __CALL_NO_REDIR_CODE + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + __VEX_INJECT_IR_CODE); \ + } while (0) + +#endif /* PLAT_s390x_linux */ + +/* ------------------------- mips32-linux ---------------- */ + +#if defined(PLAT_mips32_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; + +/* .word 0x342 + * .word 0x742 + * .word 0xC2 + * .word 0x4C2*/ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "srl $0, $0, 13\n\t" \ + "srl $0, $0, 29\n\t" \ + "srl $0, $0, 3\n\t" \ + "srl $0, $0, 19\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("move $11, %1\n\t" /*default*/ \ + "move $12, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* T3 = client_request ( T4 ) */ \ + "or $13, $13, $13\n\t" \ + "move %0, $11\n\t" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$11", "$12", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* %t9 = guest_NRADDR */ \ + "or $14, $14, $14\n\t" \ + "move %0, $11" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$11" \ + ); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir *%t9 */ \ + "or $15, $15, $15\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or $11, $11, $11\n\t" \ + ); \ + } while (0) + + +#endif /* PLAT_mips32_linux */ + +/* ------------------------- mips64-linux ---------------- */ + +#if defined(PLAT_mips64_linux) + +typedef + struct { + unsigned long nraddr; /* where's the code? */ + } + OrigFn; + +/* dsll $0,$0, 3 + * dsll $0,$0, 13 + * dsll $0,$0, 29 + * dsll $0,$0, 19*/ +#define __SPECIAL_INSTRUCTION_PREAMBLE \ + "dsll $0,$0, 3 ; dsll $0,$0,13\n\t" \ + "dsll $0,$0,29 ; dsll $0,$0,19\n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned long int _zzq_args[6]; \ + volatile unsigned long int _zzq_result; \ + _zzq_args[0] = (unsigned long int)(_zzq_request); \ + _zzq_args[1] = (unsigned long int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned long int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned long int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned long int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned long int)(_zzq_arg5); \ + __asm__ volatile("move $11, %1\n\t" /*default*/ \ + "move $12, %2\n\t" /*ptr*/ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* $11 = client_request ( $12 ) */ \ + "or $13, $13, $13\n\t" \ + "move %0, $11\n\t" /*result*/ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$11", "$12", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* $11 = guest_NRADDR */ \ + "or $14, $14, $14\n\t" \ + "move %0, $11" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$11"); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir $25 */ \ + "or $15, $15, $15\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or $11, $11, $11\n\t" \ + ); \ + } while (0) + +#endif /* PLAT_mips64_linux */ + +#if defined(PLAT_nanomips_linux) + +typedef + struct { + unsigned int nraddr; /* where's the code? */ + } + OrigFn; +/* + 8000 c04d srl zero, zero, 13 + 8000 c05d srl zero, zero, 29 + 8000 c043 srl zero, zero, 3 + 8000 c053 srl zero, zero, 19 +*/ + +#define __SPECIAL_INSTRUCTION_PREAMBLE "srl[32] $zero, $zero, 13 \n\t" \ + "srl[32] $zero, $zero, 29 \n\t" \ + "srl[32] $zero, $zero, 3 \n\t" \ + "srl[32] $zero, $zero, 19 \n\t" + +#define VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + _zzq_default, _zzq_request, \ + _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ + __extension__ \ + ({ volatile unsigned int _zzq_args[6]; \ + volatile unsigned int _zzq_result; \ + _zzq_args[0] = (unsigned int)(_zzq_request); \ + _zzq_args[1] = (unsigned int)(_zzq_arg1); \ + _zzq_args[2] = (unsigned int)(_zzq_arg2); \ + _zzq_args[3] = (unsigned int)(_zzq_arg3); \ + _zzq_args[4] = (unsigned int)(_zzq_arg4); \ + _zzq_args[5] = (unsigned int)(_zzq_arg5); \ + __asm__ volatile("move $a7, %1\n\t" /* default */ \ + "move $t0, %2\n\t" /* ptr */ \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* $a7 = client_request( $t0 ) */ \ + "or[32] $t0, $t0, $t0\n\t" \ + "move %0, $a7\n\t" /* result */ \ + : "=r" (_zzq_result) \ + : "r" (_zzq_default), "r" (&_zzq_args[0]) \ + : "$a7", "$t0", "memory"); \ + _zzq_result; \ + }) + +#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ + { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ + volatile unsigned long int __addr; \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + /* $a7 = guest_NRADDR */ \ + "or[32] $t1, $t1, $t1\n\t" \ + "move %0, $a7" /*result*/ \ + : "=r" (__addr) \ + : \ + : "$a7"); \ + _zzq_orig->nraddr = __addr; \ + } + +#define VALGRIND_CALL_NOREDIR_T9 \ + __SPECIAL_INSTRUCTION_PREAMBLE \ + /* call-noredir $25 */ \ + "or[32] $t2, $t2, $t2\n\t" + +#define VALGRIND_VEX_INJECT_IR() \ + do { \ + __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ + "or[32] $t3, $t3, $t3\n\t" \ + ); \ + } while (0) + +#endif +/* Insert assembly code for other platforms here... */ + +#endif /* NVALGRIND */ + + +/* ------------------------------------------------------------------ */ +/* PLATFORM SPECIFICS for FUNCTION WRAPPING. This is all very */ +/* ugly. It's the least-worst tradeoff I can think of. */ +/* ------------------------------------------------------------------ */ + +/* This section defines magic (a.k.a appalling-hack) macros for doing + guaranteed-no-redirection macros, so as to get from function + wrappers to the functions they are wrapping. The whole point is to + construct standard call sequences, but to do the call itself with a + special no-redirect call pseudo-instruction that the JIT + understands and handles specially. This section is long and + repetitious, and I can't see a way to make it shorter. + + The naming scheme is as follows: + + CALL_FN_{W,v}_{v,W,WW,WWW,WWWW,5W,6W,7W,etc} + + 'W' stands for "word" and 'v' for "void". Hence there are + different macros for calling arity 0, 1, 2, 3, 4, etc, functions, + and for each, the possibility of returning a word-typed result, or + no result. +*/ + +/* Use these to write the name of your wrapper. NOTE: duplicates + VG_WRAP_FUNCTION_Z{U,Z} in pub_tool_redir.h. NOTE also: inserts + the default behaviour equivalance class tag "0000" into the name. + See pub_tool_redir.h for details -- normally you don't need to + think about this, though. */ + +/* Use an extra level of macroisation so as to ensure the soname/fnname + args are fully macro-expanded before pasting them together. */ +#define VG_CONCAT4(_aa,_bb,_cc,_dd) _aa##_bb##_cc##_dd + +#define I_WRAP_SONAME_FNNAME_ZU(soname,fnname) \ + VG_CONCAT4(_vgw00000ZU_,soname,_,fnname) + +#define I_WRAP_SONAME_FNNAME_ZZ(soname,fnname) \ + VG_CONCAT4(_vgw00000ZZ_,soname,_,fnname) + +/* Use this macro from within a wrapper function to collect the + context (address and possibly other info) of the original function. + Once you have that you can then use it in one of the CALL_FN_ + macros. The type of the argument _lval is OrigFn. */ +#define VALGRIND_GET_ORIG_FN(_lval) VALGRIND_GET_NR_CONTEXT(_lval) + +/* Also provide end-user facilities for function replacement, rather + than wrapping. A replacement function differs from a wrapper in + that it has no way to get hold of the original function being + called, and hence no way to call onwards to it. In a replacement + function, VALGRIND_GET_ORIG_FN always returns zero. */ + +#define I_REPLACE_SONAME_FNNAME_ZU(soname,fnname) \ + VG_CONCAT4(_vgr00000ZU_,soname,_,fnname) + +#define I_REPLACE_SONAME_FNNAME_ZZ(soname,fnname) \ + VG_CONCAT4(_vgr00000ZZ_,soname,_,fnname) + +/* Derivatives of the main macros below, for calling functions + returning void. */ + +#define CALL_FN_v_v(fnptr) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_v(_junk,fnptr); } while (0) + +#define CALL_FN_v_W(fnptr, arg1) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_W(_junk,fnptr,arg1); } while (0) + +#define CALL_FN_v_WW(fnptr, arg1,arg2) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WW(_junk,fnptr,arg1,arg2); } while (0) + +#define CALL_FN_v_WWW(fnptr, arg1,arg2,arg3) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WWW(_junk,fnptr,arg1,arg2,arg3); } while (0) + +#define CALL_FN_v_WWWW(fnptr, arg1,arg2,arg3,arg4) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_WWWW(_junk,fnptr,arg1,arg2,arg3,arg4); } while (0) + +#define CALL_FN_v_5W(fnptr, arg1,arg2,arg3,arg4,arg5) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_5W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5); } while (0) + +#define CALL_FN_v_6W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_6W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6); } while (0) + +#define CALL_FN_v_7W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6,arg7) \ + do { volatile unsigned long _junk; \ + CALL_FN_W_7W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6,arg7); } while (0) + +/* ----------------- x86-{linux,darwin,solaris} ---------------- */ + +#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) \ + || defined(PLAT_x86_solaris) + +/* These regs are trashed by the hidden call. No need to mention eax + as gcc can already see that, plus causes gcc to bomb. */ +#define __CALLER_SAVED_REGS /*"eax"*/ "ecx", "edx" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "movl %%esp,%%edi\n\t" \ + "andl $0xfffffff0,%%esp\n\t" +#define VALGRIND_RESTORE_STACK \ + "movl %%edi,%%esp\n\t" + +/* These CALL_FN_ macros assume that on x86-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $12, %%esp\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $8, %%esp\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "subl $4, %%esp\n\t" \ + "pushl 44(%%eax)\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "pushl 48(%%eax)\n\t" \ + "pushl 44(%%eax)\n\t" \ + "pushl 40(%%eax)\n\t" \ + "pushl 36(%%eax)\n\t" \ + "pushl 32(%%eax)\n\t" \ + "pushl 28(%%eax)\n\t" \ + "pushl 24(%%eax)\n\t" \ + "pushl 20(%%eax)\n\t" \ + "pushl 16(%%eax)\n\t" \ + "pushl 12(%%eax)\n\t" \ + "pushl 8(%%eax)\n\t" \ + "pushl 4(%%eax)\n\t" \ + "movl (%%eax), %%eax\n\t" /* target->%eax */ \ + VALGRIND_CALL_NOREDIR_EAX \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "edi" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_x86_linux || PLAT_x86_darwin || PLAT_x86_solaris */ + +/* ---------------- amd64-{linux,darwin,solaris} --------------- */ + +#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) \ + || defined(PLAT_amd64_solaris) + +/* ARGREGS: rdi rsi rdx rcx r8 r9 (the rest on stack in R-to-L order) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS /*"rax",*/ "rcx", "rdx", "rsi", \ + "rdi", "r8", "r9", "r10", "r11" + +/* This is all pretty complex. It's so as to make stack unwinding + work reliably. See bug 243270. The basic problem is the sub and + add of 128 of %rsp in all of the following macros. If gcc believes + the CFA is in %rsp, then unwinding may fail, because what's at the + CFA is not what gcc "expected" when it constructs the CFIs for the + places where the macros are instantiated. + + But we can't just add a CFI annotation to increase the CFA offset + by 128, to match the sub of 128 from %rsp, because we don't know + whether gcc has chosen %rsp as the CFA at that point, or whether it + has chosen some other register (eg, %rbp). In the latter case, + adding a CFI annotation to change the CFA offset is simply wrong. + + So the solution is to get hold of the CFA using + __builtin_dwarf_cfa(), put it in a known register, and add a + CFI annotation to say what the register is. We choose %rbp for + this (perhaps perversely), because: + + (1) %rbp is already subject to unwinding. If a new register was + chosen then the unwinder would have to unwind it in all stack + traces, which is expensive, and + + (2) %rbp is already subject to precise exception updates in the + JIT. If a new register was chosen, we'd have to have precise + exceptions for it too, which reduces performance of the + generated code. + + However .. one extra complication. We can't just whack the result + of __builtin_dwarf_cfa() into %rbp and then add %rbp to the + list of trashed registers at the end of the inline assembly + fragments; gcc won't allow %rbp to appear in that list. Hence + instead we need to stash %rbp in %r15 for the duration of the asm, + and say that %r15 is trashed instead. gcc seems happy to go with + that. + + Oh .. and this all needs to be conditionalised so that it is + unchanged from before this commit, when compiled with older gccs + that don't support __builtin_dwarf_cfa. Furthermore, since + this header file is freestanding, it has to be independent of + config.h, and so the following conditionalisation cannot depend on + configure time checks. + + Although it's not clear from + 'defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM)', + this expression excludes Darwin. + .cfi directives in Darwin assembly appear to be completely + different and I haven't investigated how they work. + + For even more entertainment value, note we have to use the + completely undocumented __builtin_dwarf_cfa(), which appears to + really compute the CFA, whereas __builtin_frame_address(0) claims + to but actually doesn't. See + https://bugs.kde.org/show_bug.cgi?id=243270#c47 +*/ +#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM) +# define __FRAME_POINTER \ + ,"r"(__builtin_dwarf_cfa()) +# define VALGRIND_CFI_PROLOGUE \ + "movq %%rbp, %%r15\n\t" \ + "movq %2, %%rbp\n\t" \ + ".cfi_remember_state\n\t" \ + ".cfi_def_cfa rbp, 0\n\t" +# define VALGRIND_CFI_EPILOGUE \ + "movq %%r15, %%rbp\n\t" \ + ".cfi_restore_state\n\t" +#else +# define __FRAME_POINTER +# define VALGRIND_CFI_PROLOGUE +# define VALGRIND_CFI_EPILOGUE +#endif + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "movq %%rsp,%%r14\n\t" \ + "andq $0xfffffffffffffff0,%%rsp\n\t" +#define VALGRIND_RESTORE_STACK \ + "movq %%r14,%%rsp\n\t" + +/* These CALL_FN_ macros assume that on amd64-linux, sizeof(unsigned + long) == 8. */ + +/* NB 9 Sept 07. There is a nasty kludge here in all these CALL_FN_ + macros. In order not to trash the stack redzone, we need to drop + %rsp by 128 before the hidden call, and restore afterwards. The + nastyness is that it is only by luck that the stack still appears + to be unwindable during the hidden call - since then the behaviour + of any routine using this macro does not match what the CFI data + says. Sigh. + + Why is this important? Imagine that a wrapper has a stack + allocated local, and passes to the hidden call, a pointer to it. + Because gcc does not know about the hidden call, it may allocate + that local in the redzone. Unfortunately the hidden call may then + trash it before it comes to use it. So we must step clear of the + redzone, for the duration of the hidden call, to make it safe. + + Probably the same problem afflicts the other redzone-style ABIs too + (ppc64-linux); but for those, the stack is + self describing (none of this CFI nonsense) so at least messing + with the stack pointer doesn't give a danger of non-unwindable + stack. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $136,%%rsp\n\t" \ + "pushq 88(%%rax)\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + VALGRIND_ALIGN_STACK \ + "subq $128,%%rsp\n\t" \ + "pushq 96(%%rax)\n\t" \ + "pushq 88(%%rax)\n\t" \ + "pushq 80(%%rax)\n\t" \ + "pushq 72(%%rax)\n\t" \ + "pushq 64(%%rax)\n\t" \ + "pushq 56(%%rax)\n\t" \ + "movq 48(%%rax), %%r9\n\t" \ + "movq 40(%%rax), %%r8\n\t" \ + "movq 32(%%rax), %%rcx\n\t" \ + "movq 24(%%rax), %%rdx\n\t" \ + "movq 16(%%rax), %%rsi\n\t" \ + "movq 8(%%rax), %%rdi\n\t" \ + "movq (%%rax), %%rax\n\t" /* target->%rax */ \ + VALGRIND_CALL_NOREDIR_RAX \ + VALGRIND_RESTORE_STACK \ + VALGRIND_CFI_EPILOGUE \ + : /*out*/ "=a" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r14", "r15" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_amd64_linux || PLAT_amd64_darwin || PLAT_amd64_solaris */ + +/* ------------------------ ppc32-linux ------------------------ */ + +#if defined(PLAT_ppc32_linux) + +/* This is useful for finding out about the on-stack stuff: + + extern int f9 ( int,int,int,int,int,int,int,int,int ); + extern int f10 ( int,int,int,int,int,int,int,int,int,int ); + extern int f11 ( int,int,int,int,int,int,int,int,int,int,int ); + extern int f12 ( int,int,int,int,int,int,int,int,int,int,int,int ); + + int g9 ( void ) { + return f9(11,22,33,44,55,66,77,88,99); + } + int g10 ( void ) { + return f10(11,22,33,44,55,66,77,88,99,110); + } + int g11 ( void ) { + return f11(11,22,33,44,55,66,77,88,99,110,121); + } + int g12 ( void ) { + return f12(11,22,33,44,55,66,77,88,99,110,121,132); + } +*/ + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rlwinm 1,1,0,0,27\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc32-linux, + sizeof(unsigned long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-16\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-16\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-32\n\t" \ + /* arg11 */ \ + "lwz 3,44(11)\n\t" \ + "stw 3,16(1)\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + _argvec[12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "addi 1,1,-32\n\t" \ + /* arg12 */ \ + "lwz 3,48(11)\n\t" \ + "stw 3,20(1)\n\t" \ + /* arg11 */ \ + "lwz 3,44(11)\n\t" \ + "stw 3,16(1)\n\t" \ + /* arg10 */ \ + "lwz 3,40(11)\n\t" \ + "stw 3,12(1)\n\t" \ + /* arg9 */ \ + "lwz 3,36(11)\n\t" \ + "stw 3,8(1)\n\t" \ + /* args1-8 */ \ + "lwz 3,4(11)\n\t" /* arg1->r3 */ \ + "lwz 4,8(11)\n\t" \ + "lwz 5,12(11)\n\t" \ + "lwz 6,16(11)\n\t" /* arg4->r6 */ \ + "lwz 7,20(11)\n\t" \ + "lwz 8,24(11)\n\t" \ + "lwz 9,28(11)\n\t" \ + "lwz 10,32(11)\n\t" /* arg8->r10 */ \ + "lwz 11,0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + VALGRIND_RESTORE_STACK \ + "mr %0,3" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc32_linux */ + +/* ------------------------ ppc64-linux ------------------------ */ + +#if defined(PLAT_ppc64be_linux) + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rldicr 1,1,0,59\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned + long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+0]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+1]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+2]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+3]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+4]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+5]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+6]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+7]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+8]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+9]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+10]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+11]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg11 */ \ + "ld 3,88(11)\n\t" \ + "std 3,128(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+12]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + _argvec[2+12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 11,%1\n\t" \ + "std 2,-16(11)\n\t" /* save tocptr */ \ + "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg12 */ \ + "ld 3,96(11)\n\t" \ + "std 3,136(1)\n\t" \ + /* arg11 */ \ + "ld 3,88(11)\n\t" \ + "std 3,128(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(11)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(11)\n\t" \ + "std 3,112(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(11)\n\t" /* arg1->r3 */ \ + "ld 4, 16(11)\n\t" /* arg2->r4 */ \ + "ld 5, 24(11)\n\t" /* arg3->r5 */ \ + "ld 6, 32(11)\n\t" /* arg4->r6 */ \ + "ld 7, 40(11)\n\t" /* arg5->r7 */ \ + "ld 8, 48(11)\n\t" /* arg6->r8 */ \ + "ld 9, 56(11)\n\t" /* arg7->r9 */ \ + "ld 10, 64(11)\n\t" /* arg8->r10 */ \ + "ld 11, 0(11)\n\t" /* target->r11 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ + "mr 11,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(11)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc64be_linux */ + +/* ------------------------- ppc64le-linux ----------------------- */ +#if defined(PLAT_ppc64le_linux) + +/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "lr", "ctr", "xer", \ + "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ + "r0", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ + "r11", "r12", "r13" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +#define VALGRIND_ALIGN_STACK \ + "mr 28,1\n\t" \ + "rldicr 1,1,0,59\n\t" +#define VALGRIND_RESTORE_STACK \ + "mr 1,28\n\t" + +/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned + long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+0]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+1]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+2]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+3]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+4]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+5]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+6]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+7]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+8]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+9]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+10]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-128\n\t" /* expand stack frame */ \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+11]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg11 */ \ + "ld 3,88(12)\n\t" \ + "std 3,112(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3+12]; \ + volatile unsigned long _res; \ + /* _argvec[0] holds current r2 across the call */ \ + _argvec[1] = (unsigned long)_orig.r2; \ + _argvec[2] = (unsigned long)_orig.nraddr; \ + _argvec[2+1] = (unsigned long)arg1; \ + _argvec[2+2] = (unsigned long)arg2; \ + _argvec[2+3] = (unsigned long)arg3; \ + _argvec[2+4] = (unsigned long)arg4; \ + _argvec[2+5] = (unsigned long)arg5; \ + _argvec[2+6] = (unsigned long)arg6; \ + _argvec[2+7] = (unsigned long)arg7; \ + _argvec[2+8] = (unsigned long)arg8; \ + _argvec[2+9] = (unsigned long)arg9; \ + _argvec[2+10] = (unsigned long)arg10; \ + _argvec[2+11] = (unsigned long)arg11; \ + _argvec[2+12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "mr 12,%1\n\t" \ + "std 2,-16(12)\n\t" /* save tocptr */ \ + "ld 2,-8(12)\n\t" /* use nraddr's tocptr */ \ + "addi 1,1,-144\n\t" /* expand stack frame */ \ + /* arg12 */ \ + "ld 3,96(12)\n\t" \ + "std 3,120(1)\n\t" \ + /* arg11 */ \ + "ld 3,88(12)\n\t" \ + "std 3,112(1)\n\t" \ + /* arg10 */ \ + "ld 3,80(12)\n\t" \ + "std 3,104(1)\n\t" \ + /* arg9 */ \ + "ld 3,72(12)\n\t" \ + "std 3,96(1)\n\t" \ + /* args1-8 */ \ + "ld 3, 8(12)\n\t" /* arg1->r3 */ \ + "ld 4, 16(12)\n\t" /* arg2->r4 */ \ + "ld 5, 24(12)\n\t" /* arg3->r5 */ \ + "ld 6, 32(12)\n\t" /* arg4->r6 */ \ + "ld 7, 40(12)\n\t" /* arg5->r7 */ \ + "ld 8, 48(12)\n\t" /* arg6->r8 */ \ + "ld 9, 56(12)\n\t" /* arg7->r9 */ \ + "ld 10, 64(12)\n\t" /* arg8->r10 */ \ + "ld 12, 0(12)\n\t" /* target->r12 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R12 \ + "mr 12,%1\n\t" \ + "mr %0,3\n\t" \ + "ld 2,-16(12)\n\t" /* restore tocptr */ \ + VALGRIND_RESTORE_STACK \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[2]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r28" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_ppc64le_linux */ + +/* ------------------------- arm-linux ------------------------- */ + +#if defined(PLAT_arm_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "r0", "r1", "r2", "r3","r4", "r12", "r14" + +/* Macros to save and align the stack before making a function + call and restore it afterwards as gcc may not keep the stack + pointer aligned if it doesn't realise calls are being made + to other functions. */ + +/* This is a bit tricky. We store the original stack pointer in r10 + as it is callee-saves. gcc doesn't allow the use of r11 for some + reason. Also, we can't directly "bic" the stack pointer in thumb + mode since r13 isn't an allowed register number in that context. + So use r4 as a temporary, since that is about to get trashed + anyway, just after each use of this macro. Side effect is we need + to be very careful about any future changes, since + VALGRIND_ALIGN_STACK simply assumes r4 is usable. */ +#define VALGRIND_ALIGN_STACK \ + "mov r10, sp\n\t" \ + "mov r4, sp\n\t" \ + "bic r4, r4, #7\n\t" \ + "mov sp, r4\n\t" +#define VALGRIND_RESTORE_STACK \ + "mov sp, r10\n\t" + +/* These CALL_FN_ macros assume that on arm-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "push {r0} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "push {r0, r1} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "push {r0, r1, r2} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "push {r0, r1, r2, r3} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #40] \n\t" \ + "push {r0} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #4 \n\t" \ + "ldr r0, [%1, #40] \n\t" \ + "ldr r1, [%1, #44] \n\t" \ + "push {r0, r1} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr r0, [%1, #40] \n\t" \ + "ldr r1, [%1, #44] \n\t" \ + "ldr r2, [%1, #48] \n\t" \ + "push {r0, r1, r2} \n\t" \ + "ldr r0, [%1, #20] \n\t" \ + "ldr r1, [%1, #24] \n\t" \ + "ldr r2, [%1, #28] \n\t" \ + "ldr r3, [%1, #32] \n\t" \ + "ldr r4, [%1, #36] \n\t" \ + "push {r0, r1, r2, r3, r4} \n\t" \ + "ldr r0, [%1, #4] \n\t" \ + "ldr r1, [%1, #8] \n\t" \ + "ldr r2, [%1, #12] \n\t" \ + "ldr r3, [%1, #16] \n\t" \ + "ldr r4, [%1] \n\t" /* target->r4 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ + VALGRIND_RESTORE_STACK \ + "mov %0, r0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "r10" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_arm_linux */ + +/* ------------------------ arm64-linux ------------------------ */ + +#if defined(PLAT_arm64_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS \ + "x0", "x1", "x2", "x3","x4", "x5", "x6", "x7", "x8", "x9", \ + "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", \ + "x18", "x19", "x20", "x30", \ + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", \ + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", \ + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", \ + "v26", "v27", "v28", "v29", "v30", "v31" + +/* x21 is callee-saved, so we can use it to save and restore SP around + the hidden call. */ +#define VALGRIND_ALIGN_STACK \ + "mov x21, sp\n\t" \ + "bic sp, x21, #15\n\t" +#define VALGRIND_RESTORE_STACK \ + "mov sp, x21\n\t" + +/* These CALL_FN_ macros assume that on arm64-linux, + sizeof(unsigned long) == 8. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x20 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x20 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x30 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1, #88] \n\t" \ + "str x8, [sp, #16] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10,arg11, \ + arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + VALGRIND_ALIGN_STACK \ + "sub sp, sp, #0x30 \n\t" \ + "ldr x0, [%1, #8] \n\t" \ + "ldr x1, [%1, #16] \n\t" \ + "ldr x2, [%1, #24] \n\t" \ + "ldr x3, [%1, #32] \n\t" \ + "ldr x4, [%1, #40] \n\t" \ + "ldr x5, [%1, #48] \n\t" \ + "ldr x6, [%1, #56] \n\t" \ + "ldr x7, [%1, #64] \n\t" \ + "ldr x8, [%1, #72] \n\t" \ + "str x8, [sp, #0] \n\t" \ + "ldr x8, [%1, #80] \n\t" \ + "str x8, [sp, #8] \n\t" \ + "ldr x8, [%1, #88] \n\t" \ + "str x8, [sp, #16] \n\t" \ + "ldr x8, [%1, #96] \n\t" \ + "str x8, [sp, #24] \n\t" \ + "ldr x8, [%1] \n\t" /* target->x8 */ \ + VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_X8 \ + VALGRIND_RESTORE_STACK \ + "mov %0, x0" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS, "x21" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_arm64_linux */ + +/* ------------------------- s390x-linux ------------------------- */ + +#if defined(PLAT_s390x_linux) + +/* Similar workaround as amd64 (see above), but we use r11 as frame + pointer and save the old r11 in r7. r11 might be used for + argvec, therefore we copy argvec in r1 since r1 is clobbered + after the call anyway. */ +#if defined(__GNUC__) && defined(__GCC_HAVE_DWARF2_CFI_ASM) +# define __FRAME_POINTER \ + ,"d"(__builtin_dwarf_cfa()) +# define VALGRIND_CFI_PROLOGUE \ + ".cfi_remember_state\n\t" \ + "lgr 1,%1\n\t" /* copy the argvec pointer in r1 */ \ + "lgr 7,11\n\t" \ + "lgr 11,%2\n\t" \ + ".cfi_def_cfa r11, 0\n\t" +# define VALGRIND_CFI_EPILOGUE \ + "lgr 11, 7\n\t" \ + ".cfi_restore_state\n\t" +#else +# define __FRAME_POINTER +# define VALGRIND_CFI_PROLOGUE \ + "lgr 1,%1\n\t" +# define VALGRIND_CFI_EPILOGUE +#endif + +/* Nb: On s390 the stack pointer is properly aligned *at all times* + according to the s390 GCC maintainer. (The ABI specification is not + precise in this regard.) Therefore, VALGRIND_ALIGN_STACK and + VALGRIND_RESTORE_STACK are not defined here. */ + +/* These regs are trashed by the hidden call. Note that we overwrite + r14 in s390_irgen_noredir (VEX/priv/guest_s390_irgen.c) to give the + function a proper return address. All others are ABI defined call + clobbers. */ +#if defined(__VX__) || defined(__S390_VX__) +#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14", \ + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", \ + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", \ + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", \ + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" +#else +#define __CALLER_SAVED_REGS "0", "1", "2", "3", "4", "5", "14", \ + "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7" +#endif + +/* Nb: Although r11 is modified in the asm snippets below (inside + VALGRIND_CFI_PROLOGUE) it is not listed in the clobber section, for + two reasons: + (1) r11 is restored in VALGRIND_CFI_EPILOGUE, so effectively it is not + modified + (2) GCC will complain that r11 cannot appear inside a clobber section, + when compiled with -O -fno-omit-frame-pointer + */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 1, 0(1)\n\t" /* target->r1 */ \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "d" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +/* The call abi has the arguments in r2-r6 and stack */ +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1, arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1, arg2, arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1, arg2, arg3, arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1, arg2, arg3, arg4, arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-160\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,160\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-168\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,168\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-176\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,176\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-184\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,184\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-192\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,192\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-200\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,200\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10, arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-208\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "mvc 200(8,15), 88(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,208\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1, arg2, arg3, arg4, arg5, \ + arg6, arg7 ,arg8, arg9, arg10, arg11, arg12)\ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)arg1; \ + _argvec[2] = (unsigned long)arg2; \ + _argvec[3] = (unsigned long)arg3; \ + _argvec[4] = (unsigned long)arg4; \ + _argvec[5] = (unsigned long)arg5; \ + _argvec[6] = (unsigned long)arg6; \ + _argvec[7] = (unsigned long)arg7; \ + _argvec[8] = (unsigned long)arg8; \ + _argvec[9] = (unsigned long)arg9; \ + _argvec[10] = (unsigned long)arg10; \ + _argvec[11] = (unsigned long)arg11; \ + _argvec[12] = (unsigned long)arg12; \ + __asm__ volatile( \ + VALGRIND_CFI_PROLOGUE \ + "aghi 15,-216\n\t" \ + "lg 2, 8(1)\n\t" \ + "lg 3,16(1)\n\t" \ + "lg 4,24(1)\n\t" \ + "lg 5,32(1)\n\t" \ + "lg 6,40(1)\n\t" \ + "mvc 160(8,15), 48(1)\n\t" \ + "mvc 168(8,15), 56(1)\n\t" \ + "mvc 176(8,15), 64(1)\n\t" \ + "mvc 184(8,15), 72(1)\n\t" \ + "mvc 192(8,15), 80(1)\n\t" \ + "mvc 200(8,15), 88(1)\n\t" \ + "mvc 208(8,15), 96(1)\n\t" \ + "lg 1, 0(1)\n\t" \ + VALGRIND_CALL_NOREDIR_R1 \ + "aghi 15,216\n\t" \ + VALGRIND_CFI_EPILOGUE \ + "lgr %0, 2\n\t" \ + : /*out*/ "=d" (_res) \ + : /*in*/ "a" (&_argvec[0]) __FRAME_POINTER \ + : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS,"6","7" \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + + +#endif /* PLAT_s390x_linux */ + +/* ------------------------- mips32-linux ----------------------- */ + +#if defined(PLAT_mips32_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6", \ +"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \ +"$25", "$31" + +/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16\n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" /* arg1*/ \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "subu $29, $29, 16 \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 16 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 24\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 24 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 32\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "nop\n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 32 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 32\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 32 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 40\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 40 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 40\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 40 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 48\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 48 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 48\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 44(%1) \n\t" \ + "sw $4, 40($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 48 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + "subu $29, $29, 8 \n\t" \ + "sw $28, 0($29) \n\t" \ + "sw $31, 4($29) \n\t" \ + "lw $4, 20(%1) \n\t" \ + "subu $29, $29, 56\n\t" \ + "sw $4, 16($29) \n\t" \ + "lw $4, 24(%1) \n\t" \ + "sw $4, 20($29) \n\t" \ + "lw $4, 28(%1) \n\t" \ + "sw $4, 24($29) \n\t" \ + "lw $4, 32(%1) \n\t" \ + "sw $4, 28($29) \n\t" \ + "lw $4, 36(%1) \n\t" \ + "sw $4, 32($29) \n\t" \ + "lw $4, 40(%1) \n\t" \ + "sw $4, 36($29) \n\t" \ + "lw $4, 44(%1) \n\t" \ + "sw $4, 40($29) \n\t" \ + "lw $4, 48(%1) \n\t" \ + "sw $4, 44($29) \n\t" \ + "lw $4, 4(%1) \n\t" \ + "lw $5, 8(%1) \n\t" \ + "lw $6, 12(%1) \n\t" \ + "lw $7, 16(%1) \n\t" \ + "lw $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "addu $29, $29, 56 \n\t" \ + "lw $28, 0($29) \n\t" \ + "lw $31, 4($29) \n\t" \ + "addu $29, $29, 8 \n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_mips32_linux */ + +/* ------------------------- nanomips-linux -------------------- */ + +#if defined(PLAT_nanomips_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$t4", "$t5", "$a0", "$a1", "$a2", \ +"$a3", "$a4", "$a5", "$a6", "$a7", "$t0", "$t1", "$t2", "$t3", \ +"$t8","$t9", "$at" + +/* These CALL_FN_ macros assume that on mips-linux, sizeof(unsigned + long) == 4. */ + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[1]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[2]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[3]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[4]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[5]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[6]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[7]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[8]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + "lw $a6,28(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[9]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + __asm__ volatile( \ + "lw $t9, 0(%1)\n\t" \ + "lw $a0, 4(%1)\n\t" \ + "lw $a1, 8(%1)\n\t" \ + "lw $a2,12(%1)\n\t" \ + "lw $a3,16(%1)\n\t" \ + "lw $a4,20(%1)\n\t" \ + "lw $a5,24(%1)\n\t" \ + "lw $a6,28(%1)\n\t" \ + "lw $a7,32(%1)\n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[10]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[11]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[12]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9,44(%1) \n\t" \ + "sw $t9, 8($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long _argvec[13]; \ + volatile unsigned long _res; \ + _argvec[0] = (unsigned long)_orig.nraddr; \ + _argvec[1] = (unsigned long)(arg1); \ + _argvec[2] = (unsigned long)(arg2); \ + _argvec[3] = (unsigned long)(arg3); \ + _argvec[4] = (unsigned long)(arg4); \ + _argvec[5] = (unsigned long)(arg5); \ + _argvec[6] = (unsigned long)(arg6); \ + _argvec[7] = (unsigned long)(arg7); \ + _argvec[8] = (unsigned long)(arg8); \ + _argvec[9] = (unsigned long)(arg9); \ + _argvec[10] = (unsigned long)(arg10); \ + _argvec[11] = (unsigned long)(arg11); \ + _argvec[12] = (unsigned long)(arg12); \ + __asm__ volatile( \ + "addiu $sp, $sp, -16 \n\t" \ + "lw $t9,36(%1) \n\t" \ + "sw $t9, 0($sp) \n\t" \ + "lw $t9,40(%1) \n\t" \ + "sw $t9, 4($sp) \n\t" \ + "lw $t9,44(%1) \n\t" \ + "sw $t9, 8($sp) \n\t" \ + "lw $t9,48(%1) \n\t" \ + "sw $t9,12($sp) \n\t" \ + "lw $t9, 0(%1) \n\t" \ + "lw $a0, 4(%1) \n\t" \ + "lw $a1, 8(%1) \n\t" \ + "lw $a2,12(%1) \n\t" \ + "lw $a3,16(%1) \n\t" \ + "lw $a4,20(%1) \n\t" \ + "lw $a5,24(%1) \n\t" \ + "lw $a6,28(%1) \n\t" \ + "lw $a7,32(%1) \n\t" \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $a0 \n\t" \ + "addiu $sp, $sp, 16 \n\t" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) _res; \ + } while (0) + +#endif /* PLAT_nanomips_linux */ + +/* ------------------------- mips64-linux ------------------------- */ + +#if defined(PLAT_mips64_linux) + +/* These regs are trashed by the hidden call. */ +#define __CALLER_SAVED_REGS "$2", "$3", "$4", "$5", "$6", \ +"$7", "$8", "$9", "$10", "$11", "$12", "$13", "$14", "$15", "$24", \ +"$25", "$31" + +/* These CALL_FN_ macros assume that on mips64-linux, + sizeof(long long) == 8. */ + +#define MIPS64_LONG2REG_CAST(x) ((long long)(long)x) + +#define CALL_FN_W_v(lval, orig) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[1]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + __asm__ volatile( \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "0" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_W(lval, orig, arg1) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[2]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" /* arg1*/ \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[3]; \ + volatile unsigned long long _res; \ + _argvec[0] = _orig.nraddr; \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + + +#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[4]; \ + volatile unsigned long long _res; \ + _argvec[0] = _orig.nraddr; \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[5]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[6]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[7]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[8]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[9]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + __asm__ volatile( \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1) \n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[10]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + __asm__ volatile( \ + "dsubu $29, $29, 8\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 8\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ + arg7,arg8,arg9,arg10) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[11]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + __asm__ volatile( \ + "dsubu $29, $29, 16\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 16\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[12]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + _argvec[11] = MIPS64_LONG2REG_CAST(arg11); \ + __asm__ volatile( \ + "dsubu $29, $29, 24\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 88(%1)\n\t" \ + "sd $4, 16($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 24\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ + arg6,arg7,arg8,arg9,arg10, \ + arg11,arg12) \ + do { \ + volatile OrigFn _orig = (orig); \ + volatile unsigned long long _argvec[13]; \ + volatile unsigned long long _res; \ + _argvec[0] = MIPS64_LONG2REG_CAST(_orig.nraddr); \ + _argvec[1] = MIPS64_LONG2REG_CAST(arg1); \ + _argvec[2] = MIPS64_LONG2REG_CAST(arg2); \ + _argvec[3] = MIPS64_LONG2REG_CAST(arg3); \ + _argvec[4] = MIPS64_LONG2REG_CAST(arg4); \ + _argvec[5] = MIPS64_LONG2REG_CAST(arg5); \ + _argvec[6] = MIPS64_LONG2REG_CAST(arg6); \ + _argvec[7] = MIPS64_LONG2REG_CAST(arg7); \ + _argvec[8] = MIPS64_LONG2REG_CAST(arg8); \ + _argvec[9] = MIPS64_LONG2REG_CAST(arg9); \ + _argvec[10] = MIPS64_LONG2REG_CAST(arg10); \ + _argvec[11] = MIPS64_LONG2REG_CAST(arg11); \ + _argvec[12] = MIPS64_LONG2REG_CAST(arg12); \ + __asm__ volatile( \ + "dsubu $29, $29, 32\n\t" \ + "ld $4, 72(%1)\n\t" \ + "sd $4, 0($29)\n\t" \ + "ld $4, 80(%1)\n\t" \ + "sd $4, 8($29)\n\t" \ + "ld $4, 88(%1)\n\t" \ + "sd $4, 16($29)\n\t" \ + "ld $4, 96(%1)\n\t" \ + "sd $4, 24($29)\n\t" \ + "ld $4, 8(%1)\n\t" \ + "ld $5, 16(%1)\n\t" \ + "ld $6, 24(%1)\n\t" \ + "ld $7, 32(%1)\n\t" \ + "ld $8, 40(%1)\n\t" \ + "ld $9, 48(%1)\n\t" \ + "ld $10, 56(%1)\n\t" \ + "ld $11, 64(%1)\n\t" \ + "ld $25, 0(%1)\n\t" /* target->t9 */ \ + VALGRIND_CALL_NOREDIR_T9 \ + "daddu $29, $29, 32\n\t" \ + "move %0, $2\n" \ + : /*out*/ "=r" (_res) \ + : /*in*/ "r" (&_argvec[0]) \ + : /*trash*/ "memory", __CALLER_SAVED_REGS \ + ); \ + lval = (__typeof__(lval)) (long)_res; \ + } while (0) + +#endif /* PLAT_mips64_linux */ + +/* ------------------------------------------------------------------ */ +/* ARCHITECTURE INDEPENDENT MACROS for CLIENT REQUESTS. */ +/* */ +/* ------------------------------------------------------------------ */ + +/* Some request codes. There are many more of these, but most are not + exposed to end-user view. These are the public ones, all of the + form 0x1000 + small_number. + + Core ones are in the range 0x00000000--0x0000ffff. The non-public + ones start at 0x2000. +*/ + +/* These macros are used by tools -- they must be public, but don't + embed them into other programs. */ +#define VG_USERREQ_TOOL_BASE(a,b) \ + ((unsigned int)(((a)&0xff) << 24 | ((b)&0xff) << 16)) +#define VG_IS_TOOL_USERREQ(a, b, v) \ + (VG_USERREQ_TOOL_BASE(a,b) == ((v) & 0xffff0000)) + +/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! + This enum comprises an ABI exported by Valgrind to programs + which use client requests. DO NOT CHANGE THE NUMERIC VALUES OF THESE + ENTRIES, NOR DELETE ANY -- add new ones at the end of the most + relevant group. */ +typedef + enum { VG_USERREQ__RUNNING_ON_VALGRIND = 0x1001, + VG_USERREQ__DISCARD_TRANSLATIONS = 0x1002, + + /* These allow any function to be called from the simulated + CPU but run on the real CPU. Nb: the first arg passed to + the function is always the ThreadId of the running + thread! So CLIENT_CALL0 actually requires a 1 arg + function, etc. */ + VG_USERREQ__CLIENT_CALL0 = 0x1101, + VG_USERREQ__CLIENT_CALL1 = 0x1102, + VG_USERREQ__CLIENT_CALL2 = 0x1103, + VG_USERREQ__CLIENT_CALL3 = 0x1104, + + /* Can be useful in regression testing suites -- eg. can + send Valgrind's output to /dev/null and still count + errors. */ + VG_USERREQ__COUNT_ERRORS = 0x1201, + + /* Allows the client program and/or gdbserver to execute a monitor + command. */ + VG_USERREQ__GDB_MONITOR_COMMAND = 0x1202, + + /* Allows the client program to change a dynamic command line + option. */ + VG_USERREQ__CLO_CHANGE = 0x1203, + + /* These are useful and can be interpreted by any tool that + tracks malloc() et al, by using vg_replace_malloc.c. */ + VG_USERREQ__MALLOCLIKE_BLOCK = 0x1301, + VG_USERREQ__RESIZEINPLACE_BLOCK = 0x130b, + VG_USERREQ__FREELIKE_BLOCK = 0x1302, + /* Memory pool support. */ + VG_USERREQ__CREATE_MEMPOOL = 0x1303, + VG_USERREQ__DESTROY_MEMPOOL = 0x1304, + VG_USERREQ__MEMPOOL_ALLOC = 0x1305, + VG_USERREQ__MEMPOOL_FREE = 0x1306, + VG_USERREQ__MEMPOOL_TRIM = 0x1307, + VG_USERREQ__MOVE_MEMPOOL = 0x1308, + VG_USERREQ__MEMPOOL_CHANGE = 0x1309, + VG_USERREQ__MEMPOOL_EXISTS = 0x130a, + + /* Allow printfs to valgrind log. */ + /* The first two pass the va_list argument by value, which + assumes it is the same size as or smaller than a UWord, + which generally isn't the case. Hence are deprecated. + The second two pass the vargs by reference and so are + immune to this problem. */ + /* both :: char* fmt, va_list vargs (DEPRECATED) */ + VG_USERREQ__PRINTF = 0x1401, + VG_USERREQ__PRINTF_BACKTRACE = 0x1402, + /* both :: char* fmt, va_list* vargs */ + VG_USERREQ__PRINTF_VALIST_BY_REF = 0x1403, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF = 0x1404, + + /* Stack support. */ + VG_USERREQ__STACK_REGISTER = 0x1501, + VG_USERREQ__STACK_DEREGISTER = 0x1502, + VG_USERREQ__STACK_CHANGE = 0x1503, + + /* Wine support */ + VG_USERREQ__LOAD_PDB_DEBUGINFO = 0x1601, + + /* Querying of debug info. */ + VG_USERREQ__MAP_IP_TO_SRCLOC = 0x1701, + + /* Disable/enable error reporting level. Takes a single + Word arg which is the delta to this thread's error + disablement indicator. Hence 1 disables or further + disables errors, and -1 moves back towards enablement. + Other values are not allowed. */ + VG_USERREQ__CHANGE_ERR_DISABLEMENT = 0x1801, + + /* Some requests used for Valgrind internal, such as + self-test or self-hosting. */ + /* Initialise IR injection */ + VG_USERREQ__VEX_INIT_FOR_IRI = 0x1901, + /* Used by Inner Valgrind to inform Outer Valgrind where to + find the list of inner guest threads */ + VG_USERREQ__INNER_THREADS = 0x1902 + } Vg_ClientRequest; + +#if !defined(__GNUC__) +# define __extension__ /* */ +#endif + + +/* Returns the number of Valgrinds this code is running under. That + is, 0 if running natively, 1 if running under Valgrind, 2 if + running under Valgrind which is running under another Valgrind, + etc. */ +#define RUNNING_ON_VALGRIND \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* if not */, \ + VG_USERREQ__RUNNING_ON_VALGRIND, \ + 0, 0, 0, 0, 0) \ + + +/* Discard translation of code in the range [_qzz_addr .. _qzz_addr + + _qzz_len - 1]. Useful if you are debugging a JITter or some such, + since it provides a way to make sure valgrind will retranslate the + invalidated area. Returns no value. */ +#define VALGRIND_DISCARD_TRANSLATIONS(_qzz_addr,_qzz_len) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DISCARD_TRANSLATIONS, \ + _qzz_addr, _qzz_len, 0, 0, 0) + +#define VALGRIND_INNER_THREADS(_qzz_addr) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__INNER_THREADS, \ + _qzz_addr, 0, 0, 0, 0) + + +/* These requests are for getting Valgrind itself to print something. + Possibly with a backtrace. This is a really ugly hack. The return value + is the number of characters printed, excluding the "**** " part at the + start and the backtrace (if present). */ + +#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER) +/* Modern GCC will optimize the static routine out if unused, + and unused attribute will shut down warnings about it. */ +static int VALGRIND_PRINTF(const char *format, ...) + __attribute__((format(__printf__, 1, 2), __unused__)); +#endif +static int +#if defined(_MSC_VER) +__inline +#endif +VALGRIND_PRINTF(const char *format, ...) +{ +#if defined(NVALGRIND) + (void)format; + return 0; +#else /* NVALGRIND */ +#if defined(_MSC_VER) || defined(__MINGW64__) + uintptr_t _qzz_res; +#else + unsigned long _qzz_res; +#endif + va_list vargs; + va_start(vargs, format); +#if defined(_MSC_VER) || defined(__MINGW64__) + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_VALIST_BY_REF, + (uintptr_t)format, + (uintptr_t)&vargs, + 0, 0, 0); +#else + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_VALIST_BY_REF, + (unsigned long)format, + (unsigned long)&vargs, + 0, 0, 0); +#endif + va_end(vargs); + return (int)_qzz_res; +#endif /* NVALGRIND */ +} + +#if defined(__GNUC__) || defined(__INTEL_COMPILER) && !defined(_MSC_VER) +static int VALGRIND_PRINTF_BACKTRACE(const char *format, ...) + __attribute__((format(__printf__, 1, 2), __unused__)); +#endif +static int +#if defined(_MSC_VER) +__inline +#endif +VALGRIND_PRINTF_BACKTRACE(const char *format, ...) +{ +#if defined(NVALGRIND) + (void)format; + return 0; +#else /* NVALGRIND */ +#if defined(_MSC_VER) || defined(__MINGW64__) + uintptr_t _qzz_res; +#else + unsigned long _qzz_res; +#endif + va_list vargs; + va_start(vargs, format); +#if defined(_MSC_VER) || defined(__MINGW64__) + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF, + (uintptr_t)format, + (uintptr_t)&vargs, + 0, 0, 0); +#else + _qzz_res = VALGRIND_DO_CLIENT_REQUEST_EXPR(0, + VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF, + (unsigned long)format, + (unsigned long)&vargs, + 0, 0, 0); +#endif + va_end(vargs); + return (int)_qzz_res; +#endif /* NVALGRIND */ +} + + +/* These requests allow control to move from the simulated CPU to the + real CPU, calling an arbitrary function. + + Note that the current ThreadId is inserted as the first argument. + So this call: + + VALGRIND_NON_SIMD_CALL2(f, arg1, arg2) + + requires f to have this signature: + + Word f(Word tid, Word arg1, Word arg2) + + where "Word" is a word-sized type. + + Note that these client requests are not entirely reliable. For example, + if you call a function with them that subsequently calls printf(), + there's a high chance Valgrind will crash. Generally, your prospects of + these working are made higher if the called function does not refer to + any global variables, and does not refer to any libc or other functions + (printf et al). Any kind of entanglement with libc or dynamic linking is + likely to have a bad outcome, for tricky reasons which we've grappled + with a lot in the past. +*/ +#define VALGRIND_NON_SIMD_CALL0(_qyy_fn) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL0, \ + _qyy_fn, \ + 0, 0, 0, 0) + +#define VALGRIND_NON_SIMD_CALL1(_qyy_fn, _qyy_arg1) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL1, \ + _qyy_fn, \ + _qyy_arg1, 0, 0, 0) + +#define VALGRIND_NON_SIMD_CALL2(_qyy_fn, _qyy_arg1, _qyy_arg2) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL2, \ + _qyy_fn, \ + _qyy_arg1, _qyy_arg2, 0, 0) + +#define VALGRIND_NON_SIMD_CALL3(_qyy_fn, _qyy_arg1, _qyy_arg2, _qyy_arg3) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0 /* default return */, \ + VG_USERREQ__CLIENT_CALL3, \ + _qyy_fn, \ + _qyy_arg1, _qyy_arg2, \ + _qyy_arg3, 0) + + +/* Counts the number of errors that have been recorded by a tool. Nb: + the tool must record the errors with VG_(maybe_record_error)() or + VG_(unique_error)() for them to be counted. */ +#define VALGRIND_COUNT_ERRORS \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR( \ + 0 /* default return */, \ + VG_USERREQ__COUNT_ERRORS, \ + 0, 0, 0, 0, 0) + +/* Several Valgrind tools (Memcheck, Massif, Helgrind, DRD) rely on knowing + when heap blocks are allocated in order to give accurate results. This + happens automatically for the standard allocator functions such as + malloc(), calloc(), realloc(), memalign(), new, new[], free(), delete, + delete[], etc. + + But if your program uses a custom allocator, this doesn't automatically + happen, and Valgrind will not do as well. For example, if you allocate + superblocks with mmap() and then allocates chunks of the superblocks, all + Valgrind's observations will be at the mmap() level and it won't know that + the chunks should be considered separate entities. In Memcheck's case, + that means you probably won't get heap block overrun detection (because + there won't be redzones marked as unaddressable) and you definitely won't + get any leak detection. + + The following client requests allow a custom allocator to be annotated so + that it can be handled accurately by Valgrind. + + VALGRIND_MALLOCLIKE_BLOCK marks a region of memory as having been allocated + by a malloc()-like function. For Memcheck (an illustrative case), this + does two things: + + - It records that the block has been allocated. This means any addresses + within the block mentioned in error messages will be + identified as belonging to the block. It also means that if the block + isn't freed it will be detected by the leak checker. + + - It marks the block as being addressable and undefined (if 'is_zeroed' is + not set), or addressable and defined (if 'is_zeroed' is set). This + controls how accesses to the block by the program are handled. + + 'addr' is the start of the usable block (ie. after any + redzone), 'sizeB' is its size. 'rzB' is the redzone size if the allocator + can apply redzones -- these are blocks of padding at the start and end of + each block. Adding redzones is recommended as it makes it much more likely + Valgrind will spot block overruns. `is_zeroed' indicates if the memory is + zeroed (or filled with another predictable value), as is the case for + calloc(). + + VALGRIND_MALLOCLIKE_BLOCK should be put immediately after the point where a + heap block -- that will be used by the client program -- is allocated. + It's best to put it at the outermost level of the allocator if possible; + for example, if you have a function my_alloc() which calls + internal_alloc(), and the client request is put inside internal_alloc(), + stack traces relating to the heap block will contain entries for both + my_alloc() and internal_alloc(), which is probably not what you want. + + For Memcheck users: if you use VALGRIND_MALLOCLIKE_BLOCK to carve out + custom blocks from within a heap block, B, that has been allocated with + malloc/calloc/new/etc, then block B will be *ignored* during leak-checking + -- the custom blocks will take precedence. + + VALGRIND_FREELIKE_BLOCK is the partner to VALGRIND_MALLOCLIKE_BLOCK. For + Memcheck, it does two things: + + - It records that the block has been deallocated. This assumes that the + block was annotated as having been allocated via + VALGRIND_MALLOCLIKE_BLOCK. Otherwise, an error will be issued. + + - It marks the block as being unaddressable. + + VALGRIND_FREELIKE_BLOCK should be put immediately after the point where a + heap block is deallocated. + + VALGRIND_RESIZEINPLACE_BLOCK informs a tool about reallocation. For + Memcheck, it does four things: + + - It records that the size of a block has been changed. This assumes that + the block was annotated as having been allocated via + VALGRIND_MALLOCLIKE_BLOCK. Otherwise, an error will be issued. + + - If the block shrunk, it marks the freed memory as being unaddressable. + + - If the block grew, it marks the new area as undefined and defines a red + zone past the end of the new block. + + - The V-bits of the overlap between the old and the new block are preserved. + + VALGRIND_RESIZEINPLACE_BLOCK should be put after allocation of the new block + and before deallocation of the old block. + + In many cases, these three client requests will not be enough to get your + allocator working well with Memcheck. More specifically, if your allocator + writes to freed blocks in any way then a VALGRIND_MAKE_MEM_UNDEFINED call + will be necessary to mark the memory as addressable just before the zeroing + occurs, otherwise you'll get a lot of invalid write errors. For example, + you'll need to do this if your allocator recycles freed blocks, but it + zeroes them before handing them back out (via VALGRIND_MALLOCLIKE_BLOCK). + Alternatively, if your allocator reuses freed blocks for allocator-internal + data structures, VALGRIND_MAKE_MEM_UNDEFINED calls will also be necessary. + + Really, what's happening is a blurring of the lines between the client + program and the allocator... after VALGRIND_FREELIKE_BLOCK is called, the + memory should be considered unaddressable to the client program, but the + allocator knows more than the rest of the client program and so may be able + to safely access it. Extra client requests are necessary for Valgrind to + understand the distinction between the allocator and the rest of the + program. + + Ignored if addr == 0. +*/ +#define VALGRIND_MALLOCLIKE_BLOCK(addr, sizeB, rzB, is_zeroed) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MALLOCLIKE_BLOCK, \ + addr, sizeB, rzB, is_zeroed, 0) + +/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details. + Ignored if addr == 0. +*/ +#define VALGRIND_RESIZEINPLACE_BLOCK(addr, oldSizeB, newSizeB, rzB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__RESIZEINPLACE_BLOCK, \ + addr, oldSizeB, newSizeB, rzB, 0) + +/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details. + Ignored if addr == 0. +*/ +#define VALGRIND_FREELIKE_BLOCK(addr, rzB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__FREELIKE_BLOCK, \ + addr, rzB, 0, 0, 0) + +/* Create a memory pool. */ +#define VALGRIND_CREATE_MEMPOOL(pool, rzB, is_zeroed) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL, \ + pool, rzB, is_zeroed, 0, 0) + +/* Create a memory pool with some flags specifying extended behaviour. + When flags is zero, the behaviour is identical to VALGRIND_CREATE_MEMPOOL. + + The flag VALGRIND_MEMPOOL_METAPOOL specifies that the pieces of memory + associated with the pool using VALGRIND_MEMPOOL_ALLOC will be used + by the application as superblocks to dole out MALLOC_LIKE blocks using + VALGRIND_MALLOCLIKE_BLOCK. In other words, a meta pool is a "2 levels" + pool : first level is the blocks described by VALGRIND_MEMPOOL_ALLOC. + The second level blocks are described using VALGRIND_MALLOCLIKE_BLOCK. + Note that the association between the pool and the second level blocks + is implicit : second level blocks will be located inside first level + blocks. It is necessary to use the VALGRIND_MEMPOOL_METAPOOL flag + for such 2 levels pools, as otherwise valgrind will detect overlapping + memory blocks, and will abort execution (e.g. during leak search). + + Such a meta pool can also be marked as an 'auto free' pool using the flag + VALGRIND_MEMPOOL_AUTO_FREE, which must be OR-ed together with the + VALGRIND_MEMPOOL_METAPOOL. For an 'auto free' pool, VALGRIND_MEMPOOL_FREE + will automatically free the second level blocks that are contained + inside the first level block freed with VALGRIND_MEMPOOL_FREE. + In other words, calling VALGRIND_MEMPOOL_FREE will cause implicit calls + to VALGRIND_FREELIKE_BLOCK for all the second level blocks included + in the first level block. + Note: it is an error to use the VALGRIND_MEMPOOL_AUTO_FREE flag + without the VALGRIND_MEMPOOL_METAPOOL flag. +*/ +#define VALGRIND_MEMPOOL_AUTO_FREE 1 +#define VALGRIND_MEMPOOL_METAPOOL 2 +#define VALGRIND_CREATE_MEMPOOL_EXT(pool, rzB, is_zeroed, flags) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CREATE_MEMPOOL, \ + pool, rzB, is_zeroed, flags, 0) + +/* Destroy a memory pool. */ +#define VALGRIND_DESTROY_MEMPOOL(pool) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__DESTROY_MEMPOOL, \ + pool, 0, 0, 0, 0) + +/* Associate a piece of memory with a memory pool. */ +#define VALGRIND_MEMPOOL_ALLOC(pool, addr, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_ALLOC, \ + pool, addr, size, 0, 0) + +/* Disassociate a piece of memory from a memory pool. */ +#define VALGRIND_MEMPOOL_FREE(pool, addr) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_FREE, \ + pool, addr, 0, 0, 0) + +/* Disassociate any pieces outside a particular range. */ +#define VALGRIND_MEMPOOL_TRIM(pool, addr, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_TRIM, \ + pool, addr, size, 0, 0) + +/* Resize and/or move a piece associated with a memory pool. */ +#define VALGRIND_MOVE_MEMPOOL(poolA, poolB) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MOVE_MEMPOOL, \ + poolA, poolB, 0, 0, 0) + +/* Resize and/or move a piece associated with a memory pool. */ +#define VALGRIND_MEMPOOL_CHANGE(pool, addrA, addrB, size) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__MEMPOOL_CHANGE, \ + pool, addrA, addrB, size, 0) + +/* Return 1 if a mempool exists, else 0. */ +#define VALGRIND_MEMPOOL_EXISTS(pool) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__MEMPOOL_EXISTS, \ + pool, 0, 0, 0, 0) + +/* Mark a piece of memory as being a stack. Returns a stack id. + start is the lowest addressable stack byte, end is the highest + addressable stack byte. */ +#define VALGRIND_STACK_REGISTER(start, end) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__STACK_REGISTER, \ + start, end, 0, 0, 0) + +/* Unmark the piece of memory associated with a stack id as being a + stack. */ +#define VALGRIND_STACK_DEREGISTER(id) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_DEREGISTER, \ + id, 0, 0, 0, 0) + +/* Change the start and end address of the stack id. + start is the new lowest addressable stack byte, end is the new highest + addressable stack byte. */ +#define VALGRIND_STACK_CHANGE(id, start, end) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__STACK_CHANGE, \ + id, start, end, 0, 0) + +/* Load PDB debug info for Wine PE image_map. */ +#define VALGRIND_LOAD_PDB_DEBUGINFO(fd, ptr, total_size, delta) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__LOAD_PDB_DEBUGINFO, \ + fd, ptr, total_size, delta, 0) + +/* Map a code address to a source file name and line number. buf64 + must point to a 64-byte buffer in the caller's address space. The + result will be dumped in there and is guaranteed to be zero + terminated. If no info is found, the first byte is set to zero. */ +#define VALGRIND_MAP_IP_TO_SRCLOC(addr, buf64) \ + (unsigned)VALGRIND_DO_CLIENT_REQUEST_EXPR(0, \ + VG_USERREQ__MAP_IP_TO_SRCLOC, \ + addr, buf64, 0, 0, 0) + +/* Disable error reporting for this thread. Behaves in a stack like + way, so you can safely call this multiple times provided that + VALGRIND_ENABLE_ERROR_REPORTING is called the same number of times + to re-enable reporting. The first call of this macro disables + reporting. Subsequent calls have no effect except to increase the + number of VALGRIND_ENABLE_ERROR_REPORTING calls needed to re-enable + reporting. Child threads do not inherit this setting from their + parents -- they are always created with reporting enabled. */ +#define VALGRIND_DISABLE_ERROR_REPORTING \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \ + 1, 0, 0, 0, 0) + +/* Re-enable error reporting, as per comments on + VALGRIND_DISABLE_ERROR_REPORTING. */ +#define VALGRIND_ENABLE_ERROR_REPORTING \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CHANGE_ERR_DISABLEMENT, \ + -1, 0, 0, 0, 0) + +/* Execute a monitor command from the client program. + If a connection is opened with GDB, the output will be sent + according to the output mode set for vgdb. + If no connection is opened, output will go to the log output. + Returns 1 if command not recognised, 0 otherwise. */ +#define VALGRIND_MONITOR_COMMAND(command) \ + VALGRIND_DO_CLIENT_REQUEST_EXPR(0, VG_USERREQ__GDB_MONITOR_COMMAND, \ + command, 0, 0, 0, 0) + + +/* Change the value of a dynamic command line option. + Note that unknown or not dynamically changeable options + will cause a warning message to be output. */ +#define VALGRIND_CLO_CHANGE(option) \ + VALGRIND_DO_CLIENT_REQUEST_STMT(VG_USERREQ__CLO_CHANGE, \ + option, 0, 0, 0, 0) + + +#undef PLAT_x86_darwin +#undef PLAT_amd64_darwin +#undef PLAT_x86_win32 +#undef PLAT_amd64_win64 +#undef PLAT_x86_linux +#undef PLAT_amd64_linux +#undef PLAT_ppc32_linux +#undef PLAT_ppc64be_linux +#undef PLAT_ppc64le_linux +#undef PLAT_arm_linux +#undef PLAT_s390x_linux +#undef PLAT_mips32_linux +#undef PLAT_mips64_linux +#undef PLAT_nanomips_linux +#undef PLAT_x86_solaris +#undef PLAT_amd64_solaris + +#endif /* __VALGRIND_H */ diff --git a/phivenv/Lib/site-packages/torch/utils/bottleneck/__init__.py b/phivenv/Lib/site-packages/torch/utils/bottleneck/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/bottleneck/__main__.py b/phivenv/Lib/site-packages/torch/utils/bottleneck/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7b941641c1aecf881808c206218c984a9be85a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/bottleneck/__main__.py @@ -0,0 +1,229 @@ +# mypy: allow-untyped-defs +import argparse +import cProfile +import pstats +import sys +import os + +import torch +from torch.autograd import profiler +from torch.utils.collect_env import get_env_info + + +def redirect_argv(new_argv): + sys.argv[:] = new_argv[:] + + +def compiled_with_cuda(sysinfo): + if sysinfo.cuda_compiled_version: + return f'compiled w/ CUDA {sysinfo.cuda_compiled_version}' + return 'not compiled w/ CUDA' + + +env_summary = """ +-------------------------------------------------------------------------------- + Environment Summary +-------------------------------------------------------------------------------- +PyTorch {pytorch_version}{debug_str} {cuda_compiled} +Running with Python {py_version} and {cuda_runtime} + +`{pip_version} list` truncated output: +{pip_list_output} +""".strip() + + +def run_env_analysis(): + print('Running environment analysis...') + info = get_env_info() + + result: dict[str, str] = {} + + debug_str = '' + if info.is_debug_build: + debug_str = ' DEBUG' + + cuda_avail = '' + if info.is_cuda_available: + cuda = info.cuda_runtime_version + if cuda is not None: + cuda_avail = 'CUDA ' + cuda + else: + cuda = 'CUDA unavailable' + + pip_version = info.pip_version + pip_list_output = info.pip_packages + if pip_list_output is None: + pip_list_output = 'Unable to fetch' + + result = { + 'debug_str': debug_str, + 'pytorch_version': info.torch_version, + 'cuda_compiled': compiled_with_cuda(info), + 'py_version': f'{sys.version_info[0]}.{sys.version_info[1]}', + 'cuda_runtime': cuda_avail, + 'pip_version': pip_version, + 'pip_list_output': pip_list_output, + } + + return env_summary.format(**result) + + +def run_cprofile(code, globs, launch_blocking=False): + print('Running your script with cProfile') + prof = cProfile.Profile() + prof.enable() + exec(code, globs, None) + prof.disable() + return prof + + +cprof_summary = """ +-------------------------------------------------------------------------------- + cProfile output +-------------------------------------------------------------------------------- +""".strip() + + +def print_cprofile_summary(prof, sortby='tottime', topk=15): + print(cprof_summary) + cprofile_stats = pstats.Stats(prof).sort_stats(sortby) + cprofile_stats.print_stats(topk) + + +def run_autograd_prof(code, globs): + def run_prof(use_cuda=False): + with profiler.profile(use_cuda=use_cuda) as prof: + exec(code, globs, None) + return prof + + print('Running your script with the autograd profiler...') + result = [run_prof(use_cuda=False)] + if torch.cuda.is_available(): + result.append(run_prof(use_cuda=True)) + else: + result.append(None) + + return result + + +autograd_prof_summary = """ +-------------------------------------------------------------------------------- + autograd profiler output ({mode} mode) +-------------------------------------------------------------------------------- + {description} +{cuda_warning} +{output} +""".strip() + + +def print_autograd_prof_summary(prof, mode, sortby='cpu_time', topk=15): + valid_sortby = ['cpu_time', 'cuda_time', 'cpu_time_total', 'cuda_time_total', 'count'] + if sortby not in valid_sortby: + warn = ('WARNING: invalid sorting option for autograd profiler results: {}\n' + 'Expected `cpu_time`, `cpu_time_total`, or `count`. ' + 'Defaulting to `cpu_time`.') + print(warn.format(sortby)) + sortby = 'cpu_time' + + if mode == 'CUDA': + cuda_warning = ('\n\tBecause the autograd profiler uses the CUDA event API,\n' + '\tthe CUDA time column reports approximately max(cuda_time, cpu_time).\n' + '\tPlease ignore this output if your code does not use CUDA.\n') + else: + cuda_warning = '' + + sorted_events = sorted(prof.function_events, + key=lambda x: getattr(x, sortby), reverse=True) + topk_events = sorted_events[:topk] + + result = { + 'mode': mode, + 'description': f'top {topk} events sorted by {sortby}', + 'output': torch.autograd.profiler_util._build_table(topk_events), + 'cuda_warning': cuda_warning + } + + print(autograd_prof_summary.format(**result)) + + +descript = """ +`bottleneck` is a tool that can be used as an initial step for debugging +bottlenecks in your program. + +It summarizes runs of your script with the Python profiler and PyTorch\'s +autograd profiler. Because your script will be profiled, please ensure that it +exits in a finite amount of time. + +For more complicated uses of the profilers, please see +https://docs.python.org/3/library/profile.html and +https://pytorch.org/docs/main/autograd.html#profiler for more information. +""".strip() + + +def parse_args(): + parser = argparse.ArgumentParser(description=descript) + parser.add_argument('scriptfile', type=str, + help='Path to the script to be run. ' + 'Usually run with `python path/to/script`.') + parser.add_argument('args', type=str, nargs=argparse.REMAINDER, + help='Command-line arguments to be passed to the script.') + return parser.parse_args() + + +def cpu_time_total(autograd_prof): + return sum(event.cpu_time_total for event in autograd_prof.function_events) + + +def main(): + args = parse_args() + + # Customizable constants. + scriptfile = args.scriptfile + scriptargs = [] if args.args is None else args.args + scriptargs.insert(0, scriptfile) + cprofile_sortby = 'tottime' + cprofile_topk = 15 + autograd_prof_sortby = 'cpu_time_total' + autograd_prof_topk = 15 + + redirect_argv(scriptargs) + + sys.path.insert(0, os.path.dirname(scriptfile)) + with open(scriptfile, 'rb') as stream: + code = compile(stream.read(), scriptfile, 'exec') + globs = { + '__file__': scriptfile, + '__name__': '__main__', + '__package__': None, + '__cached__': None, + } + + print(descript) + + env_summary = run_env_analysis() + + if torch.cuda.is_available(): + torch.cuda.init() + cprofile_prof = run_cprofile(code, globs) + autograd_prof_cpu, autograd_prof_cuda = run_autograd_prof(code, globs) + + print(env_summary) + print_cprofile_summary(cprofile_prof, cprofile_sortby, cprofile_topk) + + if not torch.cuda.is_available(): + print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk) + return + + # Print both the result of the CPU-mode and CUDA-mode autograd profilers + # if their execution times are very different. + cuda_prof_exec_time = cpu_time_total(autograd_prof_cuda) + if len(autograd_prof_cpu.function_events) > 0: + cpu_prof_exec_time = cpu_time_total(autograd_prof_cpu) + pct_diff = (cuda_prof_exec_time - cpu_prof_exec_time) / cuda_prof_exec_time + if abs(pct_diff) > 0.05: + print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk) + + print_autograd_prof_summary(autograd_prof_cuda, 'CUDA', autograd_prof_sortby, autograd_prof_topk) + +if __name__ == '__main__': + main() diff --git a/phivenv/Lib/site-packages/torch/utils/bottleneck/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/bottleneck/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cfe22b238942e7404c31ce5c583d147dd1ee103 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/bottleneck/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/bottleneck/__pycache__/__main__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/bottleneck/__pycache__/__main__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16338055b2a7cce4eff397df30c747f6f2f82773 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/bottleneck/__pycache__/__main__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/bundled_inputs.py b/phivenv/Lib/site-packages/torch/utils/bundled_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..f362fddde7c8d488dc97b608726a0eebf08deec5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/bundled_inputs.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs +from typing import Any, TypeVar, Optional, NamedTuple, Union, Callable +from collections.abc import Sequence +import textwrap +import torch +from torch._C import TupleType, ListType +from torch.jit._recursive import wrap_cpp_module + + +T = TypeVar("T") + +MAX_RAW_TENSOR_SIZE = 16 + +class InflatableArg(NamedTuple): + """Helper type for bundled inputs. + + 'value' is the compressed/deflated input that is stored in the model. Value + must be of the same type as the argument to the function that it is a deflated + input for. + + 'fmt' is a formatable code string that is executed to inflate the compressed data into + the appropriate input. It can use 'value' as an input to the format str. It must result + in a value of the same type as 'value'. + + 'fmt_fn' is a formatable function code string that is executed to inflate the compressed + data into the appropriate input. It must result in a value of the same type as 'value'. + The function name should be the formatable part of the string. + + Note: Only top level InflatableArgs can be inflated. i.e. you cannot place + an inflatable arg inside of some other structure. You should instead create + an inflatable arg such that the fmt code string returns the full structure + of your input. + """ + + value: Any + fmt: str = "{}" + fmt_fn: str = "" + + +def bundle_inputs( + model: torch.jit.ScriptModule, + inputs: Union[Optional[Sequence[tuple[Any, ...]]], dict[Callable, Optional[Sequence[tuple[Any, ...]]]]], + info: Optional[Union[list[str], dict[Callable, list[str]]]] = None, + *, + _receive_inflate_expr: Optional[list[str]] = None, +) -> torch.jit.ScriptModule: + """Create and return a copy of the specified model with inputs attached. + + The original model is not mutated or changed in any way. + + Models with bundled inputs can be invoked in a uniform manner by + benchmarking and code coverage tools. + + If inputs is passed in as a list then the inputs will be bundled for 'forward'. + If inputs is instead passed in as a map then all the methods specified in the map + will have their corresponding inputs bundled. Info should match watchever type is + chosen for the inputs. + + The returned model will support the following methods: + + `get_all_bundled_inputs_for_() -> List[Tuple[Any, ...]]` + Returns a list of tuples suitable for passing to the model like + `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)` + + `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]` + Returns a dictionary mapping function names to a metadata dictionary. + This nested dictionary maps preset strings like: + 'get_inputs_function_name' -> the name of a function attribute in this model that can be + run to get back a list of inputs corresponding to that function. + 'info' -> the user provided extra information about the bundled inputs + + If forward has bundled inputs then these following functions will also be defined on the returned module: + + `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` + Returns a list of tuples suitable for passing to the model like + `for inp in model.get_all_bundled_inputs(): model(*inp)` + + `get_num_bundled_inputs() -> int` + Equivalent to `len(model.get_all_bundled_inputs())`, + but slightly easier to call from C++. + + Inputs can be specified in one of two ways: + + - The model can define `_generate_bundled_inputs_for_`. + If the user chooses this method inputs[] should map to None + + - The `inputs` argument to this function can be a dictionary mapping functions to a + list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_. + Alternatively if only bundling inputs for forward the map can be omitted and a singular list of inputs + can be provided instead. + + The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a + list of inputs, the inner tuple is the list of args that together make up one input. + For inputs of functions that take one arg, this will be a tuple of length one. The Any, ... + is the actual data that makes up the args, e.g. a tensor. + + Info is an optional parameter that maps functions to a list of strings providing extra information about that + function's bundled inputs. Alternatively if only bundling inputs for forward the map can be omitted and + a singular list of information can be provided instead. This could be descriptions, expected outputs, etc. + - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']} + + This function will attempt to optimize arguments so that (e.g.) + arguments like `torch.zeros(1000)` will be represented compactly. + Only top-level arguments will be optimized. + Tensors in lists or tuples will not. + """ + if not isinstance(model, torch.jit.ScriptModule): + raise Exception("Only ScriptModule is supported.") # noqa: TRY002 + + ignored_methods, ignored_attrs = _get_bundled_inputs_attributes_and_methods(model) + clone = torch._C._hack_do_not_use_clone_module_with_class( # type: ignore[attr-defined] + model._c, + ignored_methods, + ignored_attrs, + ) + + # The above cloning function returns a torch._C.scriptmodule and we need a torch.jit.scriptmodule. + # Fortunately theres a function in _recursive that does exactly that conversion. + cloned_module = wrap_cpp_module(clone) + if isinstance(inputs, dict): + assert isinstance(info, dict) or info is None + augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info) + else: + assert isinstance(info, list) or info is None + augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info) + return cloned_module + +def augment_model_with_bundled_inputs( + model: torch.jit.ScriptModule, + inputs: Optional[Sequence[tuple[Any, ...]]] = None, + _receive_inflate_expr: Optional[list[str]] = None, # For debugging. + info: Optional[list[str]] = None, # Optional argument to provide info about forward or its inputs + skip_size_check=False, +) -> None: + """Add bundled sample inputs to a model for the forward function. + + Models with bundled inputs can be invoked in a uniform manner by + benchmarking and code coverage tools. + + Augmented models will support the following methods: + + `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` + Returns a list of tuples suitable for passing to the model like + `for inp in model.get_all_bundled_inputs(): model(*inp)` + + `get_num_bundled_inputs() -> int` + Equivalent to `len(model.get_all_bundled_inputs())`, + but slightly easier to call from C++. + + `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]` + Returns a dictionary mapping function names to a metadata dictionary. + This nested dictionary maps preset strings like: + 'get_inputs_function_name' -> the name of a function attribute in this model that can be + run to get back a list of inputs corresponding to that function. + 'info' -> the user provided extra information about the bundled inputs + + Inputs can be specified in one of two ways: + + - The model can define `_generate_bundled_inputs_for_forward`. + If the user chooses this method inputs should be None + + - `inputs` is a list of inputs of form List[Tuple[Any, ...]]. A list of tuples where the elements + of each tuple are the args that make up one input. + """ + if not isinstance(model, torch.jit.ScriptModule): + raise Exception("Only ScriptModule is supported.") # noqa: TRY002 + + forward: Callable = model.forward + + # Sometimes forward won't have a name attached so just in case + if not hasattr(forward, "__name__"): + forward.__name__ = 'forward' + augment_many_model_functions_with_bundled_inputs( + model, + inputs={forward : inputs}, + _receive_inflate_expr=_receive_inflate_expr, + info={forward : info} if info else None, + skip_size_check=skip_size_check, + ) + + +def augment_many_model_functions_with_bundled_inputs( + model: torch.jit.ScriptModule, + inputs: dict[Callable, Optional[Sequence[tuple[Any, ...]]]], + _receive_inflate_expr: Optional[list[str]] = None, # For debugging. + info: Optional[dict[Callable, list[str]]] = None, # Optional argument to provide info about the function or its inputs + skip_size_check=False, +) -> None: + """Add bundled sample inputs to a model for an arbitrary list of public functions. + + Models with bundled inputs can be invoked in a uniform manner by + benchmarking and code coverage tools. + + Augmented models will support the following methods: + + `get_all_bundled_inputs_for_() -> List[Tuple[Any, ...]]` + Returns a list of tuples suitable for passing to the model like + `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)` + + `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]` + Returns a dictionary mapping function names to a metadata dictionary. + This nested dictionary maps preset strings like: + 'get_inputs_function_name' -> the name of a function attribute in this model that can be + run to get back a list of inputs corresponding to that function. + 'info' -> the user provided extra information about the bundled inputs + + If forward has bundled inputs then these following functions are also defined: + + `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` + Returns a list of tuples suitable for passing to the model like + `for inp in model.get_all_bundled_inputs(): model(*inp)` + + `get_num_bundled_inputs() -> int` + Equivalent to `len(model.get_all_bundled_inputs())`, + but slightly easier to call from C++. + + Inputs can be specified in one of two ways: + + - The model can define `_generate_bundled_inputs_for_`. + If the user chooses this method inputs[] should map to None + + - The `inputs` argument to this function can be a dictionary mapping functions to a + list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_. + The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a + list of inputs, the inner tuple is the list of args that together make up one input. + For inputs of functions that take one arg, this will be a tuple of length one. The Any, ... + is the actual data that makes up the args, e.g. a tensor. + + Info is an optional parameter that maps functions to a list of strings providing extra information about that + function's bundled inputs. This could be descriptions, expected outputs, etc. + - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']} + + This function will attempt to optimize arguments so that (e.g.) + arguments like `torch.zeros(1000)` will be represented compactly. + Only top-level arguments will be optimized. + Tensors in lists or tuples will not. + """ + if not isinstance(model, torch.jit.ScriptModule): + raise Exception("Only ScriptModule is supported.") # noqa: TRY002 + + if not inputs: + raise Exception("Please provide inputs for at least 1 function") # noqa: TRY002 + + if hasattr(model, "get_all_bundled_inputs") or hasattr(model, "get_bundled_inputs_functions_and_info"): + raise Exception( # noqa: TRY002 + "Models can only be augmented with bundled inputs once. " + "This Model seems to have already been augmented with " + "bundled inputs. Please start afresh with one that " + "doesn't have bundled inputs.", + ) + + get_bundled_inputs_functions_and_info_template = "" + + for function, input_list in inputs.items(): + if hasattr(function, "__name__"): + function_name = function.__name__ + else: + if hasattr(function, "name"): + function_name = function.name # type: ignore[attr-defined] + else: + raise Exception( # noqa: TRY002 + 'At least one of your functions has no attribute name please ensure all have one. m.foo.name = "foo"') + + + if input_list is not None and not isinstance(input_list, Sequence): + raise TypeError(f"Error inputs for function {function_name} is not a Sequence") + + function_arg_types = [arg.type for arg in function.schema.arguments[1:]] # type: ignore[attr-defined] + deflated_inputs_type: ListType = ListType(TupleType(function_arg_types)) + model._c._register_attribute(f"_bundled_inputs_deflated_{function_name}", deflated_inputs_type, []) + + if hasattr(model, "_generate_bundled_inputs_for_" + function_name): + if input_list is not None: + raise Exception( # noqa: TRY002 + f"inputs[{function_name}] is not None, but _generate_bundled_inputs_for_{function_name} is already defined" + ) + # Model author already defined _generate_bundled_inputs_for_. + elif input_list is None or len(input_list) == 0: + raise Exception( # noqa: TRY002 + f"inputs for {function_name} must be specified if " + f"_generate_bundled_inputs_for_{function_name} is not already defined" + ) + else: + # Iterate over the inputs and args in each input. + # Accumulate `deflated_inputs` as (possibly) compressed values + # and `parts` to be joined into the expression that unpacks them. + deflated_inputs = [] + parts = [] + for inp_idx, args in enumerate(input_list): + if not isinstance(args, tuple) and not isinstance(args, list): # type: ignore[arg-type] + raise TypeError( + f"Error bundled input for function {function_name} idx: {inp_idx} is not a Tuple or a List" + ) + deflated_args = [] + parts.append("(") + for arg_idx, arg in enumerate(args): + inflate_helper_fn_name = _get_inflate_helper_fn_name(arg_idx, inp_idx, function_name) + deflated, inflater, helper_definition = _inflate_expr( + arg, + f"deflated[{inp_idx}][{arg_idx}]", + inflate_helper_fn_name, + skip_size_check=skip_size_check, + ) + deflated_args.append(deflated) + parts.append(f" {inflater},") + if helper_definition: + model.define(textwrap.dedent(helper_definition)) + deflated_inputs.append(tuple(deflated_args)) + parts.append("),") + parts.append("") + expr = "\n".join(parts) + + # Back-channel return this expr for debugging. + if _receive_inflate_expr is not None: + _receive_inflate_expr.append(expr) + setattr(model, f"_bundled_inputs_deflated_{function_name}", deflated_inputs) + definition = textwrap.dedent(""" + def _generate_bundled_inputs_for_{name}(self): + deflated = self._bundled_inputs_deflated_{name} + return [ + {expr} + ] + """).format(expr=expr, name=function_name) + model.define(definition) + + # Define get_all_bundled_inputs_for_ that caches the generated inputs. + model.define(textwrap.dedent(""" + def get_all_bundled_inputs_for_{name}(self): + all_inputs = self._generate_bundled_inputs_for_{name}() + assert all_inputs is not None + return all_inputs + """).format(name=function_name)) + + # Add to the high level helper methods + inputs_info = repr(info[function]) if info and function in info else '[]' + get_bundled_inputs_functions_and_info_template += f""" + temp_dict : Dict[str,List[str]] = {{}} + info: List[str] = {inputs_info} + + temp_dict['info'] = info + temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{function_name}'] + all_inputs['{function_name}'] = temp_dict + """ + + # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided + if function_name == 'forward': + model.define(textwrap.dedent(""" + def get_all_bundled_inputs(self): + return self.get_all_bundled_inputs_for_forward() + """)) + model.define(textwrap.dedent(""" + def get_num_bundled_inputs(self): + return len(self.get_all_bundled_inputs_for_forward()) + """)) + + # Define some high level helper methods that act on all bundled inputs + model.define(textwrap.dedent(f""" + def get_bundled_inputs_functions_and_info(self): + all_inputs : Dict[str, Dict[str,List[str]]] = {{}} + {get_bundled_inputs_functions_and_info_template} + return all_inputs + """)) + +def _inflate_expr( + arg: T, ref: str, inflate_helper_fn_name: str, skip_size_check: bool = False +) -> tuple[Union[T, torch.Tensor], str, Optional[str]]: + # Allow custom inflation expressions any object. + # For example, calling custom image-decoding ops. + # Or just use "{}" as the format string to ignore size limits. + if isinstance(arg, InflatableArg): + if arg.fmt_fn: + if arg.fmt not in ["{}", ""]: + raise Exception( # noqa: TRY002 + f"Bundled input argument at position '{ref}' has " + f"both arg.fmt_fn => \n{arg.fmt_fn} " + f"\n and arg.fmt => {arg.fmt}. " + "Please choose `arg.fmt` if the deflater is straightforward or " + "`arg.fmt_fn` if you need a function." + ) + + helper_definition = arg.fmt_fn.format(inflate_helper_fn_name) + expr = f"self.{inflate_helper_fn_name}({ref})" + + return arg.value, expr, helper_definition + else: + return arg.value, arg.fmt.format(ref), None + + if isinstance(arg, torch.Tensor): + # Small-storage tensors can just be saved directly. + if arg._typed_storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check: + return arg, ref, None + # Small contiguous tensors can be cloned to have small storage. + # TODO: Should we do this even for non-contiguous tensors? + if arg.is_contiguous() and arg.numel() <= MAX_RAW_TENSOR_SIZE: + return arg.clone(), ref, None + # Example inputs commonly come from torch.zeros, torch.ones, or torch.full. + # These can be represented compactly. + for fmt in [torch.contiguous_format, torch.channels_last]: + if arg.is_contiguous(memory_format=fmt) and (arg == arg.flatten()[0]).all().item(): + return (arg.flatten()[0].clone().expand(*arg.size()), + f"{ref}.contiguous(memory_format={fmt})", None) + # Prevent big tensors from being bundled by default. + # TODO: Provide more useful diagnostics. + raise Exception( # noqa: TRY002 + f"Bundled input argument at position '{ref}' is " + f"a tensor with storage size {arg._typed_storage().size()}. " + f"You probably don't want to bundle this as an input. " + ) + else: + return arg, ref, None + +def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptModule) -> tuple[list[str], list[str]]: + methods: list[str] = [] + attributes: list[str] = [] + + # Has bundled inputs for forward + if hasattr(script_module, 'get_all_bundled_inputs'): + methods.append('get_all_bundled_inputs') + methods.append('get_num_bundled_inputs') + methods.append('run_on_bundled_input') + + if hasattr(script_module, 'get_bundled_inputs_functions_and_info'): + methods.append('get_bundled_inputs_functions_and_info') + all_info = script_module.get_bundled_inputs_functions_and_info() + for function_name in all_info: + methods.append("get_all_bundled_inputs_for_" + function_name) + methods.append("_generate_bundled_inputs_for_" + function_name) + attributes.append("_bundled_inputs_deflated_" + function_name) + + bundled_inputs_fn = getattr( + script_module, + f"get_all_bundled_inputs_for_{function_name}" + ) + num_bundled_inputs: int = len(bundled_inputs_fn()) + + # Check inflate helper functions for each function, argument and bundled input + func = getattr(script_module, function_name) + for arg_idx in range(len(func.schema.arguments) - 1): + for input_idx in range(num_bundled_inputs): + helper_fn_name = _get_inflate_helper_fn_name( + arg_idx=arg_idx, + input_idx=input_idx, + function_name=function_name + ) + # if the arg has an InflatableArg with fmt_fn, add the helper function name + if hasattr(script_module, helper_fn_name): + methods.append(helper_fn_name) + + return (methods, attributes) + + +def _get_inflate_helper_fn_name( + arg_idx: int, + input_idx: int, + function_name: str, +) -> str: + return f"_inflate_helper_for_{function_name}_input_{input_idx}_arg_{arg_idx}" + + + +def bundle_randn(*size, dtype=None): + """Generate a tensor that will be inflated with torch.randn.""" + stub = torch.zeros(1, dtype=dtype).expand(*size) + return InflatableArg(value=stub, fmt="torch.randn_like({})") + + +def bundle_large_tensor(t): + """Wrap a tensor to allow bundling regardless of size.""" + return InflatableArg(value=t, fmt="{}") diff --git a/phivenv/Lib/site-packages/torch/utils/checkpoint.py b/phivenv/Lib/site-packages/torch/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0bfeb586654176641a38f85a587efe8878625b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/checkpoint.py @@ -0,0 +1,1586 @@ +# mypy: allow-untyped-defs +import contextlib +import platform +import uuid +import warnings +import weakref +from collections import defaultdict +from typing import * # noqa: F403 +import enum +from weakref import ReferenceType + +import torch +import torch.fx.traceback as fx_traceback +from torch.utils._pytree import tree_map +from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode +from torch.utils._python_dispatch import TorchDispatchMode + +__all__ = [ + "checkpoint", + "checkpoint_sequential", + "CheckpointError", + "CheckpointFunction", + "check_backward_validity", + "detach_variable", + "get_device_states", + "set_device_states", + "noop_context_fn", + "set_checkpoint_early_stop", + "DefaultDeviceType", + "set_checkpoint_debug_enabled", + "CheckpointPolicy", + "SelectiveCheckpointContext", + "create_selective_checkpoint_contexts", + "SAC_IGNORED_OPS", +] + +_DEFAULT_DETERMINISM_MODE = "default" + +_checkpoint_debug_enabled: Optional[bool] = None + + +@contextlib.contextmanager +def set_checkpoint_debug_enabled(enabled: Optional[bool]): + """ + Context manager that sets whether checkpoint should print additional debug + information when running. See the ``debug`` flag for + :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that + when set, this context manager overrides the value of ``debug`` passed to + checkpoint. To defer to the local setting, pass ``None`` to this context. + + Args: + enabled (bool): Whether checkpoint should print debug information. + Default is 'None'. + """ + global _checkpoint_debug_enabled + try: + prev = _checkpoint_debug_enabled + _checkpoint_debug_enabled = enabled + yield + finally: + _checkpoint_debug_enabled = prev + + +def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + out.append(inp) + continue + + x = inp.detach() + x.requires_grad = inp.requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "Only tuple of tensors is supported. Got Unsupported input type: ", + type(inputs).__name__, + ) + + +def check_backward_validity(inputs: Iterable[Any]) -> None: + if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): + warnings.warn( + "None of the inputs have requires_grad=True. Gradients will be None" + ) + + +def _get_device_module(device="cuda"): + if device == "meta": + return torch.device("meta") + device_module = getattr(torch, device) + return device_module + + +class DefaultDeviceType: + r""" + A class that manages the default device type for checkpointing. + + If no non-CPU tensors are present, the default device type will + be used. The default value is 'cuda'. The device type is used in + the checkpointing process when determining which device states + to save and restore for recomputation. + """ + + _default_device_type = "cuda" + + @staticmethod + def set_device_type(device: str = "cuda"): + """ + Set the default device type for checkpointing. + + Args: + device (str): The device type to be set as default. Default is 'cuda'. + """ + DefaultDeviceType._default_device_type = device + + @staticmethod + def get_device_type() -> str: + """ + Get the current default device type for checkpointing. + + Returns: + str: The current default device type. + """ + return DefaultDeviceType._default_device_type + + +def _infer_device_type(*args): + device_types = [] + + def add_device_types(arg): + nonlocal device_types + if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu": + device_types.append(arg.device.type) + tree_map(add_device_types, args) + + device_types_set = set(device_types) + if len(device_types_set) > 1: + warnings.warn( + "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. " + "Device state will only be saved for devices of a single device type, and the remaining " + "devices will be ignored. Consequently, if any checkpointed functions involve randomness, " + "this may result in incorrect gradients. (Note that if CUDA devices are among the devices " + "detected, it will be prioritized; otherwise, the first device encountered will be selected.)" + f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}" + ) + if len(device_types) == 0: + return DefaultDeviceType.get_device_type() + elif "cuda" in device_types_set: + return "cuda" + else: + return device_types[0] + + +# We can't know if the run_fn will internally move some args to different devices, +# which would require logic to preserve rng states for those devices as well. +# We could paranoically stash and restore ALL the rng states for all visible devices, +# but that seems very wasteful for most cases. Compromise: Stash the RNG state for +# the device of all Tensor args. +# +# To consider: maybe get_device_states and set_device_states should reside in torch/random.py? +def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: + # This will not error out if "arg" is a CPU tensor or a non-tensor type because + # the conditionals short-circuit. + fwd_device_ids = [] + + def add_device_ids(arg): + nonlocal fwd_device_ids + if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}: + fwd_device_ids.append(arg.get_device()) + tree_map(add_device_ids, args) + + fwd_device_states = [] + device_module = _get_device_module(_infer_device_type(*args)) + for device_id in fwd_device_ids: + with device_module.device(device_id): + fwd_device_states.append(device_module.get_rng_state()) + + return fwd_device_ids, fwd_device_states + + +def set_device_states(devices, states, *, device_type=None) -> None: + """Sets random number generator states for the specified devices. + + Args: + devices: Device ids to set states for. + states: States to set. + device_type: ``device_type`` of the devices to set states for. Default + is the device returned by a call to ``DefaultDeviceType.get_device_type()``, + which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``. + """ + if device_type is None: + device_type = DefaultDeviceType.get_device_type() + if device_type == "meta": + return + device_module = _get_device_module(device_type) + for device, state in zip(devices, states): + with device_module.device(device): + device_module.set_rng_state(state) + + +def _get_autocast_kwargs(device_type="cuda"): + if torch.amp.is_autocast_available(device_type): + device_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(device_type), + "dtype": torch.get_autocast_dtype(device_type), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + else: + device_autocast_kwargs = None + + cpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled('cpu'), + "dtype": torch.get_autocast_dtype('cpu'), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + + return device_autocast_kwargs, cpu_autocast_kwargs + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.device_type = _infer_device_type(*args) + ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( + ctx.device_type + ) + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function.) + ctx.had_device_in_fwd = False + device_module = _get_device_module(ctx.device_type) + if getattr(device_module, "_initialized", False): + ctx.had_device_in_fwd = True + ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "When use_reentrant=True, torch.utils.checkpoint is incompatible" + " with .grad() or passing an `inputs` parameter to .backward()." + " To resolve this error, you can either set use_reentrant=False," + " or call .backward() without passing the `inputs` argument." + ) + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + + # Fill in inputs with appropriate saved tensors. + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_device_in_fwd: + rng_devices = ctx.fwd_devices + with torch.random.fork_rng( + devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type + ): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_device_in_fwd: + set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type) + detached_inputs = detach_variable(tuple(inputs)) + + device_autocast_ctx = torch.amp.autocast( + device_type=ctx.device_type, **ctx.device_autocast_kwargs + ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext() + with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True," + " this checkpoint() is not necessary" + ) + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs + ) + + return (None, None) + grads + + +def noop_context_fn(): + return contextlib.nullcontext(), contextlib.nullcontext() + +# Note: [torch.compile and checkpoint] +# TorchDynamo does not step inside utils.checkpoint function. The flow +# looks likes this +# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by +# speculatively checking if the forward function is safe to trace. +# 2) If yes, then Dynamo-generated Fx graph has the wrapped higher +# order op. As a result, TorchDynamo does not look inside utils.checkpoint. +# 3) If not, then TorchDynamo falls back to eager by performing a graph +# break. And here, the following disable wrapper ensures that +# TorchDynamo does not trigger again on the frames created by +# utils.checkpoint innards. +@torch._disable_dynamo +def checkpoint( + function, + *args, + use_reentrant: Optional[bool] = None, + context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, + determinism_check: str = _DEFAULT_DETERMINISM_MODE, + debug: bool = False, + **kwargs +): + r"""Checkpoint a model or part of the model. + + Activation checkpointing is a technique that trades compute for memory. + Instead of keeping tensors needed for backward alive until they are used in + gradient computation during backward, forward computation in checkpointed + regions omits saving tensors for backward and recomputes them during the + backward pass. Activation checkpointing can be applied to any part of a + model. + + There are currently two checkpointing implementations available, determined + by the :attr:`use_reentrant` parameter. It is recommended that you use + ``use_reentrant=False``. Please refer the note below for a discussion of + their differences. + + .. warning:: + + If the :attr:`function` invocation during the backward pass differs + from the forward pass, e.g., due to a global variable, the checkpointed + version may not be equivalent, potentially causing an + error being raised or leading to silently incorrect gradients. + + .. warning:: + + The ``use_reentrant`` parameter should be passed explicitly. In version + 2.4 we will raise an exception if ``use_reentrant`` is not passed. + If you are using the ``use_reentrant=True`` variant, please refer to the + note below for important considerations and potential limitations. + + .. note:: + + The reentrant variant of checkpoint (``use_reentrant=True``) and + the non-reentrant variant of checkpoint (``use_reentrant=False``) + differ in the following ways: + + * Non-reentrant checkpoint stops recomputation as soon as all needed + intermediate activations have been recomputed. This feature is enabled + by default, but can be disabled with :func:`set_checkpoint_early_stop`. + Reentrant checkpoint always recomputes :attr:`function` in its + entirety during the backward pass. + + * The reentrant variant does not record the autograd graph during the + forward pass, as it runs with the forward pass under + :func:`torch.no_grad`. The non-reentrant version does record the + autograd graph, allowing one to perform backward on the graph within + checkpointed regions. + + * The reentrant checkpoint only supports the + :func:`torch.autograd.backward` API for the backward pass without its + `inputs` argument, while the non-reentrant version supports all ways + of performing the backward pass. + + * At least one input and output must have ``requires_grad=True`` for the + reentrant variant. If this condition is unmet, the checkpointed part + of the model will not have gradients. The non-reentrant version does + not have this requirement. + + * The reentrant version does not consider tensors in nested structures + (e.g., custom objects, lists, dicts, etc) as participating in + autograd, while the non-reentrant version does. + + * The reentrant checkpoint does not support checkpointed regions with + detached tensors from the computational graph, whereas the + non-reentrant version does. For the reentrant variant, if the + checkpointed segment contains tensors detached using ``detach()`` or + with :func:`torch.no_grad`, the backward pass will raise an error. + This is because ``checkpoint`` makes all the outputs require gradients + and this causes issues when a tensor is defined to have no gradient in + the model. To avoid this, detach the tensors outside of the + ``checkpoint`` function. + + Args: + function: describes what to run in the forward pass of the model or + part of the model. It should also know how to handle the inputs + passed as the tuple. For example, in LSTM, if user passes + ``(activation, hidden)``, :attr:`function` should correctly use the + first input as ``activation`` and the second input as ``hidden`` + preserve_rng_state(bool, optional): Omit stashing and restoring + the RNG state during each checkpoint. Note that under torch.compile, + this flag doesn't take effect and we always preserve RNG state. + Default: ``True`` + use_reentrant(bool): + specify whether to use the activation checkpoint variant that + requires reentrant autograd. This parameter should be passed + explicitly. In version 2.5 we will raise an exception if + ``use_reentrant`` is not passed. If ``use_reentrant=False``, + ``checkpoint`` will use an implementation that does not require + reentrant autograd. This allows ``checkpoint`` to support additional + functionality, such as working as expected with + ``torch.autograd.grad`` and support for keyword arguments input into + the checkpointed function. + context_fn(Callable, optional): A callable returning a tuple of two + context managers. The function and its recomputation will be run + under the first and second context managers respectively. + This argument is only supported if ``use_reentrant=False``. + determinism_check(str, optional): A string specifying the determinism + check to perform. By default it is set to ``"default"`` which + compares the shapes, dtypes, and devices of the recomputed tensors + against those the saved tensors. To turn off this check, specify + ``"none"``. Currently these are the only two supported values. + Please open an issue if you would like to see more determinism + checks. This argument is only supported if ``use_reentrant=False``, + if ``use_reentrant=True``, the determinism check is always disabled. + debug(bool, optional): If ``True``, error messages will also include + a trace of the operators ran during the original forward computation + as well as the recomputation. This argument is only supported if + ``use_reentrant=False``. + args: tuple containing inputs to the :attr:`function` + + Returns: + Output of running :attr:`function` on :attr:`*args` + """ + if use_reentrant is None: + warnings.warn( + "torch.utils.checkpoint: the use_reentrant parameter should be " + "passed explicitly. In version 2.5 we will raise an exception " + "if use_reentrant is not passed. use_reentrant=False is " + "recommended, but if you need to preserve the current default " + "behavior, you can pass use_reentrant=True. Refer to docs for more " + "details on the differences between the two variants.", + stacklevel=2 + ) + use_reentrant = True + + # Hack to mix *args with **kwargs in a python 2.7-compliant way + preserve = kwargs.pop("preserve_rng_state", True) + if kwargs and use_reentrant: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + if use_reentrant: + if context_fn is not noop_context_fn or debug is not False: + raise ValueError( + "Passing `context_fn` or `debug` is only supported when " + "use_reentrant=False." + ) + return CheckpointFunction.apply(function, preserve, *args) + else: + gen = _checkpoint_without_reentrant_generator( + function, preserve, context_fn, determinism_check, debug, *args, **kwargs + ) + # Runs pre-forward logic + next(gen) + ret = function(*args, **kwargs) + # Runs post-forward logic + try: + next(gen) + except StopIteration: + return ret + + +def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs): + r"""Checkpoint a sequential model to save memory. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a model in various segments + and checkpoint each segment. All segments except the last will not store + the intermediate activations. The inputs of each checkpointed segment will + be saved for re-running the segment in the backward pass. + + .. warning:: + The ``use_reentrant`` parameter should be passed explicitly. In version + 2.4 we will raise an exception if ``use_reentrant`` is not passed. + If you are using the ``use_reentrant=True` variant, please see + :func:`~torch.utils.checkpoint.checkpoint` for + the important considerations and limitations of this variant. It is + recommended that you use ``use_reentrant=False``. + + .. warning: + Since PyTorch 1.4, it allows only one Tensor as the input and + intermediate outputs, just like :class:`torch.nn.Sequential`. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or + functions (comprising the model) to run sequentially. + segments: Number of chunks to create in the model + input: A Tensor that is input to :attr:`functions` + preserve_rng_state(bool, optional): Omit stashing and restoring + the RNG state during each checkpoint. + Default: ``True`` + use_reentrant(bool): + specify whether to use the activation checkpoint variant that + requires reentrant autograd. This parameter should be passed + explicitly. In version 2.5 we will raise an exception if + ``use_reentrant`` is not passed. If ``use_reentrant=False``, + ``checkpoint`` will use an implementation that does not require + reentrant autograd. This allows ``checkpoint`` to support additional + functionality, such as working as expected with + ``torch.autograd.grad`` and support for keyword arguments input into + the checkpointed function. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> # xdoctest: +SKIP("stub") + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_sequential(model, chunks, input_var) + """ + if use_reentrant is None: + warnings.warn( + "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " + "parameter should be passed explicitly. " + "In version 2.5 we will raise an exception if use_reentrant " + "is not passed. use_reentrant=False is " + "recommended, but if you need to preserve the current default " + "behavior, you can pass use_reentrant=True. Refer to docs for more " + "details on the differences between the two variants." + ) + use_reentrant = True + + # Hack for keyword-only parameter in a python 2.7-compliant way + preserve = kwargs.pop("preserve_rng_state", True) + if kwargs: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + def run_function(start, end, functions): + def forward(input): + for j in range(start, end + 1): + input = functions[j](input) + return input + + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = list(functions.children()) + + segment_size = len(functions) // segments + # the last chunk has to be non-volatile + end = -1 + for start in range(0, segment_size * (segments - 1), segment_size): + end = start + segment_size - 1 + input = checkpoint( + run_function(start, end, functions), + input, + use_reentrant=use_reentrant, + preserve_rng_state=preserve, + ) + return run_function(end + 1, len(functions) - 1, functions)(input) + + +def _internal_assert(cond): + if not cond: + raise AssertionError( + "Something went unexpectedly wrong in activation checkpoint. " + "Please report this bug by filing an issue to PyTorch." + ) + + +# NOTE [ Nestable Checkpoint ] +# +# The semantics of nested checkpoint can be defined by two basic rules. +# Following the two rules leads to an important implication that is central +# to motivating the design. +# +# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden +# from any outer layers of checkpoint. +# +# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its +# parent checkpoint. +# +# Implication: To recompute any given saved tensor, we need to recompute all of +# the checkpoints wrapping it. +# +# Why is this implied? To unpack a saved tensor X during backward we need to +# recompute the inner-most checkpoint (#1), and in order to recompute that +# checkpoint I need to have its inputs, which are managed by that checkpoint's +# parent (#2), which thus also needs to be recomputed first. Continue this line +# of reasoning and we realize that in order to unpack X, all checkpoints that +# were active at the time X was saved need to be recomputed. (unless we have +# already done so in that backward for some other saved tensor). +# +# In practice, we use a noop autograd Function to save inputs as saved tensors. +# During unpack calling ctx.saved_tensor triggers the parent checkpoint to +# recompute. +# +# Rule 3. We should start recomputation as if there are no checkpoints currently +# active. Checkpoints encountered during recomputation are still +# respected. +# +# When we start recomputation, we push the saved variable hook meant for +# recomputation on the stack. See examples in Rule 6 for more context. +# +# * * * * +# +# Beyond the basic semantics specific to nested checkpoint, we impose several +# more constraints that may apply to checkpointing in general. +# +# Rule 4. Lifetime of recomputed tensors +# +# Recomputed tensors are considered specific to particular invocations +# of backward and are always cleared immediately as they are unpacked +# Particularly, we require this to happen even if retain_graph=True. +# +# [ Implementation details of Rule 4 ] +# +# If we were okay with recomputed tensors staying alive after backward is run +# with retain_graph=True, we would store recomputed variables as the values of a +# WeakKeyDictionary and pack strong references to the keys, so that as we +# backward, those packed keys would be cleared as long as retain_graph=False. +# Clearing the packed key clears the corresponding entry in the WKD. +# +# If we wish recomputed variables to be immediately cleared as we unpack them in +# the retain_graph=True case, we cannot rely on the packed keys to be cleared by +# backward automatically. Instead of packing the strong reference to the key +# directly, we pack a container object, which we manually clear as we unpack. +# +# An important detail is that if a second backward happens, the second +# recomputation needs to reset the container with a newly created key. +# +# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we +# know we need. +# +# [ Implementation details of Rule 5 ] +# +# During recomputation, raise an exception if the number of recomputed tensors +# matches the number of tensors that we expected to recompute. We wrap the +# recomputation call with a try-catch to catch this specific exception. See +# Rule #6 below for some examples. +# +# Rule 6. We support doing backward inside checkpoint context +# +# [ retain_graph is True] +# +# def fn(x): +# y = x.sin() +# z = y.cos() +# gx, = torch.autograd.grad(z, x, retains_grad=True) +# return gx, z +# +# out = checkpoint(fn)(inp) +# out.backward() +# +# Because z is saved by cos while checkpoint is enabled, it would not be +# actually saved, and so the .grad() call inside must trigger a recomputation. +# +# During recomputation the "inner pack hook" has two responsibilities: +# +# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors +# 2) Pack the actual tensor (detached) so that one may perform backward on the +# recomputed graph. The tensors saved to this graph will live until the end +# of recomputation, or die earlier if someone performs backward with +# retain_graph=False. +# +# More generally performing backward on the recomputed graph occurs in the +# following cases: +# - If backward is performed inside forward, +# - During the original forward IF early-stop is disabled +# - During the original backward +# - If there are multiple .grad()/.backward() calls, we would perform backward +# on the recomputed graph even if early-stop is enabled (see the example below) +# +# [ retain_graph is False ] +# +# The example below shows what happens if during recomputation we find that some +# of the tensors we are trying to recompute have already been cleared. +# +# Spoiler: we don't do anything special, we just skip over them! +# +# def fn(x): +# y = x.sin() # (1) +# z = y.cos() # (2) +# gx, = torch.autograd.grad(z, x) # (3) +# return x.cos() * gx # (4) +# +# out = checkpoint(fn)(inp) +# out.backward() # (5) +# +# 1, 2. Don't save x and y since we are inside a checkpoint. +# 3. Trigger a recompute of fn since x and y weren't saved. +# And depending on whether early stop is enabled, either stop at (2) or +# continue running the function. +# Because we are running backward with retain_graph=False, we clear x and y's +# holders. +# 4. Don't save x since we are inside a checkpoint. +# 5. Calling backward triggers another recompute of fn. During recompute, we see +# that x and y have already been cleared in the original graph as indicated +# by holder=None. We skip over them. We still save x at (4) (since its holder +# is still alive.) + +_enable_checkpoint_early_stop = True + + +@contextlib.contextmanager +def set_checkpoint_early_stop(enable: bool): + """Context manager that sets whether checkpoint should stop recomputation early. + + By default, non-reentrant checkpoint stops recomputation as soon as it + has computed all needed Tensors. This context manager can be used to disable + that feature if it is problematic for your specific application. + + This context manager only needs to be active when forward is run. It does + not need to be active during backward. + + Example:: + + >>> # xdoctest: +SKIP(failing) + >>> message = "saved tensors default hooks are disabled" + >>> with set_checkpoint_early_stop(False): + ... # Any checkpoint under this context manager will respect this + ... # context manager, even if its backward is performed outside. + ... out = checkpoint(fn, inputs) + ... + >>> out.backward() + """ + global _enable_checkpoint_early_stop + try: + prev = _enable_checkpoint_early_stop + _enable_checkpoint_early_stop = enable + yield + finally: + _enable_checkpoint_early_stop = prev + + +class _Handle: + pass + + +class _Holder: + def __init__(self): + self.handles: Dict[int, Optional[_Handle]] = {} + + +class _NoopSaveInputs(torch.autograd.Function): + @staticmethod + def forward(*args): + return torch.empty((0,)) + + @staticmethod + def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: + # Only tensors can be saved with ctx.save_for_backward, everything else + # is captured by get_args, which is saved directly on ctx + tensor_indices, tensors = zip( + *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)] + ) + idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} + # args but with tensors replaced with None as placeholders + args = [None if isinstance(o, torch.Tensor) else o for o in inputs] + + def get_args(saved_tensors): + # restore the placeholders with the original tensors grabbed from + # ctx.saved_tensors (which may be saved on a parent checkpoint if + # this checkpoint is nested, and that would trigger a recursive + # unpack!) + ret = [ + saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o + for i, o in enumerate(args) + ] + # grab the tail since we also saved the dummy to avoid having to explicitly + # handle the case where there are no tensor inputs + return ret[1:] + + ctx.get_args = get_args + ctx.save_for_backward(*tensors) + + @staticmethod + def backward(ctx, *grad_outputs): + raise AssertionError("Did not expect to backward on this graph") + + +class _CheckpointFrame: + def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): + self.recompute_fn = recompute_fn + self.input_saver = None + self.weak_holders: List[ReferenceType] = [] + # We store this as a weakkeydictionary so that in the case of a partial + # backward, the entries in the dict are cleared alongside the Holder + # which will be removed when the SavedVariable is cleared. + self.recomputed: DefaultDict[ + int, weakref.WeakKeyDictionary[_Handle, torch.Tensor] + ] = defaultdict(weakref.WeakKeyDictionary) + # We need both recomp_counter and recomputed since they can diverge + # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885 + self.recomp_counter: DefaultDict[int, int] = defaultdict(int) + self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) + + # See Rule 5 + self.early_stop = early_stop + + # Debugging + self.metadata_fn = metadata_fn + self.unpack_error_cb = unpack_error_cb + self.x_metadatas = [] + self.forward_completed = False + self.ignore_saved_mismatch = False + + def check_recomputed_tensors_match(self, gid): + if self.ignore_saved_mismatch: + # TODO: we can probably make this check stricter by checking that + # the metadata of the first tensors still match. + return + # NOTE [ Error handling for checkpoint ] + # + # At a high level, we need to check that the tensors saved + # during original forward matches tensors saved during recompute + # This means handling 3 cases: + # + # 1. During recompute, more tensors were saved. + # + # Usually this is hidden due to the StopRecomputationError + # but if early stop is not enabled, or we would have errored + # anyway because there aren't enough weak_holders. But we + # do want to have a nice error. See the _recomputation_hook + # for details. + if not len(self.weak_holders) == self.recomp_counter[gid]: + # 2. During recompute, fewer tensors were saved + # + # We know that everytime we save something do original forward + # we append to weak_holder, and every time we save a tensor + # during recompute we increment recompute_counter. + raise CheckpointError( + "torch.utils.checkpoint: A different number of tensors was saved " + "during the original forward and recomputation.\n" + f"Number of tensors saved during forward: {len(self.weak_holders)}\n" + f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}.\n" + f"{_debug_tip_msg}" + ) + + # 3. During recompute, the same tensors were saved, but they + # have different metadata + nb_meta_different = [] + for idx, weak_holder in enumerate(self.weak_holders): + holder = weak_holder() + if holder is None: + continue + # We've seen all holders since we iterate over them in order + # For every holder that is still alive now, it must've been + # alive when we saw it during recompute, therefore, the + # gid must be set. + _internal_assert(gid in holder.handles) + # We know this is the first unpack, so it couldn't have been set + # to None yet. + _internal_assert(holder.handles[gid] is not None) + # We always set these together in the recomputation hook + _internal_assert(holder.handles[gid] in self.recomputed[gid]) + # see pack hook, x_metadata is 1:1 with weak_holders. + x_meta = self.x_metadatas[idx] + recomputed_x = self.recomputed[gid][holder.handles[gid]] + if x_meta != self.metadata_fn(recomputed_x): + nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x))) + + if len(nb_meta_different) > 0: + mismatched_tensors = "" + for idx, x_meta, recomputed_meta in nb_meta_different: + mismatched_tensors += ( + f"tensor at position {idx}:\n" + f"saved metadata: {x_meta}\n" + f"recomputed metadata: {recomputed_meta}\n" + ) + raise CheckpointError( + "torch.utils.checkpoint: Recomputed values for the following tensors " + "have different metadata than during the forward pass.\n" + f"{mismatched_tensors}.\n" + f"{_debug_tip_msg}" + ) + + +_debug_tip_msg = """ +Tip: To see a more detailed error message, either pass `debug=True` to +`torch.utils.checkpoint.checkpoint(...)` or wrap the code block +with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to +enable checkpoint‑debug mode globally. +""" + + +_checkpoint_error_template = """ \ +An error happened while unpacking tensors; dumping logs of latest computation +because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`. +Scroll all the way down for guidance on how to navigate these logs. + ++~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ +| 1. Stack traces of the operators that ran in the original forward | ++------------------------------------------------------------------------------+ + +{forward_traces} ++~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ +| 2. Stack traces of the operators that ran during recomputation | ++------------------------------------------------------------------------------+ + +{recompute_traces} ++~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ +| 3. Log of operators in the original forward and recomputation | ++------------------------------------------------------------------------------+ +(Scroll up to correlate stack traces with each operation listed below. This + helps identify their source in the code.) + +IMPORTANT: Differences in "detach" calls between the original forward and the + recomputation are expected. They are introduced by the checkpointing + mechanism and can be ignored. + +Operations executed during the original forward: + +{forward_ops} + +Operations executed during recomputation: + +{recompute_ops} + ++------------------------------------------------------------------------------+ + ERROR: Detected non-determinism while running activation checkpointing + + You are seeing this error because you passed `debug=True` to checkpoint and + tensors to be saved during the original forward and differ between those saved + during recomputation. This can happen if different operators were ran in the + original forward and in the recomputation. + + To identify where the mismatch may be coming from, you can do the following: + + 1) Compare the operators ran during original forward and recomputation to + see where they differ. These operators are printed above in the order they + were executed. + + 2) Review the stack trace for each operator to locate its invocation source. + Each operator's stack trace is printed in their execution order. + + Note that the logs can be quite long. Here's how they are structured: + (Tip: you can Ctrl-f for these headers) + + 1. Stack traces of the operators that ran in the original forward + 2. Stack traces of the operators that ran during recomputation + 3. Log of operators in the original forward and recomputation + 4. Error message <--- You are here +-------------------------------------------------------------------------------- +""" + +class CheckpointError(RuntimeError): + pass + + +def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]: + # This function returns the context_fn and error_cb to be used by the + # checkpointing mechanism. error_cb is invoked when an error is detected + # during unpack. + + # record_context_cpp is not support on non-linux non-x86_64 platforms + cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' + + class CaptureLogs: + def __init__(self): + self.logs = None + self.tbs = None + + def get_context_manager(self): + @contextlib.contextmanager + def logging_mode(): + with LoggingTensorMode(), \ + capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: + self.logs, self.tbs = logs_and_tb + yield logs_and_tb + return logging_mode() + + capture_logs_fwd = CaptureLogs() + capture_logs_recompute = CaptureLogs() + + def unpack_error_cb(e: CheckpointError): + def get_str_tb(label, capture_logs): + out = "" + total_len = len(capture_logs.logs) + for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)): + out += f"{log} ({i + 1} of {total_len} in {label})\n\n" + found_torch_dispatch = False + for line in tb: + # Start printing stack trace only after __torch_dispatch__ is found + is_torch_dispatch = line['name'] == '__torch_dispatch__' + if not found_torch_dispatch and not is_torch_dispatch: + continue + elif is_torch_dispatch: + found_torch_dispatch = True + continue + out += f"{line['filename']}:{line['line']}:{line['name']}\n" + out += "\n\n" + return out + assert capture_logs_fwd.logs is not None + assert capture_logs_recompute.logs is not None + raise CheckpointError( + _checkpoint_error_template.format( + forward_traces=get_str_tb("original", capture_logs_fwd), + recompute_traces=get_str_tb("recompute", capture_logs_recompute), + forward_ops="\n".join(capture_logs_fwd.logs), + recompute_ops="\n".join(capture_logs_recompute.logs) + ) + ) from e + + def context_fn(): + return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager() + + return context_fn, unpack_error_cb + +def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: + # These properties are fast to check, easy to understand + return { + "shape": x.shape, + "dtype": x.dtype, + "device": x.device + } + +_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { + _DEFAULT_DETERMINISM_MODE: _default_meta_extractor, + "none": lambda _: None, +} + +# See Rule 5 +class _StopRecomputationError(Exception): + pass + + +class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): + def __init__(self, target_frame_ref: ReferenceType, gid: int): + def pack_hook(x): + x = x.detach() if x.requires_grad else x + target_frame = target_frame_ref() + assert target_frame is not None # appease mypy + recomp_idx = target_frame.recomp_counter[gid] + target_frame.recomp_counter[gid] += 1 + + if recomp_idx >= len(target_frame.weak_holders): + assert not target_frame.early_stop + if not target_frame.forward_completed: + # We run into this case when early stop is not enabled and do + # grad within checkpoint. + # We need to set this flag, so we don't error out later when + # we check if the number of tensors saved during forward and + # recomputation match. + target_frame.ignore_saved_mismatch = True + return x + raise CheckpointError( + "torch.utils.checkpoint: trying to save more tensors during " + "recomputation than during the original forward pass.\n" + f"{_debug_tip_msg}" + ) + + holder = target_frame.weak_holders[recomp_idx]() + + # This holder may have been cleared because someone may have called + # backward within forward. If so, we don't need to save. + if holder is not None: + _internal_assert(holder.handles.get(gid, None) is None) + holder.handles[gid] = _Handle() + target_frame.recomputed[gid][holder.handles[gid]] = x + + if target_frame.early_stop and target_frame.recomp_counter[gid] == len( + target_frame.weak_holders + ): + raise _StopRecomputationError + # See Rule 6: [ retain_graph is True ] above + return x + + def unpack_hook(x): + # See Rule 6: [ retain_graph is True ] above for an example of when + # the graph created during recomputation could be backwarded. + return x + + super().__init__(pack_hook, unpack_hook) + + +# torch._disable_dynamo creates a reference cycle with decorated function +# This function is used to ensure that the decorated function does not have +# a closure, so that other objects aren't also kept alive. +# https://github.com/pytorch/pytorch/issues/154642 +# Note: does not work when fn is compiled +@torch._disable_dynamo +def _run_fn_with_dynamo_disabled(fn, *args, **kwargs): + return fn(*args, **kwargs) + + +class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): + def __init__(self, frame): + def pack_hook(x): + # See Rule 4 above + holder = _Holder() + frame.weak_holders.append(weakref.ref(holder)) + # Save metadata to detect non-determinism + if frame.metadata_fn is not None: + with torch.no_grad(): + frame.x_metadatas.append(frame.metadata_fn(x)) + return holder + + def unpack_hook(holder): + gid = torch._C._current_graph_task_id() + if gid == -1: + # generate a temporary id if we trigger unpack outside of a backward call + gid = int(uuid.uuid4()) + + if not frame.is_recomputed[gid]: + ctx = frame.input_saver.grad_fn + args = ctx.get_args(ctx.saved_tensors) + + try: + with _recomputation_hook( + weakref.ref(frame), gid + ), torch.autograd.enable_grad(): + # See Note: [compiled autograd and checkpoint unpack hook] + _run_fn_with_dynamo_disabled(frame.recompute_fn, *args) + except _StopRecomputationError: + pass + frame.is_recomputed[gid] = True + frame.check_recomputed_tensors_match(gid) + + _internal_assert(gid in holder.handles) + + if holder.handles[gid] is None: + raise CheckpointError( + "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already " + "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do " + "so only once. Otherwise please open an issue with details on your use case." + ) + _internal_assert(holder.handles[gid] in frame.recomputed[gid]) + ret = frame.recomputed[gid][holder.handles[gid]] + holder.handles[gid] = None + return ret + + if frame.unpack_error_cb is not None: + def unpack_hook_with_error_cb(holder): + try: + return unpack_hook(holder) + except CheckpointError as e: + frame.unpack_error_cb(e) + super().__init__(pack_hook, unpack_hook_with_error_cb) + else: + super().__init__(pack_hook, unpack_hook) + + +def _is_compiling(func, args, kwargs): + # Check if we are under AOTAutograd tracing + # Checking that a functional mode is active should always do what we want + return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None + + +class _VersionWrapper: + # Check that cached tensors are not mutated. + def __init__(self, val): + self.val: Union[torch.Tensor, Any] = val + self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None + + def get_val(self, allow_cache_entry_mutation): + if self.version is not None and not allow_cache_entry_mutation: + if self.val._version != self.version: + # Can we give user a stack trace of where the mutation happened? + raise RuntimeError( + "Tensor cached during selective activation checkpoint has been mutated" + ) + return self.val + + +def _maybe_detach(x, any_ret_has_alias_info): + # We detach for two separate reasons: + # - For view ops, we need to ensure that when the tensor is returned from + # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr + # - Avoid reference cycles + # For case 1, it is not enough to check whether x has differentiable dtype + # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. + # when the tensor is a view. + if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): + with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): + # Ensure that view performed beneath autograd properly propagates + # version counter. TODO: Use reentrant_dispatch instead of + # manually manipulating dispatch keys. Using reentrant_dispatch + # would respect inference_mode, though that is not relevant for + # this case. + x = x.detach() + return x + + +class SelectiveCheckpointContext: + """ + Context passed to policy function during selective checkpointing. + + This class is used to pass relevant metadata to the policy function during + selective checkpointing. The metadata includes whether the current invocation + of the policy function is during recomputation or not. + + Example: + >>> # xdoctest: +SKIP(stub) + >>> + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> print(ctx.is_recompute) + >>> + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) + """ + def __init__(self, *, is_recompute): + self.is_recompute = is_recompute + + +class CheckpointPolicy(enum.Enum): + """ + Enum for specifying the policy for checkpointing during backpropagation. + + The following policies are supported: + + - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward + pass and will not be recomputed during the backward pass + - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the + forward pass and will be recomputed during the backward pass + + Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden + by other subsystems like `torch.compile`. + + .. note:: + A policy function that always returns ``PREFER_RECOMPUTE`` is + equivalent to vanilla checkpointing. + + A policy function that returns ``PREFER_SAVE`` every op is + NOT equivalent to not using checkpointing. Using such a policy would + save additional tensors not limited to ones that are actually needed for + gradient computation. + """ + MUST_SAVE = 0 + PREFER_SAVE = 1 + MUST_RECOMPUTE = 2 + PREFER_RECOMPUTE = 3 + + +def _policy_from_bool(b): + # For backward compatability + return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE + + +SAC_IGNORED_OPS = { + # AC inserts different number of detach during forward and recompute. + torch.ops.aten.detach.default, + # AC's determinism check invokes additional metadata ops during forward. + # With subclasses involved, these metadata ops become dispatchable, this + # can result in incorrectness if these ops are selected cached. + torch.ops.prim.device.default, +} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) + + +class _CachingTorchDispatchMode(TorchDispatchMode): + # Used together with _CachedTorchDispatchMode to implement SAC. + def __init__(self, policy_fn, storage): + self.policy_fn = policy_fn + self.storage = storage + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if func in SAC_IGNORED_OPS: + return func(*args, **kwargs) + + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), + func, *args, **kwargs) + if isinstance(policy, bool): + policy = _policy_from_bool(policy) + + is_compiling = _is_compiling(func, args, kwargs) + + if is_compiling: + # Overwrite each node's "recompute" tag to add in the user annotation. + fx_traceback.current_meta["recompute"] = policy + + out = func(*args, **kwargs) + + # HOPs don't support func._schema + # HOPs don't alias -> this is always true today and will be always true for a long time + # TODO HOPs don't mutate -> this is always true today but will not be true forever + if isinstance(func, torch._ops.HigherOrderOperator): + any_ret_has_alias_info = False + else: + any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) + return out + +class _CachedTorchDispatchMode(TorchDispatchMode): + # Used together with _CachedTorchDispatchMode to implement SAC. + def __init__(self, policy_fn, storage, allow_cache_entry_mutation): + self.policy_fn = policy_fn + self.storage = storage + self.allow_cache_entry_mutation = allow_cache_entry_mutation + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if func in SAC_IGNORED_OPS: + return func(*args, **kwargs) + + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), + func, *args, **kwargs) + if isinstance(policy, bool): + policy = _policy_from_bool(policy) + + is_compiling = _is_compiling(func, args, kwargs) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + storage = self.storage.get(func) + if storage is None: + raise RuntimeError(f"{func} encountered during backward, but not found in storage") + if len(storage) == 0: + raise RuntimeError( + "Trying to backward an extra time. You are only allowed to backward once " + "on any region computed under selective activation checkpoint." + ) + out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) + else: + out = func(*args, **kwargs) + return out + + +def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): + """ + Helper to avoid recomputing certain ops during activation checkpointing. + + Use this with `torch.utils.checkpoint.checkpoint` to control which + operations are recomputed during the backward pass. + + Args: + policy_fn_or_list (Callable or List): + - If a policy function is provided, it should accept a + :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and + kwargs to the op, and return a :class:`CheckpointPolicy` enum value + indicating whether the execution of the op should be recomputed or not. + - If a list of operations is provided, it is equivalent to a policy + returning `CheckpointPolicy.MUST_SAVE` for the specified + operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other + operations. + allow_cache_entry_mutation (bool, optional): By default, an error is + raised if any tensors cached by selective activation checkpoint are + mutated in order to ensure correctness. If set to `True`, this check + is disabled. + Returns: + A tuple of two context managers. + + Example: + >>> # xdoctest: +REQUIRES(LINUX) + >>> import functools + >>> + >>> x = torch.rand(10, 10, requires_grad=True) + >>> y = torch.rand(10, 10, requires_grad=True) + >>> + >>> ops_to_save = [ + >>> torch.ops.aten.mm.default, + >>> ] + >>> + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> if op in ops_to_save: + >>> return CheckpointPolicy.MUST_SAVE + >>> else: + >>> return CheckpointPolicy.PREFER_RECOMPUTE + >>> + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> # or equivalently + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) + >>> + >>> def fn(x, y): + >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + >>> + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) + """ + # NB: If grad_mode is disabled, checkpoint would not run forward under + # context_fn anyway, so proceed as usual. + if isinstance(policy_fn_or_list, list): + for op in policy_fn_or_list: + if not isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + _extra_msg = ( + "Please update the OpOverloadPacket to a specific OpOverload." + "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." + ) if isinstance(op, torch._ops.OpOverloadPacket) else "" + raise ValueError( + f"Expected op in `op_list` to be an OpOverload but got: {op} " + f"of type {type(op)}. {_extra_msg}" + ) + + def policy_fn(ctx, op, *args, **kwargs): + if op in policy_fn_or_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + elif callable(policy_fn_or_list): + policy_fn = policy_fn_or_list + else: + raise TypeError("policy_fn_or_list must be either a function or a list of ops.") + + storage: Dict[Any, List[Any]] = defaultdict(list) + return ( + _CachingTorchDispatchMode(policy_fn, storage), + _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), + ) + +# NB: this helper wraps fn before calling checkpoint_impl. kwargs and +# saving/restoring of global state is handled here. + +def _checkpoint_without_reentrant_generator( + fn, + preserve_rng_state=True, + context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, + determinism_check: str = _DEFAULT_DETERMINISM_MODE, + debug: bool = False, + *args, + **kwargs +): + """Checkpointing without reentrant autograd. + + Args: + fn: describes what to run in the forward pass of the model or + part of the model. It should also know how to handle the inputs + passed as the tuple. For example, in LSTM, if user passes + ``(activation, hidden)``, :attr:`function` should correctly use the + first input as ``activation`` and the second input as ``hidden`` + preserve_rng_state(bool, optional): Omit stashing and restoring + the RNG state during each checkpoint. + Default: ``True`` + context_fn(Callable, optional): A callable returning a tuple of two + context managers. The function and its recomputation will be run + under the first and second context managers respectively. + determinism_check(str, optional): A string specifying the determinism + check to perform. By default it is set to ``"default"`` which + compares the shapes, dtypes, and devices of the recomputed tensors + against those the saved tensors. To turn off this check, specify + ``"none"``. Currently these are the only two supported values. + Please open an issue if you would like to see more determinism + checks. + debug(bool, optional): If ``True``, error messages will also include + a trace of the operators ran during the original forward computation + as well as the recomputation. + *args: Arguments to pass in to the given ``function``. + **kwargs: Keyword arguments to pass into the given ``function``. + """ + unpack_error_cb = None + + if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: + if context_fn != noop_context_fn: + raise ValueError( + "debug=True is incompatible with non-default context_fn" + ) + context_fn, unpack_error_cb = _get_debug_context_and_cb() + + if determinism_check in _allowed_determinism_checks_to_fns: + metadata_fn = _allowed_determinism_checks_to_fns[determinism_check] + else: + raise ValueError( + f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, " + f"but got {determinism_check}" + ) + + device_type = _infer_device_type(*args) + device_module = _get_device_module(device_type) + forward_context, recompute_context = context_fn() + if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: + assert ( + isinstance(forward_context, TorchDispatchMode) and + isinstance(recompute_context, TorchDispatchMode) + ), \ + "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \ + "must generate a tuple of two `TorchDispatchMode`s." + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type) + + if preserve_rng_state: + fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function. + # If they do so, we raise an error.) + had_device_in_fwd = False + if getattr(device_module, "_initialized", False): + had_device_in_fwd = True + fwd_devices, fwd_device_states = get_device_states(*args) + + def recompute_fn(*inputs): + kwargs, *args = inputs + # This will be called later during recomputation. This wrapping enables + # the necessary global state to be captured. + rng_devices = [] + if preserve_rng_state and had_device_in_fwd: + rng_devices = fwd_devices + with torch.random.fork_rng( + devices=rng_devices, enabled=preserve_rng_state, device_type=device_type + ): + if preserve_rng_state: + torch.set_rng_state(fwd_cpu_state) + if had_device_in_fwd: + set_device_states(fwd_devices, fwd_device_states, device_type=device_type) + + device_autocast_ctx = torch.amp.autocast( + device_type=device_type, **device_autocast_kwargs + ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext() + with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] + fn(*args, **kwargs) + + new_frame = _CheckpointFrame( + recompute_fn, + _enable_checkpoint_early_stop, + unpack_error_cb, + metadata_fn + ) + dummy = torch.empty((0,), requires_grad=True) + new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) + + # When ambient grad_mode is False + if new_frame.input_saver.grad_fn is None: + yield + return + + with _checkpoint_hook(new_frame), forward_context: + yield + new_frame.forward_completed = True + + if getattr(device_module, "_initialized", False) and \ + preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined] + # Device was not initialized before running the forward, so we didn't + # stash the device state. + raise RuntimeError( + "PyTorch's device state was initialized in the forward pass " + "of a Checkpoint, which is not allowed. Please open an issue " + "if you need this feature." + ) + + return + +# Note: [compiled autograd and checkpoint unpack hook] +# When tracing via compiled autograd, this hook will be visible to the +# compiler if the forward of this checkpointed region ran in eager. +# If the forward had ran under compile, it would have been wrapped in a +# higher order op. See Note: [torch.compile and checkpoint]. +# +# Since we run the recomputation hook under a enable_grad context, +# AOTDispatch will trace a joint graph for this hook, and may +# save different activations than in eager. This conflicts with the +# strict activation count checks in `frame.check_recomputed_tensors_match`. +# So, we disable this hook to force it to recompute eager checkpointed regions +# in eager. This could be removed if we can disable the partitioner for this +# graph segment. diff --git a/phivenv/Lib/site-packages/torch/utils/collect_env.py b/phivenv/Lib/site-packages/torch/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..7336a55382a4818929cdf4d545cc76e68b24a929 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/collect_env.py @@ -0,0 +1,697 @@ +# mypy: allow-untyped-defs + +# Unlike the rest of the PyTorch this file must be python2 compliant. +# This script outputs relevant system environment info +# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` +import datetime +import json +import locale +import re +import subprocess +import sys +import os +from collections import namedtuple + + +try: + import torch + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False + +# System Environment Information +SystemEnv = namedtuple('SystemEnv', [ + 'torch_version', + 'is_debug_build', + 'cuda_compiled_version', + 'gcc_version', + 'clang_version', + 'cmake_version', + 'os', + 'libc_version', + 'python_version', + 'python_platform', + 'is_cuda_available', + 'cuda_runtime_version', + 'cuda_module_loading', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', + 'pip_version', # 'pip' or 'pip3' + 'pip_packages', + 'conda_packages', + 'hip_compiled_version', + 'hip_runtime_version', + 'miopen_runtime_version', + 'caching_allocator_config', + 'is_xnnpack_available', + 'cpu_info', +]) + +COMMON_PATTERNS = [ + "torch", + "numpy", + "triton", + "optree", +] + +NVIDIA_PATTERNS = [ + "cuda-cudart", + "cuda-cupti", + "cuda-libraries", + "cuda-opencl", + "cuda-nvrtc", + "cuda-runtime", + "cublas", + "cudnn", + "cufft", + "curand", + "cusolver", + "cusparse", + "nccl", + "nvjitlink", + "nvtx", +] + +CONDA_PATTERNS = [ + "cudatoolkit", + "soumith", + "mkl", + "magma", +] + +PIP_PATTERNS = [ + "mypy", + "flake8", + "onnx", +] + + +def run(command): + """Return (return-code, stdout, stderr).""" + shell = True if type(command) is str else False + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=shell) + raw_output, raw_err = p.communicate() + rc = p.returncode + if get_platform() == 'win32': + enc = 'oem' + else: + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + err = raw_err.decode(enc) + return rc, output.strip(), err.strip() + + +def run_and_read_all(run_lambda, command): + """Run command using run_lambda; reads and returns entire output if rc is 0.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Run command using run_lambda, returns the first regex match if it exists.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + +def run_and_return_first_line(run_lambda, command): + """Run command using run_lambda and returns first line if output is not empty.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out.split('\n')[0] + + +def get_conda_packages(run_lambda, patterns=None): + if patterns is None: + patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + conda = os.environ.get('CONDA_EXE', 'conda') + out = run_and_read_all(run_lambda, "{} list".format(conda)) + if out is None: + return out + + return "\n".join( + line + for line in out.splitlines() + if not line.startswith("#") + and any(name in line for name in patterns) + ) + +def get_gcc_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + +def get_clang_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') + + +def get_cmake_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + + +def get_gpu_info(run_lambda): + if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): + if TORCH_AVAILABLE and torch.cuda.is_available(): + if torch.version.hip is not None: + prop = torch.cuda.get_device_properties(0) + if hasattr(prop, "gcnArchName"): + gcnArch = " ({})".format(prop.gcnArchName) + else: + gcnArch = "NoGCNArchNameOnOldPyTorch" + else: + gcnArch = "" + return torch.cuda.get_device_name(None) + gcnArch + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, '', out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') + + +def get_cudnn_version(run_lambda): + """Return a list of libcudnn.so; it's hard to tell which one is being used.""" + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") + where_cmd = os.path.join(system_root, 'System32', 'where') + cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) + elif get_platform() == 'darwin': + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation + # https://docs.nvidia.com/deeplearning/cudnn/installation/latest/ + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or (rc != 1 and rc != 0): + l = os.environ.get('CUDNN_LIBRARY') + if l is not None and os.path.isfile(l): + return os.path.realpath(l) + return None + files_set = set() + for fn in out.split('\n'): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files_set.add(fn) + if not files_set: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = sorted(files_set) + if len(files) == 1: + return files[0] + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = 'nvidia-smi' + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) + new_path = os.path.join(system_root, 'System32', smi) + smis = [new_path, legacy_path] + for candidate_smi in smis: + if os.path.exists(candidate_smi): + smi = '"{}"'.format(candidate_smi) + break + return smi + + +# example outputs of CPU infos +# * linux +# Architecture: x86_64 +# CPU op-mode(s): 32-bit, 64-bit +# Address sizes: 46 bits physical, 48 bits virtual +# Byte Order: Little Endian +# CPU(s): 128 +# On-line CPU(s) list: 0-127 +# Vendor ID: GenuineIntel +# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# CPU family: 6 +# Model: 106 +# Thread(s) per core: 2 +# Core(s) per socket: 32 +# Socket(s): 2 +# Stepping: 6 +# BogoMIPS: 5799.78 +# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr +# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl +# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 +# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand +# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced +# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap +# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 +# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq +# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities +# Virtualization features: +# Hypervisor vendor: KVM +# Virtualization type: full +# Caches (sum of all): +# L1d: 3 MiB (64 instances) +# L1i: 2 MiB (64 instances) +# L2: 80 MiB (64 instances) +# L3: 108 MiB (2 instances) +# NUMA: +# NUMA node(s): 2 +# NUMA node0 CPU(s): 0-31,64-95 +# NUMA node1 CPU(s): 32-63,96-127 +# Vulnerabilities: +# Itlb multihit: Not affected +# L1tf: Not affected +# Mds: Not affected +# Meltdown: Not affected +# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown +# Retbleed: Not affected +# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp +# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization +# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence +# Srbds: Not affected +# Tsx async abort: Not affected +# * win32 +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU0 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 +# +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU1 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 + +def get_cpu_info(run_lambda): + rc, out, err = 0, '', '' + if get_platform() == 'linux': + rc, out, err = run_lambda('lscpu') + elif get_platform() == 'win32': + rc, out, err = run_lambda( + 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\ + Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\ + | ConvertTo-Json"' + ) + if rc == 0: + lst = [] + try: + obj = json.loads(out) + if type(obj) is list: + for o in obj: + lst.append("----------------------") + lst.extend([f"{k}: {v}" for (k, v) in o.items()]) + else: + lst.extend([f"{k}: {v}" for (k, v) in obj.items()]) + except ValueError as e: + lst.append(out) + lst.append(str(e)) + out = "\n".join(lst) + elif get_platform() == 'darwin': + rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") + cpu_info = 'None' + if rc == 0: + cpu_info = out + else: + cpu_info = err + return cpu_info + + +def get_platform(): + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') + + +def get_windows_version(run_lambda): + ret = run_and_read_all( + run_lambda, + 'powershell.exe "gwmi -Class Win32_OperatingSystem | Select-Object -Property Caption,\ + OSArchitecture,Version | ConvertTo-Json"', + ) + try: + obj = json.loads(ret) + ret = f'{obj["Caption"]} ({obj["Version"]} {obj["OSArchitecture"]})' + except ValueError as e: + ret += f"\n{str(e)}" + return ret + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') + + +def get_os(run_lambda): + from platform import machine + platform = get_platform() + + if platform == 'win32' or platform == 'cygwin': + return get_windows_version(run_lambda) + + if platform == 'darwin': + version = get_mac_version(run_lambda) + if version is None: + return None + return 'macOS {} ({})'.format(version, machine()) + + if platform == 'linux': + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + return '{} ({})'.format(platform, machine()) + + # Unknown platform + return platform + + +def get_python_platform(): + import platform + return platform.platform() + + +def get_libc_version(): + import platform + if get_platform() != 'linux': + return 'N/A' + return '-'.join(platform.libc_ver()) + + +def get_pip_packages(run_lambda, patterns=None): + """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" + if patterns is None: + patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + + pip_version = 'pip3' if sys.version_info.major == 3 else 'pip' + + os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1' + # People generally have pip as `pip` or `pip3` + # But here it is invoked as `python -mpip` + out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze']) + if out is None: + return pip_version, out + + filtered_out = '\n'.join( + line + for line in out.splitlines() + if any(name in line for name in patterns) + ) + + return pip_version, filtered_out + + +def get_cachingallocator_config(): + ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + if not ca_config: + ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '') + return ca_config + + +def get_cuda_module_loading_config(): + if TORCH_AVAILABLE and torch.cuda.is_available(): + torch.cuda.init() + config = os.environ.get('CUDA_MODULE_LOADING', '') + return config + else: + return "N/A" + + +def is_xnnpack_available(): + if TORCH_AVAILABLE: + import torch.backends.xnnpack + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + else: + return "N/A" + +def get_env_info(): + """ + Collects environment information to aid in debugging. + + The returned environment information contains details on torch version, is debug build + or not, cuda compiled version, gcc version, clang version, cmake version, operating + system, libc version, python version, python platform, CUDA availability, CUDA + runtime version, CUDA module loading config, GPU model and configuration, Nvidia + driver version, cuDNN version, pip version and versions of relevant pip and + conda packages, HIP runtime version, MIOpen runtime version, + Caching allocator config, XNNPACK availability and CPU information. + + Returns: + SystemEnv (namedtuple): A tuple containining various environment details + and system information. + """ + run_lambda = run + pip_version, pip_list_output = get_pip_packages(run_lambda) + + if TORCH_AVAILABLE: + version_str = torch.__version__ + debug_mode_str = str(torch.version.debug) + cuda_available_str = str(torch.cuda.is_available()) + cuda_version_str = torch.version.cuda + if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + else: # HIP version + def get_version_or_na(cfg, prefix): + _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] + return _lst[0] if _lst else 'N/A' + + cfg = torch._C._show_config().split('\n') + hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') + miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') + cuda_version_str = 'N/A' + hip_compiled_version = torch.version.hip + else: + version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + + sys_version = sys.version.replace("\n", " ") + + conda_packages = get_conda_packages(run_lambda) + + return SystemEnv( + torch_version=version_str, + is_debug_build=debug_mode_str, + python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), + python_platform=get_python_platform(), + is_cuda_available=cuda_available_str, + cuda_compiled_version=cuda_version_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + cuda_module_loading=get_cuda_module_loading_config(), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + hip_compiled_version=hip_compiled_version, + hip_runtime_version=hip_runtime_version, + miopen_runtime_version=miopen_runtime_version, + pip_version=pip_version, + pip_packages=pip_list_output, + conda_packages=conda_packages, + os=get_os(run_lambda), + libc_version=get_libc_version(), + gcc_version=get_gcc_version(run_lambda), + clang_version=get_clang_version(run_lambda), + cmake_version=get_cmake_version(run_lambda), + caching_allocator_config=get_cachingallocator_config(), + is_xnnpack_available=is_xnnpack_available(), + cpu_info=get_cpu_info(run_lambda), + ) + +env_info_fmt = """ +PyTorch version: {torch_version} +Is debug build: {is_debug_build} +CUDA used to build PyTorch: {cuda_compiled_version} +ROCM used to build PyTorch: {hip_compiled_version} + +OS: {os} +GCC version: {gcc_version} +Clang version: {clang_version} +CMake version: {cmake_version} +Libc version: {libc_version} + +Python version: {python_version} +Python platform: {python_platform} +Is CUDA available: {is_cuda_available} +CUDA runtime version: {cuda_runtime_version} +CUDA_MODULE_LOADING set to: {cuda_module_loading} +GPU models and configuration: {nvidia_gpu_models} +Nvidia driver version: {nvidia_driver_version} +cuDNN version: {cudnn_version} +HIP runtime version: {hip_runtime_version} +MIOpen runtime version: {miopen_runtime_version} +Is XNNPACK available: {is_xnnpack_available} + +CPU: +{cpu_info} + +Versions of relevant libraries: +{pip_packages} +{conda_packages} +""".strip() + + +def pretty_str(envinfo): + def replace_nones(dct, replacement='Could not collect'): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true='Yes', false='No'): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def prepend(text, tag='[prepend]'): + lines = text.split('\n') + updated_lines = [tag + line for line in lines] + return '\n'.join(updated_lines) + + def replace_if_empty(text, replacement='No relevant packages'): + if text is not None and len(text) == 0: + return replacement + return text + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', + ] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = 'No CUDA' + if envinfo.cuda_compiled_version is None: + mutable_dict['cuda_compiled_version'] = 'None' + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + # If either of these are '', replace with 'No relevant packages' + mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) + mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) + + # Tag conda and pip packages with a prefix + # If they were previously None, they'll show up as ie '[conda] Could not collect' + if mutable_dict['pip_packages']: + mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], + '[{}] '.format(envinfo.pip_version)) + if mutable_dict['conda_packages']: + mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], + '[conda] ') + mutable_dict['cpu_info'] = envinfo.cpu_info + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + """ + Returns a pretty string of environment information. + + This function retrieves environment information by calling the `get_env_info` function + and then formats the information into a human-readable string. The retrieved environment + information is listed in the document of `get_env_info`. + This function is used in `python collect_env.py` that should be executed when reporting a bug. + + Returns: + str: A pretty string of the environment information. + """ + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): + minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR + if sys.platform == "linux" and os.path.exists(minidump_dir): + dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] + latest = max(dumps, key=os.path.getctime) + ctime = os.path.getctime(latest) + creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') + msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ + "if this is related to your bug please include it when you file a report ***" + print(msg, file=sys.stderr) + + + +if __name__ == '__main__': + main() diff --git a/phivenv/Lib/site-packages/torch/utils/cpp_backtrace.py b/phivenv/Lib/site-packages/torch/utils/cpp_backtrace.py new file mode 100644 index 0000000000000000000000000000000000000000..d47092fbd2b8bfeb4192955293ec91b013ac172d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/cpp_backtrace.py @@ -0,0 +1,12 @@ +# mypy: allow-untyped-defs +from torch._C import _get_cpp_backtrace + +def get_cpp_backtrace(frames_to_skip=0, maximum_number_of_frames=64) -> str: + r""" + Return a string containing the C++ stack trace of the current thread. + + Args: + frames_to_skip (int): the number of frames to skip from the top of the stack + maximum_number_of_frames (int): the maximum number of frames to return + """ + return _get_cpp_backtrace(frames_to_skip, maximum_number_of_frames) diff --git a/phivenv/Lib/site-packages/torch/utils/cpp_extension.py b/phivenv/Lib/site-packages/torch/utils/cpp_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..867e2b8d1709b08ead702f5a05faf1f9443b1541 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/cpp_extension.py @@ -0,0 +1,2995 @@ +# mypy: allow-untyped-defs +import copy +import glob +import importlib +import importlib.abc +import os +import re +import shlex +import shutil +import setuptools +import subprocess +import sys +import sysconfig +import collections +from pathlib import Path +import errno +import logging + +logger = logging.getLogger(__name__) + +import torch +import torch._appdirs +from .file_baton import FileBaton +from ._cpp_extension_versioner import ExtensionVersioner +from .hipify import hipify_python +from .hipify.hipify_python import GeneratedFileCleaner +from typing import Optional, Union +from torch.torch_version import TorchVersion, Version + +from setuptools.command.build_ext import build_ext + +IS_WINDOWS = sys.platform == 'win32' +IS_MACOS = sys.platform.startswith('darwin') +IS_LINUX = sys.platform.startswith('linux') +LIB_EXT = '.pyd' if IS_WINDOWS else '.so' +EXEC_EXT = '.exe' if IS_WINDOWS else '' +CLIB_PREFIX = '' if IS_WINDOWS else 'lib' +CLIB_EXT = '.dll' if IS_WINDOWS else '.so' +SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared' + +_HERE = os.path.abspath(__file__) +_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) +TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib') + + +SUBPROCESS_DECODE_ARGS = ('oem',) if IS_WINDOWS else () +MINIMUM_GCC_VERSION = (5, 0, 0) +MINIMUM_MSVC_VERSION = (19, 0, 24215) + +VersionRange = tuple[tuple[int, ...], tuple[int, ...]] +VersionMap = dict[str, VersionRange] +# The following values were taken from the following GitHub gist that +# summarizes the minimum valid major versions of g++/clang++ for each supported +# CUDA version: https://gist.github.com/ax3l/9489132 +# Or from include/crt/host_config.h in the CUDA SDK +# The second value is the exclusive(!) upper bound, i.e. min <= version < max +CUDA_GCC_VERSIONS: VersionMap = { + '11.0': (MINIMUM_GCC_VERSION, (10, 0)), + '11.1': (MINIMUM_GCC_VERSION, (11, 0)), + '11.2': (MINIMUM_GCC_VERSION, (11, 0)), + '11.3': (MINIMUM_GCC_VERSION, (11, 0)), + '11.4': ((6, 0, 0), (12, 0)), + '11.5': ((6, 0, 0), (12, 0)), + '11.6': ((6, 0, 0), (12, 0)), + '11.7': ((6, 0, 0), (12, 0)), +} + +MINIMUM_CLANG_VERSION = (3, 3, 0) +CUDA_CLANG_VERSIONS: VersionMap = { + '11.1': (MINIMUM_CLANG_VERSION, (11, 0)), + '11.2': (MINIMUM_CLANG_VERSION, (12, 0)), + '11.3': (MINIMUM_CLANG_VERSION, (12, 0)), + '11.4': (MINIMUM_CLANG_VERSION, (13, 0)), + '11.5': (MINIMUM_CLANG_VERSION, (13, 0)), + '11.6': (MINIMUM_CLANG_VERSION, (14, 0)), + '11.7': (MINIMUM_CLANG_VERSION, (14, 0)), +} + +__all__ = ["get_default_build_root", "check_compiler_ok_for_platform", "get_compiler_abi_compatibility_and_version", "BuildExtension", + "CppExtension", "CUDAExtension", "SyclExtension", "include_paths", "library_paths", "load", "load_inline", "is_ninja_available", + "verify_ninja_availability", "remove_extension_h_precompiler_headers", "get_cxx_compiler", "check_compiler_is_gcc"] +# Taken directly from python stdlib < 3.9 +# See https://github.com/pytorch/pytorch/issues/48617 +def _nt_quote_args(args: Optional[list[str]]) -> list[str]: + """Quote command-line arguments for DOS/Windows conventions. + + Just wraps every argument which contains blanks in double quotes, and + returns a new argument list. + """ + # Cover None-type + if not args: + return [] + return [f'"{arg}"' if ' ' in arg else arg for arg in args] + +def _find_cuda_home() -> Optional[str]: + """Find the CUDA install path.""" + # Guess #1 + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + if cuda_home is None: + # Guess #2 + nvcc_path = shutil.which("nvcc") + if nvcc_path is not None: + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + else: + # Guess #3 + if IS_WINDOWS: + cuda_homes = glob.glob( + 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') + if len(cuda_homes) == 0: + cuda_home = '' + else: + cuda_home = cuda_homes[0] + else: + cuda_home = '/usr/local/cuda' + if not os.path.exists(cuda_home): + cuda_home = None + if cuda_home and not torch.cuda.is_available(): + logger.warning("No CUDA runtime is found, using CUDA_HOME='%s'", cuda_home) + return cuda_home + +def _find_rocm_home() -> Optional[str]: + """Find the ROCm install path.""" + # Guess #1 + rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') + if rocm_home is None: + # Guess #2 + hipcc_path = shutil.which('hipcc') + if hipcc_path is not None: + rocm_home = os.path.dirname(os.path.dirname( + os.path.realpath(hipcc_path))) + # can be either /hip/bin/hipcc or /bin/hipcc + if os.path.basename(rocm_home) == 'hip': + rocm_home = os.path.dirname(rocm_home) + else: + # Guess #3 + fallback_path = '/opt/rocm' + if os.path.exists(fallback_path): + rocm_home = fallback_path + if rocm_home and torch.version.hip is None: + logger.warning("No ROCm runtime is found, using ROCM_HOME='%s'", rocm_home) + return rocm_home + +def _find_sycl_home() -> Optional[str]: + sycl_home = None + icpx_path = shutil.which('icpx') + # Guess 1: for source code build developer/user, we'll have icpx in PATH, + # which will tell us the SYCL_HOME location. + if icpx_path is not None: + sycl_home = os.path.dirname(os.path.dirname( + os.path.realpath(icpx_path))) + + # Guess 2: for users install Pytorch with XPU support, the sycl runtime is + # inside intel-sycl-rt, which is automatically installed via pip dependency. + else: + try: + files = importlib.metadata.files('intel-sycl-rt') or [] + for f in files: + if f.name == "libsycl.so": + sycl_home = os.path.dirname(Path(f.locate()).parent.resolve()) + break + except importlib.metadata.PackageNotFoundError: + logger.warning("Trying to find SYCL_HOME from intel-sycl-rt package, but it is not installed.") + return sycl_home + +def _join_rocm_home(*paths) -> str: + """ + Join paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set. + + This is basically a lazy way of raising an error for missing $ROCM_HOME + only once we need to get any ROCm-specific path. + """ + if ROCM_HOME is None: + raise OSError('ROCM_HOME environment variable is not set. ' + 'Please set it to your ROCm install root.') + return os.path.join(ROCM_HOME, *paths) + +def _join_sycl_home(*paths) -> str: + """ + Join paths with SYCL_HOME, or raises an error if it SYCL_HOME is not found. + + This is basically a lazy way of raising an error for missing SYCL_HOME + only once we need to get any SYCL-specific path. + """ + if SYCL_HOME is None: + raise OSError('SYCL runtime is not dected. Please setup the pytorch ' + 'prerequisites for Intel GPU following the instruction in ' + 'https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support ' + 'or install intel-sycl-rt via pip.') + + return os.path.join(SYCL_HOME, *paths) + + + +ABI_INCOMPATIBILITY_WARNING = ( + " !! WARNING !!" + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + "Your compiler (%s) may be ABI-incompatible with PyTorch!" + "Please use a compiler that is ABI-compatible with GCC 5.0 and above." + "See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html." + "See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6" + "for instructions on how to install GCC 5 or higher." + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + " !! WARNING !!" +) +WRONG_COMPILER_WARNING = ( + " !! WARNING !!" + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + "Your compiler (%s) is not compatible with the compiler Pytorch was" + "built with for this platform, which is %s on %s. Please" + "use %s to to compile your extension. Alternatively, you may" + "compile PyTorch from source using %s, and then you can also use" + "%s to compile your extension." + "See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help" + "with compiling PyTorch from source." + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + " !! WARNING !!" +) +CUDA_MISMATCH_MESSAGE = ( + "The detected CUDA version (%s) mismatches the version that was used to compile" + "PyTorch (%s). Please make sure to use the same CUDA versions." +) +CUDA_MISMATCH_WARN = ( + "The detected CUDA version (%s) has a minor version mismatch with the version that was used to compile PyTorch (%s). Most likely this shouldn't be a problem." +) +CUDA_NOT_FOUND_MESSAGE = ( + "CUDA was not found on the system, please set the CUDA_HOME or the CUDA_PATH" + "environment variable or add NVCC to your system PATH. The extension compilation will fail." +) +ROCM_HOME = _find_rocm_home() if (torch.cuda._is_compiled() and torch.version.hip) else None +HIP_HOME = _join_rocm_home('hip') if ROCM_HOME else None +IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False +ROCM_VERSION = None +if torch.version.hip is not None: + ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2]) + +CUDA_HOME = _find_cuda_home() if (torch.cuda._is_compiled() and torch.version.cuda) else None +CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH') +SYCL_HOME = _find_sycl_home() if torch.xpu._is_compiled() else None + +# PyTorch releases have the version pattern major.minor.patch, whereas when +# PyTorch is built from source, we append the git commit hash, which gives +# it the below pattern. +BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+') + +COMMON_MSVC_FLAGS = ['/MD', '/wd4819', '/wd4251', '/wd4244', '/wd4267', '/wd4275', '/wd4018', '/wd4190', '/wd4624', '/wd4067', '/wd4068', '/EHsc'] + +MSVC_IGNORE_CUDAFE_WARNINGS = [ + 'base_class_has_different_dll_interface', + 'field_without_dll_interface', + 'dll_interface_conflict_none_assumed', + 'dll_interface_conflict_dllexport_assumed' +] + +COMMON_NVCC_FLAGS = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + '--expt-relaxed-constexpr' +] + +COMMON_HIP_FLAGS = [ + '-D__HIP_PLATFORM_AMD__=1', + '-DUSE_ROCM=1', + '-DHIPBLAS_V2', +] + +if not IS_WINDOWS: + COMMON_HIP_FLAGS.append('-fPIC') + +COMMON_HIPCC_FLAGS = [ + '-DCUDA_HAS_FP16=1', + '-D__HIP_NO_HALF_OPERATORS__=1', + '-D__HIP_NO_HALF_CONVERSIONS__=1', + '-DHIP_ENABLE_WARP_SYNC_BUILTINS=1' +] + + + +def _get_sycl_arch_list(): + if 'TORCH_XPU_ARCH_LIST' in os.environ: + return os.environ.get('TORCH_XPU_ARCH_LIST') + arch_list = torch.xpu.get_arch_list() + # Dropping dg2* archs since they lack hardware support for fp64 and require + # special consideration from the user. If needed these platforms can + # be requested thru TORCH_XPU_ARCH_LIST environment variable. + arch_list = [x for x in arch_list if not x.startswith('dg2')] + return ','.join(arch_list) + + +# If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled +# for default spir64 target and avoid device specific compilations entirely. Further, kernels +# will be JIT compiled at runtime. +def _append_sycl_targets_if_missing(cflags): + if any(flag.startswith('-fsycl-targets=') for flag in cflags): + # do nothing: user has manually specified sycl targets + return + if _get_sycl_arch_list() != '': + # AOT (spir64_gen) + JIT (spir64) + cflags.append('-fsycl-targets=spir64_gen,spir64') + else: + # JIT (spir64) + cflags.append('-fsycl-targets=spir64') + +def _get_sycl_device_flags(cflags): + # We need last occurence of -fsycl-targets as it will be the one taking effect. + # So searching in reversed list. + flags = [f for f in reversed(cflags) if f.startswith('-fsycl-targets=')] + assert flags, "bug: -fsycl-targets should have been ammended to cflags" + + arch_list = _get_sycl_arch_list() + if arch_list != '': + flags += [f'-Xs "-device {arch_list}"'] + return flags + +_COMMON_SYCL_FLAGS = [ + '-fsycl', +] + +_SYCL_DLINK_FLAGS = [ + *_COMMON_SYCL_FLAGS, + '-fsycl-link', + '--offload-compress', +] + +JIT_EXTENSION_VERSIONER = ExtensionVersioner() + +PLAT_TO_VCVARS = { + 'win32' : 'x86', + 'win-amd64' : 'x86_amd64', +} + +min_supported_cpython = "0x03090000" # Python 3.9 hexcode + +def get_cxx_compiler(): + if IS_WINDOWS: + compiler = os.environ.get('CXX', 'cl') + else: + compiler = os.environ.get('CXX', 'c++') + return compiler + +def _is_binary_build() -> bool: + return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__) + + +def _accepted_compilers_for_platform() -> list[str]: + # gnu-c++ and gnu-cc are the conda gcc compilers + return ['clang++', 'clang'] if IS_MACOS else ['g++', 'gcc', 'gnu-c++', 'gnu-cc', 'clang++', 'clang'] + +def _maybe_write(filename, new_content): + r''' + Equivalent to writing the content into the file but will not touch the file + if it already had the right content (to avoid triggering recompile). + ''' + if os.path.exists(filename): + with open(filename) as f: + content = f.read() + + if content == new_content: + # The file already contains the right thing! + return + + with open(filename, 'w') as source_file: + source_file.write(new_content) + +def get_default_build_root() -> str: + """ + Return the path to the root folder under which extensions will built. + + For each extension module built, there will be one folder underneath the + folder returned by this function. For example, if ``p`` is the path + returned by this function and ``ext`` the name of an extension, the build + folder for the extension will be ``p/ext``. + + This directory is **user-specific** so that multiple users on the same + machine won't meet permission issues. + """ + return os.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions')) + + +def check_compiler_ok_for_platform(compiler: str) -> bool: + """ + Verify that the compiler is the expected one for the current platform. + + Args: + compiler (str): The compiler executable to check. + + Returns: + True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS, + and always True for Windows. + """ + if IS_WINDOWS: + return True + compiler_path = shutil.which(compiler) + if compiler_path is None: + return False + # Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'. + compiler_path = os.path.realpath(compiler_path) + # Check the compiler name + if any(name in compiler_path for name in _accepted_compilers_for_platform()): + return True + # If compiler wrapper is used try to infer the actual compiler by invoking it with -v flag + env = os.environ.copy() + env['LC_ALL'] = 'C' # Don't localize output + try: + version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) + except subprocess.CalledProcessError: + # If '-v' fails, try '--version' + version_string = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) + if IS_LINUX: + # Check for 'gcc' or 'g++' for sccache wrapper + pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE) + results = re.findall(pattern, version_string) + if len(results) != 1: + # Clang is also a supported compiler on Linux + # Though on Ubuntu it's sometimes called "Ubuntu clang version" + return 'clang version' in version_string + compiler_path = os.path.realpath(results[0].strip()) + # On RHEL/CentOS c++ is a gcc compiler wrapper + if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string: + return True + return any(name in compiler_path for name in _accepted_compilers_for_platform()) + if IS_MACOS: + # Check for 'clang' or 'clang++' + return version_string.startswith("Apple clang") + return False + + +def get_compiler_abi_compatibility_and_version(compiler) -> tuple[bool, TorchVersion]: + """ + Determine if the given compiler is ABI-compatible with PyTorch alongside its version. + + Args: + compiler (str): The compiler executable name to check (e.g. ``g++``). + Must be executable in a shell process. + + Returns: + A tuple that contains a boolean that defines if the compiler is (likely) ABI-incompatible with PyTorch, + followed by a `TorchVersion` string that contains the compiler version separated by dots. + """ + if not _is_binary_build(): + return (True, TorchVersion('0.0.0')) + if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']: + return (True, TorchVersion('0.0.0')) + + # First check if the compiler is one of the expected ones for the particular platform. + if not check_compiler_ok_for_platform(compiler): + logger.warning(WRONG_COMPILER_WARNING, compiler, _accepted_compilers_for_platform()[0], sys.platform, _accepted_compilers_for_platform()[0]) + return (False, TorchVersion('0.0.0')) + + if IS_MACOS: + # There is no particular minimum version we need for clang, so we're good here. + return (True, TorchVersion('0.0.0')) + try: + if IS_LINUX: + minimum_required_version = MINIMUM_GCC_VERSION + versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion']) + version = versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip().split('.') + else: + minimum_required_version = MINIMUM_MSVC_VERSION + compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT) + match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip()) + version = ['0', '0', '0'] if match is None else list(match.groups()) + except Exception: + _, error, _ = sys.exc_info() + logger.warning('Error checking compiler version for %s: %s', compiler, error) + return (False, TorchVersion('0.0.0')) + + # convert alpha-numeric string to numeric string + # amdclang++ returns str like 0.0.0git, others return 0.0.0 + numeric_version = [re.sub(r'\D', '', v) for v in version] + + if tuple(map(int, numeric_version)) >= minimum_required_version: + return (True, TorchVersion('.'.join(numeric_version))) + + compiler = f'{compiler} {".".join(numeric_version)}' + logger.warning(ABI_INCOMPATIBILITY_WARNING, compiler) + + return (False, TorchVersion('.'.join(numeric_version))) + + +def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> None: + if not CUDA_HOME: + raise RuntimeError(CUDA_NOT_FOUND_MESSAGE) + + nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc.exe' if IS_WINDOWS else 'nvcc') + if not os.path.exists(nvcc): + raise FileNotFoundError(f"nvcc not found at '{nvcc}'. Ensure CUDA path '{CUDA_HOME}' is correct.") + + cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode(*SUBPROCESS_DECODE_ARGS) + cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str) + if cuda_version is None: + return + + cuda_str_version = cuda_version.group(1) + cuda_ver = Version(cuda_str_version) + if torch.version.cuda is None: + return + + torch_cuda_version = Version(torch.version.cuda) + if cuda_ver != torch_cuda_version: + # major/minor attributes are only available in setuptools>=49.4.0 + if getattr(cuda_ver, "major", None) is None: + raise ValueError("setuptools>=49.4.0 is required") + if cuda_ver.major != torch_cuda_version.major: + raise RuntimeError(CUDA_MISMATCH_MESSAGE, cuda_str_version, torch.version.cuda) + logger.warning(CUDA_MISMATCH_WARN, cuda_str_version, torch.version.cuda) + + if not (sys.platform.startswith('linux') and + os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') not in ['ON', '1', 'YES', 'TRUE', 'Y'] and + _is_binary_build()): + return + + cuda_compiler_bounds: VersionMap = CUDA_CLANG_VERSIONS if compiler_name.startswith('clang') else CUDA_GCC_VERSIONS + + if cuda_str_version not in cuda_compiler_bounds: + logger.warning('There are no %s version bounds defined for CUDA version %s', compiler_name, cuda_str_version) + else: + min_compiler_version, max_excl_compiler_version = cuda_compiler_bounds[cuda_str_version] + # Special case for 11.4.0, which has lower compiler bounds than 11.4.1 + if "V11.4.48" in cuda_version_str and cuda_compiler_bounds == CUDA_GCC_VERSIONS: + max_excl_compiler_version = (11, 0) + min_compiler_version_str = '.'.join(map(str, min_compiler_version)) + max_excl_compiler_version_str = '.'.join(map(str, max_excl_compiler_version)) + + version_bound_str = f'>={min_compiler_version_str}, <{max_excl_compiler_version_str}' + + if compiler_version < TorchVersion(min_compiler_version_str): + raise RuntimeError( + f'The current installed version of {compiler_name} ({compiler_version}) is less ' + f'than the minimum required version by CUDA {cuda_str_version} ({min_compiler_version_str}). ' + f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).' + ) + if compiler_version >= TorchVersion(max_excl_compiler_version_str): + raise RuntimeError( + f'The current installed version of {compiler_name} ({compiler_version}) is greater ' + f'than the maximum required version by CUDA {cuda_str_version}. ' + f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).' + ) + +# Specify Visual Studio C runtime library for hipcc +def _set_hipcc_runtime_lib(is_standalone, debug): + if is_standalone: + if debug: + COMMON_HIP_FLAGS.append('-fms-runtime-lib=static_dbg') + else: + COMMON_HIP_FLAGS.append('-fms-runtime-lib=static') + else: + if debug: + COMMON_HIP_FLAGS.append('-fms-runtime-lib=dll_dbg') + else: + COMMON_HIP_FLAGS.append('-fms-runtime-lib=dll') + +def _append_sycl_std_if_no_std_present(cflags): + if not any(flag.startswith('-sycl-std=') for flag in cflags): + cflags.append('-sycl-std=2020') + + +def _wrap_sycl_host_flags(cflags): + host_cxx = get_cxx_compiler() + host_cflags = [ + f'-fsycl-host-compiler={host_cxx}', + shlex.quote(f'-fsycl-host-compiler-options={cflags}'), + ] + return host_cflags + + +class BuildExtension(build_ext): + """ + A custom :mod:`setuptools` build extension . + + This :class:`setuptools.build_ext` subclass takes care of passing the + minimum required compiler flags (e.g. ``-std=c++17``) as well as mixed + C++/CUDA/SYCL compilation (and support for CUDA/SYCL files in general). + + When using :class:`BuildExtension`, it is allowed to supply a dictionary + for ``extra_compile_args`` (rather than the usual list) that maps from + languages/compilers (the only expected values are ``cxx``, ``nvcc`` or + ``sycl``) to a list of additional compiler flags to supply to the compiler. + This makes it possible to supply different flags to the C++, CUDA and SYCL + compiler during mixed compilation. + + ``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we + attempt to build using the Ninja backend. Ninja greatly speeds up + compilation compared to the standard ``setuptools.build_ext``. + Fallbacks to the standard distutils backend if Ninja is not available. + + .. note:: + By default, the Ninja backend uses #CPUS + 2 workers to build the + extension. This may use up too many resources on some systems. One + can control the number of workers by setting the `MAX_JOBS` environment + variable to a non-negative number. + """ + + @classmethod + def with_options(cls, **options): + """Return a subclass with alternative constructor that extends any original keyword arguments to the original constructor with the given options.""" + class cls_with_options(cls): # type: ignore[misc, valid-type] + def __init__(self, *args, **kwargs): + kwargs.update(options) + super().__init__(*args, **kwargs) + + return cls_with_options + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False) + + self.use_ninja = kwargs.get('use_ninja', True) + if self.use_ninja: + # Test if we can use ninja. Fallback otherwise. + msg = ('Attempted to use ninja as the BuildExtension backend but ' + '%s. Falling back to using the slow distutils backend.') + if not is_ninja_available(): + logger.warning(msg, 'we could not find ninja.') + self.use_ninja = False + + def finalize_options(self) -> None: + super().finalize_options() + if self.use_ninja: + self.force = True + + def build_extensions(self) -> None: + compiler_name, compiler_version = self._check_abi() + + cuda_ext = False + sycl_ext = False + extension_iter = iter(self.extensions) + extension = next(extension_iter, None) + while not (cuda_ext and sycl_ext) and extension: + for source in extension.sources: + _, ext = os.path.splitext(source) + if ext == '.cu': + cuda_ext = True + elif ext == '.sycl': + sycl_ext = True + + # This check accounts on a case when cuda and sycl sources + # are mixed in the same extension. We can stop checking + # sources if both are found or there is no more sources. + if cuda_ext and sycl_ext: + break + + extension = next(extension_iter, None) + + if sycl_ext: + assert self.use_ninja, "ninja is required to build sycl extensions." + + if cuda_ext and not IS_HIP_EXTENSION: + _check_cuda_version(compiler_name, compiler_version) + + for extension in self.extensions: + # Ensure at least an empty list of flags for 'cxx', 'nvcc' and 'sycl' when + # extra_compile_args is a dict. Otherwise, default torch flags do + # not get passed. Necessary when only one of 'cxx', 'nvcc' or 'sycl' is + # passed to extra_compile_args in CUDAExtension or SyclExtension, i.e. + # CUDAExtension(..., extra_compile_args={'cxx': [...]}) + # or + # CUDAExtension(..., extra_compile_args={'nvcc': [...]}) + if isinstance(extension.extra_compile_args, dict): + for ext in ['cxx', 'nvcc', 'sycl']: + if ext not in extension.extra_compile_args: + extension.extra_compile_args[ext] = [] + + self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H') + + if IS_HIP_EXTENSION: + self._hipify_compile_flags(extension) + + if extension.py_limited_api: + # compile any extension that has passed in py_limited_api to the + # Extension constructor with the Py_LIMITED_API flag set to our + # min supported CPython version. + # See https://docs.python.org/3/c-api/stable.html#c.Py_LIMITED_API + self._add_compile_flag(extension, f'-DPy_LIMITED_API={min_supported_cpython}') + else: + # pybind11 is not CPython API stable so don't add these flags used when + # compiling pybind11 when pybind11 is not even used. otherwise, the build + # logs are confusing. + # See note [Pybind11 ABI constants] + for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + val = getattr(torch._C, f"_PYBIND11_{name}") + if val is not None and not IS_WINDOWS: + self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"') + self._define_torch_extension_name(extension) + + if 'nvcc_dlink' in extension.extra_compile_args: + assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}." + + # Register .cu, .cuh, .hip, .mm and .sycl as valid source extensions. + # NOTE: At the moment .sycl is not a standard extension for SYCL supported + # by compiler. Here we introduce a torch level convention that SYCL sources + # should have .sycl file extension. + self.compiler.src_extensions += ['.cu', '.cuh', '.hip', '.sycl'] + if torch.backends.mps.is_built(): + self.compiler.src_extensions += ['.mm'] + # Save the original _compile method for later. + if self.compiler.compiler_type == 'msvc': + self.compiler._cpp_extensions += ['.cu', '.cuh'] + original_compile = self.compiler.compile + original_spawn = self.compiler.spawn + else: + original_compile = self.compiler._compile + + def append_std17_if_no_std_present(cflags) -> None: + # NVCC does not allow multiple -std to be passed, so we avoid + # overriding the option if the user explicitly passed it. + cpp_format_prefix = '/{}:' if self.compiler.compiler_type == 'msvc' else '-{}=' + cpp_flag_prefix = cpp_format_prefix.format('std') + cpp_flag = cpp_flag_prefix + 'c++17' + if not any(flag.startswith(cpp_flag_prefix) for flag in cflags): + cflags.append(cpp_flag) + + def unix_cuda_flags(cflags): + cflags = (COMMON_NVCC_FLAGS + + ['--compiler-options', "'-fPIC'"] + + cflags + _get_cuda_arch_flags(cflags)) + + # NVCC does not allow multiple -ccbin/--compiler-bindir to be passed, so we avoid + # overriding the option if the user explicitly passed it. + _ccbin = os.getenv("CC") + if ( + _ccbin is not None + and not any(flag.startswith(('-ccbin', '--compiler-bindir')) for flag in cflags) + ): + cflags.extend(['-ccbin', _ccbin]) + + return cflags + + def convert_to_absolute_paths_inplace(paths): + # Helper function. See Note [Absolute include_dirs] + if paths is not None: + for i in range(len(paths)): + if not os.path.isabs(paths[i]): + paths[i] = os.path.abspath(paths[i]) + + def unix_wrap_single_compile(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None: + # Copy before we make any modifications. + cflags = copy.deepcopy(extra_postargs) + try: + original_compiler = self.compiler.compiler_so + if _is_cuda_file(src): + nvcc = [_join_rocm_home('bin', 'hipcc') if IS_HIP_EXTENSION else _join_cuda_home('bin', 'nvcc')] + self.compiler.set_executable('compiler_so', nvcc) + if isinstance(cflags, dict): + cflags = cflags['nvcc'] + if IS_HIP_EXTENSION: + cflags = COMMON_HIPCC_FLAGS + cflags + _get_rocm_arch_flags(cflags) + else: + cflags = unix_cuda_flags(cflags) + elif isinstance(cflags, dict): + cflags = cflags['cxx'] + if IS_HIP_EXTENSION: + cflags = COMMON_HIP_FLAGS + cflags + append_std17_if_no_std_present(cflags) + + original_compile(obj, src, ext, cc_args, cflags, pp_opts) + finally: + # Put the original compiler back in place. + self.compiler.set_executable('compiler_so', original_compiler) + + def unix_wrap_ninja_compile(sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None): + r"""Compiles sources by outputting a ninja file and running it.""" + # NB: I copied some lines from self.compiler (which is an instance + # of distutils.UnixCCompiler). See the following link. + # https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567 + # This can be fragile, but a lot of other repos also do this + # (see https://github.com/search?q=_setup_compile&type=Code) + # so it is probably OK; we'll also get CI signal if/when + # we update our python version (which is when distutils can be + # upgraded) + + # Use absolute path for output_dir so that the object file paths + # (`objects`) get generated with absolute paths. + output_dir = os.path.abspath(output_dir) + + # See Note [Absolute include_dirs] + convert_to_absolute_paths_inplace(self.compiler.include_dirs) + + _, objects, extra_postargs, pp_opts, _ = \ + self.compiler._setup_compile(output_dir, macros, + include_dirs, sources, + depends, extra_postargs) + common_cflags = self.compiler._get_cc_args(pp_opts, debug, extra_preargs) + extra_cc_cflags = self.compiler.compiler_so[1:] + with_cuda = any(map(_is_cuda_file, sources)) + with_sycl = any(map(_is_sycl_file, sources)) + + # extra_postargs can be either: + # - a dict mapping cxx/nvcc/sycl to extra flags + # - a list of extra flags. + if isinstance(extra_postargs, dict): + post_cflags = extra_postargs['cxx'] + else: + post_cflags = list(extra_postargs) + if IS_HIP_EXTENSION: + post_cflags = COMMON_HIP_FLAGS + post_cflags + append_std17_if_no_std_present(post_cflags) + + cuda_post_cflags = None + cuda_cflags = None + if with_cuda: + cuda_cflags = common_cflags + if isinstance(extra_postargs, dict): + cuda_post_cflags = extra_postargs['nvcc'] + else: + cuda_post_cflags = list(extra_postargs) + if IS_HIP_EXTENSION: + cuda_post_cflags = cuda_post_cflags + _get_rocm_arch_flags(cuda_post_cflags) + cuda_post_cflags = COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS + cuda_post_cflags + else: + cuda_post_cflags = unix_cuda_flags(cuda_post_cflags) + append_std17_if_no_std_present(cuda_post_cflags) + cuda_cflags = [shlex.quote(f) for f in cuda_cflags] + cuda_post_cflags = [shlex.quote(f) for f in cuda_post_cflags] + + if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs: + cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink']) + cuda_dlink_post_cflags = [shlex.quote(f) for f in cuda_dlink_post_cflags] + else: + cuda_dlink_post_cflags = None + + sycl_post_cflags = None + sycl_cflags = None + sycl_dlink_post_cflags = None + if with_sycl: + sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS + if isinstance(extra_postargs, dict): + sycl_post_cflags = extra_postargs['sycl'] + else: + sycl_post_cflags = list(extra_postargs) + _append_sycl_targets_if_missing(sycl_post_cflags) + append_std17_if_no_std_present(sycl_cflags) + _append_sycl_std_if_no_std_present(sycl_cflags) + host_cflags = extra_cc_cflags + common_cflags + post_cflags + append_std17_if_no_std_present(host_cflags) + # escaping quoted arguments to pass them thru SYCL compiler + host_cflags = [item.replace('"', '\\\\"') for item in host_cflags] + host_cflags = ' '.join(host_cflags) + # Note the order: shlex.quote sycl_flags first, _wrap_sycl_host_flags + # second. Reason is that sycl host flags are quoted, space containing + # strings passed to SYCL compiler. + sycl_cflags = [shlex.quote(f) for f in sycl_cflags] + sycl_cflags += _wrap_sycl_host_flags(host_cflags) + sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS.copy() + sycl_dlink_post_cflags += _get_sycl_device_flags(sycl_post_cflags) + sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags] + + _write_ninja_file_and_compile_objects( + sources=sources, + objects=objects, + cflags=[shlex.quote(f) for f in extra_cc_cflags + common_cflags], + post_cflags=[shlex.quote(f) for f in post_cflags], + cuda_cflags=cuda_cflags, + cuda_post_cflags=cuda_post_cflags, + cuda_dlink_post_cflags=cuda_dlink_post_cflags, + sycl_cflags=sycl_cflags, + sycl_post_cflags=sycl_post_cflags, + sycl_dlink_post_cflags=sycl_dlink_post_cflags, + build_directory=output_dir, + verbose=True, + with_cuda=with_cuda, + with_sycl=with_sycl) + + # Return *all* object filenames, not just the ones we just built. + return objects + + def win_cuda_flags(cflags): + return (COMMON_NVCC_FLAGS + + cflags + _get_cuda_arch_flags(cflags)) + + def win_hip_flags(cflags): + return (COMMON_HIPCC_FLAGS + COMMON_HIP_FLAGS + cflags + _get_rocm_arch_flags(cflags)) + + def win_wrap_single_compile(sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None): + + self.cflags = copy.deepcopy(extra_postargs) + extra_postargs = None + + def spawn(cmd): + # Using regex to match src, obj and include files + src_regex = re.compile('/T(p|c)(.*)') + src_list = [ + m.group(2) for m in (src_regex.match(elem) for elem in cmd) + if m + ] + + obj_regex = re.compile('/Fo(.*)') + obj_list = [ + m.group(1) for m in (obj_regex.match(elem) for elem in cmd) + if m + ] + + include_regex = re.compile(r'((\-|\/)I.*)') + include_list = [ + m.group(1) + for m in (include_regex.match(elem) for elem in cmd) if m + ] + + if len(src_list) >= 1 and len(obj_list) >= 1: + src = src_list[0] + obj = obj_list[0] + if _is_cuda_file(src): + if IS_HIP_EXTENSION: + nvcc = _get_hipcc_path() + else: + nvcc = _join_cuda_home('bin', 'nvcc') + if isinstance(self.cflags, dict): + cflags = self.cflags['nvcc'] + elif isinstance(self.cflags, list): + cflags = self.cflags + else: + cflags = [] + + if IS_HIP_EXTENSION: + cflags = win_hip_flags(cflags) + else: + cflags = win_cuda_flags(cflags) + ['-std=c++17', '--use-local-env'] + for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS: + cflags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cflags + for flag in COMMON_MSVC_FLAGS: + cflags = ['-Xcompiler', flag] + cflags + cmd = [nvcc, '-c', src, '-o', obj] + include_list + cflags + elif isinstance(self.cflags, dict): + cflags = COMMON_MSVC_FLAGS + self.cflags['cxx'] + append_std17_if_no_std_present(cflags) + cmd += cflags + elif isinstance(self.cflags, list): + cflags = COMMON_MSVC_FLAGS + self.cflags + append_std17_if_no_std_present(cflags) + cmd += cflags + + return original_spawn(cmd) + + try: + self.compiler.spawn = spawn + return original_compile(sources, output_dir, macros, + include_dirs, debug, extra_preargs, + extra_postargs, depends) + finally: + self.compiler.spawn = original_spawn + + def win_wrap_ninja_compile(sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None, + is_standalone=False): + if not self.compiler.initialized: + self.compiler.initialize() + output_dir = os.path.abspath(output_dir) + + # Note [Absolute include_dirs] + # Convert relative path in self.compiler.include_dirs to absolute path if any. + # For ninja build, the build location is not local, but instead, the build happens + # in a script-created build folder. Thus, relative paths lose their correctness. + # To be consistent with jit extension, we allow user to enter relative include_dirs + # in setuptools.setup, and we convert the relative path to absolute path here. + convert_to_absolute_paths_inplace(self.compiler.include_dirs) + + _, objects, extra_postargs, pp_opts, _ = \ + self.compiler._setup_compile(output_dir, macros, + include_dirs, sources, + depends, extra_postargs) + # Replace space with \ when using hipcc (hipcc passes includes to clang without ""s so clang sees space in include paths as new argument) + if IS_HIP_EXTENSION: + pp_opts = ["-I{}".format(s[2:].replace(" ", "\\")) if s.startswith('-I') else s for s in pp_opts] + common_cflags = extra_preargs or [] + cflags = [] + if debug: + cflags.extend(self.compiler.compile_options_debug) + else: + cflags.extend(self.compiler.compile_options) + cflags = cflags + common_cflags + pp_opts + COMMON_MSVC_FLAGS + if IS_HIP_EXTENSION: + _set_hipcc_runtime_lib(is_standalone, debug) + common_cflags.extend(COMMON_HIP_FLAGS) + else: + common_cflags.extend(COMMON_MSVC_FLAGS) + with_cuda = any(map(_is_cuda_file, sources)) + + # extra_postargs can be either: + # - a dict mapping cxx/nvcc to extra flags + # - a list of extra flags. + if isinstance(extra_postargs, dict): + post_cflags = extra_postargs['cxx'] + else: + post_cflags = list(extra_postargs) + if IS_HIP_EXTENSION: + post_cflags = COMMON_HIP_FLAGS + post_cflags + append_std17_if_no_std_present(post_cflags) + + cuda_post_cflags = None + cuda_cflags = None + if with_cuda: + cuda_cflags = ['-std=c++17'] + for common_cflag in common_cflags: + cuda_cflags.append('-Xcompiler') + cuda_cflags.append(common_cflag) + if not IS_HIP_EXTENSION: + cuda_cflags.append('--use-local-env') + for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS: + cuda_cflags.append('-Xcudafe') + cuda_cflags.append('--diag_suppress=' + ignore_warning) + cuda_cflags.extend(pp_opts) + if isinstance(extra_postargs, dict): + cuda_post_cflags = extra_postargs['nvcc'] + else: + cuda_post_cflags = list(extra_postargs) + if IS_HIP_EXTENSION: + cuda_post_cflags = win_hip_flags(cuda_post_cflags) + else: + cuda_post_cflags = win_cuda_flags(cuda_post_cflags) + cflags = _nt_quote_args(cflags) + post_cflags = _nt_quote_args(post_cflags) + if with_cuda: + cuda_cflags = _nt_quote_args(cuda_cflags) + cuda_post_cflags = _nt_quote_args(cuda_post_cflags) + if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs: + cuda_dlink_post_cflags = win_cuda_flags(extra_postargs['nvcc_dlink']) + else: + cuda_dlink_post_cflags = None + + _write_ninja_file_and_compile_objects( + sources=sources, + objects=objects, + cflags=cflags, + post_cflags=post_cflags, + cuda_cflags=cuda_cflags, + cuda_post_cflags=cuda_post_cflags, + cuda_dlink_post_cflags=cuda_dlink_post_cflags, + sycl_cflags=None, + sycl_post_cflags=None, + sycl_dlink_post_cflags=None, + build_directory=output_dir, + verbose=True, + with_cuda=with_cuda, + with_sycl=False) + + # Return *all* object filenames, not just the ones we just built. + return objects + # Monkey-patch the _compile or compile method. + # https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511 + if self.compiler.compiler_type == 'msvc': + if self.use_ninja: + self.compiler.compile = win_wrap_ninja_compile + else: + self.compiler.compile = win_wrap_single_compile + else: + if self.use_ninja: + self.compiler.compile = unix_wrap_ninja_compile + else: + self.compiler._compile = unix_wrap_single_compile + + build_ext.build_extensions(self) + + def get_ext_filename(self, ext_name): + # Get the original shared library name. For Python 3, this name will be + # suffixed with ".so", where will be something like + # cpython-37m-x86_64-linux-gnu. + ext_filename = super().get_ext_filename(ext_name) + # If `no_python_abi_suffix` is `True`, we omit the Python 3 ABI + # component. This makes building shared libraries with setuptools that + # aren't Python modules nicer. + if self.no_python_abi_suffix: + # The parts will be e.g. ["my_extension", "cpython-37m-x86_64-linux-gnu", "so"]. + ext_filename_parts = ext_filename.split('.') + # Omit the second to last element. + without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:] + ext_filename = '.'.join(without_abi) + return ext_filename + + def _check_abi(self) -> tuple[str, TorchVersion]: + # On some platforms, like Windows, compiler_cxx is not available. + if hasattr(self.compiler, 'compiler_cxx'): + compiler = self.compiler.compiler_cxx[0] + else: + compiler = get_cxx_compiler() + _, version = get_compiler_abi_compatibility_and_version(compiler) + # Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set. + if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ: + msg = ('It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.' + 'This may lead to multiple activations of the VC env.' + 'Please set `DISTUTILS_USE_SDK=1` and try again.') + raise UserWarning(msg) + return compiler, version + + def _add_compile_flag(self, extension, flag): + extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args) + if isinstance(extension.extra_compile_args, dict): + for args in extension.extra_compile_args.values(): + args.append(flag) + else: + extension.extra_compile_args.append(flag) + + # Simple hipify, replace the first occurrence of CUDA with HIP + # in flags starting with "-" and containing "CUDA", but exclude -I flags + def _hipify_compile_flags(self, extension): + if isinstance(extension.extra_compile_args, dict) and 'nvcc' in extension.extra_compile_args: + modified_flags = [] + for flag in extension.extra_compile_args['nvcc']: + if flag.startswith("-") and "CUDA" in flag and not flag.startswith("-I"): + # check/split flag into flag and value + parts = flag.split("=", 1) + if len(parts) == 2: + flag_part, value_part = parts + # replace fist instance of "CUDA" with "HIP" only in the flag and not flag value + modified_flag_part = flag_part.replace("CUDA", "HIP", 1) + modified_flag = f"{modified_flag_part}={value_part}" + else: + # replace fist instance of "CUDA" with "HIP" in flag + modified_flag = flag.replace("CUDA", "HIP", 1) + modified_flags.append(modified_flag) + logger.info('Modified flag: %s -> %s', flag, modified_flag) + else: + modified_flags.append(flag) + extension.extra_compile_args['nvcc'] = modified_flags + + def _define_torch_extension_name(self, extension): + # pybind11 doesn't support dots in the names + # so in order to support extensions in the packages + # like torch._C, we take the last part of the string + # as the library name + names = extension.name.split('.') + name = names[-1] + define = f'-DTORCH_EXTENSION_NAME={name}' + self._add_compile_flag(extension, define) + + +def CppExtension(name, sources, *args, **kwargs): + """ + Create a :class:`setuptools.Extension` for C++. + + Convenience method that creates a :class:`setuptools.Extension` with the + bare minimum (but often sufficient) arguments to build a C++ extension. + + All arguments are forwarded to the :class:`setuptools.Extension` + constructor. Full list arguments can be found at + https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference + + .. warning:: + The PyTorch python API (as provided in libtorch_python) cannot be built + with the flag ``py_limited_api=True``. When this flag is passed, it is + the user's responsibility in their library to not use APIs from + libtorch_python (in particular pytorch/python bindings) and to only use + APIs from libtorch (aten objects, operators and the dispatcher). For + example, to give access to custom ops from python, the library should + register the ops through the dispatcher. + + Contrary to CPython setuptools, who does not define -DPy_LIMITED_API + as a compile flag when py_limited_api is specified as an option for + the "bdist_wheel" command in ``setup``, PyTorch does! We will specify + -DPy_LIMITED_API=min_supported_cpython to best enforce consistency, + safety, and sanity in order to encourage best practices. To target a + different version, set min_supported_cpython to the hexcode of the + CPython version of choice. + + Example: + >>> # xdoctest: +SKIP + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT) + >>> from setuptools import setup + >>> from torch.utils.cpp_extension import BuildExtension, CppExtension + >>> setup( + ... name='extension', + ... ext_modules=[ + ... CppExtension( + ... name='extension', + ... sources=['extension.cpp'], + ... extra_compile_args=['-g'], + ... extra_link_args=['-Wl,--no-as-needed', '-lm']) + ... ], + ... cmdclass={ + ... 'build_ext': BuildExtension + ... }) + """ + include_dirs = kwargs.get('include_dirs', []) + include_dirs += include_paths() + kwargs['include_dirs'] = include_dirs + + library_dirs = kwargs.get('library_dirs', []) + library_dirs += library_paths() + kwargs['library_dirs'] = library_dirs + + libraries = kwargs.get('libraries', []) + libraries.append('c10') + libraries.append('torch') + libraries.append('torch_cpu') + if not kwargs.get('py_limited_api', False): + # torch_python uses more than the python limited api + libraries.append('torch_python') + if IS_WINDOWS: + libraries.append("sleef") + + kwargs['libraries'] = libraries + + kwargs['language'] = 'c++' + return setuptools.Extension(name, sources, *args, **kwargs) + + +def CUDAExtension(name, sources, *args, **kwargs): + """ + Create a :class:`setuptools.Extension` for CUDA/C++. + + Convenience method that creates a :class:`setuptools.Extension` with the + bare minimum (but often sufficient) arguments to build a CUDA/C++ + extension. This includes the CUDA include path, library path and runtime + library. + + All arguments are forwarded to the :class:`setuptools.Extension` + constructor. Full list arguments can be found at + https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference + + .. warning:: + The PyTorch python API (as provided in libtorch_python) cannot be built + with the flag ``py_limited_api=True``. When this flag is passed, it is + the user's responsibility in their library to not use APIs from + libtorch_python (in particular pytorch/python bindings) and to only use + APIs from libtorch (aten objects, operators and the dispatcher). For + example, to give access to custom ops from python, the library should + register the ops through the dispatcher. + + Contrary to CPython setuptools, who does not define -DPy_LIMITED_API + as a compile flag when py_limited_api is specified as an option for + the "bdist_wheel" command in ``setup``, PyTorch does! We will specify + -DPy_LIMITED_API=min_supported_cpython to best enforce consistency, + safety, and sanity in order to encourage best practices. To target a + different version, set min_supported_cpython to the hexcode of the + CPython version of choice. + + Example: + >>> # xdoctest: +SKIP + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT) + >>> from setuptools import setup + >>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension + >>> setup( + ... name='cuda_extension', + ... ext_modules=[ + ... CUDAExtension( + ... name='cuda_extension', + ... sources=['extension.cpp', 'extension_kernel.cu'], + ... extra_compile_args={'cxx': ['-g'], + ... 'nvcc': ['-O2']}, + ... extra_link_args=['-Wl,--no-as-needed', '-lcuda']) + ... ], + ... cmdclass={ + ... 'build_ext': BuildExtension + ... }) + + Compute capabilities: + + By default the extension will be compiled to run on all archs of the cards visible during the + building process of the extension, plus PTX. If down the road a new card is installed the + extension may need to be recompiled. If a visible card has a compute capability (CC) that's + newer than the newest version for which your nvcc can build fully-compiled binaries, PyTorch + will make nvcc fall back to building kernels with the newest version of PTX your nvcc does + support (see below for details on PTX). + + You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which + CCs you want the extension to support: + + ``TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py`` + ``TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py`` + + The +PTX option causes extension kernel binaries to include PTX instructions for the specified + CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >= + the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with + CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to + provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on + those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better + off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6, + "8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but + "8.0 8.6" would be better. + + Note that while it's possible to include all supported archs, the more archs get included the + slower the building process will be, as it will build a separate kernel image for each arch. + + Note that CUDA-11.5 nvcc will hit internal compiler error while parsing torch/extension.h on Windows. + To workaround the issue, move python binding logic to pure C++ file. + + Example use: + #include + at::Tensor SigmoidAlphaBlendForwardCuda(....) + + Instead of: + #include + torch::Tensor SigmoidAlphaBlendForwardCuda(...) + + Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460 + Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48 + + Relocatable device code linking: + + If you want to reference device symbols across compilation units (across object files), + the object files need to be built with `relocatable device code` (-rdc=true or -dc). + An exception to this rule is "dynamic parallelism" (nested kernel launches) which is not used a lot anymore. + `Relocatable device code` is less optimized so it needs to be used only on object files that need it. + Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step + helps reduce the protentional perf degradation of `-rdc`. + Note that it needs to be used at both steps to be useful. + + If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step. + There is also a case where `-dlink` is used without `-rdc`: + when an extension is linked against a static lib containing rdc-compiled objects + like the [NVSHMEM library](https://developer.nvidia.com/nvshmem). + + Note: Ninja is required to build a CUDA Extension with RDC linking. + + Example: + >>> # xdoctest: +SKIP + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT) + >>> CUDAExtension( + ... name='cuda_extension', + ... sources=['extension.cpp', 'extension_kernel.cu'], + ... dlink=True, + ... dlink_libraries=["dlink_lib"], + ... extra_compile_args={'cxx': ['-g'], + ... 'nvcc': ['-O2', '-rdc=true']}) + """ + library_dirs = kwargs.get('library_dirs', []) + library_dirs += library_paths(device_type="cuda") + kwargs['library_dirs'] = library_dirs + + libraries = kwargs.get('libraries', []) + libraries.append('c10') + libraries.append('torch') + libraries.append('torch_cpu') + if not kwargs.get('py_limited_api', False): + # torch_python uses more than the python limited api + libraries.append('torch_python') + if IS_HIP_EXTENSION: + libraries.append('amdhip64') + libraries.append('c10_hip') + libraries.append('torch_hip') + else: + libraries.append('cudart') + libraries.append('c10_cuda') + libraries.append('torch_cuda') + kwargs['libraries'] = libraries + + include_dirs = kwargs.get('include_dirs', []) + + if IS_HIP_EXTENSION: + build_dir = os.getcwd() + hipify_result = hipify_python.hipify( + project_directory=build_dir, + output_directory=build_dir, + header_include_dirs=include_dirs, + includes=[os.path.join(build_dir, '*')], # limit scope to build_dir only + extra_files=[os.path.abspath(s) for s in sources], + show_detailed=True, + is_pytorch_extension=True, + hipify_extra_files_only=True, # don't hipify everything in includes path + ) + + hipified_sources = set() + for source in sources: + s_abs = os.path.abspath(source) + hipified_s_abs = (hipify_result[s_abs].hipified_path if (s_abs in hipify_result and + hipify_result[s_abs].hipified_path is not None) else s_abs) + # setup() arguments must *always* be /-separated paths relative to the setup.py directory, + # *never* absolute paths + hipified_sources.add(os.path.relpath(hipified_s_abs, build_dir)) + + sources = list(hipified_sources) + + include_dirs += include_paths(device_type="cuda") + kwargs['include_dirs'] = include_dirs + + kwargs['language'] = 'c++' + + dlink_libraries = kwargs.get('dlink_libraries', []) + dlink = kwargs.get('dlink', False) or dlink_libraries + if dlink: + extra_compile_args = kwargs.get('extra_compile_args', {}) + + extra_compile_args_dlink = extra_compile_args.get('nvcc_dlink', []) + extra_compile_args_dlink += ['-dlink'] + extra_compile_args_dlink += [f'-L{x}' for x in library_dirs] + extra_compile_args_dlink += [f'-l{x}' for x in dlink_libraries] + + if (torch.version.cuda is not None) and TorchVersion(torch.version.cuda) >= '11.2': + extra_compile_args_dlink += ['-dlto'] # Device Link Time Optimization started from cuda 11.2 + + extra_compile_args['nvcc_dlink'] = extra_compile_args_dlink + + kwargs['extra_compile_args'] = extra_compile_args + + return setuptools.Extension(name, sources, *args, **kwargs) + + +def SyclExtension(name, sources, *args, **kwargs): + r""" + Creates a :class:`setuptools.Extension` for SYCL/C++. + + Convenience method that creates a :class:`setuptools.Extension` with the + bare minimum (but often sufficient) arguments to build a SYCL/C++ + extension. + + All arguments are forwarded to the :class:`setuptools.Extension` + constructor. + + .. warning:: + The PyTorch python API (as provided in libtorch_python) cannot be built + with the flag ``py_limited_api=True``. When this flag is passed, it is + the user's responsibility in their library to not use APIs from + libtorch_python (in particular pytorch/python bindings) and to only use + APIs from libtorch (aten objects, operators and the dispatcher). For + example, to give access to custom ops from python, the library should + register the ops through the dispatcher. + + Contrary to CPython setuptools, who does not define -DPy_LIMITED_API + as a compile flag when py_limited_api is specified as an option for + the "bdist_wheel" command in ``setup``, PyTorch does! We will specify + -DPy_LIMITED_API=min_supported_cpython to best enforce consistency, + safety, and sanity in order to encourage best practices. To target a + different version, set min_supported_cpython to the hexcode of the + CPython version of choice. + + Example: + >>> # xdoctest: +SKIP + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT) + >>> from torch.utils.cpp_extension import BuildExtension, SyclExtension + >>> setup( + ... name='xpu_extension', + ... ext_modules=[ + ... SyclExtension( + ... name='xpu_extension', + ... sources=['extension.cpp', 'extension_kernel.cpp'], + ... extra_compile_args={'cxx': ['-g', '-std=c++20', '-fPIC']}) + ... ], + ... cmdclass={ + ... 'build_ext': BuildExtension + ... }) + + By default the extension will be compiled to run on all archs of the cards visible during the + building process of the extension. If down the road a new card is installed the + extension may need to be recompiled. You can override the default behavior using + `TORCH_XPU_ARCH_LIST` to explicitly specify which device architectures you want the extension + to support: + + ``TORCH_XPU_ARCH_LIST="pvc,xe-lpg" python build_my_extension.py`` + + Note that while it's possible to include all supported archs, the more archs get included the + slower the building process will be, as it will build a separate kernel image for each arch. + + Note: Ninja is required to build SyclExtension. + """ + library_dirs = kwargs.get("library_dirs", []) + library_dirs += library_paths() + kwargs["library_dirs"] = library_dirs + + libraries = kwargs.get("libraries", []) + libraries.append("c10") + libraries.append("c10_xpu") + libraries.append("torch") + libraries.append("torch_cpu") + if not kwargs.get('py_limited_api', False): + # torch_python uses more than the python limited api + libraries.append("torch_python") + libraries.append("torch_xpu") + kwargs["libraries"] = libraries + + include_dirs = kwargs.get("include_dirs", []) + include_dirs += include_paths() + kwargs["include_dirs"] = include_dirs + + kwargs["language"] = "c++" + + return setuptools.Extension(name, sources, *args, **kwargs) + +def include_paths(device_type: str = "cpu") -> list[str]: + """ + Get the include paths required to build a C++ or CUDA or SYCL extension. + + Args: + device_type: Defaults to "cpu". + Returns: + A list of include path strings. + """ + lib_include = os.path.join(_TORCH_PATH, 'include') + paths = [ + lib_include, + # Remove this once torch/torch.h is officially no longer supported for C++ extensions. + os.path.join(lib_include, 'torch', 'csrc', 'api', 'include'), + ] + if device_type == "cuda" and IS_HIP_EXTENSION: + paths.append(os.path.join(lib_include, 'THH')) + paths.append(_join_rocm_home('include')) + elif device_type == "cuda": + cuda_home_include = _join_cuda_home('include') + # if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home. + # but gcc doesn't like having /usr/include passed explicitly + if cuda_home_include != '/usr/include': + paths.append(cuda_home_include) + + # Support CUDA_INC_PATH env variable supported by CMake files + if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \ + cuda_inc_path != '/usr/include': + paths.append(cuda_inc_path) + if CUDNN_HOME is not None: + paths.append(os.path.join(CUDNN_HOME, 'include')) + elif device_type == "xpu": + paths.append(_join_sycl_home('include')) + paths.append(_join_sycl_home('include', 'sycl')) + return paths + + +def library_paths(device_type: str = "cpu") -> list[str]: + """ + Get the library paths required to build a C++ or CUDA extension. + + Args: + device_type: Defaults to "cpu". + + Returns: + A list of library path strings. + """ + # We need to link against libtorch.so + paths = [TORCH_LIB_PATH] + + if device_type == "cuda" and IS_HIP_EXTENSION: + lib_dir = 'lib' + paths.append(_join_rocm_home(lib_dir)) + if HIP_HOME is not None: + paths.append(os.path.join(HIP_HOME, 'lib')) + elif device_type == "cuda": + if IS_WINDOWS: + lib_dir = os.path.join('lib', 'x64') + else: + lib_dir = 'lib64' + if (not os.path.exists(_join_cuda_home(lib_dir)) and + os.path.exists(_join_cuda_home('lib'))): + # 64-bit CUDA may be installed in 'lib' (see e.g. gh-16955) + # Note that it's also possible both don't exist (see + # _find_cuda_home) - in that case we stay with 'lib64'. + lib_dir = 'lib' + + paths.append(_join_cuda_home(lib_dir)) + if CUDNN_HOME is not None: + paths.append(os.path.join(CUDNN_HOME, lib_dir)) + elif device_type == "xpu": + if IS_WINDOWS: + lib_dir = os.path.join('lib', 'x64') + else: + lib_dir = 'lib64' + if (not os.path.exists(_join_sycl_home(lib_dir)) and + os.path.exists(_join_sycl_home('lib'))): + lib_dir = 'lib' + + paths.append(_join_sycl_home(lib_dir)) + + return paths + + +def load(name, + sources: Union[str, list[str]], + extra_cflags=None, + extra_cuda_cflags=None, + extra_sycl_cflags=None, + extra_ldflags=None, + extra_include_paths=None, + build_directory=None, + verbose=False, + with_cuda: Optional[bool] = None, + with_sycl: Optional[bool] = None, + is_python_module=True, + is_standalone=False, + keep_intermediates=True): + """ + Load a PyTorch C++ extension just-in-time (JIT). + + To load an extension, a Ninja build file is emitted, which is used to + compile the given sources into a dynamic library. This library is + subsequently loaded into the current Python process as a module and + returned from this function, ready for use. + + By default, the directory to which the build file is emitted and the + resulting library compiled to is ``/torch_extensions/``, where + ```` is the temporary folder on the current platform and ```` + the name of the extension. This location can be overridden in two ways. + First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it + replaces ``/torch_extensions`` and all extensions will be compiled + into subfolders of this directory. Second, if the ``build_directory`` + argument to this function is supplied, it overrides the entire path, i.e. + the library will be compiled into that folder directly. + + To compile the sources, the default system compiler (``c++``) is used, + which can be overridden by setting the ``CXX`` environment variable. To pass + additional arguments to the compilation process, ``extra_cflags`` or + ``extra_ldflags`` can be provided. For example, to compile your extension + with optimizations, pass ``extra_cflags=['-O3']``. You can also use + ``extra_cflags`` to pass further include directories. + + CUDA support with mixed compilation is provided. Simply pass CUDA source + files (``.cu`` or ``.cuh``) along with other sources. Such files will be + detected and compiled with nvcc rather than the C++ compiler. This includes + passing the CUDA lib64 directory as a library directory, and linking + ``cudart``. You can pass additional flags to nvcc via + ``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various + heuristics for finding the CUDA install directory are used, which usually + work fine. If not, setting the ``CUDA_HOME`` environment variable is the + safest option. + + SYCL support with mixed compilation is provided. Simply pass SYCL source + files (``.sycl``) along with other sources. Such files will be detected + and compiled with SYCL compiler (such as Intel DPC++ Compiler) rather + than the C++ compiler. You can pass additional flags to SYCL compiler + via ``extra_sycl_cflags``, just like with ``extra_cflags`` for C++. + SYCL compiler is expected to be found via system PATH environment + variable. + + Args: + name: The name of the extension to build. This MUST be the same as the + name of the pybind11 module! + sources: A list of relative or absolute paths to C++ source files. + extra_cflags: optional list of compiler flags to forward to the build. + extra_cuda_cflags: optional list of compiler flags to forward to nvcc + when building CUDA sources. + extra_sycl_cflags: optional list of compiler flags to forward to SYCL + compiler when building SYCL sources. + extra_ldflags: optional list of linker flags to forward to the build. + extra_include_paths: optional list of include directories to forward + to the build. + build_directory: optional path to use as build workspace. + verbose: If ``True``, turns on verbose logging of load steps. + with_cuda: Determines whether CUDA headers and libraries are added to + the build. If set to ``None`` (default), this value is + automatically determined based on the existence of ``.cu`` or + ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers + and libraries to be included. + with_sycl: Determines whether SYCL headers and libraries are added to + the build. If set to ``None`` (default), this value is + automatically determined based on the existence of ``.sycl`` in + ``sources``. Set it to `True`` to force SYCL headers and + libraries to be included. + is_python_module: If ``True`` (default), imports the produced shared + library as a Python module. If ``False``, behavior depends on + ``is_standalone``. + is_standalone: If ``False`` (default) loads the constructed extension + into the process as a plain dynamic library. If ``True``, build a + standalone executable. + + Returns: + If ``is_python_module`` is ``True``: + Returns the loaded PyTorch extension as a Python module. + + If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``: + Returns nothing. (The shared library is loaded into the process as + a side effect.) + + If ``is_standalone`` is ``True``. + Return the path to the executable. (On Windows, TORCH_LIB_PATH is + added to the PATH environment variable as a side effect.) + + Example: + >>> # xdoctest: +SKIP + >>> from torch.utils.cpp_extension import load + >>> module = load( + ... name='extension', + ... sources=['extension.cpp', 'extension_kernel.cu'], + ... extra_cflags=['-O2'], + ... verbose=True) + """ + return _jit_compile( + name, + [sources] if isinstance(sources, str) else sources, + extra_cflags, + extra_cuda_cflags, + extra_sycl_cflags, + extra_ldflags, + extra_include_paths, + build_directory or _get_build_directory(name, verbose), + verbose, + with_cuda, + with_sycl, + is_python_module, + is_standalone, + keep_intermediates=keep_intermediates) + +def _get_pybind11_abi_build_flags(): + # Note [Pybind11 ABI constants] + # + # Pybind11 before 2.4 used to build an ABI strings using the following pattern: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__" + # Since 2.4 compier type, stdlib and build abi parameters are also encoded like this: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__" + # + # This was done in order to further narrow down the chances of compiler ABI incompatibility + # that can cause a hard to debug segfaults. + # For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties + # captured during PyTorch native library compilation in torch/csrc/Module.cpp + + abi_cflags = [] + for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + pval = getattr(torch._C, f"_PYBIND11_{pname}") + if pval is not None and not IS_WINDOWS: + abi_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"') + return abi_cflags + +def check_compiler_is_gcc(compiler): + if not IS_LINUX: + return False + + env = os.environ.copy() + env['LC_ALL'] = 'C' # Don't localize output + try: + version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) + except Exception: + try: + version_string = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) + except Exception: + return False + # Check for 'gcc' or 'g++' for sccache wrapper + pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE) + results = re.findall(pattern, version_string) + if len(results) != 1: + return False + compiler_path = os.path.realpath(results[0].strip()) + # On RHEL/CentOS c++ is a gcc compiler wrapper + if os.path.basename(compiler_path) == 'c++' and 'gcc version' in version_string: + return True + return False + +def _check_and_build_extension_h_precompiler_headers( + extra_cflags, + extra_include_paths, + is_standalone=False): + r''' + Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules. + GCC offical manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html + PCH only works when built pch file(header.h.gch) and build target have the same build parameters. So, We need + add a signature file to record PCH file parameters. If the build parameters(signature) changed, it should rebuild + PCH file. + + Note: + 1. Windows and MacOS have different PCH mechanism. We only support Linux currently. + 2. It only works on GCC/G++. + ''' + if not IS_LINUX: + return + + compiler = get_cxx_compiler() + + b_is_gcc = check_compiler_is_gcc(compiler) + if b_is_gcc is False: + return + + head_file = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h') + head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch') + head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign') + + def listToString(s): + # initialize an empty string + string = "" + if s is None: + return string + + # traverse in the string + for element in s: + string += (element + ' ') + # return string + return string + + def format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags, torch_include_dirs, extra_cflags, extra_include_paths): + return re.sub( + r"[ \n]+", + " ", + f""" + {compiler} -x c++-header {head_file} -o {head_file_pch} {torch_include_dirs} {extra_include_paths} {extra_cflags} {common_cflags} + """, + ).strip() + + def command_to_signature(cmd): + signature = cmd.replace(' ', '_') + return signature + + def check_pch_signature_in_file(file_path, signature): + b_exist = os.path.isfile(file_path) + if b_exist is False: + return False + + with open(file_path) as file: + # read all content of a file + content = file.read() + # check if string present in a file + return signature == content + + def _create_if_not_exist(path_dir): + if not os.path.exists(path_dir): + try: + Path(path_dir).mkdir(parents=True, exist_ok=True) + except OSError as exc: # Guard against race condition + if exc.errno != errno.EEXIST: + raise RuntimeError(f"Fail to create path {path_dir}") from exc + + def write_pch_signature_to_file(file_path, pch_sign): + _create_if_not_exist(os.path.dirname(file_path)) + with open(file_path, "w") as f: + f.write(pch_sign) + f.close() + + def build_precompile_header(pch_cmd): + try: + subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Compile PreCompile Header fail, command: {pch_cmd}") from e + + extra_cflags_str = listToString(extra_cflags) + extra_include_paths_str = " ".join( + [f"-I{include}" for include in extra_include_paths] if extra_include_paths else [] + ) + + lib_include = os.path.join(_TORCH_PATH, 'include') + torch_include_dirs = [ + f"-I {lib_include}", + # Python.h + "-I {}".format(sysconfig.get_path("include")), + # torch/all.h + "-I {}".format(os.path.join(lib_include, 'torch', 'csrc', 'api', 'include')), + ] + + torch_include_dirs_str = listToString(torch_include_dirs) + + common_cflags = [] + if not is_standalone: + common_cflags += ['-DTORCH_API_INCLUDE_EXTENSION_H'] + + common_cflags += ['-std=c++17', '-fPIC'] + common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()] + common_cflags_str = listToString(common_cflags) + + pch_cmd = format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags_str, torch_include_dirs_str, extra_cflags_str, extra_include_paths_str) + pch_sign = command_to_signature(pch_cmd) + + if os.path.isfile(head_file_pch) is not True: + build_precompile_header(pch_cmd) + write_pch_signature_to_file(head_file_signature, pch_sign) + else: + b_same_sign = check_pch_signature_in_file(head_file_signature, pch_sign) + if b_same_sign is False: + build_precompile_header(pch_cmd) + write_pch_signature_to_file(head_file_signature, pch_sign) + +def remove_extension_h_precompiler_headers(): + def _remove_if_file_exists(path_file): + if os.path.exists(path_file): + os.remove(path_file) + + head_file_pch = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.gch') + head_file_signature = os.path.join(_TORCH_PATH, 'include', 'torch', 'extension.h.sign') + + _remove_if_file_exists(head_file_pch) + _remove_if_file_exists(head_file_signature) + +def load_inline(name, + cpp_sources, + cuda_sources=None, + sycl_sources=None, + functions=None, + extra_cflags=None, + extra_cuda_cflags=None, + extra_sycl_cflags=None, + extra_ldflags=None, + extra_include_paths=None, + build_directory=None, + verbose=False, + with_cuda=None, + with_sycl=None, + is_python_module=True, + with_pytorch_error_handling=True, + keep_intermediates=True, + use_pch=False, + no_implicit_headers=False): + r''' + Load a PyTorch C++ extension just-in-time (JIT) from string sources. + + This function behaves exactly like :func:`load`, but takes its sources as + strings rather than filenames. These strings are stored to files in the + build directory, after which the behavior of :func:`load_inline` is + identical to :func:`load`. + + See `the + tests `_ + for good examples of using this function. + + Sources may omit two required parts of a typical non-inline C++ extension: + the necessary header includes, as well as the (pybind11) binding code. More + precisely, strings passed to ``cpp_sources`` are first concatenated into a + single ``.cpp`` file. This file is then prepended with ``#include + `` + + Furthermore, if the ``functions`` argument is supplied, bindings will be + automatically generated for each function specified. ``functions`` can + either be a list of function names, or a dictionary mapping from function + names to docstrings. If a list is given, the name of each function is used + as its docstring. + + The sources in ``cuda_sources`` are concatenated into a separate ``.cu`` + file and prepended with ``torch/types.h``, ``cuda.h`` and + ``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled + separately, but ultimately linked into a single library. Note that no + bindings are generated for functions in ``cuda_sources`` per se. To bind + to a CUDA kernel, you must create a C++ function that calls it, and either + declare or define this C++ function in one of the ``cpp_sources`` (and + include its name in ``functions``). + + The sources in ``sycl_sources`` are concatenated into a separate ``.sycl`` + file and prepended with ``torch/types.h``, ``sycl/sycl.hpp`` includes. + The ``.cpp`` and ``.sycl`` files are compiled separately, but ultimately + linked into a single library. Note that no bindings are generated for + functions in ``sycl_sources`` per se. To bind to a SYCL kernel, you must + create a C++ function that calls it, and either declare or define this + C++ function in one of the ``cpp_sources`` (and include its name + in ``functions``). + + + + See :func:`load` for a description of arguments omitted below. + + Args: + cpp_sources: A string, or list of strings, containing C++ source code. + cuda_sources: A string, or list of strings, containing CUDA source code. + sycl_sources: A string, or list of strings, containing SYCL source code. + functions: A list of function names for which to generate function + bindings. If a dictionary is given, it should map function names to + docstrings (which are otherwise just the function names). + with_cuda: Determines whether CUDA headers and libraries are added to + the build. If set to ``None`` (default), this value is + automatically determined based on whether ``cuda_sources`` is + provided. Set it to ``True`` to force CUDA headers + and libraries to be included. + with_sycl: Determines whether SYCL headers and libraries are added to + the build. If set to ``None`` (default), this value is + automatically determined based on whether ``sycl_sources`` is + provided. Set it to ``True`` to force SYCL headers + and libraries to be included. + with_pytorch_error_handling: Determines whether pytorch error and + warning macros are handled by pytorch instead of pybind. To do + this, each function ``foo`` is called via an intermediary ``_safe_foo`` + function. This redirection might cause issues in obscure cases + of cpp. This flag should be set to ``False`` when this redirect + causes issues. + no_implicit_headers: If ``True``, skips automatically adding headers, most notably + ``#include `` and ``#include `` lines. + Use this option to improve cold start times when you + already include the necessary headers in your source code. Default: ``False``. + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT) + >>> from torch.utils.cpp_extension import load_inline + >>> source = """ + at::Tensor sin_add(at::Tensor x, at::Tensor y) { + return x.sin() + y.sin(); + } + """ + >>> module = load_inline(name='inline_extension', + ... cpp_sources=[source], + ... functions=['sin_add']) + + .. note:: + Since load_inline will just-in-time compile the source code, please ensure + that you have the right toolchains installed in the runtime. For example, + when loading C++, make sure a C++ compiler is available. If you're loading + a CUDA extension, you will need to additionally install the corresponding CUDA + toolkit (nvcc and any other dependencies your code has). Compiling toolchains + are not included when you install torch and must be additionally installed. + + During compiling, by default, the Ninja backend uses #CPUS + 2 workers to build + the extension. This may use up too many resources on some systems. One + can control the number of workers by setting the `MAX_JOBS` environment + variable to a non-negative number. + ''' + build_directory = build_directory or _get_build_directory(name, verbose) + + if isinstance(cpp_sources, str): + cpp_sources = [cpp_sources] + cuda_sources = cuda_sources or [] + if isinstance(cuda_sources, str): + cuda_sources = [cuda_sources] + sycl_sources = sycl_sources or [] + if isinstance(sycl_sources, str): + sycl_sources = [sycl_sources] + + if not no_implicit_headers: + cpp_sources.insert(0, '#include ') + + if use_pch is True: + # Using PreCompile Header('torch/extension.h') to reduce compile time. + _check_and_build_extension_h_precompiler_headers(extra_cflags, extra_include_paths) + else: + remove_extension_h_precompiler_headers() + + # If `functions` is supplied, we create the pybind11 bindings for the user. + # Here, `functions` is (or becomes, after some processing) a map from + # function names to function docstrings. + if functions is not None: + module_def = [] + module_def.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {') + if isinstance(functions, str): + functions = [functions] + if isinstance(functions, list): + # Make the function docstring the same as the function name. + functions = {f: f for f in functions} + elif not isinstance(functions, dict): + raise ValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}") + for function_name, docstring in functions.items(): + if with_pytorch_error_handling: + module_def.append(f'm.def("{function_name}", torch::wrap_pybind_function({function_name}), "{docstring}");') + else: + module_def.append(f'm.def("{function_name}", {function_name}, "{docstring}");') + module_def.append('}') + cpp_sources += module_def + + cpp_source_path = os.path.join(build_directory, 'main.cpp') + _maybe_write(cpp_source_path, "\n".join(cpp_sources)) + + sources = [cpp_source_path] + + if cuda_sources: + if not no_implicit_headers: + cuda_sources.insert(0, '#include ') + cuda_sources.insert(1, '#include ') + cuda_sources.insert(2, '#include ') + + cuda_source_path = os.path.join(build_directory, 'cuda.cu') + _maybe_write(cuda_source_path, "\n".join(cuda_sources)) + + sources.append(cuda_source_path) + + if sycl_sources: + if not no_implicit_headers: + sycl_sources.insert(0, '#include ') + sycl_sources.insert(1, '#include ') + + sycl_source_path = os.path.join(build_directory, 'sycl.sycl') + _maybe_write(sycl_source_path, "\n".join(sycl_sources)) + + sources.append(sycl_source_path) + + return _jit_compile( + name, + sources, + extra_cflags, + extra_cuda_cflags, + extra_sycl_cflags, + extra_ldflags, + extra_include_paths, + build_directory, + verbose, + with_cuda, + with_sycl, + is_python_module, + is_standalone=False, + keep_intermediates=keep_intermediates) + + +def _jit_compile(name, + sources, + extra_cflags, + extra_cuda_cflags, + extra_sycl_cflags, + extra_ldflags, + extra_include_paths, + build_directory: str, + verbose: bool, + with_cuda: Optional[bool], + with_sycl: Optional[bool], + is_python_module, + is_standalone, + keep_intermediates=True) -> None: + if is_python_module and is_standalone: + raise ValueError("`is_python_module` and `is_standalone` are mutually exclusive.") + + if with_cuda is None: + with_cuda = any(map(_is_cuda_file, sources)) + with_cudnn = any('cudnn' in f for f in extra_ldflags or []) + if with_sycl is None: + with_sycl = any(map(_is_sycl_file, sources)) + old_version = JIT_EXTENSION_VERSIONER.get_version(name) + version = JIT_EXTENSION_VERSIONER.bump_version_if_changed( + name, + sources, + build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths], + build_directory=build_directory, + with_cuda=with_cuda, + with_sycl=with_sycl, + is_python_module=is_python_module, + is_standalone=is_standalone, + ) + if version > 0: + if version != old_version and verbose: + logger.info('The input conditions for extension module %s have changed.', name) + logger.info('Bumping to version %s and re-building as %s_v%s...', version, name, version) + name = f'{name}_v{version}' + + baton = FileBaton(os.path.join(build_directory, 'lock')) + if baton.try_acquire(): + try: + if version != old_version: + with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx: + if IS_HIP_EXTENSION and (with_cuda or with_cudnn): + hipify_result = hipify_python.hipify( + project_directory=build_directory, + output_directory=build_directory, + header_include_dirs=(extra_include_paths if extra_include_paths is not None else []), + extra_files=[os.path.abspath(s) for s in sources], + ignores=[_join_rocm_home('*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers + show_detailed=verbose, + show_progress=verbose, + is_pytorch_extension=True, + clean_ctx=clean_ctx + ) + + hipified_sources = set() + for source in sources: + s_abs = os.path.abspath(source) + hipified_sources.add(hipify_result[s_abs].hipified_path if s_abs in hipify_result else s_abs) + + sources = list(hipified_sources) + + _write_ninja_file_and_build_library( + name=name, + sources=sources, + extra_cflags=extra_cflags or [], + extra_cuda_cflags=extra_cuda_cflags or [], + extra_sycl_cflags=extra_sycl_cflags or [], + extra_ldflags=extra_ldflags or [], + extra_include_paths=extra_include_paths or [], + build_directory=build_directory, + verbose=verbose, + with_cuda=with_cuda, + with_sycl=with_sycl, + is_standalone=is_standalone) + elif verbose: + logger.debug('No modifications detected for re-loaded extension module %s, skipping build step...', name) + finally: + baton.release() + else: + baton.wait() + + if verbose: + logger.info('Loading extension module %s...', name) + + if is_standalone: + return _get_exec_path(name, build_directory) + + return _import_module_from_library(name, build_directory, is_python_module) + +def _get_hipcc_path(): + if IS_WINDOWS: + # mypy thinks ROCM_VERSION is None but it will never be None here + hipcc_exe = 'hipcc.exe' if ROCM_VERSION >= (6, 4) else 'hipcc.bat' # type: ignore[operator] + return _join_rocm_home('bin', hipcc_exe) + else: + return _join_rocm_home('bin', 'hipcc') + +def _write_ninja_file_and_compile_objects( + sources: list[str], + objects, + cflags, + post_cflags, + cuda_cflags, + cuda_post_cflags, + cuda_dlink_post_cflags, + sycl_cflags, + sycl_post_cflags, + sycl_dlink_post_cflags, + build_directory: str, + verbose: bool, + with_cuda: Optional[bool], + with_sycl: Optional[bool]) -> None: + verify_ninja_availability() + + compiler = get_cxx_compiler() + + get_compiler_abi_compatibility_and_version(compiler) + if with_cuda is None: + with_cuda = any(map(_is_cuda_file, sources)) + if with_sycl is None: + with_sycl = any(map(_is_sycl_file, sources)) + build_file_path = os.path.join(build_directory, 'build.ninja') + if verbose: + logger.debug('Emitting ninja build file %s...', build_file_path) + + # Create build_directory if it does not exist + if not os.path.exists(build_directory): + if verbose: + logger.debug('Creating directory %s...', build_directory) + # This is like mkdir -p, i.e. will also create parent directories. + os.makedirs(build_directory, exist_ok=True) + + _write_ninja_file( + path=build_file_path, + cflags=cflags, + post_cflags=post_cflags, + cuda_cflags=cuda_cflags, + cuda_post_cflags=cuda_post_cflags, + cuda_dlink_post_cflags=cuda_dlink_post_cflags, + sycl_cflags=sycl_cflags, + sycl_post_cflags=sycl_post_cflags, + sycl_dlink_post_cflags=sycl_dlink_post_cflags, + sources=sources, + objects=objects, + ldflags=None, + library_target=None, + with_cuda=with_cuda, + with_sycl=with_sycl) + if verbose: + logger.info('Compiling objects...') + _run_ninja_build( + build_directory, + verbose, + # It would be better if we could tell users the name of the extension + # that failed to build but there isn't a good way to get it here. + error_prefix='Error compiling objects for extension') + + +def _write_ninja_file_and_build_library( + name, + sources: list[str], + extra_cflags, + extra_cuda_cflags, + extra_sycl_cflags, + extra_ldflags, + extra_include_paths, + build_directory: str, + verbose: bool, + with_cuda: Optional[bool], + with_sycl: Optional[bool], + is_standalone: bool = False) -> None: + verify_ninja_availability() + + compiler = get_cxx_compiler() + + get_compiler_abi_compatibility_and_version(compiler) + if with_cuda is None: + with_cuda = any(map(_is_cuda_file, sources)) + if with_sycl is None: + with_sycl = any(map(_is_sycl_file, sources)) + extra_ldflags = _prepare_ldflags( + extra_ldflags or [], + with_cuda, + verbose, + is_standalone) + build_file_path = os.path.join(build_directory, 'build.ninja') + if verbose: + logger.debug('Emitting ninja build file %s...', build_file_path) + + # Create build_directory if it does not exist + if not os.path.exists(build_directory): + if verbose: + logger.debug('Creating directory %s...', build_directory) + # This is like mkdir -p, i.e. will also create parent directories. + os.makedirs(build_directory, exist_ok=True) + + # NOTE: Emitting a new ninja build file does not cause re-compilation if + # the sources did not change, so it's ok to re-emit (and it's fast). + _write_ninja_file_to_build_library( + path=build_file_path, + name=name, + sources=sources, + extra_cflags=extra_cflags or [], + extra_cuda_cflags=extra_cuda_cflags or [], + extra_sycl_cflags=extra_sycl_cflags or [], + extra_ldflags=extra_ldflags or [], + extra_include_paths=extra_include_paths or [], + with_cuda=with_cuda, + with_sycl=with_sycl, + is_standalone=is_standalone) + + if verbose: + logger.info('Building extension module %s...', name) + _run_ninja_build( + build_directory, + verbose, + error_prefix=f"Error building extension '{name}'") + + +def is_ninja_available(): + """Return ``True`` if the `ninja `_ build system is available on the system, ``False`` otherwise.""" + try: + subprocess.check_output('ninja --version'.split()) + except Exception: + return False + else: + return True + + +def verify_ninja_availability(): + """Raise ``RuntimeError`` if `ninja `_ build system is not available on the system, does nothing otherwise.""" + if not is_ninja_available(): + raise RuntimeError("Ninja is required to load C++ extensions (pip install ninja to get it)") + + +def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): + if IS_WINDOWS: + python_lib_path = os.path.join(sys.base_exec_prefix, 'libs') + + extra_ldflags.append('c10.lib') + if with_cuda: + extra_ldflags.append('c10_cuda.lib') + extra_ldflags.append('torch_cpu.lib') + if with_cuda: + extra_ldflags.append('torch_cuda.lib') + # /INCLUDE is used to ensure torch_cuda is linked against in a project that relies on it. + # Related issue: https://github.com/pytorch/pytorch/issues/31611 + extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ') + extra_ldflags.append('torch.lib') + extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}') + if not is_standalone: + extra_ldflags.append('torch_python.lib') + extra_ldflags.append(f'/LIBPATH:{python_lib_path}') + + else: + extra_ldflags.append(f'-L{TORCH_LIB_PATH}') + extra_ldflags.append('-lc10') + if with_cuda: + extra_ldflags.append('-lc10_hip' if IS_HIP_EXTENSION else '-lc10_cuda') + extra_ldflags.append('-ltorch_cpu') + if with_cuda: + extra_ldflags.append('-ltorch_hip' if IS_HIP_EXTENSION else '-ltorch_cuda') + extra_ldflags.append('-ltorch') + if not is_standalone: + extra_ldflags.append('-ltorch_python') + + if is_standalone: + extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}") + + if with_cuda: + if verbose: + logger.info('Detected CUDA files, patching ldflags') + if IS_WINDOWS: + extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib", "x64")}') + extra_ldflags.append('cudart.lib') + if CUDNN_HOME is not None: + extra_ldflags.append(f'/LIBPATH:{os.path.join(CUDNN_HOME, "lib", "x64")}') + elif not IS_HIP_EXTENSION: + extra_lib_dir = "lib64" + if (not os.path.exists(_join_cuda_home(extra_lib_dir)) and + os.path.exists(_join_cuda_home("lib"))): + # 64-bit CUDA may be installed in "lib" + # Note that it's also possible both don't exist (see _find_cuda_home) - in that case we stay with "lib64" + extra_lib_dir = "lib" + extra_ldflags.append(f'-L{_join_cuda_home(extra_lib_dir)}') + extra_ldflags.append('-lcudart') + if CUDNN_HOME is not None: + extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME, "lib64")}') + elif IS_HIP_EXTENSION: + extra_ldflags.append(f'-L{_join_rocm_home("lib")}') + extra_ldflags.append('-lamdhip64') + return extra_ldflags + + +def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: + """ + Determine CUDA arch flags to use. + + For an arch, say "6.1", the added compile flag will be + ``-gencode=arch=compute_61,code=sm_61``. + For an added "+PTX", an additional + ``-gencode=arch=compute_xx,code=compute_xx`` is added. + + See select_compute_arch.cmake for corresponding named and supported arches + when building with CMake. + """ + # If cflags is given, there may already be user-provided arch flags in it + # (from `extra_compile_args`) + if cflags is not None: + for flag in cflags: + if 'TORCH_EXTENSION_NAME' in flag: + continue + if 'arch' in flag: + return [] + + # Note: keep combined names ("arch1+arch2") above single names, otherwise + # string replacement may not do the right thing + named_arches = collections.OrderedDict([ + ('Kepler+Tesla', '3.7'), + ('Kepler', '3.5+PTX'), + ('Maxwell+Tegra', '5.3'), + ('Maxwell', '5.0;5.2+PTX'), + ('Pascal', '6.0;6.1+PTX'), + ('Volta+Tegra', '7.2'), + ('Volta', '7.0+PTX'), + ('Turing', '7.5+PTX'), + ('Ampere+Tegra', '8.7'), + ('Ampere', '8.0;8.6+PTX'), + ('Ada', '8.9+PTX'), + ('Hopper', '9.0+PTX'), + ('Blackwell+Tegra', '10.1'), + ('Blackwell', '10.0;10.3;12.0;12.1+PTX'), + ]) + + supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2', + '7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a', + '10.0', '10.0a', '10.1', '10.1a', '10.3', '10.3a', '12.0', + '12.0a', '12.1', '12.1a'] + valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches] + + # The default is sm_30 for CUDA 9.x and 10.x + # First check for an env var (same as used by the main setup.py) + # Can be one or more architectures, e.g. "6.1" or "3.5;5.2;6.0;6.1;7.0+PTX" + # See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake + _arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None) + + # If not given, determine what's best for the GPU / CUDA version that can be found + if not _arch_list: + logger.warning( + "TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n" + "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.") + arch_list = [] + # the assumption is that the extension should run on any of the currently visible cards, + # which could be of different types - therefore all archs for visible cards should be included + for i in range(torch.cuda.device_count()): + capability = torch.cuda.get_device_capability(i) + supported_sm = [int("".join(re.findall(r"\d+", arch.split('_')[1]))) + for arch in torch.cuda.get_arch_list() if 'sm_' in arch] + max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) + # Capability of the device may be higher than what's supported by the user's + # NVCC, causing compilation error. User's NVCC is expected to match the one + # used to build pytorch, so we use the maximum supported capability of pytorch + # to clamp the capability. + capability = min(max_supported_sm, capability) + arch = f'{capability[0]}.{capability[1]}' + if arch not in arch_list: + arch_list.append(arch) + arch_list = sorted(arch_list) + arch_list[-1] += '+PTX' + else: + # Deal with lists that are ' ' separated (only deal with ';' after) + _arch_list = _arch_list.replace(' ', ';') + # Expand named arches + for named_arch, archval in named_arches.items(): + _arch_list = _arch_list.replace(named_arch, archval) + + arch_list = _arch_list.split(';') + + flags = [] + for arch in arch_list: + if arch not in valid_arch_strings: + raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported") + else: + # Handle both single and double-digit architecture versions + version = arch.split('+')[0] # Remove "+PTX" if present + major, minor = version.split('.') + num = f"{major}{minor}" + flags.append(f'-gencode=arch=compute_{num},code=sm_{num}') + if arch.endswith('+PTX'): + flags.append(f'-gencode=arch=compute_{num},code=compute_{num}') + + return sorted(set(flags)) + + +def _get_rocm_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: + # If cflags is given, there may already be user-provided arch flags in it + # (from `extra_compile_args`). If user also specified -fgpu-rdc or -fno-gpu-rdc, we + # assume they know what they're doing. Otherwise, we force -fno-gpu-rdc default. + has_gpu_rdc_flag = False + if cflags is not None: + has_custom_flags = False + for flag in cflags: + if 'amdgpu-target' in flag or 'offload-arch' in flag: + has_custom_flags = True + elif 'gpu-rdc' in flag: + has_gpu_rdc_flag = True + if has_custom_flags: + return [] if has_gpu_rdc_flag else ['-fno-gpu-rdc'] + # Use same defaults as used for building PyTorch + # Allow env var to override, just like during initial cmake build. + _archs = os.environ.get('PYTORCH_ROCM_ARCH', None) + if not _archs: + archFlags = torch._C._cuda_getArchFlags() + if archFlags: + archs = archFlags.split() + else: + archs = [] + else: + archs = _archs.replace(' ', ';').split(';') + flags = [f'--offload-arch={arch}' for arch in archs] + flags += [] if has_gpu_rdc_flag else ['-fno-gpu-rdc'] + return flags + +def _get_build_directory(name: str, verbose: bool) -> str: + """ + Get the build directory for an extension. + + Args: + name: The name of the extension + verbose: Whether to print verbose information + + Returns: + The path to the build directory + """ + root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR') + if root_extensions_directory is None: + root_extensions_directory = get_default_build_root() + cu_str = ('cpu' if torch.version.cuda is None else + f'cu{torch.version.cuda.replace(".", "")}') + python_version = f'py{sys.version_info.major}{sys.version_info.minor}{getattr(sys, "abiflags", "")}' + build_folder = f'{python_version}_{cu_str}' + + root_extensions_directory = os.path.join( + root_extensions_directory, build_folder) + + if verbose: + logger.info('Using %s as PyTorch extensions root...', root_extensions_directory) + + build_directory = os.path.join(root_extensions_directory, name) + if not os.path.exists(build_directory): + if verbose: + logger.debug('Creating extension directory %s...', build_directory) + # This is like mkdir -p, i.e. will also create parent directories. + os.makedirs(build_directory, exist_ok=True) + + return build_directory + + +def _get_num_workers(verbose: bool) -> Optional[int]: + max_jobs = os.environ.get('MAX_JOBS') + if max_jobs is not None and max_jobs.isdigit(): + if verbose: + logger.debug('Using envvar MAX_JOBS (%s) as the number of workers...', max_jobs) + return int(max_jobs) + if verbose: + logger.info( + 'Allowing ninja to set a default number of workers... ' + '(overridable by setting the environment variable MAX_JOBS=N)' + ) + return None + + +def _get_vc_env(vc_arch: str) -> dict[str, str]: + try: + from setuptools import distutils # type: ignore[attr-defined] + return distutils._msvccompiler._get_vc_env(vc_arch) + except AttributeError: + try: + from setuptools._distutils import _msvccompiler + return _msvccompiler._get_vc_env(vc_arch) # type: ignore[attr-defined] + except AttributeError: + from setuptools._distutils.compilers.C import msvc + return msvc._get_vc_env(vc_arch) # type: ignore[attr-defined] + +def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> None: + command = ['ninja', '-v'] + num_workers = _get_num_workers(verbose) + if num_workers is not None: + command.extend(['-j', str(num_workers)]) + env = os.environ.copy() + # Try to activate the vc env for the users + if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' not in env: + from setuptools import distutils # type: ignore[attr-defined] + + plat_name = distutils.util.get_platform() + plat_spec = PLAT_TO_VCVARS[plat_name] + vc_env = {k.upper(): v for k, v in _get_vc_env(plat_spec).items()} + for k, v in env.items(): + uk = k.upper() + if uk not in vc_env: + vc_env[uk] = v + env = vc_env + try: + sys.stdout.flush() + sys.stderr.flush() + # Warning: don't pass stdout=None to subprocess.run to get output. + # subprocess.run assumes that sys.__stdout__ has not been modified and + # attempts to write to it by default. However, when we call _run_ninja_build + # from ahead-of-time cpp extensions, the following happens: + # 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__. + # https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110 + # (it probably shouldn't do this) + # 2) subprocess.run (on POSIX, with no stdout override) relies on + # __stdout__ not being detached: + # https://github.com/python/cpython/blob/c352e6c7446c894b13643f538db312092b351789/Lib/subprocess.py#L1214 + # To work around this, we pass in the fileno directly and hope that + # it is valid. + stdout_fileno = 1 + subprocess.run( + command, + shell=IS_WINDOWS and IS_HIP_EXTENSION, + stdout=stdout_fileno if verbose else subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=build_directory, + check=True, + env=env) + except subprocess.CalledProcessError as e: + # Python 2 and 3 compatible way of getting the error object. + _, error, _ = sys.exc_info() + # error.output contains the stdout and stderr of the build attempt. + message = error_prefix + # `error` is a CalledProcessError (which has an `output`) attribute, but + # mypy thinks it's Optional[BaseException] and doesn't narrow + if hasattr(error, 'output') and error.output: # type: ignore[union-attr] + message += f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}" # type: ignore[union-attr] + raise RuntimeError(message) from e + + +def _get_exec_path(module_name, path): + if IS_WINDOWS and TORCH_LIB_PATH not in os.getenv('PATH', '').split(';'): + torch_lib_in_path = any( + os.path.exists(p) and os.path.samefile(p, TORCH_LIB_PATH) + for p in os.getenv('PATH', '').split(';') + ) + if not torch_lib_in_path: + os.environ['PATH'] = f"{TORCH_LIB_PATH};{os.getenv('PATH', '')}" + return os.path.join(path, f'{module_name}{EXEC_EXT}') + + +def _import_module_from_library(module_name, path, is_python_module): + filepath = os.path.join(path, f"{module_name}{LIB_EXT}") + if is_python_module: + # https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path + spec = importlib.util.spec_from_file_location(module_name, filepath) + assert spec is not None + module = importlib.util.module_from_spec(spec) + assert isinstance(spec.loader, importlib.abc.Loader) + spec.loader.exec_module(module) + return module + else: + torch.ops.load_library(filepath) + return filepath + + +def _write_ninja_file_to_build_library(path, + name, + sources, + extra_cflags, + extra_cuda_cflags, + extra_sycl_cflags, + extra_ldflags, + extra_include_paths, + with_cuda, + with_sycl, + is_standalone) -> None: + extra_cflags = [flag.strip() for flag in extra_cflags] + extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags] + extra_sycl_cflags = [flag.strip() for flag in extra_sycl_cflags] + extra_ldflags = [flag.strip() for flag in extra_ldflags] + extra_include_paths = [flag.strip() for flag in extra_include_paths] + + # Turn into absolute paths so we can emit them into the ninja build + # file wherever it is. + user_includes = [os.path.abspath(file) for file in extra_include_paths] + + # include_paths() gives us the location of torch/extension.h + # TODO generalize with_cuda as specific device type. + if with_cuda: + system_includes = include_paths("cuda") + else: + system_includes = include_paths("cpu") + # sysconfig.get_path('include') gives us the location of Python.h + # Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS + # installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder + python_include_path = sysconfig.get_path('include', scheme='nt' if IS_WINDOWS else 'posix_prefix') + if python_include_path is not None: + system_includes.append(python_include_path) + + common_cflags = [] + if not is_standalone: + common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}') + common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H') + + common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()] + + # Windows does not understand `-isystem` and quotes flags later. + if IS_WINDOWS: + common_cflags += [f'-I{include}' for include in user_includes + system_includes] + else: + common_cflags += [f'-I{shlex.quote(include)}' for include in user_includes] + common_cflags += [f'-isystem {shlex.quote(include)}' for include in system_includes] + + if IS_WINDOWS: + cflags = common_cflags + ['/std:c++17'] + extra_cflags + cflags += COMMON_HIP_FLAGS if IS_HIP_EXTENSION else COMMON_MSVC_FLAGS + cflags = _nt_quote_args(cflags) + else: + cflags = common_cflags + ['-fPIC', '-std=c++17'] + extra_cflags + + if with_cuda and IS_HIP_EXTENSION: + cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS + cuda_flags += _get_rocm_arch_flags(cuda_flags) + cuda_flags += extra_cuda_cflags + elif with_cuda: + cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags() + if IS_WINDOWS: + for flag in COMMON_MSVC_FLAGS: + cuda_flags = ['-Xcompiler', flag] + cuda_flags + for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS: + cuda_flags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cuda_flags + cuda_flags = cuda_flags + ['-std=c++17'] + cuda_flags = _nt_quote_args(cuda_flags) + cuda_flags += _nt_quote_args(extra_cuda_cflags) + else: + cuda_flags += ['--compiler-options', "'-fPIC'"] + cuda_flags += extra_cuda_cflags + if not any(flag.startswith('-std=') for flag in cuda_flags): + cuda_flags.append('-std=c++17') + cc_env = os.getenv("CC") + if cc_env is not None: + cuda_flags = ['-ccbin', cc_env] + cuda_flags + else: + cuda_flags = None + + if with_sycl: + sycl_cflags = cflags + _COMMON_SYCL_FLAGS + sycl_cflags += extra_sycl_cflags + _append_sycl_targets_if_missing(sycl_cflags) + _append_sycl_std_if_no_std_present(sycl_cflags) + host_cflags = cflags + # escaping quoted arguments to pass them thru SYCL compiler + host_cflags = [item.replace('\\"', '\\\\"') for item in host_cflags] + host_cflags = ' '.join(host_cflags) + sycl_cflags += _wrap_sycl_host_flags(host_cflags) + sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS.copy() + sycl_dlink_post_cflags += _get_sycl_device_flags(sycl_cflags) + else: + sycl_cflags = None + sycl_dlink_post_cflags = None + + def object_file_path(source_file: str) -> str: + # '/path/to/file.cpp' -> 'file' + file_name = os.path.splitext(os.path.basename(source_file))[0] + if _is_cuda_file(source_file) and with_cuda: + # Use a different object filename in case a C++ and CUDA file have + # the same filename but different extension (.cpp vs. .cu). + target = f'{file_name}.cuda.o' + elif _is_sycl_file(source_file) and with_sycl: + target = f'{file_name}.sycl.o' + else: + target = f'{file_name}.o' + return target + + objects = [object_file_path(src) for src in sources] + ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags + + # The darwin linker needs explicit consent to ignore unresolved symbols. + if IS_MACOS: + ldflags.append('-undefined dynamic_lookup') + elif IS_WINDOWS: + ldflags = _nt_quote_args(ldflags) + + ext = EXEC_EXT if is_standalone else LIB_EXT + library_target = f'{name}{ext}' + + _write_ninja_file( + path=path, + cflags=cflags, + post_cflags=None, + cuda_cflags=cuda_flags, + cuda_post_cflags=None, + cuda_dlink_post_cflags=None, + sycl_cflags=sycl_cflags, + sycl_post_cflags=[], + sycl_dlink_post_cflags=sycl_dlink_post_cflags, + sources=sources, + objects=objects, + ldflags=ldflags, + library_target=library_target, + with_cuda=with_cuda, + with_sycl=with_sycl) + + +def _write_ninja_file(path, + cflags, + post_cflags, + cuda_cflags, + cuda_post_cflags, + cuda_dlink_post_cflags, + sycl_cflags, + sycl_post_cflags, + sycl_dlink_post_cflags, + sources, + objects, + ldflags, + library_target, + with_cuda, + with_sycl) -> None: + r"""Write a ninja file that does the desired compiling and linking. + + `path`: Where to write this file + `cflags`: list of flags to pass to $cxx. Can be None. + `post_cflags`: list of flags to append to the $cxx invocation. Can be None. + `cuda_cflags`: list of flags to pass to $nvcc. Can be None. + `cuda_post_cflags`: list of flags to append to the $nvcc invocation. Can be None. + `cuda_dlink_post_cflags`: list of flags to append to the $nvcc device code link invocation. Can be None. + `sycl_cflags`: list of flags to pass to SYCL compiler. Can be None. + `sycl_post_cflags`: list of flags to append to the SYCL compiler invocation. Can be None. + `sycl_dlink_post_cflags`: list of flags to append to the SYCL compiler device code link invocation. Can be None. +e. + `sources`: list of paths to source files + `objects`: list of desired paths to objects, one per source. + `ldflags`: list of flags to pass to linker. Can be None. + `library_target`: Name of the output library. Can be None; in that case, + we do no linking. + `with_cuda`: If we should be compiling with CUDA. + """ + def sanitize_flags(flags): + if flags is None: + return [] + else: + return [flag.strip() for flag in flags] + + cflags = sanitize_flags(cflags) + post_cflags = sanitize_flags(post_cflags) + cuda_cflags = sanitize_flags(cuda_cflags) + cuda_post_cflags = sanitize_flags(cuda_post_cflags) + cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags) + sycl_cflags = sanitize_flags(sycl_cflags) + sycl_post_cflags = sanitize_flags(sycl_post_cflags) + sycl_dlink_post_cflags = sanitize_flags(sycl_dlink_post_cflags) + ldflags = sanitize_flags(ldflags) + + # Sanity checks... + assert len(sources) == len(objects) + assert len(sources) > 0 + + compiler = get_cxx_compiler() + + # Version 1.3 is required for the `deps` directive. + config = ['ninja_required_version = 1.3'] + config.append(f'cxx = {compiler}') + if with_cuda or cuda_dlink_post_cflags: + if "PYTORCH_NVCC" in os.environ: + nvcc = os.getenv("PYTORCH_NVCC") # user can set nvcc compiler with ccache using the environment variable here + else: + if IS_HIP_EXTENSION: + nvcc = _get_hipcc_path() + else: + nvcc = _join_cuda_home('bin', 'nvcc') + config.append(f'nvcc = {nvcc}') + if with_sycl or sycl_dlink_post_cflags: + sycl = 'icx' if IS_WINDOWS else 'icpx' + config.append(f'sycl = {sycl}') + + if IS_HIP_EXTENSION: + post_cflags = COMMON_HIP_FLAGS + post_cflags + flags = [f'cflags = {" ".join(cflags)}'] + flags.append(f'post_cflags = {" ".join(post_cflags)}') + if with_cuda: + flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}') + flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}') + flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}') + if with_sycl: + flags.append(f'sycl_cflags = {" ".join(sycl_cflags)}') + flags.append(f'sycl_post_cflags = {" ".join(sycl_post_cflags)}') + flags.append(f'sycl_dlink_post_cflags = {" ".join(sycl_dlink_post_cflags)}') + flags.append(f'ldflags = {" ".join(ldflags)}') + + # Turn into absolute paths so we can emit them into the ninja build + # file wherever it is. + sources = [os.path.abspath(file) for file in sources] + + # See https://ninja-build.org/build.ninja.html for reference. + compile_rule = ['rule compile'] + if IS_WINDOWS: + compiler_name = "$cxx" if IS_HIP_EXTENSION else "cl" + compile_rule.append( + f' command = {compiler_name} /showIncludes $cflags -c $in /Fo$out $post_cflags') + if not IS_HIP_EXTENSION: + compile_rule.append(' deps = msvc') + else: + compile_rule.append( + ' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags') + compile_rule.append(' depfile = $out.d') + compile_rule.append(' deps = gcc') + + if with_cuda: + cuda_compile_rule = ['rule cuda_compile'] + nvcc_gendeps = '' + # --generate-dependencies-with-compile is not supported by ROCm + # Nvcc flag `--generate-dependencies-with-compile` is not supported by sccache, which may increase build time. + if torch.version.cuda is not None and os.getenv('TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES', '0') != '1': + cuda_compile_rule.append(' depfile = $out.d') + cuda_compile_rule.append(' deps = gcc') + # Note: non-system deps with nvcc are only supported + # on Linux so use --generate-dependencies-with-compile + # to make this work on Windows too. + nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d' + cuda_compile_rule.append( + f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') + + if with_sycl: + sycl_compile_rule = ['rule sycl_compile'] + # SYCL compiler does not recognize .sycl extension automatically, + # so we pass '-x c++' explicitly notifying compiler of file format + sycl_compile_rule.append( + ' command = $sycl $sycl_cflags -c -x c++ $in -o $out $sycl_post_cflags') + + + # Emit one build rule per source to enable incremental build. + build = [] + for source_file, object_file in zip(sources, objects): + is_cuda_source = _is_cuda_file(source_file) and with_cuda + is_sycl_source = _is_sycl_file(source_file) and with_sycl + if is_cuda_source: + rule = 'cuda_compile' + elif is_sycl_source: + rule = 'sycl_compile' + else: + rule = 'compile' + if IS_WINDOWS: + source_file = source_file.replace(':', '$:') + object_file = object_file.replace(':', '$:') + source_file = source_file.replace(" ", "$ ") + object_file = object_file.replace(" ", "$ ") + build.append(f'build {object_file}: {rule} {source_file}') + + if cuda_dlink_post_cflags: + cuda_devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o') + cuda_devlink_rule = ['rule cuda_devlink'] + cuda_devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags') + cuda_devlink = [f'build {cuda_devlink_out}: cuda_devlink {" ".join(objects)}'] + objects += [cuda_devlink_out] + else: + cuda_devlink_rule, cuda_devlink = [], [] + + if sycl_dlink_post_cflags: + sycl_devlink_out = os.path.join(os.path.dirname(objects[0]), 'sycl_dlink.o') + sycl_devlink_rule = ['rule sycl_devlink'] + sycl_devlink_rule.append(' command = $sycl $in -o $out $sycl_dlink_post_cflags') + sycl_devlink = [f'build {sycl_devlink_out}: sycl_devlink {" ".join(objects)}'] + objects += [sycl_devlink_out] + else: + sycl_devlink_rule, sycl_devlink = [], [] + + if library_target is not None: + link_rule = ['rule link'] + if IS_WINDOWS: + cl_paths = subprocess.check_output(['where', + 'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n') + if len(cl_paths) >= 1: + cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:') + else: + raise RuntimeError("MSVC is required to load C++ extensions") + link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out') + else: + link_rule.append(' command = $cxx $in $ldflags -o $out') + + link = [f'build {library_target}: link {" ".join(objects)}'] + + default = [f'default {library_target}'] + else: + link_rule, link, default = [], [], [] + + # 'Blocks' should be separated by newlines, for visual benefit. + blocks = [config, flags, compile_rule] + if with_cuda: + blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] + if with_sycl: + blocks.append(sycl_compile_rule) # type: ignore[possibly-undefined] + blocks += [cuda_devlink_rule, sycl_devlink_rule, link_rule, build, cuda_devlink, sycl_devlink, link, default] + content = "\n\n".join("\n".join(b) for b in blocks) + # Ninja requires a new lines at the end of the .ninja file + content += "\n" + _maybe_write(path, content) + +def _join_cuda_home(*paths) -> str: + """ + Join paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set. + + This is basically a lazy way of raising an error for missing $CUDA_HOME + only once we need to get any CUDA-specific path. + """ + if CUDA_HOME is None: + raise OSError('CUDA_HOME environment variable is not set. ' + 'Please set it to your CUDA install root.') + return os.path.join(CUDA_HOME, *paths) + + +def _is_cuda_file(path: str) -> bool: + valid_ext = ['.cu', '.cuh'] + if IS_HIP_EXTENSION: + valid_ext.append('.hip') + return os.path.splitext(path)[1] in valid_ext + +def _is_sycl_file(path: str) -> bool: + valid_ext = ['.sycl'] + return os.path.splitext(path)[1] in valid_ext diff --git a/phivenv/Lib/site-packages/torch/utils/data/__init__.py b/phivenv/Lib/site-packages/torch/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b4ab60ecc23667035a28c8d234d447bf7230ac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/__init__.py @@ -0,0 +1,77 @@ +from torch.utils.data.dataloader import ( + _DatasetKind, + DataLoader, + default_collate, + default_convert, + get_worker_info, +) +from torch.utils.data.datapipes._decorator import ( + argument_validation, + functional_datapipe, + guaranteed_datapipes_determinism, + non_deterministic, + runtime_validation, + runtime_validation_disabled, +) +from torch.utils.data.datapipes.datapipe import ( + DataChunk, + DFIterDataPipe, + IterDataPipe, + MapDataPipe, +) +from torch.utils.data.dataset import ( + ChainDataset, + ConcatDataset, + Dataset, + IterableDataset, + random_split, + StackDataset, + Subset, + TensorDataset, +) +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, + SubsetRandomSampler, + WeightedRandomSampler, +) + + +__all__ = [ + "BatchSampler", + "ChainDataset", + "ConcatDataset", + "DFIterDataPipe", + "DataChunk", + "DataLoader", + "Dataset", + "DistributedSampler", + "IterDataPipe", + "IterableDataset", + "MapDataPipe", + "RandomSampler", + "Sampler", + "SequentialSampler", + "StackDataset", + "Subset", + "SubsetRandomSampler", + "TensorDataset", + "WeightedRandomSampler", + "_DatasetKind", + "argument_validation", + "default_collate", + "default_convert", + "functional_datapipe", + "get_worker_info", + "guaranteed_datapipes_determinism", + "non_deterministic", + "random_split", + "runtime_validation", + "runtime_validation_disabled", +] + +# Please keep this list sorted +assert __all__ == sorted(__all__) diff --git a/phivenv/Lib/site-packages/torch/utils/data/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..960280ba30f45f9d980bfb6e3a0a0db12f08bc6a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/__pycache__/backward_compatibility.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/backward_compatibility.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43bc95f8a2708a2142f327973890c5fc811bfdae Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/backward_compatibility.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/__pycache__/dataloader.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/dataloader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6a3891abd9afc7d177e17ebbe2d5f137ca1cb21 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/dataloader.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/__pycache__/dataset.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc618be2b2c5a073ac82d1ac8f9e9ce5a3b961cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/dataset.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/__pycache__/distributed.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/distributed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6972385409a68a40ac0706955538cd0a06bff2f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/distributed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/__pycache__/graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4441f82b247aaf71477097d2e32d6e318d59cd4c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b730f464cb0ff943a7cb29566242d615d62ab94b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/graph_settings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/__pycache__/sampler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/sampler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60005accb73a38cb2af84ab7ea937fdb5df5d861 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/__pycache__/sampler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/__init__.py b/phivenv/Lib/site-packages/torch/utils/data/_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36af1789a2c3b2ed9fd2161d686b5d12596312a2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/_utils/__init__.py @@ -0,0 +1,54 @@ +# mypy: allow-untyped-defs +r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py. + +A lot of multiprocessing is used in data loading, which only supports running +functions defined in global environment (py2 can't serialize static methods). +Therefore, for code tidiness we put these functions into different files in this +folder. +""" + +import atexit +import sys + +# old private location of the ExceptionWrapper that some users rely on: +from torch._utils import ExceptionWrapper + + +IS_WINDOWS = sys.platform == "win32" + + +MP_STATUS_CHECK_INTERVAL = 5.0 +r"""Interval (in seconds) to check status of processes to avoid hanging in + multiprocessing data loading. This is mainly used in getting data from + another process, in which case we need to periodically check whether the + sender is alive to prevent hanging.""" + + +python_exit_status = False +r"""Whether Python is shutting down. This flag is guaranteed to be set before +the Python core library resources are freed, but Python may already be exiting +for some time when this is set. + +Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar +hook in Python 3.7 multiprocessing library: +https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 +""" + + +try: + import numpy + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + + +def _set_python_exit_flag(): + global python_exit_status + python_exit_status = True + + +atexit.register(_set_python_exit_flag) + + +from . import collate, fetch, pin_memory, signal_handling, worker diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ed012d8e624fbf5557c73e9fcb577ce0d89e686 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01f14c608723c37985485a161eb66d1d0cc56618 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/collate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..699087650ab3f257cc091eefb88806659808c705 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/fetch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cfd3102e61102c2996ab701c9c6318e557e6a90 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/pin_memory.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edd4387fdb6e1fba24365c96e62dac9c5518e184 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/signal_handling.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85a5b8e022504c0082e12f4b39173645bc6fda73 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/_utils/__pycache__/worker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/collate.py b/phivenv/Lib/site-packages/torch/utils/data/_utils/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..d32f71f1a2e55e17630ec752ad02dafa41c6e61d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/_utils/collate.py @@ -0,0 +1,398 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. + +These methods are used to collate samples fetched from dataset into Tensor(s). +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. + +`default_collate` and `default_convert` are exposed to users via 'dataloader.py'. +""" + +import collections +import contextlib +import copy +import re +from typing import Callable, Optional, Union + +import torch + + +np_str_obj_array_pattern = re.compile(r"[SaUO]") + + +def default_convert(data): + r""" + Convert each NumPy array element into a :class:`torch.Tensor`. + + If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`. + If the input is not an NumPy array, it is left unchanged. + This is used as the default function for collation when both `batch_sampler` and `batch_size` + are NOT defined in :class:`~torch.utils.data.DataLoader`. + + The general input type to output type mapping is similar to that + of :func:`~torch.utils.data.default_collate`. See the description there for more details. + + Args: + data: a single data point to be converted + + Examples: + >>> # xdoctest: +SKIP + >>> # Example with `int` + >>> default_convert(0) + 0 + >>> # Example with NumPy array + >>> default_convert(np.array([0, 1])) + tensor([0, 1]) + >>> # Example with NamedTuple + >>> Point = namedtuple('Point', ['x', 'y']) + >>> default_convert(Point(0, 0)) + Point(x=0, y=0) + >>> default_convert(Point(np.array(0), np.array(0))) + Point(x=tensor(0), y=tensor(0)) + >>> # Example with List + >>> default_convert([np.array([0, 1]), np.array([2, 3])]) + [tensor([0, 1]), tensor([2, 3])] + """ + elem_type = type(data) + if isinstance(data, torch.Tensor): + return data + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): + # array of string classes and object + if ( + elem_type.__name__ == "ndarray" + and np_str_obj_array_pattern.search(data.dtype.str) is not None + ): + return data + return torch.as_tensor(data) + elif isinstance(data, collections.abc.Mapping): + try: + if isinstance(data, collections.abc.MutableMapping): + # The mapping type may have extra properties, so we can't just + # use `type(data)(...)` to create the new mapping. + # Create a clone and update it if the mapping type is mutable. + clone = copy.copy(data) + clone.update({key: default_convert(data[key]) for key in data}) + return clone + else: + return elem_type({key: default_convert(data[key]) for key in data}) + except TypeError: + # The mapping type may not support `copy()` / `update(mapping)` + # or `__init__(iterable)`. + return {key: default_convert(data[key]) for key in data} + elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple + return elem_type(*(default_convert(d) for d in data)) + elif isinstance(data, tuple): + return [default_convert(d) for d in data] # Backwards compatibility. + elif isinstance(data, collections.abc.Sequence) and not isinstance( + data, (str, bytes) + ): + try: + if isinstance(data, collections.abc.MutableSequence): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(data) # type: ignore[arg-type] + for i, d in enumerate(data): + clone[i] = default_convert(d) + return clone + else: + return elem_type([default_convert(d) for d in data]) + except TypeError: + # The sequence type may not support `copy()` / `__setitem__(index, item)` + # or `__init__(iterable)` (e.g., `range`). + return [default_convert(d) for d in data] + else: + return data + + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}" +) + + +def collate( + batch, + *, + collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, +): + r""" + General collate function that handles collection type of element within each batch. + + The function also opens function registry to deal with specific element types. `default_collate_fn_map` + provides default collate functions for tensors, numpy arrays, numbers and strings. + + Args: + batch: a single batch to be collated + collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function. + If the element type isn't present in this dictionary, + this function will go through each key of the dictionary in the insertion order to + invoke the corresponding collate function if the element type is a subclass of the key. + + Examples: + >>> def collate_tensor_fn(batch, *, collate_fn_map): + ... # Extend this function to handle batch of tensors + ... return torch.stack(batch, 0) + >>> def custom_collate(batch): + ... collate_map = {torch.Tensor: collate_tensor_fn} + ... return collate(batch, collate_fn_map=collate_map) + >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` + >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn}) + + Note: + Each collate function requires a positional argument for batch and a keyword argument + for the dictionary of collate functions as `collate_fn_map`. + """ + elem = batch[0] + elem_type = type(elem) + + if collate_fn_map is not None: + if elem_type in collate_fn_map: + return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) + + for collate_type in collate_fn_map: + if isinstance(elem, collate_type): + return collate_fn_map[collate_type]( + batch, collate_fn_map=collate_fn_map + ) + + if isinstance(elem, collections.abc.Mapping): + try: + if isinstance(elem, collections.abc.MutableMapping): + # The mapping type may have extra properties, so we can't just + # use `type(data)(...)` to create the new mapping. + # Create a clone and update it if the mapping type is mutable. + clone = copy.copy(elem) + clone.update( + { + key: collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map + ) + for key in elem + } + ) + return clone + else: + return elem_type( + { + key: collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map + ) + for key in elem + } + ) + except TypeError: + # The mapping type may not support `copy()` / `update(mapping)` + # or `__init__(iterable)`. + return { + key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) + for key in elem + } + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple + return elem_type( + *( + collate(samples, collate_fn_map=collate_fn_map) + for samples in zip(*batch) + ) + ) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError("each element in list of batch should be of equal size") + transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. + + if isinstance(elem, tuple): + return [ + collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] # Backwards compatibility. + else: + try: + if isinstance(elem, collections.abc.MutableSequence): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(elem) # type: ignore[arg-type] + for i, samples in enumerate(transposed): + clone[i] = collate(samples, collate_fn_map=collate_fn_map) + return clone + else: + return elem_type( + [ + collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] + ) + except TypeError: + # The sequence type may not support `copy()` / `__setitem__(index, item)` + # or `__init__(iterable)` (e.g., `range`). + return [ + collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) + + +def collate_tensor_fn( + batch, + *, + collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, +): + elem = batch[0] + out = None + if elem.is_nested: + raise RuntimeError( + "Batches of nested tensors are not currently supported by the default collate_fn; " + "please provide a custom collate_fn to handle them appropriately." + ) + if elem.layout in { + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_bsr, + torch.sparse_csc, + torch.sparse_bsc, + }: + raise RuntimeError( + "Batches of sparse tensors are not currently supported by the default collate_fn; " + "please provide a custom collate_fn to handle them appropriately." + ) + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem._typed_storage()._new_shared(numel, device=elem.device) + out = elem.new(storage).resize_(len(batch), *list(elem.size())) + return torch.stack(batch, 0, out=out) + + +def collate_numpy_array_fn( + batch, + *, + collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, +): + elem = batch[0] + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map) + + +def collate_numpy_scalar_fn( + batch, + *, + collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, +): + return torch.as_tensor(batch) + + +def collate_float_fn( + batch, + *, + collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, +): + return torch.tensor(batch, dtype=torch.float64) + + +def collate_int_fn( + batch, + *, + collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, +): + return torch.tensor(batch) + + +def collate_str_fn( + batch, + *, + collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, +): + return batch + + +default_collate_fn_map: dict[Union[type, tuple[type, ...]], Callable] = { + torch.Tensor: collate_tensor_fn +} +with contextlib.suppress(ImportError): + import numpy as np + + # For both ndarray and memmap (subclass of ndarray) + default_collate_fn_map[np.ndarray] = collate_numpy_array_fn + # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html + # Skip string scalars + default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn +default_collate_fn_map[float] = collate_float_fn +default_collate_fn_map[int] = collate_int_fn +default_collate_fn_map[str] = collate_str_fn +default_collate_fn_map[bytes] = collate_str_fn + + +def default_collate(batch): + r""" + Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size. + + The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a + Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type. + This is used as the default function for collation when + `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`. + + Here is the general input type (based on the type of the element within the batch) to output type mapping: + + * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size) + * NumPy Arrays -> :class:`torch.Tensor` + * `float` -> :class:`torch.Tensor` + * `int` -> :class:`torch.Tensor` + * `str` -> `str` (unchanged) + * `bytes` -> `bytes` (unchanged) + * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]` + * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), + default_collate([V2_1, V2_2, ...]), ...]` + * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), + default_collate([V2_1, V2_2, ...]), ...]` + + Args: + batch: a single batch to be collated + + Examples: + >>> # xdoctest: +SKIP + >>> # Example with a batch of `int`s: + >>> default_collate([0, 1, 2, 3]) + tensor([0, 1, 2, 3]) + >>> # Example with a batch of `str`s: + >>> default_collate(['a', 'b', 'c']) + ['a', 'b', 'c'] + >>> # Example with `Map` inside the batch: + >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) + {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} + >>> # Example with `NamedTuple` inside the batch: + >>> Point = namedtuple('Point', ['x', 'y']) + >>> default_collate([Point(0, 0), Point(1, 1)]) + Point(x=tensor([0, 1]), y=tensor([0, 1])) + >>> # Example with `Tuple` inside the batch: + >>> default_collate([(0, 1), (2, 3)]) + [tensor([0, 2]), tensor([1, 3])] + >>> # Example with `List` inside the batch: + >>> default_collate([[0, 1], [2, 3]]) + [tensor([0, 2]), tensor([1, 3])] + >>> # Two options to extend `default_collate` to handle specific type + >>> # Option 1: Write custom collate function and invoke `default_collate` + >>> def custom_collate(batch): + ... elem = batch[0] + ... if isinstance(elem, CustomType): # Some custom condition + ... return ... + ... else: # Fall back to `default_collate` + ... return default_collate(batch) + >>> # Option 2: In-place modify `default_collate_fn_map` + >>> def collate_customtype_fn(batch, *, collate_fn_map=None): + ... return ... + >>> default_collate_fn_map.update(CustomType, collate_customtype_fn) + >>> default_collate(batch) # Handle `CustomType` automatically + """ + return collate(batch, collate_fn_map=default_collate_fn_map) diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/fetch.py b/phivenv/Lib/site-packages/torch/utils/data/_utils/fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..793226c49fa80e2d6d5230d147c0b56465a0541a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/_utils/fetch.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset. + +This logic is shared in both single- and multi-processing data loading. +""" + + +class _BaseDatasetFetcher: + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + self.dataset = dataset + self.auto_collation = auto_collation + self.collate_fn = collate_fn + self.drop_last = drop_last + + def fetch(self, possibly_batched_index): + raise NotImplementedError + + +class _IterableDatasetFetcher(_BaseDatasetFetcher): + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + super().__init__(dataset, auto_collation, collate_fn, drop_last) + self.dataset_iter = iter(dataset) + self.ended = False + + def fetch(self, possibly_batched_index): + if self.ended: + raise StopIteration + + if self.auto_collation: + data = [] + for _ in possibly_batched_index: + try: + data.append(next(self.dataset_iter)) + except StopIteration: + self.ended = True + break + if len(data) == 0 or ( + self.drop_last and len(data) < len(possibly_batched_index) + ): + raise StopIteration + else: + data = next(self.dataset_iter) + return self.collate_fn(data) + + +class _MapDatasetFetcher(_BaseDatasetFetcher): + def fetch(self, possibly_batched_index): + if self.auto_collation: + if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: + data = self.dataset.__getitems__(possibly_batched_index) + else: + data = [self.dataset[idx] for idx in possibly_batched_index] + else: + data = self.dataset[possibly_batched_index] + return self.collate_fn(data) diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/pin_memory.py b/phivenv/Lib/site-packages/torch/utils/data/_utils/pin_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..19d127313d41140ede58e35ac31db0baa1192ee4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/_utils/pin_memory.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import collections +import copy +import queue + +import torch +from torch._utils import ExceptionWrapper + +from . import MP_STATUS_CHECK_INTERVAL + + +def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): + # This setting is thread local, and prevents the copy in pin_memory from + # consuming all CPU cores. + torch.set_num_threads(1) + + torch.multiprocessing._set_thread_name("pt_data_pin") + + if device == "cuda": + torch.cuda.set_device(device_id) + elif device == "xpu": + torch.xpu.set_device(device_id) # type: ignore[attr-defined] + elif device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + custom_device_mod.set_device(device_id) + elif device is None: + torch.accelerator.set_device_index(device_id) + + def do_one_step(): + try: + r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + return + idx, data = r + if not done_event.is_set() and not isinstance(data, ExceptionWrapper): + try: + data = pin_memory(data, device) + except Exception: + data = ExceptionWrapper( + where=f"in pin memory thread for device {device_id}" + ) + r = (idx, data) + while not done_event.is_set(): + try: + out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) + break + except queue.Full: + continue + + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + while not done_event.is_set(): + # Make sure that we don't preserve any object from one iteration + # to the next + do_one_step() + + +def pin_memory(data, device=None): + if isinstance(data, torch.Tensor): + return data.pin_memory(device) + elif isinstance(data, (str, bytes)): + return data + elif isinstance(data, collections.abc.Mapping): + try: + if isinstance(data, collections.abc.MutableMapping): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(data) + clone.update( + {k: pin_memory(sample, device) for k, sample in data.items()} + ) + return clone + else: + return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg] + except TypeError: + # The mapping type may not support `copy()` / `update(mapping)` + # or `__init__(iterable)`. + return {k: pin_memory(sample, device) for k, sample in data.items()} + elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple + return type(data)(*(pin_memory(sample, device) for sample in data)) + elif isinstance(data, tuple): + return [ + pin_memory(sample, device) for sample in data + ] # Backwards compatibility. + elif isinstance(data, collections.abc.Sequence): + try: + if isinstance(data, collections.abc.MutableSequence): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(data) # type: ignore[arg-type] + for i, item in enumerate(data): + clone[i] = pin_memory(item, device) + return clone + return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg] + except TypeError: + # The sequence type may not support `copy()` / `__setitem__(index, item)` + # or `__init__(iterable)` (e.g., `range`). + return [pin_memory(sample, device) for sample in data] + elif hasattr(data, "pin_memory"): + return data.pin_memory() + else: + return data diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/signal_handling.py b/phivenv/Lib/site-packages/torch/utils/data/_utils/signal_handling.py new file mode 100644 index 0000000000000000000000000000000000000000..30e2204d089307e91c5d4f9fce69eb6c8c6b692e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/_utils/signal_handling.py @@ -0,0 +1,79 @@ +# mypy: allow-untyped-defs +r"""Signal handling for multiprocessing data loading. + +NOTE [ Signal handling in multiprocessing data loading ] + +In cases like DataLoader, if a worker process dies due to bus error/segfault +or just hang, the main process will hang waiting for data. This is difficult +to avoid on PyTorch side as it can be caused by limited shm, or other +libraries users call in the workers. In this file and `DataLoader.cpp`, we make +our best effort to provide some error message to users when such unfortunate +events happen. + +When a _BaseDataLoaderIter starts worker processes, their pids are registered in a +defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ] +via `_set_worker_pids`. + +When an error happens in a worker process, the main process received a SIGCHLD, +and Python will eventually call the handler registered below +(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails` +call checks all registered worker pids and raise proper error message to +prevent main process from hanging waiting for data from worker. + +Additionally, at the beginning of each worker's `_utils.worker._worker_loop`, +`_set_worker_signal_handlers` is called to register critical signal handlers +(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error +message to stderr before triggering the default handler. So a message will also +be printed from the worker process when it is killed by such signals. + +See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of +this signal handling design and other mechanism we implement to make our +multiprocessing data loading robust to errors. +""" + +import signal +import threading + +# Some of the following imported functions are not used in this file, but are to +# be used `_utils.signal_handling.XXXXX`. +from torch._C import ( # noqa: F401 + _error_if_any_worker_fails, + _remove_worker_pids, + _set_worker_pids, + _set_worker_signal_handlers, +) + +from . import IS_WINDOWS + + +_SIGCHLD_handler_set = False +r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one +handler needs to be set for all DataLoaders in a process.""" + + +def _set_SIGCHLD_handler(): + # Windows doesn't support SIGCHLD handler + if IS_WINDOWS: + return + # can't set signal in child threads + if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore[attr-defined] + return + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + previous_handler = signal.getsignal(signal.SIGCHLD) + if not callable(previous_handler): + # This doesn't catch default handler, but SIGCHLD default handler is a + # no-op. + previous_handler = None + + def handler(signum, frame): + # This following call uses `waitid` with WNOHANG from C side. Therefore, + # Python can still get and update the process status successfully. + _error_if_any_worker_fails() + if previous_handler is not None: + assert callable(previous_handler) + previous_handler(signum, frame) + + signal.signal(signal.SIGCHLD, handler) + _SIGCHLD_handler_set = True diff --git a/phivenv/Lib/site-packages/torch/utils/data/_utils/worker.py b/phivenv/Lib/site-packages/torch/utils/data/_utils/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..1d6ff8e0c4cf3287fb8aef751fd68fa41d5f3a52 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/_utils/worker.py @@ -0,0 +1,374 @@ +# mypy: allow-untyped-defs +r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import os +import queue +import random +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING, Union + +import torch +from torch._utils import ExceptionWrapper + +from . import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling + + +if TYPE_CHECKING: + from torch.utils.data import Dataset + +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import BOOL, DWORD, HANDLE + + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog: + def __init__(self) -> None: + self.manager_pid = os.getppid() + + # mypy cannot detect this code is windows only + self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess( + SYNCHRONIZE, 0, self.manager_pid + ) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined] + + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = ( + self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + ) + return not self.manager_dead + +else: + + class ManagerWatchdog: # type: ignore[no-redef] + def __init__(self) -> None: + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + + +_worker_info: Optional["WorkerInfo"] = None + + +class WorkerInfo: + id: int + num_workers: int + seed: int + dataset: "Dataset" + __initialized = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__keys = tuple(kwargs.keys()) + self.__initialized = True + + def __setattr__(self, key, val): + if self.__initialized: + raise RuntimeError( + f"Cannot assign attributes to {self.__class__.__name__} objects" + ) + return super().__setattr__(key, val) + + def __repr__(self): + items = [f"{k}={getattr(self, k)}" for k in self.__keys] + return f"{self.__class__.__name__}({', '.join(items)})" + + +def get_worker_info() -> Optional[WorkerInfo]: + r"""Returns the information about the current + :class:`~torch.utils.data.DataLoader` iterator worker process. + + When called in a worker, this returns an object guaranteed to have the + following attributes: + + * :attr:`id`: the current worker id. + * :attr:`num_workers`: the total number of workers. + * :attr:`seed`: the random seed set for the current worker. This value is + determined by main process RNG and the worker id. See + :class:`~torch.utils.data.DataLoader`'s documentation for more details. + * :attr:`dataset`: the copy of the dataset object in **this** process. Note + that this will be a different object in a different process than the one + in the main process. + + When called in the main process, this returns ``None``. + + .. note:: + When used in a :attr:`worker_init_fn` passed over to + :class:`~torch.utils.data.DataLoader`, this method can be useful to + set up each worker process differently, for instance, using ``worker_id`` + to configure the ``dataset`` object to only read a specific fraction of a + sharded dataset, or use ``seed`` to seed other libraries used in dataset + code. + """ + return _worker_info + + +r"""Dummy class used to signal the end of an IterableDataset""" + + +@dataclass(frozen=True) +class _IterableDatasetStopIteration: + worker_id: int + + +r"""Dummy class used to resume the fetching when worker reuse is enabled""" + + +@dataclass(frozen=True) +class _ResumeIteration: + seed: Optional[int] = None + + +# The function `_generate_state` is adapted from `numpy.random.SeedSequence` +# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx +# It's MIT licensed, here is the copyright: + +# Copyright (c) 2015 Melissa E. O'Neill +# Copyright (c) 2019 NumPy Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +# This function generates an array of int32 as the seed for +# `numpy.random`, in order to prevent state collision due to same +# seed and algorithm for `numpy.random` and `random` modules. +# TODO: Implement `SeedSequence` like object for `torch.random` +def _generate_state(base_seed, worker_id): + INIT_A = 0x43B0D7E5 + MULT_A = 0x931E8875 + INIT_B = 0x8B51F9DD + MULT_B = 0x58F38DED + MIX_MULT_L = 0xCA01F9DD + MIX_MULT_R = 0x4973F715 + XSHIFT = 4 * 8 // 2 + MASK32 = 0xFFFFFFFF + + entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] + pool = [0] * 4 + + hash_const_A = INIT_A + + def hash(value): + nonlocal hash_const_A + value = (value ^ hash_const_A) & MASK32 + hash_const_A = (hash_const_A * MULT_A) & MASK32 + value = (value * hash_const_A) & MASK32 + value = (value ^ (value >> XSHIFT)) & MASK32 + return value + + def mix(x, y): + result_x = (MIX_MULT_L * x) & MASK32 + result_y = (MIX_MULT_R * y) & MASK32 + result = (result_x - result_y) & MASK32 + result = (result ^ (result >> XSHIFT)) & MASK32 + return result + + # Add in the entropy to the pool. + for i in range(len(pool)): + pool[i] = hash(entropy[i]) + + # Mix all bits together so late bits can affect earlier bits. + for i_src in range(len(pool)): + for i_dst in range(len(pool)): + if i_src != i_dst: + pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) + + hash_const_B = INIT_B + state = [] + for i_dst in range(4): + data_val = pool[i_dst] + data_val = (data_val ^ hash_const_B) & MASK32 + hash_const_B = (hash_const_B * MULT_B) & MASK32 + data_val = (data_val * hash_const_B) & MASK32 + data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 + state.append(data_val) + return state + + +def _worker_loop( + dataset_kind, + dataset, + index_queue, + data_queue, + done_event, + auto_collation, + collate_fn, + drop_last, + base_seed, + init_fn, + worker_id, + num_workers, + persistent_workers, + shared_seed, +): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + + try: + # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal had already happened + # again. + # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers + signal_handling._set_worker_signal_handlers() + + torch.multiprocessing._set_thread_name("pt_data_worker") + + torch.set_num_threads(1) + seed = base_seed + worker_id + random.seed(seed) + torch.manual_seed(seed) + if HAS_NUMPY: + np_seed = _generate_state(base_seed, worker_id) + import numpy as np + + np.random.seed(np_seed) + + from torch.utils.data import IterDataPipe + from torch.utils.data.graph_settings import apply_random_seed + + shared_rng = torch.Generator() + if isinstance(dataset, IterDataPipe): + assert shared_seed is not None + shared_rng.manual_seed(shared_seed) + dataset = apply_random_seed(dataset, shared_rng) + + global _worker_info + _worker_info = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset + ) + + from torch.utils.data import _DatasetKind + + init_exception = None + + try: + if init_fn is not None: + init_fn(worker_id) + + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collation, collate_fn, drop_last + ) + except Exception: + init_exception = ExceptionWrapper( + where=f"in DataLoader worker process {worker_id}" + ) + + # When using Iterable mode, some worker can exit earlier than others due + # to the IterableDataset behaving differently for different workers. + # When such things happen, an `_IterableDatasetStopIteration` object is + # sent over to the main process with the ID of this worker, so that the + # main process won't send more tasks to this worker, and will send + # `None` to this worker to properly exit it. + # + # Note that we cannot set `done_event` from a worker as it is shared + # among all processes. Instead, we set the `iteration_end` flag to + # signify that the iterator is exhausted. When either `done_event` or + # `iteration_end` is set, we skip all processing step and just wait for + # `None`. + iteration_end = False + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if isinstance(r, _ResumeIteration): + # Acknowledge the main process + data_queue.put((r, None)) + iteration_end = False + + if isinstance(dataset, IterDataPipe): + assert r.seed is not None + shared_rng.manual_seed(r.seed) + dataset = apply_random_seed(dataset, shared_rng) + + # Recreate the fetcher for worker-reuse policy + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collation, collate_fn, drop_last + ) + continue + elif r is None: + # Received the final signal + assert done_event.is_set() or iteration_end + break + elif done_event.is_set() or iteration_end: + # `done_event` is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, index = r + data: Union[_IterableDatasetStopIteration, ExceptionWrapper] + if init_exception is not None: + data = init_exception + init_exception = None + else: + try: + data = fetcher.fetch(index) # type: ignore[possibly-undefined] + except Exception as e: + if ( + isinstance(e, StopIteration) + and dataset_kind == _DatasetKind.Iterable + ): + data = _IterableDatasetStopIteration(worker_id) + # Set `iteration_end` + # (1) to save future `next(...)` calls, and + # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. + iteration_end = True + else: + # It is important that we don't store exc_info in a variable. + # `ExceptionWrapper` does the correct thing. + # See NOTE [ Python Traceback Reference Cycle Problem ] + data = ExceptionWrapper( + where=f"in DataLoader worker process {worker_id}" + ) + data_queue.put((idx, data)) + del data, idx, index, r # save memory + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass + if done_event.is_set(): + data_queue.cancel_join_thread() + data_queue.close() diff --git a/phivenv/Lib/site-packages/torch/utils/data/backward_compatibility.py b/phivenv/Lib/site-packages/torch/utils/data/backward_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..31061e849dc5b19d53605b43e1e06ce3aec5588c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/backward_compatibility.py @@ -0,0 +1,11 @@ +# mypy: allow-untyped-defs +from typing_extensions import deprecated as _deprecated + + +@_deprecated( + "Usage of `backward_compatibility.worker_init_fn` is deprecated " + "as `DataLoader` automatically applies sharding in every worker", + category=FutureWarning, +) +def worker_init_fn(worker_id): + pass diff --git a/phivenv/Lib/site-packages/torch/utils/data/dataloader.py b/phivenv/Lib/site-packages/torch/utils/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..91a933b5d64c200a090d2594110903ecd7fdc4b1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/dataloader.py @@ -0,0 +1,1664 @@ +# mypy: allow-untyped-defs +r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter. + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" +from __future__ import annotations + +import functools +import itertools +import logging +import multiprocessing as python_multiprocessing +import os +import queue +import threading +import warnings +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import Self + +import torch +import torch.distributed as dist +import torch.utils.data.graph_settings +from torch._utils import ExceptionWrapper +from torch.utils.data import _utils +from torch.utils.data.datapipes.datapipe import ( + _IterDataPipeSerializationWrapper, + _MapDataPipeSerializationWrapper, + IterDataPipe, + MapDataPipe, +) +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable + +__all__ = [ + "DataLoader", + "get_worker_info", + "default_collate", + "default_convert", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_worker_init_fn_t = Callable[[int], None] + +# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that +# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. +# See https://github.com/python/mypy/issues/3737. +_collate_fn_t = Callable[[list[_T]], Any] + + +# These functions used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import torch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate: _collate_fn_t = _utils.collate.default_collate +default_convert = _utils.collate.default_convert + +get_worker_info = _utils.worker.get_worker_info + +logger = logging.getLogger(__name__) + + +class _DatasetKind: + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + else: + return _utils.fetch._IterableDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + + Used as sampler for :class:`~torch.utils.data.IterableDataset`. + """ + + def __iter__(self): + while True: + yield None + + +def _get_distributed_settings(): + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + else: + return 1, 0 + + +def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id): + global_worker_id = worker_id + info = torch.utils.data.get_worker_info() + assert info is not None + total_workers = info.num_workers + datapipe = info.dataset + assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) + # To distribute elements across distributed process evenly, we should shard data on distributed + # processes first then shard on worker processes + total_workers *= world_size + global_worker_id = global_worker_id * world_size + rank_id + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + datapipe, total_workers, global_worker_id + ) + if worker_init_fn is not None: + worker_init_fn(worker_id) + + +def _share_dist_seed(generator, pg): + _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator) + if isinstance(pg, dist.ProcessGroup): + dist.broadcast(_shared_seed, src=0, group=pg) + return _shared_seed.item() + + +class DataLoader(Generic[_T_co]): + r""" + Data loader combines a dataset and a sampler, and provides an iterable over the given dataset. + + The :class:`~torch.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`torch.utils.data` documentation page for more details. + + Args: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler or Iterable, optional): defines the strategy to draw + samples from the dataset. Can be any ``Iterable`` with ``__len__`` + implemented. If specified, :attr:`shuffle` must not be specified. + batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but + returns a batch of indices at a time. Mutually exclusive with + :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, + and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (Callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (Callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If + ``None``, the default + `multiprocessing context `_ # noqa: D401 + of your operating system will + be used. (default: ``None``) + generator (torch.Generator, optional): If not ``None``, this RNG will be used + by RandomSampler to generate random indexes and multiprocessing to generate + ``base_seed`` for workers. (default: ``None``) + prefetch_factor (int, optional, keyword-only arg): Number of batches loaded + in advance by each worker. ``2`` means there will be a total of + 2 * num_workers batches prefetched across all workers. (default value depends + on the set value for num_workers. If value of num_workers=0 default is ``None``. + Otherwise, if value of ``num_workers > 0`` default is ``2``). + persistent_workers (bool, optional): If ``True``, the data loader will not shut down + the worker processes after a dataset has been consumed once. This allows to + maintain the workers `Dataset` instances alive. (default: ``False``) + pin_memory_device (str, optional): the device to :attr:`pin_memory` on if ``pin_memory`` is + ``True``. If not given, the current :ref:`accelerator` will be the + default. This argument is discouraged and subject to deprecated. + in_order (bool, optional): If ``False``, the data loader will not enforce that batches + are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``) + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, + it instead returns an estimate based on ``len(dataset) / batch_size``, with proper + rounding depending on :attr:`drop_last`, regardless of multi-process loading + configurations. This represents the best guess PyTorch can make because PyTorch + trusts user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. + + However, if sharding results in multiple workers having incomplete last batches, + this estimate can still be inaccurate, because (1) an otherwise complete batch can + be broken into multiple ones and (2) more than one batch worth of samples can be + dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such + cases in general. + + See `Dataset Types`_ for more details on these two types of datasets and how + :class:`~torch.utils.data.IterableDataset` interacts with + `Multi-process data loading`_. + + .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and + :ref:`data-loading-randomness` notes for random seed related questions. + + .. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data + distribution being fed to the trainer in cases with imbalanced data. + """ + + dataset: Dataset[_T_co] + batch_size: Optional[int] + num_workers: int + pin_memory: bool + drop_last: bool + timeout: float + sampler: Union[Sampler, Iterable] + pin_memory_device: str + prefetch_factor: Optional[int] + _iterator: Optional[_BaseDataLoaderIter] + __initialized = False + + def __init__( + self, + dataset: Dataset[_T_co], + batch_size: Optional[int] = 1, + shuffle: Optional[bool] = None, + sampler: Union[Sampler, Iterable, None] = None, + batch_sampler: Union[Sampler[list], Iterable[list], None] = None, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn: Optional[_worker_init_fn_t] = None, + multiprocessing_context=None, + generator=None, + *, + prefetch_factor: Optional[int] = None, + persistent_workers: bool = False, + pin_memory_device: str = "", + in_order: bool = True, + ) -> None: + torch._C._log_api_usage_once("python.data_loader") + + if num_workers < 0: + raise ValueError( + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." + ) + + if timeout < 0: + raise ValueError("timeout option should be non-negative") + + if num_workers == 0 and prefetch_factor is not None: + raise ValueError( + "prefetch_factor option could only be specified in multiprocessing." + "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None." + ) + elif num_workers > 0 and prefetch_factor is None: + prefetch_factor = 2 + elif prefetch_factor is not None and prefetch_factor < 0: + raise ValueError("prefetch_factor option should be non-negative") + + if persistent_workers and num_workers == 0: + raise ValueError("persistent_workers option needs num_workers > 0") + + self.dataset = dataset + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.pin_memory = pin_memory + self.pin_memory_device = pin_memory_device + self.timeout = timeout + self.worker_init_fn = worker_init_fn + self.multiprocessing_context = multiprocessing_context + self.in_order = in_order + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler + if isinstance(self.dataset, IterDataPipe): + self.dataset = _IterDataPipeSerializationWrapper(self.dataset) + elif isinstance(self.dataset, MapDataPipe): + self.dataset = _MapDataPipeSerializationWrapper(self.dataset) + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and IterableDataset ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + if isinstance(dataset, IterDataPipe): + if shuffle is not None: + dataset = torch.utils.data.graph_settings.apply_shuffle_settings( + dataset, shuffle=shuffle + ) + # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. + elif shuffle not in {False, None}: + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}" + ) + + if sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}" + ) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + f"batch_sampler option, but got batch_sampler={batch_sampler}" + ) + else: + shuffle = bool(shuffle) + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError("sampler option is mutually exclusive with shuffle") + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError( + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" + ) + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if drop_last: + raise ValueError( + "batch_size=None option disables auto-batching " + "and is mutually exclusive with drop_last" + ) + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] + else: + sampler = SequentialSampler(dataset) # type: ignore[arg-type] + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + self.generator = generator + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.persistent_workers = persistent_workers + + self.__initialized = True + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) + + self._iterator = None + + self.check_worker_number_rationality() + + torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] + + def _get_iterator(self) -> _BaseDataLoaderIter: + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIter(self) + + @property + def multiprocessing_context(self): + return self.__multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + if multiprocessing_context is not None: + if self.num_workers > 0: + if isinstance(multiprocessing_context, str): + valid_start_methods = torch.multiprocessing.get_all_start_methods() + if multiprocessing_context not in valid_start_methods: + raise ValueError( + "multiprocessing_context option " + f"should specify a valid start method in {valid_start_methods!r}, but got " + f"multiprocessing_context={multiprocessing_context!r}" + ) + multiprocessing_context = torch.multiprocessing.get_context( + multiprocessing_context + ) + + if not isinstance( + multiprocessing_context, python_multiprocessing.context.BaseContext + ): + raise TypeError( + "multiprocessing_context option should be a valid context " + "object or a string specifying the start method, but got " + f"multiprocessing_context={multiprocessing_context}" + ) + else: + raise ValueError( + "multiprocessing_context can only be used with " + "multi-process loading (num_workers > 0), but got " + f"num_workers={self.num_workers}" + ) + + self.__multiprocessing_context = multiprocessing_context + + def __setattr__(self, attr, val): + if self.__initialized and attr in ( + "batch_size", + "batch_sampler", + "sampler", + "drop_last", + "dataset", + "persistent_workers", + ): + raise ValueError( + f"{attr} attribute should not be set after {self.__class__.__name__} is initialized" + ) + + super().__setattr__(attr, val) + + def __iter__(self) -> _BaseDataLoaderIter: + # When using a single worker the returned iterator should be + # created everytime to avoid resetting its state + # However, in the case of a multiple workers iterator + # the iterator is only created once in the lifetime of the + # DataLoader object so that workers can be reused + if self.persistent_workers and self.num_workers > 0: + if self._iterator is None: + self._iterator = self._get_iterator() + else: + self._iterator._reset(self) + return self._iterator + else: + return self._get_iterator() + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self) -> int: + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + + # Cannot statically verify that dataset is Sized + length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type] + if ( + self.batch_size is not None + ): # IterableDataset doesn't allow custom sampler or batch_sampler + from math import ceil + + if self.drop_last: + length = length // self.batch_size + else: + length = ceil(length / self.batch_size) + return length + else: + return len(self._index_sampler) + + def check_worker_number_rationality(self): + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `torch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + suggested_max_worker_msg = ( + ( + ( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create." + ).format( + num_worker_suggest, + ( + "" + if cpuset_checked + else " (`cpuset` is not taken into account)" + ), + ) + ) + if num_worker_suggest is not None + else ( + "DataLoader is not able to compute a suggested max number of worker in current system." + ) + ) + + warn_msg = ( + f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary." + ) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, "sched_getaffinity"): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satisfy mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + + +class _BaseDataLoaderIter: + def __init__(self, loader: DataLoader) -> None: + self._dataset = loader.dataset + self._shared_seed = None + self._pg = None + if isinstance(self._dataset, IterDataPipe): + if dist.is_available() and dist.is_initialized(): + self._pg = dist.new_group(backend="gloo") + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + ws, rank = _get_distributed_settings() + self._world_size = ws + self._rank = rank + # If pin_memory_device not set, default behaviour is current accelerator. + # If pin_memory_device is set but pin_memory is not set, the default + # behaviour false. + if len(loader.pin_memory_device) == 0: + if loader.pin_memory and not torch.accelerator.is_available(): + warn_msg = ( + "'pin_memory' argument is set as true but no accelerator is found, " + "then device pinned memory won't be used." + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory and torch.accelerator.is_available() + self._pin_memory_device = None + # Currently, pin_memory would raise error on the MPS backend (see + # https://github.com/pytorch/pytorch/issues/86060), so forcibly + # disable pin_memory on MPS. Remove this restriction once pinned + # memory allocation for MPS is fixed. + if ( + self._pin_memory + and (acc := torch.accelerator.current_accelerator()) is not None + and acc.type == "mps" + ): + self._pin_memory = False + warn_msg = ( + "'pin_memory' argument is set as true but not supported on MPS now, " + "then device pinned memory won't be used." + ) + warnings.warn(warn_msg) + else: + if not loader.pin_memory: + warn_msg = ( + "'pin_memory_device' is set but 'pin_memory' argument is not set, " + "then device pinned memory won't be used." + "please set 'pin_memory' to true, if you need to use the device pin memory" + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory + self._pin_memory_device = loader.pin_memory_device + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = ( + torch.empty((), dtype=torch.int64) + .random_(generator=loader.generator) + .item() + ) + self._persistent_workers = loader.persistent_workers + self._num_yielded = 0 + self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__" + + def __iter__(self) -> Self: + return self + + def _reset(self, loader, first_iter=False): + self._sampler_iter = iter(self._index_sampler) + self._num_yielded = 0 + self._IterableDataset_len_called = loader._IterableDataset_len_called + if isinstance(self._dataset, IterDataPipe): + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + + def _next_index(self): + return next(self._sampler_iter) # may raise StopIteration + + def _next_data(self): + raise NotImplementedError + + def __next__(self) -> Any: + with torch.autograd.profiler.record_function(self._profile_name): + if self._sampler_iter is None: + # TODO(https://github.com/pytorch/pytorch/issues/76750) + self._reset() # type: ignore[call-arg] + data = self._next_data() + self._num_yielded += 1 + if ( + self._dataset_kind == _DatasetKind.Iterable + and self._IterableDataset_len_called is not None + and self._num_yielded > self._IterableDataset_len_called + ): + warn_msg = ( + f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}" + f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. " + ) + if self._num_workers > 0: + warn_msg += ( + "For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples." + ) + warnings.warn(warn_msg) + return data + + def __len__(self) -> int: + return len(self._index_sampler) + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader): + super().__init__(loader) + assert self._timeout == 0 + assert self._num_workers == 0 + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Taking care of distributed sharding + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + self._dataset, self._world_size, self._rank + ) + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, + ) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + if self._pin_memory: + data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) + return data + + +class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler.""" + + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone or it is depleted. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits normally or with error. + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # First of all, `__del__` is **not** guaranteed to be called when + # interpreter exits. Even if it is called, by the time it executes, + # many Python core library resources may already be freed, and even + # simple things like acquiring an internal lock of a queue may hang. + # Therefore, in this case, we actually need to prevent `__del__` from + # being executed, and rely on the automatic termination of daemonic + # children. + # + # Thus, we register an `atexit` hook that sets a global flag + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed (which, at least in + # CPython, is done via an `atexit` handler defined in + # `multiprocessing/util.py` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 + # registered when an object requiring this mechanism is first + # created, e.g., `mp.Queue` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 + # ) + # + # So in `__del__`, we check if `_utils.python_exit_status` is set or + # `None` (freed), and perform no-op if so. + # + # However, simply letting library clean-up codes run can also be bad, + # because such codes (i.e., `multiprocessing.util._exit_function()`) + # include join putting threads for `mp.Queue`, which can be blocking. + # Hence, the main process putting threads are called with + # `cancel_join_thread` at creation. See later section + # [ 3b. A process won't hang when putting into a queue; ] + # for more details. + # + # Here are two example cases where library clean-up codes can run + # before `__del__` is called: + # + # 1. If we hold onto a reference to the iterator, it more often + # than not tries to do `multiprocessing` library cleaning before + # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) + # and thus prevents our cleaning-up code to run first. + # + # 2. A similar issue araises when a `DataLoader` is used in a subprocess. + # When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # + # Finally, another choice is to just shutdown workers with logic in 1 + # above whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly by fatal signals. + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we set a timeout whenever we try to get data + # from `data_queue`, and check the workers' status on each timeout + # and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_data()` for details. + # + # Additionally, for child exit on non-Windows platforms, we also + # register a SIGCHLD handler (which is supported on Windows) on + # the main process, which checks if any of the workers fail in the + # (Python) handler. This is more efficient and faster in detecting + # worker failures, compared to only using the above mechanism. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # *exits*. + # + # In case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. The usual + # solution for this in Python is calling `q.cancel_join_thread`, + # which prevents automatically joining it when finalizing + # (exiting). + # + # Nonetheless, `cancel_join_thread` must only be called when the + # queue is **not** going to be read from or write into by another + # process, because it may hold onto a lock or leave corrupted data + # in the queue, leading other readers/writers to hang. + # + # Hence, + # + For worker processes, we only do so (for their output + # queues, i.e., `worker_result_queue`) before exiting. + # + For `pin_memory_thread`, its output queue `data_queue` is a + # `queue.Queue` that does blocking `put` if the queue is full. + # So there is no above problem, but as a result, in + # `_pin_memory_loop`, we do need to wrap the `put` in a loop + # that breaks not only upon success, but also when the main + # process stops reading, i.e., is shutting down. + # + For loader process, we `cancel_join_thread()` for all + # `_index_queues` because the whole purpose of workers and + # `pin_memory_thread` is to serve the loader process. If + # loader process is already exiting, we don't really care if + # the queues are corrupted. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iterator is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # `workers_done_event`: + # A `multiprocessing.Event` shared among the main process and all worker + # processes. This is used to signal the workers that the iterator is + # shutting down. After it is set, they will not send processed data to + # queues anymore, and only wait for the final `None` before exiting. + # `done_event` isn't strictly needed. I.e., we can just check for `None` + # from the input queue, but it allows us to skip wasting resources + # processing data if we are already shutting down. + # + # `pin_memory_thread_done_event`: + # A `threading.Event` for a similar purpose to that of + # `workers_done_event`, but is for the `pin_memory_thread`. The reason + # that separate events are needed is that `pin_memory_thread` reads from + # the output queue of the workers. But the workers, upon seeing that + # `workers_done_event` is set, only wants to see the final `None`, and is + # not required to flush all data in the output queue (e.g., it may call + # `cancel_join_thread` on that queue if its `IterableDataset` iterator + # happens to exhaust coincidentally, which is out of the control of the + # main process). Thus, since we will exit `pin_memory_thread` before the + # workers (see below), two separete events are used. + # + # NOTE: In short, the protocol is that the main process will set these + # `done_event`s and then the corresponding processes/threads a `None`, + # and that they may exit at any time after receiving the `None`. + # + # NOTE: Using `None` as the final signal is valid, since normal data will + # always be a 2-tuple with the 1st element being the index of the data + # transferred (different from dataset index/key), and the 2nd being + # either the dataset key or the data sample (depending on which part + # of the data model the queue is at). + # + # [ worker processes ] + # While loader process is alive: + # Get from `index_queue`. + # If get anything else, + # Check `workers_done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data: + # If is fetching from an `IterableDataset` and the iterator + # is exhausted, send an `_IterableDatasetStopIteration` + # object to signal iteration end. The main process, upon + # receiving such an object, will send `None` to this + # worker and not use the corresponding `index_queue` + # anymore. + # If timed out, + # No matter `workers_done_event` is set (still need to see `None`) + # or not, must continue to next iteration. + # (outside loop) + # If `workers_done_event` is set, (this can be False with `IterableDataset`) + # `data_queue.cancel_join_thread()`. (Everything is ending here: + # main process won't read from it; + # other workers will also call + # `cancel_join_thread`.) + # + # [ pin_memory_thread ] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While `pin_memory_thread_done_event` is not set: + # Get from `worker_result_queue`. + # If timed out, continue to get in the next iteration. + # Otherwise, process data. + # While `pin_memory_thread_done_event` is not set: + # Put processed data to `data_queue` (a `queue.Queue` with blocking put) + # If timed out, continue to put in the next iteration. + # Otherwise, break, i.e., continuing to the out loop. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [ main process ] + # In the DataLoader Iter's `__del__` + # b. Exit `pin_memory_thread` + # i. Set `pin_memory_thread_done_event`. + # ii Put `None` in `worker_result_queue`. + # iii. Join the `pin_memory_thread`. + # iv. `worker_result_queue.cancel_join_thread()`. + # + # c. Exit the workers. + # i. Set `workers_done_event`. + # ii. Put `None` in each worker's `index_queue`. + # iii. Join the workers. + # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. + # + # NOTE: (c) is better placed after (b) because it may leave corrupted + # data in `worker_result_queue`, which `pin_memory_thread` + # reads from, in which case the `pin_memory_thread` can only + # happen at timing out, which is slow. Nonetheless, same thing + # happens if a worker is killed by signal at unfortunate times, + # but in other cases, we are better off having a non-corrupted + # `worker_result_queue` for `pin_memory_thread`. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + + def __init__(self, loader): + super().__init__(loader) + + self._prefetch_factor = loader.prefetch_factor + self._in_order = loader.in_order + + assert self._num_workers > 0 + assert self._prefetch_factor > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = torch.multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Additional worker init function will take care of sharding in MP and Distributed + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + self._worker_init_fn = functools.partial( + _sharding_worker_init_fn, + self._worker_init_fn, + self._world_size, + self._rank, + ) + + # No certainty which module multiprocessing_context is + self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + self._worker_pids_set = False + self._shutdown = False + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + for i in range(self._num_workers): + # No certainty which module multiprocessing_context is + index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_utils.worker._worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed, + self._worker_init_fn, + i, + self._num_workers, + self._persistent_workers, + self._shared_seed, + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + + # Queue is not type-annotated + self._data_queue = queue.Queue() # type: ignore[var-annotated] + current_device = -1 + if self._pin_memory_device == "cuda": + current_device = torch.cuda.current_device() + elif self._pin_memory_device == "xpu": + current_device = torch.xpu.current_device() + elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr( + torch, torch._C._get_privateuse1_backend_name() + ) + current_device = custom_device_mod.current_device() + elif self._pin_memory_device is None: + current_device = torch.accelerator.current_device_index() + pin_memory_thread = threading.Thread( + target=_utils.pin_memory._pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + current_device, + self._pin_memory_thread_done_event, + self._pin_memory_device, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue # type: ignore[assignment] + + # In some rare cases, persistent workers (daemonic processes) + # would be terminated before `__del__` of iterator is invoked + # when main process exits + # It would cause failure when pin_memory_thread tries to read + # corrupted data from worker_result_queue + # atexit is used to shutdown thread and child processes in the + # right sequence before main process exits + if self._persistent_workers and self._pin_memory: + import atexit + + for w in self._workers: + atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) + + # .pid can be None only before process is spawned (not the case, so ignore) + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + self._reset(loader, first_iter=True) + + def _reset(self, loader, first_iter=False): + super()._reset(loader, first_iter) + self._send_idx = 0 # idx of the next task to be sent to workers + self._rcvd_idx = 0 # idx of the next task to be returned in __next__ + # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). + # map: task idx => - (worker_id,) if data isn't fetched (outstanding) + # \ (worker_id, data) if data is already fetched (out-of-order) + self._task_info = {} + self._tasks_outstanding = ( + 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + ) + # A list of booleans representing whether each worker still has work to + # do, i.e., not having exhausted its iterable dataset object. It always + # contains all `True`s if not using an iterable-style dataset + # (i.e., if kind != Iterable). + # Not that this indicates that a worker still has work to do *for this epoch*. + # It does not mean that a worker is dead. In case of `_persistent_workers`, + # the worker will be reset to available in the next epoch. + self._workers_status = [True for i in range(self._num_workers)] + # A list of integers representing how many tasks are outstanding for each worker + # Incremented when a task is dispatched to the worker + # Decremented when that data has been given to the main thread + # Each worker should have at most self._prefetch_factor tasks outstanding + self._workers_num_tasks = [0 for i in range(self._num_workers)] + # Reset the worker queue cycle so it resumes next epoch at worker 0 + self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) + # We resume the prefetching in case it was enabled + if not first_iter: + for idx in range(self._num_workers): + self._index_queues[idx].put( + _utils.worker._ResumeIteration(self._shared_seed) + ) + resume_iteration_cnt = self._num_workers + while resume_iteration_cnt > 0: + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + assert return_data is None + resume_iteration_cnt -= 1 + # prime the prefetch loop + for _ in range(self._prefetch_factor * self._num_workers): + self._try_put_index() + + def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `self._data_queue` once for a given timeout. + # This can also be used as inner loop of fetching without timeout, with + # the sender status as the loop condition. + # + # This raises a `RuntimeError` if any worker died expectedly. This error + # can come from either the SIGCHLD handler in `_utils/signal_handling.py` + # (only for non-Windows platforms), or the manual check below on errors + # and timeouts. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + # At timeout and error, we manually check whether any worker has + # failed. Note that this is the only mechanism for Windows to detect + # worker failures. + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append(w) + self._mark_worker_as_unavailable(worker_id) + if len(failed_workers) > 0: + pids_str = ", ".join(str(w.pid) for w in failed_workers) + raise RuntimeError( + f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" + ) from e + if isinstance(e, queue.Empty): + return (False, None) + + import errno + import tempfile + + try: + # Raise an exception if we are this close to the FDs limit. + # Apparently, trying to open only one file is not a sufficient + # test. + # See NOTE [ DataLoader on Linux and open files limit ] + fds_limit_margin = 10 + [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] + except OSError as e: + if e.errno == errno.EMFILE: + raise RuntimeError( + "Too many open files. Communication with the" + " workers is no longer possible. Please increase the" + " limit using `ulimit -n` in the shell or change the" + " sharing strategy by calling" + " `torch.multiprocessing.set_sharing_strategy('file_system')`" + " at the beginning of your code" + ) from None + raise + + # NOTE [ DataLoader on Linux and open files limit ] + # + # On Linux when DataLoader is used with multiprocessing we pass the data between + # the root process and the workers through SHM files. We remove those files from + # the filesystem as soon as they are created and keep them alive by + # passing around their file descriptors through AF_UNIX sockets. (See + # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in + # the wiki (https://github.com/pytorch/pytorch/wiki).) + # + # This sometimes leads us to exceeding the open files limit. When that happens, + # and the offending file descriptor is coming over a socket, the `socket` Python + # package silently strips the file descriptor from the message, setting only the + # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that + # it _indicates that some control data were discarded due to lack of space in + # the buffer for ancillary data_). This might reflect the C implementation of + # AF_UNIX sockets. + # + # This behaviour can be reproduced with the script and instructions at the + # bottom of this note. + # + # When that happens, the standard Python `multiprocessing` (and not + # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata` + # + # Sometimes, instead of the FD being stripped, you may get an `OSError: + # Too many open files`, both in the script below and in DataLoader. However, + # this is rare and seems to be nondeterministic. + # + # + # #!/usr/bin/env python3 + # import sys + # import socket + # import os + # import array + # import shutil + # import socket + # + # + # if len(sys.argv) != 4: + # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)") + # sys.exit(1) + # + # if __name__ == '__main__': + # dirname = sys.argv[1] + # sock_path = dirname + "/sock" + # iterations = int(sys.argv[2]) + # def dummy_path(i): + # return dirname + "/" + str(i) + ".dummy" + # + # + # if sys.argv[3] == 'send': + # while not os.path.exists(sock_path): + # pass + # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # client.connect(sock_path) + # for i in range(iterations): + # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) + # ancdata = array.array('i', [fd]) + # msg = bytes([i % 256]) + # print("Sending fd ", fd, " (iteration #", i, ")") + # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) + # + # + # else: + # assert sys.argv[3] == 'recv' + # + # if os.path.exists(dirname): + # raise Exception("Directory exists") + # + # os.mkdir(dirname) + # + # print("Opening socket...") + # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # server.bind(sock_path) + # + # print("Listening...") + # for i in range(iterations): + # a = array.array('i') + # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) + # assert(len(ancdata) == 1) + # cmsg_level, cmsg_type, cmsg_data = ancdata[0] + # a.frombytes(cmsg_data) + # print("Received fd ", a[0], " (iteration #", i, ")") + # + # shutil.rmtree(dirname) + # + # Steps to reproduce: + # + # 1. Run two shells and set lower file descriptor limit in the receiving one: + # (shell1) ulimit -n 1020 + # (shell2) ulimit -n 1022 + # + # 2. Run the script above with the `recv` option in the first shell + # (shell1) ./test_socket.py sock_tmp 1017 recv + # + # 3. Run the script with the `send` option in the second shell: + # (shell2) ./test_socket.py sock_tmp 1017 send + + def _get_data(self): + # Fetches data from `self._data_queue`. + # + # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. This is the only mechanism to detect worker failures for + # Windows. For other platforms, a SIGCHLD handler is also used for + # worker failure detection. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. + if self._timeout > 0: + success, data = self._try_get_data(self._timeout) + if success: + return data + else: + raise RuntimeError( + f"DataLoader timed out after {self._timeout} seconds" + ) + elif self._pin_memory: + while self._pin_memory_thread.is_alive(): + success, data = self._try_get_data() + if success: + return data + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError("Pin memory thread exited unexpectedly") + # In this case, `self._data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _next_data(self): + while True: + # If the worker responsible for `self._rcvd_idx` has already ended + # and was unable to fulfill this task (due to exhausting an `IterableDataset`), + # we try to advance `self._rcvd_idx` to find the next valid index. + # + # This part needs to run in the loop because both the `self._get_data()` + # call and `_IterableDatasetStopIteration` check below can mark + # extra worker(s) as dead. + while self._rcvd_idx < self._send_idx: + info = self._task_info.get(self._rcvd_idx, None) + if info: + worker_id = info[0] + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + # no valid `self._rcvd_idx` is found (i.e., didn't break) + if not self._persistent_workers: + self._shutdown_workers() + raise StopIteration + + # Now `self._rcvd_idx` is the batch index we want to fetch + + # Check if the next sample has already been generated + if len(self._task_info[self._rcvd_idx]) == 2: + worker_id, data = self._task_info.pop(self._rcvd_idx) + self._rcvd_idx += 1 + return self._process_data(data, worker_id) + + assert not self._shutdown and self._tasks_outstanding > 0 + idx, data = self._get_data() + self._tasks_outstanding -= 1 + if self._dataset_kind == _DatasetKind.Iterable: + # Check for _IterableDatasetStopIteration + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + if self._persistent_workers: + self._workers_status[data.worker_id] = False + else: + self._mark_worker_as_unavailable(data.worker_id) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + if not self._in_order: + # don't store it for later, process now + # delete from self._task_info immediately + # this keeps the object size manageable + worker_id = self._task_info.pop(idx)[0] + return self._process_data(data, worker_id) + # store out-of-order samples + self._task_info[idx] += (data,) + else: + worker_id = self._task_info.pop(idx)[0] + self._rcvd_idx += 1 + return self._process_data(data, worker_id) + + def _try_put_index(self): + max_tasks = self._prefetch_factor * self._num_workers + assert self._tasks_outstanding < max_tasks + + try: + index = self._next_index() + except StopIteration: + return + for _ in range(self._num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + if self._in_order: + break + elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum( + self._workers_status + ): + # when self._in_order is False, distribute work to a worker if it has capacity + # _workers_status is updated only in this thread, so the sum is guaranteed > 0 + break + else: + # not found (i.e., didn't break) + return + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined] + self._task_info[self._send_idx] = (worker_queue_idx,) + self._workers_num_tasks[worker_queue_idx] += 1 + self._tasks_outstanding += 1 + self._send_idx += 1 + + def _process_data(self, data, worker_idx): + self._workers_num_tasks[worker_idx] -= 1 + self._try_put_index() + if isinstance(data, ExceptionWrapper): + data.reraise() + return data + + def _mark_worker_as_unavailable(self, worker_id, shutdown=False): + # Mark a worker as having finished its work e.g., due to + # exhausting an `IterableDataset`. This should be used only when this + # `_MultiProcessingDataLoaderIter` is going to continue running. + + assert self._workers_status[worker_id] or ( + self._persistent_workers and shutdown + ) + + # Signal termination to that specific worker. + q = self._index_queues[worker_id] + # Indicate that no more data will be put on this queue by the current + # process. + q.put(None) + + # Note that we don't actually join the worker here, nor do we remove the + # worker's pid from C side struct because (1) joining may be slow, and + # (2) since we don't join, the worker may still raise error, and we + # prefer capturing those, rather than ignoring them, even though they + # are raised after the worker has finished its job. + # Joinning is deferred to `_shutdown_workers`, which it is called when + # all workers finish their jobs (e.g., `IterableDataset` replicas) or + # when this iterator is garbage collected. + + self._workers_status[worker_id] = False + + assert self._workers_done_event.is_set() == shutdown + + def _shutdown_workers(self): + # Called when shutting down this `_MultiProcessingDataLoaderIter`. + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. + if ( + _utils is None + or _utils.python_exit_status is True + or _utils.python_exit_status is None + ): + # See (2) of the note. If Python is shutting down, do no-op. + return + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + if not self._shutdown: + self._shutdown = True + try: + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, "_pin_memory_thread"): + # Use hasattr in case error happens before we set the attribute. + self._pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + + # Exit workers now. + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + # Get number of workers from `len(self._workers)` instead of + # `self._num_workers` in case we error before starting all + # workers. + # If we are using workers_status with persistent_workers + # we have to shut it down because the worker is paused + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: + # We should be able to join here, but in case anything went + # wrong, we set a timeout and if the workers fail to join, + # they are killed in the `finally` block. + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: + q.cancel_join_thread() + q.close() + finally: + # Even though all this function does is putting into queues that + # we have called `cancel_join_thread` on, weird things can + # happen when a worker is killed by a signal, e.g., hanging in + # `Event.set()`. So we need to guard this with SIGCHLD handler, + # and remove pids from the C side data structure only at the + # end. + # + # FIXME: Unfortunately, for Windows, we are missing a worker + # error detection mechanism here in this function, as it + # doesn't provide a SIGCHLD handler. + if self._worker_pids_set: + _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: + if w.is_alive(): + # Existing mechanisms try to make the workers exit + # peacefully, but in case that we unfortunately reach + # here, which we shouldn't, (e.g., pytorch/pytorch#39570), + # we kill the worker. + w.terminate() + + # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` + @staticmethod + def _clean_up_worker(w): + try: + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + finally: + if w.is_alive(): + w.terminate() + + def __del__(self): + self._shutdown_workers() diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/__init__.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50e00d755075be0a66b374edb5511e5a3d2e3701 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__init__.py @@ -0,0 +1 @@ +from torch.utils.data.datapipes import dataframe as dataframe, iter as iter, map as map diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ee2ad9e5f592a6b2402b4f1a11813b7a4eb53bf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b687c89b6092266e57578cf794de3f4693f45c8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_decorator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59444069b0b40f3ca8d1701170415f3efaaf236d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eaa177d8b60caaa438326388f7a4fa72351be60 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/_typing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c0171f8ebfc17aebb41436aa03555da0990a50f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/datapipe.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/gen_pyi.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/gen_pyi.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58c0fa82e51436f5014ba51987c987337e85945b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/__pycache__/gen_pyi.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/_decorator.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b01be741a4ecb036a25510a39c39fb6342b16c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/_decorator.py @@ -0,0 +1,213 @@ +# mypy: allow-untyped-defs +import inspect +from functools import wraps +from typing import Any, Callable, get_type_hints, Optional, Union + +from torch.utils.data.datapipes._typing import _DataPipeMeta +from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe + + +###################################################### +# Functional API +###################################################### +class functional_datapipe: + name: str + + def __init__(self, name: str, enable_df_api_tracing=False) -> None: + """ + Define a functional datapipe. + + Args: + enable_df_api_tracing - if set, any returned DataPipe would accept + DataFrames API in tracing mode. + """ + self.name = name + self.enable_df_api_tracing = enable_df_api_tracing + + def __call__(self, cls): + if issubclass(cls, IterDataPipe): + if isinstance(cls, type): # type: ignore[arg-type] + if not isinstance(cls, _DataPipeMeta): + raise TypeError( + "`functional_datapipe` can only decorate IterDataPipe" + ) + # with non_deterministic decorator + else: + if not isinstance(cls, non_deterministic) and not ( + hasattr(cls, "__self__") + and isinstance(cls.__self__, non_deterministic) + ): + raise TypeError( + "`functional_datapipe` can only decorate IterDataPipe" + ) + IterDataPipe.register_datapipe_as_function( + self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing + ) + elif issubclass(cls, MapDataPipe): + MapDataPipe.register_datapipe_as_function(self.name, cls) + + return cls + + +###################################################### +# Determinism +###################################################### +_determinism: bool = False + + +class guaranteed_datapipes_determinism: + prev: bool + + def __init__(self) -> None: + global _determinism + self.prev = _determinism + _determinism = True + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + global _determinism + _determinism = self.prev + + +class non_deterministic: + cls: Optional[type[IterDataPipe]] = None + # TODO: Lambda for picking + deterministic_fn: Callable[[], bool] + + def __init__(self, arg: Union[type[IterDataPipe], Callable[[], bool]]) -> None: + # 1. Decorator doesn't have any argument + if isinstance(arg, type): # type: ignore[arg-type] + if not issubclass(arg, IterDataPipe): # type: ignore[arg-type] + raise TypeError( + "Only `IterDataPipe` can be decorated with `non_deterministic`" + f", but {arg.__name__} is found" + ) + self.cls = arg # type: ignore[assignment] + # 2. Decorator has an argument of a function + # This class should behave differently given different inputs. Use this + # function to verify the determinism for each instance. + # When the function returns True, the instance is non-deterministic. Otherwise, + # the instance is a deterministic DataPipe. + elif isinstance(arg, Callable): # type:ignore[arg-type] + self.deterministic_fn = arg # type: ignore[assignment, misc] + else: + raise TypeError(f"{arg} can not be decorated by non_deterministic") + + def __call__(self, *args, **kwargs): + global _determinism + # Decorate IterDataPipe + if self.cls is not None: + if _determinism: + raise TypeError( + f"{self.cls.__name__} is non-deterministic, but you set 'guaranteed_datapipes_determinism'. " + "You can turn off determinism for this DataPipe if that is acceptable " + "for your application" + ) + return self.cls(*args, **kwargs) # type: ignore[call-arg] + + # Decorate with a functional argument + if not ( + isinstance(args[0], type) + and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] + ): + raise TypeError( + f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found" + ) + self.cls = args[0] + return self.deterministic_wrapper_fn + + def deterministic_wrapper_fn(self, *args, **kwargs) -> IterDataPipe: + res = self.deterministic_fn(*args, **kwargs) # type: ignore[call-arg, misc] + if not isinstance(res, bool): + raise TypeError( + "deterministic_fn of `non_deterministic` decorator is required " + f"to return a boolean value, but {type(res)} is found" + ) + global _determinism + if _determinism and res: + raise TypeError( + f"{self.cls.__name__} is non-deterministic with the inputs, but you set " # type: ignore[union-attr] + "'guaranteed_datapipes_determinism'. You can turn off determinism " + "for this DataPipe if that is acceptable for your application" + ) + return self.cls(*args, **kwargs) # type: ignore[call-arg, misc] + + +###################################################### +# Type validation +###################################################### +# Validate each argument of DataPipe with hint as a subtype of the hint. +def argument_validation(f): + signature = inspect.signature(f) + hints = get_type_hints(f) + + @wraps(f) + def wrapper(*args, **kwargs): + bound = signature.bind(*args, **kwargs) + for argument_name, value in bound.arguments.items(): + if argument_name in hints and isinstance( + hints[argument_name], _DataPipeMeta + ): + hint = hints[argument_name] + if not isinstance(value, IterDataPipe): + raise TypeError( + f"Expected argument '{argument_name}' as a IterDataPipe, but found {type(value)}" + ) + if not value.type.issubtype(hint.type): + raise TypeError( + f"Expected type of argument '{argument_name}' as a subtype of " + f"hint {hint.type}, but found {value.type}" + ) + + return f(*args, **kwargs) + + return wrapper + + +# Default value is True +_runtime_validation_enabled: bool = True + + +class runtime_validation_disabled: + prev: bool + + def __init__(self) -> None: + global _runtime_validation_enabled + self.prev = _runtime_validation_enabled + _runtime_validation_enabled = False + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + global _runtime_validation_enabled + _runtime_validation_enabled = self.prev + + +# Runtime checking +# Validate output data is subtype of return hint +def runtime_validation(f): + # TODO: + # Can be extended to validate '__getitem__' and nonblocking + if f.__name__ != "__iter__": + raise TypeError( + f"Can not decorate function {f.__name__} with 'runtime_validation'" + ) + + @wraps(f) + def wrapper(self): + global _runtime_validation_enabled + if not _runtime_validation_enabled: + yield from f(self) + else: + it = f(self) + for d in it: + if not self.type.issubtype_of_instance(d): + raise RuntimeError( + f"Expected an instance as subtype of {self.type}, but found {d}({type(d)})" + ) + yield d + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/_hook_iterator.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/_hook_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..735874aec393ec64b29419bf7b2a909ca37e9d38 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/_hook_iterator.py @@ -0,0 +1,279 @@ +# mypy: allow-untyped-defs +import functools +import inspect +from enum import Enum + +import torch + + +class _SnapshotState(Enum): + r""" + These are the snapshotting-related states that IterDataPipes can be in. + + `NotStarted` - allows you to restore a snapshot and create an iterator with reset + `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe + `Iterating` - can restore, will reset if you create a new iterator + """ + + NotStarted = 0 + Restored = 1 + Iterating = 2 + + +def _simplify_obj_name(obj) -> str: + """Simplify the display strings of objects for the purpose of rendering within DataPipe error messages.""" + if inspect.isfunction(obj): + return obj.__name__ + else: + return repr(obj) + + +def _strip_datapipe_from_name(name: str) -> str: + return name.replace("IterDataPipe", "").replace("MapDataPipe", "") + + +def _generate_input_args_string(obj): + """Generate a string for the input arguments of an object.""" + signature = inspect.signature(obj.__class__) + input_param_names = set(signature.parameters.keys()) + result = [] + for name, value in inspect.getmembers(obj): + if name in input_param_names: + result.append((name, _simplify_obj_name(value))) + return ", ".join([f"{name}={value}" for name, value in result]) + + +def _generate_iterdatapipe_msg(datapipe, simplify_dp_name: bool = False): + output_string = ( + f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})" + ) + if simplify_dp_name: + output_string = _strip_datapipe_from_name(output_string) + return output_string + + +def _gen_invalid_iterdatapipe_msg(datapipe): + return ( + "This iterator has been invalidated because another iterator has been created " + f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n" + "This may be caused multiple references to the same IterDataPipe. We recommend " + "using `.fork()` if that is necessary." + ) + + +_feedback_msg = ( + "\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free " + "to comment on this issue: https://github.com/pytorch/data/issues/45." +) + + +def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None: + r""" + Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception. + + In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well. + """ + if next_method_exists: + # This is the case where `IterDataPipe` has both `__iter__` and `__next__`. + # The `_valid_iterator_id` should either be never set (`None`), or set by at most one + # iterator (`0`). Otherwise, it means there are multiple iterators. + if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0: + extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method" + raise RuntimeError( + _gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg + ) + elif ( + hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True + ): + if hasattr(datapipe, "_check_valid_iterator_id"): + if not datapipe._check_valid_iterator_id(iterator_id): + raise RuntimeError( + "This iterator has been invalidated, because a new iterator has been created " + f"from one of the ChildDataPipes of " + f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}." + + _feedback_msg + ) + else: + raise RuntimeError( + "ChildDataPipe must have method `_check_valid_iterator_id`." + ) + elif datapipe._valid_iterator_id != iterator_id: + raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg) + + +def _set_datapipe_valid_iterator_id(datapipe): + """Given a DataPipe, updates its valid iterator ID and reset the DataPipe.""" + if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True: + if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"): + datapipe._set_main_datapipe_valid_iterator_id() # reset() is called within this method when appropriate + else: + raise RuntimeError( + "ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`." + ) + else: + if datapipe._valid_iterator_id is None: + datapipe._valid_iterator_id = 0 + else: + datapipe._valid_iterator_id += 1 + datapipe.reset() + return datapipe._valid_iterator_id + + +def hook_iterator(namespace): + r""" + Define a hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`. + + This is done for the purpose of profiling and checking if an iterator is still valid. + """ + + def profiler_record_fn_context(datapipe): + if not hasattr(datapipe, "_profile_name"): + datapipe._profile_name = _generate_iterdatapipe_msg( + datapipe, simplify_dp_name=True + ) + return torch.autograd.profiler.record_function(datapipe._profile_name) + + class IteratorDecorator: + r""" + Wrap the iterator and modifying its `__next__` method. + + This decorator is applied to DataPipes of which `__iter__` method is NOT a generator function. + Those `__iter__` method commonly returns `self` but not necessarily. + """ + + def __init__(self, iterator, datapipe, iterator_id, has_next_method): + self.iterator = iterator + self.datapipe = datapipe + self.iterator_id = iterator_id + self._profiler_enabled = torch.autograd._profiler_enabled() + # Check if `__iter__` returns `self` and `DataPipe` has `__next__` + self.self_and_has_next_method = ( + self.iterator is self.datapipe and has_next_method + ) + + def __iter__(self): + return self + + def _get_next(self): + """Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.""" + _check_iterator_valid(self.datapipe, self.iterator_id) + result = next(self.iterator) + if not self.self_and_has_next_method: + self.datapipe._number_of_samples_yielded += 1 + return result + + def __next__(self): + # TODO: Add try-except to in-place reduce traceback from the Exception + # See: https://github.com/pytorch/data/issues/284 + if self._profiler_enabled: + with profiler_record_fn_context(self.datapipe): + return self._get_next() + else: # Decided against using `contextlib.nullcontext` for performance reasons + return self._get_next() + + def __getattr__(self, name): + return getattr(self.iterator, name) + + func = namespace["__iter__"] + + # ``__iter__`` of IterDataPipe is a generator function + if inspect.isgeneratorfunction(func): + + @functools.wraps(func) + def wrap_generator(*args, **kwargs): + gen = func(*args, **kwargs) + datapipe = args[0] + if datapipe._fast_forward_iterator: + it = datapipe._fast_forward_iterator + datapipe._fast_forward_iterator = None + datapipe._snapshot_state = _SnapshotState.Iterating + while True: + try: + yield next(it) + except StopIteration: + return + iterator_id = _set_datapipe_valid_iterator_id( + datapipe + ) # This ID is tied to each created iterator + _profiler_enabled = torch.autograd._profiler_enabled() + try: + if _profiler_enabled: + with profiler_record_fn_context(datapipe): + response = gen.send(None) + else: + response = gen.send(None) + + while True: + datapipe._number_of_samples_yielded += 1 + request = yield response + # Pass through here every time `__next__` is called + if _profiler_enabled: + with profiler_record_fn_context(datapipe): + _check_iterator_valid(datapipe, iterator_id) + response = gen.send(request) + else: # Decided against using `contextlib.nullcontext` for performance reasons + _check_iterator_valid(datapipe, iterator_id) + response = gen.send(request) + except StopIteration: + return + except Exception as e: + # TODO: Simplify the traceback message to skip over `response = gen.send(None)` + # Part of https://github.com/pytorch/data/issues/284 + datapipe = args[0] + msg = "thrown by __iter__ of" + single_iterator_msg = "single iterator per IterDataPipe constraint" + if hasattr(e.args, "__len__"): + full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})" + if len(e.args) == 0 or not isinstance( + e.args[0], str + ): # If an exception message doesn't exist + e.args = (f"\nThis exception is {full_msg}",) + elif msg not in e.args[0] and single_iterator_msg not in e.args[0]: + e.args = ( + e.args[0] + f"\nThis exception is {full_msg}", + ) + e.args[1:] + raise + + namespace["__iter__"] = wrap_generator + else: # ``__iter__`` of IterDataPipe is NOT a generator function + # IterDataPipe is an iterator with both ``__iter__`` and ``__next__`` + # And ``__iter__`` may or may not return `self` + if "__next__" in namespace: # If `__next__` exists, put a wrapper around it + next_func = namespace["__next__"] + + @functools.wraps(next_func) + def wrap_next(*args, **kwargs): + datapipe = args[0] + if torch.autograd._profiler_enabled(): + with profiler_record_fn_context(datapipe): + result = next_func(*args, **kwargs) + else: + result = next_func(*args, **kwargs) + datapipe._number_of_samples_yielded += 1 + return result + + namespace["__next__"] = wrap_next + + # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but + # the user will be violating the iterator protocol. Potential issue: + # 1. Valid iterator ID may not update or checked properly + # 2. The number of samples yielded will be miscounted + + # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators + @functools.wraps(func) + def wrap_iter(*args, **kwargs): + iter_ret = func(*args, **kwargs) + datapipe = args[0] + datapipe._snapshot_state = _SnapshotState.Iterating + if datapipe._fast_forward_iterator: + iter_ret = datapipe._fast_forward_iterator + datapipe._fast_forward_iterator = None + return iter_ret + iterator_id = _set_datapipe_valid_iterator_id( + datapipe + ) # This ID is tied to each created iterator + return IteratorDecorator( + iter_ret, datapipe, iterator_id, "__next__" in namespace + ) + + namespace["__iter__"] = wrap_iter diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/_typing.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..14120bc2a2fef62a2124040e9435a03fae783ae5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/_typing.py @@ -0,0 +1,482 @@ +# mypy: allow-untyped-defs +# Taking reference from official Python typing +# https://github.com/python/cpython/blob/master/Lib/typing.py + +import collections +import functools +import numbers +import sys + +# Please check [Note: TypeMeta and TypeAlias] +# In case of metaclass conflict due to ABCMeta or _ProtocolMeta +# For Python 3.9, only Protocol in typing uses metaclass +from abc import ABCMeta +from collections.abc import Iterator + +# TODO: Use TypeAlias when Python 3.6 is deprecated +from typing import ( # type: ignore[attr-defined] + _eval_type, + _GenericAlias, + _tp_cache, + _type_check, + _type_repr, + Any, + ForwardRef, + Generic, + get_type_hints, + TypeVar, + Union, +) + +from torch.utils.data.datapipes._hook_iterator import _SnapshotState, hook_iterator + + +class GenericMeta(ABCMeta): # type: ignore[no-redef] + pass + + +class Integer(numbers.Integral): + pass + + +class Boolean(numbers.Integral): + pass + + +# Python 'type' object is not subscriptable +# Tuple[int, List, dict] -> valid +# tuple[int, list, dict] -> invalid +# Map Python 'type' to abstract base class +TYPE2ABC = { + bool: Boolean, + int: Integer, + float: numbers.Real, + complex: numbers.Complex, + dict: dict, + list: list, + set: set, + tuple: tuple, + None: type(None), +} + + +def issubtype(left, right, recursive=True): + r""" + Check if the left-side type is a subtype of the right-side type. + + If any of type is a composite type like `Union` and `TypeVar` with + bounds, it would be expanded into a list of types and check all + of left-side types are subtypes of either one from right-side types. + """ + left = TYPE2ABC.get(left, left) + right = TYPE2ABC.get(right, right) + + if right is Any or left == right: + return True + + if isinstance(right, _GenericAlias): + if getattr(right, "__origin__", None) is Generic: + return True + + if right == type(None): + return False + + # Right-side type + constraints = _decompose_type(right) + + if len(constraints) == 0 or Any in constraints: + return True + + if left is Any: + return False + + # Left-side type + variants = _decompose_type(left) + + # all() will return True for empty variants + if len(variants) == 0: + return False + + return all( + _issubtype_with_constraints(variant, constraints, recursive) + for variant in variants + ) + + +def _decompose_type(t, to_list=True): + if isinstance(t, TypeVar): + if t.__bound__ is not None: + ts = [t.__bound__] + else: + # For T_co, __constraints__ is () + ts = list(t.__constraints__) + elif hasattr(t, "__origin__") and t.__origin__ == Union: + ts = t.__args__ + else: + if not to_list: + return None + ts = [t] + # Ignored: Generator has incompatible item type "object"; expected "Type[Any]" + ts = [TYPE2ABC.get(_t, _t) for _t in ts] # type: ignore[misc] + return ts + + +def _issubtype_with_constraints(variant, constraints, recursive=True): + r""" + Check if the variant is a subtype of either one from constraints. + + For composite types like `Union` and `TypeVar` with bounds, they + would be expanded for testing. + """ + if variant in constraints: + return True + + # [Note: Subtype for Union and TypeVar] + # Python typing is able to flatten Union[Union[...]] or Union[TypeVar]. + # But it couldn't flatten the following scenarios: + # - Union[int, TypeVar[Union[...]]] + # - TypeVar[TypeVar[...]] + # So, variant and each constraint may be a TypeVar or a Union. + # In these cases, all of inner types from the variant are required to be + # extraced and verified as a subtype of any constraint. And, all of + # inner types from any constraint being a TypeVar or a Union are + # also required to be extracted and verified if the variant belongs to + # any of them. + + # Variant + vs = _decompose_type(variant, to_list=False) + + # Variant is TypeVar or Union + if vs is not None: + return all(_issubtype_with_constraints(v, constraints, recursive) for v in vs) + + # Variant is not TypeVar or Union + if hasattr(variant, "__origin__") and variant.__origin__ is not None: + v_origin = variant.__origin__ + # In Python-3.9 typing library untyped generics do not have args + v_args = getattr(variant, "__args__", None) + else: + v_origin = variant + v_args = None + + # Constraints + for constraint in constraints: + cs = _decompose_type(constraint, to_list=False) + + # Constraint is TypeVar or Union + if cs is not None: + if _issubtype_with_constraints(variant, cs, recursive): + return True + # Constraint is not TypeVar or Union + else: + # __origin__ can be None for plain list, tuple, ... in Python 3.6 + if hasattr(constraint, "__origin__") and constraint.__origin__ is not None: + c_origin = constraint.__origin__ + if v_origin == c_origin: + if not recursive: + return True + # In Python-3.9 typing library untyped generics do not have args + c_args = getattr(constraint, "__args__", None) + if c_args is None or len(c_args) == 0: + return True + if ( + v_args is not None + and len(v_args) == len(c_args) + and all( + issubtype(v_arg, c_arg) + for v_arg, c_arg in zip(v_args, c_args) + ) + ): + return True + # Tuple[int] -> Tuple + else: + if v_origin == constraint: + return True + + return False + + +def issubinstance(data, data_type): + if not issubtype(type(data), data_type, recursive=False): + return False + + # In Python-3.9 typing library __args__ attribute is not defined for untyped generics + dt_args = getattr(data_type, "__args__", None) + if isinstance(data, tuple): + if dt_args is None or len(dt_args) == 0: + return True + if len(dt_args) != len(data): + return False + return all(issubinstance(d, t) for d, t in zip(data, dt_args)) + elif isinstance(data, (list, set)): + if dt_args is None or len(dt_args) == 0: + return True + t = dt_args[0] + return all(issubinstance(d, t) for d in data) + elif isinstance(data, dict): + if dt_args is None or len(dt_args) == 0: + return True + kt, vt = dt_args + return all( + issubinstance(k, kt) and issubinstance(v, vt) for k, v in data.items() + ) + + return True + + +# [Note: TypeMeta and TypeAlias] +# In order to keep compatibility for Python 3.6, use Meta for the typing. +# TODO: When PyTorch drops the support for Python 3.6, it can be converted +# into the Alias system and using `__class_getitem__` for DataPipe. The +# typing system will gain benefit of performance and resolving metaclass +# conflicts as elaborated in https://www.python.org/dev/peps/pep-0560/ + + +class _DataPipeType: + r"""Save type annotation in `param`.""" + + def __init__(self, param): + self.param = param + + def __repr__(self): + return _type_repr(self.param) + + def __eq__(self, other): + if isinstance(other, _DataPipeType): + return self.param == other.param + return NotImplemented + + def __hash__(self): + return hash(self.param) + + def issubtype(self, other): + if isinstance(other.param, _GenericAlias): + if getattr(other.param, "__origin__", None) is Generic: + return True + if isinstance(other, _DataPipeType): + return issubtype(self.param, other.param) + if isinstance(other, type): + return issubtype(self.param, other) + raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}") + + def issubtype_of_instance(self, other): + return issubinstance(other, self.param) + + +# Default type for DataPipe without annotation +_T_co = TypeVar("_T_co", covariant=True) +_DEFAULT_TYPE = _DataPipeType(Generic[_T_co]) + + +class _DataPipeMeta(GenericMeta): + r""" + Metaclass for `DataPipe`. + + Add `type` attribute and `__init_subclass__` based on the type, and validate the return hint of `__iter__`. + + Note that there is subclass `_IterDataPipeMeta` specifically for `IterDataPipe`. + """ + + type: _DataPipeType + + def __new__(cls, name, bases, namespace, **kwargs): + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + # TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now. + cls.__origin__ = None + if "type" in namespace: + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + namespace["__type_class__"] = False + # For plain derived class without annotation + for base in bases: + if isinstance(base, _DataPipeMeta): + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + namespace.update( + {"type": _DEFAULT_TYPE, "__init_subclass__": _dp_init_subclass} + ) + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + def __init__(self, name, bases, namespace, **kwargs): + super().__init__(name, bases, namespace, **kwargs) # type: ignore[call-overload] + + # TODO: Fix isinstance bug + @_tp_cache + def _getitem_(self, params): + if params is None: + raise TypeError(f"{self.__name__}[t]: t can not be None") + if isinstance(params, str): + params = ForwardRef(params) + if not isinstance(params, tuple): + params = (params,) + + msg = f"{self.__name__}[t]: t must be a type" + params = tuple(_type_check(p, msg) for p in params) + + if isinstance(self.type.param, _GenericAlias): + orig = getattr(self.type.param, "__origin__", None) + if isinstance(orig, type) and orig is not Generic: + p = self.type.param[params] # type: ignore[index] + t = _DataPipeType(p) + l = len(str(self.type)) + 2 + name = self.__name__[:-l] + name = name + "[" + str(t) + "]" + bases = (self,) + self.__bases__ + return self.__class__( + name, + bases, + { + "__init_subclass__": _dp_init_subclass, + "type": t, + "__type_class__": True, + }, + ) + + if len(params) > 1: + raise TypeError( + f"Too many parameters for {self} actual {len(params)}, expected 1" + ) + + t = _DataPipeType(params[0]) + + if not t.issubtype(self.type): + raise TypeError( + f"Can not subclass a DataPipe[{t}] from DataPipe[{self.type}]" + ) + + # Types are equal, fast path for inheritance + if self.type == t: + return self + + name = self.__name__ + "[" + str(t) + "]" + bases = (self,) + self.__bases__ + + return self.__class__( + name, + bases, + {"__init_subclass__": _dp_init_subclass, "__type_class__": True, "type": t}, + ) + + # TODO: Fix isinstance bug + def _eq_(self, other): + if not isinstance(other, _DataPipeMeta): + return NotImplemented + if self.__origin__ is None or other.__origin__ is None: # type: ignore[has-type] + return self is other + return ( + self.__origin__ == other.__origin__ # type: ignore[has-type] + and self.type == other.type + ) + + # TODO: Fix isinstance bug + def _hash_(self): + return hash((self.__name__, self.type)) + + +class _IterDataPipeMeta(_DataPipeMeta): + r""" + Metaclass for `IterDataPipe` and inherits from `_DataPipeMeta`. + + Add various functions for behaviors specific to `IterDataPipe`. + """ + + def __new__(cls, name, bases, namespace, **kwargs): + if "reset" in namespace: + reset_func = namespace["reset"] + + @functools.wraps(reset_func) + def conditional_reset(*args, **kwargs): + r""" + Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating` or `NotStarted`. + + This allows recently restored DataPipe to preserve its restored state during the initial `__iter__` call. + """ + datapipe = args[0] + if datapipe._snapshot_state in ( + _SnapshotState.Iterating, + _SnapshotState.NotStarted, + ): + # Reset `NotStarted` is necessary because the `source_datapipe` of a DataPipe might have + # already begun iterating. + datapipe._number_of_samples_yielded = 0 + datapipe._fast_forward_iterator = None + reset_func(*args, **kwargs) + datapipe._snapshot_state = _SnapshotState.Iterating + + namespace["reset"] = conditional_reset + + if "__iter__" in namespace: + hook_iterator(namespace) + return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] + + +def _dp_init_subclass(sub_cls, *args, **kwargs): + # Add function for datapipe instance to reinforce the type + sub_cls.reinforce_type = reinforce_type + + # TODO: + # - add global switch for type checking at compile-time + + # Ignore internal type class + if getattr(sub_cls, "__type_class__", False): + return + + # Check if the string type is valid + if isinstance(sub_cls.type.param, ForwardRef): + base_globals = sys.modules[sub_cls.__module__].__dict__ + try: + param = _eval_type(sub_cls.type.param, base_globals, locals()) + sub_cls.type.param = param + except TypeError as e: + raise TypeError( + f"{sub_cls.type.param.__forward_arg__} is not supported by Python typing" + ) from e + + if "__iter__" in sub_cls.__dict__: + iter_fn = sub_cls.__dict__["__iter__"] + hints = get_type_hints(iter_fn) + if "return" in hints: + return_hint = hints["return"] + # Plain Return Hint for Python 3.6 + if return_hint == Iterator: + return + if not ( + hasattr(return_hint, "__origin__") + and ( + return_hint.__origin__ == Iterator + or return_hint.__origin__ == collections.abc.Iterator + ) + ): + raise TypeError( + "Expected 'Iterator' as the return annotation for `__iter__` of {}" + ", but found {}".format( + sub_cls.__name__, _type_repr(hints["return"]) + ) + ) + data_type = return_hint.__args__[0] + if not issubtype(data_type, sub_cls.type.param): + raise TypeError( + f"Expected return type of '__iter__' as a subtype of {sub_cls.type}," + f" but found {_type_repr(data_type)} for {sub_cls.__name__}" + ) + + +def reinforce_type(self, expected_type): + r""" + Reinforce the type for DataPipe instance. + + And the 'expected_type' is required to be a subtype of the original type + hint to restrict the type requirement of DataPipe instance. + """ + if isinstance(expected_type, tuple): + expected_type = tuple[expected_type] # type: ignore[valid-type] + _type_check(expected_type, msg="'expected_type' must be a type") + + if not issubtype(expected_type, self.type.param): + raise TypeError( + f"Expected 'expected_type' as subtype of {self.type}, but found {_type_repr(expected_type)}" + ) + + self.type = _DataPipeType(expected_type) + return self diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__init__.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b599aae89c36cda0a85aaa48ac1a4e2d3f238f8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__init__.py @@ -0,0 +1,11 @@ +from torch.utils.data.datapipes.dataframe.dataframes import ( + CaptureDataFrame, + DFIterDataPipe, +) +from torch.utils.data.datapipes.dataframe.datapipes import DataFramesAsTuplesPipe + + +__all__ = ["CaptureDataFrame", "DFIterDataPipe", "DataFramesAsTuplesPipe"] + +# Please keep this list sorted +assert __all__ == sorted(__all__) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67a1ffca451a1eb5e52fda4c47d1b85c7612be57 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04f0548563f62c1347add9b8fbc61c8fca820448 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cc6ce62be372a29554b2fba10262984ad034547 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b993348eddfa7513ab1c73092965afc5c2b250dd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..538a20aefa713fff9fc8864ae648000553f69627 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f81c19a61f5c825b08ade3a315d428df37d6e96b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + + +_pandas: Any = None +_WITH_PANDAS: Optional[bool] = None + + +def _try_import_pandas() -> bool: + try: + import pandas # type: ignore[import] + + global _pandas + _pandas = pandas + return True + except ImportError: + return False + + +# pandas used only for prototyping, will be shortly replaced with TorchArrow +def _with_pandas() -> bool: + global _WITH_PANDAS + if _WITH_PANDAS is None: + _WITH_PANDAS = _try_import_pandas() + return _WITH_PANDAS + + +class PandasWrapper: + @classmethod + def create_dataframe(cls, data, columns): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return _pandas.DataFrame(data, columns=columns) # type: ignore[union-attr] + + @classmethod + def is_dataframe(cls, data): + if not _with_pandas(): + return False + return isinstance(data, _pandas.core.frame.DataFrame) # type: ignore[union-attr] + + @classmethod + def is_column(cls, data): + if not _with_pandas(): + return False + return isinstance(data, _pandas.core.series.Series) # type: ignore[union-attr] + + @classmethod + def iterate(cls, data): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + yield from data.itertuples(index=False) + + @classmethod + def concat(cls, buffer): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return _pandas.concat(buffer) # type: ignore[union-attr] + + @classmethod + def get_item(cls, data, idx): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return data[idx : idx + 1] + + @classmethod + def get_len(cls, df): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return len(df.index) + + @classmethod + def get_columns(cls, df): + if not _with_pandas(): + raise RuntimeError("DataFrames prototype requires pandas to function") + return list(df.columns.values.tolist()) + + +# When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class) +default_wrapper = PandasWrapper + + +def get_df_wrapper(): + return default_wrapper + + +def set_df_wrapper(wrapper): + global default_wrapper + default_wrapper = wrapper + + +def create_dataframe(data, columns=None): + wrapper = get_df_wrapper() + return wrapper.create_dataframe(data, columns) + + +def is_dataframe(data): + wrapper = get_df_wrapper() + return wrapper.is_dataframe(data) + + +def get_columns(data): + wrapper = get_df_wrapper() + return wrapper.get_columns(data) + + +def is_column(data): + wrapper = get_df_wrapper() + return wrapper.is_column(data) + + +def concat(buffer): + wrapper = get_df_wrapper() + return wrapper.concat(buffer) + + +def iterate(data): + wrapper = get_df_wrapper() + return wrapper.iterate(data) + + +def get_item(data, idx): + wrapper = get_df_wrapper() + return wrapper.get_item(data, idx) + + +def get_len(df): + wrapper = get_df_wrapper() + return wrapper.get_len(df) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py new file mode 100644 index 0000000000000000000000000000000000000000..f15649e74426f115a3550ed303814a9cc01bbe9d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/dataframes.py @@ -0,0 +1,457 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.dataframe.structures import DataChunkDF +from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe + + +# TODO(VitalyFedyunin): Add error when two different traces get combined + +__all__ = [ + "Capture", + "CaptureA", + "CaptureAdd", + "CaptureCall", + "CaptureControl", + "CaptureDataFrame", + "CaptureDataFrameWithDataPipeOps", + "CaptureF", + "CaptureGetAttr", + "CaptureGetItem", + "CaptureInitial", + "CaptureLikeMock", + "CaptureMul", + "CaptureSetItem", + "CaptureSub", + "CaptureVariable", + "CaptureVariableAssign", + "DataFrameTracer", + "DataFrameTracedOps", + "disable_capture", + "get_val", +] + + +def disable_capture(): + CaptureControl.disabled = True + + +class CaptureControl: + disabled = False + + +class DataFrameTracedOps(DFIterDataPipe): + def __init__(self, source_datapipe, output_var): + self.source_datapipe = source_datapipe + self.output_var = output_var + + def __iter__(self): + for item in self.source_datapipe: + yield self.output_var.apply_ops(item) + + +# TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions +DATAPIPES_OPS = [ + "_dataframes_as_tuples", + "groupby", + "_dataframes_filter", + "map", + "to_datapipe", + "shuffle", + "concat", + "batch", + "_dataframes_per_row", + "_dataframes_concat", + "_dataframes_shuffle", +] + +UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sharding"] + + +class Capture: + # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures + + def __init__(self, schema_df=None): + self.ctx = {"operations": [], "variables": [], "schema_df": schema_df} + + def __str__(self): + return self._ops_str() + + def _ops_str(self): + res = "" + for op in self.ctx["operations"]: + if len(res) > 0: + res += "\n" + res += str(op) + return res + + def __getstate__(self): + # TODO(VitalyFedyunin): Currently can't pickle (why?) + self.ctx["schema_df"] = None + for var in self.ctx["variables"]: + var.calculated_value = None + state = {} + for item in self.__dict__: + state[item] = getattr(self, item) + return state + + def __setstate__(self, state): + for k, v in state.items(): + setattr(self, k, v) + + def __getattr__(self, attrname): + if attrname == "kwarg" or attrname == "kwargs": + raise RuntimeError("no kwargs!") + if attrname in ["__deepcopy__"]: + raise AttributeError + result = CaptureGetAttr(self, attrname, ctx=self.ctx) + return result + + def __getitem__(self, key): + return CaptureGetItem(self, key, ctx=self.ctx) + + def __setitem__(self, key, value): + self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx)) + + def __add__(self, add_val): + res = CaptureAdd(self, add_val, ctx=self.ctx) + var = CaptureVariable(res, ctx=self.ctx) + self.ctx["operations"].append( + CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) + ) + return var + + def __sub__(self, add_val): + res = CaptureSub(self, add_val, ctx=self.ctx) + var = CaptureVariable(res, ctx=self.ctx) + self.ctx["operations"].append( + CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) + ) + return var + + def __mul__(self, add_val): + res = CaptureMul(self, add_val, ctx=self.ctx) + var = CaptureVariable(res, ctx=self.ctx) + t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) + self.ctx["operations"].append(t) + return var + + def _is_context_empty(self): + return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0 + + def apply_ops_2(self, dataframe): + # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) + self.ctx["variables"][0].calculated_value = dataframe + for op in self.ctx["operations"]: + op.execute() + + @property + def columns(self): + self.apply_ops_2(self.ctx["schema_df"]) + value = self.execute() + return value.columns + + # TODO(VitalyFedyunin): Add tests + # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture + + def __call__(self, *args, **kwargs): + # TODO: Check if args or kwargs have more than one different context + if self._is_context_empty(): + # TODO: Allow CaptureA to take context from mock + for arg in args: + if isinstance(arg, Capture) and not arg._is_context_empty(): + self.ctx = arg.ctx + break + if self._is_context_empty(): + for k, v in kwargs.items(): + if isinstance(k, Capture) and not k._is_context_empty(): + self.ctx = k.ctx + break + if isinstance(v, Capture) and not v._is_context_empty(): + self.ctx = v.ctx + break + + res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs) + var = CaptureVariable(None, ctx=self.ctx) + t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res) + self.ctx["operations"].append(t) + return var + + +class CaptureF(Capture): + def __init__(self, ctx=None, **kwargs): + if ctx is None: + self.ctx = {"operations": [], "variables": []} + else: + self.ctx = ctx + self.kwargs = kwargs + + +class CaptureA(CaptureF): + def __str__(self): + return f"{self.kwargs['name']}" + + def execute(self): + value = self.kwargs["real_attribute"] + return value + + +class CaptureLikeMock: + def __init__(self, name): + import unittest.mock as mock + + # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead. + get_target, attribute = mock._get_target(name) # type: ignore[attr-defined] + self.get_target = get_target + self.attribute = attribute + self.name = name + + def __enter__(self): + self.save = getattr(self.get_target(), self.attribute) + capt = CaptureA(name=self.name, real_attribute=self.save) + setattr(self.get_target(), self.attribute, capt) + + def __exit__(self, *exc_info): + setattr(self.get_target(), self.attribute, self.save) + + +class CaptureCall(Capture): + def __init__(self, callable, ctx=None, **kwargs): + if ctx is None: + self.ctx = {"operations": [], "variables": []} + else: + self.ctx = ctx + self.kwargs = kwargs + self.callable = callable + + def __str__(self): + return "{callable}({args},{kwargs})".format( + callable=self.callable, **self.kwargs + ) + + def execute(self): + # TODO: VitalyFedyunin execute kwargs and maybe nested structures + executed_args = [] + for arg in self.kwargs["args"]: + if isinstance(arg, Capture): + executed_args.append(arg.execute()) + else: + executed_args.append(arg) + left = get_val(self.callable) + return left(*executed_args, **self.kwargs["kwargs"]) + + +class CaptureVariableAssign(CaptureF): + def __str__(self): + variable = self.kwargs["variable"] + value = self.kwargs["value"] + return f"{variable} = {value}" + + def execute(self): + self.kwargs["variable"].calculated_value = self.kwargs["value"].execute() + + +class CaptureVariable(Capture): + # TODO(VitalyFedyunin): This should be atomic and thread safe + names_idx = 0 + + def __init__(self, value, ctx): + if CaptureControl.disabled: + raise RuntimeError("Attempting to create capture variable with capture off") + self.ctx = ctx + self.value = value + self.name = f"var_{CaptureVariable.names_idx}" + CaptureVariable.names_idx += 1 + self.ctx["variables"].append(self) + + def __str__(self): + return self.name + + def execute(self): + return self.calculated_value + + def apply_ops(self, dataframe): + # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) + self.ctx["variables"][0].calculated_value = dataframe + for op in self.ctx["operations"]: + op.execute() + return self.calculated_value + + +class CaptureGetItem(Capture): + def __init__(self, left, key, ctx): + self.ctx = ctx + self.left = left + self.key = key + + def __str__(self): + return f"{self.left}[{get_val(self.key)}]" + + def execute(self): + left = self.left.execute() + return left[self.key] + + +class CaptureSetItem(Capture): + def __init__(self, left, key, value, ctx): + self.ctx = ctx + self.left = left + self.key = key + self.value = value + + def __str__(self): + return f"{self.left}[{get_val(self.key)}] = {self.value}" + + def execute(self): + left = self.left.execute() + value = self.value.execute() + left[self.key] = value + + +class CaptureAdd(Capture): + def __init__(self, left, right, ctx): + self.ctx = ctx + self.left = left + self.right = right + + def __str__(self): + return f"{self.left} + {self.right}" + + def execute(self): + return get_val(self.left) + get_val(self.right) + + +class CaptureMul(Capture): + def __init__(self, left, right, ctx): + self.ctx = ctx + self.left = left + self.right = right + + def __str__(self): + return f"{self.left} * {self.right}" + + def execute(self): + return get_val(self.left) * get_val(self.right) + + +class CaptureSub(Capture): + def __init__(self, left, right, ctx): + self.ctx = ctx + self.left = left + self.right = right + + def __str__(self): + return f"{self.left} - {self.right}" + + def execute(self): + return get_val(self.left) - get_val(self.right) + + +class CaptureGetAttr(Capture): + def __init__(self, src, name, ctx): + self.ctx = ctx + self.src = src + self.name = name + + def __str__(self): + return f"{self.src}.{self.name}" + + def execute(self): + val = get_val(self.src) + return getattr(val, self.name) + + +def get_val(capture): + if isinstance(capture, Capture): + return capture.execute() + elif isinstance(capture, str): + return f'"{capture}"' + else: + return capture + + +class CaptureInitial(CaptureVariable): + def __init__(self, schema_df=None): + new_ctx: dict[str, list[Any]] = { + "operations": [], + "variables": [], + "schema_df": schema_df, + } + super().__init__(None, new_ctx) + self.name = f"input_{self.name}" + + +class CaptureDataFrame(CaptureInitial): + pass + + +class CaptureDataFrameWithDataPipeOps(CaptureDataFrame): + def as_datapipe(self): + return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self) + + def raw_iterator(self): + return self.as_datapipe().__iter__() + + def __iter__(self): + return iter(self._dataframes_as_tuples()) + + def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF): + dp = self._dataframes_per_row()._dataframes_concat(batch_size) + dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class) + dp._dp_contains_dataframe = True + return dp + + def groupby( + self, + group_key_fn, + *, + buffer_size=10000, + group_size=None, + guaranteed_group_size=None, + drop_remaining=False, + ): + dp = self._dataframes_per_row() + dp = dp.as_datapipe().groupby( + group_key_fn, + buffer_size=buffer_size, + group_size=group_size, + guaranteed_group_size=guaranteed_group_size, + drop_remaining=drop_remaining, + ) + return dp + + def shuffle(self, *args, **kwargs): + return self._dataframes_shuffle(*args, **kwargs) + + def filter(self, *args, **kwargs): + return self._dataframes_filter(*args, **kwargs) + + def collate(self, *args, **kwargs): + raise RuntimeError("Can't collate unbatched DataFrames stream") + + def __getattr__(self, attrname): # ? + if attrname in UNIMPLEMENTED_ATTR: + raise AttributeError("Attempting to get ", attrname) + if attrname in DATAPIPES_OPS: + return (self.as_datapipe()).__getattr__(attrname) + return super().__getattr__(attrname) + + +@functional_datapipe("trace_as_dataframe") +class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc] + source_datapipe: Optional[Any] = None + + # TODO(VitalyFedyunin): Must implement all special functions of datapipes + + def set_shuffle_settings(self, *args, **kwargs): + pass + + def is_shardable(self): + return False + + def __init__(self, source_datapipe, schema_df=None): + self.source_datapipe = source_datapipe + if schema_df is None: + schema_df = next(iter(self.source_datapipe)) + super().__init__(schema_df=schema_df) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py new file mode 100644 index 0000000000000000000000000000000000000000..e92186f274c94e51ded4c5d9fa449bba033f053b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/datapipes.py @@ -0,0 +1,136 @@ +# mypy: allow-untyped-defs +import random +from typing import Any + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper +from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe + + +__all__ = [ + "ConcatDataFramesPipe", + "DataFramesAsTuplesPipe", + "ExampleAggregateAsDataFrames", + "FilterDataFramesPipe", + "PerRowDataFramesPipe", + "ShuffleDataFramesPipe", +] + + +@functional_datapipe("_dataframes_as_tuples") +class DataFramesAsTuplesPipe(IterDataPipe): + def __init__(self, source_datapipe): + self.source_datapipe = source_datapipe + + def __iter__(self): + for df in self.source_datapipe: + # for record in df.to_records(index=False): + yield from df_wrapper.iterate(df) + + +@functional_datapipe("_dataframes_per_row", enable_df_api_tracing=True) +class PerRowDataFramesPipe(DFIterDataPipe): + def __init__(self, source_datapipe): + self.source_datapipe = source_datapipe + + def __iter__(self): + for df in self.source_datapipe: + # TODO(VitalyFedyunin): Replacing with TorchArrow only API, as we are dropping pandas as followup + for i in range(len(df)): + yield df[i : i + 1] + + +@functional_datapipe("_dataframes_concat", enable_df_api_tracing=True) +class ConcatDataFramesPipe(DFIterDataPipe): + def __init__(self, source_datapipe, batch=3): + self.source_datapipe = source_datapipe + self.n_batch = batch + + def __iter__(self): + buffer = [] + for df in self.source_datapipe: + buffer.append(df) + if len(buffer) == self.n_batch: + yield df_wrapper.concat(buffer) + buffer = [] + if len(buffer): + yield df_wrapper.concat(buffer) + + +@functional_datapipe("_dataframes_shuffle", enable_df_api_tracing=True) +class ShuffleDataFramesPipe(DFIterDataPipe): + def __init__(self, source_datapipe): + self.source_datapipe = source_datapipe + + def __iter__(self): + size = None + all_buffer: list[Any] = [] + for df in self.source_datapipe: + if size is None: + size = df_wrapper.get_len(df) + all_buffer.extend( + df_wrapper.get_item(df, i) for i in range(df_wrapper.get_len(df)) + ) + random.shuffle(all_buffer) + buffer = [] + for df in all_buffer: + buffer.append(df) + if len(buffer) == size: + yield df_wrapper.concat(buffer) + buffer = [] + if len(buffer): + yield df_wrapper.concat(buffer) + + +@functional_datapipe("_dataframes_filter", enable_df_api_tracing=True) +class FilterDataFramesPipe(DFIterDataPipe): + def __init__(self, source_datapipe, filter_fn): + self.source_datapipe = source_datapipe + self.filter_fn = filter_fn + + def __iter__(self): + size = None + all_buffer = [] + filter_res = [] + for df in self.source_datapipe: + if size is None: + size = len(df.index) + for i in range(len(df.index)): + all_buffer.append(df[i : i + 1]) + filter_res.append(self.filter_fn(df.iloc[i])) + + buffer = [] + for df, res in zip(all_buffer, filter_res): + if res: + buffer.append(df) + if len(buffer) == size: + yield df_wrapper.concat(buffer) + buffer = [] + if len(buffer): + yield df_wrapper.concat(buffer) + + +@functional_datapipe("_to_dataframes_pipe", enable_df_api_tracing=True) +class ExampleAggregateAsDataFrames(DFIterDataPipe): + def __init__(self, source_datapipe, dataframe_size=10, columns=None): + self.source_datapipe = source_datapipe + self.columns = columns + self.dataframe_size = dataframe_size + + def _as_list(self, item): + try: + return list(item) + except ( + Exception + ): # TODO(VitalyFedyunin): Replace with better iterable exception + return [item] + + def __iter__(self): + aggregate = [] + for item in self.source_datapipe: + aggregate.append(self._as_list(item)) + if len(aggregate) == self.dataframe_size: + yield df_wrapper.create_dataframe(aggregate, columns=self.columns) + aggregate = [] + if len(aggregate) > 0: + yield df_wrapper.create_dataframe(aggregate, columns=self.columns) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/structures.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/structures.py new file mode 100644 index 0000000000000000000000000000000000000000..7d585a37587becb4edf7610d0a94f8dc2c975270 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/dataframe/structures.py @@ -0,0 +1,22 @@ +from collections.abc import Iterator +from typing import Any + +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper +from torch.utils.data.datapipes.datapipe import DataChunk + + +__all__ = ["DataChunkDF"] + + +class DataChunkDF(DataChunk): + """DataChunkDF iterating over individual items inside of DataFrame containers, to access DataFrames user `raw_iterator`.""" + + def __iter__(self) -> Iterator[Any]: + for df in self.items: + yield from df_wrapper.iterate(df) + + def __len__(self) -> int: + total_len = 0 + for df in self.items: + total_len += df_wrapper.get_len(df) + return total_len diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/datapipe.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/datapipe.py new file mode 100644 index 0000000000000000000000000000000000000000..303d7ea9e0486dbee0de28575cb88e16ce061f90 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/datapipe.py @@ -0,0 +1,416 @@ +import functools +import pickle +from collections.abc import Iterable, Iterator +from typing import Callable, Optional, TypeVar + +from torch.utils._import_utils import import_dill +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta +from torch.utils.data.datapipes.utils.common import ( + _deprecation_warning, + _iter_deprecated_functional_names, + _map_deprecated_functional_names, +) +from torch.utils.data.dataset import Dataset, IterableDataset + + +dill = import_dill() +HAS_DILL = dill is not None + +__all__ = [ + "DataChunk", + "DFIterDataPipe", + "IterDataPipe", + "MapDataPipe", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +UNTRACABLE_DATAFRAME_PIPES = [ + "batch", # As it returns DataChunks + "groupby", # As it returns DataChunks + "_dataframes_as_tuples", # As it unpacks DF + "trace_as_dataframe", # As it used to mark DF for tracing +] + + +class DataChunk(list[_T]): + def __init__(self, items: Iterable[_T]) -> None: + items = list(items) + super().__init__(items) + self.items = items + + def as_str(self, indent: str = "") -> str: + return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]" + + def __iter__(self) -> Iterator[_T]: + yield from super().__iter__() + + def raw_iterator(self) -> Iterator[_T]: + yield from self.items + + +class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): + r""" + Iterable-style DataPipe. + + All DataPipes that represent an iterable of data samples should subclass this. + This style of DataPipes is particularly useful when data come from a stream, or + when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its + elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``. + + All subclasses should overwrite :meth:`__iter__`, which would return an + iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its + method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should + override ``reset()`` if necessary. The common usages include resetting buffers, pointers, + and various state variables within the custom ``IterDataPipe``. + + Note: + Only `one` iterator can be valid for each ``IterDataPipe`` at a time, + and the creation a second iterator will invalidate the first one. This constraint is necessary because + some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators. + The code example below presents details on how this constraint looks in practice. + If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_. + + These DataPipes can be invoked in two ways, using the class constructor or applying their + functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes). + You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple + operations in succession. + + .. _GitHub IterDataPipe Single Iterator Issue: + https://github.com/pytorch/data/issues/45 + + Note: + When a subclass is used with :class:`~torch.utils.data.DataLoader`, each + item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader` + iterator. When :attr:`num_workers > 0`, each worker process will have a + different copy of the DataPipe object, so it is often desired to configure + each copy independently to avoid having duplicate data returned from the + workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker + process, returns information about the worker. It can be used in either the + dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's + :attr:`worker_init_fn` option to modify each copy's behavior. + + Examples: + General Usage: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, Mapper + >>> dp = IterableWrapper(range(10)) + >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor + >>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0) + >>> list(filter_dp) + [2, 4, 6, 8, 10] + Single Iterator Constraint Example: + >>> from torchdata.datapipes.iter import IterableWrapper, Mapper + >>> source_dp = IterableWrapper(range(10)) + >>> it1 = iter(source_dp) + >>> list(it1) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + >>> it1 = iter(source_dp) + >>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1` + >>> next(it2) + 0 + >>> next(it1) # Further usage of `it1` will raise a `RunTimeError` + """ + + functions: dict[str, Callable] = {} + reduce_ex_hook: Optional[Callable] = None + getstate_hook: Optional[Callable] = None + str_hook: Optional[Callable] = None + repr_hook: Optional[Callable] = None + _valid_iterator_id: Optional[int] = None + _number_of_samples_yielded: int = 0 + _snapshot_state: _SnapshotState = _SnapshotState.NotStarted + _fast_forward_iterator: Optional[Iterator] = None + + def __iter__(self) -> Iterator[_T_co]: + return self + + def __getattr__(self, attribute_name): + if attribute_name in IterDataPipe.functions: + if attribute_name in _iter_deprecated_functional_names: + kwargs = _iter_deprecated_functional_names[attribute_name] + _deprecation_warning(**kwargs) + f = IterDataPipe.functions[attribute_name] + function = functools.partial(f, self) + functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",)) + return function + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{attribute_name}" + ) + + @classmethod + def register_function(cls, function_name, function): + cls.functions[function_name] = function + + @classmethod + def register_datapipe_as_function( + cls, function_name, cls_to_register, enable_df_api_tracing=False + ): + if function_name in cls.functions: + raise Exception( # noqa: TRY002 + f"Unable to add DataPipe function name {function_name} as it is already taken" + ) + + def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs): + result_pipe = cls(source_dp, *args, **kwargs) + if isinstance(result_pipe, IterDataPipe): + if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe): + if function_name not in UNTRACABLE_DATAFRAME_PIPES: + result_pipe = result_pipe.trace_as_dataframe() + + return result_pipe + + function = functools.partial( + class_function, cls_to_register, enable_df_api_tracing + ) + functools.update_wrapper( + wrapper=function, wrapped=cls_to_register, assigned=("__doc__",) + ) + cls.functions[function_name] = function + + def __getstate__(self): + """ + Serialize `lambda` functions when `dill` is available. + + If this doesn't cover your custom DataPipe's use case, consider writing custom methods for + `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization. + """ + state = self.__dict__ + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __reduce_ex__(self, *args, **kwargs): + if IterDataPipe.reduce_ex_hook is not None: + try: + return IterDataPipe.reduce_ex_hook(self) + except NotImplementedError: + pass + return super().__reduce_ex__(*args, **kwargs) + + @classmethod + def set_getstate_hook(cls, hook_fn): + if IterDataPipe.getstate_hook is not None and hook_fn is not None: + raise RuntimeError("Attempt to override existing getstate_hook") + IterDataPipe.getstate_hook = hook_fn + + @classmethod + def set_reduce_ex_hook(cls, hook_fn): + if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None: + raise RuntimeError("Attempt to override existing reduce_ex_hook") + IterDataPipe.reduce_ex_hook = hook_fn + + def __repr__(self): + if self.repr_hook is not None: + return self.repr_hook(self) + # Instead of showing , return the class name + return str(self.__class__.__qualname__) + + def __str__(self): + if self.str_hook is not None: + return self.str_hook(self) + # Instead of showing , return the class name + return str(self.__class__.__qualname__) + + def __dir__(self): + # for auto-completion in a REPL (e.g. Jupyter notebook) + return list(super().__dir__()) + list(self.functions.keys()) + + def reset(self) -> None: + r""" + Reset the `IterDataPipe` to the initial state. + + By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities, + they may want to override this method with implementations that + may clear the buffers and reset pointers of the DataPipe. + The `reset` method is always called when `__iter__` is called as part of `hook_iterator`. + """ + + +class DFIterDataPipe(IterDataPipe): + def _is_dfpipe(self): + return True + + +class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): + r""" + Map-style DataPipe. + + All datasets that represent a map from keys to data samples should subclass this. + Subclasses should overwrite :meth:`__getitem__`, supporting fetching a + data sample for a given, unique key. Subclasses can also optionally overwrite + :meth:`__len__`, which is expected to return the size of the dataset by many + :class:`~torch.utils.data.Sampler` implementations and the default options + of :class:`~torch.utils.data.DataLoader`. + + These DataPipes can be invoked in two ways, using the class constructor or applying their + functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes). + + Note: + :class:`~torch.utils.data.DataLoader` by default constructs an index + sampler that yields integral indices. To make it work with a map-style + DataPipe with non-integral indices/keys, a custom sampler must be provided. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper, Mapper + >>> dp = SequenceWrapper(range(10)) + >>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> batch_dp = map_dp_1.batch(batch_size=2) + >>> list(batch_dp) + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] + """ + + functions: dict[str, Callable] = {} + reduce_ex_hook: Optional[Callable] = None + getstate_hook: Optional[Callable] = None + str_hook: Optional[Callable] = None + repr_hook: Optional[Callable] = None + + def __getattr__(self, attribute_name): + if attribute_name in MapDataPipe.functions: + if attribute_name in _map_deprecated_functional_names: + kwargs = _map_deprecated_functional_names[attribute_name] + _deprecation_warning(**kwargs) + f = MapDataPipe.functions[attribute_name] + function = functools.partial(f, self) + functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",)) + return function + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{attribute_name}" + ) + + @classmethod + def register_function(cls, function_name, function): + cls.functions[function_name] = function + + @classmethod + def register_datapipe_as_function(cls, function_name, cls_to_register): + if function_name in cls.functions: + raise Exception( # noqa: TRY002 + f"Unable to add DataPipe function name {function_name} as it is already taken" + ) + + def class_function(cls, source_dp, *args, **kwargs): + result_pipe = cls(source_dp, *args, **kwargs) + return result_pipe + + function = functools.partial(class_function, cls_to_register) + functools.update_wrapper( + wrapper=function, wrapped=cls_to_register, assigned=("__doc__",) + ) + cls.functions[function_name] = function + + def __getstate__(self): + """ + Serialize `lambda` functions when `dill` is available. + + If this doesn't cover your custom DataPipe's use case, consider writing custom methods for + `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization. + """ + state = self.__dict__ + if MapDataPipe.getstate_hook is not None: + return MapDataPipe.getstate_hook(state) + return state + + def __reduce_ex__(self, *args, **kwargs): + if MapDataPipe.reduce_ex_hook is not None: + try: + return MapDataPipe.reduce_ex_hook(self) + except NotImplementedError: + pass + return super().__reduce_ex__(*args, **kwargs) + + @classmethod + def set_getstate_hook(cls, hook_fn): + if MapDataPipe.getstate_hook is not None and hook_fn is not None: + raise RuntimeError("Attempt to override existing getstate_hook") + MapDataPipe.getstate_hook = hook_fn + + @classmethod + def set_reduce_ex_hook(cls, hook_fn): + if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None: + raise RuntimeError("Attempt to override existing reduce_ex_hook") + MapDataPipe.reduce_ex_hook = hook_fn + + def __repr__(self): + if self.repr_hook is not None: + return self.repr_hook(self) + # Instead of showing , return the class name + return str(self.__class__.__qualname__) + + def __str__(self): + if self.str_hook is not None: + return self.str_hook(self) + # Instead of showing , return the class name + return str(self.__class__.__qualname__) + + def __dir__(self): + # for auto-completion in a REPL (e.g. Jupyter notebook) + return list(super().__dir__()) + list(self.functions.keys()) + + +class _DataPipeSerializationWrapper: + def __init__(self, datapipe): + self._datapipe = datapipe + + def __getstate__(self): + use_dill = False + try: + value = pickle.dumps(self._datapipe) + except Exception: + if HAS_DILL: + value = dill.dumps(self._datapipe) + use_dill = True + else: + raise + return (value, use_dill) + + def __setstate__(self, state): + value, use_dill = state + if use_dill: + self._datapipe = dill.loads(value) + else: + self._datapipe = pickle.loads(value) + + def __len__(self): + try: + return len(self._datapipe) + except Exception as e: + raise TypeError( + f"{type(self).__name__} instance doesn't have valid length" + ) from e + + +class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): + def __init__(self, datapipe: IterDataPipe[_T_co]): + super().__init__(datapipe) + self._datapipe_iter: Optional[Iterator[_T_co]] = None + + def __iter__(self) -> "_IterDataPipeSerializationWrapper": + self._datapipe_iter = iter(self._datapipe) + return self + + def __next__(self) -> _T_co: # type: ignore[type-var] + assert self._datapipe_iter is not None + return next(self._datapipe_iter) + + +class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe): + def __getitem__(self, idx): + return self._datapipe[idx] diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/datapipe.pyi b/phivenv/Lib/site-packages/torch/utils/data/datapipes/datapipe.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f8dc1887b02a14c032fe67b2b9fdd08793a3eb90 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/datapipe.pyi @@ -0,0 +1,726 @@ +# @generated by torch/utils/data/datapipes/gen_pyi.py from datapipe.pyi.in +# mypy: allow-untyped-defs +# This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection +# The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt +# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other +# classes/objects here, even though we are not injecting extra code into them at the moment. + +from collections.abc import Iterable, Iterator +from typing import Any, Callable, Literal, Optional, TypeVar, Union + +from torch.utils.data import Dataset, default_collate, IterableDataset +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +UNTRACABLE_DATAFRAME_PIPES: Any + +class DataChunk(list[_T]): + items: list[_T] + def __init__(self, items: Iterable[_T]) -> None: ... + def as_str(self, indent: str = "") -> str: ... + def __iter__(self) -> Iterator[_T]: ... + def raw_iterator(self) -> Iterator[_T]: ... + +class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): + functions: dict[str, Callable] = ... + reduce_ex_hook: Callable | None = ... + getstate_hook: Callable | None = ... + str_hook: Callable | None = ... + repr_hook: Callable | None = ... + def __getattr__(self, attribute_name: Any): ... + @classmethod + def register_function(cls, function_name: Any, function: Any) -> None: ... + @classmethod + def register_datapipe_as_function( + cls, + function_name: Any, + cls_to_register: Any, + ): ... + def __getstate__(self): ... + def __reduce_ex__(self, *args: Any, **kwargs: Any): ... + @classmethod + def set_getstate_hook(cls, hook_fn: Any) -> None: ... + @classmethod + def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ... + # Functional form of 'BatcherMapDataPipe' + def batch( + self, + batch_size: int, + drop_last: bool = False, + wrapper_class: type[DataChunk] = DataChunk, + ) -> MapDataPipe: + r""" + Create mini-batches of data (functional name: ``batch``). + + An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, + or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``. + + Args: + datapipe: Iterable DataPipe being batched + batch_size: The size of each batch + drop_last: Option to drop the last batch if it's not full + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> batch_dp = dp.batch(batch_size=2) + >>> list(batch_dp) + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] + """ + # Functional form of 'ConcaterMapDataPipe' + def concat(self, *datapipes: MapDataPipe) -> MapDataPipe: + r""" + Concatenate multiple Map DataPipes (functional name: ``concat``). + + The new index of is the cumulative sum of source DataPipes. + For example, if there are 2 source DataPipes both with length 5, + index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to + elements of the first DataPipe, and 5 to 9 would refer to elements + of the second DataPipe. + + Args: + datapipes: Map DataPipes being concatenated + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp1 = SequenceWrapper(range(3)) + >>> dp2 = SequenceWrapper(range(3)) + >>> concat_dp = dp1.concat(dp2) + >>> list(concat_dp) + [0, 1, 2, 0, 1, 2] + """ + # Functional form of 'MapperMapDataPipe' + def map(self, fn: Callable = ...) -> MapDataPipe: + r""" + Apply the input function over each item from the source DataPipe (functional name: ``map``). + + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. + + Args: + datapipe: Source MapDataPipe + fn: Function being applied to each item + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper, Mapper + >>> def add_one(x): + ... return x + 1 + >>> dp = SequenceWrapper(range(10)) + >>> map_dp_1 = dp.map(add_one) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + """ + # Functional form of 'ShufflerIterDataPipe' + def shuffle(self, *, indices: Optional[list] = None) -> IterDataPipe: + r""" + Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``). + + When it is used with :class:`~torch.utils.data.DataLoader`, the methods to + set up random seed are different based on :attr:`num_workers`. + + For single-process mode (:attr:`num_workers == 0`), the random seed is set before + the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process + mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed + for each worker process. + + Args: + datapipe: MapDataPipe being shuffled + indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> shuffle_dp = dp.shuffle().set_seed(0) + >>> list(shuffle_dp) + [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] + >>> list(shuffle_dp) + [6, 1, 9, 5, 2, 4, 7, 3, 8, 0] + >>> # Reset seed for Shuffler + >>> shuffle_dp = shuffle_dp.set_seed(0) + >>> list(shuffle_dp) + [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] + + Note: + Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an + ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to + the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order + of data during data-processing. + """ + # Functional form of 'ZipperMapDataPipe' + def zip(self, *datapipes: MapDataPipe[_T_co]) -> MapDataPipe: + r""" + Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). + + This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted. + + Args: + *datapipes: Map DataPipes being aggregated + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp1 = SequenceWrapper(range(3)) + >>> dp2 = SequenceWrapper(range(10, 13)) + >>> zip_dp = dp1.zip(dp2) + >>> list(zip_dp) + [(0, 10), (1, 11), (2, 12)] + """ + +class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): + functions: dict[str, Callable] = ... + reduce_ex_hook: Optional[Callable] = ... + getstate_hook: Optional[Callable] = ... + str_hook: Optional[Callable] = ... + repr_hook: Optional[Callable] = ... + _number_of_samples_yielded: int = ... + _snapshot_state: _SnapshotState = _SnapshotState.Iterating # noqa: PYI015 + _fast_forward_iterator: Optional[Iterator] = ... + def __getattr__(self, attribute_name: Any): ... + @classmethod + def register_function(cls, function_name: Any, function: Any) -> None: ... + @classmethod + def register_datapipe_as_function( + cls, + function_name: Any, + cls_to_register: Any, + enable_df_api_tracing: bool = ..., + ): ... + def __getstate__(self): ... + def __reduce_ex__(self, *args: Any, **kwargs: Any): ... + @classmethod + def set_getstate_hook(cls, hook_fn: Any) -> None: ... + @classmethod + def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ... + # Functional form of 'BatcherIterDataPipe' + def batch( + self, + batch_size: int, + drop_last: bool = False, + wrapper_class: type[DataChunk] = DataChunk, + ) -> IterDataPipe: + r""" + Creates mini-batches of data (functional name: ``batch``). + + An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the + last batch if ``drop_last`` is set to ``False``. + + Args: + datapipe: Iterable DataPipe being batched + batch_size: The size of each batch + drop_last: Option to drop the last batch if it's not full + wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding, + defaults to ``DataChunk`` + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> dp = dp.batch(batch_size=3, drop_last=True) + >>> list(dp) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + # Functional form of 'CollatorIterDataPipe' + def collate( + self, + conversion: Union[Callable[..., Any], dict[Union[str, Any], Union[Callable, Any]], None] = default_collate, + collate_fn: Optional[Callable] = None, + ) -> IterDataPipe: # fmt: skip + r""" + Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``). + + By default, it uses :func:`torch.utils.data.default_collate`. + + .. note:: + While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the + default behavior and `functools.partial` to specify any additional arguments. + + Args: + datapipe: Iterable DataPipe being collated + collate_fn: Customized collate function to collect and combine data or a batch of data. + Default function collates to Tensor(s) based on data type. + + Example: + >>> # xdoctest: +SKIP + >>> # Convert integer data to float Tensor + >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): + ... def __init__(self, start, end): + ... super(MyIterDataPipe).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + ... def __len__(self): + ... return self.end - self.start + ... + >>> ds = MyIterDataPipe(start=3, end=7) + >>> print(list(ds)) + [3, 4, 5, 6] + >>> def collate_fn(batch): + ... return torch.tensor(batch, dtype=torch.float) + ... + >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) + >>> print(list(collated_ds)) + [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] + """ + # Functional form of 'ConcaterIterDataPipe' + def concat(self, *datapipes: IterDataPipe) -> IterDataPipe: + r""" + Concatenates multiple Iterable DataPipes (functional name: ``concat``). + + The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones. + + Args: + datapipes: Iterable DataPipes being concatenated + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> import random + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1 = IterableWrapper(range(3)) + >>> dp2 = IterableWrapper(range(5)) + >>> list(dp1.concat(dp2)) + [0, 1, 2, 0, 1, 2, 3, 4] + """ + # Functional form of 'DemultiplexerIterDataPipe' + def demux( + self, + num_instances: int, + classifier_fn: Callable[[_T_co], Optional[int]], + drop_none: bool = False, + buffer_size: int = 1000, + ) -> list[IterDataPipe]: + r""" + Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``). + + A list of the child DataPipes is returned from this operation. + + Args: + datapipe: Iterable DataPipe being filtered + num_instances: number of instances of the DataPipe to create + classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None`` + drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None`` + buffer_size: this defines the maximum number of inputs that the buffer can hold across all child + DataPipes while waiting for their values to be yielded. + Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + + Examples: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def odd_or_even(n): + ... return n % 2 + >>> source_dp = IterableWrapper(range(5)) + >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even) + >>> list(dp1) + [0, 2, 4] + >>> list(dp2) + [1, 3] + >>> # It can also filter out any element that gets `None` from the `classifier_fn` + >>> def odd_or_even_no_zero(n): + ... return n % 2 if n != 0 else None + >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True) + >>> list(dp1) + [2, 4] + >>> list(dp2) + [1, 3] + """ + # Functional form of 'FilterIterDataPipe' + def filter(self, filter_fn: Callable, input_col=None) -> IterDataPipe: + r""" + Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``). + + Args: + datapipe: Iterable DataPipe being filtered + filter_fn: Customized function mapping an element to a boolean. + input_col: Index or indices of data which ``filter_fn`` is applied, such as: + + - ``None`` as default to apply ``filter_fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def is_even(n): + ... return n % 2 == 0 + >>> dp = IterableWrapper(range(5)) + >>> filter_dp = dp.filter(filter_fn=is_even) + >>> list(filter_dp) + [0, 2, 4] + """ + # Functional form of 'ForkerIterDataPipe' + def fork( + self, + num_instances: int, + buffer_size: int = 1000, + copy: Optional[Literal["shallow", "deep"]] = None, + ) -> list[IterDataPipe]: + r""" + Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``). + + Args: + datapipe: Iterable DataPipe being copied + num_instances: number of instances of the datapipe to create + buffer_size: this restricts how far ahead the leading child DataPipe + can read relative to the slowest child DataPipe. + Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + copy: copy strategy to use for items yielded by each branch. Supported + options are ``None`` for no copying, ``"shallow"`` for shallow object + copies, and ``"deep"`` for deep object copies. Defaults to ``None``. + + Note: + All branches of the forked pipeline return the identical object unless + the copy parameter is supplied. If the object is mutable or contains + mutable objects, changing them in one branch will affect all others. + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper(range(5)) + >>> dp1, dp2 = source_dp.fork(num_instances=2) + >>> list(dp1) + [0, 1, 2, 3, 4] + >>> list(dp2) + [0, 1, 2, 3, 4] + """ + # Functional form of 'GrouperIterDataPipe' + def groupby( + self, + group_key_fn: Callable[[_T_co], Any], + *, + keep_key: bool = False, + buffer_size: int = 10000, + group_size: Optional[int] = None, + guaranteed_group_size: Optional[int] = None, + drop_remaining: bool = False, + ) -> IterDataPipe: + r""" + Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``. + + (functional name: ``groupby``). + + The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group + will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full, + the DataPipe will yield the largest batch with the same key, provided that its size is larger + than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``. + + After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity + will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``. + + Args: + datapipe: Iterable datapipe to be grouped + group_key_fn: Function used to generate group key from the data of the source datapipe + keep_key: Option to yield the matching key along with the items in a tuple, + resulting in `(key, [items])` otherwise returning [items] + buffer_size: The size of buffer for ungrouped data + group_size: The max size of each group, a batch is yielded as soon as it reaches this size + guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full + drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer + when the buffer is full + + Example: + >>> import os + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def group_fn(file): + ... return os.path.basename(file).split(".")[0] + >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) + >>> dp0 = source_dp.groupby(group_key_fn=group_fn) + >>> list(dp0) + [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] + >>> # A group is yielded as soon as its size equals to `group_size` + >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2) + >>> list(dp1) + [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] + >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` + >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) + >>> list(dp2) + [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] + """ + # Functional form of 'FileListerIterDataPipe' + def list_files( + self, + masks: Union[str, list[str]] = "", + *, + recursive: bool = False, + abspath: bool = False, + non_deterministic: bool = False, + length: int = -1, + ) -> IterDataPipe: + r""" + Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory. + + Multiple root directories can be provided (functional name: ``list_files``). + + Args: + root: Root directory or a sequence of root directories + masks: Unix style filter string or string list for filtering file name(s) + recursive: Whether to return pathname from nested directories or not + abspath: Whether to return relative pathname or absolute pathname + non_deterministic: Whether to return pathname in sorted order or not. + If ``False``, the results yielded from each root directory will be sorted + length: Nominal length of the datapipe + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import FileLister + >>> dp = FileLister(root=".", recursive=True) + >>> list(dp) + ['example.py', './data/data.tar'] + """ + # Functional form of 'MapperIterDataPipe' + def map( + self, + fn: Callable, + input_col=None, + output_col=None, + ) -> IterDataPipe: + r""" + Applies a function over each item from the source DataPipe (functional name: ``map``). + + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. + + Args: + datapipe: Source Iterable DataPipe + fn: Function being applied over each item + input_col: Index or indices of data which ``fn`` is applied, such as: + + - ``None`` as default to apply ``fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified + only when ``input_col`` is not ``None`` + + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with + multiple indices, the left-most one is used, and other indices will be removed. + - Integer is used for list/tuple. ``-1`` represents to append result at the end. + - Key is used for dict. New key is acceptable. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, Mapper + >>> def add_one(x): + ... return x + 1 + >>> dp = IterableWrapper(range(10)) + >>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` + >>> # Use `functools.partial` or explicitly define the function instead + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + """ + # Functional form of 'MultiplexerIterDataPipe' + def mux(self, *datapipes) -> IterDataPipe: + r""" + Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). + + As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, + and so on. It ends when the shortest input DataPipe is exhausted. + + Args: + datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> list(dp1.mux(dp2, dp3)) + [0, 10, 20, 1, 11, 21, 2, 12, 22] + """ + # Functional form of 'FileOpenerIterDataPipe' + def open_files( + self, + mode: str = "r", + encoding: Optional[str] = None, + length: int = -1, + ) -> IterDataPipe: + r""" + Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``). + + Args: + datapipe: Iterable datapipe that provides pathnames + mode: An optional string that specifies the mode in which + the file is opened by ``open()``. It defaults to ``r``, other options are + ``b`` for reading in binary mode and ``t`` for text mode. + encoding: An optional string that specifies the encoding of the + underlying file. It defaults to ``None`` to match the default encoding of ``open``. + length: Nominal length of the datapipe + + Note: + The opened file handles will be closed by Python's GC periodically. Users can choose + to close them explicitly. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader + >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt')) + >>> dp = FileOpener(dp) + >>> dp = StreamReader(dp) + >>> list(dp) + [('./abc.txt', 'abc')] + """ + # Functional form of 'StreamReaderIterDataPipe' + def read_from_stream(self, chunk: Optional[int] = None) -> IterDataPipe: + r""" + Given IO streams and their label names, yield bytes with label name as tuple. + + (functional name: ``read_from_stream``). + + Args: + datapipe: Iterable DataPipe provides label/URL and byte stream + chunk: Number of bytes to be read from stream per iteration. + If ``None``, all bytes will be read until the EOF. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader + >>> from io import StringIO + >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))]) + >>> list(StreamReader(dp, chunk=1)) + [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')] + """ + # Functional form of 'RoutedDecoderIterDataPipe' + def routed_decode( + self, + *handlers: Callable, + key_fn: Callable = ..., + ) -> IterDataPipe: + r""" + Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple. + + (functional name: ``routed_decode``) + + Args: + datapipe: Iterable datapipe that provides pathname and binary stream in tuples + handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder + handlers will be set as default. If multiple handles are provided, the priority + order follows the order of handlers (the first handler has the top priority) + key_fn: Function for decoder to extract key from pathname to dispatch handlers. + Default is set to extract file extension from pathname + + Note: + When ``key_fn`` is specified returning anything other than extension, the default + handler will not work and users need to specify custom handler. Custom handler + could use regex to determine the eligibility to handle data. + """ + # Functional form of 'ShardingFilterIterDataPipe' + def sharding_filter(self, sharding_group_filter=None) -> IterDataPipe: + r""" + Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``). + + After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the + original DataPipe, where `n` equals to the number of instances. + + Args: + source_datapipe: Iterable DataPipe that will be sharded + """ + # Functional form of 'ShufflerIterDataPipe' + def shuffle( + self, + *, + buffer_size: int = 10000, + unbatch_level: int = 0, + ) -> IterDataPipe: + r""" + Shuffle the input DataPipe with a buffer (functional name: ``shuffle``). + + The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then, + each item will be yielded from the buffer by reservoir sampling via iterator. + + ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the + datapipe is not shuffled. In order to fully shuffle all elements from datapipe, + ``buffer_size`` is required to be greater than or equal to the size of datapipe. + + When it is used with :class:`torch.utils.data.DataLoader`, the methods to + set up random seed are different based on :attr:`num_workers`. + + For single-process mode (:attr:`num_workers == 0`), the random seed is set before + the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process + mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed + for each worker process. + + Args: + datapipe: The IterDataPipe being shuffled + buffer_size: The buffer size for shuffling (default to ``10000``) + unbatch_level: Specifies if it is necessary to unbatch source data before + applying the shuffle + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> shuffle_dp = dp.shuffle() + >>> list(shuffle_dp) + [0, 4, 1, 6, 3, 2, 9, 5, 7, 8] + """ + # Functional form of 'UnBatcherIterDataPipe' + def unbatch(self, unbatch_level: int = 1) -> IterDataPipe: + r""" + Undos batching of data (functional name: ``unbatch``). + + In other words, it flattens the data up to the specified level within a batched DataPipe. + + Args: + datapipe: Iterable DataPipe being un-batched + unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``, + it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]]) + >>> dp1 = source_dp.unbatch() + >>> list(dp1) + [[0, 1], [2], [3, 4], [5], [6]] + >>> dp2 = source_dp.unbatch(unbatch_level=2) + >>> list(dp2) + [0, 1, 2, 3, 4, 5, 6] + """ + # Functional form of 'ZipperIterDataPipe' + def zip(self, *datapipes: IterDataPipe) -> IterDataPipe: + r""" + Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). + + The output is stopped as soon as the shortest input DataPipe is exhausted. + + Args: + *datapipes: Iterable DataPipes being aggregated + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> list(dp1.zip(dp2, dp3)) + [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] + """ + +class DFIterDataPipe(IterDataPipe): + def _is_dfpipe(self): ... + def __iter__(self): ... + +class _DataPipeSerializationWrapper: + def __init__(self, datapipe): ... + def __getstate__(self): ... + def __setstate__(self, state): ... + def __len__(self): ... + +class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): + def __iter__(self): ... + +class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe): + def __getitem__(self, idx): ... diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/gen_pyi.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/gen_pyi.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b2265da2d5c37bbec50a4e4f19d30d207b1d62 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/gen_pyi.py @@ -0,0 +1,336 @@ +# mypy: allow-untyped-defs +import os +from collections import defaultdict +from pathlib import Path +from typing import Any, Union +from typing_extensions import deprecated + + +try: + from torchgen.api.python import format_function_signature + from torchgen.utils import FileManager as FileManager +except ImportError: + import sys + + REPO_ROOT = Path(__file__).absolute().parents[4] + sys.path.insert(0, str(REPO_ROOT)) + + from torchgen.api.python import format_function_signature + from torchgen.utils import FileManager + + if len(sys.path) > 0 and sys.path[0] == str(REPO_ROOT): + del sys.path[0] + + +__all__: list[str] = [] # not intended to expose any symbols + + +def __dir__() -> list[str]: + return [] # appease public API test + + +@deprecated( + "`torch.utils.data.datapipes.gen_pyi.materialize_lines` is deprecated and will be removed in the future.", + category=FutureWarning, +) +def materialize_lines(lines: list[str], indentation: int) -> str: + output = "" + new_line_with_indent = "\n" + " " * indentation + for i, line in enumerate(lines): + if i != 0: + output += new_line_with_indent + output += line.replace("\n", new_line_with_indent) + return output + + +@deprecated( + "`torch.utils.data.datapipes.gen_pyi.gen_from_template` is deprecated and will be removed in the future.", + category=FutureWarning, +) +def gen_from_template( + dir: str, + template_name: str, + output_name: str, + replacements: list[tuple[str, Any, int]], +): + template_path = os.path.join(dir, template_name) + output_path = os.path.join(dir, output_name) + + with open(template_path, encoding="utf-8") as f: + content = f.read() + for placeholder, lines, indentation in replacements: + with open(output_path, "w", encoding="utf-8") as f: + content = content.replace( + placeholder, materialize_lines(lines, indentation) + ) + f.write(content) + + +def find_file_paths(dir_paths: list[str], files_to_exclude: set[str]) -> set[str]: + """ + When given a path to a directory, returns the paths to the relevant files within it. + + This function does NOT recursive traverse to subdirectories. + """ + paths: set[str] = set() + for dir_path in dir_paths: + all_files = os.listdir(dir_path) + python_files = {fname for fname in all_files if ".py" == fname[-3:]} + filter_files = { + fname for fname in python_files if fname not in files_to_exclude + } + paths.update({os.path.join(dir_path, fname) for fname in filter_files}) + return paths + + +def extract_method_name(line: str) -> str: + """Extract method name from decorator in the form of "@functional_datapipe({method_name})".""" + if '("' in line: + start_token, end_token = '("', '")' + elif "('" in line: + start_token, end_token = "('", "')" + else: + raise RuntimeError( + f"Unable to find appropriate method name within line:\n{line}" + ) + start, end = line.find(start_token) + len(start_token), line.find(end_token) + return line[start:end] + + +def extract_class_name(line: str) -> str: + """Extract class name from class definition in the form of "class {CLASS_NAME}({Type}):".""" + start_token = "class " + end_token = "(" + start, end = line.find(start_token) + len(start_token), line.find(end_token) + return line[start:end] + + +def parse_datapipe_file( + file_path: str, +) -> tuple[dict[str, list[str]], dict[str, str], set[str], dict[str, list[str]]]: + """Given a path to file, parses the file and returns a dictionary of method names to function signatures.""" + method_to_signature, method_to_class_name, special_output_type = {}, {}, set() + doc_string_dict = defaultdict(list) + with open(file_path, encoding="utf-8") as f: + open_paren_count = 0 + method_name, class_name, signature = "", "", "" + skip = False + for line in f: + if line.count('"""') % 2 == 1: + skip = not skip + if skip or '"""' in line: # Saving docstrings + doc_string_dict[method_name].append(line) + continue + if "@functional_datapipe" in line: + method_name = extract_method_name(line) + doc_string_dict[method_name] = [] + continue + if method_name and "class " in line: + class_name = extract_class_name(line) + continue + if method_name and ("def __init__(" in line or "def __new__(" in line): + if "def __new__(" in line: + special_output_type.add(method_name) + open_paren_count += 1 + start = line.find("(") + len("(") + line = line[start:] + if open_paren_count > 0: + open_paren_count += line.count("(") + open_paren_count -= line.count(")") + if open_paren_count == 0: + end = line.rfind(")") + signature += line[:end] + method_to_signature[method_name] = process_signature(signature) + method_to_class_name[method_name] = class_name + method_name, class_name, signature = "", "", "" + elif open_paren_count < 0: + raise RuntimeError( + "open parenthesis count < 0. This shouldn't be possible." + ) + else: + signature += line.strip() + return ( + method_to_signature, + method_to_class_name, + special_output_type, + doc_string_dict, + ) + + +def parse_datapipe_files( + file_paths: set[str], +) -> tuple[dict[str, list[str]], dict[str, str], set[str], dict[str, list[str]]]: + methods_and_signatures = {} + methods_and_class_names = {} + methods_with_special_output_types = set() + methods_and_doc_strings = {} + for path in file_paths: + ( + method_to_signature, + method_to_class_name, + methods_needing_special_output_types, + doc_string_dict, + ) = parse_datapipe_file(path) + methods_and_signatures.update(method_to_signature) + methods_and_class_names.update(method_to_class_name) + methods_with_special_output_types.update(methods_needing_special_output_types) + methods_and_doc_strings.update(doc_string_dict) + return ( + methods_and_signatures, + methods_and_class_names, + methods_with_special_output_types, + methods_and_doc_strings, + ) + + +def split_outside_bracket(line: str, delimiter: str = ",") -> list[str]: + """Given a line of text, split it on comma unless the comma is within a bracket '[]'.""" + bracket_count = 0 + curr_token = "" + res = [] + for char in line: + if char == "[": + bracket_count += 1 + elif char == "]": + bracket_count -= 1 + elif char == delimiter and bracket_count == 0: + res.append(curr_token) + curr_token = "" + continue + curr_token += char + res.append(curr_token) + return res + + +def process_signature(line: str) -> list[str]: + """ + Clean up a given raw function signature. + + This includes removing the self-referential datapipe argument, default + arguments of input functions, newlines, and spaces. + """ + tokens: list[str] = split_outside_bracket(line) + for i, token in enumerate(tokens): + tokens[i] = token.strip(" ") + if token == "cls": + tokens[i] = "self" + elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"): + # Remove the datapipe after 'self' or 'cls' unless it has '*' + tokens[i] = "" + elif "Callable =" in token: # Remove default argument if it is a function + head = token.rpartition("=")[0] + tokens[i] = head.strip(" ") + " = ..." + tokens = [t for t in tokens if t != ""] + return tokens + + +def get_method_definitions( + file_path: Union[str, list[str]], + files_to_exclude: set[str], + deprecated_files: set[str], + default_output_type: str, + method_to_special_output_type: dict[str, str], + root: str = "", +) -> list[str]: + """ + #.pyi generation for functional DataPipes Process. + + # 1. Find files that we want to process (exclude the ones who don't) + # 2. Parse method name and signature + # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces + """ + if root == "": + root = str(Path(__file__).parent.resolve()) + file_path = [file_path] if isinstance(file_path, str) else file_path + file_path = [os.path.join(root, path) for path in file_path] + file_paths = find_file_paths( + file_path, files_to_exclude=files_to_exclude.union(deprecated_files) + ) + ( + methods_and_signatures, + methods_and_class_names, + methods_w_special_output_types, + methods_and_doc_strings, + ) = parse_datapipe_files(file_paths) + + for fn_name in method_to_special_output_type: + if fn_name not in methods_w_special_output_types: + methods_w_special_output_types.add(fn_name) + + method_definitions = [] + for method_name, arguments in methods_and_signatures.items(): + class_name = methods_and_class_names[method_name] + if method_name in methods_w_special_output_types: + output_type = method_to_special_output_type[method_name] + else: + output_type = default_output_type + doc_string = "".join(methods_and_doc_strings[method_name]) + if doc_string == "": + doc_string = " ..." + else: + doc_string = "\n" + doc_string + definition = format_function_signature(method_name, arguments, output_type) + method_definitions.append( + f"# Functional form of '{class_name}'\n" + + definition.removesuffix("...").rstrip() # remove "..." + + doc_string, + ) + method_definitions.sort( + key=lambda s: s.split("\n")[1] + ) # sorting based on method_name + + return method_definitions + + +# Defined outside of main() so they can be imported by TorchData +iterDP_file_path: str = "iter" +iterDP_files_to_exclude: set[str] = {"__init__.py", "utils.py"} +iterDP_deprecated_files: set[str] = set() +iterDP_method_to_special_output_type: dict[str, str] = { + "demux": "list[IterDataPipe]", + "fork": "list[IterDataPipe]", +} + +mapDP_file_path: str = "map" +mapDP_files_to_exclude: set[str] = {"__init__.py", "utils.py"} +mapDP_deprecated_files: set[str] = set() +mapDP_method_to_special_output_type: dict[str, str] = {"shuffle": "IterDataPipe"} + + +def main() -> None: + """ + # Inject file into template datapipe.pyi.in. + + TODO: The current implementation of this script only generates interfaces for built-in methods. To generate + interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`. + """ + iter_method_definitions = get_method_definitions( + iterDP_file_path, + iterDP_files_to_exclude, + iterDP_deprecated_files, + "IterDataPipe", + iterDP_method_to_special_output_type, + ) + + map_method_definitions = get_method_definitions( + mapDP_file_path, + mapDP_files_to_exclude, + mapDP_deprecated_files, + "MapDataPipe", + mapDP_method_to_special_output_type, + ) + + path = Path(__file__).absolute().parent + fm = FileManager(install_dir=path, template_dir=path, dry_run=False) + fm.write_with_template( + "datapipe.pyi", + "datapipe.pyi.in", + lambda: { + "IterDataPipeMethods": iter_method_definitions, + "MapDataPipeMethods": map_method_definitions, + }, + ) + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__init__.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa808883b8da27f6776d16dd0fbee75ac66e6749 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__init__.py @@ -0,0 +1,65 @@ +from torch.utils.data.datapipes.iter.callable import ( + CollatorIterDataPipe as Collator, + MapperIterDataPipe as Mapper, +) +from torch.utils.data.datapipes.iter.combinatorics import ( + SamplerIterDataPipe as Sampler, + ShufflerIterDataPipe as Shuffler, +) +from torch.utils.data.datapipes.iter.combining import ( + ConcaterIterDataPipe as Concater, + DemultiplexerIterDataPipe as Demultiplexer, + ForkerIterDataPipe as Forker, + MultiplexerIterDataPipe as Multiplexer, + ZipperIterDataPipe as Zipper, +) +from torch.utils.data.datapipes.iter.filelister import ( + FileListerIterDataPipe as FileLister, +) +from torch.utils.data.datapipes.iter.fileopener import ( + FileOpenerIterDataPipe as FileOpener, +) +from torch.utils.data.datapipes.iter.grouping import ( + BatcherIterDataPipe as Batcher, + GrouperIterDataPipe as Grouper, + UnBatcherIterDataPipe as UnBatcher, +) +from torch.utils.data.datapipes.iter.routeddecoder import ( + RoutedDecoderIterDataPipe as RoutedDecoder, +) +from torch.utils.data.datapipes.iter.selecting import FilterIterDataPipe as Filter +from torch.utils.data.datapipes.iter.sharding import ( + ShardingFilterIterDataPipe as ShardingFilter, +) +from torch.utils.data.datapipes.iter.streamreader import ( + StreamReaderIterDataPipe as StreamReader, +) +from torch.utils.data.datapipes.iter.utils import ( + IterableWrapperIterDataPipe as IterableWrapper, +) + + +__all__ = [ + "Batcher", + "Collator", + "Concater", + "Demultiplexer", + "FileLister", + "FileOpener", + "Filter", + "Forker", + "Grouper", + "IterableWrapper", + "Mapper", + "Multiplexer", + "RoutedDecoder", + "Sampler", + "ShardingFilter", + "Shuffler", + "StreamReader", + "UnBatcher", + "Zipper", +] + +# Please keep this list sorted +assert __all__ == sorted(__all__) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99cb759382f4ce5da005b482e832c78faafb2e39 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d0d6f98633e4e826ef371c908b41b9c42cc7f5c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/callable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..917188838bf735df9190d5467161a9c1933bcccd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/combining.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/combining.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb1533c1560c84c9d4489854d2cb6da8d0896175 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/combining.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c534ac921c9a4f150947fdc77bfa397c78233ca3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15a4d32265c272f8a7a313548a76e0aceba2c9ed Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f21ab8fb9616da6243abec96f987fdb7b820dbe Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33321a7fc8c34ff59815601244ddc339b2295c27 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae3c051cb1bdfa6ecd12ee798b82a81faf1fb098 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..914f0b13dc9a84ac63c5ef8bb149244430ebf034 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdbf909d8814ab7f2bf4b1b34e97e32008b919ab Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e33b05a8e61a39401b594f7d20f1616d02618879 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/callable.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/callable.py new file mode 100644 index 0000000000000000000000000000000000000000..740bd0e7024940347133c8d8212689e1fc60ab5e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/callable.py @@ -0,0 +1,242 @@ +# mypy: allow-untyped-defs +import functools +from collections import namedtuple +from collections.abc import Iterator, Sized +from typing import Any, Callable, Optional, TypeVar, Union + +from torch.utils.data._utils.collate import default_collate +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import ( + _check_unpickable_fn, + validate_input_col, +) + + +__all__ = [ + "CollatorIterDataPipe", + "MapperIterDataPipe", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +@functional_datapipe("map") +class MapperIterDataPipe(IterDataPipe[_T_co]): + r""" + Applies a function over each item from the source DataPipe (functional name: ``map``). + + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. + + Args: + datapipe: Source Iterable DataPipe + fn: Function being applied over each item + input_col: Index or indices of data which ``fn`` is applied, such as: + + - ``None`` as default to apply ``fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified + only when ``input_col`` is not ``None`` + + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with + multiple indices, the left-most one is used, and other indices will be removed. + - Integer is used for list/tuple. ``-1`` represents to append result at the end. + - Key is used for dict. New key is acceptable. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, Mapper + >>> def add_one(x): + ... return x + 1 + >>> dp = IterableWrapper(range(10)) + >>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` + >>> # Use `functools.partial` or explicitly define the function instead + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + """ + + datapipe: IterDataPipe + fn: Callable + + def __init__( + self, + datapipe: IterDataPipe, + fn: Callable, + input_col=None, + output_col=None, + ) -> None: + super().__init__() + self.datapipe = datapipe + + _check_unpickable_fn(fn) + self.fn = fn # type: ignore[assignment] + + self.input_col = input_col + if input_col is None and output_col is not None: + raise ValueError("`output_col` must be None when `input_col` is None.") + if isinstance(output_col, (list, tuple)): + if len(output_col) > 1: + raise ValueError("`output_col` must be a single-element list or tuple") + output_col = output_col[0] + self.output_col = output_col + validate_input_col(fn, input_col) + + def _apply_fn(self, data): + if self.input_col is None and self.output_col is None: + return self.fn(data) + + if self.input_col is None: + res = self.fn(data) + elif isinstance(self.input_col, (list, tuple)): + args = tuple(data[col] for col in self.input_col) + res = self.fn(*args) + else: + res = self.fn(data[self.input_col]) + + # Copy tuple to list and run in-place modification because tuple is immutable. + if isinstance(data, tuple): + t_flag = True + data = list(data) + else: + t_flag = False + + if self.output_col is None: + if isinstance(self.input_col, (list, tuple)): + data[self.input_col[0]] = res + for idx in sorted(self.input_col[1:], reverse=True): + del data[idx] + else: + data[self.input_col] = res + else: + if self.output_col == -1: + data.append(res) + else: + data[self.output_col] = res + + # Convert list back to tuple + return tuple(data) if t_flag else data + + def __iter__(self) -> Iterator[_T_co]: + for data in self.datapipe: + yield self._apply_fn(data) + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + return len(self.datapipe) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + +def _collate_helper(conversion, item): + # TODO(VitalyFedyunin): Verify that item is any sort of batch + if len(item.items) > 1: + # TODO(VitalyFedyunin): Compact all batch dataframes into one + raise RuntimeError("Only supports one DataFrame per batch") + df = item[0] + columns_name = df_wrapper.get_columns(df) + tuple_names: list = [] + tuple_values: list = [] + + for name in conversion.keys(): + if name not in columns_name: + raise RuntimeError("Conversion keys missmatch") + + for name in columns_name: + if name in conversion: + if not callable(conversion[name]): + raise RuntimeError( + "Collate (DF)DataPipe requires callable as dict values" + ) + collation_fn = conversion[name] + else: + # TODO(VitalyFedyunin): Add default collation into df_wrapper + try: + import torcharrow.pytorch as tap # type: ignore[import] + + collation_fn = tap.rec.Default() + except Exception as e: + raise RuntimeError( + "unable to import default collation function from the TorchArrow" + ) from e + + tuple_names.append(str(name)) + value = collation_fn(df[name]) + tuple_values.append(value) + + # TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here + # TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty + tpl_cls = namedtuple("CollateResult", tuple_names) # type: ignore[misc] + tuple = tpl_cls(*tuple_values) + return tuple + + +@functional_datapipe("collate") +class CollatorIterDataPipe(MapperIterDataPipe): + r""" + Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``). + + By default, it uses :func:`torch.utils.data.default_collate`. + + .. note:: + While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the + default behavior and `functools.partial` to specify any additional arguments. + + Args: + datapipe: Iterable DataPipe being collated + collate_fn: Customized collate function to collect and combine data or a batch of data. + Default function collates to Tensor(s) based on data type. + + Example: + >>> # xdoctest: +SKIP + >>> # Convert integer data to float Tensor + >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): + ... def __init__(self, start, end): + ... super(MyIterDataPipe).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + ... def __len__(self): + ... return self.end - self.start + ... + >>> ds = MyIterDataPipe(start=3, end=7) + >>> print(list(ds)) + [3, 4, 5, 6] + >>> def collate_fn(batch): + ... return torch.tensor(batch, dtype=torch.float) + ... + >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) + >>> print(list(collated_ds)) + [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] + """ + + def __init__( + self, + datapipe: IterDataPipe, + conversion: Union[ + Callable[..., Any], dict[Union[str, Any], Union[Callable, Any]], None + ] = default_collate, + collate_fn: Optional[Callable] = None, + ) -> None: + # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]` + # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]` + if collate_fn is not None: + super().__init__(datapipe, fn=collate_fn) + else: + if callable(conversion): + super().__init__(datapipe, fn=conversion) + else: + # TODO(VitalyFedyunin): Validate passed dictionary + collate_fn = functools.partial(_collate_helper, conversion) + super().__init__(datapipe, fn=collate_fn) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/combinatorics.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/combinatorics.py new file mode 100644 index 0000000000000000000000000000000000000000..50fdb869020d3e6224e534004327255df8af6e05 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/combinatorics.py @@ -0,0 +1,190 @@ +# mypy: allow-untyped-defs +import random +from collections.abc import Iterator, Sized +from typing import Optional, TypeVar + +import torch +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.sampler import Sampler, SequentialSampler + + +__all__ = [ + "SamplerIterDataPipe", + "ShufflerIterDataPipe", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class SamplerIterDataPipe(IterDataPipe[_T_co]): + r""" + Generate sample elements using the provided ``Sampler`` (defaults to :class:`SequentialSampler`). + + Args: + datapipe: IterDataPipe to sample from + sampler: Sampler class to generate sample elements from input DataPipe. + Default is :class:`SequentialSampler` for IterDataPipe + """ + + datapipe: IterDataPipe + sampler: Sampler + + def __init__( + self, + datapipe: IterDataPipe, + sampler: type[Sampler] = SequentialSampler, + sampler_args: Optional[tuple] = None, + sampler_kwargs: Optional[dict] = None, + ) -> None: + assert isinstance( + datapipe, Sized + ), "Sampler class requires input datapipe implemented `__len__`" + super().__init__() + self.datapipe = datapipe + self.sampler_args = () if sampler_args is None else sampler_args + self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs + # https://github.com/python/mypy/pull/9629 will solve + self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc] + + def __iter__(self) -> Iterator[_T_co]: + return iter(self.sampler) + + def __len__(self) -> int: + # Dataset has been tested as `Sized` + if isinstance(self.sampler, Sized): + return len(self.sampler) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + +@functional_datapipe("shuffle") +class ShufflerIterDataPipe(IterDataPipe[_T_co]): + r""" + Shuffle the input DataPipe with a buffer (functional name: ``shuffle``). + + The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then, + each item will be yielded from the buffer by reservoir sampling via iterator. + + ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the + datapipe is not shuffled. In order to fully shuffle all elements from datapipe, + ``buffer_size`` is required to be greater than or equal to the size of datapipe. + + When it is used with :class:`torch.utils.data.DataLoader`, the methods to + set up random seed are different based on :attr:`num_workers`. + + For single-process mode (:attr:`num_workers == 0`), the random seed is set before + the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process + mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed + for each worker process. + + Args: + datapipe: The IterDataPipe being shuffled + buffer_size: The buffer size for shuffling (default to ``10000``) + unbatch_level: Specifies if it is necessary to unbatch source data before + applying the shuffle + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> shuffle_dp = dp.shuffle() + >>> list(shuffle_dp) + [0, 4, 1, 6, 3, 2, 9, 5, 7, 8] + """ + + datapipe: IterDataPipe[_T_co] + buffer_size: int + _buffer: list[_T_co] + _enabled: bool + _seed: Optional[int] + _rng: random.Random + + def __init__( + self, + datapipe: IterDataPipe[_T_co], + *, + buffer_size: int = 10000, + unbatch_level: int = 0, + ) -> None: + super().__init__() + # TODO: Performance optimization + # buffer can be a fixed size and remove expensive `append()` and `len()` operations + self._buffer: list[_T_co] = [] + assert buffer_size > 0, "buffer_size should be larger than 0" + if unbatch_level == 0: + self.datapipe = datapipe + else: + self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level) + self.buffer_size = buffer_size + self._enabled = True + self._seed = None + self._rng = random.Random() + + def set_shuffle(self, shuffle=True): + self._enabled = shuffle + return self + + def set_seed(self, seed: int): + self._seed = seed + return self + + def __iter__(self) -> Iterator[_T_co]: + if not self._enabled: + yield from self.datapipe + else: + for x in self.datapipe: + if len(self._buffer) == self.buffer_size: + idx = self._rng.randint(0, len(self._buffer) - 1) + val, self._buffer[idx] = self._buffer[idx], x + yield val + else: + self._buffer.append(x) + while self._buffer: + idx = self._rng.randint(0, len(self._buffer) - 1) + yield self._buffer.pop(idx) + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + return len(self.datapipe) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + def reset(self) -> None: + self._buffer = [] + if self._enabled: + if self._seed is None: + self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) + self._rng.seed(self._seed) + self._seed = None + + def __getstate__(self): + state = ( + self.datapipe, + self.buffer_size, + self._enabled, + self._seed, + self._buffer, + self._rng.getstate(), + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.datapipe, + self.buffer_size, + self._enabled, + self._seed, + self._buffer, + rng_state, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self._rng = random.Random() + self._rng.setstate(rng_state) + + def __del__(self): + self._buffer.clear() diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/combining.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/combining.py new file mode 100644 index 0000000000000000000000000000000000000000..85ca5360dbde972c043561c4aa628372686a85e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/combining.py @@ -0,0 +1,696 @@ +# mypy: allow-untyped-defs +import copy as copymodule +import warnings +from abc import ABC, abstractmethod +from collections import deque +from collections.abc import Iterator, Sized +from typing import Any, Callable, Literal, Optional, TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, StreamWrapper + + +__all__ = [ + "ConcaterIterDataPipe", + "DemultiplexerIterDataPipe", + "ForkerIterDataPipe", + "MultiplexerIterDataPipe", + "ZipperIterDataPipe", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +@functional_datapipe("concat") +class ConcaterIterDataPipe(IterDataPipe): + r""" + Concatenates multiple Iterable DataPipes (functional name: ``concat``). + + The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones. + + Args: + datapipes: Iterable DataPipes being concatenated + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> import random + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1 = IterableWrapper(range(3)) + >>> dp2 = IterableWrapper(range(5)) + >>> list(dp1.concat(dp2)) + [0, 1, 2, 0, 1, 2, 3, 4] + """ + + datapipes: tuple[IterDataPipe] + + def __init__(self, *datapipes: IterDataPipe): + if len(datapipes) == 0: + raise ValueError("Expected at least one DataPipe, but got nothing") + if not all(isinstance(dp, IterDataPipe) for dp in datapipes): + raise TypeError("Expected all inputs to be `IterDataPipe`") + self.datapipes = datapipes # type: ignore[assignment] + + def __iter__(self) -> Iterator: + for dp in self.datapipes: + yield from dp + + def __len__(self) -> int: + if all(isinstance(dp, Sized) for dp in self.datapipes): + return sum(len(dp) for dp in self.datapipes) + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + +@functional_datapipe("fork") +class ForkerIterDataPipe(IterDataPipe): + r""" + Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``). + + Args: + datapipe: Iterable DataPipe being copied + num_instances: number of instances of the datapipe to create + buffer_size: this restricts how far ahead the leading child DataPipe + can read relative to the slowest child DataPipe. + Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + copy: copy strategy to use for items yielded by each branch. Supported + options are ``None`` for no copying, ``"shallow"`` for shallow object + copies, and ``"deep"`` for deep object copies. Defaults to ``None``. + + Note: + All branches of the forked pipeline return the identical object unless + the copy parameter is supplied. If the object is mutable or contains + mutable objects, changing them in one branch will affect all others. + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper(range(5)) + >>> dp1, dp2 = source_dp.fork(num_instances=2) + >>> list(dp1) + [0, 1, 2, 3, 4] + >>> list(dp2) + [0, 1, 2, 3, 4] + """ + + def __new__( + cls, + datapipe: IterDataPipe, + num_instances: int, + buffer_size: int = 1000, + copy: Optional[Literal["shallow", "deep"]] = None, + ): + if num_instances < 1: + raise ValueError( + f"Expected `num_instances` larger than 0, but {num_instances} is found" + ) + if num_instances == 1: + return datapipe + container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size, copy) # type: ignore[abstract] + return [_ChildDataPipe(container, i) for i in range(num_instances)] + + +class _ContainerTemplate(ABC): + r"""Abstract class for container ``DataPipes``. The followings are three required methods.""" + + @abstractmethod + def get_next_element_by_instance(self, instance_id: int): + ... + + @abstractmethod + def is_every_instance_exhausted(self) -> bool: + ... + + @abstractmethod + def reset(self) -> None: + ... + + @abstractmethod + def get_length_by_instance(self, instance_id: int): + r"""Raise TypeError if it's not supposed to be implemented to support `list(datapipe)`.""" + + +def _no_op(x): + return x + + +class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate): + r""" + Container to hold instance-specific information on behalf of ForkerIterDataPipe. + + It tracks the state of its child DataPipes, maintains the buffer, and yields the next value + as requested by the child DataPipes. + """ + + def __init__( + self, + datapipe: IterDataPipe, + num_instances: int, + buffer_size: int = 1000, + copy: Optional[Literal["shallow", "deep"]] = None, + ): + self.main_datapipe = datapipe + self._datapipe_iterator: Optional[Iterator[Any]] = None + self.num_instances = num_instances + self.buffer: deque = deque() + self.buffer_size = buffer_size + if self.buffer_size < 0: + warnings.warn( + "Unlimited buffer size is set for `fork`, " + "please be aware of OOM at random places", + UserWarning, + ) + if copy is None: + self.copy_fn = _no_op + elif copy == "shallow": + self.copy_fn = copymodule.copy + elif copy == "deep": + self.copy_fn = copymodule.deepcopy + else: + raise ValueError( + f"Unknown copy method `{copy}` requested, choose one of None, `shallow` or `deep`." + ) + + self.child_pointers: list[int] = [ + 0 + ] * num_instances # Indicate the indices of the next element to get + self.slowest_ptr = 0 # The index to read by the slowest child + self.leading_ptr = 0 # The index to read by the fastest child + self.end_ptr: Optional[int] = None # The index to stop child + self._child_stop: list[bool] = [True for _ in range(num_instances)] + + def __len__(self): + return len(self.main_datapipe) + + def get_next_element_by_instance(self, instance_id: int): + if self._datapipe_iterator is None and self._child_stop[instance_id]: + self._datapipe_iterator = iter(self.main_datapipe) + self._snapshot_state = _SnapshotState.Iterating + for i in range(self.num_instances): + self._child_stop[i] = False + try: + while not self._child_stop[instance_id]: + self.child_pointers[instance_id] += 1 + if ( + self.end_ptr is not None + and self.child_pointers[instance_id] == self.end_ptr + ): + self._child_stop[instance_id] = True + break + # Use buffer + if self.buffer and self.child_pointers[instance_id] <= self.leading_ptr: + idx = self.child_pointers[instance_id] - self.slowest_ptr - 1 + return_val = self.buffer[idx] + else: # Retrieve one element from main datapipe + self.leading_ptr = self.child_pointers[instance_id] + try: + return_val = next(self._datapipe_iterator) # type: ignore[arg-type] + self.buffer.append(return_val) + except StopIteration: + self._child_stop[instance_id] = True + self._datapipe_iterator = None + self.end_ptr = self.leading_ptr + continue + if self.child_pointers[instance_id] == self.slowest_ptr + 1: + new_min = min( + self.child_pointers + ) # Can optimize by avoiding the call to min() + if self.slowest_ptr < new_min: + self.slowest_ptr = new_min + self.buffer.popleft() + if ( + self.buffer_size >= 0 + and self.leading_ptr > self.buffer_size + self.slowest_ptr + ): + raise BufferError( + "ForkerIterDataPipe buffer overflow," + + f"buffer size {self.buffer_size} is insufficient." + ) + + yield self.copy_fn(return_val) # type: ignore[possibly-undefined] + finally: + self._child_stop[instance_id] = True + # Cleanup _datapipe_iterator for the case that fork exits earlier + if all(self._child_stop): + self._datapipe_iterator = None + self._cleanup() + + def is_every_instance_exhausted(self) -> bool: + return self.end_ptr is not None and all(self._child_stop) + + def get_length_by_instance(self, instance_id: int) -> int: + return len(self.main_datapipe) + + def reset(self) -> None: + self._datapipe_iterator = None + self.buffer = deque() + self.child_pointers = [0] * self.num_instances + self.slowest_ptr = 0 + self.leading_ptr = 0 + self.end_ptr = None + self._child_stop = [True for _ in range(self.num_instances)] + + def __getstate__(self): + state = ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + self.copy_fn, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + self.copy_fn, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self._datapipe_iterator = None + self.buffer = deque() + self.child_pointers = [0] * self.num_instances + self.slowest_ptr = 0 + self.leading_ptr = 0 + self.end_ptr = None + self._child_stop = [True for _ in range(self.num_instances)] + + def _cleanup(self): + while self.buffer: + d = self.buffer.popleft() + StreamWrapper.close_streams(d) + + def __del__(self): + self._cleanup() + + +class _ChildDataPipe(IterDataPipe): + r""" + Iterable Datapipe that is a child of a main DataPipe. + + The instance of this class will pass its instance_id to get the next value from its main DataPipe. + + Note: + ChildDataPipe, like all other IterDataPipe, follows the single iterator per IterDataPipe constraint. + Since ChildDataPipes share a common buffer, when an iterator is created for one of the ChildDataPipes, + the previous iterators for all ChildDataPipes must be invalidated, with the exception when a ChildDataPipe + hasn't had an iterator created from it since the last invalidation. See the example below. + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> # Singler Iterator per IteraDataPipe Invalidation + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper(range(10)) + >>> cdp1, cdp2 = source_dp.fork(num_instances=2) + >>> it1, it2 = iter(cdp1), iter(cdp2) + >>> it3 = iter(cdp1) + >>> # The line above invalidates `it1` and `it2`, and resets `ForkerIterDataPipe`. + >>> it4 = iter(cdp2) + >>> # The line above doesn't invalidate `it3`, because an iterator for `cdp2` hasn't been created since + >>> # the last invalidation. + + Args: + main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)' + instance_id: integer identifier of this instance + """ + + _is_child_datapipe: bool = True + + def __init__(self, main_datapipe: IterDataPipe, instance_id: int): + assert isinstance(main_datapipe, _ContainerTemplate) + + self.main_datapipe: IterDataPipe = main_datapipe + self.instance_id = instance_id + + def __iter__(self): + # Note that the logic behind setting iterator ID and `reset` are handled within `hook_iterator` + # We want to separate the code for reset and yield, so that 'reset' executes before __next__ is called + return self.main_datapipe.get_next_element_by_instance(self.instance_id) + + def __len__(self): + return self.main_datapipe.get_length_by_instance(self.instance_id) + + # This method is called by `hook_iterator` in `_typing.py`. + def _set_main_datapipe_valid_iterator_id(self) -> int: + r""" + Update the valid iterator ID for both this DataPipe object and `main_datapipe`. + + `main_datapipe.reset()` is called when the ID is incremented to a new generation. + """ + # 1. First time any child iterator is created + if self.main_datapipe._valid_iterator_id is None: + self.main_datapipe._valid_iterator_id = 0 # type: ignore[attr-defined] + # 2. This instance was already in the same generation as `main_datapipe`, + # we need to increment the ID further by 1 + elif self.main_datapipe._valid_iterator_id == self._valid_iterator_id: # type: ignore[has-type] + self.main_datapipe._valid_iterator_id += 1 # type: ignore[attr-defined] + # Whenever a new generation of iterator is created, the `main_datapipe` must reset + if not self.main_datapipe.is_every_instance_exhausted(): + warnings.warn( + "Some child DataPipes are not exhausted when __iter__ is called. We are resetting " + "the buffer and each child DataPipe will read from the start again.", + UserWarning, + ) + self.main_datapipe.reset() + # 3. Otherwise, the iterator is behind the others, so it will just need to catch up by setting + # the instance's iterator to match that of `main_datapipe` + self._valid_iterator_id = self.main_datapipe._valid_iterator_id + return self._valid_iterator_id + + # This method is called by `hook_iterator` in `_typing.py`. + def _check_valid_iterator_id(self, iterator_id) -> bool: + r"""Check the valid iterator ID against that of DataPipe object and that of `main_datapipe`.""" + return ( + iterator_id == self._valid_iterator_id + and iterator_id == self.main_datapipe._valid_iterator_id + ) + + +@functional_datapipe("demux") +class DemultiplexerIterDataPipe(IterDataPipe): + r""" + Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``). + + A list of the child DataPipes is returned from this operation. + + Args: + datapipe: Iterable DataPipe being filtered + num_instances: number of instances of the DataPipe to create + classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None`` + drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None`` + buffer_size: this defines the maximum number of inputs that the buffer can hold across all child + DataPipes while waiting for their values to be yielded. + Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + + Examples: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def odd_or_even(n): + ... return n % 2 + >>> source_dp = IterableWrapper(range(5)) + >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even) + >>> list(dp1) + [0, 2, 4] + >>> list(dp2) + [1, 3] + >>> # It can also filter out any element that gets `None` from the `classifier_fn` + >>> def odd_or_even_no_zero(n): + ... return n % 2 if n != 0 else None + >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True) + >>> list(dp1) + [2, 4] + >>> list(dp2) + [1, 3] + """ + + def __new__( + cls, + datapipe: IterDataPipe, + num_instances: int, + classifier_fn: Callable[[_T_co], Optional[int]], + drop_none: bool = False, + buffer_size: int = 1000, + ): + if num_instances < 1: + raise ValueError( + f"Expected `num_instances` larger than 0, but {num_instances} is found" + ) + + _check_unpickable_fn(classifier_fn) + + # When num_instances == 1, demux can be replaced by filter, + # but keep it as Demultiplexer for the sake of consistency + # like throwing Error when classification result is out of o range + container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract] + return [_ChildDataPipe(container, i) for i in range(num_instances)] + + +class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate): + r""" + Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe. + + It tracks the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value + as requested by the child DataPipes. + """ + + def __init__( + self, + datapipe: IterDataPipe[_T_co], + num_instances: int, + classifier_fn: Callable[[_T_co], Optional[int]], + drop_none: bool, + buffer_size: int, + ): + self.main_datapipe = datapipe + self._datapipe_iterator: Optional[Iterator[Any]] = None + self.num_instances = num_instances + self.buffer_size = buffer_size + if self.buffer_size < 0: + warnings.warn( + "Unlimited buffer size is set for `demux`, " + "please be aware of OOM at random places", + UserWarning, + ) + self.current_buffer_usage = 0 + self.child_buffers: list[deque[_T_co]] = [deque() for _ in range(num_instances)] + self.classifier_fn = classifier_fn + self.drop_none = drop_none + self.main_datapipe_exhausted = False + self._child_stop: list[bool] = [True for _ in range(num_instances)] + + def _find_next(self, instance_id: int) -> _T_co: # type: ignore[type-var] + while True: + if self.main_datapipe_exhausted or self._child_stop[instance_id]: + raise StopIteration + if self._datapipe_iterator is None: + raise ValueError( + "_datapipe_iterator has not been set, likely because this private method is called directly " + "without invoking get_next_element_by_instance() first." + ) + value = next(self._datapipe_iterator) + classification = self.classifier_fn(value) + if classification is None and self.drop_none: + StreamWrapper.close_streams(value) + continue + if ( + classification is None + or classification >= self.num_instances + or classification < 0 + ): + raise ValueError( + f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " + + f"{classification} is returned." + ) + if classification == instance_id: + return value + self.child_buffers[classification].append(value) + self.current_buffer_usage += 1 + if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size: + raise BufferError( + f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient." + ) + + def get_next_element_by_instance(self, instance_id: int): + if self._datapipe_iterator is None and self._child_stop[instance_id]: + self._datapipe_iterator = iter(self.main_datapipe) + self._snapshot_state = ( + _SnapshotState.Iterating + ) # This is necessary for the DataPipe to reset properly. + self.main_datapipe_exhausted = False + for i in range(self.num_instances): + self._child_stop[i] = False + + try: + while not self._child_stop[instance_id]: + if self.child_buffers[instance_id]: + self.current_buffer_usage -= 1 + yield self.child_buffers[instance_id].popleft() + else: + try: + yield self._find_next(instance_id) + except StopIteration: + self._child_stop[instance_id] = True + self.main_datapipe_exhausted = True + self._datapipe_iterator = None + finally: + self._child_stop[instance_id] = True + # Cleanup _datapipe_iterator for the case that demux exits earlier + if all(self._child_stop): + self._datapipe_iterator = None + if self.child_buffers[instance_id]: + self._cleanup(instance_id) + + def is_every_instance_exhausted(self) -> bool: + return self.main_datapipe_exhausted and all(self._child_stop) + + def get_length_by_instance(self, instance_id: int) -> int: + raise TypeError + + def reset(self) -> None: + self._datapipe_iterator = None + self.current_buffer_usage = 0 + self.child_buffers = [deque() for _ in range(self.num_instances)] + self._child_stop = [True for _ in range(self.num_instances)] + self.main_datapipe_exhausted = False + + def __getstate__(self): + state = ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + self.classifier_fn, + self.drop_none, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.main_datapipe, + self.num_instances, + self.buffer_size, + self.classifier_fn, + self.drop_none, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self._datapipe_iterator = None + self.current_buffer_usage = 0 + self.child_buffers = [deque() for _ in range(self.num_instances)] + self._child_stop = [True for _ in range(self.num_instances)] + self.main_datapipe_exhausted = False + + def _cleanup(self, instance_id: Optional[int] = None): + ids = ( + range(self.num_instances) + if instance_id is None + else [ + instance_id, + ] + ) + for i in ids: + q = self.child_buffers[i] + while q: + d = q.popleft() + StreamWrapper.close_streams(d) + + def __del__(self): + self._cleanup() + + +@functional_datapipe("mux") +class MultiplexerIterDataPipe(IterDataPipe): + r""" + Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). + + As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, + and so on. It ends when the shortest input DataPipe is exhausted. + + Args: + datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> list(dp1.mux(dp2, dp3)) + [0, 10, 20, 1, 11, 21, 2, 12, 22] + """ + + def __init__(self, *datapipes): + self.datapipes = datapipes + self.buffer: list = ( + [] + ) # Store values to be yielded only when every iterator provides one + + def __iter__(self): + iterators = [iter(x) for x in self.datapipes] + while len(iterators): + for it in iterators: + try: + value = next(it) + self.buffer.append(value) + except StopIteration: + self.buffer.clear() + return + yield from self.buffer + self.buffer.clear() + + def __len__(self): + if all(isinstance(dp, Sized) for dp in self.datapipes): + return min(len(dp) for dp in self.datapipes) * len(self.datapipes) + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + def reset(self) -> None: + self.buffer = [] + + def __getstate__(self): + state = ( + self.datapipes, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.datapipes, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self.buffer = [] + + def __del__(self): + self.buffer.clear() + + +@functional_datapipe("zip") +class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]): + r""" + Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). + + The output is stopped as soon as the shortest input DataPipe is exhausted. + + Args: + *datapipes: Iterable DataPipes being aggregated + + Example: + >>> # xdoctest: +REQUIRES(module:torchdata) + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) + >>> list(dp1.zip(dp2, dp3)) + [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] + """ + + datapipes: tuple[IterDataPipe] + + def __init__(self, *datapipes: IterDataPipe): + if not all(isinstance(dp, IterDataPipe) for dp in datapipes): + raise TypeError( + "All inputs are required to be `IterDataPipe` for `ZipIterDataPipe`." + ) + super().__init__() + self.datapipes = datapipes # type: ignore[assignment] + + def __iter__(self) -> Iterator[tuple[_T_co]]: + iterators = [iter(datapipe) for datapipe in self.datapipes] + yield from zip(*iterators) + + def __len__(self) -> int: + if all(isinstance(dp, Sized) for dp in self.datapipes): + return min(len(dp) for dp in self.datapipes) + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/filelister.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/filelister.py new file mode 100644 index 0000000000000000000000000000000000000000..91226c7ee084d7ca4a575e61119f057561313eb5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/filelister.py @@ -0,0 +1,68 @@ +from collections.abc import Iterator, Sequence +from typing import Union + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.iter.utils import IterableWrapperIterDataPipe +from torch.utils.data.datapipes.utils.common import get_file_pathnames_from_root + + +__all__ = ["FileListerIterDataPipe"] + + +@functional_datapipe("list_files") +class FileListerIterDataPipe(IterDataPipe[str]): + r""" + Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory. + + Multiple root directories can be provided (functional name: ``list_files``). + + Args: + root: Root directory or a sequence of root directories + masks: Unix style filter string or string list for filtering file name(s) + recursive: Whether to return pathname from nested directories or not + abspath: Whether to return relative pathname or absolute pathname + non_deterministic: Whether to return pathname in sorted order or not. + If ``False``, the results yielded from each root directory will be sorted + length: Nominal length of the datapipe + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import FileLister + >>> dp = FileLister(root=".", recursive=True) + >>> list(dp) + ['example.py', './data/data.tar'] + """ + + def __init__( + self, + root: Union[str, Sequence[str], IterDataPipe] = ".", + masks: Union[str, list[str]] = "", + *, + recursive: bool = False, + abspath: bool = False, + non_deterministic: bool = False, + length: int = -1, + ) -> None: + super().__init__() + if isinstance(root, str): + root = [root] + if not isinstance(root, IterDataPipe): + root = IterableWrapperIterDataPipe(root) + self.datapipe: IterDataPipe = root + self.masks: Union[str, list[str]] = masks + self.recursive: bool = recursive + self.abspath: bool = abspath + self.non_deterministic: bool = non_deterministic + self.length: int = length + + def __iter__(self) -> Iterator[str]: + for path in self.datapipe: + yield from get_file_pathnames_from_root( + path, self.masks, self.recursive, self.abspath, self.non_deterministic + ) + + def __len__(self) -> int: + if self.length == -1: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + return self.length diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/fileopener.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/fileopener.py new file mode 100644 index 0000000000000000000000000000000000000000..082c38a5b03598c059061c754791af93de174da6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/fileopener.py @@ -0,0 +1,77 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterable +from io import IOBase +from typing import Optional + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames + + +__all__ = [ + "FileOpenerIterDataPipe", +] + + +@functional_datapipe("open_files") +class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]): + r""" + Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``). + + Args: + datapipe: Iterable datapipe that provides pathnames + mode: An optional string that specifies the mode in which + the file is opened by ``open()``. It defaults to ``r``, other options are + ``b`` for reading in binary mode and ``t`` for text mode. + encoding: An optional string that specifies the encoding of the + underlying file. It defaults to ``None`` to match the default encoding of ``open``. + length: Nominal length of the datapipe + + Note: + The opened file handles will be closed by Python's GC periodically. Users can choose + to close them explicitly. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader + >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt')) + >>> dp = FileOpener(dp) + >>> dp = StreamReader(dp) + >>> list(dp) + [('./abc.txt', 'abc')] + """ + + def __init__( + self, + datapipe: Iterable[str], + mode: str = "r", + encoding: Optional[str] = None, + length: int = -1, + ): + super().__init__() + self.datapipe: Iterable = datapipe + self.mode: str = mode + self.encoding: Optional[str] = encoding + + if self.mode not in ("b", "t", "rb", "rt", "r"): + raise ValueError(f"Invalid mode {mode}") + # TODO: enforce typing for each instance based on mode, otherwise + # `argument_validation` with this DataPipe may be potentially broken + + if "b" in mode and encoding is not None: + raise ValueError("binary mode doesn't take an encoding argument") + + self.length: int = length + + # Remove annotation due to 'IOBase' is a general type and true type + # is determined at runtime based on mode. Some `DataPipe` requiring + # a subtype would cause mypy error. + def __iter__(self): + yield from get_file_binaries_from_pathnames( + self.datapipe, self.mode, self.encoding + ) + + def __len__(self): + if self.length == -1: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + return self.length diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/grouping.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/grouping.py new file mode 100644 index 0000000000000000000000000000000000000000..c2f3677d25de085297778a99a47e473d8e293291 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/grouping.py @@ -0,0 +1,322 @@ +# mypy: allow-untyped-defs +import warnings +from collections import defaultdict +from collections.abc import Iterator, Sized +from typing import Any, Callable, Optional, TypeVar + +import torch.utils.data.datapipes.iter.sharding +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn + + +__all__ = [ + "BatcherIterDataPipe", + "GrouperIterDataPipe", + "UnBatcherIterDataPipe", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +def __getattr__(name: str): + if name in ["SHARDING_PRIORITIES", "ShardingFilterIterDataPipe"]: + warnings.warn( + f"`{name}` from `torch.utils.data.datapipes.iter.grouping` is going to be removed in PyTorch 2.1" + f"Please use `{name}` from the `torch.utils.data.datapipes.iter.sharding`", + category=FutureWarning, + stacklevel=2, + ) + + return getattr(torch.utils.data.datapipes.iter.sharding, name) + + raise AttributeError(f"module {__name__} has no attribute {name}") + + +@functional_datapipe("batch") +class BatcherIterDataPipe(IterDataPipe[DataChunk]): + r""" + Creates mini-batches of data (functional name: ``batch``). + + An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the + last batch if ``drop_last`` is set to ``False``. + + Args: + datapipe: Iterable DataPipe being batched + batch_size: The size of each batch + drop_last: Option to drop the last batch if it's not full + wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding, + defaults to ``DataChunk`` + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> dp = dp.batch(batch_size=3, drop_last=True) + >>> list(dp) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + datapipe: IterDataPipe + batch_size: int + drop_last: bool + + def __init__( + self, + datapipe: IterDataPipe, + batch_size: int, + drop_last: bool = False, + wrapper_class: type[DataChunk] = DataChunk, + ) -> None: + assert batch_size > 0, "Batch size is required to be larger than 0!" + super().__init__() + self.datapipe = datapipe + self.batch_size = batch_size + self.drop_last = drop_last + self.wrapper_class = wrapper_class + + def __iter__(self) -> Iterator[DataChunk]: + batch: list = [] + for x in self.datapipe: + batch.append(x) + if len(batch) == self.batch_size: + yield self.wrapper_class(batch) + batch = [] + if len(batch) > 0: + if not self.drop_last: + yield self.wrapper_class(batch) + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + if self.drop_last: + return len(self.datapipe) // self.batch_size + else: + return (len(self.datapipe) + self.batch_size - 1) // self.batch_size + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + + +@functional_datapipe("unbatch") +class UnBatcherIterDataPipe(IterDataPipe): + r""" + Undos batching of data (functional name: ``unbatch``). + + In other words, it flattens the data up to the specified level within a batched DataPipe. + + Args: + datapipe: Iterable DataPipe being un-batched + unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``, + it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]]) + >>> dp1 = source_dp.unbatch() + >>> list(dp1) + [[0, 1], [2], [3, 4], [5], [6]] + >>> dp2 = source_dp.unbatch(unbatch_level=2) + >>> list(dp2) + [0, 1, 2, 3, 4, 5, 6] + """ + + def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1): + self.datapipe = datapipe + self.unbatch_level = unbatch_level + + def __iter__(self): + for element in self.datapipe: + yield from self._dive(element, unbatch_level=self.unbatch_level) + + def _dive(self, element, unbatch_level): + if unbatch_level < -1: + raise ValueError("unbatch_level must be -1 or >= 0") + if unbatch_level == -1: + if isinstance(element, (list, DataChunk)): + for item in element: + yield from self._dive(item, unbatch_level=-1) + else: + yield element + elif unbatch_level == 0: + yield element + else: + if isinstance(element, (list, DataChunk)): + for item in element: + yield from self._dive(item, unbatch_level=unbatch_level - 1) + else: + raise IndexError( + f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe" + ) + + +@functional_datapipe("groupby") +class GrouperIterDataPipe(IterDataPipe[DataChunk]): + r""" + Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``. + + (functional name: ``groupby``). + + The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group + will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full, + the DataPipe will yield the largest batch with the same key, provided that its size is larger + than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``. + + After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity + will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``. + + Args: + datapipe: Iterable datapipe to be grouped + group_key_fn: Function used to generate group key from the data of the source datapipe + keep_key: Option to yield the matching key along with the items in a tuple, + resulting in `(key, [items])` otherwise returning [items] + buffer_size: The size of buffer for ungrouped data + group_size: The max size of each group, a batch is yielded as soon as it reaches this size + guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full + drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer + when the buffer is full + + Example: + >>> import os + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def group_fn(file): + ... return os.path.basename(file).split(".")[0] + >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) + >>> dp0 = source_dp.groupby(group_key_fn=group_fn) + >>> list(dp0) + [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] + >>> # A group is yielded as soon as its size equals to `group_size` + >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2) + >>> list(dp1) + [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] + >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` + >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) + >>> list(dp2) + [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] + """ + + def __init__( + self, + datapipe: IterDataPipe[_T_co], + group_key_fn: Callable[[_T_co], Any], + *, + keep_key: bool = False, + buffer_size: int = 10000, + group_size: Optional[int] = None, + guaranteed_group_size: Optional[int] = None, + drop_remaining: bool = False, + ): + _check_unpickable_fn(group_key_fn) + self.datapipe = datapipe + self.group_key_fn = group_key_fn + + self.keep_key = keep_key + self.max_buffer_size = buffer_size + self.buffer_elements: defaultdict[Any, list] = defaultdict(list) + self.curr_buffer_size = 0 + self.group_size = group_size + self.guaranteed_group_size = None + if group_size is not None and buffer_size is not None: + assert 0 < group_size <= buffer_size + self.guaranteed_group_size = group_size + if guaranteed_group_size is not None: + assert group_size is not None and 0 < guaranteed_group_size <= group_size + self.guaranteed_group_size = guaranteed_group_size + self.drop_remaining = drop_remaining + self.wrapper_class = DataChunk + + def _remove_biggest_key(self): + biggest_key = None + biggest_size = 0 + result_to_yield = None + for findkey in self.buffer_elements.keys(): + if len(self.buffer_elements[findkey]) > biggest_size: + biggest_size = len(self.buffer_elements[findkey]) + biggest_key = findkey + + if ( + self.guaranteed_group_size is not None + and biggest_size < self.guaranteed_group_size + and not self.drop_remaining + ): + raise RuntimeError( + "Failed to group items", str(self.buffer_elements[biggest_key]) + ) + + if ( + self.guaranteed_group_size is None + or biggest_size >= self.guaranteed_group_size + ): + result_to_yield = self.buffer_elements[biggest_key] + + self.curr_buffer_size -= biggest_size + del self.buffer_elements[biggest_key] + + return result_to_yield + + def __iter__(self): + for x in self.datapipe: + key = self.group_key_fn(x) + + self.buffer_elements[key].append(x) + self.curr_buffer_size += 1 + + if self.group_size is not None and self.group_size == len( + self.buffer_elements[key] + ): + result: DataChunk[Any] = self.wrapper_class(self.buffer_elements[key]) + yield (key, result) if self.keep_key else result + self.curr_buffer_size -= len(self.buffer_elements[key]) + del self.buffer_elements[key] + + if self.curr_buffer_size == self.max_buffer_size: + result_to_yield = self._remove_biggest_key() + if result_to_yield is not None: + result = self.wrapper_class(result_to_yield) + yield (key, result) if self.keep_key else result + + for key in tuple(self.buffer_elements.keys()): + result = self.wrapper_class(self.buffer_elements.pop(key)) + self.curr_buffer_size -= len(result) + yield (key, result) if self.keep_key else result + + def reset(self) -> None: + self.curr_buffer_size = 0 + self.buffer_elements = defaultdict(list) + + def __getstate__(self): + state = ( + self.datapipe, + self.group_key_fn, + self.keep_key, + self.max_buffer_size, + self.group_size, + self.guaranteed_group_size, + self.drop_remaining, + self.wrapper_class, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.datapipe, + self.group_key_fn, + self.keep_key, + self.max_buffer_size, + self.group_size, + self.guaranteed_group_size, + self.drop_remaining, + self.wrapper_class, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self.curr_buffer_size = 0 + self.buffer_elements = defaultdict(list) + + def __del__(self): + self.buffer_elements.clear() diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/routeddecoder.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/routeddecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..65ca9181909c220e1ea49e4159a35f8906e803fb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/routeddecoder.py @@ -0,0 +1,70 @@ +from collections.abc import Iterable, Iterator, Sized +from io import BufferedIOBase +from typing import Any, Callable + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import _deprecation_warning +from torch.utils.data.datapipes.utils.decoder import ( + basichandlers as decoder_basichandlers, + Decoder, + extension_extract_fn, + imagehandler as decoder_imagehandler, +) + + +__all__ = ["RoutedDecoderIterDataPipe"] + + +@functional_datapipe("routed_decode") +class RoutedDecoderIterDataPipe(IterDataPipe[tuple[str, Any]]): + r""" + Decodes binary streams from input DataPipe, yields pathname and decoded data in a tuple. + + (functional name: ``routed_decode``) + + Args: + datapipe: Iterable datapipe that provides pathname and binary stream in tuples + handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder + handlers will be set as default. If multiple handles are provided, the priority + order follows the order of handlers (the first handler has the top priority) + key_fn: Function for decoder to extract key from pathname to dispatch handlers. + Default is set to extract file extension from pathname + + Note: + When ``key_fn`` is specified returning anything other than extension, the default + handler will not work and users need to specify custom handler. Custom handler + could use regex to determine the eligibility to handle data. + """ + + def __init__( + self, + datapipe: Iterable[tuple[str, BufferedIOBase]], + *handlers: Callable, + key_fn: Callable = extension_extract_fn, + ) -> None: + super().__init__() + self.datapipe: Iterable[tuple[str, BufferedIOBase]] = datapipe + if not handlers: + handlers = (decoder_basichandlers, decoder_imagehandler("torch")) + self.decoder = Decoder(*handlers, key_fn=key_fn) + _deprecation_warning( + type(self).__name__, + deprecation_version="1.12", + removal_version="1.13", + old_functional_name="routed_decode", + ) + + def add_handler(self, *handler: Callable) -> None: + self.decoder.add_handler(*handler) + + def __iter__(self) -> Iterator[tuple[str, Any]]: + for data in self.datapipe: + pathname = data[0] + result = self.decoder(data) + yield (pathname, result[pathname]) + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + return len(self.datapipe) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/selecting.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/selecting.py new file mode 100644 index 0000000000000000000000000000000000000000..e17c1869bc1cfe6ae10b5d58a4059aaa83604b6b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/selecting.py @@ -0,0 +1,102 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterator +from typing import Callable, TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.datapipes.utils.common import ( + _check_unpickable_fn, + StreamWrapper, + validate_input_col, +) + + +__all__ = ["FilterIterDataPipe"] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + + +@functional_datapipe("filter") +class FilterIterDataPipe(IterDataPipe[_T_co]): + r""" + Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``). + + Args: + datapipe: Iterable DataPipe being filtered + filter_fn: Customized function mapping an element to a boolean. + input_col: Index or indices of data which ``filter_fn`` is applied, such as: + + - ``None`` as default to apply ``filter_fn`` to the data directly. + - Integer(s) is used for list/tuple. + - Key(s) is used for dict. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> def is_even(n): + ... return n % 2 == 0 + >>> dp = IterableWrapper(range(5)) + >>> filter_dp = dp.filter(filter_fn=is_even) + >>> list(filter_dp) + [0, 2, 4] + """ + + datapipe: IterDataPipe[_T_co] + filter_fn: Callable + + def __init__( + self, + datapipe: IterDataPipe[_T_co], + filter_fn: Callable, + input_col=None, + ) -> None: + super().__init__() + self.datapipe = datapipe + + _check_unpickable_fn(filter_fn) + self.filter_fn = filter_fn # type: ignore[assignment] + + self.input_col = input_col + validate_input_col(filter_fn, input_col) + + def _apply_filter_fn(self, data) -> bool: + if self.input_col is None: + return self.filter_fn(data) + elif isinstance(self.input_col, (list, tuple)): + args = tuple(data[col] for col in self.input_col) + return self.filter_fn(*args) + else: + return self.filter_fn(data[self.input_col]) + + def __iter__(self) -> Iterator[_T_co]: + for data in self.datapipe: + condition, filtered = self._returnIfTrue(data) + if condition: + yield filtered + else: + StreamWrapper.close_streams(data) + + def _returnIfTrue(self, data: _T) -> tuple[bool, _T]: + condition = self._apply_filter_fn(data) + + if df_wrapper.is_column(condition): + # We are operating on DataFrames filter here + result = [] + for idx, mask in enumerate(df_wrapper.iterate(condition)): + if mask: + result.append(df_wrapper.get_item(data, idx)) + if len(result): + return True, df_wrapper.concat(result) + else: + return False, None # type: ignore[return-value] + + if not isinstance(condition, bool): + raise ValueError( + "Boolean output is required for `filter_fn` of FilterIterDataPipe, got", + type(condition), + ) + + return condition, data diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/sharding.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..345f8ecbcb2cd978425abcb0979c899f51592449 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/sharding.py @@ -0,0 +1,101 @@ +# mypy: allow-untyped-defs +from collections.abc import Sized +from enum import IntEnum + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +__all__ = [ + "SHARDING_PRIORITIES", + "ShardingFilterIterDataPipe", +] + + +class SHARDING_PRIORITIES(IntEnum): + DEFAULT = 1 + DISTRIBUTED = 2 + MULTIPROCESSING = 3 + + +class _ShardingIterDataPipe(IterDataPipe): + def apply_sharding( + self, + num_of_instances: int, + instance_id: int, + sharding_group: SHARDING_PRIORITIES, + ): + raise NotImplementedError + + +@functional_datapipe("sharding_filter") +class ShardingFilterIterDataPipe(_ShardingIterDataPipe): + r""" + Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``). + + After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the + original DataPipe, where `n` equals to the number of instances. + + Args: + source_datapipe: Iterable DataPipe that will be sharded + """ + + def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None): + self.source_datapipe = source_datapipe + self.sharding_group_filter = sharding_group_filter + self.groups: dict[int, tuple[int, int]] = {} + self.num_of_instances = 1 + self.instance_id = 0 + self._update_num_of_instances() + + def apply_sharding( + self, num_of_instances, instance_id, sharding_group=SHARDING_PRIORITIES.DEFAULT + ): + if instance_id >= num_of_instances: + raise ValueError( + f"instance_id({instance_id}) should be smaller than num_of_instances({num_of_instances})" + ) + if sharding_group == SHARDING_PRIORITIES.DEFAULT: + if len(self.groups) and SHARDING_PRIORITIES.DEFAULT not in self.groups: + raise RuntimeError( + "ShardingFilter cannot mix DEFAULT and non DEFAULT groups" + ) + else: + if SHARDING_PRIORITIES.DEFAULT in self.groups: + raise RuntimeError( + "ShardingFilter cannot mix DEFAULT and non DEFAULT groups" + ) + self.groups[sharding_group] = (num_of_instances, instance_id) + self._update_num_of_instances() + + def _update_num_of_instances(self): + sorted_sharding_groups = [ + self.groups[key] + for key in sorted(self.groups.keys()) + if self.sharding_group_filter is None or key == self.sharding_group_filter + ] + + sorted_sharding_groups.reverse() + + self.num_of_instances = 1 + self.instance_id = 0 + + for group_num_of_instances, group_instance_id in sorted_sharding_groups: + self.instance_id += self.num_of_instances * group_instance_id + self.num_of_instances *= group_num_of_instances + + def __iter__(self): + for i, item in enumerate(self.source_datapipe): + if i % self.num_of_instances == self.instance_id: + yield item + + def __len__(self): + if isinstance(self.source_datapipe, Sized): + return len(self.source_datapipe) // self.num_of_instances + ( + 1 + if ( + self.instance_id < len(self.source_datapipe) % self.num_of_instances + ) + else 0 + ) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/streamreader.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/streamreader.py new file mode 100644 index 0000000000000000000000000000000000000000..afe69248774554d53aca20a995d1780a78ce7b92 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/streamreader.py @@ -0,0 +1,46 @@ +from collections.abc import Iterator +from io import IOBase +from typing import Optional + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +__all__ = ["StreamReaderIterDataPipe"] + + +@functional_datapipe("read_from_stream") +class StreamReaderIterDataPipe(IterDataPipe[tuple[str, bytes]]): + r""" + Given IO streams and their label names, yield bytes with label name as tuple. + + (functional name: ``read_from_stream``). + + Args: + datapipe: Iterable DataPipe provides label/URL and byte stream + chunk: Number of bytes to be read from stream per iteration. + If ``None``, all bytes will be read until the EOF. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader + >>> from io import StringIO + >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))]) + >>> list(StreamReader(dp, chunk=1)) + [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')] + """ + + def __init__( + self, datapipe: IterDataPipe[tuple[str, IOBase]], chunk: Optional[int] = None + ): + self.datapipe = datapipe + self.chunk = chunk + + def __iter__(self) -> Iterator[tuple[str, bytes]]: + for furl, stream in self.datapipe: + while True: + d = stream.read(self.chunk) + if not d: + stream.close() + break + yield (furl, d) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/utils.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f04f7906a02fb9d9a692f53b8dd680e5446ef4b7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/iter/utils.py @@ -0,0 +1,54 @@ +# mypy: allow-untyped-defs +import copy +import warnings + +from torch.utils.data.datapipes.datapipe import IterDataPipe + + +__all__ = ["IterableWrapperIterDataPipe"] + + +class IterableWrapperIterDataPipe(IterDataPipe): + r""" + Wraps an iterable object to create an IterDataPipe. + + Args: + iterable: Iterable object to be wrapped into an IterDataPipe + deepcopy: Option to deepcopy input iterable object for each + iterator. The copy is made when the first element is read in ``iter()``. + + .. note:: + If ``deepcopy`` is explicitly set to ``False``, users should ensure + that the data pipeline doesn't contain any in-place operations over + the iterable instance to prevent data inconsistency across iterations. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.iter import IterableWrapper + >>> dp = IterableWrapper(range(10)) + >>> list(dp) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + """ + + def __init__(self, iterable, deepcopy=True): + self.iterable = iterable + self.deepcopy = deepcopy + + def __iter__(self): + source_data = self.iterable + if self.deepcopy: + try: + source_data = copy.deepcopy(self.iterable) + # For the case that data cannot be deep-copied, + # all in-place operations will affect iterable variable. + # When this DataPipe is iterated second time, it will + # yield modified items. + except TypeError: + warnings.warn( + "The input iterable can not be deepcopied, " + "please be aware of in-place modification would affect source data." + ) + yield from source_data + + def __len__(self): + return len(self.iterable) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__init__.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07733b168b161db7293f8c4dada192c776351d53 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__init__.py @@ -0,0 +1,19 @@ +# Functional DataPipe +from torch.utils.data.datapipes.map.callable import MapperMapDataPipe as Mapper +from torch.utils.data.datapipes.map.combinatorics import ( + ShufflerIterDataPipe as Shuffler, +) +from torch.utils.data.datapipes.map.combining import ( + ConcaterMapDataPipe as Concater, + ZipperMapDataPipe as Zipper, +) +from torch.utils.data.datapipes.map.grouping import BatcherMapDataPipe as Batcher +from torch.utils.data.datapipes.map.utils import ( + SequenceWrapperMapDataPipe as SequenceWrapper, +) + + +__all__ = ["Batcher", "Concater", "Mapper", "SequenceWrapper", "Shuffler", "Zipper"] + +# Please keep this list sorted +assert __all__ == sorted(__all__) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bb1cd38d529696b3daffd4a970dea971a076613 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/callable.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/callable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6ef1b5a2fca35814ee6263d09acea8b30ca4acc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/callable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea21bce6780c5a1f5b4154593461490fc8a66274 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/combining.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/combining.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49714b5384a6bba57c66d4ba1d3543138ad1850a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/combining.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/grouping.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/grouping.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78175d5330d2938856353729395feca65f5c6be0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/grouping.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64318bf71b0311ee30bf00ba8e528085d01b912b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/callable.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/callable.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3e6b77cacc4c814562adc9dd9e40c84470c38e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/callable.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +from typing import Callable, TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import MapDataPipe +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn + + +__all__ = ["MapperMapDataPipe", "default_fn"] + + +_T_co = TypeVar("_T_co", covariant=True) + + +# Default function to return each item directly +# In order to keep datapipe picklable, eliminates the usage +# of python lambda function +def default_fn(data): + return data + + +@functional_datapipe("map") +class MapperMapDataPipe(MapDataPipe[_T_co]): + r""" + Apply the input function over each item from the source DataPipe (functional name: ``map``). + + The function can be any regular Python function or partial object. Lambda + function is not recommended as it is not supported by pickle. + + Args: + datapipe: Source MapDataPipe + fn: Function being applied to each item + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper, Mapper + >>> def add_one(x): + ... return x + 1 + >>> dp = SequenceWrapper(range(10)) + >>> map_dp_1 = dp.map(add_one) + >>> list(map_dp_1) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + >>> map_dp_2 = Mapper(dp, lambda x: x + 1) + >>> list(map_dp_2) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + """ + + datapipe: MapDataPipe + fn: Callable + + def __init__( + self, + datapipe: MapDataPipe, + fn: Callable = default_fn, + ) -> None: + super().__init__() + self.datapipe = datapipe + _check_unpickable_fn(fn) + self.fn = fn # type: ignore[assignment] + + def __len__(self) -> int: + return len(self.datapipe) + + def __getitem__(self, index) -> _T_co: + return self.fn(self.datapipe[index]) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/combinatorics.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/combinatorics.py new file mode 100644 index 0000000000000000000000000000000000000000..209945a1ab3ca540025bfdffbc9bacd983dc7f03 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/combinatorics.py @@ -0,0 +1,130 @@ +# mypy: allow-untyped-defs +import random +from collections.abc import Iterator +from typing import Optional, TypeVar + +import torch +from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe + + +__all__ = ["ShufflerIterDataPipe"] + + +_T_co = TypeVar("_T_co", covariant=True) + + +# @functional_datapipe('shuffle') +class ShufflerIterDataPipe(IterDataPipe[_T_co]): + r""" + Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``). + + When it is used with :class:`~torch.utils.data.DataLoader`, the methods to + set up random seed are different based on :attr:`num_workers`. + + For single-process mode (:attr:`num_workers == 0`), the random seed is set before + the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process + mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed + for each worker process. + + Args: + datapipe: MapDataPipe being shuffled + indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> shuffle_dp = dp.shuffle().set_seed(0) + >>> list(shuffle_dp) + [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] + >>> list(shuffle_dp) + [6, 1, 9, 5, 2, 4, 7, 3, 8, 0] + >>> # Reset seed for Shuffler + >>> shuffle_dp = shuffle_dp.set_seed(0) + >>> list(shuffle_dp) + [7, 8, 1, 5, 3, 4, 2, 0, 9, 6] + + Note: + Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an + ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to + the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order + of data during data-processing. + """ + + datapipe: MapDataPipe[_T_co] + _enabled: bool + _seed: Optional[int] + _rng: random.Random + + def __init__( + self, + datapipe: MapDataPipe[_T_co], + *, + indices: Optional[list] = None, + ) -> None: + super().__init__() + self.datapipe = datapipe + self.indices = list(range(len(datapipe))) if indices is None else indices + self._enabled = True + self._seed = None + self._rng = random.Random() + self._shuffled_indices: list = self.indices + + def set_shuffle(self, shuffle=True): + self._enabled = shuffle + return self + + def set_seed(self, seed: int): + self._seed = seed + return self + + def __iter__(self) -> Iterator[_T_co]: + if not self._enabled: + for idx in self.indices: + yield self.datapipe[idx] + else: + while self._shuffled_indices: + idx = self._shuffled_indices.pop() + yield self.datapipe[idx] + + def reset(self) -> None: + if self._enabled and self._seed is None: + self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) + self._rng.seed(self._seed) + self._seed = None + self._shuffled_indices = self._rng.sample(self.indices, len(self.indices)) + + def __len__(self) -> int: + return len(self.datapipe) + + def __getstate__(self): + state = ( + self.datapipe, + self.indices, + self._enabled, + self._seed, + self._rng.getstate(), + self._shuffled_indices, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(state) + return state + + def __setstate__(self, state): + ( + self.datapipe, + self.indices, + self._enabled, + self._seed, + rng_state, + self._shuffled_indices, + self._valid_iterator_id, + self._number_of_samples_yielded, + ) = state + self._rng = random.Random() + self._rng.setstate(rng_state) + + +MapDataPipe.register_datapipe_as_function("shuffle", ShufflerIterDataPipe) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/combining.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/combining.py new file mode 100644 index 0000000000000000000000000000000000000000..28e4c91ca5252c61c46f5fd00513e9a806a4a1f8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/combining.py @@ -0,0 +1,105 @@ +# mypy: allow-untyped-defs +from collections.abc import Sized +from typing import TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import MapDataPipe + + +__all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"] + +_T_co = TypeVar("_T_co", covariant=True) + + +@functional_datapipe("concat") +class ConcaterMapDataPipe(MapDataPipe): + r""" + Concatenate multiple Map DataPipes (functional name: ``concat``). + + The new index of is the cumulative sum of source DataPipes. + For example, if there are 2 source DataPipes both with length 5, + index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to + elements of the first DataPipe, and 5 to 9 would refer to elements + of the second DataPipe. + + Args: + datapipes: Map DataPipes being concatenated + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp1 = SequenceWrapper(range(3)) + >>> dp2 = SequenceWrapper(range(3)) + >>> concat_dp = dp1.concat(dp2) + >>> list(concat_dp) + [0, 1, 2, 0, 1, 2] + """ + + datapipes: tuple[MapDataPipe] + + def __init__(self, *datapipes: MapDataPipe): + if len(datapipes) == 0: + raise ValueError("Expected at least one DataPipe, but got nothing") + if not all(isinstance(dp, MapDataPipe) for dp in datapipes): + raise TypeError("Expected all inputs to be `MapDataPipe`") + if not all(isinstance(dp, Sized) for dp in datapipes): + raise TypeError("Expected all inputs to be `Sized`") + self.datapipes = datapipes # type: ignore[assignment] + + def __getitem__(self, index) -> _T_co: # type: ignore[type-var] + offset = 0 + for dp in self.datapipes: + if index - offset < len(dp): + return dp[index - offset] + else: + offset += len(dp) + raise IndexError(f"Index {index} is out of range.") + + def __len__(self) -> int: + return sum(len(dp) for dp in self.datapipes) + + +@functional_datapipe("zip") +class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]): + r""" + Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). + + This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted. + + Args: + *datapipes: Map DataPipes being aggregated + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp1 = SequenceWrapper(range(3)) + >>> dp2 = SequenceWrapper(range(10, 13)) + >>> zip_dp = dp1.zip(dp2) + >>> list(zip_dp) + [(0, 10), (1, 11), (2, 12)] + """ + + datapipes: tuple[MapDataPipe[_T_co], ...] + + def __init__(self, *datapipes: MapDataPipe[_T_co]) -> None: + if len(datapipes) == 0: + raise ValueError("Expected at least one DataPipe, but got nothing") + if not all(isinstance(dp, MapDataPipe) for dp in datapipes): + raise TypeError("Expected all inputs to be `MapDataPipe`") + if not all(isinstance(dp, Sized) for dp in datapipes): + raise TypeError("Expected all inputs to be `Sized`") + self.datapipes = datapipes + + def __getitem__(self, index) -> tuple[_T_co, ...]: + res = [] + for dp in self.datapipes: + try: + res.append(dp[index]) + except IndexError as e: + raise IndexError( + f"Index {index} is out of range for one of the input MapDataPipes {dp}." + ) from e + return tuple(res) + + def __len__(self) -> int: + return min(len(dp) for dp in self.datapipes) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/grouping.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/grouping.py new file mode 100644 index 0000000000000000000000000000000000000000..167b44fad1e18c913a5279cad1211e97c9a70fb0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/grouping.py @@ -0,0 +1,74 @@ +# mypy: allow-untyped-defs +from collections.abc import Sized +from typing import TypeVar + +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import DataChunk, MapDataPipe + + +__all__ = ["BatcherMapDataPipe"] + + +_T = TypeVar("_T") + + +@functional_datapipe("batch") +class BatcherMapDataPipe(MapDataPipe[DataChunk]): + r""" + Create mini-batches of data (functional name: ``batch``). + + An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, + or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``. + + Args: + datapipe: Iterable DataPipe being batched + batch_size: The size of each batch + drop_last: Option to drop the last batch if it's not full + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> batch_dp = dp.batch(batch_size=2) + >>> list(batch_dp) + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] + """ + + datapipe: MapDataPipe + batch_size: int + drop_last: bool + + def __init__( + self, + datapipe: MapDataPipe[_T], + batch_size: int, + drop_last: bool = False, + wrapper_class: type[DataChunk] = DataChunk, + ) -> None: + assert batch_size > 0, "Batch size is required to be larger than 0!" + super().__init__() + self.datapipe = datapipe + self.batch_size = batch_size + self.drop_last = drop_last + self.wrapper_class = wrapper_class + + def __getitem__(self, index) -> DataChunk: + batch: list = [] + indices = range(index * self.batch_size, (index + 1) * self.batch_size) + try: + batch.extend(self.datapipe[i] for i in indices) + return self.wrapper_class(batch) + except IndexError as e: + if not self.drop_last and len(batch) > 0: + return self.wrapper_class(batch) + else: + raise IndexError(f"Index {index} is out of bound.") from e + + def __len__(self) -> int: + if isinstance(self.datapipe, Sized): + if self.drop_last: + return len(self.datapipe) // self.batch_size + else: + return (len(self.datapipe) + self.batch_size - 1) // self.batch_size + else: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/utils.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8be9047c3c456ce6d57a63a49c6494c8d49795bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/map/utils.py @@ -0,0 +1,53 @@ +# mypy: allow-untyped-defs +import copy +import warnings + +from torch.utils.data.datapipes.datapipe import MapDataPipe + + +__all__ = ["SequenceWrapperMapDataPipe"] + + +class SequenceWrapperMapDataPipe(MapDataPipe): + r""" + Wraps a sequence object into a MapDataPipe. + + Args: + sequence: Sequence object to be wrapped into an MapDataPipe + deepcopy: Option to deepcopy input sequence object + + .. note:: + If ``deepcopy`` is set to False explicitly, users should ensure + that data pipeline doesn't contain any in-place operations over + the iterable instance, in order to prevent data inconsistency + across iterations. + + Example: + >>> # xdoctest: +SKIP + >>> from torchdata.datapipes.map import SequenceWrapper + >>> dp = SequenceWrapper(range(10)) + >>> list(dp) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) + >>> dp['a'] + 100 + """ + + def __init__(self, sequence, deepcopy=True): + if deepcopy: + try: + self.sequence = copy.deepcopy(sequence) + except TypeError: + warnings.warn( + "The input sequence can not be deepcopied, " + "please be aware of in-place modification would affect source data" + ) + self.sequence = sequence + else: + self.sequence = sequence + + def __getitem__(self, index): + return self.sequence[index] + + def __len__(self): + return len(self.sequence) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__init__.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb26aacce0663eb7b6daab7c3762772ac44a51c3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87c140c56bef9205b087c6ec2ca16d0942d849a9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64284b933355b5c83824c85dab63f8271b781632 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/snapshot.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/snapshot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5e19680150ba5c4f0c9606825555433a7db7192 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/__pycache__/snapshot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/common.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..0454d6bb0f0239a63d8ca6d2f8f6811f9a3a38ee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/common.py @@ -0,0 +1,412 @@ +# mypy: allow-untyped-defs +import fnmatch +import functools +import inspect +import os +import warnings +from collections.abc import Iterable +from io import IOBase +from typing import Any, Callable, Optional, Union + +from torch.utils._import_utils import dill_available + + +__all__ = [ + "validate_input_col", + "StreamWrapper", + "get_file_binaries_from_pathnames", + "get_file_pathnames_from_root", + "match_masks", + "validate_pathname_binary_tuple", +] + + +# BC for torchdata +DILL_AVAILABLE = dill_available() + + +def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]): + """ + Check that function used in a callable datapipe works with the input column. + + This simply ensures that the number of positional arguments matches the size + of the input column. The function must not contain any non-default + keyword-only arguments. + + Examples: + >>> # xdoctest: +SKIP("Failing on some CI machines") + >>> def f(a, b, *, c=1): + >>> return a + b + c + >>> def f_def(a, b=1, *, c=1): + >>> return a + b + c + >>> assert validate_input_col(f, [1, 2]) + >>> assert validate_input_col(f_def, 1) + >>> assert validate_input_col(f_def, [1, 2]) + + Notes: + If the function contains variable positional (`inspect.VAR_POSITIONAL`) arguments, + for example, f(a, *args), the validator will accept any size of input column + greater than or equal to the number of positional arguments. + (in this case, 1). + + Args: + fn: The function to check. + input_col: The input column to check. + + Raises: + ValueError: If the function is not compatible with the input column. + """ + try: + sig = inspect.signature(fn) + except ( + ValueError + ): # Signature cannot be inspected, likely it is a built-in fn or written in C + return + if isinstance(input_col, (list, tuple)): + input_col_size = len(input_col) + else: + input_col_size = 1 + + pos = [] + var_positional = False + non_default_kw_only = [] + + for p in sig.parameters.values(): + if p.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + pos.append(p) + elif p.kind is inspect.Parameter.VAR_POSITIONAL: + var_positional = True + elif p.kind is inspect.Parameter.KEYWORD_ONLY: + if p.default is p.empty: + non_default_kw_only.append(p) + else: + continue + + if isinstance(fn, functools.partial): + fn_name = getattr(fn.func, "__name__", repr(fn.func)) + else: + fn_name = getattr(fn, "__name__", repr(fn)) + + if len(non_default_kw_only) > 0: + raise ValueError( + f"The function {fn_name} takes {len(non_default_kw_only)} " + f"non-default keyword-only parameters, which is not allowed." + ) + + if len(sig.parameters) < input_col_size: + if not var_positional: + raise ValueError( + f"The function {fn_name} takes {len(sig.parameters)} " + f"parameters, but {input_col_size} are required." + ) + else: + if len(pos) > input_col_size: + if any(p.default is p.empty for p in pos[input_col_size:]): + raise ValueError( + f"The function {fn_name} takes {len(pos)} " + f"positional parameters, but {input_col_size} are required." + ) + elif len(pos) < input_col_size: + if not var_positional: + raise ValueError( + f"The function {fn_name} takes {len(pos)} " + f"positional parameters, but {input_col_size} are required." + ) + + +def _is_local_fn(fn): + # Functions or Methods + if hasattr(fn, "__code__"): + return fn.__code__.co_flags & inspect.CO_NESTED + # Callable Objects + else: + if hasattr(fn, "__qualname__"): + return "" in fn.__qualname__ + fn_type = type(fn) + if hasattr(fn_type, "__qualname__"): + return "" in fn_type.__qualname__ + return False + + +def _check_unpickable_fn(fn: Callable): + """ + Check function is pickable or not. + + If it is a lambda or local function, a UserWarning will be raised. If it's not a callable function, a TypeError will be raised. + """ + if not callable(fn): + raise TypeError(f"A callable function is expected, but {type(fn)} is provided.") + + # Extract function from partial object + # Nested partial function is automatically expanded as a single partial object + if isinstance(fn, functools.partial): + fn = fn.func + + # Local function + if _is_local_fn(fn) and not dill_available(): + warnings.warn( + "Local function is not supported by pickle, please use " + "regular python function or functools.partial instead." + ) + return + + # Lambda function + if hasattr(fn, "__name__") and fn.__name__ == "" and not dill_available(): + warnings.warn( + "Lambda function is not supported by pickle, please use " + "regular python function or functools.partial instead." + ) + return + + +def match_masks(name: str, masks: Union[str, list[str]]) -> bool: + # empty mask matches any input name + if not masks: + return True + + if isinstance(masks, str): + return fnmatch.fnmatch(name, masks) + + for mask in masks: + if fnmatch.fnmatch(name, mask): + return True + return False + + +def get_file_pathnames_from_root( + root: str, + masks: Union[str, list[str]], + recursive: bool = False, + abspath: bool = False, + non_deterministic: bool = False, +) -> Iterable[str]: + # print out an error message and raise the error out + def onerror(err: OSError): + warnings.warn(err.filename + " : " + err.strerror) + raise err + + if os.path.isfile(root): + path = root + if abspath: + path = os.path.abspath(path) + fname = os.path.basename(path) + if match_masks(fname, masks): + yield path + else: + for path, dirs, files in os.walk(root, onerror=onerror): + if abspath: + path = os.path.abspath(path) + if not non_deterministic: + files.sort() + for f in files: + if match_masks(f, masks): + yield os.path.join(path, f) + if not recursive: + break + if not non_deterministic: + # Note that this is in-place modifying the internal list from `os.walk` + # This only works because `os.walk` doesn't shallow copy before turn + # https://github.com/python/cpython/blob/f4c03484da59049eb62a9bf7777b963e2267d187/Lib/os.py#L407 + dirs.sort() + + +def get_file_binaries_from_pathnames( + pathnames: Iterable, mode: str, encoding: Optional[str] = None +): + if not isinstance(pathnames, Iterable): + pathnames = [ + pathnames, + ] + + if mode in ("b", "t"): + mode = "r" + mode + + for pathname in pathnames: + if not isinstance(pathname, str): + raise TypeError( + f"Expected string type for pathname, but got {type(pathname)}" + ) + yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding)) + + +def validate_pathname_binary_tuple(data: tuple[str, IOBase]): + if not isinstance(data, tuple): + raise TypeError( + f"pathname binary data should be tuple type, but it is type {type(data)}" + ) + if len(data) != 2: + raise TypeError( + f"pathname binary stream tuple length should be 2, but got {len(data)}" + ) + if not isinstance(data[0], str): + raise TypeError( + f"pathname within the tuple should have string type pathname, but it is type {type(data[0])}" + ) + if not isinstance(data[1], IOBase) and not isinstance(data[1], StreamWrapper): + raise TypeError( + f"binary stream within the tuple should have IOBase or" + f"its subclasses as type, but it is type {type(data[1])}" + ) + + +# Deprecated function names and its corresponding DataPipe type and kwargs for the `_deprecation_warning` function +_iter_deprecated_functional_names: dict[str, dict] = {} +_map_deprecated_functional_names: dict[str, dict] = {} + + +def _deprecation_warning( + old_class_name: str, + *, + deprecation_version: str, + removal_version: str, + old_functional_name: str = "", + old_argument_name: str = "", + new_class_name: str = "", + new_functional_name: str = "", + new_argument_name: str = "", + deprecate_functional_name_only: bool = False, +) -> None: + if new_functional_name and not old_functional_name: + raise ValueError( + "Old functional API needs to be specified for the deprecation warning." + ) + if new_argument_name and not old_argument_name: + raise ValueError( + "Old argument name needs to be specified for the deprecation warning." + ) + + if old_functional_name and old_argument_name: + raise ValueError( + "Deprecating warning for functional API and argument should be separated." + ) + + msg = f"`{old_class_name}()`" + if deprecate_functional_name_only and old_functional_name: + msg = f"{msg}'s functional API `.{old_functional_name}()` is" + elif old_functional_name: + msg = f"{msg} and its functional API `.{old_functional_name}()` are" + elif old_argument_name: + msg = f"The argument `{old_argument_name}` of {msg} is" + else: + msg = f"{msg} is" + msg = ( + f"{msg} deprecated since {deprecation_version} and will be removed in {removal_version}." + f"\nSee https://github.com/pytorch/data/issues/163 for details." + ) + + if new_class_name or new_functional_name: + msg = f"{msg}\nPlease use" + if new_class_name: + msg = f"{msg} `{new_class_name}()`" + if new_class_name and new_functional_name: + msg = f"{msg} or" + if new_functional_name: + msg = f"{msg} `.{new_functional_name}()`" + msg = f"{msg} instead." + + if new_argument_name: + msg = f"{msg}\nPlease use `{old_class_name}({new_argument_name}=)` instead." + + warnings.warn(msg, FutureWarning) + + +class StreamWrapper: + """ + StreamWrapper is introduced to wrap file handler generated by DataPipe operation like `FileOpener`. + + StreamWrapper would guarantee the wrapped file handler is closed when it's out of scope. + """ + + session_streams: dict[Any, int] = {} + debug_unclosed_streams: bool = False + + def __init__(self, file_obj, parent_stream=None, name=None): + self.file_obj = file_obj + self.child_counter = 0 + self.parent_stream = parent_stream + self.close_on_last_child = False + self.name = name + self.closed = False + if parent_stream is not None: + if not isinstance(parent_stream, StreamWrapper): + raise RuntimeError( + f"Parent stream should be StreamWrapper, {type(parent_stream)} was given" + ) + parent_stream.child_counter += 1 + self.parent_stream = parent_stream + if StreamWrapper.debug_unclosed_streams: + StreamWrapper.session_streams[self] = 1 + + @classmethod + def close_streams(cls, v, depth=0): + """Traverse structure and attempts to close all found StreamWrappers on best effort basis.""" + if depth > 10: + return + if isinstance(v, StreamWrapper): + v.close() + else: + # Traverse only simple structures + if isinstance(v, dict): + for vv in v.values(): + cls.close_streams(vv, depth=depth + 1) + elif isinstance(v, (list, tuple)): + for vv in v: + cls.close_streams(vv, depth=depth + 1) + + def __getattr__(self, name): + file_obj = self.__dict__["file_obj"] + return getattr(file_obj, name) + + def close(self, *args, **kwargs): + if self.closed: + return + if StreamWrapper.debug_unclosed_streams: + del StreamWrapper.session_streams[self] + if hasattr(self, "parent_stream") and self.parent_stream is not None: + self.parent_stream.child_counter -= 1 + if ( + not self.parent_stream.child_counter + and self.parent_stream.close_on_last_child + ): + self.parent_stream.close() + try: + self.file_obj.close(*args, **kwargs) + except AttributeError: + pass + self.closed = True + + def autoclose(self): + """Automatically close stream when all child streams are closed or if there are none.""" + self.close_on_last_child = True + if self.child_counter == 0: + self.close() + + def __dir__(self): + attrs = list(self.__dict__.keys()) + list(StreamWrapper.__dict__.keys()) + attrs += dir(self.file_obj) + return list(set(attrs)) + + def __del__(self): + if not self.closed: + self.close() + + def __iter__(self): + yield from self.file_obj + + def __next__(self): + return next(self.file_obj) + + def __repr__(self): + if self.name is None: + return f"StreamWrapper<{self.file_obj!r}>" + else: + return f"StreamWrapper<{self.name},{self.file_obj!r}>" + + def __getstate__(self): + return self.file_obj + + def __setstate__(self, obj): + self.file_obj = obj diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/decoder.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e4872ada7b998179dd68fca8284de3ecdbd4ccef --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/decoder.py @@ -0,0 +1,378 @@ +# mypy: allow-untyped-defs +# This file takes partial of the implementation from NVIDIA's webdataset at here: +# https://github.com/tmbdev/webdataset/blob/master/webdataset/autodecode.py + +import io +import json +import os.path +import pickle +import tempfile + +import torch +from torch.utils.data.datapipes.utils.common import StreamWrapper + + +__all__ = [ + "Decoder", + "ImageHandler", + "MatHandler", + "audiohandler", + "basichandlers", + "extension_extract_fn", + "handle_extension", + "imagehandler", + "mathandler", + "videohandler", +] + + +################################################################ +# handle basic datatypes +################################################################ +def basichandlers(extension: str, data): + """Transforms raw data (byte stream) into python objects. + + Looks at the extension and loads the data into a python object supporting + the corresponding extension. + + Args: + extension (str): The file extension + data (byte stream): Data to load into a python object. + + Returns: + object: The data loaded into a corresponding python object + supporting the extension. + + Example: + >>> import pickle + >>> data = pickle.dumps('some data') + >>> new_data = basichandlers('pickle', data) + >>> new_data + some data + + The transformation of data for extensions are: + - txt, text, transcript: utf-8 decoded data of str format + - cls, cls2, class, count, index, inx, id: int + - json, jsn: json loaded data + - pickle, pyd: pickle loaded data + - pt: torch loaded data + """ + + if extension in "txt text transcript": + return data.decode("utf-8") + + if extension in "cls cls2 class count index inx id".split(): + try: + return int(data) + except ValueError: + return None + + if extension in "json jsn": + return json.loads(data) + + if extension in "pyd pickle".split(): + return pickle.loads(data) + + if extension in "pt".split(): + stream = io.BytesIO(data) + return torch.load(stream) + + # if extension in "ten tb".split(): + # from . import tenbin + # return tenbin.decode_buffer(data) + + # if extension in "mp msgpack msg".split(): + # import msgpack + # return msgpack.unpackb(data) + + return None + + +################################################################ +# handle images +################################################################ +imagespecs = { + "l8": ("numpy", "uint8", "l"), + "rgb8": ("numpy", "uint8", "rgb"), + "rgba8": ("numpy", "uint8", "rgba"), + "l": ("numpy", "float", "l"), + "rgb": ("numpy", "float", "rgb"), + "rgba": ("numpy", "float", "rgba"), + "torchl8": ("torch", "uint8", "l"), + "torchrgb8": ("torch", "uint8", "rgb"), + "torchrgba8": ("torch", "uint8", "rgba"), + "torchl": ("torch", "float", "l"), + "torchrgb": ("torch", "float", "rgb"), + "torch": ("torch", "float", "rgb"), + "torchrgba": ("torch", "float", "rgba"), + "pill": ("pil", None, "l"), + "pil": ("pil", None, "rgb"), + "pilrgb": ("pil", None, "rgb"), + "pilrgba": ("pil", None, "rgba"), +} + + +def handle_extension(extensions, f): + """ + Return a decoder handler function for the list of extensions. + + Extensions can be a space separated list of extensions. + Extensions can contain dots, in which case the corresponding number + of extension components must be present in the key given to f. + Comparisons are case insensitive. + Examples: + handle_extension("jpg jpeg", my_decode_jpg) # invoked for any file.jpg + handle_extension("seg.jpg", special_case_jpg) # invoked only for file.seg.jpg + """ + extensions = extensions.lower().split() + + def g(key, data): + extension = key.lower().split(".") + + for target in extensions: + target = target.split(".") + if len(target) > len(extension): + continue + + if extension[-len(target) :] == target: + return f(data) + return None + + return g + + +class ImageHandler: + """ + Decode image data using the given `imagespec`. + + The `imagespec` specifies whether the image is decoded + to numpy/torch/pi, decoded to uint8/float, and decoded + to l/rgb/rgba: + + - l8: numpy uint8 l + - rgb8: numpy uint8 rgb + - rgba8: numpy uint8 rgba + - l: numpy float l + - rgb: numpy float rgb + - rgba: numpy float rgba + - torchl8: torch uint8 l + - torchrgb8: torch uint8 rgb + - torchrgba8: torch uint8 rgba + - torchl: torch float l + - torchrgb: torch float rgb + - torch: torch float rgb + - torchrgba: torch float rgba + - pill: pil None l + - pil: pil None rgb + - pilrgb: pil None rgb + - pilrgba: pil None rgba + """ + + def __init__(self, imagespec): + assert imagespec in list( + imagespecs.keys() + ), f"unknown image specification: {imagespec}" + self.imagespec = imagespec.lower() + + def __call__(self, extension, data): + if extension.lower() not in "jpg jpeg png ppm pgm pbm pnm".split(): + return None + + try: + import numpy as np + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Package `numpy` is required to be installed for default image decoder." + "Please use `pip install numpy` to install the package" + ) from e + + try: + import PIL.Image + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Package `PIL` is required to be installed for default image decoder." + "Please use `pip install Pillow` to install the package" + ) from e + + imagespec = self.imagespec + atype, etype, mode = imagespecs[imagespec] + + with io.BytesIO(data) as stream: + img = PIL.Image.open(stream) + img.load() + img = img.convert(mode.upper()) + if atype == "pil": + return img + elif atype == "numpy": + result = np.asarray(img) + assert ( + result.dtype == np.uint8 + ), f"numpy image array should be type uint8, but got {result.dtype}" + if etype == "uint8": + return result + else: + return result.astype("f") / 255.0 + elif atype == "torch": + result = np.asarray(img) + assert ( + result.dtype == np.uint8 + ), f"numpy image array should be type uint8, but got {result.dtype}" + + if etype == "uint8": + result = np.array(result.transpose(2, 0, 1)) + return torch.tensor(result) + else: + result = np.array(result.transpose(2, 0, 1)) + return torch.tensor(result) / 255.0 + return None + + +def imagehandler(imagespec): + return ImageHandler(imagespec) + + +################################################################ +# torch video +################################################################ +def videohandler(extension, data): + if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + return None + + try: + import torchvision.io + except ImportError as e: + raise ModuleNotFoundError( + "Package `torchvision` is required to be installed for default video file loader." + "Please use `pip install torchvision`" + "to install the package" + ) from e + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return torchvision.io.read_video(fname) + + +################################################################ +# torchaudio +################################################################ +def audiohandler(extension, data): + if extension not in ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma"]: + return None + + try: + import torchaudio # type: ignore[import] + except ImportError as e: + raise ModuleNotFoundError( + "Package `torchaudio` is required to be installed for default audio file loader." + "Please use `pip install torchaudio`" + "to install the package" + ) from e + + with tempfile.TemporaryDirectory() as dirname: + fname = os.path.join(dirname, f"file.{extension}") + with open(fname, "wb") as stream: + stream.write(data) + return torchaudio.load(fname) + + +################################################################ +# mat +################################################################ +class MatHandler: + def __init__(self, **loadmat_kwargs) -> None: + try: + import scipy.io as sio + except ImportError as e: + raise ModuleNotFoundError( + "Package `scipy` is required to be installed for mat file." + "Please use `pip install scipy`" + "to install the package" + ) from e + self.sio = sio + self.loadmat_kwargs = loadmat_kwargs + + def __call__(self, extension, data): + if extension != "mat": + return None + with io.BytesIO(data) as stream: + return self.sio.loadmat(stream, **self.loadmat_kwargs) + + +def mathandler(**loadmat_kwargs): + return MatHandler(**loadmat_kwargs) + + +################################################################ +# a sample decoder +################################################################ +# Extract extension from pathname +def extension_extract_fn(pathname): + ext = os.path.splitext(pathname)[1] + # Remove dot + if ext: + ext = ext[1:] + return ext + + +class Decoder: + """ + Decode key/data sets using a list of handlers. + + For each key/data item, this iterates through the list of + handlers until some handler returns something other than None. + """ + + def __init__(self, *handler, key_fn=extension_extract_fn): + self.handlers = list(handler) if handler else [] + self.key_fn = key_fn + + # Insert new handler from the beginning of handlers list to make sure the new + # handler having the highest priority + def add_handler(self, *handler): + if not handler: + return + self.handlers = list(handler) + self.handlers + + @staticmethod + def _is_stream_handle(data): + obj_to_check = data.file_obj if isinstance(data, StreamWrapper) else data + return isinstance(obj_to_check, (io.BufferedIOBase, io.RawIOBase)) + + def decode1(self, key, data): + if not data: + return data + + # if data is a stream handle, we need to read all the content before decoding + if Decoder._is_stream_handle(data): + ds = data + # The behavior of .read can differ between streams (e.g. HTTPResponse), hence this is used instead + data = b"".join(data) + ds.close() + + for f in self.handlers: + result = f(key, data) + if result is not None: + return result + return data + + def decode(self, data): + result = {} + # single data tuple(pathname, data stream) + if isinstance(data, tuple): + data = [data] + + if data is not None: + for k, v in data: + # TODO: xinyu, figure out why Nvidia do this? + if k[0] == "_": + if isinstance(v, bytes): + v = v.decode("utf-8") + result[k] = v + continue + result[k] = self.decode1(self.key_fn(k), v) + return result + + def __call__(self, data): + return self.decode(data) diff --git a/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/snapshot.py b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/snapshot.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6f98248d395af516f69de313105a161fc9cff6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/datapipes/utils/snapshot.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs +from torch.utils.data.datapipes._hook_iterator import _SnapshotState +from torch.utils.data.datapipes.datapipe import IterDataPipe +from torch.utils.data.graph_settings import apply_random_seed + + +# TODO: Caveats +# 1. Caller (either the ReadingService or DataLoader) must pass in the initial RNG +# 2. `in_batch_shuffle` and `bucketbatch` are not compatible with this because they currently +# lack the option to `set_seed`. +def _simple_graph_snapshot_restoration( + datapipe: IterDataPipe, n_iterations: int, rng=None +) -> None: + r""" + Fast-forward the given DataPipe and its parents by ``n_iterations``, re-doing computations to restore a snapshot. + + For instance, applying this function to the final DataPipe of a graph will restore the snapshot + (via fast-forward) every DataPipe within the graph. + + After you deserialize a DataPipe, you can use its `_number_of_samples_yielded` attribute as the input + to this function to forward the DataPipe. + + A DataPipe cannot be restored twice in a row unless there is an iteration started between the restoration + attempts. + + Note: + This is the simplest but least efficient way to fast-forward a DataPipe. Usage of other fast-forwarding + methods (custom ones if necessary) are recommended. + + Args: + datapipe: IterDataPipe to be fast-forwarded + n_iterations: number of iterations to fast-forward + rng: ``Optional[torch.Generator]``. If not ``None``, this RNG will be used for shuffling. The generator + should be in its `initial` state as it was first passed into ``DataLoader`` or ``ReadingService``. + """ + if datapipe._snapshot_state == _SnapshotState.Restored: + raise RuntimeError( + "Snapshot restoration cannot be applied. You can only restore simple snapshot to the graph " + "if your graph has not been restored." + ) + + # For this snapshot restoration function, we want the DataPipe to be at its initial state prior to + # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`, + # the first reset will not actually reset. + datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`. + apply_random_seed(datapipe, rng) + + remainder = n_iterations + it = iter(datapipe) # This always reset the DataPipe if it hasn't already. + while remainder > 0: + try: + next(it) + remainder -= 1 + except StopIteration as e: + raise RuntimeError( + f"Fast-forward {datapipe} by {n_iterations} iterations " + "exceeds the number of samples available." + ) from e + datapipe._fast_forward_iterator = it + # While the DataPipe has `_fast_forward_iterator`, `next()` will get result from there instead of elsewhere. + + # This will prevent the DataPipe from resetting in the `iter()` call + # If another DataPipe is consuming it, it won't have to start over again + datapipe._snapshot_state = _SnapshotState.Restored diff --git a/phivenv/Lib/site-packages/torch/utils/data/dataset.py b/phivenv/Lib/site-packages/torch/utils/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f5bb8feb8a8b660a31a4ae4420aa80df443eb30f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/dataset.py @@ -0,0 +1,485 @@ +# mypy: allow-untyped-defs +import bisect +import itertools +import math +import warnings +from collections.abc import Sequence + +# UP006 wants 'Iterable' to be imported from collections.abc but it needs to +# stay from typing for now due to BC concerns. In particular several internal +# targets fail to typecheck with: +# TypeError: Cannot create a consistent method resolution order (MRO) for +# bases Iterable, Generic +from typing import cast, Generic, Iterable, Optional, TypeVar, Union # noqa: UP035 +from typing_extensions import deprecated + +# No 'default_generator' in torch/__init__.pyi +from torch import default_generator, Generator, randperm, Tensor + + +__all__ = [ + "Dataset", + "IterableDataset", + "TensorDataset", + "StackDataset", + "ConcatDataset", + "ChainDataset", + "Subset", + "random_split", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_T_dict = dict[str, _T_co] +_T_tuple = tuple[_T_co, ...] +_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict) + + +class Dataset(Generic[_T_co]): + r"""An abstract class representing a :class:`Dataset`. + + All datasets that represent a map from keys to data samples should subclass + it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a + data sample for a given key. Subclasses could also optionally overwrite + :meth:`__len__`, which is expected to return the size of the dataset by many + :class:`~torch.utils.data.Sampler` implementations and the default options + of :class:`~torch.utils.data.DataLoader`. Subclasses could also + optionally implement :meth:`__getitems__`, for speedup batched samples + loading. This method accepts list of indices of samples of batch and returns + list of samples. + + .. note:: + :class:`~torch.utils.data.DataLoader` by default constructs an index + sampler that yields integral indices. To make it work with a map-style + dataset with non-integral indices/keys, a custom sampler must be provided. + """ + + def __getitem__(self, index) -> _T_co: + raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") + + # def __getitems__(self, indices: List) -> List[_T_co]: + # Not implemented to prevent false-positives in fetcher check in + # torch.utils.data._utils.fetch._MapDatasetFetcher + + def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]": + return ConcatDataset([self, other]) + + # No `def __len__(self)` default? + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # in pytorch/torch/utils/data/sampler.py + + +class IterableDataset(Dataset[_T_co], Iterable[_T_co]): + r"""An iterable Dataset. + + All datasets that represent an iterable of data samples should subclass it. + Such form of datasets is particularly useful when data come from a stream. + + All subclasses should overwrite :meth:`__iter__`, which would return an + iterator of samples in this dataset. + + When a subclass is used with :class:`~torch.utils.data.DataLoader`, each + item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` + iterator. When :attr:`num_workers > 0`, each worker process will have a + different copy of the dataset object, so it is often desired to configure + each copy independently to avoid having duplicate data returned from the + workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker + process, returns information about the worker. It can be used in either the + dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's + :attr:`worker_init_fn` option to modify each copy's behavior. + + Example 1: splitting workload across all workers in :meth:`__iter__`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> # xdoctest: +SKIP("Fails on MacOS12") + >>> class MyIterableDataset(torch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... worker_info = torch.utils.data.get_worker_info() + ... if worker_info is None: # single-process data loading, return the full iterator + ... iter_start = self.start + ... iter_end = self.end + ... else: # in a worker process + ... # split workload + ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... iter_start = self.start + worker_id * per_worker + ... iter_end = min(iter_start + per_worker, self.end) + ... return iter(range(iter_start, iter_end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) + [tensor([3]), tensor([4]), tensor([5]), tensor([6])] + + >>> # xdoctest: +REQUIRES(POSIX) + >>> # Multi-process loading with two worker processes + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + >>> # With even more workers + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> class MyIterableDataset(torch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) + [3, 4, 5, 6] + >>> + >>> # Directly doing multi-process loading yields duplicate data + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) + [3, 3, 4, 4, 5, 5, 6, 6] + + >>> # Define a `worker_init_fn` that configures each dataset copy differently + >>> def worker_init_fn(worker_id): + ... worker_info = torch.utils.data.get_worker_info() + ... dataset = worker_info.dataset # the dataset copy in this worker process + ... overall_start = dataset.start + ... overall_end = dataset.end + ... # configure the dataset to only process the split workload + ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... dataset.start = overall_start + worker_id * per_worker + ... dataset.end = min(dataset.start + per_worker, overall_end) + ... + + >>> # Mult-process loading with the custom `worker_init_fn` + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) + [3, 5, 4, 6] + + >>> # With even more workers + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) + [3, 4, 5, 6] + """ + + def __add__(self, other: Dataset[_T_co]): + return ChainDataset([self, other]) + + # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + + +class TensorDataset(Dataset[tuple[Tensor, ...]]): + r"""Dataset wrapping tensors. + + Each sample will be retrieved by indexing tensors along the first dimension. + + Args: + *tensors (Tensor): tensors that have the same size of the first dimension. + """ + + tensors: tuple[Tensor, ...] + + def __init__(self, *tensors: Tensor) -> None: + assert all( + tensors[0].size(0) == tensor.size(0) for tensor in tensors + ), "Size mismatch between tensors" + self.tensors = tensors + + def __getitem__(self, index): + return tuple(tensor[index] for tensor in self.tensors) + + def __len__(self): + return self.tensors[0].size(0) + + +class StackDataset(Dataset[_T_stack]): + r"""Dataset as a stacking of multiple datasets. + + This class is useful to assemble different parts of complex input data, given as datasets. + + Example: + >>> # xdoctest: +SKIP + >>> images = ImageDataset() + >>> texts = TextDataset() + >>> tuple_stack = StackDataset(images, texts) + >>> tuple_stack[0] == (images[0], texts[0]) + >>> dict_stack = StackDataset(image=images, text=texts) + >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} + + Args: + *args (Dataset): Datasets for stacking returned as tuple. + **kwargs (Dataset): Datasets for stacking returned as dict. + """ + + datasets: Union[tuple, dict] + + def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None: + if args: + if kwargs: + raise ValueError( + "Supported either ``tuple``- (via ``args``) or" + "``dict``- (via ``kwargs``) like input/output, but both types are given." + ) + self._length = len(args[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = args + elif kwargs: + tmp = list(kwargs.values()) + self._length = len(tmp[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = kwargs + else: + raise ValueError("At least one dataset should be passed") + + def __getitem__(self, index): + if isinstance(self.datasets, dict): + return {k: dataset[index] for k, dataset in self.datasets.items()} + return tuple(dataset[index] for dataset in self.datasets) + + def __getitems__(self, indices: list): + # add batched sampling support when parent datasets supports it. + if isinstance(self.datasets, dict): + dict_batch: list[_T_dict] = [{} for _ in indices] + for k, dataset in self.datasets.items(): + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, d_sample in zip(items, dict_batch): + d_sample[k] = data + else: + for idx, d_sample in zip(indices, dict_batch): + d_sample[k] = dataset[idx] + return dict_batch + + # tuple data + list_batch: list[list] = [[] for _ in indices] + for dataset in self.datasets: + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, t_sample in zip(items, list_batch): + t_sample.append(data) + else: + for idx, t_sample in zip(indices, list_batch): + t_sample.append(dataset[idx]) + tuple_batch: list[_T_tuple] = [tuple(sample) for sample in list_batch] + return tuple_batch + + def __len__(self): + return self._length + + +class ConcatDataset(Dataset[_T_co]): + r"""Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + + Args: + datasets (sequence): List of datasets to be concatenated + """ + + datasets: list[Dataset[_T_co]] + cumulative_sizes: list[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = list(datasets) + assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] + for d in self.datasets: + assert not isinstance( + d, IterableDataset + ), "ConcatDataset does not support IterableDataset" + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError( + "absolute value of index should not exceed dataset length" + ) + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + @deprecated( + "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", + category=FutureWarning, + ) + def cummulative_sizes(self): + return self.cumulative_sizes + + +class ChainDataset(IterableDataset): + r"""Dataset for chaining multiple :class:`IterableDataset` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = datasets + + def __iter__(self): + for d in self.datasets: + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" + yield from d + + def __len__(self): + total = 0 + for d in self.datasets: + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" + total += len(d) # type: ignore[arg-type] + return total + + +class Subset(Dataset[_T_co]): + r""" + Subset of a dataset at specified indices. + + Args: + dataset (Dataset): The whole Dataset + indices (sequence): Indices in the whole set selected for subset + """ + + dataset: Dataset[_T_co] + indices: Sequence[int] + + def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None: + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.dataset[[self.indices[i] for i in idx]] + return self.dataset[self.indices[idx]] + + def __getitems__(self, indices: list[int]) -> list[_T_co]: + # add batched sampling support when parent dataset supports it. + # see torch.utils.data._utils.fetch._MapDatasetFetcher + if callable(getattr(self.dataset, "__getitems__", None)): + return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] + else: + return [self.dataset[self.indices[idx]] for idx in indices] + + def __len__(self): + return len(self.indices) + + +def random_split( + dataset: Dataset[_T], + lengths: Sequence[Union[int, float]], + generator: Optional[Generator] = default_generator, +) -> list[Subset[_T]]: + r""" + Randomly split a dataset into non-overlapping new datasets of given lengths. + + If a list of fractions that sum up to 1 is given, + the lengths will be computed automatically as + floor(frac * len(dataset)) for each fraction provided. + + After computing the lengths, if there are any remainders, 1 count will be + distributed in round-robin fashion to the lengths + until there are no remainders left. + + Optionally fix the generator for reproducible results, e.g.: + + Example: + >>> # xdoctest: +SKIP + >>> generator1 = torch.Generator().manual_seed(42) + >>> generator2 = torch.Generator().manual_seed(42) + >>> random_split(range(10), [3, 7], generator=generator1) + >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) + + Args: + dataset (Dataset): Dataset to be split + lengths (sequence): lengths or fractions of splits to be produced + generator (Generator): Generator used for the random permutation. + """ + if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: + subset_lengths: list[int] = [] + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError(f"Fraction at index {i} is not between 0 and 1") + n_items_in_split = int( + math.floor(len(dataset) * frac) # type: ignore[arg-type] + ) + subset_lengths.append(n_items_in_split) + remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] + # add 1 to all the lengths in round-robin fashion until the remainder is 0 + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn( + f"Length of split at index {i} is 0. " + f"This might result in an empty dataset." + ) + + # Cannot verify that dataset is Sized + if sum(lengths) != len(dataset): # type: ignore[arg-type] + raise ValueError( + "Sum of input lengths does not equal the length of the input dataset!" + ) + + indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] + lengths = cast(Sequence[int], lengths) + return [ + Subset(dataset, indices[offset - length : offset]) + for offset, length in zip(itertools.accumulate(lengths), lengths) + ] diff --git a/phivenv/Lib/site-packages/torch/utils/data/distributed.py b/phivenv/Lib/site-packages/torch/utils/data/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..944b57d1ead59126be5dff0a0d633b4aaa4358a3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/distributed.py @@ -0,0 +1,150 @@ +import math +from collections.abc import Iterator +from typing import Optional, TypeVar + +import torch +import torch.distributed as dist +from torch.utils.data.dataset import Dataset +from torch.utils.data.sampler import Sampler + + +__all__ = ["DistributedSampler"] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class DistributedSampler(Sampler[_T_co]): + r"""Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each + process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a + :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the + original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size and that any instance of it always + returns the same elements in the same order. + + Args: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`world_size` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + + .. warning:: + In distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + Example:: + + >>> # xdoctest: +SKIP + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" + ) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[_T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/phivenv/Lib/site-packages/torch/utils/data/graph.py b/phivenv/Lib/site-packages/torch/utils/data/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..6728b10f39f02dd625166bce0cceadddcdd923c2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/graph.py @@ -0,0 +1,161 @@ +# mypy: allow-untyped-defs +import io +import pickle +import warnings +from collections.abc import Collection +from typing import Optional, Union + +from torch.utils._import_utils import dill_available +from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe + + +__all__ = ["traverse", "traverse_dps"] + +DataPipe = Union[IterDataPipe, MapDataPipe] +DataPipeGraph = dict[int, tuple[DataPipe, "DataPipeGraph"]] + + +def _stub_unpickler(): + return "STUB" + + +# TODO(VitalyFedyunin): Make sure it works without dill module installed +def _list_connected_datapipes( + scan_obj: DataPipe, only_datapipe: bool, cache: set[int] +) -> list[DataPipe]: + f = io.BytesIO() + p = pickle.Pickler( + f + ) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is + if dill_available(): + from dill import Pickler as dill_Pickler + + d = dill_Pickler(f) + else: + d = None + + captured_connections = [] + + def getstate_hook(ori_state): + state = None + if isinstance(ori_state, dict): + state = {} + for k, v in ori_state.items(): + if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): + state[k] = v + elif isinstance(ori_state, (tuple, list)): + state = [] # type: ignore[assignment] + for v in ori_state: + if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): + state.append(v) # type: ignore[attr-defined] + elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)): + state = ori_state # type: ignore[assignment] + return state + + def reduce_hook(obj): + if obj == scan_obj or id(obj) in cache: + raise NotImplementedError + else: + captured_connections.append(obj) + # Adding id to remove duplicate DataPipe serialized at the same level + cache.add(id(obj)) + return _stub_unpickler, () + + datapipe_classes: tuple[type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment] + + try: + for cls in datapipe_classes: + cls.set_reduce_ex_hook(reduce_hook) + if only_datapipe: + cls.set_getstate_hook(getstate_hook) + try: + p.dump(scan_obj) + except (pickle.PickleError, AttributeError, TypeError): + if dill_available(): + d.dump(scan_obj) + else: + raise + finally: + for cls in datapipe_classes: + cls.set_reduce_ex_hook(None) + if only_datapipe: + cls.set_getstate_hook(None) + if dill_available(): + from dill import extend as dill_extend + + dill_extend(False) # Undo change to dispatch table + return captured_connections + + +def traverse_dps(datapipe: DataPipe) -> DataPipeGraph: + r""" + Traverse the DataPipes and their attributes to extract the DataPipe graph. + + This only looks into the attribute from each DataPipe that is either a + DataPipe and a Python collection object such as ``list``, ``tuple``, + ``set`` and ``dict``. + + Args: + datapipe: the end DataPipe of the graph + Returns: + A graph represented as a nested dictionary, where keys are ids of DataPipe instances + and values are tuples of DataPipe instance and the sub-graph + """ + cache: set[int] = set() + return _traverse_helper(datapipe, only_datapipe=True, cache=cache) + + +def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph: + r""" + Traverse the DataPipes and their attributes to extract the DataPipe graph. + + [Deprecated] + When ``only_dataPipe`` is specified as ``True``, it would only look into the + attribute from each DataPipe that is either a DataPipe and a Python collection object + such as ``list``, ``tuple``, ``set`` and ``dict``. + + Note: + This function is deprecated. Please use `traverse_dps` instead. + + Args: + datapipe: the end DataPipe of the graph + only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed. + This argument is deprecating and will be removed after the next release. + Returns: + A graph represented as a nested dictionary, where keys are ids of DataPipe instances + and values are tuples of DataPipe instance and the sub-graph + """ + msg = ( + "`traverse` function and will be removed after 1.13. " + "Please use `traverse_dps` instead." + ) + if not only_datapipe: + msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`." + warnings.warn(msg, FutureWarning) + if only_datapipe is None: + only_datapipe = False + cache: set[int] = set() + return _traverse_helper(datapipe, only_datapipe, cache) + + +# Add cache here to prevent infinite recursion on DataPipe +def _traverse_helper( + datapipe: DataPipe, only_datapipe: bool, cache: set[int] +) -> DataPipeGraph: + if not isinstance(datapipe, (IterDataPipe, MapDataPipe)): + raise RuntimeError( + f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found" + ) + + dp_id = id(datapipe) + if dp_id in cache: + return {} + cache.add(dp_id) + # Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths + items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy()) + d: DataPipeGraph = {dp_id: (datapipe, {})} + for item in items: + # Using cache.copy() here is to prevent recursion on a single path rather than global graph + # Single DataPipe can present multiple times in different paths in graph + d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy())) + return d diff --git a/phivenv/Lib/site-packages/torch/utils/data/graph_settings.py b/phivenv/Lib/site-packages/torch/utils/data/graph_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3a59620f28cb5b7e320b5d5fdc2fe221713be8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/graph_settings.py @@ -0,0 +1,174 @@ +# mypy: allow-untyped-defs +import inspect +import warnings +from typing import Any, Optional +from typing_extensions import deprecated + +import torch +from torch.utils.data.datapipes.iter.sharding import ( + _ShardingIterDataPipe, + SHARDING_PRIORITIES, +) +from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps + + +__all__ = [ + "apply_random_seed", + "apply_sharding", + "apply_shuffle_seed", + "apply_shuffle_settings", + "get_all_graph_pipes", +] + + +def get_all_graph_pipes(graph: DataPipeGraph) -> list[DataPipe]: + return _get_all_graph_pipes_helper(graph, set()) + + +def _get_all_graph_pipes_helper( + graph: DataPipeGraph, id_cache: set[int] +) -> list[DataPipe]: + results: list[DataPipe] = [] + for dp_id, (datapipe, sub_graph) in graph.items(): + if dp_id in id_cache: + continue + id_cache.add(dp_id) + results.append(datapipe) + results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache)) + return results + + +def _is_sharding_datapipe(datapipe: DataPipe) -> bool: + return isinstance(datapipe, _ShardingIterDataPipe) or ( + hasattr(datapipe, "apply_sharding") + and inspect.ismethod(datapipe.apply_sharding) + ) + + +def apply_sharding( + datapipe: DataPipe, + num_of_instances: int, + instance_id: int, + sharding_group=SHARDING_PRIORITIES.DEFAULT, +) -> DataPipe: + r""" + Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``. + + RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch. + """ + graph = traverse_dps(datapipe) + + def _helper(graph, prev_applied=None): + for dp, sub_graph in graph.values(): + applied = None + if _is_sharding_datapipe(dp): + if prev_applied is not None: + raise RuntimeError( + "Sharding twice on a single pipeline is likely unintended and will cause data loss. " + f"Sharding already applied to {prev_applied} while trying to apply to {dp}" + ) + # For BC, only provide sharding_group if accepted + sig = inspect.signature(dp.apply_sharding) + if len(sig.parameters) < 3: + dp.apply_sharding(num_of_instances, instance_id) + else: + dp.apply_sharding( + num_of_instances, instance_id, sharding_group=sharding_group + ) + applied = dp + if applied is None: + applied = prev_applied + _helper(sub_graph, applied) + + _helper(graph) + + return datapipe + + +def _is_shuffle_datapipe(datapipe: DataPipe) -> bool: + return ( + hasattr(datapipe, "set_shuffle") + and hasattr(datapipe, "set_seed") + and inspect.ismethod(datapipe.set_shuffle) + and inspect.ismethod(datapipe.set_seed) + ) + + +def apply_shuffle_settings( + datapipe: DataPipe, shuffle: Optional[bool] = None +) -> DataPipe: + r""" + Traverse the graph of ``DataPipes`` to find and set shuffle attribute. + + Apply the method to each `DataPipe` that has APIs of ``set_shuffle`` + and ``set_seed``. + + Args: + datapipe: DataPipe that needs to set shuffle attribute + shuffle: Shuffle option (default: ``None`` and no-op to the graph) + """ + if shuffle is None: + return datapipe + + graph = traverse_dps(datapipe) + all_pipes = get_all_graph_pipes(graph) + shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)] + if not shufflers and shuffle: + warnings.warn( + "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. " + "Be aware that the default buffer size might not be sufficient for your task." + ) + datapipe = datapipe.shuffle() + shufflers = [ + datapipe, + ] + + for shuffler in shufflers: + shuffler.set_shuffle(shuffle) + + return datapipe + + +@deprecated( + "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. " + "Please use `apply_random_seed` instead.", + category=FutureWarning, +) +def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe: + return apply_random_seed(datapipe, rng) + + +def _is_random_datapipe(datapipe: DataPipe) -> bool: + return hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed) + + +def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe: + r""" + Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``. + + Then set the random seed based on the provided RNG to those ``DataPipe``. + + Args: + datapipe: DataPipe that needs to set randomness + rng: Random number generator to generate random seeds + """ + graph = traverse_dps(datapipe) + all_pipes = get_all_graph_pipes(graph) + # Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once. + # And, `id` is used in case of unhashable DataPipe + cache = set() + random_datapipes = [] + for pipe in all_pipes: + if id(pipe) in cache: + continue + if _is_random_datapipe(pipe): + random_datapipes.append(pipe) + cache.add(id(pipe)) + + for pipe in random_datapipes: + random_seed = int( + torch.empty((), dtype=torch.int64).random_(generator=rng).item() + ) + pipe.set_seed(random_seed) + + return datapipe diff --git a/phivenv/Lib/site-packages/torch/utils/data/sampler.py b/phivenv/Lib/site-packages/torch/utils/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a6ceecc638baef386fa7c4d13c7bef809cb937 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/data/sampler.py @@ -0,0 +1,348 @@ +# mypy: allow-untyped-defs +import itertools +from collections.abc import Iterable, Iterator, Sequence, Sized +from typing import Generic, Optional, TypeVar, Union + +import torch + + +__all__ = [ + "BatchSampler", + "RandomSampler", + "Sampler", + "SequentialSampler", + "SubsetRandomSampler", + "WeightedRandomSampler", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class Sampler(Generic[_T_co]): + r"""Base class for all Samplers. + + Every Sampler subclass has to provide an :meth:`__iter__` method, providing a + way to iterate over indices or lists of indices (batches) of dataset elements, + and may provide a :meth:`__len__` method that returns the length of the returned iterators. + + Args: + data_source (Dataset): This argument is not used and will be removed in 2.2.0. + You may still have custom implementation that utilizes it. + + Example: + >>> # xdoctest: +SKIP + >>> class AccedingSequenceLengthSampler(Sampler[int]): + >>> def __init__(self, data: List[str]) -> None: + >>> self.data = data + >>> + >>> def __len__(self) -> int: + >>> return len(self.data) + >>> + >>> def __iter__(self) -> Iterator[int]: + >>> sizes = torch.tensor([len(x) for x in self.data]) + >>> yield from torch.argsort(sizes).tolist() + >>> + >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): + >>> def __init__(self, data: List[str], batch_size: int) -> None: + >>> self.data = data + >>> self.batch_size = batch_size + >>> + >>> def __len__(self) -> int: + >>> return (len(self.data) + self.batch_size - 1) // self.batch_size + >>> + >>> def __iter__(self) -> Iterator[List[int]]: + >>> sizes = torch.tensor([len(x) for x in self.data]) + >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): + >>> yield batch.tolist() + + .. note:: The :meth:`__len__` method isn't strictly required by + :class:`~torch.utils.data.DataLoader`, but is expected in any + calculation involving the length of a :class:`~torch.utils.data.DataLoader`. + """ + + def __init__(self, data_source: Optional[Sized] = None) -> None: + if data_source is not None: + import warnings + + warnings.warn( + "`data_source` argument is not used and will be removed in 2.2.0." + "You may still have custom implementation that utilizes it." + ) + + def __iter__(self) -> Iterator[_T_co]: + raise NotImplementedError + + # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # + # Many times we have an abstract class representing a collection/iterable of + # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally + # implementing a `__len__` method. In such cases, we must make sure to not + # provide a default implementation, because both straightforward default + # implementations have their issues: + # + # + `return NotImplemented`: + # Calling `len(subclass_instance)` raises: + # TypeError: 'NotImplementedType' object cannot be interpreted as an integer + # + # + `raise NotImplementedError`: + # This prevents triggering some fallback behavior. E.g., the built-in + # `list(X)` tries to call `len(X)` first, and executes a different code + # path if the method is not found or `NotImplemented` is returned, while + # raising a `NotImplementedError` will propagate and make the call fail + # where it could have used `__iter__` to complete the call. + # + # Thus, the only two sensible things to do are + # + # + **not** provide a default `__len__`. + # + # + raise a `TypeError` instead, which is what Python uses when users call + # a method that is not defined on an object. + # (@ssnl verifies that this works on at least Python 3.7.) + + +class SequentialSampler(Sampler[int]): + r"""Samples elements sequentially, always in the same order. + + Args: + data_source (Dataset): dataset to sample from + """ + + data_source: Sized + + def __init__(self, data_source: Sized) -> None: + self.data_source = data_source + + def __iter__(self) -> Iterator[int]: + return iter(range(len(self.data_source))) + + def __len__(self) -> int: + return len(self.data_source) + + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Dataset): dataset to sample from + replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. + generator (Generator): Generator used in sampling. + """ + + data_source: Sized + replacement: bool + + def __init__( + self, + data_source: Sized, + replacement: bool = False, + num_samples: Optional[int] = None, + generator=None, + ) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + if not isinstance(self.replacement, bool): + raise TypeError( + f"replacement should be a boolean value, but got replacement={self.replacement}" + ) + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError( + f"num_samples should be a positive integer value, but got num_samples={self.num_samples}" + ) + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint( + high=n, size=(32,), dtype=torch.int64, generator=generator + ).tolist() + yield from torch.randint( + high=n, + size=(self.num_samples % 32,), + dtype=torch.int64, + generator=generator, + ).tolist() + else: + for _ in range(self.num_samples // n): + yield from torch.randperm(n, generator=generator).tolist() + yield from torch.randperm(n, generator=generator).tolist()[ + : self.num_samples % n + ] + + def __len__(self) -> int: + return self.num_samples + + +class SubsetRandomSampler(Sampler[int]): + r"""Samples elements randomly from a given list of indices, without replacement. + + Args: + indices (sequence): a sequence of indices + generator (Generator): Generator used in sampling. + """ + + indices: Sequence[int] + + def __init__(self, indices: Sequence[int], generator=None) -> None: + self.indices = indices + self.generator = generator + + def __iter__(self) -> Iterator[int]: + for i in torch.randperm(len(self.indices), generator=self.generator).tolist(): + yield self.indices[i] + + def __len__(self) -> int: + return len(self.indices) + + +class WeightedRandomSampler(Sampler[int]): + r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). + + Args: + weights (sequence) : a sequence of weights, not necessary summing up to one + num_samples (int): number of samples to draw + replacement (bool): if ``True``, samples are drawn with replacement. + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + generator (Generator): Generator used in sampling. + + Example: + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) + [4, 4, 1, 4, 5] + >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) + [0, 1, 4, 3, 2] + """ + + weights: torch.Tensor + num_samples: int + replacement: bool + + def __init__( + self, + weights: Sequence[float], + num_samples: int, + replacement: bool = True, + generator=None, + ) -> None: + if ( + not isinstance(num_samples, int) + or isinstance(num_samples, bool) + or num_samples <= 0 + ): + raise ValueError( + f"num_samples should be a positive integer value, but got num_samples={num_samples}" + ) + if not isinstance(replacement, bool): + raise ValueError( + f"replacement should be a boolean value, but got replacement={replacement}" + ) + + weights_tensor = torch.as_tensor(weights, dtype=torch.double) + if len(weights_tensor.shape) != 1: + raise ValueError( + "weights should be a 1d sequence but given " + f"weights have shape {tuple(weights_tensor.shape)}" + ) + + self.weights = weights_tensor + self.num_samples = num_samples + self.replacement = replacement + self.generator = generator + + def __iter__(self) -> Iterator[int]: + rand_tensor = torch.multinomial( + self.weights, self.num_samples, self.replacement, generator=self.generator + ) + yield from iter(rand_tensor.tolist()) + + def __len__(self) -> int: + return self.num_samples + + +class BatchSampler(Sampler[list[int]]): + r"""Wraps another sampler to yield a mini-batch of indices. + + Args: + sampler (Sampler or Iterable): Base sampler. Can be any iterable object + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + + Example: + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__( + self, + sampler: Union[Sampler[int], Iterable[int]], + batch_size: int, + drop_last: bool, + ) -> None: + # Since collections.abc.Iterable does not check for `__getitem__`, which + # is one way for an object to be an iterable, we don't do an `isinstance` + # check here. + if ( + not isinstance(batch_size, int) + or isinstance(batch_size, bool) + or batch_size <= 0 + ): + raise ValueError( + f"batch_size should be a positive integer value, but got batch_size={batch_size}" + ) + if not isinstance(drop_last, bool): + raise ValueError( + f"drop_last should be a boolean value, but got drop_last={drop_last}" + ) + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self) -> Iterator[list[int]]: + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + sampler_iter = iter(self.sampler) + if self.drop_last: + # Create multiple references to the same iterator + args = [sampler_iter] * self.batch_size + for batch_droplast in zip(*args): + yield [*batch_droplast] + else: + batch = [*itertools.islice(sampler_iter, self.batch_size)] + while batch: + yield batch + batch = [*itertools.islice(sampler_iter, self.batch_size)] + + def __len__(self) -> int: + # Can only be called if self.sampler has __len__ implemented + # We cannot enforce this condition, so we turn off typechecking for the + # implementation below. + # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore[arg-type] + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] diff --git a/phivenv/Lib/site-packages/torch/utils/deterministic.py b/phivenv/Lib/site-packages/torch/utils/deterministic.py new file mode 100644 index 0000000000000000000000000000000000000000..7782b6dd2e1f4d6193fd2154912f6ef80f756118 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/deterministic.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import sys +import types + +import torch + + +class _Deterministic(types.ModuleType): + @property + def fill_uninitialized_memory(self): + """ + Whether to fill uninitialized memory with a known value when + :meth:`torch.use_deterministic_algorithms()` is set to ``True``. + """ + return torch._C._get_deterministic_fill_uninitialized_memory() + + @fill_uninitialized_memory.setter + def fill_uninitialized_memory(self, mode): + return torch._C._set_deterministic_fill_uninitialized_memory(mode) + + +sys.modules[__name__].__class__ = _Deterministic diff --git a/phivenv/Lib/site-packages/torch/utils/dlpack.py b/phivenv/Lib/site-packages/torch/utils/dlpack.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ab6fe208f63baa75eaf5725b76dda614f48d84 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/dlpack.py @@ -0,0 +1,127 @@ +from typing import Any + +import torch +import enum + +from torch._C import _from_dlpack +from torch._C import _to_dlpack as to_dlpack + +__all__ = [ + "DLDeviceType", + "from_dlpack", + "to_dlpack", +] + + +class DLDeviceType(enum.IntEnum): + # Enums as in DLPack specification (aten/src/ATen/dlpack.h) + kDLCPU = 1, + kDLGPU = 2, + kDLCPUPinned = 3, + kDLOpenCL = 4, + kDLVulkan = 7, + kDLMetal = 8, + kDLVPI = 9, + kDLROCM = 10, + kDLExtDev = 12, + kDLOneAPI = 14, + + +torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule + +Returns an opaque object (a "DLPack capsule") representing the tensor. + +.. note:: + ``to_dlpack`` is a legacy DLPack interface. The capsule it returns + cannot be used for anything in Python other than use it as input to + ``from_dlpack``. The more idiomatic use of DLPack is to call + ``from_dlpack`` directly on the tensor object - this works when that + object has a ``__dlpack__`` method, which PyTorch and most other + libraries indeed have now. + +.. warning:: + Only call ``from_dlpack`` once per capsule produced with ``to_dlpack``. + Behavior when a capsule is consumed multiple times is undefined. + +Args: + tensor: a tensor to be exported + +The DLPack capsule shares the tensor's memory. +""") + + +# TODO: add a typing.Protocol to be able to tell Mypy that only objects with +# __dlpack__ and __dlpack_device__ methods are accepted. +def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': + """from_dlpack(ext_tensor) -> Tensor + + Converts a tensor from an external library into a ``torch.Tensor``. + + The returned PyTorch tensor will share the memory with the input tensor + (which may have come from another library). Note that in-place operations + will therefore also affect the data of the input tensor. This may lead to + unexpected issues (e.g., other libraries may have read-only flags or + immutable data structures), so the user should only do this if they know + for sure that this is fine. + + Args: + ext_tensor (object with ``__dlpack__`` attribute, or a DLPack capsule): + The tensor or DLPack capsule to convert. + + If ``ext_tensor`` is a tensor (or ndarray) object, it must support + the ``__dlpack__`` protocol (i.e., have a ``ext_tensor.__dlpack__`` + method). Otherwise ``ext_tensor`` may be a DLPack capsule, which is + an opaque ``PyCapsule`` instance, typically produced by a + ``to_dlpack`` function or method. + + Examples:: + + >>> import torch.utils.dlpack + >>> t = torch.arange(4) + + # Convert a tensor directly (supported in PyTorch >= 1.10) + >>> t2 = torch.from_dlpack(t) + >>> t2[:2] = -1 # show that memory is shared + >>> t2 + tensor([-1, -1, 2, 3]) + >>> t + tensor([-1, -1, 2, 3]) + + # The old-style DLPack usage, with an intermediate capsule object + >>> capsule = torch.utils.dlpack.to_dlpack(t) + >>> capsule + + >>> t3 = torch.from_dlpack(capsule) + >>> t3 + tensor([-1, -1, 2, 3]) + >>> t3[0] = -9 # now we're sharing memory between 3 tensors + >>> t3 + tensor([-9, -1, 2, 3]) + >>> t2 + tensor([-9, -1, 2, 3]) + >>> t + tensor([-9, -1, 2, 3]) + + """ + if hasattr(ext_tensor, '__dlpack__'): + device = ext_tensor.__dlpack_device__() + # device is either CUDA or ROCm, we need to pass the current + # stream + if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): + stream = torch.cuda.current_stream(f'cuda:{device[1]}') + # cuda_stream is the pointer to the stream and it is a public + # attribute, but it is not documented + # The array API specify that the default legacy stream must be passed + # with a value of 1 for CUDA + # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none + is_cuda = device[0] == DLDeviceType.kDLGPU + # Since pytorch is not using PTDS by default, lets directly pass + # the legacy stream + stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream + dlpack = ext_tensor.__dlpack__(stream=stream_ptr) + else: + dlpack = ext_tensor.__dlpack__() + else: + # Old versions just call the converter + dlpack = ext_tensor + return _from_dlpack(dlpack) diff --git a/phivenv/Lib/site-packages/torch/utils/file_baton.py b/phivenv/Lib/site-packages/torch/utils/file_baton.py new file mode 100644 index 0000000000000000000000000000000000000000..cccc1482e1a46f908d4edf7048367d2c79853d2d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/file_baton.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +import os +import time +import warnings + + +class FileBaton: + """A primitive, file-based synchronization utility.""" + + def __init__(self, lock_file_path, wait_seconds=0.1, warn_after_seconds=None): + """ + Create a new :class:`FileBaton`. + + Args: + lock_file_path: The path to the file used for locking. + wait_seconds: The seconds to periodically sleep (spin) when + calling ``wait()``. + warn_after_seconds: The seconds to wait before showing + lock file path to warn existing lock file. + """ + self.lock_file_path = lock_file_path + self.wait_seconds = wait_seconds + self.fd = None + self.warn_after_seconds = warn_after_seconds + + def try_acquire(self): + """ + Try to atomically create a file under exclusive access. + + Returns: + True if the file could be created, else False. + """ + try: + self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL) + return True + except FileExistsError: + return False + + def wait(self): + """ + Periodically sleeps for a certain amount until the baton is released. + + The amount of time slept depends on the ``wait_seconds`` parameter + passed to the constructor. + """ + has_warned = False + + start_time = time.time() + while os.path.exists(self.lock_file_path): + time.sleep(self.wait_seconds) + + if self.warn_after_seconds is not None: + if time.time() - start_time > self.warn_after_seconds and not has_warned: + warnings.warn(f'Waited on lock file "{self.lock_file_path}" for ' + f'{self.warn_after_seconds} seconds.') + has_warned = True + + def release(self): + """Release the baton and removes its file.""" + if self.fd is not None: + os.close(self.fd) + + os.remove(self.lock_file_path) diff --git a/phivenv/Lib/site-packages/torch/utils/flop_counter.py b/phivenv/Lib/site-packages/torch/utils/flop_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..2952228cd8ad4f6fdca58142d85771c6b9d92afa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/flop_counter.py @@ -0,0 +1,792 @@ +# mypy: allow-untyped-defs +import torch +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten +from .module_tracker import ModuleTracker +from typing import Any, Optional, Union, TypeVar, Callable +from collections.abc import Iterator +from typing_extensions import ParamSpec +from collections import defaultdict +from torch.utils._python_dispatch import TorchDispatchMode +from math import prod +from functools import wraps +import warnings + +__all__ = ["FlopCounterMode", "register_flop_formula"] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +aten = torch.ops.aten + +def get_shape(i): + if isinstance(i, torch.Tensor): + return i.shape + return i + +flop_registry: dict[Any, Any] = {} + +def shape_wrapper(f): + @wraps(f) + def nf(*args, out_val=None, **kwargs): + args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val)) + return f(*args, out_shape=out_shape, **kwargs) + return nf + +def register_flop_formula(targets, get_raw=False) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]: + if not get_raw: + flop_formula = shape_wrapper(flop_formula) + + def register(target): + if not isinstance(target, torch._ops.OpOverloadPacket): + raise ValueError( + f"register_flop_formula(targets): expected each target to be " + f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got " + f"{target} which is of type {type(target)}") + if target in flop_registry: + raise RuntimeError(f"duplicate registrations for {target}") + flop_registry[target] = flop_formula + + # To handle allowing multiple aten_ops at once + torch.utils._pytree.tree_map_(register, targets) + + return flop_formula + + return register_fun + +@register_flop_formula(aten.mm) +def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int: + """Count flops for matmul.""" + # Inputs should be a list of length 2. + # Inputs contains the shapes of two matrices. + m, k = a_shape + k2, n = b_shape + assert k == k2 + # NB(chilli): Should be 2 * k - 1 technically for FLOPs. + return m * n * 2 * k + +@register_flop_formula(aten.addmm) +def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: + """Count flops for addmm.""" + return mm_flop(a_shape, b_shape) + +@register_flop_formula(aten.bmm) +def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int: + """Count flops for the bmm operation.""" + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor. + b, m, k = a_shape + b2, k2, n = b_shape + assert b == b2 + assert k == k2 + # NB(chilli): Should be 2 * k - 1 technically for FLOPs. + flop = b * m * n * 2 * k + return flop + +@register_flop_formula(aten.baddbmm) +def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int: + """Count flops for the baddbmm operation.""" + # Inputs should be a list of length 3. + # Inputs contains the shapes of three tensors. + return bmm_flop(a_shape, b_shape) + +@register_flop_formula(aten._scaled_mm) +def _scaled_mm_flop( + a_shape, + b_shape, + scale_a_shape, + scale_b_shape, + bias_shape=None, + scale_result_shape=None, + out_dtype=None, + use_fast_accum=False, + out_shape=None, + **kwargs, +) -> int: + """Count flops for _scaled_mm.""" + return mm_flop(a_shape, b_shape) + + +def conv_flop_count( + x_shape: list[int], + w_shape: list[int], + out_shape: list[int], + transposed: bool = False, +) -> int: + """Count flops for convolution. + + Note only multiplication is + counted. Computation for bias are ignored. + Flops for a transposed convolution are calculated as + flops = (x_shape[2:] * prod(w_shape) * batch_size). + Args: + x_shape (list(int)): The input shape before convolution. + w_shape (list(int)): The filter shape. + out_shape (list(int)): The output shape after convolution. + transposed (bool): is the convolution transposed + Returns: + int: the number of flops + """ + + batch_size = x_shape[0] + conv_shape = (x_shape if transposed else out_shape)[2:] + c_out, c_in, *filter_size = w_shape + + """ + General idea here is that for a regular conv, for each point in the output + spatial dimension we convolve the filter with something (hence + `prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by + 1. batch_size, 2. the cross product of input and weight channels. + + For the transpose, it's not each point in the *output* spatial dimension but + each point in the *input* spatial dimension. + """ + # NB(chilli): I don't think this properly accounts for padding :think: + # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs. + flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2 + return flop + +@register_flop_formula([aten.convolution, aten._convolution]) +def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int: + """Count flops for convolution.""" + return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) + + +@register_flop_formula(aten.convolution_backward) +def conv_backward_flop( + grad_out_shape, + x_shape, + w_shape, + _bias, + _stride, + _padding, + _dilation, + transposed, + _output_padding, + _groups, + output_mask, + out_shape) -> int: + + def t(shape): + return [shape[1], shape[0]] + list(shape[2:]) + flop_count = 0 + + """ + Let's say we have a regular 1D conv + {A, B, C} [inp] + {i, j} [weight] + => (conv) + {Ai + Bj, Bi + Cj} [out] + + And as a reminder, the transposed conv of the above is + => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out] + + For the backwards of conv, we now have + {D, E} [grad_out] + {A, B, C} [inp] + {i, j} [weight] + + # grad_inp as conv_transpose(grad_out, weight) + Let's first compute grad_inp. To do so, we can simply look at all the + multiplications that each element of inp is involved in. For example, A is + only involved in the first element of the output (and thus only depends upon + D in grad_out), and C is only involved in the last element of the output + (and thus only depends upon E in grad_out) + + {Di, Dj + Ei, Ej} [grad_inp] + + Note that this corresponds to the below conv_transpose. This gives us the + output_mask[0] branch, which is grad_inp. + + {D, E} [inp (grad_out)] + {i, j} [weight] + => (conv_transpose) + {Di, Dj + Ei, Ej} [out (grad_inp)] + + I leave the fact that grad_inp for a transposed conv is just conv(grad_out, + weight) as an exercise for the reader. + + # grad_weight as conv(inp, grad_out) + To compute grad_weight, we again look at the terms in the output, which as + a reminder is: + => {Ai + Bj, Bi + Cj} [out] + => {D, E} [grad_out] + If we manually compute the gradient for the weights, we see it's + {AD + BE, BD + CE} [grad_weight] + + This corresponds to the below conv + {A, B, C} [inp] + {D, E} [weight (grad_out)] + => (conv) + {AD + BE, BD + CE} [out (grad_weight)] + + # grad_weight of transposed conv as conv(grad_out, inp) + As a reminder, the terms of the output of a transposed conv are: + => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out] + => {D, E, F, G} [grad_out] + + Manually computing the gradient for the weights, we see it's + {AD + BE + CF, AE + BF + CG} [grad_weight] + + This corresponds to the below conv + {D, E, F, G} [inp (grad_out)] + {A, B, C} [weight (inp)] + => (conv) + {AD + BE + CF, AE + BF + CG} [out (grad_weight)] + + For the full backwards formula, there are also some details involving + transpose of the batch/channel dimensions and groups, but I skip those for + the sake of brevity (and they're pretty similar to matmul backwards) + + Check [conv backwards decomposition as conv forwards] + """ + # grad_inp as conv_transpose(grad_out, weight) + if output_mask[0]: + grad_input_shape = get_shape(out_shape[0]) + flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed) + + if output_mask[1]: + grad_weight_shape = get_shape(out_shape[1]) + if transposed: + # grad_weight of transposed conv as conv(grad_out, inp) + flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False) + else: + # grad_weight as conv(inp, grad_out) + flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False) + + return flop_count + +def sdpa_flop_count(query_shape, key_shape, value_shape): + """ + Count flops for self-attention. + + NB: We can assume that value_shape == key_shape + """ + b, h, s_q, d_q = query_shape + _b2, _h2, s_k, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2 + total_flops = 0 + # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k] + total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k)) + # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v] + total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v)) + return total_flops + + +@register_flop_formula([aten._scaled_dot_product_efficient_attention, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_cudnn_attention]) +def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: + """Count flops for self-attention.""" + # NB: We aren't accounting for causal attention here + return sdpa_flop_count(query_shape, key_shape, value_shape) + + +def _offsets_to_lengths(offsets, max_len): + """ + If the offsets tensor is fake, then we don't know the actual lengths. + In that case, we can just assume the worst case; each batch has max length. + """ + from torch._subclasses.fake_tensor import FakeTensor + from torch._subclasses.functional_tensor import FunctionalTensor + if not isinstance(offsets, (FakeTensor, FunctionalTensor)) and offsets.device.type != "meta": + return offsets.diff().tolist() + return [max_len] * (offsets.size(0) - 1) + + +def _unpack_flash_attention_nested_shapes( + *, + query, + key, + value, + grad_out=None, + cum_seq_q, + cum_seq_k, + max_q, + max_k, +) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], Optional[tuple[int, ...]]]]: + """ + Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for + NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for + each batch element. + + In the case that this isn't a NestedTensor kernel, then it just yields the original shapes. + """ + if cum_seq_q is not None: + # This means we should be dealing with a Nested Jagged Tensor query. + # The inputs will have shape (sum(sequence len), heads, dimension) + # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension) + # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension) + # So the flops calculation in this case is an overestimate of the actual flops. + assert len(key.shape) == 3 + assert len(value.shape) == 3 + assert grad_out is None or grad_out.shape == query.shape + _, h_q, d_q = query.shape + _, h_k, d_k = key.shape + _, h_v, d_v = value.shape + assert cum_seq_q is not None + assert cum_seq_k is not None + assert cum_seq_q.shape == cum_seq_k.shape + seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q) + seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k) + for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths): + new_query_shape = (1, h_q, seq_q_len, d_q) + new_key_shape = (1, h_k, seq_k_len, d_k) + new_value_shape = (1, h_v, seq_k_len, d_v) + new_grad_out_shape = new_query_shape if grad_out is not None else None + yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape + return + + yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None + + +def _unpack_efficient_attention_nested_shapes( + *, + query, + key, + value, + grad_out=None, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, +) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], Optional[tuple[int, ...]]]]: + """ + Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for + NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for + each batch element. + + In the case that this isn't a NestedTensor kernel, then it just yields the original shapes. + """ + if cu_seqlens_q is not None: + # Unlike flash_attention_forward, we get a 4D tensor instead of a 3D tensor for efficient attention. + # + # This means we should be dealing with a Nested Jagged Tensor query. + # The inputs will have shape (sum(sequence len), heads, dimension) + # In comparison, non-Nested inputs have shape (batch, heads, sequence len, dimension) + # To deal with this, we convert to a shape of (batch, heads, max_seq_len, dimension) + # So the flops calculation in this case is an overestimate of the actual flops. + assert len(key.shape) == 4 + assert len(value.shape) == 4 + assert grad_out is None or grad_out.shape == query.shape + _, _, h_q, d_q = query.shape + _, _, h_k, d_k = key.shape + _, _, h_v, d_v = value.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert cu_seqlens_q.shape == cu_seqlens_k.shape + seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q) + seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k) + for len_q, len_k in zip(seqlens_q, seqlens_k): + new_query_shape = (1, h_q, len_q, d_q) + new_key_shape = (1, h_k, len_k, d_k) + new_value_shape = (1, h_v, len_k, d_v) + new_grad_out_shape = new_query_shape if grad_out is not None else None + yield new_query_shape, new_key_shape, new_value_shape, new_grad_out_shape + return + + yield query.shape, key.shape, value.shape, grad_out.shape if grad_out is not None else None + + +@register_flop_formula(aten._flash_attention_forward, get_raw=True) +def _flash_attention_forward_flop( + query, + key, + value, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + *args, + out_shape=None, + **kwargs +) -> int: + """Count flops for self-attention.""" + # NB: We aren't accounting for causal attention here + # in case this is a nested tensor, we unpack the individual batch elements + # and then sum the flops per batch element + sizes = _unpack_flash_attention_nested_shapes( + query=query, + key=key, + value=value, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + ) + return sum( + sdpa_flop_count(query_shape, key_shape, value_shape) + for query_shape, key_shape, value_shape, _ in sizes + ) + + +@register_flop_formula(aten._efficient_attention_forward, get_raw=True) +def _efficient_attention_forward_flop( + query, + key, + value, + bias, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + *args, + **kwargs +) -> int: + """Count flops for self-attention.""" + # NB: We aren't accounting for causal attention here + # in case this is a nested tensor, we unpack the individual batch elements + # and then sum the flops per batch element + sizes = _unpack_efficient_attention_nested_shapes( + query=query, + key=key, + value=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + ) + return sum( + sdpa_flop_count(query_shape, key_shape, value_shape) + for query_shape, key_shape, value_shape, _ in sizes + ) + + +def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape): + total_flops = 0 + b, h, s_q, d_q = query_shape + _b2, _h2, s_k, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + _b4, _h4, _s4, _d4 = grad_out_shape + assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2 + assert d_v == _d4 and s_k == _s3 and s_q == _s4 + total_flops = 0 + # Step 1: We recompute the scores matrix. + # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k] + total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k)) + + # Step 2: We propagate the gradients through the score @ v operation. + # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k] + total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k)) + # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v] + total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v)) + + # Step 3: We propagate th gradients through the k @ v operation + # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q] + total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q)) + # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k] + total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k)) + return total_flops + + +@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, + aten._scaled_dot_product_flash_attention_backward, + aten._scaled_dot_product_cudnn_attention_backward]) +def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int: + """Count flops for self-attention backward.""" + return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) + +@register_flop_formula(aten._flash_attention_backward, get_raw=True) +def _flash_attention_backward_flop( + grad_out, + query, + key, + value, + out, # named _out_shape to avoid kwarg collision with out_shape created in wrapper + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + *args, + **kwargs, +) -> int: + # in case this is a nested tensor, we unpack the individual batch elements + # and then sum the flops per batch element + shapes = _unpack_flash_attention_nested_shapes( + query=query, + key=key, + value=value, + grad_out=grad_out, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + ) + return sum( + sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) + for query_shape, key_shape, value_shape, grad_out_shape in shapes + ) + + +@register_flop_formula(aten._efficient_attention_backward, get_raw=True) +def _efficient_attention_backward_flop( + grad_out, + query, + key, + value, + bias, + out, # named _out to avoid kwarg collision with out created in wrapper + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + *args, + **kwargs, +) -> int: + # in case this is a nested tensor, we unpack the individual batch elements + # and then sum the flops per batch element + shapes = _unpack_efficient_attention_nested_shapes( + query=query, + key=key, + value=value, + grad_out=grad_out, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + ) + return sum( + sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) + for query_shape, key_shape, value_shape, grad_out_shape in shapes + ) + + +flop_registry = { + aten.mm: mm_flop, + aten.addmm: addmm_flop, + aten.bmm: bmm_flop, + aten.baddbmm: baddbmm_flop, + aten._scaled_mm: _scaled_mm_flop, + aten.convolution: conv_flop, + aten._convolution: conv_flop, + aten.convolution_backward: conv_backward_flop, + aten._scaled_dot_product_efficient_attention: sdpa_flop, + aten._scaled_dot_product_flash_attention: sdpa_flop, + aten._scaled_dot_product_cudnn_attention: sdpa_flop, + aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop, + aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop, + aten._scaled_dot_product_cudnn_attention_backward: sdpa_backward_flop, + aten._flash_attention_forward: _flash_attention_forward_flop, + aten._efficient_attention_forward: _efficient_attention_forward_flop, + aten._flash_attention_backward: _flash_attention_backward_flop, + aten._efficient_attention_backward: _efficient_attention_backward_flop, +} + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +# Define the suffixes for different orders of magnitude +suffixes = ["", "K", "M", "B", "T"] +# Thanks BingChat! +def get_suffix_str(number): + # Find the index of the appropriate suffix based on the number of digits + # with some additional overflow. + # i.e. 1.01B should be displayed as 1001M, not 1.001B + index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3)) + return suffixes[index] + +def convert_num_with_suffix(number, suffix): + index = suffixes.index(suffix) + # Divide the number by 1000^index and format it to two decimal places + value = f"{number / 1000 ** index:.3f}" + # Return the value and the suffix as a string + return value + suffixes[index] + +def convert_to_percent_str(num, denom): + if denom == 0: + return "0%" + return f"{num / denom:.2%}" + +def _pytreeify_preserve_structure(f): + @wraps(f) + def nf(args): + flat_args, spec = tree_flatten(args) + out = f(*flat_args) + return tree_unflatten(out, spec) + + return nf + + +class FlopCounterMode: + """ + ``FlopCounterMode`` is a context manager that counts the number of flops within its context. + + It does this using a ``TorchDispatchMode``. + + It also supports hierarchical output by passing a module (or list of + modules) to FlopCounterMode on construction. If you do not need hierarchical + output, you do not need to use it with a module. + + Example usage + + .. code-block:: python + + mod = ... + with FlopCounterMode(mod) as flop_counter: + mod.sum().backward() + + """ + + def __init__( + self, + mods: Optional[Union[torch.nn.Module, list[torch.nn.Module]]] = None, + depth: int = 2, + display: bool = True, + custom_mapping: Optional[dict[Any, Any]] = None): + super().__init__() + self.flop_counts: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int)) + self.depth = depth + self.display = display + self.mode: Optional[_FlopCounterMode] = None + if custom_mapping is None: + custom_mapping = {} + if mods is not None: + warnings.warn("mods argument is not needed anymore, you can stop passing it", stacklevel=2) + self.flop_registry = { + **flop_registry, + **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()} + } + self.mod_tracker = ModuleTracker() + + def get_total_flops(self) -> int: + return sum(self.flop_counts['Global'].values()) + + def get_flop_counts(self) -> dict[str, dict[Any, int]]: + """Return the flop counts as a dictionary of dictionaries. + + The outer + dictionary is keyed by module name, and the inner dictionary is keyed by + operation name. + + Returns: + Dict[str, Dict[Any, int]]: The flop counts as a dictionary. + """ + return {k: dict(v) for k, v in self.flop_counts.items()} + + def get_table(self, depth=None): + if depth is None: + depth = self.depth + if depth is None: + depth = 999999 + + import tabulate + tabulate.PRESERVE_WHITESPACE = True + header = ["Module", "FLOP", "% Total"] + values = [] + global_flops = self.get_total_flops() + global_suffix = get_suffix_str(global_flops) + is_global_subsumed = False + + def process_mod(mod_name, depth): + nonlocal is_global_subsumed + + total_flops = sum(self.flop_counts[mod_name].values()) + + is_global_subsumed |= total_flops >= global_flops + + padding = " " * depth + values = [] + values.append([ + padding + mod_name, + convert_num_with_suffix(total_flops, global_suffix), + convert_to_percent_str(total_flops, global_flops) + ]) + for k, v in self.flop_counts[mod_name].items(): + values.append([ + padding + " - " + str(k), + convert_num_with_suffix(v, global_suffix), + convert_to_percent_str(v, global_flops) + ]) + return values + + for mod in sorted(self.flop_counts.keys()): + if mod == 'Global': + continue + mod_depth = mod.count(".") + 1 + if mod_depth > depth: + continue + + cur_values = process_mod(mod, mod_depth - 1) + values.extend(cur_values) + + # We do a bit of messing around here to only output the "Global" value + # if there are any FLOPs in there that aren't already fully contained by + # a module. + if 'Global' in self.flop_counts and not is_global_subsumed: + for value in values: + value[0] = " " + value[0] + + values = process_mod('Global', 0) + values + + if len(values) == 0: + values = [["Global", "0", "0%"]] + + return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right")) + + # NB: This context manager is NOT reentrant + def __enter__(self): + self.flop_counts.clear() + self.mod_tracker.__enter__() + self.mode = _FlopCounterMode(self) + self.mode.__enter__() + return self + + def __exit__(self, *args): + assert self.mode is not None + b = self.mode.__exit__(*args) + self.mode = None # break cycles + self.mod_tracker.__exit__() + if self.display: + print(self.get_table(self.depth)) + return b + + def _count_flops(self, func_packet, out, args, kwargs): + if func_packet in self.flop_registry: + flop_count_func = self.flop_registry[func_packet] + flop_count = flop_count_func(*args, **kwargs, out_val=out) # type: ignore[operator] + for par in set(self.mod_tracker.parents): + self.flop_counts[par][func_packet] += flop_count + + return out + + +class _FlopCounterMode(TorchDispatchMode): + def __init__(self, counter: FlopCounterMode): + self.counter = counter + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + + # Skip ops from non-standard dispatch_sizes_strides_policy such as NJT + if func in {torch.ops.aten.is_contiguous.default, + torch.ops.aten.is_contiguous.memory_format, + torch.ops.aten.is_strides_like_format.default, + torch.ops.aten.is_non_overlapping_and_dense.default, + torch.ops.aten.size.default, + torch.ops.aten.sym_size.default, + torch.ops.aten.stride.default, + torch.ops.aten.sym_stride.default, + torch.ops.aten.storage_offset.default, + torch.ops.aten.sym_storage_offset.default, + torch.ops.aten.numel.default, + torch.ops.aten.sym_numel.default, + torch.ops.aten.dim.default, + torch.ops.prim.layout.default}: + + return NotImplemented + + # If we don't have func in flop_registry, see if it can decompose + if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default: + with self: + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + # no further decomposition; execute & count flops + out = func(*args, **kwargs) + return self.counter._count_flops(func._overloadpacket, out, args, kwargs) diff --git a/phivenv/Lib/site-packages/torch/utils/hipify/__init__.py b/phivenv/Lib/site-packages/torch/utils/hipify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77ac4a0d10ed25a4dc7328e910d03520c7d389bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/hipify/__init__.py @@ -0,0 +1 @@ +from .version import __version__ diff --git a/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba6fbdc23f6242317e7998829557c2d83c3c6913 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/constants.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c31765d5b8e8f66ff372a39fa688dc66a6276719 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/constants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/hipify_python.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/hipify_python.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d74c1df381d59c734359bc278d7cdbf6e871a04 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/hipify_python.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/version.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eff57c21e4c19047bc7b933f3d85b63b39458f1b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/hipify/__pycache__/version.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/hipify/constants.py b/phivenv/Lib/site-packages/torch/utils/hipify/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..3da34e800ed7bfed6190276b07ea4c7f28833224 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/hipify/constants.py @@ -0,0 +1,62 @@ +"""Constants for annotations in the mapping. + +The constants defined here are used to annotate the mapping tuples in cuda_to_hip_mappings.py. +They are based on +https://github.com/ROCm/HIPIFY/blob/master/src/Statistics.h +and fall in three categories: 1) type of mapping, 2) API of mapping, 3) unsupported +mapping. +""" + +CONV_VERSION = 0, +CONV_INIT = 1 +CONV_DEVICE = 2 +CONV_MEM = 3 +CONV_KERN = 4 +CONV_COORD_FUNC = 5 +CONV_MATH_FUNC = 6 +CONV_DEVICE_FUNC = 7 +CONV_SPECIAL_FUNC = 8 +CONV_STREAM = 9 +CONV_EVENT = 10 +CONV_OCCUPANCY = 11 +CONV_CONTEXT = 12 +CONV_PEER = 13 +CONV_MODULE = 14 +CONV_CACHE = 15 +CONV_EXEC = 16 +CONV_ERROR = 17 +CONV_DEF = 18 +CONV_TEX = 19 +CONV_GL = 20 +CONV_GRAPHICS = 21 +CONV_SURFACE = 22 +CONV_JIT = 23 +CONV_D3D9 = 24 +CONV_D3D10 = 25 +CONV_D3D11 = 26 +CONV_VDPAU = 27 +CONV_EGL = 28 +CONV_THREAD = 29 +CONV_OTHER = 30 +CONV_INCLUDE = 31 +CONV_INCLUDE_CUDA_MAIN_H = 32 +CONV_TYPE = 33 +CONV_LITERAL = 34 +CONV_NUMERIC_LITERAL = 35 +CONV_LAST = 36 + +API_DRIVER = 37 +API_RUNTIME = 38 +API_BLAS = 39 +API_SPECIAL = 40 +API_RAND = 41 +API_LAST = 42 +API_FFT = 43 +API_RTC = 44 +API_ROCTX = 45 + +HIP_UNSUPPORTED = 46 +API_PYTORCH = 1337 +API_CAFFE2 = 1338 +API_C10 = 1339 +API_ROCMSMI = 1340 diff --git a/phivenv/Lib/site-packages/torch/utils/hipify/cuda_to_hip_mappings.py b/phivenv/Lib/site-packages/torch/utils/hipify/cuda_to_hip_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..0686d441a045f28da62d00cfee428fe1ae1e7ec5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/hipify/cuda_to_hip_mappings.py @@ -0,0 +1,8821 @@ +import collections +import os + +from .constants import (API_BLAS, API_C10, API_CAFFE2, API_DRIVER, API_FFT, + API_PYTORCH, API_RAND, API_ROCTX, API_RTC, API_RUNTIME, + API_SPECIAL, API_ROCMSMI, CONV_CACHE, CONV_CONTEXT, CONV_D3D9, + CONV_D3D10, CONV_D3D11, CONV_DEF, CONV_DEVICE, + CONV_DEVICE_FUNC, CONV_EGL, CONV_ERROR, CONV_EVENT, + CONV_EXEC, CONV_GL, CONV_GRAPHICS, CONV_INCLUDE, + CONV_INCLUDE_CUDA_MAIN_H, CONV_INIT, CONV_JIT, + CONV_MATH_FUNC, CONV_MEM, CONV_MODULE, + CONV_NUMERIC_LITERAL, CONV_OCCUPANCY, CONV_OTHER, + CONV_PEER, CONV_SPECIAL_FUNC, CONV_STREAM, + CONV_SURFACE, CONV_TEX, CONV_THREAD, CONV_TYPE, + CONV_VDPAU, CONV_VERSION, HIP_UNSUPPORTED) + +""" Mapping of CUDA functions, include files, constants, and types to ROCm/HIP equivalents +This closely follows the implementation in hipify-clang +https://github.com/ROCm/hip/blob/59071b895ed1c86d9698b4c859cefcdd5acda06f/hipify-clang/src/CUDA2HipMap.cpp +and its structure. +There are different maps for fundamental names, include files, identifies, sparse, and +PyTorch specific translations. +Each of the entries in these maps translates a CUDA string to a tuple containing the +ROCm/HIP string, a type and API annotation and - optionally - an annotation if it is not +supported in ROCm/HIP yet. +""" + +_IS_FBCODE = os.environ.get("IS_FBCODE", "0") == "1" + +# FBCODE compiles against rccl sources instead of an installed rccl package. +# The header location is src/rccl.h versus rccl/rccl.h, respectively. +_RCCL_HEADER = "" if _IS_FBCODE else "" + +# List of math functions that should be replaced inside device code only. +MATH_TRANSPILATIONS = collections.OrderedDict( + [ + ("std::max", ("::max")), + ("std::min", ("::min")), + ("std::ceil", ("::ceil")), + ("std::floor", ("::floor")), + ("std::exp", ("::exp")), + ("std::log", ("::log")), + ("std::pow", ("::pow")), + ("std::fabs", ("::fabs")), + ("std::fmod", ("::fmod")), + ("std::remainder", ("::remainder")), + ("std::frexp", ("::frexp")), + ] +) + +CUDA_TYPE_NAME_MAP = collections.OrderedDict( + [ + ("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)), + ("cudaError_t", ("hipError_t", CONV_TYPE, API_RUNTIME)), + ("cudaError", ("hipError_t", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ARRAY3D_DESCRIPTOR", + ("HIP_ARRAY3D_DESCRIPTOR", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUDA_ARRAY_DESCRIPTOR", ("HIP_ARRAY_DESCRIPTOR", CONV_TYPE, API_DRIVER)), + ("CUDA_MEMCPY2D", ("hip_Memcpy2D", CONV_TYPE, API_DRIVER)), + ("CUDA_MEMCPY3D", ("HIP_MEMCPY3D", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUDA_MEMCPY3D_PEER", + ("HIP_MEMCPY3D_PEER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_POINTER_ATTRIBUTE_P2P_TOKENS", + ( + "HIP_POINTER_ATTRIBUTE_P2P_TOKENS", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_RESOURCE_DESC", + ("HIP_RESOURCE_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_RESOURCE_VIEW_DESC", + ("HIP_RESOURCE_VIEW_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUipcEventHandle", + ("hipIpcEventHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUipcMemHandle", ("hipIpcMemHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUaddress_mode", ("hipAddress_mode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUarray_cubemap_face", + ("hipArray_cubemap_face", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUarray_format", ("hipArray_format", CONV_TYPE, API_DRIVER)), + ("CUcomputemode", ("hipComputemode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUmem_advise", ("hipMemAdvise", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUmem_range_attribute", + ("hipMemRangeAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUctx_flags", ("hipCctx_flags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUdevice", ("hipDevice_t", CONV_TYPE, API_DRIVER)), + ("CUdevice_attribute_enum", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), + ("CUdevice_attribute", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), + ("CUpointer_attribute", ("hipPointer_attribute", CONV_TYPE, API_DRIVER)), + ("CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL", ("HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL", CONV_TYPE, API_DRIVER)), + ("CU_POINTER_ATTRIBUTE_BUFFER_ID", ("HIP_POINTER_ATTRIBUTE_BUFFER_ID", CONV_TYPE, API_DRIVER)), + ("CUdeviceptr", ("hipDeviceptr_t", CONV_TYPE, API_DRIVER)), + ("CUarray_st", ("hipArray", CONV_TYPE, API_DRIVER)), + ("CUarray", ("hipArray *", CONV_TYPE, API_DRIVER)), + ("CUdevprop_st", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), + ("CUdevprop", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), + ("CUfunction", ("hipFunction_t", CONV_TYPE, API_DRIVER)), + ( + "CUgraphicsResource", + ("hipGraphicsResource_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUmipmappedArray", + ("hipMipmappedArray_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUfunction_attribute", + ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUfunction_attribute_enum", + ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsMapResourceFlags", + ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsMapResourceFlags_enum", + ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsRegisterFlags", + ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsRegisterFlags_enum", + ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUoccupancy_flags", + ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUoccupancy_flags_enum", + ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUfunc_cache_enum", ("hipFuncCache", CONV_TYPE, API_DRIVER)), + ("CUfunc_cache", ("hipFuncCache", CONV_TYPE, API_DRIVER)), + ("CUipcMem_flags", ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUipcMem_flags_enum", + ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_cacheMode", ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjit_cacheMode_enum", + ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_fallback", ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjit_fallback_enum", + ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_option", ("hipJitOption", CONV_JIT, API_DRIVER)), + ("CUjit_option_enum", ("hipJitOption", CONV_JIT, API_DRIVER)), + ("CUjit_target", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUjit_target_enum", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUjitInputType", ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjitInputType_enum", + ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUlimit", ("hipLimit_t", CONV_TYPE, API_DRIVER)), + ("CUlimit_enum", ("hipLimit_t", CONV_TYPE, API_DRIVER)), + ( + "CUmemAttach_flags", + ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUmemAttach_flags_enum", + ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUmemorytype", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUmemorytype_enum", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUresourcetype", ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUresourcetype_enum", + ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUresourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), + ("CUresourceViewFormat_enum", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), + ("CUsharedconfig", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), + ("CUsharedconfig_enum", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), + ("CUcontext", ("hipCtx_t", CONV_TYPE, API_DRIVER)), + ("CUmodule", ("hipModule_t", CONV_TYPE, API_DRIVER)), + ("CUstream", ("hipStream_t", CONV_TYPE, API_DRIVER)), + ("CUstream_st", ("ihipStream_t", CONV_TYPE, API_DRIVER)), + ("CUstreamCallback", ("hipStreamCallback_t", CONV_TYPE, API_DRIVER)), + ("CUsurfObject", ("hipSurfaceObject", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUsurfref", + ("hipSurfaceReference_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUtexObject", ("hipTextureObject_t", CONV_TYPE, API_DRIVER)), + ("CUtexref", ("textureReference", CONV_TYPE, API_DRIVER)), + ("CUstream_flags", ("hipStreamFlags", CONV_TYPE, API_DRIVER)), + ( + "CUstreamWaitValue_flags", + ("hipStreamWaitValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUstreamWriteValue_flags", + ("hipStreamWriteValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUstreamBatchMemOpType", + ("hipStreamBatchMemOpType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUdevice_P2PAttribute", + ("hipDeviceP2PAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUevent", ("hipEvent_t", CONV_TYPE, API_DRIVER)), + ("CUevent_st", ("ihipEvent_t", CONV_TYPE, API_DRIVER)), + ("CUevent_flags", ("hipEventFlags", CONV_EVENT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUfilter_mode", ("hipTextureFilterMode", CONV_TEX, API_DRIVER)), + ("CUGLDeviceList", ("hipGLDeviceList", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ("CUGLmap_flags", ("hipGLMapFlags", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUd3d9DeviceList", + ("hipD3D9DeviceList", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d9map_flags", + ("hipD3D9MapFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d9register_flags", + ("hipD3D9RegisterFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10DeviceList", + ("hipd3d10DeviceList", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10map_flags", + ("hipD3D10MapFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10register_flags", + ("hipD3D10RegisterFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d11DeviceList", + ("hipd3d11DeviceList", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUeglStreamConnection_st", + ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUeglStreamConnection", + ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "libraryPropertyType_t", + ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "libraryPropertyType", + ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaStreamCallback_t", ("hipStreamCallback_t", CONV_TYPE, API_RUNTIME)), + ("cudaArray", ("hipArray", CONV_MEM, API_RUNTIME)), + ("cudaArray_t", ("hipArray_t", CONV_MEM, API_RUNTIME)), + ("cudaArray_const_t", ("hipArray_const_t", CONV_MEM, API_RUNTIME)), + ("cudaMipmappedArray_t", ("hipMipmappedArray_t", CONV_MEM, API_RUNTIME)), + ( + "cudaMipmappedArray_const_t", + ("hipMipmappedArray_const_t", CONV_MEM, API_RUNTIME), + ), + ("cudaArrayDefault", ("hipArrayDefault", CONV_MEM, API_RUNTIME)), + ("cudaArrayLayered", ("hipArrayLayered", CONV_MEM, API_RUNTIME)), + ( + "cudaArraySurfaceLoadStore", + ("hipArraySurfaceLoadStore", CONV_MEM, API_RUNTIME), + ), + ("cudaArrayCubemap", ("hipArrayCubemap", CONV_MEM, API_RUNTIME)), + ("cudaArrayTextureGather", ("hipArrayTextureGather", CONV_MEM, API_RUNTIME)), + ("cudaMemoryAdvise", ("hipMemoryAdvise", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemRangeAttribute", + ("hipMemRangeAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpyKind", ("hipMemcpyKind", CONV_MEM, API_RUNTIME)), + ("cudaMemoryType", ("hipMemoryType", CONV_MEM, API_RUNTIME)), + ("cudaExtent", ("hipExtent", CONV_MEM, API_RUNTIME)), + ("cudaPitchedPtr", ("hipPitchedPtr", CONV_MEM, API_RUNTIME)), + ("cudaPos", ("hipPos", CONV_MEM, API_RUNTIME)), + ("cudaEvent_t", ("hipEvent_t", CONV_TYPE, API_RUNTIME)), + ("cudaStream_t", ("hipStream_t", CONV_TYPE, API_RUNTIME)), + ("cudaPointerAttributes", ("hipPointerAttribute_t", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceAttr", ("hipDeviceAttribute_t", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceProp", ("hipDeviceProp_t", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceP2PAttr", + ("hipDeviceP2PAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeMode", + ("hipComputeMode", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaFuncCache", ("hipFuncCache_t", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncAttributes", + ("hipFuncAttributes", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaSharedMemConfig", ("hipSharedMemConfig", CONV_TYPE, API_RUNTIME)), + ("cudaLimit", ("hipLimit_t", CONV_TYPE, API_RUNTIME)), + ("cudaOutputMode", ("hipOutputMode", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaTextureReadMode", ("hipTextureReadMode", CONV_TEX, API_RUNTIME)), + ("cudaTextureFilterMode", ("hipTextureFilterMode", CONV_TEX, API_RUNTIME)), + ("cudaChannelFormatKind", ("hipChannelFormatKind", CONV_TEX, API_RUNTIME)), + ("cudaChannelFormatDesc", ("hipChannelFormatDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceDesc", ("hipResourceDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceViewDesc", ("hipResourceViewDesc", CONV_TEX, API_RUNTIME)), + ("cudaTextureDesc", ("hipTextureDesc", CONV_TEX, API_RUNTIME)), + ( + "surfaceReference", + ("hipSurfaceReference", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaTextureObject_t", ("hipTextureObject_t", CONV_TEX, API_RUNTIME)), + ("cudaResourceType", ("hipResourceType", CONV_TEX, API_RUNTIME)), + ("cudaResourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_RUNTIME)), + ("cudaTextureAddressMode", ("hipTextureAddressMode", CONV_TEX, API_RUNTIME)), + ( + "cudaSurfaceBoundaryMode", + ("hipSurfaceBoundaryMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSurfaceFormatMode", + ("hipSurfaceFormatMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaTextureType1D", ("hipTextureType1D", CONV_TEX, API_RUNTIME)), + ("cudaTextureType2D", ("hipTextureType2D", CONV_TEX, API_RUNTIME)), + ("cudaTextureType3D", ("hipTextureType3D", CONV_TEX, API_RUNTIME)), + ("cudaTextureTypeCubemap", ("hipTextureTypeCubemap", CONV_TEX, API_RUNTIME)), + ( + "cudaTextureType1DLayered", + ("hipTextureType1DLayered", CONV_TEX, API_RUNTIME), + ), + ( + "cudaTextureType2DLayered", + ("hipTextureType2DLayered", CONV_TEX, API_RUNTIME), + ), + ( + "cudaTextureTypeCubemapLayered", + ("hipTextureTypeCubemapLayered", CONV_TEX, API_RUNTIME), + ), + ("cudaIpcEventHandle_t", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcEventHandle_st", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcMemHandle_t", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcMemHandle_st", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphicsCubeFace", + ("hipGraphicsCubeFace", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapFlags", + ("hipGraphicsMapFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsRegisterFlags", + ("hipGraphicsRegisterFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceList", + ("hipGLDeviceList", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaGLMapFlags", ("hipGLMapFlags", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaD3D9DeviceList", + ("hipD3D9DeviceList", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapFlags", + ("hipD3D9MapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterFlags", + ("hipD3D9RegisterFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceList", + ("hipd3d10DeviceList", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10MapFlags", + ("hipD3D10MapFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterFlags", + ("hipD3D10RegisterFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceList", + ("hipd3d11DeviceList", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEglStreamConnection", + ("hipEglStreamConnection", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cublasHandle_t", ("hipblasHandle_t", CONV_TYPE, API_BLAS)), + ("cublasOperation_t", ("hipblasOperation_t", CONV_TYPE, API_BLAS)), + ("cublasStatus_t", ("hipblasStatus_t", CONV_TYPE, API_BLAS)), + ("cublasFillMode_t", ("hipblasFillMode_t", CONV_TYPE, API_BLAS)), + ("cublasDiagType_t", ("hipblasDiagType_t", CONV_TYPE, API_BLAS)), + ("cublasSideMode_t", ("hipblasSideMode_t", CONV_TYPE, API_BLAS)), + ("cublasPointerMode_t", ("hipblasPointerMode_t", CONV_TYPE, API_BLAS)), + ("cublasGemmAlgo_t", ("hipblasGemmAlgo_t", CONV_TYPE, API_BLAS)), + ( + "cublasAtomicsMode_t", + ("hipblasAtomicsMode_t", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDataType_t", + ("hipblasDatatype_t", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), + ), + ("curandStatus", ("hiprandStatus_t", CONV_TYPE, API_RAND)), + ("curandStatus_t", ("hiprandStatus_t", CONV_TYPE, API_RAND)), + ("curandRngType", ("hiprandRngType_t", CONV_TYPE, API_RAND)), + ("curandRngType_t", ("hiprandRngType_t", CONV_TYPE, API_RAND)), + ("curandGenerator_st", ("hiprandGenerator_st", CONV_TYPE, API_RAND)), + ("curandGenerator_t", ("hiprandGenerator_t", CONV_TYPE, API_RAND)), + ( + "curandDirectionVectorSet", + ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDirectionVectorSet_t", + ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandOrdering", ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ( + "curandOrdering_t", + ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistribution_st", + ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2V_st", + ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistribution_t", + ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2V_t", + ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionShift_st", + ("hiprandDistributionShift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionShift_t", + ("hiprandDistributionShift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionM2Shift_st", + ("hiprandDistributionM2Shift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionM2Shift_t", + ("hiprandDistributionM2Shift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2_st", + ("hiprandHistogramM2_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2_t", + ("hiprandHistogramM2_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2K_st", + ("hiprandHistogramM2K_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2K_t", + ("hiprandHistogramM2K_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDiscreteDistribution_st", + ("hiprandDiscreteDistribution_st", CONV_TYPE, API_RAND), + ), + ( + "curandDiscreteDistribution_t", + ("hiprandDiscreteDistribution_t", CONV_TYPE, API_RAND), + ), + ("curandMethod", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ("curandMethod_t", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ( + "curandDirectionVectors32_t", + ("hiprandDirectionVectors32_t", CONV_TYPE, API_RAND), + ), + ( + "curandDirectionVectors64_t", + ("hiprandDirectionVectors64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandStateMtgp32_t", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), + ("curandStateMtgp32", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), + ( + "curandStateScrambledSobol64_t", + ("hiprandStateScrambledSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandStateSobol64_t", + ("hiprandStateSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandStateScrambledSobol32_t", + ("hiprandStateScrambledSobol32_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandStateSobol32_t", ("hiprandStateSobol32_t", CONV_TYPE, API_RAND)), + ("curandStateMRG32k3a_t", ("hiprandStateMRG32k3a_t", CONV_TYPE, API_RAND)), + ( + "curandStatePhilox4_32_10_t", + ("hiprandStatePhilox4_32_10_t", CONV_TYPE, API_RAND), + ), + ("curandStateXORWOW_t", ("hiprandStateXORWOW_t", CONV_TYPE, API_RAND)), + ("curandState_t", ("hiprandState_t", CONV_TYPE, API_RAND)), + ("curandState", ("hiprandState_t", CONV_TYPE, API_RAND)), + ("CUuuid", ("hipUUID", CONV_TYPE, API_RUNTIME)), + ("cudaGraph_t", ("hipGraph_t", CONV_TYPE, API_RAND)), + ("cudaGraphExec_t", ("hipGraphExec_t", CONV_TYPE, API_RAND)), + ("__nv_bfloat16", ("__hip_bfloat16", CONV_TYPE, API_RUNTIME)), + ("__nv_bfloat162", ("__hip_bfloat162", CONV_TYPE, API_RUNTIME)), + ] +) + +CUDA_INCLUDE_MAP = collections.OrderedDict( + [ + # since pytorch uses "\b{pattern}\b" as the actual re pattern, + # patterns listed here have to begin and end with alnum chars + ( + "include " to differentiate + ("", (_RCCL_HEADER, CONV_INCLUDE, API_RUNTIME)), + ("nvrtc.h", ("hip/hiprtc.h", CONV_INCLUDE, API_RTC)), + ("thrust/system/cuda", ("thrust/system/hip", CONV_INCLUDE, API_BLAS)), + ("cub/util_allocator.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_raking_layout.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/cub.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/config.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/util_ptx.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/util_type.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_run_length_encode.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_load.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_store.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_radix_sort.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_select.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("nvtx3/nvtx3.hpp", ("roctracer/roctx.h", CONV_INCLUDE, API_ROCTX)), + ("nvml.h", ("rocm_smi/rocm_smi.h", CONV_INCLUDE, API_ROCMSMI)), + ] +) + +CUDA_IDENTIFIER_MAP = collections.OrderedDict( + [ + ("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)), + ( + "CUDA_ERROR_INVALID_CONTEXT", + ("hipErrorInvalidContext", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_CONTEXT_ALREADY_CURRENT", + ("hipErrorContextAlreadyCurrent", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_ARRAY_IS_MAPPED", + ("hipErrorArrayIsMapped", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_ALREADY_MAPPED", ("hipErrorAlreadyMapped", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_ALREADY_ACQUIRED", + ("hipErrorAlreadyAcquired", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_NOT_MAPPED", ("hipErrorNotMapped", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_NOT_MAPPED_AS_ARRAY", + ("hipErrorNotMappedAsArray", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_NOT_MAPPED_AS_POINTER", + ("hipErrorNotMappedAsPointer", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_CONTEXT_ALREADY_IN_USE", + ("hipErrorContextAlreadyInUse", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_INVALID_SOURCE", ("hipErrorInvalidSource", CONV_TYPE, API_DRIVER)), + ("CUDA_ERROR_FILE_NOT_FOUND", ("hipErrorFileNotFound", CONV_TYPE, API_DRIVER)), + ("CUDA_ERROR_NOT_FOUND", ("hipErrorNotFound", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING", + ( + "hipErrorLaunchIncompatibleTexturing", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE", + ("hipErrorPrimaryContextActive", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_CONTEXT_IS_DESTROYED", + ("hipErrorContextIsDestroyed", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NOT_PERMITTED", + ("hipErrorNotPermitted", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NOT_SUPPORTED", + ("hipErrorNotSupported", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMissingConfiguration", + ("hipErrorMissingConfiguration", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorPriorLaunchFailure", + ("hipErrorPriorLaunchFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidDeviceFunction", + ("hipErrorInvalidDeviceFunction", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidConfiguration", + ("hipErrorInvalidConfiguration", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidPitchValue", + ("hipErrorInvalidPitchValue", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidSymbol", + ("hipErrorInvalidSymbol", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidHostPointer", + ("hipErrorInvalidHostPointer", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidDevicePointer", + ("hipErrorInvalidDevicePointer", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaErrorInvalidTexture", + ("hipErrorInvalidTexture", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidTextureBinding", + ("hipErrorInvalidTextureBinding", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidChannelDescriptor", + ( + "hipErrorInvalidChannelDescriptor", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorInvalidMemcpyDirection", + ("hipErrorInvalidMemcpyDirection", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorAddressOfConstant", + ("hipErrorAddressOfConstant", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTextureFetchFailed", + ("hipErrorTextureFetchFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTextureNotBound", + ("hipErrorTextureNotBound", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSynchronizationError", + ("hipErrorSynchronizationError", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidFilterSetting", + ("hipErrorInvalidFilterSetting", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidNormSetting", + ("hipErrorInvalidNormSetting", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMixedDeviceExecution", + ("hipErrorMixedDeviceExecution", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNotYetImplemented", + ("hipErrorNotYetImplemented", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMemoryValueTooLarge", + ("hipErrorMemoryValueTooLarge", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInsufficientDriver", + ("hipErrorInsufficientDriver", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSetOnActiveProcess", + ("hipErrorSetOnActiveProcess", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorContextIsDestroyed", + ("hipErrorContextIsDestroyed", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaErrorInvalidSurface", + ("hipErrorInvalidSurface", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateVariableName", + ("hipErrorDuplicateVariableName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateTextureName", + ("hipErrorDuplicateTextureName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateSurfaceName", + ("hipErrorDuplicateSurfaceName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDevicesUnavailable", + ("hipErrorDevicesUnavailable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorIncompatibleDriverContext", + ( + "hipErrorIncompatibleDriverContext", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorDeviceAlreadyInUse", + ("hipErrorDeviceAlreadyInUse", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchMaxDepthExceeded", + ("hipErrorLaunchMaxDepthExceeded", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFileScopedTex", + ("hipErrorLaunchFileScopedTex", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFileScopedSurf", + ("hipErrorLaunchFileScopedSurf", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSyncDepthExceeded", + ("hipErrorSyncDepthExceeded", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchPendingCountExceeded", + ( + "hipErrorLaunchPendingCountExceeded", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorNotPermitted", + ("hipErrorNotPermitted", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNotSupported", + ("hipErrorNotSupported", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorStartupFailure", + ("hipErrorStartupFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorApiFailureBase", + ("hipErrorApiFailureBase", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_SUCCESS", ("hipSuccess", CONV_TYPE, API_DRIVER)), + ("cudaSuccess", ("hipSuccess", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_VALUE", ("hipErrorInvalidValue", CONV_TYPE, API_DRIVER)), + ("cudaErrorInvalidValue", ("hipErrorInvalidValue", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ERROR_OUT_OF_MEMORY", + ("hipErrorMemoryAllocation", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorMemoryAllocation", + ("hipErrorMemoryAllocation", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_NOT_INITIALIZED", + ("hipErrorNotInitialized", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInitializationError", + ("hipErrorInitializationError", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_DEINITIALIZED", ("hipErrorDeinitialized", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorCudartUnloading", + ("hipErrorDeinitialized", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_DISABLED", + ("hipErrorProfilerDisabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerDisabled", + ("hipErrorProfilerDisabled", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_NOT_INITIALIZED", + ("hipErrorProfilerNotInitialized", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerNotInitialized", + ("hipErrorProfilerNotInitialized", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_ALREADY_STARTED", + ("hipErrorProfilerAlreadyStarted", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerAlreadyStarted", + ("hipErrorProfilerAlreadyStarted", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_ALREADY_STOPPED", + ("hipErrorProfilerAlreadyStopped", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerAlreadyStopped", + ("hipErrorProfilerAlreadyStopped", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_NO_DEVICE", ("hipErrorNoDevice", CONV_TYPE, API_DRIVER)), + ("cudaErrorNoDevice", ("hipErrorNoDevice", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_DEVICE", ("hipErrorInvalidDevice", CONV_TYPE, API_DRIVER)), + ("cudaErrorInvalidDevice", ("hipErrorInvalidDevice", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_IMAGE", ("hipErrorInvalidImage", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorInvalidKernelImage", + ("hipErrorInvalidImage", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_MAP_FAILED", ("hipErrorMapFailed", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorMapBufferObjectFailed", + ("hipErrorMapFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_UNMAP_FAILED", ("hipErrorUnmapFailed", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorUnmapBufferObjectFailed", + ("hipErrorUnmapFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NO_BINARY_FOR_GPU", + ("hipErrorNoBinaryForGpu", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorNoKernelImageForDevice", + ("hipErrorNoBinaryForGpu", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_ECC_UNCORRECTABLE", + ("hipErrorECCNotCorrectable", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorECCUncorrectable", + ("hipErrorECCNotCorrectable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_UNSUPPORTED_LIMIT", + ("hipErrorUnsupportedLimit", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorUnsupportedLimit", + ("hipErrorUnsupportedLimit", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PEER_ACCESS_UNSUPPORTED", + ("hipErrorPeerAccessUnsupported", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessUnsupported", + ("hipErrorPeerAccessUnsupported", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_PTX", + ("hipErrorInvalidKernelFile", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidPtx", + ("hipErrorInvalidKernelFile", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_GRAPHICS_CONTEXT", + ("hipErrorInvalidGraphicsContext", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidGraphicsContext", + ("hipErrorInvalidGraphicsContext", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NVLINK_UNCORRECTABLE", + ("hipErrorNvlinkUncorrectable", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNvlinkUncorrectable", + ("hipErrorNvlinkUncorrectable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND", + ("hipErrorSharedObjectSymbolNotFound", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorSharedObjectSymbolNotFound", + ( + "hipErrorSharedObjectSymbolNotFound", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_ERROR_SHARED_OBJECT_INIT_FAILED", + ("hipErrorSharedObjectInitFailed", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorSharedObjectInitFailed", + ("hipErrorSharedObjectInitFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_OPERATING_SYSTEM", + ("hipErrorOperatingSystem", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorOperatingSystem", + ("hipErrorOperatingSystem", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_HANDLE", + ("hipErrorInvalidResourceHandle", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidResourceHandle", + ("hipErrorInvalidResourceHandle", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_NOT_READY", ("hipErrorNotReady", CONV_TYPE, API_DRIVER)), + ("cudaErrorNotReady", ("hipErrorNotReady", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ERROR_ILLEGAL_ADDRESS", + ("hipErrorIllegalAddress", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorIllegalAddress", + ("hipErrorIllegalAddress", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES", + ("hipErrorLaunchOutOfResources", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorLaunchOutOfResources", + ("hipErrorLaunchOutOfResources", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_LAUNCH_TIMEOUT", ("hipErrorLaunchTimeOut", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorLaunchTimeout", + ("hipErrorLaunchTimeOut", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED", + ("hipErrorPeerAccessAlreadyEnabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessAlreadyEnabled", + ("hipErrorPeerAccessAlreadyEnabled", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED", + ("hipErrorPeerAccessNotEnabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessNotEnabled", + ("hipErrorPeerAccessNotEnabled", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_ASSERT", + ("hipErrorAssert", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorAssert", + ("hipErrorAssert", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_TOO_MANY_PEERS", + ("hipErrorTooManyPeers", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTooManyPeers", + ("hipErrorTooManyPeers", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED", + ("hipErrorHostMemoryAlreadyRegistered", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorHostMemoryAlreadyRegistered", + ("hipErrorHostMemoryAlreadyRegistered", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED", + ("hipErrorHostMemoryNotRegistered", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorHostMemoryNotRegistered", + ("hipErrorHostMemoryNotRegistered", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_HARDWARE_STACK_ERROR", + ("hipErrorHardwareStackError", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorHardwareStackError", + ("hipErrorHardwareStackError", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_ILLEGAL_INSTRUCTION", + ("hipErrorIllegalInstruction", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorIllegalInstruction", + ("hipErrorIllegalInstruction", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_MISALIGNED_ADDRESS", + ("hipErrorMisalignedAddress", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMisalignedAddress", + ("hipErrorMisalignedAddress", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_ADDRESS_SPACE", + ("hipErrorInvalidAddressSpace", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidAddressSpace", + ("hipErrorInvalidAddressSpace", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_PC", + ("hipErrorInvalidPc", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidPc", + ("hipErrorInvalidPc", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_LAUNCH_FAILED", + ("hipErrorLaunchFailure", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFailure", + ("hipErrorLaunchFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_UNKNOWN", + ("hipErrorUnknown", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cudaErrorUnknown", ("hipErrorUnknown", CONV_TYPE, API_RUNTIME)), + ( + "CU_TR_ADDRESS_MODE_WRAP", + ("HIP_TR_ADDRESS_MODE_WRAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_CLAMP", + ("HIP_TR_ADDRESS_MODE_CLAMP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_MIRROR", + ("HIP_TR_ADDRESS_MODE_MIRROR", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_BORDER", + ("HIP_TR_ADDRESS_MODE_BORDER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_X", + ("HIP_CUBEMAP_FACE_POSITIVE_X", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_X", + ("HIP_CUBEMAP_FACE_NEGATIVE_X", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_Y", + ("HIP_CUBEMAP_FACE_POSITIVE_Y", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_Y", + ("HIP_CUBEMAP_FACE_NEGATIVE_Y", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_Z", + ("HIP_CUBEMAP_FACE_POSITIVE_Z", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_Z", + ("HIP_CUBEMAP_FACE_NEGATIVE_Z", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT8", + ("HIP_AD_FORMAT_UNSIGNED_INT8", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT16", + ("HIP_AD_FORMAT_UNSIGNED_INT16", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT32", + ("HIP_AD_FORMAT_UNSIGNED_INT32", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT8", + ("HIP_AD_FORMAT_SIGNED_INT8", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT16", + ("HIP_AD_FORMAT_SIGNED_INT16", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT32", + ("HIP_AD_FORMAT_SIGNED_INT32", CONV_TYPE, API_DRIVER), + ), + ("CU_AD_FORMAT_HALF", ("HIP_AD_FORMAT_HALF", CONV_TYPE, API_DRIVER)), + ("CU_AD_FORMAT_FLOAT", ("HIP_AD_FORMAT_FLOAT", CONV_TYPE, API_DRIVER)), + ( + "CU_COMPUTEMODE_DEFAULT", + ("hipComputeModeDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_EXCLUSIVE", + ("hipComputeModeExclusive", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_PROHIBITED", + ("hipComputeModeProhibited", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_EXCLUSIVE_PROCESS", + ("hipComputeModeExclusiveProcess", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_SET_READ_MOSTLY", + ("hipMemAdviseSetReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_UNSET_READ_MOSTLY", + ("hipMemAdviseUnsetReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_SET_PREFERRED_LOCATION", + ( + "hipMemAdviseSetPreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION", + ( + "hipMemAdviseUnsetPreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_ADVISE_SET_ACCESSED_BY", + ("hipMemAdviseSetAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_UNSET_ACCESSED_BY", + ("hipMemAdviseUnsetAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY", + ("hipMemRangeAttributeReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION", + ( + "hipMemRangeAttributePreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY", + ("hipMemRangeAttributeAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION", + ( + "hipMemRangeAttributeLastPrefetchLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_CTX_SCHED_AUTO", + ("HIP_CTX_SCHED_AUTO", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_SPIN", + ("HIP_CTX_SCHED_SPIN", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_YIELD", + ("HIP_CTX_SCHED_YIELD", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_BLOCKING_SYNC", + ("HIP_CTX_SCHED_BLOCKING_SYNC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_BLOCKING_SYNC", + ("HIP_CTX_BLOCKING_SYNC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_MASK", + ("HIP_CTX_SCHED_MASK", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_MAP_HOST", + ("HIP_CTX_MAP_HOST", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_LMEM_RESIZE_TO_MAX", + ("HIP_CTX_LMEM_RESIZE_TO_MAX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_FLAGS_MASK", + ("HIP_CTX_FLAGS_MASK", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LAUNCH_PARAM_BUFFER_POINTER", + ("HIP_LAUNCH_PARAM_BUFFER_POINTER", CONV_TYPE, API_DRIVER), + ), + ( + "CU_LAUNCH_PARAM_BUFFER_SIZE", + ("HIP_LAUNCH_PARAM_BUFFER_SIZE", CONV_TYPE, API_DRIVER), + ), + ("CU_LAUNCH_PARAM_END", ("HIP_LAUNCH_PARAM_END", CONV_TYPE, API_DRIVER)), + ( + "CU_IPC_HANDLE_SIZE", + ("HIP_IPC_HANDLE_SIZE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_DEVICEMAP", + ("HIP_MEMHOSTALLOC_DEVICEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_PORTABLE", + ("HIP_MEMHOSTALLOC_PORTABLE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_WRITECOMBINED", + ("HIP_MEMHOSTALLOC_WRITECOMBINED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_DEVICEMAP", + ("HIP_MEMHOSTREGISTER_DEVICEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_IOMEMORY", + ("HIP_MEMHOSTREGISTER_IOMEMORY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_PORTABLE", + ("HIP_MEMHOSTREGISTER_PORTABLE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PARAM_TR_DEFAULT", + ("HIP_PARAM_TR_DEFAULT", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_LEGACY", + ("HIP_STREAM_LEGACY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_PER_THREAD", + ("HIP_STREAM_PER_THREAD", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSA_OVERRIDE_FORMAT", + ("HIP_TRSA_OVERRIDE_FORMAT", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSF_NORMALIZED_COORDINATES", + ("HIP_TRSF_NORMALIZED_COORDINATES", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSF_READ_AS_INTEGER", + ("HIP_TRSF_READ_AS_INTEGER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_TRSF_SRGB", ("HIP_TRSF_SRGB", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUDA_ARRAY3D_2DARRAY", + ("HIP_ARRAY3D_LAYERED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_CUBEMAP", + ("HIP_ARRAY3D_CUBEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_DEPTH_TEXTURE", + ("HIP_ARRAY3D_DEPTH_TEXTURE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_LAYERED", + ("HIP_ARRAY3D_LAYERED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_SURFACE_LDST", + ("HIP_ARRAY3D_SURFACE_LDST", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_TEXTURE_GATHER", + ("HIP_ARRAY3D_TEXTURE_GATHER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + ( + "hipDeviceAttributeMaxThreadsPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X", + ("hipDeviceAttributeMaxBlockDimX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y", + ("hipDeviceAttributeMaxBlockDimY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z", + ("hipDeviceAttributeMaxBlockDimZ", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X", + ("hipDeviceAttributeMaxGridDimX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y", + ("hipDeviceAttributeMaxGridDimY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z", + ("hipDeviceAttributeMaxGridDimZ", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK", + ( + "hipDeviceAttributeMaxSharedMemoryPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SHARED_MEMORY_PER_BLOCK", + ( + "hipDeviceAttributeMaxSharedMemoryPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY", + ( + "hipDeviceAttributeTotalConstantMemory", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_WARP_SIZE", + ("hipDeviceAttributeWarpSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_PITCH", + ("hipDeviceAttributeMaxPitch", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK", + ( + "hipDeviceAttributeMaxRegistersPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_REGISTERS_PER_BLOCK", + ( + "hipDeviceAttributeMaxRegistersPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CLOCK_RATE", + ("hipDeviceAttributeClockRate", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT", + ( + "hipDeviceAttributeTextureAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GPU_OVERLAP", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT", + ( + "hipDeviceAttributeMultiprocessorCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT", + ( + "hipDeviceAttributeKernelExecTimeout", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_INTEGRATED", + ("hipDeviceAttributeIntegrated", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY", + ( + "hipDeviceAttributeCanMapHostMemory", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_MODE", + ("hipDeviceAttributeComputeMode", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH", + ( + "hipDeviceAttributeMaxTexture3DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT", + ( + "hipDeviceAttributeMaxTexture3DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH", + ( + "hipDeviceAttributeMaxTexture3DDepth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_NUMSLICES", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SURFACE_ALIGNMENT", + ( + "hipDeviceAttributeSurfaceAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS", + ("hipDeviceAttributeConcurrentKernels", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_ECC_ENABLED", + ("hipDeviceAttributeEccEnabled", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_BUS_ID", + ("hipDeviceAttributePciBusId", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID", + ("hipDeviceAttributePciDeviceId", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_TCC_DRIVER", + ("hipDeviceAttributeTccDriver", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE", + ( + "hipDeviceAttributeMemoryClockRate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH", + ("hipDeviceAttributeMemoryBusWidth", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE", + ("hipDeviceAttributeL2CacheSize", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR", + ("hipDeviceAttributeMaxThreadsPerMultiProcessor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING", + ( + "hipDeviceAttributeUnifiedAddressing", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTexture1DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_TEX2D_GATHER", + ( + "hipDeviceAttributeCanTex2DGather", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DGatherWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DGatherHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DWidthAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DHeightAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DDepthAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID", + ("hipDeviceAttributePciDomainId", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT", + ( + "hipDeviceAttributeTexturePitchAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH", + ( + "hipDeviceAttributeMaxTextureCubemapWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH", + ( + "hipDeviceAttributeMaxSurface1DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH", + ( + "hipDeviceAttributeMaxSurface2DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT", + ( + "hipDeviceAttributeMaxSurface2DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH", + ( + "hipDeviceAttributeMaxSurface3DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT", + ( + "hipDeviceAttributeMaxSurface3DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH", + ( + "hipDeviceAttributeMaxSurface3DDepth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurface1DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurface1DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurface2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT", + ( + "hipDeviceAttributeMaxSurface2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurface2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH", + ( + "hipDeviceAttributeMaxSurfaceCubemapWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DLinearWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLinearWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLinearHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH", + ( + "hipDeviceAttributeMaxTexture2DLinearPitch", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DMipmappedWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DMipmappedHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR", + ("hipDeviceAttributeComputeCapabilityMajor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR", + ("hipDeviceAttributeComputeCapabilityMinor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DMipmappedWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_STREAM_PRIORITIES_SUPPORTED", + ( + "hipDeviceAttributeStreamPrioritiesSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED", + ( + "hipDeviceAttributeGlobalL1CacheSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED", + ( + "hipDeviceAttributeLocalL1CacheSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR", + ( + "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + CONV_TYPE, + API_DRIVER, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR", + ( + "hipDeviceAttributeMaxRegistersPerMultiprocessor", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY", + ("hipDeviceAttributeManagedMemory", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD", + ("hipDeviceAttributeIsMultiGpuBoard", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID", + ( + "hipDeviceAttributeMultiGpuBoardGroupId", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED", + ( + "hipDeviceAttributeHostNativeAtomicSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO", + ( + "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS", + ( + "hipDeviceAttributePageableMemoryAccess", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS", + ( + "hipDeviceAttributeConcurrentManagedAccess", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED", + ( + "hipDeviceAttributeComputePreemptionSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM", + ( + "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX", + ("hipDeviceAttributeMax", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_CONTEXT", + ("hipPointerAttributeContext", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_MEMORY_TYPE", + ("hipPointerAttributeMemoryType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_DEVICE_POINTER", + ( + "hipPointerAttributeDevicePointer", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_POINTER_ATTRIBUTE_HOST_POINTER", + ("hipPointerAttributeHostPointer", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_P2P_TOKENS", + ("hipPointerAttributeP2pTokens", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_SYNC_MEMOPS", + ("hipPointerAttributeSyncMemops", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_BUFFER_ID", + ("hipPointerAttributeBufferId", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_IS_MANAGED", + ("hipPointerAttributeIsManaged", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + ( + "hipFuncAttributeMaxThreadsPerBlocks", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES", + ("hipFuncAttributeSharedSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES", + ("hipFuncAttributeMaxDynamicSharedMemorySize", CONV_TYPE, API_RUNTIME), + ), + ( + "CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES", + ("hipFuncAttributeConstSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES", + ("hipFuncAttributeLocalSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_NUM_REGS", + ("hipFuncAttributeNumRegs", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_PTX_VERSION", + ("hipFuncAttributePtxVersion", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_BINARY_VERSION", + ("hipFuncAttributeBinaryVersion", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_CACHE_MODE_CA", + ("hipFuncAttributeCacheModeCA", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX", + ("hipFuncAttributeMax", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE", + ("hipGraphicsMapFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_READ_ONLY", + ("hipGraphicsMapFlagsReadOnly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + ("hipGraphicsMapFlagsWriteDiscard", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_NONE", + ("hipGraphicsRegisterFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY", + ( + "hipGraphicsRegisterFlagsReadOnly", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD", + ( + "hipGraphicsRegisterFlagsWriteDiscard", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST", + ( + "hipGraphicsRegisterFlagsSurfaceLoadStore", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER", + ( + "hipGraphicsRegisterFlagsTextureGather", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_OCCUPANCY_DEFAULT", + ("hipOccupancyDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE", + ( + "hipOccupancyDisableCachingOverride", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_FUNC_CACHE_PREFER_NONE", + ("hipFuncCachePreferNone", CONV_CACHE, API_DRIVER), + ), + ( + "CU_FUNC_CACHE_PREFER_SHARED", + ("hipFuncCachePreferShared", CONV_CACHE, API_DRIVER), + ), + ("CU_FUNC_CACHE_PREFER_L1", ("hipFuncCachePreferL1", CONV_CACHE, API_DRIVER)), + ( + "CU_FUNC_CACHE_PREFER_EQUAL", + ("hipFuncCachePreferEqual", CONV_CACHE, API_DRIVER), + ), + ( + "CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS", + ("hipIpcMemLazyEnablePeerAccess", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUDA_IPC_HANDLE_SIZE", ("HIP_IPC_HANDLE_SIZE", CONV_TYPE, API_DRIVER)), + ( + "CU_JIT_CACHE_OPTION_NONE", + ("hipJitCacheModeOptionNone", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_CACHE_OPTION_CG", + ("hipJitCacheModeOptionCG", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_CACHE_OPTION_CA", + ("hipJitCacheModeOptionCA", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PREFER_PTX", + ("hipJitFallbackPreferPtx", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PREFER_BINARY", + ("hipJitFallbackPreferBinary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_JIT_MAX_REGISTERS", ("hipJitOptionMaxRegisters", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_THREADS_PER_BLOCK", + ("hipJitOptionThreadsPerBlock", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_WALL_TIME", ("hipJitOptionWallTime", CONV_JIT, API_DRIVER)), + ("CU_JIT_INFO_LOG_BUFFER", ("hipJitOptionInfoLogBuffer", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES", + ("hipJitOptionInfoLogBufferSizeBytes", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_ERROR_LOG_BUFFER", + ("hipJitOptionErrorLogBuffer", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES", + ("hipJitOptionErrorLogBufferSizeBytes", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_OPTIMIZATION_LEVEL", + ("hipJitOptionOptimizationLevel", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_TARGET_FROM_CUCONTEXT", + ("hipJitOptionTargetFromContext", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_TARGET", ("hipJitOptionTarget", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_FALLBACK_STRATEGY", + ("hipJitOptionFallbackStrategy", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_GENERATE_DEBUG_INFO", + ("hipJitOptionGenerateDebugInfo", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_LOG_VERBOSE", ("hipJitOptionLogVerbose", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_GENERATE_LINE_INFO", + ("hipJitOptionGenerateLineInfo", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_CACHE_MODE", ("hipJitOptionCacheMode", CONV_JIT, API_DRIVER)), + ("CU_JIT_NEW_SM3X_OPT", ("hipJitOptionSm3xOpt", CONV_JIT, API_DRIVER)), + ("CU_JIT_FAST_COMPILE", ("hipJitOptionFastCompile", CONV_JIT, API_DRIVER)), + ("CU_JIT_NUM_OPTIONS", ("hipJitOptionNumOptions", CONV_JIT, API_DRIVER)), + ( + "CU_TARGET_COMPUTE_10", + ("hipJitTargetCompute10", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_11", + ("hipJitTargetCompute11", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_12", + ("hipJitTargetCompute12", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_13", + ("hipJitTargetCompute13", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_20", + ("hipJitTargetCompute20", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_21", + ("hipJitTargetCompute21", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_30", + ("hipJitTargetCompute30", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_32", + ("hipJitTargetCompute32", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_35", + ("hipJitTargetCompute35", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_37", + ("hipJitTargetCompute37", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_50", + ("hipJitTargetCompute50", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_52", + ("hipJitTargetCompute52", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_53", + ("hipJitTargetCompute53", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_60", + ("hipJitTargetCompute60", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_61", + ("hipJitTargetCompute61", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_62", + ("hipJitTargetCompute62", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_CUBIN", + ("hipJitInputTypeBin", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_PTX", + ("hipJitInputTypePtx", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_FATBINARY", + ("hipJitInputTypeFatBinary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_OBJECT", + ("hipJitInputTypeObject", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_LIBRARY", + ("hipJitInputTypeLibrary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_NUM_INPUT_TYPES", + ("hipJitInputTypeNumInputTypes", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_STACK_SIZE", + ("hipLimitStackSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_PRINTF_FIFO_SIZE", + ("hipLimitPrintfFifoSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_MALLOC_HEAP_SIZE", + ("hipLimitMallocHeapSize", CONV_TYPE, API_DRIVER), + ), + ( + "CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH", + ("hipLimitDevRuntimeSyncDepth", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT", + ( + "hipLimitDevRuntimePendingLaunchCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_LIMIT_STACK_SIZE", + ("hipLimitStackSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_GLOBAL", + ("hipMemAttachGlobal", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_HOST", + ("hipMemAttachHost", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_SINGLE", + ("hipMemAttachSingle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_HOST", + ("hipMemTypeHost", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_DEVICE", + ("hipMemTypeDevice", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_ARRAY", + ("hipMemTypeArray", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_UNIFIED", + ("hipMemTypeUnified", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_ARRAY", + ("hipResourceTypeArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_MIPMAPPED_ARRAY", + ("hipResourceTypeMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_LINEAR", + ("hipResourceTypeLinear", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_PITCH2D", + ("hipResourceTypePitch2D", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_RES_VIEW_FORMAT_NONE", ("hipResViewFormatNone", CONV_TEX, API_DRIVER)), + ( + "CU_RES_VIEW_FORMAT_UINT_1X8", + ("hipResViewFormatUnsignedChar1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X8", + ("hipResViewFormatUnsignedChar2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X8", + ("hipResViewFormatUnsignedChar4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X8", + ("hipResViewFormatSignedChar1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X8", + ("hipResViewFormatSignedChar2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X8", + ("hipResViewFormatSignedChar4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_1X16", + ("hipResViewFormatUnsignedShort1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X16", + ("hipResViewFormatUnsignedShort2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X16", + ("hipResViewFormatUnsignedShort4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X16", + ("hipResViewFormatSignedShort1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X16", + ("hipResViewFormatSignedShort2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X16", + ("hipResViewFormatSignedShort4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_1X32", + ("hipResViewFormatUnsignedInt1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X32", + ("hipResViewFormatUnsignedInt2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X32", + ("hipResViewFormatUnsignedInt4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X32", + ("hipResViewFormatSignedInt1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X32", + ("hipResViewFormatSignedInt2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X32", + ("hipResViewFormatSignedInt4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_1X16", + ("hipResViewFormatHalf1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_2X16", + ("hipResViewFormatHalf2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_4X16", + ("hipResViewFormatHalf4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_1X32", + ("hipResViewFormatFloat1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_2X32", + ("hipResViewFormatFloat2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_4X32", + ("hipResViewFormatFloat4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC1", + ("hipResViewFormatUnsignedBlockCompressed1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC2", + ("hipResViewFormatUnsignedBlockCompressed2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC3", + ("hipResViewFormatUnsignedBlockCompressed3", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC4", + ("hipResViewFormatUnsignedBlockCompressed4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC4", + ("hipResViewFormatSignedBlockCompressed4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC5", + ("hipResViewFormatUnsignedBlockCompressed5", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC5", + ("hipResViewFormatSignedBlockCompressed5", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC6H", + ("hipResViewFormatUnsignedBlockCompressed6H", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC6H", + ("hipResViewFormatSignedBlockCompressed6H", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC7", + ("hipResViewFormatUnsignedBlockCompressed7", CONV_TEX, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE", + ("hipSharedMemBankSizeDefault", CONV_TYPE, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE", + ("hipSharedMemBankSizeFourByte", CONV_TYPE, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE", + ("hipSharedMemBankSizeEightByte", CONV_TYPE, API_DRIVER), + ), + ("CU_STREAM_DEFAULT", ("hipStreamDefault", CONV_TYPE, API_DRIVER)), + ("CU_STREAM_NON_BLOCKING", ("hipStreamNonBlocking", CONV_TYPE, API_DRIVER)), + ( + "CU_STREAM_WAIT_VALUE_GEQ", + ("hipStreamWaitValueGeq", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_EQ", + ("hipStreamWaitValueEq", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_AND", + ("hipStreamWaitValueAnd", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_FLUSH", + ("hipStreamWaitValueFlush", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WRITE_VALUE_DEFAULT", + ("hipStreamWriteValueDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER", + ( + "hipStreamWriteValueNoMemoryBarrier", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_STREAM_MEM_OP_WAIT_VALUE_32", + ("hipStreamBatchMemOpWaitValue32", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_MEM_OP_WRITE_VALUE_32", + ("hipStreamBatchMemOpWriteValue32", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES", + ( + "hipStreamBatchMemOpFlushRemoteWrites", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGetErrorName", + ("hipGetErrorName", CONV_ERROR, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGetErrorString", + ("hipDrvGetErrorString", CONV_ERROR, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuInit", ("hipInit", CONV_INIT, API_DRIVER)), + ("cuDriverGetVersion", ("hipDriverGetVersion", CONV_VERSION, API_DRIVER)), + ("cuCtxCreate", ("hipCtxCreate", CONV_CONTEXT, API_DRIVER)), + ("cuCtxCreate_v2", ("hipCtxCreate", CONV_CONTEXT, API_DRIVER)), + ("cuCtxDestroy", ("hipCtxDestroy", CONV_CONTEXT, API_DRIVER)), + ("cuCtxDestroy_v2", ("hipCtxDestroy", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetApiVersion", ("hipCtxGetApiVersion", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetCacheConfig", ("hipCtxGetCacheConfig", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetCurrent", ("hipCtxGetCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetDevice", ("hipCtxGetDevice", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetFlags", ("hipCtxGetFlags", CONV_CONTEXT, API_DRIVER)), + ("cuDeviceGetUuid", ("hipDeviceGetUuid", CONV_CONTEXT, API_DRIVER)), + ( + "cuCtxGetLimit", + ("hipCtxGetLimit", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuCtxGetSharedMemConfig", + ("hipCtxGetSharedMemConfig", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuCtxGetStreamPriorityRange", + ("hipCtxGetStreamPriorityRange", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuCtxPopCurrent_v2", ("hipCtxPopCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxPushCurrent_v2", ("hipCtxPushCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxSetCacheConfig", ("hipCtxSetCacheConfig", CONV_CONTEXT, API_DRIVER)), + ("cuCtxSetCurrent", ("hipCtxSetCurrent", CONV_CONTEXT, API_DRIVER)), + ( + "cuCtxSetLimit", + ("hipCtxSetLimit", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuCtxSetSharedMemConfig", + ("hipCtxSetSharedMemConfig", CONV_CONTEXT, API_DRIVER), + ), + ("cuCtxSynchronize", ("hipCtxSynchronize", CONV_CONTEXT, API_DRIVER)), + ("cuCtxAttach", ("hipCtxAttach", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED)), + ("cuCtxDetach", ("hipCtxDetach", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED)), + ("cuCtxEnablePeerAccess", ("hipCtxEnablePeerAccess", CONV_PEER, API_DRIVER)), + ("cuCtxDisablePeerAccess", ("hipCtxDisablePeerAccess", CONV_PEER, API_DRIVER)), + ("cuDeviceCanAccessPeer", ("hipDeviceCanAccessPeer", CONV_PEER, API_DRIVER)), + ( + "cuDeviceGetP2PAttribute", + ("hipDeviceGetP2PAttribute", CONV_PEER, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuDevicePrimaryCtxGetState", + ("hipDevicePrimaryCtxGetState", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxRelease", + ("hipDevicePrimaryCtxRelease", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxReset", + ("hipDevicePrimaryCtxReset", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxRetain", + ("hipDevicePrimaryCtxRetain", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxSetFlags", + ("hipDevicePrimaryCtxSetFlags", CONV_CONTEXT, API_DRIVER), + ), + ("cuDeviceGet", ("hipDeviceGet", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetName", ("hipDeviceGetName", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetCount", ("hipGetDeviceCount", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetAttribute", ("hipDeviceGetAttribute", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetPCIBusId", ("hipDeviceGetPCIBusId", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetByPCIBusId", ("hipDeviceGetByPCIBusId", CONV_DEVICE, API_DRIVER)), + ("cuDeviceTotalMem_v2", ("hipDeviceTotalMem", CONV_DEVICE, API_DRIVER)), + ( + "cuDeviceComputeCapability", + ("hipDeviceComputeCapability", CONV_DEVICE, API_DRIVER), + ), + ("cuDeviceGetProperties", ("hipGetDeviceProperties", CONV_DEVICE, API_DRIVER)), + ("cuLinkAddData", ("hipLinkAddData", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLinkAddFile", ("hipLinkAddFile", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuLinkComplete", + ("hipLinkComplete", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLinkCreate", ("hipLinkCreate", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLinkDestroy", ("hipLinkDestroy", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuModuleGetFunction", ("hipModuleGetFunction", CONV_MODULE, API_DRIVER)), + ("cuModuleGetGlobal_v2", ("hipModuleGetGlobal", CONV_MODULE, API_DRIVER)), + ( + "cuModuleGetSurfRef", + ("hipModuleGetSurfRef", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuModuleGetTexRef", ("hipModuleGetTexRef", CONV_MODULE, API_DRIVER)), + ("cuModuleLoad", ("hipModuleLoad", CONV_MODULE, API_DRIVER)), + ("cuModuleLoadData", ("hipModuleLoadData", CONV_MODULE, API_DRIVER)), + ("cuModuleLoadDataEx", ("hipModuleLoadDataEx", CONV_MODULE, API_DRIVER)), + ( + "cuModuleLoadFatBinary", + ("hipModuleLoadFatBinary", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuModuleUnload", ("hipModuleUnload", CONV_MODULE, API_DRIVER)), + ( + "CU_DEVICE_P2P_ATTRIBUTE_PERFORMANCE_RANK", + ( + "hipDeviceP2PAttributePerformanceRank", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED", + ( + "hipDeviceP2PAttributeAccessSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED", + ( + "hipDeviceP2PAttributeNativeAtomicSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("CU_EVENT_DEFAULT", ("hipEventDefault", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_BLOCKING_SYNC", ("hipEventBlockingSync", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_DISABLE_TIMING", ("hipEventDisableTiming", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_INTERPROCESS", ("hipEventInterprocess", CONV_EVENT, API_DRIVER)), + ("cuEventCreate", ("hipEventCreate", CONV_EVENT, API_DRIVER)), + ("cuEventDestroy", ("hipEventDestroy", CONV_EVENT, API_DRIVER)), + ("cuEventDestroy_v2", ("hipEventDestroy", CONV_EVENT, API_DRIVER)), + ("cuEventElapsedTime", ("hipEventElapsedTime", CONV_EVENT, API_DRIVER)), + ("cuEventQuery", ("hipEventQuery", CONV_EVENT, API_DRIVER)), + ("cuEventRecord", ("hipEventRecord", CONV_EVENT, API_DRIVER)), + ("cuEventSynchronize", ("hipEventSynchronize", CONV_EVENT, API_DRIVER)), + ("cuFuncSetAttribute", ("hipFuncSetAttribute", CONV_EVENT, API_DRIVER)), + ( + "cuFuncGetAttribute", + ("hipFuncGetAttribute", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuFuncSetCacheConfig", ("hipFuncSetCacheConfig", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetSharedMemConfig", + ("hipFuncSetSharedMemConfig", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLaunchKernel", ("hipModuleLaunchKernel", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetBlockShape", + ("hipFuncSetBlockShape", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cudaLaunchKernel", ("hipLaunchKernel", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetSharedSize", + ("hipFuncSetSharedSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLaunch", ("hipLaunch", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLaunchGrid", ("hipLaunchGrid", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuLaunchGridAsync", + ("hipLaunchGridAsync", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuParamSetf", ("hipParamSetf", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuParamSeti", ("hipParamSeti", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuParamSetSize", + ("hipParamSetSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuParamSetSize", + ("hipParamSetSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuParamSetv", ("hipParamSetv", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuOccupancyMaxActiveBlocksPerMultiprocessor", + ( + "hipModuleOccupancyMaxActiveBlocksPerMultiprocessor", + CONV_OCCUPANCY, + API_DRIVER, + ), + ), + ( + "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + ( + "hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + CONV_OCCUPANCY, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuOccupancyMaxPotentialBlockSize", + ("hipModuleOccupancyMaxPotentialBlockSize", CONV_OCCUPANCY, API_DRIVER), + ), + ( + "cuOccupancyMaxPotentialBlockSizeWithFlags", + ( + "hipModuleOccupancyMaxPotentialBlockSizeWithFlags", + CONV_OCCUPANCY, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cuStreamAddCallback", ("hipStreamAddCallback", CONV_STREAM, API_DRIVER)), + ( + "cuStreamAttachMemAsync", + ("hipStreamAttachMemAsync", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamCreate", + ("hipStreamCreate__", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamCreateWithPriority", + ("hipStreamCreateWithPriority", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuStreamDestroy", ("hipStreamDestroy", CONV_STREAM, API_DRIVER)), + ("cuStreamDestroy_v2", ("hipStreamDestroy", CONV_STREAM, API_DRIVER)), + ("cuStreamGetFlags", ("hipStreamGetFlags", CONV_STREAM, API_DRIVER)), + ( + "cuStreamGetPriority", + ("hipStreamGetPriority", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuStreamQuery", ("hipStreamQuery", CONV_STREAM, API_DRIVER)), + ("cuStreamSynchronize", ("hipStreamSynchronize", CONV_STREAM, API_DRIVER)), + ("cuStreamWaitEvent", ("hipStreamWaitEvent", CONV_STREAM, API_DRIVER)), + ( + "cuStreamWaitValue32", + ("hipStreamWaitValue32", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamWriteValue32", + ("hipStreamWriteValue32", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamBatchMemOp", + ("hipStreamBatchMemOp", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuArray3DCreate", ("hipArray3DCreate", CONV_MEM, API_DRIVER)), + ( + "cuArray3DGetDescriptor", + ("hipArray3DGetDescriptor", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuArrayCreate", ("hipArrayCreate", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuArrayDestroy", ("hipArrayDestroy", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuArrayGetDescriptor", + ("hipArrayGetDescriptor", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcCloseMemHandle", + ("hipIpcCloseMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcGetEventHandle", + ("hipIpcGetEventHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcGetMemHandle", + ("hipIpcGetMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcOpenEventHandle", + ("hipIpcOpenEventHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcOpenMemHandle", + ("hipIpcOpenMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemAlloc_v2", ("hipMalloc", CONV_MEM, API_DRIVER)), + ("cuMemAllocHost", ("hipMemAllocHost", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemAllocManaged", + ("hipMemAllocManaged", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemAllocPitch", + ("hipMemAllocPitch__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpy", ("hipMemcpy__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpy2D", ("hipMemcpy2D__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpy2DAsync", + ("hipMemcpy2DAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy2DUnaligned", + ("hipMemcpy2DUnaligned", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpy3D", ("hipMemcpy3D__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpy3DAsync", + ("hipMemcpy3DAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy3DPeer", + ("hipMemcpy3DPeer__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy3DPeerAsync", + ("hipMemcpy3DPeerAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyAsync", ("hipMemcpyAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoA", ("hipMemcpyAtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoD", ("hipMemcpyAtoD", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoH", ("hipMemcpyAtoH", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpyAtoHAsync", + ("hipMemcpyAtoHAsync", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyDtoA", ("hipMemcpyDtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyDtoD_v2", ("hipMemcpyDtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoDAsync_v2", ("hipMemcpyDtoDAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoH_v2", ("hipMemcpyDtoH", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoHAsync_v2", ("hipMemcpyDtoHAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoA", ("hipMemcpyHtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpyHtoAAsync", + ("hipMemcpyHtoAAsync", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyHtoD_v2", ("hipMemcpyHtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoDAsync_v2", ("hipMemcpyHtoDAsync", CONV_MEM, API_DRIVER)), + ( + "cuMemcpyPeerAsync", + ("hipMemcpyPeerAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyPeer", ("hipMemcpyPeer__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemFree", ("hipFree", CONV_MEM, API_DRIVER)), + ("cuMemFree_v2", ("hipFree", CONV_MEM, API_DRIVER)), + ("cuMemFreeHost", ("hipHostFree", CONV_MEM, API_DRIVER)), + ( + "cuMemGetAddressRange", + ("hipMemGetAddressRange", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemGetInfo_v2", ("hipMemGetInfo", CONV_MEM, API_DRIVER)), + ("cuMemHostAlloc", ("hipHostMalloc", CONV_MEM, API_DRIVER)), + ( + "cuMemHostGetDevicePointer", + ("hipMemHostGetDevicePointer", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemHostGetFlags", + ("hipMemHostGetFlags", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemHostRegister_v2", ("hipHostRegister", CONV_MEM, API_DRIVER)), + ("cuMemHostUnregister", ("hipHostUnregister", CONV_MEM, API_DRIVER)), + ("cuMemsetD16_v2", ("hipMemsetD16", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD16Async", + ("hipMemsetD16Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D16_v2", ("hipMemsetD2D16", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D16Async", + ("hipMemsetD2D16Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D32_v2", ("hipMemsetD2D32", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D32Async", + ("hipMemsetD2D32Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D8_v2", ("hipMemsetD2D8", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D8Async", + ("hipMemsetD2D8Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD32_v2", ("hipMemset", CONV_MEM, API_DRIVER)), + ("cuMemsetD32Async", ("hipMemsetAsync", CONV_MEM, API_DRIVER)), + ("cuMemsetD8_v2", ("hipMemsetD8", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD8Async", + ("hipMemsetD8Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayCreate", + ("hipMipmappedArrayCreate", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayDestroy", + ("hipMipmappedArrayDestroy", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayGetLevel", + ("hipMipmappedArrayGetLevel", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemPrefetchAsync", + ("hipMemPrefetchAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemAdvise", ("hipMemAdvise", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemRangeGetAttribute", + ("hipMemRangeGetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemRangeGetAttributes", + ("hipMemRangeGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuPointerGetAttribute", + ("hipPointerGetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemGetAddressRange_v2", + ("hipMemGetAddressRange", CONV_MEM, API_DRIVER), + ), + ( + "cuPointerGetAttributes", + ("hipPointerGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuPointerSetAttribute", + ("hipPointerSetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_TR_FILTER_MODE_POINT", ("hipFilterModePoint", CONV_TEX, API_DRIVER)), + ( + "CU_TR_FILTER_MODE_LINEAR", + ("hipFilterModeLinear", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetAddress", + ("hipTexRefGetAddress", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetAddressMode", + ("hipTexRefGetAddressMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetArray", + ("hipTexRefGetArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetBorderColor", + ("hipTexRefGetBorderColor", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFilterMode", + ("hipTexRefGetFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFlags", + ("hipTexRefGetFlags", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFormat", + ("hipTexRefGetFormat", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMaxAnisotropy", + ("hipTexRefGetMaxAnisotropy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapFilterMode", + ("hipTexRefGetMipmapFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapLevelBias", + ("hipTexRefGetMipmapLevelBias", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapLevelClamp", + ("hipTexRefGetMipmapLevelClamp", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmappedArray", + ("hipTexRefGetMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetAddress", + ("hipTexRefSetAddress", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetAddress2D", + ("hipTexRefSetAddress2D", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefSetAddressMode", ("hipTexRefSetAddressMode", CONV_TEX, API_DRIVER)), + ("cuTexRefSetArray", ("hipTexRefSetArray", CONV_TEX, API_DRIVER)), + ( + "cuTexRefSetBorderColor", + ("hipTexRefSetBorderColor", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefSetFilterMode", ("hipTexRefSetFilterMode", CONV_TEX, API_DRIVER)), + ("cuTexRefSetFlags", ("hipTexRefSetFlags", CONV_TEX, API_DRIVER)), + ("cuTexRefSetFormat", ("hipTexRefSetFormat", CONV_TEX, API_DRIVER)), + ( + "cuTexRefSetMaxAnisotropy", + ("hipTexRefSetMaxAnisotropy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapFilterMode", + ("hipTexRefSetMipmapFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapLevelBias", + ("hipTexRefSetMipmapLevelBias", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapLevelClamp", + ("hipTexRefSetMipmapLevelClamp", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmappedArray", + ("hipTexRefSetMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefCreate", ("hipTexRefCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuTexRefDestroy", + ("hipTexRefDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfRefGetArray", + ("hipSurfRefGetArray", CONV_SURFACE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfRefSetArray", + ("hipSurfRefSetArray", CONV_SURFACE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectCreate", + ("hipTexObjectCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectDestroy", + ("hipTexObjectDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetResourceDesc", + ("hipTexObjectGetResourceDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetResourceViewDesc", + ("hipTexObjectGetResourceViewDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetTextureDesc", + ("hipTexObjectGetTextureDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectCreate", + ("hipSurfObjectCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectDestroy", + ("hipSurfObjectDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectGetResourceDesc", + ("hipSurfObjectGetResourceDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsMapResources", + ("hipGraphicsMapResources", CONV_GRAPHICS, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsResourceGetMappedMipmappedArray", + ( + "hipGraphicsResourceGetMappedMipmappedArray", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsResourceGetMappedPointer", + ( + "hipGraphicsResourceGetMappedPointer", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsResourceSetMapFlags", + ( + "hipGraphicsResourceSetMapFlags", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsSubResourceGetMappedArray", + ( + "hipGraphicsSubResourceGetMappedArray", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsUnmapResources", + ("hipGraphicsUnmapResources", CONV_GRAPHICS, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsUnregisterResource", + ( + "hipGraphicsUnregisterResource", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuProfilerInitialize", + ("hipProfilerInitialize", CONV_OTHER, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuProfilerStart", ("hipProfilerStart", CONV_OTHER, API_DRIVER)), + ("cuProfilerStop", ("hipProfilerStop", CONV_OTHER, API_DRIVER)), + ( + "CU_GL_DEVICE_LIST_ALL", + ("HIP_GL_DEVICE_LIST_ALL", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_DEVICE_LIST_CURRENT_FRAME", + ("HIP_GL_DEVICE_LIST_CURRENT_FRAME", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_DEVICE_LIST_NEXT_FRAME", + ("HIP_GL_DEVICE_LIST_NEXT_FRAME", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuGLGetDevices", ("hipGLGetDevices", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuWGLGetDevice", ("hipWGLGetDevice", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CU_GL_MAP_RESOURCE_FLAGS_NONE", + ("HIP_GL_MAP_RESOURCE_FLAGS_NONE", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + CONV_GL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + CONV_GL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cuGLCtxCreate", ("hipGLCtxCreate", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ("cuGLInit", ("hipGLInit", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuGLMapBufferObject", + ("hipGLMapBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLMapBufferObjectAsync", + ("hipGLMapBufferObjectAsync", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLRegisterBufferObject", + ("hipGLRegisterBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLSetBufferObjectMapFlags", + ("hipGLSetBufferObjectMapFlags", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnmapBufferObject", + ("hipGLUnmapBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnmapBufferObjectAsync", + ("hipGLUnmapBufferObjectAsync", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnregisterBufferObject", + ("hipGLUnregisterBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_DEVICE_LIST_ALL", + ("HIP_D3D9_DEVICE_LIST_ALL", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D9_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_DEVICE_LIST_NEXT_FRAME", + ("HIP_D3D9_DEVICE_LIST_NEXT_FRAME", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9CtxCreate", + ("hipD3D9CtxCreate", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9CtxCreateOnDevice", + ("hipD3D9CtxCreateOnDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDevice", + ("hipD3D9GetDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDevices", + ("hipD3D9GetDevices", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDirect3DDevice", + ("hipD3D9GetDirect3DDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D9RegisterResource", + ("hipGraphicsD3D9RegisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_NONE", + ("HIP_D3D9_MAPRESOURCE_FLAGS_NONE", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_READONLY", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_REGISTER_FLAGS_NONE", + ("HIP_D3D9_REGISTER_FLAGS_NONE", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_REGISTER_FLAGS_ARRAY", + ("HIP_D3D9_REGISTER_FLAGS_ARRAY", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9MapResources", + ("hipD3D9MapResources", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9RegisterResource", + ("hipD3D9RegisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedArray", + ("hipD3D9ResourceGetMappedArray", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedPitch", + ("hipD3D9ResourceGetMappedPitch", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedPointer", + ("hipD3D9ResourceGetMappedPointer", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedSize", + ("hipD3D9ResourceGetMappedSize", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetSurfaceDimensions", + ( + "hipD3D9ResourceGetSurfaceDimensions", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D9ResourceSetMapFlags", + ("hipD3D9ResourceSetMapFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9UnmapResources", + ("hipD3D9UnmapResources", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9UnregisterResource", + ("hipD3D9UnregisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_DEVICE_LIST_ALL", + ("HIP_D3D10_DEVICE_LIST_ALL", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D10_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_DEVICE_LIST_NEXT_FRAME", + ( + "HIP_D3D10_DEVICE_LIST_NEXT_FRAME", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D10GetDevice", + ("hipD3D10GetDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10GetDevices", + ("hipD3D10GetDevices", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D10RegisterResource", + ( + "hipGraphicsD3D10RegisterResource", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_NONE", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_NONE", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_READONLY", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_REGISTER_FLAGS_NONE", + ("HIP_D3D10_REGISTER_FLAGS_NONE", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_REGISTER_FLAGS_ARRAY", + ("HIP_D3D10_REGISTER_FLAGS_ARRAY", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10CtxCreate", + ("hipD3D10CtxCreate", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10CtxCreateOnDevice", + ("hipD3D10CtxCreateOnDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10GetDirect3DDevice", + ("hipD3D10GetDirect3DDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10MapResources", + ("hipD3D10MapResources", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10RegisterResource", + ("hipD3D10RegisterResource", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedArray", + ("hipD3D10ResourceGetMappedArray", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedPitch", + ("hipD3D10ResourceGetMappedPitch", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedPointer", + ( + "hipD3D10ResourceGetMappedPointer", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D10ResourceGetMappedSize", + ("hipD3D10ResourceGetMappedSize", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetSurfaceDimensions", + ( + "hipD3D10ResourceGetSurfaceDimensions", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD310ResourceSetMapFlags", + ("hipD3D10ResourceSetMapFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10UnmapResources", + ("hipD3D10UnmapResources", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10UnregisterResource", + ("hipD3D10UnregisterResource", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D11_DEVICE_LIST_ALL", + ("HIP_D3D11_DEVICE_LIST_ALL", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D11_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D11_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D11_DEVICE_LIST_NEXT_FRAME", + ( + "HIP_D3D11_DEVICE_LIST_NEXT_FRAME", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D11CtxCreate", + ("hipD3D11CtxCreate", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11CtxCreateOnDevice", + ("hipD3D11CtxCreateOnDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11GetDirect3DDevice", + ("hipD3D11GetDirect3DDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsVDPAURegisterOutputSurface", + ( + "hipGraphicsVDPAURegisterOutputSurface", + CONV_VDPAU, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsVDPAURegisterVideoSurface", + ( + "hipGraphicsVDPAURegisterVideoSurface", + CONV_VDPAU, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuVDPAUGetDevice", + ("hipVDPAUGetDevice", CONV_VDPAU, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuVDPAUCtxCreate", + ("hipVDPAUCtxCreate", CONV_VDPAU, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerAcquireFrame", + ("hipEGLStreamConsumerAcquireFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerConnect", + ("hipEGLStreamConsumerConnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerConnectWithFlags", + ( + "hipEGLStreamConsumerConnectWithFlags", + CONV_EGL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuEGLStreamConsumerDisconnect", + ("hipEGLStreamConsumerDisconnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerReleaseFrame", + ("hipEGLStreamConsumerReleaseFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerConnect", + ("hipEGLStreamProducerConnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerDisconnect", + ("hipEGLStreamProducerDisconnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerPresentFrame", + ("hipEGLStreamProducerPresentFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerReturnFrame", + ("hipEGLStreamProducerReturnFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsEGLRegisterImage", + ("hipGraphicsEGLRegisterImage", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsResourceGetMappedEglFrame", + ( + "hipGraphicsResourceGetMappedEglFrame", + CONV_EGL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cudaDataType_t", ("hipDataType", CONV_TYPE, API_RUNTIME)), + ("cudaDataType", ("hipDataType", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_32F", ("HIP_R_32F", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_64F", ("HIP_R_64F", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16F", ("HIP_R_16F", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8I", ("HIP_R_8I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_32F", ("HIP_C_32F", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_64F", ("HIP_C_64F", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16F", ("HIP_C_16F", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_8I", ("HIP_C_8I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8U", ("HIP_R_8U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_8U", ("HIP_C_8U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_32I", ("HIP_R_32I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_32I", ("HIP_C_32I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_32U", ("HIP_R_32U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_32U", ("HIP_C_32U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16BF", ("HIP_R_16BF", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16BF", ("HIP_C_16BF", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_4I", ("HIP_R_4I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_4I", ("HIP_C_4I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_4U", ("HIP_R_4U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_4U", ("HIP_C_4U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16I", ("HIP_R_16I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16I", ("HIP_C_16I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16U", ("HIP_R_16U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16U", ("HIP_C_16U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_64I", ("HIP_R_64I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_64I", ("HIP_C_64I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_64U", ("HIP_R_64U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_64U", ("HIP_C_64U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8F_E4M3", ("HIP_R_8F_E4M3", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8F_E5M2", ("HIP_R_8F_E5M2", CONV_TYPE, API_RUNTIME)), + ( + "MAJOR_VERSION", + ("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "MINOR_VERSION", + ("hipLibraryMinorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "PATCH_LEVEL", + ("hipLibraryPatchVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachGlobal", + ("hipMemAttachGlobal", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachHost", + ("hipMemAttachHost", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachSingle", + ("hipMemAttachSingle", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyDefault", + ("hipOccupancyDefault", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyDisableCachingOverride", + ( + "hipOccupancyDisableCachingOverride", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaGetLastError", ("hipGetLastError", CONV_ERROR, API_RUNTIME)), + ("cudaPeekAtLastError", ("hipPeekAtLastError", CONV_ERROR, API_RUNTIME)), + ("cudaGetErrorName", ("hipGetErrorName", CONV_ERROR, API_RUNTIME)), + ("cudaGetErrorString", ("hipGetErrorString", CONV_ERROR, API_RUNTIME)), + ("cudaMemcpy3DParms", ("hipMemcpy3DParms", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy3DPeerParms", + ("hipMemcpy3DPeerParms", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpy", ("hipMemcpy", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToArray", ("hipMemcpyToArray", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToSymbol", ("hipMemcpyToSymbol", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToSymbolAsync", ("hipMemcpyToSymbolAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyAsync", ("hipMemcpyAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2D", ("hipMemcpy2D", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2DAsync", ("hipMemcpy2DAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2DToArray", ("hipMemcpy2DToArray", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy2DArrayToArray", + ("hipMemcpy2DArrayToArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DFromArray", + ("hipMemcpy2DFromArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DFromArrayAsync", + ("hipMemcpy2DFromArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DToArrayAsync", + ("hipMemcpy2DToArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpy3D", ("hipMemcpy3D", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy3DAsync", + ("hipMemcpy3DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy3DPeer", + ("hipMemcpy3DPeer", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy3DPeerAsync", + ("hipMemcpy3DPeerAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpyArrayToArray", + ("hipMemcpyArrayToArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpyFromArrayAsync", + ("hipMemcpyFromArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpyFromSymbol", ("hipMemcpyFromSymbol", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpyFromSymbolAsync", + ("hipMemcpyFromSymbolAsync", CONV_MEM, API_RUNTIME), + ), + ("cudaMemAdvise", ("hipMemAdvise", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemRangeGetAttribute", + ("hipMemRangeGetAttribute", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeGetAttributes", + ("hipMemRangeGetAttributes", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseSetReadMostly", + ("hipMemAdviseSetReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseUnsetReadMostly", + ("hipMemAdviseUnsetReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseSetPreferredLocation", + ( + "hipMemAdviseSetPreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemAdviseUnsetPreferredLocation", + ( + "hipMemAdviseUnsetPreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemAdviseSetAccessedBy", + ("hipMemAdviseSetAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseUnsetAccessedBy", + ("hipMemAdviseUnsetAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributeReadMostly", + ("hipMemRangeAttributeReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributePreferredLocation", + ( + "hipMemRangeAttributePreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemRangeAttributeAccessedBy", + ("hipMemRangeAttributeAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributeLastPrefetchLocation", + ( + "hipMemRangeAttributeLastPrefetchLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaMemcpyHostToHost", ("hipMemcpyHostToHost", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyHostToDevice", ("hipMemcpyHostToDevice", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyDeviceToHost", ("hipMemcpyDeviceToHost", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpyDeviceToDevice", + ("hipMemcpyDeviceToDevice", CONV_MEM, API_RUNTIME), + ), + ("cudaMemcpyDefault", ("hipMemcpyDefault", CONV_MEM, API_RUNTIME)), + ("cudaMemset", ("hipMemset", CONV_MEM, API_RUNTIME)), + ("cudaMemsetAsync", ("hipMemsetAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemset2D", ("hipMemset2D", CONV_MEM, API_RUNTIME)), + ( + "cudaMemset2DAsync", + ("hipMemset2DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemset3D", ("hipMemset3D", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemset3DAsync", + ("hipMemset3DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemGetInfo", ("hipMemGetInfo", CONV_MEM, API_RUNTIME)), + ("cudaDeviceGetDefaultMemPool", ("hipDeviceGetDefaultMemPool", CONV_MEM, API_RUNTIME)), + ("cudaMemAccessDesc", ("hipMemAccessDesc", CONV_MEM, API_RUNTIME)), + ("cudaMemAccessFlagsProtReadWrite", ("hipMemAccessFlagsProtReadWrite", CONV_MEM, API_RUNTIME)), + ("cudaMemLocationTypeDevice", ("hipMemLocationTypeDevice", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrReleaseThreshold", ("hipMemPoolAttrReleaseThreshold", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrReservedMemCurrent", ("hipMemPoolAttrReservedMemCurrent", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrReservedMemHigh", ("hipMemPoolAttrReservedMemHigh", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrUsedMemCurrent", ("hipMemPoolAttrUsedMemCurrent", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrUsedMemHigh", ("hipMemPoolAttrUsedMemHigh", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolGetAttribute", ("hipMemPoolGetAttribute", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolReuseAllowInternalDependencies", ("hipMemPoolReuseAllowInternalDependencies", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolReuseAllowOpportunistic", ("hipMemPoolReuseAllowOpportunistic", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolReuseFollowEventDependencies", ("hipMemPoolReuseFollowEventDependencies", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolSetAccess", ("hipMemPoolSetAccess", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolSetAttribute", ("hipMemPoolSetAttribute", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolTrimTo", ("hipMemPoolTrimTo", CONV_MEM, API_RUNTIME)), + ("cudaMemPool_t", ("hipMemPool_t", CONV_MEM, API_RUNTIME)), + ( + "cudaArrayGetInfo", + ("hipArrayGetInfo", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFreeMipmappedArray", + ("hipFreeMipmappedArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetMipmappedArrayLevel", + ("hipGetMipmappedArrayLevel", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSymbolAddress", + ("hipGetSymbolAddress", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSymbolSize", + ("hipGetSymbolSize", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemPrefetchAsync", + ("hipMemPrefetchAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMallocHost", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMallocArray", ("hipMallocArray", CONV_MEM, API_RUNTIME)), + ("cudaMalloc", ("hipMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMalloc3D", ("hipMalloc3D", CONV_MEM, API_RUNTIME)), + ("cudaMalloc3DArray", ("hipMalloc3DArray", CONV_MEM, API_RUNTIME)), + ( + "cudaMallocManaged", + ("hipMallocManaged", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMallocMipmappedArray", + ("hipMallocMipmappedArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMallocPitch", ("hipMallocPitch", CONV_MEM, API_RUNTIME)), + ("cudaFreeHost", ("hipHostFree", CONV_MEM, API_RUNTIME)), + ("cudaFreeArray", ("hipFreeArray", CONV_MEM, API_RUNTIME)), + ("cudaFree", ("hipFree", CONV_MEM, API_RUNTIME)), + ("cudaHostRegister", ("hipHostRegister", CONV_MEM, API_RUNTIME)), + ("cudaHostUnregister", ("hipHostUnregister", CONV_MEM, API_RUNTIME)), + ("cudaHostAlloc", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeHost", ("hipMemoryTypeHost", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeDevice", ("hipMemoryTypeDevice", CONV_MEM, API_RUNTIME)), + ("make_cudaExtent", ("make_hipExtent", CONV_MEM, API_RUNTIME)), + ("make_cudaPitchedPtr", ("make_hipPitchedPtr", CONV_MEM, API_RUNTIME)), + ("make_cudaPos", ("make_hipPos", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocDefault", ("hipHostMallocDefault", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocPortable", ("hipHostMallocPortable", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocMapped", ("hipHostMallocMapped", CONV_MEM, API_RUNTIME)), + ( + "cudaHostAllocWriteCombined", + ("hipHostMallocWriteCombined", CONV_MEM, API_RUNTIME), + ), + ("cudaHostGetFlags", ("hipHostGetFlags", CONV_MEM, API_RUNTIME)), + ("cudaHostRegisterDefault", ("hipHostRegisterDefault", CONV_MEM, API_RUNTIME)), + ( + "cudaHostRegisterPortable", + ("hipHostRegisterPortable", CONV_MEM, API_RUNTIME), + ), + ("cudaHostRegisterMapped", ("hipHostRegisterMapped", CONV_MEM, API_RUNTIME)), + ( + "cudaHostRegisterIoMemory", + ("hipHostRegisterIoMemory", CONV_MEM, API_RUNTIME), + ), + # ("warpSize", ("hipWarpSize", CONV_SPECIAL_FUNC, API_RUNTIME), (HIP actually uses warpSize...)), + ("cudaEventCreate", ("hipEventCreate", CONV_EVENT, API_RUNTIME)), + ( + "cudaEventCreateWithFlags", + ("hipEventCreateWithFlags", CONV_EVENT, API_RUNTIME), + ), + ("cudaEventDestroy", ("hipEventDestroy", CONV_EVENT, API_RUNTIME)), + ("cudaEventRecord", ("hipEventRecord", CONV_EVENT, API_RUNTIME)), + ("cudaEventElapsedTime", ("hipEventElapsedTime", CONV_EVENT, API_RUNTIME)), + ("cudaEventSynchronize", ("hipEventSynchronize", CONV_EVENT, API_RUNTIME)), + ("cudaEventQuery", ("hipEventQuery", CONV_EVENT, API_RUNTIME)), + ("cudaEventDefault", ("hipEventDefault", CONV_EVENT, API_RUNTIME)), + ("cudaEventBlockingSync", ("hipEventBlockingSync", CONV_EVENT, API_RUNTIME)), + ("cudaEventDisableTiming", ("hipEventDisableTiming", CONV_EVENT, API_RUNTIME)), + ("cudaEventInterprocess", ("hipEventInterprocess", CONV_EVENT, API_RUNTIME)), + ("cudaStreamCreate", ("hipStreamCreate", CONV_STREAM, API_RUNTIME)), + ( + "cudaStreamCreateWithFlags", + ("hipStreamCreateWithFlags", CONV_STREAM, API_RUNTIME), + ), + ( + "cudaStreamCreateWithPriority", + ("hipStreamCreateWithPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaStreamDestroy", ("hipStreamDestroy", CONV_STREAM, API_RUNTIME)), + ("cudaStreamWaitEvent", ("hipStreamWaitEvent", CONV_STREAM, API_RUNTIME)), + ("cudaStreamSynchronize", ("hipStreamSynchronize", CONV_STREAM, API_RUNTIME)), + ("cudaStreamGetFlags", ("hipStreamGetFlags", CONV_STREAM, API_RUNTIME)), + ("cudaStreamQuery", ("hipStreamQuery", CONV_STREAM, API_RUNTIME)), + ("cudaStreamAddCallback", ("hipStreamAddCallback", CONV_STREAM, API_RUNTIME)), + ( + "cudaStreamAttachMemAsync", + ("hipStreamAttachMemAsync", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaStreamGetPriority", + ("hipStreamGetPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaCpuDeviceId", ("hipCpuDeviceId", CONV_TYPE, API_RUNTIME)), + ("cudaStreamDefault", ("hipStreamDefault", CONV_TYPE, API_RUNTIME)), + ("cudaStreamNonBlocking", ("hipStreamNonBlocking", CONV_TYPE, API_RUNTIME)), + ("cudaStreamGetCaptureInfo", ("hipStreamGetCaptureInfo", CONV_TYPE, API_RUNTIME)), + ("cudaStreamGetCaptureInfo_v2", ("hipStreamGetCaptureInfo_v2", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureStatus", ("hipStreamCaptureStatus", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureStatusActive", ("hipStreamCaptureStatusActive", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureMode", ("hipStreamCaptureMode", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeGlobal", ("hipStreamCaptureModeGlobal", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeRelaxed", ("hipStreamCaptureModeRelaxed", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)), + ("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)), + ("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateWithFlags", ("hipGraphInstantiateWithFlags", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateFlagAutoFreeOnLaunch", ("hipGraphInstantiateFlagAutoFreeOnLaunch", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDestroy", ("hipGraphDestroy", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecDestroy", ("hipGraphExecDestroy", CONV_TYPE, API_RUNTIME)), + ("cudaGraphLaunch", ("hipGraphLaunch", CONV_TYPE, API_RUNTIME)), + ("cudaGraphGetNodes", ("hipGraphGetNodes", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotPrint", ("hipGraphDebugDotPrint", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsVerbose", ("hipGraphDebugDotFlagsVerbose", CONV_NUMERIC_LITERAL, API_RUNTIME)), + ("cudaGraphRetainUserObject", ("hipGraphRetainUserObject", CONV_TYPE, API_RUNTIME)), + ("cudaGraphUserObjectMove", ("hipGraphUserObjectMove", CONV_TYPE, API_RUNTIME)), + ("cudaUserObject_t", ("hipUserObject_t", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectCreate", ("hipUserObjectCreate", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectNoDestructorSync", ("hipUserObjectNoDestructorSync", CONV_TYPE, API_RUNTIME)), + ("cudaThreadExchangeStreamCaptureMode", ("hipThreadExchangeStreamCaptureMode", CONV_TYPE, API_RUNTIME)), + ("cudaStreamIsCapturing", ("hipStreamIsCapturing", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceSynchronize", ("hipDeviceSynchronize", CONV_DEVICE, API_RUNTIME)), + ("cudaDeviceReset", ("hipDeviceReset", CONV_DEVICE, API_RUNTIME)), + ("cudaSetDevice", ("hipSetDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaGetDevice", ("hipGetDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaGetDeviceCount", ("hipGetDeviceCount", CONV_DEVICE, API_RUNTIME)), + ("cudaChooseDevice", ("hipChooseDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaThreadExit", ("hipDeviceReset", CONV_THREAD, API_RUNTIME)), + ( + "cudaThreadGetCacheConfig", + ("hipDeviceGetCacheConfig", CONV_THREAD, API_RUNTIME), + ), + ( + "cudaThreadGetLimit", + ("hipThreadGetLimit", CONV_THREAD, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaThreadSetCacheConfig", + ("hipDeviceSetCacheConfig", CONV_THREAD, API_RUNTIME), + ), + ( + "cudaThreadSetLimit", + ("hipThreadSetLimit", CONV_THREAD, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaThreadSynchronize", ("hipDeviceSynchronize", CONV_THREAD, API_RUNTIME)), + ("cudaDeviceGetAttribute", ("hipDeviceGetAttribute", CONV_DEVICE, API_RUNTIME)), + ( + "cudaDevAttrMaxThreadsPerBlock", + ("hipDeviceAttributeMaxThreadsPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimX", + ("hipDeviceAttributeMaxBlockDimX", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimY", + ("hipDeviceAttributeMaxBlockDimY", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimZ", + ("hipDeviceAttributeMaxBlockDimZ", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimX", + ("hipDeviceAttributeMaxGridDimX", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimY", + ("hipDeviceAttributeMaxGridDimY", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimZ", + ("hipDeviceAttributeMaxGridDimZ", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxSharedMemoryPerBlock", + ("hipDeviceAttributeMaxSharedMemoryPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxSharedMemoryPerBlockOptin", + ("hipDeviceAttributeMaxSharedMemoryPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTotalConstantMemory", + ("hipDeviceAttributeTotalConstantMemory", CONV_TYPE, API_RUNTIME), + ), + ("cudaDevAttrWarpSize", ("hipDeviceAttributeWarpSize", CONV_TYPE, API_RUNTIME)), + ( + "cudaDevAttrMaxPitch", + ("hipDeviceAttributeMaxPitch", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMaxRegistersPerBlock", + ("hipDeviceAttributeMaxRegistersPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrClockRate", + ("hipDeviceAttributeClockRate", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTextureAlignment", + ( + "hipDeviceAttributeTextureAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrGpuOverlap", + ("hipDeviceAttributeGpuOverlap", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMultiProcessorCount", + ("hipDeviceAttributeMultiprocessorCount", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrKernelExecTimeout", + ( + "hipDeviceAttributeKernelExecTimeout", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrIntegrated", + ("hipDeviceAttributeIntegrated", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrCanMapHostMemory", + ( + "hipDeviceAttributeCanMapHostMemory", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputeMode", + ("hipDeviceAttributeComputeMode", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxTexture1DWidth", + ( + "hipDeviceAttributeMaxTexture1DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DWidth", + ( + "hipDeviceAttributeMaxTexture2DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DHeight", + ( + "hipDeviceAttributeMaxTexture2DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DWidth", + ( + "hipDeviceAttributeMaxTexture3DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DHeight", + ( + "hipDeviceAttributeMaxTexture3DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DDepth", + ( + "hipDeviceAttributeMaxTexture3DDepth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredWidth", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredHeight", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredLayers", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrSurfaceAlignment", + ( + "hipDeviceAttributeSurfaceAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrConcurrentKernels", + ("hipDeviceAttributeConcurrentKernels", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrEccEnabled", + ("hipDeviceAttributeEccEnabled", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDevAttrPciBusId", ("hipDeviceAttributePciBusId", CONV_TYPE, API_RUNTIME)), + ( + "cudaDevAttrPciDeviceId", + ("hipDeviceAttributePciDeviceId", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTccDriver", + ("hipDeviceAttributeTccDriver", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMemoryClockRate", + ("hipDeviceAttributeMemoryClockRate", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrGlobalMemoryBusWidth", + ("hipDeviceAttributeMemoryBusWidth", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrL2CacheSize", + ("hipDeviceAttributeL2CacheSize", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxThreadsPerMultiProcessor", + ("hipDeviceAttributeMaxThreadsPerMultiProcessor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrAsyncEngineCount", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrUnifiedAddressing", + ( + "hipDeviceAttributeUnifiedAddressing", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLayeredWidth", + ( + "hipDeviceAttributeMaxTexture1DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLayeredLayers", + ( + "hipDeviceAttributeMaxTexture1DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DGatherWidth", + ( + "hipDeviceAttributeMaxTexture2DGatherWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DGatherHeight", + ( + "hipDeviceAttributeMaxTexture2DGatherHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DWidthAlt", + ( + "hipDeviceAttributeMaxTexture3DWidthAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DHeightAlt", + ( + "hipDeviceAttributeMaxTexture3DHeightAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DDepthAlt", + ( + "hipDeviceAttributeMaxTexture3DDepthAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrPciDomainId", + ("hipDeviceAttributePciDomainId", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrTexturePitchAlignment", + ( + "hipDeviceAttributeTexturePitchAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapWidth", + ( + "hipDeviceAttributeMaxTextureCubemapWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapLayeredWidth", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapLayeredLayers", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DWidth", + ( + "hipDeviceAttributeMaxSurface1DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DWidth", + ( + "hipDeviceAttributeMaxSurface2DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DHeight", + ( + "hipDeviceAttributeMaxSurface2DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DWidth", + ( + "hipDeviceAttributeMaxSurface3DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DHeight", + ( + "hipDeviceAttributeMaxSurface3DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DDepth", + ( + "hipDeviceAttributeMaxSurface3DDepth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DLayeredWidth", + ( + "hipDeviceAttributeMaxSurface1DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DLayeredLayers", + ( + "hipDeviceAttributeMaxSurface1DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredWidth", + ( + "hipDeviceAttributeMaxSurface2DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredHeight", + ( + "hipDeviceAttributeMaxSurface2DLayeredHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredLayers", + ( + "hipDeviceAttributeMaxSurface2DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapWidth", + ( + "hipDeviceAttributeMaxSurfaceCubemapWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapLayeredWidth", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapLayeredLayers", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLinearWidth", + ( + "hipDeviceAttributeMaxTexture1DLinearWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearWidth", + ( + "hipDeviceAttributeMaxTexture2DLinearWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearHeight", + ( + "hipDeviceAttributeMaxTexture2DLinearHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearPitch", + ( + "hipDeviceAttributeMaxTexture2DLinearPitch", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DMipmappedWidth", + ( + "hipDeviceAttributeMaxTexture2DMipmappedWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DMipmappedHeight", + ( + "hipDeviceAttributeMaxTexture2DMipmappedHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputeCapabilityMajor", + ("hipDeviceAttributeComputeCapabilityMajor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrComputeCapabilityMinor", + ("hipDeviceAttributeComputeCapabilityMinor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxTexture1DMipmappedWidth", + ( + "hipDeviceAttributeMaxTexture1DMipmappedWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrStreamPrioritiesSupported", + ( + "hipDeviceAttributeStreamPrioritiesSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrGlobalL1CacheSupported", + ( + "hipDeviceAttributeGlobalL1CacheSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrLocalL1CacheSupported", + ( + "hipDeviceAttributeLocalL1CacheSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSharedMemoryPerMultiprocessor", + ( + "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + CONV_TYPE, + API_RUNTIME, + ), + ), + ( + "cudaDevAttrMaxRegistersPerMultiprocessor", + ( + "hipDeviceAttributeMaxRegistersPerMultiprocessor", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrManagedMemory", + ( + "hipDeviceAttributeManagedMemory", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrIsMultiGpuBoard", + ("hipDeviceAttributeIsMultiGpuBoard", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMultiGpuBoardGroupID", + ( + "hipDeviceAttributeMultiGpuBoardGroupID", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrHostNativeAtomicSupported", + ( + "hipDeviceAttributeHostNativeAtomicSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrSingleToDoublePrecisionPerfRatio", + ( + "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrPageableMemoryAccess", + ( + "hipDeviceAttributePageableMemoryAccess", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrConcurrentManagedAccess", + ( + "hipDeviceAttributeConcurrentManagedAccess", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputePreemptionSupported", + ( + "hipDeviceAttributeComputePreemptionSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrCanUseHostPointerForRegisteredMem", + ( + "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaPointerGetAttributes", + ("hipPointerGetAttributes", CONV_MEM, API_RUNTIME), + ), + ( + "cudaHostGetDevicePointer", + ("hipHostGetDevicePointer", CONV_MEM, API_RUNTIME), + ), + ( + "cudaGetDeviceProperties", + ("hipGetDeviceProperties", CONV_DEVICE, API_RUNTIME), + ), + ("cudaDeviceGetPCIBusId", ("hipDeviceGetPCIBusId", CONV_DEVICE, API_RUNTIME)), + ( + "cudaDeviceGetByPCIBusId", + ("hipDeviceGetByPCIBusId", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaDeviceGetStreamPriorityRange", + ( + "hipDeviceGetStreamPriorityRange", + CONV_DEVICE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaSetValidDevices", + ("hipSetValidDevices", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevP2PAttrPerformanceRank", + ( + "hipDeviceP2PAttributePerformanceRank", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevP2PAttrAccessSupported", + ( + "hipDeviceP2PAttributeAccessSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevP2PAttrNativeAtomicSupported", + ( + "hipDeviceP2PAttributeNativeAtomicSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDeviceGetP2PAttribute", + ("hipDeviceGetP2PAttribute", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeDefault", + ("hipComputeModeDefault", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeExclusive", + ("hipComputeModeExclusive", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeProhibited", + ("hipComputeModeProhibited", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeExclusiveProcess", + ("hipComputeModeExclusiveProcess", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetDeviceFlags", + ("hipGetDeviceFlags", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaSetDeviceFlags", ("hipSetDeviceFlags", CONV_DEVICE, API_RUNTIME)), + ("cudaDeviceScheduleAuto", ("hipDeviceScheduleAuto", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceScheduleSpin", ("hipDeviceScheduleSpin", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceScheduleYield", ("hipDeviceScheduleYield", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceBlockingSync", + ("hipDeviceScheduleBlockingSync", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceScheduleBlockingSync", + ("hipDeviceScheduleBlockingSync", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceScheduleMask", + ("hipDeviceScheduleMask", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDeviceMapHost", ("hipDeviceMapHost", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceLmemResizeToMax", + ("hipDeviceLmemResizeToMax", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDeviceMask", ("hipDeviceMask", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaDeviceSetCacheConfig", + ("hipDeviceSetCacheConfig", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaDeviceGetCacheConfig", + ("hipDeviceGetCacheConfig", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncAttributes", + ("hipFuncAttributes", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaFuncAttributeMaxDynamicSharedMemorySize", + ("hipFuncAttributeMaxDynamicSharedMemorySize", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaFuncAttributePreferredSharedMemoryCarveout", + ("hipFuncAttributePreferredSharedMemoryCarveout", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaFuncSetAttribute", + ("hipFuncSetAttribute", CONV_EXEC, API_RUNTIME), + ), + ("cudaFuncSetCacheConfig", ("hipFuncSetCacheConfig", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncCachePreferNone", + ("hipFuncCachePreferNone", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncCachePreferShared", + ("hipFuncCachePreferShared", CONV_CACHE, API_RUNTIME), + ), + ("cudaFuncCachePreferL1", ("hipFuncCachePreferL1", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncCachePreferEqual", + ("hipFuncCachePreferEqual", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncGetAttributes", + ("hipFuncGetAttributes", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFuncSetSharedMemConfig", + ("hipFuncSetSharedMemConfig", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetParameterBuffer", + ("hipGetParameterBuffer", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSetDoubleForDevice", + ("hipSetDoubleForDevice", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSetDoubleForHost", + ("hipSetDoubleForHost", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaConfigureCall", + ("hipConfigureCall", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaLaunch", ("hipLaunch", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaLaunchCooperativeKernel", + ("hipLaunchCooperativeKernel", CONV_EXEC, API_RUNTIME), + ), + ("cudaLaunchHostFunc", ("hipLaunchHostFunc", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaSetupArgument", + ("hipSetupArgument", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDriverGetVersion", ("hipDriverGetVersion", CONV_VERSION, API_RUNTIME)), + ( + "cudaRuntimeGetVersion", + ("hipRuntimeGetVersion", CONV_VERSION, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyMaxPotentialBlockSize", + ("hipOccupancyMaxPotentialBlockSize", CONV_OCCUPANCY, API_RUNTIME), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeWithFlags", + ( + "hipOccupancyMaxPotentialBlockSizeWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxActiveBlocksPerMultiprocessor", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessor", + CONV_OCCUPANCY, + API_RUNTIME, + ), + ), + ( + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeVariableSMem", + ( + "hipOccupancyMaxPotentialBlockSizeVariableSMem", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeVariableSMemWithFlags", + ( + "hipOccupancyMaxPotentialBlockSizeVariableSMemWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaDeviceCanAccessPeer", ("hipDeviceCanAccessPeer", CONV_PEER, API_RUNTIME)), + ( + "cudaDeviceDisablePeerAccess", + ("hipDeviceDisablePeerAccess", CONV_PEER, API_RUNTIME), + ), + ( + "cudaDeviceEnablePeerAccess", + ("hipDeviceEnablePeerAccess", CONV_PEER, API_RUNTIME), + ), + ("cudaMemcpyPeerAsync", ("hipMemcpyPeerAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyPeer", ("hipMemcpyPeer", CONV_MEM, API_RUNTIME)), + ( + "cudaIpcMemLazyEnablePeerAccess", + ("hipIpcMemLazyEnablePeerAccess", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceSetSharedMemConfig", + ("hipDeviceSetSharedMemConfig", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaDeviceGetSharedMemConfig", + ("hipDeviceGetSharedMemConfig", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeDefault", + ("hipSharedMemBankSizeDefault", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeFourByte", + ("hipSharedMemBankSizeFourByte", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeEightByte", + ("hipSharedMemBankSizeEightByte", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaLimitStackSize", + ("hipLimitStackSize", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaLimitPrintfFifoSize", + ("hipLimitPrintfFifoSize", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaLimitMallocHeapSize", ("hipLimitMallocHeapSize", CONV_TYPE, API_RUNTIME)), + ( + "cudaLimitDevRuntimeSyncDepth", + ("hipLimitDevRuntimeSyncDepth", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaLimitDevRuntimePendingLaunchCount", + ( + "hipLimitDevRuntimePendingLaunchCount", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaDeviceGetLimit", ("hipDeviceGetLimit", CONV_DEVICE, API_RUNTIME)), + ( + "cudaProfilerInitialize", + ("hipProfilerInitialize", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaProfilerStart", ("hipProfilerStart", CONV_OTHER, API_RUNTIME)), + ("cudaProfilerStop", ("hipProfilerStop", CONV_OTHER, API_RUNTIME)), + ( + "cudaKeyValuePair", + ("hipKeyValuePair", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaCSV", ("hipCSV", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaReadModeElementType", ("hipReadModeElementType", CONV_TEX, API_RUNTIME)), + ( + "cudaReadModeNormalizedFloat", + ("hipReadModeNormalizedFloat", CONV_TEX, API_RUNTIME), + ), + ("cudaFilterModePoint", ("hipFilterModePoint", CONV_TEX, API_RUNTIME)), + ("cudaFilterModeLinear", ("hipFilterModeLinear", CONV_TEX, API_RUNTIME)), + ("cudaBindTexture", ("hipBindTexture", CONV_TEX, API_RUNTIME)), + ("cudaUnbindTexture", ("hipUnbindTexture", CONV_TEX, API_RUNTIME)), + ("cudaBindTexture2D", ("hipBindTexture2D", CONV_TEX, API_RUNTIME)), + ("cudaBindTextureToArray", ("hipBindTextureToArray", CONV_TEX, API_RUNTIME)), + ( + "cudaBindTextureToMipmappedArray", + ("hipBindTextureToMipmappedArray", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureAlignmentOffset", + ("hipGetTextureAlignmentOffset", CONV_TEX, API_RUNTIME), + ), + ("cudaGetTextureReference", ("hipGetTextureReference", CONV_TEX, API_RUNTIME)), + ( + "cudaChannelFormatKindSigned", + ("hipChannelFormatKindSigned", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindUnsigned", + ("hipChannelFormatKindUnsigned", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindFloat", + ("hipChannelFormatKindFloat", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindNone", + ("hipChannelFormatKindNone", CONV_TEX, API_RUNTIME), + ), + ("cudaCreateChannelDesc", ("hipCreateChannelDesc", CONV_TEX, API_RUNTIME)), + ("cudaGetChannelDesc", ("hipGetChannelDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceTypeArray", ("hipResourceTypeArray", CONV_TEX, API_RUNTIME)), + ( + "cudaResourceTypeMipmappedArray", + ("hipResourceTypeMipmappedArray", CONV_TEX, API_RUNTIME), + ), + ("cudaResourceTypeLinear", ("hipResourceTypeLinear", CONV_TEX, API_RUNTIME)), + ("cudaResourceTypePitch2D", ("hipResourceTypePitch2D", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatNone", ("hipResViewFormatNone", CONV_TEX, API_RUNTIME)), + ( + "cudaResViewFormatUnsignedChar1", + ("hipResViewFormatUnsignedChar1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedChar2", + ("hipResViewFormatUnsignedChar2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedChar4", + ("hipResViewFormatUnsignedChar4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar1", + ("hipResViewFormatSignedChar1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar2", + ("hipResViewFormatSignedChar2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar4", + ("hipResViewFormatSignedChar4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort1", + ("hipResViewFormatUnsignedShort1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort2", + ("hipResViewFormatUnsignedShort2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort4", + ("hipResViewFormatUnsignedShort4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort1", + ("hipResViewFormatSignedShort1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort2", + ("hipResViewFormatSignedShort2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort4", + ("hipResViewFormatSignedShort4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt1", + ("hipResViewFormatUnsignedInt1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt2", + ("hipResViewFormatUnsignedInt2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt4", + ("hipResViewFormatUnsignedInt4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt1", + ("hipResViewFormatSignedInt1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt2", + ("hipResViewFormatSignedInt2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt4", + ("hipResViewFormatSignedInt4", CONV_TEX, API_RUNTIME), + ), + ("cudaResViewFormatHalf1", ("hipResViewFormatHalf1", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatHalf2", ("hipResViewFormatHalf2", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatHalf4", ("hipResViewFormatHalf4", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat1", ("hipResViewFormatFloat1", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat2", ("hipResViewFormatFloat2", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat4", ("hipResViewFormatFloat4", CONV_TEX, API_RUNTIME)), + ( + "cudaResViewFormatUnsignedBlockCompressed1", + ("hipResViewFormatUnsignedBlockCompressed1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed2", + ("hipResViewFormatUnsignedBlockCompressed2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed3", + ("hipResViewFormatUnsignedBlockCompressed3", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed4", + ("hipResViewFormatUnsignedBlockCompressed4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed4", + ("hipResViewFormatSignedBlockCompressed4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed5", + ("hipResViewFormatUnsignedBlockCompressed5", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed5", + ("hipResViewFormatSignedBlockCompressed5", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed6H", + ("hipResViewFormatUnsignedBlockCompressed6H", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed6H", + ("hipResViewFormatSignedBlockCompressed6H", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed7", + ("hipResViewFormatUnsignedBlockCompressed7", CONV_TEX, API_RUNTIME), + ), + ("cudaAddressModeWrap", ("hipAddressModeWrap", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeClamp", ("hipAddressModeClamp", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeMirror", ("hipAddressModeMirror", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeBorder", ("hipAddressModeBorder", CONV_TEX, API_RUNTIME)), + ("cudaCreateTextureObject", ("hipCreateTextureObject", CONV_TEX, API_RUNTIME)), + ( + "cudaDestroyTextureObject", + ("hipDestroyTextureObject", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectResourceDesc", + ("hipGetTextureObjectResourceDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectResourceViewDesc", + ("hipGetTextureObjectResourceViewDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectTextureDesc", + ("hipGetTextureObjectTextureDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaBindSurfaceToArray", + ("hipBindSurfaceToArray", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSurfaceReference", + ("hipGetSurfaceReference", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeZero", + ("hipBoundaryModeZero", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeClamp", + ("hipBoundaryModeClamp", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeTrap", + ("hipBoundaryModeTrap", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFormatModeForced", + ("hipFormatModeForced", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFormatModeAuto", + ("hipFormatModeAuto", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaCreateSurfaceObject", + ("hipCreateSurfaceObject", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDestroySurfaceObject", + ("hipDestroySurfaceObject", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSurfaceObjectResourceDesc", + ( + "hipGetSurfaceObjectResourceDesc", + CONV_SURFACE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaIpcCloseMemHandle", ("hipIpcCloseMemHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcGetEventHandle", ("hipIpcGetEventHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcGetMemHandle", ("hipIpcGetMemHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcOpenEventHandle", ("hipIpcOpenEventHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcOpenMemHandle", ("hipIpcOpenMemHandle", CONV_DEVICE, API_RUNTIME)), + ( + "cudaGLGetDevices", + ("hipGLGetDevices", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaWGLGetDevice", + ("hipWGLGetDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapResources", + ("hipGraphicsMapResources", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsResourceGetMappedMipmappedArray", + ( + "hipGraphicsResourceGetMappedMipmappedArray", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsResourceGetMappedPointer", + ( + "hipGraphicsResourceGetMappedPointer", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsResourceSetMapFlags", + ( + "hipGraphicsResourceSetMapFlags", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsSubResourceGetMappedArray", + ( + "hipGraphicsSubResourceGetMappedArray", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsUnmapResources", + ("hipGraphicsUnmapResources", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsUnregisterResource", + ( + "hipGraphicsUnregisterResource", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveX", + ( + "hipGraphicsCubeFacePositiveX", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeX", + ( + "hipGraphicsCubeFaceNegativeX", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveY", + ( + "hipGraphicsCubeFacePositiveY", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeY", + ( + "hipGraphicsCubeFaceNegativeY", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveZ", + ( + "hipGraphicsCubeFacePositiveZ", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeZ", + ( + "hipGraphicsCubeFaceNegativeZ", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsMapFlagsNone", + ("hipGraphicsMapFlagsNone", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapFlagsReadOnly", + ( + "hipGraphicsMapFlagsReadOnly", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsMapFlagsWriteDiscard", + ( + "hipGraphicsMapFlagsWriteDiscard", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsNone", + ( + "hipGraphicsRegisterFlagsNone", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsReadOnly", + ( + "hipGraphicsRegisterFlagsReadOnly", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsWriteDiscard", + ( + "hipGraphicsRegisterFlagsWriteDiscard", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsSurfaceLoadStore", + ( + "hipGraphicsRegisterFlagsSurfaceLoadStore", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsTextureGather", + ( + "hipGraphicsRegisterFlagsTextureGather", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLDeviceListAll", + ("HIP_GL_DEVICE_LIST_ALL", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceListCurrentFrame", + ("HIP_GL_DEVICE_LIST_CURRENT_FRAME", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceListNextFrame", + ("HIP_GL_DEVICE_LIST_NEXT_FRAME", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLGetDevices", + ("hipGLGetDevices", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaWGLGetDevice", + ("hipWGLGetDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapFlagsNone", + ("HIP_GL_MAP_RESOURCE_FLAGS_NONE", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapFlagsReadOnly", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + CONV_GL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLMapFlagsWriteDiscard", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + CONV_GL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLMapBufferObject", + ("hipGLMapBufferObject__", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapBufferObjectAsync", + ("hipGLMapBufferObjectAsync__", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLRegisterBufferObject", + ("hipGLRegisterBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLSetBufferObjectMapFlags", + ("hipGLSetBufferObjectMapFlags", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLSetGLDevice", + ("hipGLSetGLDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnmapBufferObject", + ("hipGLUnmapBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnmapBufferObjectAsync", + ("hipGLUnmapBufferObjectAsync", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnregisterBufferObject", + ("hipGLUnregisterBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9DeviceListAll", + ("HIP_D3D9_DEVICE_LIST_ALL", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9DeviceListCurrentFrame", + ( + "HIP_D3D9_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9DeviceListNextFrame", + ( + "HIP_D3D9_DEVICE_LIST_NEXT_FRAME", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9GetDevice", + ("hipD3D9GetDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9GetDevices", + ("hipD3D9GetDevices", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9GetDirect3DDevice", + ("hipD3D9GetDirect3DDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9SetDirect3DDevice", + ("hipD3D9SetDirect3DDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D9RegisterResource", + ( + "hipGraphicsD3D9RegisterResource", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlags", + ("hipD3D9MapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapFlagsNone", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_NONE", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlagsReadOnly", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlagsWriteDiscard", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9RegisterFlagsNone", + ("HIP_D3D9_REGISTER_FLAGS_NONE", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterFlagsArray", + ("HIP_D3D9_REGISTER_FLAGS_ARRAY", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapResources", + ("hipD3D9MapResources", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterResource", + ("hipD3D9RegisterResource", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedArray", + ("hipD3D9ResourceGetMappedArray", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedPitch", + ("hipD3D9ResourceGetMappedPitch", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedPointer", + ( + "hipD3D9ResourceGetMappedPointer", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9ResourceGetMappedSize", + ("hipD3D9ResourceGetMappedSize", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetSurfaceDimensions", + ( + "hipD3D9ResourceGetSurfaceDimensions", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9ResourceSetMapFlags", + ("hipD3D9ResourceSetMapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9UnmapResources", + ("hipD3D9UnmapResources", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9UnregisterResource", + ("hipD3D9UnregisterResource", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceListAll", + ("HIP_D3D10_DEVICE_LIST_ALL", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceListCurrentFrame", + ( + "HIP_D3D10_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10DeviceListNextFrame", + ( + "HIP_D3D10_DEVICE_LIST_NEXT_FRAME", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10GetDevice", + ("hipD3D10GetDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10GetDevices", + ("hipD3D10GetDevices", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D10RegisterResource", + ( + "hipGraphicsD3D10RegisterResource", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsNone", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_NONE", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsReadOnly", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsWriteDiscard", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10RegisterFlagsNone", + ("HIP_D3D10_REGISTER_FLAGS_NONE", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterFlagsArray", + ( + "HIP_D3D10_REGISTER_FLAGS_ARRAY", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10GetDirect3DDevice", + ("hipD3D10GetDirect3DDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10MapResources", + ("hipD3D10MapResources", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterResource", + ("hipD3D10RegisterResource", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10ResourceGetMappedArray", + ( + "hipD3D10ResourceGetMappedArray", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedPitch", + ( + "hipD3D10ResourceGetMappedPitch", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedPointer", + ( + "hipD3D10ResourceGetMappedPointer", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedSize", + ("hipD3D10ResourceGetMappedSize", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10ResourceGetSurfaceDimensions", + ( + "hipD3D10ResourceGetSurfaceDimensions", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceSetMapFlags", + ("hipD3D10ResourceSetMapFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10SetDirect3DDevice", + ("hipD3D10SetDirect3DDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10UnmapResources", + ("hipD3D10UnmapResources", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10UnregisterResource", + ("hipD3D10UnregisterResource", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceListAll", + ("HIP_D3D11_DEVICE_LIST_ALL", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceListCurrentFrame", + ( + "HIP_D3D11_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11DeviceListNextFrame", + ( + "HIP_D3D11_DEVICE_LIST_NEXT_FRAME", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsVDPAURegisterOutputSurface", + ( + "hipGraphicsVDPAURegisterOutputSurface", + CONV_VDPAU, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsVDPAURegisterVideoSurface", + ( + "hipGraphicsVDPAURegisterVideoSurface", + CONV_VDPAU, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaVDPAUGetDevice", + ("hipVDPAUGetDevice", CONV_VDPAU, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaVDPAUSetVDPAUDevice", + ("hipVDPAUSetDevice", CONV_VDPAU, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamConsumerAcquireFrame", + ( + "hipEGLStreamConsumerAcquireFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamConsumerConnect", + ("hipEGLStreamConsumerConnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamConsumerConnectWithFlags", + ( + "hipEGLStreamConsumerConnectWithFlags", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamConsumerReleaseFrame", + ( + "hipEGLStreamConsumerReleaseFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamProducerConnect", + ("hipEGLStreamProducerConnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamProducerDisconnect", + ("hipEGLStreamProducerDisconnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamProducerPresentFrame", + ( + "hipEGLStreamProducerPresentFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamProducerReturnFrame", + ("hipEGLStreamProducerReturnFrame", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsEGLRegisterImage", + ("hipGraphicsEGLRegisterImage", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsResourceGetMappedEglFrame", + ( + "hipGraphicsResourceGetMappedEglFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cublasInit", ("hipblasInit", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasShutdown", + ("hipblasShutdown", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetVersion", + ("hipblasGetVersion", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetError", + ("hipblasGetError", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasAlloc", ("hipblasAlloc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasFree", ("hipblasFree", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSetKernelStream", + ("hipblasSetKernelStream", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetAtomicsMode", + ("hipblasGetAtomicsMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetAtomicsMode", + ("hipblasSetAtomicsMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetMathMode", + ("hipblasGetMathMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetMathMode", + ("hipblasSetMathMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("CUBLAS_OP_N", ("HIPBLAS_OP_N", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUBLAS_OP_T", + ("HIPBLAS_OP_T", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_OP_C", + ("HIPBLAS_OP_C", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_SUCCESS", + ("HIPBLAS_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_NOT_INITIALIZED", + ("HIPBLAS_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_ALLOC_FAILED", + ("HIPBLAS_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_INVALID_VALUE", + ("HIPBLAS_STATUS_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_MAPPING_ERROR", + ("HIPBLAS_STATUS_MAPPING_ERROR", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_EXECUTION_FAILED", + ("HIPBLAS_STATUS_EXECUTION_FAILED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_INTERNAL_ERROR", + ("HIPBLAS_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_NOT_SUPPORTED", + ("HIPBLAS_STATUS_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_ARCH_MISMATCH", + ("HIPBLAS_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_FILL_MODE_LOWER", + ("HIPBLAS_FILL_MODE_LOWER", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_FILL_MODE_UPPER", + ("HIPBLAS_FILL_MODE_UPPER", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_DIAG_NON_UNIT", + ("HIPBLAS_DIAG_NON_UNIT", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ("CUBLAS_DIAG_UNIT", ("HIPBLAS_DIAG_UNIT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_SIDE_LEFT", ("HIPBLAS_SIDE_LEFT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_SIDE_RIGHT", ("HIPBLAS_SIDE_RIGHT", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUBLAS_POINTER_MODE_HOST", + ("HIPBLAS_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_POINTER_MODE_DEVICE", + ("HIPBLAS_POINTER_MODE_DEVICE", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_ATOMICS_NOT_ALLOWED", + ( + "HIPBLAS_ATOMICS_NOT_ALLOWED", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_ATOMICS_ALLOWED", + ( + "HIPBLAS_ATOMICS_ALLOWED", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_FLOAT", + ( + "HIPBLAS_DATA_FLOAT", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_DOUBLE", + ( + "HIPBLAS_DATA_DOUBLE", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_HALF", + ("HIPBLAS_DATA_HALF", CONV_NUMERIC_LITERAL, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "CUBLAS_DATA_INT8", + ("HIPBLAS_DATA_INT8", CONV_NUMERIC_LITERAL, API_BLAS, HIP_UNSUPPORTED), + ), + ("CUBLAS_GEMM_DEFAULT", ("HIPBLAS_GEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_GEMM_DEFAULT_TENSOR_OP", ("HIPBLAS_GEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("cublasCreate", ("hipblasCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasDestroy", ("hipblasDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetVector", ("hipblasSetVector", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetVector", ("hipblasGetVector", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSetVectorAsync", + ("hipblasSetVectorAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetVectorAsync", + ("hipblasGetVectorAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSetMatrix", ("hipblasSetMatrix", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetMatrix", ("hipblasGetMatrix", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetMatrixAsync", + ("hipblasGetMatrixAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetMatrixAsync", + ("hipblasSetMatrixAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasXerbla", ("hipblasXerbla", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSnrm2", ("hipblasSnrm2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDnrm2", ("hipblasDnrm2", CONV_MATH_FUNC, API_BLAS)), + ("cublasScnrm2", ("hipblasScnrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDznrm2", ("hipblasDznrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasNrm2Ex", + ("hipblasNrm2Ex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSdot", ("hipblasSdot", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSdotBatched", + ("hipblasSdotBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDdot", ("hipblasDdot", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDdotBatched", + ("hipblasDdotBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCdotu", ("hipblasCdotu", CONV_MATH_FUNC, API_BLAS)), + ("cublasCdotc", ("hipblasCdotc", CONV_MATH_FUNC, API_BLAS)), + ("cublasZdotu", ("hipblasZdotu", CONV_MATH_FUNC, API_BLAS)), + ("cublasZdotc", ("hipblasZdotc", CONV_MATH_FUNC, API_BLAS)), + ("cublasSscal", ("hipblasSscal", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSscalBatched", + ("hipblasSscalBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDscal", ("hipblasDscal", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDscalBatched", + ("hipblasDscalBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCscal", ("hipblasCscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsscal", ("hipblasCsscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZscal", ("hipblasZscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdscal", ("hipblasZdscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSaxpy", ("hipblasSaxpy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSaxpyBatched", + ("hipblasSaxpyBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDaxpy", ("hipblasDaxpy", CONV_MATH_FUNC, API_BLAS)), + ("cublasCaxpy", ("hipblasCaxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZaxpy", ("hipblasZaxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasScopy", ("hipblasScopy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScopyBatched", + ("hipblasScopyBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDcopy", ("hipblasDcopy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDcopyBatched", + ("hipblasDcopyBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCcopy", ("hipblasCcopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZcopy", ("hipblasZcopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSswap", ("hipblasSswap", CONV_MATH_FUNC, API_BLAS)), + ("cublasDswap", ("hipblasDswap", CONV_MATH_FUNC, API_BLAS)), + ("cublasCswap", ("hipblasCswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZswap", ("hipblasZswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIsamax", ("hipblasIsamax", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamax", ("hipblasIdamax", CONV_MATH_FUNC, API_BLAS)), + ("cublasIcamax", ("hipblasIcamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIzamax", ("hipblasIzamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIsamin", ("hipblasIsamin", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamin", ("hipblasIdamin", CONV_MATH_FUNC, API_BLAS)), + ("cublasIcamin", ("hipblasIcamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIzamin", ("hipblasIzamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSasum", ("hipblasSasum", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSasumBatched", + ("hipblasSasumBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDasum", ("hipblasDasum", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDasumBatched", + ("hipblasDasumBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScasum", ("hipblasScasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDzasum", ("hipblasDzasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrot", ("hipblasSrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrot", ("hipblasDrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrot", ("hipblasCrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsrot", ("hipblasCsrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZrot", ("hipblasZrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdrot", ("hipblasZdrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotg", ("hipblasSrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotg", ("hipblasDrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrotg", ("hipblasCrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZrotg", ("hipblasZrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotm", ("hipblasSrotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotm", ("hipblasDrotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotmg", ("hipblasSrotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotmg", ("hipblasDrotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgemv", ("hipblasSgemv", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSgemvBatched", + ("hipblasSgemvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDgemv", ("hipblasDgemv", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgemv", ("hipblasCgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgemv", ("hipblasZgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgbmv", ("hipblasSgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDgbmv", ("hipblasDgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCgbmv", ("hipblasCgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgbmv", ("hipblasZgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrmv", ("hipblasStrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrmv", ("hipblasDtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrmv", ("hipblasCtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrmv", ("hipblasZtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStbmv", ("hipblasStbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtbmv", ("hipblasDtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtbmv", ("hipblasCtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtbmv", ("hipblasZtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpmv", ("hipblasStpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpmv", ("hipblasDtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpmv", ("hipblasCtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpmv", ("hipblasZtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrsv", ("hipblasStrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrsv", ("hipblasDtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrsv", ("hipblasCtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrsv", ("hipblasZtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpsv", ("hipblasStpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpsv", ("hipblasDtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpsv", ("hipblasCtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpsv", ("hipblasZtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStbsv", ("hipblasStbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtbsv", ("hipblasDtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtbsv", ("hipblasCtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtbsv", ("hipblasZtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsymv", ("hipblasSsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsymv", ("hipblasDsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsymv", ("hipblasCsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsymv", ("hipblasZsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChemv", ("hipblasChemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhemv", ("hipblasZhemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsbmv", ("hipblasSsbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsbmv", ("hipblasDsbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChbmv", ("hipblasChbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhbmv", ("hipblasZhbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspmv", ("hipblasSspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspmv", ("hipblasDspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpmv", ("hipblasChpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpmv", ("hipblasZhpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSger", ("hipblasSger", CONV_MATH_FUNC, API_BLAS)), + ("cublasDger", ("hipblasDger", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgeru", ("hipblasCgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCgerc", ("hipblasCgerc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgeru", ("hipblasZgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgerc", ("hipblasZgerc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr", ("hipblasSsyr", CONV_MATH_FUNC, API_BLAS)), + ("cublasDsyr", ("hipblasDsyr", CONV_MATH_FUNC, API_BLAS)), + ("cublasCher", ("hipblasCher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher", ("hipblasZher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr", ("hipblasSspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr", ("hipblasDspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr", ("hipblasChpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr", ("hipblasZhpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr2", ("hipblasSsyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr2", ("hipblasDsyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher2", ("hipblasCher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher2", ("hipblasZher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr2", ("hipblasSspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr2", ("hipblasDspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr2", ("hipblasChpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr2", ("hipblasZhpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSgemmBatched", + ("hipblasSgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgemmBatched", + ("hipblasDgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasHgemmBatched", + ("hipblasHgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgemmStridedBatched", + ("hipblasSgemmStridedBatched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasDgemmStridedBatched", + ("hipblasDgemmStridedBatched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasHgemmStridedBatched", + ("hipblasHgemmStridedBatched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasCgemmBatched", + ("hipblasCgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3mBatched", + ("hipblasCgemm3mBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemmBatched", + ("hipblasZgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemmStridedBatched", + ( + "hipblasCgemmStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasCgemm3mStridedBatched", + ( + "hipblasCgemm3mStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasZgemmStridedBatched", + ( + "hipblasZgemmStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasHgemmStridedBatched", + ( + "hipblasHgemmStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ("cublasSgemm", ("hipblasSgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemm", ("hipblasDgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgemm", ("hipblasCgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasZgemm", ("hipblasZgemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasHgemm", ("hipblasHgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasSsyrk", ("hipblasSsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyrk", ("hipblasDsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyrk", ("hipblasCsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyrk", ("hipblasZsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCherk", ("hipblasCherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZherk", ("hipblasZherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr2k", ("hipblasSsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr2k", ("hipblasDsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyr2k", ("hipblasCsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyr2k", ("hipblasZyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyrkx", ("hipblasSsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyrkx", ("hipblasDsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyrkx", ("hipblasCsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyrkx", ("hipblasZsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher2k", ("hipblasCher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher2k", ("hipblasZher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCherkx", ("hipblasCherkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZherkx", ("hipblasZherkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsymm", ("hipblasSsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsymm", ("hipblasDsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsymm", ("hipblasCsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsymm", ("hipblasZsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChemm", ("hipblasChemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhemm", ("hipblasZhemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrsm", ("hipblasStrsm", CONV_MATH_FUNC, API_BLAS)), + ("cublasDtrsm", ("hipblasDtrsm", CONV_MATH_FUNC, API_BLAS)), + ("cublasCtrsm", ("hipblasCtrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrsm", ("hipblasZtrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasStrsmBatched", + ("hipblasStrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsmBatched", + ("hipblasDtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsmBatched", + ("hipblasCtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsmBatched", + ("hipblasZtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasStrmm", ("hipblasStrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrmm", ("hipblasDtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrmm", ("hipblasCtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrmm", ("hipblasZtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgeam", ("hipblasSgeam", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgeam", ("hipblasDgeam", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgeam", ("hipblasCgeam", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgeam", ("hipblasZgeam", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSgetrfBatched", + ("hipblasSgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetrfBatched", + ("hipblasDgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetrfBatched", + ("hipblasCgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetrfBatched", + ("hipblasZgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgetriBatched", + ("hipblasSgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetriBatched", + ("hipblasDgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetriBatched", + ("hipblasCgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetriBatched", + ("hipblasZgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgetrsBatched", + ("hipblasSgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetrsBatched", + ("hipblasDgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetrsBatched", + ("hipblasCgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetrsBatched", + ("hipblasZgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsmBatched", + ("hipblasStrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsmBatched", + ("hipblasDtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsmBatched", + ("hipblasCtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsmBatched", + ("hipblasZtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSmatinvBatched", + ("hipblasSmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDmatinvBatched", + ("hipblasDmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCmatinvBatched", + ("hipblasCmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZmatinvBatched", + ("hipblasZmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgeqrfBatched", + ("hipblasSgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgeqrfBatched", + ("hipblasDgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgeqrfBatched", + ("hipblasCgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgeqrfBatched", + ("hipblasZgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgelsBatched", + ("hipblasSgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgelsBatched", + ("hipblasDgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgelsBatched", + ("hipblasCgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgelsBatched", + ("hipblasZgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSdgmm", ("hipblasSdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDdgmm", ("hipblasDdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCdgmm", ("hipblasCdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdgmm", ("hipblasZdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpttr", ("hipblasStpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpttr", ("hipblasDtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpttr", ("hipblasCtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpttr", ("hipblasZtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrttp", ("hipblasStrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrttp", ("hipblasDtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrttp", ("hipblasCtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrttp", ("hipblasZtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCreate_v2", ("hipblasCreate_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDestroy_v2", ("hipblasDestroy_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetVersion_v2", + ("hipblasGetVersion_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSetWorkspace", ("hipblasSetWorkspace", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetStream", ("hipblasSetStream", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetStream", ("hipblasGetStream", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetStream_v2", ("hipblasSetStream_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetStream_v2", ("hipblasGetStream_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetPointerMode", + ("hipblasGetPointerMode", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasSetPointerMode", + ("hipblasSetPointerMode", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasGetPointerMode_v2", + ("hipblasGetPointerMode_v2", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasSetPointerMode_v2", + ("hipblasSetPointerMode_v2", CONV_MATH_FUNC, API_BLAS), + ), + ("cublasSgemv_v2", ("hipblasSgemv_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemv_v2", ("hipblasDgemv_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgemv_v2", + ("hipblasCgemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemv_v2", + ("hipblasZgemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgbmv_v2", + ("hipblasSgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgbmv_v2", + ("hipblasDgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgbmv_v2", + ("hipblasCgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgbmv_v2", + ("hipblasZgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrmv_v2", + ("hipblasStrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrmv_v2", + ("hipblasDtrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrmv_v2", + ("hipblasCtrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrmv_v2", + ("hipblasZtrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStbmv_v2", + ("hipblasStbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtbmv_v2", + ("hipblasDtbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtbmv_v2", + ("hipblasCtbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtbmv_v2", + ("hipblasZtbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStpmv_v2", + ("hipblasStpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtpmv_v2", + ("hipblasDtpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtpmv_v2", + ("hipblasCtpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtpmv_v2", + ("hipblasZtpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsv_v2", + ("hipblasStrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsv_v2", + ("hipblasDtrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsv_v2", + ("hipblasCtrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsv_v2", + ("hipblasZtrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStpsv_v2", + ("hipblasStpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtpsv_v2", + ("hipblasDtpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtpsv_v2", + ("hipblasCtpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtpsv_v2", + ("hipblasZtpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStbsv_v2", + ("hipblasStbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtbsv_v2", + ("hipblasDtbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtbsv_v2", + ("hipblasCtbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtbsv_v2", + ("hipblasZtbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsymv_v2", + ("hipblasSsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsymv_v2", + ("hipblasDsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsymv_v2", + ("hipblasCsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsymv_v2", + ("hipblasZsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChemv_v2", + ("hipblasChemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhemv_v2", + ("hipblasZhemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsbmv_v2", + ("hipblasSsbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsbmv_v2", + ("hipblasDsbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChbmv_v2", + ("hipblasChbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhbmv_v2", + ("hipblasZhbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSspmv_v2", + ("hipblasSspmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDspmv_v2", + ("hipblasDspmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChpmv_v2", + ("hipblasChpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhpmv_v2", + ("hipblasZhpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSger_v2", ("hipblasSger_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDger_v2", ("hipblasDger_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgeru_v2", + ("hipblasCgeru_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgerc_v2", + ("hipblasCergc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgeru_v2", + ("hipblasZgeru_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgerc_v2", + ("hipblasZgerc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSsyr_v2", ("hipblasSsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr_v2", ("hipblasDsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyr_v2", ("hipblasCsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyr_v2", ("hipblasZsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher_v2", ("hipblasCher_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher_v2", ("hipblasZher_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr_v2", ("hipblasSspr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr_v2", ("hipblasDspr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr_v2", ("hipblasChpr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr_v2", ("hipblasZhpr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSsyr2_v2", + ("hipblasSsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyr2_v2", + ("hipblasDsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyr2_v2", + ("hipblasCsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyr2_v2", + ("hipblasZsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCher2_v2", + ("hipblasCher2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZher2_v2", + ("hipblasZher2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSspr2_v2", + ("hipblasSspr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDspr2_v2", + ("hipblasDspr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChpr2_v2", + ("hipblasChpr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhpr2_v2", + ("hipblasZhpr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSgemm_v2", ("hipblasSgemm_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemm_v2", ("hipblasDgemm_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgemm_v2", + ("hipblasCgemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3m", + ("hipblasCgemm3m", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3mEx", + ("hipblasCgemm3mEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemm_v2", + ("hipblasZgemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemm3m", + ("hipblasZgemm3m", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgemmEx", + ("hipblasSgemmEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasGemmEx", ("hipblasGemmEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasGemmBatchedEx", + ("hipblasGemmBatchedEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGemmStridedBatchedEx", + ("hipblasGemmStridedBatchedEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemmEx", + ("hipblasCgemmEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasUint8gemmBias", + ("hipblasUint8gemmBias", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsyrk_v2", + ("hipblasSsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyrk_v2", + ("hipblasDsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrk_v2", + ("hipblasCsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyrk_v2", + ("hipblasZsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrkEx", + ("hipblasCsyrkEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrk3mEx", + ("hipblasCsyrk3mEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherk_v2", + ("hipblasCherk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherkEx", + ("hipblasCherkEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherk3mEx", + ("hipblasCherk3mEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZherk_v2", + ("hipblasZherk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsyr2k_v2", + ("hipblasSsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyr2k_v2", + ("hipblasDsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyr2k_v2", + ("hipblasCsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyr2k_v2", + ("hipblasZsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCher2k_v2", + ("hipblasCher2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZher2k_v2", + ("hipblasZher2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsymm_v2", + ("hipblasSsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsymm_v2", + ("hipblasDsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsymm_v2", + ("hipblasCsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsymm_v2", + ("hipblasZsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChemm_v2", + ("hipblasChemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhemm_v2", + ("hipblasZhemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsm_v2", + ("hipblasStrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsm_v2", + ("hipblasDtrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsm_v2", + ("hipblasCtrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsm_v2", + ("hipblasZtrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrmm_v2", + ("hipblasStrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrmm_v2", + ("hipblasDtrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrmm_v2", + ("hipblasCtrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrmm_v2", + ("hipblasZtrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSnrm2_v2", ("hipblasSnrm2_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDnrm2_v2", ("hipblasDnrm2_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScnrm2_v2", + ("hipblasScnrm2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDznrm2_v2", + ("hipblasDznrm2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDotEx", ("hipblasDotEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDotcEx", ("hipblasDotcEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSdot_v2", ("hipblasSdot_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDdot_v2", ("hipblasDdot_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCdotu_v2", + ("hipblasCdotu_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCdotc_v2", + ("hipblasCdotc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdotu_v2", + ("hipblasZdotu_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdotc_v2", + ("hipblasZdotc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScalEx", ("hipblasScalEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSscal_v2", ("hipblasSscal_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDscal_v2", ("hipblasDscal_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCscal_v2", + ("hipblasCscal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsscal_v2", + ("hipblasCsscal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZscal_v2", + ("hipblasZcsal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdscal_v2", + ("hipblasZdscal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasAxpyEx", ("hipblasAxpyEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSaxpy_v2", ("hipblasSaxpy_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDaxpy_v2", ("hipblasDaxpy_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCaxpy_v2", + ("hipblasCaxpy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZaxpy_v2", + ("hipblasZaxpy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScopy_v2", ("hipblasScopy_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDcopy_v2", ("hipblasDcopy_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCcopy_v2", + ("hipblasCcopy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZcopy_v2", + ("hipblasZcopy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSswap_v2", ("hipblasSswap_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDswap_v2", ("hipblasDswap_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCswap_v2", + ("hipblasCswap_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZswap_v2", + ("hipblasZswap_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasIsamax_v2", ("hipblasIsamax_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamax_v2", ("hipblasIdamax_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasIcamax_v2", + ("hipblasIcamax_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasIzamax_v2", + ("hipblasIzamax_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasIsamin_v2", ("hipblasIsamin_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamin_v2", ("hipblasIdamin_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasIcamin_v2", + ("hipblasIcamin_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasIzamin_v2", + ("hipblasIzamin_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSasum_v2", ("hipblasSasum_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDasum_v2", ("hipblasDasum_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScasum_v2", + ("hipblasScasum_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDzasum_v2", + ("hipblasDzasum_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSrot_v2", ("hipblasSrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrot_v2", ("hipblasDrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrot_v2", ("hipblasCrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasCsrot_v2", + ("hipblasCsrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasZrot_v2", ("hipblasZrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasZdrot_v2", + ("hipblasZdrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotg_v2", + ("hipblasSrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotg_v2", + ("hipblasDrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCrotg_v2", + ("hipblasCrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZrotg_v2", + ("hipblasZrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotm_v2", + ("hipblasSrotm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotm_v2", + ("hipblasDrotm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotmg_v2", + ("hipblasSrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotmg_v2", + ("hipblasDrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasComputeType_t", + ("hipblasComputeType_t", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32I", + ("HIPBLAS_COMPUTE_32I", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32F", + ("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32F_FAST_TF32", + ("HIPBLAS_COMPUTE_32F_FAST_TF32", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_64F", + ("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS) + ), + ("cublasLtEpilogue_t", ("hipblasLtEpilogue_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_DEFAULT", ("HIPBLASLT_EPILOGUE_DEFAULT", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_RELU", ("HIPBLASLT_EPILOGUE_RELU", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_BIAS", ("HIPBLASLT_EPILOGUE_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_RELU_BIAS", ("HIPBLASLT_EPILOGUE_RELU_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_GELU", ("HIPBLASLT_EPILOGUE_GELU", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_GELU_BIAS", ("HIPBLASLT_EPILOGUE_GELU_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtHandle_t", ("hipblasLtHandle_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDesc_t", ("hipblasLtMatmulDesc_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescOpaque_t", ("hipblasLtMatmulDescOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescAttributes_t", ("hipblasLtMatmulDescAttributes_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_TRANSA", ("HIPBLASLT_MATMUL_DESC_TRANSA", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_TRANSB", ("HIPBLASLT_MATMUL_DESC_TRANSB", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_EPILOGUE", ("HIPBLASLT_MATMUL_DESC_EPILOGUE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_BIAS_POINTER", ("HIPBLASLT_MATMUL_DESC_BIAS_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutCreate", ("hipblasLtMatrixLayoutCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutDestroy", ("hipblasLtMatrixLayoutDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutSetAttribute", ("hipblasLtMatrixLayoutSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT", ("HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET", ("HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreference_t", ("hipblasLtMatmulPreference_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceOpaque_t", ("hipblasLtMatmulPreferenceOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceAttributes_t", ("hipblasLtMatmulPreferenceAttributes_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_PREF_SEARCH_MODE", ("HIPBLASLT_MATMUL_PREF_SEARCH_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", ("HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulAlgo_t", ("hipblasLtMatmulAlgo_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulHeuristicResult_t", ("hipblasLtMatmulHeuristicResult_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtCreate", ("hipblasLtCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtDestroy", ("hipblasLtDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescCreate", ("hipblasLtMatmulDescCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescDestroy", ("hipblasLtMatmulDescDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescSetAttribute", ("hipblasLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceCreate", ("hipblasLtMatmulPreferenceCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceDestroy", ("hipblasLtMatmulPreferenceDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceSetAttribute", ("hipblasLtMatmulPreferenceSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulAlgoGetHeuristic", ("hipblasLtMatmulAlgoGetHeuristic", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmul", ("hipblasLtMatmul", CONV_MATH_FUNC, API_BLAS)), + ( + "CURAND_STATUS_SUCCESS", + ("HIPRAND_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_VERSION_MISMATCH", + ("HIPRAND_STATUS_VERSION_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_NOT_INITIALIZED", + ("HIPRAND_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_ALLOCATION_FAILED", + ("HIPRAND_STATUS_ALLOCATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_TYPE_ERROR", + ("HIPRAND_STATUS_TYPE_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_OUT_OF_RANGE", + ("HIPRAND_STATUS_OUT_OF_RANGE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_LENGTH_NOT_MULTIPLE", + ("HIPRAND_STATUS_LENGTH_NOT_MULTIPLE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED", + ( + "HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED", + CONV_NUMERIC_LITERAL, + API_RAND, + ), + ), + ( + "CURAND_STATUS_LAUNCH_FAILURE", + ("HIPRAND_STATUS_LAUNCH_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_PREEXISTING_FAILURE", + ("HIPRAND_STATUS_PREEXISTING_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_INITIALIZATION_FAILED", + ("HIPRAND_STATUS_INITIALIZATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_ARCH_MISMATCH", + ("HIPRAND_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_INTERNAL_ERROR", + ("HIPRAND_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + ), + ("CURAND_RNG_TEST", ("HIPRAND_RNG_TEST", CONV_NUMERIC_LITERAL, API_RAND)), + ( + "mtgp32dc_params_fast_11213", + ("mtgp32dc_params_fast_11213", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_DEFAULT", + ("HIPRAND_RNG_PSEUDO_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_XORWOW", + ("HIPRAND_RNG_PSEUDO_XORWOW", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MRG32K3A", + ("HIPRAND_RNG_PSEUDO_MRG32K3A", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MTGP32", + ("HIPRAND_RNG_PSEUDO_MTGP32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MT19937", + ("HIPRAND_RNG_PSEUDO_MT19937", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_PHILOX4_32_10", + ("HIPRAND_RNG_PSEUDO_PHILOX4_32_10", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_DEFAULT", + ("HIPRAND_RNG_QUASI_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SOBOL32", + ("HIPRAND_RNG_QUASI_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SCRAMBLED_SOBOL32", + ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SOBOL64", + ("HIPRAND_RNG_QUASI_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SCRAMBLED_SOBOL64", + ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "curand_ORDERING_PSEUDO_BEST", + ( + "HIPRAND_ORDERING_PSEUDO_BEST", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_PSEUDO_DEFAULT", + ( + "HIPRAND_ORDERING_PSEUDO_DEFAULT", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_PSEUDO_SEEDED", + ( + "HIPRAND_ORDERING_PSEUDO_SEEDED", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_QUASI_DEFAULT", + ( + "HIPRAND_ORDERING_QUASI_DEFAULT", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_DIRECTION_VECTORS_32_JOEKUO6", + ( + "HIPRAND_DIRECTION_VECTORS_32_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", + ( + "HIPRAND_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_DIRECTION_VECTORS_64_JOEKUO6", + ( + "HIPRAND_DIRECTION_VECTORS_64_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", + ( + "HIPRAND_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_CHOOSE_BEST", + ("HIPRAND_CHOOSE_BEST", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_ITR", + ("HIPRAND_ITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_KNUTH", + ("HIPRAND_KNUTH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_HITR", + ("HIPRAND_HITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ("curand_M1", ("HIPRAND_M1", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED)), + ("curand_M2", ("HIPRAND_M2", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED)), + ( + "curand_BINARY_SEARCH", + ("HIPRAND_BINARY_SEARCH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DISCRETE_GAUSS", + ("HIPRAND_DISCRETE_GAUSS", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_REJECTION", + ("HIPRAND_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DEVICE_API", + ("HIPRAND_DEVICE_API", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_FAST_REJECTION", + ("HIPRAND_FAST_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_3RD", + ("HIPRAND_3RD", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DEFINITION", + ("HIPRAND_DEFINITION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_POISSON", + ("HIPRAND_POISSON", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ("curandCreateGenerator", ("hiprandCreateGenerator", CONV_MATH_FUNC, API_RAND)), + ( + "curandCreateGeneratorHost", + ("hiprandCreateGeneratorHost", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandCreatePoissonDistribution", + ("hiprandCreatePoissonDistribution", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandDestroyDistribution", + ("hiprandDestroyDistribution", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandDestroyGenerator", + ("hiprandDestroyGenerator", CONV_MATH_FUNC, API_RAND), + ), + ("curandGenerate", ("hiprandGenerate", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateLogNormal", + ("hiprandGenerateLogNormal", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGenerateLogNormalDouble", + ("hiprandGenerateLogNormalDouble", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGenerateLongLong", + ("hiprandGenerateLongLong", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("curandGenerateNormal", ("hiprandGenerateNormal", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateNormalDouble", + ("hiprandGenerateNormalDouble", CONV_MATH_FUNC, API_RAND), + ), + ("curandGeneratePoisson", ("hiprandGeneratePoisson", CONV_MATH_FUNC, API_RAND)), + ("curandGenerateSeeds", ("hiprandGenerateSeeds", CONV_MATH_FUNC, API_RAND)), + ("curandGenerateUniform", ("hiprandGenerateUniform", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateUniformDouble", + ("hiprandGenerateUniformDouble", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGetDirectionVectors32", + ("hiprandGetDirectionVectors32", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetDirectionVectors64", + ("hiprandGetDirectionVectors64", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetProperty", + ("hiprandGetProperty", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetScrambleConstants32", + ( + "hiprandGetScrambleConstants32", + CONV_MATH_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curandGetScrambleConstants64", + ( + "hiprandGetScrambleConstants64", + CONV_MATH_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ("curandGetVersion", ("hiprandGetVersion", CONV_MATH_FUNC, API_RAND)), + ( + "curandSetGeneratorOffset", + ("hiprandSetGeneratorOffset", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandSetGeneratorOrdering", + ("hiprandSetGeneratorOrdering", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandSetPseudoRandomGeneratorSeed", + ("hiprandSetPseudoRandomGeneratorSeed", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandSetQuasiRandomGeneratorDimensions", + ("hiprandSetQuasiRandomGeneratorDimensions", CONV_MATH_FUNC, API_RAND), + ), + ("curandSetStream", ("hiprandSetStream", CONV_MATH_FUNC, API_RAND)), + ("curand", ("hiprand", CONV_DEVICE_FUNC, API_RAND)), + ("curand4", ("hiprand4", CONV_DEVICE_FUNC, API_RAND)), + ("curand_init", ("hiprand_init", CONV_DEVICE_FUNC, API_RAND)), + ("curand_log_normal", ("hiprand_log_normal", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal_double", + ("hiprand_log_normal_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_log_normal2", ("hiprand_log_normal2", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal2_double", + ("hiprand_log_normal2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_log_normal4", ("hiprand_log_normal4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal4_double", + ("hiprand_log_normal4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curand_mtgp32_single", + ("hiprand_mtgp32_single", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_mtgp32_single_specific", + ( + "hiprand_mtgp32_single_specific", + CONV_DEVICE_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_mtgp32_specific", + ("hiprand_mtgp32_specific", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("curand_normal", ("hiprand_normal", CONV_DEVICE_FUNC, API_RAND)), + ( + "curandMakeMTGP32Constants", + ("hiprandMakeMTGP32Constants", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curandMakeMTGP32KernelState", + ("hiprandMakeMTGP32KernelState", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_normal_double", ("hiprand_normal_double", CONV_DEVICE_FUNC, API_RAND)), + ("curand_normal2", ("hiprand_normal2", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_normal2_double", + ("hiprand_normal2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_normal4", ("hiprand_normal4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_normal4_double", + ("hiprand_normal4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_uniform", ("hiprand_uniform", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_uniform_double", + ("hiprand_uniform_double", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curand_uniform2_double", + ("hiprand_uniform2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_uniform4", ("hiprand_uniform4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_uniform4_double", + ("hiprand_uniform4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_discrete", ("hiprand_discrete", CONV_DEVICE_FUNC, API_RAND)), + ("curand_discrete4", ("hiprand_discrete4", CONV_DEVICE_FUNC, API_RAND)), + ("curand_poisson", ("hiprand_poisson", CONV_DEVICE_FUNC, API_RAND)), + ("curand_poisson4", ("hiprand_poisson4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_Philox4x32_10", + ("hiprand_Philox4x32_10", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("mtgp32_kernel_params", ("mtgp32_kernel_params_t", CONV_MATH_FUNC, API_RAND)), + ("CUFFT_FORWARD", ("HIPFFT_FORWARD", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUFFT_INVERSE", ("HIPFFT_BACKWARD", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUFFT_COMPATIBILITY_DEFAULT", + ( + "HIPFFT_COMPATIBILITY_DEFAULT", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ("cuComplex", ("hipComplex", CONV_TYPE, API_BLAS)), + ("cuDoubleComplex", ("hipDoubleComplex", CONV_TYPE, API_BLAS)), + ("cufftResult_t", ("hipfftResult_t", CONV_TYPE, API_FFT)), + ("cufftResult", ("hipfftResult", CONV_TYPE, API_FFT)), + ("CUFFT_SUCCESS", ("HIPFFT_SUCCESS", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_PLAN", ("HIPFFT_INVALID_PLAN", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_ALLOC_FAILED", ("HIPFFT_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_TYPE", ("HIPFFT_INVALID_TYPE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_INVALID_VALUE", + ("HIPFFT_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INTERNAL_ERROR", + ("HIPFFT_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("CUFFT_EXEC_FAILED", ("HIPFFT_EXEC_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_SETUP_FAILED", ("HIPFFT_SETUP_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_SIZE", ("HIPFFT_INVALID_SIZE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_UNALIGNED_DATA", + ("HIPFFT_UNALIGNED_DATA", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INCOMPLETE_PARAMETER_LIST", + ("HIPFFT_INCOMPLETE_PARAMETER_LIST", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INVALID_DEVICE", + ("HIPFFT_INVALID_DEVICE", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("CUFFT_PARSE_ERROR", ("HIPFFT_PARSE_ERROR", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_NO_WORKSPACE", ("HIPFFT_NO_WORKSPACE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_NOT_IMPLEMENTED", + ("HIPFFT_NOT_IMPLEMENTED", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_LICENSE_ERROR", + ("HIPFFT_LICENSE_ERROR", CONV_NUMERIC_LITERAL, API_FFT, HIP_UNSUPPORTED), + ), + ( + "CUFFT_NOT_SUPPORTED", + ("HIPFFT_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("cufftType_t", ("hipfftType_t", CONV_TYPE, API_FFT)), + ("cufftType", ("hipfftType", CONV_TYPE, API_FFT)), + ("CUFFT_R2C", ("HIPFFT_R2C", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_C2R", ("HIPFFT_C2R", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_C2C", ("HIPFFT_C2C", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_D2Z", ("HIPFFT_D2Z", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_Z2D", ("HIPFFT_Z2D", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_Z2Z", ("HIPFFT_Z2Z", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "cufftCompatibility_t", + ("hipfftCompatibility_t", CONV_TYPE, API_FFT, HIP_UNSUPPORTED), + ), + ( + "cufftCompatibility", + ("hipfftCompatibility", CONV_TYPE, API_FFT, HIP_UNSUPPORTED), + ), + ( + "CUFFT_COMPATIBILITY_FFTW_PADDING", + ( + "HIPFFT_COMPATIBILITY_FFTW_PADDING", + CONV_NUMERIC_LITERAL, + API_FFT, + HIP_UNSUPPORTED, + ), + ), + ("cufftReal", ("hipfftReal", CONV_TYPE, API_FFT)), + ("cufftDoubleReal", ("hipfftDoubleReal", CONV_TYPE, API_FFT)), + ("cufftComplex", ("hipfftComplex", CONV_TYPE, API_FFT)), + ("cufftDoubleComplex", ("hipfftDoubleComplex", CONV_TYPE, API_FFT)), + ("cufftHandle", ("hipfftHandle", CONV_TYPE, API_FFT)), + ("cufftPlan1d", ("hipfftPlan1d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlan2d", ("hipfftPlan2d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlan3d", ("hipfftPlan3d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlanMany", ("hipfftPlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan1d", ("hipfftMakePlan1d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan2d", ("hipfftMakePlan2d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan3d", ("hipfftMakePlan3d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlanMany", ("hipfftMakePlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlanMany64", ("hipfftMakePlanMany64", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSizeMany64", ("hipfftGetSizeMany64", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate1d", ("hipfftEstimate1d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate2d", ("hipfftEstimate2d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate3d", ("hipfftEstimate3d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimateMany", ("hipfftEstimateMany", CONV_MATH_FUNC, API_FFT)), + ("cufftCreate", ("hipfftCreate", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize1d", ("hipfftGetSize1d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize2d", ("hipfftGetSize2d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize3d", ("hipfftGetSize3d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSizeMany", ("hipfftGetSizeMany", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize", ("hipfftGetSize", CONV_MATH_FUNC, API_FFT)), + ("cufftSetWorkArea", ("hipfftSetWorkArea", CONV_MATH_FUNC, API_FFT)), + ( + "cufftSetAutoAllocation", + ("hipfftSetAutoAllocation", CONV_MATH_FUNC, API_FFT), + ), + ("cufftXtExec", ("hipfftXtExec", CONV_MATH_FUNC, API_FFT)), + ("cufftXtMakePlanMany", ("hipfftXtMakePlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftExecC2C", ("hipfftExecC2C", CONV_MATH_FUNC, API_FFT)), + ("cufftExecR2C", ("hipfftExecR2C", CONV_MATH_FUNC, API_FFT)), + ("cufftExecC2R", ("hipfftExecC2R", CONV_MATH_FUNC, API_FFT)), + ("cufftExecZ2Z", ("hipfftExecZ2Z", CONV_MATH_FUNC, API_FFT)), + ("cufftExecD2Z", ("hipfftExecD2Z", CONV_MATH_FUNC, API_FFT)), + ("cufftExecZ2D", ("hipfftExecZ2D", CONV_MATH_FUNC, API_FFT)), + ("cufftSetStream", ("hipfftSetStream", CONV_MATH_FUNC, API_FFT)), + ("cufftDestroy", ("hipfftDestroy", CONV_MATH_FUNC, API_FFT)), + ("cufftGetVersion", ("hipfftGetVersion", CONV_MATH_FUNC, API_FFT)), + ( + "cufftGetProperty", + ("hipfftGetProperty", CONV_MATH_FUNC, API_FFT, HIP_UNSUPPORTED), + ), + ("nvrtcResult", ("hiprtcResult", CONV_TYPE, API_RTC)), + ("NVRTC_SUCCESS", ("HIPRTC_SUCCESS", CONV_TYPE, API_RTC)), + ( + "NVRTC_ERROR_OUT_OF_MEMORY", + ("HIPRTC_ERROR_OUT_OF_MEMORY", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_PROGRAM_CREATION_FAILURE", + ("HIPRTC_ERROR_PROGRAM_CREATION_FAILURE", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INVALID_INPUT", + ("HIPRTC_ERROR_INVALID_INPUT", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INVALID_PROGRAM", + ("HIPRTC_ERROR_INVALID_PROGRAM", CONV_TYPE, API_RTC), + ), + ("NVRTC_ERROR_COMPILATION", ("HIPRTC_ERROR_COMPILATION", CONV_TYPE, API_RTC)), + ( + "NVRTC_ERROR_BUILTIN_OPERATION_FAILURE", + ("HIPRTC_ERROR_BUILTIN_OPERATION_FAILURE", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", + ("HIPRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID", + ("HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INTERNAL_ERROR", + ("HIPRTC_ERROR_INTERNAL_ERROR", CONV_TYPE, API_RTC), + ), + ("nvrtcGetErrorString", ("hiprtcGetErrorString", CONV_JIT, API_RTC)), + ("nvrtcVersion", ("hiprtcVersion", CONV_JIT, API_RTC)), + ("nvrtcProgram", ("hiprtcProgram", CONV_TYPE, API_RTC)), + ("nvrtcAddNameExpression", ("hiprtcAddNameExpression", CONV_JIT, API_RTC)), + ("nvrtcCompileProgram", ("hiprtcCompileProgram", CONV_JIT, API_RTC)), + ("nvrtcCreateProgram", ("hiprtcCreateProgram", CONV_JIT, API_RTC)), + ("nvrtcDestroyProgram", ("hiprtcDestroyProgram", CONV_JIT, API_RTC)), + ("nvrtcGetLoweredName", ("hiprtcGetLoweredName", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLog", ("hiprtcGetProgramLog", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLogSize", ("hiprtcGetProgramLogSize", CONV_JIT, API_RTC)), + ("nvrtcGetPTX", ("hiprtcGetCode", CONV_JIT, API_RTC)), + ("nvrtcGetPTXSize", ("hiprtcGetCodeSize", CONV_JIT, API_RTC)), + ("thrust::cuda", ("thrust::hip", CONV_MATH_FUNC, API_BLAS)), + ( + "cudaCpuDeviceId", + ("hipCpuDeviceId", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + # The caffe2 directory does a string match; pytorch does a word-boundary match. + # Patterns such as 'cub::' will not match for pytorch. + # We list all current uses of cub symbols for this reason. + ("cub::", ("hipcub::", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ArgMax", ("hipcub::ArgMax", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ArgMin", ("hipcub::ArgMin", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_SCAN_WARP_SCANS", ("hipcub::BLOCK_SCAN_WARP_SCANS", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_REDUCE_WARP_REDUCTIONS", ("hipcub::BLOCK_REDUCE_WARP_REDUCTIONS", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_STORE_WARP_TRANSPOSE", ("hipcub::BLOCK_STORE_WARP_TRANSPOSE", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_LOAD_DIRECT", ("hipcub::BLOCK_LOAD_DIRECT", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_STORE_DIRECT", ("hipcub::BLOCK_STORE_DIRECT", CONV_SPECIAL_FUNC, API_RUNTIME)), + ( + "cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY", + ("hipcub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY", CONV_SPECIAL_FUNC, API_RUNTIME) + ), + ("cub::BlockReduce", ("hipcub::BlockReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockScan", ("hipcub::BlockScan", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockLoad", ("hipcub::BlockLoad", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockStore", ("hipcub::BlockStore", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockRakingLayout", ("hipcub::BlockRakingLayout", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockRadixSort", ("hipcub::BlockRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Uninitialized", ("hipcub::Uninitialized", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::RowMajorTid", ("hipcub::RowMajorTid", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::CachingDeviceAllocator", ("hipcub::CachingDeviceAllocator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::CountingInputIterator", ("hipcub::CountingInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceRadixSort", ("hipcub::DeviceRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceReduce", ("hipcub::DeviceReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceRunLengthEncode", ("hipcub::DeviceRunLengthEncode", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceScan", ("hipcub::DeviceScan", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceSegmentedRadixSort", ("hipcub::DeviceSegmentedRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceSegmentedReduce", ("hipcub::DeviceSegmentedReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceSelect", ("hipcub::DeviceSelect", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::FpLimits", ("hipcub::FpLimits", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::KeyValuePair", ("hipcub::KeyValuePair", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Max", ("hipcub::Max", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Min", ("hipcub::Min", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Sum", ("hipcub::Sum", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Log2", ("hipcub::Log2", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::LaneId", ("hipcub::LaneId", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::WarpMask", ("hipcub::WarpMask", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ShuffleIndex", ("hipcub::ShuffleIndex", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ShuffleDown", ("hipcub::ShuffleDown", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ArgIndexInputIterator", ("hipcub::ArgIndexInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::TransformInputIterator", ("hipcub::TransformInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::WarpReduce", ("hipcub::WarpReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::CTA_SYNC", ("hipcub::CTA_SYNC", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("nvtxMark", ("roctxMark", CONV_OTHER, API_ROCTX)), + ("nvtxMarkA", ("roctxMarkA", CONV_OTHER, API_ROCTX)), + ("nvtxRangePushA", ("roctxRangePushA", CONV_OTHER, API_ROCTX)), + ("nvtxRangePop", ("roctxRangePop", CONV_OTHER, API_ROCTX)), + ("nvtxRangeStartA", ("roctxRangeStartA", CONV_OTHER, API_ROCTX)), + ("nvtxRangeEnd", ("roctxRangeStop", CONV_OTHER, API_ROCTX)), + ("nvtxRangeId_t", ("int", CONV_OTHER, API_ROCTX)), + ("nvmlReturn_t", ("rsmi_status_t", CONV_OTHER, API_ROCMSMI)), + ("NVML_SUCCESS", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)), + ("NVML_P2P_CAPS_INDEX_READ", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)), + ("NVML_P2P_STATUS_OK", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)), + ("NVML_ERROR_INSUFFICIENT_SIZE", ("RSMI_STATUS_INSUFFICIENT_SIZE", CONV_OTHER, API_ROCMSMI)), + ("nvmlDevice_t", ("uint32_t", CONV_OTHER, API_ROCMSMI)), + ("nvmlGpuP2PStatus_t", ("bool", CONV_OTHER, API_ROCMSMI)), + ("nvmlProcessInfo_t", ("rsmi_process_info_t", CONV_OTHER, API_ROCMSMI)), + ("nvmlGpuP2PCapsIndex_t", ("uint32_t", CONV_OTHER, API_ROCMSMI)), + ] +) + +CUDA_SPECIAL_MAP = collections.OrderedDict( + [ + # SPARSE + ("cusparseStatus_t", ("hipsparseStatus_t", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseHandle_t", ("hipsparseHandle_t", CONV_MATH_FUNC, API_SPECIAL)), + ("cuComplex", ("hipComplex", CONV_TYPE, API_SPECIAL)), + ("cuDoubleComplex", ("hipDoubleComplex", CONV_TYPE, API_SPECIAL)), + ( + "CUSPARSE_POINTER_MODE_HOST", + ("HIPSPARSE_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ("cusparseOperation_t", ("hipsparseOperation_t", CONV_TYPE, API_SPECIAL)), + ( + "cusparseCreateMatDescr", + ("hipsparseCreateMatDescr", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseCreate", ("hipsparseCreate", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseDestroyMatDescr", + ("hipsparseDestroyMatDescr", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseDestroy", ("hipsparseDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXcoo2csr", ("hipsparseXcoo2csr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseMatDescr_t", ("hipsparseMatDescr_t", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDiagType_t", ("hipsparseDiagType_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_DIAG_TYPE_UNIT", ("HIPSPARSE_DIAG_TYPE_UNIT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_DIAG_TYPE_NON_UNIT", ("HIPSPARSE_DIAG_TYPE_NON_UNIT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseSetMatDiagType", ("hipsparseSetMatDiagType", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseFillMode_t", ("hipsparseFillMode_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_FILL_MODE_UPPER", ("HIPSPARSE_FILL_MODE_UPPER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_FILL_MODE_LOWER", ("HIPSPARSE_FILL_MODE_LOWER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseSetMatFillMode", ("hipsparseSetMatFillMode", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDirection_t", ("hipsparseDirection_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_DIRECTION_ROW", ("HIPSPARSE_DIRECTION_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_DIRECTION_COLUMN", ("HIPSPARSE_DIRECTION_COLUMN", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseSolvePolicy_t", ("hipsparseSolvePolicy_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_SOLVE_POLICY_NO_LEVEL", ("HIPSPARSE_SOLVE_POLICY_NO_LEVEL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SOLVE_POLICY_USE_LEVEL", ("HIPSPARSE_SOLVE_POLICY_USE_LEVEL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseCreateBsrsv2Info", ("hipsparseCreateBsrsv2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateBsrsm2Info", ("hipsparseCreateBsrsm2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyBsrsv2Info", ("hipsparseDestroyBsrsv2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyBsrsm2Info", ("hipsparseDestroyBsrsm2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrmm", ("hipsparseSbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrmm", ("hipsparseDbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrmm", ("hipsparseCbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrmm", ("hipsparseZbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrmv", ("hipsparseSbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrmv", ("hipsparseDbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrmv", ("hipsparseCbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrmv", ("hipsparseZbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsv2_bufferSize", ("hipsparseSbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsv2_bufferSize", ("hipsparseDbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsv2_bufferSize", ("hipsparseCbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsv2_bufferSize", ("hipsparseZbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsv2_analysis", ("hipsparseSbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsv2_analysis", ("hipsparseDbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsv2_analysis", ("hipsparseCbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsv2_analysis", ("hipsparseZbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsv2_solve", ("hipsparseSbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsv2_solve", ("hipsparseDbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsv2_solve", ("hipsparseCbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsv2_solve", ("hipsparseZbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsm2_bufferSize", ("hipsparseSbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsm2_bufferSize", ("hipsparseDbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsm2_bufferSize", ("hipsparseCbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsm2_bufferSize", ("hipsparseZbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsm2_analysis", ("hipsparseSbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsm2_analysis", ("hipsparseDbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsm2_analysis", ("hipsparseCbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsm2_analysis", ("hipsparseZbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsm2_solve", ("hipsparseSbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsm2_solve", ("hipsparseDbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsm2_solve", ("hipsparseCbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsm2_solve", ("hipsparseZbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrmm2", ("hipsparseScsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrmm2", ("hipsparseDcsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCcsrmm2", ("hipsparseCcsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZcsrmm2", ("hipsparseZcsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrmm", ("hipsparseScsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrmm", ("hipsparseDcsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseXcsrsort_bufferSizeExt", + ("hipsparseXcsrsort_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseCreateCsrgemm2Info", ("hipsparseCreateCsrgemm2Info", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseDestroyCsrgemm2Info", + ("hipsparseDestroyCsrgemm2Info", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseXcsrgemm2Nnz", ("hipsparseXcsrgemm2Nnz", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgemm2_bufferSizeExt", ("hipsparseDcsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgemm2_bufferSizeExt", ("hipsparseScsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgemm2", ("hipsparseDcsrgemm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgemm2", ("hipsparseScsrgemm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSetPointerMode", ("hipsparseSetPointerMode", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXcsrgeam2Nnz", ("hipsparseXcsrgeam2Nnz", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgeam2_bufferSizeExt", ("hipsparseScsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgeam2_bufferSizeExt", ("hipsparseDcsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCcsrgeam2_bufferSizeExt", ("hipsparseCcsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZcsrgeam2_bufferSizeExt", ("hipsparseZcsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgeam2", ("hipsparseScsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgeam2", ("hipsparseDcsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCcsrgeam2", ("hipsparseCcsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZcsrgeam2", ("hipsparseZcsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXcsrsort", ("hipsparseXcsrsort", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXbsrsm2_zeroPivot", ("hipsparseXbsrsm2_zeroPivot", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXbsrsv2_zeroPivot", ("hipsparseXbsrsv2_zeroPivot", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseXcoosort_bufferSizeExt", + ("hipsparseXcoosort_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL), + ), + ( + "cusparseXcoosortByRow", + ("hipsparseXcoosortByRow", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseSetStream", ("hipsparseSetStream", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseCreateIdentityPermutation", + ("hipsparseCreateIdentityPermutation", CONV_MATH_FUNC, API_SPECIAL), + ), + ( + "cusparseSetMatIndexBase", + ("hipsparseSetMatIndexBase", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseSetMatType", ("hipsparseSetMatType", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMV", ("hipsparseSpMV", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMV_bufferSize", ("hipsparseSpMV_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMM", ("hipsparseSpMM", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMM_bufferSize", ("hipsparseSpMM_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateDnMat", ("hipsparseCreateDnMat", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCsrSetStridedBatch", ("hipsparseCsrSetStridedBatch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateDnVec", ("hipsparseCreateDnVec", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateCsr", ("hipsparseCreateCsr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyDnMat", ("hipsparseDestroyDnMat", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyDnVec", ("hipsparseDestroyDnVec", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroySpMat", ("hipsparseDestroySpMat", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_destroyDescr", ("hipsparseSpGEMM_destroyDescr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateCoo", ("hipsparseCreateCoo", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateCsr", ("hipsparseCreateCsr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_createDescr", ("hipsparseSpGEMM_createDescr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_copy", ("hipsparseSpGEMM_copy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSDDMM_bufferSize", ("hipsparseSDDMM_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSDDMM_preprocess", ("hipsparseSDDMM_preprocess", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSDDMM", ("hipsparseSDDMM", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_compute", ("hipsparseSpGEMM_compute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_workEstimation", ("hipsparseSpGEMM_workEstimation", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMatGetSize", ("hipsparseSpMatGetSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCsrSetPointers", ("hipsparseCsrSetPointers", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMVAlg_t", ("hipsparseSpMVAlg_t", CONV_TYPE, API_SPECIAL)), + ("cusparseSpMMAlg_t", ("hipsparseSpMMAlg_t", CONV_TYPE, API_SPECIAL)), + ("cusparseIndexType_t", ("hipsparseIndexType_t", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseMatDescr", ("hipsparseMatDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseDnMatDescr", ("hipsparseDnMatDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseDnVecDescr", ("hipsparseDnVecDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseSpMatDescr", ("hipsparseSpMatDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseSpGEMMDescr", ("hipsparseSpGEMMDescr", CONV_TYPE, API_SPECIAL)), + ("cusparseDnMatDescr_t", ("hipsparseDnMatDescr_t", CONV_TYPE, API_SPECIAL)), + ("cusparseDnVecDescr_t", ("hipsparseDnVecDescr_t", CONV_TYPE, API_SPECIAL)), + ("cusparseSpMatDescr_t", ("hipsparseSpMatDescr_t", CONV_TYPE, API_SPECIAL)), + ("cusparseSpGEMMDescr_t", ("hipsparseSpGEMMDescr_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_INDEX_32I", ("HIPSPARSE_INDEX_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_INDEX_64I", ("HIPSPARSE_INDEX_64I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_MV_ALG_DEFAULT", ("HIPSPARSE_MV_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_MM_ALG_DEFAULT", ("HIPSPARSE_MM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_COO_ALG1", ("HIPSPARSE_SPMM_COO_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_COO_ALG2", ("HIPSPARSE_SPMM_COO_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_SPMM_CSR_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG2", ("HIPSPARSE_SPMM_CSR_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG3", ("HIPSPARSE_SPMM_CSR_ALG3", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COOMV_ALG", ("HIPSPARSE_COOMV_ALG", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_CSRMM_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPGEMM_DEFAULT", ("HIPSPARSE_SPGEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SDDMM_ALG_DEFAULT", ("HIPSPARSE_SDDMM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ( + "CUSPARSE_STATUS_SUCCESS", + ("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_NOT_INITIALIZED", + ("HIPSPARSE_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_ALLOC_FAILED", + ("HIPSPARSE_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_INVALID_VALUE", + ("HIPSPARSE_STATUS_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_MAPPING_ERROR", + ("HIPSPARSE_STATUS_MAPPING_ERROR", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_EXECUTION_FAILED", + ("HIPSPARSE_STATUS_EXECUTION_FAILED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_INTERNAL_ERROR", + ("HIPSPARSE_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", + ( + "HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", + CONV_NUMERIC_LITERAL, + API_SPECIAL, + ), + ), + ( + "CUSPARSE_STATUS_ARCH_MISMATCH", + ("HIPSPARSE_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_ZERO_PIVOT", + ("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_OPERATION_TRANSPOSE", + ("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_OPERATION_NON_TRANSPOSE", + ("HIPSPARSE_OPERATION_NON_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE", + ( + "HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE", + CONV_NUMERIC_LITERAL, + API_SPECIAL, + ), + ), + ( + "CUSPARSE_INDEX_BASE_ZERO", + ("HIPSPARSE_INDEX_BASE_ZERO", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_INDEX_BASE_ONE", + ("HIPSPARSE_INDEX_BASE_ONE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_MATRIX_TYPE_GENERAL", + ("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + # SparseLt + ("cuSPARSELt", ("hipSPARSELt", CONV_TYPE, API_SPECIAL)), + ("AT_CUSPARSELT_ENABLED", ("AT_HIPSPARSELT_ENABLED", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_SPARSITY_50_PERCENT", ("HIPSPARSELT_SPARSITY_50_PERCENT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseComputeType", ("hipsparseLtComputetype_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32F", ("HIPSPARSELT_COMPUTE_32F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_16F", ("HIPSPARSELT_COMPUTE_16F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32I", ("HIPSPARSELT_COMPUTE_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_TF32", ("HIPSPARSELT_COMPUTE_TF32", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_BIAS_POINTER", ("HIPSPARSELT_MATMUL_BIAS_POINTER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_DEFAULT", ("HIPSPARSELT_MATMUL_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_CONFIG_ID", ("HIPSPARSELT_MATMUL_ALG_CONFIG_ID", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", ("HIPSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseLtHandle_t", ("hipsparseLtHandle_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatDescriptor_t", ("hipsparseLtMatDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtInit", ("hipsparseLtInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSpMMACompressedSize2", ("hipsparseLtSpMMACompressedSize2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSpMMACompress2", ("hipsparseLtSpMMACompress2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptor_t", ("hipsparseLtMatmulDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulPlan_t", ("hipsparseLtMatmulPlan_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulAlgSelection_t", ("hipsparseLtMatmulAlgSelection_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtDenseDescriptorInit", ("hipsparseLtDenseDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptorInit", ("hipsparseLtMatmulDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescSetAttribute", ("hipsparseLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSelectionInit", ("hipsparseLtMatmulAlgSelectionInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSetAttribute", ("hipsparseLtMatmulAlgSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanInit", ("hipsparseLtMatmulPlanInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulGetWorkspace", ("hipsparseLtMatmulGetWorkspace", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulSearch", ("hipsparseLtMatmulSearch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgGetAttribute", ("hipsparseLtMatmulAlgGetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmul", ("hipsparseLtMatmul", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatDescriptorDestroy", ("hipsparseLtMatDescriptorDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanDestroy", ("hipsparseLtMatmulPlanDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseGetErrorString", ("hipsparseGetErrorString", CONV_MATH_FUNC, API_SPECIAL)), + # SOLVER + ("cublasOperation_t", ("hipsolverOperation_t", CONV_TYPE, API_SPECIAL)), + ("CUBLAS_OP_N", ("HIPSOLVER_OP_N", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ( + "CUBLAS_OP_T", + ("HIPSOLVER_OP_T", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUBLAS_OP_C", + ("HIPSOLVER_OP_C", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ("cublasFillMode_t", ("hipsolverFillMode_t", CONV_TYPE, API_SPECIAL)), + ( + "CUBLAS_FILL_MODE_LOWER", + ("HIPSOLVER_FILL_MODE_LOWER", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUBLAS_FILL_MODE_UPPER", + ("HIPSOLVER_FILL_MODE_UPPER", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ("cublasSideMode_t", ("hipsolverSideMode_t", CONV_TYPE, API_SPECIAL)), + ("CUBLAS_SIDE_LEFT", ("HIPSOLVER_SIDE_LEFT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUBLAS_SIDE_RIGHT", ("HIPSOLVER_SIDE_RIGHT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + + ("cusolverEigMode_t", ("hipsolverEigMode_t", CONV_TYPE, API_SPECIAL)), + ("CUSOLVER_EIG_MODE_VECTOR", ("HIPSOLVER_EIG_MODE_VECTOR", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSOLVER_EIG_MODE_NOVECTOR", ("HIPSOLVER_EIG_MODE_NOVECTOR", CONV_NUMERIC_LITERAL, API_SPECIAL)), + + ("syevjInfo_t", ("hipsolverSyevjInfo_t", CONV_TYPE, API_SPECIAL)), + ("cusolverDnCreateSyevjInfo", ("hipsolverDnCreateSyevjInfo", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnXsyevjSetSortEig", ("hipsolverDnXsyevjSetSortEig", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnDestroySyevjInfo", ("hipsolverDnDestroySyevjInfo", CONV_MATH_FUNC, API_SPECIAL)), + + ("gesvdjInfo_t", ("hipsolverGesvdjInfo_t", CONV_TYPE, API_SPECIAL)), + ("cusolverDnCreateGesvdjInfo", ("hipsolverDnCreateGesvdjInfo", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnXgesvdjSetSortEig", ("hipsolverDnXgesvdjSetSortEig", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnDestroyGesvdjInfo", ("hipsolverDnDestroyGesvdjInfo", CONV_MATH_FUNC, API_SPECIAL)), + + ("cusolverDnHandle_t", ("hipsolverDnHandle_t", CONV_TYPE, API_SPECIAL)), + ("cusolverDnCreate", ("hipsolverDnCreate", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnSetStream", ("hipsolverDnSetStream", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnDestroy", ("hipsolverDnDestroy", CONV_MATH_FUNC, API_SPECIAL)), + + # from aten/src/ATen/native/hip/linalg/HIPSolver.cpp + ('cusolverDnParams_t', ('hipsolverDnParams_t', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgeqrf', ('hipsolverDnCgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgeqrf_bufferSize', ('hipsolverDnCgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvd', ('hipsolverDnCgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvd_bufferSize', ('hipsolverDnCgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdj', ('hipsolverDnCgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdjBatched', ('hipsolverDnCgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdjBatched_bufferSize', ('hipsolverDnCgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdj_bufferSize', ('hipsolverDnCgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgetrf', ('hipsolverDnCgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgetrf_bufferSize', ('hipsolverDnCgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgetrs', ('hipsolverDnCgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevd', ('hipsolverDnCheevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevd_bufferSize', ('hipsolverDnCheevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevj', ('hipsolverDnCheevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevjBatched', ('hipsolverDnCheevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevjBatched_bufferSize', ('hipsolverDnCheevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevj_bufferSize', ('hipsolverDnCheevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrf', ('hipsolverDnCpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrfBatched', ('hipsolverDnCpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrf_bufferSize', ('hipsolverDnCpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrs', ('hipsolverDnCpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrsBatched', ('hipsolverDnCpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCungqr', ('hipsolverDnCungqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCungqr_bufferSize', ('hipsolverDnCungqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCunmqr', ('hipsolverDnCunmqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCunmqr_bufferSize', ('hipsolverDnCunmqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgeqrf', ('hipsolverDnDgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgeqrf_bufferSize', ('hipsolverDnDgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvd', ('hipsolverDnDgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvd_bufferSize', ('hipsolverDnDgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdj', ('hipsolverDnDgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdjBatched', ('hipsolverDnDgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdjBatched_bufferSize', ('hipsolverDnDgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdj_bufferSize', ('hipsolverDnDgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgetrf', ('hipsolverDnDgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgetrf_bufferSize', ('hipsolverDnDgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgetrs', ('hipsolverDnDgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDorgqr', ('hipsolverDnDorgqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDorgqr_bufferSize', ('hipsolverDnDorgqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDormqr', ('hipsolverDnDormqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDormqr_bufferSize', ('hipsolverDnDormqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrf', ('hipsolverDnDpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrfBatched', ('hipsolverDnDpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrf_bufferSize', ('hipsolverDnDpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrs', ('hipsolverDnDpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrsBatched', ('hipsolverDnDpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevd', ('hipsolverDnDsyevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevd_bufferSize', ('hipsolverDnDsyevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevj', ('hipsolverDnDsyevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevjBatched', ('hipsolverDnDsyevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevjBatched_bufferSize', ('hipsolverDnDsyevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevj_bufferSize', ('hipsolverDnDsyevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgeqrf', ('hipsolverDnSgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgeqrf_bufferSize', ('hipsolverDnSgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvd', ('hipsolverDnSgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvd_bufferSize', ('hipsolverDnSgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdj', ('hipsolverDnSgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdjBatched', ('hipsolverDnSgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdjBatched_bufferSize', ('hipsolverDnSgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdj_bufferSize', ('hipsolverDnSgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgetrf', ('hipsolverDnSgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgetrf_bufferSize', ('hipsolverDnSgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgetrs', ('hipsolverDnSgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSorgqr', ('hipsolverDnSorgqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSorgqr_bufferSize', ('hipsolverDnSorgqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSormqr', ('hipsolverDnSormqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSormqr_bufferSize', ('hipsolverDnSormqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrf', ('hipsolverDnSpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrfBatched', ('hipsolverDnSpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrf_bufferSize', ('hipsolverDnSpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrs', ('hipsolverDnSpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrsBatched', ('hipsolverDnSpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevd', ('hipsolverDnSsyevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevd_bufferSize', ('hipsolverDnSsyevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevj', ('hipsolverDnSsyevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevjBatched', ('hipsolverDnSsyevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevjBatched_bufferSize', ('hipsolverDnSsyevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevj_bufferSize', ('hipsolverDnSsyevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXgeqrf', ('hipsolverDnXgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXgeqrf_bufferSize', ('hipsolverDnXgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXpotrf', ('hipsolverDnXpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXpotrf_bufferSize', ('hipsolverDnXpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXpotrs', ('hipsolverDnXpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXsyevd', ('hipsolverDnXsyevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXsyevd_bufferSize', ('hipsolverDnXsyevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgeqrf', ('hipsolverDnZgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgeqrf_bufferSize', ('hipsolverDnZgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvd', ('hipsolverDnZgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvd_bufferSize', ('hipsolverDnZgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdj', ('hipsolverDnZgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdjBatched', ('hipsolverDnZgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdjBatched_bufferSize', ('hipsolverDnZgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdj_bufferSize', ('hipsolverDnZgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgetrf', ('hipsolverDnZgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgetrf_bufferSize', ('hipsolverDnZgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgetrs', ('hipsolverDnZgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevd', ('hipsolverDnZheevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevd_bufferSize', ('hipsolverDnZheevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevj', ('hipsolverDnZheevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevjBatched', ('hipsolverDnZheevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevjBatched_bufferSize', ('hipsolverDnZheevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevj_bufferSize', ('hipsolverDnZheevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrf', ('hipsolverDnZpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrfBatched', ('hipsolverDnZpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrf_bufferSize', ('hipsolverDnZpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrs', ('hipsolverDnZpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrsBatched', ('hipsolverDnZpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZungqr', ('hipsolverDnZungqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZungqr_bufferSize', ('hipsolverDnZungqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZunmqr', ('hipsolverDnZunmqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZunmqr_bufferSize', ('hipsolverDnZunmqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + + # sytrf + ('cusolverDnDsytrf_bufferSize', ('hipsolverDnDsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsytrf_bufferSize', ('hipsolverDnSsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZsytrf_bufferSize', ('hipsolverDnZsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCsytrf_bufferSize', ('hipsolverDnCsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsytrf', ('hipsolverDnDsytrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsytrf', ('hipsolverDnSsytrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZsytrf', ('hipsolverDnZsytrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCsytrf', ('hipsolverDnCsytrf', CONV_MATH_FUNC, API_SPECIAL)), + + # gesdva strided + ( + 'cusolverDnSgesvdaStridedBatched_bufferSize', + ('hipsolverDnSgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ( + 'cusolverDnDgesvdaStridedBatched_bufferSize', + ('hipsolverDnDgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ( + 'cusolverDnCgesvdaStridedBatched_bufferSize', + ('hipsolverDnCgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ( + 'cusolverDnZgesvdaStridedBatched_bufferSize', + ('hipsolverDnZgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ('cusolverDnSgesvdaStridedBatched', ('hipsolverDnSgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdaStridedBatched', ('hipsolverDnDgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdaStridedBatched', ('hipsolverDnCgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdaStridedBatched', ('hipsolverDnZgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + + # gesvdj SetXXX + ('cusolverDnXgesvdjSetTolerance', ('hipsolverDnXgesvdjSetTolerance', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXgesvdjSetMaxSweeps', ('hipsolverDnXgesvdjSetMaxSweeps', CONV_MATH_FUNC, API_SPECIAL)), + ] +) + +PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict( + [ + ("USE_CUDA", ("USE_ROCM", API_PYTORCH)), + ("TORCH_CUDA_CPP_API", ("TORCH_HIP_CPP_API", API_PYTORCH)), + ("TORCH_CUDA_CU_API", ("TORCH_HIP_API", API_PYTORCH)), + ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)), + ("cudaHostAllocator", ("hipHostAllocator", API_PYTORCH)), + ("cudaDeviceAllocator", ("hipDeviceAllocator", API_PYTORCH)), + ("define MAX_NUM_BLOCKS 200", ("define MAX_NUM_BLOCKS 64", API_PYTORCH)), + ("cuda::CUDAGuard", ("hip::HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAGuard", ("HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::OptionalCUDAGuard", + ("hip::OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ("OptionalCUDAGuard", ("OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::CUDAStreamGuard", + ("hip::HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ("CUDAStreamGuard", ("HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::OptionalCUDAStreamGuard", + ("hip::OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "OptionalCUDAStreamGuard", + ("OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::CUDAMultiStreamGuard", + ("hip::HIPMultiStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "CUDAMultiStreamGuard", + ("HIPMultiStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + # Only get needs to be transformed this way; all the other ones can go + # straight to the normal versions hip::HIPCachingAllocator + ( + "cuda::CUDACachingAllocator::get", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH), + ), + ( + "CUDACachingAllocator::get", + ("HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::recordStream", + ( + "hip::HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "CUDACachingAllocator::recordStream", + ( + "HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "cuda::CUDAAllocator::recordStream", + ( + "hip::HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "CUDAAllocator::recordStream", + ( + "HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ("cuda::CUDAStream", ("hip::HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAStream", ("HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getStreamFromPool", + ("hip::getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH), + ), + ("getStreamFromPool", ("getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getDefaultCUDAStream", + ("hip::getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::getStreamFromExternal", + ("hip::getStreamFromExternalMasqueradingAsCUDA", API_PYTORCH), + ), + ("getStreamFromExternal", ("getStreamFromExternalMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getDefaultCUDAStream", + ("hip::getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "getDefaultCUDAStream", + ("getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::getCurrentCUDAStream", + ("hip::getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "getCurrentCUDAStream", + ("getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::setCurrentCUDAStream", + ("hip::setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "setCurrentCUDAStream", + ("setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "ATen/cudnn/Handle.h", + ("ATen/miopen/Handle.h", API_PYTORCH), + ), + # TODO: Undo this special-case; see the header for motivation behind this + # hack. It's VERY important this is only applied to PyTorch HIPify. + ( + "c10/cuda/CUDAGuard.h", + ("ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDACachingAllocator.h", + ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDAStream.h", + ("ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h", API_PYTORCH), + ), + ("gloo/cuda.h", ("gloo/hip.h", API_PYTORCH)), + ( + "gloo/cuda_allreduce_halving_doubling.h", + ("gloo/hip_allreduce_halving_doubling.h", API_PYTORCH), + ), + ( + "gloo/cuda_allreduce_halving_doubling_pipelined.h", + ("gloo/hip_allreduce_halving_doubling_pipelined.h", API_PYTORCH), + ), + ("gloo/cuda_allreduce_ring.h", ("gloo/hip_allreduce_ring.h", API_PYTORCH)), + ("gloo/cuda_allreduce_ring_chunked.h", ("gloo/hip_allreduce_ring_chunked.h", API_PYTORCH)), + ( + "gloo/cuda_broadcast_one_to_all.h", + ("gloo/hip_broadcast_one_to_all.h", API_PYTORCH), + ), + ( + "gloo::CudaAllreduceHalvingDoublingPipelined", + ("gloo::HipAllreduceHalvingDoublingPipelined", API_PYTORCH), + ), + ( + "gloo::CudaAllreduceRingChunked", + ("gloo::HipAllreduceRingChunked", API_PYTORCH), + ), + ("gloo::CudaBroadcastOneToAll", ("gloo::HipBroadcastOneToAll", API_PYTORCH)), + ("gloo::CudaHostWorkspace", ("gloo::HipHostWorkspace", API_PYTORCH)), + ("gloo::CudaDeviceWorkspace", ("gloo::HipDeviceWorkspace", API_PYTORCH)), + ("CUDNN_RNN_RELU", ("miopenRNNRELU", API_PYTORCH)), + ("CUDNN_RNN_TANH", ("miopenRNNTANH", API_PYTORCH)), + ("CUDNN_LSTM", ("miopenLSTM", API_PYTORCH)), + ("CUDNN_GRU", ("miopenGRU", API_PYTORCH)), + ("cudnnRNNMode_t", ("miopenRNNMode_t", API_PYTORCH)), + ("magma_queue_create_from_cuda", ("magma_queue_create_from_hip", API_PYTORCH)), + ] +) + +CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict( + [ + ("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)), + ("PYTORCH_CUDA_ALLOC_CONF", ("PYTORCH_CUDA_ALLOC_CONF", API_CAFFE2)), + ("cuda_stream", ("hip_stream", API_CAFFE2)), + # if the header is a native hip folder (under hip directory), + # there is no need to add a hip path to it; the trie in hipify script + # takes this mapping order to forbid further replacement + ("/hip/", ("/hip/", API_CAFFE2)), + ("/context_gpu", ("/hip/context_gpu", API_CAFFE2)), + ("/common_gpu", ("/hip/common_gpu", API_CAFFE2)), + ("/cuda_nccl_gpu", ("/hip/hip_nccl_gpu", API_CAFFE2)), + ("/mixed_utils", ("/hip/mixed_utils", API_CAFFE2)), + ("/operator_fallback_gpu", ("/hip/operator_fallback_gpu", API_CAFFE2)), + ( + "/spatial_batch_norm_op_impl", + ("/hip/spatial_batch_norm_op_impl", API_CAFFE2), + ), + ( + "/recurrent_network_executor_gpu", + ("/hip/recurrent_network_executor_gpu", API_CAFFE2), + ), + ( + "/generate_proposals_op_util_nms_gpu", + ("/hip/generate_proposals_op_util_nms_gpu", API_CAFFE2), + ), + ("/max_pool_with_index_gpu", ("/hip/max_pool_with_index_gpu", API_CAFFE2)), + ("/THCCachingAllocator_gpu", ("/hip/THCCachingAllocator_gpu", API_CAFFE2)), + ("/top_k_heap_selection", ("/hip/top_k_heap_selection", API_CAFFE2)), + ("/top_k_radix_selection", ("/hip/top_k_radix_selection", API_CAFFE2)), + ("/GpuAtomics", ("/hip/GpuAtomics", API_CAFFE2)), + ("/GpuDefs", ("/hip/GpuDefs", API_CAFFE2)), + ("/GpuScanUtils", ("/hip/GpuScanUtils", API_CAFFE2)), + ("/GpuBitonicSort", ("/hip/GpuBitonicSort", API_CAFFE2)), + ("/math/reduce.cuh", ("/math/hip/reduce.cuh", API_CAFFE2)), + ("/sgd/adagrad_fused_op_gpu.cuh", ("/sgd/hip/adagrad_fused_op_gpu.cuh", API_CAFFE2)), + ("/operators/segment_reduction_op_gpu.cuh", ("/operators/hip/segment_reduction_op_gpu.cuh", API_CAFFE2)), + ("/gather_op.cuh", ("/hip/gather_op.cuh", API_CAFFE2)), + ("caffe2/core/common_cudnn.h", ("caffe2/core/hip/common_miopen.h", API_CAFFE2)), + ("REGISTER_CUDA_OPERATOR", ("REGISTER_HIP_OPERATOR", API_CAFFE2)), + ("CUDA_1D_KERNEL_LOOP", ("HIP_1D_KERNEL_LOOP", API_CAFFE2)), + ("CUDAContext", ("HIPContext", API_CAFFE2)), + ("CAFFE_CUDA_NUM_THREADS", ("CAFFE_HIP_NUM_THREADS", API_CAFFE2)), + ("HasCudaGPU", ("HasHipGPU", API_CAFFE2)), + ("__expf", ("expf", API_CAFFE2)), + ("CUBLAS_ENFORCE", ("HIPBLAS_ENFORCE", API_CAFFE2)), + ("CUBLAS_CHECK", ("HIPBLAS_CHECK", API_CAFFE2)), + ("cublas_handle", ("hipblas_handle", API_CAFFE2)), + ("CURAND_ENFORCE", ("HIPRAND_ENFORCE", API_CAFFE2)), + ("CURAND_CHECK", ("HIPRAND_CHECK", API_CAFFE2)), + ("curandGenerateUniform", ("hiprandGenerateUniform", API_CAFFE2)), + ("curand_generator", ("hiprand_generator", API_CAFFE2)), + ("CaffeCudaGetDevice", ("CaffeHipGetDevice", API_CAFFE2)), + # do not rename CUDA_KERNEL_ASSERT, lazyInitCUDA in caffe2 sources + # the ordered dict guarantees this pattern will match first, before "CUDA" + ("CUDA_KERNEL_ASSERT", ("CUDA_KERNEL_ASSERT", API_CAFFE2)), + ("lazyInitCUDA", ("lazyInitCUDA", API_CAFFE2)), + ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_CAFFE2)), + ("CUDA", ("HIP", API_CAFFE2)), + ("Cuda", ("Hip", API_CAFFE2)), + ("cuda_", ("hip_", API_CAFFE2)), + ("_cuda", ("_hip", API_CAFFE2)), + ("CUDNN", ("MIOPEN", API_CAFFE2)), + ("CuDNN", ("MIOPEN", API_CAFFE2)), + ("cudnn", ("miopen", API_CAFFE2)), + ("namespace cuda", ("namespace hip", API_CAFFE2)), + ("cuda::CUDAGuard", ("hip::HIPGuard", API_CAFFE2)), + ("cuda::OptionalCUDAGuard", ("hip::OptionalHIPGuard", API_CAFFE2)), + ("cuda::CUDAStreamGuard", ("hip::HIPStreamGuard", API_CAFFE2)), + ("cuda::OptionalCUDAStreamGuard", ("hip::OptionalHIPStreamGuard", API_CAFFE2)), + ("c10/cuda/CUDAGuard.h", ("c10/hip/HIPGuard.h", API_CAFFE2)), + ("gloo/cuda", ("gloo/hip", API_CAFFE2)), + ] +) + +# We must tread very carefully here. Blanket conversions like are done +# in CAFFE2_SPECIFIC_MAPPINGS are not presently supported on PyTorch, +# because a regex for CUDA will also match a filename like CUDAGuard.h, +# but the HIPIFY script doesn't presently move the file and so the substitution +# will be invalid. Instead, we specifically list out every identifier +# and file from c10/cuda which may be used externally, and do substitutions this +# way. +# +# NB: if you want a transformation to ONLY apply to the c10/ directory, +# put it as API_CAFFE2 +C10_MAPPINGS = collections.OrderedDict( + [ + ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)), + ("CUDA_LAUNCH_BLOCKING=1", ("AMD_SERIALIZE_KERNEL=3", API_C10)), + ("CUDA_LAUNCH_BLOCKING", ("AMD_SERIALIZE_KERNEL", API_C10)), + ("cuda::compat::", ("hip::compat::", API_C10)), + ("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)), + ("c10/cuda/CUDADeviceAssertion.h", ("c10/hip/HIPDeviceAssertion.h", API_C10)), + ("c10/cuda/CUDADeviceAssertionHost.h", ("c10/hip/HIPDeviceAssertionHost.h", API_C10)), + ("c10/cuda/CUDAException.h", ("c10/hip/HIPException.h", API_C10)), + ("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)), + ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), + ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)), + ("c10/cuda/CUDAMiscFunctions.h", ("c10/hip/HIPMiscFunctions.h", API_C10)), + ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)), + ("c10/cuda/CUDAGraphsC10Utils.h", ("c10/hip/HIPGraphsC10Utils.h", API_C10)), + ("c10/cuda/CUDAAllocatorConfig.h", ("c10/hip/HIPAllocatorConfig.h", API_C10)), + ("c10/cuda/CUDACachingAllocator.h", ("c10/hip/HIPCachingAllocator.h", API_C10)), + ("c10/cuda/impl/CUDATest.h", ("c10/hip/impl/HIPTest.h", API_C10)), + ("c10/cuda/impl/CUDAGuardImpl.h", ("c10/hip/impl/HIPGuardImpl.h", API_C10)), + ( + "c10/cuda/impl/cuda_cmake_macros.h", + ("c10/hip/impl/hip_cmake_macros.h", API_C10), + ), + ("C10_CUDA_CHECK", ("C10_HIP_CHECK", API_C10)), + ("C10_CUDA_CHECK_WARN", ("C10_HIP_CHECK_WARN", API_C10)), + ("C10_CUDA_ERROR_HANDLED", ("C10_HIP_ERROR_HANDLED", API_C10)), + ("C10_CUDA_IGNORE_ERROR", ("C10_HIP_IGNORE_ERROR", API_C10)), + ("C10_CUDA_CLEAR_ERROR", ("C10_HIP_CLEAR_ERROR", API_C10)), + ("c10::cuda", ("c10::hip", API_C10)), + ("cuda::CUDAStream", ("hip::HIPStream", API_C10)), + ("CUDAStream", ("HIPStream", API_C10)), + # This substitution is not permissible, because there's another copy of this + # function in torch/cuda.h + # ("cuda::device_count", ("hip::device_count", API_C10)), + ("cuda::current_device", ("hip::current_device", API_C10)), + ("cuda::set_device", ("hip::set_device", API_C10)), + ("cuda::device_synchronize", ("hip::device_synchronize", API_C10)), + ("cuda::getStreamFromPool", ("hip::getStreamFromPool", API_C10)), + ("getStreamFromPool", ("getStreamFromPool", API_C10)), + ("cuda::getDefaultCUDAStream", ("hip::getDefaultHIPStream", API_C10)), + ("getDefaultCUDAStream", ("getDefaultHIPStream", API_C10)), + ("cuda::getCurrentCUDAStream", ("hip::getCurrentHIPStream", API_C10)), + ("getCurrentCUDAStream", ("getCurrentHIPStream", API_C10)), + ("cuda::get_cuda_check_prefix", ("hip::get_cuda_check_prefix", API_C10)), + ("cuda::setCurrentCUDAStream", ("hip::setCurrentHIPStream", API_C10)), + ("setCurrentCUDAStream", ("setCurrentHIPStream", API_C10)), + ("cuda::CUDACachingAllocator", ("hip::HIPCachingAllocator", API_C10)), + ("CUDACachingAllocator", ("HIPCachingAllocator", API_C10)), + ("cuda::CUDAAllocatorConfig", ("hip::HIPAllocatorConfig", API_C10)), + ("CUDAAllocatorConfig", ("HIPAllocatorConfig", API_C10)), + ("pinned_use_cuda_host_register", ("pinned_use_hip_host_register", API_C10)), + ("c10::cuda::CUDAAllocator", ("c10::hip::HIPAllocator", API_C10)), + ("cuda::CUDAAllocator", ("hip::HIPAllocator", API_C10)), + ("CUDAStreamCaptureModeGuard", ("HIPStreamCaptureModeGuard", API_C10)), + ("cuda::CUDAStreamCaptureModeGuard", ("cuda::HIPStreamCaptureModeGuard", API_C10)), + ("CUDAAllocator", ("HIPAllocator", API_C10)), + ("C10_CUDA_KERNEL_LAUNCH_CHECK", ("C10_HIP_KERNEL_LAUNCH_CHECK", API_C10)), + ("CUDAKernelLaunchRegistry", ("HIPKernelLaunchRegistry", API_C10)), + ("c10::cuda::get_cuda_check_suffix", ("c10::hip::get_hip_check_suffix", API_C10)), + ] +) + +# NB: C10 mappings are more specific than Caffe2 mappings, so run them +# first +CUDA_TO_HIP_MAPPINGS = [ + CUDA_IDENTIFIER_MAP, + CUDA_TYPE_NAME_MAP, + CUDA_INCLUDE_MAP, + CUDA_SPECIAL_MAP, + C10_MAPPINGS, + PYTORCH_SPECIFIC_MAPPINGS, + CAFFE2_SPECIFIC_MAPPINGS, +] diff --git a/phivenv/Lib/site-packages/torch/utils/hipify/hipify_python.py b/phivenv/Lib/site-packages/torch/utils/hipify/hipify_python.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba06346ae8dc1ac1ae3106586ad6009e812d531 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/hipify/hipify_python.py @@ -0,0 +1,1176 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs +""" The Python Hipify script. +## +# Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved. +# 2017-2018 Advanced Micro Devices, Inc. and +# Facebook Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +import argparse +import fnmatch +import re +import shutil +import sys +import os + +from . import constants +from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS +from .cuda_to_hip_mappings import MATH_TRANSPILATIONS + +from typing import Optional +from collections.abc import Iterator +from collections.abc import Mapping, Iterable +from enum import Enum +import functools +import hashlib + +class CurrentState(Enum): + INITIALIZED = 1 + DONE = 2 + +class HipifyResult: + def __init__(self, current_state, hipified_path): + self.current_state = current_state + self.hipified_path = hipified_path + self.status = "" + + def __str__(self): + return (f"HipifyResult:: current_state: {self.current_state}, hipified_path : {self.hipified_path}, status: {self.status}") + +HipifyFinalResult = dict[str, HipifyResult] +HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n" +HIPIFY_FINAL_RESULT: HipifyFinalResult = {} + +# Hardcode the PyTorch template map +"""This dictionary provides the mapping from PyTorch kernel template types +to their actual types.""" +PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"} + +__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter', + 'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group', + 'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared', + 'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file', + 'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header', + 'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'CurrentState', 'HipifyResult', 'hipify'] + + +class InputError(Exception): + # Exception raised for errors in the input. + + def __init__(self, message): + super().__init__(message) + self.message = message + + def __str__(self): + return f"Input error: {self.message}" + + +def openf(filename, mode): + return open(filename, mode, errors='ignore') + + +# Color coding for printing +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +# To the programmer, the output of hipify most likely are intermediates. +# This class allows users of hipify to ask for a cleanup by running the +# hipify and compilation in a with instantiating this context manager class +# with keep_intermediates=False. +# The main usecase is the cpp_extensions, specifically the load method. +# It is a good idea to keep intermediates (in case of errors or to +# not recompile unchanged files), but in cases where you don't want to +# keep them (e.g. in the CI), this can be used to remove files. +class GeneratedFileCleaner: + """Context Manager to clean up generated files""" + def __init__(self, keep_intermediates=False): + self.keep_intermediates = keep_intermediates + self.files_to_clean = set() + self.dirs_to_clean = [] + + def __enter__(self): + return self + + def open(self, fn, *args, **kwargs): + if not os.path.exists(fn): + self.files_to_clean.add(os.path.abspath(fn)) + return open(fn, *args, **kwargs) + + def makedirs(self, dn, exist_ok=False): + parent, n = os.path.split(dn) + if not n: + parent, n = os.path.split(parent) + if parent and n and not os.path.exists(parent): + self.makedirs(parent, exist_ok=True) + if not os.path.isdir(dn) or not exist_ok: + os.mkdir(dn) + self.dirs_to_clean.append(os.path.abspath(dn)) + + def __exit__(self, type, value, traceback): + if not self.keep_intermediates: + for f in self.files_to_clean: + os.unlink(f) + for d in self.dirs_to_clean[::-1]: + os.rmdir(d) + +# Follow UNIX convention for paths to use '/' instead of '\\' on Windows +def _to_unix_path(path: str) -> str: + return path.replace(os.sep, '/') + +def match_extensions(filename: str, extensions: Iterable) -> bool: + """Helper method to see if filename ends with certain extension""" + return any(filename.endswith(e) for e in extensions) + + +def _fnmatch(filepath, patterns): + return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns) + + +def matched_files_iter( + root_path: str, + includes: Iterable = (), + ignores: Iterable = (), + extensions: Iterable = (), + out_of_place_only: bool = False, + is_pytorch_extension: bool = False) -> Iterator[str]: + + exact_matches = set(includes) + + # This is a very rough heuristic; really, we want to avoid scanning + # any file which is not checked into source control, but this script + # needs to work even if you're in a Git or Hg checkout, so easier to + # just block the biggest time sinks that won't matter in the + # end. + for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True): + rel_dirpath = os.path.relpath(abs_dirpath, root_path) + if rel_dirpath == '.': + # Blah blah blah O(n) blah blah + if ".git" in dirs: + dirs.remove(".git") + if "build" in dirs: + dirs.remove("build") + if "third_party" in dirs: + dirs.remove("third_party") + dirs.append("third_party/nvfuser") + for filename in filenames: + filepath = _to_unix_path(os.path.join(abs_dirpath, filename)) + rel_filepath = _to_unix_path(os.path.join(rel_dirpath, filename)) + # We respect extensions, UNLESS you wrote the entire + # filename verbatim, in which case we always accept it + if ( + _fnmatch(filepath, includes) + and (not _fnmatch(filepath, ignores)) + and (match_extensions(filepath, extensions) or filepath in exact_matches) + ): + if not is_pytorch_extension: # for pytorch extensions, consider all files + if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath): + continue + if out_of_place_only and not is_out_of_place(rel_filepath): + continue + yield filepath + + +def preprocess_file_and_save_result( + output_directory: str, + filepath: str, + all_files: Iterable, + header_include_dirs: Iterable, + stats: dict[str, list], + hip_clang_launch: bool, + is_pytorch_extension: bool, + clean_ctx: GeneratedFileCleaner, + show_progress: bool) -> None: + fin_path = os.path.abspath(os.path.join(output_directory, filepath)) + hipify_result = HipifyResult(current_state=CurrentState.INITIALIZED, hipified_path=fin_path) + HIPIFY_FINAL_RESULT[fin_path] = hipify_result + result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats, + hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) + + # Show what happened + if show_progress and "ignored" not in result.status: + print( + fin_path, "->", + result.hipified_path, result.status, flush=True) + + HIPIFY_FINAL_RESULT[fin_path] = result + + +def compute_stats(stats): + unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]} + + # Print the number of unsupported calls + print(f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}") + + # Print the list of unsupported calls + print(", ".join(unsupported_calls)) + + # Print the number of kernel launches + print(f"\nTotal number of replaced kernel launches: {len(stats['kernel_launches']):d}") + + +def add_dim3(kernel_string, cuda_kernel): + '''adds dim3() to the second and third arguments in the kernel launch''' + count = 0 + closure = 0 + kernel_string = kernel_string.replace("<<<", "").replace(">>>", "") + arg_locs: list[dict[str, int]] = [{} for _ in range(2)] + arg_locs[count]['start'] = 0 + for ind, c in enumerate(kernel_string): + if count > 1: + break + if c == "(": + closure += 1 + elif c == ")": + closure -= 1 + if (c == "," or ind == len(kernel_string) - 1) and closure == 0: + arg_locs[count]['end'] = ind + (c != ",") + count += 1 + if count < 2: + arg_locs[count]['start'] = ind + 1 + + first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1] + second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']] + + first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ") + second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ") + + first_arg_dim3 = f"dim3({first_arg_clean})" + second_arg_dim3 = f"dim3({second_arg_clean})" + + first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3) + second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3) + cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3) + return cuda_kernel + + +RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+') + + +def processKernelLaunches(string, stats): + """ Replace the CUDA style Kernel launches with the HIP style kernel launches.""" + # Concat the namespace with the kernel names. (Find cleaner way of doing this later). + string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string) + + def grab_method_and_template(in_kernel): + # The positions for relevant kernel components. + pos = { + "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]}, + "kernel_name": {"start": -1, "end": -1}, + "template": {"start": -1, "end": -1} + } + + # Count for balancing template + count = {"<>": 0} + + # Status for whether we are parsing a certain item. + START = 0 + AT_TEMPLATE = 1 + AFTER_TEMPLATE = 2 + AT_KERNEL_NAME = 3 + + status = START + + # Parse the string character by character + for i in range(pos["kernel_launch"]["start"] - 1, -1, -1): + char = string[i] + + # Handle Templating Arguments + if status in (START, AT_TEMPLATE): + if char == ">": + if status == START: + status = AT_TEMPLATE + pos["template"]["end"] = i + count["<>"] += 1 + + if char == "<": + count["<>"] -= 1 + if count["<>"] == 0 and (status == AT_TEMPLATE): + pos["template"]["start"] = i + status = AFTER_TEMPLATE + + # Handle Kernel Name + if status != AT_TEMPLATE: + if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}: + if status != AT_KERNEL_NAME: + status = AT_KERNEL_NAME + pos["kernel_name"]["end"] = i + + # Case: Kernel name starts the string. + if i == 0: + pos["kernel_name"]["start"] = 0 + + # Finished + return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])] + + else: + # Potential ending point if we're already traversing a kernel's name. + if status == AT_KERNEL_NAME: + pos["kernel_name"]["start"] = i + + # Finished + return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])] + + def find_kernel_bounds(string): + """Finds the starting and ending points for all kernel launches in the string.""" + kernel_end = 0 + kernel_positions = [] + + # Continue until we cannot find any more kernels anymore. + while string.find("<<<", kernel_end) != -1: + # Get kernel starting position (starting from the previous ending point) + kernel_start = string.find("<<<", kernel_end) + + # Get kernel ending position (adjust end point past the >>>) + kernel_end = string.find(">>>", kernel_start) + 3 + if kernel_end <= 0: + raise InputError("no kernel end found") + + # Add to list of traversed kernels + kernel_positions.append({"start": kernel_start, "end": kernel_end, + "group": string[kernel_start: kernel_end]}) + + return kernel_positions + + # Replace comments and string literals from the code so that find_kernel_bounds does not + # wrongly capture kernels in comments and string literals. + # This function replaces them with "x" to keep positions. + def mask_comments(string): + in_comment = '' + prev_c = '' + new_string = '' + for c in string: + if in_comment == '': + # Outside comments + if c == '/' and prev_c == '/': + in_comment = '//' + elif c == '*' and prev_c == '/': + in_comment = '/*' + elif c == '"' and prev_c != '\\' and prev_c != "'": + in_comment = '"' + elif in_comment == '//': + # In // xxx + if c == '\r' or c == '\n': + in_comment = '' + elif in_comment == '/*': + # In /* xxx */ + if c == '/' and prev_c == '*': + in_comment = '' + elif in_comment == '"': + # In "" + if c == '"' and prev_c != '\\': + in_comment = '' + prev_c = c + if in_comment == '': + new_string += c + else: + new_string += 'x' + return new_string + + # Grab positional ranges of all kernel launches + get_kernel_positions = list(find_kernel_bounds(mask_comments(string))) + output_string = string + + # Replace each CUDA kernel with a HIP kernel. + for kernel in get_kernel_positions: + # Get kernel components + params = grab_method_and_template(kernel) + + # Find parenthesis after kernel launch + parenthesis = string.find("(", kernel["end"]) + + # Extract cuda kernel + cuda_kernel = string[params[0]["start"]:parenthesis + 1] + kernel_string = string[kernel['start']:kernel['end']] + end_param_index = 0 if params[1]['end'] == -1 else 1 + kernel_name_with_template = string[params[0]['start']:params[end_param_index]['end'] + 1] + cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel) + # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size) + num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")"))) + + hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace( + ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace( + ">>>", ", ").replace(kernel_name_with_template, "(" + kernel_name_with_template + ")") + + # Replace cuda kernel with hip kernel + output_string = output_string.replace(cuda_kernel, hip_kernel) + + # Update the statistics + stats["kernel_launches"].append(hip_kernel) + + return output_string + + +def find_closure_group(input_string, start, group): + """Generalization for finding a balancing closure group + + if group = ["(", ")"], then finds the first balanced parentheses. + if group = ["{", "}"], then finds the first balanced bracket. + + Given an input string, a starting position in the input string, and the group type, + find_closure_group returns the positions of group[0] and group[1] as a tuple. + + Example: + >>> find_closure_group("(hi)", 0, ["(", ")"]) + (0, 3) + """ + + inside_parenthesis = False + parens = 0 + pos = start + p_start, p_end = -1, -1 + + while pos < len(input_string): + if input_string[pos] == group[0]: + if inside_parenthesis is False: + inside_parenthesis = True + parens = 1 + p_start = pos + else: + parens += 1 + elif input_string[pos] == group[1] and inside_parenthesis: + parens -= 1 + + if parens == 0: + p_end = pos + return p_start, p_end + + pos += 1 + return None, None + + +def find_bracket_group(input_string, start): + """Finds the first balanced parantheses.""" + return find_closure_group(input_string, start, group=["{", "}"]) + + +def find_parentheses_group(input_string, start): + """Finds the first balanced bracket.""" + return find_closure_group(input_string, start, group=["(", ")"]) + + +RE_ASSERT = re.compile(r"\bassert[ ]*\(") + + +def replace_math_functions(input_string): + """FIXME: Temporarily replace std:: invocations of math functions + with non-std:: versions to prevent linker errors NOTE: This + can lead to correctness issues when running tests, since the + correct version of the math function (exp/expf) might not get + called. Plan is to remove this function once HIP supports + std:: math function calls inside device code + + """ + output_string = input_string + for func in MATH_TRANSPILATIONS: + output_string = output_string.replace(fr'{func}(', f'{MATH_TRANSPILATIONS[func]}(') + + return output_string + + +RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()") + + +def hip_header_magic(input_string): + """If the file makes kernel builtin calls and does not include the cuda_runtime.h header, + then automatically add an #include to match the "magic" includes provided by NVCC. + TODO: + Update logic to ignore cases where the cuda_runtime.h is included by another file. + """ + + # Copy the input. + output_string = input_string + + # Check if one of the following headers is already included. + headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"] + if any(re.search(fr'#include ("{ext}"|<{ext}>)', output_string) for ext in headers): + return output_string + + # Rough logic to detect if we're inside device code + hasDeviceLogic: int + hasDeviceLogic = "hipLaunchKernelGGL" in output_string + hasDeviceLogic += "__global__" in output_string + hasDeviceLogic += "__shared__" in output_string + hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None + + # If device logic found, provide the necessary header. + if hasDeviceLogic: + output_string = '#include "hip/hip_runtime.h"\n' + input_string + + return output_string + + +RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;") + + +def replace_extern_shared(input_string): + """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead. + https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__ + Example: + "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)" + "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)" + """ + output_string = input_string + output_string = RE_EXTERN_SHARED.sub( + lambda inp: f"HIP_DYNAMIC_SHARED({inp.group(1) or ''} {inp.group(2)}, {inp.group(3)})", output_string) + + return output_string + + +def get_hip_file_path(rel_filepath, is_pytorch_extension=False): + """ + Returns the new name of the hipified file + """ + # At the moment, some PyTorch source files are HIPified in place. The predicate + # is_out_of_place tells us if this is the case or not. + assert not os.path.isabs(rel_filepath) + if not is_pytorch_extension and not is_out_of_place(rel_filepath): + return rel_filepath + + dirpath, filename = os.path.split(rel_filepath) + root, ext = os.path.splitext(filename) + + # Here's the plan: + # + # In general, we need to disambiguate the HIPified filename so that + # it gets a different name from the original filename, so + # that we don't overwrite the original file + # + # There's a lot of different naming conventions across PyTorch + # and Caffe2, but the general recipe is to convert occurrences + # of cuda/gpu to hip, and add hip if there are no occurrences + # of cuda/gpu anywhere. + # + # Concretely, we do the following: + # + # - If there is a directory component named "cuda", replace + # it with "hip", AND + # + # - If the file name contains "CUDA", replace it with "HIP", AND + # + # - ALWAYS replace '.cu' with '.hip', because those files + # contain CUDA kernels that needs to be hipified and processed with + # hip compiler + # + # - If we are not hipifying a PyTorch extension, and the parent + # directory name did not change as a result of the above + # transformations, insert "hip" in the file path + # as the direct parent folder of the file + # + # - If we are hipifying a PyTorch extension, and the parent directory + # name as well as the filename (incl. extension) did not change as + # a result of the above transformations, insert "_hip" in the filename + # + # This isn't set in stone; we might adjust this to support other + # naming conventions. + + if ext == '.cu': + ext = '.hip' + + orig_filename = filename + orig_dirpath = dirpath + + dirpath = dirpath.replace('cuda', 'hip') + dirpath = dirpath.replace('CUDA', 'HIP') + dirpath = dirpath.replace('THC', 'THH') + + root = root.replace('cuda', 'hip') + root = root.replace('CUDA', 'HIP') + # Special case to handle caffe2/core/THCCachingAllocator + if dirpath != "caffe2/core": + root = root.replace('THC', 'THH') + + if not is_pytorch_extension and dirpath == orig_dirpath: + dirpath = os.path.join(dirpath, 'hip') + + if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename: + root = root + "_hip" + + return os.path.join(dirpath, root + ext) + + +def is_out_of_place(rel_filepath): + assert not os.path.isabs(rel_filepath) + if rel_filepath.startswith("torch/"): + return False + if rel_filepath.startswith("third_party/nvfuser/"): + return False + if rel_filepath.startswith("tools/autograd/templates/"): + return False + return True + + +# Keep this synchronized with includes/ignores in build_amd.py +def is_pytorch_file(rel_filepath): + assert not os.path.isabs(rel_filepath) + if rel_filepath.startswith("aten/"): + if rel_filepath.startswith("aten/src/ATen/core/"): + return False + return True + if rel_filepath.startswith("torch/"): + return True + if rel_filepath.startswith("third_party/nvfuser/"): + return True + if rel_filepath.startswith("tools/autograd/templates/"): + return True + return False + + +def is_cusparse_file(rel_filepath): + if is_pytorch_file(rel_filepath): + return "sparse" in rel_filepath.lower() + return False + + +def is_special_file(rel_filepath): + if is_pytorch_file(rel_filepath): + if "sparse" in rel_filepath.lower(): + return True + elif "linalg" in rel_filepath.lower(): + if "batchlinearalgebralibblas" in rel_filepath.lower(): + return False # don't use "special" mappings for this specific linalg cublas file + return True + return False + +def is_caffe2_gpu_file(rel_filepath): + assert not os.path.isabs(rel_filepath) + if rel_filepath.startswith("c10/cuda"): + return True + filename = os.path.basename(rel_filepath) + _, ext = os.path.splitext(filename) + return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename) + +class TrieNode: + """A Trie node whose children are represented as a directory of char: TrieNode. + A special char '' represents end of word + """ + + def __init__(self): + self.children = {} + +class Trie: + """Creates a Trie out of a list of words. The trie can be exported to a Regex pattern. + The corresponding Regex should match much faster than a simple Regex union.""" + + def __init__(self): + """Initialize the trie with an empty root node.""" + self.root = TrieNode() + self._hash = hashlib.md5(usedforsecurity=False) + self._digest = self._hash.digest() + + def add(self, word): + """Add a word to the Trie. """ + self._hash.update(word.encode()) + self._digest = self._hash.digest() + node = self.root + + for char in word: + node.children.setdefault(char, TrieNode()) + node = node.children[char] + node.children[''] = True # Mark the end of the word + + def dump(self): + """Return the root node of Trie. """ + return self.root + + def quote(self, char): + """ Escape a char for regex. """ + return re.escape(char) + + def search(self, word): + """Search whether word is present in the Trie. + Returns True if yes, else return False""" + node = self.root + for char in word: + if char in node.children: + node = node.children[char] + else: + return False + + # make sure to check the end-of-word marker present + return '' in node.children + + @functools.lru_cache # noqa: B019 + def _pattern(self, root, digest): + """Convert a Trie into a regular expression pattern + + Memoized on the hash digest of the trie, which is built incrementally + during add(). + """ + node = root + + if "" in node.children and len(node.children.keys()) == 1: + return None + + alt = [] # store alternative patterns + cc = [] # store char to char classes + q = 0 # for node representing the end of word + for char in sorted(node.children.keys()): + if isinstance(node.children[char], TrieNode): + try: + recurse = self._pattern(node.children[char], self._digest) + alt.append(self.quote(char) + recurse) + except Exception: + cc.append(self.quote(char)) + else: + q = 1 + cconly = not len(alt) > 0 + + if len(cc) > 0: + if len(cc) == 1: + alt.append(cc[0]) + else: + alt.append('[' + ''.join(cc) + ']') + + if len(alt) == 1: + result = alt[0] + else: + result = "(?:" + "|".join(alt) + ")" + + if q: + if cconly: + result += "?" + else: + result = f"(?:{result})?" + return result + + def pattern(self): + """Export the Trie to a regex pattern.""" + return self._pattern(self.root, self._digest) + + def export_to_regex(self): + """Export the Trie to a regex pattern.""" + return self._pattern(self.root, self._digest) + +CAFFE2_TRIE = Trie() +CAFFE2_MAP = {} +PYTORCH_TRIE = Trie() +PYTORCH_MAP: dict[str, object] = {} + +# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip. +# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance. +# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex. +# In the case of SPARSE, we must use the hip types for complex instead of the roc types, +# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority. +# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place. +# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices. +# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling. +PYTORCH_SPECIAL_MAP = {} + +for mapping in CUDA_TO_HIP_MAPPINGS: + assert isinstance(mapping, Mapping) + for src, value in mapping.items(): + dst = value[0] + meta_data = value[1:] + if constants.API_CAFFE2 not in meta_data: + PYTORCH_TRIE.add(src) + # if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL + # do not overwrite PYTORCH_MAP, store dst separately + if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""): + PYTORCH_SPECIAL_MAP[src] = dst + else: + PYTORCH_MAP[src] = dst + if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data: + CAFFE2_TRIE.add(src) + CAFFE2_MAP[src] = dst +RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex()) +RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)') + +RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"') +RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>') +RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"') +RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh + +""" +Returns a HipifyResult object with the following details: + "hipified_path" : absolute path of hipified source file + "status" : "ok" if hipified file was written out + "skipped" if an identical hipified file already existed or hipified file couldn't be written out + "ignored" if the source file was a hipified file itself or not meant to be hipified + "current_state" : CurrentState.INITIALIZED if source file is first ready to be hipified + CurrentState.DONE if source file is done with hipification process +""" + + +def preprocessor( + output_directory: str, + filepath: str, + all_files: Iterable, + header_include_dirs: Iterable, + stats: dict[str, list], + hip_clang_launch: bool, + is_pytorch_extension: bool, + clean_ctx: GeneratedFileCleaner, + show_progress: bool) -> HipifyResult: + """ Executes the CUDA -> HIP conversion on the specified file. """ + fin_path = os.path.abspath(os.path.join(output_directory, filepath)) + filepath = _to_unix_path(filepath) + hipify_result = HIPIFY_FINAL_RESULT[fin_path] + if filepath not in all_files: + hipify_result.hipified_path = None + hipify_result.status = "[ignored, not to be hipified]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + + rel_filepath = _to_unix_path(os.path.relpath(filepath, output_directory)) + + with open(fin_path, encoding='utf-8') as fin: + if fin.readline() == HIPIFY_C_BREADCRUMB: + hipify_result.hipified_path = None + hipify_result.status = "[ignored, input is hipified output]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + fin.seek(0) + output_source = fin.read() + + orig_output_source = output_source + + # get_hip_file_path needs a relative path to work correctly + fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension))) + if not os.path.exists(os.path.dirname(fout_path)): + clean_ctx.makedirs(os.path.dirname(fout_path)) + + # unsupported_calls statistics reporting is broken atm + def pt_repl(m): + return PYTORCH_MAP[m.group(0)] + + def pt_special_repl(m): + # checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings + return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m)) + + + if is_pytorch_extension: + output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) + else: + if is_special_file(rel_filepath): + output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source) + elif is_pytorch_file(rel_filepath): + output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) + else: + def c2_repl(m): + return CAFFE2_MAP[m.group(0)] + output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source) + + # Header rewrites + def mk_repl(templ, include_current_dir=True): + def repl(m): + f = m.group(1) + filename = os.path.basename(f) + if ( + f.startswith(("ATen/cuda", + "ATen/native/cuda", + "ATen/native/nested/cuda", + "ATen/native/quantized/cuda", + "ATen/native/sparse/cuda", + "ATen/native/transformers/cuda", + "THC/")) or + (f.startswith("THC") and not f.startswith("THCP")) + ): + return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension)) + # if filename is one of the files being hipified for this extension + if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)): + header_dir = None + header_filepath = None + # If include_current_dir True, look first in same dir as the including source file + if include_current_dir: + header_dir_to_check = os.path.dirname(fin_path) + header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) + if os.path.exists(header_path_to_check): + header_dir = header_dir_to_check + header_filepath = header_path_to_check + # If not found, look in include dirs one by one and first match wins + if header_filepath is None: + for header_include_dir in header_include_dirs: + header_dir_to_check = os.path.join(output_directory, header_include_dir) + header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) + if os.path.exists(header_path_to_check): + header_dir = header_dir_to_check + header_filepath = header_path_to_check + # If header file not found, keep as is + if header_filepath is None: + return m.group(0) + # Hipify header file first if needed + if header_filepath not in HIPIFY_FINAL_RESULT: + preprocess_file_and_save_result(output_directory, + header_filepath, + all_files, header_include_dirs, stats, hip_clang_launch, + is_pytorch_extension, clean_ctx, show_progress) + elif header_filepath in HIPIFY_FINAL_RESULT: + header_result = HIPIFY_FINAL_RESULT[header_filepath] + if header_result.current_state == CurrentState.INITIALIZED: + # get_hip_file_path needs a relative path to work correctly + header_rel_path = os.path.relpath(header_filepath, output_directory) + header_fout_path = os.path.abspath(os.path.join(output_directory, + get_hip_file_path(header_rel_path, is_pytorch_extension))) + header_result.hipified_path = header_fout_path + HIPIFY_FINAL_RESULT[header_filepath] = header_result + return templ.format(os.path.relpath(header_fout_path if header_fout_path is not None + else header_filepath, header_dir)) + hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path + return templ.format(_to_unix_path(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None + else header_filepath, header_dir))) + + return m.group(0) + return repl + output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source) + output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source) + output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source) + + # CMakeLists.txt rewrites + if filepath.endswith('CMakeLists.txt'): + output_source = output_source.replace('CUDA', 'HIP') + output_source = output_source.replace('THC', 'THH') + output_source = RE_CU_SUFFIX.sub('.hip', output_source) + + # Perform Kernel Launch Replacements + if not hip_clang_launch: + output_source = processKernelLaunches(output_source, stats) + + # Replace std:: with non-std:: versions + if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath: + output_source = replace_math_functions(output_source) + + # Include header if device code is contained. + output_source = hip_header_magic(output_source) + + # Replace the extern __shared__ + # NOTE: No longer needed after transition from hcc to hipclang. + # output_source = replace_extern_shared(output_source) + + # Don't write out identical hipified files for extensions if dirpath has not changed + if ( + is_pytorch_extension + and orig_output_source == output_source + and os.path.dirname(fin_path) == os.path.dirname(fout_path) + ): + hipify_result.hipified_path = fin_path + hipify_result.status = "[skipped, no changes]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + + # Add hipify breadcrumb for C-style files to avoid re-hipification + if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")): + output_source = HIPIFY_C_BREADCRUMB + output_source + + do_write = True + if os.path.exists(fout_path): + with open(fout_path, encoding='utf-8') as fout_old: + do_write = fout_old.read() != output_source + if do_write: + try: + with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout: + fout.write(output_source) + hipify_result.hipified_path = fout_path + hipify_result.status = "[ok]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + except OSError as e: + print(f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}', + file=sys.stderr) + hipify_result.hipified_path = fin_path + hipify_result.status = "[skipped, no permissions]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + else: + hipify_result.hipified_path = fout_path + hipify_result.status = "[skipped, already hipified]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + +def file_specific_replacement(filepath, search_string, replace_string, strict=False): + with openf(filepath, "r+") as f: + contents = f.read() + if strict: + contents = re.sub(fr'\b({re.escape(search_string)})\b', lambda x: replace_string, contents) + else: + contents = contents.replace(search_string, replace_string) + f.seek(0) + f.write(contents) + f.truncate() + + +def file_add_header(filepath, header): + with openf(filepath, "r+") as f: + contents = f.read() + if header[0] != "<" and header[-1] != ">": + header = f'"{header}"' + contents = (f'#include {header} \n') + contents + f.seek(0) + f.write(contents) + f.truncate() + + +def fix_static_global_kernels(in_txt): + """Static global kernels in HIP results in a compilation error.""" + in_txt = in_txt.replace(" __global__ static", "__global__") + return in_txt + + +RE_INCLUDE = re.compile(r"#include .*\n") + + +def extract_arguments(start, string): + """ Return the list of arguments in the upcoming function parameter closure. + Example: + string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))' + arguments (output): + '[{'start': 1, 'end': 7}, + {'start': 8, 'end': 16}, + {'start': 17, 'end': 19}, + {'start': 20, 'end': 53}]' + """ + + arguments = [] + closures = { + "<": 0, + "(": 0 + } + current_position = start + argument_start_pos = current_position + 1 + + # Search for final parenthesis + while current_position < len(string): + if string[current_position] == "(": + closures["("] += 1 + elif string[current_position] == ")": + closures["("] -= 1 + elif string[current_position] == "<": + closures["<"] += 1 + elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0: + closures["<"] -= 1 + + # Finished all arguments + if closures["("] == 0 and closures["<"] == 0: + # Add final argument + arguments.append({"start": argument_start_pos, "end": current_position}) + break + + # Finished current argument + if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",": + arguments.append({"start": argument_start_pos, "end": current_position}) + argument_start_pos = current_position + 1 + + current_position += 1 + + return arguments + + +def str2bool(v): + """ArgumentParser doesn't support type=bool. Thus, this helper method will convert + from possible string types to True / False.""" + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def hipify( + project_directory: str, + show_detailed: bool = False, + extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"), + header_extensions: Iterable = (".cuh", ".h", ".hpp"), + output_directory: str = "", + header_include_dirs: Iterable = (), + includes: Iterable = ('*',), + extra_files: Iterable = (), + out_of_place_only: bool = False, + ignores: Iterable = (), + show_progress: bool = True, + hip_clang_launch: bool = False, + is_pytorch_extension: bool = False, + hipify_extra_files_only: bool = False, + clean_ctx: Optional[GeneratedFileCleaner] = None +) -> HipifyFinalResult: + if project_directory == "": + project_directory = os.getcwd() + + # Verify the project directory exists. + if not os.path.exists(project_directory): + print("The project folder specified does not exist.") + sys.exit(1) + + # If no output directory, provide a default one. + if not output_directory: + project_directory.rstrip("/") + output_directory = project_directory + "_amd" + + if project_directory != output_directory: + includes = [include.replace(project_directory, output_directory) for include in includes] + ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores] + + # Copy from project directory to output directory if not done already. + if not os.path.exists(output_directory): + shutil.copytree(project_directory, output_directory) + + includes = list(map(_to_unix_path, includes)) + ignores = list(map(_to_unix_path, ignores)) + + all_files = list(matched_files_iter(output_directory, includes=includes, + ignores=ignores, extensions=extensions, + out_of_place_only=out_of_place_only, + is_pytorch_extension=is_pytorch_extension)) + all_files_set = set(all_files) + for f in extra_files: + if not os.path.isabs(f): + f = os.path.join(output_directory, f) + if f not in all_files_set: + all_files.append(f) + + # List all files in header_include_paths to ensure they are hipified + from pathlib import Path + for header_include_dir in header_include_dirs: + if os.path.isabs(header_include_dir): + header_include_dir_path = Path(header_include_dir) + else: + header_include_dir_path = Path(os.path.join(output_directory, header_include_dir)) + all_files.extend( + str(path) for path in header_include_dir_path.rglob('*') if path.is_file() + and _fnmatch(str(path), includes) + and (not _fnmatch(str(path), ignores)) + and match_extensions(path.name, header_extensions) + ) + + if clean_ctx is None: + clean_ctx = GeneratedFileCleaner(keep_intermediates=True) + + # Preprocessing statistics. + stats: dict[str, list] = {"unsupported_calls": [], "kernel_launches": []} + + for filepath in (all_files if not hipify_extra_files_only else extra_files): + preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs, + stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) + + print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr) + + # Show detailed summary + if show_detailed: + compute_stats(stats) + + return HIPIFY_FINAL_RESULT diff --git a/phivenv/Lib/site-packages/torch/utils/hipify/version.py b/phivenv/Lib/site-packages/torch/utils/hipify/version.py new file mode 100644 index 0000000000000000000000000000000000000000..608f35d6f6b03ca23f46fbd6500fc32f694a858f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/hipify/version.py @@ -0,0 +1 @@ +__version__ = '1.0.0' diff --git a/phivenv/Lib/site-packages/torch/utils/hooks.py b/phivenv/Lib/site-packages/torch/utils/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..c99b729b367b78da9ea4d31c9fc00bd1d81e26e2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/hooks.py @@ -0,0 +1,256 @@ +# mypy: allow-untyped-defs +import torch +from collections import OrderedDict +import weakref +import warnings +from typing import Any + +__all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"] + +class RemovableHandle: + r""" + A handle which provides the capability to remove a hook. + + Args: + hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``. + extra_dict (Union[dict, List[dict]]): An additional dictionary or list of + dictionaries whose keys will be deleted when the same keys are + removed from ``hooks_dict``. + """ + + id: int + next_id: int = 0 + + def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None: + self.hooks_dict_ref = weakref.ref(hooks_dict) + self.id = RemovableHandle.next_id + RemovableHandle.next_id += 1 + + self.extra_dict_ref: tuple = () + if isinstance(extra_dict, dict): + self.extra_dict_ref = (weakref.ref(extra_dict),) + elif isinstance(extra_dict, list): + self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict) + + def remove(self) -> None: + hooks_dict = self.hooks_dict_ref() + if hooks_dict is not None and self.id in hooks_dict: + del hooks_dict[self.id] + + for ref in self.extra_dict_ref: + extra_dict = ref() + if extra_dict is not None and self.id in extra_dict: + del extra_dict[self.id] + + def __getstate__(self): + if self.extra_dict_ref is None: + return (self.hooks_dict_ref(), self.id) + else: + return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref)) + + def __setstate__(self, state) -> None: + if state[0] is None: + # create a dead reference + self.hooks_dict_ref = weakref.ref(OrderedDict()) + else: + self.hooks_dict_ref = weakref.ref(state[0]) + self.id = state[1] + RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1) + + if len(state) < 3 or state[2] is None: + self.extra_dict_ref = () + else: + self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2]) + + def __enter__(self) -> "RemovableHandle": + return self + + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + self.remove() + + +def unserializable_hook(f): + """ + Mark a function as an unserializable hook with this decorator. + + This suppresses warnings that would otherwise arise if you attempt + to serialize a tensor that has a hook. + """ + f.__torch_unserializable__ = True + return f + + +def warn_if_has_hooks(tensor): + if tensor._backward_hooks: + for k in tensor._backward_hooks: + hook = tensor._backward_hooks[k] + if not hasattr(hook, "__torch_unserializable__"): + warnings.warn(f"backward hook {repr(hook)} on tensor will not be " + "serialized. If this is expected, you can " + "decorate the function with @torch.utils.hooks.unserializable_hook " + "to suppress this warning") + +class BackwardHook: + """ + A wrapper class to implement nn.Module backward hooks. + + It handles: + - Ignoring non-Tensor inputs and replacing them by None before calling the user hook + - Generating the proper Node to capture a set of Tensor's gradients + - Linking the gradients captures for the outputs with the gradients captured for the input + - Calling the user hook once both output and input gradients are available + """ + + def __init__(self, module, user_hooks, user_pre_hooks): + self.user_hooks = user_hooks + self.user_pre_hooks = user_pre_hooks + self.module = module + + self.grad_outputs = None + self.n_outputs = -1 + self.output_tensors_index = None + self.n_inputs = -1 + self.input_tensors_index = None + + def _pack_with_none(self, indices, values, size): + res = [None] * size + for idx, val in zip(indices, values): + res[idx] = val + + return tuple(res) + + def _unpack_none(self, indices, values): + res = [values[idx] for idx in indices] + + return tuple(res) + + def _set_user_hook(self, grad_fn): + def hook(grad_input, _): + if self.grad_outputs is None: + # This happens because the gradient in your nn.Module flows to + # the Module's input without " passing through the Module's + # output, e.g. when you're doing double backward. + return + res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs) + + for hook in self.user_hooks: + out = hook(self.module, res, self.grad_outputs) + + if out is None: + continue + + if len(out) != len(res): + raise RuntimeError("Backward hook returned an invalid number of grad_input, " + f"got {len(out)}, but expected {len(res)}") + + res = out + + self.grad_outputs = None + + return self._unpack_none(self.input_tensors_index, res) + + grad_fn.register_hook(hook) + + def _apply_on_tensors(self, fn, args): + # Can be used to apply the given function to the tensors contained in the + # args. Will return updated args and the tensors indices + tensors_idx = [] + tensors = [] + + requires_grad = False + for i, arg in enumerate(args): + if isinstance(arg, torch.Tensor): + tensors_idx.append(i) + tensors.append(arg) + requires_grad |= arg.requires_grad + + if not (requires_grad and torch.is_grad_enabled()): + return args, None + + new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors) + if len(new_tensors) == 0: + raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.") + + grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"] + if len(grad_fns) == 0: + raise RuntimeError("Error while setting up backward hooks. Please open " + "an issue with a code sample to reproduce this.") + + fn(grad_fns[0]) + + arg_list = list(args) + for idx, val in zip(tensors_idx, new_tensors): + arg_list[idx] = val + + if type(args) is tuple: + out = tuple(arg_list) + else: + out = type(args)(*arg_list) + return out, tensors_idx + + def setup_input_hook(self, args): + def fn(grad_fn): + self._set_user_hook(grad_fn) + + res, input_idx = self._apply_on_tensors(fn, args) + self.n_inputs = len(args) + self.input_tensors_index = input_idx + return res + + def setup_output_hook(self, args): + def fn(grad_fn): + def hook(_, grad_output): + self.grad_outputs = self._pack_with_none(self.output_tensors_index, + grad_output, + self.n_outputs) + + if self.user_pre_hooks: + expected_len = len(self.grad_outputs) + for user_pre_hook in self.user_pre_hooks: + hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs) + if hook_grad_outputs is None: + continue + + actual_len = len(hook_grad_outputs) + if actual_len != expected_len: + raise RuntimeError("Backward pre hook returned an invalid number of grad_output, " + f"got {actual_len}, but expected {expected_len}") + self.grad_outputs = hook_grad_outputs + + # We need to be able to clear self.grad_outputs but also return it + local_grad_outputs = self.grad_outputs + + # Special case if no input required gradients, this hook should call the user + # hook directly + if self.input_tensors_index is None: + warnings.warn("Full backward hook is firing when gradients are computed " + "with respect to module outputs since no inputs require gradients. See " + "https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook " # noqa: B950 + "for more details.", + stacklevel=5) + grad_inputs = self._pack_with_none([], [], self.n_inputs) + for user_hook in self.user_hooks: + res = user_hook(self.module, grad_inputs, self.grad_outputs) + if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)): + raise RuntimeError("Backward hook for Modules where no input requires " + "gradient should always return None or None for all gradients.") + self.grad_outputs = None + + if local_grad_outputs is not None: + assert self.output_tensors_index is not None # mypy + return tuple(local_grad_outputs[i] for i in self.output_tensors_index) + + grad_fn.register_hook(hook) + + is_tuple = True + if not isinstance(args, tuple): + args = (args,) + is_tuple = False + + res, output_idx = self._apply_on_tensors(fn, args) + self.n_outputs = len(args) + self.output_tensors_index = output_idx + + if not is_tuple: + res = res[0] + return res diff --git a/phivenv/Lib/site-packages/torch/utils/jit/__init__.py b/phivenv/Lib/site-packages/torch/utils/jit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/jit/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/jit/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5b12b18bff478f2e86221ab9a28f7f42e796a59 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/jit/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/jit/__pycache__/log_extract.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/jit/__pycache__/log_extract.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b32ad8d25573317f7e89fa272234ce8c07b32825 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/jit/__pycache__/log_extract.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/jit/log_extract.py b/phivenv/Lib/site-packages/torch/utils/jit/log_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bac39de8bc7515290f4c19d02c2335167b2a02 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/jit/log_extract.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager +from typing import Any, cast +import random +import torch +import time +from torch.utils.benchmark import Timer + +def extract_ir(filename: str) -> list[str]: + BEGIN = "" + END = "" + pfx = None + graphs = [] + with open(filename) as f: + split_strs = f.read().split(BEGIN) + for i, split_str in enumerate(split_strs): + if i == 0: + continue + end_loc = split_str.find(END) + if end_loc == -1: + continue + s = split_str[:end_loc] + pfx = split_strs[i - 1].splitlines()[-1] + lines = [x[len(pfx):] for x in s.splitlines(keepends=True)] + graphs.append(''.join(lines)) + + return graphs + + +def make_tensor_from_type(inp_type: torch._C.TensorType): + size = inp_type.sizes() + stride = inp_type.strides() + device = inp_type.device() + dtype = inp_type.dtype() + assert size is not None + assert stride is not None + assert device is not None + assert dtype is not None + return torch.empty_strided(size=size, stride=stride, device=device, dtype=dtype) + +def load_graph_and_inputs(ir: str) -> tuple[Any, list[Any]]: + graph = torch._C.parse_ir(ir, parse_tensor_constants=True) + graph.makeMultiOutputIntoTuple() + inputs = [] + for inp in graph.inputs(): + if isinstance(inp.type(), torch._C.FloatType): + inputs.append(random.uniform(.1, 100)) + elif isinstance(inp.type(), torch._C.IntType): + inputs.append(random.randint(1, 100)) + elif isinstance(inp.type(), torch._C.TensorType): + tensorType = cast(torch._C.TensorType, inp.type()) + inputs.append(make_tensor_from_type(tensorType)) + elif isinstance(inp.type(), torch._C.BoolType): + inputs.append(random.randint(0, 1) == 1) + else: + raise NotImplementedError(f"A default value is not implemented for type {inp.type()}") + + func = torch._C._create_function_from_graph("forward", graph) + torch._C._jit_pass_erase_shape_information(func.graph) + return (func, inputs) + +def time_cuda(fn, inputs, test_runs): + t = Timer(stmt="fn(*inputs)", globals={"fn": fn, "inputs" : inputs}) + times = t.blocked_autorange() + return times.median * 1000 # time in ms + +def time_cpu(fn, inputs, test_runs): + s = time.perf_counter() + for _ in range(test_runs): + fn(*inputs) + e = time.perf_counter() + return (e - s) / test_runs * 1000 # time in ms + +def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float: + graph, _ = load_graph_and_inputs(ir) + for _ in range(warmup_runs): + graph(*inputs) + + is_cpu = None + for input in inputs: + if isinstance(input, torch.Tensor): + is_cpu = input.device.type == "cpu" + break + assert is_cpu is not None + + out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs) + return out + +@contextmanager +def no_fuser(*args, **kwargs): + old_optimize = torch._C._get_graph_executor_optimize(False) + try: + yield + finally: + torch._C._get_graph_executor_optimize(old_optimize) + +def run_baseline_no_fusion(ir, inputs) -> float: + with no_fuser(): + return run_test(ir, inputs) + + +def run_nnc(ir, inputs, dynamic) -> float: + try: + strat = [("DYNAMIC", 10)] if dynamic else [("STATIC", 10)] + old_strat = torch.jit.set_fusion_strategy(strat) + with torch.jit.fuser("fuser1"): + return run_test(ir, inputs) + finally: + torch.jit.set_fusion_strategy(old_strat) + +def run_nvfuser(ir, inputs) -> float: + with torch.jit.fuser("fuser2"): + return run_test(ir, inputs) diff --git a/phivenv/Lib/site-packages/torch/utils/mkldnn.py b/phivenv/Lib/site-packages/torch/utils/mkldnn.py new file mode 100644 index 0000000000000000000000000000000000000000..389bc2eeeb93f12140ebca278e4d24563165fb2e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/mkldnn.py @@ -0,0 +1,234 @@ +# mypy: allow-untyped-defs +import torch + + +class MkldnnLinear(torch.jit.ScriptModule): + def __init__(self, dense_module, dtype): + super().__init__() + self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) + if dense_module.bias is not None: + # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy, + # we use fp32 dtype. + self.register_buffer('bias', dense_module.bias.to_mkldnn()) + else: + # TODO: Remove this once ScriptModule supports registering None buffer + self.register_buffer( + 'bias', + torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) + + @torch.jit.script_method + def __getstate__(self): + return (self.weight.to_dense(), self.bias.to_dense(), self.training) + + @torch.jit.script_method + def __setstate__(self, state): + self.weight = state[0].to_mkldnn() + self.bias = state[1].to_mkldnn() + self.training = state[2] + + @torch.jit.script_method + def forward(self, x): + x_mkldnn = x if x.is_mkldnn else x.to_mkldnn() + y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias) + y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense() + return y + + +class _MkldnnConvNd(torch.jit.ScriptModule): + """Common base of MkldnnConv1d and MkldnnConv2d.""" + + __constants__ = ['stride', 'padding', 'dilation', 'groups'] + + def __init__(self, dense_module): + super().__init__() + + self.stride = dense_module.stride + self.padding = dense_module.padding + self.dilation = dense_module.dilation + self.groups = dense_module.groups + + if dense_module.bias is not None: + self.register_buffer('bias', dense_module.bias.to_mkldnn()) + else: + # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy, + # we use fp32 dtype. + # TODO: Remove this once ScriptModule supports registering None buffer + self.register_buffer( + 'bias', + torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) + + @torch.jit.script_method + def __getstate__(self): + return (self.weight.to_dense(), self.bias.to_dense(), self.training) + + @torch.jit.script_method + def forward(self, x): + return torch.mkldnn_convolution( + x, + self.weight, + self.bias, + self.padding, + self.stride, + self.dilation, + self.groups) + + +class MkldnnConv1d(_MkldnnConvNd): + def __init__(self, dense_module, dtype): + super().__init__(dense_module) + + self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) + + @torch.jit.script_method + def __setstate__(self, state): + self.weight = state[0].to_mkldnn() + self.bias = state[1].to_mkldnn() + self.training = state[2] + + +class MkldnnConv2d(_MkldnnConvNd): + def __init__(self, dense_module, dtype): + super().__init__(dense_module) + + self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight( + dense_module.weight.to_mkldnn(dtype), + self.padding, + self.stride, + self.dilation, + self.groups)) + + @torch.jit.script_method + def __setstate__(self, state): + self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight( + state[0].to_mkldnn(), + self.padding, + self.stride, + self.dilation, + self.groups) + self.bias = state[1].to_mkldnn() + self.training = state[2] + +class MkldnnConv3d(_MkldnnConvNd): + def __init__(self, dense_module, dtype): + super().__init__(dense_module) + + self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight( + dense_module.weight.to_mkldnn(dtype), + self.padding, + self.stride, + self.dilation, + self.groups)) + + @torch.jit.script_method + def __setstate__(self, state): + self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight( + state[0].to_mkldnn(), + self.padding, + self.stride, + self.dilation, + self.groups) + self.bias = state[1].to_mkldnn() + self.training = state[2] + + +class MkldnnBatchNorm(torch.jit.ScriptModule): + __constants__ = ['exponential_average_factor', 'eps'] + + def __init__(self, dense_module): + super().__init__() + + assert not dense_module.training + assert dense_module.track_running_stats + assert dense_module.affine + + if dense_module.momentum is None: + self.exponential_average_factor = 0.0 + else: + self.exponential_average_factor = dense_module.momentum + self.eps = dense_module.eps + + self.register_buffer('weight', dense_module.weight.to_mkldnn()) + self.register_buffer('bias', dense_module.bias.to_mkldnn()) + self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn()) + self.register_buffer('running_var', dense_module.running_var.to_mkldnn()) + + @torch.jit.script_method + def __getstate__(self): + weight = self.weight.to_dense() + bias = self.bias.to_dense() + running_mean = self.running_mean.to_dense() + running_var = self.running_var.to_dense() + return (weight, bias, running_mean, running_var, self.training) + + @torch.jit.script_method + def __setstate__(self, state): + self.weight = state[0].to_mkldnn() + self.bias = state[1].to_mkldnn() + self.running_mean = state[2].to_mkldnn() + self.running_var = state[3].to_mkldnn() + self.training = state[4] + + @torch.jit.script_method + def forward(self, x): + return torch.batch_norm( + x, + self.weight, + self.bias, + self.running_mean, + self.running_var, + False, # training + self.exponential_average_factor, + self.eps, + False, # cuda_enabled + ) + +class MkldnnPrelu(torch.jit.ScriptModule): + def __init__(self, dense_module, dtype): + super().__init__() + self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) + + @torch.jit.script_method + def __getstate__(self): + return (self.weight.to_dense(), self.training) + + @torch.jit.script_method + def __setstate__(self, state): + self.weight = state[0].to_mkldnn() + self.training = state[1] + + @torch.jit.script_method + def forward(self, x): + x_mkldnn = x if x.is_mkldnn else x.to_mkldnn() + y_mkldnn = torch.prelu(x_mkldnn, self.weight) + y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense() + return y + +def to_mkldnn(module, dtype=torch.float): + assert dtype in [torch.float, torch.bfloat16, torch.half], \ + "MKLDNN only support float, bfloat16, and half path now" + + def m_fn(m, d): + if isinstance(m, torch.nn.Linear): + return MkldnnLinear(m, d) + elif isinstance(m, torch.nn.Conv1d): + return MkldnnConv1d(m, d) + elif isinstance(m, torch.nn.Conv2d): + return MkldnnConv2d(m, d) + elif isinstance(m, torch.nn.Conv3d): + return MkldnnConv3d(m, d) + elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)): + # For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype. + # so it doesn't need dtype argument. + return MkldnnBatchNorm(m) + elif isinstance(m, torch.nn.PReLU): + return MkldnnPrelu(m, d) + else: + return m + + def m_fn_rec(m, d): + new_m = m_fn(m, d) + for name, sub_m in m.named_children(): + setattr(new_m, name, m_fn_rec(sub_m, d)) + return new_m + + return m_fn_rec(module, dtype) diff --git a/phivenv/Lib/site-packages/torch/utils/mobile_optimizer.py b/phivenv/Lib/site-packages/torch/utils/mobile_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca480d858980770a29a52afa727503138491793 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/mobile_optimizer.py @@ -0,0 +1,135 @@ +# mypy: allow-untyped-defs +"""This module contains utility method for mobile model optimization and lint.""" + +import torch +from enum import Enum +from torch._C import _MobileOptimizerType as MobileOptimizerType +from typing import Optional, AnyStr + +class LintCode(Enum): + BUNDLED_INPUT = 1 + REQUIRES_GRAD = 2 + DROPOUT = 3 + BATCHNORM = 4 + +def optimize_for_mobile( + script_module: torch.jit.ScriptModule, + optimization_blocklist: Optional[set[MobileOptimizerType]] = None, + preserved_methods: Optional[list[AnyStr]] = None, + backend: str = 'CPU') -> torch.jit.RecursiveScriptModule: + """ + Optimize a torch script module for mobile deployment. + + Args: + script_module: An instance of torch script module with type of ScriptModule. + optimization_blocklist: A set with type of MobileOptimizerType. When set is not passed, + optimization method will run all the optimizer pass; otherwise, optimizer + method will run the optimization pass that is not included inside optimization_blocklist. + preserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked + backend: Device type to use for running the result model ('CPU'(default), 'Vulkan' or 'Metal'). + Returns: + A new optimized torch script module + """ + if not isinstance(script_module, torch.jit.ScriptModule): + raise TypeError( + f'Got {type(script_module)}, but ScriptModule is expected.') + + if optimization_blocklist is None: + optimization_blocklist = set() + + if preserved_methods is None: + preserved_methods = [] + + # Convert potential byte arrays into strings (if there is any) to pass type checking + # Here we use a new name as assigning it back to preserved_methods will invoke + # mypy errors (i.e. List[AnyStr] = List[str]) + preserved_methods_str: list[str] = [str(method) for method in preserved_methods] + + bundled_inputs_attributes = _get_bundled_inputs_preserved_attributes(script_module, preserved_methods_str) + if all(hasattr(script_module, method) for method in bundled_inputs_attributes): + preserved_methods_str = list(set(preserved_methods_str + bundled_inputs_attributes)) + + non_exist_methods = [method for method in preserved_methods_str if not hasattr(script_module, method)] + if non_exist_methods: + raise AttributeError( + f"The following methods to preserve do not exist in script_module: {', '.join(non_exist_methods)}") + + backend = backend.lower() + if backend == 'cpu': + optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile( + script_module._c, + optimization_blocklist, + preserved_methods_str) + elif backend == 'vulkan': + optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile( + script_module._c, + optimization_blocklist, + preserved_methods_str) + elif backend == 'metal': + optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str) + else: + raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'") + + return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module) + + +def generate_mobile_module_lints(script_module: torch.jit.ScriptModule): + """ + Generate a list of lints for a given torch script module. + + Args: + script_module: An instance of torch script module with type of ScriptModule. + + Returns: + lint_map: A list of dictionary that contains modules lints + """ + if not isinstance(script_module, torch.jit.ScriptModule): + raise TypeError( + f'Got {type(script_module)}, but ScriptModule is expected.') + + lint_list = [] + + if not hasattr(script_module, "_generate_bundled_inputs_for_forward"): + lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input for forward, please add bundled inputs " + "before saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."}) + + for name, param in script_module.named_parameters(): + if param.requires_grad: + lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": f"Param {name} requires grad, " + "please set torch.no_grad() to reduce memory usage and improve computation speed during " + "inference phase."}) + + op_names = torch.jit.export_opnames(script_module) + for op_name in op_names: + if "dropout" in op_name: + lint_list.append({"name": LintCode.DROPOUT.name, + "message": f"Operator {op_name} exists, remember to call eval() before " + "saving the module.and call torch.utils.mobile_optimizer.optimize_for_mobile to drop dropout " + "operator."}) + if "batch_norm" in op_name: + lint_list.append({"name": LintCode.BATCHNORM.name, + "message": f"Operator {op_name} exists, remember to call eval() before " + "saving the module and call torch.utils.mobile_optimizer.optimize_for_mobile to drop batch_norm " + "operator."}) + + return lint_list + +def _get_bundled_inputs_preserved_attributes(script_module: torch.jit.ScriptModule, preserved_methods: list[str]) -> list[str]: + + bundled_inputs_attributes = [] + # Has bundled inputs for forward + if hasattr(script_module, 'get_all_bundled_inputs'): + bundled_inputs_attributes.append('get_all_bundled_inputs') + bundled_inputs_attributes.append('get_num_bundled_inputs') + + # Bundled inputs in module after the change that introduced bundled inputs for multiple functions + if hasattr(script_module, 'get_bundled_inputs_functions_and_info'): + bundled_inputs_attributes.append('get_bundled_inputs_functions_and_info') + all_info = script_module.get_bundled_inputs_functions_and_info() + for function_name in all_info: + if function_name not in preserved_methods: + bundled_inputs_attributes.append(function_name) + bundled_inputs_attributes.append("get_all_bundled_inputs_for_" + function_name) + bundled_inputs_attributes.append("_bundled_inputs_deflated_" + function_name) + + return bundled_inputs_attributes diff --git a/phivenv/Lib/site-packages/torch/utils/model_dump/__init__.py b/phivenv/Lib/site-packages/torch/utils/model_dump/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55cbe0606bc5954ea4a21ff31fa893c0d3b1ebf9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/model_dump/__init__.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs +""" +model_dump: a one-stop shop for TorchScript model inspection. + +The goal of this tool is to provide a simple way to extract lots of +useful information from a TorchScript model and make it easy for humans +to consume. It (mostly) replaces zipinfo, common uses of show_pickle, +and various ad-hoc analysis notebooks. + +The tool extracts information from the model and serializes it as JSON. +That JSON can then be rendered by an HTML+JS page, either by +loading the JSON over HTTP or producing a fully self-contained page +with all of the code and data burned-in. +""" + +# Maintainer notes follow. +""" +The implementation strategy has tension between 3 goals: +- Small file size. +- Fully self-contained. +- Easy, modern JS environment. +Using Preact and HTM achieves 1 and 2 with a decent result for 3. +However, the models I tested with result in ~1MB JSON output, +so even using something heavier like full React might be tolerable +if the build process can be worked out. + +One principle I have followed that I think is very beneficial +is to keep the JSON data as close as possible to the model +and do most of the rendering logic on the client. +This makes for easier development (just refresh, usually), +allows for more laziness and dynamism, and lets us add more +views of the same data without bloating the HTML file. + +Currently, this code doesn't actually load the model or even +depend on any part of PyTorch. I don't know if that's an important +feature to maintain, but it's probably worth preserving the ability +to run at least basic analysis on models that cannot be loaded. + +I think the easiest way to develop this code is to cd into model_dump and +run "python -m http.server", then load http://localhost:8000/skeleton.html +in the browser. In another terminal, run +"python -m torch.utils.model_dump --style=json FILE > \ + torch/utils/model_dump/model_info.json" +every time you update the Python code or model. +When you update JS, just refresh. + +Possible improvements: + - Fix various TODO comments in this file and the JS. + - Make the HTML much less janky, especially the auxiliary data panel. + - Make the auxiliary data panel start small, expand when + data is available, and have a button to clear/contract. + - Clean up the JS. There's a lot of copypasta because + I don't really know how to use Preact. + - Make the HTML render and work nicely inside a Jupyter notebook. + - Add the ability for JS to choose the URL to load the JSON based + on the page URL (query or hash). That way we could publish the + inlined skeleton once and have it load various JSON blobs. + - Add a button to expand all expandable sections so ctrl-F works well. + - Add hyperlinking from data to code, and code to code. + - Add hyperlinking from debug info to Diffusion. + - Make small tensor contents available. + - Do something nice for quantized models + (they probably don't work at all right now). +""" + +import argparse +import io +import json +import os +import pickle +import pprint +import re +import sys +import urllib.parse +import zipfile +from pathlib import Path +import warnings + +import torch.utils.show_pickle + + +DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024 + +__all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton', + 'burn_in_info', 'get_info_and_burn_skeleton'] + +def get_storage_info(storage): + assert isinstance(storage, torch.utils.show_pickle.FakeObject) + assert storage.module == "pers" + assert storage.name == "obj" + assert storage.state is None + assert isinstance(storage.args, tuple) + assert len(storage.args) == 1 + sa = storage.args[0] + assert isinstance(sa, tuple) + assert len(sa) == 5 + assert sa[0] == "storage" + assert isinstance(sa[1], torch.utils.show_pickle.FakeClass) + assert sa[1].module == "torch" + assert sa[1].name.endswith("Storage") + storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:]) + return storage_info + + +def hierarchical_pickle(data): + if isinstance(data, (bool, int, float, str, type(None))): + return data + if isinstance(data, list): + return [hierarchical_pickle(d) for d in data] + if isinstance(data, tuple): + return { + "__tuple_values__": hierarchical_pickle(list(data)), + } + if isinstance(data, dict): + return { + "__is_dict__": True, + "keys": hierarchical_pickle(list(data.keys())), + "values": hierarchical_pickle(list(data.values())), + } + if isinstance(data, torch.utils.show_pickle.FakeObject): + typename = f"{data.module}.{data.name}" + if ( + typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.')) + ): + assert data.args == () + return { + "__module_type__": typename, + "state": hierarchical_pickle(data.state), + } + if typename == "torch._utils._rebuild_tensor_v2": + assert data.state is None + storage, offset, size, stride, requires_grad, *_ = data.args + storage_info = get_storage_info(storage) + return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]} + if typename == "torch._utils._rebuild_qtensor": + assert data.state is None + storage, offset, size, stride, quantizer, requires_grad, *_ = data.args + storage_info = get_storage_info(storage) + assert isinstance(quantizer, tuple) + assert isinstance(quantizer[0], torch.utils.show_pickle.FakeClass) + assert quantizer[0].module == "torch" + if quantizer[0].name == "per_tensor_affine": + assert len(quantizer) == 3 + assert isinstance(quantizer[1], float) + assert isinstance(quantizer[2], int) + quantizer_extra = list(quantizer[1:3]) + else: + quantizer_extra = [] + quantizer_json = [quantizer[0].name] + quantizer_extra + return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]} + if typename == "torch.jit._pickle.restore_type_tag": + assert data.state is None + obj, typ = data.args + assert isinstance(typ, str) + return hierarchical_pickle(obj) + if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename): + assert data.state is None + ls, = data.args + assert isinstance(ls, list) + return hierarchical_pickle(ls) + if typename == "torch.device": + assert data.state is None + name, = data.args + assert isinstance(name, str) + # Just forget that it was a device and return the name. + return name + if typename == "builtin.UnicodeDecodeError": + assert data.state is None + msg, = data.args + assert isinstance(msg, str) + # Hack: Pretend this is a module so we don't need custom serialization. + # Hack: Wrap the message in a tuple so it looks like a nice state object. + # TODO: Undo at least that second hack. We should support string states. + return { + "__module_type__": typename, + "state": hierarchical_pickle((msg,)), + } + raise Exception(f"Can't prepare fake object of type for JS: {typename}") # noqa: TRY002 + raise Exception(f"Can't prepare data of type for JS: {type(data)}") # noqa: TRY002 + + +def get_model_info( + path_or_file, + title=None, + extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT): + """Get JSON-friendly information about a model. + + The result is suitable for being saved as model_info.json, + or passed to burn_in_info. + """ + + if isinstance(path_or_file, os.PathLike): + default_title = os.fspath(path_or_file) + file_size = path_or_file.stat().st_size # type: ignore[attr-defined] + elif isinstance(path_or_file, str): + default_title = path_or_file + file_size = Path(path_or_file).stat().st_size + else: + default_title = "buffer" + path_or_file.seek(0, io.SEEK_END) + file_size = path_or_file.tell() + path_or_file.seek(0) + + title = title or default_title + + with zipfile.ZipFile(path_or_file) as zf: + path_prefix = None + zip_files = [] + for zi in zf.infolist(): + prefix = re.sub("/.*", "", zi.filename) + if path_prefix is None: + path_prefix = prefix + elif prefix != path_prefix: + raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002 + zip_files.append(dict( + filename=zi.filename, + compression=zi.compress_type, + compressed_size=zi.compress_size, + file_size=zi.file_size, + )) + + assert path_prefix is not None + version = zf.read(path_prefix + "/version").decode("utf-8").strip() + + def get_pickle(name): + assert path_prefix is not None + with zf.open(path_prefix + f"/{name}.pkl") as handle: + raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() + return hierarchical_pickle(raw) + + model_data = get_pickle("data") + constants = get_pickle("constants") + + # Intern strings that are likely to be re-used. + # Pickle automatically detects shared structure, + # so re-used strings are stored efficiently. + # However, JSON has no way of representing this, + # so we have to do it manually. + interned_strings : dict[str, int] = {} + + def ist(s): + if s not in interned_strings: + interned_strings[s] = len(interned_strings) + return interned_strings[s] + + code_files = {} + for zi in zf.infolist(): + if not zi.filename.endswith(".py"): + continue + with zf.open(zi) as handle: + raw_code = handle.read() + with zf.open(zi.filename + ".debug_pkl") as handle: + raw_debug = handle.read() + + # Parse debug info and add begin/end markers if not present + # to ensure that we cover the entire source code. + debug_info_t = pickle.loads(raw_debug) + text_table = None + + if (len(debug_info_t) == 3 and + isinstance(debug_info_t[0], str) and + debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'): + _, text_table, content = debug_info_t + + def parse_new_format(line): + # (0, (('', '', 0), 0, 0)) + num, ((text_indexes, fname_idx, offset), start, end), tag = line + text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index] + fname = text_table[fname_idx] # type: ignore[index] + return num, ((text, fname, offset), start, end), tag + + debug_info_t = map(parse_new_format, content) + + debug_info = list(debug_info_t) + if not debug_info: + debug_info.append((0, (('', '', 0), 0, 0))) + if debug_info[-1][0] != len(raw_code): + debug_info.append((len(raw_code), (('', '', 0), 0, 0))) + + code_parts = [] + for di, di_next in zip(debug_info, debug_info[1:]): + start, source_range, *_ = di + end = di_next[0] + assert end > start + source, s_start, s_end = source_range + s_text, s_file, s_line = source + # TODO: Handle this case better. TorchScript ranges are in bytes, + # but JS doesn't really handle byte strings. + # if bytes and chars are not equivalent for this string, + # zero out the ranges so we don't highlight the wrong thing. + if len(s_text) != len(s_text.encode("utf-8")): + s_start = 0 + s_end = 0 + text = raw_code[start:end] + code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end]) + code_files[zi.filename] = code_parts + + extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json") + extra_files_jsons = {} + for zi in zf.infolist(): + if not extra_files_json_pattern.fullmatch(zi.filename): + continue + if zi.file_size > extra_file_size_limit: + continue + with zf.open(zi) as handle: + try: + json_content = json.load(handle) + extra_files_jsons[zi.filename] = json_content + except json.JSONDecodeError: + extra_files_jsons[zi.filename] = "INVALID JSON" + + always_render_pickles = { + "bytecode.pkl", + } + extra_pickles = {} + for zi in zf.infolist(): + if not zi.filename.endswith(".pkl"): + continue + with zf.open(zi) as handle: + # TODO: handle errors here and just ignore the file? + # NOTE: For a lot of these files (like bytecode), + # we could get away with just unpickling, but this should be safer. + obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() + buf = io.StringIO() + pprint.pprint(obj, buf) + contents = buf.getvalue() + # Checked the rendered length instead of the file size + # because pickles with shared structure can explode in size during rendering. + if os.path.basename(zi.filename) not in always_render_pickles and \ + len(contents) > extra_file_size_limit: + continue + extra_pickles[zi.filename] = contents + + return {"model": dict( + title=title, + file_size=file_size, + version=version, + zip_files=zip_files, + interned_strings=list(interned_strings), + code_files=code_files, + model_data=model_data, + constants=constants, + extra_files_jsons=extra_files_jsons, + extra_pickles=extra_pickles, + )} + + +def get_inline_skeleton(): + """Get a fully-inlined skeleton of the frontend. + + The returned HTML page has no external network dependencies for code. + It can load model_info.json over HTTP, or be passed to burn_in_info. + """ + + import importlib.resources + + skeleton = importlib.resources.read_text(__package__, "skeleton.html") + js_code = importlib.resources.read_text(__package__, "code.js") + for js_module in ["preact", "htm"]: + js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs") + js_url = "data:application/javascript," + urllib.parse.quote(js_lib) + js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url) + skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code) + return skeleton + + +def burn_in_info(skeleton, info): + """Burn model info into the HTML skeleton. + + The result will render the hard-coded model info and + have no external network dependencies for code or data. + """ + + # Note that Python's json serializer does not escape slashes in strings. + # Since we're inlining this JSON directly into a script tag, a string + # containing "" would end the script prematurely and + # mess up our page. Unconditionally escape fixes that. + return skeleton.replace( + "BURNED_IN_MODEL_INFO = null", + "BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/")) + + +def get_info_and_burn_skeleton(path_or_bytesio, **kwargs): + model_info = get_model_info(path_or_bytesio, **kwargs) + skeleton = get_inline_skeleton() + page = burn_in_info(skeleton, model_info) + return page + + +def main(argv, *, stdout=None): + warnings.warn("torch.utils.model_dump is deprecated and will be removed in a future PyTorch release.") + parser = argparse.ArgumentParser() + parser.add_argument("--style", choices=["json", "html"]) + parser.add_argument("--title") + parser.add_argument("model") + args = parser.parse_args(argv[1:]) + + info = get_model_info(args.model, title=args.title) + + output = stdout or sys.stdout + + if args.style == "json": + output.write(json.dumps(info, sort_keys=True) + "\n") + elif args.style == "html": + skeleton = get_inline_skeleton() + page = burn_in_info(skeleton, info) + output.write(page) + else: + raise Exception("Invalid style") # noqa: TRY002 diff --git a/phivenv/Lib/site-packages/torch/utils/model_dump/__main__.py b/phivenv/Lib/site-packages/torch/utils/model_dump/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..07956f5654ef85908fe5207fdcbd8bcde6166676 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/model_dump/__main__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +import sys +from . import main + +sys.exit(main(sys.argv)) diff --git a/phivenv/Lib/site-packages/torch/utils/model_dump/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/model_dump/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f85ae8fcef3ec093b3c56f121b34371e6282cfd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/model_dump/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/model_dump/__pycache__/__main__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/model_dump/__pycache__/__main__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d365502b6a1f473cf595de4d6c5e2d72b5a7bc8f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/model_dump/__pycache__/__main__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/model_dump/code.js b/phivenv/Lib/site-packages/torch/utils/model_dump/code.js new file mode 100644 index 0000000000000000000000000000000000000000..d0eab53d72fbc382ef06cfceb420765c997b6f1f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/model_dump/code.js @@ -0,0 +1,689 @@ +import { h, Component, render } from 'https://unpkg.com/preact?module'; +import htm from 'https://unpkg.com/htm?module'; + +const html = htm.bind(h); + +const BURNED_IN_MODEL_INFO = null; + +// https://stackoverflow.com/a/20732091 +function humanFileSize(size) { + if (size == 0) { return "0 B"; } + var i = Math.floor( Math.log(size) / Math.log(1024) ); + return (size / Math.pow(1024, i)).toFixed(2) * 1 + ' ' + ['B', 'kB', 'MB', 'GB', 'TB'][i]; +} + +function caret(down) { + return down ? "\u25BE" : "\u25B8"; +} + +class Blamer { + constructor() { + this.blame_on_click = false; + this.aux_content_pane = null; + } + + setAuxContentPane(pane) { + this.aux_content_pane = pane; + } + + readyBlame() { + this.blame_on_click = true; + } + + maybeBlame(arg) { + if (!this.blame_on_click) { + return; + } + this.blame_on_click = false; + if (!this.aux_content_pane) { + return; + } + this.aux_content_pane.doBlame(arg); + } +} + +let blame = new Blamer(); + +class Hider extends Component { + constructor() { + super(); + this.state = { shown: null }; + } + + componentDidMount() { + this.setState({ shown: this.props.shown === "true" }); + } + + render({name, children}, {shown}) { + let my_caret = html` this.click()} >${caret(shown)}`; + return html`
+

${my_caret} ${name}

+
${shown ? this.props.children : []}
`; + } + + click() { + this.setState({shown: !this.state.shown}); + } +} + +function ModelSizeSection({model: {file_size, zip_files}}) { + let store_size = 0; + let compr_size = 0; + for (const zi of zip_files) { + if (zi.compression === 0) { + // TODO: Maybe check that compressed_size === file_size. + store_size += zi.compressed_size; + } else { + compr_size += zi.compressed_size; + } + } + let zip_overhead = file_size - store_size - compr_size; + // TODO: Better formatting. Right-align this. + return html` + <${Hider} name="Model Size" shown=true> +
.
+      Model size: ${file_size} (${humanFileSize(file_size)})
+      Stored files: ${store_size} (${humanFileSize(store_size)})
+      Compressed files: ${compr_size} (${humanFileSize(compr_size)})
+      Zip overhead: ${zip_overhead} (${humanFileSize(zip_overhead)})
+    
`; +} + +function StructuredDataSection({name, data, shown}) { + return html` + <${Hider} name=${name} shown=${shown}> +
+ <${StructuredData} data=${data} indent="" prefix=""/> +
`; +} + +class StructuredData extends Component { + constructor() { + super(); + this.state = { shown: false }; + + this.INLINE_TYPES = new Set(["boolean", "number", "string"]) + this.IGNORED_STATE_KEYS = new Set(["training", "_is_full_backward_hook"]) + } + + click() { + this.setState({shown: !this.state.shown}); + } + + expando(data) { + if (data === null || this.INLINE_TYPES.has(typeof(data))) { + return false; + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + // TODO: Maybe show simple lists and tuples on one line. + return true; + } + if (data.__tuple_values__) { + // TODO: Maybe show simple lists and tuples on one line. + return true; + } + if (data.__is_dict__) { + // TODO: Maybe show simple (empty?) dicts on one line. + return true; + } + if (data.__module_type__) { + return true; + } + if (data.__tensor_v2__) { + return false; + } + if (data.__qtensor__) { + return false; + } + throw new Error("Can't handle data type.", data); + } + + renderHeadline(data) { + if (data === null) { + return "None"; + } + if (typeof(data) == "boolean") { + const sd = String(data); + return sd.charAt(0).toUpperCase() + sd.slice(1); + } + if (typeof(data) == "number") { + return JSON.stringify(data); + } + if (typeof(data) == "string") { + return JSON.stringify(data); + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + return "list(["; + } + if (data.__tuple_values__) { + return "tuple(("; + } + if (data.__is_dict__) { + return "dict({"; + } + if (data.__module_type__) { + return data.__module_type__ + "()"; + } + if (data.__tensor_v2__) { + const [storage, offset, size, stride, grad] = data.__tensor_v2__; + const [dtype, key, device, numel] = storage; + return this.renderTensor( + "tensor", dtype, key, device, numel, offset, size, stride, grad, []); + } + if (data.__qtensor__) { + const [storage, offset, size, stride, quantizer, grad] = data.__qtensor__; + const [dtype, key, device, numel] = storage; + let extra_parts = []; + if (quantizer[0] == "per_tensor_affine") { + extra_parts.push(`scale=${quantizer[1]}`); + extra_parts.push(`zero_point=${quantizer[2]}`); + } else { + extra_parts.push(`quantizer=${quantizer[0]}`); + } + return this.renderTensor( + "qtensor", dtype, key, device, numel, offset, size, stride, grad, extra_parts); + } + throw new Error("Can't handle data type.", data); + } + + renderTensor( + prefix, + dtype, + storage_key, + device, + storage_numel, + offset, + size, + stride, + grad, + extra_parts) { + let parts = [ + "(" + size.join(",") + ")", + dtype, + ]; + parts.push(...extra_parts); + if (device != "cpu") { + parts.push(device); + } + if (grad) { + parts.push("grad"); + } + // TODO: Check stride and indicate if the tensor is channels-last or non-contiguous + // TODO: Check size, stride, offset, and numel and indicate if + // the tensor doesn't use all data in storage. + // TODO: Maybe show key? + void(offset); + void(stride); + void(storage_key); + void(storage_numel); + return prefix + "(" + parts.join(", ") + ")"; + } + + renderBody(indent, data) { + if (data === null || this.INLINE_TYPES.has(typeof(data))) { + throw "Should not reach here." + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + let new_indent = indent + "\u00A0\u00A0"; + let parts = []; + for (let idx = 0; idx < data.length; idx++) { + // Does it make sense to put explicit index numbers here? + parts.push(html`
<${StructuredData} prefix=${idx + ": "} indent=${new_indent} data=${data[idx]} />`); + } + return parts; + } + if (data.__tuple_values__) { + // Handled the same as lists. + return this.renderBody(indent, data.__tuple_values__); + } + if (data.__is_dict__) { + let new_indent = indent + "\u00A0\u00A0"; + let parts = []; + for (let idx = 0; idx < data.keys.length; idx++) { + if (typeof(data.keys[idx]) != "string") { + parts.push(html`
${new_indent}Non-string key`); + } else { + parts.push(html`
<${StructuredData} prefix=${data.keys[idx] + ": "} indent=${new_indent} data=${data.values[idx]} />`); + } + } + return parts; + } + if (data.__module_type__) { + const mstate = data.state; + if (mstate === null || typeof(mstate) != "object") { + throw new Error("Bad module state"); + } + let new_indent = indent + "\u00A0\u00A0"; + let parts = []; + if (mstate.__is_dict__) { + // TODO: Less copy/paste between this and normal dicts. + for (let idx = 0; idx < mstate.keys.length; idx++) { + if (typeof(mstate.keys[idx]) != "string") { + parts.push(html`
${new_indent}Non-string key`); + } else if (this.IGNORED_STATE_KEYS.has(mstate.keys[idx])) { + // Do nothing. + } else { + parts.push(html`
<${StructuredData} prefix=${mstate.keys[idx] + ": "} indent=${new_indent} data=${mstate.values[idx]} />`); + } + } + } else if (mstate.__tuple_values__) { + parts.push(html`
<${StructuredData} prefix="" indent=${new_indent} data=${mstate} />`); + } else if (mstate.__module_type__) { + // We normally wouldn't have the state of a module be another module, + // but we use "modules" to encode special values (like Unicode decode + // errors) that might be valid states. Just go with it. + parts.push(html`
<${StructuredData} prefix="" indent=${new_indent} data=${mstate} />`); + } else { + throw new Error("Bad module state"); + } + return parts; + } + if (data.__tensor_v2__) { + throw "Should not reach here." + } + if (data.__qtensor__) { + throw "Should not reach here." + } + throw new Error("Can't handle data type.", data); + } + + render({data, indent, prefix}, {shown}) { + const exp = this.expando(data) ? html` this.click()} >${caret(shown)} ` : ""; + const headline = this.renderHeadline(data); + const body = shown ? this.renderBody(indent, data) : ""; + return html`${indent}${exp}${prefix}${headline}${body}`; + } +} + +function ZipContentsSection({model: {zip_files}}) { + // TODO: Add human-readable sizes? + // TODO: Add sorting options? + // TODO: Add hierarchical collapsible tree? + return html` + <${Hider} name="Zip Contents" shown=false> + + + + + + + + + + + ${zip_files.map(zf => html` + + + + + `)} + +
ModeSizeCompressedName
${{0: "store", 8: "deflate"}[zf.compression] || zf.compression}${zf.file_size}${zf.compressed_size}${zf.filename}
`; +} + +function CodeSection({model: {code_files}}) { + return html` + <${Hider} name="Code" shown=false> +
+ ${Object.entries(code_files).map(([fn, code]) => html`<${OneCodeSection} + filename=${fn} code=${code} />`)} +
`; +} + +class OneCodeSection extends Component { + constructor() { + super(); + this.state = { shown: false }; + } + + click() { + const shown = !this.state.shown; + this.setState({shown: shown}); + } + + render({filename, code}, {shown}) { + const header = html` +

+ this.click()} >${caret(shown)} + ${filename}

+ `; + if (!shown) { + return header; + } + return html` + ${header} +
${code.map(c => this.renderBlock(c))}
+ `; + } + + renderBlock([text, ist_file, line, ist_s_text, s_start, s_end]) { + return html` blame.maybeBlame({ist_file, line, ist_s_text, s_start, s_end})} + >${text}`; + } +} + +function ExtraJsonSection({files}) { + return html` + <${Hider} name="Extra files (JSON)" shown=false> +
+

Use "Log Raw Model Info" for hierarchical view in browser console.

+ ${Object.entries(files).map(([fn, json]) => html`<${OneJsonSection} + filename=${fn} json=${json} />`)} +
`; +} + +class OneJsonSection extends Component { + constructor() { + super(); + this.state = { shown: false }; + } + + click() { + const shown = !this.state.shown; + this.setState({shown: shown}); + } + + render({filename, json}, {shown}) { + const header = html` +

+ this.click()} >${caret(shown)} + ${filename}

+ `; + if (!shown) { + return header; + } + return html` + ${header} +
${JSON.stringify(json, null, 2)}
+ `; + } +} + +function ExtraPicklesSection({files}) { + return html` + <${Hider} name="Extra Pickles" shown=false> +
+ ${Object.entries(files).map(([fn, content]) => html`<${OnePickleSection} + filename=${fn} content=${content} />`)} +
`; +} + +class OnePickleSection extends Component { + constructor() { + super(); + this.state = { shown: false }; + } + + click() { + const shown = !this.state.shown; + this.setState({shown: shown}); + } + + render({filename, content}, {shown}) { + const header = html` +

+ this.click()} >${caret(shown)} + ${filename}

+ `; + if (!shown) { + return header; + } + return html` + ${header} +
${content}
+ `; + } +} + +function assertStorageAreEqual(key, lhs, rhs) { + if (lhs.length !== rhs.length || + !lhs.every((val, idx) => val === rhs[idx])) { + throw new Error("Storage mismatch for key '" + key + "'"); + } +} + +function computeTensorMemory(numel, dtype) { + const sizes = { + "Byte": 1, + "Char": 1, + "Short": 2, + "Int": 4, + "Long": 8, + "Half": 2, + "Float": 4, + "Double": 8, + "ComplexHalf": 4, + "ComplexFloat": 8, + "ComplexDouble": 16, + "Bool": 1, + "QInt8": 1, + "QUInt8": 1, + "QInt32": 4, + "BFloat16": 2, + }; + let dtsize = sizes[dtype]; + if (!dtsize) { + throw new Error("Unrecognized dtype: " + dtype); + } + return numel * dtsize; +} + +// TODO: Maybe track by dtype as well. +// TODO: Maybe distinguish between visible size and storage size. +function getTensorStorages(data) { + if (data === null) { + return new Map(); + } + if (typeof(data) == "boolean") { + return new Map(); + } + if (typeof(data) == "number") { + return new Map(); + } + if (typeof(data) == "string") { + return new Map(); + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + let result = new Map(); + for (const item of data) { + const tensors = getTensorStorages(item); + for (const [key, storage] of tensors.entries()) { + if (!result.has(key)) { + result.set(key, storage); + } else { + const old_storage = result.get(key); + assertStorageAreEqual(key, old_storage, storage); + } + } + } + return result; + } + if (data.__tuple_values__) { + return getTensorStorages(data.__tuple_values__); + } + if (data.__is_dict__) { + return getTensorStorages(data.values); + } + if (data.__module_type__) { + return getTensorStorages(data.state); + } + if (data.__tensor_v2__) { + const [storage, offset, size, stride, grad] = data.__tensor_v2__; + const [dtype, key, device, numel] = storage; + return new Map([[key, storage]]); + } + if (data.__qtensor__) { + const [storage, offset, size, stride, quantizer, grad] = data.__qtensor__; + const [dtype, key, device, numel] = storage; + return new Map([[key, storage]]); + } + throw new Error("Can't handle data type.", data); +} + +function getTensorMemoryByDevice(pickles) { + let all_tensors = []; + for (const [name, pickle] of pickles) { + const tensors = getTensorStorages(pickle); + all_tensors.push(...tensors.values()); + } + let result = {}; + for (const storage of all_tensors.values()) { + const [dtype, key, device, numel] = storage; + const size = computeTensorMemory(numel, dtype); + result[device] = (result[device] || 0) + size; + } + return result; +} + +// Make this a separate component so it is rendered lazily. +class OpenTensorMemorySection extends Component { + render({model: {model_data, constants}}) { + let sizes = getTensorMemoryByDevice(new Map([ + ["data", model_data], + ["constants", constants], + ])); + return html` + + + + + + + + + + ${Object.entries(sizes).map(([dev, size]) => html` + + + + `)} + +
DeviceBytesHuman
${dev}${size}${humanFileSize(size)}
`; + } +} + +function TensorMemorySection({model}) { + return html` + <${Hider} name="Tensor Memory" shown=false> + <${OpenTensorMemorySection} model=${model} />`; +} + +class AuxContentPane extends Component { + constructor() { + super(); + this.state = { + blame_info: null, + }; + } + + doBlame(arg) { + this.setState({...this.state, blame_info: arg}); + } + + render({model: {interned_strings}}, {blame_info}) { + let blame_content = ""; + if (blame_info) { + const {ist_file, line, ist_s_text, s_start, s_end} = blame_info; + let s_text = interned_strings[ist_s_text]; + if (s_start != 0 || s_end != s_text.length) { + let prefix = s_text.slice(0, s_start); + let main = s_text.slice(s_start, s_end); + let suffix = s_text.slice(s_end); + s_text = html`${prefix}${main}${suffix}`; + } + blame_content = html` +

${interned_strings[ist_file]}:${line}

+
${s_start}:${s_end}
+
${s_text}

+ `; + } + return html` + +
+ ${blame_content} + `; + } +} + +class App extends Component { + constructor() { + super(); + this.state = { + err: false, + model: null, + }; + } + + componentDidMount() { + const app = this; + if (BURNED_IN_MODEL_INFO !== null) { + app.setState({model: BURNED_IN_MODEL_INFO}); + } else { + fetch("./model_info.json").then(function(response) { + if (!response.ok) { + throw new Error("Response not ok."); + } + return response.json(); + }).then(function(body) { + app.setState({model: body}); + }).catch(function(error) { + console.log("Top-level error: ", error); + }); + } + } + + componentDidCatch(error) { + void(error); + this.setState({...this.state, err: true}); + } + + render(_, {err}) { + if (this.state.model === null) { + return html`

Loading...

`; + } + + const model = this.state.model.model; + + let error_msg = ""; + if (err) { + error_msg = html`

An error occurred. Check console

`; + } + + return html` + ${error_msg} +
+

TorchScript Model (version ${model.version}): ${model.title}

+ + <${ModelSizeSection} model=${model}/> + <${StructuredDataSection} name="Model Data" data=${model.model_data} shown=true/> + <${StructuredDataSection} name="Constants" data=${model.constants} shown=false/> + <${ZipContentsSection} model=${model}/> + <${CodeSection} model=${model}/> + <${ExtraJsonSection} files=${model.extra_files_jsons}/> + <${ExtraPicklesSection} files=${model.extra_pickles}/> + <${TensorMemorySection} model=${model}/> +
+
+ <${AuxContentPane} + err=${this.state.error} + model=${model} + ref=${(p) => blame.setAuxContentPane(p)}/> +
+ `; + } +} + +render(h(App), document.body); diff --git a/phivenv/Lib/site-packages/torch/utils/model_dump/htm.mjs b/phivenv/Lib/site-packages/torch/utils/model_dump/htm.mjs new file mode 100644 index 0000000000000000000000000000000000000000..71d0b610fc29e5ad513d2235cdd80deb59a163a4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/model_dump/htm.mjs @@ -0,0 +1,2 @@ +// HTM, Apache License +var n=function(t,s,r,e){var u;s[0]=0;for(var h=1;h=5&&((e||!n&&5===r)&&(h.push(r,0,e,s),r=6),n&&(h.push(r,n,0,s),r=6)),e=""},a=0;a"===t?(r=1,e=""):e=t+e[0]:u?t===u?u="":e+=t:'"'===t||"'"===t?u=t:">"===t?(p(),r=1):r&&("="===t?(r=5,s=e,e=""):"/"===t&&(r<5||">"===n[a][l+1])?(p(),3===r&&(h=h[0]),r=h,(h=h[0]).push(2,0,r),r=0):" "===t||"\t"===t||"\n"===t||"\r"===t?(p(),r=2):e+=t),3===r&&"!--"===e&&(r=4,h=h[0])}return p(),h}(s)),r),arguments,[])).length>1?r:r[0]} diff --git a/phivenv/Lib/site-packages/torch/utils/model_dump/preact.mjs b/phivenv/Lib/site-packages/torch/utils/model_dump/preact.mjs new file mode 100644 index 0000000000000000000000000000000000000000..bd32598c6cb5fd76eb61a2217010fc9abb4d3693 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/model_dump/preact.mjs @@ -0,0 +1,2 @@ +// Preact, MIT License +var n,l,u,i,t,o,r={},f=[],e=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i;function c(e,n){for(var t in n)e[t]=n[t];return e}function s(e){var n=e.parentNode;n&&n.removeChild(e)}function a(e,n,t){var _,l,o,r=arguments,i={};for(o in n)"key"==o?_=n[o]:"ref"==o?l=n[o]:i[o]=n[o];if(arguments.length>3)for(t=[t],o=3;o0?v(m.type,m.props,m.key,null,m.__v):m)){if(m.__=t,m.__b=t.__b+1,null===(h=P[p])||h&&m.key==h.key&&m.type===h.type)P[p]=void 0;else for(a=0;a3)for(t=[t],o=3;o + + + TorchScript Model + + + + + + + + diff --git a/phivenv/Lib/site-packages/torch/utils/model_zoo.py b/phivenv/Lib/site-packages/torch/utils/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..c109354498e6c1a55de73b3c2246c2a2cbb2114b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/model_zoo.py @@ -0,0 +1,2 @@ +# torchvision imports tqdm from here. +from torch.hub import tqdm, load_state_dict_from_url as load_url # noqa: F401 diff --git a/phivenv/Lib/site-packages/torch/utils/module_tracker.py b/phivenv/Lib/site-packages/torch/utils/module_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..590c97706a243f71b32ab7c3d88bd9dec9e6d82d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/module_tracker.py @@ -0,0 +1,159 @@ +# mypy: allow-untyped-defs +import logging +import weakref +from typing import TYPE_CHECKING + +import torch +from torch.autograd.graph import register_multi_grad_hook +from torch.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, +) +from torch.utils._pytree import tree_flatten + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + + +logger = logging.getLogger(__name__) + + +__all__ = ["ModuleTracker"] + + +class ModuleTracker: + """ + ``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution + so that other system can query which Module is currently being executed (or its backward is being + executed). + + You can access the ``parents`` attribute on this context manager to get the set of all the + Modules currently being executed via their fqn (fully qualified name, also used as the key within + the state_dict). + You can access the ``is_bw`` attribute to know if you are currently running in backward or not. + + Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag + will remain ``True`` after the forward until another Module is executed. If you need it to be + more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance + is possible but not done yet, please submit an issue requesting this if you need it. + + Example usage + + .. code-block:: python + + mod = torch.nn.Linear(2, 2) + + with ModuleTracker() as tracker: + # Access anything during the forward pass + def my_linear(m1, m2, bias): + print(f"Current modules: {tracker.parents}") + return torch.mm(m1, m2.t()) + bias + torch.nn.functional.linear = my_linear + + mod(torch.rand(2, 2)) + + """ + + parents: set[str] + """ + A Set containing the fqn for each module currently running their forward + """ + + def __init__(self) -> None: + self.parents = {"Global"} + self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self._seen_modules: weakref.WeakSet = weakref.WeakSet() + self._has_callback = False + self._hooks: list[RemovableHandle] = [] + + def _maybe_set_engine_callback(self): + # This assumes no concurrent calls to backward + if self._has_callback: + return + + def callback(): + self.parents = {"Global"} + self._has_callback = False + + torch.autograd.Variable._execution_engine.queue_callback(callback) + self._has_callback = True + + @property + def is_bw(self): + """ + A boolean marking if this is currently running during the backward pass or not + """ + return torch._C._current_graph_task_id() != -1 + + def _get_mod_name(self, mod): + if mod not in self._known_modules: + self._known_modules[mod] = type(mod).__name__ + mod_name = self._known_modules[mod] + if mod not in self._seen_modules: + for name, submod in mod.named_children(): + self._known_modules[submod] = f"{mod_name}.{name}" + self._get_mod_name(submod) + self._seen_modules.add(mod) + return mod_name + + def _get_append_fn(self, name, is_bw): + def fn(*args): + if is_bw: + self._maybe_set_engine_callback() + if name in self.parents: + logger.info( + "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s", + name, + "backward" if is_bw else "forward", + ) + self.parents.add(name) + + return fn + + def _get_pop_fn(self, name, is_bw): + def fn(*args): + if name in self.parents: + self.parents.remove(name) + else: + logger.info( + "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s", + name, + "backward" if is_bw else "forward", + ) + + return fn + + def _fw_pre_hook(self, mod, input): + name = self._get_mod_name(mod) + self._get_append_fn(name, False)() + + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if tensors: + self._hooks.append( + register_multi_grad_hook(tensors, self._get_pop_fn(name, True)) + ) + + def _fw_post_hook(self, mod, input, output): + name = self._get_mod_name(mod) + self._get_pop_fn(name, False)() + + args, _ = tree_flatten(output) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if tensors: + self._hooks.append( + register_multi_grad_hook(tensors, self._get_append_fn(name, True)) + ) + + def __enter__(self): + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook(self._fw_post_hook) + return self + + def __exit__(self, *args): + self._fw_pre_handle.remove() + self._fw_post_handle.remove() + for hook in self._hooks: + hook.remove() + self._hooks.clear() diff --git a/phivenv/Lib/site-packages/torch/utils/serialization/__init__.py b/phivenv/Lib/site-packages/torch/utils/serialization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c9d7d19b523b1f63c01abc4c1687847959bebc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/serialization/__init__.py @@ -0,0 +1 @@ +from . import config diff --git a/phivenv/Lib/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0656203fafb6ea57c49cdbf420516305315dcbc9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/serialization/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/serialization/__pycache__/config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/serialization/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c4275df02908a8050df2f14e1b811d927a1a0c6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/serialization/__pycache__/config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/serialization/config.py b/phivenv/Lib/site-packages/torch/utils/serialization/config.py new file mode 100644 index 0000000000000000000000000000000000000000..36cffa203cf73732f4c522e84b94de58730d7c0d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/serialization/config.py @@ -0,0 +1,25 @@ +import sys +from typing import Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from torch.serialization import LoadEndianness as _LoadEndianess + +from torch.utils._config_module import install_config_module as _install_config_module + + +class load: + mmap: bool = False + endianness: _Optional["_LoadEndianess"] = None + # MAP_PRIVATE = 2 + mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2 + calculate_storage_offsets: bool = False + + +class save: + compute_crc32: bool = True + use_pinned_memory_for_d2h: bool = False + storage_alignment: int = 64 + + +_install_config_module(sys.modules[__name__]) diff --git a/phivenv/Lib/site-packages/torch/utils/show_pickle.py b/phivenv/Lib/site-packages/torch/utils/show_pickle.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e86a2cdf8715d8c5184e9725371cc886100d9d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/show_pickle.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs +import sys +import pickle +import struct +import pprint +import zipfile +import fnmatch +from typing import Any, IO + +__all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"] + +class FakeObject: + def __init__(self, module, name, args): + self.module = module + self.name = name + self.args = args + # NOTE: We don't distinguish between state never set and state set to None. + self.state = None + + def __repr__(self): + state_str = "" if self.state is None else f"(state={self.state!r})" + return f"{self.module}.{self.name}{self.args!r}{state_str}" + + def __setstate__(self, state): + self.state = state + + @staticmethod + def pp_format(printer, obj, stream, indent, allowance, context, level): + if not obj.args and obj.state is None: + stream.write(repr(obj)) + return + if obj.state is None: + stream.write(f"{obj.module}.{obj.name}") + printer._format(obj.args, stream, indent + 1, allowance + 1, context, level) + return + if not obj.args: + stream.write(f"{obj.module}.{obj.name}()(state=\n") + indent += printer._indent_per_level + stream.write(" " * indent) + printer._format(obj.state, stream, indent, allowance + 1, context, level + 1) + stream.write(")") + return + raise Exception("Need to implement") # noqa: TRY002 + + +class FakeClass: + def __init__(self, module, name): + self.module = module + self.name = name + self.__new__ = self.fake_new # type: ignore[assignment] + + def __repr__(self): + return f"{self.module}.{self.name}" + + def __call__(self, *args): + return FakeObject(self.module, self.name, args) + + def fake_new(self, *args): + return FakeObject(self.module, self.name, args[1:]) + + +class DumpUnpickler(pickle._Unpickler): # type: ignore[name-defined] + def __init__( + self, + file, + *, + catch_invalid_utf8=False, + **kwargs): + super().__init__(file, **kwargs) + self.catch_invalid_utf8 = catch_invalid_utf8 + + def find_class(self, module, name): + return FakeClass(module, name) + + def persistent_load(self, pid): + return FakeObject("pers", "obj", (pid,)) + + dispatch = dict(pickle._Unpickler.dispatch) # type: ignore[attr-defined] + + # Custom objects in TorchScript are able to return invalid UTF-8 strings + # from their pickle (__getstate__) functions. Install a custom loader + # for strings that catches the decode exception and replaces it with + # a sentinel object. + def load_binunicode(self): + strlen, = struct.unpack(" sys.maxsize: + raise Exception("String too long.") # noqa: TRY002 + str_bytes = self.read(strlen) # type: ignore[attr-defined] + obj: Any + try: + obj = str(str_bytes, "utf-8", "surrogatepass") + except UnicodeDecodeError as exn: + if not self.catch_invalid_utf8: + raise + obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),)) + self.append(obj) # type: ignore[attr-defined] + dispatch[pickle.BINUNICODE[0]] = load_binunicode # type: ignore[assignment] + + @classmethod + def dump(cls, in_stream, out_stream): + value = cls(in_stream).load() + pprint.pprint(value, stream=out_stream) + return value + + +def main(argv, output_stream=None): + if len(argv) != 2: + # Don't spam stderr if not using stdout. + if output_stream is not None: + raise Exception("Pass argv of length 2.") # noqa: TRY002 + sys.stderr.write("usage: show_pickle PICKLE_FILE\n") + sys.stderr.write(" PICKLE_FILE can be any of:\n") + sys.stderr.write(" path to a pickle file\n") + sys.stderr.write(" file.zip@member.pkl\n") + sys.stderr.write(" file.zip@*/pattern.*\n") + sys.stderr.write(" (shell glob pattern for members)\n") + sys.stderr.write(" (only first match will be shown)\n") + return 2 + + fname = argv[1] + handle: IO[bytes] + if "@" not in fname: + with open(fname, "rb") as handle: + DumpUnpickler.dump(handle, output_stream) + else: + zfname, mname = fname.split("@", 1) + with zipfile.ZipFile(zfname) as zf: + if "*" not in mname: + with zf.open(mname) as handle: + DumpUnpickler.dump(handle, output_stream) + else: + found = False + for info in zf.infolist(): + if fnmatch.fnmatch(info.filename, mname): + with zf.open(info) as handle: + DumpUnpickler.dump(handle, output_stream) + found = True + break + if not found: + raise Exception(f"Could not find member matching {mname} in {zfname}") # noqa: TRY002 + + +if __name__ == "__main__": + # This hack works on every version of Python I've tested. + # I've tested on the following versions: + # 3.7.4 + if True: + pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format # type: ignore[attr-defined] + + sys.exit(main(sys.argv)) diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__init__.py b/phivenv/Lib/site-packages/torch/utils/tensorboard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..445544590a7776b87b5148a54ee9ffb54714f57c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/tensorboard/__init__.py @@ -0,0 +1,19 @@ +import tensorboard +from torch._vendor.packaging.version import Version + +if not hasattr(tensorboard, "__version__") or Version( + tensorboard.__version__ +) < Version("1.15"): + raise ImportError("TensorBoard logging requires TensorBoard version 1.15 or above") + +del Version +del tensorboard + +from .writer import FileWriter, SummaryWriter +from tensorboard.summary.writer.record_writer import RecordWriter + +__all__ = [ + "FileWriter", + "RecordWriter", + "SummaryWriter", +] diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e46932434c722eacc7b662ac523ebc6883d48df4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_convert_np.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_convert_np.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..febc7c952c864bc176a6fb08c9faa0601866dcc1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_convert_np.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_embedding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_embedding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03959591c7eaadf30026301192f3fe73691c354a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_embedding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64d44577d67e63dffdde21f693156d9f00f1df14 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_proto_graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_proto_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb37bb6a8d46271946e9919fb578412de160af52 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_proto_graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0c864fe9ebf8d1ecf3c8177f290aea7dca254fa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..263f94962adfe16fc064dfdb60287ab1b0ffcfe0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/summary.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/summary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e29abf359fbb6d1ea08e8907fcf70407def17066 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/summary.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/writer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/writer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33a101a25e12b38c1e6d4d294f257b63911ca8b7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/tensorboard/__pycache__/writer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/_convert_np.py b/phivenv/Lib/site-packages/torch/utils/tensorboard/_convert_np.py new file mode 100644 index 0000000000000000000000000000000000000000..63c83a0ca8131b53caac37b031b11351eecccd9e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/tensorboard/_convert_np.py @@ -0,0 +1,33 @@ +"""This module converts objects into numpy array.""" + +import numpy as np + +import torch + + +def make_np(x: torch.Tensor) -> np.ndarray: + """ + Convert an object into numpy array. + + Args: + x: An instance of torch tensor + + Returns: + numpy.array: Numpy array + """ + if isinstance(x, np.ndarray): + return x + if np.isscalar(x): + return np.array([x]) + if isinstance(x, torch.Tensor): + return _prepare_pytorch(x) + raise NotImplementedError( + f"Got {type(x)}, but numpy array or torch tensor are expected." + ) + + +def _prepare_pytorch(x: torch.Tensor) -> np.ndarray: + if x.dtype == torch.bfloat16: + x = x.to(torch.float16) + x = x.detach().cpu().numpy() + return x diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/_embedding.py b/phivenv/Lib/site-packages/torch/utils/tensorboard/_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..2b11478b29377346cbca34da88ef629c7884bb2d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/tensorboard/_embedding.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +import math +import numpy as np +from ._convert_np import make_np +from ._utils import make_grid +from tensorboard.compat import tf +from tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo + + +_HAS_GFILE_JOIN = hasattr(tf.io.gfile, "join") + + +def _gfile_join(a, b): + # The join API is different between tensorboard's TF stub and TF: + # https://github.com/tensorflow/tensorboard/issues/6080 + # We need to try both because `tf` may point to either the stub or the real TF. + if _HAS_GFILE_JOIN: + return tf.io.gfile.join(a, b) + else: + fs = tf.io.gfile.get_filesystem(a) + return fs.join(a, b) + + +def make_tsv(metadata, save_path, metadata_header=None): + if not metadata_header: + metadata = [str(x) for x in metadata] + else: + assert len(metadata_header) == len( + metadata[0] + ), "len of header must be equal to the number of columns in metadata" + metadata = ["\t".join(str(e) for e in l) for l in [metadata_header] + metadata] + + metadata_bytes = tf.compat.as_bytes("\n".join(metadata) + "\n") + with tf.io.gfile.GFile(_gfile_join(save_path, "metadata.tsv"), "wb") as f: + f.write(metadata_bytes) + + +# https://github.com/tensorflow/tensorboard/issues/44 image label will be squared +def make_sprite(label_img, save_path): + from PIL import Image + from io import BytesIO + + # this ensures the sprite image has correct dimension as described in + # https://www.tensorflow.org/get_started/embedding_viz + nrow = int(math.ceil((label_img.size(0)) ** 0.5)) + arranged_img_CHW = make_grid(make_np(label_img), ncols=nrow) + + # augment images so that #images equals nrow*nrow + arranged_augment_square_HWC = np.zeros( + (arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3) + ) + arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0) # chw -> hwc + arranged_augment_square_HWC[: arranged_img_HWC.shape[0], :, :] = arranged_img_HWC + im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255))) + + with BytesIO() as buf: + im.save(buf, format="PNG") + im_bytes = buf.getvalue() + + with tf.io.gfile.GFile(_gfile_join(save_path, "sprite.png"), "wb") as f: + f.write(im_bytes) + + +def get_embedding_info(metadata, label_img, subdir, global_step, tag): + info = EmbeddingInfo() + info.tensor_name = f"{tag}:{str(global_step).zfill(5)}" + info.tensor_path = _gfile_join(subdir, "tensors.tsv") + if metadata is not None: + info.metadata_path = _gfile_join(subdir, "metadata.tsv") + if label_img is not None: + info.sprite.image_path = _gfile_join(subdir, "sprite.png") + info.sprite.single_image_dim.extend([label_img.size(3), label_img.size(2)]) + return info + + +def write_pbtxt(save_path, contents): + config_path = _gfile_join(save_path, "projector_config.pbtxt") + with tf.io.gfile.GFile(config_path, "wb") as f: + f.write(tf.compat.as_bytes(contents)) + + +def make_mat(matlist, save_path): + with tf.io.gfile.GFile(_gfile_join(save_path, "tensors.tsv"), "wb") as f: + for x in matlist: + x = [str(i.item()) for i in x] + f.write(tf.compat.as_bytes("\t".join(x) + "\n")) diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/_onnx_graph.py b/phivenv/Lib/site-packages/torch/utils/tensorboard/_onnx_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..1bab6edb8ee6ffceb7a85de26014e2eccb44d017 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/tensorboard/_onnx_graph.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +from tensorboard.compat.proto.graph_pb2 import GraphDef +from tensorboard.compat.proto.node_def_pb2 import NodeDef +from tensorboard.compat.proto.versions_pb2 import VersionDef +from tensorboard.compat.proto.attr_value_pb2 import AttrValue +from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto + + +def load_onnx_graph(fname): + import onnx + + m = onnx.load(fname) # type: ignore[attr-defined] + g = m.graph + return parse(g) + + +def parse(graph): + nodes = [] + import itertools + + nodes_proto = list(itertools.chain(graph.input, graph.output)) + + for node in nodes_proto: + print(node.name) + shapeproto = TensorShapeProto( + dim=[ + TensorShapeProto.Dim(size=d.dim_value) + for d in node.type.tensor_type.shape.dim + ] + ) + nodes.append( + NodeDef( + name=node.name.encode(encoding="utf_8"), + op="Variable", + input=[], + attr={ + "dtype": AttrValue(type=node.type.tensor_type.elem_type), + "shape": AttrValue(shape=shapeproto), + }, + ) + ) + + for node in graph.node: + _attr = [" = ".join([str(f[1]) for f in s.ListFields()]) for s in node.attribute] + attr = ", ".join(_attr).encode(encoding="utf_8") + print(node.output[0]) + nodes.append( + NodeDef( + name=node.output[0].encode(encoding="utf_8"), + op=node.op_type, + input=node.input, + attr={"parameters": AttrValue(s=attr)}, + ) + ) + + # two pass token replacement, appends opname to object id + mapping = {} + for node in nodes: + mapping[node.name] = node.op + "_" + node.name + + return GraphDef(node=nodes, versions=VersionDef(producer=22)) diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/_proto_graph.py b/phivenv/Lib/site-packages/torch/utils/tensorboard/_proto_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..45af87e8528f2083c8c7c275c282444730626ce5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/tensorboard/_proto_graph.py @@ -0,0 +1,54 @@ +# mypy: allow-untyped-defs +from typing import Optional +from tensorboard.compat.proto.node_def_pb2 import NodeDef +from tensorboard.compat.proto.attr_value_pb2 import AttrValue +from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto + + +def attr_value_proto(dtype, shape, s): + """Create a dict of objects matching a NodeDef's attr field. + + Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto + specifically designed for a NodeDef. The values have been reverse engineered from + standard TensorBoard logged data. + """ + attr = {} + if s is not None: + attr["attr"] = AttrValue(s=s.encode(encoding="utf_8")) + if shape is not None: + shapeproto = tensor_shape_proto(shape) + attr["_output_shapes"] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto])) + return attr + + +def tensor_shape_proto(outputsize): + """Create an object matching a tensor_shape field. + + Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto . + """ + return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize]) + + +def node_proto( + name, + op="UnSpecified", + input=None, + dtype=None, + shape: Optional[tuple] = None, + outputsize=None, + attributes="", +): + """Create an object matching a NodeDef. + + Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto . + """ + if input is None: + input = [] + if not isinstance(input, list): + input = [input] + return NodeDef( + name=name.encode(encoding="utf_8"), + op=op, + input=input, + attr=attr_value_proto(dtype, outputsize, attributes), + ) diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/_pytorch_graph.py b/phivenv/Lib/site-packages/torch/utils/tensorboard/_pytorch_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0d1456fa7a5455a383ecf53af196cd2f6e4ebb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/tensorboard/_pytorch_graph.py @@ -0,0 +1,376 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict +import contextlib +from typing import Any + +from tensorboard.compat.proto.config_pb2 import RunMetadata +from tensorboard.compat.proto.graph_pb2 import GraphDef +from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats +from tensorboard.compat.proto.versions_pb2 import VersionDef + +import torch +from ._proto_graph import node_proto + +methods_OP = [ + "attributeNames", + "hasMultipleOutputs", + "hasUses", + "inputs", + "kind", + "outputs", + "outputsSize", + "scopeName", +] +# Some additional methods to explure for methods_IO are +# +# 'unique' (type int) +# 'type' (type >) +# +# But the below are sufficient for now. +methods_IO = ["node", "offset", "debugName"] + +GETATTR_KIND = "prim::GetAttr" +CLASSTYPE_KIND = "ClassType" + + +class NodeBase: + def __init__( + self, + debugName=None, + inputs=None, + scope=None, + tensor_size=None, + op_type="UnSpecified", + attributes="", + ): + # TODO; Specify a __slots__ for this class or potentially + # used namedtuple instead + self.debugName = debugName + self.inputs = inputs + self.tensor_size = tensor_size + self.kind = op_type + self.attributes = attributes + self.scope = scope + + def __repr__(self): + repr = [] + repr.append(str(type(self))) + repr.extend( + m + ": " + str(getattr(self, m)) + str(type(getattr(self, m))) + for m in dir(self) + if "__" not in m + ) + return "\n".join(repr) + "\n\n" + + +class NodePy(NodeBase): + def __init__(self, node_cpp, valid_methods): + super().__init__(node_cpp) + valid_methods = valid_methods[:] + self.inputs = [] + + for m in valid_methods: + if m == "inputs" or m == "outputs": + list_of_node = list(getattr(node_cpp, m)()) + io_unique_names = [] + io_tensor_sizes = [] + for n in list_of_node: + io_unique_names.append(n.debugName()) + if n.isCompleteTensor(): + io_tensor_sizes.append(n.type().sizes()) + else: + io_tensor_sizes.append(None) + + setattr(self, m, io_unique_names) + setattr(self, m + "tensor_size", io_tensor_sizes) + + else: + setattr(self, m, getattr(node_cpp, m)()) + + +class NodePyIO(NodePy): + def __init__(self, node_cpp, input_or_output=None): + super().__init__(node_cpp, methods_IO) + try: + tensor_size = node_cpp.type().sizes() + except RuntimeError: + tensor_size = [ + 1, + ] # fail when constant model is used. + self.tensor_size = tensor_size + # Kind attribute string is purely descriptive and will be shown + # in detailed information for the node in TensorBoard's graph plugin. + # + # NodePyOP nodes get this from their kind() method. + self.kind = "Parameter" + if input_or_output: + self.input_or_output = input_or_output + self.kind = "IO Node" + + +class NodePyOP(NodePy): + def __init__(self, node_cpp): + super().__init__(node_cpp, methods_OP) + # Replace single quote which causes strange behavior in TensorBoard + # TODO: See if we can remove this in the future + self.attributes = str( + {k: _node_get(node_cpp, k) for k in node_cpp.attributeNames()} + ).replace("'", " ") + self.kind = node_cpp.kind() + + +class GraphPy: + """Helper class to convert torch.nn.Module to GraphDef proto and visualization with TensorBoard. + + GraphDef generation operates in two passes: + + In the first pass, all nodes are read and saved to two lists. + One list is for input/output nodes (nodes_io), which only have inbound + or outbound connections, but not both. Another list is for internal + operator nodes (nodes_op). The first pass also saves all scope name + appeared in the nodes in scope_name_appeared list for later processing. + + In the second pass, scope names are fully applied to all nodes. + debugNameToScopedName is a mapping from a node's ID to its fully qualified + scope name. e.g. Net1/Linear[0]/1. Unfortunately torch.jit doesn't have + totally correct scope output, so this is nontrivial. The function + populate_namespace_from_OP_to_IO and find_common_root are used to + assign scope name to a node based on the connection between nodes + in a heuristic kind of way. Bookkeeping is done with shallowest_scope_name + and scope_name_appeared. + """ + + def __init__(self): + self.nodes_op = [] + self.nodes_io = OrderedDict() + self.unique_name_to_scoped_name = {} + self.shallowest_scope_name = "default" + self.scope_name_appeared = [] + + def append(self, x): + if isinstance(x, NodePyIO): + self.nodes_io[x.debugName] = x + if isinstance(x, NodePyOP): + self.nodes_op.append(x) + + def printall(self): + print("all nodes") + for node in self.nodes_op: + print(node) + for key in self.nodes_io: + print(self.nodes_io[key]) + + def find_common_root(self): + for fullscope in self.scope_name_appeared: + if fullscope: + self.shallowest_scope_name = fullscope.split("/")[0] + + def populate_namespace_from_OP_to_IO(self): + for node in self.nodes_op: + for node_output, outputSize in zip(node.outputs, node.outputstensor_size): + self.scope_name_appeared.append(node.scopeName) + self.nodes_io[node_output] = NodeBase( + node_output, + node.inputs, + node.scopeName, + outputSize, + op_type=node.kind, + attributes=node.attributes, + ) + + self.find_common_root() + + for node in self.nodes_op: + for input_node_id in node.inputs: + self.unique_name_to_scoped_name[input_node_id] = ( + node.scopeName + "/" + input_node_id + ) + + for key, node in self.nodes_io.items(): + if type(node) == NodeBase: + self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName + if hasattr(node, "input_or_output"): + self.unique_name_to_scoped_name[key] = ( + node.input_or_output + "/" + node.debugName + ) + + if hasattr(node, "scope") and node.scope is not None: + self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName + if node.scope == "" and self.shallowest_scope_name: + self.unique_name_to_scoped_name[node.debugName] = ( + self.shallowest_scope_name + "/" + node.debugName + ) + + # replace name + for key, node in self.nodes_io.items(): + self.nodes_io[key].inputs = [ + self.unique_name_to_scoped_name[node_input_id] + for node_input_id in node.inputs + ] + if node.debugName in self.unique_name_to_scoped_name: + self.nodes_io[key].debugName = self.unique_name_to_scoped_name[ + node.debugName + ] + + def to_proto(self): + """Convert graph representation of GraphPy object to TensorBoard required format.""" + # TODO: compute correct memory usage and CPU time once + # PyTorch supports it + nodes = [ + node_proto( + v.debugName, + input=v.inputs, + outputsize=v.tensor_size, + op=v.kind, + attributes=v.attributes, + ) + for v in self.nodes_io.values() + ] + return nodes + + +def parse(graph, trace, args=None, omit_useless_nodes=True): + """Parse an optimized PyTorch model graph and produces a list of nodes and node stats. + + Useful for eventual conversion to TensorBoard protobuf format. + + Args: + graph (PyTorch module): The model graph to be parsed. + trace (PyTorch JIT TracedModule): The model trace to be parsed. + args (tuple): input tensor[s] for the model. + omit_useless_nodes (boolean): Whether to remove nodes from the graph. + """ + nodes_py = GraphPy() + for node in graph.inputs(): + if omit_useless_nodes: + if ( + len(node.uses()) == 0 + ): # number of user of the node (= number of outputs/ fanout) + continue + + if node.type().kind() != CLASSTYPE_KIND: + nodes_py.append(NodePyIO(node, "input")) + + attr_to_scope: dict[Any, str] = {} + for node in graph.nodes(): + if node.kind() == GETATTR_KIND: + attr_name = node.s("name") + attr_key = node.output().debugName() + parent = node.input().node() + if ( + parent.kind() == GETATTR_KIND + ): # If the parent node is not the top-level "self" node + parent_attr_key = parent.output().debugName() + parent_scope = attr_to_scope[parent_attr_key] + attr_scope = parent_scope.split("/")[-1] + attr_to_scope[attr_key] = f"{parent_scope}/{attr_scope}.{attr_name}" + else: + attr_to_scope[attr_key] = f"__module.{attr_name}" + # We don't need classtype nodes; scope will provide this information + if node.output().type().kind() != CLASSTYPE_KIND: + node_py = NodePyOP(node) + node_py.scopeName = attr_to_scope[attr_key] # type: ignore[attr-defined] + nodes_py.append(node_py) + else: + nodes_py.append(NodePyOP(node)) + + for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops + node_pyio = NodePyIO(node, "output") + node_pyio.debugName = f"output.{i + 1}" + node_pyio.inputs = [node.debugName()] + nodes_py.append(node_pyio) + + def parse_traced_name(module): + if isinstance(module, torch.jit.TracedModule): + module_name = module._name + else: + module_name = getattr(module, "original_name", "Module") + return module_name + + alias_to_name = {} + base_name = parse_traced_name(trace) + for name, module in trace.named_modules(prefix="__module"): + mod_name = parse_traced_name(module) + attr_name = name.split(".")[-1] + alias_to_name[name] = f"{mod_name}[{attr_name}]" + + for node in nodes_py.nodes_op: + module_aliases = node.scopeName.split("/") + replacements = [ + alias_to_name[alias] if alias in alias_to_name else alias.split(".")[-1] + for alias in module_aliases + ] + node.scopeName = base_name + if any(replacements): + node.scopeName += "/" + "/".join(replacements) + + nodes_py.populate_namespace_from_OP_to_IO() + return nodes_py.to_proto() + + +def graph(model, args, verbose=False, use_strict_trace=True): + """ + Process a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard. + + Args: + model (PyTorch module): The model to be parsed. + args (tuple): input tensor[s] for the model. + verbose (bool): Whether to print out verbose information while + processing. + use_strict_trace (bool): Whether to pass keyword argument `strict` to + `torch.jit.trace`. Pass False when you want the tracer to + record your mutable container types (list, dict) + """ + with _set_model_to_eval(model): + try: + trace = torch.jit.trace(model, args, strict=use_strict_trace) + graph = trace.graph + torch._C._jit_pass_inline(graph) + except RuntimeError as e: + print(e) + print("Error occurs, No graph saved") + raise e + + if verbose: + print(graph) + list_of_nodes = parse(graph, trace, args) + # We are hardcoding that this was run on CPU even though it might have actually + # run on GPU. Note this is what is shown in TensorBoard and has no bearing + # on actual execution. + # TODO: See if we can extract GPU vs CPU information from the PyTorch model + # and pass it correctly to TensorBoard. + # + # Definition of StepStats and DeviceStepStats can be found at + # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/proto.ts + # and + # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto + stepstats = RunMetadata( + step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]) + ) + return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats + # The producer version has been reverse engineered from standard + # TensorBoard logged data. + + +@contextlib.contextmanager +def _set_model_to_eval(model): + """Context manager to temporarily set the training mode of ``model`` to eval.""" + if not isinstance(model, torch.jit.ScriptFunction): + originally_training = model.training + model.train(False) + try: + yield + finally: + model.train(originally_training) + else: + # Do nothing for ScriptFunction + try: + yield + finally: + pass + + +def _node_get(node: torch._C.Node, key: str): + """Get attributes of a node which is polymorphic over return type.""" + sel = node.kindOf(key) + return getattr(node, sel)(key) diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/_utils.py b/phivenv/Lib/site-packages/torch/utils/tensorboard/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..95bad6b604e1f4a81bd6e79251703dbc1b6ba07c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/tensorboard/_utils.py @@ -0,0 +1,127 @@ +# mypy: allow-untyped-defs +import numpy as np +import numpy.typing as npt + + +# Functions for converting +def figure_to_image(figures, close=True): + """Render matplotlib figure to numpy format. + + Note that this requires the ``matplotlib`` package. + + Args: + figures (matplotlib.pyplot.figure or list of figures): figure or a list of figures + close (bool): Flag to automatically close the figure + + Returns: + numpy.array: image in [CHW] order + """ + import matplotlib.pyplot as plt + import matplotlib.backends.backend_agg as plt_backend_agg + + def render_to_rgb(figure): + canvas = plt_backend_agg.FigureCanvasAgg(figure) + canvas.draw() + data: npt.NDArray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) + w, h = figure.canvas.get_width_height() + image_hwc = data.reshape([h, w, 4])[:, :, 0:3] + image_chw = np.moveaxis(image_hwc, source=2, destination=0) + if close: + plt.close(figure) + return image_chw + + if isinstance(figures, list): + images = [render_to_rgb(figure) for figure in figures] + return np.stack(images) + else: + image = render_to_rgb(figures) + return image + + +def _prepare_video(V): + """ + Convert a 5D tensor into 4D tensor. + + Convesrion is done from [batchsize, time(frame), channel(color), height, width] (5D tensor) + to [time(frame), new_width, new_height, channel] (4D tensor). + + A batch of images are spreaded to a grid, which forms a frame. + e.g. Video with batchsize 16 will have a 4x4 grid. + """ + b, t, c, h, w = V.shape + + if V.dtype == np.uint8: + V = np.float32(V) / 255.0 + + def is_power2(num): + return num != 0 and ((num & (num - 1)) == 0) + + # pad to nearest power of 2, all at once + if not is_power2(V.shape[0]): + len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0]) + V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0) + + n_rows = 2 ** ((b.bit_length() - 1) // 2) + n_cols = V.shape[0] // n_rows + + V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w)) + V = np.transpose(V, axes=(2, 0, 4, 1, 5, 3)) + V = np.reshape(V, newshape=(t, n_rows * h, n_cols * w, c)) + + return V + + +def make_grid(I, ncols=8): + # I: N1HW or N3HW + assert isinstance(I, np.ndarray), "plugin error, should pass numpy array here" + if I.shape[1] == 1: + I = np.concatenate([I, I, I], 1) + assert I.ndim == 4 and I.shape[1] == 3 + nimg = I.shape[0] + H = I.shape[2] + W = I.shape[3] + ncols = min(nimg, ncols) + nrows = int(np.ceil(float(nimg) / ncols)) + canvas = np.zeros((3, H * nrows, W * ncols), dtype=I.dtype) + i = 0 + for y in range(nrows): + for x in range(ncols): + if i >= nimg: + break + canvas[:, y * H : (y + 1) * H, x * W : (x + 1) * W] = I[i] + i = i + 1 + return canvas + + # if modality == 'IMG': + # if x.dtype == np.uint8: + # x = x.astype(np.float32) / 255.0 + + +def convert_to_HWC(tensor, input_format): # tensor: numpy array + assert len(set(input_format)) == len( + input_format + ), f"You can not use the same dimension shordhand twice. input_format: {input_format}" + assert len(tensor.shape) == len( + input_format + ), f"size of input tensor and input format are different. \ + tensor shape: {tensor.shape}, input_format: {input_format}" + input_format = input_format.upper() + + if len(input_format) == 4: + index = [input_format.find(c) for c in "NCHW"] + tensor_NCHW = tensor.transpose(index) + tensor_CHW = make_grid(tensor_NCHW) + return tensor_CHW.transpose(1, 2, 0) + + if len(input_format) == 3: + index = [input_format.find(c) for c in "HWC"] + tensor_HWC = tensor.transpose(index) + if tensor_HWC.shape[2] == 1: + tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2) + return tensor_HWC + + if len(input_format) == 2: + index = [input_format.find(c) for c in "HW"] + tensor = tensor.transpose(index) + tensor = np.stack([tensor, tensor, tensor], 2) + return tensor diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/summary.py b/phivenv/Lib/site-packages/torch/utils/tensorboard/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..44cbd7428d85aad0b5fcb60b86b7c8946100b5a2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/tensorboard/summary.py @@ -0,0 +1,982 @@ +# mypy: allow-untyped-defs +import json +import logging +import os +import struct + +from typing import Any, Optional + +import torch +import numpy as np + +from google.protobuf import struct_pb2 + +from tensorboard.compat.proto.summary_pb2 import ( + HistogramProto, + Summary, + SummaryMetadata, +) +from tensorboard.compat.proto.tensor_pb2 import TensorProto +from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto +from tensorboard.plugins.custom_scalar import layout_pb2 +from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData +from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData + +from ._convert_np import make_np +from ._utils import _prepare_video, convert_to_HWC + +__all__ = [ + "half_to_int", + "int_to_half", + "hparams", + "scalar", + "histogram_raw", + "histogram", + "make_histogram", + "image", + "image_boxes", + "draw_boxes", + "make_image", + "video", + "make_video", + "audio", + "custom_scalars", + "text", + "tensor_proto", + "pr_curve_raw", + "pr_curve", + "compute_curve", + "mesh", +] + +logger = logging.getLogger(__name__) + +def half_to_int(f: float) -> int: + """Casts a half-precision float value into an integer. + + Converts a half precision floating point value, such as `torch.half` or + `torch.bfloat16`, into an integer value which can be written into the + half_val field of a TensorProto for storage. + + To undo the effects of this conversion, use int_to_half(). + + """ + buf = struct.pack("f", f) + return struct.unpack("i", buf)[0] + +def int_to_half(i: int) -> float: + """Casts an integer value to a half-precision float. + + Converts an integer value obtained from half_to_int back into a floating + point value. + + """ + buf = struct.pack("i", i) + return struct.unpack("f", buf)[0] + +def _tensor_to_half_val(t: torch.Tensor) -> list[int]: + return [half_to_int(x) for x in t.flatten().tolist()] + +def _tensor_to_complex_val(t: torch.Tensor) -> list[float]: + return torch.view_as_real(t).flatten().tolist() + +def _tensor_to_list(t: torch.Tensor) -> list[Any]: + return t.flatten().tolist() + +# type maps: torch.Tensor type -> (protobuf type, protobuf val field) +_TENSOR_TYPE_MAP = { + torch.half: ("DT_HALF", "half_val", _tensor_to_half_val), + torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val), + torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val), + torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list), + torch.float: ("DT_FLOAT", "float_val", _tensor_to_list), + torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list), + torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list), + torch.int8: ("DT_INT8", "int_val", _tensor_to_list), + torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list), + torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list), + torch.int16: ("DT_INT16", "int_val", _tensor_to_list), + torch.short: ("DT_INT16", "int_val", _tensor_to_list), + torch.int: ("DT_INT32", "int_val", _tensor_to_list), + torch.int32: ("DT_INT32", "int_val", _tensor_to_list), + torch.qint32: ("DT_INT32", "int_val", _tensor_to_list), + torch.int64: ("DT_INT64", "int64_val", _tensor_to_list), + torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val), + torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val), + torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val), + torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val), + torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list), + torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val), + torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val), + torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list), + torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list), + torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list), +} + + +def _calc_scale_factor(tensor): + converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor + return 1 if converted.dtype == np.uint8 else 255 + + +def _draw_single_box( + image, + xmin, + ymin, + xmax, + ymax, + display_str, + color="black", + color_text="black", + thickness=2, +): + from PIL import ImageDraw, ImageFont + + font = ImageFont.load_default() + draw = ImageDraw.Draw(image) + (left, right, top, bottom) = (xmin, xmax, ymin, ymax) + draw.line( + [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], + width=thickness, + fill=color, + ) + if display_str: + text_bottom = bottom + # Reverse list and print from bottom to top. + _left, _top, _right, _bottom = font.getbbox(display_str) + text_width, text_height = _right - _left, _bottom - _top + margin = np.ceil(0.05 * text_height) + draw.rectangle( + [ + (left, text_bottom - text_height - 2 * margin), + (left + text_width, text_bottom), + ], + fill=color, + ) + draw.text( + (left + margin, text_bottom - text_height - margin), + display_str, + fill=color_text, + font=font, + ) + return image + + +def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): + """Output three `Summary` protocol buffers needed by hparams plugin. + + `Experiment` keeps the metadata of an experiment, such as the name of the + hyperparameters and the name of the metrics. + `SessionStartInfo` keeps key-value pairs of the hyperparameters + `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS + + Args: + hparam_dict: A dictionary that contains names of the hyperparameters + and their values. + metric_dict: A dictionary that contains names of the metrics + and their values. + hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that + contains names of the hyperparameters and all discrete values they can hold + + Returns: + The `Summary` protobufs for Experiment, SessionStartInfo and + SessionEndInfo + """ + import torch + from tensorboard.plugins.hparams.api_pb2 import ( + DataType, + Experiment, + HParamInfo, + MetricInfo, + MetricName, + Status, + ) + from tensorboard.plugins.hparams.metadata import ( + EXPERIMENT_TAG, + PLUGIN_DATA_VERSION, + PLUGIN_NAME, + SESSION_END_INFO_TAG, + SESSION_START_INFO_TAG, + ) + from tensorboard.plugins.hparams.plugin_data_pb2 import ( + HParamsPluginData, + SessionEndInfo, + SessionStartInfo, + ) + + # TODO: expose other parameters in the future. + # hp = HParamInfo(name='lr',display_name='learning rate', + # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10, + # max_value=100)) + # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy', + # description='', dataset_type=DatasetType.DATASET_VALIDATION) + # exp = Experiment(name='123', description='456', time_created_secs=100.0, + # hparam_infos=[hp], metric_infos=[mt], user='tw') + + if not isinstance(hparam_dict, dict): + logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.") + raise TypeError( + "parameter: hparam_dict should be a dictionary, nothing logged." + ) + if not isinstance(metric_dict, dict): + logger.warning("parameter: metric_dict should be a dictionary, nothing logged.") + raise TypeError( + "parameter: metric_dict should be a dictionary, nothing logged." + ) + + hparam_domain_discrete = hparam_domain_discrete or {} + if not isinstance(hparam_domain_discrete, dict): + raise TypeError( + "parameter: hparam_domain_discrete should be a dictionary, nothing logged." + ) + for k, v in hparam_domain_discrete.items(): + if ( + k not in hparam_dict + or not isinstance(v, list) + or not all(isinstance(d, type(hparam_dict[k])) for d in v) + ): + raise TypeError( + f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]." + ) + hps = [] + + ssi = SessionStartInfo() + for k, v in hparam_dict.items(): + if v is None: + continue + if isinstance(v, (int, float)): + ssi.hparams[k].number_value = v + + if k in hparam_domain_discrete: + domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue( + values=[ + struct_pb2.Value(number_value=d) + for d in hparam_domain_discrete[k] + ] + ) + else: + domain_discrete = None + + hps.append( + HParamInfo( + name=k, + type=DataType.Value("DATA_TYPE_FLOAT64"), + domain_discrete=domain_discrete, + ) + ) + continue + + if isinstance(v, str): + ssi.hparams[k].string_value = v + + if k in hparam_domain_discrete: + domain_discrete = struct_pb2.ListValue( + values=[ + struct_pb2.Value(string_value=d) + for d in hparam_domain_discrete[k] + ] + ) + else: + domain_discrete = None + + hps.append( + HParamInfo( + name=k, + type=DataType.Value("DATA_TYPE_STRING"), + domain_discrete=domain_discrete, + ) + ) + continue + + if isinstance(v, bool): + ssi.hparams[k].bool_value = v + + if k in hparam_domain_discrete: + domain_discrete = struct_pb2.ListValue( + values=[ + struct_pb2.Value(bool_value=d) + for d in hparam_domain_discrete[k] + ] + ) + else: + domain_discrete = None + + hps.append( + HParamInfo( + name=k, + type=DataType.Value("DATA_TYPE_BOOL"), + domain_discrete=domain_discrete, + ) + ) + continue + + if isinstance(v, torch.Tensor): + v = make_np(v)[0] + ssi.hparams[k].number_value = v + hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64"))) + continue + raise ValueError( + "value should be one of int, float, str, bool, or torch.Tensor" + ) + + content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION) + smd = SummaryMetadata( + plugin_data=SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ) + ) + ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) + + mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()] + + exp = Experiment(hparam_infos=hps, metric_infos=mts) + + content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION) + smd = SummaryMetadata( + plugin_data=SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ) + ) + exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)]) + + sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS")) + content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION) + smd = SummaryMetadata( + plugin_data=SummaryMetadata.PluginData( + plugin_name=PLUGIN_NAME, content=content.SerializeToString() + ) + ) + sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)]) + + return exp, ssi, sei + + +def scalar(name, tensor, collections=None, new_style=False, double_precision=False): + """Output a `Summary` protocol buffer containing a single scalar value. + + The generated Summary has a Tensor.proto containing the input Tensor. + Args: + name: A name for the generated node. Will also serve as the series name in + TensorBoard. + tensor: A real numeric Tensor containing a single value. + collections: Optional list of graph collections keys. The new summary op is + added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. + new_style: Whether to use new style (tensor field) or old style (simple_value + field). New style could lead to faster data loading. + Returns: + A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf. + Raises: + ValueError: If tensor has the wrong shape or type. + """ + tensor = make_np(tensor).squeeze() + assert ( + tensor.ndim == 0 + ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions." + # python float is double precision in numpy + scalar = float(tensor) + if new_style: + tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT") + if double_precision: + tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE") + + plugin_data = SummaryMetadata.PluginData(plugin_name="scalars") + smd = SummaryMetadata(plugin_data=plugin_data) + return Summary( + value=[ + Summary.Value( + tag=name, + tensor=tensor_proto, + metadata=smd, + ) + ] + ) + else: + return Summary(value=[Summary.Value(tag=name, simple_value=scalar)]) + + +def tensor_proto(tag, tensor): + """Outputs a `Summary` protocol buffer containing the full tensor. + The generated Summary has a Tensor.proto containing the input Tensor. + Args: + tag: A name for the generated node. Will also serve as the series name in + TensorBoard. + tensor: Tensor to be converted to protobuf + Returns: + A tensor protobuf in a `Summary` protobuf. + Raises: + ValueError: If tensor is too big to be converted to protobuf, or + tensor data type is not supported + """ + if tensor.numel() * tensor.itemsize >= (1 << 31): + raise ValueError( + "tensor is bigger than protocol buffer's hard limit of 2GB in size" + ) + + if tensor.dtype in _TENSOR_TYPE_MAP: + dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype] + tensor_proto = TensorProto( + **{ + "dtype": dtype, + "tensor_shape": TensorShapeProto( + dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape] + ), + field_name: conversion_fn(tensor), + }, + ) + else: + raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}") + + plugin_data = SummaryMetadata.PluginData(plugin_name="tensor") + smd = SummaryMetadata(plugin_data=plugin_data) + return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)]) + + +def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts): + # pylint: disable=line-too-long + """Output a `Summary` protocol buffer with a histogram. + + The generated + [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) + has one summary value containing a histogram for `values`. + Args: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + min: A float or int min value + max: A float or int max value + num: Int number of values + sum: Float or int sum of all values + sum_squares: Float or int sum of squares for all values + bucket_limits: A numeric `Tensor` with upper value per bucket + bucket_counts: A numeric `Tensor` with number of values per bucket + Returns: + A scalar `Tensor` of type `string`. The serialized `Summary` protocol + buffer. + """ + hist = HistogramProto( + min=min, + max=max, + num=num, + sum=sum, + sum_squares=sum_squares, + bucket_limit=bucket_limits, + bucket=bucket_counts, + ) + return Summary(value=[Summary.Value(tag=name, histo=hist)]) + + +def histogram(name, values, bins, max_bins=None): + # pylint: disable=line-too-long + """Output a `Summary` protocol buffer with a histogram. + + The generated + [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) + has one summary value containing a histogram for `values`. + This op reports an `InvalidArgument` error if any value is not finite. + Args: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + values: A real numeric `Tensor`. Any shape. Values to use to + build the histogram. + Returns: + A scalar `Tensor` of type `string`. The serialized `Summary` protocol + buffer. + """ + values = make_np(values) + hist = make_histogram(values.astype(float), bins, max_bins) + return Summary(value=[Summary.Value(tag=name, histo=hist)]) + + +def make_histogram(values, bins, max_bins=None): + """Convert values into a histogram proto using logic from histogram.cc.""" + if values.size == 0: + raise ValueError("The input has no element.") + values = values.reshape(-1) + counts, limits = np.histogram(values, bins=bins) + num_bins = len(counts) + if max_bins is not None and num_bins > max_bins: + subsampling = num_bins // max_bins + subsampling_remainder = num_bins % subsampling + if subsampling_remainder != 0: + counts = np.pad( + counts, + pad_width=[[0, subsampling - subsampling_remainder]], + mode="constant", + constant_values=0, + ) + counts = counts.reshape(-1, subsampling).sum(axis=-1) + new_limits = np.empty((counts.size + 1,), limits.dtype) + new_limits[:-1] = limits[:-1:subsampling] + new_limits[-1] = limits[-1] + limits = new_limits + + # Find the first and the last bin defining the support of the histogram: + + cum_counts = np.cumsum(np.greater(counts, 0)) + start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right") + start = int(start) + end = int(end) + 1 + del cum_counts + + # TensorBoard only includes the right bin limits. To still have the leftmost limit + # included, we include an empty bin left. + # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the + # first nonzero-count bin: + counts = ( + counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]]) + ) + limits = limits[start : end + 1] + + if counts.size == 0 or limits.size == 0: + raise ValueError("The histogram is empty, please file a bug report.") + + sum_sq = values.dot(values) + return HistogramProto( + min=values.min(), + max=values.max(), + num=len(values), + sum=values.sum(), + sum_squares=sum_sq, + bucket_limit=limits.tolist(), + bucket=counts.tolist(), + ) + + +def image(tag, tensor, rescale=1, dataformats="NCHW"): + """Output a `Summary` protocol buffer with images. + + The summary has up to `max_images` summary values containing images. The + images are built from `tensor` which must be 3-D with shape `[height, width, + channels]` and where `channels` can be: + * 1: `tensor` is interpreted as Grayscale. + * 3: `tensor` is interpreted as RGB. + * 4: `tensor` is interpreted as RGBA. + The `name` in the outputted Summary.Value protobufs is generated based on the + name, with a suffix depending on the max_outputs setting: + * If `max_outputs` is 1, the summary value tag is '*name*/image'. + * If `max_outputs` is greater than 1, the summary value tags are + generated sequentially as '*name*/image/0', '*name*/image/1', etc. + Args: + tag: A name for the generated node. Will also serve as a series name in + TensorBoard. + tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width, + channels]` where `channels` is 1, 3, or 4. + 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8). + The image() function will scale the image values to [0, 255] by applying + a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values + will be clipped. + Returns: + A scalar `Tensor` of type `string`. The serialized `Summary` protocol + buffer. + """ + tensor = make_np(tensor) + tensor = convert_to_HWC(tensor, dataformats) + # Do not assume that user passes in values in [0, 255], use data type to detect + scale_factor = _calc_scale_factor(tensor) + tensor = tensor.astype(np.float32) + tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) + image = make_image(tensor, rescale=rescale) + return Summary(value=[Summary.Value(tag=tag, image=image)]) + + +def image_boxes( + tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None +): + """Output a `Summary` protocol buffer with images.""" + tensor_image = make_np(tensor_image) + tensor_image = convert_to_HWC(tensor_image, dataformats) + tensor_boxes = make_np(tensor_boxes) + tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image) + image = make_image( + tensor_image.clip(0, 255).astype(np.uint8), + rescale=rescale, + rois=tensor_boxes, + labels=labels, + ) + return Summary(value=[Summary.Value(tag=tag, image=image)]) + + +def draw_boxes(disp_image, boxes, labels=None): + # xyxy format + num_boxes = boxes.shape[0] + list_gt = range(num_boxes) + for i in list_gt: + disp_image = _draw_single_box( + disp_image, + boxes[i, 0], + boxes[i, 1], + boxes[i, 2], + boxes[i, 3], + display_str=None if labels is None else labels[i], + color="Red", + ) + return disp_image + + +def make_image(tensor, rescale=1, rois=None, labels=None): + """Convert a numpy representation of an image to Image protobuf.""" + from PIL import Image + + height, width, channel = tensor.shape + scaled_height = int(height * rescale) + scaled_width = int(width * rescale) + image = Image.fromarray(tensor) + if rois is not None: + image = draw_boxes(image, rois, labels=labels) + ANTIALIAS = Image.Resampling.LANCZOS + image = image.resize((scaled_width, scaled_height), ANTIALIAS) + import io + + output = io.BytesIO() + image.save(output, format="PNG") + image_string = output.getvalue() + output.close() + return Summary.Image( + height=height, + width=width, + colorspace=channel, + encoded_image_string=image_string, + ) + + +def video(tag, tensor, fps=4): + tensor = make_np(tensor) + tensor = _prepare_video(tensor) + # If user passes in uint8, then we don't need to rescale by 255 + scale_factor = _calc_scale_factor(tensor) + tensor = tensor.astype(np.float32) + tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) + video = make_video(tensor, fps) + return Summary(value=[Summary.Value(tag=tag, image=video)]) + + +def make_video(tensor, fps): + try: + import moviepy # noqa: F401 + except ImportError: + print("add_video needs package moviepy") + return + try: + from moviepy import editor as mpy + except ImportError: + print( + "moviepy is installed, but can't import moviepy.editor.", + "Some packages could be missing [imageio, requests]", + ) + return + import tempfile + + _t, h, w, c = tensor.shape + + # encode sequence of images into gif string + clip = mpy.ImageSequenceClip(list(tensor), fps=fps) + + filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name + try: # newer version of moviepy use logger instead of progress_bar argument. + clip.write_gif(filename, verbose=False, logger=None) + except TypeError: + try: # older version of moviepy does not support progress_bar argument. + clip.write_gif(filename, verbose=False, progress_bar=False) + except TypeError: + clip.write_gif(filename, verbose=False) + + with open(filename, "rb") as f: + tensor_string = f.read() + + try: + os.remove(filename) + except OSError: + logger.warning("The temporary file used by moviepy cannot be deleted.") + + return Summary.Image( + height=h, width=w, colorspace=c, encoded_image_string=tensor_string + ) + + +def audio(tag, tensor, sample_rate=44100): + array = make_np(tensor) + array = array.squeeze() + if abs(array).max() > 1: + print("warning: audio amplitude out of range, auto clipped.") + array = array.clip(-1, 1) + assert array.ndim == 1, "input tensor should be 1 dimensional." + array = (array * np.iinfo(np.int16).max).astype(" 127: # weird, value > 127 breaks protobuf + num_thresholds = 127 + data = np.stack((tp, fp, tn, fn, precision, recall)) + pr_curve_plugin_data = PrCurvePluginData( + version=0, num_thresholds=num_thresholds + ).SerializeToString() + plugin_data = SummaryMetadata.PluginData( + plugin_name="pr_curves", content=pr_curve_plugin_data + ) + smd = SummaryMetadata(plugin_data=plugin_data) + tensor = TensorProto( + dtype="DT_FLOAT", + float_val=data.reshape(-1).tolist(), + tensor_shape=TensorShapeProto( + dim=[ + TensorShapeProto.Dim(size=data.shape[0]), + TensorShapeProto.Dim(size=data.shape[1]), + ] + ), + ) + return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) + + +def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None): + # weird, value > 127 breaks protobuf + num_thresholds = min(num_thresholds, 127) + data = compute_curve( + labels, predictions, num_thresholds=num_thresholds, weights=weights + ) + pr_curve_plugin_data = PrCurvePluginData( + version=0, num_thresholds=num_thresholds + ).SerializeToString() + plugin_data = SummaryMetadata.PluginData( + plugin_name="pr_curves", content=pr_curve_plugin_data + ) + smd = SummaryMetadata(plugin_data=plugin_data) + tensor = TensorProto( + dtype="DT_FLOAT", + float_val=data.reshape(-1).tolist(), + tensor_shape=TensorShapeProto( + dim=[ + TensorShapeProto.Dim(size=data.shape[0]), + TensorShapeProto.Dim(size=data.shape[1]), + ] + ), + ) + return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) + + +# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py +def compute_curve(labels, predictions, num_thresholds=None, weights=None): + _MINIMUM_COUNT = 1e-7 + + if weights is None: + weights = 1.0 + + # Compute bins of true positives and false positives. + bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) + float_labels = labels.astype(np.float64) + histogram_range = (0, num_thresholds - 1) + tp_buckets, _ = np.histogram( + bucket_indices, + bins=num_thresholds, + range=histogram_range, + weights=float_labels * weights, + ) + fp_buckets, _ = np.histogram( + bucket_indices, + bins=num_thresholds, + range=histogram_range, + weights=(1.0 - float_labels) * weights, + ) + + # Obtain the reverse cumulative sum. + tp = np.cumsum(tp_buckets[::-1])[::-1] + fp = np.cumsum(fp_buckets[::-1])[::-1] + tn = fp[0] - fp + fn = tp[0] - tp + precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp) + recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn) + return np.stack((tp, fp, tn, fn, precision, recall)) + + +def _get_tensor_summary( + name, display_name, description, tensor, content_type, components, json_config +): + """Create a tensor summary with summary metadata. + + Args: + name: Uniquely identifiable name of the summary op. Could be replaced by + combination of name and type to make it unique even outside of this + summary. + display_name: Will be used as the display name in TensorBoard. + Defaults to `name`. + description: A longform readable description of the summary data. Markdown + is supported. + tensor: Tensor to display in summary. + content_type: Type of content inside the Tensor. + components: Bitmask representing present parts (vertices, colors, etc.) that + belong to the summary. + json_config: A string, JSON-serialized dictionary of ThreeJS classes + configuration. + + Returns: + Tensor summary with metadata. + """ + import torch + from tensorboard.plugins.mesh import metadata + + tensor = torch.as_tensor(tensor) + + tensor_metadata = metadata.create_summary_metadata( + name, + display_name, + content_type, + components, + tensor.shape, + description, + json_config=json_config, + ) + + tensor = TensorProto( + dtype="DT_FLOAT", + float_val=tensor.reshape(-1).tolist(), + tensor_shape=TensorShapeProto( + dim=[ + TensorShapeProto.Dim(size=tensor.shape[0]), + TensorShapeProto.Dim(size=tensor.shape[1]), + TensorShapeProto.Dim(size=tensor.shape[2]), + ] + ), + ) + + tensor_summary = Summary.Value( + tag=metadata.get_instance_name(name, content_type), + tensor=tensor, + metadata=tensor_metadata, + ) + + return tensor_summary + + +def _get_json_config(config_dict): + """Parse and returns JSON string from python dictionary.""" + json_config = "{}" + if config_dict is not None: + json_config = json.dumps(config_dict, sort_keys=True) + return json_config + + +# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py +def mesh( + tag, vertices, colors, faces, config_dict, display_name=None, description=None +): + """Output a merged `Summary` protocol buffer with a mesh/point cloud. + + Args: + tag: A name for this summary operation. + vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D + coordinates of vertices. + faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of + vertices within each triangle. + colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each + vertex. + display_name: If set, will be used as the display name in TensorBoard. + Defaults to `name`. + description: A longform readable description of the summary data. Markdown + is supported. + config_dict: Dictionary with ThreeJS classes names and configuration. + + Returns: + Merged summary for mesh/point cloud representation. + """ + from tensorboard.plugins.mesh import metadata + from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData + + json_config = _get_json_config(config_dict) + + summaries = [] + tensors = [ + (vertices, MeshPluginData.VERTEX), + (faces, MeshPluginData.FACE), + (colors, MeshPluginData.COLOR), + ] + tensors = [tensor for tensor in tensors if tensor[0] is not None] + components = metadata.get_components_bitmask( + [content_type for (tensor, content_type) in tensors] + ) + + for tensor, content_type in tensors: + summaries.append( + _get_tensor_summary( + tag, + display_name, + description, + tensor, + content_type, + components, + json_config, + ) + ) + + return Summary(value=summaries) diff --git a/phivenv/Lib/site-packages/torch/utils/tensorboard/writer.py b/phivenv/Lib/site-packages/torch/utils/tensorboard/writer.py new file mode 100644 index 0000000000000000000000000000000000000000..762c409a7ee4f00181659caefa6caf4ad2e2607e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/tensorboard/writer.py @@ -0,0 +1,1208 @@ +# mypy: allow-untyped-defs +"""Provide an API for writing protocol buffers to event files to be consumed by TensorBoard for visualization.""" + +import os +import time +from typing import Optional, TYPE_CHECKING, Union + +import torch + +if TYPE_CHECKING: + from matplotlib.figure import Figure +from tensorboard.compat import tf +from tensorboard.compat.proto import event_pb2 +from tensorboard.compat.proto.event_pb2 import Event, SessionLog +from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig +from tensorboard.summary.writer.event_file_writer import EventFileWriter + +from ._convert_np import make_np +from ._embedding import get_embedding_info, make_mat, make_sprite, make_tsv, write_pbtxt +from ._onnx_graph import load_onnx_graph +from ._pytorch_graph import graph +from ._utils import figure_to_image +from .summary import ( + audio, + custom_scalars, + histogram, + histogram_raw, + hparams, + image, + image_boxes, + mesh, + pr_curve, + pr_curve_raw, + scalar, + tensor_proto, + text, + video, +) + +__all__ = ["FileWriter", "SummaryWriter"] + + +class FileWriter: + """Writes protocol buffers to event files to be consumed by TensorBoard. + + The `FileWriter` class provides a mechanism to create an event file in a + given directory and add summaries and events to it. The class updates the + file contents asynchronously. This allows a training program to call methods + to add data to the file directly from the training loop, without slowing down + training. + """ + + def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix=""): + """Create a `FileWriter` and an event file. + + On construction the writer creates a new event file in `log_dir`. + The other arguments to the constructor control the asynchronous writes to + the event file. + + Args: + log_dir: A string. Directory where event file will be written. + max_queue: Integer. Size of the queue for pending events and + summaries before one of the 'add' calls forces a flush to disk. + Default is ten items. + flush_secs: Number. How often, in seconds, to flush the + pending events and summaries to disk. Default is every two minutes. + filename_suffix: A string. Suffix added to all event filenames + in the log_dir directory. More details on filename construction in + tensorboard.summary.writer.event_file_writer.EventFileWriter. + """ + # Sometimes PosixPath is passed in and we need to coerce it to + # a string in all cases + # TODO: See if we can remove this in the future if we are + # actually the ones passing in a PosixPath + log_dir = str(log_dir) + self.event_writer = EventFileWriter( + log_dir, max_queue, flush_secs, filename_suffix + ) + + def get_logdir(self): + """Return the directory where event file will be written.""" + return self.event_writer.get_logdir() + + def add_event(self, event, step=None, walltime=None): + """Add an event to the event file. + + Args: + event: An `Event` protocol buffer. + step: Number. Optional global step value for training process + to record with the event. + walltime: float. Optional walltime to override the default (current) + walltime (from time.time()) seconds after epoch + """ + event.wall_time = time.time() if walltime is None else walltime + if step is not None: + # Make sure step is converted from numpy or other formats + # since protobuf might not convert depending on version + event.step = int(step) + self.event_writer.add_event(event) + + def add_summary(self, summary, global_step=None, walltime=None): + """Add a `Summary` protocol buffer to the event file. + + This method wraps the provided summary in an `Event` protocol buffer + and adds it to the event file. + + Args: + summary: A `Summary` protocol buffer. + global_step: Number. Optional global step value for training process + to record with the summary. + walltime: float. Optional walltime to override the default (current) + walltime (from time.time()) seconds after epoch + """ + event = event_pb2.Event(summary=summary) + self.add_event(event, global_step, walltime) + + def add_graph(self, graph_profile, walltime=None): + """Add a `Graph` and step stats protocol buffer to the event file. + + Args: + graph_profile: A `Graph` and step stats protocol buffer. + walltime: float. Optional walltime to override the default (current) + walltime (from time.time()) seconds after epoch + """ + graph = graph_profile[0] + stepstats = graph_profile[1] + event = event_pb2.Event(graph_def=graph.SerializeToString()) + self.add_event(event, None, walltime) + + trm = event_pb2.TaggedRunMetadata( + tag="step1", run_metadata=stepstats.SerializeToString() + ) + event = event_pb2.Event(tagged_run_metadata=trm) + self.add_event(event, None, walltime) + + def add_onnx_graph(self, graph, walltime=None): + """Add a `Graph` protocol buffer to the event file. + + Args: + graph: A `Graph` protocol buffer. + walltime: float. Optional walltime to override the default (current) + _get_file_writerfrom time.time()) + """ + event = event_pb2.Event(graph_def=graph.SerializeToString()) + self.add_event(event, None, walltime) + + def flush(self): + """Flushes the event file to disk. + + Call this method to make sure that all pending events have been written to + disk. + """ + self.event_writer.flush() + + def close(self): + """Flushes the event file to disk and close the file. + + Call this method when you do not need the summary writer anymore. + """ + self.event_writer.close() + + def reopen(self): + """Reopens the EventFileWriter. + + Can be called after `close()` to add more events in the same directory. + The events will go into a new events file. + Does nothing if the EventFileWriter was not closed. + """ + self.event_writer.reopen() + + +class SummaryWriter: + """Writes entries directly to event files in the log_dir to be consumed by TensorBoard. + + The `SummaryWriter` class provides a high-level API to create an event file + in a given directory and add summaries and events to it. The class updates the + file contents asynchronously. This allows a training program to call methods + to add data to the file directly from the training loop, without slowing down + training. + """ + + def __init__( + self, + log_dir=None, + comment="", + purge_step=None, + max_queue=10, + flush_secs=120, + filename_suffix="", + ): + """Create a `SummaryWriter` that will write out events and summaries to the event file. + + Args: + log_dir (str): Save directory location. Default is + runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run. + Use hierarchical folder structure to compare + between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc. + for each new experiment to compare across them. + comment (str): Comment log_dir suffix appended to the default + ``log_dir``. If ``log_dir`` is assigned, this argument has no effect. + purge_step (int): + When logging crashes at step :math:`T+X` and restarts at step :math:`T`, + any events whose global_step larger or equal to :math:`T` will be + purged and hidden from TensorBoard. + Note that crashed and resumed experiments should have the same ``log_dir``. + max_queue (int): Size of the queue for pending events and + summaries before one of the 'add' calls forces a flush to disk. + Default is ten items. + flush_secs (int): How often, in seconds, to flush the + pending events and summaries to disk. Default is every two minutes. + filename_suffix (str): Suffix added to all event filenames in + the log_dir directory. More details on filename construction in + tensorboard.summary.writer.event_file_writer.EventFileWriter. + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + + # create a summary writer with automatically generated folder name. + writer = SummaryWriter() + # folder location: runs/May04_22-14-54_s-MacBook-Pro.local/ + + # create a summary writer using the specified folder name. + writer = SummaryWriter("my_experiment") + # folder location: my_experiment + + # create a summary writer with comment appended. + writer = SummaryWriter(comment="LR_0.1_BATCH_16") + # folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/ + + """ + torch._C._log_api_usage_once("tensorboard.create.summarywriter") + if not log_dir: + import socket + from datetime import datetime + + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + log_dir = os.path.join( + "runs", current_time + "_" + socket.gethostname() + comment + ) + self.log_dir = log_dir + self.purge_step = purge_step + self.max_queue = max_queue + self.flush_secs = flush_secs + self.filename_suffix = filename_suffix + + # Initialize the file writers, but they can be cleared out on close + # and recreated later as needed. + self.file_writer = self.all_writers = None + self._get_file_writer() + + # Create default bins for histograms, see generate_testdata.py in tensorflow/tensorboard + v = 1e-12 + buckets = [] + neg_buckets = [] + while v < 1e20: + buckets.append(v) + neg_buckets.append(-v) + v *= 1.1 + self.default_bins = neg_buckets[::-1] + [0] + buckets + + def _get_file_writer(self): + """Return the default FileWriter instance. Recreates it if closed.""" + if self.all_writers is None or self.file_writer is None: + self.file_writer = FileWriter( + self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix + ) + self.all_writers = {self.file_writer.get_logdir(): self.file_writer} + if self.purge_step is not None: + most_recent_step = self.purge_step + self.file_writer.add_event( + Event(step=most_recent_step, file_version="brain.Event:2") + ) + self.file_writer.add_event( + Event( + step=most_recent_step, + session_log=SessionLog(status=SessionLog.START), + ) + ) + self.purge_step = None + return self.file_writer + + def get_logdir(self): + """Return the directory where event files will be written.""" + return self.log_dir + + def add_hparams( + self, + hparam_dict, + metric_dict, + hparam_domain_discrete=None, + run_name=None, + global_step=None, + ): + """Add a set of hyperparameters to be compared in TensorBoard. + + Args: + hparam_dict (dict): Each key-value pair in the dictionary is the + name of the hyper parameter and it's corresponding value. + The type of the value can be one of `bool`, `string`, `float`, + `int`, or `None`. + metric_dict (dict): Each key-value pair in the dictionary is the + name of the metric and it's corresponding value. Note that the key used + here should be unique in the tensorboard record. Otherwise the value + you added by ``add_scalar`` will be displayed in hparam plugin. In most + cases, this is unwanted. + hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that + contains names of the hyperparameters and all discrete values they can hold + run_name (str): Name of the run, to be included as part of the logdir. + If unspecified, will use current timestamp. + global_step (int): Global step value to record + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + with SummaryWriter() as w: + for i in range(5): + w.add_hparams({'lr': 0.1*i, 'bsize': i}, + {'hparam/accuracy': 10*i, 'hparam/loss': 10*i}) + + Expected result: + + .. image:: _static/img/tensorboard/add_hparam.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_hparams") + if type(hparam_dict) is not dict or type(metric_dict) is not dict: + raise TypeError("hparam_dict and metric_dict should be dictionary.") + exp, ssi, sei = hparams(hparam_dict, metric_dict, hparam_domain_discrete) + + if not run_name: + run_name = str(time.time()) + logdir = os.path.join(self._get_file_writer().get_logdir(), run_name) + with SummaryWriter(log_dir=logdir) as w_hp: + w_hp.file_writer.add_summary(exp, global_step) + w_hp.file_writer.add_summary(ssi, global_step) + w_hp.file_writer.add_summary(sei, global_step) + for k, v in metric_dict.items(): + w_hp.add_scalar(k, v, global_step) + + def add_scalar( + self, + tag, + scalar_value, + global_step=None, + walltime=None, + new_style=False, + double_precision=False, + ): + """Add scalar data to summary. + + Args: + tag (str): Data identifier + scalar_value (float or string/blobname): Value to save + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + with seconds after epoch of event + new_style (boolean): Whether to use new style (tensor field) or old + style (simple_value field). New style could lead to faster data loading. + Examples:: + + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter() + x = range(100) + for i in x: + writer.add_scalar('y=2x', i * 2, i) + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_scalar.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_scalar") + + summary = scalar( + tag, scalar_value, new_style=new_style, double_precision=double_precision + ) + self._get_file_writer().add_summary(summary, global_step, walltime) + + def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): + """Add many scalar data to summary. + + Args: + main_tag (str): The parent name for the tags + tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter() + r = 5 + for i in range(100): + writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r), + 'xcosx':i*np.cos(i/r), + 'tanx': np.tan(i/r)}, i) + writer.close() + # This call adds three values to the same scalar plot with the tag + # 'run_14h' in TensorBoard's scalar section. + + Expected result: + + .. image:: _static/img/tensorboard/add_scalars.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_scalars") + walltime = time.time() if walltime is None else walltime + fw_logdir = self._get_file_writer().get_logdir() + for tag, scalar_value in tag_scalar_dict.items(): + fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag + assert self.all_writers is not None + if fw_tag in self.all_writers.keys(): + fw = self.all_writers[fw_tag] + else: + fw = FileWriter( + fw_tag, self.max_queue, self.flush_secs, self.filename_suffix + ) + self.all_writers[fw_tag] = fw + fw.add_summary(scalar(main_tag, scalar_value), global_step, walltime) + + def add_tensor( + self, + tag, + tensor, + global_step=None, + walltime=None, + ): + """Add tensor data to summary. + + Args: + tag (str): Data identifier + tensor (torch.Tensor): tensor to save + global_step (int): Global step value to record + Examples:: + + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter() + x = torch.tensor([1,2,3]) + writer.add_scalar('x', x) + writer.close() + + Expected result: + Summary::tensor::float_val [1,2,3] + ::tensor::shape [3] + ::tag 'x' + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_tensor") + + summary = tensor_proto(tag, tensor) + self._get_file_writer().add_summary(summary, global_step, walltime) + + def add_histogram( + self, + tag, + values, + global_step=None, + bins="tensorflow", + walltime=None, + max_bins=None, + ): + """Add histogram to summary. + + Args: + tag (str): Data identifier + values (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram + global_step (int): Global step value to record + bins (str): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find + other options in: https://numpy.org/doc/stable/reference/generated/numpy.histogram.html + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + writer = SummaryWriter() + for i in range(10): + x = np.random.random(1000) + writer.add_histogram('distribution centers', x + i, i) + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_histogram.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_histogram") + if isinstance(bins, str) and bins == "tensorflow": + bins = self.default_bins + self._get_file_writer().add_summary( + histogram(tag, values, bins, max_bins=max_bins), global_step, walltime + ) + + def add_histogram_raw( + self, + tag, + min, + max, + num, + sum, + sum_squares, + bucket_limits, + bucket_counts, + global_step=None, + walltime=None, + ): + """Add histogram with raw data. + + Args: + tag (str): Data identifier + min (float or int): Min value + max (float or int): Max value + num (int): Number of values + sum (float or int): Sum of all values + sum_squares (float or int): Sum of squares for all values + bucket_limits (torch.Tensor, numpy.ndarray): Upper value per bucket. + The number of elements of it should be the same as `bucket_counts`. + bucket_counts (torch.Tensor, numpy.ndarray): Number of values per bucket + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + writer = SummaryWriter() + dummy_data = [] + for idx, value in enumerate(range(50)): + dummy_data += [idx + 0.001] * value + + bins = list(range(50+2)) + bins = np.array(bins) + values = np.array(dummy_data).astype(float).reshape(-1) + counts, limits = np.histogram(values, bins=bins) + sum_sq = values.dot(values) + writer.add_histogram_raw( + tag='histogram_with_raw_data', + min=values.min(), + max=values.max(), + num=len(values), + sum=values.sum(), + sum_squares=sum_sq, + bucket_limits=limits[1:].tolist(), + bucket_counts=counts.tolist(), + global_step=0) + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_histogram_raw.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_histogram_raw") + if len(bucket_limits) != len(bucket_counts): + raise ValueError( + "len(bucket_limits) != len(bucket_counts), see the document." + ) + self._get_file_writer().add_summary( + histogram_raw( + tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts + ), + global_step, + walltime, + ) + + def add_image( + self, tag, img_tensor, global_step=None, walltime=None, dataformats="CHW" + ): + """Add image data to summary. + + Note that this requires the ``pillow`` package. + + Args: + tag (str): Data identifier + img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + dataformats (str): Image data format specification of the form + CHW, HWC, HW, WH, etc. + Shape: + img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to + convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job. + Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as + corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``. + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + img = np.zeros((3, 100, 100)) + img[0] = np.arange(0, 10000).reshape(100, 100) / 10000 + img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000 + + img_HWC = np.zeros((100, 100, 3)) + img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 + img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000 + + writer = SummaryWriter() + writer.add_image('my_image', img, 0) + + # If you have non-default dimension setting, set the dataformats argument. + writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC') + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_image.png + :scale: 50 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_image") + self._get_file_writer().add_summary( + image(tag, img_tensor, dataformats=dataformats), global_step, walltime + ) + + def add_images( + self, tag, img_tensor, global_step=None, walltime=None, dataformats="NCHW" + ): + """Add batched image data to summary. + + Note that this requires the ``pillow`` package. + + Args: + tag (str): Data identifier + img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + dataformats (str): Image data format specification of the form + NCHW, NHWC, CHW, HWC, HW, WH, etc. + Shape: + img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be + accepted. e.g. NCHW or NHWC. + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + + img_batch = np.zeros((16, 3, 100, 100)) + for i in range(16): + img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i + img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i + + writer = SummaryWriter() + writer.add_images('my_image_batch', img_batch, 0) + writer.close() + + Expected result: + + .. image:: _static/img/tensorboard/add_images.png + :scale: 30 % + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_images") + self._get_file_writer().add_summary( + image(tag, img_tensor, dataformats=dataformats), global_step, walltime + ) + + def add_image_with_boxes( + self, + tag, + img_tensor, + box_tensor, + global_step=None, + walltime=None, + rescale=1, + dataformats="CHW", + labels=None, + ): + """Add image and draw bounding boxes on the image. + + Args: + tag (str): Data identifier + img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data + box_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Box data (for detected objects) + box should be represented as [x1, y1, x2, y2]. + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + rescale (float): Optional scale override + dataformats (str): Image data format specification of the form + NCHW, NHWC, CHW, HWC, HW, WH, etc. + labels (list of string): The label to be shown for each bounding box. + Shape: + img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformats`` argument. + e.g. CHW or HWC + + box_tensor: (torch.Tensor, numpy.ndarray, or string/blobname): NX4, where N is the number of + boxes and each 4 elements in a row represents (xmin, ymin, xmax, ymax). + """ + torch._C._log_api_usage_once("tensorboard.logging.add_image_with_boxes") + if labels is not None: + if isinstance(labels, str): + labels = [labels] + if len(labels) != box_tensor.shape[0]: + labels = None + self._get_file_writer().add_summary( + image_boxes( + tag, + img_tensor, + box_tensor, + rescale=rescale, + dataformats=dataformats, + labels=labels, + ), + global_step, + walltime, + ) + + def add_figure( + self, + tag: str, + figure: Union["Figure", list["Figure"]], + global_step: Optional[int] = None, + close: bool = True, + walltime: Optional[float] = None, + ) -> None: + """Render matplotlib figure into an image and add it to summary. + + Note that this requires the ``matplotlib`` package. + + Args: + tag: Data identifier + figure: Figure or a list of figures + global_step: Global step value to record + close: Flag to automatically close the figure + walltime: Optional override default walltime (time.time()) + seconds after epoch of event + """ + torch._C._log_api_usage_once("tensorboard.logging.add_figure") + if isinstance(figure, list): + self.add_image( + tag, + figure_to_image(figure, close), + global_step, + walltime, + dataformats="NCHW", + ) + else: + self.add_image( + tag, + figure_to_image(figure, close), + global_step, + walltime, + dataformats="CHW", + ) + + def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None): + """Add video data to summary. + + Note that this requires the ``moviepy`` package. + + Args: + tag (str): Data identifier + vid_tensor (torch.Tensor): Video data + global_step (int): Global step value to record + fps (float or int): Frames per second + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + Shape: + vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`. + """ + torch._C._log_api_usage_once("tensorboard.logging.add_video") + self._get_file_writer().add_summary( + video(tag, vid_tensor, fps), global_step, walltime + ) + + def add_audio( + self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None + ): + """Add audio data to summary. + + Args: + tag (str): Data identifier + snd_tensor (torch.Tensor): Sound data + global_step (int): Global step value to record + sample_rate (int): sample rate in Hz + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + Shape: + snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1]. + """ + torch._C._log_api_usage_once("tensorboard.logging.add_audio") + self._get_file_writer().add_summary( + audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime + ) + + def add_text(self, tag, text_string, global_step=None, walltime=None): + """Add text data to summary. + + Args: + tag (str): Data identifier + text_string (str): String to save + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + Examples:: + + writer.add_text('lstm', 'This is an lstm', 0) + writer.add_text('rnn', 'This is an rnn', 10) + """ + torch._C._log_api_usage_once("tensorboard.logging.add_text") + self._get_file_writer().add_summary( + text(tag, text_string), global_step, walltime + ) + + def add_onnx_graph(self, prototxt): + torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph") + self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt)) + + def add_graph( + self, model, input_to_model=None, verbose=False, use_strict_trace=True + ): + """Add graph data to summary. + + Args: + model (torch.nn.Module): Model to draw. + input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of + variables to be fed. + verbose (bool): Whether to print graph structure in console. + use_strict_trace (bool): Whether to pass keyword argument `strict` to + `torch.jit.trace`. Pass False when you want the tracer to + record your mutable container types (list, dict) + """ + torch._C._log_api_usage_once("tensorboard.logging.add_graph") + # A valid PyTorch model should have a 'forward' method + self._get_file_writer().add_graph( + graph(model, input_to_model, verbose, use_strict_trace) + ) + + @staticmethod + def _encode(rawstr): + # I'd use urllib but, I'm unsure about the differences from python3 to python2, etc. + retval = rawstr + retval = retval.replace("%", f"%{ord('%'):02x}") + retval = retval.replace("/", f"%{ord('/'):02x}") + retval = retval.replace("\\", "%%%02x" % (ord("\\"))) # noqa: UP031 + return retval + + def add_embedding( + self, + mat, + metadata=None, + label_img=None, + global_step=None, + tag="default", + metadata_header=None, + ): + """Add embedding projector data to summary. + + Args: + mat (torch.Tensor or numpy.ndarray): A matrix which each row is the feature vector of the data point + metadata (list): A list of labels, each element will be converted to string + label_img (torch.Tensor): Images correspond to each data point + global_step (int): Global step value to record + tag (str): Name for the embedding + metadata_header (list): A list of headers for multi-column metadata. If given, each metadata must be + a list with values corresponding to headers. + Shape: + mat: :math:`(N, D)`, where N is number of data and D is feature dimension + + label_img: :math:`(N, C, H, W)` + + Examples:: + + import keyword + import torch + meta = [] + while len(meta)<100: + meta = meta+keyword.kwlist # get some strings + meta = meta[:100] + + for i, v in enumerate(meta): + meta[i] = v+str(i) + + label_img = torch.rand(100, 3, 10, 32) + for i in range(100): + label_img[i]*=i/100.0 + + writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img) + writer.add_embedding(torch.randn(100, 5), label_img=label_img) + writer.add_embedding(torch.randn(100, 5), metadata=meta) + + .. note:: + Categorical (i.e. non-numeric) metadata cannot have more than 50 unique values if they are to be used for + coloring in the embedding projector. + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_embedding") + mat = make_np(mat) + if global_step is None: + global_step = 0 + # clear pbtxt? + + # Maybe we should encode the tag so slashes don't trip us up? + # I don't think this will mess us up, but better safe than sorry. + subdir = f"{str(global_step).zfill(5)}/{self._encode(tag)}" + save_path = os.path.join(self._get_file_writer().get_logdir(), subdir) + + fs = tf.io.gfile + if fs.exists(save_path): + if fs.isdir(save_path): + print( + "warning: Embedding dir exists, did you set global_step for add_embedding()?" + ) + else: + raise NotADirectoryError( + f"Path: `{save_path}` exists, but is a file. Cannot proceed." + ) + else: + fs.makedirs(save_path) + + if metadata is not None: + assert mat.shape[0] == len( + metadata + ), "#labels should equal with #data points" + make_tsv(metadata, save_path, metadata_header=metadata_header) + + if label_img is not None: + assert ( + mat.shape[0] == label_img.shape[0] + ), "#images should equal with #data points" + make_sprite(label_img, save_path) + + assert ( + mat.ndim == 2 + ), "mat should be 2D, where mat.size(0) is the number of data points" + make_mat(mat, save_path) + + # Filesystem doesn't necessarily have append semantics, so we store an + # internal buffer to append to and re-write whole file after each + # embedding is added + if not hasattr(self, "_projector_config"): + self._projector_config = ProjectorConfig() + embedding_info = get_embedding_info( + metadata, label_img, subdir, global_step, tag + ) + self._projector_config.embeddings.extend([embedding_info]) + + from google.protobuf import text_format + + config_pbtxt = text_format.MessageToString(self._projector_config) + write_pbtxt(self._get_file_writer().get_logdir(), config_pbtxt) + + def add_pr_curve( + self, + tag, + labels, + predictions, + global_step=None, + num_thresholds=127, + weights=None, + walltime=None, + ): + """Add precision recall curve. + + Plotting a precision-recall curve lets you understand your model's + performance under different threshold settings. With this function, + you provide the ground truth labeling (T/F) and prediction confidence + (usually the output of your model) for each target. The TensorBoard UI + will let you choose the threshold interactively. + + Args: + tag (str): Data identifier + labels (torch.Tensor, numpy.ndarray, or string/blobname): + Ground truth data. Binary label for each element. + predictions (torch.Tensor, numpy.ndarray, or string/blobname): + The probability that an element be classified as true. + Value should be in [0, 1] + global_step (int): Global step value to record + num_thresholds (int): Number of thresholds used to draw the curve. + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + import numpy as np + labels = np.random.randint(2, size=100) # binary label + predictions = np.random.rand(100) + writer = SummaryWriter() + writer.add_pr_curve('pr_curve', labels, predictions, 0) + writer.close() + + """ + torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve") + labels, predictions = make_np(labels), make_np(predictions) + self._get_file_writer().add_summary( + pr_curve(tag, labels, predictions, num_thresholds, weights), + global_step, + walltime, + ) + + def add_pr_curve_raw( + self, + tag, + true_positive_counts, + false_positive_counts, + true_negative_counts, + false_negative_counts, + precision, + recall, + global_step=None, + num_thresholds=127, + weights=None, + walltime=None, + ): + """Add precision recall curve with raw data. + + Args: + tag (str): Data identifier + true_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): true positive counts + false_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): false positive counts + true_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): true negative counts + false_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): false negative counts + precision (torch.Tensor, numpy.ndarray, or string/blobname): precision + recall (torch.Tensor, numpy.ndarray, or string/blobname): recall + global_step (int): Global step value to record + num_thresholds (int): Number of thresholds used to draw the curve. + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md + """ + torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve_raw") + self._get_file_writer().add_summary( + pr_curve_raw( + tag, + true_positive_counts, + false_positive_counts, + true_negative_counts, + false_negative_counts, + precision, + recall, + num_thresholds, + weights, + ), + global_step, + walltime, + ) + + def add_custom_scalars_multilinechart( + self, tags, category="default", title="untitled" + ): + """Shorthand for creating multilinechart. Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*. + + Args: + tags (list): list of tags that have been used in ``add_scalar()`` + + Examples:: + + writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330']) + """ + torch._C._log_api_usage_once( + "tensorboard.logging.add_custom_scalars_multilinechart" + ) + layout = {category: {title: ["Multiline", tags]}} + self._get_file_writer().add_summary(custom_scalars(layout)) + + def add_custom_scalars_marginchart( + self, tags, category="default", title="untitled" + ): + """Shorthand for creating marginchart. + + Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*, + which should have exactly 3 elements. + + Args: + tags (list): list of tags that have been used in ``add_scalar()`` + + Examples:: + + writer.add_custom_scalars_marginchart(['twse/0050', 'twse/2330', 'twse/2006']) + """ + torch._C._log_api_usage_once( + "tensorboard.logging.add_custom_scalars_marginchart" + ) + assert len(tags) == 3 + layout = {category: {title: ["Margin", tags]}} + self._get_file_writer().add_summary(custom_scalars(layout)) + + def add_custom_scalars(self, layout): + """Create special chart by collecting charts tags in 'scalars'. + + NOTE: This function can only be called once for each SummaryWriter() object. + + Because it only provides metadata to tensorboard, the function can be called before or after the training loop. + + Args: + layout (dict): {categoryName: *charts*}, where *charts* is also a dictionary + {chartName: *ListOfProperties*}. The first element in *ListOfProperties* is the chart's type + (one of **Multiline** or **Margin**) and the second element should be a list containing the tags + you have used in add_scalar function, which will be collected into the new chart. + + Examples:: + + layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]}, + 'USA':{ 'dow':['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']], + 'nasdaq':['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}} + + writer.add_custom_scalars(layout) + """ + torch._C._log_api_usage_once("tensorboard.logging.add_custom_scalars") + self._get_file_writer().add_summary(custom_scalars(layout)) + + def add_mesh( + self, + tag, + vertices, + colors=None, + faces=None, + config_dict=None, + global_step=None, + walltime=None, + ): + """Add meshes or 3D point clouds to TensorBoard. + + The visualization is based on Three.js, + so it allows users to interact with the rendered object. Besides the basic definitions + such as vertices, faces, users can further provide camera parameter, lighting condition, etc. + Please check https://threejs.org/docs/index.html#manual/en/introduction/Creating-a-scene for + advanced usage. + + Args: + tag (str): Data identifier + vertices (torch.Tensor): List of the 3D coordinates of vertices. + colors (torch.Tensor): Colors for each vertex + faces (torch.Tensor): Indices of vertices within each triangle. (Optional) + config_dict: Dictionary with ThreeJS classes names and configuration. + global_step (int): Global step value to record + walltime (float): Optional override default walltime (time.time()) + seconds after epoch of event + + Shape: + vertices: :math:`(B, N, 3)`. (batch, number_of_vertices, channels) + + colors: :math:`(B, N, 3)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`. + + faces: :math:`(B, N, 3)`. The values should lie in [0, number_of_vertices] for type `uint8`. + + Examples:: + + from torch.utils.tensorboard import SummaryWriter + vertices_tensor = torch.as_tensor([ + [1, 1, 1], + [-1, -1, 1], + [1, -1, -1], + [-1, 1, -1], + ], dtype=torch.float).unsqueeze(0) + colors_tensor = torch.as_tensor([ + [255, 0, 0], + [0, 255, 0], + [0, 0, 255], + [255, 0, 255], + ], dtype=torch.int).unsqueeze(0) + faces_tensor = torch.as_tensor([ + [0, 2, 3], + [0, 3, 1], + [0, 1, 2], + [1, 3, 2], + ], dtype=torch.int).unsqueeze(0) + + writer = SummaryWriter() + writer.add_mesh('my_mesh', vertices=vertices_tensor, colors=colors_tensor, faces=faces_tensor) + + writer.close() + """ + torch._C._log_api_usage_once("tensorboard.logging.add_mesh") + self._get_file_writer().add_summary( + mesh(tag, vertices, colors, faces, config_dict), global_step, walltime + ) + + def flush(self): + """Flushes the event file to disk. + + Call this method to make sure that all pending events have been written to + disk. + """ + if self.all_writers is None: + return + for writer in self.all_writers.values(): + writer.flush() + + def close(self): + if self.all_writers is None: + return # ignore double close + for writer in self.all_writers.values(): + writer.flush() + writer.close() + self.file_writer = self.all_writers = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/phivenv/Lib/site-packages/torch/utils/throughput_benchmark.py b/phivenv/Lib/site-packages/torch/utils/throughput_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..4aae52c5f6b1587630cad3c232b73bef53a60b2b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/throughput_benchmark.py @@ -0,0 +1,160 @@ +# mypy: allow-untyped-defs + +import torch._C + + +def format_time(time_us=None, time_ms=None, time_s=None): + """Define time formatting.""" + assert sum([time_us is not None, time_ms is not None, time_s is not None]) == 1 + + US_IN_SECOND = 1e6 + US_IN_MS = 1e3 + + if time_us is None: + if time_ms is not None: + time_us = time_ms * US_IN_MS + elif time_s is not None: + time_us = time_s * US_IN_SECOND + else: + raise AssertionError("Shouldn't reach here :)") + + if time_us >= US_IN_SECOND: + return f'{time_us / US_IN_SECOND:.3f}s' + if time_us >= US_IN_MS: + return f'{time_us / US_IN_MS:.3f}ms' + return f'{time_us:.3f}us' + + +class ExecutionStats: + def __init__(self, c_stats, benchmark_config): + self._c_stats = c_stats + self.benchmark_config = benchmark_config + + @property + def latency_avg_ms(self): + return self._c_stats.latency_avg_ms + + @property + def num_iters(self): + return self._c_stats.num_iters + + @property + def iters_per_second(self): + """Return total number of iterations per second across all calling threads.""" + return self.num_iters / self.total_time_seconds + + @property + def total_time_seconds(self): + return self.num_iters * ( + self.latency_avg_ms / 1000.0) / self.benchmark_config.num_calling_threads + + def __str__(self): + return '\n'.join([ + "Average latency per example: " + format_time(time_ms=self.latency_avg_ms), + f"Total number of iterations: {self.num_iters}", + f"Total number of iterations per second (across all threads): {self.iters_per_second:.2f}", + "Total time: " + format_time(time_s=self.total_time_seconds) + ]) + + +class ThroughputBenchmark: + """ + This class is a wrapper around a c++ component throughput_benchmark::ThroughputBenchmark. + + This wrapper on the throughput_benchmark::ThroughputBenchmark component is responsible + for executing a PyTorch module (nn.Module or ScriptModule) under an inference + server like load. It can emulate multiple calling threads to a single module + provided. In the future we plan to enhance this component to support inter and + intra-op parallelism as well as multiple models running in a single process. + + Please note that even though nn.Module is supported, it might incur an overhead + from the need to hold GIL every time we execute Python code or pass around + inputs as Python objects. As soon as you have a ScriptModule version of your + model for inference deployment it is better to switch to using it in this + benchmark. + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> from torch.utils import ThroughputBenchmark + >>> bench = ThroughputBenchmark(my_module) + >>> # Pre-populate benchmark's data set with the inputs + >>> for input in inputs: + ... # Both args and kwargs work, same as any PyTorch Module / ScriptModule + ... bench.add_input(input[0], x2=input[1]) + >>> # Inputs supplied above are randomly used during the execution + >>> stats = bench.benchmark( + ... num_calling_threads=4, + ... num_warmup_iters = 100, + ... num_iters = 1000, + ... ) + >>> print("Avg latency (ms): {}".format(stats.latency_avg_ms)) + >>> print("Number of iterations: {}".format(stats.num_iters)) + """ + + def __init__(self, module): + if isinstance(module, torch.jit.ScriptModule): + self._benchmark = torch._C.ThroughputBenchmark(module._c) + else: + self._benchmark = torch._C.ThroughputBenchmark(module) + + def run_once(self, *args, **kwargs): + """ + Given input id (input_idx) run benchmark once and return prediction. + + This is useful for testing that benchmark actually runs the module you + want it to run. input_idx here is an index into inputs array populated + by calling add_input() method. + """ + return self._benchmark.run_once(*args, **kwargs) + + def add_input(self, *args, **kwargs): + """ + Store a single input to a module into the benchmark memory and keep it there. + + During the benchmark execution every thread is going to pick up a + random input from the all the inputs ever supplied to the benchmark via + this function. + """ + self._benchmark.add_input(*args, **kwargs) + + def benchmark( + self, + num_calling_threads=1, + num_warmup_iters=10, + num_iters=100, + profiler_output_path=""): + """ + Run a benchmark on the module. + + Args: + num_warmup_iters (int): Warmup iters are used to make sure we run a module + a few times before actually measuring things. This way we avoid cold + caches and any other similar problems. This is the number of warmup + iterations for each of the thread in separate + + num_iters (int): Number of iterations the benchmark should run with. + This number is separate from the warmup iterations. Also the number is + shared across all the threads. Once the num_iters iterations across all + the threads is reached, we will stop execution. Though total number of + iterations might be slightly larger. Which is reported as + stats.num_iters where stats is the result of this function + + profiler_output_path (str): Location to save Autograd Profiler trace. + If not empty, Autograd Profiler will be enabled for the main benchmark + execution (but not the warmup phase). The full trace will be saved + into the file path provided by this argument + + + This function returns BenchmarkExecutionStats object which is defined via pybind11. + It currently has two fields: + - num_iters - number of actual iterations the benchmark have made + - avg_latency_ms - average time it took to infer on one input example in milliseconds + """ + config = torch._C.BenchmarkConfig() + config.num_calling_threads = num_calling_threads + config.num_warmup_iters = num_warmup_iters + config.num_iters = num_iters + config.profiler_output_path = profiler_output_path + c_stats = self._benchmark.benchmark(config) + return ExecutionStats(c_stats, config) diff --git a/phivenv/Lib/site-packages/torch/utils/viz/__init__.py b/phivenv/Lib/site-packages/torch/utils/viz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/utils/viz/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/viz/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfed0fc654df16bcc60485f3a518e90fbeee0d16 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/viz/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/viz/__pycache__/_cycles.cpython-39.pyc b/phivenv/Lib/site-packages/torch/utils/viz/__pycache__/_cycles.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7505cf546fe785d3f42dddee8774ed483829a06b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/utils/viz/__pycache__/_cycles.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/utils/viz/_cycles.py b/phivenv/Lib/site-packages/torch/utils/viz/_cycles.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1b7f73e5345fd9cca3711dadde8e53fd6f768e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/viz/_cycles.py @@ -0,0 +1,499 @@ +# mypy: allow-untyped-defs +import gc +import sys +from typing import Any, NamedTuple, Optional +import types +import weakref +import json +from tempfile import NamedTemporaryFile +import torch +from torch.cuda._memory_viz import _frames_fmt, _block_extra +import atexit +import logging +logger = logging.getLogger(__name__) + +def observe_garbage(observer): + enabled = True + + def disable(): + # when GC runs during exit, things like `sys` will already be unloaded + # so we have to disable the callback to avoid hitting errors. + nonlocal enabled + enabled = False + atexit.register(disable) + + def gc_callback(phase, info): + nonlocal enabled + if not enabled: + return + if phase == "start": + gc.set_debug(gc.DEBUG_SAVEALL) + elif phase == "stop": + orig_trace = sys.getprofile() + self_return = [False] + + def do_collect(*args, **kwargs): + nonlocal enabled + if not self_return[0]: + self_return[0] = True + else: + sys.setprofile(orig_trace) + enabled = False + try: + # things in gc.garbage have survived a collection + # so to free them we have to collect a generation greater than them + # but that might _also_ free other stuff and we don't want to miss + # that stuff. So we have to now force gc at the highest level here, + # report all of what we found, _then_ we can free it up. + if info['generation'] != 2: + gc.collect() + observer(gc.garbage) + gc.garbage.clear() + # we have to re-run GC to clean up the cycles + # we saved from before. + gc.set_debug(0) + before = torch.cuda.memory_allocated() + gc.collect() + after = torch.cuda.memory_allocated() + if before != after: + logger.warning("CUDA Memory changed during GC, %d bytes freed.", before - after) + finally: + enabled = True + if orig_trace is not None: + return orig_trace(*args, **kwargs) + sys.setprofile(do_collect) + + gc.callbacks.append(gc_callback) + + # provide a way to disarm the callback + def remove(): + gc.callbacks.remove(gc_callback) + return remove + +# Function to visualize cycles adapated from refcycle: +# Copyright 2013 Mark Dickinson +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def _get_cell_type(): + def f(x=None): + return lambda: x + return type(f().__closure__[0]) + +CellType = _get_cell_type() + +def annotated_references(obj): + """ + Return known information about references held by the given object. + + Returns a mapping from referents to lists of descriptions. Note that there + may be more than one edge leading to any particular referent; hence the + need for a list. Descriptions are currently strings. + + """ + references: dict[int, list[str]] = {} + + def add_reference(name, obj): + references.setdefault(id(obj), []).append(name) + + def add_attrs(*attrs): + for attr in attrs: + if hasattr(obj, attr): + add_reference(attr, getattr(obj, attr)) + + def add_cell_references(): + try: + add_attrs("cell_contents") + except ValueError: + # if cell_contents is empty, + # accessing it raises ValueError + # in this case there is no object to + # annotate + pass + + def add_function_references(): + add_attrs("__defaults__", + "__closure__", + "__globals__", + "__code__", + "__name__", + "__module__", + "__doc__" + "__qualname__", + "__annotations__", + "__kwdefaults__") + + + def add_sequence_references(): + for position, item in enumerate(obj): + add_reference(f"[{position}]", item) + + def add_dict_references(): + for key, value in obj.items(): + add_reference("key", key) + add_reference(f"[{repr(key)}]", value) + + def add_set_references(): + for elt in obj: + add_reference("element", elt) + + def add_bound_method_references(): + add_attrs("__self__", "__func__", "im_class") + + def add_weakref_references(): + # For subclasses of weakref, we can't reliably distinguish the + # callback (if any) from other attributes. + if type(obj) is weakref.ref: + referents = gc.get_referents(obj) + if len(referents) == 1: + target = referents[0] + add_reference("__callback__", target) + + + def add_frame_references(): + f_locals = obj.f_locals + add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals") + # Some badly-behaved code replaces the f_locals dict with + # something that doesn't support the full dict interface. So we + # only continue with the annotation if f_locals is a Python dict. + if type(f_locals) is dict: + for name, local in obj.f_locals.items(): + add_reference(f"local {name}", local) + + def add_getset_descriptor_references(): + add_attrs("__objclass__", "__name__", "__doc__") + + type_based_references = { + tuple: add_sequence_references, + list: add_sequence_references, + dict: add_dict_references, + set: add_set_references, + frozenset: add_set_references, + types.FunctionType: add_function_references, + types.FrameType: add_frame_references, + CellType: add_cell_references, + types.MethodType: add_bound_method_references, + weakref.ref: add_weakref_references, + types.GetSetDescriptorType: add_getset_descriptor_references, + } + + for type_ in type(obj).__mro__: + if type_ in type_based_references: + type_based_references[type_]() + + add_attrs("__dict__", "__class__") + if isinstance(obj, type): + add_attrs("__mro__") + + return references + +############################################################################### +# Object annotations. + + +BASE_TYPES = (int, float, complex, type(None), str, bytes) +FRAME_FILENAME_LIMIT = 32 + +def object_annotation(obj): + """ + Return a string to be used for Graphviz nodes. + + The string should be short but as informative as possible. + """ + + def format_sequence(obj): + body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for i, x in zip(range(8), obj)) + if len(obj) > 8: + body = f'{body}, ...{len(obj) - 8}' + return body + + # For basic types, use the repr. + if isinstance(obj, BASE_TYPES): + return repr(obj) + if type(obj).__name__ == 'function': + return f"function\n{obj.__name__}" + elif isinstance(obj, types.MethodType): + try: + func_name = obj.__func__.__qualname__ + except AttributeError: + func_name = "" + return f"instancemethod\n{func_name}" + elif isinstance(obj, list): + return f"[{format_sequence(obj)}]" + elif isinstance(obj, tuple): + return f"({format_sequence(obj)})" + elif isinstance(obj, dict): + return f"dict[{len(obj)}]" + elif isinstance(obj, types.ModuleType): + return f"module\n{obj.__name__}" + elif isinstance(obj, type): + return f"type\n{obj.__name__}" + elif isinstance(obj, weakref.ref): + referent = obj() + if referent is None: + return "weakref (dead referent)" + else: + return f"weakref to id 0x{id(referent):x}" + elif isinstance(obj, types.FrameType): + filename = obj.f_code.co_filename + if len(filename) > FRAME_FILENAME_LIMIT: + filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):] + return f"frame\n{filename}:{obj.f_lineno}" + else: + return f"object\n{type(obj).__module__}.{type(obj).__name__}" + + + +class Node(NamedTuple): + label: str + context: Optional[str] + root: bool + referrents: list[tuple[str, int]] + +def create_graph(objects, *, context=None, filter=None): + if context is None: + context = cuda_allocation_context() + if filter is None: + filter = is_cuda_tensor + + objects = [obj for obj in objects if not isinstance(obj, weakref.ProxyTypes)] + nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects] + node_referrers: list[list[int]] = [[] for obj in objects] + + id_to_node = {id(obj): i for i, obj in enumerate(objects)} + for obj in objects: + fidx = id_to_node[id(obj)] + f = nodes[fidx] + references = annotated_references(obj) + for referrent in gc.get_referents(obj): + rid = id(referrent) + tidx = id_to_node.get(rid, None) + if tidx is None: + continue + labels = references.get(rid, ["?"]) + node_referrers[tidx].append(fidx) + for label in labels: + f.referrents.append((label, tidx)) + + to_search = [i for i, n in enumerate(nodes) if n.root] + to_keep = set() + while to_search: + idx = to_search.pop() + if idx in to_keep: + continue + to_keep.add(idx) + referrers = node_referrers[idx] + to_search.extend(referrers) + id_to_filtered_id: dict[int, int] = {} + filtered: list[Any] = [] + for i, n in enumerate(nodes): + if i in to_keep: + id_to_filtered_id[i] = len(id_to_filtered_id) + filtered.append(n) + for n in filtered: + n.referrents[:] = [(label, id_to_filtered_id[idx]) + for (label, idx) in n.referrents + if idx in id_to_filtered_id] + return filtered + +def escape(n): + return json.dumps(n) + + +def is_cuda_tensor(obj): + return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor) + +def cuda_allocation_context(): + snapshot = torch.cuda.memory._snapshot() + addr_to_frame = {} + for seg in snapshot['segments']: + addr = seg['address'] + for blk in seg['blocks']: + if blk['state'] == 'active_allocated': + frames, _real_size = _block_extra(blk) + addr_to_frame[addr] = frames + addr += blk['size'] + + def object_context(obj): + if is_cuda_tensor(obj): + addr = obj.untyped_storage().data_ptr() + frames = addr_to_frame.get(addr) + if frames is not None: + return '\n'.join(_frames_fmt(frames, full_filename=True)) + return None + return object_context + +def to_dot(nodes): + lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;'] + for i, n in enumerate(nodes): + lines.append(f'{i} [label={escape(n.label)}, color={ "red" if n.root else "black"}];') + + for i, f in enumerate(nodes): + for label, j in f.referrents: + lines.append(f'{i} -> {j} [label = {escape(label)}]') + lines.append("}\n") + return '\n'.join(lines) + +_template = """ + + + + + + +
+
+
+
+
Mouse over tensor objects to see where they were allocated.
+
+
+ + + + +""" +_listener_template = """ +document.getElementById('node{id}').addEventListener('mouseover', function(event) {{ + document.getElementById("stacktrace").textContent = {stack} +}}) +""" +def to_html(nodes): + listeners = [] + for i, n in enumerate(nodes): + if n.context is None: + continue + s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}')) + listeners.append(s) + dot = to_dot(nodes) + return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners)) + +def observe_tensor_cycles(callback): + torch.cuda.memory._record_memory_history(max_entries=100000) + + def observer(garbage): + if garbage: + if not any(is_cuda_tensor(obj) for obj in garbage): + logger.info("No CUDA Tensors found in garbage") + return + callback(to_html(create_graph(garbage))) + return observe_garbage(observer) + + +def warn_tensor_cycles(): + """ + Install a warning that reports whenever a cycle that is holding CUDA memory is observed. + + The warning produces an .html file that visualizes the cycle, + and links it to the stack frame that allocted the CUDA tensor. + + Reference cycles are freed by the cycle collector rather than being cleaned up + when the objects in the cycle first become unreachable. If a cycle points to a tensor, + the CUDA memory for that tensor will not be freed until garbage collection runs. + Accumulation of CUDA allocations can lead to out of memory errors (OOMs), as well as + non-deterministic allocation behavior which is harder to debug. + """ + logger.info("Watching Python reference cycles for CUDA Tensors.") + + def write_and_log(html): + with NamedTemporaryFile('w', suffix='.html', delete=False) as f: + f.write(html) + logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name) + return observe_tensor_cycles(write_and_log) diff --git a/phivenv/Lib/site-packages/torch/utils/weak.py b/phivenv/Lib/site-packages/torch/utils/weak.py new file mode 100644 index 0000000000000000000000000000000000000000..35810df723e71558c839127275c0a1cc81989f88 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/utils/weak.py @@ -0,0 +1,338 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections.abc as _collections_abc +import weakref + +from _weakrefset import _IterationGuard # type: ignore[attr-defined] +from collections.abc import Mapping, MutableMapping +from weakref import ref + +from torch import Tensor + + +WeakRef = ref + + +__all__ = [ + "TensorWeakRef", + "WeakIdRef", + "WeakIdKeyDictionary", + "WeakTensorKeyDictionary", +] + + +# This file defines a variant of WeakKeyDictionary that overrides the hashing +# behavior of the key to use object identity, rather than the builtin +# __eq__/__hash__ functions. This is useful for Tensor weak keys, as their +# __eq__ implementation return a Tensor (elementwise equality), which means +# you can't use them directly with the WeakKeyDictionary in standard library. +# +# Our implementation strategy is to create a wrapper weak key object, which we +# use as a key in a stock Python dictionary. This is similar to how weakref +# implements WeakKeyDictionary, but instead of using weakref.ref as the +# wrapper, we use a custom wrapper that has different __eq__ and __hash__ +# behavior. Note that we subsequently store this weak key directly in an +# ORDINARY dictionary, since the newly constructed WeakIdKey's only use would +# be a dictionary so it would have no strong references. Ensuring that +# only live WeakIdKeys are in the map is handled by putting finalizers on the +# original key object. + + +# It is simpler to implement this with composition, but if we want to +# directly reuse the callback mechanism on weakref, we need the weakref +# and the key to be exactly the same object. Reusing the callback mechanism +# minimizes the divergence between our implementation and Lib/weakref.py +# +# NB: Prefer using this when working with weakrefs of Tensors; e.g., do +# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of +# easy to get wrong cases transparently for you. +class WeakIdRef(weakref.ref): + __slots__ = ["_id"] + + def __init__(self, key, callback=None): + # Unlike stock weakref, which preserves hash semantics of the + # original object but lazily defers hash calls until the first + # time the user attempts to hash the weakref, we can eagerly + # cache the id of the key as we know this is definitely the hash + # method + self._id = id(key) + super().__init__(key, callback) # type: ignore[call-arg] + + def __call__(self): + r = super().__call__() + # Special logic for Tensor PyObject resurrection + if hasattr(r, "_fix_weakref"): + r._fix_weakref() # type: ignore[union-attr] + return r + + def __hash__(self): + return self._id + + def __eq__(self, other): + # An attractive but wrong alternate implementation is to only test if + # the stored _ids match. This can lead to an ABA problem if you have: + # + # a1 = A() + # w1 = WeakIdRef(a1) + # del a1 + # a2 = A() # suppose it gets the same ID as a1 + # w2 = WeakIdRef(a2) + # print(w1 == w2) + # + # This should be False, as a1 and a2 are unrelated (and a1 is + # dead anyway) + a = self() + b = other() + if a is not None and b is not None: + return a is b + return self is other + + +# This is the same as WeakIdRef but equality is checked using hash() rather than id. +# This will be equivalent to the one above except for classes where hash is not their id. +class _WeakHashRef(weakref.ref): + __slots__ = ["_id"] + + def __init__(self, key, callback=None): + # Unlike stock weakref, which preserves hash semantics of the + # original object but lazily defers hash calls until the first + # time the user attempts to hash the weakref, we can eagerly + # cache the id of the key as we know this is definitely the hash + # method + self._id = hash(key) + super().__init__(key, callback) # type: ignore[call-arg] + + def __call__(self): + r = super().__call__() + # Special logic for Tensor PyObject resurrection + if hasattr(r, "_fix_weakref"): + r._fix_weakref() # type: ignore[union-attr] + return r + + def __hash__(self): + return self._id + + def __eq__(self, other): + # Use hash equality to determine ref equality. + # ScriptObject implements __hash__ to return the wrapped IValue's id, so + # this is equivalent to doing an identity comparison. + a = self() + b = other() + if a is not None and b is not None: + return hash(a) == hash(b) + return self is other + + +# This is directly adapted from cpython/Lib/weakref.py +class WeakIdKeyDictionary(MutableMapping): + def __init__(self, dict=None, ref_type=WeakIdRef): # CHANGED + self.data = {} + + self.ref_type = ref_type # CHANGED + + def remove(k, selfref=ref(self)): + self = selfref() + if self is not None: + if self._iterating: + self._pending_removals.append(k) + else: + try: + del self.data[k] + except KeyError: + pass + + self._remove = remove + # A list of dead weakrefs (keys to be removed) + self._pending_removals = [] + self._iterating = set() + self._dirty_len = False + if dict is not None: + self.update(dict) + + def _commit_removals(self): + # NOTE: We don't need to call this method before mutating the dict, + # because a dead weakref never compares equal to a live weakref, + # even if they happened to refer to equal objects. + # However, it means keys may already have been removed. + pop = self._pending_removals.pop + d = self.data + while True: + try: + key = pop() + except IndexError: + return + + try: + del d[key] + except KeyError: + pass + + def _scrub_removals(self): + d = self.data + self._pending_removals = [k for k in self._pending_removals if k in d] + self._dirty_len = False + + def __delitem__(self, key): + self._dirty_len = True + del self.data[self.ref_type(key)] # CHANGED + + def __getitem__(self, key): + return self.data[self.ref_type(key)] # CHANGED + + def __len__(self): + if self._dirty_len and self._pending_removals: + # self._pending_removals may still contain keys which were + # explicitly removed, we have to scrub them (see issue #21173). + self._scrub_removals() + return len(self.data) - len(self._pending_removals) + + def __repr__(self): + return f"<{self.__class__.__name__} at {id(self):#x}>" + + def __setitem__(self, key, value): + self.data[self.ref_type(key, self._remove)] = value # CHANGED + + def copy(self): + new = WeakIdKeyDictionary() + with _IterationGuard(self): + for key, value in self.data.items(): + o = key() + if o is not None: + new[o] = value + return new + + __copy__ = copy + + def __deepcopy__(self, memo): + from copy import deepcopy + + new = self.__class__() + with _IterationGuard(self): + for key, value in self.data.items(): + o = key() + if o is not None: + new[o] = deepcopy(value, memo) + return new + + def get(self, key, default=None): + return self.data.get(self.ref_type(key), default) # CHANGED + + def __contains__(self, key): + try: + wr = self.ref_type(key) # CHANGED + except TypeError: + return False + return wr in self.data + + def items(self): + with _IterationGuard(self): + for wr, value in self.data.items(): + key = wr() + if key is not None: + yield key, value + + def keys(self): + with _IterationGuard(self): + for wr in self.data: + obj = wr() + if obj is not None: + yield obj + + __iter__ = keys + + def values(self): + with _IterationGuard(self): + for wr, value in self.data.items(): + if wr() is not None: + yield value + + def keyrefs(self): + """Return a list of weak references to the keys. + + The references are not guaranteed to be 'live' at the time + they are used, so the result of calling the references needs + to be checked before being used. This can be used to avoid + creating references that will cause the garbage collector to + keep the keys around longer than needed. + + """ + return list(self.data) + + def popitem(self): + self._dirty_len = True + while True: + key, value = self.data.popitem() + o = key() + if o is not None: + return o, value + + def pop(self, key, *args): + self._dirty_len = True + return self.data.pop(self.ref_type(key), *args) # CHANGED + + def setdefault(self, key, default=None): + return self.data.setdefault( + self.ref_type(key, self._remove), default + ) # CHANGED + + def update(self, dict=None, **kwargs): # type: ignore[override] + d = self.data + if dict is not None: + if not hasattr(dict, "items"): + dict = type({})(dict) + for key, value in dict.items(): + d[self.ref_type(key, self._remove)] = value # CHANGED + if len(kwargs): + self.update(kwargs) + + def __ior__(self, other): + self.update(other) + return self + + def __or__(self, other): + if isinstance(other, _collections_abc.Mapping): + c = self.copy() + c.update(other) + return c + return NotImplemented + + def __ror__(self, other): + if isinstance(other, _collections_abc.Mapping): + c = self.__class__() + c.update(other) + c.update(self) + return c + return NotImplemented + + # Default Mapping equality will tests keys for equality, but + # we want to test ids for equality + def __eq__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + return {id(k): v for k, v in self.items()} == { + id(k): v for k, v in other.items() + } + + +# Convenience alias +WeakTensorKeyDictionary = WeakIdKeyDictionary + + +class TensorWeakRef: + """Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref.""" + + ref: WeakRef[Tensor] + + def __init__(self, tensor: Tensor): + assert isinstance(tensor, Tensor) + self.ref = weakref.ref(tensor) + + def __call__(self): + out = self.ref() + if out is None: + return out + assert isinstance(out, Tensor) + # TODO, add _fix_weakref type binding + out._fix_weakref() # type: ignore[attr-defined] + return out diff --git a/phivenv/Lib/site-packages/torch/xpu/__init__.py b/phivenv/Lib/site-packages/torch/xpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79161d6867439171ba78618f0eba550024423354 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/xpu/__init__.py @@ -0,0 +1,561 @@ +# mypy: allow-untyped-defs +r""" +This package introduces support for the XPU backend, specifically tailored for +Intel GPU optimization. + +This package is lazily initialized, so you can always import it, and use +:func:`is_available()` to determine if your system supports XPU. +""" +import threading +import traceback +from functools import lru_cache +from typing import Any, Callable, Optional, Union + +import torch +import torch._C +from torch import device as _device +from torch._utils import _dummy_type, _LazySeedTracker + +from ._utils import _get_device_index +from .streams import Event, Stream + + +_initialized = False +_tls = threading.local() +_initialization_lock = threading.Lock() +_queued_calls: list[ + tuple[Callable[[], None], list[str]] +] = [] # don't invoke these until initialization occurs +_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False) +_device_t = Union[_device, str, int, None] +_lazy_seed_tracker = _LazySeedTracker() +default_generators: tuple[torch._C.Generator] = () # type: ignore[assignment] + + +def _is_compiled() -> bool: + r"""Return true if compile with XPU support.""" + return torch._C._has_xpu + + +if _is_compiled(): + _XpuDeviceProperties = torch._C._XpuDeviceProperties + _exchange_device = torch._C._xpu_exchangeDevice + _maybe_exchange_device = torch._C._xpu_maybeExchangeDevice +else: + # Define dummy if PyTorch was compiled without XPU + _XpuDeviceProperties = _dummy_type("_XpuDeviceProperties") # type: ignore[assignment, misc] + + def _exchange_device(device: int) -> int: + raise NotImplementedError("PyTorch was compiled without XPU support") + + def _maybe_exchange_device(device: int) -> int: + raise NotImplementedError("PyTorch was compiled without XPU support") + + +@lru_cache(maxsize=1) +def device_count() -> int: + r"""Return the number of XPU device available.""" + if not _is_compiled(): + return 0 + return torch._C._xpu_getDeviceCount() + + +def is_available() -> bool: + r"""Return a bool indicating if XPU is currently available.""" + # This function never throws. + return device_count() > 0 + + +def is_bf16_supported(including_emulation: bool = True) -> bool: + r"""Return a bool indicating if the current XPU device supports dtype bfloat16.""" + if not is_available(): + return False + return ( + including_emulation + or torch.xpu.get_device_properties().has_bfloat16_conversions + ) + + +def is_initialized(): + r"""Return whether PyTorch's XPU state has been initialized.""" + return _initialized and not _is_in_bad_fork() + + +def _lazy_call(callable, **kwargs): + if is_initialized(): + callable() + else: + global _lazy_seed_tracker + if kwargs.get("seed_all", False): + _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack()) + elif kwargs.get("seed", False): + _lazy_seed_tracker.queue_seed(callable, traceback.format_stack()) + else: + # Don't store the actual traceback to avoid memory cycle + _queued_calls.append((callable, traceback.format_stack())) + + +def init(): + r"""Initialize PyTorch's XPU state. + This is a Python API about lazy initialization that avoids initializing + XPU until the first time it is accessed. Does nothing if the XPU state is + already initialized. + """ + _lazy_init() + + +def _lazy_init(): + global _initialized, _queued_calls + if is_initialized() or hasattr(_tls, "is_initializing"): + return + with _initialization_lock: + # This test was was protected via GIL. Double-check whether XPU has + # already been initialized. + if is_initialized(): + return + # Stop promptly upon encountering a bad fork error. + if _is_in_bad_fork(): + raise RuntimeError( + "Cannot re-initialize XPU in forked subprocess. To use XPU with " + "multiprocessing, you must use the 'spawn' start method" + ) + if not _is_compiled(): + raise AssertionError("Torch not compiled with XPU enabled") + # This function inits XPU backend and detects bad fork processing. + torch._C._xpu_init() + # Some of the queued calls may reentrantly call _lazy_init(); We need to + # just return without initializing in that case. + _tls.is_initializing = True + + _queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls) + + try: + for queued_call, orig_traceback in _queued_calls: + try: + queued_call() + except Exception as e: + msg = ( + f"XPU call failed lazily at initialization with error: {str(e)}\n\n" + f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}" + ) + raise Exception(msg) from e # noqa: TRY002 + finally: + delattr(_tls, "is_initializing") + _initialized = True + + +class _DeviceGuard: + def __init__(self, index: int): + self.idx = index + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch.xpu._exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch.xpu._maybe_exchange_device(self.prev_idx) + return False + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int or str): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device: Any): + self.idx = _get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch.xpu._exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): + self.idx = torch.xpu._maybe_exchange_device(self.prev_idx) + return False + + +class device_of(device): + r"""Context-manager that changes the current device to that of given object. + + You can use both tensors and storages as arguments. If a given object is + not allocated on a XPU, this is a no-op. + + Args: + obj (Tensor or Storage): object allocated on the selected device. + """ + + def __init__(self, obj): + idx = obj.get_device() if obj.is_xpu else -1 + super().__init__(idx) + + +def set_device(device: _device_t) -> None: + r"""Set the current device. + + Args: + device (torch.device or int or str): selected device. This function is a + no-op if this argument is negative. + """ + _lazy_init() + device = _get_device_index(device) + if device >= 0: + torch._C._xpu_setDevice(device) + + +def get_device_name(device: Optional[_device_t] = None) -> str: + r"""Get the name of a device. + + Args: + device (torch.device or int or str, optional): device for which to + return the name. This function is a no-op if this argument is a + negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + + Returns: + str: the name of the device + """ + return get_device_properties(device).name + + +@lru_cache(None) +def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: + r"""Get the xpu capability of a device. + + Args: + device (torch.device or int or str, optional): device for which to + return the device capability. This function is a no-op if this + argument is a negative integer. It uses the current device, given by + :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` + (default). + + Returns: + Dict[str, Any]: the xpu capability dictionary of the device + """ + props = get_device_properties(device) + # pybind service attributes are no longer needed and their presence breaks + # the further logic related to the serialization of the created dictionary. + # In particular it filters out `` + # to fix Triton tests. + # This field appears after updating pybind to 2.13.6. + return { + prop: getattr(props, prop) + for prop in dir(props) + if not prop.startswith(("__", "_pybind11_")) + } + + +def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties: + r"""Get the properties of a device. + + Args: + device (torch.device or int or str): device for which to return the + properties of the device. + + Returns: + _XpuDeviceProperties: the properties of the device + """ + _lazy_init() + device = _get_device_index(device, optional=True) + return _get_device_properties(device) # type: ignore[name-defined] # noqa: F821 + + +def current_device() -> int: + r"""Return the index of a currently selected device.""" + _lazy_init() + return torch._C._xpu_getDevice() + + +def _get_device(device: Union[int, str, torch.device]) -> torch.device: + r"""Return the torch.device type object from the passed in device. + + Args: + device (torch.device or int or str): selected device. + """ + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("xpu", device) + return device + + +class StreamContext: + r"""Context-manager that selects a given stream. + + All XPU kernels queued within its context will be enqueued on a selected + stream. + + Args: + Stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. + """ + cur_stream: Optional["torch.xpu.Stream"] + + def __init__(self, stream: Optional["torch.xpu.Stream"]): + self.stream = stream + self.idx = _get_device_index(None, True) + if self.idx is None: + self.idx = -1 + + def __enter__(self): + cur_stream = self.stream + if cur_stream is None or self.idx == -1: + return + self.src_prev_stream = torch.xpu.current_stream(None) + + # If the stream is not on the current device, then set the current stream on the device + if self.src_prev_stream.device != cur_stream.device: + with device(cur_stream.device): + self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device) + torch.xpu.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + cur_stream = self.stream + if cur_stream is None or self.idx == -1: + return + + # Reset the stream on the original device and destination device + if self.src_prev_stream.device != cur_stream.device: + torch.xpu.set_stream(self.dst_prev_stream) + torch.xpu.set_stream(self.src_prev_stream) + + +def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext: + r"""Wrap around the Context-manager StreamContext that selects a given stream. + + Arguments: + stream (Stream): selected stream. This manager is a no-op if it's ``None``. + """ + return StreamContext(stream) + + +def _set_stream_by_id(stream_id, device_index, device_type): + r"""set stream specified by the stream id, device index and device type + + Args: stream_id (int): not visible to the user, used to assigned to the specific stream. + device_index (int): selected device index. + device_type (int): selected device type. + """ + torch._C._xpu_setStream( + stream_id=stream_id, + device_index=device_index, + device_type=device_type, + ) + + +def set_stream(stream: Stream): + r"""Set the current stream.This is a wrapper API to set the stream. + Usage of this function is discouraged in favor of the ``stream`` + context manager. + + Args: + stream (Stream): selected stream. This function is a no-op + if this argument is ``None``. + """ + if stream is None: + return + _lazy_init() + _set_stream_by_id( + stream_id=stream.stream_id, + device_index=stream.device_index, + device_type=stream.device_type, + ) + + +def current_stream(device: Optional[_device_t] = None) -> Stream: + r"""Return the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): selected device. Returns + the currently selected :class:`Stream` for the current device, given + by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` + (default). + """ + _lazy_init() + streamdata = torch._C._xpu_getCurrentStream( + _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) + + +def get_stream_from_external( + data_ptr: int, device: Optional[_device_t] = None +) -> Stream: + r"""Return a :class:`Stream` from an external SYCL queue. + + This function is used to wrap SYCL queue created in other libraries in order + to facilitate data exchange and multi-library interactions. + + .. note:: This function doesn't manage the queue life-cycle, it is the user + responsibility to keep the referenced queue alive while this returned stream is + being used. The different SYCL queue pointers will result in distinct + :class:`Stream` objects, even if the SYCL queues they dereference are equivalent. + + Args: + data_ptr(int): Integer representation of the `sycl::queue*` value passed externally. + device(torch.device or int, optional): the device where the queue was originally created. + It is the user responsibility to ensure the device is specified correctly. + """ + _lazy_init() + streamdata = torch._C._xpu_getStreamFromExternal( + data_ptr, _get_device_index(device, optional=True) + ) + return Stream( + stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] + ) + + +def synchronize(device: _device_t = None) -> None: + r"""Wait for all kernels in all streams on a XPU device to complete. + + Args: + device (torch.device or int, optional): device for which to synchronize. + It uses the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + _lazy_init() + device = _get_device_index(device, optional=True) + return torch._C._xpu_synchronize(device) + + +def get_arch_list() -> list[str]: + r"""Return list XPU architectures this library was compiled for.""" + if not _is_compiled(): + return [] + arch_flags = torch._C._xpu_getArchFlags() + if arch_flags is None: + return [] + return arch_flags.split() + + +def get_gencode_flags() -> str: + r"""Return XPU AOT(ahead-of-time) build flags this library was compiled with.""" + arch_list = get_arch_list() + if len(arch_list) == 0: + return "" + return f'-device {",".join(arch for arch in arch_list)}' + + +def _get_generator(device: torch.device) -> torch._C.Generator: + r"""Return the XPU Generator object for the given device. + + Args: + device (torch.device): selected device. + """ + idx = device.index + if idx is None: + idx = current_device() + return torch.xpu.default_generators[idx] + + +def _set_rng_state_offset( + offset: int, device: Union[int, str, torch.device] = "xpu" +) -> None: + r"""Set the random number generator state offset of the specified GPU. + + Args: + offset (int): The desired offset + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). + """ + final_device = _get_device(device) + + def cb(): + default_generator = _get_generator(final_device) + default_generator.set_offset(offset) + + _lazy_call(cb) + + +def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: + r"""Return the random number generator state offset of the specified GPU. + + Args: + device (torch.device or int, optional): The device to return the RNG state offset of. + Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). + + .. warning:: + This function eagerly initializes XPU. + """ + _lazy_init() + final_device = _get_device(device) + default_generator = _get_generator(final_device) + return default_generator.get_offset() + + +# import here to avoid circular import +from .memory import ( + empty_cache, + max_memory_allocated, + max_memory_reserved, + mem_get_info, + memory_allocated, + memory_reserved, + memory_stats, + memory_stats_as_nested_dict, + reset_accumulated_memory_stats, + reset_peak_memory_stats, +) +from .random import ( + get_rng_state, + get_rng_state_all, + initial_seed, + manual_seed, + manual_seed_all, + seed, + seed_all, + set_rng_state, + set_rng_state_all, +) + + +__all__ = [ + "Event", + "Stream", + "StreamContext", + "current_device", + "current_stream", + "default_generators", + "device", + "device_of", + "device_count", + "empty_cache", + "get_arch_list", + "get_device_capability", + "get_device_name", + "get_device_properties", + "get_gencode_flags", + "get_rng_state", + "get_rng_state_all", + "get_stream_from_external", + "init", + "initial_seed", + "is_available", + "is_bf16_supported", + "is_initialized", + "manual_seed", + "manual_seed_all", + "max_memory_allocated", + "max_memory_reserved", + "mem_get_info", + "memory_allocated", + "memory_reserved", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", + "seed", + "seed_all", + "set_device", + "set_rng_state", + "set_rng_state_all", + "set_stream", + "stream", + "streams", + "synchronize", +] diff --git a/phivenv/Lib/site-packages/torch/xpu/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/xpu/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e8e35a3ae37eb65c5ff284db80301adee41054d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/xpu/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/xpu/__pycache__/_gpu_trace.cpython-39.pyc b/phivenv/Lib/site-packages/torch/xpu/__pycache__/_gpu_trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..102d91ba5b34e70b115a6affcdf108a8592476f8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/xpu/__pycache__/_gpu_trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/xpu/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/xpu/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be7acf4cf4d6395228af92acd30f7e58b0c00c89 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/xpu/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/xpu/__pycache__/memory.cpython-39.pyc b/phivenv/Lib/site-packages/torch/xpu/__pycache__/memory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f161e878e743ca10fd8c29a65c6a75f0e54659d0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/xpu/__pycache__/memory.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/xpu/__pycache__/random.cpython-39.pyc b/phivenv/Lib/site-packages/torch/xpu/__pycache__/random.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a1c251332b0f1c863d58f6beab65aa591ef7167 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/xpu/__pycache__/random.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/xpu/__pycache__/streams.cpython-39.pyc b/phivenv/Lib/site-packages/torch/xpu/__pycache__/streams.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46914035f56d3ee5611ca42dc6e0cc87bcdb105a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/xpu/__pycache__/streams.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/xpu/_gpu_trace.py b/phivenv/Lib/site-packages/torch/xpu/_gpu_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..ee6e9c2e97571c2a3391e0c149245a9c81885827 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/xpu/_gpu_trace.py @@ -0,0 +1,69 @@ +from typing import Callable + +from torch._utils import CallbackRegistry + + +EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry("XPU event creation") +EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry("XPU event deletion") +EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( + "XPU event record" +) +EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry("XPU event wait") +MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "XPU memory allocation" +) +MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "XPU memory deallocation" +) +StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "XPU stream creation" +) +DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry( + "XPU device synchronization" +) +StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "XPU stream synchronization" +) +EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( + "XPU event synchronization" +) + + +def register_callback_for_event_creation(cb: Callable[[int], None]) -> None: + EventCreationCallbacks.add_callback(cb) + + +def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None: + EventDeletionCallbacks.add_callback(cb) + + +def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None: + EventRecordCallbacks.add_callback(cb) + + +def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None: + EventWaitCallbacks.add_callback(cb) + + +def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None: + MemoryAllocationCallbacks.add_callback(cb) + + +def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None: + MemoryDeallocationCallbacks.add_callback(cb) + + +def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None: + StreamCreationCallbacks.add_callback(cb) + + +def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None: + DeviceSynchronizationCallbacks.add_callback(cb) + + +def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None: + StreamSynchronizationCallbacks.add_callback(cb) + + +def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None: + EventSynchronizationCallbacks.add_callback(cb) diff --git a/phivenv/Lib/site-packages/torch/xpu/_utils.py b/phivenv/Lib/site-packages/torch/xpu/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..512e8474ee443788db8bc2de79e83a03b9173693 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/xpu/_utils.py @@ -0,0 +1,39 @@ +from typing import Any + +import torch + +# The _get_device_index has been moved to torch.utils._get_device_index +from torch._utils import _get_device_index as _torch_get_device_index + + +def _get_device_index( + device: Any, optional: bool = False, allow_cpu: bool = False +) -> int: + r"""Get the device index from :attr:`device`, which can be a torch.device + object, a Python integer, or ``None``. + + If :attr:`device` is a torch.device object, returns the device index if it + is a XPU device. Note that for a XPU device without a specified index, + i.e., ``torch.device('xpu')``, this will return the current default XPU + device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``, + CPU devices will be accepted and ``-1`` will be returned in this case. + + If :attr:`device` is a Python integer, it is returned as is. + + If :attr:`device` is ``None``, this will return the current default XPU + device if :attr:`optional` is ``True``. + """ + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + if isinstance(device, torch.device): + if allow_cpu: + if device.type not in ["xpu", "cpu"]: + raise ValueError(f"Expected a xpu or cpu device, but got: {device}") + elif device.type != "xpu": + raise ValueError(f"Expected a xpu device, but got: {device}") + if not torch.jit.is_scripting(): + if isinstance(device, torch.xpu.device): + return device.idx + return _torch_get_device_index(device, optional, allow_cpu) diff --git a/phivenv/Lib/site-packages/torch/xpu/memory.py b/phivenv/Lib/site-packages/torch/xpu/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..d760ef3dcd32af9c993f97253b5391c21f1b1775 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/xpu/memory.py @@ -0,0 +1,208 @@ +import collections +from typing import Any, Union + +import torch +from torch.types import Device + +from . import _get_device_index, is_initialized + + +_device_t = Union[Device, str, int, None] + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other XPU application. + + .. note:: + :func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU + memory available for PyTorch. However, it may help reduce fragmentation + of XPU memory in certain cases. + """ + if is_initialized(): + torch._C._xpu_emptyCache() + + +def reset_peak_memory_stats(device: _device_t = None) -> None: + r"""Reset the "peak" stats tracked by the XPU memory allocator. + + See :func:`~torch.xpu.memory_stats` for details. Peak stats correspond to the + `"peak"` key in each individual stat dict. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + device = _get_device_index(device, optional=True) + return torch._C._xpu_resetPeakMemoryStats(device) + + +def reset_accumulated_memory_stats(device: _device_t = None) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the XPU memory allocator. + + See :func:`~torch.xpu.memory_stats` for details. Accumulated stats correspond to + the `"allocated"` and `"freed"` keys in each individual stat dict. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + device = _get_device_index(device, optional=True) + return torch._C._xpu_resetAccumulatedMemoryStats(device) + + +def memory_stats_as_nested_dict(device: _device_t = None) -> dict[str, Any]: + r"""Return the result of :func:`~torch.xpu.memory_stats` as a nested dictionary.""" + if not is_initialized(): + return {} + device = _get_device_index(device, optional=True) + return torch._C._xpu_memoryStats(device) + + +def memory_stats(device: _device_t = None) -> dict[str, Any]: + r"""Return a dictionary of XPU memory allocator statistics for a given device. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + memory requested by client code, compare this with allocated_bytes to check if + allocation rounding adds too much overhead. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool (for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool (for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistics for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + result = [] + + def _recurse_add_to_result(prefix: str, obj: Any) -> None: + if isinstance(obj, dict): + if len(prefix) > 0: + prefix += "." + for k, v in obj.items(): + _recurse_add_to_result(prefix + k, v) + else: + result.append((prefix, obj)) + + stats = memory_stats_as_nested_dict(device=device) + _recurse_add_to_result("", stats) + result.sort() + + return collections.OrderedDict(result) + + +def memory_allocated(device: _device_t = None) -> int: + r"""Return the current GPU memory occupied by tensors in bytes for a given device. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + This is likely less than the amount shown in `xpu-smi` since some + unused memory can be held by the caching allocator and some context + needs to be created on GPU. + """ + return memory_stats(device=device).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device: _device_t = None) -> int: + r"""Return the maximum GPU memory occupied by tensors in bytes for a given device. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. For example, these two + functions can measure the peak allocated memory usage of each iteration in a + training loop. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + return memory_stats(device=device).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device: _device_t = None) -> int: + r"""Return the current GPU memory managed by the caching allocator in bytes for a given device. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + return memory_stats(device=device).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device: _device_t = None) -> int: + r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. For example, these two functions + can measure the peak cached memory amount of each iteration in a training + loop. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + return memory_stats(device=device).get("reserved_bytes.all.peak", 0) + + +def mem_get_info(device: _device_t = None) -> tuple[int, int]: + r"""Return the global free and total GPU memory for a given device. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + + Returns: + int: the memory available on the device in units of bytes. + int: the total memory on the device in units of bytes + """ + device = _get_device_index(device, optional=True) + return torch._C._xpu_getMemoryInfo(device) + + +__all__ = [ + "empty_cache", + "max_memory_allocated", + "max_memory_reserved", + "mem_get_info", + "memory_allocated", + "memory_reserved", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", +] diff --git a/phivenv/Lib/site-packages/torch/xpu/random.py b/phivenv/Lib/site-packages/torch/xpu/random.py new file mode 100644 index 0000000000000000000000000000000000000000..403febf25866a33850ce035b9eee7e01c83766e1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/xpu/random.py @@ -0,0 +1,177 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterable +from typing import Union + +import torch +from torch import Tensor + +from . import _lazy_call, _lazy_init, current_device, device_count + + +def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor: + r"""Return the random number generator state of the specified GPU as a ByteTensor. + + Args: + device (torch.device or int, optional): The device to return the RNG state of. + Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). + + .. warning:: + This function eagerly initializes XPU. + """ + _lazy_init() + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("xpu", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch.xpu.default_generators[idx] + return default_generator.get_state() + + +def get_rng_state_all() -> list[Tensor]: + r"""Return a list of ByteTensor representing the random number states of all devices.""" + results = [get_rng_state(i) for i in range(device_count())] + return results + + +def set_rng_state( + new_state: Tensor, device: Union[int, str, torch.device] = "xpu" +) -> None: + r"""Set the random number generator state of the specified GPU. + + Args: + new_state (torch.ByteTensor): The desired state + device (torch.device or int, optional): The device to set the RNG state. + Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). + """ + with torch._C._DisableFuncTorch(): + new_state_copy = new_state.clone(memory_format=torch.contiguous_format) + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("xpu", device) + + def cb(): + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch.xpu.default_generators[idx] + default_generator.set_state(new_state_copy) + + _lazy_call(cb) + + +def set_rng_state_all(new_states: Iterable[Tensor]) -> None: + r"""Set the random number generator state of all devices. + + Args: + new_states (Iterable of torch.ByteTensor): The desired state for each device. + """ + for i, state in enumerate(new_states): + set_rng_state(state, i) + + +def manual_seed(seed: int) -> None: + r"""Set the seed for generating random numbers for the current GPU. + + It's safe to call this function if XPU is not available; in that case, it is silently ignored. + + Args: + seed (int): The desired seed. + + .. warning:: + If you are working with a multi-GPU model, this function is insufficient + to get determinism. To seed all GPUs, use :func:`manual_seed_all`. + """ + seed = int(seed) + + def cb(): + idx = current_device() + default_generator = torch.xpu.default_generators[idx] + default_generator.manual_seed(seed) + + _lazy_call(cb, seed=True) + + +def manual_seed_all(seed: int) -> None: + r"""Set the seed for generating random numbers on all GPUs. + + It's safe to call this function if XPU is not available; in that case, it is silently ignored. + + Args: + seed (int): The desired seed. + """ + seed = int(seed) + + def cb(): + for i in range(device_count()): + default_generator = torch.xpu.default_generators[i] + default_generator.manual_seed(seed) + + _lazy_call(cb, seed_all=True) + + +def seed() -> None: + r"""Set the seed for generating random numbers to a random number for the current GPU. + + It's safe to call this function if XPU is not available; in that case, it is silently ignored. + + .. warning:: + If you are working with a multi-GPU model, this function will only initialize + the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. + """ + + def cb(): + idx = current_device() + default_generator = torch.xpu.default_generators[idx] + default_generator.seed() + + _lazy_call(cb) + + +def seed_all() -> None: + r"""Set the seed for generating random numbers to a random number on all GPUs. + + It's safe to call this function if XPU is not available; in that case, it is silently ignored. + """ + + def cb(): + random_seed = 0 + seeded = False + for i in range(device_count()): + default_generator = torch.xpu.default_generators[i] + if not seeded: + default_generator.seed() + random_seed = default_generator.initial_seed() + seeded = True + else: + default_generator.manual_seed(random_seed) + + _lazy_call(cb) + + +def initial_seed() -> int: + r"""Return the current random seed of the current GPU. + + .. warning:: + This function eagerly initializes XPU. + """ + _lazy_init() + idx = current_device() + default_generator = torch.xpu.default_generators[idx] + return default_generator.initial_seed() + + +__all__ = [ + "get_rng_state", + "get_rng_state_all", + "set_rng_state", + "set_rng_state_all", + "manual_seed", + "manual_seed_all", + "seed", + "seed_all", + "initial_seed", +] diff --git a/phivenv/Lib/site-packages/torch/xpu/streams.py b/phivenv/Lib/site-packages/torch/xpu/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..202de0daf2a90d95f0c37e0e042978840a032eca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/xpu/streams.py @@ -0,0 +1,173 @@ +# mypy: allow-untyped-defs +import ctypes + +import torch +from torch._utils import _dummy_type + + +if not hasattr(torch._C, "_XpuStreamBase"): + # Define dummy base classes + torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase") + torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase") + + +class Stream(torch._C._XpuStreamBase): + r"""Wrapper around a XPU stream. + + A XPU stream is a linear sequence of execution that belongs to a specific + device, independent from other streams. It supports with statement as a + context manager to ensure the operators within the with block are running + on the corresponding stream. + + Args: + device(torch.device or int, optional): a device on which to allocate + the stream. If :attr:`device` is ``None`` (default) or a negative + integer, this will use the current device. + priority(int, optional): priority of the stream, which can be positive, 0, or negative. + A lower number indicates a higher priority. By default, the priority is set to 0. + If the value falls outside of the allowed priority range, it will automatically be + mapped to the nearest valid priority (lowest for large positive numbers or + highest for large negative numbers). + """ + + def __new__(cls, device=None, priority=0, **kwargs): + # setting device manager is expensive, so we avoid it unless necessary + if device is None or ("stream_id" in kwargs and "device_index" in kwargs): + return super().__new__(cls, priority=priority, **kwargs) + else: + with torch.xpu.device(device): + return super().__new__(cls, priority=priority, **kwargs) + + def wait_event(self, event) -> None: + r"""Make all future work submitted to the stream wait for an event. + + Args: + event (torch.xpu.Event): an event to wait for. + """ + event.wait(self) + + def wait_stream(self, stream) -> None: + r"""Synchronize with another stream. + + All future work submitted to this stream will wait until all kernels + submitted to a given stream at the time of call complete. + + Args: + stream (Stream): a stream to synchronize. + """ + self.wait_event(stream.record_event()) + + def record_event(self, event=None): + r"""Record an event. + + Args: + event (torch.xpu.Event, optional): event to record. If not given, a new one + will be allocated. + + Returns: + Recorded event. + """ + if event is None: + event = Event() + event.record(self) + return event + + def query(self) -> bool: + r"""Check if all the work submitted has been completed. + + Returns: + A boolean indicating if all kernels in this stream are completed. + """ + return super().query() + + def synchronize(self) -> None: + r"""Wait for all the kernels in this stream to complete.""" + super().synchronize() + + @property + def _as_parameter_(self): + return ctypes.c_void_p(self.sycl_queue) + + def __eq__(self, o): + if isinstance(o, Stream): + return super().__eq__(o) + return False + + def __hash__(self): + return hash((self.sycl_queue, self.device)) + + def __repr__(self): + return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})" + + +class Event(torch._C._XpuEventBase): + r"""Wrapper around a XPU event. + + XPU events are synchronization markers that can be used to monitor the + device's progress, and to synchronize XPU streams. + + The underlying XPU events are lazily initialized when the event is first + recorded. After creation, only streams on the same device may record the + event. However, streams on any device can wait on the event. + + Args: + enable_timing (bool, optional): indicates if the event should measure time + (default: ``False``) + """ + + def __new__(cls, enable_timing=False): + return super().__new__(cls, enable_timing=enable_timing) + + def record(self, stream=None) -> None: + r"""Record the event in a given stream. + + Uses ``torch.xpu.current_stream()`` if no stream is specified. The + stream's device must match the event's device. + """ + if stream is None: + stream = torch.xpu.current_stream() + super().record(stream) + + def wait(self, stream=None) -> None: + r"""Make all future work submitted to the given stream wait for this event. + + Use ``torch.xpu.current_stream()`` if no stream is specified. + """ + if stream is None: + stream = torch.xpu.current_stream() + super().wait(stream) + + def query(self) -> bool: + r"""Check if all work currently captured by event has completed. + + Returns: + A boolean indicating if all work currently captured by event has + completed. + """ + return super().query() + + def elapsed_time(self, end_event): + r"""Return the time elapsed. + + Time reported in milliseconds after the event was recorded and + before the end_event was recorded. + """ + return super().elapsed_time(end_event) + + def synchronize(self) -> None: + r"""Wait for the event to complete. + + Waits until the completion of all work currently captured in this event. + This prevents the CPU thread from proceeding until the event completes. + """ + super().synchronize() + + @property + def _as_parameter_(self): + return ctypes.c_void_p(self.sycl_event) + + def __repr__(self): + if self.sycl_event: + return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})" + else: + return "torch.xpu.Event(uninitialized)"